tempor.methods.prediction.one_off.classification.plugin_ode_classifier module

One-off classification plugin based on “Neural Ordinary Differential Equations”.

class tempor.methods.prediction.one_off.classification.plugin_ode_classifier.ODEClassifierParams(n_units_hidden: int = 100, n_layers_hidden: int = 1, nonlin: Literal[none] | Literal[elu] | Literal[relu] | Literal[leaky_relu] | Literal[selu] | Literal[tanh] | Literal[sigmoid] | Literal[softmax] | Literal[gumbel_softmax] = 'relu', dropout: float = 0, atol: float = 0.01, rtol: float = 0.01, interpolation: Literal[cubic] | Literal[linear] = 'cubic', lr: float = 0.001, weight_decay: float = 0.001, n_iter: int = 1000, batch_size: int = 500, n_iter_print: int = 100, random_state: int = 0, patience: int = 10, clipping_value: int = 1, train_ratio: float = 0.8, device: str | None = None, dataloader_sampler: Literal[BatchSampler] | Literal[RandomSampler] | Literal[Sampler] | Literal[SequentialSampler] | Literal[SubsetRandomSampler] | Literal[WeightedRandomSampler] | None = None)[source]

Bases: object

Initialization parameters for ODEClassifier.

n_units_hidden : int = 100

Number of hidden units.

n_layers_hidden : int = 1

Number of hidden layers.

nonlin : Literal[none] | Literal[elu] | Literal[relu] | Literal[leaky_relu] | Literal[selu] | Literal[tanh] | Literal[sigmoid] | Literal[softmax] | Literal[gumbel_softmax] = 'relu'

Nonlin.

Type:

Activation for hidden layers. Available options

dropout : float = 0

Dropout value.

atol : float = 0.01

Absolute tolerance for solution.

rtol : float = 0.01

Relative tolerance for solution.

interpolation : Literal[cubic] | Literal[linear] = 'cubic'

"cubic" or "linear".

lr : float = 0.001

Learning rate for optimizer.

weight_decay : float = 0.001

l2 (ridge) penalty for the weights.

n_iter : int = 1000

Maximum number of iterations.

batch_size : int = 500

Batch size.

n_iter_print : int = 100

Number of iterations after which to print updates and check the validation loss.

random_state : int = 0

Random_state used.

patience : int = 10

Number of iterations to wait before early stopping after decrease in validation loss.

clipping_value : int = 1

Gradients clipping value.

train_ratio : float = 0.8

Train/test split ratio.

device : str | None = None

String representing PyTorch device. If None, DEVICE.

dataloader_sampler : Literal[BatchSampler] | Literal[RandomSampler] | Literal[Sampler] | Literal[SequentialSampler] | Literal[SubsetRandomSampler] | Literal[WeightedRandomSampler] | None = None

Custom data sampler for training.

class tempor.methods.prediction.one_off.classification.plugin_ode_classifier.ODEClassifier(**params: Any)[source]

Bases: BaseOneOffClassifier

Classifier based on ordinary differential equation (ODE) solvers.

Parameters:
**params : Any

Parameters and defaults as defined in ODEClassifierParams.

Example

>>> from tempor import plugin_loader
>>>
>>> dataset = plugin_loader.get("prediction.one_off.sine", plugin_type="datasource").load()
>>>
>>> # Load the model:
>>> model = plugin_loader.get("prediction.one_off.classification.ode_classifier", n_iter=50)
>>>
>>> # Train:
>>> model.fit(dataset)
ODEClassifier(...)
>>>
>>> # Predict:
>>> assert model.predict(dataset).numpy().shape == (len(dataset), 1)

References

“Neural Ordinary Differential Equations”, Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud.

category : ClassVar[plugin_typing.PluginCategory] = 'prediction.one_off.classification'

Plugin category, such as 'prediction.one_off.classification'. Must be set by the plugin class using @register_plugin.

name : ClassVar[plugin_typing.PluginName] = 'ode_classifier'

Plugin name, such as 'my_nn_classifier'. Must be set by the plugin class using @register_plugin.

plugin_type : ClassVar[plugin_typing.PluginTypeArg] = 'method'

Plugin type, such as 'method'. May be optionally set by the plugin class using @register_plugin, else will set the default plugin type.

ParamsDefinition

alias of ODEClassifierParams

params : ODEClassifierParams
static hyperparameter_space(*args: Any, **kwargs: Any) list[Params][source]

The hyperparameter search domain, used for tuning.

Can provide variadics *args and **kwargs, these will be received from sample_hyperparameters.