import sys sys.path.append('code') import numpy as np import os import pretty_midi as pyd import torch import sys from model import VAE from util_tools.format_converter import melody_data2matrix, melody_matrix2data, chord_data2matrix, chord_matrix2data from torch.distributions import kl_divergence, Normal from nottingham_dataset import Nottingham import copy def chord_grid2data(est_pitch, bpm=60., start=0., max_simu_note=6, pitch_eos=129, num_step=32, min_pitch=0): est_pitch = est_pitch[:, :, 0]#.cpu().detach().numpy() #(32, max_simu_note-1), NO BATCH HERE if est_pitch.shape[1] == max_simu_note: est_pitch = est_pitch[:, 1:] #print(est_pitch.shape) #print(est_pitch) harmonic_rhythm = 1. - (est_pitch[:, 0]==pitch_eos) * 1. #print(harmonic_rhythm) pr = np.zeros((32, 128), dtype=int) alpha = 0.25 * 60 / bpm notes = [] for t in range(num_step): for n in range(max_simu_note-1): note = est_pitch[t, n] if note == pitch_eos: break pitch = note + 12*4 duration = 1 for j in range(t+1, num_step): if harmonic_rhythm[j] == 1: break duration +=1 pr[t, pitch] = min(duration, 32 - t) notes.append( pyd.Note(100, int(pitch), start + t * alpha, start + (t + duration) * alpha)) chord = pyd.Instrument(program=pyd.instrument_name_to_program('Acoustic Grand Piano')) chord.notes = notes return chord def melody_matrix2data(melody_matrix, tempo=120, start_time=0.0, get_list=False): HOLD_PITCH = 12 REST_PITCH = 13 #melodyMatrix = melody_matrix[:, :ROLL_SIZE] chroma = np.concatenate((melody_matrix[:, :12], melody_matrix[:, 15: 17]), axis=-1) register = melody_matrix[:, -10:] #print(chroma.shape) melodySequence = np.argmax(chroma, axis=-1) #print(melodySequence) melody_notes = [] minStep = 60 / tempo / 4 onset_or_rest = [i for i in range(len(melodySequence)) if not melodySequence[i]==HOLD_PITCH] onset_or_rest.append(len(melodySequence)) for idx, onset in enumerate(onset_or_rest[:-1]): if melodySequence[onset] == REST_PITCH: continue else: pitch = melodySequence[onset] + 12 * np.argmax(register[onset]) #print(pitch) start = onset * minStep end = onset_or_rest[idx+1] * minStep noteRecon = pyd.Note(velocity=100, pitch=pitch, start=start_time+start, end=start_time+end) melody_notes.append(noteRecon) if get_list: return melody_notes else: melody = pyd.Instrument(program=pyd.instrument_name_to_program('Acoustic Grand Piano')) melody.notes = melody_notes return melody def get_gt(chord, melody): #chord: (num_step, max_simu_note, 1), numpy #melody: (num_step, 28), numpy chord_recon = chord_grid2data(chord, 30, pitch_eos=13) melody_recon = melody_matrix2data(melody, 120) music = pyd.PrettyMIDI(initial_tempo=120) music.instruments.append(melody_recon) music.instruments.append(chord_recon) return music def shift(original_melody, p_shift): melody = copy.deepcopy(original_melody).cpu().detach().numpy()[0] onsets, pitch = np.nonzero(melody[:, :12]) onsets, register = np.nonzero(melody[:, -10:]) onset130 = pitch + register*12 onset130 += p_shift onset12 = onset130 % 12 register = onset130 // 12 melody[onsets, :] = 0 melody[onsets, onset12] = 1. melody[onsets, register+17] = 1. return torch.from_numpy(melody).float().unsqueeze(0) def reconstruct(chord, melody): #chord: (1, num_step, max_simu_note, 1), torch.LongTensor, cuda() #melody: (1, num_step*4, 28), torch.FloatTensor, cuda() lengths = model.get_len_index_tensor(chord) # lengths: (B, num_step) chord = model.index_tensor_to_multihot_tensor(chord) chord = model.enc_note_embedding(chord) #(B, num_step, max_simu_note, note_emb_size) mel_ebd = model.enc_note_embedding(melody) #(B, num_step*4, note_emb_size) melody_beat_summary = mel_ebd[:, ::4, :] + mel_ebd[:, 1::4, :] + mel_ebd[:, 2::4, :] + mel_ebd[:, 3::4, :] dist, mu, = model.encoder(chord, lengths, melody_beat_summary) z = dist.mean pitch_outs = model.decoder(z, melody_beat_summary, inference=True, x=None, lengths=None, teacher_forcing_ratio1=0., teacher_forcing_ratio2=0.) pitch_outs = pitch_outs.max(-1, keepdim=True)[1] pitch_outs = pitch_outs.cpu().detach().numpy() chord_track = chord_grid2data(pitch_outs[0], bpm=120//4, start=0, pitch_eos=13) melody = melody.cpu().detach().numpy()[0] melody_track = melody_matrix2data(melody, tempo=120) music = pyd.PrettyMIDI() music.instruments.append(melody_track) music.instruments.append(chord_track) return music def melody_control(chord, melody, new_melody): #chord: (B, num_step, max_simu_note, 1), torch.LongTensor, cuda() #melody: (B, num_step*4, 28), torch.FloatTensor, cuda() #new_melody: (B, num_step*4, 28), torch.FloatTensor, cuda() lengths = model.get_len_index_tensor(chord) # lengths: (B, num_step) chord = model.index_tensor_to_multihot_tensor(chord) chord = model.enc_note_embedding(chord) #(B, num_step, max_simu_note, note_emb_size) mel_ebd = model.enc_note_embedding(melody) #(B, num_step*4, note_emb_size) melody_beat_summary = mel_ebd[:, ::4, :] + mel_ebd[:, 1::4, :] + mel_ebd[:, 2::4, :] + mel_ebd[:, 3::4, :] new_mel_ebd = model.enc_note_embedding(new_melody) #(B, num_step*4, note_emb_size) new_melody_beat_summary = new_mel_ebd[:, ::4, :] + new_mel_ebd[:, 1::4, :] + new_mel_ebd[:, 2::4, :] + new_mel_ebd[:, 3::4, :] dist, mu, = model.encoder(chord, lengths, melody_beat_summary) z = dist.mean pitch_outs = model.decoder(z, new_melody_beat_summary, inference=True, x=None, lengths=None, teacher_forcing_ratio1=0., teacher_forcing_ratio2=0.) pitch_outs = pitch_outs.max(-1, keepdim=True)[1] pitch_outs = pitch_outs.cpu().detach().numpy() chord_track = chord_grid2data(pitch_outs[0], bpm=120//4, start=0, pitch_eos=13) new_melody = new_melody.cpu().detach().numpy()[0] melody_track = melody_matrix2data(new_melody, tempo=120) music = pyd.PrettyMIDI() music.instruments.append(melody_track) music.instruments.append(chord_track) return music def melody_prior_control(new_melody): #new_melody: (B, num_step*4, 28), torch.FloatTensor, cuda() new_mel_ebd = model.enc_note_embedding(new_melody) #(B, num_step*4, note_emb_size) new_melody_beat_summary = new_mel_ebd[:, ::4, :] + new_mel_ebd[:, 1::4, :] + new_mel_ebd[:, 2::4, :] + new_mel_ebd[:, 3::4, :] z = Normal(torch.zeros(128), torch.ones(128)).rsample().unsqueeze(0) pitch_outs = model.decoder(z, new_melody_beat_summary, inference=True, x=None, lengths=None, teacher_forcing_ratio1=0., teacher_forcing_ratio2=0.) pitch_outs = pitch_outs.max(-1, keepdim=True)[1] pitch_outs = pitch_outs.cpu().detach().numpy() chord_track = chord_grid2data(pitch_outs[0], bpm=120//4, start=0, pitch_eos=13) new_melody = new_melody.cpu().detach().numpy()[0] melody_track = melody_matrix2data(new_melody, tempo=120) music = pyd.PrettyMIDI() music.instruments.append(melody_track) music.instruments.append(chord_track) return music import utils config_fn = './code/model_config.json' #train_hyperparams = utils.load_params_dict('train_hyperparams', config_fn) model_params = utils.load_params_dict('model_params', config_fn) data_repr_params = utils.load_params_dict('data_repr', config_fn) #project_params = utils.load_params_dict('project', config_fn) #dataset_path = utils.load_params_dict('dataset_paths', config_fn) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = VAE(max_simu_note=data_repr_params['max_simu_note'], max_pitch=data_repr_params['max_pitch'], min_pitch=data_repr_params['min_pitch'], pitch_sos=data_repr_params['pitch_sos'], pitch_eos=data_repr_params['pitch_eos'], pitch_pad=data_repr_params['pitch_pad'], num_step=data_repr_params['num_time_step'], note_emb_size=model_params['note_emb_size'], enc_notes_hid_size=model_params['enc_notes_hid_size'], enc_time_hid_size=model_params['enc_time_hid_size'], z_size=model_params['z_size'], dec_emb_hid_size=model_params['dec_emb_hid_size'], dec_time_hid_size=model_params['dec_time_hid_size'], dec_notes_hid_size=model_params['dec_notes_hid_size'], discr_nhead = model_params["discr_nhead"], discr_hid_size = model_params["discr_hid_size"], discr_dropout = model_params["discr_dropout"], discr_nlayer = model_params["discr_nlayer"], device=device ) weight_path = './code/ad-ptvae_param.pt' params = torch.load(weight_path,map_location=torch.device(device)) if 'model_state_dict' in params: params = params['model_state_dict'] model.load_state_dict(params) if torch.cuda.is_available(): model.cuda() else: model.cpu() model.eval() print('-'*100) print(f'Loaded {weight_path}') print('-'*100) dataset = np.load('./code/data.npy', allow_pickle=True).T print('-'*100) print(f'Loaded ./code/data.npy') print('-'*100) np.random.seed(0) np.random.shuffle(dataset) anchor = int(dataset.shape[0] * 0.95) val_data = dataset[anchor:, :] val_set = Nottingham(dataset=val_data.T, length=128, step_size=16, chord_fomat='pr', shift_high=0, shift_low=0) print(len(val_set)) WRITE_PATH = './code/demo_generate' if not os.path.exists(WRITE_PATH): os.makedirs(WRITE_PATH) chord_1, _, melody_1, _ = val_set.__getitem__(338) music = get_gt(chord_1, melody_1) music.write(os.path.join(WRITE_PATH, 'gt_1.mid')) chord_1 = torch.from_numpy(chord_1).long().unsqueeze(0) melody_1 = torch.from_numpy(melody_1).float().unsqueeze(0) music = reconstruct(chord_1, melody_1) music.write(os.path.join(WRITE_PATH, 'recon_1.mid')) print(f'Saved to {WRITE_PATH}/recon_1.mid') chord_2, _, melody_2, _ = val_set.__getitem__(2749) music = get_gt(chord_2, melody_2) music.write(os.path.join(WRITE_PATH, 'gt_2.mid')) chord_2 = torch.from_numpy(chord_2).long().unsqueeze(0) melody_2 = torch.from_numpy(melody_2).float().unsqueeze(0) music = reconstruct(chord_2, melody_2) music.write(os.path.join(WRITE_PATH, 'recon_2.mid')) print(f'Saved to {WRITE_PATH}/recon_2.mid') chord_3, _, melody_3, _ = val_set.__getitem__(3413) music = get_gt(chord_3, melody_3) music.write(os.path.join(WRITE_PATH, 'gt_3.mid')) chord_3 = torch.from_numpy(chord_3).long().unsqueeze(0) melody_3 = torch.from_numpy(melody_3).float().unsqueeze(0) music = reconstruct(chord_3, melody_3) music.write(os.path.join(WRITE_PATH, 'recon_3.mid')) print(f'Saved to {WRITE_PATH}/recon_3.mid') chord_4, _, melody_4, _ = val_set.__getitem__(5126) music = get_gt(chord_4, melody_4) music.write(os.path.join(WRITE_PATH, 'gt_4.mid')) chord_4 = torch.from_numpy(chord_4).long().unsqueeze(0) melody_4 = torch.from_numpy(melody_4).float().unsqueeze(0) music = reconstruct(chord_4, melody_4) music.write(os.path.join(WRITE_PATH, 'recon_4.mid')) midi = pyd.PrettyMIDI(f'{WRITE_PATH}/recon_1.mid') melody = melody_data2matrix(midi.instruments[0], midi.get_downbeats()) melody = val_set.truncate_melody(melody) melody_1_modal_change = torch.from_numpy(melody).float().unsqueeze(0) midi = pyd.PrettyMIDI(f'{WRITE_PATH}/recon_2.mid') melody = melody_data2matrix(midi.instruments[0], midi.get_downbeats()) melody = val_set.truncate_melody(melody) melody_2_modal_change = torch.from_numpy(melody).float().unsqueeze(0) midi = pyd.PrettyMIDI(f'{WRITE_PATH}/recon_3.mid') melody = melody_data2matrix(midi.instruments[0], midi.get_downbeats()) melody = val_set.truncate_melody(melody) melody_3_modal_change = torch.from_numpy(melody).float().unsqueeze(0) midi = pyd.PrettyMIDI(f'{WRITE_PATH}/recon_4.mid') melody = melody_data2matrix(midi.instruments[0], midi.get_downbeats()) melody = val_set.truncate_melody(melody) melody_4_modal_change = torch.from_numpy(melody).float().unsqueeze(0) music = melody_control(chord_1, melody_1, shift(melody_1, 6)) music.write(os.path.join(WRITE_PATH, 'control_1_transpose.mid')) print(f'Saved to {WRITE_PATH}/control_1_transpose.mid') music = melody_control(chord_2, melody_2, shift(melody_2, 6)) music.write(os.path.join(WRITE_PATH, 'control_2_transpose.mid')) print(f'Saved to {WRITE_PATH}/control_2_transpose.mid') music = melody_control(chord_3, melody_3, shift(melody_3, 6)) music.write(os.path.join(WRITE_PATH, 'control_3_transpose.mid')) print(f'Saved to {WRITE_PATH}/control_3_transpose.mid') music = melody_control(chord_4, melody_4, shift(melody_4, 6)) music.write(os.path.join(WRITE_PATH, 'control_4_transpose1.mid')) print(f'Saved to {WRITE_PATH}/control_4_transpose.mid') music = melody_control(chord_1, melody_1, melody_1_modal_change) music.write(os.path.join(WRITE_PATH, 'control_1_modal_change.mid')) print(f'Saved to {WRITE_PATH}/control_1_modal_change.mid') music = melody_control(chord_2, melody_2, melody_2_modal_change) music.write(os.path.join(WRITE_PATH, 'control_2_modal_change.mid')) print(f'Saved to {WRITE_PATH}/control_2_modal_change.mid') music = melody_control(chord_3, melody_3, melody_3_modal_change) music.write(os.path.join(WRITE_PATH, 'control_3_modal_change.mid')) print(f'Saved to {WRITE_PATH}/control_3_modal_change.mid') music = melody_control(chord_4, melody_4, melody_4_modal_change) music.write(os.path.join(WRITE_PATH, 'control_4_modal_change.mid')) print(f'Saved to {WRITE_PATH}/control_4_modal_change.mid') music = melody_prior_control(melody_1) music.write(os.path.join(WRITE_PATH, 'control_1_prior.mid')) print(f'Saved to {WRITE_PATH}/control_1_prior.mid') music = melody_prior_control(melody_2) music.write(os.path.join(WRITE_PATH, 'control_2_prior.mid')) print(f'Saved to {WRITE_PATH}/control_2_prior.mid') music = melody_prior_control(melody_3) music.write(os.path.join(WRITE_PATH, 'control_3_prior.mid')) print(f'Saved to {WRITE_PATH}/control_3_prior.mid') music = melody_prior_control(melody_4) music.write(os.path.join(WRITE_PATH, 'control_4_prior.mid')) print(f'Saved to {WRITE_PATH}/control_4_prior.mid') music = melody_control(chord_1, melody_1, melody_2) music.write(os.path.join(WRITE_PATH, 'control_1c+2m.mid')) music = melody_control(chord_2, melody_2, melody_1) music.write(os.path.join(WRITE_PATH, 'control_2c+1m.mid')) music = melody_control(chord_1, melody_1, melody_3) music.write(os.path.join(WRITE_PATH, 'control_1c+3m.mid')) music = melody_control(chord_3, melody_3, melody_1) music.write(os.path.join(WRITE_PATH, 'control_3c+1m.mid')) music = melody_control(chord_1, melody_1, melody_4) music.write(os.path.join(WRITE_PATH, 'control_1c+4m.mid')) music = melody_control(chord_4, melody_4, melody_1) music.write(os.path.join(WRITE_PATH, 'control_4c+1m.mid')) music = melody_control(chord_2, melody_2, melody_3) music.write(os.path.join(WRITE_PATH, 'control_2c+3m.mid')) music = melody_control(chord_3, melody_3, melody_2) music.write(os.path.join(WRITE_PATH, 'control_3c+2m.mid')) music = melody_control(chord_2, melody_2, melody_4) music.write(os.path.join(WRITE_PATH, 'control_2c+4m.mid')) music = melody_control(chord_4, melody_4, melody_2) music.write(os.path.join(WRITE_PATH, 'control_4c+2m.mid')) music = melody_control(chord_3, melody_3, melody_4) music.write(os.path.join(WRITE_PATH, 'control_3c+4m.mid')) music = melody_control(chord_4, melody_4, melody_3) music.write(os.path.join(WRITE_PATH, 'control_4c+3m.mid'))