Spaces:
Build error
Build error
mattricesound
commited on
Commit
•
e316ba7
1
Parent(s):
c2a47f9
Remove unecessary components. Add back in examples. Add ability to swap local model for using musicgen gradio api
Browse files
app.py
CHANGED
@@ -16,29 +16,35 @@ from tempfile import NamedTemporaryFile
|
|
16 |
import time
|
17 |
import typing as tp
|
18 |
import warnings
|
|
|
19 |
|
20 |
import torch
|
21 |
import gradio as gr
|
22 |
-
|
23 |
from audiocraft.data.audio_utils import convert_audio
|
24 |
-
from audiocraft.data.audio import audio_write
|
25 |
from audiocraft.models import MusicGen
|
26 |
|
27 |
from demucs import pretrained
|
28 |
from demucs.apply import apply_model
|
29 |
from demucs.audio import convert_audio
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
MODEL = None # Last used model
|
33 |
DEMUCS_MODEL = None
|
34 |
MAX_BATCH_SIZE = 12
|
35 |
INTERRUPTING = False
|
|
|
36 |
# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
|
37 |
_old_call = sp.call
|
38 |
|
39 |
stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3}
|
40 |
stem_idx = torch.LongTensor([stem2idx['vocal'], stem2idx['other'], stem2idx['bass']])
|
41 |
|
|
|
42 |
|
43 |
|
44 |
def _call_nostderr(*args, **kwargs):
|
@@ -94,14 +100,19 @@ def make_waveform(*args, **kwargs):
|
|
94 |
|
95 |
def load_model(version='melody'):
|
96 |
global MODEL, DEMUCS_MODEL
|
97 |
-
print("Loading model", version)
|
98 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
99 |
-
if
|
100 |
-
|
101 |
-
|
|
|
|
|
102 |
if DEMUCS_MODEL is None:
|
103 |
DEMUCS_MODEL = pretrained.get_model('htdemucs').to(device)
|
104 |
|
|
|
|
|
|
|
|
|
105 |
|
106 |
def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
|
107 |
MODEL.set_generation_params(duration=duration, **gen_kwargs)
|
@@ -149,31 +160,19 @@ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
|
|
149 |
demucs_output = demucs_output.cpu()
|
150 |
|
151 |
# Naming
|
152 |
-
|
153 |
-
d_filename = f"temp/{texts[0][:10]}_no_drums.wav"
|
154 |
|
155 |
# If path exists, add number. If number exists, update number.
|
156 |
i = 1
|
157 |
-
while Path(
|
158 |
-
|
159 |
-
d_filename = f"{texts[0][:10]}_{i}_no_drums.wav"
|
160 |
i += 1
|
161 |
|
162 |
-
# with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
163 |
-
audio_write(
|
164 |
-
filename, output, MODEL.sample_rate, strategy="loudness",
|
165 |
-
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
|
166 |
-
# out_files.append(pool.submit(make_waveform, filename))
|
167 |
-
out_files.append(filename)
|
168 |
-
file_cleaner.add(filename)
|
169 |
-
# with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
170 |
audio_write(
|
171 |
d_filename, demucs_output, MODEL.sample_rate, strategy="loudness",
|
172 |
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
|
173 |
out_files.append(d_filename)
|
174 |
-
# out_files.append(pool.submit(make_waveform, d_filename))
|
175 |
file_cleaner.add(d_filename)
|
176 |
-
# res = [out_file.result() for out_file in out_files]
|
177 |
res = [out_file for out_file in out_files]
|
178 |
for file in res:
|
179 |
file_cleaner.add(file)
|
@@ -183,18 +182,10 @@ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
|
|
183 |
|
184 |
|
185 |
|
186 |
-
def predict_full(text, melody,
|
187 |
global INTERRUPTING
|
188 |
INTERRUPTING = False
|
189 |
-
|
190 |
-
raise gr.Error("Temperature must be >= 0.")
|
191 |
-
if topk < 0:
|
192 |
-
raise gr.Error("Topk must be non-negative.")
|
193 |
-
if topp < 0:
|
194 |
-
raise gr.Error("Topp must be non-negative.")
|
195 |
-
|
196 |
-
topk = int(topk)
|
197 |
-
|
198 |
def _progress(generated, to_generate):
|
199 |
progress((generated, to_generate))
|
200 |
if INTERRUPTING:
|
@@ -202,77 +193,114 @@ def predict_full(text, melody, duration, topk, topp, temperature, cfg_coef, prog
|
|
202 |
MODEL.set_custom_progress_callback(_progress)
|
203 |
|
204 |
outs = _do_predictions(
|
205 |
-
[text], [melody], duration, progress=True
|
206 |
-
top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
|
207 |
|
208 |
-
return outs[0], outs[
|
209 |
|
210 |
|
211 |
-
def toggle_audio_src(choice):
|
212 |
-
if choice == "mic":
|
213 |
-
return gr.update(source="microphone", value=None, label="Microphone")
|
214 |
-
else:
|
215 |
-
return gr.update(source="upload", value=None, label="File")
|
216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
def ui_full(launch_kwargs):
|
219 |
with gr.Blocks() as interface:
|
|
|
|
|
|
|
|
|
|
|
220 |
with gr.Row():
|
221 |
with gr.Column():
|
222 |
with gr.Row():
|
223 |
text = gr.Text(label="Input Text", interactive=True)
|
224 |
with gr.Column():
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
|
|
229 |
with gr.Row():
|
230 |
submit = gr.Button("Submit")
|
231 |
# Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
|
232 |
-
_ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
|
233 |
-
|
234 |
-
duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
|
235 |
-
with gr.Row():
|
236 |
-
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
237 |
-
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
238 |
-
temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
|
239 |
-
cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
|
240 |
with gr.Column():
|
241 |
-
|
242 |
-
|
243 |
-
output_normal = gr.Audio(label="Generated Music")
|
244 |
-
with gr.Row():
|
245 |
-
file_download = gr.File(label="Download")
|
246 |
-
with gr.Row():
|
247 |
-
# output_without_drum = gr.Video(label="Removed drums")
|
248 |
-
output_without_drum = gr.Audio(label="Removed drums")
|
249 |
-
with gr.Row():
|
250 |
-
file_download_no_drum = gr.File(label="Download")
|
251 |
-
with gr.Row():
|
252 |
gr.Markdown(
|
253 |
"""
|
254 |
Note that the files will be deleted after 10 minutes, so make sure to download!
|
255 |
"""
|
256 |
)
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
|
|
276 |
)
|
277 |
|
278 |
interface.queue().launch(**launch_kwargs)
|
@@ -286,41 +314,18 @@ if __name__ == "__main__":
|
|
286 |
default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
|
287 |
help='IP to listen on for connections to Gradio',
|
288 |
)
|
289 |
-
parser.add_argument(
|
290 |
-
'--username', type=str, default='', help='Username for authentication'
|
291 |
-
)
|
292 |
-
parser.add_argument(
|
293 |
-
'--password', type=str, default='', help='Password for authentication'
|
294 |
-
)
|
295 |
-
parser.add_argument(
|
296 |
-
'--server_port',
|
297 |
-
type=int,
|
298 |
-
default=0,
|
299 |
-
help='Port to run the server listener on',
|
300 |
-
)
|
301 |
-
parser.add_argument(
|
302 |
-
'--inbrowser', action='store_true', help='Open in browser'
|
303 |
-
)
|
304 |
-
parser.add_argument(
|
305 |
-
'--share', action='store_true', help='Share the gradio UI'
|
306 |
-
)
|
307 |
|
308 |
args = parser.parse_args()
|
309 |
|
310 |
launch_kwargs = {}
|
311 |
launch_kwargs['server_name'] = args.listen
|
312 |
|
313 |
-
|
314 |
-
launch_kwargs['auth'] = (args.username, args.password)
|
315 |
-
if args.server_port:
|
316 |
-
launch_kwargs['server_port'] = args.server_port
|
317 |
-
if args.inbrowser:
|
318 |
-
launch_kwargs['inbrowser'] = args.inbrowser
|
319 |
-
if args.share:
|
320 |
-
launch_kwargs['share'] = args.share
|
321 |
-
|
322 |
# Load melody model
|
323 |
load_model()
|
|
|
|
|
324 |
if not os.path.exists("temp"):
|
325 |
os.mkdir("temp")
|
326 |
# Show the interface
|
|
|
16 |
import time
|
17 |
import typing as tp
|
18 |
import warnings
|
19 |
+
import glob
|
20 |
|
21 |
import torch
|
22 |
import gradio as gr
|
23 |
+
import numpy as np
|
24 |
from audiocraft.data.audio_utils import convert_audio
|
25 |
+
from audiocraft.data.audio import audio_write, audio_read
|
26 |
from audiocraft.models import MusicGen
|
27 |
|
28 |
from demucs import pretrained
|
29 |
from demucs.apply import apply_model
|
30 |
from demucs.audio import convert_audio
|
31 |
+
from gradio_client import Client
|
32 |
+
|
33 |
+
LOCAL = False
|
34 |
|
35 |
|
36 |
MODEL = None # Last used model
|
37 |
DEMUCS_MODEL = None
|
38 |
MAX_BATCH_SIZE = 12
|
39 |
INTERRUPTING = False
|
40 |
+
client = None
|
41 |
# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
|
42 |
_old_call = sp.call
|
43 |
|
44 |
stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3}
|
45 |
stem_idx = torch.LongTensor([stem2idx['vocal'], stem2idx['other'], stem2idx['bass']])
|
46 |
|
47 |
+
melody_files = glob.glob('clips/**/*.mp3', recursive=True)
|
48 |
|
49 |
|
50 |
def _call_nostderr(*args, **kwargs):
|
|
|
100 |
|
101 |
def load_model(version='melody'):
|
102 |
global MODEL, DEMUCS_MODEL
|
|
|
103 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
104 |
+
if LOCAL:
|
105 |
+
if MODEL is None or MODEL.name != version:
|
106 |
+
print("Loading model", version)
|
107 |
+
# If gpu is not available, we'll use cpu.
|
108 |
+
MODEL = MusicGen.get_pretrained(version, device=device)
|
109 |
if DEMUCS_MODEL is None:
|
110 |
DEMUCS_MODEL = pretrained.get_model('htdemucs').to(device)
|
111 |
|
112 |
+
def connect_to_endpoint():
|
113 |
+
global client
|
114 |
+
client = Client("https://facebook-musicgen--44zzp.hf.space/")
|
115 |
+
|
116 |
|
117 |
def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
|
118 |
MODEL.set_generation_params(duration=duration, **gen_kwargs)
|
|
|
160 |
demucs_output = demucs_output.cpu()
|
161 |
|
162 |
# Naming
|
163 |
+
d_filename = f"temp/{texts[0][:10]}.wav"
|
|
|
164 |
|
165 |
# If path exists, add number. If number exists, update number.
|
166 |
i = 1
|
167 |
+
while Path(d_filename).exists():
|
168 |
+
d_filename = f"temp/{texts[0][:10]}_{i}.wav"
|
|
|
169 |
i += 1
|
170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
audio_write(
|
172 |
d_filename, demucs_output, MODEL.sample_rate, strategy="loudness",
|
173 |
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
|
174 |
out_files.append(d_filename)
|
|
|
175 |
file_cleaner.add(d_filename)
|
|
|
176 |
res = [out_file for out_file in out_files]
|
177 |
for file in res:
|
178 |
file_cleaner.add(file)
|
|
|
182 |
|
183 |
|
184 |
|
185 |
+
def predict_full(text, melody, progress=gr.Progress()):
|
186 |
global INTERRUPTING
|
187 |
INTERRUPTING = False
|
188 |
+
print("Running local model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
def _progress(generated, to_generate):
|
190 |
progress((generated, to_generate))
|
191 |
if INTERRUPTING:
|
|
|
193 |
MODEL.set_custom_progress_callback(_progress)
|
194 |
|
195 |
outs = _do_predictions(
|
196 |
+
[text], [melody], duration=10, progress=True)
|
|
|
197 |
|
198 |
+
return outs[0], gr.File.update(value=outs[0], visible=True)
|
199 |
|
200 |
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
+
def select_new_melody():
|
203 |
+
new_melody_file = np.random.choice(melody_files)
|
204 |
+
return gr.update(source="upload", value=new_melody_file)
|
205 |
+
|
206 |
+
def run_remote_model(text, melody):
|
207 |
+
print("Running Audiocraft API model with text", text, "and melody", melody)
|
208 |
+
result = client.predict(
|
209 |
+
text, # str in 'Describe your music' Textbox component
|
210 |
+
melody, # str (filepath or URL to file) in 'File' Audio component
|
211 |
+
fn_index=0
|
212 |
+
)
|
213 |
+
# Naming
|
214 |
+
d_filename = os.path.join("temp", f"{text[:10]}.wav")
|
215 |
+
# If path exists, add number. If number exists, update number.
|
216 |
+
i = 1
|
217 |
+
while Path(d_filename).exists():
|
218 |
+
d_filename = os.path.join("temp", f"{text[:10]}_{i}.wav")
|
219 |
+
i += 1
|
220 |
+
|
221 |
+
# Convert mp4 to wav, using ffmpeg
|
222 |
+
# ffmpeg -i input.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 output.wav
|
223 |
+
sp.run(["ffmpeg", "-i", result, "-vn", "-acodec", "pcm_s16le", "-ar", "32000", "-ac", "1", d_filename])
|
224 |
+
# Load wav file
|
225 |
+
output, sr = audio_read(d_filename)
|
226 |
+
# Demucs
|
227 |
+
print("Running demucs")
|
228 |
+
wav = convert_audio(output, sr, DEMUCS_MODEL.samplerate, DEMUCS_MODEL.audio_channels)
|
229 |
+
wav = wav.unsqueeze(0)
|
230 |
+
stems = apply_model(DEMUCS_MODEL, wav)
|
231 |
+
stems = stems[:, stem_idx] # extract stem
|
232 |
+
stems = stems.sum(1) # merge extracted stems
|
233 |
+
stems = convert_audio(stems, DEMUCS_MODEL.samplerate, 32000, 1)
|
234 |
+
demucs_output = stems[0]
|
235 |
+
|
236 |
+
output = output.cpu()
|
237 |
+
demucs_output = demucs_output.cpu()
|
238 |
+
|
239 |
+
|
240 |
+
audio_write(
|
241 |
+
d_filename, demucs_output, 32000, strategy="loudness",
|
242 |
+
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
|
243 |
+
file_cleaner.add(d_filename)
|
244 |
+
print("Finished", text)
|
245 |
+
print("Tempfiles currently stored: ", len(file_cleaner.files))
|
246 |
+
return d_filename, gr.File.update(value=d_filename, visible=True)
|
247 |
|
248 |
def ui_full(launch_kwargs):
|
249 |
with gr.Blocks() as interface:
|
250 |
+
gr.Markdown(
|
251 |
+
"""
|
252 |
+
# Soundsauce Melody Playground
|
253 |
+
"""
|
254 |
+
)
|
255 |
with gr.Row():
|
256 |
with gr.Column():
|
257 |
with gr.Row():
|
258 |
text = gr.Text(label="Input Text", interactive=True)
|
259 |
with gr.Column():
|
260 |
+
# previously, type="numpy"
|
261 |
+
if LOCAL:
|
262 |
+
audio_type="numpy"
|
263 |
+
else:
|
264 |
+
audio_type="filepath"
|
265 |
+
melody = gr.Audio(type=audio_type, label="File",
|
266 |
+
interactive=True, elem_id="melody-input", value="clips/chipmunk.wav")
|
267 |
+
new_melody = gr.Button("New Melody", interactive=True)
|
268 |
with gr.Row():
|
269 |
submit = gr.Button("Submit")
|
270 |
# Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
|
271 |
+
# _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
|
272 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
with gr.Column():
|
274 |
+
output_without_drum = gr.Audio(label="Output")
|
275 |
+
file_download_no_drum = gr.File(label="Download", visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
gr.Markdown(
|
277 |
"""
|
278 |
Note that the files will be deleted after 10 minutes, so make sure to download!
|
279 |
"""
|
280 |
)
|
281 |
+
if LOCAL:
|
282 |
+
submit.click(predict_full,
|
283 |
+
inputs=[text, melody],
|
284 |
+
outputs=[output_without_drum, file_download_no_drum])
|
285 |
+
else:
|
286 |
+
submit.click(run_remote_model, inputs=[text, melody], outputs=[output_without_drum, file_download_no_drum])
|
287 |
+
new_melody.click(select_new_melody, outputs=[melody])
|
288 |
+
gr.Examples(
|
289 |
+
fn=predict_full,
|
290 |
+
examples=[
|
291 |
+
["Enchanting Flute Trills amidst Misty String Section"],
|
292 |
+
["Gliding Mellotron Strings over Vibrant Phrases"],
|
293 |
+
["Synth Brass Melody Floating over Airy Wind Chimes"],
|
294 |
+
["Echoing Electric Guitar Licks with Ethereal Vocal Chops"],
|
295 |
+
["Rhythmic Acoustic Guitar Licks with Echoing Layers"],
|
296 |
+
["Whimsical Flute Flourishes in a Mystical Forest Glade"],
|
297 |
+
["Airy Piccolo Trills accompanied by Floating Harp Arpeggios"],
|
298 |
+
["Dreamy Harp Glissandos accompanied by Distant Celesta"],
|
299 |
+
["Hypnotic Synth Pads layered with Enigmatic Guitar Progressions"],
|
300 |
+
["Enchanting Kalimba Melodies atop Mystical Atmosphere"],
|
301 |
+
],
|
302 |
+
inputs=[text],
|
303 |
+
outputs=[output_without_drum, file_download_no_drum]
|
304 |
)
|
305 |
|
306 |
interface.queue().launch(**launch_kwargs)
|
|
|
314 |
default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
|
315 |
help='IP to listen on for connections to Gradio',
|
316 |
)
|
317 |
+
parser.add_argument("--local", action="store_true", help="Run locally instead of using API")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
|
319 |
args = parser.parse_args()
|
320 |
|
321 |
launch_kwargs = {}
|
322 |
launch_kwargs['server_name'] = args.listen
|
323 |
|
324 |
+
LOCAL = args.local
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
# Load melody model
|
326 |
load_model()
|
327 |
+
if not LOCAL:
|
328 |
+
connect_to_endpoint()
|
329 |
if not os.path.exists("temp"):
|
330 |
os.mkdir("temp")
|
331 |
# Show the interface
|