Source code for tempor.methods.treatments.temporal._base
importabcfromtypingimportAny,Listimportpydanticfromtyping_extensionsimportSelfimporttempor.methods.coreasmethods_corefromtempor.coreimportpydantic_utilsfromtempor.dataimportdataset,samplesdefcheck_data_class(data:Any)->None:"""Check that the passed data is of the correct class (`dataset.OneOffTreatmentEffectsDataset`). Args: data (Any): Data to check. Raises: TypeError: If the data is not of the correct class. """ifnotisinstance(data,dataset.TemporalTreatmentEffectsDataset):raiseTypeError("Expected `data` passed to a temporal treatment effects estimator to be "f"`{dataset.TemporalTreatmentEffectsDataset.__name__}` but was {type(data)}")
@abc.abstractmethoddef_predict(self,data:dataset.PredictiveDataset,*args:Any,**kwargs:Any)->samples.TimeSeriesSamplesBase:# pragma: no cover...
[docs]@pydantic_utils.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))defpredict_counterfactuals(self,data:dataset.PredictiveDataset,*args:Any,**kwargs:Any,)->List:"""Predict counterfactuals for the given data. Args: data (dataset.PredictiveDataset): Data to predict counterfactuals for. *args (Any): Additional positional arguments. **kwargs (Any): Additional keyword arguments. Returns: List: List of counterfactual predictions. """check_data_class(data)returnsuper().predict_counterfactuals(data,*args,**kwargs)
@abc.abstractmethoddef_predict_counterfactuals(self,data:dataset.PredictiveDataset,*args:Any,**kwargs:Any)->List:# pragma: no cover...