asigalov61 commited on
Commit
2dbabd8
·
verified ·
1 Parent(s): eed841c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -292
app.py CHANGED
@@ -30,256 +30,48 @@ def TranscribePianoAudio(input_audio):
30
  print('=' * 70)
31
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
32
  start_time = reqtime.time()
33
-
34
- print('Loading model...')
35
-
36
- SEQ_LEN = 8192 # Models seq len
37
- PAD_IDX = 19463 # Models pad index
38
- DEVICE = 'cuda' # 'cuda'
39
-
40
- # instantiate the model
41
-
42
- model = TransformerWrapper(
43
- num_tokens = PAD_IDX+1,
44
- max_seq_len = SEQ_LEN,
45
- attn_layers = Decoder(dim = 1024, depth = 32, heads = 32, attn_flash = True)
46
- )
47
-
48
- model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
49
-
50
- model.to(DEVICE)
51
- print('=' * 70)
52
-
53
- print('Loading model checkpoint...')
54
-
55
- model.load_state_dict(
56
- torch.load('Giant_Music_Transformer_Large_Trained_Model_36074_steps_0.3067_loss_0.927_acc.pth',
57
- map_location=DEVICE))
58
- print('=' * 70)
59
-
60
- model.eval()
61
-
62
- if DEVICE == 'cpu':
63
- dtype = torch.bfloat16
64
- else:
65
- dtype = torch.bfloat16
66
-
67
- ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
68
-
69
- print('Done!')
70
  print('=' * 70)
71
 
72
- fn = os.path.basename(input_midi.name)
 
73
  fn1 = fn.split('.')[0]
74
 
75
  input_num_of_notes = max(8, min(2048, input_num_of_notes))
76
 
77
  print('-' * 70)
78
  print('Input file name:', fn)
79
- print('Req num of notes:', input_num_of_notes)
80
- print('Req patch number:', input_patch_number)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  print('-' * 70)
82
 
83
  #===============================================================================
84
- raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
85
 
86
  #===============================================================================
87
  # Enhanced score notes
88
 
89
- events_matrix1 = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
90
-
91
- #=======================================================
92
- # PRE-PROCESSING
93
-
94
- # checking number of instruments in a composition
95
- instruments_list_without_drums = list(set([y[3] for y in events_matrix1 if y[3] != 9]))
96
- instruments_list = list(set([y[3] for y in events_matrix1]))
97
-
98
- if len(events_matrix1) > 0 and len(instruments_list_without_drums) > 0:
99
-
100
- #======================================
101
-
102
- events_matrix2 = []
103
-
104
- # Recalculating timings
105
- for e in events_matrix1:
106
-
107
- # Original timings
108
- e[1] = int(e[1] / 16)
109
- e[2] = int(e[2] / 16)
110
-
111
- #===================================
112
- # ORIGINAL COMPOSITION
113
- #===================================
114
-
115
- # Sorting by patch, pitch, then by start-time
116
-
117
- events_matrix1.sort(key=lambda x: x[6])
118
- events_matrix1.sort(key=lambda x: x[4], reverse=True)
119
- events_matrix1.sort(key=lambda x: x[1])
120
-
121
- #=======================================================
122
- # FINAL PROCESSING
123
-
124
- melody_chords = []
125
- melody_chords2 = []
126
-
127
- # Break between compositions / Intro seq
128
-
129
- if 9 in instruments_list:
130
- drums_present = 19331 # Yes
131
- else:
132
- drums_present = 19330 # No
133
-
134
- if events_matrix1[0][3] != 9:
135
- pat = events_matrix1[0][6]
136
- else:
137
- pat = 128
138
-
139
- melody_chords.extend([19461, drums_present, 19332+pat]) # Intro seq
140
-
141
- #=======================================================
142
- # MAIN PROCESSING CYCLE
143
- #=======================================================
144
-
145
- abs_time = 0
146
-
147
- pbar_time = 0
148
-
149
- pe = events_matrix1[0]
150
-
151
- chords_counter = 1
152
-
153
- comp_chords_len = len(list(set([y[1] for y in events_matrix1])))
154
-
155
- for e in events_matrix1:
156
-
157
- #=======================================================
158
- # Timings...
159
-
160
- # Cliping all values...
161
- delta_time = max(0, min(255, e[1]-pe[1]))
162
-
163
- # Durations and channels
164
-
165
- dur = max(0, min(255, e[2]))
166
- cha = max(0, min(15, e[3]))
167
-
168
- # Patches
169
- if cha == 9: # Drums patch will be == 128
170
- pat = 128
171
-
172
- else:
173
- pat = e[6]
174
-
175
- # Pitches
176
-
177
- ptc = max(1, min(127, e[4]))
178
-
179
- # Velocities
180
-
181
- # Calculating octo-velocity
182
- vel = max(8, min(127, e[5]))
183
- velocity = round(vel / 15)-1
184
-
185
- #=======================================================
186
- # FINAL NOTE SEQ
187
-
188
- # Writing final note asynchronously
189
-
190
- dur_vel = (8 * dur) + velocity
191
- pat_ptc = (129 * pat) + ptc
192
 
193
- melody_chords.extend([delta_time, dur_vel+256, pat_ptc+2304])
194
- melody_chords2.append([delta_time, dur_vel+256, pat_ptc+2304])
195
-
196
- pe = e
197
-
198
-
199
  #==================================================================
200
 
201
  print('=' * 70)
202
- print('Number of tokens:', len(melody_chords))
203
- print('Number of notes:', len(melody_chords2))
204
- print('Sample output events', melody_chords[:5])
205
- print('=' * 70)
206
- print('Generating...')
207
-
208
- #@title Pitches/Instruments Inpainting
209
-
210
- #@markdown You can stop the inpainting at any time to render partial results
211
-
212
- #@markdown Inpainting settings
213
-
214
- #@markdown Select MIDI patch present in the composition to inpaint
215
-
216
- inpaint_MIDI_patch = input_patch_number
217
-
218
- #@markdown Generation settings
219
-
220
- number_of_prime_tokens = 90 # @param {type:"slider", min:3, max:8190, step:3}
221
- number_of_memory_tokens = 1024 # @param {type:"slider", min:3, max:8190, step:3}
222
- number_of_samples_per_inpainted_note = 1 #@param {type:"slider", min:1, max:16, step:1}
223
- temperature = 0.85
224
-
225
- print('=' * 70)
226
- print('Giant Music Transformer Inpainting Model Generator')
227
- print('=' * 70)
228
-
229
- nidx = 0
230
-
231
- for i, m in enumerate(melody_chords):
232
-
233
- cpatch = (melody_chords[i]-2304) // 129
234
-
235
- if 2304 <= melody_chords[i] < 18945 and (cpatch) == inpaint_MIDI_patch:
236
- nidx += 1
237
-
238
- if nidx == input_num_of_notes+(number_of_prime_tokens // 3):
239
- break
240
-
241
- nidx = i
242
-
243
- out2 = []
244
-
245
- for m in melody_chords[:number_of_prime_tokens]:
246
- out2.append(m)
247
-
248
- for i in range(number_of_prime_tokens, len(melody_chords[:nidx])):
249
-
250
- cpatch = (melody_chords[i]-2304) // 129
251
-
252
- if 2304 <= melody_chords[i] < 18945 and (cpatch) == inpaint_MIDI_patch:
253
-
254
- samples = []
255
-
256
- for j in range(number_of_samples_per_inpainted_note):
257
-
258
- inp = torch.LongTensor(out2[-number_of_memory_tokens:]).cuda()
259
-
260
- with ctx:
261
- out1 = model.generate(inp,
262
- 1,
263
- temperature=temperature,
264
- return_prime=True,
265
- verbose=False)
266
-
267
- with torch.no_grad():
268
- test_loss, test_acc = model(out1)
269
-
270
- samples.append([out1.tolist()[0][-1], test_acc.tolist()])
271
-
272
- accs = [y[1] for y in samples]
273
- max_acc = max(accs)
274
- max_acc_sample = samples[accs.index(max_acc)][0]
275
-
276
- cpitch = (max_acc_sample-2304) % 129
277
-
278
- out2.extend([((cpatch * 129) + cpitch)+2304])
279
-
280
- else:
281
- out2.append(melody_chords[i])
282
-
283
  print('=' * 70)
284
  print('Done!')
285
  print('=' * 70)
@@ -287,71 +79,15 @@ def TranscribePianoAudio(input_audio):
287
  #===============================================================================
288
  print('Rendering results...')
289
 
290
- print('=' * 70)
291
- print('Sample INTs', out2[:12])
292
- print('=' * 70)
293
-
294
- if len(out2) != 0:
295
-
296
- song = out2
297
- song_f = []
298
-
299
- time = 0
300
- dur = 0
301
- vel = 90
302
- pitch = 0
303
- channel = 0
304
-
305
- patches = [-1] * 16
306
-
307
- channels = [0] * 16
308
- channels[9] = 1
309
-
310
- for ss in song:
311
-
312
- if 0 <= ss < 256:
313
-
314
- time += ss * 16
315
-
316
- if 256 <= ss < 2304:
317
-
318
- dur = ((ss-256) // 8) * 16
319
- vel = (((ss-256) % 8)+1) * 15
320
-
321
- if 2304 <= ss < 18945:
322
-
323
- patch = (ss-2304) // 129
324
-
325
- if patch < 128:
326
-
327
- if patch not in patches:
328
- if 0 in channels:
329
- cha = channels.index(0)
330
- channels[cha] = 1
331
- else:
332
- cha = 15
333
-
334
- patches[cha] = patch
335
- channel = patches.index(patch)
336
- else:
337
- channel = patches.index(patch)
338
-
339
- if patch == 128:
340
- channel = 9
341
-
342
- pitch = (ss-2304) % 129
343
-
344
- song_f.append(['note', time, dur, channel, pitch, vel, patch ])
345
-
346
- patches = [0 if x==-1 else x for x in patches]
347
 
348
- detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
349
- output_signature = 'Giant Music Transformer',
350
  output_file_name = fn1,
351
  track_name='Project Los Angeles',
352
  list_of_MIDI_patches=patches
353
  )
354
-
355
  new_fn = fn1+'.mid'
356
 
357
 
@@ -428,7 +164,7 @@ if __name__ == "__main__":
428
  output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
429
 
430
 
431
- run_event = run_btn.click(InpaintPitches, [input_midi, input_num_of_notes, input_patch_number],
432
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
433
 
434
  gr.Examples(
 
30
  print('=' * 70)
31
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
32
  start_time = reqtime.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  print('=' * 70)
34
 
35
+ f = input_midi.name
36
+ fn = os.path.basename(f)
37
  fn1 = fn.split('.')[0]
38
 
39
  input_num_of_notes = max(8, min(2048, input_num_of_notes))
40
 
41
  print('-' * 70)
42
  print('Input file name:', fn)
43
+ print('-' * 70)
44
+ print('Loading audio...')
45
+
46
+ # Load audio
47
+ (audio, _) = load_audio(f, sr=sample_rate, mono=True)
48
+ print('Done!')
49
+ print('-' * 70)
50
+ print('Loading transcriptor..')
51
+
52
+ # Transcriptor
53
+ transcriptor = PianoTranscription(device='cuda') # 'cuda' | 'cpu'
54
+ print('Done!')
55
+ print('-' * 70)
56
+ print('Transcribing...')
57
+
58
+ transcribed_dict = transcriptor.transcribe(audio, fn+'.mid')
59
+ print('Done!')
60
  print('-' * 70)
61
 
62
  #===============================================================================
63
+ raw_score = TMIDIX.midi2single_track_ms_score(fn+'.mid')
64
 
65
  #===============================================================================
66
  # Enhanced score notes
67
 
68
+ escore = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
 
 
 
 
 
 
70
  #==================================================================
71
 
72
  print('=' * 70)
73
+ print('Number of transcribed notes:', len(escore))
74
+ print('Sample trascribed MIDI events', escore[:5])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  print('=' * 70)
76
  print('Done!')
77
  print('=' * 70)
 
79
  #===============================================================================
80
  print('Rendering results...')
81
 
82
+ patches = [0] * 16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(escore,
85
+ output_signature = 'ByteDance Solo Piano Audio to MIDI Transcription',
86
  output_file_name = fn1,
87
  track_name='Project Los Angeles',
88
  list_of_MIDI_patches=patches
89
  )
90
+ print('=' * 70)
91
  new_fn = fn1+'.mid'
92
 
93
 
 
164
  output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
165
 
166
 
167
+ run_event = run_btn.click(TranscribePianoAudio, [input_audio],
168
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
169
 
170
  gr.Examples(