adefossez commited on
Commit
a31f19b
2 Parent(s): eaf8326 c16da55

Merge branch 'longgen' into our_hf2

Browse files
CHANGELOG.md CHANGED
@@ -13,6 +13,8 @@ Now repeating the conditioning periodically if it is too short.
13
 
14
  More options when launching Gradio app locally (thanks @ashleykleynhans).
15
 
 
 
16
  ## [0.0.1] - 2023-06-09
17
 
18
  Initial release, with model evaluation only.
 
13
 
14
  More options when launching Gradio app locally (thanks @ashleykleynhans).
15
 
16
+ Testing out PyTorch 2.0 memory efficient attention.
17
+
18
  ## [0.0.1] - 2023-06-09
19
 
20
  Initial release, with model evaluation only.
app.py CHANGED
@@ -15,7 +15,7 @@ from audiocraft.models import MusicGen
15
  from audiocraft.data.audio import audio_write
16
 
17
  MODEL = None
18
- IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ['SPACE_ID']
19
 
20
 
21
  def load_model(version):
 
15
  from audiocraft.data.audio import audio_write
16
 
17
  MODEL = None
18
+ IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ.get('SPACE_ID', '')
19
 
20
 
21
  def load_model(version):
app_batched.py CHANGED
@@ -6,7 +6,12 @@ This source code is licensed under the license found in the
6
  LICENSE file in the root directory of this source tree.
7
  """
8
 
 
 
 
9
  from tempfile import NamedTemporaryFile
 
 
10
  import torch
11
  import gradio as gr
12
  from audiocraft.data.audio_utils import convert_audio
@@ -16,6 +21,29 @@ from audiocraft.models import MusicGen
16
 
17
  MODEL = None
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def load_model():
21
  print("Loading model")
@@ -28,11 +56,13 @@ def predict(texts, melodies):
28
  MODEL = load_model()
29
 
30
  duration = 12
 
 
31
  MODEL.set_generation_params(duration=duration)
32
 
33
- print(texts, melodies)
 
34
  processed_melodies = []
35
-
36
  target_sr = 32000
37
  target_ac = 1
38
  for melody in melodies:
@@ -60,73 +90,133 @@ def predict(texts, melodies):
60
  audio_write(
61
  file.name, output, MODEL.sample_rate, strategy="loudness",
62
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
63
- waveform_video = gr.make_waveform(file.name)
64
- out_files.append(waveform_video)
65
- return [out_files]
66
-
67
-
68
- with gr.Blocks() as demo:
69
- gr.Markdown(
70
- """
71
- # MusicGen
72
-
73
- This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
74
- presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
75
- <br/>
76
- <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
77
- <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
78
- for longer sequences, more control and no queue.</p>
79
- """
80
- )
81
- with gr.Row():
82
- with gr.Column():
83
- with gr.Row():
84
- text = gr.Text(label="Describe your music", lines=2, interactive=True)
85
- melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
86
- with gr.Row():
87
- submit = gr.Button("Generate")
88
- with gr.Column():
89
- output = gr.Video(label="Generated Music")
90
- submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=12)
91
- gr.Examples(
92
- fn=predict,
93
- examples=[
94
- [
95
- "An 80s driving pop song with heavy drums and synth pads in the background",
96
- "./assets/bach.mp3",
97
- ],
98
- [
99
- "A cheerful country song with acoustic guitars",
100
- "./assets/bolero_ravel.mp3",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  ],
102
- [
103
- "90s rock song with electric guitar and heavy drums",
104
- None,
105
- ],
106
- [
107
- "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
108
- "./assets/bach.mp3",
109
- ],
110
- [
111
- "lofi slow bpm electro chill with organic samples",
112
- None,
113
- ],
114
- ],
115
- inputs=[text, melody],
116
- outputs=[output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
118
- gr.Markdown("""
119
- ### More details
120
-
121
- The model will generate 12 seconds of audio based on the description you provided.
122
- You can optionaly provide a reference audio from which a broad melody will be extracted.
123
- The model will then try to follow both the description and melody provided.
124
- All samples are generated with the `melody` model.
125
-
126
- You can also use your own GPU or a Google Colab by following the instructions on our repo.
127
 
128
- See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
129
- for more details.
130
- """)
131
 
132
- demo.queue(max_size=60).launch()
 
 
 
 
 
 
 
 
6
  LICENSE file in the root directory of this source tree.
7
  """
8
 
9
+ import argparse
10
+ from concurrent.futures import ProcessPoolExecutor
11
+ import subprocess as sp
12
  from tempfile import NamedTemporaryFile
13
+ import time
14
+ import warnings
15
  import torch
16
  import gradio as gr
17
  from audiocraft.data.audio_utils import convert_audio
 
21
 
22
  MODEL = None
23
 
24
+ _old_call = sp.call
25
+
26
+
27
+ def _call_nostderr(*args, **kwargs):
28
+ # Avoid ffmpeg vomitting on the logs.
29
+ kwargs['stderr'] = sp.DEVNULL
30
+ kwargs['stdout'] = sp.DEVNULL
31
+ _old_call(*args, **kwargs)
32
+
33
+
34
+ sp.call = _call_nostderr
35
+ pool = ProcessPoolExecutor(3)
36
+ pool.__enter__()
37
+
38
+
39
+ def make_waveform(*args, **kwargs):
40
+ be = time.time()
41
+ with warnings.catch_warnings():
42
+ warnings.simplefilter('ignore')
43
+ out = gr.make_waveform(*args, **kwargs)
44
+ print("Make a video took", time.time() - be)
45
+ return out
46
+
47
 
48
  def load_model():
49
  print("Loading model")
 
56
  MODEL = load_model()
57
 
58
  duration = 12
59
+ max_text_length = 512
60
+ texts = [text[:max_text_length] for text in texts]
61
  MODEL.set_generation_params(duration=duration)
62
 
63
+ print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
64
+ be = time.time()
65
  processed_melodies = []
 
66
  target_sr = 32000
67
  target_ac = 1
68
  for melody in melodies:
 
90
  audio_write(
91
  file.name, output, MODEL.sample_rate, strategy="loudness",
92
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
93
+ out_files.append(pool.submit(make_waveform, file.name))
94
+ res = [[out_file.result() for out_file in out_files]]
95
+ print("batch finished", len(texts), time.time() - be)
96
+ return res
97
+
98
+
99
+ def ui(**kwargs):
100
+ with gr.Blocks() as demo:
101
+ gr.Markdown(
102
+ """
103
+ # MusicGen
104
+
105
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
106
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
107
+ <br/>
108
+ <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
109
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
110
+ for longer sequences, more control and no queue.</p>
111
+ """
112
+ )
113
+ with gr.Row():
114
+ with gr.Column():
115
+ with gr.Row():
116
+ text = gr.Text(label="Describe your music", lines=2, interactive=True)
117
+ melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
118
+ with gr.Row():
119
+ submit = gr.Button("Generate")
120
+ with gr.Column():
121
+ output = gr.Video(label="Generated Music")
122
+ submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=8)
123
+ gr.Examples(
124
+ fn=predict,
125
+ examples=[
126
+ [
127
+ "An 80s driving pop song with heavy drums and synth pads in the background",
128
+ "./assets/bach.mp3",
129
+ ],
130
+ [
131
+ "A cheerful country song with acoustic guitars",
132
+ "./assets/bolero_ravel.mp3",
133
+ ],
134
+ [
135
+ "90s rock song with electric guitar and heavy drums",
136
+ None,
137
+ ],
138
+ [
139
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
140
+ "./assets/bach.mp3",
141
+ ],
142
+ [
143
+ "lofi slow bpm electro chill with organic samples",
144
+ None,
145
+ ],
146
  ],
147
+ inputs=[text, melody],
148
+ outputs=[output]
149
+ )
150
+ gr.Markdown("""
151
+ ### More details
152
+
153
+ The model will generate 12 seconds of audio based on the description you provided.
154
+ You can optionaly provide a reference audio from which a broad melody will be extracted.
155
+ The model will then try to follow both the description and melody provided.
156
+ All samples are generated with the `melody` model.
157
+
158
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
159
+
160
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
161
+ for more details.
162
+ """)
163
+
164
+ # Show the interface
165
+ launch_kwargs = {}
166
+ username = kwargs.get('username')
167
+ password = kwargs.get('password')
168
+ server_port = kwargs.get('server_port', 0)
169
+ inbrowser = kwargs.get('inbrowser', False)
170
+ share = kwargs.get('share', False)
171
+ server_name = kwargs.get('listen')
172
+
173
+ launch_kwargs['server_name'] = server_name
174
+
175
+ if username and password:
176
+ launch_kwargs['auth'] = (username, password)
177
+ if server_port > 0:
178
+ launch_kwargs['server_port'] = server_port
179
+ if inbrowser:
180
+ launch_kwargs['inbrowser'] = inbrowser
181
+ if share:
182
+ launch_kwargs['share'] = share
183
+ demo.queue(max_size=8 * 4).launch(**launch_kwargs)
184
+
185
+
186
+ if __name__ == "__main__":
187
+ parser = argparse.ArgumentParser()
188
+ parser.add_argument(
189
+ '--listen',
190
+ type=str,
191
+ default='127.0.0.1',
192
+ help='IP to listen on for connections to Gradio',
193
+ )
194
+ parser.add_argument(
195
+ '--username', type=str, default='', help='Username for authentication'
196
+ )
197
+ parser.add_argument(
198
+ '--password', type=str, default='', help='Password for authentication'
199
+ )
200
+ parser.add_argument(
201
+ '--server_port',
202
+ type=int,
203
+ default=0,
204
+ help='Port to run the server listener on',
205
+ )
206
+ parser.add_argument(
207
+ '--inbrowser', action='store_true', help='Open in browser'
208
+ )
209
+ parser.add_argument(
210
+ '--share', action='store_true', help='Share the gradio UI'
211
  )
 
 
 
 
 
 
 
 
 
212
 
213
+ args = parser.parse_args()
 
 
214
 
215
+ ui(
216
+ username=args.username,
217
+ password=args.password,
218
+ inbrowser=args.inbrowser,
219
+ server_port=args.server_port,
220
+ share=args.share,
221
+ listen=args.listen
222
+ )
audiocraft/models/musicgen.py CHANGED
@@ -96,7 +96,7 @@ class MusicGen:
96
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
97
  top_p: float = 0.0, temperature: float = 1.0,
98
  duration: float = 30.0, cfg_coef: float = 3.0,
99
- two_step_cfg: bool = False):
100
  """Set the generation parameters for MusicGen.
101
 
102
  Args:
@@ -109,8 +109,13 @@ class MusicGen:
109
  two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
110
  instead of batching together the two. This has some impact on how things
111
  are padded but seems to have little impact in practice.
 
 
 
112
  """
113
- assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
 
 
114
  self.generation_params = {
115
  'max_gen_len': int(duration * self.frame_rate),
116
  'use_sampling': use_sampling,
 
96
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
97
  top_p: float = 0.0, temperature: float = 1.0,
98
  duration: float = 30.0, cfg_coef: float = 3.0,
99
+ two_step_cfg: bool = False, extend_stride: float = 15):
100
  """Set the generation parameters for MusicGen.
101
 
102
  Args:
 
109
  two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
110
  instead of batching together the two. This has some impact on how things
111
  are padded but seems to have little impact in practice.
112
+ extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
113
+ should we extend the audio each time. Larger values will mean less context is
114
+ preserved, and shorter value will require extra computations.
115
  """
116
+ # assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
117
+ assert extend_stride <= 25, "Keep at least 5 seconds of overlap!"
118
+ self.extend_stride = extend_stride
119
  self.generation_params = {
120
  'max_gen_len': int(duration * self.frame_rate),
121
  'use_sampling': use_sampling,
audiocraft/modules/transformer.py CHANGED
@@ -247,20 +247,20 @@ class StreamingMultiheadAttention(StreamingModule):
247
  # Complete the key/value pair using the streaming state.
248
  if self._streaming_state:
249
  pk = self._streaming_state['past_keys']
250
- nk = torch.cat([pk, k], dim=1)
251
  if v is k:
252
  nv = nk
253
  else:
254
  pv = self._streaming_state['past_values']
255
- nv = torch.cat([pv, v], dim=1)
256
  else:
257
  nk = k
258
  nv = v
259
 
260
- assert nk.shape[1] == nv.shape[1]
261
  offset = 0
262
  if self.past_context is not None:
263
- offset = max(0, nk.shape[1] - self.past_context)
264
  if self._is_streaming:
265
  self._streaming_state['past_keys'] = nk[:, offset:]
266
  if v is not k:
@@ -271,6 +271,7 @@ class StreamingMultiheadAttention(StreamingModule):
271
  self._streaming_state['offset'] = torch.tensor(0)
272
  return nk, nv
273
 
 
274
  def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
275
  # Apply rope embeddings to query and key tensors.
276
  assert self.rope is not None
@@ -325,7 +326,7 @@ class StreamingMultiheadAttention(StreamingModule):
325
  q = self.q_layer_norm(q)
326
  k = self.k_layer_norm(k)
327
  # q, k, v = [rearrange(x, "b t (h d) -> (b h) t d", h=self.num_heads) for x in [q, k, v]]
328
- q, k, v = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k, v]]
329
  else:
330
  if not _is_profiled():
331
  # profiling breaks that propertysomehow.
@@ -333,7 +334,7 @@ class StreamingMultiheadAttention(StreamingModule):
333
  assert value is key, "specialized implementation"
334
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
335
  if self.kv_repeat == 1:
336
- packed = rearrange(projected, "b t (p h d) -> b t p h d", p=3, h=self.num_heads)
337
  q, k, v = ops.unbind(packed, dim=2)
338
  else:
339
  embed_dim = self.embed_dim
@@ -355,6 +356,7 @@ class StreamingMultiheadAttention(StreamingModule):
355
  k = self.k_layer_norm(k)
356
  q, k = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k]]
357
  if self.rope:
 
358
  q, k = self._apply_rope(q, k)
359
  k, v = self._complete_kv(k, v)
360
  if self.kv_repeat > 1:
@@ -364,7 +366,8 @@ class StreamingMultiheadAttention(StreamingModule):
364
  q, k, v = [x.float() for x in [q, k, v]]
365
  if self.memory_efficient:
366
  p = self.dropout if self.training else 0
367
- x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
 
368
  else:
369
  # We include the dot product as float32, for consistency
370
  # with the other implementations that include that step
@@ -385,7 +388,7 @@ class StreamingMultiheadAttention(StreamingModule):
385
  w = F.dropout(w, self.dropout, training=self.training).to(v)
386
  x = torch.einsum("bhqk,bkhc->bqhc", w, v)
387
  x = x.to(dtype)
388
- x = rearrange(x, "b t h d -> b t (h d)", h=self.num_heads)
389
  x = self.out_proj(x)
390
  else:
391
  key, value = self._complete_kv(key, value)
 
247
  # Complete the key/value pair using the streaming state.
248
  if self._streaming_state:
249
  pk = self._streaming_state['past_keys']
250
+ nk = torch.cat([pk, k], dim=2)
251
  if v is k:
252
  nv = nk
253
  else:
254
  pv = self._streaming_state['past_values']
255
+ nv = torch.cat([pv, v], dim=2)
256
  else:
257
  nk = k
258
  nv = v
259
 
260
+ assert nk.shape[2] == nv.shape[2]
261
  offset = 0
262
  if self.past_context is not None:
263
+ offset = max(0, nk.shape[2] - self.past_context)
264
  if self._is_streaming:
265
  self._streaming_state['past_keys'] = nk[:, offset:]
266
  if v is not k:
 
271
  self._streaming_state['offset'] = torch.tensor(0)
272
  return nk, nv
273
 
274
+
275
  def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
276
  # Apply rope embeddings to query and key tensors.
277
  assert self.rope is not None
 
326
  q = self.q_layer_norm(q)
327
  k = self.k_layer_norm(k)
328
  # q, k, v = [rearrange(x, "b t (h d) -> (b h) t d", h=self.num_heads) for x in [q, k, v]]
329
+ q, k, v = [rearrange(x, "b t (h d) -> b h t d", h=self.num_heads) for x in [q, k, v]]
330
  else:
331
  if not _is_profiled():
332
  # profiling breaks that propertysomehow.
 
334
  assert value is key, "specialized implementation"
335
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
336
  if self.kv_repeat == 1:
337
+ packed = rearrange(projected, "b t (p h d) -> b h p t d", p=3, h=self.num_heads)
338
  q, k, v = ops.unbind(packed, dim=2)
339
  else:
340
  embed_dim = self.embed_dim
 
356
  k = self.k_layer_norm(k)
357
  q, k = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k]]
358
  if self.rope:
359
+ assert False, "Not supported for now"
360
  q, k = self._apply_rope(q, k)
361
  k, v = self._complete_kv(k, v)
362
  if self.kv_repeat > 1:
 
366
  q, k, v = [x.float() for x in [q, k, v]]
367
  if self.memory_efficient:
368
  p = self.dropout if self.training else 0
369
+ x = torch.nn.functional.scaled_dot_product_attention(
370
+ q, k, v, is_causal=attn_mask is not None, dropout_p=p)
371
  else:
372
  # We include the dot product as float32, for consistency
373
  # with the other implementations that include that step
 
388
  w = F.dropout(w, self.dropout, training=self.training).to(v)
389
  x = torch.einsum("bhqk,bkhc->bqhc", w, v)
390
  x = x.to(dtype)
391
+ x = rearrange(x, "b h t d -> b t (h d)", h=self.num_heads)
392
  x = self.out_proj(x)
393
  else:
394
  key, value = self._complete_kv(key, value)