Spaces:
Running
on
Zero
Running
on
Zero
add application, spaces, requirements, readme
Browse files- README.md +1 -1
- app.py +239 -0
- requirements.txt +4 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Audiocraft
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: Audiocraft
|
3 |
+
emoji: π·πΈπΉπΊποΈποΈποΈπ§
|
4 |
colorFrom: blue
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
app.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from concurrent.futures import ProcessPoolExecutor
|
5 |
+
from pathlib import Path
|
6 |
+
from tempfile import NamedTemporaryFile
|
7 |
+
import time
|
8 |
+
import typing as tp
|
9 |
+
import subprocess as sp
|
10 |
+
import torch
|
11 |
+
import gradio as gr
|
12 |
+
from audiocraft.data.audio_utils import f32_pcm, normalize_audio
|
13 |
+
from audiocraft.data.audio import audio_write
|
14 |
+
from audiocraft.models import JASCO
|
15 |
+
|
16 |
+
MODEL = None
|
17 |
+
MAX_BATCH_SIZE = 12
|
18 |
+
INTERRUPTING = False
|
19 |
+
|
20 |
+
# Wrap subprocess call to clean logs
|
21 |
+
_old_call = sp.call
|
22 |
+
|
23 |
+
def _call_nostderr(*args, **kwargs):
|
24 |
+
kwargs['stderr'] = sp.DEVNULL
|
25 |
+
kwargs['stdout'] = sp.DEVNULL
|
26 |
+
_old_call(*args, **kwargs)
|
27 |
+
|
28 |
+
sp.call = _call_nostderr
|
29 |
+
|
30 |
+
# Preallocate process pool
|
31 |
+
pool = ProcessPoolExecutor(4)
|
32 |
+
pool.__enter__()
|
33 |
+
|
34 |
+
def interrupt():
|
35 |
+
global INTERRUPTING
|
36 |
+
INTERRUPTING = True
|
37 |
+
|
38 |
+
class FileCleaner:
|
39 |
+
def __init__(self, file_lifetime: float = 3600):
|
40 |
+
self.file_lifetime = file_lifetime
|
41 |
+
self.files = []
|
42 |
+
|
43 |
+
def add(self, path: tp.Union[str, Path]):
|
44 |
+
self._cleanup()
|
45 |
+
self.files.append((time.time(), Path(path)))
|
46 |
+
|
47 |
+
def _cleanup(self):
|
48 |
+
now = time.time()
|
49 |
+
for time_added, path in list(self.files):
|
50 |
+
if now - time_added > self.file_lifetime:
|
51 |
+
if path.exists():
|
52 |
+
path.unlink()
|
53 |
+
self.files.pop(0)
|
54 |
+
else:
|
55 |
+
break
|
56 |
+
|
57 |
+
file_cleaner = FileCleaner()
|
58 |
+
|
59 |
+
def chords_string_to_list(chords: str):
|
60 |
+
if chords == '':
|
61 |
+
return []
|
62 |
+
chords = chords.replace('[', '').replace(']', '').replace(' ', '')
|
63 |
+
chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
|
64 |
+
return [(x[0], float(x[1])) for x in chrd_times]
|
65 |
+
|
66 |
+
def load_model(version='facebook/jasco-chords-drums-400M'):
|
67 |
+
global MODEL
|
68 |
+
print("Loading model", version)
|
69 |
+
if MODEL is None or MODEL.name != version:
|
70 |
+
MODEL = None
|
71 |
+
MODEL = JASCO.get_pretrained(version)
|
72 |
+
|
73 |
+
@spaces.GPU
|
74 |
+
def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs):
|
75 |
+
MODEL.set_generation_params(**gen_kwargs)
|
76 |
+
be = time.time()
|
77 |
+
|
78 |
+
chords = chords_string_to_list(chords)
|
79 |
+
|
80 |
+
if melody_matrix is not None:
|
81 |
+
melody_matrix = torch.load(melody_matrix.name, weights_only=True)
|
82 |
+
if len(melody_matrix.shape) != 2:
|
83 |
+
raise gr.Error(f"Melody matrix should be a torch tensor of shape [n_melody_bins, T]; got: {melody_matrix.shape}")
|
84 |
+
if melody_matrix.shape[0] > melody_matrix.shape[1]:
|
85 |
+
melody_matrix = melody_matrix.permute(1, 0)
|
86 |
+
|
87 |
+
if drum_prompt is None:
|
88 |
+
preprocessed_drums_wav = None
|
89 |
+
drums_sr = 32000
|
90 |
+
else:
|
91 |
+
drums_sr, drums = drum_prompt[0], f32_pcm(torch.from_numpy(drum_prompt[1])).t()
|
92 |
+
if drums.dim() == 1:
|
93 |
+
drums = drums[None]
|
94 |
+
drums = normalize_audio(drums, strategy="loudness", loudness_headroom_db=16, sample_rate=drums_sr)
|
95 |
+
preprocessed_drums_wav = drums
|
96 |
+
|
97 |
+
try:
|
98 |
+
outputs = MODEL.generate_music(descriptions=texts, chords=chords,
|
99 |
+
drums_wav=preprocessed_drums_wav,
|
100 |
+
melody_salience_matrix=melody_matrix,
|
101 |
+
drums_sample_rate=drums_sr, progress=progress)
|
102 |
+
except RuntimeError as e:
|
103 |
+
raise gr.Error("Error while generating " + e.args[0])
|
104 |
+
|
105 |
+
outputs = outputs.detach().cpu().float()
|
106 |
+
out_wavs = []
|
107 |
+
for output in outputs:
|
108 |
+
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
109 |
+
audio_write(
|
110 |
+
file.name, output, MODEL.sample_rate, strategy="loudness",
|
111 |
+
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
|
112 |
+
out_wavs.append(file.name)
|
113 |
+
file_cleaner.add(file.name)
|
114 |
+
return out_wavs
|
115 |
+
|
116 |
+
@spaces.GPU
|
117 |
+
def predict_full(model, text, chords_sym, melody_file,
|
118 |
+
drums_file, drums_mic, drum_input_src,
|
119 |
+
cfg_coef_all, cfg_coef_txt,
|
120 |
+
ode_rtol, ode_atol,
|
121 |
+
ode_solver, ode_steps,
|
122 |
+
progress=gr.Progress()):
|
123 |
+
global INTERRUPTING
|
124 |
+
INTERRUPTING = False
|
125 |
+
progress(0, desc="Loading model...")
|
126 |
+
load_model(model)
|
127 |
+
|
128 |
+
max_generated = 0
|
129 |
+
|
130 |
+
def _progress(generated, to_generate):
|
131 |
+
nonlocal max_generated
|
132 |
+
max_generated = max(generated, max_generated)
|
133 |
+
progress((min(max_generated, to_generate), to_generate))
|
134 |
+
if INTERRUPTING:
|
135 |
+
raise gr.Error("Interrupted.")
|
136 |
+
|
137 |
+
MODEL.set_custom_progress_callback(_progress)
|
138 |
+
|
139 |
+
drums = drums_mic if drum_input_src == "mic" else drums_file
|
140 |
+
wavs = _do_predictions(
|
141 |
+
texts=[text] * 2,
|
142 |
+
chords=chords_sym,
|
143 |
+
drum_prompt=drums,
|
144 |
+
melody_matrix=melody_file,
|
145 |
+
progress=True,
|
146 |
+
gradio_progress=progress,
|
147 |
+
cfg_coef_all=cfg_coef_all,
|
148 |
+
cfg_coef_txt=cfg_coef_txt,
|
149 |
+
ode_rtol=ode_rtol,
|
150 |
+
ode_atol=ode_atol,
|
151 |
+
euler=ode_solver == 'euler',
|
152 |
+
euler_steps=ode_steps)
|
153 |
+
|
154 |
+
return wavs
|
155 |
+
|
156 |
+
with gr.Blocks() as demo:
|
157 |
+
gr.Markdown("""
|
158 |
+
# JASCO - Text-to-Music Generation with Temporal Control
|
159 |
+
Generate 10-second music clips using text descriptions and temporal controls (chords, drums, melody).
|
160 |
+
""")
|
161 |
+
|
162 |
+
with gr.Row():
|
163 |
+
with gr.Column():
|
164 |
+
submit = gr.Button("Generate")
|
165 |
+
interrupt_btn = gr.Button("Interrupt")
|
166 |
+
|
167 |
+
with gr.Column():
|
168 |
+
audio_output_0 = gr.Audio(label="Generated Audio 1", type='filepath')
|
169 |
+
audio_output_1 = gr.Audio(label="Generated Audio 2", type='filepath')
|
170 |
+
|
171 |
+
with gr.Row():
|
172 |
+
with gr.Column():
|
173 |
+
text = gr.Text(label="Input Text",
|
174 |
+
value="Strings, woodwind, orchestral, symphony.",
|
175 |
+
interactive=True)
|
176 |
+
with gr.Column():
|
177 |
+
model = gr.Radio([
|
178 |
+
'facebook/jasco-chords-drums-400M',
|
179 |
+
'facebook/jasco-chords-drums-1B',
|
180 |
+
'facebook/jasco-chords-drums-melody-400M',
|
181 |
+
'facebook/jasco-chords-drums-melody-1B'
|
182 |
+
], label="Model", value='facebook/jasco-chords-drums-melody-400M')
|
183 |
+
|
184 |
+
gr.Markdown("### Chords Conditions")
|
185 |
+
chords_sym = gr.Text(
|
186 |
+
label="Chord Progression",
|
187 |
+
value="(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)",
|
188 |
+
interactive=True
|
189 |
+
)
|
190 |
+
|
191 |
+
gr.Markdown("### Drums Conditions")
|
192 |
+
with gr.Row():
|
193 |
+
drum_input_src = gr.Radio(["file", "mic"], value="file", label="Drums Input Source")
|
194 |
+
drums_file = gr.Audio(sources=["upload"], type="numpy", label="Drums File")
|
195 |
+
drums_mic = gr.Audio(sources=["microphone"], type="numpy", label="Drums Mic")
|
196 |
+
|
197 |
+
gr.Markdown("### Melody Conditions")
|
198 |
+
melody_file = gr.File(label="Melody File")
|
199 |
+
|
200 |
+
with gr.Row():
|
201 |
+
cfg_coef_all = gr.Number(label="CFG ALL", value=1.25, step=0.25)
|
202 |
+
cfg_coef_txt = gr.Number(label="CFG TEXT", value=2.5, step=0.25)
|
203 |
+
ode_tol = gr.Number(label="ODE Tolerance", value=1e-4, step=1e-5)
|
204 |
+
ode_solver = gr.Radio(['euler', 'dopri5'], label="ODE Solver", value='euler')
|
205 |
+
ode_steps = gr.Number(label="Euler Steps", value=10, step=1)
|
206 |
+
|
207 |
+
submit.click(
|
208 |
+
fn=predict_full,
|
209 |
+
inputs=[
|
210 |
+
model, text, chords_sym, melody_file,
|
211 |
+
drums_file, drums_mic, drum_input_src,
|
212 |
+
cfg_coef_all, cfg_coef_txt,
|
213 |
+
ode_tol, ode_tol, ode_solver, ode_steps
|
214 |
+
],
|
215 |
+
outputs=[audio_output_0, audio_output_1]
|
216 |
+
)
|
217 |
+
|
218 |
+
interrupt_btn.click(fn=interrupt, queue=False)
|
219 |
+
|
220 |
+
gr.Examples(
|
221 |
+
examples=[
|
222 |
+
[
|
223 |
+
"80s pop with groovy synth bass and electric piano",
|
224 |
+
"(N, 0.0), (C, 0.32), (Dm7, 3.456), (Am, 4.608), (F, 8.32), (C, 9.216)",
|
225 |
+
None,
|
226 |
+
None,
|
227 |
+
],
|
228 |
+
[
|
229 |
+
"Strings, woodwind, orchestral, symphony.",
|
230 |
+
"(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)",
|
231 |
+
None,
|
232 |
+
None,
|
233 |
+
],
|
234 |
+
],
|
235 |
+
inputs=[text, chords_sym, melody_file, drums_file],
|
236 |
+
outputs=[audio_output_0, audio_output_1]
|
237 |
+
)
|
238 |
+
|
239 |
+
demo.queue().launch()
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
accelerate
|
4 |
+
audiocraft
|