|
import torch |
|
import numpy as np |
|
import pretty_midi as pm |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
CHORD_DICTIONARY = { |
|
"C:major": np.array([1,0,0,0,1,0,0,1,0,0,0,0]), |
|
"C#:major": np.array([0,1,0,0,0,1,0,0,1,0,0,0]), |
|
"D:major": np.array([0,0,1,0,0,0,1,0,0,1,0,0]), |
|
"Eb:major": np.array([0,0,0,1,0,0,0,1,0,0,1,0]), |
|
"E:major": np.array([0,0,0,0,1,0,0,0,1,0,0,1]), |
|
"F:major": np.array([1,0,0,0,0,1,0,0,0,1,0,0]), |
|
"F#:major": np.array([0,1,0,0,0,0,1,0,0,0,1,0]), |
|
"G:major": np.array([0,0,1,0,0,0,0,1,0,0,0,1]), |
|
"Ab:major": np.array([1,0,0,1,0,0,0,0,1,0,0,0]), |
|
"A:major": np.array([0,1,0,0,1,0,0,0,0,1,0,0]), |
|
"Bb:major": np.array([0,0,1,0,0,1,0,0,0,0,1,0]), |
|
"B:major": np.array([0,0,0,1,0,0,1,0,0,0,0,1]), |
|
|
|
"c:minor": np.array([1,0,0,1,0,0,0,1,0,0,0,0]), |
|
"c#:minor": np.array([0,1,0,0,1,0,0,0,1,0,0,0]), |
|
"d:minor": np.array([0,0,1,0,0,1,0,0,0,1,0,0]), |
|
"eb:minor": np.array([0,0,0,1,0,0,1,0,0,0,1,0]), |
|
"e:minor": np.array([0,0,0,0,1,0,0,1,0,0,0,1]), |
|
"f:minor": np.array([1,0,0,0,0,1,0,0,1,0,0,0]), |
|
"f#:minor": np.array([0,1,0,0,0,0,1,0,0,1,0,0]), |
|
"g:minor": np.array([0,0,1,0,0,0,0,1,0,0,1,0]), |
|
"g#:minor": np.array([0,0,0,1,0,0,0,0,1,0,0,1]), |
|
"a:minor": np.array([1,0,0,0,1,0,0,0,0,1,0,0]), |
|
"bb:minor": np.array([0,1,0,0,0,1,0,0,0,0,1,0]), |
|
"b:minor": np.array([0,0,1,0,0,0,1,0,0,0,0,1]), |
|
} |
|
|
|
|
|
def edit_rhythm(piano_roll_full, num_notes_onset, mask_full, reduce_extra_notes=True): |
|
''' |
|
piano_roll_full: a tensor with shape (batch_size, 2, length, h) # length=64 is length of roll, h is number of possible pitch |
|
num_notes_onset: a tensor with shape (batch_size, length) |
|
mask_full: a tensor with shape the same as piano_roll, corresponding to the 7 notes chroma |
|
reduce_extra_notes: True if want to reduce extra notes |
|
''' |
|
|
|
|
|
print("Coming") |
|
|
|
onset_roll = piano_roll_full[:,0,:,:] |
|
mask = mask_full[:,0,:,:] |
|
shape = onset_roll.shape |
|
|
|
onset_roll = onset_roll.reshape(-1,shape[-1]) |
|
mask = mask.reshape(-1,shape[-1]) |
|
num_notes = num_notes_onset.reshape(-1) |
|
|
|
reduce_note_threshold = 0.499 |
|
increase_note_threshold = 0.501 |
|
|
|
|
|
final_onset_roll = onset_roll.clone() |
|
|
|
|
|
if reduce_extra_notes: |
|
threshold_mask = onset_roll > reduce_note_threshold |
|
|
|
values_above_threshold = torch.where(threshold_mask & (mask == 1), onset_roll, torch.tensor(-float('inf')).to(onset_roll.device)) |
|
|
|
|
|
num_notes_max = int(num_notes.max().item()) |
|
topk_values, topk_indices = torch.topk(values_above_threshold, num_notes_max, dim=1) |
|
|
|
|
|
col_indices = torch.arange(num_notes_max, device=onset_roll.device).expand(len(onset_roll), num_notes_max) |
|
topk_mask = (col_indices < num_notes.unsqueeze(1)) & (topk_values > -float("inf")) |
|
|
|
|
|
final_onset_roll[threshold_mask & (mask == 1)] = reduce_note_threshold |
|
|
|
|
|
flat_row_indices = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_indices) |
|
flat_row_indices = flat_row_indices[topk_mask] |
|
|
|
|
|
valid_topk_indices = topk_indices[topk_mask] |
|
valid_topk_values = topk_values[topk_mask] |
|
|
|
|
|
final_onset_roll = final_onset_roll.index_put_((flat_row_indices, valid_topk_indices), valid_topk_values) |
|
|
|
|
|
pitch_less_84_mask = torch.ones_like(mask) |
|
pitch_less_84_mask[:,51:] = 0 |
|
|
|
|
|
threshold_mask_2 = (final_onset_roll >= increase_note_threshold)&(mask==1) |
|
greater_than_threshold2_count = threshold_mask_2.sum(dim=1) |
|
|
|
|
|
remaining_needed = num_notes - greater_than_threshold2_count |
|
remaining_needed_max = int(remaining_needed.max().item()) |
|
print("\n\n\n",remaining_needed_max,"\n\n\n") |
|
if remaining_needed_max>=0: |
|
|
|
values_below_threshold2 = torch.where((final_onset_roll < increase_note_threshold)&(mask==1)&(pitch_less_84_mask==1), final_onset_roll, torch.tensor(-float('inf')).to(onset_roll.device)) |
|
topk_below_threshold2_values, topk_below_threshold2_indices = torch.topk(values_below_threshold2, remaining_needed_max, dim=1) |
|
|
|
|
|
col_indices_below_threshold2 = torch.arange(remaining_needed_max, device=onset_roll.device).expand(len(onset_roll), remaining_needed_max) |
|
adjust_mask = (col_indices_below_threshold2 < remaining_needed.unsqueeze(1)) & (topk_below_threshold2_values > -float("inf")) |
|
|
|
|
|
flat_row_indices_below_threshold2 = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_below_threshold2_indices) |
|
flat_row_indices_below_threshold2 = flat_row_indices_below_threshold2[adjust_mask] |
|
|
|
|
|
valid_below_threshold2_indices = topk_below_threshold2_indices[adjust_mask] |
|
|
|
|
|
final_onset_roll = final_onset_roll.index_put_((flat_row_indices_below_threshold2, valid_below_threshold2_indices), torch.tensor(increase_note_threshold, device=onset_roll.device)) |
|
|
|
final_onset_roll = final_onset_roll.reshape(shape) |
|
piano_roll_full[:,0,:,:] = final_onset_roll |
|
return piano_roll_full |
|
|
|
def X0EditFunc(x0, background_condition, sampler_device=device, reduce_extra_notes=True, rhythm_control="Yes"): |
|
|
|
maj_chd = torch.tensor([[1.,0,0,0,1,0,0,1,0,0,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device) |
|
maj_chd = torch.tile(maj_chd, (1, 64 // maj_chd.size(1) + 1)) |
|
min_chd = torch.tensor([[1.,0,0,0,1,0,0,0,0,1,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device) |
|
min_chd = torch.tile(min_chd, (1, 64 // min_chd.size(1) + 1)) |
|
|
|
|
|
maj_chd_rotations = torch.stack([torch.roll(maj_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64] |
|
min_chd_rotations = torch.stack([torch.roll(min_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64] |
|
|
|
|
|
|
|
|
|
chd_scale_map = torch.concat([maj_chd_rotations, min_chd_rotations], axis=0) |
|
|
|
|
|
if background_condition[:,:2,:,:].min()<0: |
|
correct_chord_condition = -background_condition[:,:2,:,:]-1 |
|
else: |
|
correct_chord_condition = background_condition[:,:2,:,:] |
|
merged_chd_roll = torch.max(correct_chord_condition[:,0,:,:], correct_chord_condition[:,1,:,:]) |
|
chd_chroma_ours = torch.clamp(merged_chd_roll, min=0.0, max=1.0) |
|
shape = chd_chroma_ours.shape |
|
chd_chroma_ours = chd_chroma_ours.reshape(-1,64) |
|
matches = (chd_scale_map[:, 0, :].unsqueeze(0) - chd_chroma_ours.unsqueeze(1)>=0).all(dim=-1) |
|
seven_notes_chroma_ours = torch.einsum('ij,jk->ik', matches.float(), chd_scale_map[:, 1, :]).reshape(shape) |
|
seven_notes_chroma_ours = seven_notes_chroma_ours.unsqueeze(1).repeat((1,2,1,1)) |
|
|
|
no_chd_match = torch.all(seven_notes_chroma_ours == 0, dim=-1) |
|
seven_notes_chroma_ours[no_chd_match] = 1. |
|
|
|
|
|
x0 = torch.where((seven_notes_chroma_ours==0)&(x0>0), 0.0 , x0) |
|
print("See Coming?") |
|
|
|
if (background_condition[:,:2,:,:].min()>=0) and (rhythm_control=="Yes"): |
|
num_onset_notes, _ = torch.max(background_condition[:,0,:,:], axis=-1) |
|
x0 = edit_rhythm(x0, num_onset_notes, seven_notes_chroma_ours, reduce_extra_notes) |
|
|
|
return x0 |
|
|
|
def expand_roll(roll, unit=4, contain_onset=False): |
|
|
|
n_channel, length, height = roll.shape |
|
|
|
expanded_roll = roll.repeat(unit, axis=1) |
|
if contain_onset: |
|
expanded_roll = expanded_roll.reshape((n_channel, length, unit, height)) |
|
expanded_roll[1::2, :, 1:] = np.maximum(expanded_roll[::2, :, 1:], expanded_roll[1::2, :, 1:]) |
|
|
|
expanded_roll[::2, :, 1:] = 0 |
|
expanded_roll = expanded_roll.reshape((n_channel, length * unit, height)) |
|
return expanded_roll |
|
|
|
def cut_piano_roll(piano_roll, resolution=16, lowest=33, highest=96): |
|
piano_roll_cut = piano_roll[:,:,lowest:highest+1] |
|
return piano_roll_cut |
|
|
|
def circular_extend(chd_roll, lowest=33, highest=96): |
|
|
|
C4 = 60-lowest |
|
C3 = C4-12 |
|
shape = chd_roll.shape |
|
ext_chd = np.zeros((shape[0],shape[1],highest+1-lowest)) |
|
ext_chd[:,:,C4:C4+12] = chd_roll |
|
ext_chd[:,:,C3:C3+12] = chd_roll |
|
return ext_chd |
|
|
|
|
|
def default_quantization(v): |
|
return 1 if v > 0.5 else 0 |
|
|
|
def extend_piano_roll(piano_roll: np.ndarray, lowest=33, highest=96): |
|
|
|
|
|
padded_roll = np.pad(piano_roll, ((0, 0), (0, 0), (lowest, 127-highest)), mode='constant', constant_values=0) |
|
return padded_roll |
|
|
|
|
|
|
|
def piano_roll_to_note_mat(piano_roll: np.ndarray, quantization_func=None): |
|
""" |
|
piano_roll: (2, L, 128), onset and sustain channel. |
|
raise_chord: whether pitch below 48 (mel-chd boundary) will be raised an octave |
|
""" |
|
def convert_p(p_, note_list): |
|
edit_note_flag = False |
|
for t in range(n_step): |
|
onset_state = quantization_func(piano_roll[0, t, p_]) |
|
sustain_state = quantization_func(piano_roll[1, t, p_]) |
|
|
|
is_onset = bool(onset_state) |
|
is_sustain = bool(sustain_state) and not is_onset |
|
|
|
pitch = p_ |
|
|
|
if is_onset: |
|
edit_note_flag = True |
|
note_list.append([t, pitch, 1]) |
|
elif is_sustain: |
|
if edit_note_flag: |
|
note_list[-1][-1] += 1 |
|
else: |
|
edit_note_flag = False |
|
return note_list |
|
|
|
quantization_func = default_quantization if quantization_func is None else quantization_func |
|
assert len(piano_roll.shape) == 3 and piano_roll.shape[0] == 2 and piano_roll.shape[2] == 128, f"{piano_roll.shape}" |
|
|
|
n_step = piano_roll.shape[1] |
|
|
|
notes = [] |
|
for p in range(128): |
|
convert_p(p, notes) |
|
|
|
return notes |
|
|
|
|
|
def note_mat_to_notes(note_mat, bpm, unit=1/4, shift_beat=0., shift_sec=0., vel=100): |
|
"""Default use shift beat""" |
|
|
|
beat_alpha = 60 / bpm |
|
step_alpha = unit * beat_alpha |
|
|
|
notes = [] |
|
|
|
shift_sec = shift_sec if shift_beat is None else shift_beat * beat_alpha |
|
|
|
for note in note_mat: |
|
onset, pitch, dur = note |
|
start = onset * step_alpha + shift_sec |
|
end = (onset + dur) * step_alpha + shift_sec |
|
|
|
notes.append(pm.Note(vel, int(pitch), start, end)) |
|
|
|
return notes |
|
|
|
|
|
def create_pm_object(bpm, piano_notes_list, chd_notes_list, lsh_notes_list=None): |
|
midi = pm.PrettyMIDI(initial_tempo=bpm) |
|
|
|
piano_program = pm.instrument_name_to_program('Acoustic Grand Piano') |
|
piano = pm.Instrument(program=piano_program) |
|
piano.notes+=piano_notes_list |
|
midi.instruments.append(piano) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if lsh_notes_list is not None: |
|
lsh_program = pm.instrument_name_to_program('Acoustic Grand Piano') |
|
lsh = pm.Instrument(program=lsh_program) |
|
lsh.notes+=lsh_notes_list |
|
midi.instruments.append(lsh) |
|
|
|
return midi |
|
|
|
def piano_roll_to_midi(piano_roll: np.ndarray, chd_roll: np.ndarray, lsh_roll=None, bpm=80): |
|
piano_mat = piano_roll_to_note_mat(piano_roll) |
|
piano_notes = note_mat_to_notes(piano_mat, bpm) |
|
|
|
chd_mat = piano_roll_to_note_mat(chd_roll) |
|
chd_notes = note_mat_to_notes(chd_mat, bpm) |
|
|
|
if lsh_roll is not None: |
|
lsh_mat = piano_roll_to_note_mat(lsh_roll) |
|
lsh_notes = note_mat_to_notes(lsh_mat, bpm) |
|
else: |
|
lsh_notes=None |
|
|
|
piano_pm = create_pm_object(bpm = 80, piano_notes_list=piano_notes, |
|
chd_notes_list=chd_notes, lsh_notes_list=lsh_notes) |
|
return piano_pm |
|
|
|
def save_midi(pm, filename): |
|
pm.write(filename) |