import os import math import torch import numpy as np import pandas as pd import pretty_midi import matplotlib as mpl import matplotlib.pyplot as plt from . import dist_util import yaml from types import SimpleNamespace from music_rule_guidance.piano_roll_to_chord import piano_roll_to_pretty_midi, KEY_DICT, IND2KEY from music_rule_guidance.rule_maps import FUNC_DICT, LOSS_DICT from music_rule_guidance.music_rules import MAX_PIANO, MIN_PIANO plt.rcParams['figure.dpi'] = 300 plt.rcParams['savefig.dpi'] = 300 # bounds to compute classes for nd editing VERTICAL_ND_BOUNDS = [1.29, 2.7578125, 3.61, 4.4921875, 5.28125, 6.1171875, 7.22] VERTICAL_ND_CENTER = [0.56, 2.0239, 3.1839, 4.0511, 4.8867, 5.6992, 6.6686, 7.77] HORIZONTAL_ND_BOUNDS = [1.8, 2.6, 3.2, 3.6, 4.4, 4.8, 5.8] HORIZONTAL_ND_CENTER = [1.4, 2.2000, 2.9, 3.4, 4.0, 4.6, 5.3, 6.3] def dict_to_obj(d): if isinstance(d, list): d = [dict_to_obj(x) if isinstance(x, dict) else x for x in d] if not isinstance(d, dict): return d return SimpleNamespace(**{k: dict_to_obj(v) for k, v in d.items()}) def load_config(filename): with open(filename, 'r') as file: data = yaml.safe_load(file) # Convert the dictionary to an object data_obj = dict_to_obj(data) return data_obj @torch.no_grad() def decode_sample_for_midi(sample, embed_model=None, scale_factor=1., threshold=-0.95): # decode latent samples to a long piano roll of [0, 127] sample = sample / scale_factor if embed_model is not None: image_size_h = sample.shape[-2] image_size_w = sample.shape[-1] if image_size_h > image_size_w: # transposed for raster col, don't need to permute for pixel space sample = sample.permute(0, 1, 3, 2) # vertical axis means pitch after transpose num_latents = sample.shape[-1] // sample.shape[-2] if image_size_h >= image_size_w: sample = torch.chunk(sample, num_latents, dim=-1) # B x C x H x W sample = torch.concat(sample, dim=0) # 1st second for all batch, 2nd second for all batch, ... sample = embed_model.decode(sample) if image_size_h >= image_size_w: sample = torch.concat(torch.chunk(sample, num_latents, dim=0), dim=-1) sample[sample <= threshold] = -1. # heuristic thresholding the background sample = ((sample + 1) * 63.5).clamp(0, 127).to(torch.uint8) sample = sample.permute(0, 2, 3, 1) sample = sample.contiguous() return sample def save_piano_roll_midi(sample, save_dir, fs=100, y=None, save_piano_roll=False, save_ind=0): # input shape: B x 128 (pitch) x time (no pedal) or B x 2 (pedal) x 128 x time (with pedal) fig_size = sample.shape[-1] // 128 * 3 plt.rcParams["figure.figsize"] = (fig_size, 3) pedal = True if len(sample.shape) == 4 else False onset = True if sample.shape[1] == 3 else False for i in range(sample.shape[0]): cur_sample = sample[i] if cur_sample.shape[-1] < 5000 and save_piano_roll: # do not save piano rolls that are too long if pedal: plt.imshow(cur_sample[0, ::-1], vmin=0, vmax=127) else: plt.imshow(cur_sample[::-1], vmin=0, vmax=127) plt.savefig(os.path.join(save_dir, "prsample_" + str(i) + ".png")) if onset: # add onset for first column first_column = cur_sample[0, :, 0] first_onset_pitch = first_column.nonzero()[0] cur_sample[1, first_onset_pitch, 0] = 127 cur_sample = cur_sample.astype(np.float32) pm = piano_roll_to_pretty_midi(cur_sample, fs=fs) if y is not None: save_name = 'sample_' + str(i + save_ind) + '_y_' + str(y[i].item()) + '.midi' else: save_name = 'sample_' + str(i + save_ind) + '.midi' pm.write(os.path.join(save_dir, save_name)) return def eval_rule_loss(generated_samples, target_rules): results = {} batch_size = generated_samples.shape[0] for rule_name, rule_target in target_rules.items(): rule_target_list = rule_target.tolist() if batch_size == 1: rule_target_list = [rule_target_list] results[rule_name + '.target_rule'] = rule_target_list rule_target = rule_target.to(generated_samples.device) if 'chord' in rule_name: gen_rule, key, corr = FUNC_DICT[rule_name](generated_samples, return_key=True) key_strings = [IND2KEY[key_ind] for key_ind in key] loss = LOSS_DICT[rule_name](gen_rule, rule_target) mean_loss, std_loss, gen_rule, loss = loss.mean(), loss.std(), gen_rule.tolist(), loss.tolist() if batch_size == 1: gen_rule = [gen_rule] results[rule_name + '.gen_rule'] = gen_rule results[rule_name + '.key_str'] = key_strings results[rule_name + '.key_corr'] = corr results[rule_name + '.loss'] = loss else: gen_rule = FUNC_DICT[rule_name](generated_samples) loss = LOSS_DICT[rule_name](gen_rule, rule_target) mean_loss, std_loss, gen_rule, loss = loss.mean(), loss.std(), gen_rule.tolist(), loss.tolist() if batch_size == 1: gen_rule = [gen_rule] results[rule_name + '.gen_rule'] = gen_rule results[rule_name + '.loss'] = loss return pd.DataFrame(results) def compute_rule(generated_samples, orig_samples, target_rules): results = {} batch_size = generated_samples.shape[0] for rule_name in target_rules: rule_target = FUNC_DICT[rule_name](orig_samples) rule_target_list = rule_target.tolist() if batch_size == 1: rule_target_list = [rule_target_list] results[rule_name + '.target_rule'] = rule_target_list rule_target = rule_target.to(generated_samples.device) if rule_name == 'chord_progression': gen_rule, key, corr = FUNC_DICT[rule_name](generated_samples, return_key=True) key_strings = [IND2KEY[key_ind] for key_ind in key] loss = LOSS_DICT[rule_name](gen_rule, rule_target) mean_loss, std_loss, gen_rule, loss = loss.mean(), loss.std(), gen_rule.tolist(), loss.tolist() if batch_size == 1: gen_rule = [gen_rule] results[rule_name + '.gen_rule'] = gen_rule results[rule_name + '.key_str'] = key_strings results[rule_name + '.key_corr'] = corr results[rule_name + '.loss'] = loss else: gen_rule = FUNC_DICT[rule_name](generated_samples) loss = LOSS_DICT[rule_name](gen_rule, rule_target) mean_loss, std_loss, gen_rule, loss = loss.mean(), loss.std(), gen_rule.tolist(), loss.tolist() if batch_size == 1: gen_rule = [gen_rule] results[rule_name + '.gen_rule'] = gen_rule results[rule_name + '.loss'] = loss return pd.DataFrame(results) def visualize_piano_roll(piano_roll): """ Assuming piano roll has shape Bx1x128x1024, and the values are between [-1, 1], on gpu. Visualize with some gap in between the first 256, last 256/ """ piano_roll = torch.flip(piano_roll, [2]) piano_roll = (piano_roll + 1) / 2. vis_length = 256 gap = 80 plt.rcParams["figure.figsize"] = (12, 3) data = torch.zeros(128, vis_length * 2 + gap) data[:, :vis_length] = piano_roll[0, 0, :, :vis_length] data[:, -vis_length:] = piano_roll[0, 0, :, -vis_length:] data_clone = data.clone() # make it look thicker data[1:, :] = data[1:, :] + data_clone[:-1, :] data[2:, :] = data[2:, :] + data_clone[:-2, :] data = data.cpu().numpy() plt.imshow(data, cmap=mpl.colormaps['Blues']) ax = plt.gca() # gca stands for 'get current axis' for edge, spine in ax.spines.items(): spine.set_linewidth(2) # Adjust the value as per your requirement plt.grid(color='gray', linestyle='-', linewidth=2., alpha=0.5, which='both', axis='x') plt.xticks( np.concatenate((np.arange(0, vis_length + 1, 128), np.arange(vis_length + gap, vis_length * 2 + gap, 128)))) # plt.savefig('piano_roll_example.png', bbox_inches='tight', pad_inches=0.1, dpi=300) plt.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False) plt.tight_layout() plt.show() plt.rcParams["figure.figsize"] = (3, 3) for i in range(2): plt.imshow(data[:, i*128: (i+1)*128], cmap=mpl.colormaps['Blues']) ax = plt.gca() for edge, spine in ax.spines.items(): spine.set_linewidth(2) plt.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False) plt.tight_layout() plt.show() for i in range(-2, 0): if (i+1)*128 < 0: plt.imshow(data[:, i*128: (i+1)*128], cmap=mpl.colormaps['Blues']) else: plt.imshow(data[:, i*128:], cmap=mpl.colormaps['Blues']) ax = plt.gca() for edge, spine in ax.spines.items(): spine.set_linewidth(2) plt.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False) plt.tight_layout() plt.show() return def visualize_full_piano_roll(midi_file_name, fs=100): """ Visualize full piano roll from midi file """ midi_data = pretty_midi.PrettyMIDI(midi_file_name) # do not process sustain pedal piano_roll = torch.tensor(midi_data.get_piano_roll(fs=fs, pedal_threshold=None)) data = torch.flip(piano_roll, [0]) plt.rcParams["figure.figsize"] = (12, 3) # data_clone = data.clone() # # make it look thicker # data[1:, :] = data[1:, :] + data_clone[:-1, :] # data[2:, :] = data[2:, :] + data_clone[:-2, :] data = data.cpu().numpy() plt.imshow(data, cmap=mpl.colormaps['Blues']) ax = plt.gca() # gca stands for 'get current axis' for edge, spine in ax.spines.items(): spine.set_linewidth(2) # Adjust the value as per your requirement plt.grid(color='gray', linestyle='-', linewidth=2., alpha=0.5, which='both', axis='x') plt.xticks(np.arange(0, piano_roll.shape[1], 128)) # plt.savefig('piano_roll_example.png', bbox_inches='tight', pad_inches=0.1, dpi=300) plt.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False) plt.tight_layout() plt.show() return def plot_record(vals, title, save_dir): ts = [item[0] for item in vals] log_probs = [item[1] for item in vals] plt.plot(ts, log_probs) plt.gca().invert_xaxis() plt.title(title) plt.savefig(save_dir + '/' + title + '.png') plt.show() return def quantize_pedal(value, num_bins=8): """Quantize an integer value from 0 to 127 into 8 bins and return the center value of the bin.""" if value < 0 or value > 127: raise ValueError("Value should be between 0 and 127") # Determine bin size bin_size = 128 // num_bins # 16 # Quantize the value bin_index = value // bin_size bin_center = bin_size * bin_index + bin_size // 2 # Handle edge case for the last bin if bin_center > 127: bin_center = 127 return bin_center def get_full_piano_roll(midi_data, fs, show=False): # do not process sustain pedal piano_roll, onset_roll = midi_data.get_piano_roll(fs=fs, pedal_threshold=None, onset=True) # save pedal roll explicitly pedal_roll = np.zeros_like(piano_roll) # process pedal for instru in midi_data.instruments: pedal_changes = [_e for _e in instru.control_changes if _e.number == CC_SUSTAIN_PEDAL] for cc in pedal_changes: time_now = int(cc.time * fs) if time_now < pedal_roll.shape[-1]: # need to distinguish control_change 0 and background 0, with quantize 0-16 will be 8 # in muscore files, 0 immediately followed by 127, need to shift by one column if pedal_roll[MIN_PIANO, time_now] != 0. and abs(pedal_roll[MIN_PIANO, time_now] - cc.value) > 64: # use shift 2 here to prevent missing change when using interpolation augmentation pedal_roll[MIN_PIANO:MAX_PIANO + 1, min(time_now + 2, pedal_roll.shape[-1] - 1)] = quantize_pedal(cc.value) else: pedal_roll[MIN_PIANO:MAX_PIANO + 1, time_now] = quantize_pedal(cc.value) full_roll = np.concatenate((piano_roll[None], onset_roll[None], pedal_roll[None]), axis=0) if show: plt.imshow(piano_roll[::-1, :1024], vmin=0, vmax=127) plt.show() plt.imshow(pedal_roll[::-1, :1024], vmin=0, vmax=127) plt.show() return full_roll