Source code for tempor.models.clairvoyance2.components.torch.synctwin_train_utils

# mypy: ignore-errors

import copy
import os

import numpy as np
import numpy.random
import torch
import torch.optim as optim


[docs]def get_batch_standard(batch_size, *args): array_list = args n_total = args[0].shape[1] # batching on dim=1 batch_ind = numpy.random.choice(np.arange(n_total), batch_size) batch_ind = torch.tensor(batch_ind, dtype=torch.long).to(args[0].device) mini_batch = [] for a in array_list: mini_batch.append(a[:, batch_ind, ...]) return mini_batch
[docs]def get_folds(start, split, *args): a = args ret = [] for x in a: if x.dim() == 3: y = x[:, start::split, :] else: y = x[start::split] ret.append(y) return ret
[docs]def create_paths(*args): for base_path in args: if not os.path.exists(base_path): os.makedirs(base_path)
[docs]def pre_train_reconstruction_prognostic_loss( nsc, x_full, t_full, mask_full, y_full, y_mask_full, x_full_val=None, t_full_val=None, mask_full_val=None, y_full_val=None, y_mask_full_val=None, niters=5000, test_freq=1000, batch_size=None, ): if x_full_val is None: x_full_val = x_full t_full_val = t_full mask_full_val = mask_full y_full_val = y_full y_mask_full_val = y_mask_full enc = nsc.encoder dec = nsc.decoder assert nsc.decoder_Y is not None dec_Y = nsc.decoder_Y optimizer = optim.Adam(list(dec.parameters()) + list(enc.parameters()) + list(dec_Y.parameters())) y_mask_full = torch.stack([y_mask_full] * dec_Y.max_seq_len, dim=0).unsqueeze(-1) test_freq = test_freq if test_freq > 0 else 1 best_loss = 1e9 enc_sd = None dec_sd = None dec_Y_sd = None for itr in range(1, niters + 1): optimizer.zero_grad() if batch_size is not None: x, t, mask, y, y_mask = get_batch_standard( # pylint: disable=unbalanced-tuple-unpacking batch_size, x_full, t_full, mask_full, y_full, y_mask_full ) else: x, t, mask, y, y_mask = x_full, t_full, mask_full, y_full, y_mask_full C = nsc.get_representation(x, t, mask) x_hat = nsc.get_reconstruction(C, t, mask) loss_X = nsc.reconstruction_loss(x, x_hat, mask) y_hat = nsc.get_prognostics(C, t, mask) loss_Y = nsc.prognostic_loss2(y, y_hat, y_mask) loss = loss_X + loss_Y loss.backward() optimizer.step() if itr % test_freq == 0: with torch.no_grad(): if x_full_val.shape[1] < 5000: C = nsc.get_representation(x_full_val, t_full_val, mask_full_val) x_hat = nsc.get_reconstruction(C, t_full_val, mask_full_val) loss_X = nsc.reconstruction_loss(x_full_val, x_hat, mask_full_val) y_hat = nsc.get_prognostics(C, t_full_val, mask_full_val) loss_Y = nsc.prognostic_loss2(y_full_val, y_hat, y_mask_full_val) loss = loss_X + loss_Y else: loss = 0 n_fold = x_full_val.shape[1] // 500 for fold in range(n_fold): ( # pylint: disable=unbalanced-tuple-unpacking x_full_vb, t_full_vb, mask_full_vb, y_full_vb, y_mask_full_vb, ) = get_folds(fold, n_fold, x_full_val, t_full_val, mask_full_val, y_full_val, y_mask_full_val) C = nsc.get_representation(x_full_vb, t_full_vb, mask_full_vb) x_hat = nsc.get_reconstruction(C, t_full_vb, mask_full_vb) loss_X = nsc.reconstruction_loss(x_full_vb, x_hat, mask_full_vb) y_hat = nsc.get_prognostics(C, t_full_vb, mask_full_vb) loss_Y = nsc.prognostic_loss2(y_full_vb, y_hat, y_mask_full_vb) loss = loss_X + loss_Y + loss print("Iter {:04d} | Total Loss {:.6f}".format(itr, loss.item())) # type: ignore if loss < best_loss: best_loss = loss enc_sd = copy.deepcopy(enc.state_dict()) dec_sd = copy.deepcopy(dec.state_dict()) dec_Y_sd = copy.deepcopy(dec_Y.state_dict()) enc.load_state_dict(enc_sd) dec.load_state_dict(dec_sd) dec_Y.load_state_dict(dec_Y_sd) return best_loss
[docs]def update_representations(nsc, x_full, t_full, mask_full, batch_ind_full): with torch.no_grad(): C = nsc.get_representation(x_full, t_full, mask_full) nsc.update_C0(C, batch_ind_full)
[docs]def load_pre_train_and_init( nsc, x_full, t_full, mask_full, batch_ind_full, model_path="models/sync/{}.pth", init_decoder_Y=False ): enc = nsc.encoder dec = nsc.decoder enc.load_state_dict(torch.load(model_path.format("encoder.pth"))) dec.load_state_dict(torch.load(model_path.format("decoder.pth"))) if init_decoder_Y: dec_Y = nsc.decoder_Y dec_Y.load_state_dict(torch.load(model_path.format("decoder_Y.pth"))) with torch.no_grad(): C = nsc.get_representation(x_full, t_full, mask_full) nsc.update_C0(C, batch_ind_full)
[docs]def train_B_self_expressive( nsc, x_full, t_full, mask_full, batch_ind_full, niters=20000, batch_size=None, lr=1e-3, test_freq=1000, ): # mini-batch training not implemented assert batch_size is None optimizer = optim.Adam([nsc.B], lr=lr) test_freq = test_freq if test_freq > 0 else 1 best_loss = 1e9 with torch.no_grad(): C = nsc.get_representation(x_full, t_full, mask_full) nsc_sd = None for itr in range(1, niters + 1): optimizer.zero_grad() B_reduced = nsc.get_B_reduced(batch_ind_full) loss = nsc.self_expressive_loss(C, B_reduced) loss.backward() optimizer.step() if itr % test_freq == 0: with torch.no_grad(): print("Iter {:04d} | Total Loss {:.6f}".format(itr, loss.item())) if np.isnan(loss.item()): raise RuntimeError("NaN loss encountered") if loss < best_loss: best_loss = loss nsc_sd = copy.deepcopy(nsc.state_dict()) nsc.load_state_dict(nsc_sd) return 0
[docs]def get_prediction(nsc, batch_ind_full, y_control, itr=500): batch_ind_full = batch_ind_full.to(nsc.B.device) y_control = y_control.to(nsc.B.device) y_hat_list = list() for _ in range(itr): with torch.no_grad(): B_reduced = nsc.get_B_reduced(batch_ind_full) y_hat = torch.matmul(B_reduced, y_control) if torch.sum(torch.isinf(y_hat)).item() == 0: y_hat_list.append(y_hat) y_hat_mat = torch.stack(y_hat_list, dim=-1) y_hat_mat[torch.isinf(y_hat_mat)] = 0.0 y_hat2 = torch.mean(y_hat_mat, dim=-1) return y_hat2
[docs]def get_treatment_effect(nsc, batch_ind_full, y_full, y_control, itr=500): batch_ind_full = batch_ind_full.to(nsc.B.device) y_control = y_control.to(nsc.B.device) y_full = y_full.to(nsc.B.device) y_hat_list = list() for _ in range(itr): with torch.no_grad(): B_reduced = nsc.get_B_reduced(batch_ind_full) y_hat = torch.matmul(B_reduced, y_control) if torch.sum(torch.isinf(y_hat)).item() == 0: y_hat_list.append(y_hat) y_hat_mat = torch.stack(y_hat_list, dim=-1) y_hat_mat[torch.isinf(y_hat_mat)] = 0.0 y_hat2 = torch.mean(y_hat_mat, dim=-1) return (y_full - y_hat2)[:, nsc.n_unit :, :], y_hat2