import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import math import matplotlib.pyplot as plt import json import yaml class Fundamental_Music_Embedding(nn.Module): def __init__(self, d_model, base, if_trainable = False, if_translation_bias_trainable = True, device='cpu', type = "se",emb_nn=None,translation_bias_type = "nd"): super().__init__() self.d_model = d_model self.device = device self.base = base self.if_trainable = if_trainable #whether the se is trainable if translation_bias_type is not None: self.if_translation_bias = True self.if_translation_bias_trainable = if_translation_bias_trainable #default the 2d vector is trainable if translation_bias_type=="2d": translation_bias = torch.rand((1, 2), dtype = torch.float32) #Returns a tensor filled with random numbers from a uniform distribution on the interval [0, 1)[0,1) elif translation_bias_type=="nd": translation_bias = torch.rand((1, self.d_model), dtype = torch.float32) translation_bias = nn.Parameter(translation_bias, requires_grad=True) self.register_parameter("translation_bias", translation_bias) else: self.if_translation_bias = False i = torch.arange(d_model) angle_rates = 1 / torch.pow(self.base, (2 * (i//2)) / d_model) angle_rates = angle_rates[None, ... ]#.cuda() if self.if_trainable: angles = nn.Parameter(angle_rates, requires_grad=True) self.register_parameter("angles", angles) else: self.angles = angle_rates def __call__(self, inp, device): if inp.dim()==2: inp = inp[..., None] #pos (batch, num_pitch, 1) elif inp.dim()==1: inp = inp[None, ..., None] #pos (1, num_pitch, 1) angle_rads = inp*self.angles.to(device) #(batch, num_pitch)*(1,dim) # apply sin to even indices in the array; 2i angle_rads[:, :, 0::2] = torch.sin(angle_rads.clone()[:, : , 0::2]) # apply cos to odd indices in the array; 2i+1 angle_rads[:, :, 1::2] = torch.cos(angle_rads.clone()[:, :, 1::2]) pos_encoding = angle_rads.to(torch.float32) if self.if_translation_bias: if self.translation_bias.size()[-1]!= self.d_model: translation_bias = self.translation_bias.repeat(1, 1,int(self.d_model/2)) else: translation_bias = self.translation_bias pos_encoding += translation_bias else: self.translation_bias = None return pos_encoding class Music_PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, if_index = True, if_global_timing = True, if_modulo_timing = True, device = 'cuda:0'): super().__init__() self.if_index = if_index self.if_global_timing = if_global_timing self.if_modulo_timing = if_modulo_timing self.dropout = nn.Dropout(p=dropout) self.index_embedding = Fundamental_Music_Embedding( d_model = d_model, base=10000, if_trainable=False, translation_bias_type = None, if_translation_bias_trainable = False, type = "se" )# .cuda() self.global_time_embedding = Fundamental_Music_Embedding( d_model = d_model, base=10001, if_trainable=False, translation_bias_type = None, if_translation_bias_trainable = False, type = "se" )# .cuda() self.modulo_time_embedding = Fundamental_Music_Embedding( d_model = d_model, base=10001, if_trainable=False, translation_bias_type = None, if_translation_bias_trainable = False, type = "se" )# .cuda() position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) ''' if self.if_global_timing: print("pe add global time") if self.if_modulo_timing: print("pe add modulo time") if self.if_index: print("pe add idx") ''' def forward(self, inp,dur_onset_cumsum = None): if self.if_index: pe_index = self.pe[:inp.size(1)] #[seq_len, batch_size, embedding_dim] pe_index = torch.swapaxes(pe_index, 0, 1) #[batch_size, seq_len, embedding_dim] inp += pe_index if self.if_global_timing: global_timing = dur_onset_cumsum global_timing_embedding = self.global_time_embedding(global_timing) inp += global_timing_embedding if self.if_modulo_timing: modulo_timing = dur_onset_cumsum%4 modulo_timing_embedding = self.modulo_time_embedding(modulo_timing) inp += modulo_timing_embedding return self.dropout(inp) class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): pos = self.pe[:x.size(1)] #[seq_len, batch_size, embedding_dim] pos = torch.swapaxes(pos, 0, 1) #[batch_size, seq_len, embedding_dim] x = x + pos return self.dropout(x) class chord_tokenizer(): def __init__(self,seq_len_chord=88,if_pad = True): # self.pitch_dict = {'pad': 0, "None":1, "A": 2, "A#": 3, "Bb":3, "B":4, "C":5, "C#":6, "Db":6, "D": 7, "D#":8, "Eb":8, "E": 9 ,"F":10, "F#":11, "Gb":11, "G":12, "G#":13, "Ab":13} self.pitch_dict = {'pad': 0, "None":1, "N":1, "A": 2, "A#": 3, "Bb":3, "B":4, "Cb": 4, "B#":5, "C":5, "C#":6, "Db":6, "D": 7, "D#":8, "Eb":8, "E": 9 , "Fb": 9, "E#": 10, "F":10, "F#":11, "Gb":11, "G":12, "G#":13, "Ab":13} self.chord_type_dict = {'pad': 0, "None": 1,"N": 1, "maj": 2, "maj7": 3, "m": 4, "m6": 5, "m7": 6, "m7b5": 7, "6": 8, "7": 9, "aug": 10, "dim":11} #, "/": self.chord_inversion_dict = {'pad': 0, "None":1, "N":1,"inv": 2, "no_inv":3} self.seq_len_chord = seq_len_chord self.if_pad = if_pad def __call__(self, chord, chord_time): if len(chord)==0: chord, chord_time = ["N"], [0.] if self.if_pad: pad_len_chord = self.seq_len_chord - len(chord) chord_mask = [True]*len(chord) +[False]*pad_len_chord chord += ["pad"]*pad_len_chord chord_time += [chord_time[-1]]*pad_len_chord else: chord_mask = [True]*len(chord) self.chord_root, self.chord_type, self.chord_inv = self.tokenize_chord_lst(chord) self.chord_time = chord_time self.chord_mask = chord_mask # print("out",self.chord_root, self.chord_type, self.chord_inv, self.chord_time, self.chord_mask) return self.chord_root, self.chord_type, self.chord_inv, self.chord_time, self.chord_mask def get_chord_root_type_inversion_timestamp(self, chord): if chord =="pad": return "pad", "pad", "pad" if chord =="N": return "N", "N", "N" if len(chord.split('/'))>1: chord_inv = "inv" else: chord_inv = "no_inv" chord_wo_inv = chord.split('/')[0] if len(chord_wo_inv)>1: # this part might have a '#' or 'b' if chord_wo_inv[1]=='#' or chord_wo_inv[1]=='b': chord_root=chord_wo_inv[0:2] else: chord_root=chord_wo_inv[0] else: chord_root=chord_wo_inv[0] if len(chord_wo_inv)>len(chord_root): chord_type=chord_wo_inv[len(chord_root):] else: chord_type='maj' return chord_root, chord_type, chord_inv def tokenize_chord_lst(self, chord_lst): out_root = [] out_type = [] out_inv = [] for chord in chord_lst: chord_root, chord_type, chord_inversion= self.get_chord_root_type_inversion_timestamp(chord) out_root.append(self.pitch_dict[chord_root]) out_type.append(self.chord_type_dict[chord_type]) out_inv.append(self.chord_inversion_dict[chord_inversion]) return out_root, out_type, out_inv class beat_tokenizer(): def __init__(self,seq_len_beat=88,if_pad = True): self.beat_dict = {'pad': 0, "None":1, 1.: 2, 2.: 3, 3.:4, 4.:5, 5.:6, 6.:7, 7.:8} self.if_pad = if_pad self.seq_len_beat = seq_len_beat def __call__(self, beat_lst): # beats = [[0.56, 1.1, 1.66, 2.24, 2.8, 3.36, 3.92, 4.48, 5.04, 5.6, 6.16, 6.74, 7.32, 7.9, 8.46, 9.0, 9.58], [3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0]] if self.if_pad: if len(beat_lst[0])==0: beat_mask = [False]*self.seq_len_beat beat_lst = [[0.]*self.seq_len_beat, ["pad"]*self.seq_len_beat] else: pad_len_beat = self.seq_len_beat - len(beat_lst[0]) beat_mask = [True]*len(beat_lst[0]) +[False]*pad_len_beat beat_lst = [beat_lst[0]+[beat_lst[0][-1]]*pad_len_beat, beat_lst[1]+["pad"]*pad_len_beat ] else: beat_mask = [True]*len(beat_lst[0]) self.beat = [self.beat_dict[x] for x in beat_lst[1]] self.beat_timing = beat_lst[0] return self.beat, self.beat_timing, beat_mask # class beat_tokenizer_by_frame(): # def __init__(self, frame_resolution = 0.01, max_len = 10): # def __call__(self, beat_lst): # def timestamp2frame(,frame_resolution, max_len): # def frame2timestamp(frame_resolution, man_len) def l2_norm(a, b): return torch.linalg.norm(a-b, ord = 2, dim = -1) def rounding(x): return x-torch.sin(2.*math.pi*x)/(2.*math.pi) class Chord_Embedding(nn.Module): def __init__(self, FME, PE, d_model = 256, d_oh_type = 12, d_oh_inv = 4): super().__init__() self.FME = FME self.PE = PE self.d_model = d_model self.d_oh_type = d_oh_type self.d_oh_inv = d_oh_inv self.chord_ffn = nn.Linear(d_oh_type + d_oh_inv + d_model + d_model, d_model) #.cuda() def __call__(self, chord_root, chord_type, chord_inv, chord_timing, device): #chords: (B, LEN, 4) #Embed root using FME #Embed chord type, chord inversion using OH #Embed timestamps using shared PE chord_root_emb = self.FME(chord_root, device) # print(chord_root_emb.size()) # print('this is chord root: ', chord_root) # print('this is chord type: ', chord_type) # print('this is chord inv: ', chord_inv) # chord_root_emb = torch.randn((2,20,1024)).cuda() # print(chord_root_emb.device) # chord_root_emb = F.one_hot(chord_type.to(torch.int64), num_classes = self.d_model).to(torch.float32) chord_type_emb = F.one_hot(chord_type.to(torch.int64), num_classes = self.d_oh_type).to(torch.float32) chord_inv_emb = F.one_hot(chord_inv.to(torch.int64), num_classes = self.d_oh_inv).to(torch.float32) chord_time_emb = self.PE.global_time_embedding(chord_timing, device) chord_emb = self.chord_ffn(torch.cat((chord_root_emb, chord_type_emb, chord_inv_emb, chord_time_emb), dim = -1).to(device)) # print("TADY toje", chord_emb.device) return chord_emb class Beat_Embedding(nn.Module): def __init__(self, PE, d_model = 256, d_oh_beat_type = 4): super().__init__() self.PE = PE self.d_model = d_model self.d_oh_beat_type = d_oh_beat_type self.beat_ffn = nn.Linear(d_oh_beat_type+d_model, d_model) def __call__(self, beats, beats_timing, device): #Embed beat type using OH #Embed time using PE beat_type_emb = F.one_hot(beats.to(torch.int64), num_classes = self.d_oh_beat_type).to(torch.float32).to(device) beat_time_emb = self.PE.global_time_embedding(beats_timing, device) merged_beat = torch.cat((beat_type_emb, beat_time_emb), dim = -1) beat_emb = self.beat_ffn(merged_beat) return beat_emb if __name__ == "__main__": config_path = "/data/nicolas/TANGO/config/model_embedding_config.yaml" with open (config_path, 'r') as f: cfg = yaml.safe_load(f) beats = [[0.56, 1.1, 1.66, 2.24, 2.8, 3.36, 3.92, 4.48, 5.04, 5.6, 6.16, 6.74, 7.32, 7.9, 8.46, 9.0, 9.58], [3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0]] beats = np.array(beats).T.tolist() chords = [["Gm", 0.464399092], ["Eb", 1.393197278], ["F", 3.157913832], ["Bb", 4.736870748], ["F7", 5.758548752], ["Gm", 6.501587301], ["Eb", 8.173424036], ["F7", 9.938140589]] chord_tokenizer = chord_tokenizer(seq_len_chord=30,if_pad = True) beat_tokenizer = beat_tokenizer(seq_len_beat=17,if_pad = True) #TOKENIZE CHORDS AND BEATS AT DATALOADING PART chord_tokens, chord_masks = chord_tokenizer(chords)#adding batch dimension beat_tokens, beat_masks = beat_tokenizer(beats) chord_tokens, chord_masks, beat_tokens, beat_masks = chord_tokens[None, ...], chord_masks[None, ...], beat_tokens[None, ...], beat_masks[None, ...] #adding batch dimension print("tokeninzing chords and beats", chord_tokens.shape, beat_tokens.shape) #EMBEDDING CHORDS AND BEATS WITHIN THE MODEL FME = Fundamental_Music_Embedding(**cfg["FME_embedding_conf"]) PE = Music_PositionalEncoding(**cfg["Position_encoding_conf"]) chord_embedding_layer = Chord_Embedding(FME, PE, **cfg["Chord_Embedding_conf"]) chord_embedded = chord_embedding_layer(chord_tokens) beat_embedding_layer = Beat_Embedding(PE, **cfg["Beat_Embedding_conf"]) beat_embedded = beat_embedding_layer(beat_tokens) print("embedding tokenized chords and beats", chord_embedded.shape, beat_embedded.shape)