tempor.models.ddh module

Model components for the Dynamic DeepHit implementation.

tempor.models.ddh.get_padded_features(x: ndarray | list[ndarray], pad_size: int | None = None, fill: float = nan) ndarray[source]

Helper function to pad variable length RNN inputs with nans.

class tempor.models.ddh.DynamicDeepHitModel(split: int = 100, n_layers_hidden: int = 2, n_units_hidden: int = 100, rnn_mode: str = 'LSTM', dropout: float = 0.1, alpha: float = 0.1, beta: float = 0.1, sigma: float = 0.1, patience: int = 20, lr: float = 0.001, batch_size: int = 100, n_iter: int = 1000, device: Any = device(type='cpu'), val_size: float = 0.1, random_state: int = 0, clipping_value: int = 1, output_mode: str = 'MLP')[source]

Bases: object

Dynamic DeepHit model implementation.

This implementation considers that the last event happen at the same time for each patient. The CIF is therefore simplified.

fit(x: ndarray, t: ndarray, e: ndarray) Self[source]

Fit the model to the data.

Parameters:
x : np.ndarray

Covariates.

t : np.ndarray

Event times.

e : np.ndarray

Event values.

Returns:

Trained model.

Return type:

Self

discretize(t: ndarray | list[ndarray], split: int, split_time: list[float] | None = None) tuple[source]

Discretize the survival horizon.

Parameters:
t : Union[np.ndarray, List[np.ndarray]]

Times of events.

split : int

Number of bins.

split_time : Optional[List[float]], optional

List of bins (must be same length than split). Provide if split is not provided. Defaults to None.

Returns:

(t_discretized, split_time) Discretized events time; split time.

Return type:

Tuple

predict_emb(x: ndarray) Tensor[source]

Predict the embedding of the data.

Parameters:
x : np.ndarray

Covariates.

Returns:

Embedding of the data.

Return type:

torch.Tensor

predict_survival(x: ndarray, t: list, risk: int = 1, all_step: bool = False, batch_size: int = 100) ndarray[source]

Predict the survival function.

Parameters:
x : np.ndarray

Covariates.

t : List

Times to predict the survival function.

risk : int, optional

Risk value. Defaults to 1.

all_step : bool, optional

Predict all steps. Defaults to False.

batch_size : int, optional

Batch size. Defaults to 100.

Returns:

Array of survival function values.

Return type:

np.ndarray

predict_risk(x: ndarray, t: list, **kwargs: Any) ndarray[source]

Predict the risk.

Parameters:
x : np.ndarray

Covariates.

t : List

Times to predict the risk.

**kwargs : Any

Additional arguments passed to predict_survival.

Returns:

Array of risk values.

Return type:

np.ndarray

negative_log_likelihood(outcomes: Tensor, cif: list[Tensor], t: Tensor, e: Tensor) Tensor[source]

Compute the log likelihood loss.

This function is used to compute the survival loss.

ranking_loss(cif: list[Tensor], t: Tensor, e: Tensor) Tensor[source]

Penalize wrong ordering of probability.

Equivalent to a C Index. This function is used to penalize wrong ordering in the survival prediction.

longitudinal_loss(longitudinal_prediction: Tensor, x: Tensor) Tensor[source]

Penalize error in the longitudinal predictions. This function is used to compute the error made by the RNN.

NB: In the paper, they seem to use different losses for continuous and categorical, but this was not reflected in the code associated (therefore we compute MSE for all).

NB: Original paper mentions possibility of different alphas for each risk, but takes same for all (for ranking loss).

total_loss(x: Tensor, t: Tensor, e: Tensor) Tensor[source]

Compute total loss.

class tempor.models.ddh.DynamicDeepHitLayers(input_dim: int, seq_len: int, output_dim: int, layers_rnn: int, hidden_rnn: int, rnn_type: str = 'LSTM', dropout: float = 0.1, risks: int = 1, output_type: str = 'MLP', device: Any = device(type='cpu'))[source]

Bases: Module

Dynamic DeepHit layers component.

training : bool
forward_attention(x: Tensor, inputmask: Tensor, hidden: Tensor) Tensor[source]

Forward attention implementation.

forward_emb(x: Tensor) tuple[Tensor, Tensor][source]

The forward function that is called when data is passed through DynamicDeepHit.

forward(x: Tensor) tuple[Tensor, list][source]

The forward function that is called when data is passed through DynamicDeepHit.