"""Trainer for BigSmall Multitask Models""" # Training / Eval Imports import torch import torch.optim as optim from neural_methods.trainer.BaseTrainer import BaseTrainer from neural_methods import loss from neural_methods.model.BigSmall import BigSmall from evaluation.bigsmall_multitask_metrics import (calculate_bvp_metrics, calculate_resp_metrics, calculate_bp4d_au_metrics) # Other Imports from collections import OrderedDict import numpy as np import os from tqdm import tqdm class BigSmallTrainer(BaseTrainer): def define_model(self, config): # BigSmall Model model = BigSmall(n_segment=3) if self.using_TSM: self.frame_depth = config.MODEL.BIGSMALL.FRAME_DEPTH self.base_len = self.num_of_gpu * self.frame_depth return model def format_data_shape(self, data, labels): # reshape big data data_big = data[0] N, D, C, H, W = data_big.shape data_big = data_big.view(N * D, C, H, W) # reshape small data data_small = data[1] N, D, C, H, W = data_small.shape data_small = data_small.view(N * D, C, H, W) # reshape labels if len(labels.shape) != 3: # this training format requires labels that are of shape N_label, D_label, C_label labels = torch.unsqueeze(labels, dim=-1) N_label, D_label, C_label = labels.shape labels = labels.view(N_label * D_label, C_label) # If using temporal shift module if self.using_TSM: data_big = data_big[:(N * D) // self.base_len * self.base_len] data_small = data_small[:(N * D) // self.base_len * self.base_len] labels = labels[:(N * D) // self.base_len * self.base_len] data[0] = data_big data[1] = data_small labels = torch.unsqueeze(labels, dim=-1) return data, labels def send_data_to_device(self, data, labels): big_data = data[0].to(self.device) small_data = data[1].to(self.device) labels = labels.to(self.device) data = (big_data, small_data) return data, labels def get_label_idxs(self, label_list, used_labels): label_idxs = [] for l in used_labels: idx = label_list.index(l) label_idxs.append(idx) return label_idxs def remove_data_parallel(self, old_state_dict): new_state_dict = OrderedDict() for k, v in old_state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v return new_state_dict def save_model(self, index): if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) model_path = os.path.join(self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth') torch.save(self.model.state_dict(), model_path) print('Saved Model Path: ', model_path) print('') def __init__(self, config, data_loader): print('') print('Init BigSmall Multitask Trainer\n\n') self.config = config # save config file # Set up GPU/CPU compute device if torch.cuda.is_available() and config.NUM_OF_GPU_TRAIN > 0: self.device = torch.device(config.DEVICE) # set device to primary GPU self.num_of_gpu = config.NUM_OF_GPU_TRAIN # set number of used GPUs else: self.device = "cpu" # if no GPUs set device is CPU self.num_of_gpu = 0 # no GPUs used # Defining model self.using_TSM = True self.model = self.define_model(config) # define the model if torch.cuda.device_count() > 1 and config.NUM_OF_GPU_TRAIN > 1: # distribute model across GPUs self.model = torch.nn.DataParallel(self.model, device_ids=list(range(config.NUM_OF_GPU_TRAIN))) # data parallel model self.model = self.model.to(self.device) # send model to primary GPU # Training parameters self.batch_size = config.TRAIN.BATCH_SIZE self.max_epoch_num = config.TRAIN.EPOCHS self.LR = config.TRAIN.LR # Set Loss and Optimizer AU_weights = torch.as_tensor([9.64, 11.74, 16.77, 1.05, 0.53, 0.56, 0.75, 0.69, 8.51, 6.94, 5.03, 25.00]).to(self.device) self.criterionAU = torch.nn.BCEWithLogitsLoss(pos_weight=AU_weights).to(self.device) self.criterionBVP = torch.nn.MSELoss().to(self.device) self.criterionRESP = torch.nn.MSELoss().to(self.device) self.optimizer = optim.AdamW(self.model.parameters(), lr=self.LR, weight_decay=0) # self.scaler = torch.cuda.amp.GradScaler() # Loss scalar # Model info (saved more dir, chunk len, best epoch, etc.) self.model_dir = config.MODEL.MODEL_DIR self.model_file_name = config.TRAIN.MODEL_FILE_NAME self.chunk_len = config.TRAIN.DATA.PREPROCESS.CHUNK_LENGTH # Epoch To Use For Test self.used_epoch = 0 # Indicies corresponding to used labels label_list = ['bp_wave', 'HR_bpm', 'systolic_bp', 'diastolic_bp', 'mean_bp', 'resp_wave', 'resp_bpm', 'eda', 'AU01', 'AU02', 'AU04', 'AU05', 'AU06', 'AU06int', 'AU07', 'AU09', 'AU10', 'AU10int', 'AU11', 'AU12', 'AU12int', 'AU13', 'AU14', 'AU14int', 'AU15', 'AU16', 'AU17', 'AU17int', 'AU18', 'AU19', 'AU20', 'AU22', 'AU23', 'AU24', 'AU27', 'AU28', 'AU29', 'AU30', 'AU31', 'AU32', 'AU33', 'AU34', 'AU35', 'AU36', 'AU37', 'AU38', 'AU39', 'pos_bvp','pos_env_norm_bvp'] used_labels = ['bp_wave', 'AU01', 'AU02', 'AU04', 'AU06', 'AU07', 'AU10', 'AU12', 'AU14', 'AU15', 'AU17', 'AU23', 'AU24', 'pos_env_norm_bvp', 'resp_wave'] # Get indicies for labels from npy array au_label_list = [label for label in used_labels if 'AU' in label] bvp_label_list_train = [label for label in used_labels if 'bvp' in label] bvp_label_list_test = [label for label in used_labels if 'bp_wave' in label] resp_label_list = [label for label in used_labels if 'resp' in label] self.label_idx_train_au = self.get_label_idxs(label_list, au_label_list) self.label_idx_valid_au = self.get_label_idxs(label_list, au_label_list) self.label_idx_test_au = self.get_label_idxs(label_list, au_label_list) self.label_idx_train_bvp = self.get_label_idxs(label_list, bvp_label_list_train) self.label_idx_valid_bvp = self.get_label_idxs(label_list, bvp_label_list_train) self.label_idx_test_bvp = self.get_label_idxs(label_list, bvp_label_list_test) self.label_idx_train_resp = self.get_label_idxs(label_list, resp_label_list) self.label_idx_valid_resp = self.get_label_idxs(label_list, resp_label_list) self.label_idx_test_resp = self.get_label_idxs(label_list, resp_label_list) def train(self, data_loader): """Model Training""" if data_loader["train"] is None: raise ValueError("No data for train") print('Starting Training Routine') print('') # Init min validation loss as infinity min_valid_loss = np.inf # minimum validation loss # ARRAYS TO SAVE (LOSS ARRAYS) train_loss_dict = dict() train_au_loss_dict = dict() train_bvp_loss_dict = dict() train_resp_loss_dict = dict() val_loss_dict = dict() val_au_loss_dict = dict() val_bvp_loss_dict = dict() val_resp_loss_dict = dict() # TODO: Expand tracking and subsequent plotting of these losses for BigSmall mean_training_losses = [] mean_valid_losses = [] lrs = [] # ITERATE THROUGH EPOCHS for epoch in range(self.max_epoch_num): print(f"====Training Epoch: {epoch}====") # INIT PARAMS FOR TRAINING running_loss = 0.0 # tracks avg loss over mini batches of 100 train_loss = [] train_au_loss = [] train_bvp_loss = [] train_resp_loss = [] self.model.train() # put model in train mode # MODEL TRAINING tbar = tqdm(data_loader["train"], ncols=80) for idx, batch in enumerate(tbar): tbar.set_description("Train epoch %s" % epoch) # GATHER AND FORMAT BATCH DATA data, labels = batch[0], batch[1] data, labels = self.format_data_shape(data, labels) data, labels = self.send_data_to_device(data, labels) # FOWARD AND BACK PROPOGATE THROUGH MODEL self.optimizer.zero_grad() au_out, bvp_out, resp_out = self.model(data) au_loss = self.criterionAU(au_out, labels[:, self.label_idx_train_au, 0]) # au loss bvp_loss = self.criterionBVP(bvp_out, labels[:, self.label_idx_train_bvp, 0]) # bvp loss resp_loss = self.criterionRESP(resp_out, labels[:, self.label_idx_train_resp, 0]) # resp loss loss = au_loss + bvp_loss + resp_loss # sum losses loss.backward() # Append the current learning rate to the list lrs.append(self.scheduler.get_last_lr()) self.optimizer.step() # self.scaler.scale(loss).backward() # Loss scaling # self.scaler.step(self.optimizer) # self.scaler.update() # UPDATE RUNNING LOSS AND PRINTED TERMINAL OUTPUT AND SAVED LOSSES train_loss.append(loss.item()) train_au_loss.append(au_loss.item()) train_bvp_loss.append(bvp_loss.item()) train_resp_loss.append(resp_loss.item()) running_loss += loss.item() if idx % 100 == 99: # print every 100 mini-batches print(f'[{epoch}, {idx + 1:5d}] loss: {running_loss / 100:.3f}') running_loss = 0.0 tbar.set_postfix({"loss:": loss.item(), "lr:": self.optimizer.param_groups[0]["lr"]}) # APPEND EPOCH LOSS LIST TO TRAINING LOSS DICTIONARY train_loss_dict[epoch] = train_loss train_au_loss_dict[epoch] = train_au_loss train_bvp_loss_dict[epoch] = train_bvp_loss train_resp_loss_dict[epoch] = train_resp_loss print('') # Append the mean training loss for the epoch mean_training_losses.append(np.mean(train_loss)) # SAVE MODEL FOR THIS EPOCH self.save_model(epoch) # VALIDATION (IF ENABLED) if not self.config.TEST.USE_LAST_EPOCH: # Get validation losses valid_loss, valid_au_loss, valid_bvp_loss, valid_resp_loss = self.valid(data_loader) mean_valid_losses.append(valid_loss) val_loss_dict[epoch] = valid_loss val_au_loss_dict[epoch] = valid_au_loss val_bvp_loss_dict[epoch] = valid_bvp_loss val_resp_loss_dict[epoch] = valid_resp_loss print('validation loss: ', valid_loss) # Update used model if self.model_to_use == 'best_epoch' and (valid_loss < min_valid_loss): min_valid_loss = valid_loss self.used_epoch = epoch print("Update best model! Best epoch: {}".format(self.used_epoch)) elif self.model_to_use == 'last_epoch': self.used_epoch = epoch # VALIDATION (NOT ENABLED) else: self.used_epoch = epoch print('') if self.config.TRAIN.PLOT_LOSSES_AND_LR: self.plot_losses_and_lrs(mean_training_losses, mean_valid_losses, lrs, self.config) # PRINT MODEL TO BE USED FOR TESTING print("Used model trained epoch:{}, val_loss:{}".format(self.used_epoch, min_valid_loss)) print('') def valid(self, data_loader): """ Model evaluation on the validation dataset.""" if data_loader["valid"] is None: raise ValueError("No data for valid") print("===Validating===") # INIT PARAMS FOR VALIDATION valid_loss = [] valid_au_loss = [] valid_bvp_loss = [] valid_resp_loss = [] self.model.eval() # MODEL VALIDATION with torch.no_grad(): vbar = tqdm(data_loader["valid"], ncols=80) for valid_idx, valid_batch in enumerate(vbar): vbar.set_description("Validation") # GATHER AND FORMAT BATCH DATA data, labels = valid_batch[0], valid_batch[1] data, labels = self.format_data_shape(data, labels) data, labels = self.send_data_to_device(data, labels) au_out, bvp_out, resp_out = self.model(data) au_loss = self.criterionAU(au_out, labels[:, self.label_idx_valid_au, 0]) # au loss bvp_loss = self.criterionBVP(bvp_out, labels[:, self.label_idx_valid_bvp, 0]) # bvp loss resp_loss = self.criterionRESP(resp_out, labels[:, self.label_idx_valid_resp, 0]) # resp loss loss = au_loss + bvp_loss + resp_loss # sum losses # APPEND VAL LOSS valid_loss.append(loss.item()) valid_au_loss.append(au_loss.item()) valid_bvp_loss.append(bvp_loss.item()) valid_resp_loss.append(resp_loss.item()) vbar.set_postfix(loss=loss.item()) valid_loss = np.asarray(valid_loss) valid_au_loss = np.asarray(valid_au_loss) valid_bvp_loss = np.asarray(valid_bvp_loss) valid_resp_loss = np.asarray(valid_resp_loss) return np.mean(valid_loss), np.mean(valid_au_loss), np.mean(valid_bvp_loss), np.mean(valid_resp_loss) def test(self, data_loader): """ Model evaluation on the testing dataset.""" print("===Testing===") print('') # SETUP if data_loader["test"] is None: raise ValueError("No data for test") # Change chunk length to be test chunk length self.chunk_len = self.config.TEST.DATA.PREPROCESS.CHUNK_LENGTH # ARRAYS TO SAVE (PREDICTIONS AND METRICS ARRAYS) preds_dict_au = dict() labels_dict_au = dict() preds_dict_bvp = dict() labels_dict_bvp = dict() preds_dict_resp = dict() labels_dict_resp = dict() # IF ONLY_TEST MODE LOAD PRETRAINED MODEL if self.config.TOOLBOX_MODE == "only_test": model_path = self.config.INFERENCE.MODEL_PATH print("Testing uses pretrained model!") print('Model path:', model_path) if not os.path.exists(model_path): raise ValueError("Inference model path error! Please check INFERENCE.MODEL_PATH in your yaml.") # IF USING MODEL FROM TRAINING else: model_path = os.path.join(self.model_dir, self.model_file_name + '_Epoch' + str(self.used_epoch) + '.pth') print("Testing uses non-pretrained model!") print('Model path:', model_path) if not os.path.exists(model_path): raise ValueError("Something went wrong... cant find trained model...") print('') # LOAD ABOVED SPECIFIED MODEL FOR TESTING self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) self.model = self.model.to(self.device) self.model.eval() # MODEL TESTING print("Running model evaluation on the testing dataset!") with torch.no_grad(): for _, test_batch in enumerate(tqdm(data_loader["test"], ncols=80)): # PROCESSING - ANALYSIS, METRICS, SAVING OUT DATA batch_size = test_batch[1].shape[0] # get batch size # GATHER AND FORMAT BATCH DATA data, labels = test_batch[0], test_batch[1] data, labels = self.format_data_shape(data, labels) data, labels = self.send_data_to_device(data, labels) # Weird dataloader bug is causing the final training batch to be of size 0... if labels.shape[0] == 0: continue # GET MODEL PREDICTIONS au_out, bvp_out, resp_out = self.model(data) au_out = torch.sigmoid(au_out) # GATHER AND SLICE LABELS USED FOR TEST DATASET TEST_AU = False if len(self.label_idx_test_au) > 0: # if test dataset has AU TEST_AU = True labels_au = labels[:, self.label_idx_test_au] else: # if not set whole AU labels array to -1 labels_au = np.ones((batch_size, len(self.label_idx_train_au))) labels_au = -1 * labels_au # labels_au = torch.from_numpy(labels_au) TEST_BVP = False if len(self.label_idx_test_bvp) > 0: # if test dataset has BVP TEST_BVP = True labels_bvp = labels[:, self.label_idx_test_bvp] else: # if not set whole BVP labels array to -1 labels_bvp = np.ones((batch_size, len(self.label_idx_train_bvp))) labels_bvp = -1 * labels_bvp # labels_bvp = torch.from_numpy(labels_bvp) TEST_RESP = False if len(self.label_idx_test_resp) > 0: # if test dataset has BVP TEST_RESP = True labels_resp = labels[:, self.label_idx_test_resp] else: # if not set whole BVP labels array to -1 labels_resp = np.ones((batch_size, len(self.label_idx_train_resp))) labels_resp = -1 * labels_resp # labels_resp = torch.from_numpy(labels_resp) # ITERATE THROUGH BATCH, SORT, AND ADD TO CORRECT DICTIONARY for idx in range(batch_size): # if the labels are cut off due to TSM dataformating if idx * self.chunk_len >= labels.shape[0] and self.using_TSM: continue subj_index = test_batch[2][idx] sort_index = int(test_batch[3][idx]) # add subject to prediction / label arrays if subj_index not in preds_dict_bvp.keys(): preds_dict_au[subj_index] = dict() labels_dict_au[subj_index] = dict() preds_dict_bvp[subj_index] = dict() labels_dict_bvp[subj_index] = dict() preds_dict_resp[subj_index] = dict() labels_dict_resp[subj_index] = dict() # append predictions and labels to subject dict preds_dict_au[subj_index][sort_index] = au_out[idx * self.chunk_len:(idx + 1) * self.chunk_len] labels_dict_au[subj_index][sort_index] = labels_au[idx * self.chunk_len:(idx + 1) * self.chunk_len] preds_dict_bvp[subj_index][sort_index] = bvp_out[idx * self.chunk_len:(idx + 1) * self.chunk_len] labels_dict_bvp[subj_index][sort_index] = labels_bvp[idx * self.chunk_len:(idx + 1) * self.chunk_len] preds_dict_resp[subj_index][sort_index] = resp_out[idx * self.chunk_len:(idx + 1) * self.chunk_len] labels_dict_resp[subj_index][sort_index] = labels_resp[idx * self.chunk_len:(idx + 1) * self.chunk_len] # Calculate Eval Metrics bvp_metric_dict = calculate_bvp_metrics(preds_dict_bvp, labels_dict_bvp, self.config) resp_metric_dict = calculate_resp_metrics(preds_dict_resp, labels_dict_resp, self.config) au_metric_dict = calculate_bp4d_au_metrics(preds_dict_au, labels_dict_au, self.config)