Source code for tempor.methods.prediction.one_off.classification.plugin_laplace_classifier

"""One-off classification plugin based on
"Neural Laplace: Learning diverse classes of differential equations in the Laplace domain".
"""

import dataclasses
from typing import Any, List, Optional

import numpy as np
from typing_extensions import Self

from tempor.core import plugins
from tempor.data import dataset, samples
from tempor.methods.core.params import CategoricalParams, FloatParams, IntegerParams, Params
from tempor.methods.prediction.one_off.classification import BaseOneOffClassifier
from tempor.models import utils as model_utils
from tempor.models.constants import Nonlin, Samp
from tempor.models.ts_ode import ILTAlgorithm, NeuralODE


[docs]@dataclasses.dataclass class LaplaceODEClassifierParams: """Initialization parameters for :class:`LaplaceODEClassifier`.""" n_units_hidden: int = 100 """Number of hidden units.""" n_layers_hidden: int = 1 """Number of hidden layers.""" nonlin: Nonlin = "relu" """Activation for hidden layers. Available options: :obj:`~tempor.models.constants.Nonlin`.""" dropout: float = 0 """Dropout value.""" # CDE specific: ilt_reconstruction_terms: int = 33 """Number of ILT reconstruction terms, i.e. the number of complex :math:`s` points in ``laplace_rep_func`` to reconstruct a single time point.""" ilt_algorithm: ILTAlgorithm = "fourier" """Inverse Laplace transform algorithm to use. Available are {``fourier``, ``dehoog``, ``cme``, ``fixed_tablot``, ``stehfest``}.""" # Training: lr: float = 1e-3 """Learning rate for optimizer.""" weight_decay: float = 1e-3 """l2 (ridge) penalty for the weights.""" n_iter: int = 1000 """Maximum number of iterations.""" batch_size: int = 500 """Batch size.""" n_iter_print: int = 100 """Number of iterations after which to print updates and check the validation loss.""" random_state: int = 0 """Random_state used.""" patience: int = 10 """Number of iterations to wait before early stopping after decrease in validation loss.""" clipping_value: int = 1 """Gradients clipping value.""" train_ratio: float = 0.8 """Train/test split ratio.""" device: Optional[str] = None """String representing PyTorch device. If `None`, `~tempor.models.constants.DEVICE`.""" dataloader_sampler: Optional[Samp] = None """Custom data sampler for training."""
[docs]@plugins.register_plugin(name="laplace_ode_classifier", category="prediction.one_off.classification") class LaplaceODEClassifier(BaseOneOffClassifier): ParamsDefinition = LaplaceODEClassifierParams params: LaplaceODEClassifierParams # type: ignore def __init__(self, **params: Any) -> None: """Inverse Laplace Transform (ILT) algorithms implemented in PyTorch. Backpropagation through differential equation (DE) solutions in the Laplace domain is supported using the Riemann stereographic projection for better global representation of the complex Laplace domain. Args: **params (Any): Parameters and defaults as defined in :class:`LaplaceODEClassifierParams`. Example: >>> from tempor import plugin_loader >>> >>> dataset = plugin_loader.get("prediction.one_off.google_stocks", plugin_type="datasource").load() >>> >>> # Load the model: >>> model = plugin_loader.get("prediction.one_off.classification.laplace_ode_classifier", n_iter=50) >>> >>> # Train: >>> model.fit(dataset) LaplaceODEClassifier(...) >>> >>> # Predict: >>> assert model.predict(dataset).numpy().shape == (len(dataset), 1) References: "Neural Laplace: Learning diverse classes of differential equations in the Laplace domain", Holt, Samuel I and Qian, Zhaozhi and van der Schaar, Mihaela. """ # TODO: Model currently fails to run with SineDataSource data. Investigate and resolve. super().__init__(**params) self.device = model_utils.get_device(self.params.device) self.dataloader_sampler = model_utils.get_sampler(self.params.dataloader_sampler) self.model: Optional[NeuralODE] = None def _fit( self, data: dataset.BaseDataset, *args: Any, **kwargs: Any, ) -> Self: static, temporal, observation_times, outcome = self._unpack_dataset(data) outcome = outcome.squeeze() n_classes = len(np.unique(outcome)) self.model = NeuralODE( task_type="classification", n_static_units_in=static.shape[-1], n_temporal_units_in=temporal.shape[-1], output_shape=[n_classes], n_units_hidden=self.params.n_units_hidden, n_layers_hidden=self.params.n_layers_hidden, # Laplace backend="laplace", ilt_algorithm=self.params.ilt_algorithm, ilt_reconstruction_terms=self.params.ilt_reconstruction_terms, # training n_iter=self.params.n_iter, n_iter_print=self.params.n_iter_print, batch_size=self.params.batch_size, lr=self.params.lr, weight_decay=self.params.weight_decay, device=self.device, dataloader_sampler=self.dataloader_sampler, dropout=self.params.dropout, nonlin=self.params.nonlin, random_state=self.params.random_state, clipping_value=self.params.clipping_value, patience=self.params.patience, train_ratio=self.params.train_ratio, ) self.model.fit(static, temporal, observation_times, outcome) return self def _predict( self, data: dataset.PredictiveDataset, *args: Any, **kwargs: Any, ) -> samples.StaticSamplesBase: if self.model is None: raise RuntimeError("Fit the model first") static, temporal, observation_times, _ = self._unpack_dataset(data) preds = self.model.predict(static, temporal, observation_times) preds = preds.astype(float) preds = preds.reshape(-1, 1) return samples.StaticSamples.from_numpy(preds) def _predict_proba( self, data: dataset.PredictiveDataset, *args: Any, **kwargs: Any, ) -> samples.StaticSamplesBase: if self.model is None: raise RuntimeError("Fit the model first") static, temporal, observation_times, _ = self._unpack_dataset(data) preds = self.model.predict_proba(static, temporal, observation_times) preds = preds.astype(float) return samples.StaticSamples.from_numpy(preds)
[docs] @staticmethod def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Params]: # noqa: D102 return [ IntegerParams(name="n_units_hidden", low=100, high=1000), IntegerParams(name="n_layers_hidden", low=1, high=5), CategoricalParams(name="ilt_algorithm", choices=["fourier", "dehoog", "cme", "fixed_tablot", "stehfest"]), CategoricalParams(name="batch_size", choices=[64, 128, 256, 512]), CategoricalParams(name="lr", choices=[1e-3, 1e-4, 2e-4]), FloatParams(name="dropout", low=0, high=0.2), CategoricalParams(name="nonlin", choices=["relu", "elu", "leaky_relu", "selu"]), ]