Source code for tempor.models.clairvoyance2.components.torch.common

# mypy: ignore-errors

from typing import Dict, Mapping, Optional, Type

import torch
import torch.nn as nn

ACTIVATION_MAP: Mapping[str, type] = {
    "ReLU": nn.ReLU,
    "Softmax": nn.Softmax,
    "Sigmoid": nn.Sigmoid,
    "Tanh": nn.Tanh,
}
ACTIVATION_KWARGS: Mapping[str, Mapping] = {
    "ReLU": dict(),
    "Softmax": dict(dim=-1),
    "Sigmoid": dict(),
    "Tanh": dict(),
}
# TODO: ^ Add more.


[docs]def init_activation(activation: str, kwargs: Optional[Mapping[str, Dict]] = None) -> nn.Module: kwargs = kwargs if kwargs is not None else ACTIVATION_KWARGS[activation] return ACTIVATION_MAP[activation](**ACTIVATION_KWARGS[activation])
OPTIM_MAP: Mapping[str, Type[torch.optim.Optimizer]] = { "Adam": torch.optim.Adam, "SGD": torch.optim.SGD, # TODO: Allow more. }