adefossez commited on
Commit
62cbdc0
2 Parent(s): df20ddc 13ef076

Merge branch 'longgen' into our_hf2

Browse files
README.md CHANGED
@@ -5,7 +5,7 @@ tags:
5
  - "music generation"
6
  - "language models"
7
  - "LLMs"
8
- app_file: "app_batched.py"
9
  emoji: 🎵
10
  colorFrom: white
11
  colorTo: blue
@@ -54,11 +54,12 @@ pip install -e . # or if you cloned the repo locally
54
 
55
  ## Usage
56
  We offer a number of way to interact with MusicGen:
57
- 1. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally, or use the provided [colab notebook](https://colab.research.google.com/drive/1fxGqfg96RBUvGxZ1XXN07s3DthrKUl4-?usp=sharing).
58
- 2. You can use the gradio demo locally by running `python app.py`.
59
- 3. A demo is also available on the [`facebook/MusicGen` HuggingFace Space](https://huggingface.co/spaces/facebook/MusicGen) (huge thanks to all the HF team for their support).
60
- 4. Finally, you can run the [Gradio demo with a Colab GPU](https://colab.research.google.com/drive/1-Xe9NCdIs2sCUbiSmwHXozK6AAhMm7_i?usp=sharing),
61
- as adapted from [@camenduru Colab](https://github.com/camenduru/MusicGen-colab).
 
62
 
63
  ## API
64
 
 
5
  - "music generation"
6
  - "language models"
7
  - "LLMs"
8
+ app_file: "app.py"
9
  emoji: 🎵
10
  colorFrom: white
11
  colorTo: blue
 
54
 
55
  ## Usage
56
  We offer a number of way to interact with MusicGen:
57
+ 1. A demo is also available on the [`facebook/MusicGen` HuggingFace Space](https://huggingface.co/spaces/facebook/MusicGen) (huge thanks to all the HF team for their support).
58
+ 2. You can run the extended demo on a Colab: [colab notebook](https://colab.research.google.com/drive/1fxGqfg96RBUvGxZ1XXN07s3DthrKUl4-?usp=sharing).
59
+ 3. You can use the gradio demo locally by running `python app.py`.
60
+ 4. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally (if you have a GPU).
61
+ 5. Finally, checkout [@camenduru Colab page](https://github.com/camenduru/MusicGen-colab) which is regularly
62
+ updated with contributions from @camenduru and the community.
63
 
64
  ## API
65
 
app.py CHANGED
@@ -1,70 +1,139 @@
1
- """
2
- Copyright (c) Meta Platforms, Inc. and affiliates.
3
- All rights reserved.
4
 
5
- This source code is licensed under the license found in the
6
- LICENSE file in the root directory of this source tree.
7
- """
 
 
8
 
9
- from tempfile import NamedTemporaryFile
10
  import argparse
 
 
 
 
 
 
 
11
  import torch
12
  import gradio as gr
13
- import os
14
- from audiocraft.models import MusicGen
15
  from audiocraft.data.audio import audio_write
 
16
 
17
- MODEL = None
18
- IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ.get('SPACE_ID', '')
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- def load_model(version):
22
- print("Loading model", version)
23
- return MusicGen.get_pretrained(version)
24
 
 
 
 
25
 
26
- def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
 
 
 
 
 
 
 
 
 
 
27
  global MODEL
28
- topk = int(topk)
29
- if MODEL is None or MODEL.name != model:
30
- MODEL = load_model(model)
31
-
32
- if duration > MODEL.lm.cfg.dataset.segment_duration:
33
- raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
34
- MODEL.set_generation_params(
35
- use_sampling=True,
36
- top_k=topk,
37
- top_p=topp,
38
- temperature=temperature,
39
- cfg_coef=cfg_coef,
40
- duration=duration,
41
- )
 
 
 
 
 
 
 
 
42
 
43
- if melody:
44
- sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
45
- print(melody.shape)
46
- if melody.dim() == 2:
47
- melody = melody[None]
48
- melody = melody[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
49
- output = MODEL.generate_with_chroma(
50
- descriptions=[text],
51
- melody_wavs=melody,
52
- melody_sample_rate=sr,
53
- progress=False
54
  )
55
  else:
56
- output = MODEL.generate(descriptions=[text], progress=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- output = output.detach().cpu().float()[0]
59
- with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
60
- audio_write(
61
- file.name, output, MODEL.sample_rate, strategy="loudness",
62
- loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
63
- waveform_video = gr.make_waveform(file.name)
64
- return waveform_video
65
 
 
 
 
 
66
 
67
- def ui(**kwargs):
 
68
  with gr.Blocks() as interface:
69
  gr.Markdown(
70
  """
@@ -73,14 +142,6 @@ def ui(**kwargs):
73
  presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
74
  """
75
  )
76
- if IS_SHARED_SPACE:
77
- gr.Markdown("""
78
- ⚠ This Space doesn't work in this shared UI ⚠
79
-
80
- <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
81
- <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
82
- to use it privately, or use the <a href="https://huggingface.co/spaces/facebook/MusicGen">public demo</a>
83
- """)
84
  with gr.Row():
85
  with gr.Column():
86
  with gr.Row():
@@ -88,10 +149,12 @@ def ui(**kwargs):
88
  melody = gr.Audio(source="upload", type="numpy", label="Melody Condition (optional)", interactive=True)
89
  with gr.Row():
90
  submit = gr.Button("Submit")
 
 
91
  with gr.Row():
92
  model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
93
  with gr.Row():
94
- duration = gr.Slider(minimum=1, maximum=30, value=10, label="Duration", interactive=True)
95
  with gr.Row():
96
  topk = gr.Number(label="Top-k", value=250, interactive=True)
97
  topp = gr.Number(label="Top-p", value=0, interactive=True)
@@ -99,9 +162,9 @@ def ui(**kwargs):
99
  cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
100
  with gr.Column():
101
  output = gr.Video(label="Generated Music")
102
- submit.click(predict, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
103
  gr.Examples(
104
- fn=predict,
105
  examples=[
106
  [
107
  "An 80s driving pop song with heavy drums and synth pads in the background",
@@ -137,7 +200,13 @@ def ui(**kwargs):
137
  ### More details
138
 
139
  The model will generate a short music extract based on the description you provided.
140
- You can generate up to 30 seconds of audio.
 
 
 
 
 
 
141
 
142
  We present 4 model variations:
143
  1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
@@ -154,27 +223,75 @@ def ui(**kwargs):
154
  """
155
  )
156
 
157
- # Show the interface
158
- launch_kwargs = {}
159
- username = kwargs.get('username')
160
- password = kwargs.get('password')
161
- server_port = kwargs.get('server_port', 0)
162
- inbrowser = kwargs.get('inbrowser', False)
163
- share = kwargs.get('share', False)
164
- server_name = kwargs.get('listen')
165
 
166
- launch_kwargs['server_name'] = server_name
167
 
168
- if username and password:
169
- launch_kwargs['auth'] = (username, password)
170
- if server_port > 0:
171
- launch_kwargs['server_port'] = server_port
172
- if inbrowser:
173
- launch_kwargs['inbrowser'] = inbrowser
174
- if share:
175
- launch_kwargs['share'] = share
176
 
177
- interface.queue().launch(**launch_kwargs, max_threads=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
 
180
  if __name__ == "__main__":
@@ -182,7 +299,11 @@ if __name__ == "__main__":
182
  parser.add_argument(
183
  '--listen',
184
  type=str,
 
185
  default='0.0.0.0',
 
 
 
186
  help='IP to listen on for connections to Gradio',
187
  )
188
  parser.add_argument(
@@ -206,11 +327,18 @@ if __name__ == "__main__":
206
 
207
  args = parser.parse_args()
208
 
209
- ui(
210
- username=args.username,
211
- password=args.password,
212
- inbrowser=args.inbrowser,
213
- server_port=args.server_port,
214
- share=args.share,
215
- listen=args.listen
216
- )
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
 
3
 
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
8
+ # also released under the MIT license.
9
 
 
10
  import argparse
11
+ from concurrent.futures import ProcessPoolExecutor
12
+ import os
13
+ import subprocess as sp
14
+ from tempfile import NamedTemporaryFile
15
+ import time
16
+ import warnings
17
+
18
  import torch
19
  import gradio as gr
20
+
21
+ from audiocraft.data.audio_utils import convert_audio
22
  from audiocraft.data.audio import audio_write
23
+ from audiocraft.models import MusicGen
24
 
 
 
25
 
26
+ MODEL = None # Last used model
27
+ IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
28
+ MAX_BATCH_SIZE = 12
29
+ BATCHED_DURATION = 15
30
+ INTERRUPTING = False
31
+ # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
32
+ _old_call = sp.call
33
+
34
+
35
+ def _call_nostderr(*args, **kwargs):
36
+ # Avoid ffmpeg vomitting on the logs.
37
+ kwargs['stderr'] = sp.DEVNULL
38
+ kwargs['stdout'] = sp.DEVNULL
39
+ _old_call(*args, **kwargs)
40
+
41
+
42
+ sp.call = _call_nostderr
43
+ # Preallocating the pool of processes.
44
+ pool = ProcessPoolExecutor(4)
45
+ pool.__enter__()
46
 
 
 
 
47
 
48
+ def interrupt():
49
+ global INTERRUPTING
50
+ INTERRUPTING = True
51
 
52
+ def make_waveform(*args, **kwargs):
53
+ # Further remove some warnings.
54
+ be = time.time()
55
+ with warnings.catch_warnings():
56
+ warnings.simplefilter('ignore')
57
+ out = gr.make_waveform(*args, **kwargs)
58
+ print("Make a video took", time.time() - be)
59
+ return out
60
+
61
+
62
+ def load_model(version='melody'):
63
  global MODEL
64
+ print("Loading model", version)
65
+ if MODEL is None or MODEL.name != version:
66
+ MODEL = MusicGen.get_pretrained(version)
67
+
68
+
69
+ def _do_predictions(texts, melodies, duration, **gen_kwargs):
70
+ MODEL.set_generation_params(duration=duration, **gen_kwargs)
71
+ print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
72
+ be = time.time()
73
+ processed_melodies = []
74
+ target_sr = 32000
75
+ target_ac = 1
76
+ for melody in melodies:
77
+ if melody is None:
78
+ processed_melodies.append(None)
79
+ else:
80
+ sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
81
+ if melody.dim() == 1:
82
+ melody = melody[None]
83
+ melody = melody[..., :int(sr * duration)]
84
+ melody = convert_audio(melody, sr, target_sr, target_ac)
85
+ processed_melodies.append(melody)
86
 
87
+ if any(m is not None for m in processed_melodies):
88
+ outputs = MODEL.generate_with_chroma(
89
+ descriptions=texts,
90
+ melody_wavs=processed_melodies,
91
+ melody_sample_rate=target_sr,
92
+ progress=True
 
 
 
 
 
93
  )
94
  else:
95
+ outputs = MODEL.generate(texts, progress=True)
96
+
97
+ outputs = outputs.detach().cpu().float()
98
+ out_files = []
99
+ for output in outputs:
100
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
101
+ audio_write(
102
+ file.name, output, MODEL.sample_rate, strategy="loudness",
103
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
104
+ out_files.append(pool.submit(make_waveform, file.name))
105
+ res = [out_file.result() for out_file in out_files]
106
+ print("batch finished", len(texts), time.time() - be)
107
+ return res
108
+
109
+
110
+ def predict_batched(texts, melodies):
111
+ max_text_length = 512
112
+ texts = [text[:max_text_length] for text in texts]
113
+ load_model('melody')
114
+ res = _do_predictions(texts, melodies, BATCHED_DURATION)
115
+ return [res]
116
+
117
+
118
+ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
119
+ global INTERRUPTING
120
+ INTERRUPTING = False
121
+ topk = int(topk)
122
+ load_model(model)
123
 
124
+ def _progress(generated, to_generate):
125
+ progress((generated, to_generate))
126
+ if INTERRUPTING:
127
+ raise gr.Error("Interrupted.")
128
+ MODEL.set_custom_progress_callback(_progress)
 
 
129
 
130
+ outs = _do_predictions(
131
+ [text], [melody], duration,
132
+ top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
133
+ return outs[0]
134
 
135
+
136
+ def ui_full(launch_kwargs):
137
  with gr.Blocks() as interface:
138
  gr.Markdown(
139
  """
 
142
  presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
143
  """
144
  )
 
 
 
 
 
 
 
 
145
  with gr.Row():
146
  with gr.Column():
147
  with gr.Row():
 
149
  melody = gr.Audio(source="upload", type="numpy", label="Melody Condition (optional)", interactive=True)
150
  with gr.Row():
151
  submit = gr.Button("Submit")
152
+ # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
153
+ _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
154
  with gr.Row():
155
  model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
156
  with gr.Row():
157
+ duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
158
  with gr.Row():
159
  topk = gr.Number(label="Top-k", value=250, interactive=True)
160
  topp = gr.Number(label="Top-p", value=0, interactive=True)
 
162
  cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
163
  with gr.Column():
164
  output = gr.Video(label="Generated Music")
165
+ submit.click(predict_full, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
166
  gr.Examples(
167
+ fn=predict_full,
168
  examples=[
169
  [
170
  "An 80s driving pop song with heavy drums and synth pads in the background",
 
200
  ### More details
201
 
202
  The model will generate a short music extract based on the description you provided.
203
+ The model can generate up to 30 seconds of audio in one pass. It is now possible
204
+ to extend the generation by feeding back the end of the previous chunk of audio.
205
+ This can take a long time, and the model might lose consistency. The model might also
206
+ decide at arbitrary positions that the song ends.
207
+
208
+ **WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min). An overlap of 12 seconds
209
+ is kept with the previously generated chunk, and 18 "new" seconds are generated each time.
210
 
211
  We present 4 model variations:
212
  1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
 
223
  """
224
  )
225
 
226
+ interface.queue().launch(**launch_kwargs)
 
 
 
 
 
 
 
227
 
 
228
 
229
+ def ui_batched(launch_kwargs):
230
+ with gr.Blocks() as demo:
231
+ gr.Markdown(
232
+ """
233
+ # MusicGen
 
 
 
234
 
235
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
236
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
237
+ <br/>
238
+ <a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
239
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
240
+ for longer sequences, more control and no queue.</p>
241
+ """
242
+ )
243
+ with gr.Row():
244
+ with gr.Column():
245
+ with gr.Row():
246
+ text = gr.Text(label="Describe your music", lines=2, interactive=True)
247
+ melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
248
+ with gr.Row():
249
+ submit = gr.Button("Generate")
250
+ with gr.Column():
251
+ output = gr.Video(label="Generated Music")
252
+ submit.click(predict_batched, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=MAX_BATCH_SIZE)
253
+ gr.Examples(
254
+ fn=predict_batched,
255
+ examples=[
256
+ [
257
+ "An 80s driving pop song with heavy drums and synth pads in the background",
258
+ "./assets/bach.mp3",
259
+ ],
260
+ [
261
+ "A cheerful country song with acoustic guitars",
262
+ "./assets/bolero_ravel.mp3",
263
+ ],
264
+ [
265
+ "90s rock song with electric guitar and heavy drums",
266
+ None,
267
+ ],
268
+ [
269
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
270
+ "./assets/bach.mp3",
271
+ ],
272
+ [
273
+ "lofi slow bpm electro chill with organic samples",
274
+ None,
275
+ ],
276
+ ],
277
+ inputs=[text, melody],
278
+ outputs=[output]
279
+ )
280
+ gr.Markdown("""
281
+ ### More details
282
+
283
+ The model will generate 12 seconds of audio based on the description you provided.
284
+ You can optionaly provide a reference audio from which a broad melody will be extracted.
285
+ The model will then try to follow both the description and melody provided.
286
+ All samples are generated with the `melody` model.
287
+
288
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
289
+
290
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
291
+ for more details.
292
+ """)
293
+
294
+ demo.queue(max_size=8 * 4).launch(**launch_kwargs)
295
 
296
 
297
  if __name__ == "__main__":
 
299
  parser.add_argument(
300
  '--listen',
301
  type=str,
302
+ <<<<<<< HEAD
303
  default='0.0.0.0',
304
+ =======
305
+ default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
306
+ >>>>>>> longgen
307
  help='IP to listen on for connections to Gradio',
308
  )
309
  parser.add_argument(
 
327
 
328
  args = parser.parse_args()
329
 
330
+ launch_kwargs = {}
331
+ if args.username and args.password:
332
+ launch_kwargs['auth'] = (args.username, args.password)
333
+ if args.server_port:
334
+ launch_kwargs['server_port'] = args.server_port
335
+ if args.inbrowser:
336
+ launch_kwargs['inbrowser'] = args.inbrowser
337
+ if args.share:
338
+ launch_kwargs['share'] = args.share
339
+
340
+ # Show the interface
341
+ if IS_BATCHED:
342
+ ui_batched(launch_kwargs)
343
+ else:
344
+ ui_full(launch_kwargs)
audiocraft/models/loaders.py CHANGED
@@ -80,8 +80,6 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_di
80
  cfg = OmegaConf.create(pkg['xp.cfg'])
81
  cfg.device = str(device)
82
  if cfg.device == 'cpu':
83
- cfg.transformer_lm.memory_efficient = False
84
- cfg.transformer_lm.custom = True
85
  cfg.dtype = 'float32'
86
  else:
87
  cfg.dtype = 'float16'
 
80
  cfg = OmegaConf.create(pkg['xp.cfg'])
81
  cfg.device = str(device)
82
  if cfg.device == 'cpu':
 
 
83
  cfg.dtype = 'float32'
84
  else:
85
  cfg.dtype = 'float16'
audiocraft/models/musicgen.py CHANGED
@@ -36,13 +36,16 @@ class MusicGen:
36
  used to map audio to invertible discrete representations.
37
  lm (LMModel): Language model over discrete representations.
38
  """
39
- def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel):
 
40
  self.name = name
41
  self.compression_model = compression_model
42
  self.lm = lm
 
43
  self.device = next(iter(lm.parameters())).device
44
  self.generation_params: dict = {}
45
  self.set_generation_params(duration=15) # 15 seconds by default
 
46
  if self.device.type == 'cpu':
47
  self.autocast = TorchAutocast(enabled=False)
48
  else:
@@ -65,7 +68,7 @@ class MusicGen:
65
  return self.compression_model.channels
66
 
67
  @staticmethod
68
- def get_pretrained(name: str = 'melody', device='cuda'):
69
  """Return pretrained model, we provide four models:
70
  - small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
71
  - medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
@@ -73,6 +76,12 @@ class MusicGen:
73
  - large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
74
  """
75
 
 
 
 
 
 
 
76
  if name == 'debug':
77
  # used only for unit tests
78
  compression_model = get_debug_compression_model(device)
@@ -96,7 +105,7 @@ class MusicGen:
96
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
97
  top_p: float = 0.0, temperature: float = 1.0,
98
  duration: float = 30.0, cfg_coef: float = 3.0,
99
- two_step_cfg: bool = False, extend_stride: float = 15):
100
  """Set the generation parameters for MusicGen.
101
 
102
  Args:
@@ -113,11 +122,10 @@ class MusicGen:
113
  should we extend the audio each time. Larger values will mean less context is
114
  preserved, and shorter value will require extra computations.
115
  """
116
- # assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
117
- assert extend_stride <= 25, "Keep at least 5 seconds of overlap!"
118
  self.extend_stride = extend_stride
 
119
  self.generation_params = {
120
- 'max_gen_len': int(duration * self.frame_rate),
121
  'use_sampling': use_sampling,
122
  'temp': temperature,
123
  'top_k': top_k,
@@ -126,6 +134,10 @@ class MusicGen:
126
  'two_step_cfg': two_step_cfg,
127
  }
128
 
 
 
 
 
129
  def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
130
  """Generate samples in an unconditional manner.
131
 
@@ -268,20 +280,79 @@ class MusicGen:
268
  Returns:
269
  torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
270
  """
 
 
 
 
271
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
272
- print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
 
 
 
 
 
 
273
 
274
  if prompt_tokens is not None:
275
- assert self.generation_params['max_gen_len'] > prompt_tokens.shape[-1], \
276
  "Prompt is longer than audio to generate"
277
 
278
  callback = None
279
  if progress:
280
  callback = _progress_callback
281
 
282
- # generate by sampling from LM
283
- with self.autocast:
284
- gen_tokens = self.lm.generate(prompt_tokens, attributes, callback=callback, **self.generation_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
  # generate audio
287
  assert gen_tokens.dim() == 3
 
36
  used to map audio to invertible discrete representations.
37
  lm (LMModel): Language model over discrete representations.
38
  """
39
+ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
40
+ max_duration: float = 30):
41
  self.name = name
42
  self.compression_model = compression_model
43
  self.lm = lm
44
+ self.max_duration = max_duration
45
  self.device = next(iter(lm.parameters())).device
46
  self.generation_params: dict = {}
47
  self.set_generation_params(duration=15) # 15 seconds by default
48
+ self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
49
  if self.device.type == 'cpu':
50
  self.autocast = TorchAutocast(enabled=False)
51
  else:
 
68
  return self.compression_model.channels
69
 
70
  @staticmethod
71
+ def get_pretrained(name: str = 'melody', device=None):
72
  """Return pretrained model, we provide four models:
73
  - small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
74
  - medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
 
76
  - large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
77
  """
78
 
79
+ if device is None:
80
+ if torch.cuda.device_count():
81
+ device = 'cuda'
82
+ else:
83
+ device = 'cpu'
84
+
85
  if name == 'debug':
86
  # used only for unit tests
87
  compression_model = get_debug_compression_model(device)
 
105
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
106
  top_p: float = 0.0, temperature: float = 1.0,
107
  duration: float = 30.0, cfg_coef: float = 3.0,
108
+ two_step_cfg: bool = False, extend_stride: float = 18):
109
  """Set the generation parameters for MusicGen.
110
 
111
  Args:
 
122
  should we extend the audio each time. Larger values will mean less context is
123
  preserved, and shorter value will require extra computations.
124
  """
125
+ assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
 
126
  self.extend_stride = extend_stride
127
+ self.duration = duration
128
  self.generation_params = {
 
129
  'use_sampling': use_sampling,
130
  'temp': temperature,
131
  'top_k': top_k,
 
134
  'two_step_cfg': two_step_cfg,
135
  }
136
 
137
+ def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
138
+ """Override the default progress callback."""
139
+ self._progress_callback = progress_callback
140
+
141
  def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
142
  """Generate samples in an unconditional manner.
143
 
 
280
  Returns:
281
  torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
282
  """
283
+ total_gen_len = int(self.duration * self.frame_rate)
284
+ max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
285
+ current_gen_offset: int = 0
286
+
287
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
288
+ generated_tokens += current_gen_offset
289
+ if self._progress_callback is not None:
290
+ # Note that total_gen_len might be quite wrong depending on the
291
+ # codebook pattern used, but with delay it is almost accurate.
292
+ self._progress_callback(generated_tokens, total_gen_len)
293
+ else:
294
+ print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
295
 
296
  if prompt_tokens is not None:
297
+ assert max_prompt_len >= prompt_tokens.shape[-1], \
298
  "Prompt is longer than audio to generate"
299
 
300
  callback = None
301
  if progress:
302
  callback = _progress_callback
303
 
304
+ if self.duration <= self.max_duration:
305
+ # generate by sampling from LM, simple case.
306
+ with self.autocast:
307
+ gen_tokens = self.lm.generate(
308
+ prompt_tokens, attributes,
309
+ callback=callback, max_gen_len=total_gen_len, **self.generation_params)
310
+
311
+ else:
312
+ # now this gets a bit messier, we need to handle prompts,
313
+ # melody conditioning etc.
314
+ ref_wavs = [attr.wav['self_wav'] for attr in attributes]
315
+ all_tokens = []
316
+ if prompt_tokens is None:
317
+ prompt_length = 0
318
+ else:
319
+ all_tokens.append(prompt_tokens)
320
+ prompt_length = prompt_tokens.shape[-1]
321
+
322
+ stride_tokens = int(self.frame_rate * self.extend_stride)
323
+
324
+ while current_gen_offset + prompt_length < total_gen_len:
325
+ time_offset = current_gen_offset / self.frame_rate
326
+ chunk_duration = min(self.duration - time_offset, self.max_duration)
327
+ max_gen_len = int(chunk_duration * self.frame_rate)
328
+ for attr, ref_wav in zip(attributes, ref_wavs):
329
+ wav_length = ref_wav.length.item()
330
+ if wav_length == 0:
331
+ continue
332
+ # We will extend the wav periodically if it not long enough.
333
+ # we have to do it here rather than in conditioners.py as otherwise
334
+ # we wouldn't have the full wav.
335
+ initial_position = int(time_offset * self.sample_rate)
336
+ wav_target_length = int(self.max_duration * self.sample_rate)
337
+ print(initial_position / self.sample_rate, wav_target_length / self.sample_rate)
338
+ positions = torch.arange(initial_position,
339
+ initial_position + wav_target_length, device=self.device)
340
+ attr.wav['self_wav'] = WavCondition(
341
+ ref_wav[0][:, positions % wav_length],
342
+ torch.full_like(ref_wav[1], wav_target_length))
343
+ with self.autocast:
344
+ gen_tokens = self.lm.generate(
345
+ prompt_tokens, attributes,
346
+ callback=callback, max_gen_len=max_gen_len, **self.generation_params)
347
+ if prompt_tokens is None:
348
+ all_tokens.append(gen_tokens)
349
+ else:
350
+ all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
351
+ prompt_tokens = gen_tokens[:, :, stride_tokens:]
352
+ prompt_length = prompt_tokens.shape[-1]
353
+ current_gen_offset += stride_tokens
354
+
355
+ gen_tokens = torch.cat(all_tokens, dim=-1)
356
 
357
  # generate audio
358
  assert gen_tokens.dim() == 3
audiocraft/modules/transformer.py CHANGED
@@ -25,6 +25,22 @@ from xformers import ops
25
  from .rope import RotaryEmbedding
26
  from .streaming import StreamingModule
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def _is_profiled() -> bool:
30
  # Return true if we are currently running with a xformers profiler activated.
@@ -75,14 +91,22 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float =
75
 
76
  def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
77
  """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers"""
78
- bs, slen, n_kv_heads, head_dim = x.shape
79
  if n_rep == 1:
80
  return x
81
- return (
82
- x[:, :, :, None, :]
83
- .expand(bs, slen, n_kv_heads, n_rep, head_dim)
84
- .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
85
- )
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  class LayerScale(nn.Module):
@@ -210,6 +234,7 @@ class StreamingMultiheadAttention(StreamingModule):
210
  # Return a causal mask, accounting for potentially stored past keys/values
211
  # We actually return a bias for the attention score, as this has the same
212
  # convention both in the builtin MHA in Pytorch, and Xformers functions.
 
213
  if self.memory_efficient:
214
  from xformers.ops import LowerTriangularMask
215
  if current_steps == 1:
@@ -222,7 +247,7 @@ class StreamingMultiheadAttention(StreamingModule):
222
  return LowerTriangularMask()
223
  if self._streaming_state:
224
  past_keys = self._streaming_state['past_keys']
225
- past_steps = past_keys.shape[1]
226
  else:
227
  past_steps = 0
228
 
@@ -239,6 +264,7 @@ class StreamingMultiheadAttention(StreamingModule):
239
  torch.full([], float('-inf'), device=device, dtype=dtype))
240
 
241
  def _complete_kv(self, k, v):
 
242
  if self.cross_attention:
243
  # With cross attention we assume all keys and values
244
  # are already available, and streaming is with respect
@@ -247,20 +273,20 @@ class StreamingMultiheadAttention(StreamingModule):
247
  # Complete the key/value pair using the streaming state.
248
  if self._streaming_state:
249
  pk = self._streaming_state['past_keys']
250
- nk = torch.cat([pk, k], dim=2)
251
  if v is k:
252
  nv = nk
253
  else:
254
  pv = self._streaming_state['past_values']
255
- nv = torch.cat([pv, v], dim=2)
256
  else:
257
  nk = k
258
  nv = v
259
 
260
- assert nk.shape[2] == nv.shape[2]
261
  offset = 0
262
  if self.past_context is not None:
263
- offset = max(0, nk.shape[2] - self.past_context)
264
  if self._is_streaming:
265
  self._streaming_state['past_keys'] = nk[:, offset:]
266
  if v is not k:
@@ -271,8 +297,9 @@ class StreamingMultiheadAttention(StreamingModule):
271
  self._streaming_state['offset'] = torch.tensor(0)
272
  return nk, nv
273
 
274
-
275
  def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
 
 
276
  # Apply rope embeddings to query and key tensors.
277
  assert self.rope is not None
278
  if 'past_keys' in self._streaming_state:
@@ -293,6 +320,11 @@ class StreamingMultiheadAttention(StreamingModule):
293
  assert not is_causal, ("new param added in torch 2.0.1 not supported, "
294
  "use the causal args in the constructor.")
295
 
 
 
 
 
 
296
  dtype = query.dtype
297
  if self._is_streaming:
298
  assert self.causal or self.cross_attention, \
@@ -325,8 +357,7 @@ class StreamingMultiheadAttention(StreamingModule):
325
  if self.qk_layer_norm is True:
326
  q = self.q_layer_norm(q)
327
  k = self.k_layer_norm(k)
328
- # q, k, v = [rearrange(x, "b t (h d) -> (b h) t d", h=self.num_heads) for x in [q, k, v]]
329
- q, k, v = [rearrange(x, "b t (h d) -> b h t d", h=self.num_heads) for x in [q, k, v]]
330
  else:
331
  if not _is_profiled():
332
  # profiling breaks that propertysomehow.
@@ -334,7 +365,11 @@ class StreamingMultiheadAttention(StreamingModule):
334
  assert value is key, "specialized implementation"
335
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
336
  if self.kv_repeat == 1:
337
- packed = rearrange(projected, "b t (p h d) -> b h p t d", p=3, h=self.num_heads)
 
 
 
 
338
  q, k, v = ops.unbind(packed, dim=2)
339
  else:
340
  embed_dim = self.embed_dim
@@ -345,18 +380,17 @@ class StreamingMultiheadAttention(StreamingModule):
345
  end = start + per_head_dim * kv_heads
346
  k = projected[:, :, start: end]
347
  v = projected[:, :, end:]
348
- q = rearrange(q, "b t (h d) -> b t h d", h=self.num_heads)
349
- k = rearrange(k, "b t (h d) -> b t h d", h=kv_heads)
350
- v = rearrange(v, "b t (h d) -> b t h d", h=kv_heads)
351
 
352
  if self.qk_layer_norm is True:
353
  assert self.kv_repeat == 1
354
- q, k = [rearrange(x, "b t h d -> b t (h d)") for x in [q, k]]
355
  q = self.q_layer_norm(q)
356
  k = self.k_layer_norm(k)
357
- q, k = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k]]
358
  if self.rope:
359
- assert False, "Not supported for now"
360
  q, k = self._apply_rope(q, k)
361
  k, v = self._complete_kv(k, v)
362
  if self.kv_repeat > 1:
@@ -366,8 +400,11 @@ class StreamingMultiheadAttention(StreamingModule):
366
  q, k, v = [x.float() for x in [q, k, v]]
367
  if self.memory_efficient:
368
  p = self.dropout if self.training else 0
369
- x = torch.nn.functional.scaled_dot_product_attention(
370
- q, k, v, is_causal=attn_mask is not None, dropout_p=p)
 
 
 
371
  else:
372
  # We include the dot product as float32, for consistency
373
  # with the other implementations that include that step
@@ -377,18 +414,21 @@ class StreamingMultiheadAttention(StreamingModule):
377
  # extend a bit the range of operations done in float32,
378
  # although this should make no difference.
379
  q = q / q.shape[-1] ** 0.5
 
 
380
  if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
381
  with torch.autocast(device_type=q.device.type, dtype=torch.float32):
382
- pre_w = torch.einsum("bqhc,bkhc->bhqk", q, k)
383
  else:
384
- pre_w = torch.einsum("bqhc,bkhc->bhqk", q, k)
385
  if attn_mask is not None:
386
  pre_w = pre_w + attn_mask
387
  w = torch.softmax(pre_w, dim=-1)
388
  w = F.dropout(w, self.dropout, training=self.training).to(v)
389
- x = torch.einsum("bhqk,bkhc->bqhc", w, v)
 
390
  x = x.to(dtype)
391
- x = rearrange(x, "b h t d -> b t (h d)", h=self.num_heads)
392
  x = self.out_proj(x)
393
  else:
394
  key, value = self._complete_kv(key, value)
 
25
  from .rope import RotaryEmbedding
26
  from .streaming import StreamingModule
27
 
28
+ _efficient_attention_backend: str = 'torch'
29
+
30
+
31
+ def set_efficient_attention_backend(backend: str = 'torch'):
32
+ # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
33
+ global _efficient_attention_backend
34
+ assert _efficient_attention_backend in ['xformers', 'torch']
35
+ _efficient_attention_backend = backend
36
+
37
+
38
+ def _get_attention_time_dimension() -> int:
39
+ if _efficient_attention_backend == 'torch':
40
+ return 2
41
+ else:
42
+ return 1
43
+
44
 
45
  def _is_profiled() -> bool:
46
  # Return true if we are currently running with a xformers profiler activated.
 
91
 
92
  def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
93
  """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers"""
 
94
  if n_rep == 1:
95
  return x
96
+ if _efficient_attention_backend == 'torch':
97
+ bs, n_kv_heads, slen, head_dim = x.shape
98
+ return (
99
+ x[:, :, None, :, :]
100
+ .expand(bs, n_kv_heads, n_rep, slen, head_dim)
101
+ .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
102
+ )
103
+ else:
104
+ bs, slen, n_kv_heads, head_dim = x.shape
105
+ return (
106
+ x[:, :, :, None, :]
107
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
108
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
109
+ )
110
 
111
 
112
  class LayerScale(nn.Module):
 
234
  # Return a causal mask, accounting for potentially stored past keys/values
235
  # We actually return a bias for the attention score, as this has the same
236
  # convention both in the builtin MHA in Pytorch, and Xformers functions.
237
+ time_dim = _get_attention_time_dimension()
238
  if self.memory_efficient:
239
  from xformers.ops import LowerTriangularMask
240
  if current_steps == 1:
 
247
  return LowerTriangularMask()
248
  if self._streaming_state:
249
  past_keys = self._streaming_state['past_keys']
250
+ past_steps = past_keys.shape[time_dim]
251
  else:
252
  past_steps = 0
253
 
 
264
  torch.full([], float('-inf'), device=device, dtype=dtype))
265
 
266
  def _complete_kv(self, k, v):
267
+ time_dim = _get_attention_time_dimension()
268
  if self.cross_attention:
269
  # With cross attention we assume all keys and values
270
  # are already available, and streaming is with respect
 
273
  # Complete the key/value pair using the streaming state.
274
  if self._streaming_state:
275
  pk = self._streaming_state['past_keys']
276
+ nk = torch.cat([pk, k], dim=time_dim)
277
  if v is k:
278
  nv = nk
279
  else:
280
  pv = self._streaming_state['past_values']
281
+ nv = torch.cat([pv, v], dim=time_dim)
282
  else:
283
  nk = k
284
  nv = v
285
 
286
+ assert nk.shape[time_dim] == nv.shape[time_dim]
287
  offset = 0
288
  if self.past_context is not None:
289
+ offset = max(0, nk.shape[time_dim] - self.past_context)
290
  if self._is_streaming:
291
  self._streaming_state['past_keys'] = nk[:, offset:]
292
  if v is not k:
 
297
  self._streaming_state['offset'] = torch.tensor(0)
298
  return nk, nv
299
 
 
300
  def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
301
+ # TODO: fix and verify layout.
302
+ assert _efficient_attention_backend == 'xformers', 'Rope not supported with torch attn.'
303
  # Apply rope embeddings to query and key tensors.
304
  assert self.rope is not None
305
  if 'past_keys' in self._streaming_state:
 
320
  assert not is_causal, ("new param added in torch 2.0.1 not supported, "
321
  "use the causal args in the constructor.")
322
 
323
+ time_dim = _get_attention_time_dimension()
324
+ if time_dim == 2:
325
+ layout = "b h t d"
326
+ else:
327
+ layout = "b t h d"
328
  dtype = query.dtype
329
  if self._is_streaming:
330
  assert self.causal or self.cross_attention, \
 
357
  if self.qk_layer_norm is True:
358
  q = self.q_layer_norm(q)
359
  k = self.k_layer_norm(k)
360
+ q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
 
361
  else:
362
  if not _is_profiled():
363
  # profiling breaks that propertysomehow.
 
365
  assert value is key, "specialized implementation"
366
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
367
  if self.kv_repeat == 1:
368
+ if time_dim == 2:
369
+ bound_layout = "b h p t d"
370
+ else:
371
+ bound_layout = "b t p h d"
372
+ packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
373
  q, k, v = ops.unbind(packed, dim=2)
374
  else:
375
  embed_dim = self.embed_dim
 
380
  end = start + per_head_dim * kv_heads
381
  k = projected[:, :, start: end]
382
  v = projected[:, :, end:]
383
+ q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
384
+ k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
385
+ v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
386
 
387
  if self.qk_layer_norm is True:
388
  assert self.kv_repeat == 1
389
+ q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
390
  q = self.q_layer_norm(q)
391
  k = self.k_layer_norm(k)
392
+ q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
393
  if self.rope:
 
394
  q, k = self._apply_rope(q, k)
395
  k, v = self._complete_kv(k, v)
396
  if self.kv_repeat > 1:
 
400
  q, k, v = [x.float() for x in [q, k, v]]
401
  if self.memory_efficient:
402
  p = self.dropout if self.training else 0
403
+ if _efficient_attention_backend == 'torch':
404
+ x = torch.nn.functional.scaled_dot_product_attention(
405
+ q, k, v, is_causal=attn_mask is not None, dropout_p=p)
406
+ else:
407
+ x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
408
  else:
409
  # We include the dot product as float32, for consistency
410
  # with the other implementations that include that step
 
414
  # extend a bit the range of operations done in float32,
415
  # although this should make no difference.
416
  q = q / q.shape[-1] ** 0.5
417
+ key_layout = layout.replace('t', 'k')
418
+ query_layout = layout
419
  if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
420
  with torch.autocast(device_type=q.device.type, dtype=torch.float32):
421
+ pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
422
  else:
423
+ pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
424
  if attn_mask is not None:
425
  pre_w = pre_w + attn_mask
426
  w = torch.softmax(pre_w, dim=-1)
427
  w = F.dropout(w, self.dropout, training=self.training).to(v)
428
+ # Key and value have the same format.
429
+ x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
430
  x = x.to(dtype)
431
+ x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
432
  x = self.out_proj(x)
433
  else:
434
  key, value = self._complete_kv(key, value)
tests/models/test_musicgen.py CHANGED
@@ -13,7 +13,7 @@ from audiocraft.models import MusicGen
13
  class TestSEANetModel:
14
  def get_musicgen(self):
15
  mg = MusicGen.get_pretrained(name='debug', device='cpu')
16
- mg.set_generation_params(duration=2.0)
17
  return mg
18
 
19
  def test_base(self):
@@ -48,3 +48,11 @@ class TestSEANetModel:
48
  wav = mg.generate(
49
  ['youpi', 'lapin dort'])
50
  assert list(wav.shape) == [2, 1, 64000]
 
 
 
 
 
 
 
 
 
13
  class TestSEANetModel:
14
  def get_musicgen(self):
15
  mg = MusicGen.get_pretrained(name='debug', device='cpu')
16
+ mg.set_generation_params(duration=2.0, extend_stride=2.)
17
  return mg
18
 
19
  def test_base(self):
 
48
  wav = mg.generate(
49
  ['youpi', 'lapin dort'])
50
  assert list(wav.shape) == [2, 1, 64000]
51
+
52
+ def test_generate_long(self):
53
+ mg = self.get_musicgen()
54
+ mg.max_duration = 3.
55
+ mg.set_generation_params(duration=4., extend_stride=2.)
56
+ wav = mg.generate(
57
+ ['youpi', 'lapin dort'])
58
+ assert list(wav.shape) == [2, 1, 32000 * 4]
tests/modules/test_rope.py CHANGED
@@ -7,10 +7,11 @@
7
  import torch
8
 
9
  from audiocraft.modules.rope import RotaryEmbedding
10
- from audiocraft.modules.transformer import StreamingTransformer
11
 
12
 
13
  def test_rope():
 
14
  B, T, H, C = 8, 75, 16, 128
15
 
16
  rope = RotaryEmbedding(dim=C)
@@ -23,6 +24,7 @@ def test_rope():
23
 
24
 
25
  def test_rope_io_dtypes():
 
26
  B, T, H, C = 8, 75, 16, 128
27
 
28
  rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
@@ -46,6 +48,7 @@ def test_rope_io_dtypes():
46
 
47
 
48
  def test_transformer_with_rope():
 
49
  torch.manual_seed(1234)
50
  for pos in ['rope', 'sin_rope']:
51
  tr = StreamingTransformer(
@@ -61,6 +64,7 @@ def test_transformer_with_rope():
61
 
62
  @torch.no_grad()
63
  def test_rope_streaming():
 
64
  torch.manual_seed(1234)
65
  tr = StreamingTransformer(
66
  16, 4, 2, causal=True, dropout=0.,
@@ -88,6 +92,7 @@ def test_rope_streaming():
88
 
89
  @torch.no_grad()
90
  def test_rope_streaming_past_context():
 
91
  torch.manual_seed(1234)
92
 
93
  for context in [None, 10]:
@@ -117,6 +122,7 @@ def test_rope_streaming_past_context():
117
 
118
 
119
  def test_rope_memory_efficient():
 
120
  torch.manual_seed(1234)
121
  tr = StreamingTransformer(
122
  16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
@@ -137,6 +143,7 @@ def test_rope_memory_efficient():
137
 
138
 
139
  def test_rope_with_xpos():
 
140
  B, T, H, C = 8, 75, 16, 128
141
 
142
  rope = RotaryEmbedding(dim=C, xpos=True)
@@ -149,6 +156,7 @@ def test_rope_with_xpos():
149
 
150
 
151
  def test_positional_scale():
 
152
  B, T, H, C = 8, 75, 16, 128
153
 
154
  rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
 
7
  import torch
8
 
9
  from audiocraft.modules.rope import RotaryEmbedding
10
+ from audiocraft.modules.transformer import StreamingTransformer, set_efficient_attention_backend
11
 
12
 
13
  def test_rope():
14
+ set_efficient_attention_backend('xformers')
15
  B, T, H, C = 8, 75, 16, 128
16
 
17
  rope = RotaryEmbedding(dim=C)
 
24
 
25
 
26
  def test_rope_io_dtypes():
27
+ set_efficient_attention_backend('xformers')
28
  B, T, H, C = 8, 75, 16, 128
29
 
30
  rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
 
48
 
49
 
50
  def test_transformer_with_rope():
51
+ set_efficient_attention_backend('xformers')
52
  torch.manual_seed(1234)
53
  for pos in ['rope', 'sin_rope']:
54
  tr = StreamingTransformer(
 
64
 
65
  @torch.no_grad()
66
  def test_rope_streaming():
67
+ set_efficient_attention_backend('xformers')
68
  torch.manual_seed(1234)
69
  tr = StreamingTransformer(
70
  16, 4, 2, causal=True, dropout=0.,
 
92
 
93
  @torch.no_grad()
94
  def test_rope_streaming_past_context():
95
+ set_efficient_attention_backend('xformers')
96
  torch.manual_seed(1234)
97
 
98
  for context in [None, 10]:
 
122
 
123
 
124
  def test_rope_memory_efficient():
125
+ set_efficient_attention_backend('xformers')
126
  torch.manual_seed(1234)
127
  tr = StreamingTransformer(
128
  16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
 
143
 
144
 
145
  def test_rope_with_xpos():
146
+ set_efficient_attention_backend('xformers')
147
  B, T, H, C = 8, 75, 16, 128
148
 
149
  rope = RotaryEmbedding(dim=C, xpos=True)
 
156
 
157
 
158
  def test_positional_scale():
159
+ set_efficient_attention_backend('xformers')
160
  B, T, H, C = 8, 75, 16, 128
161
 
162
  rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
tests/modules/test_transformer.py CHANGED
@@ -9,7 +9,8 @@ from itertools import product
9
  import pytest
10
  import torch
11
 
12
- from audiocraft.modules.transformer import StreamingMultiheadAttention, StreamingTransformer
 
13
 
14
 
15
  def test_transformer_causal_streaming():
@@ -86,19 +87,22 @@ def test_streaming_api():
86
 
87
  def test_memory_efficient():
88
  torch.manual_seed(1234)
89
- tr = StreamingTransformer(
90
- 16, 4, 2, custom=True, dropout=0., layer_scale=0.1)
91
- tr_mem_efficient = StreamingTransformer(
92
- 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1)
93
- tr_mem_efficient.load_state_dict(tr.state_dict())
94
- tr.eval()
95
- steps = 12
96
- x = torch.randn(3, steps, 16)
97
 
98
- with torch.no_grad():
99
- y = tr(x)
100
- y2 = tr_mem_efficient(x)
101
- assert torch.allclose(y, y2), (y - y2).norm()
 
 
 
 
 
 
 
 
 
102
 
103
 
104
  def test_attention_as_float32():
@@ -129,30 +133,32 @@ def test_attention_as_float32():
129
  @torch.no_grad()
130
  def test_streaming_memory_efficient():
131
  torch.manual_seed(1234)
132
- tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)
133
- tr_mem_efficient = StreamingTransformer(
134
- 16, 4, 2, dropout=0., memory_efficient=True, causal=True)
135
- tr.load_state_dict(tr_mem_efficient.state_dict())
136
- tr.eval()
137
- tr_mem_efficient.eval()
138
- steps = 12
139
- x = torch.randn(3, steps, 16)
 
 
140
 
141
- ref = tr(x)
142
 
143
- with tr_mem_efficient.streaming():
144
- outs = []
145
- # frame_sizes = [2] + [1] * (steps - 2)
146
- frame_sizes = [1] * steps
147
 
148
- for frame_size in frame_sizes:
149
- frame = x[:, :frame_size]
150
- x = x[:, frame_size:]
151
- outs.append(tr_mem_efficient(frame))
152
 
153
- out = torch.cat(outs, dim=1)
154
- delta = torch.norm(out - ref) / torch.norm(out)
155
- assert delta < 1e-6, delta
156
 
157
 
158
  def test_cross_attention():
@@ -204,7 +210,7 @@ def test_cross_attention_compat():
204
 
205
  y = cross_attn(queries, keys, values)[0]
206
  y_ref = ref_attn(queries, keys, values)[0]
207
- assert torch.allclose(y, y_ref, atol=1e-7)
208
 
209
  # Now let's check that streaming is working properly.
210
  with cross_attn.streaming():
 
9
  import pytest
10
  import torch
11
 
12
+ from audiocraft.modules.transformer import (
13
+ StreamingMultiheadAttention, StreamingTransformer, set_efficient_attention_backend)
14
 
15
 
16
  def test_transformer_causal_streaming():
 
87
 
88
  def test_memory_efficient():
89
  torch.manual_seed(1234)
90
+ for backend in ['torch', 'xformers']:
91
+ set_efficient_attention_backend(backend)
 
 
 
 
 
 
92
 
93
+ tr = StreamingTransformer(
94
+ 16, 4, 2, custom=True, dropout=0., layer_scale=0.1)
95
+ tr_mem_efficient = StreamingTransformer(
96
+ 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1)
97
+ tr_mem_efficient.load_state_dict(tr.state_dict())
98
+ tr.eval()
99
+ steps = 12
100
+ x = torch.randn(3, steps, 16)
101
+
102
+ with torch.no_grad():
103
+ y = tr(x)
104
+ y2 = tr_mem_efficient(x)
105
+ assert torch.allclose(y, y2), ((y - y2).norm(), backend)
106
 
107
 
108
  def test_attention_as_float32():
 
133
  @torch.no_grad()
134
  def test_streaming_memory_efficient():
135
  torch.manual_seed(1234)
136
+ for backend in ['torch', 'xformers']:
137
+ set_efficient_attention_backend(backend)
138
+ tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)
139
+ tr_mem_efficient = StreamingTransformer(
140
+ 16, 4, 2, dropout=0., memory_efficient=True, causal=True)
141
+ tr.load_state_dict(tr_mem_efficient.state_dict())
142
+ tr.eval()
143
+ tr_mem_efficient.eval()
144
+ steps = 12
145
+ x = torch.randn(3, steps, 16)
146
 
147
+ ref = tr(x)
148
 
149
+ with tr_mem_efficient.streaming():
150
+ outs = []
151
+ # frame_sizes = [2] + [1] * (steps - 2)
152
+ frame_sizes = [1] * steps
153
 
154
+ for frame_size in frame_sizes:
155
+ frame = x[:, :frame_size]
156
+ x = x[:, frame_size:]
157
+ outs.append(tr_mem_efficient(frame))
158
 
159
+ out = torch.cat(outs, dim=1)
160
+ delta = torch.norm(out - ref) / torch.norm(out)
161
+ assert delta < 1e-6, delta
162
 
163
 
164
  def test_cross_attention():
 
210
 
211
  y = cross_attn(queries, keys, values)[0]
212
  y_ref = ref_attn(queries, keys, values)[0]
213
+ assert torch.allclose(y, y_ref, atol=1e-7), (y - y_ref).norm() / y_ref.norm()
214
 
215
  # Now let's check that streaming is working properly.
216
  with cross_attn.streaming():