Spaces:
Runtime error
Runtime error
fix midi visualizer
Browse files- app.py +25 -9
- javascript/app.js +21 -31
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import argparse
|
| 2 |
import glob
|
| 3 |
import os.path
|
|
|
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import numpy as np
|
|
@@ -107,7 +108,7 @@ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
|
| 107 |
|
| 108 |
|
| 109 |
def create_msg(name, data):
|
| 110 |
-
return {"name": name, "data": data}
|
| 111 |
|
| 112 |
|
| 113 |
def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
|
@@ -164,7 +165,7 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
|
|
| 164 |
|
| 165 |
def cancel_run(mid_seq):
|
| 166 |
if mid_seq is None:
|
| 167 |
-
return None, None
|
| 168 |
mid = tokenizer.detokenize(mid_seq)
|
| 169 |
with open(f"output.mid", 'wb') as f:
|
| 170 |
f.write(MIDI.score2midi(mid))
|
|
@@ -189,17 +190,25 @@ def load_javascript(dir="javascript"):
|
|
| 189 |
|
| 190 |
gr.routes.templates.TemplateResponse = template_response
|
| 191 |
|
|
|
|
| 192 |
# JSMsgReceiver
|
| 193 |
-
|
|
|
|
|
|
|
| 194 |
|
| 195 |
|
|
|
|
| 196 |
def JSMsgReceiver_postprocess(self, y):
|
|
|
|
| 197 |
if self.elem_id == "msg_receiver" and y:
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
|
| 202 |
-
gr.
|
| 203 |
|
| 204 |
number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
|
| 205 |
40: "Blush", 48: "Orchestra"}
|
|
@@ -214,8 +223,8 @@ if __name__ == "__main__":
|
|
| 214 |
opt = parser.parse_args()
|
| 215 |
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
| 216 |
models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
|
| 217 |
-
"j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
|
| 218 |
-
"touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
|
| 219 |
}
|
| 220 |
models = {}
|
| 221 |
tokenizer = MIDITokenizer()
|
|
@@ -238,7 +247,14 @@ if __name__ == "__main__":
|
|
| 238 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
| 239 |
" for faster running and longer generation"
|
| 240 |
)
|
| 241 |
-
js_msg = gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
|
| 243 |
type="value", value=list(models.keys())[0])
|
| 244 |
tab_select = gr.State(value=0)
|
|
|
|
| 1 |
import argparse
|
| 2 |
import glob
|
| 3 |
import os.path
|
| 4 |
+
import uuid
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
import numpy as np
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
def create_msg(name, data):
|
| 111 |
+
return {"name": name, "data": data, "uuid": uuid.uuid4().hex}
|
| 112 |
|
| 113 |
|
| 114 |
def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
|
|
|
| 165 |
|
| 166 |
def cancel_run(mid_seq):
|
| 167 |
if mid_seq is None:
|
| 168 |
+
return None, None, []
|
| 169 |
mid = tokenizer.detokenize(mid_seq)
|
| 170 |
with open(f"output.mid", 'wb') as f:
|
| 171 |
f.write(MIDI.score2midi(mid))
|
|
|
|
| 190 |
|
| 191 |
gr.routes.templates.TemplateResponse = template_response
|
| 192 |
|
| 193 |
+
|
| 194 |
# JSMsgReceiver
|
| 195 |
+
Textbox_postprocess_ori = gr.Textbox.postprocess
|
| 196 |
+
|
| 197 |
+
msg_history = []
|
| 198 |
|
| 199 |
|
| 200 |
+
# the change event may not trigger every time, so send msg history to avoid msg missing.
|
| 201 |
def JSMsgReceiver_postprocess(self, y):
|
| 202 |
+
global msg_history
|
| 203 |
if self.elem_id == "msg_receiver" and y:
|
| 204 |
+
msg_history.append(y)
|
| 205 |
+
if len(msg_history) > 50:
|
| 206 |
+
msg_history = msg_history[1:]
|
| 207 |
+
y = json.dumps(msg_history)
|
| 208 |
+
return Textbox_postprocess_ori(self, y)
|
| 209 |
|
| 210 |
|
| 211 |
+
gr.Textbox.postprocess = JSMsgReceiver_postprocess
|
| 212 |
|
| 213 |
number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
|
| 214 |
40: "Blush", 48: "Orchestra"}
|
|
|
|
| 223 |
opt = parser.parse_args()
|
| 224 |
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
| 225 |
models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
|
| 226 |
+
# "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
|
| 227 |
+
# "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
|
| 228 |
}
|
| 229 |
models = {}
|
| 230 |
tokenizer = MIDITokenizer()
|
|
|
|
| 247 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
| 248 |
" for faster running and longer generation"
|
| 249 |
)
|
| 250 |
+
js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
|
| 251 |
+
js_msg.change(None, [js_msg], [], js="""
|
| 252 |
+
(msg_json) =>{
|
| 253 |
+
let msgs = JSON.parse(msg_json);
|
| 254 |
+
executeCallbacks(msgReceiveCallbacks, msgs);
|
| 255 |
+
return [];
|
| 256 |
+
}
|
| 257 |
+
""")
|
| 258 |
input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
|
| 259 |
type="value", value=list(models.keys())[0])
|
| 260 |
tab_select = gr.State(value=0)
|
javascript/app.js
CHANGED
|
@@ -76,33 +76,6 @@ document.addEventListener("DOMContentLoaded", function() {
|
|
| 76 |
mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
|
| 77 |
});
|
| 78 |
|
| 79 |
-
(()=>{
|
| 80 |
-
let mse_receiver_inited = null
|
| 81 |
-
onUiUpdate(()=>{
|
| 82 |
-
let app = gradioApp()
|
| 83 |
-
let msg_receiver = app.querySelector("#msg_receiver");
|
| 84 |
-
if(!!msg_receiver && mse_receiver_inited !== msg_receiver){
|
| 85 |
-
let mutationObserver = new MutationObserver(function(ms){
|
| 86 |
-
ms.forEach((m)=>{
|
| 87 |
-
m.addedNodes.forEach((node)=>{
|
| 88 |
-
if(node.nodeName === "P"){
|
| 89 |
-
let obj = JSON.parse(node.innerText);
|
| 90 |
-
if(obj instanceof Array){
|
| 91 |
-
obj.forEach((o)=>{executeCallbacks(msgReceiveCallbacks, o);});
|
| 92 |
-
}else{
|
| 93 |
-
executeCallbacks(msgReceiveCallbacks, obj);
|
| 94 |
-
}
|
| 95 |
-
}
|
| 96 |
-
})
|
| 97 |
-
})
|
| 98 |
-
});
|
| 99 |
-
mutationObserver.observe( msg_receiver, {childList:true, subtree:true, characterData:true})
|
| 100 |
-
console.log("receiver init");
|
| 101 |
-
mse_receiver_inited = msg_receiver;
|
| 102 |
-
}
|
| 103 |
-
})
|
| 104 |
-
})()
|
| 105 |
-
|
| 106 |
function HSVtoRGB(h, s, v) {
|
| 107 |
let r, g, b, i, f, p, q, t;
|
| 108 |
i = Math.floor(h * 6);
|
|
@@ -261,9 +234,11 @@ class MidiVisualizer extends HTMLElement{
|
|
| 261 |
this.midiTimes.push({ms:ms, t: t, tempo: tempo})
|
| 262 |
}
|
| 263 |
if(midiEvent[0]==="note"){
|
| 264 |
-
this.totalTimeMs = ms + (midiEvent[3]/ this.timePreBeat)*tempo
|
|
|
|
|
|
|
| 265 |
}
|
| 266 |
-
lastT = t
|
| 267 |
})
|
| 268 |
}
|
| 269 |
|
|
@@ -431,7 +406,22 @@ customElements.define('midi-visualizer', MidiVisualizer);
|
|
| 431 |
divInner.textContent = `${progress}/${total}`;
|
| 432 |
}
|
| 433 |
|
| 434 |
-
onMsgReceive((
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
switch (msg.name) {
|
| 436 |
case "visualizer_clear":
|
| 437 |
midi_visualizer.clearMidiEvents();
|
|
@@ -452,5 +442,5 @@ customElements.define('midi-visualizer', MidiVisualizer);
|
|
| 452 |
break;
|
| 453 |
default:
|
| 454 |
}
|
| 455 |
-
}
|
| 456 |
})();
|
|
|
|
| 76 |
mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
|
| 77 |
});
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
function HSVtoRGB(h, s, v) {
|
| 80 |
let r, g, b, i, f, p, q, t;
|
| 81 |
i = Math.floor(h * 6);
|
|
|
|
| 234 |
this.midiTimes.push({ms:ms, t: t, tempo: tempo})
|
| 235 |
}
|
| 236 |
if(midiEvent[0]==="note"){
|
| 237 |
+
this.totalTimeMs = Math.max(this.totalTimeMs, ms + (midiEvent[3]/ this.timePreBeat)*tempo)
|
| 238 |
+
}else{
|
| 239 |
+
this.totalTimeMs = Math.max(this.totalTimeMs, ms);
|
| 240 |
}
|
| 241 |
+
lastT = t;
|
| 242 |
})
|
| 243 |
}
|
| 244 |
|
|
|
|
| 406 |
divInner.textContent = `${progress}/${total}`;
|
| 407 |
}
|
| 408 |
|
| 409 |
+
onMsgReceive((msgs)=>{
|
| 410 |
+
for(let msg of msgs){
|
| 411 |
+
if(msg instanceof Array){
|
| 412 |
+
msg.forEach((o)=>{handleMsg(o)});
|
| 413 |
+
}else{
|
| 414 |
+
handleMsg(msg);
|
| 415 |
+
}
|
| 416 |
+
}
|
| 417 |
+
})
|
| 418 |
+
let handled_msgs = [];
|
| 419 |
+
function handleMsg(msg){
|
| 420 |
+
if(handled_msgs.indexOf(msg.uuid)!== -1)
|
| 421 |
+
return;
|
| 422 |
+
handled_msgs.push(msg.uuid);
|
| 423 |
+
if(handled_msgs.length > 200)
|
| 424 |
+
handled_msgs = handled_msgs.slice(1);
|
| 425 |
switch (msg.name) {
|
| 426 |
case "visualizer_clear":
|
| 427 |
midi_visualizer.clearMidiEvents();
|
|
|
|
| 442 |
break;
|
| 443 |
default:
|
| 444 |
}
|
| 445 |
+
}
|
| 446 |
})();
|