Source code for tempor.models.clairvoyance2.data.utils.counterfactual_utils
# mypy: ignore-errorsfromtypingimportList,Sequence,Unionimportnumpyasnpimportpandasaspdimporttorchfrom...interfaceimportTCounterfactualPredictionsfrom...interface.horizonimportTimeIndexHorizonfrom..importTimeSeries# TODO: Test
[docs]defto_counterfactual_predictions(list_counterfactual_predictions:Sequence[Union[np.ndarray,torch.Tensor]],data_historic_temporal_targets:TimeSeries,horizon:TimeIndexHorizon,)->TCounterfactualPredictions:assertlen(horizon.time_index_sequence)==1time_index=horizon.time_index_sequence[0]list_ts:List[TimeSeries]=[]template_ts=data_historic_temporal_targetsforarrinlist_counterfactual_predictions:ifisinstance(arr,torch.Tensor):arr=arr.detach().cpu().numpy()assertisinstance(arr,np.ndarray)ifarr.ndimnotin(2,3):raiseValueError("Arrays for counterfactual predictions must be either 2D or 3D (with 0th dimension size 1)")ifarr.ndim==3andarr.shape[0]!=1:raiseValueError("Arrays for counterfactual predictions must have 0th dimension size 1 if 3D (i.e. single sample)")ifarr.ndim==3:arr=arr[0,:,:]arr=arr.astype(float)assertisinstance(template_ts,TimeSeries)list_ts.append(TimeSeries.new_like(like=template_ts,data=pd.DataFrame(data=arr,index=time_index,columns=template_ts.df.columns),))returnlist_ts