Test In Colab

Extending TemporAI Tutorial 03: Writing a Custom Data Source Plugin

This tutorial shows how to extend TemporAI by wring a custom data source plugin.

Note

See also “Writing a Custom Plugin 101” section in “Writing a Custom Method Plugin” tutorial.

Inherit from the appropriate base class for the category of the data source plugin you are writing.

You need to find which category of data source plugin you are writing.

You can view all the different data source plugin categories as so:

[ ]:
from tempor import plugin_loader

plugin_categories = plugin_loader.list_categories(plugin_type="datasource")

list(plugin_categories.keys())
['prediction.one_off',
 'prediction.temporal',
 'time_to_event',
 'treatments.one_off',
 'treatments.temporal']

Remember you can also see the existing data source plugins and how they correspond to different categories, as follows:

[ ]:
all_plugins = plugin_loader.list(plugin_type="datasource")

from rich.pretty import pprint  # For prettifying the print output only.

pprint(all_plugins, indent_guides=True)
{
'prediction': {'one_off': ['sine', 'google_stocks'], 'temporal': ['uci_diabetes', 'dummy_prediction']},
'time_to_event': ['pbc'],
'treatments': {'one_off': ['pkpd'], 'temporal': ['dummy_treatments']}
}

Let’s say you would like to write a plugin of category "prediction.one_off".

You can find which base class you need to inherit from as follows.

[ ]:
plugin_categories = plugin_loader.list_categories(plugin_type="datasource")

print("Base classes for all categories:")
pprint(plugin_categories, indent_guides=False)

print("Base class you need:")
print(plugin_categories["prediction.one_off"])
Base classes for all categories:
{
    'prediction.one_off': <class 'tempor.datasources.datasource.OneOffPredictionDataSource'>,
    'prediction.temporal': <class 'tempor.datasources.datasource.TemporalPredictionDataSource'>,
    'time_to_event': <class 'tempor.datasources.datasource.TimeToEventAnalysisDataSource'>,
    'treatments.one_off': <class 'tempor.datasources.datasource.OneOffTreatmentEffectsDataSource'>,
    'treatments.temporal': <class 'tempor.datasources.datasource.TemporalTreatmentEffectsDataSource'>
}
Base class you need:
<class 'tempor.datasources.datasource.OneOffPredictionDataSource'>

You can then find the class in the TemporAI source code, to see its method signatures etc.

Implement the methods the plugin needs.

DataSource plugins require the following methods to be implemented: * load() which returns the appropriate DataSet. * dataset_dir() which returns a string with the subdirectory where any data files will be stored. If no data files, return None. * url() which returns the data URL if relevant. If not applicable, return None.

The initializer __init__() can take keyword arguments related to initialization of the dataset, e.g. number of samples, random seed, etc.

If you haven’t implemented some required method for the plugin, Python will notify you by raising an exception when you attempt to instantiate your plugin (see Python ``abc` <https://docs.python.org/3/library/abc.html>`__).

Register the plugin with TemporAI.

Registering your plugin with TemporAI is very simple, you need to use the register_plugin decorator, as shown in the example below.

You will need to specify the name of your plugin and its category in the decorator.

The plugin_type needs to be set to "datasource".

from tempor.core.plugins import register_plugin

@register_plugin(name="my_plugin", category="prediction.one_off", plugin_type="datasource")
class MyPlugin(OneOffPredictionDataSource):
    ...

Example

Now putting this together in a minimal example.

[ ]:
from typing import Any

import numpy as np

from tempor.data.dataset import OneOffPredictionDataset
from tempor.core.plugins import register_plugin
from tempor.datasources.datasource import OneOffPredictionDataSource


@register_plugin(name="my_datasource", category="prediction.one_off", plugin_type="datasource")
class MyDataSource(OneOffPredictionDataSource):
    def __init__(self, random_seed: int = 123, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self.random_seed = random_seed

    def url(self):
        return None

    def dataset_dir(self):
        return None

    def load(self) -> OneOffPredictionDataset:
        np.random.seed(self.random_seed)
        return OneOffPredictionDataset(
            time_series=np.random.normal(size=(100, 30, 10)),
            targets=np.random.normal(size=(100, 1)),
        )

We now see our plugin in TemporAI:

[ ]:
from tempor import plugin_loader

all_plugins = plugin_loader.list(plugin_type="datasource")

pprint(all_plugins, indent_guides=True)

my_datasource_found = "my_datasource" in all_plugins["prediction"]["one_off"]
print(f"`my_datasource` plugin found in the category 'prediction.one_off': {my_datasource_found}")
assert my_datasource_found
{
'prediction': {
│   │   'one_off': ['sine', 'google_stocks', 'my_datasource'],
│   │   'temporal': ['uci_diabetes', 'dummy_prediction']
},
'time_to_event': ['pbc'],
'treatments': {'one_off': ['pkpd'], 'temporal': ['dummy_treatments']}
}
`my_datasource` plugin found in the category 'prediction.one_off': True

The plugin can be used as normal.

[ ]:
# Get the plugin.

my_datasource = plugin_loader.get("prediction.one_off.my_datasource", plugin_type="datasource")

print(my_datasource)
<__main__.MyDataSource object at 0x7f5f36a6cbb0>
[ ]:
# Load data.

dataset = my_datasource.load()

dataset
OneOffPredictionDataset(
    time_series=TimeSeriesSamples([100, *, 10]),
    predictive=OneOffPredictionTaskData(targets=StaticSamples([100, 1]))
)
[ ]:
# Preview covariates.

dataset.time_series

TimeSeriesSamples with data:

feat_0 feat_1 feat_2 feat_3 feat_4 feat_5 feat_6 feat_7 feat_8 feat_9
sample_idx time_idx
0 0 -1.085631 0.997345 0.282978 -1.506295 -0.578600 1.651437 -2.426679 -0.428913 1.265936 -0.866740
1 -0.678886 -0.094709 1.491390 -0.638902 -0.443982 -0.434351 2.205930 2.186786 1.004054 0.386186
2 0.737369 1.490732 -0.935834 1.175829 -1.253881 -0.637752 0.907105 -1.428681 -0.140069 -0.861755
3 -0.255619 -2.798589 -1.771533 -0.699877 0.927462 -0.173636 0.002846 0.688223 -0.879536 0.283627
4 -0.805367 -1.727669 -0.390900 0.573806 0.338589 -0.011830 2.392365 0.412912 0.978736 2.238143
... ... ... ... ... ... ... ... ... ... ... ...
99 25 -0.567276 -1.011354 -0.263128 0.281661 0.850365 0.675597 0.518956 1.458113 0.514021 -0.845099
26 -0.074948 2.889178 -0.055376 -1.284538 -0.215400 -0.002616 -0.406990 -0.089739 0.264811 1.060700
27 0.167216 -0.226127 1.517813 2.083333 -1.053875 -0.212461 1.006044 -0.253001 0.298598 -1.256375
28 1.212878 -1.656727 0.702245 0.047495 -0.736849 -0.050498 0.285193 0.735459 -0.384255 -0.262967
29 0.511946 0.672145 0.709544 -1.208061 -0.158659 -1.428280 0.430501 -1.144726 -0.473682 1.659917

3000 rows × 10 columns

[ ]:
# Preview targets.

dataset.predictive.targets

StaticSamples with data:

feat_0
sample_idx
0 -1.054170
1 -0.783011
2 1.827901
3 1.746807
4 1.328258
... ...
95 -0.766137
96 1.112182
97 0.076831
98 -1.566442
99 -1.267637

100 rows × 1 columns