asigalov61 commited on
Commit
df2f0e4
1 Parent(s): 4d766d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -111
app.py CHANGED
@@ -25,82 +25,8 @@ in_space = os.getenv("SYSTEM") == "spaces"
25
  # =================================================================================================
26
 
27
  @spaces.GPU
28
- def classify_GPU(input_data):
29
-
30
- print('Loading model...')
31
-
32
- SEQ_LEN = 1026
33
- PAD_IDX = 940
34
- DEVICE = 'cuda' # 'cuda'
35
-
36
- # instantiate the model
37
-
38
- model = TransformerWrapper(
39
- num_tokens = PAD_IDX+1,
40
- max_seq_len = SEQ_LEN,
41
- attn_layers = Decoder(dim = 1024, depth = 24, heads = 32, attn_flash = True)
42
- )
43
-
44
- model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
45
-
46
- model = torch.nn.DataParallel(model)
47
-
48
- model.to(DEVICE)
49
-
50
- print('=' * 70)
51
-
52
- print('Loading model checkpoint...')
53
-
54
- model.load_state_dict(
55
- torch.load('Ultimate_MIDI_Classifier_Trained_Model_29886_steps_0.556_loss_0.8339_acc.pth',
56
- map_location=DEVICE))
57
- print('=' * 70)
58
-
59
- model.eval()
60
-
61
- if DEVICE == 'cpu':
62
- dtype = torch.bfloat16
63
- else:
64
- dtype = torch.bfloat16
65
-
66
- ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
67
-
68
- print('Done!')
69
- print('=' * 70)
70
-
71
- #==================================================================
72
-
73
- print('=' * 70)
74
- print('Ultimate MIDI Classifier')
75
- print('=' * 70)
76
- print('Classifying...')
77
-
78
- torch.cuda.empty_cache()
79
-
80
- model.eval()
81
-
82
- x = torch.tensor(input_data[:1022], dtype=torch.long, device=DEVICE)
83
-
84
- with ctx:
85
- out = model.module.generate(x,
86
- 2,
87
- filter_logits_fn=top_k,
88
- filter_kwargs={'k': 1},
89
- temperature=0.9,
90
- return_prime=False,
91
- verbose=False)
92
-
93
- result = tuple(out[0].tolist())
94
-
95
- return result
96
-
97
- # =================================================================================================
98
-
99
  def ClassifyMIDI(input_midi):
100
 
101
- SEQ_LEN = 1024
102
- PAD_IDX = 14627
103
-
104
  print('=' * 70)
105
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
106
  start_time = reqtime.time()
@@ -122,60 +48,108 @@ def ClassifyMIDI(input_midi):
122
 
123
  escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
124
 
125
- escore = [e for e in TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32) if e[6] < 80]
126
-
127
- cscore = TMIDIX.chordify_score([1000, escore])
128
-
 
 
 
129
  #=======================================================
130
- # MAIN PROCESSING CYCLE
 
131
  #=======================================================
132
-
 
133
  melody_chords = []
134
-
135
- pe = cscore[0][0]
136
-
137
- for c in cscore:
138
-
139
- pitches = []
140
-
141
- for e in c:
142
-
143
- if e[4] not in pitches:
144
-
145
- dtime = max(0, min(127, e[1]-pe[1]))
146
-
147
- dur = max(1, min(127, e[2]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  ptc = max(1, min(127, e[4]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- melody_chords.append([dtime, dur, ptc])
 
 
 
151
 
152
- pitches.append(ptc)
 
 
153
 
154
- pe = e
 
 
155
 
156
- #==============================================================
 
 
157
 
158
- seq = []
159
  input_data = []
160
 
161
- notes_counter = 0
 
162
 
163
- for mm in melody_chords:
164
 
165
- time = mm[0]
166
- dur = mm[1]
167
- ptc = mm[2]
168
 
169
- seq.extend([time, dur+128, ptc+256])
170
- notes_counter += 1
171
 
172
- for i in range(0, len(seq)-SEQ_LEN-4, (SEQ_LEN-4) // 4):
173
- schunk = seq[i:i+SEQ_LEN-4]
174
- input_data.append([14624] + schunk + [14625])
175
 
176
  print('Done!')
177
  print('=' * 70)
178
-
 
 
 
 
179
  #==============================================================
180
 
181
  classification_summary_string = '=' * 70
@@ -194,7 +168,77 @@ def ClassifyMIDI(input_midi):
194
  classification_summary_string += '=' * 70
195
  classification_summary_string += '\n'
196
 
197
- output, results = classify_GPU(input_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  all_results_labels = [classifier_labels[0][r-384] for r in results]
200
  final_result = mode(results)
 
25
  # =================================================================================================
26
 
27
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def ClassifyMIDI(input_midi):
29
 
 
 
 
30
  print('=' * 70)
31
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
32
  start_time = reqtime.time()
 
48
 
49
  escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
50
 
51
+ #===============================================================================
52
+ # Augmented enhanced score notes
53
+
54
+ escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)
55
+
56
+ escore_notes = [e for e in escore_notes if e[6] < 80 or e[6] == 128]
57
+
58
  #=======================================================
59
+ # Augmentation
60
+
61
  #=======================================================
62
+ # FINAL PROCESSING
63
+
64
  melody_chords = []
65
+
66
+ #=======================================================
67
+ # MAIN PROCESSING CYCLE
68
+ #=======================================================
69
+
70
+ pe = escore_notes[0]
71
+
72
+ pitches = []
73
+
74
+ notes_counter = 0
75
+
76
+ for e in escore_notes:
77
+
78
+ #=======================================================
79
+ # Timings...
80
+
81
+ delta_time = max(0, min(127, e[1]-pe[1]))
82
+
83
+ if delta_time != 0:
84
+ pitches = []
85
+
86
+ # Durations and channels
87
+
88
+ dur = max(1, min(127, e[2]))
89
+
90
+ # Patches
91
+ pat = max(0, min(128, e[6]))
92
+
93
+ # Pitches
94
+
95
+ if pat == 128:
96
+ ptc = max(1, min(127, e[4]))+128
97
+ else:
98
  ptc = max(1, min(127, e[4]))
99
+
100
+ #=======================================================
101
+ # FINAL NOTE SEQ
102
+
103
+ # Writing final note synchronously
104
+
105
+ if ptc not in pitches:
106
+ melody_chords.extend([delta_time, dur+128, ptc+256])
107
+ pitches.append(ptc)
108
+ notes_counter += 1
109
+
110
+ pe = e
111
+
112
+ #==============================================================
113
 
114
+ print('Done!')
115
+ print('=' * 70)
116
+ print('Composition has', notes_counter, 'notes')
117
+ print('=' * 70)
118
 
119
+ print('=' * 70)
120
+ print('Ultimate MIDI Classifier')
121
+ print('=' * 70)
122
 
123
+ print('Input MIDI file name:', midi_name)
124
+ print('=' * 70)
125
+ print('Sampling score...')
126
 
127
+ chunk_size = 1020
128
+
129
+ score = melody_chords
130
 
 
131
  input_data = []
132
 
133
+ for i in range(0, len(score)-chunk_size, chunk_size // classification_sampling_resolution):
134
+ schunk = score[i:i+chunk_size]
135
 
136
+ if len(schunk) == chunk_size:
137
 
138
+ td = [937]
 
 
139
 
140
+ td.extend(schunk)
 
141
 
142
+ td.extend([938])
143
+
144
+ input_data.append(td)
145
 
146
  print('Done!')
147
  print('=' * 70)
148
+ print('Composition was split into' , len(input_data), 'samples', 'of 340 notes each with', 340 - chunk_size // classification_sampling_resolution // 3, 'notes overlap')
149
+ print('=' * 70)
150
+ print('Number of notes in all composition samples:', len(input_data) * 340)
151
+ print('=' * 70)
152
+
153
  #==============================================================
154
 
155
  classification_summary_string = '=' * 70
 
168
  classification_summary_string += '=' * 70
169
  classification_summary_string += '\n'
170
 
171
+ print('Loading model...')
172
+
173
+ SEQ_LEN = 1026
174
+ PAD_IDX = 940
175
+ DEVICE = 'cuda' # 'cuda'
176
+
177
+ # instantiate the model
178
+
179
+ model = TransformerWrapper(
180
+ num_tokens = PAD_IDX+1,
181
+ max_seq_len = SEQ_LEN,
182
+ attn_layers = Decoder(dim = 1024, depth = 24, heads = 32, attn_flash = True)
183
+ )
184
+
185
+ model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
186
+
187
+ model = torch.nn.DataParallel(model)
188
+
189
+ model.to(DEVICE)
190
+
191
+ print('=' * 70)
192
+
193
+ print('Loading model checkpoint...')
194
+
195
+ model.load_state_dict(
196
+ torch.load('Ultimate_MIDI_Classifier_Trained_Model_29886_steps_0.556_loss_0.8339_acc.pth',
197
+ map_location=DEVICE))
198
+ print('=' * 70)
199
+
200
+ if DEVICE == 'cpu':
201
+ dtype = torch.bfloat16
202
+ else:
203
+ dtype = torch.bfloat16
204
+
205
+ ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
206
+
207
+ print('Done!')
208
+ print('=' * 70)
209
+
210
+ #==================================================================
211
+
212
+ print('=' * 70)
213
+ print('Ultimate MIDI Classifier')
214
+ print('=' * 70)
215
+ print('Classifying...')
216
+
217
+ torch.cuda.empty_cache()
218
+
219
+ model.eval()
220
+
221
+ artist_results = []
222
+ song_results = []
223
+
224
+ results = []
225
+
226
+ for input in input_data:
227
+
228
+ x = torch.tensor(input[:1022], dtype=torch.long, device='cuda')
229
+
230
+ with ctx:
231
+ out = model.module.generate(x,
232
+ 2,
233
+ filter_logits_fn=top_k,
234
+ filter_kwargs={'k': 1},
235
+ temperature=0.9,
236
+ return_prime=False,
237
+ verbose=False)
238
+
239
+ result = tuple(out[0].tolist())
240
+
241
+ results.append(result)
242
 
243
  all_results_labels = [classifier_labels[0][r-384] for r in results]
244
  final_result = mode(results)