Spaces:
Build error
Build error
Commit
•
0517e91
1
Parent(s):
7c82007
Update requirements.txt to be git repo
Browse files- app.py +15 -12
- requirements.txt +1 -1
app.py
CHANGED
@@ -102,7 +102,7 @@ def load_model(version='melody'):
|
|
102 |
MODEL = MusicGen.get_pretrained(version, device=device)
|
103 |
|
104 |
|
105 |
-
def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
|
106 |
MODEL.set_generation_params(duration=duration, **gen_kwargs)
|
107 |
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
|
108 |
be = time.time()
|
@@ -135,14 +135,15 @@ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
|
|
135 |
out_files = []
|
136 |
for output in outputs:
|
137 |
# Demucs
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
146 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
147 |
audio_write(
|
148 |
file.name, output, MODEL.sample_rate, strategy="loudness",
|
@@ -168,7 +169,7 @@ def predict_batched(texts, melodies):
|
|
168 |
return [res]
|
169 |
|
170 |
|
171 |
-
def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
|
172 |
global INTERRUPTING
|
173 |
INTERRUPTING = False
|
174 |
if temperature < 0:
|
@@ -189,7 +190,7 @@ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coe
|
|
189 |
|
190 |
outs = _do_predictions(
|
191 |
[text], [melody], duration, progress=True,
|
192 |
-
top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
|
193 |
|
194 |
|
195 |
|
@@ -228,10 +229,12 @@ def ui_full(launch_kwargs):
|
|
228 |
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
229 |
temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
|
230 |
cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
|
|
|
|
|
231 |
with gr.Column():
|
232 |
output = gr.Video(label="Generated Music")
|
233 |
submit.click(predict_full,
|
234 |
-
inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef],
|
235 |
outputs=[output])
|
236 |
radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
|
237 |
gr.Markdown(
|
|
|
102 |
MODEL = MusicGen.get_pretrained(version, device=device)
|
103 |
|
104 |
|
105 |
+
def _do_predictions(texts, melodies, duration, progress=False, drums=True, **gen_kwargs):
|
106 |
MODEL.set_generation_params(duration=duration, **gen_kwargs)
|
107 |
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
|
108 |
be = time.time()
|
|
|
135 |
out_files = []
|
136 |
for output in outputs:
|
137 |
# Demucs
|
138 |
+
if not drums:
|
139 |
+
print("Running demucs")
|
140 |
+
wav = convert_audio(output, MODEL.sample_rate, demucs_model.samplerate, demucs_model.audio_channels)
|
141 |
+
wav = wav.unsqueeze(0)
|
142 |
+
stems = apply_model(demucs_model, wav)
|
143 |
+
stems = stems[:, stem_idx] # extract stem
|
144 |
+
stems = stems.sum(1) # merge extracted stems
|
145 |
+
stems = convert_audio(stems, demucs_model.samplerate, MODEL.sample_rate, 1)
|
146 |
+
output = stems[0]
|
147 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
148 |
audio_write(
|
149 |
file.name, output, MODEL.sample_rate, strategy="loudness",
|
|
|
169 |
return [res]
|
170 |
|
171 |
|
172 |
+
def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, drums=drums, progress=gr.Progress()):
|
173 |
global INTERRUPTING
|
174 |
INTERRUPTING = False
|
175 |
if temperature < 0:
|
|
|
190 |
|
191 |
outs = _do_predictions(
|
192 |
[text], [melody], duration, progress=True,
|
193 |
+
top_k=topk, top_p=topp, temperature=temperature, drums=drums, cfg_coef=cfg_coef)
|
194 |
|
195 |
|
196 |
|
|
|
229 |
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
230 |
temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
|
231 |
cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
|
232 |
+
with gr.Row():
|
233 |
+
drums = gr.Checkbox(label="Drums", value=True, interactive=True)
|
234 |
with gr.Column():
|
235 |
output = gr.Video(label="Generated Music")
|
236 |
submit.click(predict_full,
|
237 |
+
inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef, drums],
|
238 |
outputs=[output])
|
239 |
radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
|
240 |
gr.Markdown(
|
requirements.txt
CHANGED
@@ -12,7 +12,7 @@ arrow==1.2.3
|
|
12 |
asttokens==2.2.1
|
13 |
async-timeout==4.0.2
|
14 |
attrs==23.1.0
|
15 |
-
audiocraft
|
16 |
audioread==3.0.0
|
17 |
av==10.0.0
|
18 |
backcall==0.2.0
|
|
|
12 |
asttokens==2.2.1
|
13 |
async-timeout==4.0.2
|
14 |
attrs==23.1.0
|
15 |
+
audiocraft @ git+https://github.com/facebookresearch/audiocraft@0.0.2a2
|
16 |
audioread==3.0.0
|
17 |
av==10.0.0
|
18 |
backcall==0.2.0
|