Spaces:
Running
on
A10G
Running
on
A10G
no progress on batch
Browse files
app.py
CHANGED
@@ -49,6 +49,7 @@ def interrupt():
|
|
49 |
global INTERRUPTING
|
50 |
INTERRUPTING = True
|
51 |
|
|
|
52 |
def make_waveform(*args, **kwargs):
|
53 |
# Further remove some warnings.
|
54 |
be = time.time()
|
@@ -66,7 +67,7 @@ def load_model(version='melody'):
|
|
66 |
MODEL = MusicGen.get_pretrained(version)
|
67 |
|
68 |
|
69 |
-
def _do_predictions(texts, melodies, duration, **gen_kwargs):
|
70 |
MODEL.set_generation_params(duration=duration, **gen_kwargs)
|
71 |
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
|
72 |
be = time.time()
|
@@ -89,10 +90,10 @@ def _do_predictions(texts, melodies, duration, **gen_kwargs):
|
|
89 |
descriptions=texts,
|
90 |
melody_wavs=processed_melodies,
|
91 |
melody_sample_rate=target_sr,
|
92 |
-
progress=
|
93 |
)
|
94 |
else:
|
95 |
-
outputs = MODEL.generate(texts, progress=
|
96 |
|
97 |
outputs = outputs.detach().cpu().float()
|
98 |
out_files = []
|
@@ -128,7 +129,7 @@ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coe
|
|
128 |
MODEL.set_custom_progress_callback(_progress)
|
129 |
|
130 |
outs = _do_predictions(
|
131 |
-
[text], [melody], duration,
|
132 |
top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
|
133 |
return outs[0]
|
134 |
|
@@ -324,6 +325,8 @@ if __name__ == "__main__":
|
|
324 |
args = parser.parse_args()
|
325 |
|
326 |
launch_kwargs = {}
|
|
|
|
|
327 |
if args.username and args.password:
|
328 |
launch_kwargs['auth'] = (args.username, args.password)
|
329 |
if args.server_port:
|
|
|
49 |
global INTERRUPTING
|
50 |
INTERRUPTING = True
|
51 |
|
52 |
+
|
53 |
def make_waveform(*args, **kwargs):
|
54 |
# Further remove some warnings.
|
55 |
be = time.time()
|
|
|
67 |
MODEL = MusicGen.get_pretrained(version)
|
68 |
|
69 |
|
70 |
+
def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
|
71 |
MODEL.set_generation_params(duration=duration, **gen_kwargs)
|
72 |
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
|
73 |
be = time.time()
|
|
|
90 |
descriptions=texts,
|
91 |
melody_wavs=processed_melodies,
|
92 |
melody_sample_rate=target_sr,
|
93 |
+
progress=progress,
|
94 |
)
|
95 |
else:
|
96 |
+
outputs = MODEL.generate(texts, progress=progress)
|
97 |
|
98 |
outputs = outputs.detach().cpu().float()
|
99 |
out_files = []
|
|
|
129 |
MODEL.set_custom_progress_callback(_progress)
|
130 |
|
131 |
outs = _do_predictions(
|
132 |
+
[text], [melody], duration, progress=True,
|
133 |
top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
|
134 |
return outs[0]
|
135 |
|
|
|
325 |
args = parser.parse_args()
|
326 |
|
327 |
launch_kwargs = {}
|
328 |
+
launch_kwargs['server_name'] = args.listen
|
329 |
+
|
330 |
if args.username and args.password:
|
331 |
launch_kwargs['auth'] = (args.username, args.password)
|
332 |
if args.server_port:
|