Spaces:
Sleeping
Sleeping
fix streaming
Browse files- app.py +14 -12
- javascript/app.js +3 -6
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import argparse
|
2 |
import glob
|
3 |
import os.path
|
|
|
4 |
import uuid
|
5 |
|
6 |
import gradio as gr
|
@@ -113,21 +114,15 @@ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
|
113 |
|
114 |
|
115 |
def create_msg(name, data):
|
116 |
-
return {"name": name, "data": data
|
117 |
|
118 |
|
119 |
-
def send_msgs(msgs
|
120 |
-
|
121 |
-
msgs_history = []
|
122 |
-
msgs_history.append(msgs)
|
123 |
-
if len(msgs_history) > 25:
|
124 |
-
msgs_history= msgs_history[1:]
|
125 |
-
return json.dumps(msgs_history)
|
126 |
|
127 |
|
128 |
def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, seed, seed_rand,
|
129 |
gen_events, temp, top_p, top_k, allow_cc):
|
130 |
-
msgs_history = []
|
131 |
mid_seq = []
|
132 |
bpm = int(bpm)
|
133 |
gen_events = int(gen_events)
|
@@ -167,16 +162,23 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, seed, see
|
|
167 |
init_msgs = [create_msg("visualizer_clear", False)]
|
168 |
for tokens in mid_seq:
|
169 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
170 |
-
yield mid_seq, None, None, seed, send_msgs(init_msgs
|
171 |
model = models[model_name]
|
172 |
midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
173 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
174 |
disable_channels=disable_channels, generator=generator)
|
|
|
|
|
175 |
for i, token_seq in enumerate(midi_generator):
|
176 |
token_seq = token_seq.tolist()
|
177 |
mid_seq.append(token_seq)
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
180 |
mid = tokenizer.detokenize(mid_seq)
|
181 |
with open(f"output.mid", 'wb') as f:
|
182 |
f.write(MIDI.score2midi(mid))
|
|
|
1 |
import argparse
|
2 |
import glob
|
3 |
import os.path
|
4 |
+
import time
|
5 |
import uuid
|
6 |
|
7 |
import gradio as gr
|
|
|
114 |
|
115 |
|
116 |
def create_msg(name, data):
|
117 |
+
return {"name": name, "data": data}
|
118 |
|
119 |
|
120 |
+
def send_msgs(msgs):
|
121 |
+
return json.dumps(msgs)
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
|
124 |
def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, seed, seed_rand,
|
125 |
gen_events, temp, top_p, top_k, allow_cc):
|
|
|
126 |
mid_seq = []
|
127 |
bpm = int(bpm)
|
128 |
gen_events = int(gen_events)
|
|
|
162 |
init_msgs = [create_msg("visualizer_clear", False)]
|
163 |
for tokens in mid_seq:
|
164 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
165 |
+
yield mid_seq, None, None, seed, send_msgs(init_msgs)
|
166 |
model = models[model_name]
|
167 |
midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
168 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
169 |
disable_channels=disable_channels, generator=generator)
|
170 |
+
t = time.time()
|
171 |
+
events = []
|
172 |
for i, token_seq in enumerate(midi_generator):
|
173 |
token_seq = token_seq.tolist()
|
174 |
mid_seq.append(token_seq)
|
175 |
+
events.append(tokenizer.tokens2event(token_seq))
|
176 |
+
ct = time.time()
|
177 |
+
if ct - t > 0.2:
|
178 |
+
yield mid_seq, None, None, seed, send_msgs([create_msg("visualizer_append", events), create_msg("progress", [i + 1, gen_events])])
|
179 |
+
t = ct
|
180 |
+
events = []
|
181 |
+
|
182 |
mid = tokenizer.detokenize(mid_seq)
|
183 |
with open(f"output.mid", 'wb') as f:
|
184 |
f.write(MIDI.score2midi(mid))
|
javascript/app.js
CHANGED
@@ -420,18 +420,16 @@ customElements.define('midi-visualizer', MidiVisualizer);
|
|
420 |
}
|
421 |
}
|
422 |
})
|
423 |
-
let handled_msgs = [];
|
424 |
function handleMsg(msg){
|
425 |
-
if(handled_msgs.indexOf(msg.uuid)!== -1)
|
426 |
-
return;
|
427 |
-
handled_msgs.push(msg.uuid);
|
428 |
switch (msg.name) {
|
429 |
case "visualizer_clear":
|
430 |
midi_visualizer.clearMidiEvents(false);
|
431 |
createProgressBar(midi_visualizer_container_inited)
|
432 |
break;
|
433 |
case "visualizer_append":
|
434 |
-
|
|
|
|
|
435 |
break;
|
436 |
case "progress":
|
437 |
let progress = msg.data[0]
|
@@ -446,7 +444,6 @@ customElements.define('midi-visualizer', MidiVisualizer);
|
|
446 |
midi_visualizer.finishAppendMidiEvent()
|
447 |
midi_visualizer.setPlayTime(0);
|
448 |
removeProgressBar(midi_visualizer_container_inited);
|
449 |
-
handled_msgs = []
|
450 |
break;
|
451 |
default:
|
452 |
}
|
|
|
420 |
}
|
421 |
}
|
422 |
})
|
|
|
423 |
function handleMsg(msg){
|
|
|
|
|
|
|
424 |
switch (msg.name) {
|
425 |
case "visualizer_clear":
|
426 |
midi_visualizer.clearMidiEvents(false);
|
427 |
createProgressBar(midi_visualizer_container_inited)
|
428 |
break;
|
429 |
case "visualizer_append":
|
430 |
+
msg.data.forEach( value => {
|
431 |
+
midi_visualizer.appendMidiEvent(value);
|
432 |
+
})
|
433 |
break;
|
434 |
case "progress":
|
435 |
let progress = msg.data[0]
|
|
|
444 |
midi_visualizer.finishAppendMidiEvent()
|
445 |
midi_visualizer.setPlayTime(0);
|
446 |
removeProgressBar(midi_visualizer_container_inited);
|
|
|
447 |
break;
|
448 |
default:
|
449 |
}
|