Source code for tempor.methods.time_to_event.helper_embedding

"""Helper class for embedding time-series data for time-to-event analysis using Dynamic DeepHit embeddings."""

import abc
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, cast

import numpy as np
import pandas as pd
from typing_extensions import Self

import tempor.exc
from tempor.data import data_typing, dataset, samples
from tempor.methods.core.params import CategoricalParams, FloatParams, IntegerParams, Params
from tempor.models import utils
from tempor.models.ddh import DynamicDeepHitModel, output_modes, rnn_modes


[docs]class OutputTimeToEventAnalysis: """Helper base class for time-to-event analysis models."""
[docs] @abc.abstractmethod def fit(self, X: pd.DataFrame, T: pd.Series, Y: pd.Series) -> Self: # pragma: no cover """Fit the model. Args: X (pd.DataFrame): Input covariates of fixed shape - time-series data should be embedded first. T (pd.Series): Event times. Y (pd.Series): Event values. Returns: Self: Fitted model. """ ... # pylint: disable=unnecessary-ellipsis
[docs] @abc.abstractmethod def predict_risk(self, X: pd.DataFrame, time_horizons: List) -> pd.DataFrame: # pragma: no cover """Predict risk scores. Args: X (pd.DataFrame): Input covariates of fixed shape - time-series data should be embedded first. time_horizons (List): Time horizons to predict risk at. Returns: pd.DataFrame: Risk scores. """ ... # pylint: disable=unnecessary-ellipsis
[docs]class DDHEmbedding: def __init__(self, emb_model: DynamicDeepHitModel) -> None: """Survival analysis embedding creation for time-series with :class:`tempor.models.ddh.DynamicDeepHitModel`. Args: emb_model (DynamicDeepHitModel): :class:`tempor.models.ddh.DynamicDeepHitModel` to use for temporal feature embedding. """ self.emb_model = emb_model def _merge_data( self, static: Optional[np.ndarray], temporal: List[np.ndarray], observation_times: List[np.ndarray], ) -> np.ndarray: if static is None: static = np.zeros((len(temporal), 0)) merged = [] for idx, item in enumerate(temporal): # pylint: disable=unused-variable local_static = static[idx].reshape(1, -1) local_static = np.repeat(local_static, len(temporal[idx]), axis=0) tst = np.concatenate( [ temporal[idx], local_static, np.asarray(observation_times[idx]).reshape(-1, 1), ], axis=1, ) merged.append(tst) return np.array(merged, dtype=object) def _validate_data(self, data: dataset.TimeToEventAnalysisDataset) -> None: if data.predictive.targets is not None and data.predictive.targets.num_features > 1: raise tempor.exc.UnsupportedSetupException( f"{self.__class__.__name__} does not currently support more than one event feature, " f"but features found were: {data.predictive.targets.dataframe().columns}" ) # TODO: This needs investigating - likely different length sequences aren't handled properly. # if not data.time_series.num_timesteps_equal(): # raise tempor.exc.UnsupportedSetupException( # f"{self.__class__.__name__} currently requires all samples to have the same number of timesteps, " # f"but found timesteps of varying lengths {np.unique(data.time_series.num_timesteps()).tolist()}" # ) def _convert_data( self, data: dataset.TimeToEventAnalysisDataset ) -> Tuple[Optional[np.ndarray], List[np.ndarray], List[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: if data.has_static: static = data.static.numpy() if data.static is not None else None else: static = np.zeros((data.time_series.num_samples, 0)) temporal = [df.to_numpy() for df in data.time_series.list_of_dataframes()] observation_times = data.time_series.time_indexes_float() if data.predictive is not None and data.predictive.targets is not None: event_times, event_values = ( df.to_numpy().reshape((-1,)) for df in data.predictive.targets.split_as_two_dataframes() ) else: event_times, event_values = None, None return (static, temporal, observation_times, event_times, event_values)
[docs] def prepare_fit( self, data: dataset.BaseDataset, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Prepare data for fitting. Args: data (dataset.BaseDataset): Input dataset. Returns: Tuple[np.ndarray, np.ndarray, np.ndarray]: Processed covariate data, event times, event values. """ utils.enable_reproducibility(self.emb_model.random_state) data = cast(dataset.TimeToEventAnalysisDataset, data) self._validate_data(data) (static, temporal, observation_times, event_times, event_values) = self._convert_data(data) processed_data = self._merge_data(static, temporal, observation_times) if TYPE_CHECKING: # pragma: no cover assert event_times is not None and event_values is not None # nosec B101 return processed_data, event_times, event_values
[docs] def prepare_predict( # pylint: disable=unused-argument self, data: dataset.PredictiveDataset, horizons: data_typing.TimeIndex, *args: Any, **kwargs: Any, ) -> np.ndarray: """Prepare data for prediction. Args: data (dataset.PredictiveDataset): Input dataset. horizons (data_typing.TimeIndex): Time horizons to predict risk at. *args (Any): Additional arguments. **kwargs (Any): Additional keyword arguments. Returns: np.ndarray: Processed data. """ data = cast(dataset.TimeToEventAnalysisDataset, data) self._validate_data(data) (static, temporal, observation_times, _, _) = self._convert_data(data) processed_data = self._merge_data(static, temporal, observation_times) return processed_data
[docs] @staticmethod def hyperparameter_space( # pylint: disable=unused-argument *args: Any, **kwargs: Any, ) -> List[Params]: # noqa: D102 return [ IntegerParams(name="n_units_hidden", low=10, high=100, step=10), IntegerParams(name="n_layers_hidden", low=1, high=4), CategoricalParams(name="batch_size", choices=[100, 200, 500]), CategoricalParams(name="lr", choices=[1e-2, 1e-3, 1e-4]), CategoricalParams(name="rnn_mode", choices=list(rnn_modes)), CategoricalParams(name="output_mode", choices=list(output_modes)), FloatParams(name="alpha", low=0.0, high=0.5), FloatParams(name="sigma", low=0.0, high=0.5), FloatParams(name="beta", low=0.0, high=0.5), FloatParams(name="dropout", low=0.0, high=0.2), ]
[docs]class DDHEmbeddingTimeToEventAnalysis(DDHEmbedding): def __init__( self, output_model: OutputTimeToEventAnalysis, emb_model: DynamicDeepHitModel, ) -> None: """Survival analysis embedding creation for time-series with :class:`tempor.models.ddh.DynamicDeepHitModel` followed by ``output_model`` :class:`OutputTimeToEventAnalysis` survival analysis estimator. Args: output_model (OutputTimeToEventAnalysis): Output model to use for predicting risk. emb_model (DynamicDeepHitModel): :class:`tempor.models.ddh.DynamicDeepHitModel` to use for temporal feature embedding. """ DDHEmbedding.__init__(self, emb_model=emb_model) self.output_model = output_model
[docs] def fit( # pylint: disable=unused-argument self, data: dataset.BaseDataset, *args: Any, **kwargs: Any, ) -> Self: """Fit the model. Args: data (dataset.BaseDataset): Input dataset. *args (Any): Additional arguments. **kwargs (Any): Additional keyword arguments. Returns: Self: Fitted model. """ processed_data, event_times, event_values = self.prepare_fit(data) self.emb_model.fit(processed_data, event_times, event_values) embeddings = self.emb_model.predict_emb(processed_data) self.output_model.fit( pd.DataFrame(embeddings), # pyright: ignore pd.Series(event_times), pd.Series(event_values), ) return self
[docs] def predict( self, data: dataset.PredictiveDataset, horizons: data_typing.TimeIndex, *args: Any, **kwargs: Any, ) -> samples.TimeSeriesSamplesBase: """Predict risk scores. ``*args`` and ``**kwargs`` will be passed to ``self.emb_model.predict_emb()``. E.g. ``batch_size`` batch size parameter can be provided this way. Args: data (dataset.PredictiveDataset): Input dataset. horizons (data_typing.TimeIndex): Time horizons to predict risk at. *args (Any): Additional arguments. Passed to ``self.emb_model.predict_emb()`` **kwargs (Any): Additional keyword arguments. Passed to ``self.emb_model.predict_emb()`` Returns: samples.TimeSeriesSamplesBase: Predicted risk scores. """ processed_data = self.prepare_predict(data, horizons, *args, **kwargs) embeddings = self.emb_model.predict_emb(processed_data) risk = self.output_model.predict_risk( pd.DataFrame(embeddings), # pyright: ignore horizons, ) risk = np.asarray(risk) return samples.TimeSeriesSamples( risk.reshape((risk.shape[0], risk.shape[1], 1)), sample_index=data.time_series.sample_index(), time_indexes=[horizons] * data.time_series.num_samples, # pyright: ignore feature_index=["risk_score"], )
[docs] @staticmethod def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Params]: # noqa: D102 return DDHEmbedding.hyperparameter_space(*args, **kwargs)