Source code for tempor.methods.time_to_event.plugin_ddh

"""Dynamic DeepHit survival analysis model."""

import dataclasses
from typing import Any, List

from typing_extensions import Self

from tempor.core import plugins
from tempor.data import data_typing, dataset, samples
from tempor.methods.core.params import Params
from tempor.methods.time_to_event import BaseTimeToEventAnalysis
from tempor.models.ddh import DynamicDeepHitModel, OutputMode, RnnMode

from .helper_embedding import DDHEmbedding


[docs]@dataclasses.dataclass class DynamicDeepHitTimeToEventAnalysisParams: n_iter: int = 1000 """Number of training epochs.""" batch_size: int = 100 """Training batch size.""" lr: float = 1e-3 """Training learning rate.""" n_layers_hidden: int = 1 """Number of hidden layers in the network.""" n_units_hidden: int = 40 """Number of units for each hidden layer.""" split: int = 100 """Number of discrete buckets.""" rnn_mode: RnnMode = "GRU" """Internal temporal architecture, one of `RnnMode`.""" alpha: float = 0.34 """Weighting (0, 1) likelihood and rank loss (L2 in paper). 1 gives only likelihood, and 0 gives only rank loss.""" beta: float = 0.27 """Beta, see paper.""" sigma: float = 0.21 """From eta in rank loss (L2 in paper).""" dropout: float = 0.06 """Network dropout value.""" device: str = "cpu" """PyTorch Device.""" val_size: float = 0.1 """Early stopping: size of validation set.""" patience: int = 20 """Early stopping: training patience without any improvement.""" output_mode: OutputMode = "MLP" """Output network, on of `OutputMode`.""" random_state: int = 0 """Random seed."""
[docs]@plugins.register_plugin(name="dynamic_deephit", category="time_to_event") class DynamicDeepHitTimeToEventAnalysis(BaseTimeToEventAnalysis, DDHEmbedding): ParamsDefinition = DynamicDeepHitTimeToEventAnalysisParams params: DynamicDeepHitTimeToEventAnalysisParams # type: ignore def __init__(self, **params: Any) -> None: """Dynamic DeepHit survival analysis model. Note: Current implementation has the following limitations: - Only one output feature is supported (no competing risks). - Risk prediction for time points beyond the last event time in the dataset may throw errors. Args: **params (Any): Parameters and defaults as defined in :class:`DynamicDeepHitTimeToEventAnalysisParams`. References: "Dynamic-DeepHit: A Deep Learning Approach for Dynamic Survival Analysis With Competing Risks Based on Longitudinal Data", Changhee Lee, Jinsung Yoon, Mihaela van der Schaar. """ super().__init__(**params) self.model = DynamicDeepHitModel( split=self.params.split, n_layers_hidden=self.params.n_layers_hidden, n_units_hidden=self.params.n_units_hidden, rnn_mode=self.params.rnn_mode, alpha=self.params.alpha, beta=self.params.beta, sigma=self.params.sigma, dropout=self.params.dropout, val_size=self.params.val_size, patience=self.params.patience, lr=self.params.lr, batch_size=self.params.batch_size, n_iter=self.params.n_iter, output_mode=self.params.output_mode, device=self.params.device, ) DDHEmbedding.__init__(self, emb_model=self.model) def _fit( self, data: dataset.BaseDataset, *args: Any, **kwargs: Any, ) -> Self: processed_data, event_times, event_values = self.prepare_fit(data) self.model.fit(processed_data, event_times, event_values) return self def _predict( self, data: dataset.PredictiveDataset, horizons: data_typing.TimeIndex, *args: Any, **kwargs: Any, ) -> samples.TimeSeriesSamplesBase: # NOTE: kwargs will be passed to DynamicDeepHitModel.predict_risk(). # E.g. `batch_size` batch size parameter can be provided this way. processed_data = self.prepare_predict(data, horizons, *args, **kwargs) risk = self.model.predict_risk(processed_data, horizons, **kwargs) return samples.TimeSeriesSamples( risk.reshape((risk.shape[0], risk.shape[1], 1)), sample_index=data.time_series.sample_index(), time_indexes=[horizons] * data.time_series.num_samples, # pyright: ignore feature_index=["risk_score"], )
[docs] @staticmethod def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Params]: # noqa: D102 return DDHEmbedding.hyperparameter_space(*args, **kwargs)