File size: 13,828 Bytes
bf8981a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 |
import torch
import numpy as np
import pretty_midi as pm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHORD_DICTIONARY = {
"C:major": np.array([1,0,0,0,1,0,0,1,0,0,0,0]),
"C#:major": np.array([0,1,0,0,0,1,0,0,1,0,0,0]),
"D:major": np.array([0,0,1,0,0,0,1,0,0,1,0,0]),
"Eb:major": np.array([0,0,0,1,0,0,0,1,0,0,1,0]),
"E:major": np.array([0,0,0,0,1,0,0,0,1,0,0,1]),
"F:major": np.array([1,0,0,0,0,1,0,0,0,1,0,0]),
"F#:major": np.array([0,1,0,0,0,0,1,0,0,0,1,0]),
"G:major": np.array([0,0,1,0,0,0,0,1,0,0,0,1]),
"Ab:major": np.array([1,0,0,1,0,0,0,0,1,0,0,0]),
"A:major": np.array([0,1,0,0,1,0,0,0,0,1,0,0]),
"Bb:major": np.array([0,0,1,0,0,1,0,0,0,0,1,0]),
"B:major": np.array([0,0,0,1,0,0,1,0,0,0,0,1]),
"c:minor": np.array([1,0,0,1,0,0,0,1,0,0,0,0]),
"c#:minor": np.array([0,1,0,0,1,0,0,0,1,0,0,0]),
"d:minor": np.array([0,0,1,0,0,1,0,0,0,1,0,0]),
"eb:minor": np.array([0,0,0,1,0,0,1,0,0,0,1,0]),
"e:minor": np.array([0,0,0,0,1,0,0,1,0,0,0,1]),
"f:minor": np.array([1,0,0,0,0,1,0,0,1,0,0,0]),
"f#:minor": np.array([0,1,0,0,0,0,1,0,0,1,0,0]),
"g:minor": np.array([0,0,1,0,0,0,0,1,0,0,1,0]),
"g#:minor": np.array([0,0,0,1,0,0,0,0,1,0,0,1]),
"a:minor": np.array([1,0,0,0,1,0,0,0,0,1,0,0]),
"bb:minor": np.array([0,1,0,0,0,1,0,0,0,0,1,0]),
"b:minor": np.array([0,0,1,0,0,0,1,0,0,0,0,1]),
}
def edit_rhythm(piano_roll_full, num_notes_onset, mask_full, reduce_extra_notes=True):
'''
piano_roll_full: a tensor with shape (batch_size, 2, length, h) # length=64 is length of roll, h is number of possible pitch
num_notes_onset: a tensor with shape (batch_size, length)
mask_full: a tensor with shape the same as piano_roll, corresponding to the 7 notes chroma
reduce_extra_notes: True if want to reduce extra notes
'''
########## for those greater than the threshold, if num of notes exceed num_notes[i],
########## will keep the first ones and set others to threshold
print("Coming")
# we only edit onset
onset_roll = piano_roll_full[:,0,:,:]
mask = mask_full[:,0,:,:]
shape = onset_roll.shape
onset_roll = onset_roll.reshape(-1,shape[-1])
mask = mask.reshape(-1,shape[-1])
num_notes = num_notes_onset.reshape(-1)
reduce_note_threshold = 0.499
increase_note_threshold = 0.501
# Initialize a tensor to store the modified values
final_onset_roll = onset_roll.clone()
########### if number of notes > required, remove the extra notes ###############
if reduce_extra_notes:
threshold_mask = onset_roll > reduce_note_threshold
# Set all values <= reduce_note_threshold to -inf to exclude them from top-k selection
values_above_threshold = torch.where(threshold_mask & (mask == 1), onset_roll, torch.tensor(-float('inf')).to(onset_roll.device))
# Get the top num_notes.max() values for each row
num_notes_max = int(num_notes.max().item()) # Maximum number of notes needed in any row
topk_values, topk_indices = torch.topk(values_above_threshold, num_notes_max, dim=1)
# Create a mask for the top num_notes[i] values for each row
col_indices = torch.arange(num_notes_max, device=onset_roll.device).expand(len(onset_roll), num_notes_max)
topk_mask = (col_indices < num_notes.unsqueeze(1)) & (topk_values > -float("inf"))
# Set all values greater than reduce_note_threshold to reduce_note_threshold initially
final_onset_roll[threshold_mask & (mask == 1)] = reduce_note_threshold
# Create a flattened index to scatter the top values back into final_onset_roll
flat_row_indices = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_indices)
flat_row_indices = flat_row_indices[topk_mask]
# Gather the valid topk_indices and corresponding values
valid_topk_indices = topk_indices[topk_mask]
valid_topk_values = topk_values[topk_mask]
# Use scatter to place the top num_notes[i] values back to their original positions
final_onset_roll = final_onset_roll.index_put_((flat_row_indices, valid_topk_indices), valid_topk_values)
########### if number of notes < required, add some notes ###############
pitch_less_84_mask = torch.ones_like(mask)
pitch_less_84_mask[:,51:] = 0
# Count how many values >= increase_note_threshold for each row
threshold_mask_2 = (final_onset_roll >= increase_note_threshold)&(mask==1)
greater_than_threshold2_count = threshold_mask_2.sum(dim=1)
# For those rows, find the remaining number of values needed to be set to increase_note_threshold
remaining_needed = num_notes - greater_than_threshold2_count
remaining_needed_max = int(remaining_needed.max().item())
print("\n\n\n",remaining_needed_max,"\n\n\n")
if remaining_needed_max>=0: # need to add notes
# Find the values in each row that are < increase_note_threshold but are the highest (so we can set them to increase_note_threshold)
values_below_threshold2 = torch.where((final_onset_roll < increase_note_threshold)&(mask==1)&(pitch_less_84_mask==1), final_onset_roll, torch.tensor(-float('inf')).to(onset_roll.device))
topk_below_threshold2_values, topk_below_threshold2_indices = torch.topk(values_below_threshold2, remaining_needed_max, dim=1)
# Mask to only adjust the needed number of values in each row
col_indices_below_threshold2 = torch.arange(remaining_needed_max, device=onset_roll.device).expand(len(onset_roll), remaining_needed_max)
adjust_mask = (col_indices_below_threshold2 < remaining_needed.unsqueeze(1)) & (topk_below_threshold2_values > -float("inf"))
# Flatten row indices for the new top-k below increase_note_threshold
flat_row_indices_below_threshold2 = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_below_threshold2_indices)
flat_row_indices_below_threshold2 = flat_row_indices_below_threshold2[adjust_mask]
# Gather the valid indices and set them to increase_note_threshold
valid_below_threshold2_indices = topk_below_threshold2_indices[adjust_mask]
# Update the final_onset_roll to make sure we now have exactly num_notes[i] values >= increase_note_threshold
final_onset_roll = final_onset_roll.index_put_((flat_row_indices_below_threshold2, valid_below_threshold2_indices), torch.tensor(increase_note_threshold, device=onset_roll.device))
final_onset_roll = final_onset_roll.reshape(shape)
piano_roll_full[:,0,:,:] = final_onset_roll
return piano_roll_full
def X0EditFunc(x0, background_condition, sampler_device=device, reduce_extra_notes=True, rhythm_control="Yes"):
# 预先计算 major 和 minor 和弦的所有旋转
maj_chd = torch.tensor([[1.,0,0,0,1,0,0,1,0,0,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device)
maj_chd = torch.tile(maj_chd, (1, 64 // maj_chd.size(1) + 1))
min_chd = torch.tensor([[1.,0,0,0,1,0,0,0,0,1,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device)
min_chd = torch.tile(min_chd, (1, 64 // min_chd.size(1) + 1))
# all chords, with rotation
maj_chd_rotations = torch.stack([torch.roll(maj_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64]
min_chd_rotations = torch.stack([torch.roll(min_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64]
# combine all chords
# chd_scale_map is a tensor with shape (N, 2, 64), N is total number of chord types,
# 2 is (chord_chroma, corresponding_scale_chroma), 64 is number of possible notes
chd_scale_map = torch.concat([maj_chd_rotations, min_chd_rotations], axis=0)
# if using null rhythm condition, have to convert -2 to 1 and -1 to 0
if background_condition[:,:2,:,:].min()<0:
correct_chord_condition = -background_condition[:,:2,:,:]-1
else:
correct_chord_condition = background_condition[:,:2,:,:]
merged_chd_roll = torch.max(correct_chord_condition[:,0,:,:], correct_chord_condition[:,1,:,:]) # chd roll of our bg_cond
chd_chroma_ours = torch.clamp(merged_chd_roll, min=0.0, max=1.0) # chd chroma of our bg_cond
shape = chd_chroma_ours.shape
chd_chroma_ours = chd_chroma_ours.reshape(-1,64)
matches = (chd_scale_map[:, 0, :].unsqueeze(0) - chd_chroma_ours.unsqueeze(1)>=0).all(dim=-1)
seven_notes_chroma_ours = torch.einsum('ij,jk->ik', matches.float(), chd_scale_map[:, 1, :]).reshape(shape)
seven_notes_chroma_ours = seven_notes_chroma_ours.unsqueeze(1).repeat((1,2,1,1))
no_chd_match = torch.all(seven_notes_chroma_ours == 0, dim=-1)
seven_notes_chroma_ours[no_chd_match] = 1.
# edit notes based on chroma
x0 = torch.where((seven_notes_chroma_ours==0)&(x0>0), 0.0 , x0)
print("See Coming?")
# edit rhythm
if (background_condition[:,:2,:,:].min()>=0) and (rhythm_control=="Yes"): # only edit if rhythm is provided
num_onset_notes, _ = torch.max(background_condition[:,0,:,:], axis=-1)
x0 = edit_rhythm(x0, num_onset_notes, seven_notes_chroma_ours, reduce_extra_notes)
return x0
def expand_roll(roll, unit=4, contain_onset=False):
# roll: (Channel, T, H) -> (Channel, T * unit, H)
n_channel, length, height = roll.shape
expanded_roll = roll.repeat(unit, axis=1)
if contain_onset:
expanded_roll = expanded_roll.reshape((n_channel, length, unit, height))
expanded_roll[1::2, :, 1:] = np.maximum(expanded_roll[::2, :, 1:], expanded_roll[1::2, :, 1:])
expanded_roll[::2, :, 1:] = 0
expanded_roll = expanded_roll.reshape((n_channel, length * unit, height))
return expanded_roll
def cut_piano_roll(piano_roll, resolution=16, lowest=33, highest=96):
piano_roll_cut = piano_roll[:,:,lowest:highest+1]
return piano_roll_cut
def circular_extend(chd_roll, lowest=33, highest=96):
#chd_roll: 6*L*12->6*L*64
C4 = 60-lowest
C3 = C4-12
shape = chd_roll.shape
ext_chd = np.zeros((shape[0],shape[1],highest+1-lowest))
ext_chd[:,:,C4:C4+12] = chd_roll
ext_chd[:,:,C3:C3+12] = chd_roll
return ext_chd
def default_quantization(v):
return 1 if v > 0.5 else 0
def extend_piano_roll(piano_roll: np.ndarray, lowest=33, highest=96):
## this function is for extending the cutted piano rolls into the full 128 piano rolls
## recall that the piano rolls are of dimensions (2,L,64), we add zeros and fill it into (2,L,128)
padded_roll = np.pad(piano_roll, ((0, 0), (0, 0), (lowest, 127-highest)), mode='constant', constant_values=0)
return padded_roll
def piano_roll_to_note_mat(piano_roll: np.ndarray, quantization_func=None):
"""
piano_roll: (2, L, 128), onset and sustain channel.
raise_chord: whether pitch below 48 (mel-chd boundary) will be raised an octave
"""
def convert_p(p_, note_list):
edit_note_flag = False
for t in range(n_step):
onset_state = quantization_func(piano_roll[0, t, p_])
sustain_state = quantization_func(piano_roll[1, t, p_])
is_onset = bool(onset_state)
is_sustain = bool(sustain_state) and not is_onset
pitch = p_
if is_onset:
edit_note_flag = True
note_list.append([t, pitch, 1])
elif is_sustain:
if edit_note_flag:
note_list[-1][-1] += 1
else:
edit_note_flag = False
return note_list
quantization_func = default_quantization if quantization_func is None else quantization_func
assert len(piano_roll.shape) == 3 and piano_roll.shape[0] == 2 and piano_roll.shape[2] == 128, f"{piano_roll.shape}"
n_step = piano_roll.shape[1]
notes = []
for p in range(128):
convert_p(p, notes)
return notes
def note_mat_to_notes(note_mat, bpm, unit=1/4, shift_beat=0., shift_sec=0., vel=100):
"""Default use shift beat"""
beat_alpha = 60 / bpm
step_alpha = unit * beat_alpha
notes = []
shift_sec = shift_sec if shift_beat is None else shift_beat * beat_alpha
for note in note_mat:
onset, pitch, dur = note
start = onset * step_alpha + shift_sec
end = (onset + dur) * step_alpha + shift_sec
notes.append(pm.Note(vel, int(pitch), start, end))
return notes
def create_pm_object(bpm, piano_notes_list, chd_notes_list, lsh_notes_list=None):
midi = pm.PrettyMIDI(initial_tempo=bpm)
piano_program = pm.instrument_name_to_program('Acoustic Grand Piano')
piano = pm.Instrument(program=piano_program)
piano.notes+=piano_notes_list
midi.instruments.append(piano)
# chd_program = pm.instrument_name_to_program('Violin')
# chd = pm.Instrument(program=chd_program)
# chd.notes+=chd_notes_list
# midi.instruments.append(chd)
if lsh_notes_list is not None:
lsh_program = pm.instrument_name_to_program('Acoustic Grand Piano')
lsh = pm.Instrument(program=lsh_program)
lsh.notes+=lsh_notes_list
midi.instruments.append(lsh)
return midi
def piano_roll_to_midi(piano_roll: np.ndarray, chd_roll: np.ndarray, lsh_roll=None, bpm=80):
piano_mat = piano_roll_to_note_mat(piano_roll)
piano_notes = note_mat_to_notes(piano_mat, bpm)
chd_mat = piano_roll_to_note_mat(chd_roll)
chd_notes = note_mat_to_notes(chd_mat, bpm)
if lsh_roll is not None:
lsh_mat = piano_roll_to_note_mat(lsh_roll)
lsh_notes = note_mat_to_notes(lsh_mat, bpm)
else:
lsh_notes=None
piano_pm = create_pm_object(bpm = 80, piano_notes_list=piano_notes,
chd_notes_list=chd_notes, lsh_notes_list=lsh_notes)
return piano_pm
def save_midi(pm, filename):
pm.write(filename) |