adefossez commited on
Commit
56d7528
1 Parent(s): 5591dfc

no progress on batch

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -49,6 +49,7 @@ def interrupt():
49
  global INTERRUPTING
50
  INTERRUPTING = True
51
 
 
52
  def make_waveform(*args, **kwargs):
53
  # Further remove some warnings.
54
  be = time.time()
@@ -66,7 +67,7 @@ def load_model(version='melody'):
66
  MODEL = MusicGen.get_pretrained(version)
67
 
68
 
69
- def _do_predictions(texts, melodies, duration, **gen_kwargs):
70
  MODEL.set_generation_params(duration=duration, **gen_kwargs)
71
  print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
72
  be = time.time()
@@ -89,10 +90,10 @@ def _do_predictions(texts, melodies, duration, **gen_kwargs):
89
  descriptions=texts,
90
  melody_wavs=processed_melodies,
91
  melody_sample_rate=target_sr,
92
- progress=True
93
  )
94
  else:
95
- outputs = MODEL.generate(texts, progress=True)
96
 
97
  outputs = outputs.detach().cpu().float()
98
  out_files = []
@@ -128,7 +129,7 @@ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coe
128
  MODEL.set_custom_progress_callback(_progress)
129
 
130
  outs = _do_predictions(
131
- [text], [melody], duration,
132
  top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
133
  return outs[0]
134
 
@@ -324,6 +325,8 @@ if __name__ == "__main__":
324
  args = parser.parse_args()
325
 
326
  launch_kwargs = {}
 
 
327
  if args.username and args.password:
328
  launch_kwargs['auth'] = (args.username, args.password)
329
  if args.server_port:
 
49
  global INTERRUPTING
50
  INTERRUPTING = True
51
 
52
+
53
  def make_waveform(*args, **kwargs):
54
  # Further remove some warnings.
55
  be = time.time()
 
67
  MODEL = MusicGen.get_pretrained(version)
68
 
69
 
70
+ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
71
  MODEL.set_generation_params(duration=duration, **gen_kwargs)
72
  print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
73
  be = time.time()
 
90
  descriptions=texts,
91
  melody_wavs=processed_melodies,
92
  melody_sample_rate=target_sr,
93
+ progress=progress,
94
  )
95
  else:
96
+ outputs = MODEL.generate(texts, progress=progress)
97
 
98
  outputs = outputs.detach().cpu().float()
99
  out_files = []
 
129
  MODEL.set_custom_progress_callback(_progress)
130
 
131
  outs = _do_predictions(
132
+ [text], [melody], duration, progress=True,
133
  top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
134
  return outs[0]
135
 
 
325
  args = parser.parse_args()
326
 
327
  launch_kwargs = {}
328
+ launch_kwargs['server_name'] = args.listen
329
+
330
  if args.username and args.password:
331
  launch_kwargs['auth'] = (args.username, args.password)
332
  if args.server_port: