tempor.models.clairvoyance2.components.torch.synctwin_train_utils module
-
tempor.models.clairvoyance2.components.torch.synctwin_train_utils.get_batch_standard(batch_size, *args)[source]
-
tempor.models.clairvoyance2.components.torch.synctwin_train_utils.get_folds(start, split, *args)[source]
-
tempor.models.clairvoyance2.components.torch.synctwin_train_utils.create_paths(*args)[source]
-
tempor.models.clairvoyance2.components.torch.synctwin_train_utils.pre_train_reconstruction_prognostic_loss(nsc, x_full, t_full, mask_full, y_full, y_mask_full, x_full_val=
None, t_full_val=None, mask_full_val=None, y_full_val=None, y_mask_full_val=None, niters=5000, test_freq=1000, batch_size=None)[source]
-
tempor.models.clairvoyance2.components.torch.synctwin_train_utils.update_representations(nsc, x_full, t_full, mask_full, batch_ind_full)[source]
-
tempor.models.clairvoyance2.components.torch.synctwin_train_utils.load_pre_train_and_init(nsc, x_full, t_full, mask_full, batch_ind_full, model_path=
'models/sync/{}.pth', init_decoder_Y=False)[source]
-
tempor.models.clairvoyance2.components.torch.synctwin_train_utils.train_B_self_expressive(nsc, x_full, t_full, mask_full, batch_ind_full, niters=
20000, batch_size=None, lr=0.001, test_freq=1000)[source]
-
tempor.models.clairvoyance2.components.torch.synctwin_train_utils.get_prediction(nsc, batch_ind_full, y_control, itr=
500)[source]
-
tempor.models.clairvoyance2.components.torch.synctwin_train_utils.get_treatment_effect(nsc, batch_ind_full, y_full, y_control, itr=
500)[source]
Back to top