tempor.methods.treatments.one_off.regression.plugin_synctwin_regressor module

SyncTwin treatment effects estimation.

class tempor.methods.treatments.one_off.regression.plugin_synctwin_regressor.SyncTwinParams(hidden_size: int = 20, tau: float = 1.0, lambda_prognostic: float = 1.0, lambda_reconstruction: float = 1.0, batch_size: int = 32, pretraining_iterations: int = 5000, matching_iterations: int = 20000, inference_iterations: int = 20000, use_validation_set_in_training: bool = True, treatment_status_is_treated: int = 1)[source]

Bases: object

Parameters for SyncTwin model. See paper “SyncTwin: Treatment Effect Estimation with Longitudinal Outcomes”.

hidden_size : int = 20
tau : float = 1.0
lambda_prognostic : float = 1.0
lambda_reconstruction : float = 1.0
batch_size : int = 32
pretraining_iterations : int = 5000
matching_iterations : int = 20000
inference_iterations : int = 20000
use_validation_set_in_training : bool = True
treatment_status_is_treated : int = 1
class tempor.methods.treatments.one_off.regression.plugin_synctwin_regressor.SyncTwinTreatmentsRegressor(**params: Any)[source]

Bases: BaseOneOffTreatmentEffects

SyncTwin treatment effects estimation.

Parameters:
**params : Any

Parameters for the model.

Example

>>> from tempor import plugin_loader
>>>
>>> # Load the model:
>>> model = plugin_loader.get("treatments.one_off.regression.synctwin_regressor", n_iter=50)
>>>
>>> # Train:
>>> # model.fit(dataset)
>>>
>>> # Predict:
>>> # assert model.predict(dataset, n_future_steps = 10).numpy().shape == (len(dataset), 10, 5)

References

SyncTwin: Treatment Effect Estimation with Longitudinal Outcomes, Zhaozhi Qian, Yao Zhang, Ioana Bica, Angela Wood, Mihaela van der Schaar.

ParamsDefinition

alias of SyncTwinParams

params : SyncTwinParams
static hyperparameter_space(*args: Any, **kwargs: Any) list[Params][source]

The hyperparameter search domain, used for tuning.

Can provide variadics *args and **kwargs, these will be received from sample_hyperparameters.

category : ClassVar[plugin_typing.PluginCategory] = 'treatments.one_off.regression'

Plugin category, such as 'prediction.one_off.classification'. Must be set by the plugin class using @register_plugin.

name : ClassVar[plugin_typing.PluginName] = 'synctwin_regressor'

Plugin name, such as 'my_nn_classifier'. Must be set by the plugin class using @register_plugin.

plugin_type : ClassVar[plugin_typing.PluginTypeArg] = 'method'

Plugin type, such as 'method'. May be optionally set by the plugin class using @register_plugin, else will set the default plugin type.