Spaces:
Running
on
Zero
Running
on
Zero
asigalov61
commited on
Commit
•
65d99ea
1
Parent(s):
28c7a7a
Update app.py
Browse files
app.py
CHANGED
@@ -23,7 +23,7 @@ in_space = os.getenv("SYSTEM") == "spaces"
|
|
23 |
# =================================================================================================
|
24 |
|
25 |
@spaces.GPU
|
26 |
-
def
|
27 |
print('=' * 70)
|
28 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
29 |
start_time = reqtime.time()
|
@@ -31,7 +31,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
|
|
31 |
print('Loading model...')
|
32 |
|
33 |
SEQ_LEN = 8192 # Models seq len
|
34 |
-
PAD_IDX =
|
35 |
DEVICE = 'cuda' # 'cuda'
|
36 |
|
37 |
# instantiate the model
|
@@ -39,7 +39,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
|
|
39 |
model = TransformerWrapper(
|
40 |
num_tokens = PAD_IDX+1,
|
41 |
max_seq_len = SEQ_LEN,
|
42 |
-
attn_layers = Decoder(dim =
|
43 |
)
|
44 |
|
45 |
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
@@ -50,7 +50,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
|
|
50 |
print('Loading model checkpoint...')
|
51 |
|
52 |
model.load_state_dict(
|
53 |
-
torch.load('
|
54 |
map_location=DEVICE))
|
55 |
print('=' * 70)
|
56 |
|
@@ -59,7 +59,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
|
|
59 |
if DEVICE == 'cpu':
|
60 |
dtype = torch.bfloat16
|
61 |
else:
|
62 |
-
dtype = torch.
|
63 |
|
64 |
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
|
65 |
|
@@ -69,13 +69,12 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
|
|
69 |
fn = os.path.basename(input_midi.name)
|
70 |
fn1 = fn.split('.')[0]
|
71 |
|
72 |
-
|
73 |
|
74 |
print('-' * 70)
|
75 |
print('Input file name:', fn)
|
76 |
-
print('Req num
|
77 |
-
print('
|
78 |
-
print('Strip notes:', input_strip_notes)
|
79 |
print('-' * 70)
|
80 |
|
81 |
#===============================================================================
|
@@ -84,124 +83,121 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
|
|
84 |
#===============================================================================
|
85 |
# Enhanced score notes
|
86 |
|
87 |
-
|
88 |
|
89 |
-
|
|
|
90 |
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
-
#=======================================================
|
94 |
-
# PRE-PROCESSING
|
95 |
-
|
96 |
-
#===============================================================================
|
97 |
-
# Augmented enhanced score notes
|
98 |
-
|
99 |
-
no_drums_escore_notes = TMIDIX.augment_enhanced_score_notes(no_drums_escore_notes)
|
100 |
-
|
101 |
-
cscore = TMIDIX.chordify_score([1000, no_drums_escore_notes])
|
102 |
-
|
103 |
-
clean_cscore = []
|
104 |
-
|
105 |
-
for c in cscore:
|
106 |
-
pitches = []
|
107 |
-
cho = []
|
108 |
-
for cc in c:
|
109 |
-
if cc[4] not in pitches:
|
110 |
-
cho.append(cc)
|
111 |
-
pitches.append(cc[4])
|
112 |
-
|
113 |
-
clean_cscore.append(cho)
|
114 |
-
|
115 |
-
#=======================================================
|
116 |
-
# FINAL PROCESSING
|
117 |
-
|
118 |
-
melody_chords = []
|
119 |
-
chords = []
|
120 |
-
times = [0]
|
121 |
-
durs = []
|
122 |
-
|
123 |
-
#=======================================================
|
124 |
-
# MAIN PROCESSING CYCLE
|
125 |
-
#=======================================================
|
126 |
-
|
127 |
-
pe = clean_cscore[0][0]
|
128 |
-
|
129 |
-
first_chord = True
|
130 |
-
|
131 |
-
for c in clean_cscore:
|
132 |
-
|
133 |
-
# Chords
|
134 |
-
|
135 |
-
c.sort(key=lambda x: x[4], reverse=True)
|
136 |
-
|
137 |
-
tones_chord = sorted(set([cc[4] % 12 for cc in c]))
|
138 |
-
|
139 |
-
try:
|
140 |
-
chord_token = TMIDIX.ALL_CHORDS_SORTED.index(tones_chord)
|
141 |
-
except:
|
142 |
-
checked_tones_chord = TMIDIX.check_and_fix_tones_chord(tones_chord)
|
143 |
-
chord_token = TMIDIX.ALL_CHORDS_SORTED.index(checked_tones_chord)
|
144 |
-
|
145 |
-
melody_chords.extend([chord_token+384])
|
146 |
-
|
147 |
-
if input_strip_notes:
|
148 |
-
if len(tones_chord) > 1:
|
149 |
-
chords.extend([chord_token+384])
|
150 |
-
|
151 |
-
else:
|
152 |
-
chords.extend([chord_token+384])
|
153 |
-
|
154 |
-
if first_chord:
|
155 |
-
melody_chords.extend([0])
|
156 |
-
first_chord = False
|
157 |
-
|
158 |
-
for e in c:
|
159 |
-
|
160 |
-
#=======================================================
|
161 |
-
# Timings...
|
162 |
-
|
163 |
-
time = e[1]-pe[1]
|
164 |
-
|
165 |
-
dur = e[2]
|
166 |
-
|
167 |
-
if time != 0 and time % 2 != 0:
|
168 |
-
time += 1
|
169 |
-
if dur % 2 != 0:
|
170 |
-
dur += 1
|
171 |
-
|
172 |
-
delta_time = int(max(0, min(255, time)) / 2)
|
173 |
-
|
174 |
-
# Durations
|
175 |
-
|
176 |
-
dur = int(max(0, min(255, dur)) / 2)
|
177 |
-
|
178 |
-
# Pitches
|
179 |
-
|
180 |
-
ptc = max(1, min(127, e[4]))
|
181 |
-
|
182 |
-
#=======================================================
|
183 |
-
# FINAL NOTE SEQ
|
184 |
-
|
185 |
-
# Writing final note asynchronously
|
186 |
-
|
187 |
-
if delta_time != 0:
|
188 |
-
melody_chords.extend([delta_time, dur+128, ptc+256])
|
189 |
-
if input_strip_notes:
|
190 |
-
if len(c) > 1:
|
191 |
-
times.append(delta_time)
|
192 |
-
durs.append(dur+128)
|
193 |
-
else:
|
194 |
-
times.append(delta_time)
|
195 |
-
durs.append(dur+128)
|
196 |
-
else:
|
197 |
-
melody_chords.extend([dur+128, ptc+256])
|
198 |
-
|
199 |
-
pe = e
|
200 |
|
201 |
#==================================================================
|
202 |
|
203 |
print('=' * 70)
|
204 |
-
|
|
|
205 |
print('Sample output events', melody_chords[:5])
|
206 |
print('=' * 70)
|
207 |
print('Generating...')
|
@@ -226,7 +222,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
|
|
226 |
if input_conditioning_type == 'Chords-Times-Durations':
|
227 |
output.append(durs[idx])
|
228 |
|
229 |
-
x = torch.tensor([output] * 1, dtype=torch.long, device=
|
230 |
|
231 |
o = 0
|
232 |
|
@@ -376,9 +372,8 @@ if __name__ == "__main__":
|
|
376 |
gr.Markdown("## Upload your MIDI or select a sample example MIDI")
|
377 |
|
378 |
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
|
379 |
-
|
380 |
-
|
381 |
-
input_strip_notes = gr.Checkbox(label="Strip notes from the composition")
|
382 |
|
383 |
run_btn = gr.Button("generate", variant="primary")
|
384 |
|
@@ -391,7 +386,7 @@ if __name__ == "__main__":
|
|
391 |
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
|
392 |
|
393 |
|
394 |
-
run_event = run_btn.click(
|
395 |
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
|
396 |
|
397 |
gr.Examples(
|
|
|
23 |
# =================================================================================================
|
24 |
|
25 |
@spaces.GPU
|
26 |
+
def InpaintPitches(input_midi, input_num_of_notes, input_patch_number):
|
27 |
print('=' * 70)
|
28 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
29 |
start_time = reqtime.time()
|
|
|
31 |
print('Loading model...')
|
32 |
|
33 |
SEQ_LEN = 8192 # Models seq len
|
34 |
+
PAD_IDX = 19463 # Models pad index
|
35 |
DEVICE = 'cuda' # 'cuda'
|
36 |
|
37 |
# instantiate the model
|
|
|
39 |
model = TransformerWrapper(
|
40 |
num_tokens = PAD_IDX+1,
|
41 |
max_seq_len = SEQ_LEN,
|
42 |
+
attn_layers = Decoder(dim = 1024, depth = 32, heads = 32, attn_flash = True)
|
43 |
)
|
44 |
|
45 |
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
|
|
50 |
print('Loading model checkpoint...')
|
51 |
|
52 |
model.load_state_dict(
|
53 |
+
torch.load('Giant_Music_Transformer_Large_Trained_Model_36074_steps_0.3067_loss_0.927_acc.pth',
|
54 |
map_location=DEVICE))
|
55 |
print('=' * 70)
|
56 |
|
|
|
59 |
if DEVICE == 'cpu':
|
60 |
dtype = torch.bfloat16
|
61 |
else:
|
62 |
+
dtype = torch.bfloat16
|
63 |
|
64 |
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
|
65 |
|
|
|
69 |
fn = os.path.basename(input_midi.name)
|
70 |
fn1 = fn.split('.')[0]
|
71 |
|
72 |
+
input_num_of_notes = max(8, min(2048, input_num_of_notes))
|
73 |
|
74 |
print('-' * 70)
|
75 |
print('Input file name:', fn)
|
76 |
+
print('Req num of notes:', input_num_of_notes)
|
77 |
+
print('Req patch number:', input_patch_number)
|
|
|
78 |
print('-' * 70)
|
79 |
|
80 |
#===============================================================================
|
|
|
83 |
#===============================================================================
|
84 |
# Enhanced score notes
|
85 |
|
86 |
+
events_matrix1 = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
|
87 |
|
88 |
+
#=======================================================
|
89 |
+
# PRE-PROCESSING
|
90 |
|
91 |
+
# checking number of instruments in a composition
|
92 |
+
instruments_list_without_drums = list(set([y[3] for y in events_matrix1 if y[3] != 9]))
|
93 |
+
instruments_list = list(set([y[3] for y in events_matrix1]))
|
94 |
+
|
95 |
+
if len(events_matrix1) > 0 and len(instruments_list_without_drums) > 0:
|
96 |
+
|
97 |
+
#======================================
|
98 |
+
|
99 |
+
events_matrix2 = []
|
100 |
+
|
101 |
+
# Recalculating timings
|
102 |
+
for e in events_matrix1:
|
103 |
+
|
104 |
+
# Original timings
|
105 |
+
e[1] = int(e[1] / 16)
|
106 |
+
e[2] = int(e[2] / 16)
|
107 |
+
|
108 |
+
#===================================
|
109 |
+
# ORIGINAL COMPOSITION
|
110 |
+
#===================================
|
111 |
+
|
112 |
+
# Sorting by patch, pitch, then by start-time
|
113 |
+
|
114 |
+
events_matrix1.sort(key=lambda x: x[6])
|
115 |
+
events_matrix1.sort(key=lambda x: x[4], reverse=True)
|
116 |
+
events_matrix1.sort(key=lambda x: x[1])
|
117 |
+
|
118 |
+
#=======================================================
|
119 |
+
# FINAL PROCESSING
|
120 |
+
|
121 |
+
melody_chords = []
|
122 |
+
melody_chords2 = []
|
123 |
+
|
124 |
+
# Break between compositions / Intro seq
|
125 |
+
|
126 |
+
if 9 in instruments_list:
|
127 |
+
drums_present = 19331 # Yes
|
128 |
+
else:
|
129 |
+
drums_present = 19330 # No
|
130 |
+
|
131 |
+
if events_matrix1[0][3] != 9:
|
132 |
+
pat = events_matrix1[0][6]
|
133 |
+
else:
|
134 |
+
pat = 128
|
135 |
+
|
136 |
+
melody_chords.extend([19461, drums_present, 19332+pat]) # Intro seq
|
137 |
+
|
138 |
+
#=======================================================
|
139 |
+
# MAIN PROCESSING CYCLE
|
140 |
+
#=======================================================
|
141 |
+
|
142 |
+
abs_time = 0
|
143 |
+
|
144 |
+
pbar_time = 0
|
145 |
+
|
146 |
+
pe = events_matrix1[0]
|
147 |
+
|
148 |
+
chords_counter = 1
|
149 |
+
|
150 |
+
comp_chords_len = len(list(set([y[1] for y in events_matrix1])))
|
151 |
+
|
152 |
+
for e in events_matrix1:
|
153 |
+
|
154 |
+
#=======================================================
|
155 |
+
# Timings...
|
156 |
+
|
157 |
+
# Cliping all values...
|
158 |
+
delta_time = max(0, min(255, e[1]-pe[1]))
|
159 |
+
|
160 |
+
# Durations and channels
|
161 |
+
|
162 |
+
dur = max(0, min(255, e[2]))
|
163 |
+
cha = max(0, min(15, e[3]))
|
164 |
+
|
165 |
+
# Patches
|
166 |
+
if cha == 9: # Drums patch will be == 128
|
167 |
+
pat = 128
|
168 |
+
|
169 |
+
else:
|
170 |
+
pat = e[6]
|
171 |
+
|
172 |
+
# Pitches
|
173 |
+
|
174 |
+
ptc = max(1, min(127, e[4]))
|
175 |
+
|
176 |
+
# Velocities
|
177 |
+
|
178 |
+
# Calculating octo-velocity
|
179 |
+
vel = max(8, min(127, e[5]))
|
180 |
+
velocity = round(vel / 15)-1
|
181 |
+
|
182 |
+
#=======================================================
|
183 |
+
# FINAL NOTE SEQ
|
184 |
+
|
185 |
+
# Writing final note asynchronously
|
186 |
+
|
187 |
+
dur_vel = (8 * dur) + velocity
|
188 |
+
pat_ptc = (129 * pat) + ptc
|
189 |
+
|
190 |
+
melody_chords.extend([delta_time, dur_vel+256, pat_ptc+2304])
|
191 |
+
melody_chords2.append([delta_time, dur_vel+256, pat_ptc+2304])
|
192 |
+
|
193 |
+
pe = e
|
194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
#==================================================================
|
197 |
|
198 |
print('=' * 70)
|
199 |
+
print('Number of tokens:', len(melody_chords))
|
200 |
+
print('Number of notes:', len(melody_chords2))
|
201 |
print('Sample output events', melody_chords[:5])
|
202 |
print('=' * 70)
|
203 |
print('Generating...')
|
|
|
222 |
if input_conditioning_type == 'Chords-Times-Durations':
|
223 |
output.append(durs[idx])
|
224 |
|
225 |
+
x = torch.tensor([output] * 1, dtype=torch.long, device=DEVICE)
|
226 |
|
227 |
o = 0
|
228 |
|
|
|
372 |
gr.Markdown("## Upload your MIDI or select a sample example MIDI")
|
373 |
|
374 |
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
|
375 |
+
input_num_of_notes = gr.Slider(8, 2048, value=128, step=8, label="Number of composition notes to inpaint")
|
376 |
+
input_patch_number = gr.Slider(0, 127, value=0, step=1, label="Composition MIDI patch to inpaint")
|
|
|
377 |
|
378 |
run_btn = gr.Button("generate", variant="primary")
|
379 |
|
|
|
386 |
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
|
387 |
|
388 |
|
389 |
+
run_event = run_btn.click(InpaintPitches, [input_midi, input_num_of_notes, input_patch_number],
|
390 |
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
|
391 |
|
392 |
gr.Examples(
|