File size: 13,828 Bytes
bf8981a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import torch
import numpy as np
import pretty_midi as pm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CHORD_DICTIONARY = {
    "C:major": np.array([1,0,0,0,1,0,0,1,0,0,0,0]),
    "C#:major": np.array([0,1,0,0,0,1,0,0,1,0,0,0]),
    "D:major": np.array([0,0,1,0,0,0,1,0,0,1,0,0]),
    "Eb:major": np.array([0,0,0,1,0,0,0,1,0,0,1,0]),
    "E:major": np.array([0,0,0,0,1,0,0,0,1,0,0,1]),
    "F:major": np.array([1,0,0,0,0,1,0,0,0,1,0,0]),
    "F#:major": np.array([0,1,0,0,0,0,1,0,0,0,1,0]),
    "G:major": np.array([0,0,1,0,0,0,0,1,0,0,0,1]),
    "Ab:major": np.array([1,0,0,1,0,0,0,0,1,0,0,0]),
    "A:major": np.array([0,1,0,0,1,0,0,0,0,1,0,0]),
    "Bb:major": np.array([0,0,1,0,0,1,0,0,0,0,1,0]),
    "B:major": np.array([0,0,0,1,0,0,1,0,0,0,0,1]),

    "c:minor": np.array([1,0,0,1,0,0,0,1,0,0,0,0]),
    "c#:minor": np.array([0,1,0,0,1,0,0,0,1,0,0,0]),
    "d:minor": np.array([0,0,1,0,0,1,0,0,0,1,0,0]),
    "eb:minor": np.array([0,0,0,1,0,0,1,0,0,0,1,0]),
    "e:minor": np.array([0,0,0,0,1,0,0,1,0,0,0,1]),
    "f:minor": np.array([1,0,0,0,0,1,0,0,1,0,0,0]),
    "f#:minor": np.array([0,1,0,0,0,0,1,0,0,1,0,0]),
    "g:minor": np.array([0,0,1,0,0,0,0,1,0,0,1,0]),
    "g#:minor": np.array([0,0,0,1,0,0,0,0,1,0,0,1]),
    "a:minor": np.array([1,0,0,0,1,0,0,0,0,1,0,0]),
    "bb:minor": np.array([0,1,0,0,0,1,0,0,0,0,1,0]),
    "b:minor": np.array([0,0,1,0,0,0,1,0,0,0,0,1]),
}


def edit_rhythm(piano_roll_full, num_notes_onset, mask_full, reduce_extra_notes=True):
    '''
    piano_roll_full: a tensor with shape (batch_size, 2, length, h) # length=64 is length of roll, h is number of possible pitch
    num_notes_onset: a tensor with shape (batch_size, length)
    mask_full: a tensor with shape the same as piano_roll, corresponding to the 7 notes chroma
    reduce_extra_notes: True if want to reduce extra notes
    '''
    ########## for those greater than the threshold, if num of notes exceed num_notes[i], 
    ########## will keep the first ones and set others to threshold
    print("Coming")
    # we only edit onset
    onset_roll = piano_roll_full[:,0,:,:]
    mask = mask_full[:,0,:,:]
    shape = onset_roll.shape

    onset_roll = onset_roll.reshape(-1,shape[-1])
    mask = mask.reshape(-1,shape[-1])
    num_notes = num_notes_onset.reshape(-1)

    reduce_note_threshold = 0.499
    increase_note_threshold = 0.501

    # Initialize a tensor to store the modified values
    final_onset_roll = onset_roll.clone()

    ########### if number of notes > required, remove the extra notes ###############
    if reduce_extra_notes:
        threshold_mask = onset_roll > reduce_note_threshold
        # Set all values <= reduce_note_threshold to -inf to exclude them from top-k selection
        values_above_threshold = torch.where(threshold_mask & (mask == 1), onset_roll, torch.tensor(-float('inf')).to(onset_roll.device))

        # Get the top num_notes.max() values for each row
        num_notes_max = int(num_notes.max().item())  # Maximum number of notes needed in any row
        topk_values, topk_indices = torch.topk(values_above_threshold, num_notes_max, dim=1)

        # Create a mask for the top num_notes[i] values for each row
        col_indices = torch.arange(num_notes_max, device=onset_roll.device).expand(len(onset_roll), num_notes_max)
        topk_mask = (col_indices < num_notes.unsqueeze(1)) & (topk_values > -float("inf"))

        # Set all values greater than reduce_note_threshold to reduce_note_threshold initially
        final_onset_roll[threshold_mask & (mask == 1)] = reduce_note_threshold

        # Create a flattened index to scatter the top values back into final_onset_roll
        flat_row_indices = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_indices)
        flat_row_indices = flat_row_indices[topk_mask]

        # Gather the valid topk_indices and corresponding values
        valid_topk_indices = topk_indices[topk_mask]
        valid_topk_values = topk_values[topk_mask]

        # Use scatter to place the top num_notes[i] values back to their original positions
        final_onset_roll = final_onset_roll.index_put_((flat_row_indices, valid_topk_indices), valid_topk_values)

    ########### if number of notes < required, add some notes ###############
    pitch_less_84_mask = torch.ones_like(mask)
    pitch_less_84_mask[:,51:] = 0

    # Count how many values >= increase_note_threshold for each row
    threshold_mask_2 = (final_onset_roll >= increase_note_threshold)&(mask==1)
    greater_than_threshold2_count = threshold_mask_2.sum(dim=1)

    # For those rows, find the remaining number of values needed to be set to increase_note_threshold
    remaining_needed = num_notes - greater_than_threshold2_count
    remaining_needed_max = int(remaining_needed.max().item())
    print("\n\n\n",remaining_needed_max,"\n\n\n")
    if remaining_needed_max>=0: # need to add notes
        # Find the values in each row that are < increase_note_threshold but are the highest (so we can set them to increase_note_threshold)
        values_below_threshold2 = torch.where((final_onset_roll < increase_note_threshold)&(mask==1)&(pitch_less_84_mask==1), final_onset_roll, torch.tensor(-float('inf')).to(onset_roll.device))
        topk_below_threshold2_values, topk_below_threshold2_indices = torch.topk(values_below_threshold2, remaining_needed_max, dim=1)

        # Mask to only adjust the needed number of values in each row
        col_indices_below_threshold2 = torch.arange(remaining_needed_max, device=onset_roll.device).expand(len(onset_roll), remaining_needed_max)
        adjust_mask = (col_indices_below_threshold2 < remaining_needed.unsqueeze(1)) & (topk_below_threshold2_values > -float("inf"))

        # Flatten row indices for the new top-k below increase_note_threshold
        flat_row_indices_below_threshold2 = torch.arange(onset_roll.size(0), device=onset_roll.device).unsqueeze(1).expand_as(topk_below_threshold2_indices)
        flat_row_indices_below_threshold2 = flat_row_indices_below_threshold2[adjust_mask]

        # Gather the valid indices and set them to increase_note_threshold
        valid_below_threshold2_indices = topk_below_threshold2_indices[adjust_mask]

        # Update the final_onset_roll to make sure we now have exactly num_notes[i] values >= increase_note_threshold
        final_onset_roll = final_onset_roll.index_put_((flat_row_indices_below_threshold2, valid_below_threshold2_indices), torch.tensor(increase_note_threshold, device=onset_roll.device))
        
    final_onset_roll = final_onset_roll.reshape(shape)
    piano_roll_full[:,0,:,:] = final_onset_roll
    return piano_roll_full

def X0EditFunc(x0, background_condition, sampler_device=device, reduce_extra_notes=True, rhythm_control="Yes"):
    # 预先计算 major 和 minor 和弦的所有旋转
    maj_chd = torch.tensor([[1.,0,0,0,1,0,0,1,0,0,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device)
    maj_chd = torch.tile(maj_chd, (1, 64 // maj_chd.size(1) + 1))
    min_chd = torch.tensor([[1.,0,0,0,1,0,0,0,0,1,0,0],[1,0,1,0,1,1,0,1,0,1,0,1]], device=sampler_device)
    min_chd = torch.tile(min_chd, (1, 64 // min_chd.size(1) + 1))
    
    # all chords, with rotation
    maj_chd_rotations = torch.stack([torch.roll(maj_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64]
    min_chd_rotations = torch.stack([torch.roll(min_chd, shifts=-i, dims=1) for i in range(12)], dim=0)[:,:,:64]
    
    # combine all chords
    # chd_scale_map is a tensor with shape (N, 2, 64), N is total number of chord types, 
    # 2 is (chord_chroma, corresponding_scale_chroma), 64 is number of possible notes
    chd_scale_map = torch.concat([maj_chd_rotations, min_chd_rotations], axis=0)

    # if using null rhythm condition, have to convert -2 to 1 and -1 to 0
    if background_condition[:,:2,:,:].min()<0:
        correct_chord_condition = -background_condition[:,:2,:,:]-1
    else:
        correct_chord_condition = background_condition[:,:2,:,:]
    merged_chd_roll = torch.max(correct_chord_condition[:,0,:,:], correct_chord_condition[:,1,:,:]) # chd roll of our bg_cond
    chd_chroma_ours = torch.clamp(merged_chd_roll, min=0.0, max=1.0) # chd chroma of our bg_cond
    shape = chd_chroma_ours.shape
    chd_chroma_ours = chd_chroma_ours.reshape(-1,64)
    matches = (chd_scale_map[:, 0, :].unsqueeze(0) - chd_chroma_ours.unsqueeze(1)>=0).all(dim=-1)
    seven_notes_chroma_ours = torch.einsum('ij,jk->ik', matches.float(), chd_scale_map[:, 1, :]).reshape(shape)
    seven_notes_chroma_ours = seven_notes_chroma_ours.unsqueeze(1).repeat((1,2,1,1))

    no_chd_match = torch.all(seven_notes_chroma_ours == 0, dim=-1)
    seven_notes_chroma_ours[no_chd_match] = 1.
    
    # edit notes based on chroma
    x0 = torch.where((seven_notes_chroma_ours==0)&(x0>0), 0.0 , x0)
    print("See Coming?")
    # edit rhythm
    if (background_condition[:,:2,:,:].min()>=0) and (rhythm_control=="Yes"): # only edit if rhythm is provided
        num_onset_notes, _ = torch.max(background_condition[:,0,:,:], axis=-1)
        x0 = edit_rhythm(x0, num_onset_notes, seven_notes_chroma_ours, reduce_extra_notes)

    return x0

def expand_roll(roll, unit=4, contain_onset=False):
    # roll: (Channel, T, H) -> (Channel, T * unit, H)
    n_channel, length, height = roll.shape

    expanded_roll = roll.repeat(unit, axis=1)
    if contain_onset:
        expanded_roll = expanded_roll.reshape((n_channel, length, unit, height))
        expanded_roll[1::2, :, 1:] = np.maximum(expanded_roll[::2, :, 1:], expanded_roll[1::2, :, 1:])

        expanded_roll[::2, :, 1:] = 0
        expanded_roll = expanded_roll.reshape((n_channel, length * unit, height))
    return expanded_roll

def cut_piano_roll(piano_roll, resolution=16, lowest=33, highest=96):
    piano_roll_cut = piano_roll[:,:,lowest:highest+1]
    return piano_roll_cut

def circular_extend(chd_roll, lowest=33, highest=96):
    #chd_roll: 6*L*12->6*L*64
    C4 = 60-lowest
    C3 = C4-12
    shape = chd_roll.shape
    ext_chd = np.zeros((shape[0],shape[1],highest+1-lowest))
    ext_chd[:,:,C4:C4+12] = chd_roll
    ext_chd[:,:,C3:C3+12] = chd_roll
    return ext_chd


def default_quantization(v):
    return 1 if v > 0.5 else 0

def extend_piano_roll(piano_roll: np.ndarray, lowest=33, highest=96):
    ## this function is for extending the cutted piano rolls into the full 128 piano rolls
    ## recall that the piano rolls are of dimensions (2,L,64), we add zeros and fill it into (2,L,128)
    padded_roll = np.pad(piano_roll, ((0, 0), (0, 0), (lowest, 127-highest)), mode='constant', constant_values=0)
    return padded_roll



def piano_roll_to_note_mat(piano_roll: np.ndarray, quantization_func=None):
    """
    piano_roll: (2, L, 128), onset and sustain channel.
    raise_chord: whether pitch below 48 (mel-chd boundary) will be raised an octave
    """
    def convert_p(p_, note_list):
        edit_note_flag = False
        for t in range(n_step):
            onset_state = quantization_func(piano_roll[0, t, p_])
            sustain_state = quantization_func(piano_roll[1, t, p_])

            is_onset = bool(onset_state)
            is_sustain = bool(sustain_state) and not is_onset

            pitch = p_ 

            if is_onset:
                edit_note_flag = True
                note_list.append([t, pitch, 1])
            elif is_sustain:
                if edit_note_flag:
                    note_list[-1][-1] += 1
            else:
                edit_note_flag = False
        return note_list

    quantization_func = default_quantization if quantization_func is None else quantization_func
    assert len(piano_roll.shape) == 3 and piano_roll.shape[0] == 2 and piano_roll.shape[2] == 128, f"{piano_roll.shape}" 

    n_step = piano_roll.shape[1]

    notes = []
    for p in range(128):
            convert_p(p, notes)

    return notes


def note_mat_to_notes(note_mat, bpm, unit=1/4, shift_beat=0., shift_sec=0., vel=100):
    """Default use shift beat"""

    beat_alpha = 60 / bpm
    step_alpha = unit * beat_alpha

    notes = []

    shift_sec = shift_sec if shift_beat is None else shift_beat * beat_alpha

    for note in note_mat:
        onset, pitch, dur = note
        start = onset * step_alpha + shift_sec
        end = (onset + dur) * step_alpha + shift_sec

        notes.append(pm.Note(vel, int(pitch), start, end))

    return notes


def create_pm_object(bpm, piano_notes_list, chd_notes_list, lsh_notes_list=None):
    midi = pm.PrettyMIDI(initial_tempo=bpm)

    piano_program = pm.instrument_name_to_program('Acoustic Grand Piano')
    piano = pm.Instrument(program=piano_program)
    piano.notes+=piano_notes_list
    midi.instruments.append(piano)

    # chd_program = pm.instrument_name_to_program('Violin')
    # chd = pm.Instrument(program=chd_program)
    # chd.notes+=chd_notes_list
    # midi.instruments.append(chd)

    if lsh_notes_list is not None:
        lsh_program = pm.instrument_name_to_program('Acoustic Grand Piano')
        lsh = pm.Instrument(program=lsh_program)
        lsh.notes+=lsh_notes_list
        midi.instruments.append(lsh)

    return midi

def piano_roll_to_midi(piano_roll: np.ndarray, chd_roll: np.ndarray, lsh_roll=None, bpm=80):
    piano_mat = piano_roll_to_note_mat(piano_roll)
    piano_notes = note_mat_to_notes(piano_mat, bpm)

    chd_mat = piano_roll_to_note_mat(chd_roll)
    chd_notes = note_mat_to_notes(chd_mat, bpm)

    if lsh_roll is not None:
        lsh_mat = piano_roll_to_note_mat(lsh_roll)
        lsh_notes = note_mat_to_notes(lsh_mat, bpm)
    else:
        lsh_notes=None

    piano_pm = create_pm_object(bpm = 80, piano_notes_list=piano_notes, 
                                chd_notes_list=chd_notes, lsh_notes_list=lsh_notes)
    return piano_pm

def save_midi(pm, filename):
    pm.write(filename)