Tonic commited on
Commit
266a885
Β·
unverified Β·
1 Parent(s): 90a4e68

add application, spaces, requirements, readme

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +239 -0
  3. requirements.txt +4 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Audiocraft
3
- emoji: 🐠
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
 
1
  ---
2
  title: Audiocraft
3
+ emoji: πŸŽ·πŸŽΈπŸŽΉπŸŽΊπŸŽ™οΈπŸŽšοΈπŸŽ›οΈπŸŽ§
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import logging
3
+ import os
4
+ from concurrent.futures import ProcessPoolExecutor
5
+ from pathlib import Path
6
+ from tempfile import NamedTemporaryFile
7
+ import time
8
+ import typing as tp
9
+ import subprocess as sp
10
+ import torch
11
+ import gradio as gr
12
+ from audiocraft.data.audio_utils import f32_pcm, normalize_audio
13
+ from audiocraft.data.audio import audio_write
14
+ from audiocraft.models import JASCO
15
+
16
+ MODEL = None
17
+ MAX_BATCH_SIZE = 12
18
+ INTERRUPTING = False
19
+
20
+ # Wrap subprocess call to clean logs
21
+ _old_call = sp.call
22
+
23
+ def _call_nostderr(*args, **kwargs):
24
+ kwargs['stderr'] = sp.DEVNULL
25
+ kwargs['stdout'] = sp.DEVNULL
26
+ _old_call(*args, **kwargs)
27
+
28
+ sp.call = _call_nostderr
29
+
30
+ # Preallocate process pool
31
+ pool = ProcessPoolExecutor(4)
32
+ pool.__enter__()
33
+
34
+ def interrupt():
35
+ global INTERRUPTING
36
+ INTERRUPTING = True
37
+
38
+ class FileCleaner:
39
+ def __init__(self, file_lifetime: float = 3600):
40
+ self.file_lifetime = file_lifetime
41
+ self.files = []
42
+
43
+ def add(self, path: tp.Union[str, Path]):
44
+ self._cleanup()
45
+ self.files.append((time.time(), Path(path)))
46
+
47
+ def _cleanup(self):
48
+ now = time.time()
49
+ for time_added, path in list(self.files):
50
+ if now - time_added > self.file_lifetime:
51
+ if path.exists():
52
+ path.unlink()
53
+ self.files.pop(0)
54
+ else:
55
+ break
56
+
57
+ file_cleaner = FileCleaner()
58
+
59
+ def chords_string_to_list(chords: str):
60
+ if chords == '':
61
+ return []
62
+ chords = chords.replace('[', '').replace(']', '').replace(' ', '')
63
+ chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
64
+ return [(x[0], float(x[1])) for x in chrd_times]
65
+
66
+ def load_model(version='facebook/jasco-chords-drums-400M'):
67
+ global MODEL
68
+ print("Loading model", version)
69
+ if MODEL is None or MODEL.name != version:
70
+ MODEL = None
71
+ MODEL = JASCO.get_pretrained(version)
72
+
73
+ @spaces.GPU
74
+ def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs):
75
+ MODEL.set_generation_params(**gen_kwargs)
76
+ be = time.time()
77
+
78
+ chords = chords_string_to_list(chords)
79
+
80
+ if melody_matrix is not None:
81
+ melody_matrix = torch.load(melody_matrix.name, weights_only=True)
82
+ if len(melody_matrix.shape) != 2:
83
+ raise gr.Error(f"Melody matrix should be a torch tensor of shape [n_melody_bins, T]; got: {melody_matrix.shape}")
84
+ if melody_matrix.shape[0] > melody_matrix.shape[1]:
85
+ melody_matrix = melody_matrix.permute(1, 0)
86
+
87
+ if drum_prompt is None:
88
+ preprocessed_drums_wav = None
89
+ drums_sr = 32000
90
+ else:
91
+ drums_sr, drums = drum_prompt[0], f32_pcm(torch.from_numpy(drum_prompt[1])).t()
92
+ if drums.dim() == 1:
93
+ drums = drums[None]
94
+ drums = normalize_audio(drums, strategy="loudness", loudness_headroom_db=16, sample_rate=drums_sr)
95
+ preprocessed_drums_wav = drums
96
+
97
+ try:
98
+ outputs = MODEL.generate_music(descriptions=texts, chords=chords,
99
+ drums_wav=preprocessed_drums_wav,
100
+ melody_salience_matrix=melody_matrix,
101
+ drums_sample_rate=drums_sr, progress=progress)
102
+ except RuntimeError as e:
103
+ raise gr.Error("Error while generating " + e.args[0])
104
+
105
+ outputs = outputs.detach().cpu().float()
106
+ out_wavs = []
107
+ for output in outputs:
108
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
109
+ audio_write(
110
+ file.name, output, MODEL.sample_rate, strategy="loudness",
111
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
112
+ out_wavs.append(file.name)
113
+ file_cleaner.add(file.name)
114
+ return out_wavs
115
+
116
+ @spaces.GPU
117
+ def predict_full(model, text, chords_sym, melody_file,
118
+ drums_file, drums_mic, drum_input_src,
119
+ cfg_coef_all, cfg_coef_txt,
120
+ ode_rtol, ode_atol,
121
+ ode_solver, ode_steps,
122
+ progress=gr.Progress()):
123
+ global INTERRUPTING
124
+ INTERRUPTING = False
125
+ progress(0, desc="Loading model...")
126
+ load_model(model)
127
+
128
+ max_generated = 0
129
+
130
+ def _progress(generated, to_generate):
131
+ nonlocal max_generated
132
+ max_generated = max(generated, max_generated)
133
+ progress((min(max_generated, to_generate), to_generate))
134
+ if INTERRUPTING:
135
+ raise gr.Error("Interrupted.")
136
+
137
+ MODEL.set_custom_progress_callback(_progress)
138
+
139
+ drums = drums_mic if drum_input_src == "mic" else drums_file
140
+ wavs = _do_predictions(
141
+ texts=[text] * 2,
142
+ chords=chords_sym,
143
+ drum_prompt=drums,
144
+ melody_matrix=melody_file,
145
+ progress=True,
146
+ gradio_progress=progress,
147
+ cfg_coef_all=cfg_coef_all,
148
+ cfg_coef_txt=cfg_coef_txt,
149
+ ode_rtol=ode_rtol,
150
+ ode_atol=ode_atol,
151
+ euler=ode_solver == 'euler',
152
+ euler_steps=ode_steps)
153
+
154
+ return wavs
155
+
156
+ with gr.Blocks() as demo:
157
+ gr.Markdown("""
158
+ # JASCO - Text-to-Music Generation with Temporal Control
159
+ Generate 10-second music clips using text descriptions and temporal controls (chords, drums, melody).
160
+ """)
161
+
162
+ with gr.Row():
163
+ with gr.Column():
164
+ submit = gr.Button("Generate")
165
+ interrupt_btn = gr.Button("Interrupt")
166
+
167
+ with gr.Column():
168
+ audio_output_0 = gr.Audio(label="Generated Audio 1", type='filepath')
169
+ audio_output_1 = gr.Audio(label="Generated Audio 2", type='filepath')
170
+
171
+ with gr.Row():
172
+ with gr.Column():
173
+ text = gr.Text(label="Input Text",
174
+ value="Strings, woodwind, orchestral, symphony.",
175
+ interactive=True)
176
+ with gr.Column():
177
+ model = gr.Radio([
178
+ 'facebook/jasco-chords-drums-400M',
179
+ 'facebook/jasco-chords-drums-1B',
180
+ 'facebook/jasco-chords-drums-melody-400M',
181
+ 'facebook/jasco-chords-drums-melody-1B'
182
+ ], label="Model", value='facebook/jasco-chords-drums-melody-400M')
183
+
184
+ gr.Markdown("### Chords Conditions")
185
+ chords_sym = gr.Text(
186
+ label="Chord Progression",
187
+ value="(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)",
188
+ interactive=True
189
+ )
190
+
191
+ gr.Markdown("### Drums Conditions")
192
+ with gr.Row():
193
+ drum_input_src = gr.Radio(["file", "mic"], value="file", label="Drums Input Source")
194
+ drums_file = gr.Audio(sources=["upload"], type="numpy", label="Drums File")
195
+ drums_mic = gr.Audio(sources=["microphone"], type="numpy", label="Drums Mic")
196
+
197
+ gr.Markdown("### Melody Conditions")
198
+ melody_file = gr.File(label="Melody File")
199
+
200
+ with gr.Row():
201
+ cfg_coef_all = gr.Number(label="CFG ALL", value=1.25, step=0.25)
202
+ cfg_coef_txt = gr.Number(label="CFG TEXT", value=2.5, step=0.25)
203
+ ode_tol = gr.Number(label="ODE Tolerance", value=1e-4, step=1e-5)
204
+ ode_solver = gr.Radio(['euler', 'dopri5'], label="ODE Solver", value='euler')
205
+ ode_steps = gr.Number(label="Euler Steps", value=10, step=1)
206
+
207
+ submit.click(
208
+ fn=predict_full,
209
+ inputs=[
210
+ model, text, chords_sym, melody_file,
211
+ drums_file, drums_mic, drum_input_src,
212
+ cfg_coef_all, cfg_coef_txt,
213
+ ode_tol, ode_tol, ode_solver, ode_steps
214
+ ],
215
+ outputs=[audio_output_0, audio_output_1]
216
+ )
217
+
218
+ interrupt_btn.click(fn=interrupt, queue=False)
219
+
220
+ gr.Examples(
221
+ examples=[
222
+ [
223
+ "80s pop with groovy synth bass and electric piano",
224
+ "(N, 0.0), (C, 0.32), (Dm7, 3.456), (Am, 4.608), (F, 8.32), (C, 9.216)",
225
+ None,
226
+ None,
227
+ ],
228
+ [
229
+ "Strings, woodwind, orchestral, symphony.",
230
+ "(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)",
231
+ None,
232
+ None,
233
+ ],
234
+ ],
235
+ inputs=[text, chords_sym, melody_file, drums_file],
236
+ outputs=[audio_output_0, audio_output_1]
237
+ )
238
+
239
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ accelerate
4
+ audiocraft