Extending TemporAI Tutorial 03: Writing a Custom Data Source Plugin¶
This tutorial shows how to extend TemporAI by wring a custom data source plugin.
Note
See also “Writing a Custom Plugin 101” section in “Writing a Custom Method Plugin” tutorial.
Inherit from the appropriate base class for the category of the data source plugin you are writing.¶
You need to find which category of data source plugin you are writing.
You can view all the different data source plugin categories as so:
[ ]:
from tempor import plugin_loader
plugin_categories = plugin_loader.list_categories(plugin_type="datasource")
list(plugin_categories.keys())
['prediction.one_off',
'prediction.temporal',
'time_to_event',
'treatments.one_off',
'treatments.temporal']
Remember you can also see the existing data source plugins and how they correspond to different categories, as follows:
[ ]:
all_plugins = plugin_loader.list(plugin_type="datasource")
from rich.pretty import pprint # For prettifying the print output only.
pprint(all_plugins, indent_guides=True)
{ │ 'prediction': {'one_off': ['sine', 'google_stocks'], 'temporal': ['uci_diabetes', 'dummy_prediction']}, │ 'time_to_event': ['pbc'], │ 'treatments': {'one_off': ['pkpd'], 'temporal': ['dummy_treatments']} }
Let’s say you would like to write a plugin of category "prediction.one_off".
You can find which base class you need to inherit from as follows.
[ ]:
plugin_categories = plugin_loader.list_categories(plugin_type="datasource")
print("Base classes for all categories:")
pprint(plugin_categories, indent_guides=False)
print("Base class you need:")
print(plugin_categories["prediction.one_off"])
Base classes for all categories:
{ 'prediction.one_off': <class 'tempor.datasources.datasource.OneOffPredictionDataSource'>, 'prediction.temporal': <class 'tempor.datasources.datasource.TemporalPredictionDataSource'>, 'time_to_event': <class 'tempor.datasources.datasource.TimeToEventAnalysisDataSource'>, 'treatments.one_off': <class 'tempor.datasources.datasource.OneOffTreatmentEffectsDataSource'>, 'treatments.temporal': <class 'tempor.datasources.datasource.TemporalTreatmentEffectsDataSource'> }
Base class you need:
<class 'tempor.datasources.datasource.OneOffPredictionDataSource'>
You can then find the class in the TemporAI source code, to see its method signatures etc.
Implement the methods the plugin needs.¶
DataSource plugins require the following methods to be implemented: * load() which returns the appropriate DataSet. * dataset_dir() which returns a string with the subdirectory where any data files will be stored. If no data files, return None. * url() which returns the data URL if relevant. If not applicable, return None.
The initializer __init__() can take keyword arguments related to initialization of the dataset, e.g. number of samples, random seed, etc.
If you haven’t implemented some required method for the plugin, Python will notify you by raising an exception when you attempt to instantiate your plugin (see Python ``abc` <https://docs.python.org/3/library/abc.html>`__).
Register the plugin with TemporAI.¶
Registering your plugin with TemporAI is very simple, you need to use the register_plugin decorator, as shown in the example below.
You will need to specify the name of your plugin and its category in the decorator.
The plugin_type needs to be set to "datasource".
from tempor.core.plugins import register_plugin
@register_plugin(name="my_plugin", category="prediction.one_off", plugin_type="datasource")
class MyPlugin(OneOffPredictionDataSource):
...
Example¶
Now putting this together in a minimal example.
[ ]:
from typing import Any
import numpy as np
from tempor.data.dataset import OneOffPredictionDataset
from tempor.core.plugins import register_plugin
from tempor.datasources.datasource import OneOffPredictionDataSource
@register_plugin(name="my_datasource", category="prediction.one_off", plugin_type="datasource")
class MyDataSource(OneOffPredictionDataSource):
def __init__(self, random_seed: int = 123, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.random_seed = random_seed
def url(self):
return None
def dataset_dir(self):
return None
def load(self) -> OneOffPredictionDataset:
np.random.seed(self.random_seed)
return OneOffPredictionDataset(
time_series=np.random.normal(size=(100, 30, 10)),
targets=np.random.normal(size=(100, 1)),
)
We now see our plugin in TemporAI:
[ ]:
from tempor import plugin_loader
all_plugins = plugin_loader.list(plugin_type="datasource")
pprint(all_plugins, indent_guides=True)
my_datasource_found = "my_datasource" in all_plugins["prediction"]["one_off"]
print(f"`my_datasource` plugin found in the category 'prediction.one_off': {my_datasource_found}")
assert my_datasource_found
{ │ 'prediction': { │ │ 'one_off': ['sine', 'google_stocks', 'my_datasource'], │ │ 'temporal': ['uci_diabetes', 'dummy_prediction'] │ }, │ 'time_to_event': ['pbc'], │ 'treatments': {'one_off': ['pkpd'], 'temporal': ['dummy_treatments']} }
`my_datasource` plugin found in the category 'prediction.one_off': True
The plugin can be used as normal.
[ ]:
# Get the plugin.
my_datasource = plugin_loader.get("prediction.one_off.my_datasource", plugin_type="datasource")
print(my_datasource)
<__main__.MyDataSource object at 0x7f5f36a6cbb0>
[ ]:
# Load data.
dataset = my_datasource.load()
dataset
OneOffPredictionDataset(
time_series=TimeSeriesSamples([100, *, 10]),
predictive=OneOffPredictionTaskData(targets=StaticSamples([100, 1]))
)
[ ]:
# Preview covariates.
dataset.time_series
TimeSeriesSamples with data:
| feat_0 | feat_1 | feat_2 | feat_3 | feat_4 | feat_5 | feat_6 | feat_7 | feat_8 | feat_9 | ||
|---|---|---|---|---|---|---|---|---|---|---|---|
| sample_idx | time_idx | ||||||||||
| 0 | 0 | -1.085631 | 0.997345 | 0.282978 | -1.506295 | -0.578600 | 1.651437 | -2.426679 | -0.428913 | 1.265936 | -0.866740 |
| 1 | -0.678886 | -0.094709 | 1.491390 | -0.638902 | -0.443982 | -0.434351 | 2.205930 | 2.186786 | 1.004054 | 0.386186 | |
| 2 | 0.737369 | 1.490732 | -0.935834 | 1.175829 | -1.253881 | -0.637752 | 0.907105 | -1.428681 | -0.140069 | -0.861755 | |
| 3 | -0.255619 | -2.798589 | -1.771533 | -0.699877 | 0.927462 | -0.173636 | 0.002846 | 0.688223 | -0.879536 | 0.283627 | |
| 4 | -0.805367 | -1.727669 | -0.390900 | 0.573806 | 0.338589 | -0.011830 | 2.392365 | 0.412912 | 0.978736 | 2.238143 | |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 99 | 25 | -0.567276 | -1.011354 | -0.263128 | 0.281661 | 0.850365 | 0.675597 | 0.518956 | 1.458113 | 0.514021 | -0.845099 |
| 26 | -0.074948 | 2.889178 | -0.055376 | -1.284538 | -0.215400 | -0.002616 | -0.406990 | -0.089739 | 0.264811 | 1.060700 | |
| 27 | 0.167216 | -0.226127 | 1.517813 | 2.083333 | -1.053875 | -0.212461 | 1.006044 | -0.253001 | 0.298598 | -1.256375 | |
| 28 | 1.212878 | -1.656727 | 0.702245 | 0.047495 | -0.736849 | -0.050498 | 0.285193 | 0.735459 | -0.384255 | -0.262967 | |
| 29 | 0.511946 | 0.672145 | 0.709544 | -1.208061 | -0.158659 | -1.428280 | 0.430501 | -1.144726 | -0.473682 | 1.659917 |
3000 rows × 10 columns
[ ]:
# Preview targets.
dataset.predictive.targets
StaticSamples with data:
| feat_0 | |
|---|---|
| sample_idx | |
| 0 | -1.054170 |
| 1 | -0.783011 |
| 2 | 1.827901 |
| 3 | 1.746807 |
| 4 | 1.328258 |
| ... | ... |
| 95 | -0.766137 |
| 96 | 1.112182 |
| 97 | 0.076831 |
| 98 | -1.566442 |
| 99 | -1.267637 |
100 rows × 1 columns