import torch import numpy as np import pretty_midi as pm device = torch.device("cuda" if torch.cuda.is_available() else "cpu") CHORD_DICTIONARY = { "C:major": np.array([1,0,0,0,1,0,0,1,0,0,0,0]), "C#:major": np.array([0,1,0,0,0,1,0,0,1,0,0,0]), "D:major": np.array([0,0,1,0,0,0,1,0,0,1,0,0]), "Eb:major": np.array([0,0,0,1,0,0,0,1,0,0,1,0]), "E:major": np.array([0,0,0,0,1,0,0,0,1,0,0,1]), "F:major": np.array([1,0,0,0,0,1,0,0,0,1,0,0]), "F#:major": np.array([0,1,0,0,0,0,1,0,0,0,1,0]), "G:major": np.array([0,0,1,0,0,0,0,1,0,0,0,1]), "Ab:major": np.array([1,0,0,1,0,0,0,0,1,0,0,0]), "A:major": np.array([0,1,0,0,1,0,0,0,0,1,0,0]), "Bb:major": np.array([0,0,1,0,0,1,0,0,0,0,1,0]), "B:major": np.array([0,0,0,1,0,0,1,0,0,0,0,1]), "c:minor": np.array([1,0,0,1,0,0,0,1,0,0,0,0]), "c#:minor": np.array([0,1,0,0,1,0,0,0,1,0,0,0]), "d:minor": np.array([0,0,1,0,0,1,0,0,0,1,0,0]), "eb:minor": np.array([0,0,0,1,0,0,1,0,0,0,1,0]), "e:minor": np.array([0,0,0,0,1,0,0,1,0,0,0,1]), "f:minor": np.array([1,0,0,0,0,1,0,0,1,0,0,0]), "f#:minor": np.array([0,1,0,0,0,0,1,0,0,1,0,0]), "g:minor": np.array([0,0,1,0,0,0,0,1,0,0,1,0]), "g#:minor": np.array([0,0,0,1,0,0,0,0,1,0,0,1]), "a:minor": np.array([1,0,0,0,1,0,0,0,0,1,0,0]), "bb:minor": np.array([0,1,0,0,0,1,0,0,0,0,1,0]), "b:minor": np.array([0,0,1,0,0,0,1,0,0,0,0,1]), } def edit_rhythm(piano_roll_full, num_notes_onset, mask_full, reduce_extra_notes=True): ''' piano_roll_full: a tensor with shape (batch_size, 2, length, h) # length=64 is length of roll, h is number of possible pitch num_notes_onset: a tensor with shape (batch_size, length) mask_full: a tensor with shape the same as piano_roll, corresponding to the 7 notes chroma reduce_extra_notes: True if want to reduce extra notes ''' ########## for those greater than the threshold, if num of notes exceed num_notes[i], ########## will keep the first ones and set others to threshold print("Coming") # we only edit onset onset_roll = piano_roll_full[:,0,:,:] mask = mask_full[:,0,:,:] shape = onset_roll.shape onset_roll = onset_roll.reshape(-1,shape[-1]) mask = mask.reshape(-1,shape[-1]) num_notes = num_notes_onset.reshape(-1) reduce_note_threshold = 0.499 increase_note_threshold = 0.501 # Initialize a tensor to store the modified values final_onset_roll = onset_roll.clone() ########### if number of notes > required, remove the extra notes ############### if reduce_extra_notes: threshold_mask = onset_roll > reduce_note_threshold # Set all values <= reduce_note_threshold to -inf to exclude them from top-k selection values_above_threshold = torch.where(threshold_mask & (mask == 1), onset_roll, torch.tensor(-float('inf')).to(onset_roll.device)) # Get the top num_notes.max() values for each row num_notes_max = int(num_notes.max().item()) # Maximum number of notes needed in any row topk_values, topk_indices = torch.topk(values_above_threshold, num_notes_max, dim=1) # Create a mask for the top num_notes[i] values for each row col_indices = torch.arange(num_notes_max, device=onset_roll.device).expand(len(onset_roll), num_notes_max) topk_mask = (col_indices < num_notes.unsqueeze(1)) & (topk_values > -float("inf")) # Set all values greater than reduce_note_threshold to reduce_note_threshold initially final_onset_roll[threshold_mask & (mask == 1)] = reduce_note_threshold # Create a flattened index to scatter the top values back into final_onset_roll flat_row_indices = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_indices) flat_row_indices = flat_row_indices[topk_mask] # Gather the valid topk_indices and corresponding values valid_topk_indices = topk_indices[topk_mask] valid_topk_values = topk_values[topk_mask] # Use scatter to place the top num_notes[i] values back to their original positions final_onset_roll = final_onset_roll.index_put_((flat_row_indices, valid_topk_indices), valid_topk_values) ########### if number of notes < required, add some notes ############### pitch_less_84_mask = torch.ones_like(mask) pitch_less_84_mask[:,51:] = 0 # Count how many values >= increase_note_threshold for each row threshold_mask_2 = (final_onset_roll >= increase_note_threshold)&(mask==1) greater_than_threshold2_count = threshold_mask_2.sum(dim=1) # For those rows, find the remaining number of values needed to be set to increase_note_threshold remaining_needed = num_notes - greater_than_threshold2_count remaining_needed_max = int(remaining_needed.max().item()) print("\n\n\n",remaining_needed_max,"\n\n\n") if remaining_needed_max>=0: # need to add notes # Find the values in each row that are < increase_note_threshold but are the highest (so we can set them to increase_note_threshold) values_below_threshold2 = torch.where((final_onset_roll < increase_note_threshold)&(mask==1)&(pitch_less_84_mask==1), final_onset_roll, torch.tensor(-float('inf')).to(onset_roll.device)) topk_below_threshold2_values, topk_below_threshold2_indices = torch.topk(values_below_threshold2, remaining_needed_max, dim=1) # Mask to only adjust the needed number of values in each row col_indices_below_threshold2 = torch.arange(remaining_needed_max, device=onset_roll.device).expand(len(onset_roll), remaining_needed_max) adjust_mask = (col_indices_below_threshold2 < remaining_needed.unsqueeze(1)) & (topk_below_threshold2_values > -float("inf")) # Flatten row indices for the new top-k below increase_note_threshold flat_row_indices_below_threshold2 = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_below_threshold2_indices) flat_row_indices_below_threshold2 = flat_row_indices_below_threshold2[adjust_mask] # Gather the valid indices and set them to increase_note_threshold valid_below_threshold2_indices = topk_below_threshold2_indices[adjust_mask] # Update the final_onset_roll to make sure we now have exactly num_notes[i] values >= increase_note_threshold final_onset_roll = final_onset_roll.index_put_((flat_row_indices_below_threshold2, valid_below_threshold2_indices), torch.tensor(increase_note_threshold, device=onset_roll.device)) final_onset_roll = final_onset_roll.reshape(shape) piano_roll_full[:,0,:,:] = final_onset_roll return piano_roll_full def X0EditFunc(x0, background_condition, sampler_device=device, reduce_extra_notes=True, rhythm_control="Yes"): # 预先计算 major 和 minor 和弦的所有旋转 maj_chd = torch.tensor([[1.,0,0,0,1,0,0,1,0,0,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device) maj_chd = torch.tile(maj_chd, (1, 64 // maj_chd.size(1) + 1)) min_chd = torch.tensor([[1.,0,0,0,1,0,0,0,0,1,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device) min_chd = torch.tile(min_chd, (1, 64 // min_chd.size(1) + 1)) # all chords, with rotation maj_chd_rotations = torch.stack([torch.roll(maj_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64] min_chd_rotations = torch.stack([torch.roll(min_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64] # combine all chords # chd_scale_map is a tensor with shape (N, 2, 64), N is total number of chord types, # 2 is (chord_chroma, corresponding_scale_chroma), 64 is number of possible notes chd_scale_map = torch.concat([maj_chd_rotations, min_chd_rotations], axis=0) # if using null rhythm condition, have to convert -2 to 1 and -1 to 0 if background_condition[:,:2,:,:].min()<0: correct_chord_condition = -background_condition[:,:2,:,:]-1 else: correct_chord_condition = background_condition[:,:2,:,:] merged_chd_roll = torch.max(correct_chord_condition[:,0,:,:], correct_chord_condition[:,1,:,:]) # chd roll of our bg_cond chd_chroma_ours = torch.clamp(merged_chd_roll, min=0.0, max=1.0) # chd chroma of our bg_cond shape = chd_chroma_ours.shape chd_chroma_ours = chd_chroma_ours.reshape(-1,64) matches = (chd_scale_map[:, 0, :].unsqueeze(0) - chd_chroma_ours.unsqueeze(1)>=0).all(dim=-1) seven_notes_chroma_ours = torch.einsum('ij,jk->ik', matches.float(), chd_scale_map[:, 1, :]).reshape(shape) seven_notes_chroma_ours = seven_notes_chroma_ours.unsqueeze(1).repeat((1,2,1,1)) no_chd_match = torch.all(seven_notes_chroma_ours == 0, dim=-1) seven_notes_chroma_ours[no_chd_match] = 1. # edit notes based on chroma x0 = torch.where((seven_notes_chroma_ours==0)&(x0>0), 0.0 , x0) print("See Coming?") # edit rhythm if (background_condition[:,:2,:,:].min()>=0) and (rhythm_control=="Yes"): # only edit if rhythm is provided num_onset_notes, _ = torch.max(background_condition[:,0,:,:], axis=-1) x0 = edit_rhythm(x0, num_onset_notes, seven_notes_chroma_ours, reduce_extra_notes) return x0 def expand_roll(roll, unit=4, contain_onset=False): # roll: (Channel, T, H) -> (Channel, T * unit, H) n_channel, length, height = roll.shape expanded_roll = roll.repeat(unit, axis=1) if contain_onset: expanded_roll = expanded_roll.reshape((n_channel, length, unit, height)) expanded_roll[1::2, :, 1:] = np.maximum(expanded_roll[::2, :, 1:], expanded_roll[1::2, :, 1:]) expanded_roll[::2, :, 1:] = 0 expanded_roll = expanded_roll.reshape((n_channel, length * unit, height)) return expanded_roll def cut_piano_roll(piano_roll, resolution=16, lowest=33, highest=96): piano_roll_cut = piano_roll[:,:,lowest:highest+1] return piano_roll_cut def circular_extend(chd_roll, lowest=33, highest=96): #chd_roll: 6*L*12->6*L*64 C4 = 60-lowest C3 = C4-12 shape = chd_roll.shape ext_chd = np.zeros((shape[0],shape[1],highest+1-lowest)) ext_chd[:,:,C4:C4+12] = chd_roll ext_chd[:,:,C3:C3+12] = chd_roll return ext_chd def default_quantization(v): return 1 if v > 0.5 else 0 def extend_piano_roll(piano_roll: np.ndarray, lowest=33, highest=96): ## this function is for extending the cutted piano rolls into the full 128 piano rolls ## recall that the piano rolls are of dimensions (2,L,64), we add zeros and fill it into (2,L,128) padded_roll = np.pad(piano_roll, ((0, 0), (0, 0), (lowest, 127-highest)), mode='constant', constant_values=0) return padded_roll def piano_roll_to_note_mat(piano_roll: np.ndarray, quantization_func=None): """ piano_roll: (2, L, 128), onset and sustain channel. raise_chord: whether pitch below 48 (mel-chd boundary) will be raised an octave """ def convert_p(p_, note_list): edit_note_flag = False for t in range(n_step): onset_state = quantization_func(piano_roll[0, t, p_]) sustain_state = quantization_func(piano_roll[1, t, p_]) is_onset = bool(onset_state) is_sustain = bool(sustain_state) and not is_onset pitch = p_ if is_onset: edit_note_flag = True note_list.append([t, pitch, 1]) elif is_sustain: if edit_note_flag: note_list[-1][-1] += 1 else: edit_note_flag = False return note_list quantization_func = default_quantization if quantization_func is None else quantization_func assert len(piano_roll.shape) == 3 and piano_roll.shape[0] == 2 and piano_roll.shape[2] == 128, f"{piano_roll.shape}" n_step = piano_roll.shape[1] notes = [] for p in range(128): convert_p(p, notes) return notes def note_mat_to_notes(note_mat, bpm, unit=1/4, shift_beat=0., shift_sec=0., vel=100): """Default use shift beat""" beat_alpha = 60 / bpm step_alpha = unit * beat_alpha notes = [] shift_sec = shift_sec if shift_beat is None else shift_beat * beat_alpha for note in note_mat: onset, pitch, dur = note start = onset * step_alpha + shift_sec end = (onset + dur) * step_alpha + shift_sec notes.append(pm.Note(vel, int(pitch), start, end)) return notes def create_pm_object(bpm, piano_notes_list, chd_notes_list, lsh_notes_list=None): midi = pm.PrettyMIDI(initial_tempo=bpm) piano_program = pm.instrument_name_to_program('Acoustic Grand Piano') piano = pm.Instrument(program=piano_program) piano.notes+=piano_notes_list midi.instruments.append(piano) # chd_program = pm.instrument_name_to_program('Violin') # chd = pm.Instrument(program=chd_program) # chd.notes+=chd_notes_list # midi.instruments.append(chd) if lsh_notes_list is not None: lsh_program = pm.instrument_name_to_program('Acoustic Grand Piano') lsh = pm.Instrument(program=lsh_program) lsh.notes+=lsh_notes_list midi.instruments.append(lsh) return midi def piano_roll_to_midi(piano_roll: np.ndarray, chd_roll: np.ndarray, lsh_roll=None, bpm=80): piano_mat = piano_roll_to_note_mat(piano_roll) piano_notes = note_mat_to_notes(piano_mat, bpm) chd_mat = piano_roll_to_note_mat(chd_roll) chd_notes = note_mat_to_notes(chd_mat, bpm) if lsh_roll is not None: lsh_mat = piano_roll_to_note_mat(lsh_roll) lsh_notes = note_mat_to_notes(lsh_mat, bpm) else: lsh_notes=None piano_pm = create_pm_object(bpm = 80, piano_notes_list=piano_notes, chd_notes_list=chd_notes, lsh_notes_list=lsh_notes) return piano_pm def save_midi(pm, filename): pm.write(filename)