mustango / layers /layers.py
deepanway's picture
update files for device agnostic inference
9e0eee2
raw
history blame
14.5 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import json
import yaml
class Fundamental_Music_Embedding(nn.Module):
def __init__(self, d_model, base, if_trainable = False, if_translation_bias_trainable = True, device='cpu', type = "se",emb_nn=None,translation_bias_type = "nd"):
super().__init__()
self.d_model = d_model
self.device = device
self.base = base
self.if_trainable = if_trainable #whether the se is trainable
if translation_bias_type is not None:
self.if_translation_bias = True
self.if_translation_bias_trainable = if_translation_bias_trainable #default the 2d vector is trainable
if translation_bias_type=="2d":
translation_bias = torch.rand((1, 2), dtype = torch.float32) #Returns a tensor filled with random numbers from a uniform distribution on the interval [0, 1)[0,1)
elif translation_bias_type=="nd":
translation_bias = torch.rand((1, self.d_model), dtype = torch.float32)
translation_bias = nn.Parameter(translation_bias, requires_grad=True)
self.register_parameter("translation_bias", translation_bias)
else:
self.if_translation_bias = False
i = torch.arange(d_model)
angle_rates = 1 / torch.pow(self.base, (2 * (i//2)) / d_model)
angle_rates = angle_rates[None, ... ]#.cuda()
if self.if_trainable:
angles = nn.Parameter(angle_rates, requires_grad=True)
self.register_parameter("angles", angles)
else:
self.angles = angle_rates
def __call__(self, inp, device):
if inp.dim()==2:
inp = inp[..., None] #pos (batch, num_pitch, 1)
elif inp.dim()==1:
inp = inp[None, ..., None] #pos (1, num_pitch, 1)
angle_rads = inp*self.angles.to(device) #(batch, num_pitch)*(1,dim)
# apply sin to even indices in the array; 2i
angle_rads[:, :, 0::2] = torch.sin(angle_rads.clone()[:, : , 0::2])
# apply cos to odd indices in the array; 2i+1
angle_rads[:, :, 1::2] = torch.cos(angle_rads.clone()[:, :, 1::2])
pos_encoding = angle_rads.to(torch.float32)
if self.if_translation_bias:
if self.translation_bias.size()[-1]!= self.d_model:
translation_bias = self.translation_bias.repeat(1, 1,int(self.d_model/2))
else:
translation_bias = self.translation_bias
pos_encoding += translation_bias
else:
self.translation_bias = None
return pos_encoding
class Music_PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, if_index = True, if_global_timing = True, if_modulo_timing = True, device = 'cuda:0'):
super().__init__()
self.if_index = if_index
self.if_global_timing = if_global_timing
self.if_modulo_timing = if_modulo_timing
self.dropout = nn.Dropout(p=dropout)
self.index_embedding = Fundamental_Music_Embedding(
d_model = d_model, base=10000, if_trainable=False, translation_bias_type = None,
if_translation_bias_trainable = False, type = "se"
)# .cuda()
self.global_time_embedding = Fundamental_Music_Embedding(
d_model = d_model, base=10001, if_trainable=False, translation_bias_type = None,
if_translation_bias_trainable = False, type = "se"
)# .cuda()
self.modulo_time_embedding = Fundamental_Music_Embedding(
d_model = d_model, base=10001, if_trainable=False, translation_bias_type = None,
if_translation_bias_trainable = False, type = "se"
)# .cuda()
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
'''
if self.if_global_timing:
print("pe add global time")
if self.if_modulo_timing:
print("pe add modulo time")
if self.if_index:
print("pe add idx")
'''
def forward(self, inp,dur_onset_cumsum = None):
if self.if_index:
pe_index = self.pe[:inp.size(1)] #[seq_len, batch_size, embedding_dim]
pe_index = torch.swapaxes(pe_index, 0, 1) #[batch_size, seq_len, embedding_dim]
inp += pe_index
if self.if_global_timing:
global_timing = dur_onset_cumsum
global_timing_embedding = self.global_time_embedding(global_timing)
inp += global_timing_embedding
if self.if_modulo_timing:
modulo_timing = dur_onset_cumsum%4
modulo_timing_embedding = self.modulo_time_embedding(modulo_timing)
inp += modulo_timing_embedding
return self.dropout(inp)
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
pos = self.pe[:x.size(1)] #[seq_len, batch_size, embedding_dim]
pos = torch.swapaxes(pos, 0, 1) #[batch_size, seq_len, embedding_dim]
x = x + pos
return self.dropout(x)
class chord_tokenizer():
def __init__(self,seq_len_chord=88,if_pad = True):
# self.pitch_dict = {'pad': 0, "None":1, "A": 2, "A#": 3, "Bb":3, "B":4, "C":5, "C#":6, "Db":6, "D": 7, "D#":8, "Eb":8, "E": 9 ,"F":10, "F#":11, "Gb":11, "G":12, "G#":13, "Ab":13}
self.pitch_dict = {'pad': 0, "None":1, "N":1, "A": 2, "A#": 3, "Bb":3, "B":4, "Cb": 4, "B#":5, "C":5, "C#":6, "Db":6, "D": 7, "D#":8, "Eb":8, "E": 9 , "Fb": 9, "E#": 10, "F":10, "F#":11, "Gb":11, "G":12, "G#":13, "Ab":13}
self.chord_type_dict = {'pad': 0, "None": 1,"N": 1, "maj": 2, "maj7": 3, "m": 4, "m6": 5, "m7": 6, "m7b5": 7, "6": 8, "7": 9, "aug": 10, "dim":11} #, "/":
self.chord_inversion_dict = {'pad': 0, "None":1, "N":1,"inv": 2, "no_inv":3}
self.seq_len_chord = seq_len_chord
self.if_pad = if_pad
def __call__(self, chord, chord_time):
if len(chord)==0:
chord, chord_time = ["N"], [0.]
if self.if_pad:
pad_len_chord = self.seq_len_chord - len(chord)
chord_mask = [True]*len(chord) +[False]*pad_len_chord
chord += ["pad"]*pad_len_chord
chord_time += [chord_time[-1]]*pad_len_chord
else:
chord_mask = [True]*len(chord)
self.chord_root, self.chord_type, self.chord_inv = self.tokenize_chord_lst(chord)
self.chord_time = chord_time
self.chord_mask = chord_mask
# print("out",self.chord_root, self.chord_type, self.chord_inv, self.chord_time, self.chord_mask)
return self.chord_root, self.chord_type, self.chord_inv, self.chord_time, self.chord_mask
def get_chord_root_type_inversion_timestamp(self, chord):
if chord =="pad":
return "pad", "pad", "pad"
if chord =="N":
return "N", "N", "N"
if len(chord.split('/'))>1:
chord_inv = "inv"
else:
chord_inv = "no_inv"
chord_wo_inv = chord.split('/')[0]
if len(chord_wo_inv)>1: # this part might have a '#' or 'b'
if chord_wo_inv[1]=='#' or chord_wo_inv[1]=='b':
chord_root=chord_wo_inv[0:2]
else:
chord_root=chord_wo_inv[0]
else:
chord_root=chord_wo_inv[0]
if len(chord_wo_inv)>len(chord_root):
chord_type=chord_wo_inv[len(chord_root):]
else:
chord_type='maj'
return chord_root, chord_type, chord_inv
def tokenize_chord_lst(self, chord_lst):
out_root = []
out_type = []
out_inv = []
for chord in chord_lst:
chord_root, chord_type, chord_inversion= self.get_chord_root_type_inversion_timestamp(chord)
out_root.append(self.pitch_dict[chord_root])
out_type.append(self.chord_type_dict[chord_type])
out_inv.append(self.chord_inversion_dict[chord_inversion])
return out_root, out_type, out_inv
class beat_tokenizer():
def __init__(self,seq_len_beat=88,if_pad = True):
self.beat_dict = {'pad': 0, "None":1, 1.: 2, 2.: 3, 3.:4, 4.:5, 5.:6, 6.:7, 7.:8}
self.if_pad = if_pad
self.seq_len_beat = seq_len_beat
def __call__(self, beat_lst):
# beats = [[0.56, 1.1, 1.66, 2.24, 2.8, 3.36, 3.92, 4.48, 5.04, 5.6, 6.16, 6.74, 7.32, 7.9, 8.46, 9.0, 9.58], [3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0]]
if self.if_pad:
if len(beat_lst[0])==0:
beat_mask = [False]*self.seq_len_beat
beat_lst = [[0.]*self.seq_len_beat, ["pad"]*self.seq_len_beat]
else:
pad_len_beat = self.seq_len_beat - len(beat_lst[0])
beat_mask = [True]*len(beat_lst[0]) +[False]*pad_len_beat
beat_lst = [beat_lst[0]+[beat_lst[0][-1]]*pad_len_beat, beat_lst[1]+["pad"]*pad_len_beat ]
else:
beat_mask = [True]*len(beat_lst[0])
self.beat = [self.beat_dict[x] for x in beat_lst[1]]
self.beat_timing = beat_lst[0]
return self.beat, self.beat_timing, beat_mask
# class beat_tokenizer_by_frame():
# def __init__(self, frame_resolution = 0.01, max_len = 10):
# def __call__(self, beat_lst):
# def timestamp2frame(,frame_resolution, max_len):
# def frame2timestamp(frame_resolution, man_len)
def l2_norm(a, b):
return torch.linalg.norm(a-b, ord = 2, dim = -1)
def rounding(x):
return x-torch.sin(2.*math.pi*x)/(2.*math.pi)
class Chord_Embedding(nn.Module):
def __init__(self, FME, PE, d_model = 256, d_oh_type = 12, d_oh_inv = 4):
super().__init__()
self.FME = FME
self.PE = PE
self.d_model = d_model
self.d_oh_type = d_oh_type
self.d_oh_inv = d_oh_inv
self.chord_ffn = nn.Linear(d_oh_type + d_oh_inv + d_model + d_model, d_model) #.cuda()
def __call__(self, chord_root, chord_type, chord_inv, chord_timing, device):
#chords: (B, LEN, 4)
#Embed root using FME
#Embed chord type, chord inversion using OH
#Embed timestamps using shared PE
chord_root_emb = self.FME(chord_root, device)
# print(chord_root_emb.size())
# print('this is chord root: ', chord_root)
# print('this is chord type: ', chord_type)
# print('this is chord inv: ', chord_inv)
# chord_root_emb = torch.randn((2,20,1024)).cuda()
# print(chord_root_emb.device)
# chord_root_emb = F.one_hot(chord_type.to(torch.int64), num_classes = self.d_model).to(torch.float32)
chord_type_emb = F.one_hot(chord_type.to(torch.int64), num_classes = self.d_oh_type).to(torch.float32)
chord_inv_emb = F.one_hot(chord_inv.to(torch.int64), num_classes = self.d_oh_inv).to(torch.float32)
chord_time_emb = self.PE.global_time_embedding(chord_timing, device)
chord_emb = self.chord_ffn(torch.cat((chord_root_emb, chord_type_emb, chord_inv_emb, chord_time_emb), dim = -1).to(device))
# print("TADY toje", chord_emb.device)
return chord_emb
class Beat_Embedding(nn.Module):
def __init__(self, PE, d_model = 256, d_oh_beat_type = 4):
super().__init__()
self.PE = PE
self.d_model = d_model
self.d_oh_beat_type = d_oh_beat_type
self.beat_ffn = nn.Linear(d_oh_beat_type+d_model, d_model)
def __call__(self, beats, beats_timing, device):
#Embed beat type using OH
#Embed time using PE
beat_type_emb = F.one_hot(beats.to(torch.int64), num_classes = self.d_oh_beat_type).to(torch.float32).to(device)
beat_time_emb = self.PE.global_time_embedding(beats_timing, device)
merged_beat = torch.cat((beat_type_emb, beat_time_emb), dim = -1)
beat_emb = self.beat_ffn(merged_beat)
return beat_emb
if __name__ == "__main__":
config_path = "/data/nicolas/TANGO/config/model_embedding_config.yaml"
with open (config_path, 'r') as f:
cfg = yaml.safe_load(f)
beats = [[0.56, 1.1, 1.66, 2.24, 2.8, 3.36, 3.92, 4.48, 5.04, 5.6, 6.16, 6.74, 7.32, 7.9, 8.46, 9.0, 9.58], [3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0]]
beats = np.array(beats).T.tolist()
chords = [["Gm", 0.464399092], ["Eb", 1.393197278], ["F", 3.157913832], ["Bb", 4.736870748], ["F7", 5.758548752], ["Gm", 6.501587301], ["Eb", 8.173424036], ["F7", 9.938140589]]
chord_tokenizer = chord_tokenizer(seq_len_chord=30,if_pad = True)
beat_tokenizer = beat_tokenizer(seq_len_beat=17,if_pad = True)
#TOKENIZE CHORDS AND BEATS AT DATALOADING PART
chord_tokens, chord_masks = chord_tokenizer(chords)#adding batch dimension
beat_tokens, beat_masks = beat_tokenizer(beats)
chord_tokens, chord_masks, beat_tokens, beat_masks = chord_tokens[None, ...], chord_masks[None, ...], beat_tokens[None, ...], beat_masks[None, ...] #adding batch dimension
print("tokeninzing chords and beats", chord_tokens.shape, beat_tokens.shape)
#EMBEDDING CHORDS AND BEATS WITHIN THE MODEL
FME = Fundamental_Music_Embedding(**cfg["FME_embedding_conf"])
PE = Music_PositionalEncoding(**cfg["Position_encoding_conf"])
chord_embedding_layer = Chord_Embedding(FME, PE, **cfg["Chord_Embedding_conf"])
chord_embedded = chord_embedding_layer(chord_tokens)
beat_embedding_layer = Beat_Embedding(PE, **cfg["Beat_Embedding_conf"])
beat_embedded = beat_embedding_layer(beat_tokens)
print("embedding tokenized chords and beats", chord_embedded.shape, beat_embedded.shape)