haoyuliu00's picture
Initial commit with cleaned history
bf8981a
import torch
import numpy as np
import pretty_midi as pm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHORD_DICTIONARY = {
"C:major": np.array([1,0,0,0,1,0,0,1,0,0,0,0]),
"C#:major": np.array([0,1,0,0,0,1,0,0,1,0,0,0]),
"D:major": np.array([0,0,1,0,0,0,1,0,0,1,0,0]),
"Eb:major": np.array([0,0,0,1,0,0,0,1,0,0,1,0]),
"E:major": np.array([0,0,0,0,1,0,0,0,1,0,0,1]),
"F:major": np.array([1,0,0,0,0,1,0,0,0,1,0,0]),
"F#:major": np.array([0,1,0,0,0,0,1,0,0,0,1,0]),
"G:major": np.array([0,0,1,0,0,0,0,1,0,0,0,1]),
"Ab:major": np.array([1,0,0,1,0,0,0,0,1,0,0,0]),
"A:major": np.array([0,1,0,0,1,0,0,0,0,1,0,0]),
"Bb:major": np.array([0,0,1,0,0,1,0,0,0,0,1,0]),
"B:major": np.array([0,0,0,1,0,0,1,0,0,0,0,1]),
"c:minor": np.array([1,0,0,1,0,0,0,1,0,0,0,0]),
"c#:minor": np.array([0,1,0,0,1,0,0,0,1,0,0,0]),
"d:minor": np.array([0,0,1,0,0,1,0,0,0,1,0,0]),
"eb:minor": np.array([0,0,0,1,0,0,1,0,0,0,1,0]),
"e:minor": np.array([0,0,0,0,1,0,0,1,0,0,0,1]),
"f:minor": np.array([1,0,0,0,0,1,0,0,1,0,0,0]),
"f#:minor": np.array([0,1,0,0,0,0,1,0,0,1,0,0]),
"g:minor": np.array([0,0,1,0,0,0,0,1,0,0,1,0]),
"g#:minor": np.array([0,0,0,1,0,0,0,0,1,0,0,1]),
"a:minor": np.array([1,0,0,0,1,0,0,0,0,1,0,0]),
"bb:minor": np.array([0,1,0,0,0,1,0,0,0,0,1,0]),
"b:minor": np.array([0,0,1,0,0,0,1,0,0,0,0,1]),
}
def edit_rhythm(piano_roll_full, num_notes_onset, mask_full, reduce_extra_notes=True):
'''
piano_roll_full: a tensor with shape (batch_size, 2, length, h) # length=64 is length of roll, h is number of possible pitch
num_notes_onset: a tensor with shape (batch_size, length)
mask_full: a tensor with shape the same as piano_roll, corresponding to the 7 notes chroma
reduce_extra_notes: True if want to reduce extra notes
'''
########## for those greater than the threshold, if num of notes exceed num_notes[i],
########## will keep the first ones and set others to threshold
print("Coming")
# we only edit onset
onset_roll = piano_roll_full[:,0,:,:]
mask = mask_full[:,0,:,:]
shape = onset_roll.shape
onset_roll = onset_roll.reshape(-1,shape[-1])
mask = mask.reshape(-1,shape[-1])
num_notes = num_notes_onset.reshape(-1)
reduce_note_threshold = 0.499
increase_note_threshold = 0.501
# Initialize a tensor to store the modified values
final_onset_roll = onset_roll.clone()
########### if number of notes > required, remove the extra notes ###############
if reduce_extra_notes:
threshold_mask = onset_roll > reduce_note_threshold
# Set all values <= reduce_note_threshold to -inf to exclude them from top-k selection
values_above_threshold = torch.where(threshold_mask & (mask == 1), onset_roll, torch.tensor(-float('inf')).to(onset_roll.device))
# Get the top num_notes.max() values for each row
num_notes_max = int(num_notes.max().item()) # Maximum number of notes needed in any row
topk_values, topk_indices = torch.topk(values_above_threshold, num_notes_max, dim=1)
# Create a mask for the top num_notes[i] values for each row
col_indices = torch.arange(num_notes_max, device=onset_roll.device).expand(len(onset_roll), num_notes_max)
topk_mask = (col_indices < num_notes.unsqueeze(1)) & (topk_values > -float("inf"))
# Set all values greater than reduce_note_threshold to reduce_note_threshold initially
final_onset_roll[threshold_mask & (mask == 1)] = reduce_note_threshold
# Create a flattened index to scatter the top values back into final_onset_roll
flat_row_indices = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_indices)
flat_row_indices = flat_row_indices[topk_mask]
# Gather the valid topk_indices and corresponding values
valid_topk_indices = topk_indices[topk_mask]
valid_topk_values = topk_values[topk_mask]
# Use scatter to place the top num_notes[i] values back to their original positions
final_onset_roll = final_onset_roll.index_put_((flat_row_indices, valid_topk_indices), valid_topk_values)
########### if number of notes < required, add some notes ###############
pitch_less_84_mask = torch.ones_like(mask)
pitch_less_84_mask[:,51:] = 0
# Count how many values >= increase_note_threshold for each row
threshold_mask_2 = (final_onset_roll >= increase_note_threshold)&(mask==1)
greater_than_threshold2_count = threshold_mask_2.sum(dim=1)
# For those rows, find the remaining number of values needed to be set to increase_note_threshold
remaining_needed = num_notes - greater_than_threshold2_count
remaining_needed_max = int(remaining_needed.max().item())
print("\n\n\n",remaining_needed_max,"\n\n\n")
if remaining_needed_max>=0: # need to add notes
# Find the values in each row that are < increase_note_threshold but are the highest (so we can set them to increase_note_threshold)
values_below_threshold2 = torch.where((final_onset_roll < increase_note_threshold)&(mask==1)&(pitch_less_84_mask==1), final_onset_roll, torch.tensor(-float('inf')).to(onset_roll.device))
topk_below_threshold2_values, topk_below_threshold2_indices = torch.topk(values_below_threshold2, remaining_needed_max, dim=1)
# Mask to only adjust the needed number of values in each row
col_indices_below_threshold2 = torch.arange(remaining_needed_max, device=onset_roll.device).expand(len(onset_roll), remaining_needed_max)
adjust_mask = (col_indices_below_threshold2 < remaining_needed.unsqueeze(1)) & (topk_below_threshold2_values > -float("inf"))
# Flatten row indices for the new top-k below increase_note_threshold
flat_row_indices_below_threshold2 = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_below_threshold2_indices)
flat_row_indices_below_threshold2 = flat_row_indices_below_threshold2[adjust_mask]
# Gather the valid indices and set them to increase_note_threshold
valid_below_threshold2_indices = topk_below_threshold2_indices[adjust_mask]
# Update the final_onset_roll to make sure we now have exactly num_notes[i] values >= increase_note_threshold
final_onset_roll = final_onset_roll.index_put_((flat_row_indices_below_threshold2, valid_below_threshold2_indices), torch.tensor(increase_note_threshold, device=onset_roll.device))
final_onset_roll = final_onset_roll.reshape(shape)
piano_roll_full[:,0,:,:] = final_onset_roll
return piano_roll_full
def X0EditFunc(x0, background_condition, sampler_device=device, reduce_extra_notes=True, rhythm_control="Yes"):
# 预先计算 major 和 minor 和弦的所有旋转
maj_chd = torch.tensor([[1.,0,0,0,1,0,0,1,0,0,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device)
maj_chd = torch.tile(maj_chd, (1, 64 // maj_chd.size(1) + 1))
min_chd = torch.tensor([[1.,0,0,0,1,0,0,0,0,1,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device)
min_chd = torch.tile(min_chd, (1, 64 // min_chd.size(1) + 1))
# all chords, with rotation
maj_chd_rotations = torch.stack([torch.roll(maj_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64]
min_chd_rotations = torch.stack([torch.roll(min_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64]
# combine all chords
# chd_scale_map is a tensor with shape (N, 2, 64), N is total number of chord types,
# 2 is (chord_chroma, corresponding_scale_chroma), 64 is number of possible notes
chd_scale_map = torch.concat([maj_chd_rotations, min_chd_rotations], axis=0)
# if using null rhythm condition, have to convert -2 to 1 and -1 to 0
if background_condition[:,:2,:,:].min()<0:
correct_chord_condition = -background_condition[:,:2,:,:]-1
else:
correct_chord_condition = background_condition[:,:2,:,:]
merged_chd_roll = torch.max(correct_chord_condition[:,0,:,:], correct_chord_condition[:,1,:,:]) # chd roll of our bg_cond
chd_chroma_ours = torch.clamp(merged_chd_roll, min=0.0, max=1.0) # chd chroma of our bg_cond
shape = chd_chroma_ours.shape
chd_chroma_ours = chd_chroma_ours.reshape(-1,64)
matches = (chd_scale_map[:, 0, :].unsqueeze(0) - chd_chroma_ours.unsqueeze(1)>=0).all(dim=-1)
seven_notes_chroma_ours = torch.einsum('ij,jk->ik', matches.float(), chd_scale_map[:, 1, :]).reshape(shape)
seven_notes_chroma_ours = seven_notes_chroma_ours.unsqueeze(1).repeat((1,2,1,1))
no_chd_match = torch.all(seven_notes_chroma_ours == 0, dim=-1)
seven_notes_chroma_ours[no_chd_match] = 1.
# edit notes based on chroma
x0 = torch.where((seven_notes_chroma_ours==0)&(x0>0), 0.0 , x0)
print("See Coming?")
# edit rhythm
if (background_condition[:,:2,:,:].min()>=0) and (rhythm_control=="Yes"): # only edit if rhythm is provided
num_onset_notes, _ = torch.max(background_condition[:,0,:,:], axis=-1)
x0 = edit_rhythm(x0, num_onset_notes, seven_notes_chroma_ours, reduce_extra_notes)
return x0
def expand_roll(roll, unit=4, contain_onset=False):
# roll: (Channel, T, H) -> (Channel, T * unit, H)
n_channel, length, height = roll.shape
expanded_roll = roll.repeat(unit, axis=1)
if contain_onset:
expanded_roll = expanded_roll.reshape((n_channel, length, unit, height))
expanded_roll[1::2, :, 1:] = np.maximum(expanded_roll[::2, :, 1:], expanded_roll[1::2, :, 1:])
expanded_roll[::2, :, 1:] = 0
expanded_roll = expanded_roll.reshape((n_channel, length * unit, height))
return expanded_roll
def cut_piano_roll(piano_roll, resolution=16, lowest=33, highest=96):
piano_roll_cut = piano_roll[:,:,lowest:highest+1]
return piano_roll_cut
def circular_extend(chd_roll, lowest=33, highest=96):
#chd_roll: 6*L*12->6*L*64
C4 = 60-lowest
C3 = C4-12
shape = chd_roll.shape
ext_chd = np.zeros((shape[0],shape[1],highest+1-lowest))
ext_chd[:,:,C4:C4+12] = chd_roll
ext_chd[:,:,C3:C3+12] = chd_roll
return ext_chd
def default_quantization(v):
return 1 if v > 0.5 else 0
def extend_piano_roll(piano_roll: np.ndarray, lowest=33, highest=96):
## this function is for extending the cutted piano rolls into the full 128 piano rolls
## recall that the piano rolls are of dimensions (2,L,64), we add zeros and fill it into (2,L,128)
padded_roll = np.pad(piano_roll, ((0, 0), (0, 0), (lowest, 127-highest)), mode='constant', constant_values=0)
return padded_roll
def piano_roll_to_note_mat(piano_roll: np.ndarray, quantization_func=None):
"""
piano_roll: (2, L, 128), onset and sustain channel.
raise_chord: whether pitch below 48 (mel-chd boundary) will be raised an octave
"""
def convert_p(p_, note_list):
edit_note_flag = False
for t in range(n_step):
onset_state = quantization_func(piano_roll[0, t, p_])
sustain_state = quantization_func(piano_roll[1, t, p_])
is_onset = bool(onset_state)
is_sustain = bool(sustain_state) and not is_onset
pitch = p_
if is_onset:
edit_note_flag = True
note_list.append([t, pitch, 1])
elif is_sustain:
if edit_note_flag:
note_list[-1][-1] += 1
else:
edit_note_flag = False
return note_list
quantization_func = default_quantization if quantization_func is None else quantization_func
assert len(piano_roll.shape) == 3 and piano_roll.shape[0] == 2 and piano_roll.shape[2] == 128, f"{piano_roll.shape}"
n_step = piano_roll.shape[1]
notes = []
for p in range(128):
convert_p(p, notes)
return notes
def note_mat_to_notes(note_mat, bpm, unit=1/4, shift_beat=0., shift_sec=0., vel=100):
"""Default use shift beat"""
beat_alpha = 60 / bpm
step_alpha = unit * beat_alpha
notes = []
shift_sec = shift_sec if shift_beat is None else shift_beat * beat_alpha
for note in note_mat:
onset, pitch, dur = note
start = onset * step_alpha + shift_sec
end = (onset + dur) * step_alpha + shift_sec
notes.append(pm.Note(vel, int(pitch), start, end))
return notes
def create_pm_object(bpm, piano_notes_list, chd_notes_list, lsh_notes_list=None):
midi = pm.PrettyMIDI(initial_tempo=bpm)
piano_program = pm.instrument_name_to_program('Acoustic Grand Piano')
piano = pm.Instrument(program=piano_program)
piano.notes+=piano_notes_list
midi.instruments.append(piano)
# chd_program = pm.instrument_name_to_program('Violin')
# chd = pm.Instrument(program=chd_program)
# chd.notes+=chd_notes_list
# midi.instruments.append(chd)
if lsh_notes_list is not None:
lsh_program = pm.instrument_name_to_program('Acoustic Grand Piano')
lsh = pm.Instrument(program=lsh_program)
lsh.notes+=lsh_notes_list
midi.instruments.append(lsh)
return midi
def piano_roll_to_midi(piano_roll: np.ndarray, chd_roll: np.ndarray, lsh_roll=None, bpm=80):
piano_mat = piano_roll_to_note_mat(piano_roll)
piano_notes = note_mat_to_notes(piano_mat, bpm)
chd_mat = piano_roll_to_note_mat(chd_roll)
chd_notes = note_mat_to_notes(chd_mat, bpm)
if lsh_roll is not None:
lsh_mat = piano_roll_to_note_mat(lsh_roll)
lsh_notes = note_mat_to_notes(lsh_mat, bpm)
else:
lsh_notes=None
piano_pm = create_pm_object(bpm = 80, piano_notes_list=piano_notes,
chd_notes_list=chd_notes, lsh_notes_list=lsh_notes)
return piano_pm
def save_midi(pm, filename):
pm.write(filename)