Source code for tempor.models.clairvoyance2.data.dataset

# mypy: ignore-errors

import copy
from collections.abc import Sequence as SequenceABC
from typing import Callable, Dict, Iterator, Optional, Sequence, Tuple, Union

import numpy as np

from ..utils.dev import raise_not_implemented
from ._utils import time_index_equal
from .constants import T_SamplesIndexDtype
from .dataformat import (
    EventSamples,
    StaticSamples,
    T_ContainerInitializable,
    T_SampleIndex_Compatible,
    T_SampleIndexClass,
    T_SamplesIndexDtype_AsTuple,
    T_TSS_ContainerInitializable,
    TimeSeriesSamples,
    TMissingIndicator,
)
from .dataformat_base import Copyable, SupportsNewLike

T_TimeSeriesSamplesInitializable = Union[TimeSeriesSamples, Sequence[T_TSS_ContainerInitializable]]
T_StaticSamplesInitializable = Union[StaticSamples, T_ContainerInitializable]
T_EventSamplesInitializable = Union[EventSamples]  # pyright: ignore


# TODO: Should probably implement a number of mixins that TimeSeriesSamples etc. have.
[docs]class Dataset(Copyable, SupportsNewLike, SequenceABC): temporal_covariates: TimeSeriesSamples static_covariates: Optional[StaticSamples] = None event_covariates: Optional[EventSamples] = None temporal_targets: Optional[TimeSeriesSamples] = None temporal_treatments: Optional[TimeSeriesSamples] = None event_targets: Optional[EventSamples] = None event_treatments: Optional[EventSamples] = None _container_names = [ "temporal_covariates", "static_covariates", "event_covariates", "temporal_targets", "temporal_treatments", "event_targets", "event_treatments", ] # NOTE: ^ Keep the above in the same order as these are initialized in the constructor. _container_order = { container_name: container_order for container_order, container_name in enumerate(_container_names) } def _init_time_series_samples(self, data: T_TimeSeriesSamplesInitializable, **kwargs) -> TimeSeriesSamples: if isinstance(data, TimeSeriesSamples): return data else: return TimeSeriesSamples(data, **kwargs) def _init_static_samples(self, data: T_StaticSamplesInitializable, **kwargs) -> StaticSamples: if isinstance(data, StaticSamples): return data else: return StaticSamples(data, **kwargs) def _init_event_samples(self, data: T_EventSamplesInitializable, **kwargs) -> EventSamples: if isinstance(data, EventSamples): return data else: return EventSamples(data, **kwargs) def __init__( self, temporal_covariates: T_TimeSeriesSamplesInitializable, static_covariates: Optional[T_StaticSamplesInitializable] = None, event_covariates: Optional[T_EventSamplesInitializable] = None, temporal_targets: Optional[T_TimeSeriesSamplesInitializable] = None, temporal_treatments: Optional[T_TimeSeriesSamplesInitializable] = None, event_targets: Optional[T_EventSamplesInitializable] = None, event_treatments: Optional[T_EventSamplesInitializable] = None, # NOTE: The below arguments are only applicable when *initializing* the above containers from DF, array etc., # not applicable when passing as TimeSeriesSamples/StaticSamples objects. # TODO: Maybe this is not user-friendly, could redo such that it will always set these. sample_indices: Optional[T_SampleIndex_Compatible] = None, missing_indicator: TMissingIndicator = np.nan, ) -> None: self.temporal_covariates = self._init_time_series_samples( temporal_covariates, sample_indices=sample_indices, missing_indicator=missing_indicator, ) if static_covariates is not None: self.static_covariates = self._init_static_samples( static_covariates, sample_indices=sample_indices, missing_indicator=missing_indicator, ) if event_covariates is not None: self.event_covariates = self._init_event_samples(data=event_covariates) if temporal_targets is not None: self.temporal_targets = self._init_time_series_samples( temporal_targets, sample_indices=sample_indices, missing_indicator=missing_indicator, ) if temporal_treatments is not None: self.temporal_treatments = self._init_time_series_samples( temporal_treatments, sample_indices=sample_indices, missing_indicator=missing_indicator, ) if event_targets is not None: self.event_targets = self._init_event_samples(data=event_targets) if event_treatments is not None: self.event_treatments = self._init_event_samples(data=event_treatments) self._refresh_container_dicts() self.validate() def _sort_container_dict(self, container_dict): return {k: container_dict[k] for k in sorted(container_dict, key=lambda x: self._container_order[x])} def _refresh_container_dicts(self) -> None: self._static_data_containers: Dict[str, Optional[StaticSamples]] = { "static_covariates": self.static_covariates, } self._temporal_data_containers: Dict[str, Optional[TimeSeriesSamples]] = { "temporal_covariates": self.temporal_covariates, "temporal_targets": self.temporal_targets, "temporal_treatments": self.temporal_treatments, } self._event_data_containers: Dict[str, Optional[EventSamples]] = { "event_covariates": self.event_covariates, "event_targets": self.event_targets, "event_treatments": self.event_treatments, } self._all_data_containers: Dict[str, Union[StaticSamples, TimeSeriesSamples, EventSamples, None]] = copy.copy( self._static_data_containers # type: ignore # Invariance issue, ignore. ) self._all_data_containers.update(self._temporal_data_containers) self._all_data_containers.update(self._event_data_containers) self._static_data_containers = self._sort_container_dict(self._static_data_containers) self._temporal_data_containers = self._sort_container_dict(self._temporal_data_containers) self._event_data_containers = self._sort_container_dict(self._event_data_containers) self._all_data_containers = self._sort_container_dict(self._all_data_containers) @staticmethod def _time_series_samples_repr(time_series_samples: TimeSeriesSamples) -> str: name = time_series_samples.__class__.__name__ if time_series_samples.all_samples_same_n_timesteps: n_timesteps = str(time_series_samples.n_timesteps_per_sample[0]) else: n_timesteps = "*" shape = f"[{time_series_samples.n_samples},{n_timesteps},{len(time_series_samples.features)}]" return f"{name}({shape})" @staticmethod def _static_samples_repr(static_samples: StaticSamples) -> str: name = static_samples.__class__.__name__ shape = f"[{static_samples.n_samples},{len(static_samples.features)}]" return f"{name}({shape})" @staticmethod def _event_samples_repr(event_samples: EventSamples) -> str: name = event_samples.__class__.__name__ shape = f"[{event_samples.n_samples},{len(event_samples.features)}]" return f"{name}({shape})" def __repr__(self) -> str: tab = " " sep = f"\n{tab}" attributes_repr = "" for container_name, container in self.all_data_containers.items(): if isinstance(container, StaticSamples): repr_helper: Callable = self._static_samples_repr elif isinstance(container, TimeSeriesSamples): repr_helper = self._time_series_samples_repr else: repr_helper = self._event_samples_repr attributes_repr += f"{sep}{container_name}={repr_helper(container)}," # type: ignore return f"{self.__class__.__name__}({attributes_repr}\n)" @staticmethod def _sample_indices_eq( sample_indices_a: Sequence[T_SamplesIndexDtype], sample_indices_b: Sequence[T_SamplesIndexDtype] ) -> bool: if len(sample_indices_a) != len(sample_indices_b): return False else: return all(a == b for a, b in zip(sample_indices_a, sample_indices_b)) @property def n_samples(self) -> int: return self.temporal_covariates.n_samples @property def sample_index(self) -> T_SampleIndexClass: # type: ignore return self.temporal_covariates.sample_index @property def sample_indices(self) -> Sequence[T_SamplesIndexDtype]: return self.temporal_covariates.sample_indices @property def static_data_containers(self) -> Dict[str, StaticSamples]: self._refresh_container_dicts() return {k: v for k, v in self._static_data_containers.items() if v is not None} @property def temporal_data_containers(self) -> Dict[str, TimeSeriesSamples]: self._refresh_container_dicts() return {k: v for k, v in self._temporal_data_containers.items() if v is not None} @property def event_data_containers(self) -> Dict[str, EventSamples]: self._refresh_container_dicts() return {k: v for k, v in self._event_data_containers.items() if v is not None} @property def all_data_containers(self) -> Dict[str, Union[StaticSamples, TimeSeriesSamples, EventSamples]]: self._refresh_container_dicts() return {k: v for k, v in self._all_data_containers.items() if v is not None}
[docs] def validate(self) -> None: # Validate number of samples. # NOTE: Strictly this isn't necessary as checked below, but here for convenience. n_samples_expected = self.temporal_covariates.n_samples for container_name, container in self.all_data_containers.items(): if container.n_samples != n_samples_expected: raise ValueError( f"Expected {n_samples_expected} samples " f"but found {container.n_samples} in `{container_name}`" ) # Validate that the sample indices match across all containers. sample_indices_expected = self.temporal_covariates.sample_indices for container_name, container in self.all_data_containers.items(): if not self._sample_indices_eq(container.sample_indices, sample_indices_expected): raise ValueError( f"The sample indices in `{container_name}` didn't match those in `temporal_covariates`" )
[docs] def check_temporal_containers_have_same_time_index(self) -> Tuple[bool, Optional[Tuple[str, str]]]: list_container_name_container = list(self.temporal_data_containers.items()) pairs = list(zip(list_container_name_container, list_container_name_container[1:])) for (a_name, a), (b_name, b) in pairs: if not time_index_equal(a, b): return False, (a_name, b_name) return True, None
[docs] @staticmethod def new_like(like: "Dataset", **kwargs) -> "Dataset": kwargs = SupportsNewLike.process_kwargs( kwargs, dict( temporal_covariates=like.temporal_covariates, static_covariates=like.static_covariates, event_covariates=like.event_covariates, temporal_targets=like.temporal_targets, temporal_treatments=like.temporal_treatments, event_targets=like.event_targets, event_treatments=like.event_treatments, sample_indices=like.temporal_covariates.sample_indices, missing_indicator=like.temporal_covariates.missing_indicator, ), ) return Dataset(**kwargs) # type: ignore # Mypy complains about kwargs but it's fine.
[docs] @staticmethod def new_empty_like(like: "Dataset", **kwargs) -> "Dataset": raise_not_implemented("Dataset.new_empty_like method")
# --- Sequence Interface --- def __getitem__(self, key) -> "Dataset": if not isinstance(key, (slice)) and not hasattr(key, "__len__"): if not isinstance(key, T_SamplesIndexDtype_AsTuple): raise KeyError(f"Expected `key` to be one of types {T_SamplesIndexDtype_AsTuple} but {type(key)} found") key = (key,) kwargs = dict() for container_name, container in self.all_data_containers.items(): kwargs[container_name] = container[key] return self.new_like(like=self, **kwargs) def __len__(self) -> int: return len(self.temporal_covariates) def __iter__(self) -> Iterator["Dataset"]: for idx in self.sample_indices: yield self[idx] def __contains__(self, value) -> bool: return value in self.sample_indices def __reversed__(self) -> Iterator["Dataset"]: for idx in reversed(self.sample_indices): yield self[idx]
[docs] def index(self, value, start=0, stop=None): raise NotImplementedError
[docs] def count(self, value): raise NotImplementedError
# --- Sequence Interface (End) ---