User Guide Tutorial 06: Treatment Effects¶
This tutorial shows how to use TemporAI treatments plugins.
All treatments plugins¶
⚠️ The
treatmentsAPI is preliminary and likely to change.
In the treatment effects estimation task, the goal is to predict a counterfactual outcome given an alternative treatment.
To see all the relevant plugins:
[ ]:
from tempor import plugin_loader
from rich.pretty import pprint
all_treatments_plugins = plugin_loader.list()["treatments"]
pprint(all_treatments_plugins, indent_guides=False)
{ 'one_off': {'regression': ['synctwin_regressor']}, 'temporal': {'classification': ['crn_classifier'], 'regression': ['crn_regressor']} }
Now also load data source(s) we will use:
[ ]:
DummyTemporalTreatmentEffectsDataSource = plugin_loader.get_class(
"treatments.temporal.dummy_treatments", plugin_type="datasource"
)
Using a temporal treatment effects plugin.¶
In this setting, the treatments are time series, and the outcomes are also time series.
[ ]:
from tempor import plugin_loader
dataset = DummyTemporalTreatmentEffectsDataSource(
random_state=42,
temporal_covariates_missing_prob=0.0,
temporal_treatments_n_features=1,
temporal_treatments_n_categories=2,
).load()
print(dataset)
model = plugin_loader.get("treatments.temporal.regression.crn_regressor", epochs=20)
print(model)
TemporalTreatmentEffectsDataset(
time_series=TimeSeriesSamples([100, *, 5]),
static=StaticSamples([100, 3]),
predictive=TemporalTreatmentEffectsTaskData(
targets=TimeSeriesSamples([100, *, 2]),
treatments=TimeSeriesSamples([100, *, 1])
)
)
CRNTreatmentsRegressor(
name='crn_regressor',
category='treatments.temporal.regression',
plugin_type='method',
params={
'encoder_rnn_type': 'LSTM',
'encoder_hidden_size': 100,
'encoder_num_layers': 1,
'encoder_bias': True,
'encoder_dropout': 0.0,
'encoder_bidirectional': False,
'encoder_nonlinearity': None,
'encoder_proj_size': None,
'decoder_rnn_type': 'LSTM',
'decoder_hidden_size': 100,
'decoder_num_layers': 1,
'decoder_bias': True,
'decoder_dropout': 0.0,
'decoder_bidirectional': False,
'decoder_nonlinearity': None,
'decoder_proj_size': None,
'adapter_hidden_dims': [50],
'adapter_out_activation': 'Tanh',
'predictor_hidden_dims': [],
'predictor_out_activation': None,
'max_len': None,
'optimizer_str': 'Adam',
'optimizer_kwargs': {'lr': 0.01, 'weight_decay': 1e-05},
'batch_size': 32,
'epochs': 20,
'padding_indicator': -999.0
}
)
[ ]:
# Targets:
dataset.predictive.targets
TimeSeriesSamples with data:
| 0 | 1 | ||
|---|---|---|---|
| sample_idx | time_idx | ||
| 0 | 0 | -3.110475 | -3.566948 |
| 1 | 1.528495 | -0.653673 | |
| 2 | 2.275307 | -0.695371 | |
| 3 | 4.844060 | 3.469371 | |
| 4 | 4.420301 | 5.147500 | |
| ... | ... | ... | ... |
| 99 | 7 | 5.994185 | 6.225290 |
| 8 | 10.913662 | 5.346697 | |
| 9 | 9.558824 | 7.585175 | |
| 10 | 10.194430 | 5.795619 | |
| 11 | 13.774189 | 8.457336 |
1573 rows × 2 columns
[ ]:
# Treatments:
dataset.predictive.treatments
TimeSeriesSamples with data:
| 0 | ||
|---|---|---|
| sample_idx | time_idx | |
| 0 | 0 | 0 |
| 1 | 1 | |
| 2 | 1 | |
| 3 | 0 | |
| 4 | 0 | |
| ... | ... | ... |
| 99 | 7 | 1 |
| 8 | 1 | |
| 9 | 0 | |
| 10 | 0 | |
| 11 | 1 |
1573 rows × 1 columns
[ ]:
# Train.
model.fit(dataset);
Preparing data for decoder training...
Preparing data for decoder training DONE.
=== Training stage: 1. Train encoder ===
Epoch: 0, Prediction Loss: 75.212, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 75.212
Epoch: 1, Prediction Loss: 32.453, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 32.453
Epoch: 2, Prediction Loss: 19.389, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 19.389
Epoch: 3, Prediction Loss: 19.679, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 19.679
Epoch: 4, Prediction Loss: 19.588, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 19.588
Epoch: 5, Prediction Loss: 14.915, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 14.915
Epoch: 6, Prediction Loss: 10.608, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 10.608
Epoch: 7, Prediction Loss: 8.684, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 8.684
Epoch: 8, Prediction Loss: 6.953, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 6.953
Epoch: 9, Prediction Loss: 5.645, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 5.645
Epoch: 10, Prediction Loss: 5.060, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 5.060
Epoch: 11, Prediction Loss: 4.658, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.658
Epoch: 12, Prediction Loss: 4.398, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.398
Epoch: 13, Prediction Loss: 4.265, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.265
Epoch: 14, Prediction Loss: 4.170, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.170
Epoch: 15, Prediction Loss: 4.054, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.054
Epoch: 16, Prediction Loss: 3.949, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.949
Epoch: 17, Prediction Loss: 3.940, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.940
Epoch: 18, Prediction Loss: 3.918, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.918
Epoch: 19, Prediction Loss: 3.851, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.851
=== Training stage: 2. Train decoder ===
Epoch: 0, Prediction Loss: 34.622, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 34.622
Epoch: 1, Prediction Loss: 5.329, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 5.329
Epoch: 2, Prediction Loss: 3.826, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.826
Epoch: 3, Prediction Loss: 3.767, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.767
Epoch: 4, Prediction Loss: 3.746, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.746
Epoch: 5, Prediction Loss: 3.741, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.741
Epoch: 6, Prediction Loss: 3.726, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.726
Epoch: 7, Prediction Loss: 3.694, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.694
Epoch: 8, Prediction Loss: 3.723, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.723
Epoch: 9, Prediction Loss: 3.730, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.730
Epoch: 10, Prediction Loss: 3.670, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.670
Epoch: 11, Prediction Loss: 3.691, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.691
Epoch: 12, Prediction Loss: 3.726, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.726
Epoch: 13, Prediction Loss: 3.701, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.701
Epoch: 14, Prediction Loss: 3.696, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.696
Epoch: 15, Prediction Loss: 3.668, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.668
Epoch: 16, Prediction Loss: 3.612, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.612
Epoch: 17, Prediction Loss: 3.718, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.718
Epoch: 18, Prediction Loss: 3.717, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.717
Epoch: 19, Prediction Loss: 3.658, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.658
[ ]:
# Predict counterfactuals:
import numpy as np
dataset = dataset[:5]
# Define horizons for each sample.
horizons = [tc.time_indexes()[0][len(tc.time_indexes()[0]) // 2 :] for tc in dataset.time_series]
print("Horizons for sample 0:\n", horizons[0], end="\n\n")
# Define treatment scenarios for each sample.
treatment_scenarios = [[np.asarray([1] * len(h)), np.asarray([0] * len(h))] for h in horizons]
print("Alternative treatment scenarios for sample 0:\n", treatment_scenarios[0], end="\n\n")
# Call predict_counterfactuals.
counterfactuals = model.predict_counterfactuals(dataset, horizons=horizons, treatment_scenarios=treatment_scenarios)
print("Counterfactual outcomes for sample 0, given the alternative treatment scenarios:\n")
for idx, c in enumerate(counterfactuals[0]):
print(f"Treatment scenario {idx}, {treatment_scenarios[0][idx]}")
print(c, end="\n\n")
Horizons for sample 0:
[5, 6, 7, 8, 9, 10]
Alternative treatment scenarios for sample 0:
[array([1, 1, 1, 1, 1, 1]), array([0, 0, 0, 0, 0, 0])]
Counterfactual outcomes for sample 0, given the alternative treatment scenarios:
Treatment scenario 0, [1 1 1 1 1 1]
TimeSeries() with data:
0 1
time_idx
5 5.966007 4.997900
6 6.120805 5.138746
7 6.138720 5.155504
8 6.140833 5.157483
9 6.141082 5.157716
10 6.141112 5.157744
Treatment scenario 1, [0 0 0 0 0 0]
TimeSeries() with data:
0 1
time_idx
5 6.303439 4.996541
6 6.467038 5.143897
7 6.486365 5.162077
8 6.488708 5.164285
9 6.488993 5.164552
10 6.489028 5.164585