Source code for tempor.models.clairvoyance2.interface.requirements

# mypy: ignore-errors

from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING, NoReturn, Optional, Sequence, Union

from ..data import Dataset, EventSamples, StaticSamples, TimeSeries, TimeSeriesSamples
from ..data.constants import T_NumericDtype_AsTuple, T_SamplesIndexDtype
from ..utils.common import python_type_from_np_pd_dtype
from .horizon import Horizon, HorizonOpts, NStepAheadHorizon

if TYPE_CHECKING:
    from .model import TTreatmentScenarios


[docs]class DataStructureOpts(Enum): TIME_SERIES = auto() STATIC = auto() EVENT = auto()
[docs]class DataValueOpts(Enum): ANY = auto() NUMERIC = auto() NUMERIC_CATEGORICAL = auto() NUMERIC_BINARY = auto()
[docs]@dataclass(frozen=True) class DatasetRequirements: # Miscellaneous: requires_static_covariates_present: bool = False requires_no_missing_data: bool = False # Value types: static_covariates_value_type: DataValueOpts = DataValueOpts.ANY temporal_covariates_value_type: DataValueOpts = DataValueOpts.ANY temporal_targets_value_type: DataValueOpts = DataValueOpts.ANY temporal_treatments_value_type: DataValueOpts = DataValueOpts.ANY event_covariates_value_type: DataValueOpts = DataValueOpts.ANY event_targets_value_type: DataValueOpts = DataValueOpts.ANY event_treatments_value_type: DataValueOpts = DataValueOpts.ANY # Special temporal requirements: requires_all_temporal_data_samples_aligned: bool = False requires_all_temporal_data_regular: bool = False requires_all_temporal_data_index_numeric: bool = False requires_all_temporal_containers_shares_index: bool = True
[docs]@dataclass(frozen=True) class PredictionRequirements: target_data_structure: DataStructureOpts = DataStructureOpts.TIME_SERIES horizon_type: HorizonOpts = HorizonOpts.N_STEP_AHEAD min_timesteps_target_when_fit: int = 1 min_timesteps_target_when_predict: int = 1
[docs]@dataclass(frozen=True) class TreatmentEffectsRequirements: # NOTE: target, horizon are expected to be the same as for PredictionRequirements. treatment_data_structure: DataStructureOpts = DataStructureOpts.TIME_SERIES min_timesteps_treatment_when_fit: int = 1 min_timesteps_treatment_when_predict: int = 1 min_timesteps_treatment_when_predict_counterfactual: int = 1
[docs]@dataclass(frozen=True) class Requirements: dataset_requirements: DatasetRequirements = DatasetRequirements() prediction_requirements: Optional[PredictionRequirements] = None treatment_effects_requirements: Optional[TreatmentEffectsRequirements] = None
[docs]def raise_requirements_mismatch_error(requirement_name: str, explanation_text: str) -> NoReturn: raise RuntimeError(f"Requirements mismatch occurred. Requirement: '{requirement_name}'. {explanation_text}.")
[docs]def get_container_friendly_name(container_name: str) -> str: return container_name.replace("_", " ")
# NOTE: Needs more testing.
[docs]class RequirementsChecker: @staticmethod def _check_min_timesteps(min_timesteps: int, container: TimeSeriesSamples, container_name: str, preface: str): if min(container.n_timesteps_per_sample) < min_timesteps: raise_requirements_mismatch_error( f"{preface} {get_container_friendly_name(container_name)}", f"Requires at least {min_timesteps} but a sample with " f"{min(container.n_timesteps_per_sample)} timesteps was found", ) @staticmethod def _check_data_value_type( requirement: DataValueOpts, container: Union[TimeSeriesSamples, StaticSamples, EventSamples], preface: str, ): if requirement in ( DataValueOpts.NUMERIC, DataValueOpts.NUMERIC_CATEGORICAL, DataValueOpts.NUMERIC_BINARY, ): if not container.all_features_numeric: raise_requirements_mismatch_error( f"{preface} `{requirement.name}`", f"Incompatible data found. Preview:\n{container.df.head()}", ) if requirement == DataValueOpts.NUMERIC_CATEGORICAL: if not container.all_features_categorical: raise_requirements_mismatch_error( f"{preface} `{requirement.name}`", f"Incompatible data found. Preview:\n{container.df.head()}", ) if requirement == DataValueOpts.NUMERIC_BINARY: if not container.all_features_binary: raise_requirements_mismatch_error( f"{preface} `{requirement.name}`", f"Incompatible data found. Preview:\n{container.df.head()}", ) @staticmethod def _check_data_requirements_predict( # pylint: disable=unused-argument called_at_fit_time: bool, requirements: Requirements, data: Dataset, horizon: Optional[Horizon], **kwargs, ): if requirements.prediction_requirements is None: raise RuntimeError("Prediction requirements need to be set on a predictor model, but None found") if requirements.prediction_requirements.target_data_structure == DataStructureOpts.TIME_SERIES: if data.temporal_targets is None: raise_requirements_mismatch_error( f"Prediction requirement: prediction target `{DataStructureOpts.TIME_SERIES}`", "Dataset must contain temporal targets in this case but did not", ) if called_at_fit_time: if requirements.prediction_requirements.min_timesteps_target_when_fit is not None: RequirementsChecker._check_min_timesteps( min_timesteps=requirements.prediction_requirements.min_timesteps_target_when_fit, container=data.temporal_targets, container_name="temporal_targets", preface="Prediction requirement: minimum number of timesteps at fit-time,", ) else: RequirementsChecker._check_min_timesteps( min_timesteps=requirements.prediction_requirements.min_timesteps_target_when_predict, container=data.temporal_targets, container_name="temporal_targets", preface="Prediction requirement: minimum number of timesteps at predict-time,", ) if requirements.treatment_effects_requirements is not None and data.temporal_treatments is not None: RequirementsChecker._check_min_timesteps( min_timesteps=requirements.treatment_effects_requirements.min_timesteps_treatment_when_predict, container=data.temporal_treatments, container_name="temporal_treatments", preface="Treatment effects requirement: minimum number of timesteps at predict-time,", ) if horizon is not None: if requirements.prediction_requirements.horizon_type == HorizonOpts.N_STEP_AHEAD: if not isinstance(horizon, NStepAheadHorizon): raise_requirements_mismatch_error( f"Prediction requirement: prediction horizon `{HorizonOpts.N_STEP_AHEAD}`", f"A prediction horizon of type {NStepAheadHorizon} is expected, but found {type(horizon)}", ) for container_name, container in data.temporal_data_containers.items(): len_ = max(container.n_timesteps_per_sample) if horizon.n_step >= len_: raise_requirements_mismatch_error( f"Prediction requirement: prediction horizon `{HorizonOpts.N_STEP_AHEAD}`", "N step ahead horizon must be < max timesteps in " f"{get_container_friendly_name(container_name)}, but was " f"{horizon.n_step} >= {len_}", ) # PredictionTargetType.TIME_SERIES > PredictionHorizonType.TIME_INDEX: if requirements.prediction_requirements.horizon_type == HorizonOpts.TIME_INDEX: # TODO: Implement any data requirements. pass elif requirements.prediction_requirements.target_data_structure == DataStructureOpts.EVENT: # TODO: Any requirements checks. pass elif requirements.prediction_requirements.target_data_structure == DataStructureOpts.STATIC: # TODO: Any requirements checks. pass @staticmethod def _check_data_requirements_predict_counterfactuals( # pylint: disable=unused-argument called_at_fit_time: bool, requirements: Requirements, data: Dataset, sample_index: Optional[T_SamplesIndexDtype], treatment_scenarios: Optional["TTreatmentScenarios"], horizon: Optional[Horizon], **kwargs, ): if requirements.treatment_effects_requirements is None: raise RuntimeError( "Treatment effects requirements need to be set on a treatment effects model, but None found" ) # DataStructure.TIME_SERIES: if requirements.treatment_effects_requirements.treatment_data_structure == DataStructureOpts.TIME_SERIES: if data.temporal_targets is None: raise_requirements_mismatch_error( f"Treatment effects requirements: treatment type `{DataStructureOpts.TIME_SERIES}`", "Dataset must contain temporal targets in this case but did not", ) if data.temporal_treatments is None: raise_requirements_mismatch_error( f"Treatment effects requirements: treatment type `{DataStructureOpts.TIME_SERIES}`", "Dataset must contain temporal treatments in this case but did not", ) if called_at_fit_time: RequirementsChecker._check_min_timesteps( min_timesteps=requirements.treatment_effects_requirements.min_timesteps_treatment_when_fit, container=data.temporal_treatments, container_name="temporal_treatments", preface="Treatment effects requirement: minimum number of timesteps at fit-time,", ) else: RequirementsChecker._check_min_timesteps( min_timesteps=requirements.treatment_effects_requirements.min_timesteps_treatment_when_predict_counterfactual, container=data.temporal_targets, container_name="temporal_targets", preface="Treatment effects requirement: minimum number of timesteps at " "predict-counterfactual-time,", ) elif requirements.treatment_effects_requirements.treatment_data_structure == DataStructureOpts.EVENT: # TODO: Any requirements checks. pass elif requirements.treatment_effects_requirements.treatment_data_structure == DataStructureOpts.STATIC: # TODO: Any requirements checks. pass # TODO: The below is temporary. Interface is not settled and may change. if treatment_scenarios is not None: if requirements.treatment_effects_requirements.treatment_data_structure == DataStructureOpts.TIME_SERIES: assert isinstance(treatment_scenarios, Sequence) for ts in treatment_scenarios: assert isinstance(ts, TimeSeries) if requirements.dataset_requirements.temporal_treatments_value_type == DataValueOpts.NUMERIC_BINARY: for ts in treatment_scenarios: assert ts.all_features_numeric and ts.all_features_binary
[docs] @staticmethod def check_data_requirements_general(called_at_fit_time: bool, requirements: Requirements, data: Dataset, **kwargs): # General data requirements. # Miscellaneous. if requirements.dataset_requirements.requires_static_covariates_present: if data.static_covariates is None: raise_requirements_mismatch_error( "Dataset requirement: requires static samples", "Dataset did not have static samples" ) if requirements.dataset_requirements.requires_no_missing_data: for container_name, container in data.all_data_containers.items(): if container.has_missing: raise_requirements_mismatch_error( "Dataset requirement: requires no missing data", f"Dataset {get_container_friendly_name(container_name)} had missing data", ) # Check data value types. if data.static_covariates is not None and requirements.dataset_requirements.static_covariates_value_type: RequirementsChecker._check_data_value_type( requirement=requirements.dataset_requirements.static_covariates_value_type, container=data.static_covariates, preface="Dataset requirement: static covariates data type", ) if requirements.dataset_requirements.temporal_covariates_value_type: RequirementsChecker._check_data_value_type( requirement=requirements.dataset_requirements.temporal_covariates_value_type, container=data.temporal_covariates, preface="Dataset requirement: temporal covariates data type", ) if data.temporal_targets is not None and requirements.dataset_requirements.temporal_targets_value_type: RequirementsChecker._check_data_value_type( requirement=requirements.dataset_requirements.temporal_targets_value_type, container=data.temporal_targets, preface="Dataset requirement: temporal target data type", ) if data.temporal_treatments is not None and requirements.dataset_requirements.temporal_treatments_value_type: RequirementsChecker._check_data_value_type( requirement=requirements.dataset_requirements.temporal_treatments_value_type, container=data.temporal_treatments, preface="Dataset requirement: temporal treatment data type", ) if data.event_covariates is not None and requirements.dataset_requirements.event_covariates_value_type: RequirementsChecker._check_data_value_type( requirement=requirements.dataset_requirements.event_covariates_value_type, container=data.event_covariates, preface="Dataset requirement: event covariates data type", ) if data.event_targets is not None and requirements.dataset_requirements.event_targets_value_type: RequirementsChecker._check_data_value_type( requirement=requirements.dataset_requirements.event_targets_value_type, container=data.event_targets, preface="Dataset requirement: event target data type", ) if data.event_treatments is not None and requirements.dataset_requirements.event_treatments_value_type: RequirementsChecker._check_data_value_type( requirement=requirements.dataset_requirements.event_treatments_value_type, container=data.event_treatments, preface="Dataset requirement: event treatment data type", ) # Special temporal requirements. if requirements.dataset_requirements.requires_all_temporal_data_regular: for container_name, container in data.temporal_data_containers.items(): is_regular, _ = container.is_regular() # TODO: Compare the diff. and ensure they are the same? if not is_regular: raise_requirements_mismatch_error( "Dataset requirement: requires regular timeseries", f"Dataset {get_container_friendly_name(container_name)} did not have a regular time index", ) if requirements.dataset_requirements.requires_all_temporal_data_samples_aligned: for container_name, container in data.temporal_data_containers.items(): if not container.all_samples_aligned: raise_requirements_mismatch_error( "Dataset requirement: requires aligned timeseries", f"Dataset {get_container_friendly_name(container_name)} were not all aligned by their index", ) if requirements.dataset_requirements.requires_all_temporal_data_index_numeric: acceptable_types = T_NumericDtype_AsTuple for container_name, container in data.temporal_data_containers.items(): if len(container) > 0: ts = container[0] if TYPE_CHECKING: assert isinstance(ts, TimeSeries) dtype = python_type_from_np_pd_dtype(ts.time_index.dtype) # type: ignore if dtype not in acceptable_types: raise_requirements_mismatch_error( "Dataset requirement: requires numeric timeseries index", f"Dataset {get_container_friendly_name(container_name)} had index of dtype {dtype}", ) if requirements.dataset_requirements.requires_all_temporal_containers_shares_index: check_outcome, names = data.check_temporal_containers_have_same_time_index() if check_outcome is False: assert names is not None a_name, b_name = names raise_requirements_mismatch_error( "Dataset requirement: requires all temporal containers have same time index (for each sample)", f"The containers {a_name} and {b_name} did not have the same time index for all samples", ) # Try to get additional kwargs if provided. horizon = kwargs.pop("horizon") if "horizon" in kwargs else None sample_index = kwargs.pop("sample_index") if "sample_index" in kwargs else None treatment_scenarios = kwargs.pop("treatment_scenarios") if "treatment_scenarios" in kwargs else None # Prediction-specific data requirements: if requirements.prediction_requirements is not None: RequirementsChecker._check_data_requirements_predict( called_at_fit_time=called_at_fit_time, requirements=requirements, data=data, horizon=horizon, **kwargs ) # Treatment effects -specific data requirements: if requirements.treatment_effects_requirements is not None: # DataStructure.TIME_SERIES: RequirementsChecker._check_data_requirements_predict_counterfactuals( called_at_fit_time=called_at_fit_time, requirements=requirements, data=data, sample_index=sample_index, treatment_scenarios=treatment_scenarios, horizon=horizon, **kwargs, )
[docs] @staticmethod def check_data_requirements_transform(requirements: Requirements, data: Dataset, **kwargs): # Currently no checks. pass
[docs] @staticmethod def check_data_requirements_predict( requirements: Requirements, data: Dataset, horizon: Optional[Horizon], **kwargs ): # Currently no checks. if horizon is None: raise RuntimeError("Prediction model must receive a horizon object at predict-time") RequirementsChecker._check_data_requirements_predict( called_at_fit_time=False, requirements=requirements, data=data, horizon=horizon, **kwargs )
[docs] @staticmethod def check_data_requirements_predict_counterfactuals( requirements: Requirements, data: Dataset, sample_index: T_SamplesIndexDtype, treatment_scenarios: "TTreatmentScenarios", horizon: Optional[Horizon], **kwargs, ): # Currently no checks. if horizon is None: raise RuntimeError("Treatment effects model must receive a horizon object at predict-time") RequirementsChecker._check_data_requirements_predict_counterfactuals( called_at_fit_time=False, requirements=requirements, data=data, sample_index=sample_index, treatment_scenarios=treatment_scenarios, horizon=horizon, **kwargs, )
[docs] @staticmethod def check_predictor_model_requirements(predictor): requirements: Requirements = predictor.requirements if requirements.prediction_requirements is None: raise_requirements_mismatch_error( "Prediction requirements", f"Prediction model {predictor.__class__.__name__} must have prediction requirements defined, " "but found None", )
[docs] @staticmethod def check_treatment_effects_model_requirements(treatment_effects_model): requirements: Requirements = treatment_effects_model.requirements if requirements.treatment_effects_requirements is None: raise_requirements_mismatch_error( "Treatment effects requirements", f"Treatment effects model {treatment_effects_model.__class__.__name__} must have treatment effects " "requirements defined, but found None", )