tempor.methods.treatments.temporal.classification.plugin_crn_classifier module¶
Counterfactual Recurrent Network treatment effects model for classification on the outcomes (targets).
- class tempor.methods.treatments.temporal.classification.plugin_crn_classifier.CRNTreatmentsClassifier(**params: Any)[source]¶
Bases:
BaseTemporalTreatmentEffectsCounterfactual Recurrent Network treatment effects model for classification on the outcomes (targets).
- Parameters:¶
- **params : Any
Parameters for the model.
Example
>>> from tempor import plugin_loader >>> >>> # Load the model: >>> model = plugin_loader.get("treatments.temporal.classification.crn_classifier", n_iter=50) >>> >>> # Train: >>> # model.fit(dataset) >>> >>> # Predict: >>> # assert model.predict(dataset, n_future_steps = 10).numpy().shape == (len(dataset), 10, 5)References
Estimating counterfactual treatment outcomes over time through adversarially balanced representations, Ioana Bica, Ahmed M. Alaa, James Jordon, Mihaela van der Schaar.
- ParamsDefinition¶
alias of
Seq2seqParams
- params : Seq2seqParams¶
- static hyperparameter_space(*args: Any, **kwargs: Any) list[Params][source]¶
The hyperparameter search domain, used for tuning.
Can provide variadics
*argsand**kwargs, these will be received fromsample_hyperparameters.
-
category : ClassVar[plugin_typing.PluginCategory] =
'treatments.temporal.classification'¶ Plugin category, such as
'prediction.one_off.classification'. Must be set by the plugin class using@register_plugin.
-
name : ClassVar[plugin_typing.PluginName] =
'crn_classifier'¶ Plugin name, such as
'my_nn_classifier'. Must be set by the plugin class using@register_plugin.
-
plugin_type : ClassVar[plugin_typing.PluginTypeArg] =
'method'¶ Plugin type, such as
'method'. May be optionally set by the plugin class using@register_plugin, else will set the default plugin type.