[docs]classBaseSampler(torch.utils.data.sampler.Sampler):"""DataSampler samples the conditional vector and corresponding data."""
[docs]defget_dataset_conditionals(self)->Optional[np.ndarray]:# pragma: no cover # noqa: D102returnNone
[docs]@pydantic_utils.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))defsample_conditional(self,batch:int,**kwargs:Any# pylint: disable=unused-argument)->Optional[Tuple]:# pragma: no cover # noqa: D102returnNone
[docs]@pydantic_utils.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))defsample_conditional_for_class(self,batch:int,c:int# pylint: disable=unused-argument)->Optional[np.ndarray]:# pragma: no cover # noqa: D102returnNone
[docs]defconditional_dimension(self)->int:# pragma: no cover # noqa: D102"""Return the total number of categories."""return0
[docs]defconditional_probs(self)->Optional[np.ndarray]:# pragma: no cover # noqa: D102"""Return the total number of categories."""returnNone
[docs]deftrain_test(self)->Tuple:# pragma: no cover # noqa: D102raiseNotImplementedError()
[docs]classImbalancedDatasetSampler(BaseSampler):@pydantic_utils.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))def__init__(self,labels:List,train_size:float=0.8)->None:"""Samples elements randomly from a given list of indices for imbalanced dataset."""super().__init__(None)# if indices is not provided, all elements in the dataset will be consideredindices=list(range(len(labels)))self.train_idx,self.test_idx=train_test_split(indices,train_size=train_size)self.train_mapping={old_idx:new_idxfornew_idx,old_idxinenumerate(self.train_idx)}# if num_samples is not provided, draw `len(indices)` samples in each iterationself.num_train_samples=len(self.train_idx)# distribution of classes in the datasetdf=pd.DataFrame()df["label"]=labelsdf.index=indices# pyright: ignoredf=df.loc[self.train_idx]df=df.sort_index()label_to_count=df["label"].value_counts()weights=1.0/label_to_count[df["label"]]self.weights=torch.DoubleTensor(weights.to_list())def__iter__(self)->Generator:# noqa: D105return(self.train_mapping[self.train_idx[i]]foriintorch.multinomial(self.weights,self.num_train_samples,replacement=True))def__len__(self)->int:# noqa: D105returnlen(self.train_idx)
[docs]deftrain_test(self)->Tuple:"""Return the train and test indices. Returns: Tuple: ``(self.train_idx, self.test_idx)`` """returnself.train_idx,self.test_idx