Test In Colab

Extending TemporAI Tutorial 01: Writing a Custom Method Plugin

This tutorial shows how to extend TemporAI by wring a custom method (as in algorithm, model) plugin.

Writing a Custom Plugin 101

In order to write a custom plugin for TemporAI, you need to do the following: 1. Inherit from the appropriate base class for the category of plugin you are writing. 2. Implement the methods (as in, functions of the class) that the plugin needs. 3. Register the plugin with TemporAI.

We will go through an example in this tutorial.

1. Inherit from the appropriate base class for the category of the method plugin you are writing.

You need to find which category of method plugin you are writing.

A summary of different plugin categories is available in the README.

You can also view all the different plugin categories as so:

[ ]:
from tempor import plugin_loader

plugin_categories = plugin_loader.list_categories(plugin_type="method")

list(plugin_categories.keys())
['prediction.one_off.classification',
 'prediction.one_off.regression',
 'prediction.temporal.classification',
 'prediction.temporal.regression',
 'preprocessing.encoding.static',
 'preprocessing.encoding.temporal',
 'preprocessing.imputation.static',
 'preprocessing.imputation.temporal',
 'preprocessing.nop',
 'preprocessing.scaling.static',
 'preprocessing.scaling.temporal',
 'time_to_event',
 'treatments.one_off.regression',
 'treatments.temporal.classification',
 'treatments.temporal.regression']

Remember you can also see the existing method plugins and how they correspond to different categories, as follows:

[ ]:
all_plugins = plugin_loader.list(plugin_type="method")

from rich.pretty import pprint  # For prettifying the print output only.

pprint(all_plugins, indent_guides=True)
{
'prediction': {
│   │   'one_off': {
│   │   │   'classification': ['cde_classifier', 'ode_classifier', 'nn_classifier', 'laplace_ode_classifier'],
│   │   │   'regression': ['laplace_ode_regressor', 'nn_regressor', 'ode_regressor', 'cde_regressor']
│   │   },
│   │   'temporal': {'classification': ['seq2seq_classifier'], 'regression': ['seq2seq_regressor']}
},
'preprocessing': {
│   │   'encoding': {'static': ['static_onehot_encoder'], 'temporal': ['ts_onehot_encoder']},
│   │   'imputation': {
│   │   │   'static': ['static_tabular_imputer'],
│   │   │   'temporal': ['ffill', 'ts_tabular_imputer', 'bfill']
│   │   },
│   │   'nop': ['nop_transformer'],
│   │   'scaling': {
│   │   │   'static': ['static_minmax_scaler', 'static_standard_scaler'],
│   │   │   'temporal': ['ts_minmax_scaler', 'ts_standard_scaler']
│   │   }
},
'time_to_event': ['ts_coxph', 'ts_xgb', 'dynamic_deephit'],
'treatments': {
│   │   'one_off': {'regression': ['synctwin_regressor']},
│   │   'temporal': {'classification': ['crn_classifier'], 'regression': ['crn_regressor']}
}
}

Let’s say you would like to write a plugin of category "prediction.one_off.classification".

You can find which base class you need to inherit from as follows.

[ ]:
plugin_categories = plugin_loader.list_categories(plugin_type="method")

print("Base classes for all categories:")
pprint(plugin_categories, indent_guides=False)

print("Base class you need:")
print(plugin_categories["prediction.one_off.classification"])
Base classes for all categories:
{
    'prediction.one_off.classification': <class 'tempor.methods.prediction.one_off.classification.BaseOneOffClassifier'>,
    'prediction.one_off.regression': <class 'tempor.methods.prediction.one_off.regression.BaseOneOffRegressor'>,
    'prediction.temporal.classification': <class 'tempor.methods.prediction.temporal.classification.BaseTemporalClassifier'>,
    'prediction.temporal.regression': <class 'tempor.methods.prediction.temporal.regression.BaseTemporalRegressor'>,
    'preprocessing.encoding.static': <class 'tempor.methods.preprocessing.encoding._base.BaseEncoder'>,
    'preprocessing.encoding.temporal': <class 'tempor.methods.preprocessing.encoding._base.BaseEncoder'>,
    'preprocessing.imputation.static': <class 'tempor.methods.preprocessing.imputation._base.BaseImputer'>,
    'preprocessing.imputation.temporal': <class 'tempor.methods.preprocessing.imputation._base.BaseImputer'>,
    'preprocessing.nop': <class 'tempor.methods.core._base_transformer.BaseTransformer'>,
    'preprocessing.scaling.static': <class 'tempor.methods.preprocessing.scaling._base.BaseScaler'>,
    'preprocessing.scaling.temporal': <class 'tempor.methods.preprocessing.scaling._base.BaseScaler'>,
    'time_to_event': <class 'tempor.methods.time_to_event.BaseTimeToEventAnalysis'>,
    'treatments.one_off.regression': <class 'tempor.methods.treatments.one_off._base.BaseOneOffTreatmentEffects'>,
    'treatments.temporal.classification': <class 'tempor.methods.treatments.temporal._base.BaseTemporalTreatmentEffects'>,
    'treatments.temporal.regression': <class 'tempor.methods.treatments.temporal._base.BaseTemporalTreatmentEffects'>
}
Base class you need:
<class 'tempor.methods.prediction.one_off.classification.BaseOneOffClassifier'>

You can then find the class in the TemporAI source code, to see its method signatures etc.

2. Implement the methods the plugin needs.

Different category plugins have different methods (functions) that need to be implemented, but the key methods are: * _fit() where you provide your implementation of the fitting (training). * _predict() where you provide your implementation of the prediction (inference). * _transform() where you provide your implementation of data transformation (for preprocessing plugins).

Classification-related plugins also have _predict_proba() and treatment effects plugins have _predict_counterfactuals().

Note that these methods have a preceding underscore _, and are different from the corresponding “public” methods without the underscore (e.g fit()). When extending, you need to implement the _<...> method, and the corresponding “public” method in TemporAI is what the user of your plugin will call. The “public” methods also do various necessary validation and other checks behind the scenes.

If you haven’t implemented some required method for the plugin, Python will notify you by raising an exception when you attempt to instantiate your plugin (see Python ``abc` <https://docs.python.org/3/library/abc.html>`__).

In our example case, you will need to implement the following methods for BaseOneOffClassifier:

from tempor.methods.prediction.one_off.classification import BaseOneOffClassifier

class MyPlugin(BaseOneOffClassifier):
    # The initializer:
    def __init__(self, **params: Any) -> None:
        ...

    # The _fit implementation.
    def _fit(self, data: dataset.BaseDataset, *args, **kwargs):
        ...

    def _predict(self, data: dataset.PredictiveDataset, *args: Any, **kwargs: Any) -> samples.StaticSamplesBase:
        ...

    def _predict_proba(self, data: dataset.PredictiveDataset, *args: Any, **kwargs: Any) -> samples.StaticSamplesBase:
        ...

    @staticmethod
    def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Params]:
        # This method is not currently used in TemporAI (it will be used once AutoML component is implemented).
        # For now, you may just return an empty list.
        ...

3. Register the plugin with TemporAI.

Registering your plugin with TemporAI is very simple, you need to use the register_plugin decorator, as shown in the example below.

You will need to specify the name of your plugin and its category in the decorator.

Note: You may omit plugin_type="method" below, as "method" is the default plugin type.

from tempor.core.plugins import register_plugin

@register_plugin(name="my_plugin", category="prediction.one_off.classification", plugin_type="method")
class MyPlugin(BaseOneOffClassifier):
    ...

Note on __init__ parameters (arguments)

You will also need to define the input parameters (arguments) that will be passed into your plugin’s __init__ in the following way:

import dataclasses

# 1. Write dataclass with your __init__ parameters:
@dataclasses.dataclass
class MyPluginParams:
    # Specify the parameter, data type and default value as below:
    lr: float = 0.001
    batch_size: int = 100

class MyPlugin(BaseOneOffClassifier):
    # 2. Set the `ParamsDefinition` class variable in your plugin to this dataclass.
    ParamsDefinition = MyPluginParams

    def __init__(self, **params: Any) -> None:
        # 3. Call the parent __init__ as so.
        super().__init__(**params)

        # 4. You will now be able to access these in your class like so:
        print(self.params.lr)
        print(self.params.batch_size)


# 5. The user will then be able to specify the arguments as necessary when initializing your plugin:
model = MyPlugin(batch_size=22)

Putting it all together

Now putting this together in an example of a one-off classifier plugin that always returns 1s.

[ ]:
import dataclasses
from typing import Any, List

import numpy as np

from tempor.core.plugins import register_plugin
from tempor.methods.core.params import Params
from tempor.data import dataset, samples
from tempor.methods.prediction.one_off.classification import BaseOneOffClassifier


@dataclasses.dataclass
class MyClassifierParams:
    some_parameter: int = 1
    other_parameter: float = 0.5


@register_plugin(name="my_classifier", category="prediction.one_off.classification", plugin_type="method")
class MyClassifierClassifier(BaseOneOffClassifier):
    ParamsDefinition = MyClassifierParams

    def __init__(self, **param) -> None:
        super().__init__(**param)

    def _fit(self, data: dataset.BaseDataset, *args, **kwargs):
        """Does nothing."""
        return self  # Fit method needs to return `self`.

    def _predict(self, data: dataset.PredictiveDataset, *args: Any, **kwargs: Any) -> samples.StaticSamplesBase:
        """Always returns 1"""

        assert data.predictive.targets is not None
        preds = np.ones_like(data.predictive.targets.numpy())

        return samples.StaticSamples.from_numpy(preds, dtype=int)

    def _predict_proba(self, data: dataset.PredictiveDataset, *args: Any, **kwargs: Any) -> samples.StaticSamplesBase:
        """Always returns 1.0"""

        assert data.predictive.targets is not None
        preds = np.ones_like(data.predictive.targets.numpy())

        return samples.StaticSamples.from_numpy(preds, dtype=float)

    @staticmethod
    def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Params]:
        return []

We now see our plugin in TemporAI:

[ ]:
from tempor import plugin_loader

all_plugins = plugin_loader.list(plugin_type="method")

pprint(all_plugins, indent_guides=True)

my_classifier_found = "my_classifier" in all_plugins["prediction"]["one_off"]["classification"]
print(f"`my_classifier` plugin found in the category 'prediction.one_off.classification': {my_classifier_found}")
assert my_classifier_found
{
'prediction': {
│   │   'one_off': {
│   │   │   'classification': [
│   │   │   │   'cde_classifier',
│   │   │   │   'ode_classifier',
│   │   │   │   'nn_classifier',
│   │   │   │   'laplace_ode_classifier',
│   │   │   │   'my_classifier'
│   │   │   ],
│   │   │   'regression': ['laplace_ode_regressor', 'nn_regressor', 'ode_regressor', 'cde_regressor']
│   │   },
│   │   'temporal': {'classification': ['seq2seq_classifier'], 'regression': ['seq2seq_regressor']}
},
'preprocessing': {
│   │   'encoding': {'static': ['static_onehot_encoder'], 'temporal': ['ts_onehot_encoder']},
│   │   'imputation': {
│   │   │   'static': ['static_tabular_imputer'],
│   │   │   'temporal': ['ffill', 'ts_tabular_imputer', 'bfill']
│   │   },
│   │   'nop': ['nop_transformer'],
│   │   'scaling': {
│   │   │   'static': ['static_minmax_scaler', 'static_standard_scaler'],
│   │   │   'temporal': ['ts_minmax_scaler', 'ts_standard_scaler']
│   │   }
},
'time_to_event': ['ts_coxph', 'ts_xgb', 'dynamic_deephit'],
'treatments': {
│   │   'one_off': {'regression': ['synctwin_regressor']},
│   │   'temporal': {'classification': ['crn_classifier'], 'regression': ['crn_regressor']}
}
}
`my_classifier` plugin found in the category 'prediction.one_off.classification': True

The plugin can be used as normal.

[ ]:
# Get the plugin.

my_classifier = plugin_loader.get("prediction.one_off.classification.my_classifier", plugin_type="method")

print(my_classifier)
MyClassifierClassifier(
    name='my_classifier',
    category='prediction.one_off.classification',
    plugin_type='method',
    params={'some_parameter': 1, 'other_parameter': 0.5}
)
[ ]:
# Fit and predict on some data.

dataset = plugin_loader.get("prediction.one_off.sine", plugin_type="datasource", random_state=42).load()

my_classifier.fit(dataset)

print("Prediction:")
my_classifier.predict(dataset)
Prediction:

StaticSamples with data:

feat_0
sample_idx
0 1
1 1
2 1
3 1
4 1
... ...
95 1
96 1
97 1
98 1
99 1

100 rows × 1 columns