Source code for tempor.models.transformer

"""Model implementations for Transformers."""

from typing import Any

import torch
from torch import nn
from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer

from .constants import DEVICE


[docs]class Permute(nn.Module): def __init__(self, *dims: Any) -> None: """Permute dimensions of a tensor with `torch.Tensor.permute`. Args: *dims (Any): Dimensions to permute. """ super().__init__() self.dims = dims
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass.""" return x.permute(self.dims)
[docs]class Transpose(nn.Module): def __init__(self, *dims: Any, contiguous: bool = False) -> None: """Transpose dimensions of a tensor with `torch.Tensor.transpose`. Args: *dims (Any): Dimensions to transpose. contiguous (bool, optional): Whether to call `.contiguous()` on the output. Defaults to `False`. """ super().__init__() self.dims, self.contiguous = dims, contiguous
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass.""" if self.contiguous: return x.transpose(*self.dims).contiguous() else: return x.transpose(*self.dims)
[docs]class TransformerModel(nn.Module): def __init__( self, n_units_in: int, n_units_hidden: int = 64, n_head: int = 1, d_ffn: int = 128, dropout: float = 0.1, activation: str = "relu", n_layers_hidden: int = 1, device: Any = DEVICE, ) -> None: """Transformer model. Args: n_units_in (int): The number of features (a.k.a. variables, dimensions, channels) in the time series dataset. n_units_hidden (int, optional): Total dimension of the model. Defaults to ``64``. n_head (int, optional): Parallel attention heads. Defaults to ``1``. d_ffn (int, optional): The dimension of the feedforward network model. Defaults to ``128``. dropout (float, optional): Dropout value passed to `~torch.nn.modules.transformer.TransformerEncoderLayer` s. Defaults to ``0.1``. activation (str, optional): The activation function of intermediate layer, ``"relu"`` or ``"gelu"``. Defaults to ``"relu"``. n_layers_hidden (int, optional): The number of sub-encoder-layers in the encoder. Defaults to ``1``. device (Any, optional): PyTorch device. Defaults to `~tempor.models.constants.DEVICE`. """ super().__init__() encoder_layer = TransformerEncoderLayer( n_units_hidden, n_head, dim_feedforward=d_ffn, dropout=dropout, activation=activation, ) encoder_norm = nn.LayerNorm(n_units_hidden) self.transformer_encoder = TransformerEncoder( # type: ignore [no-untyped-call] encoder_layer, n_layers_hidden, norm=encoder_norm, ) self.model = nn.Sequential( Permute(1, 0, 2), # bs x seq_len x nvars -> seq_len x bs x nvars nn.Linear(n_units_in, n_units_hidden), # seq_len x bs x nvars -> seq_len x bs x n_units_hidden nn.ReLU(), self.transformer_encoder, Transpose(1, 0), # seq_len x bs x n_units_hidden -> bs x seq_len x n_units_hidden nn.ReLU(), ).to(device)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass.""" return self.model(x) # pylint: disable=not-callable