Source code for tempor.models.clairvoyance2.components.torch.rnn

# mypy: ignore-errors

import contextlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from ...utils import tensor_like as tl
from ...utils.common import safe_init_dotmap
from ...utils.dev import raise_not_implemented
from .ffnn import FeedForwardNet

RNNClass = Type[nn.RNNBase]
RNNHidden = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]

RNN_CLASS_MAP: Mapping[str, Type[nn.RNNBase]] = {"RNN": nn.RNN, "LSTM": nn.LSTM, "GRU": nn.GRU}

_DEBUG = False


[docs]class RecurrentNet(nn.Module): def __init__( self, rnn_type: str, input_size: int, hidden_size: int, nonlinearity: Optional[str], num_layers: int, bias: bool, dropout: float, bidirectional: bool, proj_size: Optional[int], ) -> None: super().__init__() kwargs: Dict[str, Any] = dict( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=True, # NOTE: We adopt batch first convention. dropout=dropout, bidirectional=bidirectional, ) self.rnn_type = rnn_type rnn_class = RNN_CLASS_MAP[rnn_type] if rnn_class == nn.RNN: if nonlinearity is None: nonlinearity = "tanh" kwargs["nonlinearity"] = nonlinearity if rnn_class == nn.LSTM: if proj_size is None: proj_size = 0 kwargs["proj_size"] = proj_size self.params = safe_init_dotmap(kwargs) self.rnn = rnn_class(**kwargs)
[docs] def forward(self, x: torch.Tensor, h: Optional[RNNHidden]) -> Tuple[torch.Tensor, RNNHidden]: if h is not None: rnn_out, h_out = self.rnn(x, h) else: rnn_out, h_out = self.rnn(x) return rnn_out, h_out
[docs] def get_output_and_h_dim(self) -> Tuple[int, Tuple[int, int, int]]: """A convenience method that computes the size of the output of `forward()` for each time-step. Useful for defining the input size of a downstream module to be applied at each timestep. For logic behind this calculation see: * https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html * https://pytorch.org/docs/stable/generated/torch.nn.GRU.html * https://pytorch.org/docs/stable/generated/torch.nn.RNN.html Note: This class always has `batch_first=True`, so timesteps are in the second dimension of the output. Returns: Tuple[int, Tuple[int, int, int]]: (`output` feature dimension, (`D * num_layers`, `H_out`, `H_cell`) """ if "proj_size" in self.params and self.params.proj_size > 0: h_out = self.params.proj_size else: h_out = self.params.hidden_size d = int(self.params.bidirectional) + 1 out_dim = d * h_out d_num_layers = d * self.params.num_layers h_cell = 0 if isinstance(self.rnn, nn.LSTM): h_cell = self.params.hidden_size if _DEBUG is True: # pragma: no cover print("------ compute_rnn_output_dim_per_timestep() ------") print("h_out", h_out) print("d", d) print("out_dim", out_dim) print("d_num_layers", d_num_layers) print("h_cell", h_cell) print("--- compute_rnn_output_dim_per_timestep() [END] ---") return out_dim, (d_num_layers, h_out, h_cell)
[docs]class AutoregressiveMixin(ABC): def __init__(self, feed_first_n: Optional[int] = None) -> None: self.feed_first_n = feed_first_n self.x_used_in_autoregress: Any = None @abstractmethod def _forward_for_autoregress( self, x: torch.Tensor, timestep_idx: int, **kwargs ) -> torch.Tensor: # pragma: no cover ... @staticmethod def _validate_shape(t: torch.Tensor, t_name: str, feed_first_n: Optional[int]) -> None: if t.ndim != 3: raise RuntimeError(f"{t_name} expected to have 3 dimensions but {t.ndim} found") if feed_first_n is not None: if feed_first_n > t.shape[-1]: raise RuntimeError( f"`feed_first_n` ({feed_first_n}) must be < or = the size " f"of the last dimension of {t_name} ({t.shape[-1]})" )
[docs] def autoregress(self, x: torch.Tensor, **kwargs) -> torch.Tensor: self._validate_shape(x, "`x`", self.feed_first_n) self.x_used_in_autoregress = x.clone() out_list: List[torch.Tensor] = [] n_timesteps = x.shape[1] for time_idx in range(n_timesteps): out = self._forward_for_autoregress(self.x_used_in_autoregress[:, [time_idx], :], time_idx, **kwargs) self._validate_shape(out, "`forward_for_autoregress()` output", self.feed_first_n) assert out.shape[1] == 1 if self.feed_first_n is None: if out.shape[-1] != x.shape[-1]: raise RuntimeError( "`forward_for_autoregress()` output and `x` last dimension must be the same size " f"but were {out.shape[-1]} and {x.shape[-1]} respectively" ) out_list.append(out) if time_idx < n_timesteps - 1: self.x_used_in_autoregress[:, [time_idx + 1], : self.feed_first_n] = out out = torch.cat(out_list, dim=1) return out
[docs]class RecurrentFFNet(AutoregressiveMixin, nn.Module): def __init__( self, rnn_type: str, input_size: int, hidden_size: int, nonlinearity: Optional[str], num_layers: int, bias: bool, dropout: float, bidirectional: bool, proj_size: Optional[int], # --- ff_out_size: int, ff_in_size_adjust: int = 0, ff_hidden_dims: Sequence[int] = tuple(), ff_out_activation: Optional[str] = "ReLU", ff_hidden_activations: Optional[str] = "ReLU", ) -> None: nn.Module.__init__(self) AutoregressiveMixin.__init__(self, feed_first_n=ff_out_size) self.rnn_type = rnn_type self.rnn = RecurrentNet( rnn_type=rnn_type, input_size=input_size, hidden_size=hidden_size, nonlinearity=nonlinearity, num_layers=num_layers, bias=bias, dropout=dropout, bidirectional=bidirectional, proj_size=proj_size, ) self.ff_in_size, *_ = self.rnn.get_output_and_h_dim() self.ff_in_size += ff_in_size_adjust self.ff_out_size = ff_out_size self.ffnn = FeedForwardNet( in_dim=self.ff_in_size, out_dim=self.ff_out_size, hidden_dims=ff_hidden_dims, out_activation=ff_out_activation, hidden_activations=ff_hidden_activations, )
[docs] def rnn_out_postprocess(self, rnn_out: torch.Tensor, **kwargs) -> torch.Tensor: # pylint: disable=unused-argument return rnn_out
[docs] def forward( self, x: torch.Tensor, h: Optional[RNNHidden], padding_indicator: Optional[float] = None, **kwargs_rnn_out_callback, ) -> Tuple[torch.Tensor, torch.Tensor, RNNHidden]: if padding_indicator is not None: with packed(x, padding_indicator) as p: p.packed, h = self.rnn(p.packed, h=h) rnn_out = p.unpacked else: rnn_out, h = self.rnn(x, h=h) if TYPE_CHECKING: assert h is not None if _DEBUG is True: # pragma: no cover print("rnn_out.shape", rnn_out.shape) # type: ignore print("h (or h concat c) shape", rnn_out.shape) # type: ignore # TODO: Possibly an option to concatenate *non* last layer's h_n[, c_n], but may be needlessly complex. # h_flattened = h.reshape(shape=[current_batch_size, -1]) # print("h_flattened.shape", h_flattened.shape) rnn_out_postprocessed = self.rnn_out_postprocess(rnn_out, **kwargs_rnn_out_callback) # type: ignore out = apply_to_each_timestep( self.ffnn, input_tensor=rnn_out_postprocessed, output_size=self.ff_out_size, concat_tensors=[], padding_indicator=padding_indicator, expected_module_input_size=self.ff_in_size, ) return out, rnn_out, h # type: ignore
def _forward_for_autoregress(self, x: torch.Tensor, timestep_idx: int, **kwargs) -> torch.Tensor: out, *_ = self.forward(x, **kwargs) return out
[docs]@dataclass class PackedContainer: packed: torch.nn.utils.rnn.PackedSequence unpacked: Optional[torch.Tensor] = None unpacked_lens: Optional[torch.Tensor] = None
[docs]@contextlib.contextmanager def packed(x: torch.Tensor, padding_indicator: float, batch_first: bool = True, enforce_sorted: bool = False): if x.ndim != 3: raise RuntimeError(f"Input to `packed` must be a 3 dimensional tensor but {x.ndim} dimensions found") if batch_first is False: raise_not_implemented("`packed()` with batch_first = False") max_len = x.shape[1] # Treat as padding if *any* of the features has a padding value: padding_bools = tl.eq_indicator(x.detach(), padding_indicator).any(dim=-1) # Assert that all padding was at the end: expect_padding_true = padding_bools.sum(dim=1) for idx, len_ in enumerate(expect_padding_true): if len_ > 0: if (padding_bools[idx, -len_:] == False).any(): # noqa: E712 raise RuntimeError("Found padding values not at the end of sequences") where_all_padding = padding_bools.all(dim=1) lengths_type = torch.int64 device = where_all_padding.device out_lens_template = torch.zeros(size=where_all_padding.shape, dtype=lengths_type, device=device) x_exclude_all_padding_samples = x[~where_all_padding, :, :] x_seq_lens = (~padding_bools).sum(dim=1)[~where_all_padding] x_seq_lens = x_seq_lens.to(device="cpu", dtype=lengths_type) x_packed = pack_padded_sequence( x_exclude_all_padding_samples, x_seq_lens, batch_first=batch_first, enforce_sorted=enforce_sorted ) packed_container = PackedContainer(x_packed) try: yield packed_container finally: x_unpacked, x_unpacked_lens = pad_packed_sequence( packed_container.packed, batch_first=batch_first, padding_value=padding_indicator, total_length=max_len ) out_template = torch.full( size=(x.shape[0], x.shape[1], x_unpacked.shape[2]), fill_value=padding_indicator, dtype=x_unpacked.dtype, device=x_unpacked.device, ) out_template[~where_all_padding, :, :] = x_unpacked out_lens_template[~where_all_padding] = x_unpacked_lens.to(device=device) packed_container.unpacked = out_template packed_container.unpacked_lens = out_lens_template
[docs]def apply_to_each_timestep( module: nn.Module, input_tensor: torch.Tensor, output_size: int, expected_module_input_size: int, padding_indicator: Optional[float], concat_tensors: Iterable[torch.Tensor] = tuple(), ) -> torch.Tensor: """Applies `module` forward to each timestep of `input_tensor`. Timestep dimension is expected to be dimension 1. Args: module (`nn.Module`): Module to apply at each timestep. input_tensor (`torch.Tensor`): Tensor to apply module to. Shape: `[n_samples, n_timesteps, n_features]`. output_size (`int`): The size of the feature (last) dimension of the `module` output. expected_module_input_size (`int`): Will check that module input dimension is this value. padding_indicator (`Optional[float]`): If `None`, assume no padding in `input_tensor` timestep dimension. If a float value (or `nan`), those tensor elements are treated as padding and not passed through `module`. concat_tensors (`Iterable[torch.Tensor]`): Optionally provide a sequence of tensors to concatenate to input at each timestep, before passing to `module`. Defaults to `tuple()`. Raises: `RuntimeError`: If `expected_module_input_size` input size check fails. Returns: `torch.Tensor`: Output tensor. """ assert module is not None module_out_list = [] for timestep_idx in range(input_tensor.shape[1]): input_timestep = input_tensor[:, timestep_idx, :] module_in = torch.cat([input_timestep, *concat_tensors], dim=-1) if _DEBUG is True: # pragma: no cover print("input_timestep.shape", input_timestep.shape) print("module_in.shape", module_in.shape) fill_val = 0.0 if padding_indicator is not None: is_padding_selector = tl.eq_indicator(input_timestep[:, -1], padding_indicator) assert isinstance(is_padding_selector, torch.Tensor) module_in = module_in[~is_padding_selector] fill_val = padding_indicator module_out_template = torch.full( size=(input_timestep.shape[0], output_size), fill_value=fill_val, device=input_tensor.device ) if _DEBUG is True: # pragma: no cover print("module_out_template.shape", module_out_template.shape) if module_in.shape[-1] != expected_module_input_size: raise RuntimeError( f"Module input wasn't of expected size, expected {expected_module_input_size}, " f"was {module_in.shape[-1]}" ) module_out_timestep = module(module_in) if _DEBUG is True: # pragma: no cover print("module_out_timestep.shape", module_out_timestep.shape) # Overwrite with padding value: # TODO: better way of masking? if padding_indicator is not None: assert isinstance(is_padding_selector, torch.Tensor) # type: ignore module_out_template[~is_padding_selector] = module_out_timestep else: module_out_template[:] = module_out_timestep module_out_list.append(module_out_template) # Concatenate along the time dimension, to get shape (n_samples, n_timesteps, output_size) final_output = torch.stack(module_out_list, dim=1) if _DEBUG is True: # pragma: no cover print("final_output.shape", final_output.shape) return final_output
[docs]def mask_and_reshape(mask_selector: torch.BoolTensor, tensor: torch.Tensor) -> torch.Tensor: # First applies `mask_selector` selector to `tensor` to take only values where `mask_selector` is True, this makes # a 1D tensor. Then reshapes this resultant tensor to have the same size on the last dimension as the original # tensor, as in: tensor_masked.reshape(-1, tensor.shape[-1]) tensor_masked = torch.masked_select(tensor, mask=mask_selector) tensor_final = tensor_masked.reshape(-1, tensor.shape[-1]) return tensor_final