adefossez commited on
Commit
6ec60d5
1 Parent(s): 6a458f2

support both torch and xformers + merge apps

Browse files
app.py CHANGED
@@ -1,70 +1,125 @@
1
- """
2
- Copyright (c) Meta Platforms, Inc. and affiliates.
3
- All rights reserved.
4
 
5
- This source code is licensed under the license found in the
6
- LICENSE file in the root directory of this source tree.
7
- """
8
 
9
- from tempfile import NamedTemporaryFile
10
  import argparse
 
 
 
 
 
 
 
11
  import torch
12
  import gradio as gr
13
- import os
14
- from audiocraft.models import MusicGen
15
  from audiocraft.data.audio import audio_write
 
16
 
17
- MODEL = None
18
- IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ.get('SPACE_ID', '')
19
 
 
 
 
 
 
 
20
 
21
- def load_model(version):
22
- print("Loading model", version)
23
- return MusicGen.get_pretrained(version)
 
 
 
24
 
25
 
26
- def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  global MODEL
28
- topk = int(topk)
29
- if MODEL is None or MODEL.name != model:
30
- MODEL = load_model(model)
31
 
32
- if duration > MODEL.lm.cfg.dataset.segment_duration:
33
- raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
34
- MODEL.set_generation_params(
35
- use_sampling=True,
36
- top_k=topk,
37
- top_p=topp,
38
- temperature=temperature,
39
- cfg_coef=cfg_coef,
40
- duration=duration,
41
- )
42
 
43
- if melody:
44
- sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
45
- print(melody.shape)
46
- if melody.dim() == 2:
47
- melody = melody[None]
48
- melody = melody[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
49
- output = MODEL.generate_with_chroma(
50
- descriptions=[text],
51
- melody_wavs=melody,
52
- melody_sample_rate=sr,
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  progress=False
54
  )
55
  else:
56
- output = MODEL.generate(descriptions=[text], progress=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- output = output.detach().cpu().float()[0]
59
- with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
60
- audio_write(
61
- file.name, output, MODEL.sample_rate, strategy="loudness",
62
- loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
63
- waveform_video = gr.make_waveform(file.name)
64
- return waveform_video
65
 
66
 
67
- def ui(**kwargs):
68
  with gr.Blocks() as interface:
69
  gr.Markdown(
70
  """
@@ -73,14 +128,6 @@ def ui(**kwargs):
73
  presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
74
  """
75
  )
76
- if IS_SHARED_SPACE:
77
- gr.Markdown("""
78
- ⚠ This Space doesn't work in this shared UI ⚠
79
-
80
- <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
81
- <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
82
- to use it privately, or use the <a href="https://huggingface.co/spaces/facebook/MusicGen">public demo</a>
83
- """)
84
  with gr.Row():
85
  with gr.Column():
86
  with gr.Row():
@@ -99,9 +146,9 @@ def ui(**kwargs):
99
  cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
100
  with gr.Column():
101
  output = gr.Video(label="Generated Music")
102
- submit.click(predict, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
103
  gr.Examples(
104
- fn=predict,
105
  examples=[
106
  [
107
  "An 80s driving pop song with heavy drums and synth pads in the background",
@@ -154,35 +201,83 @@ def ui(**kwargs):
154
  """
155
  )
156
 
157
- # Show the interface
158
- launch_kwargs = {}
159
- username = kwargs.get('username')
160
- password = kwargs.get('password')
161
- server_port = kwargs.get('server_port', 0)
162
- inbrowser = kwargs.get('inbrowser', False)
163
- share = kwargs.get('share', False)
164
- server_name = kwargs.get('listen')
165
-
166
- launch_kwargs['server_name'] = server_name
167
-
168
- if username and password:
169
- launch_kwargs['auth'] = (username, password)
170
- if server_port > 0:
171
- launch_kwargs['server_port'] = server_port
172
- if inbrowser:
173
- launch_kwargs['inbrowser'] = inbrowser
174
- if share:
175
- launch_kwargs['share'] = share
176
-
177
  interface.queue().launch(**launch_kwargs, max_threads=1)
178
 
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  if __name__ == "__main__":
181
  parser = argparse.ArgumentParser()
182
  parser.add_argument(
183
  '--listen',
184
  type=str,
185
- default='0.0.0.0',
186
  help='IP to listen on for connections to Gradio',
187
  )
188
  parser.add_argument(
@@ -206,11 +301,18 @@ if __name__ == "__main__":
206
 
207
  args = parser.parse_args()
208
 
209
- ui(
210
- username=args.username,
211
- password=args.password,
212
- inbrowser=args.inbrowser,
213
- server_port=args.server_port,
214
- share=args.share,
215
- listen=args.listen
216
- )
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
 
3
 
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
 
6
 
 
7
  import argparse
8
+ from concurrent.futures import ProcessPoolExecutor
9
+ import os
10
+ import subprocess as sp
11
+ from tempfile import NamedTemporaryFile
12
+ import time
13
+ import warnings
14
+
15
  import torch
16
  import gradio as gr
17
+
18
+ from audiocraft.data.audio_utils import convert_audio
19
  from audiocraft.data.audio import audio_write
20
+ from audiocraft.models import MusicGen
21
 
 
 
22
 
23
+ MODEL = None # Last used model
24
+ IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
25
+ MAX_BATCH_SIZE = 12
26
+ BATCHED_DURATION = 15
27
+ # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
28
+ _old_call = sp.call
29
 
30
+
31
+ def _call_nostderr(*args, **kwargs):
32
+ # Avoid ffmpeg vomitting on the logs.
33
+ kwargs['stderr'] = sp.DEVNULL
34
+ kwargs['stdout'] = sp.DEVNULL
35
+ _old_call(*args, **kwargs)
36
 
37
 
38
+ sp.call = _call_nostderr
39
+ # Preallocating the pool of processes.
40
+ pool = ProcessPoolExecutor(3)
41
+ pool.__enter__()
42
+
43
+
44
+ def make_waveform(*args, **kwargs):
45
+ # Further remove some warnings.
46
+ be = time.time()
47
+ with warnings.catch_warnings():
48
+ warnings.simplefilter('ignore')
49
+ out = gr.make_waveform(*args, **kwargs)
50
+ print("Make a video took", time.time() - be)
51
+ return out
52
+
53
+
54
+ def load_model(version='melody'):
55
  global MODEL
56
+ print("Loading model", version)
57
+ if MODEL is None or MODEL.name != version:
58
+ MODEL = MusicGen.get_pretrained(version)
59
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ def _do_predictions(texts, melodies, duration, **gen_kwargs):
62
+ MODEL.set_generation_params(duration=duration, **gen_kwargs)
63
+ print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
64
+ be = time.time()
65
+ processed_melodies = []
66
+ target_sr = 32000
67
+ target_ac = 1
68
+ for melody in melodies:
69
+ if melody is None:
70
+ processed_melodies.append(None)
71
+ else:
72
+ sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
73
+ if melody.dim() == 1:
74
+ melody = melody[None]
75
+ melody = melody[..., :int(sr * duration)]
76
+ melody = convert_audio(melody, sr, target_sr, target_ac)
77
+ processed_melodies.append(melody)
78
+
79
+ if processed_melodies.any():
80
+ outputs = MODEL.generate_with_chroma(
81
+ descriptions=texts,
82
+ melody_wavs=processed_melodies,
83
+ melody_sample_rate=target_sr,
84
  progress=False
85
  )
86
  else:
87
+ outputs = MODEL.generate(texts, progress=False)
88
+
89
+ outputs = outputs.detach().cpu().float()
90
+ out_files = []
91
+ for output in outputs:
92
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
93
+ audio_write(
94
+ file.name, output, MODEL.sample_rate, strategy="loudness",
95
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
96
+ out_files.append(pool.submit(make_waveform, file.name))
97
+ res = [out_file.result() for out_file in out_files]
98
+ print("batch finished", len(texts), time.time() - be)
99
+ return res
100
+
101
+
102
+ def predict_batched(texts, melodies):
103
+ max_text_length = 512
104
+ texts = [text[:max_text_length] for text in texts]
105
+ load_model('melody')
106
+ res = _do_predictions(texts, melodies, BATCHED_DURATION)
107
+ return [res]
108
+
109
+
110
+ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef):
111
+ topk = int(topk)
112
+ load_model(model)
113
+ if duration > MODEL.lm.cfg.dataset.segment_duration:
114
+ raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
115
 
116
+ outs = _do_predictions(
117
+ [text], [melody], duration,
118
+ topk=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
119
+ return outs[0]
 
 
 
120
 
121
 
122
+ def ui_full(launch_kwargs):
123
  with gr.Blocks() as interface:
124
  gr.Markdown(
125
  """
 
128
  presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
129
  """
130
  )
 
 
 
 
 
 
 
 
131
  with gr.Row():
132
  with gr.Column():
133
  with gr.Row():
 
146
  cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
147
  with gr.Column():
148
  output = gr.Video(label="Generated Music")
149
+ submit.click(predict_full, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
150
  gr.Examples(
151
+ fn=predict_full,
152
  examples=[
153
  [
154
  "An 80s driving pop song with heavy drums and synth pads in the background",
 
201
  """
202
  )
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  interface.queue().launch(**launch_kwargs, max_threads=1)
205
 
206
 
207
+ def ui_batched(launch_kwargs):
208
+ with gr.Blocks() as demo:
209
+ gr.Markdown(
210
+ """
211
+ # MusicGen
212
+
213
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
214
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
215
+ <br/>
216
+ <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
217
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
218
+ for longer sequences, more control and no queue.</p>
219
+ """
220
+ )
221
+ with gr.Row():
222
+ with gr.Column():
223
+ with gr.Row():
224
+ text = gr.Text(label="Describe your music", lines=2, interactive=True)
225
+ melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
226
+ with gr.Row():
227
+ submit = gr.Button("Generate")
228
+ with gr.Column():
229
+ output = gr.Video(label="Generated Music")
230
+ submit.click(predict_batched, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=MAX_BATCH_SIZE)
231
+ gr.Examples(
232
+ fn=predict_batched,
233
+ examples=[
234
+ [
235
+ "An 80s driving pop song with heavy drums and synth pads in the background",
236
+ "./assets/bach.mp3",
237
+ ],
238
+ [
239
+ "A cheerful country song with acoustic guitars",
240
+ "./assets/bolero_ravel.mp3",
241
+ ],
242
+ [
243
+ "90s rock song with electric guitar and heavy drums",
244
+ None,
245
+ ],
246
+ [
247
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
248
+ "./assets/bach.mp3",
249
+ ],
250
+ [
251
+ "lofi slow bpm electro chill with organic samples",
252
+ None,
253
+ ],
254
+ ],
255
+ inputs=[text, melody],
256
+ outputs=[output]
257
+ )
258
+ gr.Markdown("""
259
+ ### More details
260
+
261
+ The model will generate 12 seconds of audio based on the description you provided.
262
+ You can optionaly provide a reference audio from which a broad melody will be extracted.
263
+ The model will then try to follow both the description and melody provided.
264
+ All samples are generated with the `melody` model.
265
+
266
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
267
+
268
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
269
+ for more details.
270
+ """)
271
+
272
+ demo.queue(max_size=8 * 4).launch(**launch_kwargs)
273
+
274
+
275
  if __name__ == "__main__":
276
  parser = argparse.ArgumentParser()
277
  parser.add_argument(
278
  '--listen',
279
  type=str,
280
+ default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
281
  help='IP to listen on for connections to Gradio',
282
  )
283
  parser.add_argument(
 
301
 
302
  args = parser.parse_args()
303
 
304
+ launch_kwargs = {}
305
+ if args.username and args.password:
306
+ launch_kwargs['auth'] = (args.username, args.password)
307
+ if args.server_port:
308
+ launch_kwargs['server_port'] = args.server_port
309
+ if args.inbrowser:
310
+ launch_kwargs['inbrowser'] = args.inbrowser
311
+ if args.share:
312
+ launch_kwargs['share'] = args.share
313
+
314
+ # Show the interface
315
+ if IS_BATCHED:
316
+ ui_batched(launch_kwargs)
317
+ else:
318
+ ui_full(launch_kwargs)
app_batched.py DELETED
@@ -1,222 +0,0 @@
1
- """
2
- Copyright (c) Meta Platforms, Inc. and affiliates.
3
- All rights reserved.
4
-
5
- This source code is licensed under the license found in the
6
- LICENSE file in the root directory of this source tree.
7
- """
8
-
9
- import argparse
10
- from concurrent.futures import ProcessPoolExecutor
11
- import subprocess as sp
12
- from tempfile import NamedTemporaryFile
13
- import time
14
- import warnings
15
- import torch
16
- import gradio as gr
17
- from audiocraft.data.audio_utils import convert_audio
18
- from audiocraft.data.audio import audio_write
19
- from audiocraft.models import MusicGen
20
-
21
-
22
- MODEL = None
23
-
24
- _old_call = sp.call
25
-
26
-
27
- def _call_nostderr(*args, **kwargs):
28
- # Avoid ffmpeg vomitting on the logs.
29
- kwargs['stderr'] = sp.DEVNULL
30
- kwargs['stdout'] = sp.DEVNULL
31
- _old_call(*args, **kwargs)
32
-
33
-
34
- sp.call = _call_nostderr
35
- pool = ProcessPoolExecutor(3)
36
- pool.__enter__()
37
-
38
-
39
- def make_waveform(*args, **kwargs):
40
- be = time.time()
41
- with warnings.catch_warnings():
42
- warnings.simplefilter('ignore')
43
- out = gr.make_waveform(*args, **kwargs)
44
- print("Make a video took", time.time() - be)
45
- return out
46
-
47
-
48
- def load_model():
49
- print("Loading model")
50
- return MusicGen.get_pretrained("melody")
51
-
52
-
53
- def predict(texts, melodies):
54
- global MODEL
55
- if MODEL is None:
56
- MODEL = load_model()
57
-
58
- duration = 12
59
- max_text_length = 512
60
- texts = [text[:max_text_length] for text in texts]
61
- MODEL.set_generation_params(duration=duration)
62
-
63
- print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
64
- be = time.time()
65
- processed_melodies = []
66
- target_sr = 32000
67
- target_ac = 1
68
- for melody in melodies:
69
- if melody is None:
70
- processed_melodies.append(None)
71
- else:
72
- sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
73
- if melody.dim() == 1:
74
- melody = melody[None]
75
- melody = melody[..., :int(sr * duration)]
76
- melody = convert_audio(melody, sr, target_sr, target_ac)
77
- processed_melodies.append(melody)
78
-
79
- outputs = MODEL.generate_with_chroma(
80
- descriptions=texts,
81
- melody_wavs=processed_melodies,
82
- melody_sample_rate=target_sr,
83
- progress=False
84
- )
85
-
86
- outputs = outputs.detach().cpu().float()
87
- out_files = []
88
- for output in outputs:
89
- with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
90
- audio_write(
91
- file.name, output, MODEL.sample_rate, strategy="loudness",
92
- loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
93
- out_files.append(pool.submit(make_waveform, file.name))
94
- res = [[out_file.result() for out_file in out_files]]
95
- print("batch finished", len(texts), time.time() - be)
96
- return res
97
-
98
-
99
- def ui(**kwargs):
100
- with gr.Blocks() as demo:
101
- gr.Markdown(
102
- """
103
- # MusicGen
104
-
105
- This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
106
- presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
107
- <br/>
108
- <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
109
- <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
110
- for longer sequences, more control and no queue.</p>
111
- """
112
- )
113
- with gr.Row():
114
- with gr.Column():
115
- with gr.Row():
116
- text = gr.Text(label="Describe your music", lines=2, interactive=True)
117
- melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
118
- with gr.Row():
119
- submit = gr.Button("Generate")
120
- with gr.Column():
121
- output = gr.Video(label="Generated Music")
122
- submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=8)
123
- gr.Examples(
124
- fn=predict,
125
- examples=[
126
- [
127
- "An 80s driving pop song with heavy drums and synth pads in the background",
128
- "./assets/bach.mp3",
129
- ],
130
- [
131
- "A cheerful country song with acoustic guitars",
132
- "./assets/bolero_ravel.mp3",
133
- ],
134
- [
135
- "90s rock song with electric guitar and heavy drums",
136
- None,
137
- ],
138
- [
139
- "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
140
- "./assets/bach.mp3",
141
- ],
142
- [
143
- "lofi slow bpm electro chill with organic samples",
144
- None,
145
- ],
146
- ],
147
- inputs=[text, melody],
148
- outputs=[output]
149
- )
150
- gr.Markdown("""
151
- ### More details
152
-
153
- The model will generate 12 seconds of audio based on the description you provided.
154
- You can optionaly provide a reference audio from which a broad melody will be extracted.
155
- The model will then try to follow both the description and melody provided.
156
- All samples are generated with the `melody` model.
157
-
158
- You can also use your own GPU or a Google Colab by following the instructions on our repo.
159
-
160
- See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
161
- for more details.
162
- """)
163
-
164
- # Show the interface
165
- launch_kwargs = {}
166
- username = kwargs.get('username')
167
- password = kwargs.get('password')
168
- server_port = kwargs.get('server_port', 0)
169
- inbrowser = kwargs.get('inbrowser', False)
170
- share = kwargs.get('share', False)
171
- server_name = kwargs.get('listen')
172
-
173
- launch_kwargs['server_name'] = server_name
174
-
175
- if username and password:
176
- launch_kwargs['auth'] = (username, password)
177
- if server_port > 0:
178
- launch_kwargs['server_port'] = server_port
179
- if inbrowser:
180
- launch_kwargs['inbrowser'] = inbrowser
181
- if share:
182
- launch_kwargs['share'] = share
183
- demo.queue(max_size=8 * 4).launch(**launch_kwargs)
184
-
185
-
186
- if __name__ == "__main__":
187
- parser = argparse.ArgumentParser()
188
- parser.add_argument(
189
- '--listen',
190
- type=str,
191
- default='0.0.0.0',
192
- help='IP to listen on for connections to Gradio',
193
- )
194
- parser.add_argument(
195
- '--username', type=str, default='', help='Username for authentication'
196
- )
197
- parser.add_argument(
198
- '--password', type=str, default='', help='Password for authentication'
199
- )
200
- parser.add_argument(
201
- '--server_port',
202
- type=int,
203
- default=0,
204
- help='Port to run the server listener on',
205
- )
206
- parser.add_argument(
207
- '--inbrowser', action='store_true', help='Open in browser'
208
- )
209
- parser.add_argument(
210
- '--share', action='store_true', help='Share the gradio UI'
211
- )
212
-
213
- args = parser.parse_args()
214
-
215
- ui(
216
- username=args.username,
217
- password=args.password,
218
- inbrowser=args.inbrowser,
219
- server_port=args.server_port,
220
- share=args.share,
221
- listen=args.listen
222
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/modules/transformer.py CHANGED
@@ -25,6 +25,22 @@ from xformers import ops
25
  from .rope import RotaryEmbedding
26
  from .streaming import StreamingModule
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def _is_profiled() -> bool:
30
  # Return true if we are currently running with a xformers profiler activated.
@@ -75,14 +91,22 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float =
75
 
76
  def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
77
  """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers"""
78
- bs, slen, n_kv_heads, head_dim = x.shape
79
  if n_rep == 1:
80
  return x
81
- return (
82
- x[:, :, :, None, :]
83
- .expand(bs, slen, n_kv_heads, n_rep, head_dim)
84
- .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
85
- )
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  class LayerScale(nn.Module):
@@ -210,6 +234,7 @@ class StreamingMultiheadAttention(StreamingModule):
210
  # Return a causal mask, accounting for potentially stored past keys/values
211
  # We actually return a bias for the attention score, as this has the same
212
  # convention both in the builtin MHA in Pytorch, and Xformers functions.
 
213
  if self.memory_efficient:
214
  from xformers.ops import LowerTriangularMask
215
  if current_steps == 1:
@@ -222,7 +247,7 @@ class StreamingMultiheadAttention(StreamingModule):
222
  return LowerTriangularMask()
223
  if self._streaming_state:
224
  past_keys = self._streaming_state['past_keys']
225
- past_steps = past_keys.shape[1]
226
  else:
227
  past_steps = 0
228
 
@@ -239,6 +264,7 @@ class StreamingMultiheadAttention(StreamingModule):
239
  torch.full([], float('-inf'), device=device, dtype=dtype))
240
 
241
  def _complete_kv(self, k, v):
 
242
  if self.cross_attention:
243
  # With cross attention we assume all keys and values
244
  # are already available, and streaming is with respect
@@ -247,20 +273,20 @@ class StreamingMultiheadAttention(StreamingModule):
247
  # Complete the key/value pair using the streaming state.
248
  if self._streaming_state:
249
  pk = self._streaming_state['past_keys']
250
- nk = torch.cat([pk, k], dim=2)
251
  if v is k:
252
  nv = nk
253
  else:
254
  pv = self._streaming_state['past_values']
255
- nv = torch.cat([pv, v], dim=2)
256
  else:
257
  nk = k
258
  nv = v
259
 
260
- assert nk.shape[2] == nv.shape[2]
261
  offset = 0
262
  if self.past_context is not None:
263
- offset = max(0, nk.shape[2] - self.past_context)
264
  if self._is_streaming:
265
  self._streaming_state['past_keys'] = nk[:, offset:]
266
  if v is not k:
@@ -271,8 +297,9 @@ class StreamingMultiheadAttention(StreamingModule):
271
  self._streaming_state['offset'] = torch.tensor(0)
272
  return nk, nv
273
 
274
-
275
  def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
 
 
276
  # Apply rope embeddings to query and key tensors.
277
  assert self.rope is not None
278
  if 'past_keys' in self._streaming_state:
@@ -293,6 +320,11 @@ class StreamingMultiheadAttention(StreamingModule):
293
  assert not is_causal, ("new param added in torch 2.0.1 not supported, "
294
  "use the causal args in the constructor.")
295
 
 
 
 
 
 
296
  dtype = query.dtype
297
  if self._is_streaming:
298
  assert self.causal or self.cross_attention, \
@@ -325,8 +357,7 @@ class StreamingMultiheadAttention(StreamingModule):
325
  if self.qk_layer_norm is True:
326
  q = self.q_layer_norm(q)
327
  k = self.k_layer_norm(k)
328
- # q, k, v = [rearrange(x, "b t (h d) -> (b h) t d", h=self.num_heads) for x in [q, k, v]]
329
- q, k, v = [rearrange(x, "b t (h d) -> b h t d", h=self.num_heads) for x in [q, k, v]]
330
  else:
331
  if not _is_profiled():
332
  # profiling breaks that propertysomehow.
@@ -334,7 +365,11 @@ class StreamingMultiheadAttention(StreamingModule):
334
  assert value is key, "specialized implementation"
335
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
336
  if self.kv_repeat == 1:
337
- packed = rearrange(projected, "b t (p h d) -> b h p t d", p=3, h=self.num_heads)
 
 
 
 
338
  q, k, v = ops.unbind(packed, dim=2)
339
  else:
340
  embed_dim = self.embed_dim
@@ -345,18 +380,17 @@ class StreamingMultiheadAttention(StreamingModule):
345
  end = start + per_head_dim * kv_heads
346
  k = projected[:, :, start: end]
347
  v = projected[:, :, end:]
348
- q = rearrange(q, "b t (h d) -> b t h d", h=self.num_heads)
349
- k = rearrange(k, "b t (h d) -> b t h d", h=kv_heads)
350
- v = rearrange(v, "b t (h d) -> b t h d", h=kv_heads)
351
 
352
  if self.qk_layer_norm is True:
353
  assert self.kv_repeat == 1
354
- q, k = [rearrange(x, "b t h d -> b t (h d)") for x in [q, k]]
355
  q = self.q_layer_norm(q)
356
  k = self.k_layer_norm(k)
357
- q, k = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k]]
358
  if self.rope:
359
- assert False, "Not supported for now"
360
  q, k = self._apply_rope(q, k)
361
  k, v = self._complete_kv(k, v)
362
  if self.kv_repeat > 1:
@@ -366,8 +400,11 @@ class StreamingMultiheadAttention(StreamingModule):
366
  q, k, v = [x.float() for x in [q, k, v]]
367
  if self.memory_efficient:
368
  p = self.dropout if self.training else 0
369
- x = torch.nn.functional.scaled_dot_product_attention(
370
- q, k, v, is_causal=attn_mask is not None, dropout_p=p)
 
 
 
371
  else:
372
  # We include the dot product as float32, for consistency
373
  # with the other implementations that include that step
@@ -377,18 +414,21 @@ class StreamingMultiheadAttention(StreamingModule):
377
  # extend a bit the range of operations done in float32,
378
  # although this should make no difference.
379
  q = q / q.shape[-1] ** 0.5
 
 
380
  if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
381
  with torch.autocast(device_type=q.device.type, dtype=torch.float32):
382
- pre_w = torch.einsum("bqhc,bkhc->bhqk", q, k)
383
  else:
384
- pre_w = torch.einsum("bqhc,bkhc->bhqk", q, k)
385
  if attn_mask is not None:
386
  pre_w = pre_w + attn_mask
387
  w = torch.softmax(pre_w, dim=-1)
388
  w = F.dropout(w, self.dropout, training=self.training).to(v)
389
- x = torch.einsum("bhqk,bkhc->bqhc", w, v)
 
390
  x = x.to(dtype)
391
- x = rearrange(x, "b h t d -> b t (h d)", h=self.num_heads)
392
  x = self.out_proj(x)
393
  else:
394
  key, value = self._complete_kv(key, value)
 
25
  from .rope import RotaryEmbedding
26
  from .streaming import StreamingModule
27
 
28
+ _efficient_attention_backend: str = 'torch'
29
+
30
+
31
+ def set_efficient_attention_backend(backend: str = 'torch'):
32
+ # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
33
+ global _efficient_attention_backend
34
+ assert _efficient_attention_backend in ['xformers', 'torch']
35
+ _efficient_attention_backend = backend
36
+
37
+
38
+ def _get_attention_time_dimension() -> int:
39
+ if _efficient_attention_backend == 'torch':
40
+ return 2
41
+ else:
42
+ return 1
43
+
44
 
45
  def _is_profiled() -> bool:
46
  # Return true if we are currently running with a xformers profiler activated.
 
91
 
92
  def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
93
  """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers"""
 
94
  if n_rep == 1:
95
  return x
96
+ if _efficient_attention_backend == 'torch':
97
+ bs, n_kv_heads, slen, head_dim = x.shape
98
+ return (
99
+ x[:, :, None, :, :]
100
+ .expand(bs, n_kv_heads, n_rep, slen, head_dim)
101
+ .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
102
+ )
103
+ else:
104
+ bs, slen, n_kv_heads, head_dim = x.shape
105
+ return (
106
+ x[:, :, :, None, :]
107
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
108
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
109
+ )
110
 
111
 
112
  class LayerScale(nn.Module):
 
234
  # Return a causal mask, accounting for potentially stored past keys/values
235
  # We actually return a bias for the attention score, as this has the same
236
  # convention both in the builtin MHA in Pytorch, and Xformers functions.
237
+ time_dim = _get_attention_time_dimension()
238
  if self.memory_efficient:
239
  from xformers.ops import LowerTriangularMask
240
  if current_steps == 1:
 
247
  return LowerTriangularMask()
248
  if self._streaming_state:
249
  past_keys = self._streaming_state['past_keys']
250
+ past_steps = past_keys.shape[time_dim]
251
  else:
252
  past_steps = 0
253
 
 
264
  torch.full([], float('-inf'), device=device, dtype=dtype))
265
 
266
  def _complete_kv(self, k, v):
267
+ time_dim = _get_attention_time_dimension()
268
  if self.cross_attention:
269
  # With cross attention we assume all keys and values
270
  # are already available, and streaming is with respect
 
273
  # Complete the key/value pair using the streaming state.
274
  if self._streaming_state:
275
  pk = self._streaming_state['past_keys']
276
+ nk = torch.cat([pk, k], dim=time_dim)
277
  if v is k:
278
  nv = nk
279
  else:
280
  pv = self._streaming_state['past_values']
281
+ nv = torch.cat([pv, v], dim=time_dim)
282
  else:
283
  nk = k
284
  nv = v
285
 
286
+ assert nk.shape[time_dim] == nv.shape[time_dim]
287
  offset = 0
288
  if self.past_context is not None:
289
+ offset = max(0, nk.shape[time_dim] - self.past_context)
290
  if self._is_streaming:
291
  self._streaming_state['past_keys'] = nk[:, offset:]
292
  if v is not k:
 
297
  self._streaming_state['offset'] = torch.tensor(0)
298
  return nk, nv
299
 
 
300
  def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
301
+ # TODO: fix and verify layout.
302
+ assert _efficient_attention_backend == 'xformers', 'Rope not supported with torch attn.'
303
  # Apply rope embeddings to query and key tensors.
304
  assert self.rope is not None
305
  if 'past_keys' in self._streaming_state:
 
320
  assert not is_causal, ("new param added in torch 2.0.1 not supported, "
321
  "use the causal args in the constructor.")
322
 
323
+ time_dim = _get_attention_time_dimension()
324
+ if time_dim == 2:
325
+ layout = "b h t d"
326
+ else:
327
+ layout = "b t h d"
328
  dtype = query.dtype
329
  if self._is_streaming:
330
  assert self.causal or self.cross_attention, \
 
357
  if self.qk_layer_norm is True:
358
  q = self.q_layer_norm(q)
359
  k = self.k_layer_norm(k)
360
+ q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
 
361
  else:
362
  if not _is_profiled():
363
  # profiling breaks that propertysomehow.
 
365
  assert value is key, "specialized implementation"
366
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
367
  if self.kv_repeat == 1:
368
+ if time_dim == 2:
369
+ bound_layout = "b h p t d"
370
+ else:
371
+ bound_layout = "b t p h d"
372
+ packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
373
  q, k, v = ops.unbind(packed, dim=2)
374
  else:
375
  embed_dim = self.embed_dim
 
380
  end = start + per_head_dim * kv_heads
381
  k = projected[:, :, start: end]
382
  v = projected[:, :, end:]
383
+ q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
384
+ k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
385
+ v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
386
 
387
  if self.qk_layer_norm is True:
388
  assert self.kv_repeat == 1
389
+ q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
390
  q = self.q_layer_norm(q)
391
  k = self.k_layer_norm(k)
392
+ q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
393
  if self.rope:
 
394
  q, k = self._apply_rope(q, k)
395
  k, v = self._complete_kv(k, v)
396
  if self.kv_repeat > 1:
 
400
  q, k, v = [x.float() for x in [q, k, v]]
401
  if self.memory_efficient:
402
  p = self.dropout if self.training else 0
403
+ if _efficient_attention_backend == 'torch':
404
+ x = torch.nn.functional.scaled_dot_product_attention(
405
+ q, k, v, is_causal=attn_mask is not None, dropout_p=p)
406
+ else:
407
+ x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
408
  else:
409
  # We include the dot product as float32, for consistency
410
  # with the other implementations that include that step
 
414
  # extend a bit the range of operations done in float32,
415
  # although this should make no difference.
416
  q = q / q.shape[-1] ** 0.5
417
+ key_layout = layout.replace('t', 'k')
418
+ query_layout = layout
419
  if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
420
  with torch.autocast(device_type=q.device.type, dtype=torch.float32):
421
+ pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
422
  else:
423
+ pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
424
  if attn_mask is not None:
425
  pre_w = pre_w + attn_mask
426
  w = torch.softmax(pre_w, dim=-1)
427
  w = F.dropout(w, self.dropout, training=self.training).to(v)
428
+ # Key and value have the same format.
429
+ x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
430
  x = x.to(dtype)
431
+ x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
432
  x = self.out_proj(x)
433
  else:
434
  key, value = self._complete_kv(key, value)
tests/modules/test_rope.py CHANGED
@@ -7,10 +7,11 @@
7
  import torch
8
 
9
  from audiocraft.modules.rope import RotaryEmbedding
10
- from audiocraft.modules.transformer import StreamingTransformer
11
 
12
 
13
  def test_rope():
 
14
  B, T, H, C = 8, 75, 16, 128
15
 
16
  rope = RotaryEmbedding(dim=C)
@@ -23,6 +24,7 @@ def test_rope():
23
 
24
 
25
  def test_rope_io_dtypes():
 
26
  B, T, H, C = 8, 75, 16, 128
27
 
28
  rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
@@ -46,6 +48,7 @@ def test_rope_io_dtypes():
46
 
47
 
48
  def test_transformer_with_rope():
 
49
  torch.manual_seed(1234)
50
  for pos in ['rope', 'sin_rope']:
51
  tr = StreamingTransformer(
@@ -61,6 +64,7 @@ def test_transformer_with_rope():
61
 
62
  @torch.no_grad()
63
  def test_rope_streaming():
 
64
  torch.manual_seed(1234)
65
  tr = StreamingTransformer(
66
  16, 4, 2, causal=True, dropout=0.,
@@ -88,6 +92,7 @@ def test_rope_streaming():
88
 
89
  @torch.no_grad()
90
  def test_rope_streaming_past_context():
 
91
  torch.manual_seed(1234)
92
 
93
  for context in [None, 10]:
@@ -117,6 +122,7 @@ def test_rope_streaming_past_context():
117
 
118
 
119
  def test_rope_memory_efficient():
 
120
  torch.manual_seed(1234)
121
  tr = StreamingTransformer(
122
  16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
@@ -137,6 +143,7 @@ def test_rope_memory_efficient():
137
 
138
 
139
  def test_rope_with_xpos():
 
140
  B, T, H, C = 8, 75, 16, 128
141
 
142
  rope = RotaryEmbedding(dim=C, xpos=True)
@@ -149,6 +156,7 @@ def test_rope_with_xpos():
149
 
150
 
151
  def test_positional_scale():
 
152
  B, T, H, C = 8, 75, 16, 128
153
 
154
  rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
 
7
  import torch
8
 
9
  from audiocraft.modules.rope import RotaryEmbedding
10
+ from audiocraft.modules.transformer import StreamingTransformer, set_efficient_attention_backend
11
 
12
 
13
  def test_rope():
14
+ set_efficient_attention_backend('xformers')
15
  B, T, H, C = 8, 75, 16, 128
16
 
17
  rope = RotaryEmbedding(dim=C)
 
24
 
25
 
26
  def test_rope_io_dtypes():
27
+ set_efficient_attention_backend('xformers')
28
  B, T, H, C = 8, 75, 16, 128
29
 
30
  rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
 
48
 
49
 
50
  def test_transformer_with_rope():
51
+ set_efficient_attention_backend('xformers')
52
  torch.manual_seed(1234)
53
  for pos in ['rope', 'sin_rope']:
54
  tr = StreamingTransformer(
 
64
 
65
  @torch.no_grad()
66
  def test_rope_streaming():
67
+ set_efficient_attention_backend('xformers')
68
  torch.manual_seed(1234)
69
  tr = StreamingTransformer(
70
  16, 4, 2, causal=True, dropout=0.,
 
92
 
93
  @torch.no_grad()
94
  def test_rope_streaming_past_context():
95
+ set_efficient_attention_backend('xformers')
96
  torch.manual_seed(1234)
97
 
98
  for context in [None, 10]:
 
122
 
123
 
124
  def test_rope_memory_efficient():
125
+ set_efficient_attention_backend('xformers')
126
  torch.manual_seed(1234)
127
  tr = StreamingTransformer(
128
  16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
 
143
 
144
 
145
  def test_rope_with_xpos():
146
+ set_efficient_attention_backend('xformers')
147
  B, T, H, C = 8, 75, 16, 128
148
 
149
  rope = RotaryEmbedding(dim=C, xpos=True)
 
156
 
157
 
158
  def test_positional_scale():
159
+ set_efficient_attention_backend('xformers')
160
  B, T, H, C = 8, 75, 16, 128
161
 
162
  rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
tests/modules/test_transformer.py CHANGED
@@ -9,7 +9,8 @@ from itertools import product
9
  import pytest
10
  import torch
11
 
12
- from audiocraft.modules.transformer import StreamingMultiheadAttention, StreamingTransformer
 
13
 
14
 
15
  def test_transformer_causal_streaming():
@@ -86,19 +87,22 @@ def test_streaming_api():
86
 
87
  def test_memory_efficient():
88
  torch.manual_seed(1234)
89
- tr = StreamingTransformer(
90
- 16, 4, 2, custom=True, dropout=0., layer_scale=0.1)
91
- tr_mem_efficient = StreamingTransformer(
92
- 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1)
93
- tr_mem_efficient.load_state_dict(tr.state_dict())
94
- tr.eval()
95
- steps = 12
96
- x = torch.randn(3, steps, 16)
97
 
98
- with torch.no_grad():
99
- y = tr(x)
100
- y2 = tr_mem_efficient(x)
101
- assert torch.allclose(y, y2), (y - y2).norm()
 
 
 
 
 
 
 
 
 
102
 
103
 
104
  def test_attention_as_float32():
@@ -129,30 +133,32 @@ def test_attention_as_float32():
129
  @torch.no_grad()
130
  def test_streaming_memory_efficient():
131
  torch.manual_seed(1234)
132
- tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)
133
- tr_mem_efficient = StreamingTransformer(
134
- 16, 4, 2, dropout=0., memory_efficient=True, causal=True)
135
- tr.load_state_dict(tr_mem_efficient.state_dict())
136
- tr.eval()
137
- tr_mem_efficient.eval()
138
- steps = 12
139
- x = torch.randn(3, steps, 16)
 
 
140
 
141
- ref = tr(x)
142
 
143
- with tr_mem_efficient.streaming():
144
- outs = []
145
- # frame_sizes = [2] + [1] * (steps - 2)
146
- frame_sizes = [1] * steps
147
 
148
- for frame_size in frame_sizes:
149
- frame = x[:, :frame_size]
150
- x = x[:, frame_size:]
151
- outs.append(tr_mem_efficient(frame))
152
 
153
- out = torch.cat(outs, dim=1)
154
- delta = torch.norm(out - ref) / torch.norm(out)
155
- assert delta < 1e-6, delta
156
 
157
 
158
  def test_cross_attention():
@@ -204,7 +210,7 @@ def test_cross_attention_compat():
204
 
205
  y = cross_attn(queries, keys, values)[0]
206
  y_ref = ref_attn(queries, keys, values)[0]
207
- assert torch.allclose(y, y_ref, atol=1e-7)
208
 
209
  # Now let's check that streaming is working properly.
210
  with cross_attn.streaming():
 
9
  import pytest
10
  import torch
11
 
12
+ from audiocraft.modules.transformer import (
13
+ StreamingMultiheadAttention, StreamingTransformer, set_efficient_attention_backend)
14
 
15
 
16
  def test_transformer_causal_streaming():
 
87
 
88
  def test_memory_efficient():
89
  torch.manual_seed(1234)
90
+ for backend in ['torch', 'xformers']:
91
+ set_efficient_attention_backend(backend)
 
 
 
 
 
 
92
 
93
+ tr = StreamingTransformer(
94
+ 16, 4, 2, custom=True, dropout=0., layer_scale=0.1)
95
+ tr_mem_efficient = StreamingTransformer(
96
+ 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1)
97
+ tr_mem_efficient.load_state_dict(tr.state_dict())
98
+ tr.eval()
99
+ steps = 12
100
+ x = torch.randn(3, steps, 16)
101
+
102
+ with torch.no_grad():
103
+ y = tr(x)
104
+ y2 = tr_mem_efficient(x)
105
+ assert torch.allclose(y, y2), ((y - y2).norm(), backend)
106
 
107
 
108
  def test_attention_as_float32():
 
133
  @torch.no_grad()
134
  def test_streaming_memory_efficient():
135
  torch.manual_seed(1234)
136
+ for backend in ['torch', 'xformers']:
137
+ set_efficient_attention_backend(backend)
138
+ tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)
139
+ tr_mem_efficient = StreamingTransformer(
140
+ 16, 4, 2, dropout=0., memory_efficient=True, causal=True)
141
+ tr.load_state_dict(tr_mem_efficient.state_dict())
142
+ tr.eval()
143
+ tr_mem_efficient.eval()
144
+ steps = 12
145
+ x = torch.randn(3, steps, 16)
146
 
147
+ ref = tr(x)
148
 
149
+ with tr_mem_efficient.streaming():
150
+ outs = []
151
+ # frame_sizes = [2] + [1] * (steps - 2)
152
+ frame_sizes = [1] * steps
153
 
154
+ for frame_size in frame_sizes:
155
+ frame = x[:, :frame_size]
156
+ x = x[:, frame_size:]
157
+ outs.append(tr_mem_efficient(frame))
158
 
159
+ out = torch.cat(outs, dim=1)
160
+ delta = torch.norm(out - ref) / torch.norm(out)
161
+ assert delta < 1e-6, delta
162
 
163
 
164
  def test_cross_attention():
 
210
 
211
  y = cross_attn(queries, keys, values)[0]
212
  y_ref = ref_attn(queries, keys, values)[0]
213
+ assert torch.allclose(y, y_ref, atol=1e-7), (y - y_ref).norm() / y_ref.norm()
214
 
215
  # Now let's check that streaming is working properly.
216
  with cross_attn.streaming():