Source code for tempor.methods.preprocessing.imputation.static.plugin_static_tabular_imputer
"""Impute static covariates using any tabular imputer from the `hyperimpute` library."""importdataclassesfromtypingimportAny,Dict,Listfromtyping_extensionsimportSelf,get_argsfromtempor.coreimportpluginsfromtempor.dataimportdatasetfromtempor.data.samplesimportStaticSamplesfromtempor.methods.core.paramsimportCategoricalParams,Paramsfromtempor.methods.preprocessing.imputation._baseimportBaseImputer,TabularImputerTypefrom..hyperimpute_utilsimportmonkeypatch_hyperimpute_loggerwithmonkeypatch_hyperimpute_logger():fromhyperimpute.plugins.imputersimportImputers
[docs]@dataclasses.dataclassclassStaticTabularImputerParams:"""Initialization parameters for :class:`StaticTabularImputer`."""imputer:TabularImputerType="ice""""Which imputer to use for static covariate imputation."""random_state:int=0"""Random seed. Will be passed on to the underlying imputer."""imputer_params:Dict[str,Any]=dataclasses.field(default_factory=dict)"""Parameters to pass to the underlying imputer as a keyword arguments dictionary. Defaults to ``{}``."""
[docs]@plugins.register_plugin(name="static_tabular_imputer",category="preprocessing.imputation.static")classStaticTabularImputer(BaseImputer):ParamsDefinition=StaticTabularImputerParamsparams:StaticTabularImputerParams# type: ignoredef__init__(self,**params:Any)->None:"""Impute the static covariates using any tabular imputer from the `hyperimpute` library. Args: **params (Any): Parameters and defaults as defined in :class:`StaticTabularImputerParams`. Example: >>> from tempor import plugin_loader >>> >>> dataset = plugin_loader.get( ... "prediction.one_off.sine", ... plugin_type="datasource", ... with_missing=True, ... ).load() >>> assert dataset.static.dataframe().isna().sum().sum() != 0 >>> >>> # Load the model: >>> model = plugin_loader.get("preprocessing.imputation.static.static_tabular_imputer") >>> >>> # Train: >>> model.fit(dataset) StaticTabularImputer(...) >>> >>> # Impute: >>> imputed = model.transform(dataset) >>> assert imputed.static.dataframe().isna().sum().sum() == 0 """if"imputer_params"inparamsand"random_state"inparams["imputer_params"]:raiseValueError("Do not pass `random_state` as a key in `imputer_params`, pass it directly as `random_state`")super().__init__(**params)self.params.imputer_params["random_state"]=self.params.random_stateself.imputer=Imputers().get(self.params.imputer,**self.params.imputer_params)def_fit(self,data:dataset.BaseDataset,*args:Any,**kwargs:Any)->Self:ifdata.staticisnotNone:self.imputer.fit(data.static.dataframe())returnselfdef_transform(self,data:dataset.BaseDataset,*args:Any,**kwargs:Any)->dataset.BaseDataset:# Impute static data.ifdata.staticisnotNone:static_data=data.static.dataframe()imputed_static_data=self.imputer.transform(static_data)imputed_static_data.columns=static_data.columnsimputed_static_data.index=static_data.indexdata.static=StaticSamples.from_dataframe(imputed_static_data)returndata