""" FactorizePhys: Matrix Factorization for Multidimensional Attention in Remote Physiological Sensing NeurIPS 2024 Jitesh Joshi, Sos S. Agaian, and Youngjun Cho """ import os import numpy as np import torch import torch.optim as optim from evaluation.metrics import calculate_metrics from neural_methods.loss.NegPearsonLoss import Neg_Pearson from neural_methods.model.FactorizePhys.FactorizePhys import FactorizePhys from neural_methods.model.FactorizePhys.FactorizePhysBig import FactorizePhysBig from neural_methods.trainer.BaseTrainer import BaseTrainer from tqdm import tqdm class FactorizePhysTrainer(BaseTrainer): def __init__(self, config, data_loader): """Inits parameters from args and the writer for TensorboardX.""" super().__init__() self.max_epoch_num = config.TRAIN.EPOCHS self.model_dir = config.MODEL.MODEL_DIR self.model_file_name = config.TRAIN.MODEL_FILE_NAME self.batch_size = config.TRAIN.BATCH_SIZE self.num_of_gpu = config.NUM_OF_GPU_TRAIN self.dropout_rate = config.MODEL.DROP_RATE self.base_len = self.num_of_gpu self.config = config self.min_valid_loss = None self.best_epoch = 0 if torch.cuda.is_available() and config.NUM_OF_GPU_TRAIN > 0: dev_list = [int(d) for d in config.DEVICE.replace("cuda:", "").split(",")] self.device = torch.device(dev_list[0]) #currently toolbox only supports 1 GPU self.num_of_gpu = 1 #config.NUM_OF_GPU_TRAIN # set number of used GPUs else: self.device = torch.device("cpu") # if no GPUs set device is CPU self.num_of_gpu = 0 # no GPUs used frames = self.config.MODEL.FactorizePhys.FRAME_NUM in_channels = self.config.MODEL.FactorizePhys.CHANNELS model_type = self.config.MODEL.FactorizePhys.TYPE model_type = model_type.lower() md_config = {} md_config["FRAME_NUM"] = self.config.MODEL.FactorizePhys.FRAME_NUM md_config["MD_TYPE"] = self.config.MODEL.FactorizePhys.MD_TYPE md_config["MD_FSAM"] = self.config.MODEL.FactorizePhys.MD_FSAM md_config["MD_TRANSFORM"] = self.config.MODEL.FactorizePhys.MD_TRANSFORM md_config["MD_S"] = self.config.MODEL.FactorizePhys.MD_S md_config["MD_R"] = self.config.MODEL.FactorizePhys.MD_R md_config["MD_STEPS"] = self.config.MODEL.FactorizePhys.MD_STEPS md_config["MD_INFERENCE"] = self.config.MODEL.FactorizePhys.MD_INFERENCE md_config["MD_RESIDUAL"] = self.config.MODEL.FactorizePhys.MD_RESIDUAL self.md_infer = self.config.MODEL.FactorizePhys.MD_INFERENCE self.use_fsam = self.config.MODEL.FactorizePhys.MD_FSAM if model_type == "standard": self.model = FactorizePhys(frames=frames, md_config=md_config, in_channels=in_channels, dropout=self.dropout_rate, device=self.device) # [3, T, 72,72] elif model_type == "big": self.model = FactorizePhysBig(frames=frames, md_config=md_config, in_channels=in_channels, dropout=self.dropout_rate, device=self.device) # [3, T, 144,144] else: print("Unexpected model type specified. Should be standard or big, but specified:", model_type) exit() if torch.cuda.device_count() > 0 and self.num_of_gpu > 0: # distribute model across GPUs self.model = torch.nn.DataParallel(self.model, device_ids=[self.device]) # data parallel model else: self.model = torch.nn.DataParallel(self.model).to(self.device) if self.config.TOOLBOX_MODE == "train_and_test" or self.config.TOOLBOX_MODE == "only_train": self.num_train_batches = len(data_loader["train"]) self.criterion = Neg_Pearson() self.optimizer = optim.Adam( self.model.parameters(), lr=self.config.TRAIN.LR) # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html self.scheduler = torch.optim.lr_scheduler.OneCycleLR( self.optimizer, max_lr=self.config.TRAIN.LR, epochs=self.config.TRAIN.EPOCHS, steps_per_epoch=self.num_train_batches) elif self.config.TOOLBOX_MODE == "only_test": pass else: raise ValueError("FactorizePhys trainer initialized in incorrect toolbox mode!") def train(self, data_loader): """Training routine for model""" if data_loader["train"] is None: raise ValueError("No data for train") mean_training_losses = [] mean_valid_losses = [] mean_appx_error = [] lrs = [] for epoch in range(self.max_epoch_num): print('') print(f"====Training Epoch: {epoch}====") running_loss = 0.0 train_loss = [] appx_error_list = [] self.model.train() tbar = tqdm(data_loader["train"], ncols=80) for idx, batch in enumerate(tbar): tbar.set_description("Train epoch %s" % epoch) data = batch[0].to(self.device) labels = batch[1].to(self.device) if len(labels.shape) > 2: labels = labels[..., 0] # Compatibility wigth multi-signal labelled data labels = (labels - torch.mean(labels)) / torch.std(labels) # normalize last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1) data = torch.cat((data, last_frame), 2) # last_sample = torch.unsqueeze(labels[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1) # labels = torch.cat((labels, last_sample), 0) # labels = torch.diff(labels, dim=0) # labels = labels/ torch.std(labels) # normalize # labels[torch.isnan(labels)] = 0 self.optimizer.zero_grad() if self.model.training and self.use_fsam: pred_ppg, vox_embed, factorized_embed, appx_error = self.model(data) else: pred_ppg, vox_embed = self.model(data) pred_ppg = (pred_ppg - torch.mean(pred_ppg)) / torch.std(pred_ppg) # normalize loss = self.criterion(pred_ppg, labels) loss.backward() running_loss += loss.item() if idx % 100 == 99: # print every 100 mini-batches print( f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}') running_loss = 0.0 train_loss.append(loss.item()) if self.use_fsam: appx_error_list.append(appx_error.item()) # Append the current learning rate to the list lrs.append(self.scheduler.get_last_lr()) self.optimizer.step() self.scheduler.step() if self.use_fsam: tbar.set_postfix({"appx_error": appx_error.item()}, loss=loss.item()) else: tbar.set_postfix(loss=loss.item()) # Append the mean training loss for the epoch mean_training_losses.append(np.mean(train_loss)) if self.use_fsam: mean_appx_error.append(np.mean(appx_error_list)) print("Mean train loss: {}, Mean appx error: {}".format( np.mean(train_loss), np.mean(appx_error_list))) else: print("Mean train loss: {}".format(np.mean(train_loss))) self.save_model(epoch) if not self.config.TEST.USE_LAST_EPOCH: valid_loss = self.valid(data_loader) mean_valid_losses.append(valid_loss) print('validation loss: ', valid_loss) if self.min_valid_loss is None: self.min_valid_loss = valid_loss self.best_epoch = epoch print("Update best model! Best epoch: {}".format(self.best_epoch)) elif (valid_loss < self.min_valid_loss): self.min_valid_loss = valid_loss self.best_epoch = epoch print("Update best model! Best epoch: {}".format(self.best_epoch)) if not self.config.TEST.USE_LAST_EPOCH: print("best trained epoch: {}, min_val_loss: {}".format( self.best_epoch, self.min_valid_loss)) if self.config.TRAIN.PLOT_LOSSES_AND_LR: self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config) def valid(self, data_loader): """ Runs the model on valid sets.""" if data_loader["valid"] is None: raise ValueError("No data for valid") print('') print(" ====Validing===") valid_loss = [] self.model.eval() valid_step = 0 with torch.no_grad(): vbar = tqdm(data_loader["valid"], ncols=80) for valid_idx, valid_batch in enumerate(vbar): vbar.set_description("Validation") data, labels = valid_batch[0].to(self.device), valid_batch[1].to(self.device) if len(labels.shape) > 2: labels = labels[..., 0] # Compatibility wigth multi-signal labelled data labels = (labels - torch.mean(labels)) / torch.std(labels) # normalize last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1) data = torch.cat((data, last_frame), 2) # last_sample = torch.unsqueeze(labels[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1) # labels = torch.cat((labels, last_sample), 0) # labels = torch.diff(labels, dim=0) # labels = labels/ torch.std(labels) # normalize # labels[torch.isnan(labels)] = 0 if self.md_infer and self.use_fsam: pred_ppg, vox_embed, factorized_embed, appx_error = self.model(data) else: pred_ppg, vox_embed = self.model(data) pred_ppg = (pred_ppg - torch.mean(pred_ppg)) / torch.std(pred_ppg) # normalize loss = self.criterion(pred_ppg, labels) valid_loss.append(loss.item()) valid_step += 1 # vbar.set_postfix(loss=loss.item()) if self.md_infer and self.use_fsam: vbar.set_postfix({"appx_error": appx_error.item()}, loss=loss.item()) else: vbar.set_postfix(loss=loss.item()) valid_loss = np.asarray(valid_loss) return np.mean(valid_loss) def test(self, data_loader): """ Runs the model on test sets.""" if data_loader["test"] is None: raise ValueError("No data for test") print('') print("===Testing===") predictions = dict() labels = dict() if self.config.TOOLBOX_MODE == "only_test": if not os.path.exists(self.config.INFERENCE.MODEL_PATH): raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.") self.model.load_state_dict(torch.load(self.config.INFERENCE.MODEL_PATH, map_location=self.device), strict=False) print("Testing uses pretrained model!") print(self.config.INFERENCE.MODEL_PATH) else: if self.config.TEST.USE_LAST_EPOCH: last_epoch_model_path = os.path.join( self.model_dir, self.model_file_name + '_Epoch' + str(self.max_epoch_num - 1) + '.pth') print("Testing uses last epoch as non-pretrained model!") print(last_epoch_model_path) self.model.load_state_dict(torch.load(last_epoch_model_path, map_location=self.device), strict=False) else: best_model_path = os.path.join( self.model_dir, self.model_file_name + '_Epoch' + str(self.best_epoch) + '.pth') print("Testing uses best epoch selected using model selection as non-pretrained model!") print(best_model_path) self.model.load_state_dict(torch.load(best_model_path, map_location=self.device), strict=False) self.model = self.model.to(self.device) self.model.eval() print("Running model evaluation on the testing dataset!") with torch.no_grad(): for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)): batch_size = test_batch[0].shape[0] data, labels_test = test_batch[0].to(self.device), test_batch[1].to(self.device) if len(labels_test.shape) > 2: labels_test = labels_test[..., 0] # Compatibility wigth multi-signal labelled data labels_test = (labels_test - torch.mean(labels_test)) / torch.std(labels_test) # normalize last_frame = torch.unsqueeze(data[:, :, -1, :, :], 2).repeat(1, 1, max(self.num_of_gpu, 1), 1, 1) data = torch.cat((data, last_frame), 2) # last_sample = torch.unsqueeze(labels_test[-1, :], 0).repeat(max(self.num_of_gpu, 1), 1) # labels_test = torch.cat((labels_test, last_sample), 0) # labels_test = torch.diff(labels_test, dim=0) # labels_test = labels_test/ torch.std(labels_test) # normalize # labels_test[torch.isnan(labels_test)] = 0 if self.md_infer and self.use_fsam: pred_ppg_test, vox_embed, factorized_embed, appx_error = self.model(data) else: pred_ppg_test, vox_embed = self.model(data) pred_ppg_test = (pred_ppg_test - torch.mean(pred_ppg_test)) / torch.std(pred_ppg_test) # normalize if self.config.TEST.OUTPUT_SAVE_DIR: labels_test = labels_test.cpu() pred_ppg_test = pred_ppg_test.cpu() for idx in range(batch_size): subj_index = test_batch[2][idx] sort_index = int(test_batch[3][idx]) if subj_index not in predictions.keys(): predictions[subj_index] = dict() labels[subj_index] = dict() predictions[subj_index][sort_index] = pred_ppg_test[idx] labels[subj_index][sort_index] = labels_test[idx] print('') calculate_metrics(predictions, labels, self.config) if self.config.TEST.OUTPUT_SAVE_DIR: # saving test outputs self.save_test_outputs(predictions, labels, self.config) def save_model(self, index): if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) model_path = os.path.join( self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth') torch.save(self.model.state_dict(), model_path) print('Saved Model Path: ', model_path)