asigalov61 commited on
Commit
65d99ea
1 Parent(s): 28c7a7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -125
app.py CHANGED
@@ -23,7 +23,7 @@ in_space = os.getenv("SYSTEM") == "spaces"
23
  # =================================================================================================
24
 
25
  @spaces.GPU
26
- def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type, input_strip_notes):
27
  print('=' * 70)
28
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
  start_time = reqtime.time()
@@ -31,7 +31,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
31
  print('Loading model...')
32
 
33
  SEQ_LEN = 8192 # Models seq len
34
- PAD_IDX = 707 # Models pad index
35
  DEVICE = 'cuda' # 'cuda'
36
 
37
  # instantiate the model
@@ -39,7 +39,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
39
  model = TransformerWrapper(
40
  num_tokens = PAD_IDX+1,
41
  max_seq_len = SEQ_LEN,
42
- attn_layers = Decoder(dim = 2048, depth = 4, heads = 16, attn_flash = True)
43
  )
44
 
45
  model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
@@ -50,7 +50,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
50
  print('Loading model checkpoint...')
51
 
52
  model.load_state_dict(
53
- torch.load('Chords_Progressions_Transformer_Small_2048_Trained_Model_12947_steps_0.9316_loss_0.7386_acc.pth',
54
  map_location=DEVICE))
55
  print('=' * 70)
56
 
@@ -59,7 +59,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
59
  if DEVICE == 'cpu':
60
  dtype = torch.bfloat16
61
  else:
62
- dtype = torch.float16
63
 
64
  ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
65
 
@@ -69,13 +69,12 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
69
  fn = os.path.basename(input_midi.name)
70
  fn1 = fn.split('.')[0]
71
 
72
- input_num_tokens = max(4, min(128, input_num_tokens))
73
 
74
  print('-' * 70)
75
  print('Input file name:', fn)
76
- print('Req num toks:', input_num_tokens)
77
- print('Conditioning type:', input_conditioning_type)
78
- print('Strip notes:', input_strip_notes)
79
  print('-' * 70)
80
 
81
  #===============================================================================
@@ -84,124 +83,121 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
84
  #===============================================================================
85
  # Enhanced score notes
86
 
87
- escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
88
 
89
- no_drums_escore_notes = [e for e in escore_notes if e[6] < 80]
 
90
 
91
- if len(no_drums_escore_notes) > 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- #=======================================================
94
- # PRE-PROCESSING
95
-
96
- #===============================================================================
97
- # Augmented enhanced score notes
98
-
99
- no_drums_escore_notes = TMIDIX.augment_enhanced_score_notes(no_drums_escore_notes)
100
-
101
- cscore = TMIDIX.chordify_score([1000, no_drums_escore_notes])
102
-
103
- clean_cscore = []
104
-
105
- for c in cscore:
106
- pitches = []
107
- cho = []
108
- for cc in c:
109
- if cc[4] not in pitches:
110
- cho.append(cc)
111
- pitches.append(cc[4])
112
-
113
- clean_cscore.append(cho)
114
-
115
- #=======================================================
116
- # FINAL PROCESSING
117
-
118
- melody_chords = []
119
- chords = []
120
- times = [0]
121
- durs = []
122
-
123
- #=======================================================
124
- # MAIN PROCESSING CYCLE
125
- #=======================================================
126
-
127
- pe = clean_cscore[0][0]
128
-
129
- first_chord = True
130
-
131
- for c in clean_cscore:
132
-
133
- # Chords
134
-
135
- c.sort(key=lambda x: x[4], reverse=True)
136
-
137
- tones_chord = sorted(set([cc[4] % 12 for cc in c]))
138
-
139
- try:
140
- chord_token = TMIDIX.ALL_CHORDS_SORTED.index(tones_chord)
141
- except:
142
- checked_tones_chord = TMIDIX.check_and_fix_tones_chord(tones_chord)
143
- chord_token = TMIDIX.ALL_CHORDS_SORTED.index(checked_tones_chord)
144
-
145
- melody_chords.extend([chord_token+384])
146
-
147
- if input_strip_notes:
148
- if len(tones_chord) > 1:
149
- chords.extend([chord_token+384])
150
-
151
- else:
152
- chords.extend([chord_token+384])
153
-
154
- if first_chord:
155
- melody_chords.extend([0])
156
- first_chord = False
157
-
158
- for e in c:
159
-
160
- #=======================================================
161
- # Timings...
162
-
163
- time = e[1]-pe[1]
164
-
165
- dur = e[2]
166
-
167
- if time != 0 and time % 2 != 0:
168
- time += 1
169
- if dur % 2 != 0:
170
- dur += 1
171
-
172
- delta_time = int(max(0, min(255, time)) / 2)
173
-
174
- # Durations
175
-
176
- dur = int(max(0, min(255, dur)) / 2)
177
-
178
- # Pitches
179
-
180
- ptc = max(1, min(127, e[4]))
181
-
182
- #=======================================================
183
- # FINAL NOTE SEQ
184
-
185
- # Writing final note asynchronously
186
-
187
- if delta_time != 0:
188
- melody_chords.extend([delta_time, dur+128, ptc+256])
189
- if input_strip_notes:
190
- if len(c) > 1:
191
- times.append(delta_time)
192
- durs.append(dur+128)
193
- else:
194
- times.append(delta_time)
195
- durs.append(dur+128)
196
- else:
197
- melody_chords.extend([dur+128, ptc+256])
198
-
199
- pe = e
200
 
201
  #==================================================================
202
 
203
  print('=' * 70)
204
-
 
205
  print('Sample output events', melody_chords[:5])
206
  print('=' * 70)
207
  print('Generating...')
@@ -226,7 +222,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
226
  if input_conditioning_type == 'Chords-Times-Durations':
227
  output.append(durs[idx])
228
 
229
- x = torch.tensor([output] * 1, dtype=torch.long, device='cuda')
230
 
231
  o = 0
232
 
@@ -376,9 +372,8 @@ if __name__ == "__main__":
376
  gr.Markdown("## Upload your MIDI or select a sample example MIDI")
377
 
378
  input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
379
- input_num_tokens = gr.Slider(4, 128, value=32, step=1, label="Number of composition chords to generate progression for")
380
- input_conditioning_type = gr.Radio(["Chords", "Chords-Times", "Chords-Times-Durations"], label="Conditioning type")
381
- input_strip_notes = gr.Checkbox(label="Strip notes from the composition")
382
 
383
  run_btn = gr.Button("generate", variant="primary")
384
 
@@ -391,7 +386,7 @@ if __name__ == "__main__":
391
  output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
392
 
393
 
394
- run_event = run_btn.click(GenerateAccompaniment, [input_midi, input_num_tokens, input_conditioning_type, input_strip_notes],
395
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
396
 
397
  gr.Examples(
 
23
  # =================================================================================================
24
 
25
  @spaces.GPU
26
+ def InpaintPitches(input_midi, input_num_of_notes, input_patch_number):
27
  print('=' * 70)
28
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
  start_time = reqtime.time()
 
31
  print('Loading model...')
32
 
33
  SEQ_LEN = 8192 # Models seq len
34
+ PAD_IDX = 19463 # Models pad index
35
  DEVICE = 'cuda' # 'cuda'
36
 
37
  # instantiate the model
 
39
  model = TransformerWrapper(
40
  num_tokens = PAD_IDX+1,
41
  max_seq_len = SEQ_LEN,
42
+ attn_layers = Decoder(dim = 1024, depth = 32, heads = 32, attn_flash = True)
43
  )
44
 
45
  model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
 
50
  print('Loading model checkpoint...')
51
 
52
  model.load_state_dict(
53
+ torch.load('Giant_Music_Transformer_Large_Trained_Model_36074_steps_0.3067_loss_0.927_acc.pth',
54
  map_location=DEVICE))
55
  print('=' * 70)
56
 
 
59
  if DEVICE == 'cpu':
60
  dtype = torch.bfloat16
61
  else:
62
+ dtype = torch.bfloat16
63
 
64
  ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
65
 
 
69
  fn = os.path.basename(input_midi.name)
70
  fn1 = fn.split('.')[0]
71
 
72
+ input_num_of_notes = max(8, min(2048, input_num_of_notes))
73
 
74
  print('-' * 70)
75
  print('Input file name:', fn)
76
+ print('Req num of notes:', input_num_of_notes)
77
+ print('Req patch number:', input_patch_number)
 
78
  print('-' * 70)
79
 
80
  #===============================================================================
 
83
  #===============================================================================
84
  # Enhanced score notes
85
 
86
+ events_matrix1 = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
87
 
88
+ #=======================================================
89
+ # PRE-PROCESSING
90
 
91
+ # checking number of instruments in a composition
92
+ instruments_list_without_drums = list(set([y[3] for y in events_matrix1 if y[3] != 9]))
93
+ instruments_list = list(set([y[3] for y in events_matrix1]))
94
+
95
+ if len(events_matrix1) > 0 and len(instruments_list_without_drums) > 0:
96
+
97
+ #======================================
98
+
99
+ events_matrix2 = []
100
+
101
+ # Recalculating timings
102
+ for e in events_matrix1:
103
+
104
+ # Original timings
105
+ e[1] = int(e[1] / 16)
106
+ e[2] = int(e[2] / 16)
107
+
108
+ #===================================
109
+ # ORIGINAL COMPOSITION
110
+ #===================================
111
+
112
+ # Sorting by patch, pitch, then by start-time
113
+
114
+ events_matrix1.sort(key=lambda x: x[6])
115
+ events_matrix1.sort(key=lambda x: x[4], reverse=True)
116
+ events_matrix1.sort(key=lambda x: x[1])
117
+
118
+ #=======================================================
119
+ # FINAL PROCESSING
120
+
121
+ melody_chords = []
122
+ melody_chords2 = []
123
+
124
+ # Break between compositions / Intro seq
125
+
126
+ if 9 in instruments_list:
127
+ drums_present = 19331 # Yes
128
+ else:
129
+ drums_present = 19330 # No
130
+
131
+ if events_matrix1[0][3] != 9:
132
+ pat = events_matrix1[0][6]
133
+ else:
134
+ pat = 128
135
+
136
+ melody_chords.extend([19461, drums_present, 19332+pat]) # Intro seq
137
+
138
+ #=======================================================
139
+ # MAIN PROCESSING CYCLE
140
+ #=======================================================
141
+
142
+ abs_time = 0
143
+
144
+ pbar_time = 0
145
+
146
+ pe = events_matrix1[0]
147
+
148
+ chords_counter = 1
149
+
150
+ comp_chords_len = len(list(set([y[1] for y in events_matrix1])))
151
+
152
+ for e in events_matrix1:
153
+
154
+ #=======================================================
155
+ # Timings...
156
+
157
+ # Cliping all values...
158
+ delta_time = max(0, min(255, e[1]-pe[1]))
159
+
160
+ # Durations and channels
161
+
162
+ dur = max(0, min(255, e[2]))
163
+ cha = max(0, min(15, e[3]))
164
+
165
+ # Patches
166
+ if cha == 9: # Drums patch will be == 128
167
+ pat = 128
168
+
169
+ else:
170
+ pat = e[6]
171
+
172
+ # Pitches
173
+
174
+ ptc = max(1, min(127, e[4]))
175
+
176
+ # Velocities
177
+
178
+ # Calculating octo-velocity
179
+ vel = max(8, min(127, e[5]))
180
+ velocity = round(vel / 15)-1
181
+
182
+ #=======================================================
183
+ # FINAL NOTE SEQ
184
+
185
+ # Writing final note asynchronously
186
+
187
+ dur_vel = (8 * dur) + velocity
188
+ pat_ptc = (129 * pat) + ptc
189
+
190
+ melody_chords.extend([delta_time, dur_vel+256, pat_ptc+2304])
191
+ melody_chords2.append([delta_time, dur_vel+256, pat_ptc+2304])
192
+
193
+ pe = e
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  #==================================================================
197
 
198
  print('=' * 70)
199
+ print('Number of tokens:', len(melody_chords))
200
+ print('Number of notes:', len(melody_chords2))
201
  print('Sample output events', melody_chords[:5])
202
  print('=' * 70)
203
  print('Generating...')
 
222
  if input_conditioning_type == 'Chords-Times-Durations':
223
  output.append(durs[idx])
224
 
225
+ x = torch.tensor([output] * 1, dtype=torch.long, device=DEVICE)
226
 
227
  o = 0
228
 
 
372
  gr.Markdown("## Upload your MIDI or select a sample example MIDI")
373
 
374
  input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
375
+ input_num_of_notes = gr.Slider(8, 2048, value=128, step=8, label="Number of composition notes to inpaint")
376
+ input_patch_number = gr.Slider(0, 127, value=0, step=1, label="Composition MIDI patch to inpaint")
 
377
 
378
  run_btn = gr.Button("generate", variant="primary")
379
 
 
386
  output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
387
 
388
 
389
+ run_event = run_btn.click(InpaintPitches, [input_midi, input_num_of_notes, input_patch_number],
390
  [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
391
 
392
  gr.Examples(