"""Model components for the Dynamic DeepHit implementation."""
import warnings
from copy import deepcopy
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from typing_extensions import Literal, Self, get_args
from tempor.models import constants
from .mlp import MLP
from .transformer import TransformerModel
from .ts_model import TimeSeriesLayer
RnnMode = Literal[
"GRU",
"LSTM",
"RNN",
"Transformer",
]
OutputMode = Literal[
"MLP",
"LSTM",
"GRU",
"RNN",
"Transformer",
"TCN",
"InceptionTime",
"InceptionTimePlus",
"ResCNN",
"XCM",
]
rnn_modes = get_args(RnnMode)
output_modes = get_args(OutputMode)
[docs]def get_padded_features(
x: Union[np.ndarray, List[np.ndarray]], pad_size: Optional[int] = None, fill: float = np.nan
) -> np.ndarray:
"""Helper function to pad variable length RNN inputs with nans."""
if pad_size is None:
pad_size = max([len(x_) for x_ in x])
x_padded = []
for i in range(len(x)):
if pad_size == len(x[i]):
x_padded.append(x[i].astype(float))
elif pad_size > len(x[i]):
pads = fill * np.ones((pad_size - len(x[i]),) + x[i].shape[1:])
x_padded.append(np.concatenate([x[i], pads]).astype(float))
else:
x_padded.append(x[i][:pad_size].astype(float))
return np.asarray(x_padded)
[docs]class DynamicDeepHitModel:
def __init__(
self,
split: int = 100,
n_layers_hidden: int = 2,
n_units_hidden: int = 100,
rnn_mode: str = "LSTM",
dropout: float = 0.1,
alpha: float = 0.1,
beta: float = 0.1,
sigma: float = 0.1,
patience: int = 20,
lr: float = 1e-3,
batch_size: int = 100,
n_iter: int = 1000,
device: Any = constants.DEVICE,
val_size: float = 0.1,
random_state: int = 0,
clipping_value: int = 1,
output_mode: str = "MLP",
) -> None:
"""Dynamic DeepHit model implementation.
This implementation considers that the last event happen at the same time for each patient.
The CIF is therefore simplified.
"""
self.split = split
self.split_time = None
self.pad_size = 0
self.layers_rnn = n_layers_hidden
self.hidden_rnn = n_units_hidden
self.rnn_type = rnn_mode
self.alpha = alpha
self.beta = beta
self.sigma = sigma
self.device = torch.device(device) if isinstance(device, str) else device
self.dropout = dropout
self.lr = lr
self.n_iter = n_iter
self.batch_size = batch_size
self.val_size = val_size
self.clipping_value = clipping_value
self.patience = patience
self.random_state = random_state
self.output_type = output_mode
self.model: Optional[DynamicDeepHitLayers] = None
def _setup_model(self, inputdim: int, seqlen: int, risks: int) -> "DynamicDeepHitLayers":
return (
DynamicDeepHitLayers(
inputdim,
seqlen,
self.split,
self.layers_rnn,
self.hidden_rnn,
rnn_type=self.rnn_type,
dropout=self.dropout,
risks=risks,
device=self.device,
output_type=self.output_type,
)
.float()
.to(self.device)
)
[docs] def fit(
self,
x: np.ndarray,
t: np.ndarray,
e: np.ndarray,
) -> Self:
"""Fit the model to the data.
Args:
x (np.ndarray): Covariates.
t (np.ndarray): Event times.
e (np.ndarray): Event values.
Returns:
Self: Trained model.
"""
discretized_t, self.split_time = self.discretize(t, self.split, self.split_time)
processed_data = self._preprocess_training_data(x, discretized_t, e)
x_train, t_train, e_train, x_val, t_val, e_val = processed_data
inputdim = x_train.shape[-1]
seqlen = x_train.shape[-2]
maxrisk = int(np.nanmax(e_train.cpu().numpy()))
self.model = self._setup_model(inputdim, seqlen, risks=maxrisk)
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
patience, old_loss = 0, np.inf
nbatches = int(x_train.shape[0] / self.batch_size) + 1
valbatches = int(x_val.shape[0] / self.batch_size) + 1
best_param = deepcopy(self.model.state_dict())
for i in range(self.n_iter): # pylint: disable=unused-variable
self.model.train()
for j in range(nbatches):
xb = x_train[j * self.batch_size : (j + 1) * self.batch_size]
tb = t_train[j * self.batch_size : (j + 1) * self.batch_size]
eb = e_train[j * self.batch_size : (j + 1) * self.batch_size]
if xb.shape[0] == 0: # pragma: no cover
continue
optimizer.zero_grad()
loss = self.total_loss(xb, tb, eb)
loss.backward() # type: ignore [no-untyped-call]
if self.clipping_value > 0:
torch.nn.utils.clip_grad_norm_( # type: ignore [attr-defined] # pyright: ignore
self.model.parameters(),
self.clipping_value,
)
optimizer.step()
self.model.eval()
valid_loss: Any = 0.0
for j in range(valbatches):
xb = x_val[j * self.batch_size : (j + 1) * self.batch_size]
tb = t_val[j * self.batch_size : (j + 1) * self.batch_size]
eb = e_val[j * self.batch_size : (j + 1) * self.batch_size]
if xb.shape[0] == 0: # pragma: no cover
continue
valid_loss += self.total_loss(xb, tb, eb)
if torch.isnan(valid_loss): # pragma: no cover
raise RuntimeError("NaNs detected in the total loss")
valid_loss = valid_loss.item()
if valid_loss < old_loss:
patience = 0
old_loss = valid_loss
best_param = deepcopy(self.model.state_dict())
else:
patience += 1
if patience == self.patience:
break
self.model.load_state_dict(best_param)
self.model.eval()
return self
[docs] def discretize(
self, t: Union[np.ndarray, List[np.ndarray]], split: int, split_time: Optional[List[float]] = None
) -> Tuple:
"""Discretize the survival horizon.
Args:
t (Union[np.ndarray, List[np.ndarray]]):
Times of events.
split (int):
Number of bins.
split_time (Optional[List[float]], optional):
List of bins (must be same length than split). Provide if ``split`` is not provided.
Defaults to `None`.
Returns:
Tuple: ``(t_discretized, split_time)`` Discretized events time; split time.
"""
if split_time is None:
_, split_time = np.histogram(t, split - 1) # type: ignore
t_discretized = np.array(
[np.digitize(t_, split_time, right=True) - 1 for t_ in t], dtype=object # type: ignore
)
return t_discretized, split_time
def _preprocess_test_data(self, x: Union[np.ndarray, List[np.ndarray]]) -> torch.Tensor:
data = torch.from_numpy(get_padded_features(x, pad_size=self.pad_size)).float().to(self.device)
return data
def _preprocess_training_data(
self,
x: np.ndarray,
t: np.ndarray,
e: np.ndarray,
) -> Tuple:
"""RNNs require different preprocessing for variable length sequences."""
idx = list(range(x.shape[0]))
np.random.seed(self.random_state)
np.random.shuffle(idx)
x = get_padded_features(x)
self.pad_size = x.shape[1]
x_train_np, t_train_np, e_train_np = x[idx], t[idx], e[idx]
x_train = torch.from_numpy(x_train_np.astype(float)).float().to(self.device)
t_train = torch.from_numpy(t_train_np.astype(float)).float().to(self.device)
e_train = torch.from_numpy(e_train_np.astype(int)).float().to(self.device)
val_size = int(self.val_size * x_train.shape[0])
if val_size == 0:
raise RuntimeError(
f"Not enough samples to create a validation set, please increase `val_size` or use a larger dataset. "
f"`val_size` was {val_size} and total number of samples in dataset was {x_train.shape[0]}."
)
if val_size < 10:
# Raise a RuntimeWarning if the validation set is very small.
warnings.warn(
f"Validation set is very small ({val_size} samples). "
"Consider increasing `val_size` or using a larger dataset.",
RuntimeWarning,
)
x_val, t_val, e_val = x_train[-val_size:], t_train[-val_size:], e_train[-val_size:]
x_train = x_train[:-val_size]
t_train = t_train[:-val_size]
e_train = e_train[:-val_size]
return (x_train, t_train, e_train, x_val, t_val, e_val)
[docs] def predict_emb(
self,
x: np.ndarray,
) -> torch.Tensor:
"""Predict the embedding of the data.
Args:
x (np.ndarray): Covariates.
Returns:
torch.Tensor: Embedding of the data.
"""
if self.model is None:
raise RuntimeError(
"The model has not been fitted yet. Please fit the "
+ "model using the `fit` method on some training data "
+ "before calling `predict_survival`."
)
x_in: torch.Tensor = self._preprocess_test_data(x)
_, emb = self.model.forward_emb(x_in)
return emb
[docs] def predict_survival(
self,
x: np.ndarray,
t: List,
risk: int = 1,
all_step: bool = False,
batch_size: int = 100,
) -> np.ndarray:
"""Predict the survival function.
Args:
x (np.ndarray): Covariates.
t (List): Times to predict the survival function.
risk (int, optional): Risk value. Defaults to ``1``.
all_step (bool, optional): Predict all steps. Defaults to `False`.
batch_size (int, optional): Batch size. Defaults to ``100``.
Returns:
np.ndarray: Array of survival function values.
"""
if self.model is None:
raise RuntimeError(
"The model has not been fitted yet. Please fit the "
+ "model using the `fit` method on some training data "
+ "before calling `predict_survival`."
)
lens = [len(x_) for x_ in x]
x_in: Any = x
if all_step:
new_x = []
for x_, l_ in zip(x, lens):
new_x += [x_[: li + 1] for li in range(l_)]
x_in = new_x
# TODO: The below [t] is messy, need to investigate...
t = self.discretize([t], self.split, self.split_time)[0][0] # type: ignore
x_in_tensor: torch.Tensor = self._preprocess_test_data(x_in)
batches = int(len(x) / batch_size) + 1
scores: dict = {t_: [] for t_ in t}
for j in range(batches):
xb = x_in_tensor[j * batch_size : (j + 1) * batch_size]
_, f = self.model(xb) # pylint: disable=not-callable
for t_ in t:
pred = torch.cumsum(f[int(risk) - 1], dim=1)[:, t_].squeeze().detach().cpu().numpy().tolist()
if isinstance(pred, list):
scores[t_].extend(pred)
else: # pragma: no cover
scores[t_].append(pred)
output = []
for t_ in t:
output.append(scores[t_])
return 1 - np.asarray(output).T
[docs] def predict_risk(self, x: np.ndarray, t: List, **kwargs: Any) -> np.ndarray:
"""Predict the risk.
Args:
x (np.ndarray): Covariates.
t (List): Times to predict the risk.
**kwargs (Any): Additional arguments passed to ``predict_survival``.
Returns:
np.ndarray: Array of risk values.
"""
return 1 - self.predict_survival(x, t, **kwargs)
[docs] def negative_log_likelihood(
self,
outcomes: torch.Tensor,
cif: List[torch.Tensor],
t: torch.Tensor,
e: torch.Tensor,
) -> torch.Tensor:
"""Compute the log likelihood loss.
This function is used to compute the survival loss.
"""
loss: torch.Tensor = 0.0 # type: ignore
censored_cif: torch.Tensor = 0.0 # type: ignore
for k, ok in enumerate(outcomes):
# Censored cif
censored_cif += cif[k][e == 0][:, t[e == 0]]
# Uncensored
selection = e == (k + 1)
loss += torch.sum(torch.log(ok[selection][:, t[selection]] + constants.EPS))
# Censored loss
loss += torch.sum(torch.log(nn.ReLU()(1 - censored_cif) + constants.EPS))
return -loss / len(outcomes) # type: ignore [return-value]
[docs] def ranking_loss(
self,
cif: List[torch.Tensor],
t: torch.Tensor,
e: torch.Tensor,
) -> torch.Tensor:
"""Penalize wrong ordering of probability.
Equivalent to a C Index. This function is used to penalize wrong ordering in the survival prediction.
"""
loss: torch.Tensor = 0.0 # type: ignore
# Data ordered by time
for k, cif_k in enumerate(cif):
for ci, ti in zip(cif_k[e - 1 == k], t[e - 1 == k]):
# For all events: all patients that didn't experience event before
# must have a lower risk for that cause
if torch.sum(t > ti) > 0:
# TODO: When data are sorted in time -> wan we make it even faster?
loss += torch.mean( # type: ignore [call-overload]
torch.exp((cif_k[t > ti][:, ti] - ci[ti])) / self.sigma
)
return loss / len(cif) # type: ignore [return-value]
[docs] def longitudinal_loss(self, longitudinal_prediction: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""Penalize error in the longitudinal predictions. This function is used to compute the error made by the RNN.
NB: In the paper, they seem to use different losses for continuous and categorical,
but this was not reflected in the code associated (therefore we compute MSE for all).
NB: Original paper mentions possibility of different alphas for each risk,
but takes same for all (for ranking loss).
"""
length = (~torch.isnan(x[:, :, 0])).sum(dim=1) - 1
# Create a grid of the column index
index = torch.arange(x.size(1)).repeat(x.size(0), 1).to(self.device)
# Select all predictions until the last observed
prediction_mask = index <= (length - 1).unsqueeze(1).repeat(1, x.size(1))
# Select all observations that can be predicted
observation_mask = index <= length.unsqueeze(1).repeat(1, x.size(1))
observation_mask[:, 0] = False # Remove first observation
return torch.nn.MSELoss(reduction="mean")(longitudinal_prediction[prediction_mask], x[observation_mask])
[docs] def total_loss(
self,
x: torch.Tensor,
t: torch.Tensor,
e: torch.Tensor,
) -> torch.Tensor:
"""Compute total loss."""
if self.model is None:
raise RuntimeError("Invalid model for loss")
longitudinal_prediction, outcomes = self.model(x.float()) # pylint: disable=not-callable
if torch.isnan(longitudinal_prediction).sum() != 0:
raise RuntimeError("NaNs detected in the longitudinal_prediction")
t, e = t.long(), e.int()
# Compute cumulative function from predicted outcomes
cif = [torch.cumsum(ok, 1) for ok in outcomes]
return (
(1 - self.alpha - self.beta)
* self.longitudinal_loss(longitudinal_prediction, x) # type: ignore [return-value]
+ self.alpha * self.ranking_loss(cif, t, e)
+ self.beta * self.negative_log_likelihood(outcomes, cif, t, e)
)
[docs]class DynamicDeepHitLayers(nn.Module):
def __init__(
self,
input_dim: int,
seq_len: int,
output_dim: int,
layers_rnn: int,
hidden_rnn: int,
rnn_type: str = "LSTM",
dropout: float = 0.1,
risks: int = 1,
output_type: str = "MLP",
device: Any = constants.DEVICE,
) -> None:
"""Dynamic DeepHit layers component."""
super(DynamicDeepHitLayers, self).__init__()
self.input_dim = input_dim
self.seq_len = seq_len
self.output_dim = output_dim
self.risks = risks
self.rnn_type = rnn_type
self.device = device
self.dropout = dropout
# RNN model for longitudinal data
self.embedding: nn.Module
if self.rnn_type == "LSTM":
self.embedding = nn.LSTM(
input_dim,
hidden_rnn,
layers_rnn,
bias=False,
batch_first=True,
)
elif self.rnn_type == "RNN":
self.embedding = nn.RNN(
input_dim,
hidden_rnn,
layers_rnn,
bias=False,
batch_first=True,
nonlinearity="relu",
)
elif self.rnn_type == "GRU":
self.embedding = nn.GRU(
input_dim,
hidden_rnn,
layers_rnn,
bias=False,
batch_first=True,
)
elif self.rnn_type == "Transformer":
self.embedding = TransformerModel(
input_dim,
hidden_rnn,
n_layers_hidden=layers_rnn,
dropout=dropout,
)
else:
raise RuntimeError(f"Unknown rnn_type {rnn_type}")
# Longitudinal network
self.longitudinal = MLP(
task_type="regression",
n_units_in=hidden_rnn,
n_units_out=input_dim,
n_layers_hidden=layers_rnn,
n_units_hidden=hidden_rnn,
dropout=self.dropout,
device=device,
)
# Attention mechanism
self.attention: Union[MLP, TimeSeriesLayer]
if output_type == "MLP":
self.attention = MLP(
task_type="regression",
n_units_in=input_dim + hidden_rnn,
n_units_out=1,
dropout=self.dropout,
n_layers_hidden=layers_rnn,
n_units_hidden=hidden_rnn,
device=device,
)
else:
self.attention = TimeSeriesLayer(
n_static_units_in=0,
n_temporal_units_in=input_dim + hidden_rnn,
n_temporal_window=seq_len,
n_units_out=seq_len,
n_temporal_units_hidden=hidden_rnn,
n_temporal_layers_hidden=layers_rnn,
mode=output_type,
dropout=self.dropout,
device=device,
)
self.attention_soft = nn.Softmax(1) # On temporal dimension
self.output_type = output_type
# Cause specific network
cause_specific = []
for r in range(self.risks): # pylint: disable=unused-variable
cause_specific.append(
MLP(
task_type="regression",
n_units_in=input_dim + hidden_rnn,
n_units_out=output_dim,
dropout=self.dropout,
n_layers_hidden=layers_rnn,
n_units_hidden=hidden_rnn,
device=device,
)
)
self.cause_specific = nn.ModuleList(cause_specific)
# Probability
self.soft = nn.Softmax(dim=-1) # On all observed output
self.to(self.device)
[docs] def forward_attention(self, x: torch.Tensor, inputmask: torch.Tensor, hidden: torch.Tensor) -> torch.Tensor:
"""Forward attention implementation."""
# Attention using last observation to predict weight of all previously observed
# Extract last observation (the one used for predictions)
last_observations = (~inputmask).sum(dim=1) - 1
last_observations_idx = last_observations.unsqueeze(1).repeat(1, x.size(1))
index = torch.arange(x.size(1)).repeat(x.size(0), 1).to(self.device)
last = index == last_observations_idx
x_last = x[last]
# Concatenate all previous with new to measure attention
concatenation = torch.cat([hidden, x_last.unsqueeze(1).repeat(1, x.size(1), 1)], -1)
# Compute attention and normalize
if self.output_type == "MLP":
attention = self.attention(concatenation).squeeze(-1)
else:
attention = self.attention(torch.zeros(len(concatenation), 0).to(self.device), concatenation).squeeze(-1)
attention[index >= last_observations_idx] = -1e10 # Want soft max to be zero as values not observed
attention[last_observations > 0] = self.attention_soft(
attention[last_observations > 0]
) # Weight previous observation
attention[last_observations == 0] = 0 # No context for only one observation
# Risk networks
# The original paper is not clear on how the last observation is
# combined with the temporal sum, other code was concatenating them
attention = attention.unsqueeze(2).repeat(1, 1, hidden.size(2))
hidden_attentive = torch.sum(attention * hidden, dim=1)
return torch.cat([hidden_attentive, x_last], 1)
[docs] def forward_emb(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""The forward function that is called when data is passed through DynamicDeepHit."""
# RNN representation - Nan values for not observed data
x = x.clone()
inputmask = torch.isnan(x[:, :, 0])
x[torch.isnan(x)] = -1
if torch.isnan(x).sum() != 0: # pragma: no cover
raise RuntimeError("NaNs detected in the input")
if self.rnn_type in ["GRU", "LSTM", "RNN"]:
hidden, _ = self.embedding(x)
else:
hidden = self.embedding(x)
if torch.isnan(hidden).sum() != 0: # pragma: no cover
raise RuntimeError("NaNs detected in the embeddings")
# Longitudinal modelling
longitudinal_prediction = self.longitudinal(hidden)
if torch.isnan(longitudinal_prediction).sum() != 0: # pragma: no cover
raise RuntimeError("NaNs detected in the longitudinal_prediction")
hidden_attentive = self.forward_attention(x, inputmask, hidden)
return longitudinal_prediction, hidden_attentive
[docs] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List]:
"""The forward function that is called when data is passed through DynamicDeepHit."""
# RNN representation - Nan values for not observed data
x = x.to(self.device)
longitudinal_prediction, hidden_attentive = self.forward_emb(x)
outcomes = []
for cs_nn in self.cause_specific:
outcomes.append(cs_nn(hidden_attentive))
# Soft max for probability distribution
outcomes_t = torch.cat(outcomes, dim=1)
outcomes_t = self.soft(outcomes_t)
if torch.isnan(outcomes_t).sum() != 0: # pragma: no cover
raise RuntimeError("NaNs detected in the outcome")
outcomes = [outcomes_t[:, i * self.output_dim : (i + 1) * self.output_dim] for i in range(self.risks)]
return longitudinal_prediction, outcomes