tempor.models.samplers module

Custom torch samplers.

class tempor.models.samplers.BaseSampler(data_source: Sized | None = None)[source]

Bases: Sampler

DataSampler samples the conditional vector and corresponding data.

get_dataset_conditionals() ndarray | None[source]
sample_conditional(batch: int, **kwargs: Any) tuple | None[source]
sample_conditional_for_class(batch: int, c: int) ndarray | None[source]
conditional_dimension() int[source]

Return the total number of categories.

conditional_probs() ndarray | None[source]

Return the total number of categories.

train_test() tuple[source]
class tempor.models.samplers.ImbalancedDatasetSampler(labels: list, train_size: float = 0.8)[source]

Bases: BaseSampler

Samples elements randomly from a given list of indices for imbalanced dataset.

train_test() tuple[source]

Return the train and test indices.

Returns:

(self.train_idx, self.test_idx)

Return type:

Tuple