"""core_diff.py |
Differentiable metrics that can be used to train VAE |
Assuming piano roll is from [-1, 1], -1 is background |
Input size from data loader: 1x1x128xLENGTH |
""" |
import torch |
import numpy as np |
import matplotlib as mpl |
import matplotlib.pyplot as plt |
import torch.nn.functional as F |
from .piano_roll_to_chord import piano_roll_to_chords, piano_roll_to_chords_save_midi |
VERTICAL_ND_BOUNDS = [1.29, 2.7578125, 3.61, 4.4921875, 5.28125, 6.1171875, 7.22] |
VERTICAL_ND_CENTER = [0.56, 2.0239, 3.1839, 4.0511, 4.8867, 5.6992, 6.6686, 7.77] |
HORIZONTAL_ND_BOUNDS = [1.8, 2.6, 3.2, 3.6, 4.4, 4.8, 5.8] |
HORIZONTAL_ND_CENTER = [1.4, 2.2000, 2.9, 3.4, 4.0, 4.6, 5.3, 6.3] |
MIN_PIANO, MAX_PIANO, OFF = 21, 108, -1 |
def piano_like(x): |
x[:, :, :MIN_PIANO, :] = OFF |
x[:, :, MAX_PIANO + 1:, :] = OFF |
return x |
def total_pitch_class_histogram(piano_roll): |
piano_roll = piano_roll[:, :1, :, :] |
piano_roll = piano_like(piano_roll) |
piano_roll = (piano_roll + 1) / 2. |
piano_roll = piano_roll.squeeze(dim=1) |
piano_roll_reduce_time = torch.sum(piano_roll, dim=-1) |
piano_roll_padded = torch.concat((piano_roll_reduce_time, torch.zeros(piano_roll.shape[0], 4, device=piano_roll.device)), dim=-1) |
pr_reshape = piano_roll_padded.unsqueeze(dim=1).reshape(-1, 11, 12).permute(0, 2, 1) |
histogram = pr_reshape.sum(dim=-1) |
histogram = histogram / (torch.sum(histogram, dim=-1, keepdim=True) + 1e-12) |
if histogram.shape[0] == 1: |
return histogram.squeeze(dim=0) |
else: |
return histogram |
def note_density(piano_roll, interval=128, quantize_factor=1, horizontal_scale=5): |
""" |
return both vertical and horizontal note density |
vertical density is the average number of vertical notes per column every interval |
horizontal density is the total number of horizontal notes every interval / 5 |
quantize_factor: tolerant to slight mismatch in time |
horizontal_scale: rescale horizontal nd so it's on the same order of mag as vertical nd |
""" |
piano_roll = piano_roll[:, :1, :, :] |
batch_size = piano_roll.shape[0] |
orig_size = piano_roll.shape[-1] |
if quantize_factor != 1: |
piano_roll = F.interpolate(piano_roll, size=(128, orig_size // quantize_factor), mode='nearest') |
interval = interval // quantize_factor |
piano_roll = piano_like(piano_roll) |
piano_roll[piano_roll < -0.95] = -1. |
piano_roll = (piano_roll + 1) / 2. |
piano_roll[piano_roll >= 1e-2] = 1. |
piano_roll[piano_roll < 1e-2] = 0. |
vertical_nd_per_col = piano_roll.sum(dim=2) |
piano_roll = F.pad(piano_roll, (1, 1), 'constant') |
diff_piano_roll = torch.diff(piano_roll) |
diff_piano_roll[diff_piano_roll < 0] = 0 |
horizontal_nd_per_col = diff_piano_roll.sum(dim=2)[:, :, :-1] |
horizontal_nd_per_col[horizontal_nd_per_col != 0.] = 1 |
vertical_nd = vertical_nd_per_col.reshape(batch_size, 1, -1, interval).mean(dim=-1) |
horizontal_nd = horizontal_nd_per_col.reshape(batch_size, 1, -1, interval).sum(dim=-1) / horizontal_scale |
nd = torch.concat((vertical_nd, horizontal_nd), dim=-1) |
if batch_size == 1: |
return nd.squeeze() |
else: |
return nd.squeeze(dim=1) |
def note_density_class(piano_roll, interval=128, quantize_factor=1, horizontal_scale=1): |
vt_bounds = torch.tensor(VERTICAL_ND_BOUNDS).to(piano_roll.device) |
hr_bounds = torch.tensor(HORIZONTAL_ND_BOUNDS).to(piano_roll.device) / horizontal_scale |
orig_rule = note_density(piano_roll, interval=interval, quantize_factor=quantize_factor, horizontal_scale=horizontal_scale) |
total_length = orig_rule.shape[-1] |
vt_nd_classes = torch.bucketize(orig_rule[:, :total_length // 2], vt_bounds) |
hr_nd_classes = torch.bucketize(orig_rule[:, total_length // 2:], hr_bounds) |
target_rule = torch.concat((vt_nd_classes, hr_nd_classes), dim=-1) |
return target_rule |
def get_chords(piano_roll_batch, given_key=None, fs=100, window_size=1.28, return_key=False): |
piano_roll_batch = piano_roll_batch[:, :1, :, :] |
if not return_key: |
out_all = [] |
else: |
out_chord_all = [] |
out_key_all = [] |
out_key_corr_all = [] |
piano_roll_batch = piano_like(piano_roll_batch) |
piano_roll_batch[piano_roll_batch < -0.95] = -1. |
piano_roll_batch = (piano_roll_batch + 1) / 2 * 127 |
piano_roll_batch = torch.clamp(piano_roll_batch, min=0, max=127) |
for i in range(piano_roll_batch.shape[0]): |
piano_roll = piano_roll_batch[i, 0].cpu().numpy().astype(np.intc) |
out = piano_roll_to_chords(piano_roll, given_key=given_key, fs=fs, window_size=window_size, return_key=return_key) |
if return_key: |
out_chord_all.append(out["chords"].unsqueeze(dim=0)) |
out_key_all.append(out["key"]) |
out_key_corr_all.append(out["correlationCoefficient"]) |
else: |
out_all.append(out["chords"].unsqueeze(dim=0)) |
if return_key: |
chords = torch.concat(out_chord_all, dim=0) |
if chords.shape[0] == 1: |
chords = chords.squeeze(dim=0) |
return chords, out_key_all, out_key_corr_all |
else: |
chords = torch.concat(out_all, dim=0) |
if chords.shape[0] == 1: |
chords = chords.squeeze(dim=0) |
return chords |