Source code for tempor.methods.prediction.one_off.classification
"""One-off classification estimators."""importabcfromtypingimportAny,Tupleimportnumpyasnpimportpydanticfromtyping_extensionsimportSelfimporttempor.methods.coreasmethods_corefromtempor.coreimportplugins,pydantic_utilsfromtempor.dataimportdataset,samplesdefcheck_data_class(data:Any)->None:"""Check if data is a one-off prediction dataset. Args: data (Any): The data. Raises: TypeError: If data is not a one-off prediction dataset. """ifnotisinstance(data,dataset.OneOffPredictionDataset):raiseTypeError("Expected `data` passed to a one-off classification estimator to be "f"`{dataset.OneOffPredictionDataset.__name__}` but was {type(data)}")
[docs]classBaseOneOffClassifier(methods_core.BasePredictor):def__init__(self,**params:Any)->None:# pylint: disable=useless-super-delegation"""Base class for one-off classification estimators. Args: **params (Any): Parameters as defined in :class:`BasePredictorParams`. """super().__init__(**params)
@abc.abstractmethoddef_predict(self,data:dataset.PredictiveDataset,*args:Any,**kwargs:Any,)->samples.StaticSamplesBase:# pragma: no cover...@abc.abstractmethoddef_predict_proba(self,data:dataset.PredictiveDataset,*args:Any,**kwargs:Any)->samples.StaticSamplesBase:# pragma: no cover...def_unpack_dataset(self,data:dataset.BaseDataset)->Tuple:temporal=data.time_series.numpy()observation_times=np.asarray(data.time_series.time_indexes())ifdata.predictiveisnotNoneanddata.predictive.targetsisnotNone:outcome=data.predictive.targets.numpy()else:outcome=np.zeros((len(temporal),0))ifdata.staticisnotNone:static=data.static.numpy()else:static=np.zeros((len(temporal),0))iflen(outcome.shape)==1:outcome=outcome.reshape(-1,1)returnstatic,temporal,observation_times,outcome