tempor.models.clairvoyance2.treatment_effects.crn module¶
-
class tempor.models.clairvoyance2.treatment_effects.crn.RecurrentFFNet_ConcatTreatment(rnn_type: str, input_size: int, hidden_size: int, nonlinearity: str | None, num_layers: int, bias: bool, dropout: float, bidirectional: bool, proj_size: int | None, ff_out_size: int, ff_in_size_adjust: int =
0, ff_hidden_dims: Sequence[int] =(), ff_out_activation: str | None ='ReLU', ff_hidden_activations: str | None ='ReLU')[source]¶ Bases:
RecurrentFFNet
-
class tempor.models.clairvoyance2.treatment_effects.crn.TreatBalancerNet(in_dim: int, out_dim: int, hidden_dims: Sequence[int] =
(), out_activation: str | None ='ReLU', hidden_activations: str | None ='ReLU')[source]¶ Bases:
FeedForwardNetInitializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
class tempor.models.clairvoyance2.treatment_effects.crn.CRNTreatmentEffectsModelBase(loss_fn: Module, params: dict[str, Any] | None =
None)[source]¶ Bases:
TreatmentEffectsModel,Seq2SeqCRNStylePredictorBase,OrganizedTreatmentEffectsModuleMixin
-
class tempor.models.clairvoyance2.treatment_effects.crn.CRNRegressor(params: dict[str, Any] | None =
None)[source]¶ Bases:
CRNTreatmentEffectsModelBase-
requirements : Requirements =
Requirements(dataset_requirements=DatasetRequirements(requires_static_covariates_present=False, requires_no_missing_data=True, static_covariates_value_type=<DataValueOpts.NUMERIC: 2>, temporal_covariates_value_type=<DataValueOpts.NUMERIC: 2>, temporal_targets_value_type=<DataValueOpts.NUMERIC_CATEGORICAL: 3>, temporal_treatments_value_type=<DataValueOpts.NUMERIC_BINARY: 4>, event_covariates_value_type=<DataValueOpts.ANY: 1>, event_targets_value_type=<DataValueOpts.ANY: 1>, event_treatments_value_type=<DataValueOpts.ANY: 1>, requires_all_temporal_data_samples_aligned=False, requires_all_temporal_data_regular=False, requires_all_temporal_data_index_numeric=False, requires_all_temporal_containers_shares_index=True), prediction_requirements=PredictionRequirements(target_data_structure=<DataStructureOpts.TIME_SERIES: 1>, horizon_type=<HorizonOpts.TIME_INDEX: 2>, min_timesteps_target_when_fit=6, min_timesteps_target_when_predict=3), treatment_effects_requirements=TreatmentEffectsRequirements(treatment_data_structure=<DataStructureOpts.TIME_SERIES: 1>, min_timesteps_treatment_when_fit=6, min_timesteps_treatment_when_predict=3, min_timesteps_treatment_when_predict_counterfactual=3))¶
-
DEFAULT_PARAMS : dict[str, Any] | NamedTuple =
_DefaultParams(encoder_rnn_type='LSTM', encoder_hidden_size=100, encoder_num_layers=1, encoder_bias=True, encoder_dropout=0.0, encoder_bidirectional=False, encoder_nonlinearity=None, encoder_proj_size=None, decoder_rnn_type='LSTM', decoder_hidden_size=100, decoder_num_layers=1, decoder_bias=True, decoder_dropout=0.0, decoder_bidirectional=False, decoder_nonlinearity=None, decoder_proj_size=None, adapter_hidden_dims=[50], adapter_out_activation='Tanh', predictor_hidden_dims=[], predictor_out_activation=None, treat_net_hidden_dims=[], treat_net_out_activation=None, max_len=None, optimizer_str='Adam', optimizer_kwargs={'lr': 0.01, 'weight_decay': 1e-05}, batch_size=32, epochs=100, padding_indicator=-999.0)¶
- encoder_treat_net : FeedForwardNet | None¶
- decoder_treat_net : FeedForwardNet | None¶
- encoder : RecurrentFFNet | None¶
- decoder : RecurrentFFNet | None¶
- adapter : FeedForwardNet | None¶
- loss_fn : nn.Module¶
- params : DotMap¶
- inferred_params : DotMap¶
-
requirements : Requirements =
-
class tempor.models.clairvoyance2.treatment_effects.crn.CRNClassifier(params: dict[str, Any] | None =
None)[source]¶ Bases:
CRNTreatmentEffectsModelBase-
requirements : Requirements =
Requirements(dataset_requirements=DatasetRequirements(requires_static_covariates_present=False, requires_no_missing_data=True, static_covariates_value_type=<DataValueOpts.NUMERIC: 2>, temporal_covariates_value_type=<DataValueOpts.NUMERIC: 2>, temporal_targets_value_type=<DataValueOpts.NUMERIC_CATEGORICAL: 3>, temporal_treatments_value_type=<DataValueOpts.NUMERIC_BINARY: 4>, event_covariates_value_type=<DataValueOpts.ANY: 1>, event_targets_value_type=<DataValueOpts.ANY: 1>, event_treatments_value_type=<DataValueOpts.ANY: 1>, requires_all_temporal_data_samples_aligned=False, requires_all_temporal_data_regular=False, requires_all_temporal_data_index_numeric=False, requires_all_temporal_containers_shares_index=True), prediction_requirements=PredictionRequirements(target_data_structure=<DataStructureOpts.TIME_SERIES: 1>, horizon_type=<HorizonOpts.TIME_INDEX: 2>, min_timesteps_target_when_fit=6, min_timesteps_target_when_predict=3), treatment_effects_requirements=TreatmentEffectsRequirements(treatment_data_structure=<DataStructureOpts.TIME_SERIES: 1>, min_timesteps_treatment_when_fit=6, min_timesteps_treatment_when_predict=3, min_timesteps_treatment_when_predict_counterfactual=3))¶
-
DEFAULT_PARAMS : dict[str, Any] | NamedTuple =
_DefaultParams(encoder_rnn_type='LSTM', encoder_hidden_size=100, encoder_num_layers=1, encoder_bias=True, encoder_dropout=0.0, encoder_bidirectional=False, encoder_nonlinearity=None, encoder_proj_size=None, decoder_rnn_type='LSTM', decoder_hidden_size=100, decoder_num_layers=1, decoder_bias=True, decoder_dropout=0.0, decoder_bidirectional=False, decoder_nonlinearity=None, decoder_proj_size=None, adapter_hidden_dims=[50], adapter_out_activation='Tanh', predictor_hidden_dims=[], predictor_out_activation=None, treat_net_hidden_dims=[], treat_net_out_activation=None, max_len=None, optimizer_str='Adam', optimizer_kwargs={'lr': 0.01, 'weight_decay': 1e-05}, batch_size=32, epochs=100, padding_indicator=-999.0)¶
- encoder_treat_net : FeedForwardNet | None¶
- decoder_treat_net : FeedForwardNet | None¶
- encoder : RecurrentFFNet | None¶
- decoder : RecurrentFFNet | None¶
- adapter : FeedForwardNet | None¶
- loss_fn : nn.Module¶
- params : DotMap¶
- inferred_params : DotMap¶
-
requirements : Requirements =