"""Time series ODE model implementations."""
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import pydantic
import torch
import torchcde
import torchdiffeq
import torchlaplace
import torchlaplace.inverse_laplace
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, sampler
from typing_extensions import Literal
from tempor.core import pydantic_utils
from tempor.log import logger as log
from tempor.models.constants import DEVICE, ModelTaskType, Nonlin, ODEBackend
from tempor.models.mlp import MLP
from tempor.models.samplers import ImbalancedDatasetSampler
from tempor.models.utils import enable_reproducibility
Interpolation = Literal["cubic", "linear"]
ILTAlgorithm = Literal["fourier", "dehoog", "cme", "fixed_tablot", "stehfest"]
[docs]class CDEFunc(torch.nn.Module):
def __init__(
self,
n_units_in: int,
n_units_hidden: int,
n_layers_hidden: int = 1,
nonlin: Nonlin = "relu",
dropout: float = 0,
device: Any = DEVICE,
):
r"""CDEFunc computes :math:`f_\theta for the CDE model : z_t = z_0 + \int_0^t f_\theta(z_s) dX_s`.
Args:
n_units_in (int):
Number of input units.
n_units_hidden (int):
Number of hidden units.
n_layers_hidden (int, optional):
Number of hidden layers. Defaults to ``1``.
nonlin (Nonlin, optional):
Nonlinearity to use in NN. Available options: :obj:`~tempor.models.constants.Nonlin`.
Defaults to ``"relu"``.
dropout (float, optional):
Dropout value. If ``0``, the dropout is not used. Defaults to ``0``.
device (Any, optional):
PyTorch device to use. Defaults to :obj:`~tempor.models.constants.DEVICE`.
"""
super(CDEFunc, self).__init__()
self.n_units_in = n_units_in
self.n_units_hidden = n_units_hidden
n_units_out = n_units_in * n_units_hidden
self.model = MLP(
task_type="regression",
n_units_in=n_units_hidden,
n_units_out=n_units_out,
n_layers_hidden=n_layers_hidden,
n_units_hidden=n_units_hidden,
dropout=dropout,
nonlin=nonlin,
device=device,
nonlin_out=[("tanh", n_units_out)],
)
[docs] def forward(self, t: torch.Tensor, z: torch.Tensor) -> torch.Tensor: # pylint: disable=unused-argument
"""Forward pass."""
z = self.model(z)
z = z.view(*z.shape[:-1], self.n_units_hidden, self.n_units_in)
return z
[docs]class ODEFunc(torch.nn.Module):
def __init__(
self,
n_units_hidden: int,
n_layers_hidden: int = 1,
nonlin: Nonlin = "relu",
dropout: float = 0,
device: Any = DEVICE,
):
r"""ODEFunc computes :math:`f_\theta for the ODE model : z_t = z_0 + \int_0^t f_\theta(z_s) dX_s`.
Args:
n_units_hidden (int):
Number of hidden units.
n_layers_hidden (int, optional):
Number of hidden layers. Defaults to ``1``.
nonlin (Nonlin, optional):
Nonlinearity to use in NN. Available options: :obj:`~tempor.models.constants.Nonlin`.
Defaults to ``"relu"``.
dropout (float, optional):
Dropout value. If ``0``, the dropout is not used. Defaults to ``0``.
device (Any, optional):
PyTorch device to use. Defaults to :obj:`~tempor.models.constants.DEVICE`.
"""
super(ODEFunc, self).__init__()
self.n_units_hidden = n_units_hidden
self.model = MLP(
task_type="regression",
n_units_in=n_units_hidden,
n_units_out=n_units_hidden,
n_layers_hidden=n_layers_hidden,
n_units_hidden=n_units_hidden,
dropout=dropout,
nonlin=nonlin,
device=device,
nonlin_out=[("tanh", n_units_hidden)],
)
[docs] def forward(self, t: torch.Tensor, z: torch.Tensor) -> torch.Tensor: # pylint: disable=unused-argument
"""Forward pass."""
return self.model(z)
[docs]class ReverseGRUEncoder(nn.Module):
def __init__(
self,
n_units_in: int,
n_units_latent: int,
n_units_hidden: int,
device: Any = DEVICE,
):
"""Model (encoder and Laplace representation func). Encodes observed trajectory into latent vector."""
super(ReverseGRUEncoder, self).__init__()
self.gru = nn.GRU(n_units_in, n_units_hidden, 2, batch_first=True)
self.linear_out = nn.Linear(n_units_hidden, n_units_latent).to(device)
nn.init.xavier_uniform_(self.linear_out.weight)
[docs] def forward(self, observed_data: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
trajs_to_encode = observed_data # (batch_size, t_observed_dim, observed_dim)
reversed_trajs_to_encode = torch.flip(trajs_to_encode, (1,))
out, _ = self.gru(reversed_trajs_to_encode)
return nn.Tanh()(self.linear_out(out[:, -1, :])) # pylint: disable=not-callable
[docs]class LaplaceFunc(nn.Module):
def __init__(
self,
s_dim: int,
n_units_out: int,
n_units_latent: int,
n_units_hidden: int = 64,
device: Any = DEVICE,
) -> None:
"""SphereSurfaceModel : ``C^{b+k} -> C^{bxd}`` - in Riemann Sphere Coords : ``b dim s`` reconstruction terms,
`k` is latent encoding dimension, `d` is output dimension.
"""
super(LaplaceFunc, self).__init__()
self.s_dim = s_dim
self.n_units_out = n_units_out
self.n_units_latent = n_units_latent
self.linear_tanh_stack = nn.Sequential(
nn.Linear(s_dim * 2 + n_units_latent, n_units_hidden),
nn.Tanh(),
nn.Linear(n_units_hidden, n_units_hidden),
nn.Tanh(),
nn.Linear(n_units_hidden, (s_dim) * 2 * n_units_out),
)
for m in self.linear_tanh_stack.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
phi_max = torch.pi / 2.0
self.phi_scale = phi_max - -torch.pi / 2.0
self.to(device)
[docs] def forward(self, i: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass."""
out = self.linear_tanh_stack(i.view(-1, self.s_dim * 2 + self.n_units_latent)).view(
-1, 2 * self.n_units_out, self.s_dim
)
theta = nn.Tanh()(out[:, : self.n_units_out, :]) * torch.pi # From - pi to + pi
phi = (
nn.Tanh()(out[:, self.n_units_out :, :]) * self.phi_scale / 2.0 - torch.pi / 2.0 + self.phi_scale / 2.0
) # Form -pi / 2 to + pi / 2
return theta, phi
[docs]class NeuralODE(torch.nn.Module):
def __init__(
self,
task_type: ModelTaskType,
n_static_units_in: int,
n_temporal_units_in: int,
output_shape: List[int],
n_units_hidden: int = 100,
n_layers_hidden: int = 1,
nonlin: Nonlin = "relu",
nonlin_out: Optional[List[Tuple[Nonlin, int]]] = None,
dropout: float = 0,
backend: ODEBackend = "cde",
# CDE/ODE specific:
atol: float = 1e-2,
rtol: float = 1e-2,
interpolation: Interpolation = "cubic",
# Laplace specific:
ilt_reconstruction_terms: int = 33,
ilt_algorithm: ILTAlgorithm = "fourier",
# Training:
lr: float = 1e-3,
weight_decay: float = 1e-3,
n_iter: int = 1000,
batch_size: int = 500,
n_iter_print: int = 100,
random_state: int = 0,
patience: int = 10,
n_iter_min: int = 100,
clipping_value: int = 1,
train_ratio: float = 0.8,
device: Any = DEVICE,
dataloader_sampler: Optional[sampler.Sampler] = None,
):
r"""The model that computes the integral in: :math:`z_t = z_0 + \int_0^t f_\theta(z_s) dX_s`.
Neural ODEs are a new family of deep neural network models. Instead of specifying a discrete sequence of
hidden layers, we parameterize the derivative of the hidden state using a neural network.
The output of the network is computed using a blackbox differential equation solver.
These are continuous-depth models that have constant memory cost, adapt their evaluation strategy to each input,
and can explicitly trade numerical precision for speed.
Args:
task_type (ModelTaskType):
The type of the problem. Available options: :obj:`~tempor.models.constants.ModelTaskType`.
n_static_units_in (int):
Number of features in the static tensor.
n_temporal_units_in (int):
Number of features in the temporal tensor.
output_shape (List[int]):
Shape of the output tensor.
n_units_hidden (int, optional):
Number of hidden units in each layer. Defaults to ``100``.
n_layers_hidden (int, optional):
Number of hidden layers. Defaults to ``1``.
nonlin (Nonlin, optional):
Nonlinearity to use in NN. Available options: :obj:`~tempor.models.constants.Nonlin`.
Defaults to ``"relu"``.
nonlin_out (Optional[List[Tuple[Nonlin, int]]], optional):
List of activations for the output. Example ``[("tanh", 1), ("softmax", 3)]`` - means the output layer
will apply ``"tanh"`` for the first unit, and ``"softmax"`` for the following 3 units in the output.
Defaults to `None`.
dropout (float, optional):
Dropout value. If ``0``, the dropout is not used. Defaults to ``0``.
backend (ODEBackend, optional):
Which solver to use: ``"cde"``, ``"ode"``, ``"laplace"``. Defaults to ``"cde"``.
atol (float, optional):
Specific to ``"ode"`` and ``"cde"`` backends. Absolute tolerance for solution. Defaults to ``1e-2``.
rtol (float, optional):
Specific to ``"ode"`` and ``"cde"`` backends. Relative tolerance for solution. Defaults to ``1e-2``.
interpolation (Interpolation, optional):
Specific to ``"ode"`` and ``"cde"`` backends. ``"cubic"`` or ``"linear"``. Defaults to ``"cubic"``.
ilt_reconstruction_terms (int, optional):
Specific to ``"laplace"`` backend. Number of ILT reconstruction terms, i.e. the number of complex
:math:`s` points in ``laplace_rep_func`` to reconstruct a single time point. Defaults to ``33``.
ilt_algorithm (ILTAlgorithm, optional):
Specific to ``"laplace"`` backend. Inverse Laplace transform algorithm to use. Available are
{``fourier``, ``dehoog``, ``cme``, ``fixed_tablot``, ``stehfest``}. Defaults to ``"fourier"``.
lr (float, optional):
Learning rate for optimizer. Defaults to ``1e-3``.
weight_decay (float, optional):
l2 (ridge) penalty for the weights. Defaults to ``1e-3``.
n_iter (int, optional):
Maximum number of iterations. Defaults to ``1000``.
batch_size (int, optional):
Batch size. Defaults to ``500``.
n_iter_print (int, optional):
Number of iterations after which to print updates and check the validation loss. Defaults to ``100``.
random_state (int, optional):
Random_state used. Defaults to ``0``.
patience (int, optional):
Number of iterations to wait before early stopping after decrease in validation loss.
Defaults to ``10``.
n_iter_min (int, optional):
Minimum number of iterations to go through before starting early stopping. Defaults to ``100``.
clipping_value (int, optional):
Gradients clipping value. Defaults to ``1``.
train_ratio (float, optional):
Train/test split ratio. Defaults to ``0.8``.
device (Any, optional):
PyTorch device to use. Defaults to `~tempor.models.constants.DEVICE`.
dataloader_sampler (Optional[sampler.Sampler], optional):
Custom data sampler for training. Defaults to `None`.
"""
super(NeuralODE, self).__init__()
enable_reproducibility(random_state)
if len(output_shape) == 0:
raise ValueError("Invalid output shape")
self.task_type = task_type
self.backend = backend
self.func: Union[CDEFunc, ODEFunc, LaplaceFunc]
if self.backend == "cde":
self.func = CDEFunc(
n_temporal_units_in + 1, # we add the observation times
n_units_hidden,
n_layers_hidden=n_layers_hidden,
nonlin=nonlin,
dropout=dropout,
device=device,
)
elif self.backend == "ode":
self.func = ODEFunc(
n_units_hidden,
n_layers_hidden=n_layers_hidden,
nonlin=nonlin,
dropout=dropout,
device=device,
)
elif self.backend == "laplace":
self.func = LaplaceFunc(
ilt_reconstruction_terms,
n_units_out=n_units_hidden,
n_units_latent=n_units_hidden,
device=device,
)
else:
raise RuntimeError(f"Invalid ODE backend {self.backend}")
self.initial_temporal: Union[MLP, ReverseGRUEncoder]
if self.backend in ["ode", "cde"]:
self.initial_temporal = MLP(
task_type="regression",
n_units_in=n_temporal_units_in + 1, # we add the observation times
n_units_out=n_units_hidden,
n_layers_hidden=n_layers_hidden,
n_units_hidden=n_units_hidden,
dropout=dropout,
nonlin=nonlin,
device=device,
)
else: # self.backend == "laplace":
self.initial_temporal = ReverseGRUEncoder(
n_temporal_units_in + 1,
n_units_latent=n_units_hidden,
n_units_hidden=n_units_hidden,
device=device,
).to(device)
self.output_shape = output_shape
self.n_units_out = int(np.prod(self.output_shape))
self.n_units_hidden = n_units_hidden
output_input_size = n_units_hidden
self.initial_static: Optional[MLP] = None
if n_static_units_in > 0:
self.initial_static = MLP(
task_type="regression",
n_units_in=n_static_units_in,
n_units_out=n_units_hidden,
n_layers_hidden=n_layers_hidden,
n_units_hidden=n_units_hidden,
dropout=dropout,
nonlin=nonlin,
device=device,
)
output_input_size += n_units_hidden
self.output = MLP(
task_type=task_type,
n_units_in=output_input_size,
n_units_out=self.n_units_out,
n_layers_hidden=n_layers_hidden,
n_units_hidden=n_units_hidden,
dropout=dropout,
nonlin=nonlin,
device=device,
nonlin_out=nonlin_out,
)
# ODE specific
self.atol = atol
self.rtol = rtol
self.interpolation = interpolation
self.ilt_reconstruction_terms = ilt_reconstruction_terms
self.ilt_algorithm = ilt_algorithm
# training
self.n_iter = n_iter
self.n_iter_print = n_iter_print
self.n_iter_min = n_iter_min
self.batch_size = batch_size
self.patience = patience
self.clipping_value = clipping_value
self.device = device
self.train_ratio = train_ratio
self.random_state = random_state
self.dataloader_sampler = dataloader_sampler
if self.backend == "laplace":
# Kludge to make sure `torchlaplace` uses the correct device.
# TODO: If `torchlaplace` is updated to allow passing `device` argument to `laplace_reconstruct`,
# remove this kludge and update accordingly.
torchlaplace.inverse_laplace.device = str(self.device)
self.loss: nn.Module
if task_type == "classification":
self.loss = nn.CrossEntropyLoss()
else:
self.loss = nn.MSELoss()
self.optimizer = torch.optim.Adam(
self.parameters(),
lr=lr,
weight_decay=weight_decay,
) # optimize all rnn parameters
[docs] def forward(
self,
static_data: torch.Tensor,
temporal_data: torch.Tensor,
observation_times: torch.Tensor,
) -> torch.Tensor:
"""Forward pass."""
# sanity
if torch.isnan(static_data).sum() != 0:
raise ValueError("NaNs detected in the static data")
if torch.isnan(temporal_data).sum() != 0:
raise ValueError("NaNs detected in the temporal data")
if torch.isnan(observation_times).sum() != 0:
raise ValueError("NaNs detected in the temporal horizons")
# Include the observation times as a channel in the dataset
temporal_data_ext = torch.cat([temporal_data, observation_times.unsqueeze(-1)], dim=-1)
# Convert the dataset into a continuous path.
coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(temporal_data_ext)
# Interpolate the input
if self.interpolation == "linear":
spline = torchcde.LinearInterpolation(coeffs)
elif self.interpolation == "cubic":
spline = torchcde.CubicSpline(coeffs)
else:
raise RuntimeError(f"Invalid interpolation {self.interpolation}")
# Solve the ODE using a solver
if self.backend == "cde":
# Initial hidden state should be a function of the first observation.
X0 = spline.evaluate(spline.interval[0])
z0 = self.initial_temporal(X0)
z_T = torchcde.cdeint(X=spline, func=self.func, z0=z0, t=spline.interval, atol=self.atol, rtol=self.rtol)
z_T = z_T[:, 1] # pyright: ignore
elif self.backend == "ode":
X_emb = self.initial_temporal(temporal_data_ext)
z_T = torchdiffeq.odeint_adjoint(
self.func,
X_emb,
spline.interval,
atol=self.atol,
rtol=self.rtol,
)
z_T = z_T[1] # pyright: ignore
z_T = z_T[:, -1, :] # pyright: ignore # Last time point.
elif self.backend == "laplace":
X_emb = self.initial_temporal(temporal_data_ext)
z_T = torchlaplace.laplace_reconstruct(
laplace_rep_func=self.func,
p=X_emb,
t=observation_times,
recon_dim=self.n_units_hidden,
ilt_reconstruction_terms=self.ilt_reconstruction_terms,
ilt_algorithm=self.ilt_algorithm,
)
z_T = z_T[:, -1, :]
else:
raise RuntimeError(f"Invalid solver {self.backend}")
# Compute static embedding
if static_data is not None and self.initial_static is not None:
static_emb = self.initial_static(static_data)
z_T = torch.cat([z_T, static_emb], dim=-1)
out = self.output(z_T)
return out.reshape(-1, *self.output_shape)
[docs] @pydantic_utils.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))
def predict(
self,
static_data: Union[List, np.ndarray, torch.Tensor],
temporal_data: Union[List, np.ndarray, torch.Tensor],
observation_times: Union[List, np.ndarray, torch.Tensor],
) -> np.ndarray:
"""Make predictions."""
self.eval()
with torch.no_grad():
(
static_data_t,
temporal_data_t,
observation_times_t,
_,
window_batches,
) = self._prepare_input(static_data, temporal_data, observation_times)
yt = torch.zeros(len(temporal_data), *self.output_shape).to(self.device)
for widx in range(len(temporal_data_t)):
window_size = len(observation_times_t[widx][0])
local_yt = self(
static_data_t[widx],
temporal_data_t[widx],
observation_times_t[widx],
)
yt[window_batches[window_size]] = local_yt
if self.task_type == "classification":
return np.argmax(yt.cpu().numpy(), -1)
else:
return yt.cpu().numpy()
[docs] @pydantic_utils.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))
def predict_proba(
self,
static_data: Union[List, np.ndarray, torch.Tensor],
temporal_data: Union[List, np.ndarray, torch.Tensor],
observation_times: Union[List, np.ndarray, torch.Tensor],
) -> np.ndarray:
"""Predict probabilities."""
self.eval()
if self.task_type != "classification":
raise RuntimeError("Task valid only for classification")
with torch.no_grad():
(
static_data_t,
temporal_data_t,
observation_times_t,
_,
window_batches,
) = self._prepare_input(static_data, temporal_data, observation_times)
yt = torch.zeros(len(temporal_data), *self.output_shape).to(self.device)
for widx in range(len(temporal_data_t)):
window_size = len(observation_times_t[widx][0])
local_yt = self(
static_data_t[widx],
temporal_data_t[widx],
observation_times_t[widx],
)
yt[window_batches[window_size]] = local_yt
return yt.cpu().numpy()
[docs] def score(
self,
static_data: Union[List, np.ndarray, torch.Tensor],
temporal_data: Union[List, np.ndarray, torch.Tensor],
observation_times: Union[List, np.ndarray, torch.Tensor],
outcome: Union[List, np.ndarray],
) -> float:
"""Compute default score."""
y_pred = self.predict(static_data, temporal_data, observation_times)
outcome = np.asarray(outcome)
if self.task_type == "classification":
return np.mean(y_pred.astype(int) == outcome.astype(int))
else:
return np.mean(np.inner(outcome - y_pred, outcome - y_pred) / 2.0)
[docs] @pydantic_utils.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))
def fit(
self,
static_data: Union[List, np.ndarray, torch.Tensor],
temporal_data: Union[List, np.ndarray, torch.Tensor],
observation_times: Union[List, np.ndarray, torch.Tensor],
outcome: Union[List, np.ndarray, torch.Tensor],
) -> Any:
"""Fit (train) the model."""
(
static_data_t,
temporal_data_t,
observation_times_t,
outcome_t,
_,
) = self._prepare_input(static_data, temporal_data, observation_times, outcome)
return self._train(static_data_t, temporal_data_t, observation_times_t, outcome_t)
@pydantic_utils.validate_arguments(config=pydantic.ConfigDict(arbitrary_types_allowed=True))
def _train(
self,
static_data: List[torch.Tensor],
temporal_data: List[torch.Tensor],
observation_times: List[torch.Tensor],
outcome: List[torch.Tensor],
) -> Any:
patience = 0
prev_error = np.inf
train_dataloaders = []
test_dataloaders = []
for widx in range(len(temporal_data)):
train_dl, test_dl = self.dataloader(
static_data[widx],
temporal_data[widx],
observation_times[widx],
outcome[widx],
)
train_dataloaders.append(train_dl)
test_dataloaders.append(test_dl)
# training and testing
for it in range(self.n_iter):
train_loss = self._train_epoch(train_dataloaders)
if (it + 1) % self.n_iter_print == 0:
val_loss = self._test_epoch(test_dataloaders)
log.info(f"Epoch:{it}| train loss: {train_loss}, validation loss: {val_loss}")
if val_loss < prev_error:
patience = 0
prev_error = val_loss
else:
patience += 1
if patience > self.patience:
break
return self
def _train_epoch(self, loaders: List[DataLoader]) -> float:
self.train()
losses = []
for loader in loaders:
for step, (static_mb, temporal_mb, horizons_mb, y_mb) in enumerate( # pylint: disable=unused-variable
loader
):
self.optimizer.zero_grad() # clear gradients for this training step
pred = self(static_mb, temporal_mb, horizons_mb) # rnn output
if torch.isnan(pred).sum() > 0: # pragma: no cover
raise RuntimeError("NaNs in the training prediction")
loss = self.loss(pred.squeeze(), y_mb.squeeze())
loss.backward() # backpropagation, compute gradients
if self.clipping_value > 0:
torch.nn.utils.clip_grad_norm_( # type: ignore [attr-defined] # pyright: ignore
self.parameters(),
self.clipping_value,
)
self.optimizer.step() # apply gradients
if torch.isnan(loss): # pragma: no cover
raise RuntimeError("NaNs in the loss")
losses.append(loss.detach().cpu())
return float(np.mean(losses))
def _test_epoch(self, loaders: List[DataLoader]) -> float:
self.eval()
losses = []
for loader in loaders:
for step, (static_mb, temporal_mb, horizons_mb, y_mb) in enumerate( # pylint: disable=unused-variable
loader
):
pred = self(static_mb, temporal_mb, horizons_mb) # ODE output
if torch.isnan(pred).sum() > 0: # pragma: no cover
raise RuntimeError("NaNs in the test prediction")
loss = self.loss(pred.squeeze(), y_mb.squeeze())
losses.append(loss.detach().cpu())
return float(np.mean(losses))
[docs] def dataloader(
self,
static_data: torch.Tensor,
temporal_data: torch.Tensor,
observation_times: torch.Tensor,
outcome: torch.Tensor,
) -> Tuple[DataLoader, DataLoader]:
"""Get the train and test `torch` dataloaders."""
stratify = None
_, out_counts = torch.unique(outcome, return_counts=True)
if out_counts.min() > 1:
stratify = outcome.cpu()
split: Tuple[torch.Tensor, ...] = train_test_split( # pyright: ignore
static_data.cpu(),
temporal_data.cpu(),
observation_times.cpu(),
outcome.cpu(),
train_size=self.train_ratio,
random_state=self.random_state,
stratify=stratify,
)
(
static_data_train,
static_data_test,
temporal_data_train,
temporal_data_test,
observation_times_train,
observation_times_test,
outcome_train,
outcome_test,
) = split
train_dataset = TensorDataset(
static_data_train.to(self.device),
temporal_data_train.to(self.device),
observation_times_train.to(self.device),
outcome_train.to(self.device),
)
test_dataset = TensorDataset(
static_data_test.to(self.device),
temporal_data_test.to(self.device),
observation_times_test.to(self.device),
outcome_test.to(self.device),
)
sampler_ = self.dataloader_sampler
if sampler_ is None and self.task_type == "classification":
sampler_ = ImbalancedDatasetSampler(outcome_train.squeeze().cpu().numpy().tolist())
return (
DataLoader(
train_dataset,
batch_size=self.batch_size,
sampler=sampler_,
pin_memory=False,
),
DataLoader(
test_dataset,
batch_size=self.batch_size,
pin_memory=False,
),
)
def _check_tensor(self, X: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
if isinstance(X, torch.Tensor):
return X.to(self.device)
else:
return torch.from_numpy(np.asarray(X)).to(self.device)
def _prepare_input(
self,
static_data: Union[List, np.ndarray, torch.Tensor],
temporal_data: Union[List, np.ndarray, torch.Tensor],
observation_times: Union[List, np.ndarray, torch.Tensor],
outcome: Optional[Union[List, np.ndarray, torch.Tensor]] = None,
) -> Tuple:
static_data = np.asarray(static_data)
temporal_data = np.asarray(temporal_data)
observation_times = np.asarray(observation_times)
if outcome is not None:
outcome = np.asarray(outcome)
window_batches: Dict[int, List[int]] = {}
for idx, item in enumerate(observation_times):
window_len = len(item)
if window_len not in window_batches:
window_batches[window_len] = []
window_batches[window_len].append(idx)
static_data_mb = []
temporal_data_mb = []
observation_times_mb = []
outcome_mb = []
for widx in window_batches:
indices = window_batches[widx]
static_data_t = self._check_tensor(static_data[indices]).float()
local_temporal_data = np.array(temporal_data[indices].tolist()).astype(float)
temporal_data_t = self._check_tensor(local_temporal_data).float()
local_observation_times = np.array(observation_times[indices].tolist()).astype(float)
observation_times_t = self._check_tensor(local_observation_times).float()
static_data_mb.append(static_data_t)
temporal_data_mb.append(temporal_data_t)
observation_times_mb.append(observation_times_t)
if outcome is not None:
outcome_t = self._check_tensor(outcome[indices]).float()
if self.task_type == "classification":
outcome_t = outcome_t.long()
outcome_mb.append(outcome_t)
return (
static_data_mb,
temporal_data_mb,
observation_times_mb,
outcome_mb,
window_batches,
)