Spaces:
Sleeping
Sleeping
update tokenizer
Browse files- app.py +15 -6
- midi_tokenizer.py +111 -9
app.py
CHANGED
@@ -121,7 +121,8 @@ def send_msgs(msgs):
|
|
121 |
return json.dumps(msgs)
|
122 |
|
123 |
|
124 |
-
def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events,
|
|
|
125 |
gen_events, temp, top_p, top_k, allow_cc):
|
126 |
mid_seq = []
|
127 |
bpm = int(bpm)
|
@@ -153,8 +154,11 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, midi_opt,
|
|
153 |
disable_patch_change = True
|
154 |
disable_channels = [i for i in range(16) if i not in patches]
|
155 |
elif mid is not None:
|
156 |
-
eps = 4 if
|
157 |
-
mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps
|
|
|
|
|
|
|
158 |
mid = np.asarray(mid, dtype=np.int64)
|
159 |
mid = mid[:int(midi_events)]
|
160 |
for token_seq in mid:
|
@@ -306,7 +310,10 @@ if __name__ == "__main__":
|
|
306 |
input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
|
307 |
step=1,
|
308 |
value=128)
|
309 |
-
|
|
|
|
|
|
|
310 |
example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
|
311 |
[input_midi, input_midi_events])
|
312 |
|
@@ -330,8 +337,10 @@ if __name__ == "__main__":
|
|
330 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
331 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
332 |
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
|
333 |
-
input_midi, input_midi_events,
|
334 |
-
|
|
|
|
|
335 |
[output_midi_seq, output_midi, output_audio, input_seed, js_msg],
|
336 |
concurrency_limit=3)
|
337 |
stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
|
|
|
121 |
return json.dumps(msgs)
|
122 |
|
123 |
|
124 |
+
def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events,
|
125 |
+
reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
|
126 |
gen_events, temp, top_p, top_k, allow_cc):
|
127 |
mid_seq = []
|
128 |
bpm = int(bpm)
|
|
|
154 |
disable_patch_change = True
|
155 |
disable_channels = [i for i in range(16) if i not in patches]
|
156 |
elif mid is not None:
|
157 |
+
eps = 4 if reduce_cc_st else 0
|
158 |
+
mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps,
|
159 |
+
remap_track_channel=remap_track_channel,
|
160 |
+
add_default_instr=add_default_instr,
|
161 |
+
remove_empty_channels=remove_empty_channels)
|
162 |
mid = np.asarray(mid, dtype=np.int64)
|
163 |
mid = mid[:int(midi_events)]
|
164 |
for token_seq in mid:
|
|
|
310 |
input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
|
311 |
step=1,
|
312 |
value=128)
|
313 |
+
input_reduce_cc_st = gr.Checkbox(label="reduce control_change and set_tempo events", value=True)
|
314 |
+
input_remap_track_channel = gr.Checkbox(label="remap tracks and channels to have only one channel per track", value=True)
|
315 |
+
input_add_default_instr = gr.Checkbox(label="add a default instrument to channels that don't have an instrument", value=True)
|
316 |
+
input_remove_empty_channels = gr.Checkbox(label="remove channels without notes", value=False)
|
317 |
example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
|
318 |
[input_midi, input_midi_events])
|
319 |
|
|
|
337 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
338 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
339 |
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
|
340 |
+
input_midi, input_midi_events, input_reduce_cc_st, input_remap_track_channel,
|
341 |
+
input_add_default_instr, input_remove_empty_channels, input_seed,
|
342 |
+
input_seed_rand, input_gen_events, input_temp, input_top_p, input_top_k,
|
343 |
+
input_allow_cc],
|
344 |
[output_midi_seq, output_midi, output_audio, input_seed, js_msg],
|
345 |
concurrency_limit=3)
|
346 |
stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
|
midi_tokenizer.py
CHANGED
@@ -42,9 +42,16 @@ class MIDITokenizer:
|
|
42 |
tempo = int((60 / bpm) * 10 ** 6)
|
43 |
return tempo
|
44 |
|
45 |
-
def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4
|
|
|
46 |
ticks_per_beat = midi_score[0]
|
47 |
event_list = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
for track_idx, track in enumerate(midi_score[1:129]):
|
49 |
last_notes = {}
|
50 |
patch_dict = {}
|
@@ -53,9 +60,18 @@ class MIDITokenizer:
|
|
53 |
for event in track:
|
54 |
if event[0] not in self.events:
|
55 |
continue
|
|
|
56 |
t = round(16 * event[1] / ticks_per_beat) # quantization
|
57 |
new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
|
58 |
if event[0] == "note":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
|
60 |
elif event[0] == "set_tempo":
|
61 |
if new_event[4] == 0: # invalid tempo
|
@@ -68,12 +84,18 @@ class MIDITokenizer:
|
|
68 |
key = tuple(new_event[:-1])
|
69 |
if event[0] == "patch_change":
|
70 |
c, p = event[2:]
|
|
|
|
|
71 |
last_p = patch_dict.setdefault(c, None)
|
72 |
if last_p == p:
|
73 |
continue
|
74 |
patch_dict[c] = p
|
|
|
|
|
75 |
elif event[0] == "control_change":
|
76 |
c, cc, v = event[2:]
|
|
|
|
|
77 |
last_v = control_dict.setdefault((c, cc), 0)
|
78 |
if abs(last_v - v) < cc_eps:
|
79 |
continue
|
@@ -84,6 +106,13 @@ class MIDITokenizer:
|
|
84 |
continue
|
85 |
last_tempo = tempo
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
if event[0] == "note": # to eliminate note overlap due to quantization
|
88 |
cp = tuple(new_event[5:7])
|
89 |
if cp in last_notes:
|
@@ -95,8 +124,79 @@ class MIDITokenizer:
|
|
95 |
last_notes[cp] = (key, new_event)
|
96 |
event_list[key] = new_event
|
97 |
event_list = list(event_list.values())
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
setup_events = {}
|
101 |
notes_in_setup = False
|
102 |
for i, event in enumerate(event_list): # optimise setup
|
@@ -113,7 +213,7 @@ class MIDITokenizer:
|
|
113 |
pre_event = event_list[i - 1]
|
114 |
has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
|
115 |
if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre) :
|
116 |
-
event_list = sorted(setup_events.values(), key=
|
117 |
break
|
118 |
else:
|
119 |
if event[0] == "note":
|
@@ -122,7 +222,10 @@ class MIDITokenizer:
|
|
122 |
setup_events[key] = new_event
|
123 |
|
124 |
last_t1 = 0
|
|
|
125 |
for event in event_list:
|
|
|
|
|
126 |
cur_t1 = event[1]
|
127 |
event[1] = event[1] - last_t1
|
128 |
tokens = self.event2tokens(event)
|
@@ -181,7 +284,7 @@ class MIDITokenizer:
|
|
181 |
if track_idx not in tracks_dict:
|
182 |
tracks_dict[track_idx] = []
|
183 |
tracks_dict[track_idx].append([event[0], t] + event[4:])
|
184 |
-
tracks = list(tracks_dict.
|
185 |
|
186 |
for i in range(len(tracks)): # to eliminate note overlap
|
187 |
track = tracks[i]
|
@@ -292,7 +395,6 @@ class MIDITokenizer:
|
|
292 |
notes_bandwidth_list = []
|
293 |
instruments = {}
|
294 |
piano_channels = []
|
295 |
-
undef_instrument = False
|
296 |
abs_t1 = 0
|
297 |
last_t = 0
|
298 |
for tsi, tokens in enumerate(midi_seq):
|
@@ -309,7 +411,9 @@ class MIDITokenizer:
|
|
309 |
time_hist[t2] += 1
|
310 |
if c != 9: # ignore drum channel
|
311 |
if c not in instruments:
|
312 |
-
|
|
|
|
|
313 |
note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
|
314 |
if last_t != t:
|
315 |
notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
|
@@ -330,8 +434,6 @@ class MIDITokenizer:
|
|
330 |
reasons.append("total_min")
|
331 |
if total_notes > total_notes_max:
|
332 |
reasons.append("total_max")
|
333 |
-
if undef_instrument:
|
334 |
-
reasons.append("undef_instr")
|
335 |
if len(note_windows) == 0 and total_notes > 0:
|
336 |
reasons.append("drum_only")
|
337 |
if reasons:
|
|
|
42 |
tempo = int((60 / bpm) * 10 ** 6)
|
43 |
return tempo
|
44 |
|
45 |
+
def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4,
|
46 |
+
remap_track_channel=False, add_default_instr=False, remove_empty_channels=False):
|
47 |
ticks_per_beat = midi_score[0]
|
48 |
event_list = {}
|
49 |
+
track_idx_map = {i: dict() for i in range(16)}
|
50 |
+
track_idx_dict = {}
|
51 |
+
channels = []
|
52 |
+
patch_channels = []
|
53 |
+
empty_channels = [True]*16
|
54 |
+
channel_note_tracks = {i: list() for i in range(16)}
|
55 |
for track_idx, track in enumerate(midi_score[1:129]):
|
56 |
last_notes = {}
|
57 |
patch_dict = {}
|
|
|
60 |
for event in track:
|
61 |
if event[0] not in self.events:
|
62 |
continue
|
63 |
+
c = -1
|
64 |
t = round(16 * event[1] / ticks_per_beat) # quantization
|
65 |
new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
|
66 |
if event[0] == "note":
|
67 |
+
c = event[3]
|
68 |
+
if c > 15 or c < 0:
|
69 |
+
continue
|
70 |
+
empty_channels[c] = False
|
71 |
+
track_idx_dict.setdefault(c, track_idx)
|
72 |
+
note_tracks = channel_note_tracks[c]
|
73 |
+
if track_idx not in note_tracks:
|
74 |
+
note_tracks.append(track_idx)
|
75 |
new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
|
76 |
elif event[0] == "set_tempo":
|
77 |
if new_event[4] == 0: # invalid tempo
|
|
|
84 |
key = tuple(new_event[:-1])
|
85 |
if event[0] == "patch_change":
|
86 |
c, p = event[2:]
|
87 |
+
if c > 15 or c < 0:
|
88 |
+
continue
|
89 |
last_p = patch_dict.setdefault(c, None)
|
90 |
if last_p == p:
|
91 |
continue
|
92 |
patch_dict[c] = p
|
93 |
+
if c not in patch_channels:
|
94 |
+
patch_channels.append(c)
|
95 |
elif event[0] == "control_change":
|
96 |
c, cc, v = event[2:]
|
97 |
+
if c > 15 or c < 0:
|
98 |
+
continue
|
99 |
last_v = control_dict.setdefault((c, cc), 0)
|
100 |
if abs(last_v - v) < cc_eps:
|
101 |
continue
|
|
|
106 |
continue
|
107 |
last_tempo = tempo
|
108 |
|
109 |
+
if c != -1:
|
110 |
+
if c not in channels:
|
111 |
+
channels.append(c)
|
112 |
+
tr_map = track_idx_map[c]
|
113 |
+
if track_idx not in tr_map:
|
114 |
+
tr_map[track_idx] = 0
|
115 |
+
|
116 |
if event[0] == "note": # to eliminate note overlap due to quantization
|
117 |
cp = tuple(new_event[5:7])
|
118 |
if cp in last_notes:
|
|
|
124 |
last_notes[cp] = (key, new_event)
|
125 |
event_list[key] = new_event
|
126 |
event_list = list(event_list.values())
|
127 |
+
|
128 |
+
empty_channels = [c for c in channels if empty_channels[c]]
|
129 |
+
|
130 |
+
if remap_track_channel:
|
131 |
+
patch_channels = []
|
132 |
+
channels_count = 0
|
133 |
+
channels_map = {9: 9} if 9 in channels else {}
|
134 |
+
for c in channels:
|
135 |
+
if c == 9:
|
136 |
+
continue
|
137 |
+
channels_map[c] = channels_count
|
138 |
+
channels_count += 1
|
139 |
+
if channels_count == 9:
|
140 |
+
channels_count = 10
|
141 |
+
channels = list(channels_map.values())
|
142 |
+
|
143 |
+
track_count = 0
|
144 |
+
track_idx_map_order = [k for k,v in sorted(list(channels_map.items()), key=lambda x: x[1])]
|
145 |
+
for c in track_idx_map_order: # tracks not to remove
|
146 |
+
if remove_empty_channels and c in empty_channels:
|
147 |
+
continue
|
148 |
+
tr_map = track_idx_map[c]
|
149 |
+
for track_idx in tr_map:
|
150 |
+
note_tracks = channel_note_tracks[c]
|
151 |
+
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
152 |
+
continue
|
153 |
+
track_count += 1
|
154 |
+
tr_map[track_idx] = track_count
|
155 |
+
for c in track_idx_map_order: # tracks to remove
|
156 |
+
if not (remove_empty_channels and c in empty_channels):
|
157 |
+
continue
|
158 |
+
tr_map = track_idx_map[c]
|
159 |
+
for track_idx in tr_map:
|
160 |
+
note_tracks = channel_note_tracks[c]
|
161 |
+
if not (len(note_tracks) != 0 and track_idx not in note_tracks):
|
162 |
+
continue
|
163 |
+
track_count += 1
|
164 |
+
tr_map[track_idx] = track_count
|
165 |
+
|
166 |
+
empty_channels = [channels_map[c] for c in empty_channels]
|
167 |
+
|
168 |
+
for event in event_list:
|
169 |
+
name = event[0]
|
170 |
+
track_idx = event[3]
|
171 |
+
if name == "note":
|
172 |
+
c = event[5]
|
173 |
+
event[5] = channels_map[c]
|
174 |
+
event[3] = track_idx_map[c][track_idx]
|
175 |
+
track_idx_dict[event[5]] = event[3]
|
176 |
+
elif name == "set_tempo":
|
177 |
+
event[3] = 0
|
178 |
+
elif name == "control_change" or name == "patch_change":
|
179 |
+
c = event[4]
|
180 |
+
event[4] = channels_map[c]
|
181 |
+
tr_map = track_idx_map[c]
|
182 |
+
# move the event to first track of the channel if it's original track is empty
|
183 |
+
note_tracks = channel_note_tracks[c]
|
184 |
+
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
185 |
+
track_idx = channel_note_tracks[c][0]
|
186 |
+
new_track_idx = tr_map.setdefault(track_idx, next(iter(tr_map.values())))
|
187 |
+
event[3] = new_track_idx
|
188 |
+
if name == "patch_change" and event[4] not in patch_channels:
|
189 |
+
patch_channels.append(event[4])
|
190 |
+
|
191 |
+
if add_default_instr:
|
192 |
+
for c in channels:
|
193 |
+
if c not in patch_channels:
|
194 |
+
event_list.append(["patch_change", 0,0, track_idx_dict[c], c, 0])
|
195 |
+
|
196 |
+
events_name_order = {"set_tempo":0, "patch_change":1, "control_change":2, "note":3}
|
197 |
+
events_order = lambda e: e[1:4] + [events_name_order[e[0]]]
|
198 |
+
event_list = sorted(event_list, key=events_order)
|
199 |
+
|
200 |
setup_events = {}
|
201 |
notes_in_setup = False
|
202 |
for i, event in enumerate(event_list): # optimise setup
|
|
|
213 |
pre_event = event_list[i - 1]
|
214 |
has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
|
215 |
if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre) :
|
216 |
+
event_list = sorted(setup_events.values(), key=events_order) + event_list[i:]
|
217 |
break
|
218 |
else:
|
219 |
if event[0] == "note":
|
|
|
222 |
setup_events[key] = new_event
|
223 |
|
224 |
last_t1 = 0
|
225 |
+
midi_seq = []
|
226 |
for event in event_list:
|
227 |
+
if remove_empty_channels and event[0] in ["control_change", "patch_change"] and event[4] in empty_channels:
|
228 |
+
continue
|
229 |
cur_t1 = event[1]
|
230 |
event[1] = event[1] - last_t1
|
231 |
tokens = self.event2tokens(event)
|
|
|
284 |
if track_idx not in tracks_dict:
|
285 |
tracks_dict[track_idx] = []
|
286 |
tracks_dict[track_idx].append([event[0], t] + event[4:])
|
287 |
+
tracks = [tr for idx, tr in sorted(list(tracks_dict.items()), key=lambda it: it[0])]
|
288 |
|
289 |
for i in range(len(tracks)): # to eliminate note overlap
|
290 |
track = tracks[i]
|
|
|
395 |
notes_bandwidth_list = []
|
396 |
instruments = {}
|
397 |
piano_channels = []
|
|
|
398 |
abs_t1 = 0
|
399 |
last_t = 0
|
400 |
for tsi, tokens in enumerate(midi_seq):
|
|
|
411 |
time_hist[t2] += 1
|
412 |
if c != 9: # ignore drum channel
|
413 |
if c not in instruments:
|
414 |
+
instruments[c] = 0
|
415 |
+
if c not in piano_channels:
|
416 |
+
piano_channels.append(c)
|
417 |
note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
|
418 |
if last_t != t:
|
419 |
notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
|
|
|
434 |
reasons.append("total_min")
|
435 |
if total_notes > total_notes_max:
|
436 |
reasons.append("total_max")
|
|
|
|
|
437 |
if len(note_windows) == 0 and total_notes > 0:
|
438 |
reasons.append("drum_only")
|
439 |
if reasons:
|