Source code for tempor.data.predictive

"""Module defining the `PredictiveTaskData` class and its subclasses, which are used to store the data components
relevant for different predictive tasks (e.g. prediction, time-to-event analysis, treatment effects).
"""

import abc
from typing import TYPE_CHECKING, Any, Generator, Optional

import rich.pretty

from tempor.core.utils import RichReprStrPassthrough

from . import data_typing, samples

if TYPE_CHECKING:  # pragma: no cover
    from .dataset import PredictiveDataset  # For typing only, no circular import.


[docs]class PredictiveTaskData(abc.ABC): _targets: Optional[samples.DataSamples] _treatments: Optional[samples.DataSamples] @property @abc.abstractmethod def predictive_task(self) -> data_typing.PredictiveTask: # pragma: no cover """Return the predictive task enum value corresponding to the class. Returns: data_typing.PredictiveTask: The predictive task enum value. """ ... # pylint: disable=unnecessary-ellipsis def __init__( # pylint: disable=unused-argument self, parent_dataset: "PredictiveDataset", targets: Any, treatments: Optional[Any], **kwargs: Any, ) -> None: """The predictive task data abstract base class. Args: parent_dataset (PredictiveDataset): The parent predictive dataset object. targets (Any): The targets data. treatments (Optional[Any]): The treatments data. **kwargs (Any): Additional keyword arguments. Currently unused. """ self.parent_dataset = parent_dataset # ^ In order to be able to call parent dataset's `validate` method in the targets/treatments property setters. self._targets = targets self._treatments = treatments def __rich_repr__(self) -> Generator: """A `rich` representation of the class. Yields: Generator: The fields and their values fed to `rich`. """ if self.targets is not None: yield "targets", RichReprStrPassthrough(self.targets.short_repr()) else: yield "targets", None if self.treatments is not None: yield "treatments", RichReprStrPassthrough(self.treatments.short_repr()) def __repr__(self) -> str: """The `repr()` representation of the class. Returns: str: The representation. """ return rich.pretty.pretty_repr(self) @property def targets(self) -> Optional[samples.DataSamples]: """The property containing the targets data. Returns: Optional[samples.DataSamples]: The targets data. """ return self._targets @targets.setter def targets(self, value: Optional[samples.DataSamples]) -> None: self._targets = value self.parent_dataset.validate() @property def treatments(self) -> Optional[samples.DataSamples]: """The property containing the treatments data. Returns: Optional[samples.DataSamples]: The treatments data. """ return self._treatments @treatments.setter def treatments(self, value: Optional[samples.DataSamples]) -> None: self._treatments = value self.parent_dataset.validate()
# Predictive task data classes corresponding to different tasks follow. More can be added to handle new tasks. # --- Prediction tasks: ---
[docs]class OneOffPredictionTaskData(PredictiveTaskData): # One-off prediction (e.g., one-off classification with a target like patient death). targets: Optional[samples.StaticSamplesBase] treatments: None @property def predictive_task(self) -> data_typing.PredictiveTask: """Return the predictive task enum value corresponding to the class. Here, ``ONE_OFF_PREDICTION``. Returns: data_typing.PredictiveTask: The predictive task enum value. Here, ``ONE_OFF_PREDICTION``. """ return data_typing.PredictiveTask.ONE_OFF_PREDICTION def __init__( self, parent_dataset: "PredictiveDataset", targets: Optional[data_typing.DataContainer], **kwargs: Any ) -> None: """The one-off prediction task data class. Args: parent_dataset (PredictiveDataset): The parent predictive dataset object. targets (Optional[data_typing.DataContainer]): The targets data. **kwargs (Any): Additional keyword arguments. Currently unused. """ super().__init__(parent_dataset=parent_dataset, targets=targets, treatments=None) self._targets = samples.StaticSamples(targets) if targets is not None else None self._treatments = None
[docs]class TemporalPredictionTaskData(PredictiveTaskData): # Temporal prediction (e.g., predicting a patient's temperature real valued time series). targets: Optional[samples.TimeSeriesSamplesBase] treatments: None @property def predictive_task(self) -> data_typing.PredictiveTask: """Return the predictive task enum value corresponding to the class. Here, ``TEMPORAL_PREDICTION``. Returns: data_typing.PredictiveTask: The predictive task enum value. Here, ``TEMPORAL_PREDICTION``. """ return data_typing.PredictiveTask.TEMPORAL_PREDICTION def __init__( self, parent_dataset: "PredictiveDataset", targets: Optional[data_typing.DataContainer], **kwargs: Any ) -> None: """The temporal prediction task data class. Args: parent_dataset (PredictiveDataset): The parent predictive dataset object. targets (Optional[data_typing.DataContainer]): The targets data. **kwargs (Any): Additional keyword arguments. Currently unused. """ super().__init__(parent_dataset=parent_dataset, targets=targets, treatments=None) self._targets = samples.TimeSeriesSamples(targets) if targets is not None else None self._treatments = None
# --- Time-to-event tasks: ---
[docs]class TimeToEventAnalysisTaskData(PredictiveTaskData): # Time-to-event (survival) analysis (e.g. Dynamic DeepHit). targets: Optional[samples.EventSamplesBase] treatments: None @property def predictive_task(self) -> data_typing.PredictiveTask: """Return the predictive task enum value corresponding to the class. Here, ``TIME_TO_EVENT_ANALYSIS``. Returns: data_typing.PredictiveTask: The predictive task enum value. Here, ``TIME_TO_EVENT_ANALYSIS``. """ return data_typing.PredictiveTask.TIME_TO_EVENT_ANALYSIS def __init__( self, parent_dataset: "PredictiveDataset", targets: Optional[data_typing.DataContainer], **kwargs: Any ) -> None: """The time-to-event analysis task data class. Args: parent_dataset (PredictiveDataset): The parent predictive dataset object. targets (Optional[data_typing.DataContainer]): The targets data. **kwargs (Any): Additional keyword arguments. Currently unused. """ super().__init__(parent_dataset=parent_dataset, targets=targets, treatments=None) self._targets = samples.EventSamples(targets) if targets is not None else None self._treatments = None
# --- Treatment Effects tasks: ---
[docs]class OneOffTreatmentEffectsTaskData(PredictiveTaskData): # Treatment effects with time series outcomes but one-off treatment event(s) (e.g. SyncTwin) targets: Optional[samples.TimeSeriesSamplesBase] treatments: samples.EventSamplesBase @property def predictive_task(self) -> data_typing.PredictiveTask: """Return the predictive task enum value corresponding to the class. Here, ``ONE_OFF_TREATMENT_EFFECTS``. Returns: data_typing.PredictiveTask: The predictive task enum value. Here, ``ONE_OFF_TREATMENT_EFFECTS``. """ return data_typing.PredictiveTask.ONE_OFF_TREATMENT_EFFECTS def __init__( self, parent_dataset: "PredictiveDataset", targets: Optional[data_typing.DataContainer], treatments: data_typing.DataContainer, **kwargs: Any, ) -> None: """The one-off treatment effects task data class. Args: parent_dataset (PredictiveDataset): The parent predictive dataset object. targets (Optional[data_typing.DataContainer]): The targets data. treatments (data_typing.DataContainer): The treatments data. **kwargs (Any): Additional keyword arguments. Currently unused. """ super().__init__(parent_dataset=parent_dataset, targets=targets, treatments=treatments) self._targets = samples.TimeSeriesSamples(targets) if targets is not None else None self._treatments = samples.EventSamples(treatments)
[docs]class TemporalTreatmentEffectsTaskData(PredictiveTaskData): # Temporal treatment effects (i.e. outcomes are time series and treatments are also time series, e.g. RMSN, CRN). targets: Optional[samples.TimeSeriesSamplesBase] treatments: samples.TimeSeriesSamplesBase @property def predictive_task(self) -> data_typing.PredictiveTask: """Return the predictive task enum value corresponding to the class. Here, ``TEMPORAL_TREATMENT_EFFECTS``. Returns: data_typing.PredictiveTask: The predictive task enum value. Here, ``TEMPORAL_TREATMENT_EFFECTS``. """ return data_typing.PredictiveTask.TEMPORAL_TREATMENT_EFFECTS def __init__( self, parent_dataset: "PredictiveDataset", targets: Optional[data_typing.DataContainer], treatments: data_typing.DataContainer, **kwargs: Any, ) -> None: """The temporal treatment effects task data class. Args: parent_dataset (PredictiveDataset): The parent predictive dataset object. targets (Optional[data_typing.DataContainer]): The targets data. treatments (data_typing.DataContainer): The treatments data. **kwargs (Any): Additional keyword arguments. Currently unused. """ super().__init__(parent_dataset=parent_dataset, targets=targets, treatments=treatments) self._targets = samples.TimeSeriesSamples(targets) if targets is not None else None self._treatments = samples.TimeSeriesSamples(treatments)