Spaces:
Paused
Paused
v1.2
Browse files- app.py +48 -27
- javascript/app.js +4 -3
- midi_tokenizer.py +146 -35
app.py
CHANGED
@@ -111,16 +111,19 @@ def create_msg(name, data):
|
|
111 |
return {"name": name, "data": data, "uuid": uuid.uuid4().hex}
|
112 |
|
113 |
|
114 |
-
def send_msgs(msgs, msgs_history):
|
|
|
|
|
115 |
msgs_history.append(msgs)
|
116 |
-
if len(msgs_history) >
|
117 |
-
msgs_history
|
118 |
return json.dumps(msgs_history)
|
119 |
|
120 |
|
121 |
-
def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
122 |
msgs_history = []
|
123 |
mid_seq = []
|
|
|
124 |
gen_events = int(gen_events)
|
125 |
max_len = gen_events
|
126 |
|
@@ -129,6 +132,8 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
|
|
129 |
if tab == 0:
|
130 |
i = 0
|
131 |
mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
|
|
|
|
|
132 |
patches = {}
|
133 |
if instruments is None:
|
134 |
instruments = []
|
@@ -151,10 +156,10 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
|
|
151 |
max_len += len(mid)
|
152 |
for token_seq in mid:
|
153 |
mid_seq.append(token_seq.tolist())
|
154 |
-
init_msgs = [create_msg("visualizer_clear",
|
155 |
for tokens in mid_seq:
|
156 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
157 |
-
yield mid_seq, None, None, send_msgs(init_msgs, msgs_history)
|
158 |
model = models[model_name]
|
159 |
generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
160 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
@@ -163,22 +168,31 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
|
|
163 |
token_seq = token_seq.tolist()
|
164 |
mid_seq.append(token_seq)
|
165 |
event = tokenizer.tokens2event(token_seq)
|
166 |
-
yield mid_seq, None, None, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history)
|
167 |
mid = tokenizer.detokenize(mid_seq)
|
168 |
with open(f"output.mid", 'wb') as f:
|
169 |
f.write(MIDI.score2midi(mid))
|
170 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
|
174 |
-
def cancel_run(mid_seq
|
175 |
if mid_seq is None:
|
176 |
return None, None, []
|
177 |
mid = tokenizer.detokenize(mid_seq)
|
178 |
with open(f"output.mid", 'wb') as f:
|
179 |
f.write(MIDI.score2midi(mid))
|
180 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
181 |
-
|
|
|
|
|
|
|
|
|
182 |
|
183 |
|
184 |
def load_javascript(dir="javascript"):
|
@@ -200,6 +214,7 @@ def load_javascript(dir="javascript"):
|
|
200 |
|
201 |
|
202 |
def hf_hub_download_retry(repo_id, filename):
|
|
|
203 |
retry = 0
|
204 |
err = None
|
205 |
while retry < 30:
|
@@ -246,9 +261,9 @@ if __name__ == "__main__":
|
|
246 |
"Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
|
247 |
"[Open In Colab]"
|
248 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
249 |
-
" for faster running and longer generation"
|
|
|
250 |
)
|
251 |
-
js_msg_history_state = gr.State(value=[])
|
252 |
js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
|
253 |
js_msg.change(None, [js_msg], [], js="""
|
254 |
(msg_json) =>{
|
@@ -262,19 +277,25 @@ if __name__ == "__main__":
|
|
262 |
tab_select = gr.State(value=0)
|
263 |
with gr.Tabs():
|
264 |
with gr.TabItem("instrument prompt") as tab1:
|
265 |
-
input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
|
266 |
multiselect=True, max_choices=15, type="value")
|
267 |
-
input_drum_kit = gr.Dropdown(label="drum kit", choices=list(drum_kits2number.keys()), type="value",
|
268 |
value="None")
|
|
|
|
|
|
|
269 |
example1 = gr.Examples([
|
270 |
[[], "None"],
|
271 |
[["Acoustic Grand"], "None"],
|
272 |
-
[[
|
273 |
-
|
274 |
-
[[
|
275 |
-
|
276 |
-
[[
|
277 |
-
|
|
|
|
|
|
|
278 |
"Electric Bass(finger)"], "Standard"]
|
279 |
], [input_instruments, input_drum_kit])
|
280 |
with gr.TabItem("midi prompt") as tab2:
|
@@ -292,19 +313,19 @@ if __name__ == "__main__":
|
|
292 |
with gr.Accordion("options", open=False):
|
293 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
294 |
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
|
295 |
-
input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=
|
296 |
input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
|
297 |
-
example3 = gr.Examples([[1, 0.98,
|
298 |
run_btn = gr.Button("generate", variant="primary")
|
299 |
stop_btn = gr.Button("stop and output")
|
300 |
output_midi_seq = gr.State()
|
301 |
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
|
302 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
303 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
304 |
-
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit,
|
305 |
-
input_midi_events, input_gen_events, input_temp,
|
306 |
-
input_allow_cc],
|
307 |
-
[output_midi_seq, output_midi, output_audio, js_msg
|
308 |
concurrency_limit=3)
|
309 |
-
stop_btn.click(cancel_run, [output_midi_seq
|
310 |
app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
|
|
|
111 |
return {"name": name, "data": data, "uuid": uuid.uuid4().hex}
|
112 |
|
113 |
|
114 |
+
def send_msgs(msgs, msgs_history=None):
|
115 |
+
if msgs_history is None:
|
116 |
+
msgs_history = []
|
117 |
msgs_history.append(msgs)
|
118 |
+
if len(msgs_history) > 25:
|
119 |
+
msgs_history= msgs_history[1:]
|
120 |
return json.dumps(msgs_history)
|
121 |
|
122 |
|
123 |
+
def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
124 |
msgs_history = []
|
125 |
mid_seq = []
|
126 |
+
bpm = int(bpm)
|
127 |
gen_events = int(gen_events)
|
128 |
max_len = gen_events
|
129 |
|
|
|
132 |
if tab == 0:
|
133 |
i = 0
|
134 |
mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
|
135 |
+
if bpm != 0:
|
136 |
+
mid.append(tokenizer.event2tokens(["set_tempo",0,0,0, bpm]))
|
137 |
patches = {}
|
138 |
if instruments is None:
|
139 |
instruments = []
|
|
|
156 |
max_len += len(mid)
|
157 |
for token_seq in mid:
|
158 |
mid_seq.append(token_seq.tolist())
|
159 |
+
init_msgs = [create_msg("visualizer_clear", False)]
|
160 |
for tokens in mid_seq:
|
161 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
162 |
+
yield mid_seq, None, None, send_msgs(init_msgs, msgs_history)
|
163 |
model = models[model_name]
|
164 |
generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
165 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
|
|
168 |
token_seq = token_seq.tolist()
|
169 |
mid_seq.append(token_seq)
|
170 |
event = tokenizer.tokens2event(token_seq)
|
171 |
+
yield mid_seq, None, None, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history)
|
172 |
mid = tokenizer.detokenize(mid_seq)
|
173 |
with open(f"output.mid", 'wb') as f:
|
174 |
f.write(MIDI.score2midi(mid))
|
175 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
176 |
+
# resend all msgs
|
177 |
+
msgs = [create_msg("visualizer_end", None), create_msg("visualizer_clear", True)]
|
178 |
+
for tokens in mid_seq:
|
179 |
+
msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
180 |
+
msgs.append(create_msg("visualizer_end", None))
|
181 |
+
yield mid_seq, "output.mid", (44100, audio), send_msgs(msgs)
|
182 |
|
183 |
|
184 |
+
def cancel_run(mid_seq):
|
185 |
if mid_seq is None:
|
186 |
return None, None, []
|
187 |
mid = tokenizer.detokenize(mid_seq)
|
188 |
with open(f"output.mid", 'wb') as f:
|
189 |
f.write(MIDI.score2midi(mid))
|
190 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
191 |
+
msgs = [create_msg("visualizer_end", None), create_msg("visualizer_clear", True)]
|
192 |
+
for tokens in mid_seq:
|
193 |
+
msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
194 |
+
msgs.append(create_msg("visualizer_end", None))
|
195 |
+
return "output.mid", (44100, audio), send_msgs(msgs)
|
196 |
|
197 |
|
198 |
def load_javascript(dir="javascript"):
|
|
|
214 |
|
215 |
|
216 |
def hf_hub_download_retry(repo_id, filename):
|
217 |
+
print(f"downloading {repo_id} {filename}")
|
218 |
retry = 0
|
219 |
err = None
|
220 |
while retry < 30:
|
|
|
261 |
"Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
|
262 |
"[Open In Colab]"
|
263 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
264 |
+
" for faster running and longer generation\n\n"
|
265 |
+
"**Update v1.2**: Optimise the tokenizer and dataset"
|
266 |
)
|
|
|
267 |
js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
|
268 |
js_msg.change(None, [js_msg], [], js="""
|
269 |
(msg_json) =>{
|
|
|
277 |
tab_select = gr.State(value=0)
|
278 |
with gr.Tabs():
|
279 |
with gr.TabItem("instrument prompt") as tab1:
|
280 |
+
input_instruments = gr.Dropdown(label="🪗instruments (auto if empty)", choices=list(patch2number.keys()),
|
281 |
multiselect=True, max_choices=15, type="value")
|
282 |
+
input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
|
283 |
value="None")
|
284 |
+
input_bpm = gr.Slider(label="BPM (beats per minute, auto if 0)", minimum=0, maximum=255,
|
285 |
+
step=1,
|
286 |
+
value=0)
|
287 |
example1 = gr.Examples([
|
288 |
[[], "None"],
|
289 |
[["Acoustic Grand"], "None"],
|
290 |
+
[['Acoustic Grand', 'SynthStrings 2', 'SynthStrings 1', 'Pizzicato Strings',
|
291 |
+
'Pad 2 (warm)', 'Tremolo Strings', 'String Ensemble 1'], "Orchestra"],
|
292 |
+
[['Trumpet', 'Oboe', 'Trombone', 'String Ensemble 1', 'Clarinet',
|
293 |
+
'French Horn', 'Pad 4 (choir)', 'Bassoon', 'Flute'], "None"],
|
294 |
+
[['Flute', 'French Horn', 'Clarinet', 'String Ensemble 2', 'English Horn', 'Bassoon',
|
295 |
+
'Oboe', 'Pizzicato Strings'], "Orchestra"],
|
296 |
+
[['Electric Piano 2', 'Lead 5 (charang)', 'Electric Bass(pick)', 'Lead 2 (sawtooth)',
|
297 |
+
'Pad 1 (new age)', 'Orchestra Hit', 'Cello', 'Electric Guitar(clean)'], "Standard"],
|
298 |
+
[["Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar",
|
299 |
"Electric Bass(finger)"], "Standard"]
|
300 |
], [input_instruments, input_drum_kit])
|
301 |
with gr.TabItem("midi prompt") as tab2:
|
|
|
313 |
with gr.Accordion("options", open=False):
|
314 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
315 |
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
|
316 |
+
input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
|
317 |
input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
|
318 |
+
example3 = gr.Examples([[1, 0.98, 20], [1, 0.98, 12]], [input_temp, input_top_p, input_top_k])
|
319 |
run_btn = gr.Button("generate", variant="primary")
|
320 |
stop_btn = gr.Button("stop and output")
|
321 |
output_midi_seq = gr.State()
|
322 |
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
|
323 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
324 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
325 |
+
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
|
326 |
+
input_midi, input_midi_events, input_gen_events, input_temp,
|
327 |
+
input_top_p, input_top_k, input_allow_cc],
|
328 |
+
[output_midi_seq, output_midi, output_audio, js_msg],
|
329 |
concurrency_limit=3)
|
330 |
+
stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
|
331 |
app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
|
javascript/app.js
CHANGED
@@ -146,13 +146,14 @@ class MidiVisualizer extends HTMLElement{
|
|
146 |
this.setPlayTime(0);
|
147 |
}
|
148 |
|
149 |
-
clearMidiEvents(){
|
150 |
this.pause()
|
151 |
this.midiEvents = [];
|
152 |
this.activeNotes = [];
|
153 |
this.midiTimes = [];
|
154 |
this.t1 = 0
|
155 |
-
|
|
|
156 |
this.setPlayTime(0);
|
157 |
this.totalTimeMs = 0;
|
158 |
this.playTimeMs = 0
|
@@ -426,7 +427,7 @@ customElements.define('midi-visualizer', MidiVisualizer);
|
|
426 |
handled_msgs.push(msg.uuid);
|
427 |
switch (msg.name) {
|
428 |
case "visualizer_clear":
|
429 |
-
midi_visualizer.clearMidiEvents();
|
430 |
createProgressBar(midi_visualizer_container_inited)
|
431 |
break;
|
432 |
case "visualizer_append":
|
|
|
146 |
this.setPlayTime(0);
|
147 |
}
|
148 |
|
149 |
+
clearMidiEvents(keepColor=false){
|
150 |
this.pause()
|
151 |
this.midiEvents = [];
|
152 |
this.activeNotes = [];
|
153 |
this.midiTimes = [];
|
154 |
this.t1 = 0
|
155 |
+
if (!keepColor)
|
156 |
+
this.colorMap.clear()
|
157 |
this.setPlayTime(0);
|
158 |
this.totalTimeMs = 0;
|
159 |
this.playTimeMs = 0
|
|
|
427 |
handled_msgs.push(msg.uuid);
|
428 |
switch (msg.name) {
|
429 |
case "visualizer_clear":
|
430 |
+
midi_visualizer.clearMidiEvents(msg.data);
|
431 |
createProgressBar(midi_visualizer_container_inited)
|
432 |
break;
|
433 |
case "visualizer_append":
|
midi_tokenizer.py
CHANGED
@@ -42,22 +42,48 @@ class MIDITokenizer:
|
|
42 |
tempo = int((60 / bpm) * 10 ** 6)
|
43 |
return tempo
|
44 |
|
45 |
-
def tokenize(self, midi_score, add_bos_eos=True):
|
46 |
ticks_per_beat = midi_score[0]
|
47 |
event_list = {}
|
48 |
for track_idx, track in enumerate(midi_score[1:129]):
|
49 |
last_notes = {}
|
|
|
|
|
|
|
50 |
for event in track:
|
|
|
|
|
51 |
t = round(16 * event[1] / ticks_per_beat) # quantization
|
52 |
new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
|
53 |
if event[0] == "note":
|
54 |
new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
|
55 |
elif event[0] == "set_tempo":
|
56 |
-
new_event[4]
|
|
|
|
|
|
|
57 |
if event[0] == "note":
|
58 |
key = tuple(new_event[:4] + new_event[5:-1])
|
59 |
else:
|
60 |
key = tuple(new_event[:-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
if event[0] == "note": # to eliminate note overlap due to quantization
|
62 |
cp = tuple(new_event[5:7])
|
63 |
if cp in last_notes:
|
@@ -71,21 +97,39 @@ class MIDITokenizer:
|
|
71 |
event_list = list(event_list.values())
|
72 |
event_list = sorted(event_list, key=lambda e: e[1:4])
|
73 |
midi_seq = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
last_t1 = 0
|
76 |
for event in event_list:
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
|
85 |
-
for i, p in enumerate(self.events[name])]
|
86 |
-
tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
|
87 |
-
midi_seq.append(tokens)
|
88 |
-
last_t1 = cur_t1
|
89 |
|
90 |
if add_bos_eos:
|
91 |
bos = [self.bos_id] + [self.pad_id] * (self.max_token_seq - 1)
|
@@ -96,6 +140,8 @@ class MIDITokenizer:
|
|
96 |
def event2tokens(self, event):
|
97 |
name = event[0]
|
98 |
params = event[1:]
|
|
|
|
|
99 |
tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
|
100 |
for i, p in enumerate(self.events[name])]
|
101 |
tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
|
@@ -120,14 +166,10 @@ class MIDITokenizer:
|
|
120 |
t1 = 0
|
121 |
for tokens in midi_seq:
|
122 |
if tokens[0] in self.id_events:
|
123 |
-
|
124 |
-
if
|
125 |
continue
|
126 |
-
|
127 |
-
params = [params[i] - self.parameter_ids[p][0] for i, p in enumerate(self.events[name])]
|
128 |
-
if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
|
129 |
-
continue
|
130 |
-
event = [name] + params
|
131 |
if name == "set_tempo":
|
132 |
event[4] = self.bpm2tempo(event[4])
|
133 |
if event[0] == "note":
|
@@ -183,7 +225,7 @@ class MIDITokenizer:
|
|
183 |
return img
|
184 |
|
185 |
def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10,
|
186 |
-
max_track_shift=
|
187 |
pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
|
188 |
vel_shift = random.randint(-max_vel_shift, max_vel_shift)
|
189 |
cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
|
@@ -239,16 +281,85 @@ class MIDITokenizer:
|
|
239 |
midi_seq_new.append(tokens_new)
|
240 |
return midi_seq_new
|
241 |
|
242 |
-
def
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
tempo = int((60 / bpm) * 10 ** 6)
|
43 |
return tempo
|
44 |
|
45 |
+
def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4):
|
46 |
ticks_per_beat = midi_score[0]
|
47 |
event_list = {}
|
48 |
for track_idx, track in enumerate(midi_score[1:129]):
|
49 |
last_notes = {}
|
50 |
+
patch_dict = {}
|
51 |
+
control_dict = {}
|
52 |
+
last_tempo = 0
|
53 |
for event in track:
|
54 |
+
if event[0] not in self.events:
|
55 |
+
continue
|
56 |
t = round(16 * event[1] / ticks_per_beat) # quantization
|
57 |
new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
|
58 |
if event[0] == "note":
|
59 |
new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
|
60 |
elif event[0] == "set_tempo":
|
61 |
+
if new_event[4] == 0: # invalid tempo
|
62 |
+
continue
|
63 |
+
bpm = int(self.tempo2bpm(new_event[4]))
|
64 |
+
new_event[4] = min(bpm, 255)
|
65 |
if event[0] == "note":
|
66 |
key = tuple(new_event[:4] + new_event[5:-1])
|
67 |
else:
|
68 |
key = tuple(new_event[:-1])
|
69 |
+
if event[0] == "patch_change":
|
70 |
+
c, p = event[2:]
|
71 |
+
last_p = patch_dict.setdefault(c, None)
|
72 |
+
if last_p == p:
|
73 |
+
continue
|
74 |
+
patch_dict[c] = p
|
75 |
+
elif event[0] == "control_change":
|
76 |
+
c, cc, v = event[2:]
|
77 |
+
last_v = control_dict.setdefault((c, cc), 0)
|
78 |
+
if abs(last_v - v) < cc_eps:
|
79 |
+
continue
|
80 |
+
control_dict[(c, cc)] = v
|
81 |
+
elif event[0] == "set_tempo":
|
82 |
+
tempo = new_event[-1]
|
83 |
+
if abs(last_tempo - tempo) < tempo_eps:
|
84 |
+
continue
|
85 |
+
last_tempo = tempo
|
86 |
+
|
87 |
if event[0] == "note": # to eliminate note overlap due to quantization
|
88 |
cp = tuple(new_event[5:7])
|
89 |
if cp in last_notes:
|
|
|
97 |
event_list = list(event_list.values())
|
98 |
event_list = sorted(event_list, key=lambda e: e[1:4])
|
99 |
midi_seq = []
|
100 |
+
setup_events = {}
|
101 |
+
notes_in_setup = False
|
102 |
+
for i, event in enumerate(event_list): # optimise setup
|
103 |
+
new_event = [*event]
|
104 |
+
if event[0] != "note":
|
105 |
+
new_event[1] = 0
|
106 |
+
new_event[2] = 0
|
107 |
+
has_next = False
|
108 |
+
has_pre = False
|
109 |
+
if i < len(event_list) - 1:
|
110 |
+
next_event = event_list[i + 1]
|
111 |
+
has_next = event[1] + event[2] == next_event[1] + next_event[2]
|
112 |
+
if notes_in_setup and i > 0:
|
113 |
+
pre_event = event_list[i - 1]
|
114 |
+
has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
|
115 |
+
if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre) :
|
116 |
+
event_list = sorted(setup_events.values(), key=lambda e: 1 if e[0] == "note" else 0) + event_list[i:]
|
117 |
+
break
|
118 |
+
else:
|
119 |
+
if event[0] == "note":
|
120 |
+
notes_in_setup = True
|
121 |
+
key = tuple(event[3:-1])
|
122 |
+
setup_events[key] = new_event
|
123 |
|
124 |
last_t1 = 0
|
125 |
for event in event_list:
|
126 |
+
cur_t1 = event[1]
|
127 |
+
event[1] = event[1] - last_t1
|
128 |
+
tokens = self.event2tokens(event)
|
129 |
+
if not tokens:
|
130 |
+
continue
|
131 |
+
midi_seq.append(tokens)
|
132 |
+
last_t1 = cur_t1
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
if add_bos_eos:
|
135 |
bos = [self.bos_id] + [self.pad_id] * (self.max_token_seq - 1)
|
|
|
140 |
def event2tokens(self, event):
|
141 |
name = event[0]
|
142 |
params = event[1:]
|
143 |
+
if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
|
144 |
+
return []
|
145 |
tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
|
146 |
for i, p in enumerate(self.events[name])]
|
147 |
tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
|
|
|
166 |
t1 = 0
|
167 |
for tokens in midi_seq:
|
168 |
if tokens[0] in self.id_events:
|
169 |
+
event = self.tokens2event(tokens)
|
170 |
+
if not event:
|
171 |
continue
|
172 |
+
name = event[0]
|
|
|
|
|
|
|
|
|
173 |
if name == "set_tempo":
|
174 |
event[4] = self.bpm2tempo(event[4])
|
175 |
if event[0] == "note":
|
|
|
225 |
return img
|
226 |
|
227 |
def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10,
|
228 |
+
max_track_shift=0, max_channel_shift=16):
|
229 |
pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
|
230 |
vel_shift = random.randint(-max_vel_shift, max_vel_shift)
|
231 |
cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
|
|
|
281 |
midi_seq_new.append(tokens_new)
|
282 |
return midi_seq_new
|
283 |
|
284 |
+
def check_quality(self, midi_seq, alignment_min=0.4, tonality_min=0.8, piano_max=0.7, notes_bandwidth_min=3, notes_density_max=30, notes_density_min=2.5, total_notes_max=10000, total_notes_min=500, note_window_size=16):
|
285 |
+
total_notes = 0
|
286 |
+
channels = []
|
287 |
+
time_hist = [0] * 16
|
288 |
+
note_windows = {}
|
289 |
+
notes_sametime = []
|
290 |
+
notes_density_list = []
|
291 |
+
tonality_list = []
|
292 |
+
notes_bandwidth_list = []
|
293 |
+
instruments = {}
|
294 |
+
piano_channels = []
|
295 |
+
undef_instrument = False
|
296 |
+
abs_t1 = 0
|
297 |
+
last_t = 0
|
298 |
+
for tsi, tokens in enumerate(midi_seq):
|
299 |
+
event = self.tokens2event(tokens)
|
300 |
+
if not event:
|
301 |
+
continue
|
302 |
+
t1, t2, tr = event[1:4]
|
303 |
+
abs_t1 += t1
|
304 |
+
t = abs_t1 * 16 + t2
|
305 |
+
c = None
|
306 |
+
if event[0] == "note":
|
307 |
+
d, c, p, v = event[4:]
|
308 |
+
total_notes += 1
|
309 |
+
time_hist[t2] += 1
|
310 |
+
if c != 9: # ignore drum channel
|
311 |
+
if c not in instruments:
|
312 |
+
undef_instrument = True
|
313 |
+
note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
|
314 |
+
if last_t != t:
|
315 |
+
notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
|
316 |
+
notes_sametime_p = [p_ for _, p_ in notes_sametime]
|
317 |
+
if len(notes_sametime) > 0:
|
318 |
+
notes_bandwidth_list.append(max(notes_sametime_p) - min(notes_sametime_p))
|
319 |
+
notes_sametime.append((t + d - 1, p))
|
320 |
+
elif event[0] == "patch_change":
|
321 |
+
c, p = event[4:]
|
322 |
+
instruments[c] = p
|
323 |
+
if p == 0 and c not in piano_channels:
|
324 |
+
piano_channels.append(c)
|
325 |
+
if c is not None and c not in channels:
|
326 |
+
channels.append(c)
|
327 |
+
last_t = t
|
328 |
+
reasons = []
|
329 |
+
if total_notes < total_notes_min:
|
330 |
+
reasons.append("total_min")
|
331 |
+
if total_notes > total_notes_max:
|
332 |
+
reasons.append("total_max")
|
333 |
+
if undef_instrument:
|
334 |
+
reasons.append("undef_instr")
|
335 |
+
if len(note_windows) == 0 and total_notes > 0:
|
336 |
+
reasons.append("drum_only")
|
337 |
+
if reasons:
|
338 |
+
return False, reasons
|
339 |
+
time_hist = sorted(time_hist, reverse=True)
|
340 |
+
alignment = sum(time_hist[:2]) / total_notes
|
341 |
+
for notes in note_windows.values():
|
342 |
+
key_hist = [0] * 12
|
343 |
+
for p in notes:
|
344 |
+
key_hist[p % 12] += 1
|
345 |
+
key_hist = sorted(key_hist, reverse=True)
|
346 |
+
tonality_list.append(sum(key_hist[:7]) / len(notes))
|
347 |
+
notes_density_list.append(len(notes) / note_window_size)
|
348 |
+
tonality_list = sorted(tonality_list)
|
349 |
+
tonality = sum(tonality_list)/len(tonality_list)
|
350 |
+
notes_bandwidth = sum(notes_bandwidth_list)/len(notes_bandwidth_list) if notes_bandwidth_list else 0
|
351 |
+
notes_density = max(notes_density_list) if notes_density_list else 0
|
352 |
+
piano_ratio = len(piano_channels) / len(channels)
|
353 |
+
if len(channels) <=3: # ignore piano threshold if it is a piano solo midi
|
354 |
+
piano_max = 1
|
355 |
+
if alignment < alignment_min: # check weather the notes align to the bars (because some midi files are recorded)
|
356 |
+
reasons.append("alignment")
|
357 |
+
if tonality < tonality_min: # check whether the music is tonal
|
358 |
+
reasons.append("tonality")
|
359 |
+
if notes_bandwidth < notes_bandwidth_min: # check whether music is melodic line only
|
360 |
+
reasons.append("bandwidth")
|
361 |
+
if not notes_density_min < notes_density < notes_density_max:
|
362 |
+
reasons.append("density")
|
363 |
+
if piano_ratio > piano_max: # check whether most instruments is piano (because some midi files don't have instruments assigned correctly)
|
364 |
+
reasons.append("piano")
|
365 |
+
return not reasons, reasons
|