asigalov61 commited on
Commit
403a65b
1 Parent(s): efe9672

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -105
app.py CHANGED
@@ -33,8 +33,7 @@ import matplotlib.pyplot as plt
33
  @spaces.GPU
34
  def Harmonize_Melody(input_src_midi,
35
  source_melody_transpose_value,
36
- harmonizer_melody_chunk_size,
37
- harmonizer_max_matches_count,
38
  melody_MIDI_patch_number,
39
  harmonized_accompaniment_MIDI_patch_number,
40
  base_MIDI_patch_number
@@ -48,13 +47,13 @@ def Harmonize_Melody(input_src_midi,
48
 
49
  sfn = os.path.basename(input_src_midi.name)
50
  sfn1 = sfn.split('.')[0]
 
51
  print('Input src MIDI name:', sfn)
52
 
53
  print('=' * 70)
54
  print('Requested settings:')
55
  print('Source melody transpose value:', source_melody_transpose_value)
56
- print('Harmonizer melody chunk size:', harmonizer_melody_chunk_size)
57
- print('Harmonizer max matrches count:', harmonizer_max_matches_count)
58
  print('Melody MIDI patch number:', melody_MIDI_patch_number)
59
  print('Harmonized accompaniment MIDI patch number:', harmonized_accompaniment_MIDI_patch_number)
60
  print('Base MIDI patch number:', base_MIDI_patch_number)
@@ -95,32 +94,6 @@ def Harmonize_Melody(input_src_midi,
95
  print('Melody has', len(mel_pitches), 'notes')
96
  print('=' * 70)
97
 
98
- #==================================================================
99
-
100
- print('=' * 70)
101
- print('Creating chords dict...')
102
- print('=' * 70)
103
-
104
- chords_groups = []
105
-
106
- for i in range(12):
107
- grp = []
108
- for c in TMIDIX.ALL_CHORDS_FILTERED:
109
- if i in c:
110
- grp.append(c)
111
-
112
- if grp:
113
- chords_groups.append(grp)
114
-
115
- max_grp_len = len(max(chords_groups, key=len))
116
-
117
- chords_groups_padded = []
118
-
119
- for c in chords_groups:
120
- grp = c + [[-1]] * (max_grp_len-len(c))
121
-
122
- chords_groups_padded.extend(grp)
123
-
124
  #===============================================================================
125
 
126
  print('=' * 70)
@@ -129,87 +102,80 @@ def Harmonize_Melody(input_src_midi,
129
 
130
  print('Loading Melody Harmonizer Transformer Model...')
131
 
 
 
132
 
133
-
134
- print('=' * 70)
135
- print('Harmonizing...')
136
- print('=' * 70)
137
-
138
- #===============================================================================
139
 
140
- song = []
 
 
 
 
141
 
142
- csize = harmonizer_melody_chunk_size
143
- matches_mem_size = harmonizer_max_matches_count
144
 
145
- i = 0
146
- dev = 0
147
- dchunk = []
148
 
149
- #===============================================================================
150
-
151
- def find_best_match(matches):
152
 
153
- mlens = []
154
 
155
- for sidx in matches:
156
- mlen = len(TMIDIX.flatten(long_chords_chunks_mult[sidx[0]][sidx[1]:sidx[1]+(csize // 2)]))
157
- mlens.append(mlen)
158
 
159
- max_len = max(mlens)
160
- max_len_idx = mlens.index(max_len)
161
 
162
- return matches[max_len_idx]
 
 
 
 
 
 
163
 
164
  #===============================================================================
165
 
166
- while i < len(mel_pitches):
167
 
168
- matches = []
169
 
170
- for midx, mel in enumerate(long_mels_chunks_mult):
171
- if len(mel) >= csize:
172
- schunk = mel_pitches[i:i+csize]
173
- idx = HaystackSearch.HaystackSearch(schunk, mel)
174
 
175
- if idx != -1:
176
- matches.append([midx, idx])
177
- if matches_mem_size > -1:
178
- if len(matches) > matches_mem_size:
179
- break
180
 
181
- if matches:
182
 
183
- sidx = find_best_match(matches)
184
 
185
- fchunk = long_chords_chunks_mult[sidx[0]][sidx[1]:sidx[1]+csize]
186
 
187
- song.extend(fchunk[:(csize // 2)])
188
- i += (csize // 2)
189
- dchunk = fchunk
190
- dev = 0
191
- print('step', i)
192
-
193
- else:
194
 
195
- if dchunk:
 
 
 
 
 
 
 
196
 
197
- song.append(dchunk[(csize // 2)+dev])
198
- dev += 1
199
- i += 1
200
- print('dead chord', i, dev)
201
- else:
202
- print('DEAD END!!!')
203
- song.append([mel_pitches[0]+48])
204
- break
205
 
 
206
 
207
- if dev == csize // 2:
208
- print('DEAD END!!!')
209
- break
210
 
211
- song = song[:len(mel_pitches)]
 
 
212
 
 
 
 
 
213
  print('Harmonized', len(song), 'out of', len(mel_pitches), 'notes')
214
 
215
  print('Done!')
@@ -233,22 +199,28 @@ def Harmonize_Melody(input_src_midi,
233
  patches[3] = melody_MIDI_patch_number
234
 
235
  for i, s in enumerate(song):
236
-
237
- time = mel_score[i][1] * 16
238
- dur = mel_score[i][2] * 16
239
 
240
- output_score.append(['note', time, dur, 3, mel_score[i][4], 115+(mel_score[i][4] % 12), 40])
241
 
242
- for p in s:
243
- output_score.append(['note', time, dur, 0, p, max(40, p), harmonized_accompaniment_MIDI_patch_number])
244
-
245
- if base_MIDI_patch_number > -1:
246
- output_score.append(['note', time, dur, 2, (s[-1] % 12)+24, 120-(s[-1] % 12), base_MIDI_patch_number])
 
 
 
 
 
 
 
 
 
247
 
248
- fn1 = "Monophonic-MIDI-Melody-Harmonizer-Composition"
249
 
250
  detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score,
251
- output_signature = 'Monophonic MIDI Melody Harmonizer',
252
  output_file_name = fn1,
253
  track_name='Project Los Angeles',
254
  list_of_MIDI_patches=patches
@@ -346,8 +318,7 @@ if __name__ == "__main__":
346
  gr.Markdown("## Select harmonization options")
347
 
348
  source_melody_transpose_value = gr.Slider(-6, 6, value=0, step=1, label="Source melody transpose value", info="You can transpose source melody by specified number of semitones if the original melody key does not harmonize well")
349
- harmonizer_melody_chunk_size = gr.Slider(4, 16, value=8, step=2, label="Hamonizer melody chunk size", info="Larger chunk sizes result in better harmonization at the cost of speed and harminzation length")
350
- harmonizer_max_matches_count = gr.Slider(-1, 20, value=0, step=1, label="Harmonizer max matches count", info="Maximum number of harmonized chords per melody note to collect and to select from")
351
 
352
  melody_MIDI_patch_number = gr.Slider(0, 127, value=40, step=1, label="Source melody MIDI patch number")
353
  harmonized_accompaniment_MIDI_patch_number = gr.Slider(0, 127, value=0, step=1, label="Harmonized accompaniment MIDI patch number")
@@ -370,8 +341,7 @@ if __name__ == "__main__":
370
  run_event = run_btn.click(Harmonize_Melody,
371
  [input_src_midi,
372
  source_melody_transpose_value,
373
- harmonizer_melody_chunk_size,
374
- harmonizer_max_matches_count,
375
  melody_MIDI_patch_number,
376
  harmonized_accompaniment_MIDI_patch_number,
377
  base_MIDI_patch_number],
@@ -380,18 +350,17 @@ if __name__ == "__main__":
380
 
381
  gr.Examples(
382
  [
383
- ["USSR Anthem Seed Melody.mid", 0, 12, -1, 40, 0, 35],
384
  ],
385
  [input_src_midi,
386
  source_melody_transpose_value,
387
- harmonizer_melody_chunk_size,
388
- harmonizer_max_matches_count,
389
  melody_MIDI_patch_number,
390
  harmonized_accompaniment_MIDI_patch_number,
391
  base_MIDI_patch_number],
392
  [output_audio, output_plot, output_midi, output_summary],
393
  Harmonize_Melody,
394
- cache_examples=False,
395
  )
396
 
397
  app.queue().launch()
 
33
  @spaces.GPU
34
  def Harmonize_Melody(input_src_midi,
35
  source_melody_transpose_value,
36
+ model_top_k_sampling_value,
 
37
  melody_MIDI_patch_number,
38
  harmonized_accompaniment_MIDI_patch_number,
39
  base_MIDI_patch_number
 
47
 
48
  sfn = os.path.basename(input_src_midi.name)
49
  sfn1 = sfn.split('.')[0]
50
+
51
  print('Input src MIDI name:', sfn)
52
 
53
  print('=' * 70)
54
  print('Requested settings:')
55
  print('Source melody transpose value:', source_melody_transpose_value)
56
+ print('Model top_k sampling value:', model_top_k_sampling_value)
 
57
  print('Melody MIDI patch number:', melody_MIDI_patch_number)
58
  print('Harmonized accompaniment MIDI patch number:', harmonized_accompaniment_MIDI_patch_number)
59
  print('Base MIDI patch number:', base_MIDI_patch_number)
 
94
  print('Melody has', len(mel_pitches), 'notes')
95
  print('=' * 70)
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  #===============================================================================
98
 
99
  print('=' * 70)
 
102
 
103
  print('Loading Melody Harmonizer Transformer Model...')
104
 
105
+ SEQ_LEN = 75
106
+ PAD_IDX = 144
107
 
108
+ # instantiate the model
 
 
 
 
 
109
 
110
+ model = TransformerWrapper(
111
+ num_tokens = PAD_IDX+1,
112
+ max_seq_len = SEQ_LEN,
113
+ attn_layers = Decoder(dim = 1024, depth = 12, heads = 16, attn_flash = True)
114
+ )
115
 
116
+ model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)
 
117
 
118
+ model_path = 'Melody_Harmonizer_Transformer_Trained_Model_7522_steps_0.6545_loss_0.7906_acc.pth'
 
 
119
 
120
+ model.load_state_dict(torch.load(model_path))
 
 
121
 
122
+ model.cuda()
123
 
124
+ dtype = torch.bfloat16
 
 
125
 
126
+ ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)
 
127
 
128
+ model.eval()
129
+
130
+ print('Done!')
131
+
132
+ print('=' * 70)
133
+ print('Harmonizing...')
134
+ print('=' * 70)
135
 
136
  #===============================================================================
137
 
138
+ mel_remainder_value = (((len(mel_pitches) // 24)+1) * 24) - len(mel_pitches)
139
 
140
+ mel_pitches_ext = mel_pitches + mel_pitches[:mel_remainder_value]
141
 
142
+ song = []
 
 
 
143
 
144
+ for i in range(0, len(mel_pitches_ext)-12, 12):
 
 
 
 
145
 
146
+ mel_chunk = mel_pitches_ext[i:i+24]
147
 
148
+ data = [141] + mel_chunk + [142]
149
 
150
+ for j in range(24):
151
 
152
+ data.append(mel_chunk[j])
153
+
154
+ x = torch.tensor([data], dtype=torch.long, device='cuda')
 
 
 
 
155
 
156
+ with ctx:
157
+ out = model.generate(x,
158
+ 1,
159
+ filter_logits_fn=top_k,
160
+ filter_kwargs={'k': model_top_k_sampling_value},
161
+ temperature=0.9,
162
+ return_prime=False,
163
+ verbose=False)
164
 
165
+ outy = out.tolist()[0]
 
 
 
 
 
 
 
166
 
167
+ data.append(outy[0])
168
 
169
+ if i != len(mel_pitches_ext)-24:
 
 
170
 
171
+ song.extend(data[26:50])
172
+ else:
173
+ song.extend(data[26:])
174
 
175
+ song = song[:len(mel_pitches) * 2]
176
+
177
+ #===============================================================================
178
+
179
  print('Harmonized', len(song), 'out of', len(mel_pitches), 'notes')
180
 
181
  print('Done!')
 
199
  patches[3] = melody_MIDI_patch_number
200
 
201
  for i, s in enumerate(song):
 
 
 
202
 
203
+ if 11 < s < 141:
204
 
205
+ time = mel_score[i][1] * 16
206
+ dur = mel_score[i][2] * 16
207
+
208
+ output_score.append(['note', time, dur, 3, mel_score[i][4], 115+(mel_score[i][4] % 12), 40])
209
+
210
+ chord = TMIDIX.ALL_CHORDS_FILTERED[s-12]
211
+
212
+ for c in chord:
213
+
214
+ pitch = 48+c
215
+ output_score.append(['note', time, dur, 0, pitch, max(40, pitch), harmonized_accompaniment_MIDI_patch_number])
216
+
217
+ if base_MIDI_patch_number > -1:
218
+ output_score.append(['note', time, dur, 2, chord[-1]+24, 120-chord[-1], base_MIDI_patch_number])
219
 
220
+ fn1 = "Melody-Harmonizer-Transformer-Composition"
221
 
222
  detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score,
223
+ output_signature = 'Melody Harmonizer Transformer',
224
  output_file_name = fn1,
225
  track_name='Project Los Angeles',
226
  list_of_MIDI_patches=patches
 
318
  gr.Markdown("## Select harmonization options")
319
 
320
  source_melody_transpose_value = gr.Slider(-6, 6, value=0, step=1, label="Source melody transpose value", info="You can transpose source melody by specified number of semitones if the original melody key does not harmonize well")
321
+ model_top_k_sampling_value = gr.Slider(1, 50, value=15, step=1, label="Model sampling top_k value", info="Decreasing this value may produce better harmonization results in some cases")
 
322
 
323
  melody_MIDI_patch_number = gr.Slider(0, 127, value=40, step=1, label="Source melody MIDI patch number")
324
  harmonized_accompaniment_MIDI_patch_number = gr.Slider(0, 127, value=0, step=1, label="Harmonized accompaniment MIDI patch number")
 
341
  run_event = run_btn.click(Harmonize_Melody,
342
  [input_src_midi,
343
  source_melody_transpose_value,
344
+ model_top_k_sampling_value,
 
345
  melody_MIDI_patch_number,
346
  harmonized_accompaniment_MIDI_patch_number,
347
  base_MIDI_patch_number],
 
350
 
351
  gr.Examples(
352
  [
353
+ ["USSR Anthem Seed Melody.mid", 0, 15, 40, 0, 35],
354
  ],
355
  [input_src_midi,
356
  source_melody_transpose_value,
357
+ model_top_k_sampling_value,
 
358
  melody_MIDI_patch_number,
359
  harmonized_accompaniment_MIDI_patch_number,
360
  base_MIDI_patch_number],
361
  [output_audio, output_plot, output_midi, output_summary],
362
  Harmonize_Melody,
363
+ cache_examples=True,
364
  )
365
 
366
  app.queue().launch()