Test In Colab

User Guide Tutorial 06: Treatment Effects

This tutorial shows how to use TemporAI treatments plugins.

All treatments plugins

⚠️ The treatments API is preliminary and likely to change.

In the treatment effects estimation task, the goal is to predict a counterfactual outcome given an alternative treatment.

To see all the relevant plugins:

[ ]:
from tempor import plugin_loader
from rich.pretty import pprint

all_treatments_plugins = plugin_loader.list()["treatments"]

pprint(all_treatments_plugins, indent_guides=False)
{
    'one_off': {'regression': ['synctwin_regressor']},
    'temporal': {'classification': ['crn_classifier'], 'regression': ['crn_regressor']}
}

Now also load data source(s) we will use:

[ ]:
DummyTemporalTreatmentEffectsDataSource = plugin_loader.get_class(
    "treatments.temporal.dummy_treatments", plugin_type="datasource"
)

Using a temporal treatment effects plugin.

In this setting, the treatments are time series, and the outcomes are also time series.

[ ]:
from tempor import plugin_loader

dataset = DummyTemporalTreatmentEffectsDataSource(
    random_state=42,
    temporal_covariates_missing_prob=0.0,
    temporal_treatments_n_features=1,
    temporal_treatments_n_categories=2,
).load()
print(dataset)

model = plugin_loader.get("treatments.temporal.regression.crn_regressor", epochs=20)
print(model)
TemporalTreatmentEffectsDataset(
    time_series=TimeSeriesSamples([100, *, 5]),
    static=StaticSamples([100, 3]),
    predictive=TemporalTreatmentEffectsTaskData(
        targets=TimeSeriesSamples([100, *, 2]),
        treatments=TimeSeriesSamples([100, *, 1])
    )
)
CRNTreatmentsRegressor(
    name='crn_regressor',
    category='treatments.temporal.regression',
    plugin_type='method',
    params={
        'encoder_rnn_type': 'LSTM',
        'encoder_hidden_size': 100,
        'encoder_num_layers': 1,
        'encoder_bias': True,
        'encoder_dropout': 0.0,
        'encoder_bidirectional': False,
        'encoder_nonlinearity': None,
        'encoder_proj_size': None,
        'decoder_rnn_type': 'LSTM',
        'decoder_hidden_size': 100,
        'decoder_num_layers': 1,
        'decoder_bias': True,
        'decoder_dropout': 0.0,
        'decoder_bidirectional': False,
        'decoder_nonlinearity': None,
        'decoder_proj_size': None,
        'adapter_hidden_dims': [50],
        'adapter_out_activation': 'Tanh',
        'predictor_hidden_dims': [],
        'predictor_out_activation': None,
        'max_len': None,
        'optimizer_str': 'Adam',
        'optimizer_kwargs': {'lr': 0.01, 'weight_decay': 1e-05},
        'batch_size': 32,
        'epochs': 20,
        'padding_indicator': -999.0
    }
)
[ ]:
# Targets:
dataset.predictive.targets

TimeSeriesSamples with data:

0 1
sample_idx time_idx
0 0 -3.110475 -3.566948
1 1.528495 -0.653673
2 2.275307 -0.695371
3 4.844060 3.469371
4 4.420301 5.147500
... ... ... ...
99 7 5.994185 6.225290
8 10.913662 5.346697
9 9.558824 7.585175
10 10.194430 5.795619
11 13.774189 8.457336

1573 rows × 2 columns

[ ]:
# Treatments:
dataset.predictive.treatments

TimeSeriesSamples with data:

0
sample_idx time_idx
0 0 0
1 1
2 1
3 0
4 0
... ... ...
99 7 1
8 1
9 0
10 0
11 1

1573 rows × 1 columns

[ ]:
# Train.
model.fit(dataset);
Preparing data for decoder training...
Preparing data for decoder training DONE.
=== Training stage: 1. Train encoder ===
Epoch: 0, Prediction Loss: 75.212, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 75.212
Epoch: 1, Prediction Loss: 32.453, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 32.453
Epoch: 2, Prediction Loss: 19.389, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 19.389
Epoch: 3, Prediction Loss: 19.679, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 19.679
Epoch: 4, Prediction Loss: 19.588, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 19.588
Epoch: 5, Prediction Loss: 14.915, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 14.915
Epoch: 6, Prediction Loss: 10.608, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 10.608
Epoch: 7, Prediction Loss: 8.684, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 8.684
Epoch: 8, Prediction Loss: 6.953, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 6.953
Epoch: 9, Prediction Loss: 5.645, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 5.645
Epoch: 10, Prediction Loss: 5.060, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 5.060
Epoch: 11, Prediction Loss: 4.658, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.658
Epoch: 12, Prediction Loss: 4.398, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.398
Epoch: 13, Prediction Loss: 4.265, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.265
Epoch: 14, Prediction Loss: 4.170, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.170
Epoch: 15, Prediction Loss: 4.054, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 4.054
Epoch: 16, Prediction Loss: 3.949, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.949
Epoch: 17, Prediction Loss: 3.940, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.940
Epoch: 18, Prediction Loss: 3.918, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.918
Epoch: 19, Prediction Loss: 3.851, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.851
=== Training stage: 2. Train decoder ===
Epoch: 0, Prediction Loss: 34.622, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 34.622
Epoch: 1, Prediction Loss: 5.329, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 5.329
Epoch: 2, Prediction Loss: 3.826, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.826
Epoch: 3, Prediction Loss: 3.767, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.767
Epoch: 4, Prediction Loss: 3.746, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.746
Epoch: 5, Prediction Loss: 3.741, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.741
Epoch: 6, Prediction Loss: 3.726, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.726
Epoch: 7, Prediction Loss: 3.694, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.694
Epoch: 8, Prediction Loss: 3.723, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.723
Epoch: 9, Prediction Loss: 3.730, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.730
Epoch: 10, Prediction Loss: 3.670, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.670
Epoch: 11, Prediction Loss: 3.691, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.691
Epoch: 12, Prediction Loss: 3.726, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.726
Epoch: 13, Prediction Loss: 3.701, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.701
Epoch: 14, Prediction Loss: 3.696, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.696
Epoch: 15, Prediction Loss: 3.668, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.668
Epoch: 16, Prediction Loss: 3.612, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.612
Epoch: 17, Prediction Loss: 3.718, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.718
Epoch: 18, Prediction Loss: 3.717, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.717
Epoch: 19, Prediction Loss: 3.658, Lambda: 1.000, Treatment BR Loss: 0.000, Loss: 3.658
[ ]:
# Predict counterfactuals:

import numpy as np

dataset = dataset[:5]

# Define horizons for each sample.
horizons = [tc.time_indexes()[0][len(tc.time_indexes()[0]) // 2 :] for tc in dataset.time_series]
print("Horizons for sample 0:\n", horizons[0], end="\n\n")

# Define treatment scenarios for each sample.
treatment_scenarios = [[np.asarray([1] * len(h)), np.asarray([0] * len(h))] for h in horizons]
print("Alternative treatment scenarios for sample 0:\n", treatment_scenarios[0], end="\n\n")

# Call predict_counterfactuals.
counterfactuals = model.predict_counterfactuals(dataset, horizons=horizons, treatment_scenarios=treatment_scenarios)
print("Counterfactual outcomes for sample 0, given the alternative treatment scenarios:\n")
for idx, c in enumerate(counterfactuals[0]):
    print(f"Treatment scenario {idx}, {treatment_scenarios[0][idx]}")
    print(c, end="\n\n")
Horizons for sample 0:
 [5, 6, 7, 8, 9, 10]

Alternative treatment scenarios for sample 0:
 [array([1, 1, 1, 1, 1, 1]), array([0, 0, 0, 0, 0, 0])]

Counterfactual outcomes for sample 0, given the alternative treatment scenarios:

Treatment scenario 0, [1 1 1 1 1 1]
TimeSeries() with data:
                 0         1
time_idx
5         5.966007  4.997900
6         6.120805  5.138746
7         6.138720  5.155504
8         6.140833  5.157483
9         6.141082  5.157716
10        6.141112  5.157744

Treatment scenario 1, [0 0 0 0 0 0]
TimeSeries() with data:
                 0         1
time_idx
5         6.303439  4.996541
6         6.467038  5.143897
7         6.486365  5.162077
8         6.488708  5.164285
9         6.488993  5.164552
10        6.489028  5.164585