Source code for tempor.models.clairvoyance2.treatment_effects.synctwin

# mypy: ignore-errors

# NOTE:
# Experimental, minimally tested, will be significantly changed and improved.

from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple

import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split

from ..components.torch import interfaces as ti
from ..components.torch import synctwin_train_utils
from ..components.torch.synctwin_models import LinearDecoder, RegularDecoder, RegularEncoder, SyncTwinModule
from ..data import Dataset, TimeSeries, TimeSeriesSamples
from ..data.constants import T_SamplesIndexDtype
from ..data.utils import time_index_equal, time_index_utils, to_counterfactual_predictions
from ..interface import (
    Horizon,
    TCounterfactualPredictions,
    TDefaultParams,
    TimeIndexHorizon,
    TParams,
    TPredictOutput,
    TreatmentEffectsModel,
    TTreatmentScenarios,
)
from ..interface import requirements as r
from ..utils.array_manipulation import compute_deltas
from ..utils.dev import NEEDED


class _DefaultParams(NamedTuple):
    # Main hyperparameters:
    hidden_size: int = 20
    tau: float = 1.0
    lambda_prognostic: float = 1.0
    lambda_reconstruction: float = 1.0
    batch_size: int = 32
    pretraining_iterations: int = 5_000
    matching_iterations: int = 20_000
    inference_iterations: int = 20_000
    # Misc:
    use_validation_set_in_training: bool = True
    treatment_status_is_treated: int = 1


[docs]class SyncTwinTensors(NamedTuple): x_full: torch.Tensor t_full: torch.Tensor mask_full: torch.Tensor batch_ind_full: torch.Tensor y_full: torch.Tensor y_control: torch.Tensor y_mask_full: torch.Tensor
# TODO: When returning the result array, MUST ASSIGN TO APPROPRIATE SAMPLE INDICES!
[docs]class SyncTwinRegressor( TreatmentEffectsModel, ti.OrganizedTreatmentEffectsModuleMixin, ti.OrganizedPredictorModuleMixin, ti.OrganizedModule ): requirements: r.Requirements = r.Requirements( dataset_requirements=r.DatasetRequirements( temporal_covariates_value_type=r.DataValueOpts.NUMERIC, temporal_targets_value_type=r.DataValueOpts.NUMERIC, static_covariates_value_type=r.DataValueOpts.NUMERIC, event_treatments_value_type=r.DataValueOpts.NUMERIC_BINARY, requires_no_missing_data=True, # TODO: For now, eventually allow. ), prediction_requirements=r.PredictionRequirements( target_data_structure=r.DataStructureOpts.TIME_SERIES, horizon_type=r.HorizonOpts.TIME_INDEX, ), treatment_effects_requirements=r.TreatmentEffectsRequirements( treatment_data_structure=r.DataStructureOpts.EVENT, ), ) DEFAULT_PARAMS: TDefaultParams = _DefaultParams() # Fixed params: _lambda_express: float = 1.0 _reg_B: float = 0.0 _self_expressive_lr: float = 0.001 _validation_set_frac = 0.5 _prediction_compute_iters = 500 expected_treatment_statuses = (0, 1) def _get_other_treatment_status(self, treatment_status_indicator) -> int: assert treatment_status_indicator in self.expected_treatment_statuses return int(1 - treatment_status_indicator) def __init__(self, params: Optional[TParams] = None) -> None: TreatmentEffectsModel.__init__(self, params) ti.OrganizedModule.__init__(self) # Quick validation: if self.params.treatment_status_is_treated not in self.expected_treatment_statuses: raise ValueError( f"`treatment_status_is_treated` must be one of: {self.expected_treatment_statuses} " f"but was {self.params.treatment_status_is_treated}" ) self._treated_indicator = self.params.treatment_status_is_treated self._control_indicator = self._get_other_treatment_status(self.params.treatment_status_is_treated) # Components: self.encoder: Optional[nn.Module] = NEEDED self.decoder: Optional[nn.Module] = NEEDED self.decoder_y: Optional[nn.Module] = NEEDED self.synctwin: Optional[SyncTwinModule] = NEEDED # Helpers: self._predict_synctwin_n_unit = None self._predict_synctwin_n_treated = None self._pretraining_test_freq = max(self.params.pretraining_iterations // 10, 1) self._matching_test_freq = max(self.params.matching_iterations // 10, 1) self._inference_iterations_test_freq = max(self.params.inference_iterations // 10, 1) @property def treated_indicator(self) -> int: return self._treated_indicator @property def control_indicator(self) -> int: return self._control_indicator @staticmethod def _extract_pre_treatment( time_series_samples: TimeSeriesSamples, event_time_indexes, name: str ) -> TimeSeriesSamples: time_series_samples = time_index_utils.time_series_samples.take_all_before_start( # type: ignore time_series_samples, event_time_indexes ) if not time_series_samples.all_samples_same_n_timesteps: raise RuntimeError( f"{SyncTwinRegressor.__name__} requires that {name} up to the treatment event time " "have the same number of timesteps but this was not the case" ) return time_series_samples @staticmethod def _extract_post_treatment( time_series_samples: TimeSeriesSamples, event_time_indexes, name: str ) -> TimeSeriesSamples: time_series_samples = time_index_utils.time_series_samples.take_all_from_start( # type: ignore time_series_samples, event_time_indexes ) if not time_series_samples.all_samples_same_n_timesteps: raise RuntimeError( f"{SyncTwinRegressor.__name__} requires that {name} from treatment event time on " "have the same number of timesteps but this was not the case" ) return time_series_samples def _convert_data_to_synctwin_format(self, data: Dataset, check_only: bool = False): assert data.temporal_targets is not None assert data.event_treatments is not None if data.event_treatments.n_features != 1: raise RuntimeError( f"{SyncTwinRegressor.__name__} requires exactly one event treatments feature but " f"{data.event_treatments.n_features} found" ) # TODO: Shouldn't be limited to 1D targets, generalize. if data.temporal_targets.n_features != 1: raise RuntimeError( f"{SyncTwinRegressor.__name__} requires exactly one temporal targets feature but " f"{data.temporal_targets.n_features} found" ) treatment_feature = list(data.event_treatments.features.values())[0] treatment_categories = tuple(treatment_feature.categories) if treatment_categories != self.expected_treatment_statuses: raise RuntimeError( f"{SyncTwinRegressor.__name__} requires the treatment feature to have the categories: " f"{self.expected_treatment_statuses} but {treatment_categories} found." ) # Get n_{treated,control}_samples, sort control then treatment samples. treatment_feature_name = list(data.event_treatments.features.keys())[0] df = data.event_treatments.df treated_sample_indices = (df[df[treatment_feature_name] == self.treated_indicator]).index.get_level_values(0) control_sample_indices = (df[df[treatment_feature_name] == self.control_indicator]).index.get_level_values(0) n_treated_samples = len(treated_sample_indices) n_control_samples = len(control_sample_indices) if not (n_treated_samples <= n_control_samples): raise RuntimeError( f"{SyncTwinRegressor.__name__} requires the num. treated samples <= n control samples: " f"but these were {n_treated_samples} and {n_control_samples} respectively" ) # Get the time index of each event. event_time_indexes: List = [] for event in data.event_treatments: event_time_indexes.append(event.df.index.get_level_values(1)) # Extract the pre-treatment part of the covariates and targets. pre_treat_temporal_covariates = self._extract_pre_treatment( data.temporal_covariates, event_time_indexes, "temporal covariates" ) pre_treat_temporal_targets = self._extract_pre_treatment( data.temporal_targets, event_time_indexes, "temporal targets" ) if not time_index_equal(pre_treat_temporal_covariates, pre_treat_temporal_targets): raise RuntimeError( f"{SyncTwinRegressor.__name__} requires pre-treatment covariates and targets have " "the same time indexes (for each sample) but this was not the case" ) # Extract the post-treatment targets. post_treat_temporal_targets = self._extract_post_treatment( data.temporal_targets, event_time_indexes, "temporal targets" ) if check_only is True: return True x_full_cov = pre_treat_temporal_covariates.to_torch_tensor(dtype=self.dtype, device=self.device) x_full_targ = pre_treat_temporal_targets.to_torch_tensor(dtype=self.dtype, device=self.device) x_full = torch.cat([x_full_cov, x_full_targ], dim=-1).permute(1, 0, 2) y_full = post_treat_temporal_targets.to_torch_tensor(dtype=self.dtype, device=self.device).permute(1, 0, 2) # Get time deltas. time_index = pre_treat_temporal_targets.to_torch_tensor_time_index(dtype=self.dtype, device=self.device) time_deltas = compute_deltas(time_index) time_deltas[:, 0, :] = 1.0 assert time_deltas.shape[-1] == 1 t_full = torch.cat([time_deltas] * x_full.shape[-1], dim=-1).permute(1, 0, 2) # Get the mask. mask_full = torch.ones_like(x_full) # Treatment value filtering. tr = torch.tensor(data.event_treatments.df.to_numpy()[:, 0], dtype=self.dtype, device=self.device) y_mask_full = torch.zeros_like(tr) y_mask_full[tr == self.treated_indicator] = 0.0 y_mask_full[tr == self.control_indicator] = 1.0 y_control = y_full[:, y_mask_full == 1.0, :] # Batch ind full. batch_ind_full = torch.tensor(data.sample_indices, dtype=torch.long, device=self.device) return ( SyncTwinTensors(x_full, t_full, mask_full, batch_ind_full, y_full, y_control, y_mask_full), (n_treated_samples, n_control_samples), ) def _split_data( self, synctwin_tensors: SyncTwinTensors, test_size: float ) -> Tuple[SyncTwinTensors, SyncTwinTensors]: # TODO: Make this cleaner. x_full = synctwin_tensors.x_full.permute(1, 0, 2).numpy() t_full = synctwin_tensors.t_full.permute(1, 0, 2).numpy() mask_full = synctwin_tensors.mask_full.permute(1, 0, 2).numpy() # batch_ind_full is regenerated as range. y_full = synctwin_tensors.y_full.permute(1, 0, 2).numpy() # y_control is re-done from y_full. y_mask_full = synctwin_tensors.y_mask_full.numpy() ( x_full_train, x_full_val, t_full_train, t_full_val, mask_full_train, mask_full_val, y_full_train, y_full_val, y_mask_full_train, y_mask_full_val, ) = train_test_split(x_full, t_full, mask_full, y_full, y_mask_full, test_size=test_size, stratify=y_mask_full) x_full_train = torch.tensor(x_full_train, dtype=self.dtype, device=self.device).permute(1, 0, 2) t_full_train = torch.tensor(t_full_train, dtype=self.dtype, device=self.device).permute(1, 0, 2) mask_full_train = torch.tensor(mask_full_train, dtype=self.dtype, device=self.device).permute(1, 0, 2) y_full_train = torch.tensor(y_full_train, dtype=self.dtype, device=self.device).permute(1, 0, 2) y_mask_full_train = torch.tensor(y_mask_full_train, dtype=self.dtype, device=self.device) x_full_val = torch.tensor(x_full_val, dtype=self.dtype, device=self.device).permute(1, 0, 2) t_full_val = torch.tensor(t_full_val, dtype=self.dtype, device=self.device).permute(1, 0, 2) mask_full_val = torch.tensor(mask_full_val, dtype=self.dtype, device=self.device).permute(1, 0, 2) y_full_val = torch.tensor(y_full_val, dtype=self.dtype, device=self.device).permute(1, 0, 2) y_mask_full_val = torch.tensor(y_mask_full_val, dtype=self.dtype, device=self.device) y_control_train = y_full_train[:, y_mask_full_train == 1.0, :] y_control_val = y_full_val[:, y_mask_full_val == 1.0, :] # TODO: This is dodgy. Needs to be investigated (batch_ind_full regenerated as range). batch_ind_full_train = torch.arange(y_mask_full_train.shape[0], dtype=torch.long) batch_ind_full_val = torch.arange(y_mask_full_val.shape[0], dtype=torch.long) return ( SyncTwinTensors( x_full=x_full_train, t_full=t_full_train, mask_full=mask_full_train, batch_ind_full=batch_ind_full_train, y_full=y_full_train, y_control=y_control_train, y_mask_full=y_mask_full_train, ), SyncTwinTensors( x_full=x_full_val, t_full=t_full_val, mask_full=mask_full_val, batch_ind_full=batch_ind_full_val, y_full=y_full_val, y_control=y_control_val, y_mask_full=y_mask_full_val, ), ) def _init_submodules(self) -> None: self.encoder = RegularEncoder( input_dim=self.inferred_params.encoder_input_size, hidden_dim=self.params.hidden_size, device=self.device ) self.decoder = RegularDecoder( hidden_dim=self.encoder.hidden_dim, output_dim=self.encoder.input_dim, max_seq_len=self.inferred_params.pre_treat_len, device=self.device, ) self.decoder_y = LinearDecoder( hidden_dim=self.encoder.hidden_dim, output_dim=self.inferred_params.decoder_y_output_size, max_seq_len=self.inferred_params.post_treat_len, device=self.device, ) self.synctwin = SyncTwinModule( n_unit=self.inferred_params.synctwin_n_unit, n_treated=self.inferred_params.synctwin_n_treated, device=self.device, dtype=self.dtype, reg_B=self._reg_B, lam_express=self._lambda_express, lam_recon=self.params.lambda_reconstruction, lam_prognostic=self.params.lambda_prognostic, tau=self.params.tau, encoder=self.encoder, decoder=self.decoder, decoder_Y=self.decoder_y, ) def _init_inferred_params(self, data: Dataset, **kwargs) -> None: # All have already been set in _prep_data_for_fit(). pass def _init_optimizers(self) -> None: # Handled elsewhere. pass def _prep_data_for_fit(self, data: Dataset, **kwargs): synctwin_tensors, _ = self._convert_data_to_synctwin_format(data) # type: ignore if self.params.use_validation_set_in_training: synctwin_tensors_tain, synctwin_tensors_val = self._split_data(synctwin_tensors, self._validation_set_frac) else: synctwin_tensors_tain = synctwin_tensors synctwin_tensors_val = synctwin_tensors if synctwin_tensors_tain.x_full.shape[1] != synctwin_tensors_val.x_full.shape[1]: raise RuntimeError("Was not possible to split data into test and validation set 50:50") self.inferred_params.encoder_input_size = synctwin_tensors_tain.x_full.shape[-1] self.inferred_params.pre_treat_len = synctwin_tensors_tain.x_full.shape[0] self.inferred_params.decoder_y_output_size = synctwin_tensors_tain.y_full.shape[-1] self.inferred_params.post_treat_len = synctwin_tensors_tain.y_full.shape[0] self.inferred_params.synctwin_n_unit = (synctwin_tensors_tain.y_mask_full == 1.0).sum().item() self.inferred_params.synctwin_n_treated = (synctwin_tensors_tain.y_mask_full == 0.0).sum().item() return synctwin_tensors_tain, synctwin_tensors_val def _prep_submodules_for_fit(self) -> None: if TYPE_CHECKING: assert self.encoder is not None and self.decoder is not None assert self.decoder_y is not None assert self.synctwin is not None self.encoder.to(self.device, dtype=self.dtype) self.decoder.to(self.device, dtype=self.dtype) self.decoder_y.to(self.device, dtype=self.dtype) self.synctwin.to(self.device, dtype=self.dtype) self.encoder.train() self.decoder.train() self.decoder_y.train() self.synctwin.train() def _prep_data_for_predict(self, data: Dataset, horizon: Optional[Horizon], **kwargs) -> Tuple[torch.Tensor, ...]: raise NotImplementedError def _prep_submodules_for_predict(self) -> None: raise NotImplementedError def _prep_data_for_predict_counterfactuals( self, data: Dataset, sample_index: T_SamplesIndexDtype, treatment_scenarios: TTreatmentScenarios, horizon: Optional[Horizon], **kwargs, ) -> Tuple[torch.Tensor, ...]: assert data.event_treatments is not None assert len(treatment_scenarios) == 1 assert treatment_scenarios[0].df.shape == (1, 1) treatment_indicator_scenario = treatment_scenarios[0].df.values[0, 0] treatment_indicator_actual = data.event_treatments[sample_index].df.values[0, 0] if treatment_indicator_scenario == treatment_indicator_actual: raise ValueError( f"Factual treatment indicator ({treatment_indicator_scenario}) for " "sample {sample_index} cannot be provided as a treatment scenario" ) if treatment_indicator_actual == self.control_indicator: raise ValueError( "Currently can only predict counterfactuals for *treated* samples " "(i.e. can predict the untreated outcome for a sample that factually received treatment). " f"However sample {sample_index} was a control sample in the data.\n{treatment_scenarios[0]}" ) synctwin_tensors, _ = self._convert_data_to_synctwin_format(data) # type: ignore assert self.inferred_params.encoder_input_size == synctwin_tensors.x_full.shape[-1] assert self.inferred_params.pre_treat_len == synctwin_tensors.x_full.shape[0] assert self.inferred_params.decoder_y_output_size == synctwin_tensors.y_full.shape[-1] assert self.inferred_params.post_treat_len == synctwin_tensors.y_full.shape[0] self._predict_synctwin_n_unit = (synctwin_tensors.y_mask_full == 1.0).sum().item() self._predict_synctwin_n_treated = (synctwin_tensors.y_mask_full == 0.0).sum().item() return synctwin_tensors def _prep_submodules_for_predict_counterfactuals(self) -> None: # NOTE: This whole method content - not intuitive. assert self.encoder is not None and self.decoder is not None and self.decoder_y is not None assert self._predict_synctwin_n_unit is not None and self._predict_synctwin_n_treated is not None self.synctwin = SyncTwinModule( n_unit=self._predict_synctwin_n_unit, n_treated=self._predict_synctwin_n_treated, device=self.device, dtype=self.dtype, reg_B=self._reg_B, lam_express=self._lambda_express, lam_recon=self.params.lambda_reconstruction, lam_prognostic=self.params.lambda_prognostic, tau=self.params.tau, encoder=self.encoder, decoder=self.decoder, decoder_Y=self.decoder_y, ) self.encoder.to(self.device, dtype=self.dtype) self.decoder.to(self.device, dtype=self.dtype) self.decoder_y.to(self.device, dtype=self.dtype) self.synctwin.to(self.device, dtype=self.dtype) self.encoder.train() self.decoder.train() self.decoder_y.train() self.synctwin.train() def _fit(self, data: Dataset, horizon: Optional[Horizon] = None, **kwargs) -> "SyncTwinRegressor": self.set_attributes_from_kwargs(**kwargs) synctwin_tensors_tain, synctwin_tensors_val = self.prep_fit(data) assert isinstance(synctwin_tensors_tain, SyncTwinTensors) assert isinstance(synctwin_tensors_val, SyncTwinTensors) print("=== Training Stage 1: Pretraining ===") synctwin_train_utils.pre_train_reconstruction_prognostic_loss( self.synctwin, x_full=synctwin_tensors_tain.x_full, t_full=synctwin_tensors_tain.t_full, mask_full=synctwin_tensors_tain.mask_full, y_full=synctwin_tensors_tain.y_full, y_mask_full=synctwin_tensors_tain.y_mask_full, x_full_val=synctwin_tensors_val.x_full, t_full_val=synctwin_tensors_val.t_full, mask_full_val=synctwin_tensors_val.mask_full, y_full_val=synctwin_tensors_val.y_full, y_mask_full_val=synctwin_tensors_val.y_mask_full, niters=self.params.pretraining_iterations, batch_size=self.params.batch_size, test_freq=self._pretraining_test_freq, ) print("=== Training Stage 2: Matching ===") synctwin_train_utils.update_representations( self.synctwin, x_full=synctwin_tensors_val.x_full, t_full=synctwin_tensors_val.t_full, mask_full=synctwin_tensors_val.mask_full, batch_ind_full=synctwin_tensors_val.batch_ind_full, ) synctwin_train_utils.train_B_self_expressive( self.synctwin, x_full=synctwin_tensors_val.x_full, t_full=synctwin_tensors_val.t_full, mask_full=synctwin_tensors_val.mask_full, batch_ind_full=synctwin_tensors_val.batch_ind_full, niters=self.params.matching_iterations, batch_size=None, # NOTE: Batched training not implemented. lr=self._self_expressive_lr, test_freq=self._matching_test_freq, ) return self def _predict(self, data: Dataset, horizon: Optional[Horizon], **kwargs) -> TPredictOutput: raise NotImplementedError( "predict() method of SyncTwin is not implemented. To get counterfactual " "predictions, call predict_counterfactuals()" )
[docs] def get_possible_prediction_horizon(self, sample_index: T_SamplesIndexDtype, data: Dataset): self._convert_data_to_synctwin_format(data, check_only=True) # If the above validates fine, get the horizon: assert data.temporal_targets is not None assert data.event_treatments is not None sample_targets_timeseries = data.temporal_targets[sample_index] sample_event_time = data.event_treatments.df.loc[[sample_index], :, :].index.get_level_values(1)[0] # type: ignore future_horizon = sample_targets_timeseries.df.loc[sample_event_time:, :].index return TimeIndexHorizon(time_index_sequence=[future_horizon])
[docs] def get_possible_treatment_scenarios(self, sample_index: T_SamplesIndexDtype, data: Dataset): self._convert_data_to_synctwin_format(data, check_only=True) assert data.event_treatments is not None sample_event = data.event_treatments[sample_index] # print(sample_event) sample_event_treatment_indicator = sample_event.df.values[0, 0] new_indicator = self._get_other_treatment_status(sample_event_treatment_indicator) sample_event = sample_event.copy() sample_event.df[:] = new_indicator return (sample_event,)
def _predict_counterfactuals( self, data: Dataset, sample_index: T_SamplesIndexDtype, treatment_scenarios: TTreatmentScenarios, horizon: Optional[Horizon], **kwargs, ) -> TCounterfactualPredictions: self.set_attributes_from_kwargs(**kwargs) assert data.temporal_targets is not None assert isinstance(horizon, TimeIndexHorizon) assert self.synctwin is not None synctwin_tensors = self.prep_predict_counterfactuals(data, sample_index, treatment_scenarios, horizon) assert isinstance(synctwin_tensors, SyncTwinTensors) print("=== Running Inference Stage 1: Matching ===") synctwin_train_utils.update_representations( self.synctwin, x_full=synctwin_tensors.x_full, t_full=synctwin_tensors.t_full, mask_full=synctwin_tensors.mask_full, batch_ind_full=synctwin_tensors.batch_ind_full, ) synctwin_train_utils.train_B_self_expressive( self.synctwin, x_full=synctwin_tensors.x_full, t_full=synctwin_tensors.t_full, mask_full=synctwin_tensors.mask_full, batch_ind_full=synctwin_tensors.batch_ind_full, niters=self.params.inference_iterations, batch_size=None, # NOTE: Batched training not implemented. lr=self._self_expressive_lr, test_freq=self._inference_iterations_test_freq, ) synctwin_train_utils.update_representations( self.synctwin, x_full=synctwin_tensors.x_full, t_full=synctwin_tensors.t_full, mask_full=synctwin_tensors.mask_full, batch_ind_full=synctwin_tensors.batch_ind_full, ) self.synctwin.eval() print("=== Running Inference Stage 2: Computing Counterfactuals ===") y_hat = synctwin_train_utils.get_prediction( self.synctwin, batch_ind_full=synctwin_tensors.batch_ind_full, y_control=synctwin_tensors.y_control, itr=self._prediction_compute_iters, ) print("Done") y_hat_sample = y_hat[:, sample_index, :] y_hat_sample = y_hat_sample.detach().cpu().numpy() data_historic_temporal_targets = data.temporal_targets[sample_index] if TYPE_CHECKING: assert isinstance(data_historic_temporal_targets, TimeSeries) list_ts = to_counterfactual_predictions([y_hat_sample], data_historic_temporal_targets, horizon) return list_ts