import os import pickle import random import torch import torch.nn as nn import numpy as np from torch.utils.data import Dataset from utilities.constants import * from utilities.device import cpu_device from utilities.device import get_device import json SEQUENCE_START = 0 class VevoDataset(Dataset): def __init__(self, dataset_root = "./dataset/", split="train", split_ver="v1", vis_models="2d/clip_l14p", emo_model="6c_l14p", max_seq_chord=300, max_seq_video=300, random_seq=True, is_video = True): self.dataset_root = dataset_root self.vevo_chord_root = os.path.join( dataset_root, "vevo_chord", "lab_v2_norm", "all") self.vevo_emotion_root = os.path.join( dataset_root, "vevo_emotion", emo_model, "all") self.vevo_motion_root = os.path.join( dataset_root, "vevo_motion", "all") self.vevo_scene_offset_root = os.path.join( dataset_root, "vevo_scene_offset", "all") self.vevo_meta_split_path = os.path.join( dataset_root, "vevo_meta", "split", split_ver, split + ".txt") self.vevo_loudness_root = os.path.join( dataset_root, "vevo_loudness", "all") self.vevo_note_density_root = os.path.join( dataset_root, "vevo_note_density", "all") self.max_seq_video = max_seq_video self.max_seq_chord = max_seq_chord self.random_seq = random_seq self.is_video = is_video self.vis_models_arr = vis_models.split(" ") self.vevo_semantic_root_list = [] self.id_list = [] self.emo_model = emo_model if IS_VIDEO: for i in range( len(self.vis_models_arr) ): p1 = self.vis_models_arr[i].split("/")[0] p2 = self.vis_models_arr[i].split("/")[1] vevo_semantic_root = os.path.join(dataset_root, "vevo_semantic" , "all" , p1, p2) self.vevo_semantic_root_list.append( vevo_semantic_root ) with open( self.vevo_meta_split_path ) as f: for line in f: self.id_list.append(line.strip()) self.data_files_chord = [] self.data_files_emotion = [] self.data_files_motion = [] self.data_files_scene_offset = [] self.data_files_semantic_list = [] self.data_files_loudness = [] self.data_files_note_density = [] for i in range(len(self.vis_models_arr)): self.data_files_semantic_list.append([]) for fid in self.id_list: fpath_chord = os.path.join( self.vevo_chord_root, fid + ".lab" ) fpath_emotion = os.path.join( self.vevo_emotion_root, fid + ".lab" ) fpath_motion = os.path.join( self.vevo_motion_root, fid + ".lab" ) fpath_scene_offset = os.path.join( self.vevo_scene_offset_root, fid + ".lab" ) fpath_loudness = os.path.join( self.vevo_loudness_root, fid + ".lab" ) fpath_note_density = os.path.join( self.vevo_note_density_root, fid + ".lab" ) fpath_semantic_list = [] for vevo_semantic_root in self.vevo_semantic_root_list: fpath_semantic = os.path.join( vevo_semantic_root, fid + ".npy" ) fpath_semantic_list.append(fpath_semantic) checkFile_semantic = True for fpath_semantic in fpath_semantic_list: if not os.path.exists(fpath_semantic): checkFile_semantic = False checkFile_chord = os.path.exists(fpath_chord) checkFile_emotion = os.path.exists(fpath_emotion) checkFile_motion = os.path.exists(fpath_motion) checkFile_scene_offset = os.path.exists(fpath_scene_offset) checkFile_loudness = os.path.exists(fpath_loudness) checkFile_note_density = os.path.exists(fpath_note_density) if checkFile_chord and checkFile_emotion and checkFile_motion \ and checkFile_scene_offset and checkFile_semantic and checkFile_loudness and checkFile_note_density : self.data_files_chord.append(fpath_chord) self.data_files_emotion.append(fpath_emotion) self.data_files_motion.append(fpath_motion) self.data_files_scene_offset.append(fpath_scene_offset) self.data_files_loudness.append(fpath_loudness) self.data_files_note_density.append(fpath_note_density) if IS_VIDEO: for i in range(len(self.vis_models_arr)): self.data_files_semantic_list[i].append( fpath_semantic_list[i] ) chordDicPath = os.path.join( dataset_root, "vevo_meta/chord.json") chordRootDicPath = os.path.join( dataset_root, "vevo_meta/chord_root.json") chordAttrDicPath = os.path.join( dataset_root, "vevo_meta/chord_attr.json") with open(chordDicPath) as json_file: self.chordDic = json.load(json_file) with open(chordRootDicPath) as json_file: self.chordRootDic = json.load(json_file) with open(chordAttrDicPath) as json_file: self.chordAttrDic = json.load(json_file) def __len__(self): return len(self.data_files_chord) def __getitem__(self, idx): #### ---- CHORD ----- #### feature_chord = np.empty(self.max_seq_chord) feature_chord.fill(CHORD_PAD) feature_chordRoot = np.empty(self.max_seq_chord) feature_chordRoot.fill(CHORD_ROOT_PAD) feature_chordAttr = np.empty(self.max_seq_chord) feature_chordAttr.fill(CHORD_ATTR_PAD) key = "" with open(self.data_files_chord[idx], encoding = 'utf-8') as f: for line in f: line = line.strip() line_arr = line.split(" ") if line_arr[0] == "key": key = line_arr[1] + " "+ line_arr[2] continue time = line_arr[0] time = int(time) if time >= self.max_seq_chord: break chord = line_arr[1] chordID = self.chordDic[chord] feature_chord[time] = chordID chord_arr = chord.split(":") if len(chord_arr) == 1: if chord_arr[0] == "N": chordRootID = self.chordRootDic["N"] chordAttrID = self.chordAttrDic["N"] feature_chordRoot[time] = chordRootID feature_chordAttr[time] = chordAttrID else: chordRootID = self.chordRootDic[chord_arr[0]] feature_chordRoot[time] = chordRootID feature_chordAttr[time] = 1 elif len(chord_arr) == 2: chordRootID = self.chordRootDic[chord_arr[0]] chordAttrID = self.chordAttrDic[chord_arr[1]] feature_chordRoot[time] = chordRootID feature_chordAttr[time] = chordAttrID if "major" in key: feature_key = torch.tensor([0]) else: feature_key = torch.tensor([1]) feature_chord = torch.from_numpy(feature_chord) feature_chord = feature_chord.to(torch.long) feature_chordRoot = torch.from_numpy(feature_chordRoot) feature_chordRoot = feature_chordRoot.to(torch.long) feature_chordAttr = torch.from_numpy(feature_chordAttr) feature_chordAttr = feature_chordAttr.to(torch.long) feature_key = feature_key.float() x = feature_chord[:self.max_seq_chord-1] tgt = feature_chord[1:self.max_seq_chord] x_root = feature_chordRoot[:self.max_seq_chord-1] tgt_root = feature_chordRoot[1:self.max_seq_chord] x_attr = feature_chordAttr[:self.max_seq_chord-1] tgt_attr = feature_chordAttr[1:self.max_seq_chord] if time < self.max_seq_chord: tgt[time] = CHORD_END tgt_root[time] = CHORD_ROOT_END tgt_attr[time] = CHORD_ATTR_END #### ---- SCENE OFFSET ----- #### feature_scene_offset = np.empty(self.max_seq_video) feature_scene_offset.fill(SCENE_OFFSET_PAD) with open(self.data_files_scene_offset[idx], encoding = 'utf-8') as f: for line in f: line = line.strip() line_arr = line.split(" ") time = line_arr[0] time = int(time) if time >= self.max_seq_chord: break sceneID = line_arr[1] feature_scene_offset[time] = int(sceneID)+1 feature_scene_offset = torch.from_numpy(feature_scene_offset) feature_scene_offset = feature_scene_offset.to(torch.float32) #### ---- MOTION ----- #### feature_motion = np.empty(self.max_seq_video) feature_motion.fill(MOTION_PAD) with open(self.data_files_motion[idx], encoding = 'utf-8') as f: for line in f: line = line.strip() line_arr = line.split(" ") time = line_arr[0] time = int(time) if time >= self.max_seq_chord: break motion = line_arr[1] feature_motion[time] = float(motion) feature_motion = torch.from_numpy(feature_motion) feature_motion = feature_motion.to(torch.float32) #### ---- NOTE_DENSITY ----- #### feature_note_density = np.empty(self.max_seq_video) feature_note_density.fill(NOTE_DENSITY_PAD) with open(self.data_files_note_density[idx], encoding = 'utf-8') as f: for line in f: line = line.strip() line_arr = line.split(" ") time = line_arr[0] time = int(time) if time >= self.max_seq_chord: break note_density = line_arr[1] feature_note_density[time] = float(note_density) feature_note_density = torch.from_numpy(feature_note_density) feature_note_density = feature_note_density.to(torch.float32) #### ---- LOUDNESS ----- #### feature_loudness = np.empty(self.max_seq_video) feature_loudness.fill(LOUDNESS_PAD) with open(self.data_files_loudness[idx], encoding = 'utf-8') as f: for line in f: line = line.strip() line_arr = line.split(" ") time = line_arr[0] time = int(time) if time >= self.max_seq_chord: break loudness = line_arr[1] feature_loudness[time] = float(loudness) feature_loudness = torch.from_numpy(feature_loudness) feature_loudness = feature_loudness.to(torch.float32) #### ---- EMOTION ----- #### if self.emo_model.startswith("6c"): feature_emotion = np.empty( (self.max_seq_video, 6)) else: feature_emotion = np.empty( (self.max_seq_video, 5)) feature_emotion.fill(EMOTION_PAD) with open(self.data_files_emotion[idx], encoding = 'utf-8') as f: for line in f: line = line.strip() line_arr = line.split(" ") if line_arr[0] == "time": continue time = line_arr[0] time = int(time) if time >= self.max_seq_chord: break if len(line_arr) == 7: emo1, emo2, emo3, emo4, emo5, emo6 = \ line_arr[1],line_arr[2],line_arr[3],line_arr[4],line_arr[5],line_arr[6] emoList = [ float(emo1), float(emo2), float(emo3), float(emo4), float(emo5), float(emo6) ] elif len(line_arr) == 6: emo1, emo2, emo3, emo4, emo5 = \ line_arr[1],line_arr[2],line_arr[3],line_arr[4],line_arr[5] emoList = [ float(emo1), float(emo2), float(emo3), float(emo4), float(emo5) ] emoList = np.array(emoList) feature_emotion[time] = emoList feature_emotion = torch.from_numpy(feature_emotion) feature_emotion = feature_emotion.to(torch.float32) feature_emotion_argmax = torch.argmax(feature_emotion, dim=1) _, max_prob_indices = torch.max(feature_emotion, dim=1) max_prob_values = torch.gather(feature_emotion, dim=1, index=max_prob_indices.unsqueeze(1)) max_prob_values = max_prob_values.squeeze() # -- emotion to chord # maj dim sus4 min7 min sus2 aug dim7 maj6 hdim7 7 min6 maj7 # 0. extcing : [1,0,1,0,0,0,0,0,0,0,1,0,0] # 1. fearful : [0,1,0,1,0,0,0,1,0,1,0,0,0] # 2. tense : [0,1,1,1,0,0,0,0,0,0,1,0,0] # 3. sad : [0,0,0,1,1,1,0,0,0,0,0,0,0] # 4. relaxing: [1,0,0,0,0,0,0,0,1,0,0,0,1] # 5. neutral : [0,0,0,0,0,0,0,0,0,0,0,0,0] a0 = [0]+[1,0,1,0,0,0,0,0,0,0,1,0,0]*12+[0,0] a1 = [0]+[0,1,0,1,0,0,0,1,0,1,0,0,0]*12+[0,0] a2 = [0]+[0,1,1,1,0,0,0,0,0,0,1,0,0]*12+[0,0] a3 = [0]+[0,0,0,1,1,1,0,0,0,0,0,0,0]*12+[0,0] a4 = [0]+[1,0,0,0,0,0,0,0,1,0,0,0,1]*12+[0,0] a5 = [0]+[0,0,0,0,0,0,0,0,0,0,0,0,0]*12+[0,0] aend = [0]+[0,0,0,0,0,0,0,0,0,0,0,0,0]*12+[1,0] apad = [0]+[0,0,0,0,0,0,0,0,0,0,0,0,0]*12+[0,1] a0_tensor = torch.tensor(a0) a1_tensor = torch.tensor(a1) a2_tensor = torch.tensor(a2) a3_tensor = torch.tensor(a3) a4_tensor = torch.tensor(a4) a5_tensor = torch.tensor(a5) aend_tensor = torch.tensor(aend) apad_tensor = torch.tensor(apad) mapped_tensor = torch.zeros((300, 159)) for i, val in enumerate(feature_emotion_argmax): if feature_chord[i] == CHORD_PAD: mapped_tensor[i] = apad_tensor elif feature_chord[i] == CHORD_END: mapped_tensor[i] = aend_tensor elif val == 0: mapped_tensor[i] = a0_tensor elif val == 1: mapped_tensor[i] = a1_tensor elif val == 2: mapped_tensor[i] = a2_tensor elif val == 3: mapped_tensor[i] = a3_tensor elif val == 4: mapped_tensor[i] = a4_tensor elif val == 5: mapped_tensor[i] = a5_tensor # feature emotion : [1, 300, 6] # y : [299, 159] # tgt : [299] # tgt_emo : [299, 159] # tgt_emo_prob : [299] tgt_emotion = mapped_tensor[1:] tgt_emotion_prob = max_prob_values[1:] feature_semantic_list = [] if self.is_video: for i in range( len(self.vis_models_arr) ): video_feature = np.load(self.data_files_semantic_list[i][idx]) dim_vf = video_feature.shape[1] # 2048 video_feature_tensor = torch.from_numpy( video_feature ) feature_semantic = torch.full((self.max_seq_video, dim_vf,), SEMANTIC_PAD , dtype=torch.float32, device=cpu_device()) if(video_feature_tensor.shape[0] < self.max_seq_video): feature_semantic[:video_feature_tensor.shape[0]] = video_feature_tensor else: feature_semantic = video_feature_tensor[:self.max_seq_video] feature_semantic_list.append(feature_semantic) return { "x":x, "tgt":tgt, "x_root":x_root, "tgt_root":tgt_root, "x_attr":x_attr, "tgt_attr":tgt_attr, "semanticList": feature_semantic_list, "key": feature_key, "scene_offset": feature_scene_offset, "motion": feature_motion, "emotion": feature_emotion, "tgt_emotion" : tgt_emotion, "tgt_emotion_prob" : tgt_emotion_prob, "note_density" : feature_note_density, "loudness" : feature_loudness } def create_vevo_datasets(dataset_root = "./dataset", max_seq_chord=300, max_seq_video=300, vis_models="2d/clip_l14p", emo_model="6c_l14p", split_ver="v1", random_seq=True, is_video=True): train_dataset = VevoDataset( dataset_root = dataset_root, split="train", split_ver=split_ver, vis_models=vis_models, emo_model =emo_model, max_seq_chord=max_seq_chord, max_seq_video=max_seq_video, random_seq=random_seq, is_video = is_video ) val_dataset = VevoDataset( dataset_root = dataset_root, split="val", split_ver=split_ver, vis_models=vis_models, emo_model =emo_model, max_seq_chord=max_seq_chord, max_seq_video=max_seq_video, random_seq=random_seq, is_video = is_video ) test_dataset = VevoDataset( dataset_root = dataset_root, split="test", split_ver=split_ver, vis_models=vis_models, emo_model =emo_model, max_seq_chord=max_seq_chord, max_seq_video=max_seq_video, random_seq=random_seq, is_video = is_video ) return train_dataset, val_dataset, test_dataset def compute_vevo_accuracy(out, tgt): softmax = nn.Softmax(dim=-1) out = torch.argmax(softmax(out), dim=-1) out = out.flatten() tgt = tgt.flatten() mask = (tgt != CHORD_PAD) out = out[mask] tgt = tgt[mask] if(len(tgt) == 0): return 1.0 num_right = (out == tgt) num_right = torch.sum(num_right).type(TORCH_FLOAT) acc = num_right / len(tgt) return acc def compute_hits_k(out, tgt, k): softmax = nn.Softmax(dim=-1) out = softmax(out) _, topk_indices = torch.topk(out, k, dim=-1) # Get the indices of top-k values tgt = tgt.flatten() topk_indices = torch.squeeze(topk_indices, dim = 0) num_right = 0 pt = 0 for i, tlist in enumerate(topk_indices): if tgt[i] == CHORD_PAD: num_right += 0 else: pt += 1 if tgt[i].item() in tlist: num_right += 1 # Empty if len(tgt) == 0: return 1.0 num_right = torch.tensor(num_right, dtype=torch.float32) hitk = num_right / pt return hitk def compute_hits_k_root_attr(out_root, out_attr, tgt, k): softmax = nn.Softmax(dim=-1) out_root = softmax(out_root) out_attr = softmax(out_attr) tensor_shape = torch.Size([1, 299, 159]) out = torch.zeros(tensor_shape) for i in range(out.shape[-1]): if i == 0 : out[0, :, i] = out_root[0, :, 0] * out_attr[0, :, 0] elif i == 157: out[0, :, i] = out_root[0, :, 13] * out_attr[0, :, 14] elif i == 158: out[0, :, i] = out_root[0, :, 14] * out_attr[0, :, 15] else: rootindex = int( (i-1)/13 ) + 1 attrindex = (i-1)%13 + 1 out[0, :, i] = out_root[0, :, rootindex] * out_attr[0, :, attrindex] out = softmax(out) _, topk_indices = torch.topk(out, k, dim=-1) # Get the indices of top-k values tgt = tgt.flatten() topk_indices = torch.squeeze(topk_indices, dim = 0) num_right = 0 pt = 0 for i, tlist in enumerate(topk_indices): if tgt[i] == CHORD_PAD: num_right += 0 else: pt += 1 if tgt[i].item() in tlist: num_right += 1 if len(tgt) == 0: return 1.0 num_right = torch.tensor(num_right, dtype=torch.float32) hitk = num_right / pt return hitk def compute_vevo_correspondence(out, tgt, tgt_emotion, tgt_emotion_prob, emotion_threshold): tgt_emotion = tgt_emotion.squeeze() tgt_emotion_prob = tgt_emotion_prob.squeeze() dataset_root = "./dataset/" chordRootInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_root_inv.json") chordAttrInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_attr_inv.json") chordAttrDicPath = os.path.join( dataset_root, "vevo_meta/chord_attr.json") chordDicPath = os.path.join( dataset_root, "vevo_meta/chord.json") chordInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_inv.json") with open(chordRootInvDicPath) as json_file: chordRootInvDic = json.load(json_file) with open(chordAttrDicPath) as json_file: chordAttrDic = json.load(json_file) with open(chordAttrInvDicPath) as json_file: chordAttrInvDic = json.load(json_file) with open(chordDicPath) as json_file: chordDic = json.load(json_file) with open(chordInvDicPath) as json_file: chordInvDic = json.load(json_file) softmax = nn.Softmax(dim=-1) out = torch.argmax(softmax(out), dim=-1) out = out.flatten() tgt = tgt.flatten() num_right = 0 tgt_emotion_quality = tgt_emotion[:, 0:14] pt = 0 for i, out_element in enumerate( out ): all_zeros = torch.all(tgt_emotion_quality[i] == 0) if tgt_emotion[i][-1] == 1 or all_zeros or tgt_emotion_prob[i] < emotion_threshold: num_right += 0 else: pt += 1 if out_element.item() != CHORD_END and out_element.item() != CHORD_PAD: gen_chord = chordInvDic[ str( out_element.item() ) ] chord_arr = gen_chord.split(":") if len(chord_arr) == 1: out_quality = 1 elif len(chord_arr) == 2: chordAttrID = chordAttrDic[chord_arr[1]] out_quality = chordAttrID # 0:N, 1:maj ... 13:maj7 if tgt_emotion_quality[i][out_quality] == 1: num_right += 1 if(len(tgt_emotion) == 0): return 1.0 if(pt == 0): return -1 num_right = torch.tensor(num_right, dtype=torch.float32) acc = num_right / pt return acc def compute_vevo_correspondence_root_attr(y_root, y_attr, tgt, tgt_emotion, tgt_emotion_prob, emotion_threshold): tgt_emotion = tgt_emotion.squeeze() tgt_emotion_prob = tgt_emotion_prob.squeeze() dataset_root = "./dataset/" chordRootInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_root_inv.json") chordAttrInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_attr_inv.json") chordAttrDicPath = os.path.join( dataset_root, "vevo_meta/chord_attr.json") chordDicPath = os.path.join( dataset_root, "vevo_meta/chord.json") chordInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_inv.json") with open(chordRootInvDicPath) as json_file: chordRootInvDic = json.load(json_file) with open(chordAttrDicPath) as json_file: chordAttrDic = json.load(json_file) with open(chordAttrInvDicPath) as json_file: chordAttrInvDic = json.load(json_file) with open(chordDicPath) as json_file: chordDic = json.load(json_file) with open(chordInvDicPath) as json_file: chordInvDic = json.load(json_file) softmax = nn.Softmax(dim=-1) y_root = torch.argmax(softmax(y_root), dim=-1) y_attr = torch.argmax(softmax(y_attr), dim=-1) y_root = y_root.flatten() y_attr = y_attr.flatten() tgt = tgt.flatten() y = np.empty( len(tgt) ) y.fill(CHORD_PAD) for i in range(len(tgt)): if y_root[i].item() == CHORD_ROOT_PAD or y_attr[i].item() == CHORD_ATTR_PAD: y[i] = CHORD_PAD elif y_root[i].item() == CHORD_ROOT_END or y_attr[i].item() == CHORD_ATTR_END: y[i] = CHORD_END else: chordRoot = chordRootInvDic[str(y_root[i].item())] chordAttr = chordAttrInvDic[str(y_attr[i].item())] if chordRoot == "N": y[i] = 0 else: if chordAttr == "N" or chordAttr == "maj": y[i] = chordDic[chordRoot] else: chord = chordRoot + ":" + chordAttr y[i] = chordDic[chord] y = torch.from_numpy(y) y = y.to(torch.long) y = y.to(get_device()) y = y.flatten() num_right = 0 tgt_emotion_quality = tgt_emotion[:, 0:14] pt = 0 for i, y_element in enumerate( y ): all_zeros = torch.all(tgt_emotion_quality[i] == 0) if tgt_emotion[i][-1] == 1 or all_zeros or tgt_emotion_prob[i] < emotion_threshold: num_right += 0 else: pt += 1 if y_element.item() != CHORD_END and y_element.item() != CHORD_PAD: gen_chord = chordInvDic[ str( y_element.item() ) ] chord_arr = gen_chord.split(":") if len(chord_arr) == 1: y_quality = 1 elif len(chord_arr) == 2: chordAttrID = chordAttrDic[chord_arr[1]] y_quality = chordAttrID # 0:N, 1:maj ... 13:maj7 if tgt_emotion_quality[i][y_quality] == 1: num_right += 1 if(len(tgt_emotion) == 0): return 1.0 if(pt == 0): return -1 num_right = torch.tensor(num_right, dtype=torch.float32) acc = num_right / pt return acc def compute_vevo_accuracy_root_attr(y_root, y_attr, tgt): dataset_root = "./dataset/" chordRootInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_root_inv.json") chordAttrInvDicPath = os.path.join( dataset_root, "vevo_meta/chord_attr_inv.json") chordDicPath = os.path.join( dataset_root, "vevo_meta/chord.json") with open(chordRootInvDicPath) as json_file: chordRootInvDic = json.load(json_file) with open(chordAttrInvDicPath) as json_file: chordAttrInvDic = json.load(json_file) with open(chordDicPath) as json_file: chordDic = json.load(json_file) softmax = nn.Softmax(dim=-1) y_root = torch.argmax(softmax(y_root), dim=-1) y_attr = torch.argmax(softmax(y_attr), dim=-1) y_root = y_root.flatten() y_attr = y_attr.flatten() tgt = tgt.flatten() mask = (tgt != CHORD_PAD) y = np.empty( len(tgt) ) y.fill(CHORD_PAD) for i in range(len(tgt)): if y_root[i].item() == CHORD_ROOT_PAD or y_attr[i].item() == CHORD_ATTR_PAD: y[i] = CHORD_PAD elif y_root[i].item() == CHORD_ROOT_END or y_attr[i].item() == CHORD_ATTR_END: y[i] = CHORD_END else: chordRoot = chordRootInvDic[str(y_root[i].item())] chordAttr = chordAttrInvDic[str(y_attr[i].item())] if chordRoot == "N": y[i] = 0 else: if chordAttr == "N" or chordAttr == "maj": y[i] = chordDic[chordRoot] else: chord = chordRoot + ":" + chordAttr y[i] = chordDic[chord] y = torch.from_numpy(y) y = y.to(torch.long) y = y.to(get_device()) y = y[mask] tgt = tgt[mask] # Empty if(len(tgt) == 0): return 1.0 num_right = (y == tgt) num_right = torch.sum(num_right).type(TORCH_FLOAT) acc = num_right / len(tgt) return acc