tempor.methods.time_to_event.plugin_ddh module

Dynamic DeepHit survival analysis model.

class tempor.methods.time_to_event.plugin_ddh.DynamicDeepHitTimeToEventAnalysisParams(n_iter: int = 1000, batch_size: int = 100, lr: float = 0.001, n_layers_hidden: int = 1, n_units_hidden: int = 40, split: int = 100, rnn_mode: 'GRU' | 'LSTM' | 'RNN' | 'Transformer' = 'GRU', alpha: float = 0.34, beta: float = 0.27, sigma: float = 0.21, dropout: float = 0.06, device: str = 'cpu', val_size: float = 0.1, patience: int = 20, output_mode: 'MLP' | 'LSTM' | 'GRU' | 'RNN' | 'Transformer' | 'TCN' | 'InceptionTime' | 'InceptionTimePlus' | 'ResCNN' | 'XCM' = 'MLP', random_state: int = 0)[source]

Bases: object

n_iter : int = 1000

Number of training epochs.

batch_size : int = 100

Training batch size.

lr : float = 0.001

Training learning rate.

n_layers_hidden : int = 1

Number of hidden layers in the network.

n_units_hidden : int = 40

Number of units for each hidden layer.

split : int = 100

Number of discrete buckets.

rnn_mode : Literal[GRU] | Literal[LSTM] | Literal[RNN] | Literal[Transformer] = 'GRU'

Internal temporal architecture, one of RnnMode.

alpha : float = 0.34

Weighting (0, 1) likelihood and rank loss (L2 in paper). 1 gives only likelihood, and 0 gives only rank loss.

beta : float = 0.27

Beta, see paper.

sigma : float = 0.21

From eta in rank loss (L2 in paper).

dropout : float = 0.06

Network dropout value.

device : str = 'cpu'

PyTorch Device.

val_size : float = 0.1

size of validation set.

Type:

Early stopping

patience : int = 20

training patience without any improvement.

Type:

Early stopping

output_mode : Literal[MLP] | Literal[LSTM] | Literal[GRU] | Literal[RNN] | Literal[Transformer] | Literal[TCN] | Literal[InceptionTime] | Literal[InceptionTimePlus] | Literal[ResCNN] | Literal[XCM] = 'MLP'

Output network, on of OutputMode.

random_state : int = 0

Random seed.

class tempor.methods.time_to_event.plugin_ddh.DynamicDeepHitTimeToEventAnalysis(**params: Any)[source]

Bases: BaseTimeToEventAnalysis, DDHEmbedding

Dynamic DeepHit survival analysis model.

Note

Current implementation has the following limitations:
  • Only one output feature is supported (no competing risks).

  • Risk prediction for time points beyond the last event time in the dataset may throw errors.

Parameters:
**params : Any

Parameters and defaults as defined in DynamicDeepHitTimeToEventAnalysisParams.

References

“Dynamic-DeepHit: A Deep Learning Approach for Dynamic Survival Analysis With Competing Risks Based on Longitudinal Data”, Changhee Lee, Jinsung Yoon, Mihaela van der Schaar.

category : ClassVar[plugin_typing.PluginCategory] = 'time_to_event'

Plugin category, such as 'prediction.one_off.classification'. Must be set by the plugin class using @register_plugin.

name : ClassVar[plugin_typing.PluginName] = 'dynamic_deephit'

Plugin name, such as 'my_nn_classifier'. Must be set by the plugin class using @register_plugin.

plugin_type : ClassVar[plugin_typing.PluginTypeArg] = 'method'

Plugin type, such as 'method'. May be optionally set by the plugin class using @register_plugin, else will set the default plugin type.

ParamsDefinition

alias of DynamicDeepHitTimeToEventAnalysisParams

params : DynamicDeepHitTimeToEventAnalysisParams
static hyperparameter_space(*args: Any, **kwargs: Any) list[Params][source]

The hyperparameter search domain, used for tuning.

Can provide variadics *args and **kwargs, these will be received from sample_hyperparameters.