import torch import time from .constants import * from utilities.device import get_device from .lr_scheduling import get_lr import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix import json from dataset.vevo_dataset import compute_vevo_accuracy, compute_vevo_correspondence, compute_hits_k, compute_hits_k_root_attr, compute_vevo_accuracy_root_attr, compute_vevo_correspondence_root_attr def train_epoch(cur_epoch, model, dataloader, train_loss_func, train_loss_emotion_func, opt, lr_scheduler=None, print_modulus=1, isVideo=True): loss_chord = -1 loss_emotion = -1 model.train() for batch_num, batch in enumerate(dataloader): time_before = time.time() opt.zero_grad() x = batch["x"].to(get_device()) tgt = batch["tgt"].to(get_device()) x_root = batch["x_root"].to(get_device()) tgt_root = batch["tgt_root"].to(get_device()) x_attr = batch["x_attr"].to(get_device()) tgt_attr = batch["tgt_attr"].to(get_device()) tgt_emotion = batch["tgt_emotion"].to(get_device()) tgt_emotion_prob = batch["tgt_emotion_prob"].to(get_device()) feature_semantic_list = [] for feature_semantic in batch["semanticList"]: feature_semantic_list.append( feature_semantic.to(get_device()) ) feature_key = batch["key"].to(get_device()) feature_scene_offset = batch["scene_offset"].to(get_device()) feature_motion = batch["motion"].to(get_device()) feature_emotion = batch["emotion"].to(get_device()) if isVideo: # use VideoMusicTransformer if IS_SEPERATED: y_root, y_attr = model(x, x_root, x_attr, feature_semantic_list, feature_key, feature_scene_offset, feature_motion, feature_emotion) y_root = y_root.reshape(y_root.shape[0] * y_root.shape[1], -1) y_attr = y_attr.reshape(y_attr.shape[0] * y_attr.shape[1], -1) tgt_root = tgt_root.flatten() tgt_attr = tgt_attr.flatten() tgt_emotion = tgt_emotion.squeeze() loss_chord_root = train_loss_func.forward(y_root, tgt_root) loss_chord_attr = train_loss_func.forward(y_attr, tgt_attr) loss_chord = loss_chord_root + loss_chord_attr first_14 = tgt_emotion[:, :14] last_2 = tgt_emotion[:, -2:] tgt_emotion_attr = torch.cat((first_14, last_2), dim=1) loss_emotion = train_loss_emotion_func.forward(y_attr, tgt_emotion_attr) total_loss = LOSS_LAMBDA * loss_chord + (1-LOSS_LAMBDA) * loss_emotion total_loss.backward() opt.step() if(lr_scheduler is not None): lr_scheduler.step() else: #videomusic tran nosep y = model(x, x_root, x_attr, feature_semantic_list, feature_key, feature_scene_offset, feature_motion, feature_emotion) y = y.reshape(y.shape[0] * y.shape[1], -1) tgt = tgt.flatten() tgt_emotion = tgt_emotion.squeeze() loss_chord = train_loss_func.forward(y, tgt) loss_emotion = train_loss_emotion_func.forward(y, tgt_emotion) total_loss = LOSS_LAMBDA * loss_chord + (1-LOSS_LAMBDA) * loss_emotion total_loss.backward() opt.step() if(lr_scheduler is not None): lr_scheduler.step() else: # music transformer if IS_SEPERATED: y_root, y_attr = model(x, x_root, x_attr, feature_key) y_root = y_root.reshape(y_root.shape[0] * y_root.shape[1], -1) y_attr = y_attr.reshape(y_attr.shape[0] * y_attr.shape[1], -1) tgt_root = tgt_root.flatten() tgt_attr = tgt_attr.flatten() tgt_emotion = tgt_emotion.squeeze() loss_chord_root = train_loss_func.forward(y_root, tgt_root) loss_chord_attr = train_loss_func.forward(y_attr, tgt_attr) loss_chord = loss_chord_root + loss_chord_attr loss_emotion = -1 total_loss = loss_chord total_loss.backward() opt.step() if(lr_scheduler is not None): lr_scheduler.step() else: # use MusicTransformer (no sep) y = model(x, x_root, x_attr, feature_key) y = y.reshape(y.shape[0] * y.shape[1], -1) tgt = tgt.flatten() loss_chord = train_loss_func.forward(y, tgt) loss_emotion = -1 total_loss = loss_chord total_loss.backward() opt.step() if(lr_scheduler is not None): lr_scheduler.step() time_after = time.time() time_took = time_after - time_before if((batch_num+1) % print_modulus == 0): print(SEPERATOR) print("Epoch", cur_epoch, " Batch", batch_num+1, "/", len(dataloader)) print("LR:", get_lr(opt)) print("Train loss (total):", float(total_loss)) print("Train loss (chord):", float(loss_chord)) print("Train loss (emotion):", float(loss_emotion)) print("") print("Time (s):", time_took) print(SEPERATOR) print("") return def eval_model(model, dataloader, eval_loss_func, eval_loss_emotion_func, isVideo = True, isGenConfusionMatrix=False): model.eval() avg_acc = -1 avg_cor = -1 avg_acc_cor = -1 avg_h1 = -1 avg_h3 = -1 avg_h5 = -1 avg_loss_chord = -1 avg_loss_emotion = -1 avg_total_loss = -1 true_labels = [] true_root_labels = [] true_attr_labels = [] pred_labels = [] pred_root_labels = [] pred_attr_labels = [] with torch.set_grad_enabled(False): n_test = len(dataloader) n_test_cor = 0 sum_loss_chord = 0.0 sum_loss_emotion = 0.0 sum_total_loss = 0.0 sum_acc = 0.0 sum_cor = 0.0 sum_h1 = 0.0 sum_h3 = 0.0 sum_h5 = 0.0 for batch in dataloader: x = batch["x"].to(get_device()) tgt = batch["tgt"].to(get_device()) x_root = batch["x_root"].to(get_device()) tgt_root = batch["tgt_root"].to(get_device()) x_attr = batch["x_attr"].to(get_device()) tgt_attr = batch["tgt_attr"].to(get_device()) tgt_emotion = batch["tgt_emotion"].to(get_device()) tgt_emotion_prob = batch["tgt_emotion_prob"].to(get_device()) feature_semantic_list = [] for feature_semantic in batch["semanticList"]: feature_semantic_list.append( feature_semantic.to(get_device()) ) feature_key = batch["key"].to(get_device()) feature_scene_offset = batch["scene_offset"].to(get_device()) feature_motion = batch["motion"].to(get_device()) feature_emotion = batch["emotion"].to(get_device()) if isVideo: if IS_SEPERATED: y_root, y_attr = model(x, x_root, x_attr, feature_semantic_list, feature_key, feature_scene_offset, feature_motion, feature_emotion) sum_acc += float(compute_vevo_accuracy_root_attr(y_root, y_attr, tgt)) cor = float(compute_vevo_correspondence_root_attr(y_root, y_attr, tgt, tgt_emotion, tgt_emotion_prob, EMOTION_THRESHOLD)) if cor >= 0 : n_test_cor +=1 sum_cor += cor sum_h1 += float(compute_hits_k_root_attr(y_root, y_attr, tgt,1)) sum_h3 += float(compute_hits_k_root_attr(y_root, y_attr, tgt,3)) sum_h5 += float(compute_hits_k_root_attr(y_root, y_attr, tgt,5)) y_root = y_root.reshape(y_root.shape[0] * y_root.shape[1], -1) y_attr = y_attr.reshape(y_attr.shape[0] * y_attr.shape[1], -1) tgt_root = tgt_root.flatten() tgt_attr = tgt_attr.flatten() tgt_emotion = tgt_emotion.squeeze() loss_chord_root = eval_loss_func.forward(y_root, tgt_root) loss_chord_attr = eval_loss_func.forward(y_attr, tgt_attr) loss_chord = loss_chord_root + loss_chord_attr first_14 = tgt_emotion[:, :14] last_2 = tgt_emotion[:, -2:] tgt_emotion_attr = torch.cat((first_14, last_2), dim=1) loss_emotion = eval_loss_emotion_func.forward(y_attr, tgt_emotion_attr) total_loss = LOSS_LAMBDA * loss_chord + (1-LOSS_LAMBDA) * loss_emotion sum_loss_chord += float(loss_chord) sum_loss_emotion += float(loss_emotion) sum_total_loss += float(total_loss) else: y= model(x, x_root, x_attr, feature_semantic_list, feature_key, feature_scene_offset, feature_motion, feature_emotion) sum_acc += float(compute_vevo_accuracy(y, tgt )) cor = float(compute_vevo_correspondence(y, tgt, tgt_emotion, tgt_emotion_prob, EMOTION_THRESHOLD)) if cor >= 0 : n_test_cor +=1 sum_cor += cor sum_h1 += float(compute_hits_k(y, tgt,1)) sum_h3 += float(compute_hits_k(y, tgt,3)) sum_h5 += float(compute_hits_k(y, tgt,5)) y = y.reshape(y.shape[0] * y.shape[1], -1) tgt = tgt.flatten() tgt_root = tgt_root.flatten() tgt_attr = tgt_attr.flatten() tgt_emotion = tgt_emotion.squeeze() loss_chord = eval_loss_func.forward(y, tgt) loss_emotion = eval_loss_emotion_func.forward(y, tgt_emotion) total_loss = LOSS_LAMBDA * loss_chord + (1-LOSS_LAMBDA) * loss_emotion sum_loss_chord += float(loss_chord) sum_loss_emotion += float(loss_emotion) sum_total_loss += float(total_loss) if isGenConfusionMatrix: pred = y.argmax(dim=1).detach().cpu().numpy() pred_root = [] pred_attr = [] for i in pred: if i == 0: pred_root.append(0) pred_attr.append(0) elif i == 157: pred_root.append(CHORD_ROOT_END) pred_attr.append(CHORD_ATTR_END) elif i == 158: pred_root.append(CHORD_ROOT_PAD) pred_attr.append(CHORD_ATTR_PAD) else: rootindex = int( (i-1)/13 ) + 1 attrindex = (i-1)%13 + 1 pred_root.append(rootindex) pred_attr.append(attrindex) pred_root = np.array(pred_root) pred_attr = np.array(pred_attr) true = tgt.detach().cpu().numpy() true_root = tgt_root.detach().cpu().numpy() true_attr = tgt_attr.detach().cpu().numpy() pred_labels.extend(pred) pred_root_labels.extend(pred_root) pred_attr_labels.extend(pred_attr) true_labels.extend(true) true_root_labels.extend(true_root) true_attr_labels.extend(true_attr) else: if IS_SEPERATED: y_root, y_attr = model(x, x_root, x_attr, feature_key) sum_acc += float(compute_vevo_accuracy_root_attr(y_root, y_attr, tgt)) cor = float(compute_vevo_correspondence_root_attr(y_root, y_attr, tgt, tgt_emotion, tgt_emotion_prob, EMOTION_THRESHOLD)) if cor >= 0 : n_test_cor +=1 sum_cor += cor sum_h1 += float(compute_hits_k_root_attr(y_root, y_attr, tgt,1)) sum_h3 += float(compute_hits_k_root_attr(y_root, y_attr, tgt,3)) sum_h5 += float(compute_hits_k_root_attr(y_root, y_attr, tgt,5)) y_root = y_root.reshape(y_root.shape[0] * y_root.shape[1], -1) y_attr = y_attr.reshape(y_attr.shape[0] * y_attr.shape[1], -1) tgt_root = tgt_root.flatten() tgt_attr = tgt_attr.flatten() tgt_emotion = tgt_emotion.squeeze() loss_chord_root = eval_loss_func.forward(y_root, tgt_root) loss_chord_attr = eval_loss_func.forward(y_attr, tgt_attr) loss_chord = loss_chord_root + loss_chord_attr first_14 = tgt_emotion[:, :14] last_2 = tgt_emotion[:, -2:] tgt_emotion_attr = torch.cat((first_14, last_2), dim=1) loss_emotion = eval_loss_emotion_func.forward(y_attr, tgt_emotion_attr) total_loss = LOSS_LAMBDA * loss_chord + (1-LOSS_LAMBDA) * loss_emotion sum_loss_chord += float(loss_chord) sum_loss_emotion += float(loss_emotion) sum_total_loss += float(total_loss) else: # use MusicTransformer no sep y = model(x, x_root, x_attr, feature_key) sum_acc += float(compute_vevo_accuracy(y, tgt )) cor = float(compute_vevo_correspondence(y, tgt, tgt_emotion, tgt_emotion_prob, EMOTION_THRESHOLD)) if cor >= 0 : n_test_cor +=1 sum_cor += cor sum_h1 += float(compute_hits_k(y, tgt,1)) sum_h3 += float(compute_hits_k(y, tgt,3)) sum_h5 += float(compute_hits_k(y, tgt,5)) tgt_emotion = tgt_emotion.squeeze() y = y.reshape(y.shape[0] * y.shape[1], -1) tgt = tgt.flatten() loss_chord = eval_loss_func.forward(y, tgt) loss_emotion = eval_loss_emotion_func.forward(y, tgt_emotion) total_loss = loss_chord sum_loss_chord += float(loss_chord) sum_loss_emotion += float(loss_emotion) sum_total_loss += float(total_loss) avg_loss_chord = sum_loss_chord / n_test avg_loss_emotion = sum_loss_emotion / n_test avg_total_loss = sum_total_loss / n_test avg_acc = sum_acc / n_test avg_cor = sum_cor / n_test_cor avg_h1 = sum_h1 / n_test avg_h3 = sum_h3 / n_test avg_h5 = sum_h5 / n_test avg_acc_cor = (avg_acc + avg_cor)/ 2.0 if isGenConfusionMatrix: chordInvDicPath = "./dataset/vevo_meta/chord_inv.json" chordRootInvDicPath = "./dataset/vevo_meta/chord_root_inv.json" chordAttrInvDicPath = "./dataset/vevo_meta/chord_attr_inv.json" with open(chordInvDicPath) as json_file: chordInvDic = json.load(json_file) with open(chordRootInvDicPath) as json_file: chordRootInvDic = json.load(json_file) with open(chordAttrInvDicPath) as json_file: chordAttrInvDic = json.load(json_file) # Confusion matrix (CHORD) topChordList = [] with open("./dataset/vevo_meta/top_chord.txt", encoding = 'utf-8') as f: for line in f: line = line.strip() line_arr = line.split(" ") if len(line_arr) == 3 : chordID = line_arr[1] topChordList.append( int(chordID) ) topChordList = np.array(topChordList) topChordList = topChordList[:10] mask = np.isin(true_labels, topChordList) true_labels = np.array(true_labels)[mask] pred_labels = np.array(pred_labels)[mask] conf_matrix = confusion_matrix(true_labels, pred_labels, labels=topChordList) label_names = [ chordInvDic[str(label_id)] for label_id in topChordList ] plt.figure(figsize=(8, 6)) plt.imshow(conf_matrix, cmap=plt.cm.Blues) plt.title("Confusion Matrix") plt.colorbar() tick_marks = np.arange(len(topChordList)) plt.xticks(tick_marks, label_names, rotation=45) plt.yticks(tick_marks, label_names) thresh = conf_matrix.max() / 2.0 for i in range(conf_matrix.shape[0]): for j in range(conf_matrix.shape[1]): plt.text(j, i, format(conf_matrix[i, j], 'd'), ha="center", va="center", color="white" if conf_matrix[i, j] > thresh else "black") plt.ylabel('True label') plt.xlabel('Predicted label') plt.tight_layout() plt.savefig("confusion_matrix.png") plt.show() # Confusion matrix (CHORD ROOT) chordRootList = np.arange(1, 13) conf_matrix = confusion_matrix(true_root_labels, pred_root_labels, labels= chordRootList ) label_names = [ chordRootInvDic[str(label_id)] for label_id in chordRootList ] plt.figure(figsize=(8, 6)) plt.imshow(conf_matrix, cmap=plt.cm.Blues) plt.title("Confusion Matrix (Chord root)") plt.colorbar() tick_marks = np.arange(len(chordRootList)) plt.xticks(tick_marks, label_names, rotation=45) plt.yticks(tick_marks, label_names) thresh = conf_matrix.max() / 2.0 for i in range(conf_matrix.shape[0]): for j in range(conf_matrix.shape[1]): plt.text(j, i, format(conf_matrix[i, j], 'd'), ha="center", va="center", color="white" if conf_matrix[i, j] > thresh else "black") plt.ylabel('True label') plt.xlabel('Predicted label') plt.tight_layout() plt.savefig("confusion_matrix_root.png") plt.show() # Confusion matrix (CHORD ATTR) chordAttrList = np.arange(1, 14) conf_matrix = confusion_matrix(true_attr_labels, pred_attr_labels, labels= chordAttrList ) label_names = [ chordAttrInvDic[str(label_id)] for label_id in chordAttrList ] plt.figure(figsize=(8, 6)) plt.imshow(conf_matrix, cmap=plt.cm.Blues) plt.title("Confusion Matrix (Chord quality)") plt.colorbar() tick_marks = np.arange(len(chordAttrList)) plt.xticks(tick_marks, label_names, rotation=45) plt.yticks(tick_marks, label_names) thresh = conf_matrix.max() / 2.0 for i in range(conf_matrix.shape[0]): for j in range(conf_matrix.shape[1]): plt.text(j, i, format(conf_matrix[i, j], 'd'), ha="center", va="center", color="white" if conf_matrix[i, j] > thresh else "black") plt.ylabel('True label') plt.xlabel('Predicted label') plt.tight_layout() plt.savefig("confusion_matrix_quality.png") plt.show() return { "avg_total_loss" : avg_total_loss, "avg_loss_chord" : avg_loss_chord, "avg_loss_emotion": avg_loss_emotion, "avg_acc" : avg_acc, "avg_cor" : avg_cor, "avg_acc_cor" : avg_acc_cor, "avg_h1" : avg_h1, "avg_h3" : avg_h3, "avg_h5" : avg_h5 }