tempor.methods.treatments.temporal.classification.plugin_crn_classifier module

Counterfactual Recurrent Network treatment effects model for classification on the outcomes (targets).

class tempor.methods.treatments.temporal.classification.plugin_crn_classifier.CRNTreatmentsClassifier(**params: Any)[source]

Bases: BaseTemporalTreatmentEffects

Counterfactual Recurrent Network treatment effects model for classification on the outcomes (targets).

Parameters:
**params : Any

Parameters for the model.

Example

>>> from tempor import plugin_loader
>>>
>>> # Load the model:
>>> model = plugin_loader.get("treatments.temporal.classification.crn_classifier", n_iter=50)
>>>
>>> # Train:
>>> # model.fit(dataset)
>>>
>>> # Predict:
>>> # assert model.predict(dataset, n_future_steps = 10).numpy().shape == (len(dataset), 10, 5)

References

Estimating counterfactual treatment outcomes over time through adversarially balanced representations, Ioana Bica, Ahmed M. Alaa, James Jordon, Mihaela van der Schaar.

ParamsDefinition

alias of Seq2seqParams

params : Seq2seqParams
static hyperparameter_space(*args: Any, **kwargs: Any) list[Params][source]

The hyperparameter search domain, used for tuning.

Can provide variadics *args and **kwargs, these will be received from sample_hyperparameters.

category : ClassVar[plugin_typing.PluginCategory] = 'treatments.temporal.classification'

Plugin category, such as 'prediction.one_off.classification'. Must be set by the plugin class using @register_plugin.

name : ClassVar[plugin_typing.PluginName] = 'crn_classifier'

Plugin name, such as 'my_nn_classifier'. Must be set by the plugin class using @register_plugin.

plugin_type : ClassVar[plugin_typing.PluginTypeArg] = 'method'

Plugin type, such as 'method'. May be optionally set by the plugin class using @register_plugin, else will set the default plugin type.