Source code for tempor.models.clairvoyance2.components.torch.ffnn

# mypy: ignore-errors

from typing import Optional, OrderedDict, Sequence

import torch
import torch.nn as nn

from .common import init_activation


[docs]class FeedForwardNet(nn.Module): def __init__( self, in_dim: int, out_dim: int, hidden_dims: Sequence[int] = tuple(), out_activation: Optional[str] = "ReLU", hidden_activations: Optional[str] = "ReLU", ) -> None: super().__init__() list_dims = [in_dim] + list(hidden_dims) + [out_dim] dim_pairs = list(zip(list_dims[:-1], list_dims[1:])) layer_counter = 0 ordered_dict_components: OrderedDict[str, nn.Module] = OrderedDict() for in_feat, out_feat in dim_pairs[:-1]: ordered_dict_components[f"linear_{layer_counter}"] = nn.Linear(in_feat, out_feat) if hidden_activations is not None: ordered_dict_components[f"activation_{layer_counter}"] = init_activation(hidden_activations) layer_counter += 1 in_feat, out_feat = dim_pairs[-1] ordered_dict_components[f"linear_{layer_counter}"] = nn.Linear(in_feat, out_feat) if out_activation is not None: ordered_dict_components[f"activation_{layer_counter}"] = init_activation(out_activation) self.seq = nn.Sequential(ordered_dict_components)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.seq(x)