skytnt commited on
Commit
9958d06
·
1 Parent(s): 5ac6133

fix streaming

Browse files
Files changed (2) hide show
  1. app.py +14 -12
  2. javascript/app.js +3 -6
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import argparse
2
  import glob
3
  import os.path
 
4
  import uuid
5
 
6
  import gradio as gr
@@ -113,21 +114,15 @@ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
113
 
114
 
115
  def create_msg(name, data):
116
- return {"name": name, "data": data, "uuid": uuid.uuid4().hex}
117
 
118
 
119
- def send_msgs(msgs, msgs_history=None):
120
- if msgs_history is None:
121
- msgs_history = []
122
- msgs_history.append(msgs)
123
- if len(msgs_history) > 25:
124
- msgs_history= msgs_history[1:]
125
- return json.dumps(msgs_history)
126
 
127
 
128
  def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, seed, seed_rand,
129
  gen_events, temp, top_p, top_k, allow_cc):
130
- msgs_history = []
131
  mid_seq = []
132
  bpm = int(bpm)
133
  gen_events = int(gen_events)
@@ -167,16 +162,23 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, seed, see
167
  init_msgs = [create_msg("visualizer_clear", False)]
168
  for tokens in mid_seq:
169
  init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
170
- yield mid_seq, None, None, seed, send_msgs(init_msgs, msgs_history)
171
  model = models[model_name]
172
  midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
173
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
174
  disable_channels=disable_channels, generator=generator)
 
 
175
  for i, token_seq in enumerate(midi_generator):
176
  token_seq = token_seq.tolist()
177
  mid_seq.append(token_seq)
178
- event = tokenizer.tokens2event(token_seq)
179
- yield mid_seq, None, None, seed, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history)
 
 
 
 
 
180
  mid = tokenizer.detokenize(mid_seq)
181
  with open(f"output.mid", 'wb') as f:
182
  f.write(MIDI.score2midi(mid))
 
1
  import argparse
2
  import glob
3
  import os.path
4
+ import time
5
  import uuid
6
 
7
  import gradio as gr
 
114
 
115
 
116
  def create_msg(name, data):
117
+ return {"name": name, "data": data}
118
 
119
 
120
+ def send_msgs(msgs):
121
+ return json.dumps(msgs)
 
 
 
 
 
122
 
123
 
124
  def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, seed, seed_rand,
125
  gen_events, temp, top_p, top_k, allow_cc):
 
126
  mid_seq = []
127
  bpm = int(bpm)
128
  gen_events = int(gen_events)
 
162
  init_msgs = [create_msg("visualizer_clear", False)]
163
  for tokens in mid_seq:
164
  init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
165
+ yield mid_seq, None, None, seed, send_msgs(init_msgs)
166
  model = models[model_name]
167
  midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
168
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
169
  disable_channels=disable_channels, generator=generator)
170
+ t = time.time()
171
+ events = []
172
  for i, token_seq in enumerate(midi_generator):
173
  token_seq = token_seq.tolist()
174
  mid_seq.append(token_seq)
175
+ events.append(tokenizer.tokens2event(token_seq))
176
+ ct = time.time()
177
+ if ct - t > 0.2:
178
+ yield mid_seq, None, None, seed, send_msgs([create_msg("visualizer_append", events), create_msg("progress", [i + 1, gen_events])])
179
+ t = ct
180
+ events = []
181
+
182
  mid = tokenizer.detokenize(mid_seq)
183
  with open(f"output.mid", 'wb') as f:
184
  f.write(MIDI.score2midi(mid))
javascript/app.js CHANGED
@@ -420,18 +420,16 @@ customElements.define('midi-visualizer', MidiVisualizer);
420
  }
421
  }
422
  })
423
- let handled_msgs = [];
424
  function handleMsg(msg){
425
- if(handled_msgs.indexOf(msg.uuid)!== -1)
426
- return;
427
- handled_msgs.push(msg.uuid);
428
  switch (msg.name) {
429
  case "visualizer_clear":
430
  midi_visualizer.clearMidiEvents(false);
431
  createProgressBar(midi_visualizer_container_inited)
432
  break;
433
  case "visualizer_append":
434
- midi_visualizer.appendMidiEvent(msg.data);
 
 
435
  break;
436
  case "progress":
437
  let progress = msg.data[0]
@@ -446,7 +444,6 @@ customElements.define('midi-visualizer', MidiVisualizer);
446
  midi_visualizer.finishAppendMidiEvent()
447
  midi_visualizer.setPlayTime(0);
448
  removeProgressBar(midi_visualizer_container_inited);
449
- handled_msgs = []
450
  break;
451
  default:
452
  }
 
420
  }
421
  }
422
  })
 
423
  function handleMsg(msg){
 
 
 
424
  switch (msg.name) {
425
  case "visualizer_clear":
426
  midi_visualizer.clearMidiEvents(false);
427
  createProgressBar(midi_visualizer_container_inited)
428
  break;
429
  case "visualizer_append":
430
+ msg.data.forEach( value => {
431
+ midi_visualizer.appendMidiEvent(value);
432
+ })
433
  break;
434
  case "progress":
435
  let progress = msg.data[0]
 
444
  midi_visualizer.finishAppendMidiEvent()
445
  midi_visualizer.setPlayTime(0);
446
  removeProgressBar(midi_visualizer_container_inited);
 
447
  break;
448
  default:
449
  }