"""Module defining the TemporAI dataset concept in :class:`BaseDataset` and its derived classes."""
# pylint: disable=unnecessary-ellipsis
import abc
import dataclasses
from typing import Any, ClassVar, Generator, Optional, Tuple, Union
import numpy as np
import rich.pretty
import sklearn.model_selection
from typing_extensions import Self
from tempor.core.utils import RichReprStrPassthrough
from tempor.log import log_helpers, logger
from . import data_typing
from . import predictive as pred
from . import samples, utils
# NOTE: Can probably add other splitters:
Splitter = Union[
sklearn.model_selection.KFold,
sklearn.model_selection.StratifiedKFold,
]
@dataclasses.dataclass(frozen=True)
class _SampleIndexMismatchMsg:
static: ClassVar[str] = (
"`sample_index` of static samples did not match `sample_index` of time series samples. "
"Note that the samples need to be in the same order."
)
targets: ClassVar[str] = (
"`sample_index` of targets did not match `sample_index` of time series samples. "
"Note that the samples need to be in the same order."
)
treatments: ClassVar[str] = (
"`sample_index` of treatments did not match `sample_index` of time series samples. "
"Note that the samples need to be in the same order."
)
@dataclasses.dataclass(frozen=True)
class _TimeIndexesMismatchMsg:
targets: ClassVar[str] = "`time_indexes` of targets did not match `time_indexes` of time series covariates."
treatments: ClassVar[str] = "`time_indexes` of treatments did not match `time_indexes` of time series covariates."
@dataclasses.dataclass(frozen=True)
class _ExceptionMessages:
sample_index_mismatch: ClassVar[_SampleIndexMismatchMsg] = _SampleIndexMismatchMsg()
time_indexes_mismatch: ClassVar[_TimeIndexesMismatchMsg] = _TimeIndexesMismatchMsg()
EXCEPTION_MESSAGES = _ExceptionMessages()
"""Reusable error messages for the module."""
[docs]class BaseDataset(abc.ABC):
_time_series: samples.TimeSeriesSamplesBase
_static: Optional[samples.StaticSamplesBase]
predictive: Optional[pred.PredictiveTaskData]
def __init__(
self,
time_series: data_typing.DataContainer,
*,
static: Optional[data_typing.DataContainer] = None,
targets: Optional[data_typing.DataContainer] = None,
treatments: Optional[data_typing.DataContainer] = None,
**kwargs: Any,
) -> None:
"""Abstract base class representing a dataset used by TemporAI.
Initialize one of its derived classes (e.g. :class:`OneOffPredictionDataset`,
:class:`TimeToEventAnalysisDataset` etc.) depending on the type of task.
See also tutorial ``tutorials/tutorial01_data_format.ipynb`` for examples of use.
Args:
time_series (data_typing.DataContainer):
Data representing time series covariates of the samples. Will be initialized as `TimeSeriesSamples`.
static (Optional[data_typing.DataContainer], optional):
Data representing static covariates of the samples. Will be initialized as `StaticSamples`.
Defaults to `None`.
targets (Optional[data_typing.DataContainer], optional):
Data representing target (outcome) feature(s) of the samples. Will be initialized as
``{TimeSeries,Static,Event}Samples`` depending on problem setting in the derived class.
Defaults to `None`.
treatments (Optional[data_typing.DataContainer], optional):
Data representing treatment (intervention) feature(s) of the samples. Will be initialized as
``{TimeSeries,Static,Event}Samples`` depending on problem setting in the derived class.
Defaults to `None`.
**kwargs (Any):
Additional keyword arguments to be passed to the derived class's ``_init_predictive`` method.
"""
self._time_series = samples.TimeSeriesSamples(time_series)
self._static = samples.StaticSamples(static) if static is not None else None
self._init_predictive(targets=targets, treatments=treatments, **kwargs)
self.validate()
def __rich_repr__(self) -> Generator:
"""A `rich` representation of the class.
Yields:
Generator: The fields and their values fed to `rich`.
"""
yield "time_series", RichReprStrPassthrough(self.time_series.short_repr())
if self.static is not None:
yield "static", RichReprStrPassthrough(self.static.short_repr())
if self.predictive is not None:
yield "predictive", self.predictive
def __repr__(self) -> str:
"""The `repr()` representation of the class.
Returns:
str: The representation.
"""
return rich.pretty.pretty_repr(self)
@abc.abstractmethod
def _init_predictive(
self,
targets: Optional[data_typing.DataContainer],
treatments: Optional[data_typing.DataContainer],
**kwargs: Any,
) -> None: # pragma: no cover
"""A method to initialize ``self.predictive`` in derived classes."""
...
@property
def has_static(self) -> bool:
"""A property returning whether the dataset has static data.
Returns:
bool: Whether the dataset has static data.
"""
return self.static is not None
@property
def has_predictive_data(self) -> bool:
"""A property returning whether the dataset has predictive data (``targets`` or ``treatments``).
Returns:
bool: Whether the dataset has predictive data.
"""
return self.predictive is not None
@property
def predictive_task(self) -> Union[data_typing.PredictiveTask, None]:
"""A property returning the predictive task of the dataset (or `None`).
Returns:
Union[data_typing.PredictiveTask, None]: The predictive task of the dataset.
"""
if self.predictive is not None:
return self.predictive.predictive_task
else:
return None
[docs] def validate(self) -> None:
"""Validate integrity of the dataset."""
with log_helpers.exc_to_log("Dataset validation failed"):
if self.static is not None:
if self.static.sample_index() != self.time_series.sample_index():
raise ValueError(EXCEPTION_MESSAGES.sample_index_mismatch.static)
self._validate()
@abc.abstractmethod
def _validate(self) -> None: # pragma: no cover
...
@property
def time_series(self) -> samples.TimeSeriesSamplesBase:
"""The property containing the time series covariates of the dataset.
Returns:
samples.TimeSeriesSamplesBase: The time series covariates of the dataset.
"""
return self._time_series
@time_series.setter
def time_series(self, value: samples.TimeSeriesSamplesBase) -> None:
self._time_series = value
self.validate()
@property
def static(self) -> Optional[samples.StaticSamplesBase]:
"""The property containing the static covariates of the dataset.
Returns:
Optional[samples.StaticSamplesBase]: The static covariates of the dataset.
"""
return self._static
@static.setter
def static(self, value: Optional[samples.StaticSamplesBase]) -> None:
self._static = value
self.validate()
@property
@abc.abstractmethod
def fit_ready(self) -> bool: # pragma: no cover
"""Returns whether the :class:`BaseDataset` is in a state ready to be ``fit`` on."""
...
def __len__(self) -> int:
"""The dataset length, which is the number of samples.
Returns:
int: Dataset length.
"""
return self.time_series.num_samples
def __getitem__(self, key: data_typing.GetItemKey) -> Self:
"""Return a subset of the dataset.
Args:
key (data_typing.GetItemKey): The key to index the dataset with.
Returns:
Self: The subset of the dataset.
"""
key_ = utils.ensure_pd_iloc_key_returns_df(key)
new_dataset = self.__class__(
time_series=self.time_series[key_].dataframe(), # pyright: ignore
static=self.static[key_].dataframe() if self.has_static else None, # type: ignore[union-attr,index]
targets=(
self.predictive.targets[key_].dataframe() # type: ignore[union-attr]
if (self.has_predictive_data and self.predictive.targets is not None) # type: ignore[union-attr]
else None
),
treatments=(
self.predictive.treatments[key_].dataframe() # type: ignore[union-attr]
if (self.has_predictive_data and self.predictive.treatments is not None) # type: ignore[union-attr]
else None
),
)
return new_dataset
[docs] def train_test_split(
self,
*,
test_size: Optional[float] = None,
train_size: Optional[float] = None,
random_state: Union[int, np.random.RandomState, None] = None, # pylint: disable=no-member
shuffle: bool = True,
stratify: Any = None,
) -> Tuple[Self, Self]:
"""Split `Dataset` into train and test sets.
The arguments ``test_size`` ... ``stratify`` are passed to `sklearn.model_selection.train_test_split` to
generate the split.
Args:
test_size (Optional[float], optional):
Passed to `sklearn.model_selection.train_test_split`. Defaults to `None`.
train_size (Optional[float], optional):
Passed to `sklearn.model_selection.train_test_split`. Defaults to `None`.
random_state (Union[int, np.random.RandomState, None], optional):
Passed to `sklearn.model_selection.train_test_split`. Defaults to `None`.
shuffle (bool, optional):
Passed to `sklearn.model_selection.train_test_split`. Defaults to `True`.
stratify (Any, optional):
Passed to `sklearn.model_selection.train_test_split`. Defaults to `None`.
Returns:
Tuple[Self, Self]: The split tuple ``(dataset_train, dataset_test)``.
"""
sample_ilocs = list(range(len(self)))
sample_ilocs_train, sample_ilocs_test = sklearn.model_selection.train_test_split(
sample_ilocs,
test_size=test_size,
train_size=train_size,
random_state=random_state,
shuffle=shuffle,
stratify=stratify,
)
return self[sample_ilocs_train], self[sample_ilocs_test]
[docs] def split(
self,
splitter: Splitter,
**kwargs: Any,
) -> Generator[Tuple[Self, Self], None, None]:
"""Generate dataset splits according to the scikit-learn ``splitter`` (`~tempor.data.dataset.Splitter`).
The ``kwargs`` are passed to the underlying splitter's ``split`` method.
Example:
>>> from sklearn.model_selection import KFold
>>> from tempor import plugin_loader
>>> data = plugin_loader.get("prediction.one_off.sine", plugin_type="datasource").load()
>>> kfold = KFold(n_splits=5)
>>> len([(data_train, data_test) for (data_train, data_test) in data.split(splitter=kfold)])
5
Args:
splitter (Splitter): A `sklearn` splitter.
**kwargs (Any): Additional keyword arguments to be passed to the ``splitter``'s ``split`` method.
Yields:
Tuple[Self, Self]: ``(dataset_train, dataset_test)`` for each split.
"""
sample_ilocs: Any = list(range(len(self)))
for sample_ilocs_train, sample_ilocs_test in splitter.split(X=sample_ilocs, **kwargs):
yield self[sample_ilocs_train], self[sample_ilocs_test]
# `Dataset`s corresponding to different tasks follow. More can be added to handle new Tasks.
[docs]class CovariatesDataset(BaseDataset):
def __init__(
self,
time_series: data_typing.DataContainer,
*,
static: Optional[data_typing.DataContainer] = None,
targets: Optional[data_typing.DataContainer] = None,
treatments: Optional[data_typing.DataContainer] = None,
**kwargs: Any,
) -> None:
"""A :class:`BaseDataset` subclass for a dataset that does not contain any predictive data
(``targets`` or ``treatments``).
"""
super().__init__(time_series=time_series, static=static, targets=targets, treatments=treatments, **kwargs)
def _init_predictive(
self,
targets: Optional[data_typing.DataContainer],
treatments: Optional[data_typing.DataContainer],
**kwargs: Any,
) -> None:
if targets is not None:
raise ValueError(f"`targets` must not be set for a {self.__class__.__name__}.")
if treatments is not None:
raise ValueError(f"`treatments` must not be set for a {self.__class__.__name__}.")
self.predictive = None
def _validate(self) -> None:
# No additional checks needed.
pass
@property
def fit_ready(self) -> bool:
"""Check if the dataset is ready to be fit on.
Returns:
bool: Whether the dataset is ready to be fit on.
"""
return True
[docs]class PredictiveDataset(BaseDataset):
predictive: pred.PredictiveTaskData
def __init__(
self,
time_series: data_typing.DataContainer,
*,
targets: Optional[data_typing.DataContainer],
static: Optional[data_typing.DataContainer] = None,
treatments: Optional[data_typing.DataContainer] = None,
**kwargs: Any,
) -> None:
"""A :class:`BaseDataset` subclass for a dataset that can contain predictive data
(``targets`` or ``treatments``).
This is an abstract class, to be derived from for different predictive task -specific ``Dataset`` s.
"""
super().__init__(time_series=time_series, static=static, targets=targets, treatments=treatments, **kwargs)
@property
@abc.abstractmethod
def predict_ready(self) -> bool: # pragma: no cover
"""Returns whether the :class:`PredictiveDataset` is in a state ready to be ``predict`` ed on."""
...
[docs]class OneOffPredictionDataset(PredictiveDataset):
predictive: pred.OneOffPredictionTaskData
def __init__(
self,
time_series: data_typing.DataContainer,
*,
targets: Optional[data_typing.DataContainer],
static: Optional[data_typing.DataContainer] = None,
treatments: Optional[data_typing.DataContainer] = None,
**kwargs: Any,
) -> None:
"""A :class:`PredictiveDataset` subclass for the one-off prediction problem setting,
see :class:`BaseDataset` docs.
In this setting: ``targets`` are required for fitting, will be initialized as `StaticSamples`.
"""
super().__init__(time_series=time_series, static=static, targets=targets, treatments=treatments, **kwargs)
def _init_predictive(
self,
targets: Optional[data_typing.DataContainer],
treatments: Optional[data_typing.DataContainer],
**kwargs: Any,
) -> None:
if targets is None:
logger.debug(
f"`targets` provided was None for {self.__class__.__name__}, "
"this Dataset can only be used for prediction not fitting"
)
self.predictive = pred.OneOffPredictionTaskData(parent_dataset=self, targets=targets, **kwargs)
def _validate(self) -> None:
if self.predictive.targets is not None:
if self.predictive.targets.sample_index() != self.time_series.sample_index():
raise ValueError(EXCEPTION_MESSAGES.sample_index_mismatch.targets)
@property
def fit_ready(self) -> bool:
"""Check if the dataset is ready to be fit on.
Returns:
bool: Whether the dataset is ready to be fit on.
"""
return self.predictive.targets is not None
@property
def predict_ready(self) -> bool:
"""Check if the dataset is ready to be predicted on.
Returns:
bool: Whether the dataset is ready to be predicted on.
"""
return True
[docs]class TemporalPredictionDataset(PredictiveDataset):
predictive: pred.TemporalPredictionTaskData
def __init__(
self,
time_series: data_typing.DataContainer,
*,
targets: Optional[data_typing.DataContainer],
static: Optional[data_typing.DataContainer] = None,
treatments: Optional[data_typing.DataContainer] = None,
**kwargs: Any,
) -> None:
"""A :class:`PredictiveDataset` subclass for the temporal prediction problem setting,
see :class:`BaseDataset` docs.
In this setting: ``targets`` are required for fitting, will be initialized as `TimeSeriesSamples`.
"""
super().__init__(time_series=time_series, static=static, targets=targets, treatments=treatments, **kwargs)
def _init_predictive(
self,
targets: Optional[data_typing.DataContainer],
treatments: Optional[data_typing.DataContainer],
**kwargs: Any,
) -> None:
if targets is None:
logger.debug(
f"`targets` provided was None for {self.__class__.__name__}, "
"this Dataset can only be used for prediction not fitting"
)
self.predictive = pred.TemporalPredictionTaskData(parent_dataset=self, targets=targets, **kwargs)
def _validate(self) -> None:
if self.predictive.targets is not None:
if self.predictive.targets.sample_index() != self.time_series.sample_index():
raise ValueError(EXCEPTION_MESSAGES.sample_index_mismatch.targets)
if self.predictive.targets.time_indexes() != self.time_series.time_indexes():
raise ValueError(EXCEPTION_MESSAGES.time_indexes_mismatch.targets)
@property
def fit_ready(self) -> bool:
"""Check if the dataset is ready to be fit on.
Returns:
bool: Whether the dataset is ready to be fit on.
"""
return self.predictive.targets is not None
@property
def predict_ready(self) -> bool:
"""Check if the dataset is ready to be predicted on.
Returns:
bool: Whether the dataset is ready to be predicted on.
"""
return True
[docs]class TimeToEventAnalysisDataset(PredictiveDataset):
predictive: pred.TimeToEventAnalysisTaskData
def __init__(
self,
time_series: data_typing.DataContainer,
*,
targets: Optional[data_typing.DataContainer],
static: Optional[data_typing.DataContainer] = None,
treatments: Optional[data_typing.DataContainer] = None,
**kwargs: Any,
) -> None:
"""A :class:`PredictiveDataset` subclass for the time-to-event analysis problem setting,
see :class:`BaseDataset` docs.
In this setting: ``targets`` are required for fitting, will be initialized as `EventSamples`.
"""
super().__init__(time_series=time_series, static=static, targets=targets, treatments=treatments, **kwargs)
def _init_predictive(
self,
targets: Optional[data_typing.DataContainer],
treatments: Optional[data_typing.DataContainer],
**kwargs: Any,
) -> None:
if targets is None:
logger.debug(
f"`targets` provided was None for {self.__class__.__name__}, "
"this Dataset can only be used for prediction not fitting"
)
self.predictive = pred.TimeToEventAnalysisTaskData(parent_dataset=self, targets=targets, **kwargs)
def _validate(self) -> None:
if self.predictive.targets is not None:
if self.predictive.targets.sample_index() != self.time_series.sample_index():
raise ValueError(EXCEPTION_MESSAGES.sample_index_mismatch.targets)
# TODO: Possible checks - some checks on .time_series and .predictive.targets in terms of
# their relative position in time?
@property
def fit_ready(self) -> bool:
"""Check if the dataset is ready to be fit on.
Returns:
bool: Whether the dataset is ready to be fit on.
"""
return self.predictive.targets is not None
@property
def predict_ready(self) -> bool:
"""Check if the dataset is ready to be predicted on.
Returns:
bool: Whether the dataset is ready to be predicted on.
"""
return True
[docs]class OneOffTreatmentEffectsDataset(PredictiveDataset):
predictive: pred.OneOffTreatmentEffectsTaskData
def __init__(
self,
time_series: data_typing.DataContainer,
*,
targets: Optional[data_typing.DataContainer],
treatments: data_typing.DataContainer,
static: Optional[data_typing.DataContainer] = None,
**kwargs: Any,
) -> None:
"""A :class:`PredictiveDataset` subclass for the one-off treatment effects problem setting,
see :class:`BaseDataset` docs.
In this setting: ``targets`` are required for fitting, will be initialized as `TimeSeriesSamples`;
``treatments`` are required for both fitting and prediction, will be initialized as `EventSamples`.
"""
super().__init__(time_series=time_series, static=static, targets=targets, treatments=treatments, **kwargs)
def _init_predictive(
self,
targets: Optional[data_typing.DataContainer],
treatments: Optional[data_typing.DataContainer],
**kwargs: Any,
) -> None:
if targets is None:
logger.debug(
f"`targets` provided was None for {self.__class__.__name__}, "
"this Dataset can only be used for prediction not fitting"
)
if treatments is None:
raise ValueError("One-off treatment effects task requires `treatments`")
self.predictive = pred.OneOffTreatmentEffectsTaskData(
parent_dataset=self, targets=targets, treatments=treatments, **kwargs
)
def _validate(self) -> None:
if self.predictive.targets is not None:
if self.predictive.targets.sample_index() != self.time_series.sample_index():
raise ValueError(EXCEPTION_MESSAGES.sample_index_mismatch.targets)
if self.predictive.targets.time_indexes() != self.time_series.time_indexes():
raise ValueError(EXCEPTION_MESSAGES.time_indexes_mismatch.targets)
if self.predictive.treatments.sample_index() != self.time_series.sample_index():
raise ValueError(EXCEPTION_MESSAGES.sample_index_mismatch.treatments)
# TODO: Possible checks - some checks on .time_series and .predictive.treatments in terms of
# their relative position in time?
@property
def fit_ready(self) -> bool:
"""Check if the dataset is ready to be fit on.
Returns:
bool: Whether the dataset is ready to be fit on.
"""
return self.predictive.targets is not None and self.predictive.treatments is not None
@property
def predict_ready(self) -> bool:
"""Check if the dataset is ready to be predicted on.
Returns:
bool: Whether the dataset is ready to be predicted on.
"""
return self.predictive.treatments is not None
[docs]class TemporalTreatmentEffectsDataset(PredictiveDataset):
predictive: pred.TemporalTreatmentEffectsTaskData
def __init__(
self,
time_series: data_typing.DataContainer,
*,
targets: Optional[data_typing.DataContainer],
treatments: data_typing.DataContainer,
static: Optional[data_typing.DataContainer] = None,
**kwargs: Any,
) -> None:
"""A :class:`PredictiveDataset` subclass for the temporal treatment effects problem setting,
see :class:`BaseDataset` docs.
In this setting: ``targets`` are required for fitting, will be initialized as `TimeSeriesSamples`;
``treatments`` are required for both fitting and prediction, will be initialized as `TimeSeriesSamples`.
"""
super().__init__(time_series=time_series, static=static, targets=targets, treatments=treatments, **kwargs)
def _init_predictive(
self,
targets: Optional[data_typing.DataContainer],
treatments: Optional[data_typing.DataContainer],
**kwargs: Any,
) -> None:
if targets is None:
logger.debug(
f"`targets` provided was None for {self.__class__.__name__}, "
"this Dataset can only be used for prediction not fitting"
)
if treatments is None:
raise ValueError("Temporal treatment effects task requires `treatments`")
self.predictive = pred.TemporalTreatmentEffectsTaskData(
parent_dataset=self, targets=targets, treatments=treatments, **kwargs
)
def _validate(self) -> None:
if self.predictive.targets is not None:
if self.predictive.targets.sample_index() != self.time_series.sample_index():
raise ValueError(EXCEPTION_MESSAGES.sample_index_mismatch.targets)
if self.predictive.targets.time_indexes() != self.time_series.time_indexes():
raise ValueError(EXCEPTION_MESSAGES.time_indexes_mismatch.targets)
if self.predictive.treatments.sample_index() != self.time_series.sample_index():
raise ValueError(EXCEPTION_MESSAGES.sample_index_mismatch.treatments)
if self.predictive.treatments.time_indexes() != self.time_series.time_indexes():
raise ValueError(EXCEPTION_MESSAGES.time_indexes_mismatch.treatments)
@property
def fit_ready(self) -> bool:
"""Check if the dataset is ready to be fit on.
Returns:
bool: Whether the dataset is ready to be fit on.
"""
return self.predictive.targets is not None and self.predictive.treatments is not None
@property
def predict_ready(self) -> bool:
"""Check if the dataset is ready to be predicted on.
Returns:
bool: Whether the dataset is ready to be predicted on.
"""
return self.predictive.treatments is not None