Source code for tempor.models.clairvoyance2.treatment_effects.crn

# mypy: ignore-errors

from typing import TYPE_CHECKING, Any, List, Mapping, NamedTuple, Optional, Sequence, Tuple, cast

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from ..components.torch.common import OPTIM_MAP
from ..components.torch.ffnn import FeedForwardNet
from ..components.torch.gradient_reversal import GradientReversalModule
from ..components.torch.interfaces import OrganizedTreatmentEffectsModuleMixin
from ..components.torch.rnn import RecurrentFFNet, apply_to_each_timestep, mask_and_reshape
from ..data import DEFAULT_PADDING_INDICATOR, Dataset, TimeSeries, TimeSeriesSamples
from ..data.constants import T_SamplesIndexDtype
from ..data.utils import split_time_series, time_index_utils, to_counterfactual_predictions
from ..interface import (
    Horizon,
    TCounterfactualPredictions,
    TDefaultParams,
    TimeIndexHorizon,
    TParams,
    TreatmentEffectsModel,
    TTreatmentScenarios,
)
from ..interface import requirements as r
from ..prediction.seq2seq import Seq2SeqCRNStylePredictorBase
from ..utils import tensor_like as tl
from ..utils.array_manipulation import n_step_shift_back, n_step_shift_forward
from ..utils.dev import NEEDED

_DEBUG = False


# TODO: For clarity, get rid of TEncodedRepresentation and always use RNNHidden?
TEncodedRepresentation = Tuple[torch.Tensor, Optional[torch.Tensor]]


[docs]class RecurrentFFNet_ConcatTreatment(RecurrentFFNet):
[docs] def rnn_out_postprocess(self, rnn_out: torch.Tensor, **kwargs) -> torch.Tensor: concat_treatment = kwargs["concat_treatment"] return torch.cat([rnn_out, concat_treatment], dim=-1)
def _forward_for_autoregress(self, x: torch.Tensor, timestep_idx: int, **kwargs) -> torch.Tensor: concat_treatment = kwargs.pop("concat_treatment") concat_treatment = concat_treatment[:, [timestep_idx], :] out, *_ = self.forward(x, concat_treatment=concat_treatment, **kwargs) return out
[docs]class TreatBalancerNet(FeedForwardNet): def __init__( self, in_dim: int, out_dim: int, hidden_dims: Sequence[int] = tuple(), out_activation: Optional[str] = "ReLU", hidden_activations: Optional[str] = "ReLU", ) -> None: super().__init__(in_dim, out_dim, hidden_dims, out_activation, hidden_activations) self.softmax = nn.Softmax(dim=-1) self.revgrad = GradientReversalModule(alpha=1.0)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.revgrad(x) x = super().forward(x) x = self.softmax(x) return x
class _DefaultParams(NamedTuple): # Encoder: encoder_rnn_type: str = "LSTM" encoder_hidden_size: int = 100 encoder_num_layers: int = 1 encoder_bias: bool = True encoder_dropout: float = 0.0 encoder_bidirectional: bool = False encoder_nonlinearity: Optional[str] = None encoder_proj_size: Optional[int] = None # Decoder: decoder_rnn_type: str = "LSTM" decoder_hidden_size: int = 100 decoder_num_layers: int = 1 decoder_bias: bool = True decoder_dropout: float = 0.0 decoder_bidirectional: bool = False decoder_nonlinearity: Optional[str] = None decoder_proj_size: Optional[int] = None # Adapter FF NN: adapter_hidden_dims: Sequence[int] = [50] adapter_out_activation: Optional[str] = "Tanh" # Predictor FF NN: predictor_hidden_dims: Sequence[int] = [] predictor_out_activation: Optional[str] = None # Treatment Balancer FF NN: treat_net_hidden_dims: Sequence[int] = [] treat_net_out_activation: Optional[str] = None # Misc: max_len: Optional[int] = None optimizer_str: str = "Adam" optimizer_kwargs: Mapping[str, Any] = dict(lr=0.01, weight_decay=1e-5) batch_size: int = 32 epochs: int = 100 padding_indicator: float = DEFAULT_PADDING_INDICATOR # TODO: Test this with various sets of params and make sure it doesn't fail.
[docs]class CRNTreatmentEffectsModelBase( TreatmentEffectsModel, Seq2SeqCRNStylePredictorBase, OrganizedTreatmentEffectsModuleMixin ): requirements: r.Requirements DEFAULT_PARAMS: TDefaultParams def __init__(self, loss_fn: nn.Module, params: Optional[TParams] = None) -> None: TreatmentEffectsModel.__init__(self, params) Seq2SeqCRNStylePredictorBase.__init__(self, loss_fn=loss_fn, params=params) # Treatment balancer. self.encoder_treat_net: Optional[FeedForwardNet] = NEEDED self.decoder_treat_net: Optional[FeedForwardNet] = NEEDED self.treat_loss = nn.CrossEntropyLoss() def _init_submodules_encoder_decoder(self) -> None: # Initialize Encoder models: self.encoder = RecurrentFFNet_ConcatTreatment( rnn_type=self.params.encoder_rnn_type, input_size=self.inferred_params.encoder_input_size, hidden_size=self.params.encoder_hidden_size, nonlinearity=self.params.encoder_nonlinearity, num_layers=self.params.encoder_num_layers, bias=self.params.encoder_bias, dropout=self.params.encoder_dropout, bidirectional=self.params.encoder_bidirectional, proj_size=self.params.encoder_proj_size, ff_out_size=self.inferred_params.encoder_predictor_output_size, ff_in_size_adjust=self.inferred_params.predictor_input_size_adjust, # Note this. ff_hidden_dims=self.params.predictor_hidden_dims, ff_out_activation=self.params.predictor_out_activation, ff_hidden_activations="ReLU", ) # Initialize Decoder models: self.decoder = RecurrentFFNet_ConcatTreatment( rnn_type=self.params.decoder_rnn_type, input_size=self.inferred_params.decoder_input_size, hidden_size=self.params.decoder_hidden_size, nonlinearity=self.params.decoder_nonlinearity, num_layers=self.params.decoder_num_layers, bias=self.params.decoder_bias, dropout=self.params.decoder_dropout, bidirectional=self.params.decoder_bidirectional, proj_size=self.params.decoder_proj_size, ff_out_size=self.inferred_params.decoder_predictor_output_size, ff_in_size_adjust=self.inferred_params.predictor_input_size_adjust, # Note this. ff_hidden_dims=self.params.predictor_hidden_dims, ff_out_activation=self.params.predictor_out_activation, ff_hidden_activations="ReLU", ) def _init_submodules_treat_net(self) -> None: # Initialize Treatment Balancers: if TYPE_CHECKING: assert self.encoder is not None and self.decoder is not None encoder_out_dim, *_ = self.encoder.rnn.get_output_and_h_dim() decoder_out_dim, *_ = self.decoder.rnn.get_output_and_h_dim() self.inferred_params.encoder_treat_net_input_size = encoder_out_dim self.inferred_params.decoder_treat_net_input_size = decoder_out_dim self.encoder_treat_net = TreatBalancerNet( in_dim=self.inferred_params.encoder_treat_net_input_size, out_dim=self.inferred_params.treat_net_output_size, hidden_dims=self.params.treat_net_hidden_dims, out_activation=self.params.treat_net_out_activation, hidden_activations="ReLU", ) self.decoder_treat_net = TreatBalancerNet( in_dim=self.inferred_params.decoder_treat_net_input_size, out_dim=self.inferred_params.treat_net_output_size, hidden_dims=self.params.treat_net_hidden_dims, out_activation=self.params.treat_net_out_activation, hidden_activations="ReLU", ) def _init_submodules(self) -> None: self._init_submodules_encoder_decoder() super()._init_submodules_adapter() self._init_submodules_treat_net() def _init_inferred_params(self, data: Dataset, **kwargs) -> None: assert data.temporal_covariates is not None assert data.temporal_targets is not None assert data.temporal_treatments is not None # Initialize the helper attributes. # + 1 below are for time deltas. self.inferred_params.encoder_input_size = ( data.temporal_covariates.n_features + data.temporal_treatments.n_features + 1 ) if data.static_covariates is not None: self.inferred_params.encoder_input_size += data.static_covariates.n_features self.inferred_params.decoder_input_size = ( data.temporal_targets.n_features + data.temporal_treatments.n_features + 1 ) if data.static_covariates is not None: self.inferred_params.decoder_input_size += data.static_covariates.n_features self.inferred_params.encoder_input_size += data.temporal_targets.n_features self.inferred_params.encoder_predictor_output_size = data.temporal_targets.n_features self.inferred_params.decoder_predictor_output_size = data.temporal_targets.n_features self.inferred_params.treat_net_output_size = data.temporal_treatments.n_features self.inferred_params.predictor_input_size_adjust = data.temporal_treatments.n_features # Inferred batch size: self.inferred_params.encoder_batch_size = min(self.params.batch_size, data.n_samples) self.inferred_params.decoder_batch_size = NEEDED # This is set later. def _init_optimizers(self): self.encoder = cast(RecurrentFFNet, self.encoder) self.encoder_treat_net = cast(FeedForwardNet, self.encoder_treat_net) self.adapter = cast(FeedForwardNet, self.adapter) self.decoder = cast(RecurrentFFNet, self.decoder) self.decoder_treat_net = cast(FeedForwardNet, self.decoder_treat_net) # Initialize optimizers. self.encoder_optim = OPTIM_MAP[self.params.optimizer_str]( params=[*self.encoder.parameters(), *self.encoder_treat_net.parameters()], **self.params.optimizer_kwargs, ) self.decoder_optim = OPTIM_MAP[self.params.optimizer_str]( params=[*self.adapter.parameters(), *self.decoder.parameters(), *self.decoder_treat_net.parameters()], **self.params.optimizer_kwargs, ) def _prep_treat_tensors(self, data: Dataset, t_cov: torch.Tensor, t_targ: torch.Tensor): if TYPE_CHECKING: assert data.temporal_treatments is not None t_treat = data.temporal_treatments.to_torch_tensor( padding_indicator=self.params.padding_indicator, max_len=self.params.max_len, dtype=self.dtype, device=self.device, ) t_cov = n_step_shift_back(t_cov, n_step=1) t_targ = n_step_shift_back(t_targ, n_step=1) t_treat_out = n_step_shift_back(t_treat, n_step=1)[:, : t_cov.shape[1], :] t_treat = n_step_shift_forward(t_treat, n_step=1)[:, : t_cov.shape[1], :] t_cov = torch.cat([t_cov, t_treat], dim=-1) # NOTE: # If time indexes originally are: # t_targ like [1, 2, 3, 4] # t_cov like [0, 1, 2, 3] # t_treat like [0, 1, 2, 3, (...)] # # Then at the end: # t_targ like [2, 3, 4] # t_treat_out like [1, 2, 3] # New t_cov combines: # original t_cov [1, 2, 3] # original t_treat [0, 1, 2] return t_cov, t_targ, t_treat_out def _prep_torch_tensors_encoder(self, data: Dataset, shift_targ_cov: bool): t_cov, t_targ = super()._prep_torch_tensors_encoder(data, shift_targ_cov) t_cov, t_targ, t_treat_out = self._prep_treat_tensors(data, t_cov=t_cov, t_targ=t_targ) return t_cov, t_targ, t_treat_out def _prep_torch_tensors_decoder(self, data: Dataset): t_targ, decoder_input = super()._prep_torch_tensors_decoder(data) decoder_input, t_targ, t_treat_out = self._prep_treat_tensors(data, t_cov=decoder_input, t_targ=t_targ) return t_targ, decoder_input, t_treat_out def _prep_data_for_fit(self, data: Dataset, **kwargs) -> Tuple[torch.Tensor, ...]: min_pre_len = kwargs.pop("min_pre_len") min_post_len = kwargs.pop("min_post_len") repeat_last_pre_step = kwargs.pop("repeat_last_pre_step") encoder_tensors = self._prep_torch_tensors_encoder(data, shift_targ_cov=True) print("Preparing data for decoder training...") data_pre, data_post, _ = split_time_series.split_at_each_step( data, min_pre_len=min_pre_len, min_post_len=min_post_len, repeat_last_pre_step=repeat_last_pre_step ) self.inferred_params.decoder_batch_size = min(self.params.batch_size, data_post.n_samples) print("Preparing data for decoder training DONE.") t_cov_to_encode, _, t_treat_to_encode = self._prep_torch_tensors_encoder(data_pre, shift_targ_cov=False) decoder_tensors = self._prep_torch_tensors_decoder(data_post) return (*encoder_tensors, t_cov_to_encode, t_treat_to_encode, *decoder_tensors) def _prep_torch_tensors_decoder_inference( self, data: Dataset, horizon: TimeIndexHorizon, ): decoder_input = super()._prep_torch_tensors_decoder_inference(data, horizon) if TYPE_CHECKING: assert data.temporal_treatments is not None ts_treat = time_index_utils.time_series_samples.take_all_from_one_before_start( time_series_samples_=data.temporal_treatments, time_indexes=horizon, inplace=False ) if TYPE_CHECKING: assert ts_treat is not None t_treat = ts_treat.to_torch_tensor( padding_indicator=self.params.padding_indicator, max_len=decoder_input.shape[1] + 1, dtype=self.dtype, device=self.device, ) decoder_input = torch.cat([decoder_input, t_treat[:, :-1, :]], dim=-1) t_treat_out = t_treat[:, 1:, :] return decoder_input, t_treat_out def _prep_torch_tensors_decoder_inference_counterfactuals( self, data: Dataset, treatment_scenario: TimeSeries, horizon: TimeIndexHorizon, ): decoder_input = super()._prep_torch_tensors_decoder_inference(data, horizon) if TYPE_CHECKING: assert data.temporal_treatments is not None t_treat_last = time_index_utils.time_series_samples.take_one_before_start(data.temporal_treatments, horizon) if TYPE_CHECKING: assert t_treat_last is not None t_treat_last = t_treat_last.to_torch_tensor( padding_indicator=self.params.padding_indicator, max_len=1, dtype=self.dtype, device=self.device, ) t_treat = treatment_scenario.to_torch_tensor( padding_indicator=self.params.padding_indicator, max_len=decoder_input.shape[1], dtype=self.dtype, device=self.device, ).unsqueeze(dim=0) t_treat = torch.cat([t_treat_last, t_treat], dim=1) decoder_input = torch.cat([decoder_input, t_treat[:, :-1, :]], dim=-1) t_treat_out = t_treat[:, 1:, :] return decoder_input, t_treat_out def _prep_data_for_predict(self, data: Dataset, horizon: Optional[Horizon], **kwargs) -> Tuple[torch.Tensor, ...]: assert data.temporal_covariates is not None assert data.temporal_targets is not None assert isinstance(horizon, TimeIndexHorizon) # Make sure to not use "future" values for prediction. data_encode = time_index_utils.dataset.take_temporal_data_before_start(data, horizon, inplace=False) if TYPE_CHECKING: assert data_encode is not None t_cov_to_encode, _, t_treat_out = self._prep_torch_tensors_encoder(data_encode, shift_targ_cov=False) encoded_representations = self._get_encoder_representation(t_cov_to_encode, t_treat_out=t_treat_out) h, c = super()._reshape_h_sample_dim_0(encoded_representations) decoder_input, decoder_t_treat_out = self._prep_torch_tensors_decoder_inference(data, horizon) assert h.shape[0] == decoder_input.shape[0] return h, c, decoder_input, decoder_t_treat_out def _prep_data_for_predict_counterfactuals( self, data: Dataset, sample_index: T_SamplesIndexDtype, treatment_scenarios: TTreatmentScenarios, horizon: Optional[Horizon], **kwargs, ) -> Tuple[Any, ...]: assert data.temporal_covariates is not None assert data.temporal_targets is not None assert isinstance(horizon, TimeIndexHorizon) # Make sure to not use "future" values for prediction. data_encode = time_index_utils.dataset.take_temporal_data_before_start(data, horizon, inplace=False) if TYPE_CHECKING: assert data_encode is not None t_cov_to_encode, _, t_treat_out = self._prep_torch_tensors_encoder(data_encode, shift_targ_cov=False) encoded_representations = self._get_encoder_representation(t_cov_to_encode, t_treat_out=t_treat_out) h, c = super()._reshape_h_sample_dim_0(encoded_representations) decoder_input_list = [] decoder_t_treat_out_list = [] for treatment_scenario in treatment_scenarios: assert isinstance(treatment_scenario, TimeSeries) decoder_input, decoder_t_treat_out = self._prep_torch_tensors_decoder_inference_counterfactuals( data, treatment_scenario, horizon ) assert h.shape[0] == decoder_input.shape[0] decoder_input_list.append(decoder_input) decoder_t_treat_out_list.append(decoder_t_treat_out) return h, c, decoder_input_list, decoder_t_treat_out_list def _prep_submodules_for_fit(self) -> None: assert self.encoder_treat_net is not None and self.decoder_treat_net is not None super()._prep_submodules_for_fit() self.encoder_treat_net.to(self.device, dtype=self.dtype) self.decoder_treat_net.to(self.device, dtype=self.dtype) self.encoder_treat_net.train() self.decoder_treat_net.train() def _prep_submodules_for_predict(self) -> None: assert self.encoder_treat_net is not None and self.decoder_treat_net is not None super()._prep_submodules_for_predict() self.encoder_treat_net.to(self.device, dtype=self.dtype) self.decoder_treat_net.to(self.device, dtype=self.dtype) self.encoder_treat_net.eval() self.decoder_treat_net.eval() def _compute_lambda(self, epoch_idx: int) -> torch.Tensor: return 2.0 / (1.0 + torch.exp(-10.0 * torch.tensor(epoch_idx + 1))) - 1.0 def _train_encoder(self, encoder_tensors: Tuple) -> None: if TYPE_CHECKING: assert self.encoder is not None assert self.encoder_optim is not None assert self.encoder_treat_net is not None dataloader = DataLoader( TensorDataset(*encoder_tensors), batch_size=self.inferred_params.encoder_batch_size, shuffle=True ) for epoch_idx in range(self.params.epochs): n_samples_cumul = 0 epoch_loss = 0.0 epoch_loss_target = 0.0 epoch_loss_treat = 0.0 lambda_ = 0.0 for _, (t_cov, t_targ, t_treat_out) in enumerate(dataloader): current_batch_size = t_cov.shape[0] n_samples_cumul += current_batch_size out, rnn_out, _ = self.encoder( t_cov, h=None, padding_indicator=self.params.padding_indicator, concat_treatment=t_treat_out ) out_treat_net = apply_to_each_timestep( self.encoder_treat_net, input_tensor=rnn_out, output_size=self.inferred_params.treat_net_output_size, concat_tensors=[], padding_indicator=self.params.padding_indicator, expected_module_input_size=self.inferred_params.encoder_treat_net_input_size, ) not_padding_targ = ~tl.eq_indicator(t_targ, self.params.padding_indicator) if TYPE_CHECKING: assert isinstance(not_padding_targ, torch.BoolTensor) out = mask_and_reshape(mask_selector=not_padding_targ, tensor=out) t_targ = mask_and_reshape(mask_selector=not_padding_targ, tensor=t_targ) not_padding_treat = ~tl.eq_indicator(t_treat_out, self.params.padding_indicator) out_treat_net = mask_and_reshape(mask_selector=not_padding_treat, tensor=out_treat_net) t_treat_out = mask_and_reshape(mask_selector=not_padding_treat, tensor=t_treat_out) out = self.process_output_for_loss(out) loss_target = self.loss_fn(out, t_targ) loss_treat = self.treat_loss(out_treat_net, t_treat_out) lambda_ = self._compute_lambda(epoch_idx) loss = loss_target + lambda_ * loss_treat # Optimization: self.encoder_optim.zero_grad() loss.backward() self.encoder_optim.step() epoch_loss_target += loss_target.item() * current_batch_size epoch_loss_treat += loss_treat.item() * current_batch_size epoch_loss += loss.item() * current_batch_size epoch_loss_target /= n_samples_cumul epoch_loss_treat /= n_samples_cumul epoch_loss /= n_samples_cumul print( f"Epoch: {epoch_idx}, Prediction Loss: {epoch_loss_target:.3f}, " f"Lambda: {lambda_:.3f}, Treatment BR Loss: {epoch_loss_treat:.3f}, Loss: {epoch_loss:.3f}" ) def _get_encoder_representation(self, t_cov: torch.Tensor, t_treat_out=NEEDED, **kwargs) -> TEncodedRepresentation: # 2. Get the encoded representations. if TYPE_CHECKING: assert self.encoder is not None assert isinstance(t_treat_out, torch.Tensor) # Not sure this is needed here, but just in case: self.encoder.eval() is_lstm = self.params.encoder_rnn_type == "LSTM" with torch.no_grad(): _, _, h = self.encoder( t_cov, h=None, padding_indicator=self.params.padding_indicator, concat_treatment=t_treat_out ) h, c = h if is_lstm else (h, None) return h, c def _train_decoder(self, encoded_representations: TEncodedRepresentation, decoder_tensors: Tuple) -> None: if TYPE_CHECKING: assert self.encoder is not None and self.decoder is not None assert self.decoder_optim is not None assert self.decoder_treat_net is not None assert self.adapter is not None (t_targ, decoder_input, t_treat_out) = decoder_tensors h, c = super()._reshape_h_sample_dim_0(encoded_representations) assert h.shape[0] == decoder_input.shape[0] dataloader = DataLoader( TensorDataset(h, c, t_targ, decoder_input, t_treat_out), batch_size=self.inferred_params.decoder_batch_size, shuffle=True, ) for epoch_idx in range(self.params.epochs): n_samples_cumul = 0 epoch_loss = 0.0 epoch_loss_target = 0.0 epoch_loss_treat = 0.0 lambda_ = 0.0 for _, (h, c, t_targ, decoder_input, t_treat_out) in enumerate(dataloader): current_batch_size = t_targ.shape[0] n_samples_cumul += current_batch_size # Pass encoded representations through the adapter. h_adapter_out = self._pass_h_through_adapter(h, c) out, rnn_out, _ = self.decoder( decoder_input, h=h_adapter_out, padding_indicator=self.params.padding_indicator, concat_treatment=t_treat_out, ) out_treat_net = apply_to_each_timestep( self.decoder_treat_net, input_tensor=rnn_out, output_size=self.inferred_params.treat_net_output_size, concat_tensors=[], padding_indicator=self.params.padding_indicator, expected_module_input_size=self.inferred_params.decoder_treat_net_input_size, ) not_padding_targ = ~tl.eq_indicator(t_targ, self.params.padding_indicator) if TYPE_CHECKING: assert isinstance(not_padding_targ, torch.BoolTensor) out = mask_and_reshape(mask_selector=not_padding_targ, tensor=out) t_targ = mask_and_reshape(mask_selector=not_padding_targ, tensor=t_targ) not_padding_treat = ~tl.eq_indicator(t_treat_out, self.params.padding_indicator) out_treat_net = mask_and_reshape(mask_selector=not_padding_treat, tensor=out_treat_net) t_treat_out = mask_and_reshape(mask_selector=not_padding_treat, tensor=t_treat_out) out = self.process_output_for_loss(out) loss_target = self.loss_fn(out, t_targ) loss_treat = self.treat_loss(out_treat_net, t_treat_out) lambda_ = self._compute_lambda(epoch_idx) loss = loss_target + lambda_ * loss_treat # Optimization: self.decoder_optim.zero_grad() loss.backward() self.decoder_optim.step() epoch_loss_target += loss_target.item() * current_batch_size epoch_loss_treat += loss_treat.item() * current_batch_size epoch_loss += loss.item() * current_batch_size epoch_loss_target /= n_samples_cumul epoch_loss_treat /= n_samples_cumul epoch_loss /= n_samples_cumul print( f"Epoch: {epoch_idx}, Prediction Loss: {epoch_loss_target:.3f}, " f"Lambda: {lambda_:.3f}, Treatment BR Loss: {epoch_loss_treat:.3f}, Loss: {epoch_loss:.3f}" ) def _fit( self, data: Dataset, horizon: Horizon = None, # type: ignore **kwargs, ) -> "CRNTreatmentEffectsModelBase": self.set_attributes_from_kwargs(**kwargs) # Ensure there are at least 3 timesteps in the "post" part of TimeSeries after the split and # at least 3 timesteps in the "pre" part. This is due to the cov./targ./treat. shifts that are needed. ( encoder_t_cov, encoder_t_targ, encoder_t_treat_out, t_cov_to_encode, t_treat_to_encode, decoder_t_targ, decoder_input, decoder_t_treat_out, ) = self.prep_fit(data=data, min_pre_len=3, min_post_len=3, repeat_last_pre_step=True) # Run the training stages. print("=== Training stage: 1. Train encoder ===") self._train_encoder(encoder_tensors=(encoder_t_cov, encoder_t_targ, encoder_t_treat_out)) print("=== Training stage: 2. Train decoder ===") if TYPE_CHECKING: assert isinstance(t_cov_to_encode, torch.Tensor) encoded_representations = self._get_encoder_representation(t_cov_to_encode, t_treat_out=t_treat_to_encode) self._train_decoder( encoded_representations, decoder_tensors=(decoder_t_targ, decoder_input, decoder_t_treat_out) ) return self def _predict(self, data: Dataset, horizon: Optional[Horizon], **kwargs) -> TimeSeriesSamples: self.set_attributes_from_kwargs(**kwargs) data = data.copy() if TYPE_CHECKING: assert self.decoder is not None assert data.temporal_targets is not None assert isinstance(horizon, TimeIndexHorizon) h, c, decoder_input, t_treat_out = self.prep_predict(data, horizon=horizon) if TYPE_CHECKING: assert isinstance(h, torch.Tensor) assert isinstance(decoder_input, torch.Tensor) assert isinstance(t_treat_out, torch.Tensor) with torch.no_grad(): h_adapter_out = self._pass_h_through_adapter(h, c) out = self.decoder.autoregress( decoder_input, h=h_adapter_out, padding_indicator=self.params.padding_indicator, concat_treatment=t_treat_out, ) out_final: Any = self.process_output_for_loss(out) out_final[tl.eq_indicator(out, self.params.padding_indicator)] = self.params.padding_indicator prediction = TimeSeriesSamples.new_empty_like(like=data.temporal_targets) prediction.update_from_sequence_of_arrays( out_final, time_index_sequence=horizon.time_index_sequence, padding_indicator=self.params.padding_indicator ) return prediction def _predict_counterfactuals( self, data: Dataset, sample_index: T_SamplesIndexDtype, treatment_scenarios: TTreatmentScenarios, horizon: Optional[Horizon], **kwargs, ) -> TCounterfactualPredictions: self.set_attributes_from_kwargs(**kwargs) data = data[sample_index].copy() if TYPE_CHECKING: assert self.decoder is not None assert data.temporal_targets is not None assert isinstance(horizon, TimeIndexHorizon) h, c, decoder_input_list, decoder_t_treat_out_list = self.prep_predict_counterfactuals( data, sample_index=sample_index, treatment_scenarios=treatment_scenarios, horizon=horizon ) if TYPE_CHECKING: assert isinstance(h, torch.Tensor) assert isinstance(decoder_input_list, list) assert isinstance(decoder_t_treat_out_list, list) list_counterfactual_predictions: List[torch.Tensor] = [] for decoder_input, decoder_t_treat_out in zip(decoder_input_list, decoder_t_treat_out_list): assert isinstance(decoder_input, torch.Tensor) assert isinstance(decoder_t_treat_out, torch.Tensor) with torch.no_grad(): h_adapter_out = self._pass_h_through_adapter(h, c) out = self.decoder.autoregress( decoder_input, h=h_adapter_out, padding_indicator=self.params.padding_indicator, concat_treatment=decoder_t_treat_out, ) out_final = self.process_output_for_loss(out) # The output should be single-sample and shouldn't have any padding. assert out_final.shape[0] == 1 assert tl.eq_indicator(out, self.params.padding_indicator).sum().item() == 0 list_counterfactual_predictions.append(out_final) data_historic_temporal_targets = data.temporal_targets[sample_index] if TYPE_CHECKING: assert isinstance(data_historic_temporal_targets, TimeSeries) list_ts = to_counterfactual_predictions( list_counterfactual_predictions, data_historic_temporal_targets, horizon ) return list_ts
[docs]class CRNRegressor(CRNTreatmentEffectsModelBase): requirements: r.Requirements = r.Requirements( dataset_requirements=r.DatasetRequirements( temporal_covariates_value_type=r.DataValueOpts.NUMERIC, temporal_targets_value_type=r.DataValueOpts.NUMERIC_CATEGORICAL, temporal_treatments_value_type=r.DataValueOpts.NUMERIC_BINARY, static_covariates_value_type=r.DataValueOpts.NUMERIC, requires_no_missing_data=True, ), prediction_requirements=r.PredictionRequirements( target_data_structure=r.DataStructureOpts.TIME_SERIES, horizon_type=r.HorizonOpts.TIME_INDEX, min_timesteps_target_when_fit=6, min_timesteps_target_when_predict=3, ), treatment_effects_requirements=r.TreatmentEffectsRequirements( treatment_data_structure=r.DataStructureOpts.TIME_SERIES, min_timesteps_treatment_when_fit=6, min_timesteps_treatment_when_predict=3, min_timesteps_treatment_when_predict_counterfactual=3, ), ) DEFAULT_PARAMS: TDefaultParams = _DefaultParams()
[docs] def process_output_for_loss(self, output: torch.Tensor, **kwargs) -> torch.Tensor: return output
def __init__(self, params: Optional[TParams] = None) -> None: super().__init__(loss_fn=nn.MSELoss(), params=params)
[docs]class CRNClassifier(CRNTreatmentEffectsModelBase): requirements: r.Requirements = r.Requirements( dataset_requirements=r.DatasetRequirements( temporal_covariates_value_type=r.DataValueOpts.NUMERIC, temporal_targets_value_type=r.DataValueOpts.NUMERIC_CATEGORICAL, temporal_treatments_value_type=r.DataValueOpts.NUMERIC_BINARY, static_covariates_value_type=r.DataValueOpts.NUMERIC, requires_no_missing_data=True, ), prediction_requirements=r.PredictionRequirements( target_data_structure=r.DataStructureOpts.TIME_SERIES, horizon_type=r.HorizonOpts.TIME_INDEX, min_timesteps_target_when_fit=6, min_timesteps_target_when_predict=3, ), treatment_effects_requirements=r.TreatmentEffectsRequirements( treatment_data_structure=r.DataStructureOpts.TIME_SERIES, min_timesteps_treatment_when_fit=6, min_timesteps_treatment_when_predict=3, min_timesteps_treatment_when_predict_counterfactual=3, ), ) DEFAULT_PARAMS: TDefaultParams = _DefaultParams()
[docs] def process_output_for_loss(self, output: torch.Tensor, **kwargs) -> torch.Tensor: return self.softmax(output)
def __init__(self, params: Optional[TParams] = None) -> None: super().__init__(loss_fn=nn.CrossEntropyLoss(), params=params) self.softmax = nn.Softmax(dim=-1)