tempor.models.clairvoyance2.components.torch.interfaces module

Useful reusable interfaces for PyTorch models.

class tempor.models.clairvoyance2.components.torch.interfaces.OrganizedModule[source]

Bases: Module, ABC

Initializes internal Module state, shared by both nn.Module and ScriptModule.

prep_fit(data: Dataset, **kwargs) tuple[Tensor | DataLoader, ...][source]
set_attributes_from_kwargs(check_unknown_kwargs: bool = True, **kwargs)[source]
training : bool
class tempor.models.clairvoyance2.components.torch.interfaces.OrganizedPredictorModuleMixin[source]

Bases: ABC

prep_predict(data: Dataset, horizon: Horizon | None, **kwargs) tuple[Tensor | DataLoader, ...][source]
class tempor.models.clairvoyance2.components.torch.interfaces.OrganizedTreatmentEffectsModuleMixin[source]

Bases: ABC

prep_predict_counterfactuals(data: Dataset, sample_index: int, treatment_scenarios: Sequence[TimeSeries | EventSamples], horizon: Horizon | None, **kwargs) tuple[Tensor | DataLoader, ...][source]
class tempor.models.clairvoyance2.components.torch.interfaces.CustomizableLossMixin(loss_fn)[source]

Bases: ABC

abstract process_output_for_loss(output: Tensor, **kwargs) Tensor[source]
class tempor.models.clairvoyance2.components.torch.interfaces.SavableTorchModelMixin[source]

Bases: SavableModelMixin

state_dict : Callable
load_state_dict : Callable
save(path: str) None[source]
classmethod load(path: str)[source]