koichi12 commited on
Commit
5ef2986
·
verified ·
1 Parent(s): c1012a5

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/xformers/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/xformers/__pycache__/_cpp_lib.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/xformers/__pycache__/_deprecation_warning.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/xformers/__pycache__/attn_bias_utils.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/xformers/__pycache__/checkpoint.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/xformers/__pycache__/info.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/xformers/__pycache__/test.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/xformers/__pycache__/utils.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/xformers/__pycache__/version.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_mem_eff_attention.py +373 -0
  11. .venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_sp24.py +178 -0
  12. .venv/lib/python3.11/site-packages/xformers/components/attention/__init__.py +124 -0
  13. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/__init__.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/_sputnik_sparse.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/attention_mask.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/attention_patterns.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/base.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/compositional.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/core.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/favor.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/fourier_mix.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/global_tokens.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/lambda_layer.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/linformer.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/local.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/nystrom.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/ortho.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/pooling.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/random.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/scaled_dot_product.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/sparsity_config.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/utils.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/visual.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/xformers/components/attention/_sputnik_sparse.py +121 -0
  35. .venv/lib/python3.11/site-packages/xformers/components/attention/attention_mask.py +143 -0
  36. .venv/lib/python3.11/site-packages/xformers/components/attention/base.py +95 -0
  37. .venv/lib/python3.11/site-packages/xformers/components/attention/compositional.py +341 -0
  38. .venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__init__.py +26 -0
  39. .venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/__init__.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/base.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/softmax.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/base.py +61 -0
  43. .venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/softmax.py +288 -0
  44. .venv/lib/python3.11/site-packages/xformers/components/attention/global_tokens.py +122 -0
  45. .venv/lib/python3.11/site-packages/xformers/components/attention/linformer.py +74 -0
  46. .venv/lib/python3.11/site-packages/xformers/components/attention/ortho.py +324 -0
  47. .venv/lib/python3.11/site-packages/xformers/components/attention/pooling.py +82 -0
  48. .venv/lib/python3.11/site-packages/xformers/components/attention/sparsity_config.py +812 -0
  49. .venv/lib/python3.11/site-packages/xformers/components/attention/utils.py +108 -0
  50. .venv/lib/python3.11/site-packages/xformers/components/feedforward/__init__.py +78 -0
.venv/lib/python3.11/site-packages/xformers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.66 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/__pycache__/_cpp_lib.cpython-311.pyc ADDED
Binary file (8.38 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/__pycache__/_deprecation_warning.cpython-311.pyc ADDED
Binary file (661 Bytes). View file
 
.venv/lib/python3.11/site-packages/xformers/__pycache__/attn_bias_utils.cpython-311.pyc ADDED
Binary file (22.6 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/__pycache__/checkpoint.cpython-311.pyc ADDED
Binary file (27.3 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/__pycache__/info.cpython-311.pyc ADDED
Binary file (4.55 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/__pycache__/test.cpython-311.pyc ADDED
Binary file (177 Bytes). View file
 
.venv/lib/python3.11/site-packages/xformers/__pycache__/utils.cpython-311.pyc ADDED
Binary file (8.17 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/__pycache__/version.cpython-311.pyc ADDED
Binary file (207 Bytes). View file
 
.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_mem_eff_attention.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import itertools
8
+ import random
9
+ from functools import partial
10
+
11
+ import torch
12
+ from torch.utils import benchmark
13
+
14
+ import xformers.ops
15
+ import xformers.ops.fmha as fmha
16
+ from xformers.attn_bias_utils import create_attn_bias, ref_attention
17
+ from xformers.benchmarks.utils import benchmark_main_helper, create_argparser
18
+
19
+ torch.backends.cuda.matmul.allow_tf32 = False
20
+
21
+ min_run_time = 0.5
22
+ device = torch.device("cuda")
23
+
24
+ NUM_THREADS = [1] if device.type == "cuda" else [1, 40]
25
+ VISION_SHAPES = [
26
+ # ViT
27
+ (384, 197, 1, 88),
28
+ (384, 197, 1, 80),
29
+ (384, 197, 1, 64),
30
+ (1024, 197, 1, 88),
31
+ (1024, 197, 1, 80),
32
+ (1024, 197, 1, 64),
33
+ # ViT-Huge
34
+ (32 * 16, 197, 1, 80),
35
+ (32, 197, 16, 80),
36
+ (32, 197, 16, 64),
37
+ (32, 197, 16, 128),
38
+ # ViT-Giant
39
+ (16 * 16, 197, 1, 88),
40
+ (16, 197, 16, 88),
41
+ (16, 197, 16, 64),
42
+ (16, 197, 16, 128),
43
+ # FB models
44
+ (1024, 82, 8, 64),
45
+ (150, 256, 16, 64),
46
+ (64, 256, 12, 64),
47
+ # Stable diffusion (https://github.com/huggingface/diffusers/pull/532)
48
+ (1, 4096, 16, 40), # 512x512
49
+ (1, 16384, 16, 40), # 1024x1024
50
+ (1, 4096, 16, 80),
51
+ (1, 16384, 16, 80),
52
+ # + bs4
53
+ (4, 4096, 16, 40),
54
+ (4, 16384, 16, 40),
55
+ (4, 4096, 16, 80),
56
+ (4, 16384, 16, 80),
57
+ # ParlAI model
58
+ (256, 4096, 16, 64),
59
+ # Zetta B M H K
60
+ (8, 2048, 20, 128),
61
+ ]
62
+
63
+ LLM_SHAPES = [
64
+ # LLaMa 70b - mp=8/16
65
+ *sorted(itertools.product([1, 2], [2048, 4096, 8192], [4, 8], [128])),
66
+ *sorted(
67
+ itertools.product([16], [128, 512, 1024], [16], [16, 32, 64, 128, 160, 256])
68
+ ),
69
+ ]
70
+
71
+
72
+ OPS = [
73
+ (xformers.ops.fmha.cutlass.FwOp, xformers.ops.fmha.cutlass.BwOp),
74
+ (xformers.ops.fmha.flash.FwOp, xformers.ops.fmha.flash.BwOp),
75
+ (xformers.ops.fmha.flash3.FwOp, xformers.ops.fmha.flash3.BwOp),
76
+ (xformers.ops.fmha.ck.FwOp, xformers.ops.fmha.ck.BwOp),
77
+ ]
78
+
79
+
80
+ def product_dict(**kwargs):
81
+ keys = kwargs.keys()
82
+ vals = kwargs.values()
83
+ for instance in itertools.product(*vals):
84
+ yield dict(zip(keys, instance))
85
+
86
+
87
+ VISION_CASES, LLM_CASES = [
88
+ list(
89
+ product_dict(
90
+ shape_q=SHAPES,
91
+ num_threads=NUM_THREADS,
92
+ dropout_p=[0.0],
93
+ attn_bias_cfg=[(type(None), False)],
94
+ dtype=[torch.half],
95
+ )
96
+ )
97
+ for SHAPES in (VISION_SHAPES, LLM_SHAPES)
98
+ ]
99
+
100
+ # Add more cases with some variations
101
+ for c in VISION_CASES.copy():
102
+ c = c.copy()
103
+ c.update(
104
+ random.Random(str(c["shape_q"])).choice(
105
+ [
106
+ {"dropout_p": 0.3},
107
+ {"attn_bias_cfg": (torch.Tensor, False)},
108
+ {"attn_bias_cfg": (torch.Tensor, True)},
109
+ {"dtype": torch.bfloat16},
110
+ {"dtype": torch.float},
111
+ ]
112
+ )
113
+ )
114
+ VISION_CASES.append(c)
115
+
116
+
117
+ LLM_CASE_UPDATES = [
118
+ {"attn_bias_cfg": (torch.Tensor, True)},
119
+ {"attn_bias_cfg": (xformers.ops.LowerTriangularMask, False)},
120
+ *[
121
+ {
122
+ "attn_bias_cfg": (
123
+ xformers.ops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
124
+ False,
125
+ ),
126
+ "Hkv": Hkv,
127
+ "dtype": torch.bfloat16,
128
+ }
129
+ for Hkv in [1, 2]
130
+ ],
131
+ ]
132
+
133
+ for c in LLM_CASES.copy():
134
+ for update in LLM_CASE_UPDATES:
135
+ c = c.copy()
136
+ c.update(update)
137
+ LLM_CASES.append(c)
138
+
139
+ CASES = VISION_CASES + LLM_CASES
140
+
141
+
142
+ def create_tensors(shape_q, Hkv, dtype, requires_grad=False, packed=True):
143
+ stacked_shape = list(shape_q) # B, M, H, K
144
+ Hq = shape_q[2]
145
+ stacked_dim = 2 if packed else 0
146
+ stacked_shape.insert(stacked_dim, 3)
147
+ qkv = torch.rand(
148
+ stacked_shape, device=device, dtype=dtype, requires_grad=requires_grad
149
+ )
150
+ q = torch.rand(shape_q, device=device, dtype=dtype, requires_grad=requires_grad)
151
+ shape_kv = (shape_q[0], shape_q[1], Hkv, shape_q[3])
152
+ k = (
153
+ torch.rand(shape_kv, device=device, dtype=dtype, requires_grad=requires_grad)
154
+ .reshape(shape_q[0], shape_q[1], 1, Hkv, shape_q[3])
155
+ .expand(shape_q[0], shape_q[1], Hq // Hkv, Hkv, shape_q[3])
156
+ .reshape(shape_q)
157
+ )
158
+ v = (
159
+ torch.rand(shape_kv, device=device, dtype=dtype, requires_grad=requires_grad)
160
+ .reshape(shape_q[0], shape_q[1], 1, Hkv, shape_q[3])
161
+ .expand(shape_q[0], shape_q[1], Hq // Hkv, Hkv, shape_q[3])
162
+ .reshape(shape_q)
163
+ )
164
+
165
+ return qkv, q, k, v
166
+
167
+
168
+ def mem_eff_attention_fw(
169
+ shape_q,
170
+ num_threads: int,
171
+ attn_bias_cfg,
172
+ dropout_p,
173
+ dtype,
174
+ packed=True,
175
+ Hkv=None,
176
+ ):
177
+ B, M, Hq, K = shape_q
178
+ Hkv = Hkv or Hq
179
+ _, q, k, v = create_tensors(
180
+ shape_q,
181
+ Hkv,
182
+ dtype,
183
+ requires_grad=False,
184
+ packed=packed,
185
+ )
186
+ attn_bias_type, attn_bias_requires_grad = attn_bias_cfg
187
+ if attn_bias_requires_grad:
188
+ return
189
+
190
+ dtype_str = {
191
+ torch.bfloat16: "b16",
192
+ torch.half: "f16",
193
+ torch.float: "f32",
194
+ }[dtype]
195
+ sub_label = (
196
+ f"{dtype_str} {B}-{M}-{Hq}-{Hkv}-{K}, p={dropout_p}, "
197
+ f"BiasT={attn_bias_type.__name__}"
198
+ )
199
+
200
+ has_run = False
201
+ for fw_op, bw_op in OPS:
202
+ bias = create_attn_bias(
203
+ attn_bias_type,
204
+ batch_size=B,
205
+ num_heads=Hq,
206
+ num_heads_groups=Hq // Hkv,
207
+ q_len=M,
208
+ kv_len=M,
209
+ dtype=dtype,
210
+ device=device,
211
+ requires_grad=attn_bias_requires_grad,
212
+ fmt="BMHK",
213
+ op=fw_op,
214
+ )
215
+ inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p)
216
+ if isinstance(
217
+ bias,
218
+ (
219
+ fmha.attn_bias.BlockDiagonalMask,
220
+ fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
221
+ ),
222
+ ):
223
+ q, k, v = [x.reshape([1, -1, *x.shape[2:]]) for x in [q, k, v]]
224
+ if not fw_op.supports(inp):
225
+ continue
226
+
227
+ yield benchmark.Timer(
228
+ stmt="fn(q, k, v, attn_bias, p)",
229
+ globals={
230
+ "q": q,
231
+ "k": k,
232
+ "v": v,
233
+ "attn_bias": inp.attn_bias,
234
+ "p": dropout_p,
235
+ "fn": partial(
236
+ xformers.ops.memory_efficient_attention, op=(fw_op, bw_op)
237
+ ),
238
+ },
239
+ label=f"attention (attn_bias={attn_bias_type})",
240
+ description=fw_op.NAME,
241
+ sub_label=sub_label,
242
+ num_threads=num_threads,
243
+ )
244
+ has_run = True
245
+
246
+ if not has_run:
247
+ return
248
+
249
+ yield benchmark.Timer(
250
+ stmt="fn(q, k, v, attn_bias, p)",
251
+ globals={
252
+ "q": q,
253
+ "k": k,
254
+ "v": v,
255
+ "attn_bias": inp.attn_bias,
256
+ "p": dropout_p,
257
+ "fn": ref_attention,
258
+ },
259
+ label=f"attention (attn_bias={attn_bias_type})",
260
+ description="eager",
261
+ sub_label=sub_label,
262
+ num_threads=num_threads,
263
+ )
264
+
265
+
266
+ def mem_eff_attention_bw(
267
+ shape_q, num_threads: int, attn_bias_cfg, dropout_p, dtype, Hkv=None
268
+ ):
269
+ B, M, Hq, K = shape_q
270
+ Hkv = Hkv or Hq
271
+ _, q, k, v = create_tensors(
272
+ shape_q,
273
+ Hkv,
274
+ dtype,
275
+ requires_grad=True,
276
+ )
277
+
278
+ attn_bias_type, attn_bias_requires_grad = attn_bias_cfg
279
+
280
+ dtype_str = {
281
+ torch.bfloat16: "b16",
282
+ torch.half: "f16",
283
+ torch.float: "f32",
284
+ }[dtype]
285
+ sub_label = (
286
+ f"{dtype_str} {B}-{M}-{Hq}-{Hkv}-{K}, p={dropout_p}, "
287
+ f"BiasT={attn_bias_type.__name__}, BiasGrad={attn_bias_requires_grad}"
288
+ )
289
+
290
+ has_run = False
291
+ for fw_op, bw_op in OPS:
292
+ bias = create_attn_bias(
293
+ attn_bias_type,
294
+ batch_size=B,
295
+ num_heads=Hq,
296
+ num_heads_groups=Hq // Hkv,
297
+ q_len=M,
298
+ kv_len=M,
299
+ dtype=dtype,
300
+ device=device,
301
+ requires_grad=attn_bias_requires_grad,
302
+ fmt="BMHK",
303
+ op=bw_op,
304
+ )
305
+ inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p)
306
+
307
+ if not fw_op.supports(inp) or not bw_op.supports(inp):
308
+ continue
309
+ has_run = True
310
+ out = xformers.ops.memory_efficient_attention(
311
+ inp.query, inp.key, inp.value, inp.attn_bias, inp.p, op=(fw_op, bw_op)
312
+ )
313
+ grad_benchmark = torch.ones_like(q)
314
+
315
+ yield benchmark.Timer(
316
+ stmt="out.backward(grad, retain_graph=True)",
317
+ globals={
318
+ "out": out,
319
+ "grad": grad_benchmark,
320
+ },
321
+ label=f"attention backward (attn_bias={attn_bias_type})",
322
+ description=bw_op.NAME,
323
+ sub_label=sub_label,
324
+ num_threads=num_threads,
325
+ )
326
+ del out
327
+
328
+ if not has_run:
329
+ return
330
+ yield benchmark.Timer(
331
+ stmt="out.backward(grad, retain_graph=True)",
332
+ globals={
333
+ "out": ref_attention(q, k, v, inp.attn_bias, dropout_p),
334
+ "grad": grad_benchmark,
335
+ },
336
+ label=f"attention backward (attn_bias={attn_bias_type})",
337
+ description="vanilla",
338
+ sub_label=sub_label,
339
+ num_threads=num_threads,
340
+ )
341
+
342
+
343
+ def main():
344
+ arg_parser = create_argparser()
345
+ arg_parser.add_argument(
346
+ "--omit-forward",
347
+ action="store_true",
348
+ help="Do not run forward benchmarks",
349
+ )
350
+ arg_parser.add_argument(
351
+ "--omit-backward",
352
+ action="store_true",
353
+ help="Do not run backward benchmarks",
354
+ )
355
+ args = arg_parser.parse_args()
356
+ if not args.omit_forward:
357
+ benchmark_main_helper(
358
+ mem_eff_attention_fw,
359
+ CASES,
360
+ arg_parser=arg_parser,
361
+ min_run_time=min_run_time,
362
+ )
363
+ if not args.omit_backward:
364
+ benchmark_main_helper(
365
+ mem_eff_attention_bw,
366
+ CASES,
367
+ arg_parser=arg_parser,
368
+ min_run_time=min_run_time,
369
+ )
370
+
371
+
372
+ if __name__ == "__main__":
373
+ main()
.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_sp24.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from typing import Tuple
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from utils import DTYPE2STR, benchmark_main_helper2, product_dict
13
+
14
+ import xformers.ops as xops
15
+
16
+ min_run_time = 0.5
17
+ device = torch.device("cuda")
18
+
19
+ CASES = list(
20
+ product_dict(
21
+ B_in_hidden_out_ft=[
22
+ (2048 * 8, 2048, 2048 * 3, 2048),
23
+ (2048, 5120, 5120 * 3, 5120), # 13b
24
+ (1024, 8192, 8192 * 3, 8192), # 30b
25
+ (2048, 8192, 8192 * 3, 8192), # 30b
26
+ (2048 * 2, 8192, 8192 * 3, 8192), # 30b
27
+ # DINO ViT-L: lg + sm crops (patch16)
28
+ (64 * 2 * (14 * 14 + 1) + 64 * 8 * (6 * 6 + 1), 1024, 1024 * 4, 1024),
29
+ # DINO ViT-g: lg + sm crops (patch16)
30
+ (
31
+ 12 * 2 * (16 * 16 + 1 + 11) + 12 * 8 * (7 * 7 + 1 + 11),
32
+ 1536,
33
+ 1536 * 4,
34
+ 1536,
35
+ ),
36
+ ],
37
+ dtype=[torch.half],
38
+ bias=[False],
39
+ )
40
+ )
41
+
42
+
43
+ class Mlp(nn.Module):
44
+ LINEAR_CLS = nn.Linear
45
+
46
+ def __init__(
47
+ self, B_in_hidden_out_ft: Tuple[int, int, int, int], dtype, bias: bool, bw: bool
48
+ ) -> None:
49
+ B, in_ft, hid_ft, out_ft = B_in_hidden_out_ft
50
+ super().__init__()
51
+ self.label = "mlp"
52
+ self.sub_label = (
53
+ f"{DTYPE2STR[dtype]} ({B},{in_ft},{hid_ft},{out_ft}){' b' if bias else ''}"
54
+ )
55
+ self.fc1 = self.LINEAR_CLS(in_ft, hid_ft, bias=bias)
56
+ self.act = nn.GELU()
57
+ self.fc2 = self.LINEAR_CLS(hid_ft, out_ft, bias=bias)
58
+ self.grad = torch.randn([B, out_ft], device="cuda", dtype=dtype)
59
+ self.input = torch.randn(
60
+ [B, in_ft], device="cuda", dtype=dtype, requires_grad=True
61
+ )
62
+ self.out = self.input
63
+ self.to("cuda").to(dtype)
64
+
65
+ def fw(self):
66
+ x = self.input
67
+ x = self.fc1(x)
68
+ x = self.act(x)
69
+ x = self.fc2(x)
70
+ self.out = x
71
+
72
+ def bw(self):
73
+ self.out.backward(self.grad, retain_graph=True)
74
+
75
+
76
+ class MlpDenseMask(Mlp):
77
+ def fw(self):
78
+ x = self.input
79
+ x = self.fc1(x)
80
+
81
+ mask = torch.ops.xformers.sparse24_largest_mask_2d(x)
82
+ x = mask * x
83
+
84
+ x = self.act(x)
85
+ x = self.fc2(x)
86
+ self.out = x
87
+
88
+
89
+ class MlpAct24(Mlp):
90
+ def fw(self):
91
+ x = self.input
92
+ x = self.fc1(x)
93
+
94
+ x = xops.sparsify24(x)
95
+
96
+ x = self.act(x)
97
+ x = self.fc2(x)
98
+ self.out = x
99
+
100
+
101
+ class LinearW24(torch.nn.Linear):
102
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
103
+ w_sparse = xops.sparsify24(
104
+ self.weight,
105
+ gradient="24dense",
106
+ backend="cusparselt",
107
+ )
108
+ return F.linear(input, w_sparse, self.bias)
109
+
110
+
111
+ class MlpW24(Mlp):
112
+ LINEAR_CLS = LinearW24
113
+
114
+
115
+ class MicrobenchmarkBase:
116
+ def __init__(
117
+ self, B_in_hidden_out_ft: Tuple[int, int, int, int], dtype, bias: bool, bw: bool
118
+ ) -> None:
119
+ B, in_ft, hid_ft, out_ft = B_in_hidden_out_ft
120
+ super().__init__()
121
+ self.label = "mlp"
122
+ self.sub_label = (
123
+ f"{DTYPE2STR[dtype]} ({B},{in_ft},{hid_ft},{out_ft}){' b' if bias else ''}"
124
+ )
125
+ self.input = torch.randn(
126
+ [B, in_ft], device="cuda", dtype=dtype, requires_grad=True
127
+ )
128
+ self.input_colMajor = self.input.t().contiguous().t()
129
+ self.input_sp = xops.sparsify24(self.input)
130
+
131
+ def bw(self) -> None:
132
+ return None
133
+
134
+
135
+ class MicrobenchmarkSparsify24(MicrobenchmarkBase):
136
+ def fw(self) -> torch.Tensor:
137
+ xops.sparsify24(self.input)
138
+ return self.input
139
+
140
+
141
+ class MicrobenchmarkSp24ApplyDense(MicrobenchmarkBase):
142
+ def fw(self) -> torch.Tensor:
143
+ xops.sparsify24_like(self.input, pattern=self.input_sp, out_dense=True)
144
+ return self.input
145
+
146
+
147
+ class MicrobenchmarkSp24ApplyDenseT(MicrobenchmarkBase):
148
+ def fw(self) -> torch.Tensor:
149
+ xops.sparsify24_like(self.input_colMajor, pattern=self.input_sp, out_dense=True)
150
+ return self.input
151
+
152
+
153
+ class MicrobenchmarkInputClone(MicrobenchmarkBase):
154
+ def fw(self) -> torch.Tensor:
155
+ self.input.clone()
156
+ return self.input
157
+
158
+
159
+ functions = {
160
+ "act24": MlpAct24,
161
+ "dense": Mlp,
162
+ "w24": MlpW24,
163
+ "s24_inp_sparsify24": MicrobenchmarkSparsify24,
164
+ "s24_inp_apply_dense": MicrobenchmarkSp24ApplyDense,
165
+ "s24_inp_apply_dense_t": MicrobenchmarkSp24ApplyDenseT,
166
+ "s24_inp_clone": MicrobenchmarkInputClone,
167
+ }
168
+ benchmark_main_helper2(
169
+ "sp24_fw", fw=True, cases=CASES, functions=functions, min_run_time=min_run_time
170
+ )
171
+ benchmark_main_helper2(
172
+ "sp24_fwbw",
173
+ fw=True,
174
+ bw=True,
175
+ cases=CASES,
176
+ functions=functions,
177
+ min_run_time=min_run_time,
178
+ )
.venv/lib/python3.11/site-packages/xformers/components/attention/__init__.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Dict, Set, Union
9
+
10
+ import torch
11
+
12
+ from xformers.utils import (
13
+ generate_matching_config,
14
+ get_registry_decorator,
15
+ import_all_modules,
16
+ )
17
+
18
+ from ._sputnik_sparse import SparseCS
19
+ from .attention_mask import AttentionMask
20
+ from .base import Attention, AttentionConfig # noqa
21
+
22
+ logger = logging.getLogger("xformers")
23
+
24
+
25
+ # CREDITS: Classy Vision registry mechanism
26
+
27
+ ATTENTION_REGISTRY: Dict[str, Any] = {}
28
+ ATTENTION_CLASS_NAMES: Set[str] = set()
29
+
30
+ # Arbitrary threshold for now,
31
+ # in between dense and sparse matrix algorithms for the attention mechanism
32
+ _DENSITY_THRESHOLD = 0.30 # noqa # from the sputnik paper, vs.
33
+ _USE_SPUTNIK = True
34
+
35
+
36
+ def build_attention(config: Union[Dict[str, Any], AttentionConfig]):
37
+ """Builds an attention from a config.
38
+
39
+ This assumes a 'name' key in the config which is used to determine what
40
+ attention class to instantiate. For instance, a config `{"name": "my_attention",
41
+ "foo": "bar"}` will find a class that was registered as "my_attention"
42
+ (see :func:`register_attention`) and call .from_config on it."""
43
+
44
+ if not isinstance(config, AttentionConfig):
45
+ try:
46
+ config_instance = generate_matching_config(
47
+ config, ATTENTION_REGISTRY[config["name"]].config
48
+ )
49
+ except KeyError as e:
50
+ name = config["name"]
51
+ logger.warning(f"{name} not available among {ATTENTION_REGISTRY.keys()}")
52
+ raise e
53
+ else:
54
+ config_instance = config
55
+
56
+ return ATTENTION_REGISTRY[config_instance.name].constructor.from_config(
57
+ config_instance
58
+ )
59
+
60
+
61
+ """Registers an Attention subclass.
62
+
63
+ This decorator allows xFormers to instantiate a subclass of Attention
64
+ from a configuration file, even if the class itself is not part of the
65
+ xFormers library. To use it, apply this decorator to an Attention
66
+ subclass, like this:
67
+
68
+ .. code-block:: python
69
+
70
+ @dataclass
71
+ class MyConfig:
72
+ ...
73
+
74
+ @register_attention('my_attention', MyConfig)
75
+ class MyAttention(Attention):
76
+ ...
77
+
78
+ To instantiate an attention from a configuration file, see :func:`build_attention`."""
79
+ register_attention: Callable[[str, Any], Callable[[Any], Any]] = get_registry_decorator(
80
+ ATTENTION_REGISTRY, ATTENTION_CLASS_NAMES, Attention, AttentionConfig
81
+ )
82
+
83
+
84
+ def maybe_sparsify(matrix) -> Any:
85
+ # Sparsify if that makes sense
86
+ if torch.count_nonzero(matrix).item() / matrix.numel() > _DENSITY_THRESHOLD:
87
+ # If not sparse, then AttentionMask is the reference type
88
+ return AttentionMask.from_bool(matrix)
89
+
90
+ return sparsify(matrix)
91
+
92
+
93
+ def sparsify(matrix):
94
+ if _USE_SPUTNIK:
95
+ return SparseCS(matrix)
96
+ return matrix.to_sparse()
97
+
98
+
99
+ from .favor import FavorAttention # noqa
100
+ from .global_tokens import GlobalAttention # noqa
101
+ from .linformer import LinformerAttention # noqa
102
+ from .local import LocalAttention # noqa
103
+ from .nystrom import NystromAttention # noqa
104
+ from .ortho import OrthoFormerAttention # noqa
105
+ from .random import RandomAttention # noqa
106
+ from .scaled_dot_product import ScaledDotProduct # noqa
107
+
108
+ __all__ = [
109
+ "ScaledDotProduct",
110
+ "LocalAttention",
111
+ "LinformerAttention",
112
+ "NystromAttention",
113
+ "RandomAttention",
114
+ "OrthoFormerAttention",
115
+ "GlobalAttention",
116
+ "FavorAttention",
117
+ "Attention",
118
+ "AttentionMask",
119
+ "build_attention",
120
+ "register_attention",
121
+ ]
122
+
123
+ # automatically import any Python files in the directory
124
+ import_all_modules(str(Path(__file__).parent), "xformers.components.attention")
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (4.18 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/_sputnik_sparse.cpython-311.pyc ADDED
Binary file (7.3 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/attention_mask.cpython-311.pyc ADDED
Binary file (7.48 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/attention_patterns.cpython-311.pyc ADDED
Binary file (15.5 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/base.cpython-311.pyc ADDED
Binary file (4.5 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/compositional.cpython-311.pyc ADDED
Binary file (14.2 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/core.cpython-311.pyc ADDED
Binary file (11.2 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/favor.cpython-311.pyc ADDED
Binary file (7.42 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/fourier_mix.cpython-311.pyc ADDED
Binary file (2.14 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/global_tokens.cpython-311.pyc ADDED
Binary file (5.47 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/lambda_layer.cpython-311.pyc ADDED
Binary file (3.85 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/linformer.cpython-311.pyc ADDED
Binary file (4 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/local.cpython-311.pyc ADDED
Binary file (5.22 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/nystrom.cpython-311.pyc ADDED
Binary file (12.9 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/ortho.cpython-311.pyc ADDED
Binary file (15.8 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/pooling.cpython-311.pyc ADDED
Binary file (3.26 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/random.cpython-311.pyc ADDED
Binary file (5.43 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/scaled_dot_product.cpython-311.pyc ADDED
Binary file (5.46 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/sparsity_config.cpython-311.pyc ADDED
Binary file (41.9 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/utils.cpython-311.pyc ADDED
Binary file (4.46 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/visual.cpython-311.pyc ADDED
Binary file (4.68 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/_sputnik_sparse.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import torch
8
+
9
+ from xformers.ops import masked_matmul
10
+ from xformers.sparse import SparseCSRTensor
11
+
12
+ # TODO: this is here for BC
13
+ from xformers.sparse.utils import _csr_to_coo, _dense_to_sparse # noqa: F401
14
+
15
+
16
+ class SparseCS:
17
+ def __init__(self, matrix, device=None):
18
+ if device is None:
19
+ device = torch.device("cpu")
20
+ if matrix.ndim == 2:
21
+ matrix = matrix[None]
22
+ assert matrix.ndim == 3
23
+ self._mat = SparseCSRTensor.from_dense(matrix).to(device)
24
+
25
+ @property
26
+ def device(self):
27
+ return self._mat.device
28
+
29
+ @property
30
+ def ndim(self):
31
+ return self._mat.ndim
32
+
33
+ @property
34
+ def dtype(self):
35
+ return self._mat.dtype
36
+
37
+ @property
38
+ def is_sparse(self):
39
+ return True
40
+
41
+ @property
42
+ def shape(self):
43
+ return self._mat.shape[1:]
44
+
45
+ @property
46
+ def values(self):
47
+ return self._mat.values()
48
+
49
+ @property
50
+ def row_indices(self):
51
+ return self._mat._csr_row_indices
52
+
53
+ @property
54
+ def column_indices(self):
55
+ return self._mat._csr_column_indices
56
+
57
+ @property
58
+ def row_offsets(self):
59
+ return self._mat._csr_row_offsets
60
+
61
+ @property
62
+ def _transp_info(self):
63
+ return self._mat._csr_transp_info
64
+
65
+ @classmethod
66
+ def wrap(
67
+ cls, shape, values, row_indices, row_offsets, column_indices, _transp_info
68
+ ):
69
+ matrix = cls.__new__(cls)
70
+ _shape = (values.shape[0],) + shape
71
+ csr_matrix = SparseCSRTensor._wrap(
72
+ _shape, values, row_indices, row_offsets, column_indices, _transp_info
73
+ )
74
+ matrix._mat = csr_matrix
75
+ return matrix
76
+
77
+ @classmethod
78
+ def _wrap(cls, csr_matrix):
79
+ assert isinstance(csr_matrix, SparseCSRTensor)
80
+ matrix = cls.__new__(cls)
81
+ matrix._mat = csr_matrix
82
+ return matrix
83
+
84
+ def __mul__(self, other):
85
+ assert isinstance(other, (int, float))
86
+ return type(self)._wrap(self._mat * other)
87
+
88
+ def __add__(self, other):
89
+ assert isinstance(other, type(self))
90
+ return type(self)._wrap(self._mat + other._mat)
91
+
92
+ def matmul_with_mask(self, a, b):
93
+ return type(self)._wrap(masked_matmul(a, b, self._mat))
94
+
95
+ def softmax(self):
96
+ out = torch.nn.functional.softmax(self._mat, -1)
97
+ return type(self)._wrap(out)
98
+
99
+ def spmm(self, b):
100
+ out = torch.bmm(self._mat, b)
101
+ return out
102
+
103
+ def transpose(self):
104
+ out = torch.transpose(self._mat, -2, -1)
105
+ return type(self)._wrap(out)
106
+
107
+ def to(self, device):
108
+ assert isinstance(device, torch.device)
109
+ out = self._mat.to(device)
110
+ return type(self)._wrap(out)
111
+
112
+ def to_dense(self):
113
+ return self._mat.to_dense()
114
+
115
+ def logical_and(self, other: torch.Tensor):
116
+ assert not isinstance(other, SparseCS)
117
+ out = torch.logical_and(self._mat, other)
118
+ return type(self)._wrap(out)
119
+
120
+ def __and__(self, other):
121
+ return self.logical_and(other)
.venv/lib/python3.11/site-packages/xformers/components/attention/attention_mask.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from typing import Optional, Type, TypeVar
8
+
9
+ import torch
10
+
11
+ Self = TypeVar("Self", bound="AttentionMask")
12
+
13
+
14
+ class AttentionMask:
15
+ """
16
+ Holds an attention mask, along with a couple of helpers and attributes.
17
+
18
+ .. note: this is an additive mask, meaning that coefficients which should be computed hold the '0.' value,
19
+ and coefficients which should be skipped hold the '-inf' value. Any other value is possible if the purpose
20
+ is to bias the attention computation for instance
21
+
22
+ .. note: the attention mask dimensions are expected to be `[batch, to_sequence, from_sequence]`,
23
+ `[to_sequence, from_sequence]`, or anything broadcastable in between
24
+ """
25
+
26
+ def __init__(self, additive_mask: torch.Tensor, is_causal: bool = False):
27
+ assert additive_mask.is_floating_point(), additive_mask.dtype
28
+ assert not additive_mask.requires_grad
29
+
30
+ if additive_mask.ndim == 2:
31
+ additive_mask = additive_mask.unsqueeze(0)
32
+
33
+ self.values = additive_mask
34
+ self.is_causal = is_causal
35
+ self.seq_len = additive_mask.shape[1]
36
+ self.to_seq_len = additive_mask.shape[0]
37
+
38
+ def to_bool(self) -> torch.Tensor:
39
+ """
40
+ .. warning: we assume here that True implies that the value should be computed
41
+ """
42
+ return self.values != float("-inf")
43
+
44
+ @classmethod
45
+ def from_bool(cls: Type[Self], x: torch.Tensor) -> Self:
46
+ """
47
+ Create an AttentionMask given a boolean pattern.
48
+ .. warning: we assume here that True implies that the value should be computed
49
+ """
50
+ assert x.dtype == torch.bool
51
+
52
+ additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device)
53
+ additive_mask.masked_fill_(x, 0.0)
54
+ additive_mask.masked_fill_(~x, float("-inf"))
55
+
56
+ return cls(additive_mask)
57
+
58
+ @classmethod
59
+ def from_multiplicative(cls: Type[Self], x: torch.Tensor) -> Self:
60
+ """
61
+ Create an AttentionMask given a multiplicative attention mask.
62
+ """
63
+ assert not x.dtype == torch.bool
64
+
65
+ additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device)
66
+ x = x.bool()
67
+
68
+ additive_mask.masked_fill_(x, 0.0)
69
+ additive_mask.masked_fill_(~x, float("-inf"))
70
+
71
+ return cls(additive_mask)
72
+
73
+ @classmethod
74
+ def make_causal(
75
+ cls: Type[Self],
76
+ seq_len: int,
77
+ to_seq_len: Optional[int] = None,
78
+ device: Optional[torch.device] = None,
79
+ dtype: Optional[torch.dtype] = None,
80
+ ) -> Self:
81
+ if not to_seq_len:
82
+ to_seq_len = seq_len
83
+
84
+ additive_mask = torch.triu(
85
+ torch.ones(seq_len, to_seq_len, device=device, dtype=dtype) * float("-inf"),
86
+ diagonal=1,
87
+ )
88
+ return cls(additive_mask=additive_mask, is_causal=True)
89
+
90
+ def make_crop(
91
+ self, seq_len: int, to_seq_len: Optional[int] = None
92
+ ) -> "AttentionMask":
93
+ """
94
+ Return a cropped attention mask, whose underlying tensor is a view of this one
95
+ """
96
+
97
+ if not to_seq_len:
98
+ to_seq_len = seq_len
99
+
100
+ return AttentionMask(
101
+ self.values[:, :seq_len, :to_seq_len], is_causal=self.is_causal
102
+ )
103
+
104
+ def __repr__(self):
105
+ return f"AttentionMask - causal {self.is_causal} - mask " + str(self.values)
106
+
107
+ @property
108
+ def device(self):
109
+ return self.values.device
110
+
111
+ @property
112
+ def is_sparse(self):
113
+ return False
114
+
115
+ @property
116
+ def ndim(self):
117
+ return len(self.values.shape)
118
+
119
+ @property
120
+ def dtype(self):
121
+ return self.values.dtype
122
+
123
+ @property
124
+ def shape(self):
125
+ return self.values.shape
126
+
127
+ def __add__(self, other):
128
+ return AttentionMask(self.values + other.values, is_causal=False)
129
+
130
+ def to(
131
+ self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
132
+ ) -> "AttentionMask":
133
+ assert device is None or isinstance(device, torch.device)
134
+ assert dtype is None or isinstance(dtype, torch.dtype)
135
+ assert device is not None or dtype is not None
136
+
137
+ # Noop if we don't need to create another instance
138
+ if ((device and device == self.device) or not device) and (
139
+ (dtype and dtype == self.dtype) or not dtype
140
+ ):
141
+ return self
142
+
143
+ return AttentionMask(self.values.to(device=device, dtype=dtype), self.is_causal)
.venv/lib/python3.11/site-packages/xformers/components/attention/base.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from abc import ABCMeta, abstractmethod
8
+ from dataclasses import asdict, dataclass
9
+ from typing import Optional, Type, TypeVar
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from xformers._deprecation_warning import deprecated_function
15
+ from xformers.components.attention import AttentionMask
16
+
17
+
18
+ @dataclass
19
+ class AttentionConfig:
20
+ """Parameters required for all Attentions.
21
+ Can accept and store extra parameters.
22
+ """
23
+
24
+ name: str # the registered name for this attention mechanism
25
+ dropout: float # dropout probability
26
+
27
+
28
+ Self = TypeVar("Self", bound="Attention")
29
+
30
+
31
+ # Define the common interface, every attention block needs to derive from it
32
+ class Attention(nn.Module, metaclass=ABCMeta):
33
+ r"""The base Attention mechanism, which is typically a sub-part of the multi-head attention"""
34
+
35
+ _causal_mask: Optional[AttentionMask] = None
36
+
37
+ @abstractmethod
38
+ def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
39
+ super().__init__()
40
+ deprecated_function(self)
41
+
42
+ # Requires the inputs to be projected
43
+ self.requires_input_projection = True
44
+
45
+ # Whether the head dimension needs to be present (if not it can be folded into the batch dimension)
46
+ self.requires_head_dimension = False
47
+
48
+ # key padding mask and attention mask must be passed in as separate arguments instead of a merged attention mask
49
+ self.requires_separate_masks = False
50
+
51
+ # Requires that K and Q have the same sequence length
52
+ self.requires_same_k_q_dimensions = False
53
+
54
+ # Whether the attention owns the single head/multihead mechanism
55
+ # so that the MHA wrapper should skip it
56
+ self.requires_skip_multi_head = False
57
+
58
+ # This attention requires a context length which is squared, often due to 2D pooling
59
+ self.requires_squared_context = False
60
+
61
+ # Whether this attention mechanism supports attention masks
62
+ self.supports_attention_mask = True
63
+ self.supports_key_padding_mask = False
64
+
65
+ @classmethod
66
+ def from_config(cls: Type[Self], config: AttentionConfig) -> Self:
67
+ # Generate the class inputs from the config
68
+ fields = asdict(config)
69
+
70
+ # Skip all Nones so that default values are used
71
+ fields = {k: v for k, v in fields.items() if v is not None}
72
+
73
+ return cls(**fields)
74
+
75
+ @abstractmethod
76
+ def forward(
77
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
78
+ ) -> torch.Tensor:
79
+ raise NotImplementedError
80
+
81
+ @staticmethod
82
+ def _maybe_pad_sequence(x: torch.Tensor, mask: torch.Tensor):
83
+ """
84
+ If the sequence is shorter than the mask, return a padded view
85
+ """
86
+ if x.shape[-2] != mask.shape[-1]:
87
+ assert x.shape[-2] < mask.shape[-1], (
88
+ "Sequence is bigger than the provided mask, cannot infer what to do with it."
89
+ " Please update your attention mask"
90
+ )
91
+
92
+ pad_size = (0, 0, 0, mask.shape[-1] - x.shape[-2], 0, 0)
93
+ return torch.nn.functional.pad(x, pad_size, mode="constant", value=0.0)
94
+
95
+ return x
.venv/lib/python3.11/site-packages/xformers/components/attention/compositional.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ # Credits: this is heavily inspired by the official implementation, present in
8
+ # https://github.com/sarthmit/Compositional-Attention
9
+ # Original author: Sarthak Mittal
10
+
11
+ # This is a simplified version, for the sake of clarity, and because some features could be exposed later
12
+ # via the library directly.
13
+ # In particular, code paths for TPUs, quantization and gumbel softmax have been removed
14
+ # We're also following the same dimension ordering as in the rest of the xformers library
15
+ # which is to say [Batch, Sequence, Embedding] wherever possible
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from torch import Tensor, nn
24
+
25
+ from xformers.components.attention import (
26
+ Attention,
27
+ AttentionConfig,
28
+ AttentionMask,
29
+ register_attention,
30
+ )
31
+ from xformers.components.attention.core import _softmax
32
+ from xformers.components.input_projection import InputProjection, InputProjectionConfig
33
+
34
+
35
+ def _either_or(a: Optional[int], b: int) -> int:
36
+ return a if a is not None else b
37
+
38
+
39
+ @dataclass
40
+ class CompositionalAttentionConfig(AttentionConfig):
41
+ dim_model: int
42
+ num_heads: int
43
+ dim_attn: Optional[int] = None
44
+ num_rules: Optional[int] = None
45
+ dim_key: Optional[int] = None
46
+ dim_value: Optional[int] = None
47
+ dim_selection: Optional[int] = None
48
+ dropout: float
49
+ qk_rule: bool = False
50
+ nonlinear: bool = False
51
+ q_compose: bool = False
52
+ bias: bool = True
53
+ causal: Optional[bool] = False
54
+ in_proj_container: Optional[InputProjection] = None
55
+ use_separate_proj_weight: Optional[bool] = False
56
+
57
+
58
+ @register_attention("compositional", CompositionalAttentionConfig)
59
+ class CompositionalAttention(Attention):
60
+ """Compositional Attention, as proposed in
61
+ "Compositional Attention: Disentangling search and retrieval"_, S. Mittal et al.
62
+
63
+ A key insight from this proposal is that the attention mechanism can be conceived as two steps:
64
+ a search and a retrieval operation. When queried, the model can search for the most relevant information
65
+ (Softmax(QKt)), then retrieve information given the Value.
66
+
67
+ Contrary to the original attention proposal, which does not consider interactions in between heads,
68
+ the compositional attention will consider all possible interactions and softmax over that dimension,
69
+ so that the information retrieved covers the most relevant dimensions. The number of heads and rules to
70
+ use is thus typically smaller than for a comparable traditional Transformer, and asking for the same number of heads
71
+ may not fit in memory.
72
+
73
+ Args:
74
+ dim_model: dimension of the incoming latent space
75
+ num_heads: number of heads *for the search operation*
76
+ dim_attn: dimension (embedding) of the attention
77
+ num_rules: number of rules to consider *for the retrieval operation*
78
+ dim_selection: dimension of the scoring/selection space for the retrievals
79
+ dim_key, dim_value: dimensions of K and V, if different from Q
80
+ dropout: attention dropout probability
81
+ qk_rule: QK product will drive the retrieval process
82
+ nonlinear: use a non linear method to score the retrievals
83
+ bias: use bias in the initial projection step
84
+ causal: causal computations (attend to the past only)
85
+
86
+ _"Compositional Attention: Disentangling search and retrieval": https://arxiv.org/pdf/2110.09419v1.pdf
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ dim_model: int,
92
+ num_heads: int,
93
+ dim_attn: Optional[int] = None,
94
+ num_rules: Optional[int] = None,
95
+ dim_selection: Optional[int] = None,
96
+ dim_key: Optional[int] = None,
97
+ dim_value: Optional[int] = None,
98
+ dropout=0.0,
99
+ qk_rule=False,
100
+ nonlinear=False,
101
+ q_compose=False,
102
+ in_proj_container: Optional[InputProjection] = None,
103
+ use_separate_proj_weight: Optional[bool] = False,
104
+ bias=True,
105
+ causal=False,
106
+ *_,
107
+ **__,
108
+ ):
109
+ super().__init__()
110
+
111
+ # Define the inherited flags
112
+ self.requires_skip_multi_head = (
113
+ True # This attention owns the multi-head mechanism
114
+ )
115
+
116
+ # Handle defaults / undefined values
117
+ self.dim_model = dim_model
118
+ num_rules = _either_or(num_rules, num_heads)
119
+ dim_selection = _either_or(dim_selection, dim_model // num_heads)
120
+
121
+ # All the initial definition plumbing
122
+ dim_attn = _either_or(dim_attn, dim_model)
123
+ dim_key = _either_or(dim_key, dim_model)
124
+ dim_value = _either_or(dim_value, dim_model)
125
+
126
+ self.in_proj_container = (
127
+ in_proj_container
128
+ if in_proj_container is not None
129
+ else InputProjection(
130
+ query_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias),
131
+ key_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias)
132
+ if use_separate_proj_weight
133
+ else None,
134
+ value_proj_params=InputProjectionConfig(dim_model, dim_value, bias=bias)
135
+ if use_separate_proj_weight
136
+ else None,
137
+ )
138
+ )
139
+
140
+ self.num_heads = num_heads
141
+ self.num_rules = num_rules
142
+ self.qk_rule = qk_rule
143
+ self.dim_selection = dim_selection
144
+ self.nonlinear = nonlinear
145
+ self.q_compose = q_compose
146
+
147
+ self.dropout_module = nn.Dropout(dropout)
148
+ self.dim_head = dim_model // num_heads
149
+ self.value_dim = dim_attn // num_rules
150
+
151
+ assert (
152
+ self.value_dim * num_rules == dim_attn
153
+ ), "value_dim must be divisible by num_rules"
154
+
155
+ self.scaling = self.dim_head**-0.5
156
+ self.scaling_values = self.dim_selection**-0.5
157
+
158
+ self.out_proj = nn.Linear(self.num_heads * self.value_dim, dim_model, bias=bias)
159
+
160
+ if self.qk_rule:
161
+ self.value_k = nn.Linear(self.value_dim, self.dim_selection, bias=bias)
162
+ if self.q_compose:
163
+ self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias)
164
+ else:
165
+ self.value_q = nn.Linear(
166
+ dim_model, self.dim_selection * self.num_heads, bias=bias
167
+ )
168
+ else:
169
+ if self.q_compose:
170
+ self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias)
171
+ else:
172
+ self.value_q = nn.Linear(
173
+ dim_model, self.dim_selection * self.num_heads, bias=bias
174
+ )
175
+ if self.nonlinear:
176
+ self.score_network: nn.Module = nn.Sequential(
177
+ nn.Linear(
178
+ self.dim_selection + self.value_dim,
179
+ self.dim_selection,
180
+ bias=bias,
181
+ ),
182
+ nn.ReLU(),
183
+ nn.Linear(self.dim_selection, 1, bias=bias),
184
+ )
185
+ else:
186
+ self.score_network = nn.Linear(
187
+ self.dim_selection + self.value_dim, 1, bias=bias
188
+ )
189
+
190
+ self.causal = causal
191
+
192
+ # Properties specific to this attention mechanism
193
+ self.supports_attention_mask = True
194
+ self.supports_key_padding_mask = False
195
+
196
+ self._reset_parameters()
197
+
198
+ def _reset_parameters(self):
199
+ # NOTE: in_proj_container is already initialized
200
+
201
+ if self.qk_rule:
202
+ nn.init.xavier_uniform_(self.value_k.weight, gain=1 / math.sqrt(2))
203
+ nn.init.xavier_uniform_(self.value_q.weight, gain=1 / math.sqrt(2))
204
+ else:
205
+ nn.init.xavier_uniform_(self.value_q.weight)
206
+ if self.nonlinear:
207
+ nn.init.xavier_uniform_(self.score_network[0].weight)
208
+ nn.init.xavier_uniform_(self.score_network[2].weight)
209
+ else:
210
+ nn.init.xavier_uniform_(self.score_network.weight)
211
+
212
+ nn.init.xavier_uniform_(self.out_proj.weight)
213
+ if self.out_proj.bias is not None:
214
+ nn.init.constant_(self.out_proj.bias, 0.0)
215
+
216
+ def forward(
217
+ self,
218
+ q: Tensor,
219
+ k: Tensor,
220
+ v: Tensor,
221
+ att_mask: Optional[Tensor] = None,
222
+ *args,
223
+ **kwargs,
224
+ ) -> Tensor:
225
+ """
226
+ Input shape: Time x Batch x Channel
227
+
228
+ Args:
229
+ att_mask (ByteTensor, optional): typically used to
230
+ implement causal attention, where the mask prevents the
231
+ attention from looking forward in time (default: None).
232
+ """
233
+
234
+ B, Sq, E = q.shape
235
+ _, Sk, _ = k.shape
236
+
237
+ assert E == self.dim_model
238
+
239
+ # First define projected query/key/values
240
+ # We keep the projected and original tensors in flight,
241
+ # depending on the options the original values could be reused
242
+ q_unprojected = q
243
+ q, k, v = self.in_proj_container(query=q, key=k, value=v)
244
+ q *= self.scaling
245
+
246
+ # Init causal mask if needed, now that we know the context length
247
+ if self.causal and (
248
+ self._causal_mask is None or self._causal_mask.shape[0] != Sk
249
+ ):
250
+ self._causal_mask = AttentionMask.make_causal(Sq, Sq, device=q.device)
251
+
252
+ # Convenience, create an attention mask if a tensor was passed
253
+ # This sanitizes different mask types being passed, from now on it's additive
254
+ if isinstance(att_mask, torch.Tensor):
255
+ # By default we don't know of the causality, and a check would be expensive
256
+ att_mask_additive: Optional[AttentionMask] = (
257
+ AttentionMask.from_bool(att_mask)
258
+ if att_mask.dtype == torch.bool
259
+ else AttentionMask(att_mask, is_causal=False)
260
+ )
261
+ else:
262
+ att_mask_additive = None
263
+
264
+ # Handle the attention and key padding masks
265
+ if self._causal_mask is not None:
266
+ # Optionally add the causal mask
267
+ if att_mask_additive is not None:
268
+ att_mask_additive += self._causal_mask
269
+ else:
270
+ att_mask_additive = self._causal_mask
271
+
272
+ # Flatten the heads or the rules
273
+ q = (
274
+ q.view(B, Sq, self.num_heads, self.dim_head)
275
+ .movedim(2, 1)
276
+ .flatten(0, 1) # [B * num_heads, Sq, dim_head]
277
+ )
278
+ k = (
279
+ k.view(B, Sk, self.num_heads, self.dim_head).movedim(2, 1).flatten(0, 1)
280
+ ) # [B * num_heads, Sk, dim_head]
281
+ v = v.view(B, -1, self.num_rules, self.value_dim).movedim(2, 1).flatten(0, 1)
282
+
283
+ # Compute the search: Softmax(QKt)
284
+ attn_weights = torch.bmm(q, k.transpose(1, 2)) # [B * self.num_heads, Sq, Sk]
285
+
286
+ if att_mask_additive is not None:
287
+ attn_weights += att_mask_additive.values
288
+
289
+ attn_weights = _softmax(attn_weights, causal=self.causal)
290
+
291
+ attn_weights = attn_weights.view(B, self.num_heads, Sq, Sk)
292
+ attn_probs = self.dropout_module(attn_weights)
293
+
294
+ # Now compute the information retrieval
295
+ # keep all the heads in flight, we'll score the different possibilities
296
+ # - compute all the possible retrievals
297
+ v = v.view(B, 1, self.num_rules, Sk, self.value_dim)
298
+ attn_probs = attn_probs.unsqueeze(2)
299
+ attn = torch.matmul(attn_probs, v).view(
300
+ B, self.num_heads, self.num_rules, Sq, self.value_dim
301
+ )
302
+
303
+ attn = attn.movedim(3, 1) # [B, Sq, H, Rules, Values]
304
+
305
+ # - search the most appropriate retrieval among all the values
306
+ if self.q_compose:
307
+ v_q = self.value_q(q.transpose(0, 1)).view(
308
+ B, Sq, self.num_heads, 1, self.dim_selection
309
+ )
310
+ else:
311
+ v_q = self.value_q(q_unprojected).view(
312
+ B, Sq, self.num_heads, 1, self.dim_selection
313
+ )
314
+
315
+ if self.qk_rule:
316
+ v_q *= self.scaling_values
317
+ v_k = (
318
+ self.value_k(attn)
319
+ .view(B, Sq, self.num_heads, self.num_rules, self.dim_selection)
320
+ .transpose(4, 3)
321
+ .contiguous()
322
+ )
323
+ v_score = torch.matmul(v_q, v_k).view(
324
+ B, Sq, self.num_heads, self.num_rules, 1
325
+ )
326
+ else:
327
+ v_q = v_q.expand(-1, -1, -1, self.num_rules, -1)
328
+ v_in = torch.cat([attn, v_q], dim=-1)
329
+ v_score = self.score_network(v_in).view(
330
+ B, Sq, self.num_heads, self.num_rules, 1
331
+ )
332
+
333
+ v_score = F.softmax(v_score, dim=3)
334
+
335
+ # - extracted values are the original attention (inc. all the values) weighted by value score
336
+ attn = (attn * v_score).sum(dim=3).view(B, Sq, self.num_heads * self.value_dim)
337
+
338
+ # Final attention projection, same as other mechanisms
339
+ attn = self.out_proj(attn)
340
+
341
+ return attn
.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from enum import Enum
8
+
9
+ from .base import FeatureMap, FeatureMapConfig
10
+ from .softmax import NormDistribution, SMHyperbolic, SMOrf, SMReg
11
+
12
+
13
+ class FeatureMapType(str, Enum):
14
+ SMOrf = "sm_orf"
15
+ SMHyp = "sm_hyp"
16
+ SMReg = "sm_reg" # regularized softmax kernel
17
+
18
+
19
+ __all__ = [
20
+ "SMOrf",
21
+ "SMReg",
22
+ "SMHyperbolic",
23
+ "NormDistribution",
24
+ "FeatureMapConfig",
25
+ "FeatureMap",
26
+ ]
.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (865 Bytes). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/base.cpython-311.pyc ADDED
Binary file (3.02 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/softmax.cpython-311.pyc ADDED
Binary file (11.7 kB). View file
 
.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/base.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from abc import abstractmethod
8
+ from dataclasses import asdict, dataclass
9
+ from typing import Optional, Type, TypeVar
10
+
11
+ import torch
12
+
13
+ """
14
+ Feature maps allow for a given query or key to be encoded in a different space.
15
+ """
16
+
17
+ Self = TypeVar("Self", bound="FeatureMap")
18
+
19
+
20
+ @dataclass
21
+ class FeatureMapConfig:
22
+ name: str
23
+ dim_features: int
24
+ iter_before_redraw: Optional[int]
25
+ normalize_inputs: Optional[bool]
26
+ epsilon: Optional[float]
27
+
28
+
29
+ class FeatureMap(torch.nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim_features: int,
33
+ iter_before_redraw: Optional[int] = None,
34
+ normalize_inputs: bool = False,
35
+ epsilon: float = 1e-6,
36
+ ):
37
+ super().__init__()
38
+
39
+ self.dim_features = dim_features
40
+ self.dim_feature_map = dim_features
41
+
42
+ self.iter_before_redraw = iter_before_redraw
43
+ self.features: Optional[torch.Tensor] = None
44
+ self.epsilon = epsilon
45
+ self.normalize_inputs = normalize_inputs
46
+
47
+ self._iter_counter = 0
48
+
49
+ @abstractmethod
50
+ def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
51
+ raise NotImplementedError()
52
+
53
+ @classmethod
54
+ def from_config(cls: Type[Self], config: FeatureMapConfig) -> Self:
55
+ # Generate the class inputs from the config
56
+ fields = asdict(config)
57
+
58
+ # Skip all Nones so that default values are used
59
+ fields = {k: v for k, v in fields.items() if v is not None}
60
+
61
+ return cls(**fields)
.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/softmax.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import math
8
+ from enum import Enum, auto
9
+ from typing import Optional
10
+
11
+ import torch
12
+ from torch.autograd.profiler import record_function
13
+
14
+ from .base import FeatureMap
15
+
16
+ """
17
+ A set of feature maps which approximate the softmax kernel, as per the Performers_ paper.
18
+
19
+ _Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
20
+ https://arxiv.org/pdf/2009.14794v1.pdf
21
+ """
22
+
23
+
24
+ class NormDistribution(Enum):
25
+ Xi = auto()
26
+ Uniform = auto()
27
+
28
+
29
+ class SoftMaxPositiveEstimators(FeatureMap):
30
+ def __init__(
31
+ self,
32
+ dim_features: int,
33
+ iter_before_redraw: Optional[int],
34
+ normalize_inputs: bool = False,
35
+ epsilon: float = 1e-6,
36
+ softmax_temp: float = -1,
37
+ ):
38
+ super().__init__(dim_features, iter_before_redraw, normalize_inputs, epsilon)
39
+ self.softmax_temp = softmax_temp
40
+
41
+ # Handle the scaling from all kernels by √m.
42
+ # This normalizes for all the feature maps involved
43
+ self.h_scale = math.log(math.sqrt(self.dim_features))
44
+
45
+ def pre_scale(self, x: torch.Tensor) -> torch.Tensor:
46
+ with record_function("feature_map::pre_scale"):
47
+ # Re-draw counting logic
48
+ if (
49
+ (
50
+ self.iter_before_redraw is not None
51
+ and self._iter_counter > self.iter_before_redraw
52
+ )
53
+ or self.features is None
54
+ or self.features.device != x.device
55
+ ):
56
+ # The feature map is actually using half the dimension, we'll concatenate + and - features
57
+ self._iter_counter = 1
58
+ self.features = self._get_feature_map(
59
+ x.shape[-1], self.dim_feature_map, x.device
60
+ )
61
+
62
+ features = self.features
63
+ assert features is not None
64
+
65
+ if features.dtype != x.dtype:
66
+ self.features = features.to(x.dtype)
67
+
68
+ self._iter_counter += 1
69
+
70
+ # Normalization / softmax
71
+ if self.softmax_temp < 0:
72
+ # A = exp(QK.t/√d), so each input will be scaled by √√d
73
+ self.softmax_temp = x.shape[-1] ** -0.25
74
+
75
+ x_scaled = x * self.softmax_temp
76
+
77
+ # Compute the scaling factors in logspace, applied from within the exponential
78
+ # - dimnish possible exponential overflow
79
+ # - remove a multiply across the batch, replace by an addition
80
+ norm_x_2 = torch.einsum("...d,...d->...", x_scaled, x_scaled).unsqueeze(-1)
81
+ self.offset = -0.5 * norm_x_2 - self.h_scale + self.epsilon
82
+
83
+ if self.normalize_inputs:
84
+ # L0 normalize the exponential term, can be useful for numerical stability
85
+ # This ensures that features +- offset is below 1
86
+ self.offset -= norm_x_2.max(1, keepdim=True)[0]
87
+
88
+ # Return the scaled inputs, the rest depends on the kernel being used
89
+ return x_scaled
90
+
91
+ @staticmethod
92
+ @torch.no_grad()
93
+ def _get_random_ortho_matrix(
94
+ blocks: int,
95
+ dim: int,
96
+ device: torch.device,
97
+ norm_distribution: NormDistribution = NormDistribution.Uniform,
98
+ ) -> torch.Tensor:
99
+ r"""
100
+ Generate a random matrix whose rows are exactly orthonormal
101
+
102
+ "How to generate random matrices from the classical compact groups", Mezzadri, 2007
103
+ https://arxiv.org/pdf/math-ph/0609050v2.pdf
104
+
105
+ .. note: the typical qr decomposition does not give uniform results, qr decomposition is not
106
+ unique and the qr decomposition routines are biased towards numerical stability. See the above
107
+ paper for more information.
108
+
109
+ .. note: this does not follow the original implementation from the Performers authors.
110
+ see docs/assets/kde plots to visualize the impact of using the R signs to correct Q
111
+ """
112
+
113
+ H = torch.randn((blocks, dim, dim), device=device, requires_grad=False)
114
+
115
+ # Randomly scale the norms of the features, Xi distributed
116
+ if norm_distribution == NormDistribution.Xi:
117
+ # NOTE: This averages to sqrt(d)
118
+ norms = torch.sqrt(torch.einsum("...d,...d->...", H, H))
119
+
120
+ Q, R = torch.linalg.qr(H)
121
+ Q = torch.diag_embed(torch.sign(torch.diagonal(R, dim1=1, dim2=2))) @ Q
122
+
123
+ # Normalize if need be. Uniform NormDistribution does nothing, Q is already orthonormal
124
+ if norm_distribution == NormDistribution.Xi:
125
+ return torch.diag_embed(norms) @ Q
126
+
127
+ return Q
128
+
129
+
130
+ class SMOrf(SoftMaxPositiveEstimators):
131
+ """
132
+ "Positive random orthogonal features" softmax estimator,
133
+ SM_ort^m+, as proposed in the Performers_ paper, Lemma 1.
134
+
135
+ _Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
136
+ https://arxiv.org/pdf/2009.14794v1.pdf
137
+ """
138
+
139
+ @torch.no_grad()
140
+ def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
141
+ """
142
+ Generate the projection matrix onto the random features
143
+
144
+ .. note: The heads dimension needs to be taken into account, hence the per-block random matrix
145
+ and not uniformally random.
146
+ """
147
+
148
+ # Get per block random unitary matrices.
149
+ # We need enough of them to project the whole input dimension, regardless of the
150
+ # requested dimension of the features
151
+ features = self._get_random_ortho_matrix(
152
+ math.ceil(dim_input / dim_features),
153
+ dim_features,
154
+ norm_distribution=NormDistribution.Xi,
155
+ device=device,
156
+ )
157
+
158
+ return features.flatten(0, 1)[:dim_input]
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ # Softmax-dimension related scaling, shared for all kernels
162
+ x_scaled = super().pre_scale(x)
163
+ assert self.features is not None
164
+
165
+ # Project onto the random feature map.
166
+ x_scaled = x_scaled @ self.features
167
+ return torch.exp(x_scaled + self.offset)
168
+
169
+
170
+ class SMHyperbolic(SoftMaxPositiveEstimators):
171
+ """
172
+ "Positive random features hyperbolic" estimator, SMHyp+,
173
+ as proposed in the Performers_ paper, Lemma 1.
174
+
175
+ _Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
176
+ https://arxiv.org/pdf/2009.14794v1.pdf
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ dim_features: int,
182
+ iter_before_redraw: Optional[int],
183
+ normalize_inputs: bool = False,
184
+ epsilon: float = 1e-6,
185
+ softmax_temp: float = -1,
186
+ ):
187
+ super().__init__(
188
+ dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
189
+ )
190
+
191
+ assert (
192
+ dim_features % 2 == 0
193
+ ), "The feature dimension needs to be even with this kernel"
194
+ self.dim_feature_map = self.dim_features // 2
195
+
196
+ @torch.no_grad()
197
+ def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
198
+ """
199
+ Generate the projection matrix onto the random features
200
+
201
+ .. note: The heads dimension needs to be taken into account, hence the per-block random matrix
202
+ and not uniformally random.
203
+ """
204
+
205
+ # Get per block random unitary matrices.
206
+ # We need enough of them to project the whole input dimension, regardless of the
207
+ # requested dimension of the features
208
+ features = self._get_random_ortho_matrix(
209
+ math.ceil(dim_input / dim_features),
210
+ dim_features,
211
+ norm_distribution=NormDistribution.Xi,
212
+ device=device,
213
+ )
214
+
215
+ return features.flatten(0, 1)[:dim_input]
216
+
217
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
218
+ # Softmax-dimension related scaling, shared for all kernels
219
+ x_scaled = super().pre_scale(x)
220
+
221
+ # Project onto the random feature map, concatenate both + and - results
222
+ # This follows Lemma 1 in the original Performers Paper to best approximate a
223
+ # softmax kernel (cosh representation)
224
+ x_scaled = x_scaled @ self.features
225
+ return torch.cat(
226
+ [torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
227
+ dim=-1,
228
+ )
229
+
230
+
231
+ class SMReg(SoftMaxPositiveEstimators):
232
+ """
233
+ "Regularized softmax kernel" estimator, SMREG+, as proposed in the Performers_ paper.
234
+
235
+ _Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
236
+ https://arxiv.org/pdf/2009.14794v1.pdf
237
+ """
238
+
239
+ def __init__(
240
+ self,
241
+ dim_features: int,
242
+ iter_before_redraw: Optional[int],
243
+ normalize_inputs: bool = False,
244
+ epsilon: float = 1e-6,
245
+ softmax_temp: float = -1,
246
+ ):
247
+ super().__init__(
248
+ dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
249
+ )
250
+
251
+ assert (
252
+ dim_features % 2 == 0
253
+ ), "The feature dimension needs to be even with this kernel"
254
+ self.dim_feature_map = self.dim_features // 2
255
+
256
+ @torch.no_grad()
257
+ def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
258
+ """
259
+ Generate the projection matrix onto the random features
260
+
261
+ .. note: The heads dimension needs to be taken into account, hence the per-block random matrix
262
+ and not uniformally random.
263
+ """
264
+
265
+ # Get per block random unitary matrices.
266
+ # We need enough of them to project the whole input dimension, regardless of the
267
+ # requested dimension of the features
268
+ features = self._get_random_ortho_matrix(
269
+ math.ceil(dim_input / dim_features),
270
+ dim_features,
271
+ norm_distribution=NormDistribution.Uniform,
272
+ device=device,
273
+ ).flatten(0, 1)
274
+ norms = math.sqrt(dim_input) * torch.ones(features.shape[0], device=device)
275
+ return (torch.diag(norms) @ features)[:dim_input]
276
+
277
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
278
+ # Softmax-dimension related scaling, shared for all kernels
279
+ x_scaled = super().pre_scale(x)
280
+
281
+ # Project onto the random feature map, concatenate both + and - results
282
+ # This follows Lemma 1 in the original Performers Paper to best approximate a
283
+ # softmax kernel (cosh representation + sample regularization)
284
+ x_scaled = x_scaled @ self.features
285
+ return torch.cat(
286
+ [torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
287
+ dim=-1,
288
+ )
.venv/lib/python3.11/site-packages/xformers/components/attention/global_tokens.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Optional, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from xformers.components.attention import (
14
+ Attention,
15
+ AttentionConfig,
16
+ AttentionMask,
17
+ maybe_sparsify,
18
+ register_attention,
19
+ sparsify,
20
+ )
21
+ from xformers.components.attention.attention_patterns import (
22
+ causal_1d_pattern,
23
+ global_token_pattern,
24
+ )
25
+ from xformers.components.attention.core import scaled_dot_product_attention
26
+
27
+
28
+ @dataclass
29
+ class GlobalAttentionConfig(AttentionConfig):
30
+ attention_query_mask: torch.Tensor # Mark the queries which have global attention
31
+ causal: Optional[bool]
32
+ force_sparsity: Optional[bool]
33
+
34
+
35
+ @register_attention("global", GlobalAttentionConfig)
36
+ class GlobalAttention(Attention):
37
+ def __init__(
38
+ self,
39
+ dropout: float,
40
+ attention_query_mask: torch.Tensor,
41
+ causal: bool = False,
42
+ force_sparsity: bool = False,
43
+ *_,
44
+ **__,
45
+ ):
46
+ r"""
47
+ Global attention, as proposed for instance in BigBird_ or Longformer_.
48
+
49
+ Global means in that case that the queries positively labelled in the ```attention_query_mask``` can attend
50
+ to all the other queries. The queries negatively labelled in the ```attention_query_mask``` cannot attend to
51
+ any other query.
52
+
53
+ This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory.
54
+
55
+ Args:
56
+ dropout (float): probability of an element to be zeroed
57
+ attention_query_mask (torch.Tensor): if true, this query can attend to all the others
58
+
59
+ """
60
+ super().__init__()
61
+
62
+ assert attention_query_mask.dtype == torch.bool, "A boolean mask is expected"
63
+ assert (
64
+ attention_query_mask.shape[1] == 1
65
+ and attention_query_mask.shape[0] > attention_query_mask.shape[1]
66
+ ), "A N x 1 query mask is expected"
67
+
68
+ self.attn_drop = nn.Dropout(dropout, inplace=False)
69
+ self.attention_mask = global_token_pattern(attention_query_mask[:, 0])
70
+ self.force_sparsity = force_sparsity
71
+
72
+ if causal:
73
+ self.attention_mask &= causal_1d_pattern(attention_query_mask.shape[1])
74
+
75
+ self.attention_mask = (
76
+ sparsify(self.attention_mask)
77
+ if self.force_sparsity
78
+ else maybe_sparsify(self.attention_mask)
79
+ )
80
+
81
+ # Properties specific to this attention mechanism
82
+ self.requires_same_k_q_dimensions = True
83
+ self.supports_attention_mask = False
84
+ self.supports_key_padding_mask = False
85
+
86
+ def forward(
87
+ self,
88
+ q: torch.Tensor,
89
+ k: torch.Tensor,
90
+ v: torch.Tensor,
91
+ att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
92
+ *_,
93
+ **__,
94
+ ):
95
+ # Make sure that the mask is on the right device
96
+ if self.attention_mask.device != q.device:
97
+ self.attention_mask = self.attention_mask.to(q.device)
98
+
99
+ # Mask-aware attention
100
+ if att_mask is not None:
101
+ if att_mask.dtype == torch.bool and isinstance(
102
+ self.attention_mask, AttentionMask
103
+ ):
104
+ if not isinstance(att_mask, AttentionMask):
105
+ att_mask = AttentionMask.from_bool(att_mask)
106
+ mask = self.attention_mask + att_mask
107
+ else:
108
+ mask = self.attention_mask & att_mask
109
+ else:
110
+ mask = self.attention_mask
111
+
112
+ # Handle q/k/v which would not fit the mask
113
+ seq_len = q.shape[-2]
114
+ q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v))
115
+
116
+ # Normal attention with the global tokens mask
117
+ att = scaled_dot_product_attention(
118
+ q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop
119
+ )
120
+
121
+ # Take into account an hypothetical padding
122
+ return att[:, :seq_len, :]
.venv/lib/python3.11/site-packages/xformers/components/attention/linformer.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from xformers.components.attention import Attention, AttentionConfig, register_attention
14
+ from xformers.components.attention.core import scaled_dot_product_attention
15
+
16
+
17
+ @dataclass
18
+ class LinformerSelfAttentionConfig(AttentionConfig):
19
+ seq_len: int # dimension of the input sequence
20
+ k: Optional[int] # dimension of the internal space
21
+
22
+
23
+ @register_attention("linformer", LinformerSelfAttentionConfig)
24
+ class LinformerAttention(Attention):
25
+ def __init__(
26
+ self, dropout: float, seq_len: int, k: Optional[int] = None, *args, **kwargs
27
+ ):
28
+ """
29
+ Linformer attention mechanism,
30
+ from `Linformer: Self-Attention with Linear Complexity`_, Wang et al (2020).
31
+ The original notation is kept as is.
32
+
33
+ .. _`Linformer: Self-Attention with Linear Complexity` : https://arxiv.org/abs/2006.04768v2
34
+ """
35
+ super().__init__()
36
+
37
+ if k is None:
38
+ k = seq_len // 4
39
+
40
+ self.k = k
41
+ self.E = nn.Linear(seq_len, k, bias=False)
42
+ self.F = nn.Linear(seq_len, k, bias=False)
43
+ self.attn_drop = nn.Dropout(dropout, inplace=False)
44
+ self.seq_len = seq_len
45
+
46
+ # MHA related flags:
47
+ # kq need to have the same dimension
48
+ self.requires_same_k_q_dimensions = True
49
+
50
+ # This attention does not support attention masks
51
+ self.supports_attention_mask = False
52
+
53
+ def forward(
54
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
55
+ ):
56
+ # Handle a smaller dimension than expected
57
+ padding = 0
58
+ if q.shape[1] < self.seq_len:
59
+ padding = self.seq_len - q.shape[1]
60
+ pad_dims = (0, 0, 0, padding)
61
+ q = torch.nn.functional.pad(q, pad_dims)
62
+ k = torch.nn.functional.pad(k, pad_dims)
63
+ v = torch.nn.functional.pad(v, pad_dims)
64
+
65
+ k_projected = self.E(k.transpose(-2, -1)).transpose(-2, -1)
66
+ v_projected = self.F(v.transpose(-2, -1)).transpose(-2, -1)
67
+
68
+ y = scaled_dot_product_attention(
69
+ q=q, k=k_projected, v=v_projected, att_mask=None, dropout=self.attn_drop
70
+ )
71
+
72
+ y = self.attn_drop(y)
73
+
74
+ return y[:, :-padding, :] if padding > 0 else y
.venv/lib/python3.11/site-packages/xformers/components/attention/ortho.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import logging
8
+ from dataclasses import dataclass
9
+ from enum import Enum
10
+ from typing import Optional, Union
11
+
12
+ import torch
13
+ import torch.autograd.profiler as profiler
14
+ import torch.nn as nn
15
+ import torch.nn.functional as Fn
16
+
17
+ from xformers.components.attention import (
18
+ Attention,
19
+ AttentionConfig,
20
+ AttentionMask,
21
+ register_attention,
22
+ )
23
+ from xformers.components.attention.core import (
24
+ scaled_dot_product_attention,
25
+ scaled_query_key_softmax,
26
+ )
27
+
28
+ logger = logging.getLogger("xformers")
29
+
30
+
31
+ class LandmarkSelection(str, Enum):
32
+ Orthogonal = "orthogonal"
33
+ KMeans = "kmeans"
34
+ KMeans_Spherical = "kmeans_spherical"
35
+ Random = "random"
36
+
37
+
38
+ @dataclass
39
+ class OrthoformerAttentionConfig(AttentionConfig):
40
+ """
41
+ num_landmarks Number of landmarks to use for softmax approximation.
42
+ subsample_fraction Percentage of q_samples matrix to sample per iteration
43
+ landmark_selection Landmark selection strategy
44
+ """
45
+
46
+ num_landmarks: Optional[int]
47
+ subsample_fraction: Optional[float]
48
+ landmark_selection: Optional[LandmarkSelection]
49
+
50
+
51
+ @register_attention("orthoformer", OrthoformerAttentionConfig)
52
+ class OrthoFormerAttention(Attention):
53
+ def __init__(
54
+ self,
55
+ dropout: float,
56
+ num_landmarks: int = 32,
57
+ subsample_fraction: float = 1.0,
58
+ landmark_selection: LandmarkSelection = LandmarkSelection.Orthogonal,
59
+ *args,
60
+ **kwargs,
61
+ ):
62
+ """
63
+ Orthoformer_ attention mechanism.
64
+ ::
65
+
66
+ "Keeping Your Eye on the Ball: Trajectory Attention in Video Transformers"
67
+ Patrick, M., Campbell, D., Asano, Y., Misra, I., Metze, F., Feichtenhofer,
68
+ C., Vedaldi, A., Henriques, J. (2021)
69
+
70
+ Reference codebase: https://github.com/facebookresearch/Motionformer
71
+
72
+ .. _Orthoformer: https://arxiv.org/abs/2106.05392
73
+
74
+ """
75
+ super().__init__()
76
+
77
+ self.num_landmarks = num_landmarks
78
+ self.attn_drop = nn.Dropout(dropout)
79
+ self.subsample_fraction = subsample_fraction
80
+ self.landmark_selection = landmark_selection
81
+
82
+ # Properties specific to this attention mechanism
83
+ self.supports_attention_mask = True
84
+ self.supports_key_padding_mask = False
85
+
86
+ def forward(
87
+ self,
88
+ q: torch.Tensor,
89
+ k: torch.Tensor,
90
+ v: torch.Tensor,
91
+ att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None,
92
+ *args,
93
+ **kwargs,
94
+ ):
95
+ N = k.shape[1]
96
+
97
+ if self.num_landmarks == N:
98
+ # Default attention
99
+ x = scaled_dot_product_attention(q, k, v, att_mask)
100
+ else:
101
+ with torch.no_grad(), profiler.record_function("select landmarks"):
102
+ if self.landmark_selection == LandmarkSelection.Orthogonal:
103
+ landmarks = self._compute_orthogonal_landmarks(q)
104
+ elif self.landmark_selection == LandmarkSelection.Random:
105
+ half_L = self.num_landmarks // 2
106
+ landmarks_q = q[:, torch.randint(q.size(1), (half_L,)), :]
107
+ landmarks_k = k[:, torch.randint(k.size(1), (half_L,)), :]
108
+ landmarks = torch.cat((landmarks_q, landmarks_k), dim=-2)
109
+ elif self.landmark_selection == LandmarkSelection.KMeans:
110
+ landmarks = self._cluster_landmarks(q)
111
+ elif self.landmark_selection == LandmarkSelection.KMeans_Spherical:
112
+ landmarks = self._cluster_landmarks(q, spherical=True)
113
+
114
+ if att_mask is not None:
115
+ logger.warning(
116
+ "Orthoformer: attention mask passed alongside with using landmarks to reduce dimensions. \
117
+ The two are typically not compatible"
118
+ )
119
+ # FIXME: Should we still accept a mask in that case ?
120
+ att_mask = None
121
+
122
+ # pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems
123
+ # like it could be uninitialized.
124
+ kernel_1 = scaled_query_key_softmax(q, landmarks, att_mask)
125
+ # pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems
126
+ # like it could be uninitialized.
127
+ kernel_2 = scaled_query_key_softmax(landmarks, k, att_mask)
128
+ x = torch.matmul(kernel_1, torch.matmul(kernel_2, v))
129
+ x = self.attn_drop(x)
130
+ return x
131
+
132
+ def _cluster_landmarks(
133
+ self,
134
+ q: torch.Tensor,
135
+ spherical: bool = False,
136
+ num_iters: int = 6,
137
+ ) -> torch.Tensor:
138
+ """
139
+ Construct set of landmarks by recursively selecting new landmarks
140
+ that are maximally orthogonal to the existing set.
141
+ Returns near orthogonal landmarks with shape (B, M, D).
142
+ """
143
+
144
+ num_landmarks = min(self.num_landmarks, q.shape[1])
145
+
146
+ if self.subsample_fraction < 1.0:
147
+ num_samples = max(
148
+ int(self.subsample_fraction * q.size(-2)), num_landmarks
149
+ ) # Need at least M/2 samples of queries and keys
150
+ q_samples = q[:, torch.randint(q.size(-2), (num_samples,)), :] # (B, N, D)
151
+ else:
152
+ q_samples = q # (B, N, D)
153
+
154
+ if spherical:
155
+ q_samples_normalized = Fn.normalize(
156
+ q_samples, p=2, dim=-1
157
+ ) # may need to change default eps to eps=1e-8 for mixed precision compatibility
158
+ landmarks = self._kmeans_spherical(
159
+ q_samples_normalized, num_landmarks, num_iters
160
+ )
161
+ else:
162
+ landmarks = self._kmeans(q_samples, num_landmarks, num_iters)
163
+ return landmarks # (B, M, D)
164
+
165
+ def _kmeans(self, x: torch.Tensor, K: int, num_iters: int = 10):
166
+ """
167
+ Arguments:
168
+ x: (B, N, D)
169
+ K: number of clusters
170
+ num_iters: the number of kmeans updates
171
+ """
172
+
173
+ B, N, D = x.size()
174
+ assert K <= N, f"{K} > {N}"
175
+
176
+ c = x[
177
+ :, torch.randperm(N, device=x.device)[:K], :
178
+ ].clone() # initialisation for the centroids
179
+
180
+ with profiler.record_function("kmeans"):
181
+ x_i = x.view(B, N, 1, D)
182
+ c_j = c.view(B, 1, K, D)
183
+ counts = c.new_zeros(B, K)
184
+ ones = x.new_ones((B, N))
185
+
186
+ for _ in range(num_iters):
187
+ # E step: assign points to the nearest cluster
188
+ D_ij = ((x_i - c_j) ** 2).sum(-1) # (B, N, K) squared distances
189
+ cl = D_ij.argmin(
190
+ dim=-1, keepdim=True
191
+ ).long() # (B, N, 1) index of point to nearest cluster
192
+
193
+ # M step: update the centroids
194
+ c.zero_()
195
+ c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster
196
+ counts.fill_(1e-6) # avoid div0
197
+ counts.scatter_add_(
198
+ -1, cl.squeeze(-1), ones
199
+ ) # number of points per cluster
200
+ c.divide_(counts.unsqueeze(-1)) # compute the average
201
+
202
+ return c
203
+
204
+ def _kmeans_spherical(self, x: torch.Tensor, K: int, num_iters=10):
205
+ """
206
+ Arguments:
207
+ x: (B, N, D)
208
+ """
209
+ B, N, D = x.size()
210
+ assert K <= N, f"{K} > {N}"
211
+
212
+ # initialisation for the centroids
213
+ c = x[:, torch.randperm(N, device=x.device)[:K], :].clone()
214
+
215
+ with profiler.record_function("kmeans_spherical"):
216
+ counts = c.new_zeros(B, K)
217
+ ones = x.new_ones((B, N))
218
+
219
+ for _ in range(num_iters):
220
+ # E step: assign points to the nearest cluster
221
+ D_ij = torch.matmul(
222
+ x, c.transpose(-2, -1)
223
+ ) # (B, N, K) cosine similarity
224
+ cl = D_ij.argmax(
225
+ dim=-1, keepdim=True
226
+ ).long() # (B, N, 1) index of point to nearest cluster
227
+
228
+ # M step: update the centroids
229
+ c.zero_()
230
+ c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster
231
+ counts.fill_(1e-6) # avoid div0
232
+ counts.scatter_add_(
233
+ -1, cl.squeeze(-1), ones
234
+ ) # number of points per cluster
235
+ c.divide_(counts.unsqueeze(-1)) # compute the average
236
+ c = Fn.normalize(c, p=2, dim=-1) # renormalise
237
+ return c
238
+
239
+ def _compute_orthogonal_landmarks(self, q: torch.Tensor) -> torch.Tensor:
240
+ """
241
+ Construct set of landmarks by recursively selecting new landmarks
242
+ that are maximally orthogonal to the existing set.
243
+ Returns near orthogonal landmarks with shape (B, M, D).
244
+ """
245
+
246
+ if self.subsample_fraction < 1.0:
247
+ # Need at least M samples of queries
248
+ num_samples = max(
249
+ int(self.subsample_fraction * q.size(-2)), self.num_landmarks
250
+ )
251
+ q_samples = q[
252
+ :, torch.randint(q.size(-2), (num_samples,), device=q.device), :
253
+ ]
254
+ else:
255
+ # (B, N, D)
256
+ q_samples = q
257
+
258
+ # may need to change default eps to eps=1e-8 for mixed precision compatibility
259
+ q_samples_normalized = Fn.normalize(q_samples, p=2, dim=-1)
260
+ B, N, D = q_samples_normalized.shape
261
+
262
+ selected_mask = torch.zeros((B, N, 1), device=q_samples_normalized.device)
263
+ landmark_mask = torch.ones(
264
+ (B, 1, 1), dtype=selected_mask.dtype, device=q_samples_normalized.device
265
+ )
266
+
267
+ #  Get initial random landmark
268
+ random_idx = torch.randint(
269
+ q_samples_normalized.size(-2), (B, 1, 1), device=q_samples_normalized.device
270
+ )
271
+ selected_mask.scatter_(-2, random_idx, landmark_mask)
272
+
273
+ #  Selected landmarks
274
+ selected_landmarks = torch.empty(
275
+ (B, self.num_landmarks, D),
276
+ device=q_samples_normalized.device,
277
+ dtype=q_samples_normalized.dtype,
278
+ )
279
+ selected_landmarks[:, 0, :] = q_samples_normalized[
280
+ torch.arange(q_samples_normalized.size(0)), random_idx.view(-1), :
281
+ ].view(B, D)
282
+
283
+ # Store computed cosine similarities
284
+ cos_sims = torch.empty(
285
+ (B, N, self.num_landmarks),
286
+ device=q_samples_normalized.device,
287
+ dtype=q_samples_normalized.dtype,
288
+ )
289
+
290
+ for M in range(1, self.num_landmarks):
291
+ with profiler.record_function("find new landmark"):
292
+ #  Calculate absolute cosine similarity between selected and unselected landmarks
293
+ # (B, N, D) * (B, D) -> (B, N)
294
+ cos_sims[:, :, M - 1] = torch.einsum(
295
+ "b n d, b d -> b n",
296
+ q_samples_normalized,
297
+ selected_landmarks[:, M - 1, :],
298
+ ).abs()
299
+
300
+ # (B, N, M) cosine similarities of current set of landmarks wrt all queries and keys
301
+ cos_sim_set = cos_sims[:, :, :M]
302
+
303
+ #  Get orthogonal landmark: landmark with smallest absolute cosine similarity:
304
+ # set cosine similarity for already selected landmarks to > 1
305
+ cos_sim_set.view(-1, M)[selected_mask.flatten().bool(), :] = 10
306
+
307
+ # (B,) - want max for non
308
+ selected_landmark_idx = cos_sim_set.amax(-1).argmin(-1)
309
+
310
+ #  Add most orthogonal landmark to selected landmarks:
311
+ selected_landmarks[:, M, :] = q_samples_normalized[
312
+ torch.arange(q_samples_normalized.size(0)), selected_landmark_idx, :
313
+ ].view(B, D)
314
+
315
+ #  Removed selected indices from non-selected mask:
316
+ selected_mask.scatter_(
317
+ -2, selected_landmark_idx.unsqueeze(-1).unsqueeze(-1), landmark_mask
318
+ )
319
+
320
+ # (B, M, D)
321
+ landmarks = torch.masked_select(q_samples, selected_mask.bool()).reshape(
322
+ B, -1, D
323
+ )
324
+ return landmarks # (B, M, D)
.venv/lib/python3.11/site-packages/xformers/components/attention/pooling.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import math
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from xformers.components.attention import Attention, AttentionConfig, register_attention
15
+
16
+
17
+ @dataclass
18
+ class PoolingAttentionConfig(AttentionConfig):
19
+ pool_size: int # dimension of the input sequence
20
+ stride: Optional[int] # dimension of the internal space
21
+ padding: Optional[int]
22
+
23
+
24
+ @register_attention("pooling", PoolingAttentionConfig)
25
+ class Pooling(Attention):
26
+ def __init__(
27
+ self,
28
+ pool_size: int = 3,
29
+ stride: int = 1,
30
+ padding: Optional[int] = None,
31
+ *_,
32
+ **__,
33
+ ):
34
+ """
35
+ Pooling token mixing mechanism, as proposed in
36
+ `Metaformer is actually what you need for vision`_, Yu et al (2021).
37
+
38
+ The original notation is kept as is.
39
+
40
+ .. _`Metaformer is actually what you need for vision` : https://arxiv.org/pdf/2111.11418v1.pdf
41
+ """
42
+ super().__init__()
43
+
44
+ padding = padding if padding is not None else pool_size // 2
45
+ self.pool = nn.AvgPool2d(
46
+ pool_size,
47
+ stride=stride,
48
+ padding=pool_size // 2,
49
+ count_include_pad=False,
50
+ )
51
+
52
+ # MHA related flags:
53
+ # kq need to have the same dimension
54
+ self.requires_same_k_q_dimensions = False
55
+
56
+ # This attention does not support attention masks
57
+ self.supports_attention_mask = False
58
+
59
+ # This "attention" (token mixing) skips the multihead attention altogether
60
+ self.requires_skip_multi_head = True
61
+ self.requires_input_projection = False
62
+
63
+ # This operator does not really handle q,k,v
64
+ self.requires_same_k_q_dimensions = True
65
+
66
+ # This attention requires the 2d structure out of the context,
67
+ # implictly assumed to be a squared length
68
+ self.requires_squared_context = True
69
+
70
+ def forward(self, q: torch.Tensor, *_, **__):
71
+ # Expose the 2D token structure
72
+ B, HW, C = q.shape
73
+ H = int(math.sqrt(HW))
74
+ assert H * H == HW
75
+
76
+ q = q.transpose(-2, -1).reshape(B, C, H, H)
77
+
78
+ # 2D pool
79
+ x_pool = self.pool(q) - q # compensate for the residual path
80
+
81
+ # Get back to B HW C
82
+ return x_pool.flatten(2, 3).transpose(-2, -1)
.venv/lib/python3.11/site-packages/xformers/components/attention/sparsity_config.py ADDED
@@ -0,0 +1,812 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ The code has been adopted from DeepSpeed
7
+ (https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/sparse_attention/sparsity_config.py)
8
+ """
9
+
10
+ import random
11
+
12
+ import torch
13
+
14
+
15
+ class SparsityConfig:
16
+ """Abstract Configuration class to store `sparsity configuration of a self attention layer`.
17
+ It contains shared property of different block-sparse sparsity patterns. However, each class
18
+ needs to extend it based on required property and functionality.
19
+ """
20
+
21
+ def __init__(self, num_heads, block_size=16, different_layout_per_head=False):
22
+ """Initialize the Sparsity Pattern Config.
23
+ Arguments:
24
+ num_heads: required: an integer determining number of attention heads of the layer.
25
+ block_size: optional: an integer determining the block size. Current implementation of
26
+ sparse self-attention is based on blocked sparse matrices. In which this parameter
27
+ defines size of such blocks, `Block X Block`.
28
+ different_layout_per_head: optional: a boolean determining if each head should be
29
+ assigned a different sparsity layout; default is false and this will be satisfied
30
+ based on availability.
31
+ """
32
+
33
+ self.num_heads = num_heads
34
+ self.block_size = block_size
35
+ self.different_layout_per_head = different_layout_per_head
36
+ self.num_layout_heads = num_heads if different_layout_per_head else 1
37
+
38
+ def setup_layout(self, seq_len):
39
+ """Create layout tensor for the given sequence length
40
+ Arguments:
41
+ seq_len: required: an integer determining number of attention heads of the layer.
42
+ Return:
43
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) for sparsity layout
44
+ of all head; initialized with zero
45
+ """
46
+
47
+ if seq_len % self.block_size != 0:
48
+ raise ValueError(
49
+ f"Sequence Length, {seq_len}, needs to be dividable by Block size {self.block_size}!"
50
+ )
51
+ num_blocks = seq_len // self.block_size
52
+ # TODO Currently we allocate layout per head; needs to be updated if heads share a single layout.
53
+ layout = torch.zeros(
54
+ (self.num_heads, num_blocks, num_blocks), dtype=torch.int64
55
+ )
56
+ return layout
57
+
58
+ def check_and_propagate_first_head_layout(self, layout):
59
+ """If all heads require same sparsity layout, it propagate first head layout to all heads
60
+ Arguments:
61
+ layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
62
+ sparsity layout of all head; may not be completely set at this step
63
+ Return:
64
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
65
+ layout of all head
66
+ """
67
+
68
+ if not self.different_layout_per_head:
69
+ layout[1 : self.num_heads, :, :] = layout[0, :, :]
70
+ return layout
71
+
72
+
73
+ class DenseSparsityConfig(SparsityConfig):
74
+ """Configuration class to store `Dense` configuration.
75
+ In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and
76
+ comprehension.
77
+ """
78
+
79
+ def __init__(self, num_heads, block_size=16, different_layout_per_head=False):
80
+ """Initialize the Dense Sparsity Pattern Config.
81
+ In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison
82
+ and comprehension.
83
+ Arguments:
84
+ num_heads: required: an integer determining number of attention heads of the layer.
85
+ block_size: optional: an integer determining the block size. Current implementation of
86
+ sparse self-attention is based on blocked sparse matrices. In which this parameter
87
+ defines size of such blocks, `Block X Block`.
88
+ different_layout_per_head: optional: this is just for the sake of consistency with
89
+ other sparsity formats; can ignore it for DenseSparsityConfig
90
+ """
91
+
92
+ super().__init__(num_heads, block_size, different_layout_per_head)
93
+
94
+ def make_layout(self, seq_len):
95
+ """Set 1 to all blocks of the layout meanins the pattern is dense; not sparse.
96
+ Arguments:
97
+ seq_len: required: an integer determining the underling sequence length;
98
+ must be <= max sequence length
99
+ Return:
100
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
101
+ layout of all head; for dense everything is 1
102
+ """
103
+
104
+ layout = self.setup_layout(seq_len)
105
+ layout[:, :, :] = 1
106
+ return layout
107
+
108
+
109
+ class FixedSparsityConfig(SparsityConfig):
110
+ """Configuration class to store `Fixed` sparsity configuration.
111
+ For more details about this sparsity config, please see `Generative Modeling with
112
+ Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
113
+ This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ num_heads,
119
+ block_size=16,
120
+ different_layout_per_head=False,
121
+ num_local_blocks=4,
122
+ num_global_blocks=1,
123
+ attention="bidirectional",
124
+ horizontal_global_attention=False,
125
+ num_different_global_patterns=1,
126
+ ):
127
+ """Initialize `Fixed` Sparsity Pattern Config.
128
+ For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
129
+ Arguments:
130
+ num_heads: required: an integer determining number of attention heads of the layer.
131
+ block_size: optional: an integer determining the block size. Current implementation of
132
+ sparse self-attention is based on blocked sparse matrices. In which this parameter
133
+ defines size of such blocks, `Block X Block`.
134
+ different_layout_per_head: optional: a boolean determining if each head should be
135
+ assigned a different sparsity layout; default is false and this will be satisfied
136
+ based on availability.
137
+ num_local_blocks: optional: an integer determining the number of blocks in local attention
138
+ window.
139
+ num_global_blocks: optional: an integer determining how many consecutive blocks in a local
140
+ window is used as the representative of the window for global attention.
141
+ attention: optional: a string determining attention type. Attention can be `unidirectional`,
142
+ such as autoregressive models, in which tokens attend only to tokens appear before them
143
+ in the context. Considering that, the upper triangular of attention matrix is empty as
144
+ above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
145
+ any other tokens before or after them. Then, the upper triangular part of the attention
146
+ matrix is mirror of the lower triangular in the above figure.
147
+ horizontal_global_attention: optional: a boolean determining if blocks that are global
148
+ representative of a local window, also attend to all other blocks. This is valid only if
149
+ attention type is `bidirectional`. Looking at the attention matrix, that means global
150
+ attention not only includes the vertical blocks, but also horizontal blocks.
151
+ num_different_global_patterns: optional: an integer determining number of different global
152
+ attentions layouts. While global attention can be fixed by which block/s are representative
153
+ of any local window, since there are multi-heads, each head can use a different global representative.
154
+ For example, with 4 blocks local window and global attention size of 1 block, we can have 4 different
155
+ versions in which the first, Second, third, or forth block of each local window can be global
156
+ representative of that window. This parameter determines how many of such patterns we want.
157
+ Of course, there is a limitation based on num_local_blocks and num_global_blocks.
158
+ """
159
+
160
+ super().__init__(num_heads, block_size, different_layout_per_head)
161
+
162
+ self.num_local_blocks = num_local_blocks
163
+
164
+ if num_local_blocks % num_global_blocks != 0:
165
+ raise ValueError(
166
+ f"""Number of blocks in a local window, {num_local_blocks},
167
+ must be dividable by number of global blocks, {num_global_blocks}!"""
168
+ )
169
+ self.num_global_blocks = num_global_blocks
170
+
171
+ if attention != "unidirectional" and attention != "bidirectional":
172
+ raise NotImplementedError(
173
+ 'only "uni/bi-directional" attentions are supported for now!'
174
+ )
175
+ self.attention = attention
176
+
177
+ if attention != "bidirectional" and horizontal_global_attention:
178
+ raise ValueError(
179
+ 'only "bi-directional" attentions can support horizontal global attention!'
180
+ )
181
+ self.horizontal_global_attention = horizontal_global_attention
182
+
183
+ if num_different_global_patterns > 1 and not different_layout_per_head:
184
+ raise ValueError(
185
+ """Number of different layouts cannot be more than one when you have set a single layout
186
+ for all heads! Set different_layout_per_head to True."""
187
+ )
188
+ if num_different_global_patterns > (num_local_blocks // num_global_blocks):
189
+ raise ValueError(
190
+ f"""Number of layout versions (num_different_global_patterns), {num_different_global_patterns},
191
+ cannot be larger than number of local window blocks divided by number of global blocks,
192
+ {num_local_blocks} / {num_global_blocks} = {num_local_blocks//num_global_blocks}!"""
193
+ )
194
+ self.num_different_global_patterns = num_different_global_patterns
195
+
196
+ def set_local_layout(self, h, layout):
197
+ """Sets local attention layout used by the given head in the sparse attention.
198
+ Arguments:
199
+ h: required: an integer determining head index
200
+ layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
201
+ sparsity layout of all head; may not be completely set at this step
202
+ Return:
203
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
204
+ layout of all head in which local layout is set
205
+ """
206
+
207
+ num_blocks = layout.shape[1]
208
+ for i in range(0, num_blocks, self.num_local_blocks):
209
+ end = min(i + self.num_local_blocks, num_blocks)
210
+ for row in range(i, end):
211
+ for col in range(
212
+ i, (row + 1 if self.attention == "unidirectional" else end)
213
+ ):
214
+ layout[h, row, col] = 1
215
+ return layout
216
+
217
+ def set_global_layout(self, h, layout):
218
+ """Sets global attention layout used by the given head in the sparse attention.
219
+ Currently we set global blocks starting from the last block of a local window to the first one.
220
+ That means if a local window consists of 4 blocks and global attention size is one block, we use
221
+ block #4 in each local window as global. If we have different layout per head, then other heads
222
+ will get #3, #2, and #1. And if we have more heads (and different layout has set) than num of global
223
+ attentions, multiple head may have same global attentions.
224
+ Note) if horizontal_global_attention is set, global blocks will be set both horizontally and
225
+ vertically.
226
+ Arguments:
227
+ h: required: an integer determining head index
228
+ layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
229
+ sparsity layout of all head; may not be completely set at this step
230
+ Return:
231
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
232
+ layout of all head in which global layout is set
233
+ """
234
+
235
+ num_blocks = layout.shape[1]
236
+ first_global_block_idx = (
237
+ self.num_local_blocks
238
+ - (1 + h % self.num_different_global_patterns) * self.num_global_blocks
239
+ )
240
+
241
+ # set all global blocks except the last one if (in last local window)
242
+ end = num_blocks - (num_blocks % self.num_local_blocks)
243
+ for i in range(first_global_block_idx, end, self.num_local_blocks):
244
+
245
+ # vertical global attention
246
+ first_row = 0 if self.attention == "bidirectional" else i
247
+ # (((i // self.num_local_blocks) + 1) * self.num_local_blocks)
248
+ # if (first_row < num_blocks):
249
+ layout[h, first_row:, i : i + self.num_global_blocks] = 1
250
+
251
+ # horizontal global attention; only in bidirectional attention
252
+ if self.horizontal_global_attention:
253
+ layout[h, i : i + self.num_global_blocks, :] = 1
254
+
255
+ # set last global blocks; handle possible short last local window
256
+ if end < num_blocks:
257
+ start = min(
258
+ end + first_global_block_idx, num_blocks - self.num_global_blocks
259
+ )
260
+ end = start + self.num_global_blocks
261
+
262
+ # vertical global attention
263
+ first_row = 0 if self.attention == "bidirectional" else start
264
+ # (((start // self.num_local_blocks) + 1) * self.num_local_blocks)
265
+ # if (first_row < num_blocks):
266
+ layout[h, first_row:, start:end] = 1
267
+
268
+ # horizontal global attention
269
+ if self.horizontal_global_attention:
270
+ layout[h, start:end, :] = 1
271
+ return layout
272
+
273
+ def make_layout(self, seq_len):
274
+ """Generates `Fixed` sparsity layout used by each head in the sparse attention.
275
+ Arguments:
276
+ seq_len: required: an integer determining number of attention heads of the layer.
277
+ Return:
278
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Fixed`
279
+ sparsity layout of all head
280
+ """
281
+
282
+ layout = self.setup_layout(seq_len)
283
+ for h in range(0, self.num_layout_heads):
284
+ layout = self.set_local_layout(h, layout)
285
+ layout = self.set_global_layout(h, layout)
286
+
287
+ layout = self.check_and_propagate_first_head_layout(layout)
288
+ return layout
289
+
290
+
291
+ class VariableSparsityConfig(SparsityConfig):
292
+ """Configuration class to store `Variable` sparsity configuration.
293
+ This layout is an extension of FixedSparsityConfig in which:
294
+ - user can set random layout; default value is zero means no random block
295
+ - user can provide a list of local block sizes
296
+ - user can provide a list of global block indices.
297
+ For more details about `Fixed` sparsity config, please see `Generative Modeling with
298
+ Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
299
+ This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
300
+ """
301
+
302
+ def __init__(
303
+ self,
304
+ num_heads,
305
+ block_size=16,
306
+ different_layout_per_head=False,
307
+ num_random_blocks=0,
308
+ local_window_blocks=[4],
309
+ global_block_indices=[0],
310
+ global_block_end_indices=None,
311
+ attention="bidirectional",
312
+ horizontal_global_attention=False,
313
+ ):
314
+ """Initialize `Variable` Sparsity Pattern Config.
315
+ For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
316
+ Arguments:
317
+ num_heads: required: an integer determining number of attention heads of the layer.
318
+ block_size: optional: an integer determining the block size. Current implementation of sparse
319
+ self-attention is based on blocked sparse matrices. In which this parameter defines
320
+ size of such blocks, `Block X Block`.
321
+ different_layout_per_head: optional: a boolean determining if each head should be assigned a
322
+ different sparsity layout; default is false and this will be satisfied based on
323
+ availability. Currently this sparsity config can only assign single layout to all heads;
324
+ needs to be extended for different layout per head.
325
+ num_random_blocks: optional: an integer determining the number of random blocks in each block row.
326
+ local_window_blocks: optional: a list of integers determining the number of blocks in each
327
+ local attention window. It assumes first number determines # of blocks in the first local
328
+ window, second the second window, ..., and the last number determines the number of blocks
329
+ in the remaining local windows.
330
+ global_block_indices: optional: a list of integers determining which blocks are considered
331
+ as global attention. Given indices, determine the blocks that all other token blocks
332
+ attend to and they attend to all other token blocks. Default value is only index 0.
333
+ Notice that if global_block_end_indices parameter is set, this parameter is used as
334
+ starting index of each global window.
335
+ global_block_end_indices: optional: a list of integers determining end indices of global
336
+ window blocks. By default this is not used. But if it is set, it must have the same size
337
+ of global_block_indices parameter, and combining this two parameters, for each index i,
338
+ blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are
339
+ considered as global attention.
340
+ attention: optional: a string determining attention type. Attention can be `unidirectional`,
341
+ such as autoregressive models, in which tokens attend only to tokens appear before them
342
+ in the context. Considering that, the upper triangular of attention matrix is empty as
343
+ above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
344
+ any other tokens before or after them. Then, the upper triangular part of the attention
345
+ matrix is mirror of the lower triangular in the above figure.
346
+ horizontal_global_attention: optional: a boolean determining if blocks that are global
347
+ representative of a local window, also attend to all other blocks. This is valid only if
348
+ attention type is `bidirectional`. Looking at the attention matrix, that means global
349
+ attention not only includes the vertical blocks, but also horizontal blocks.
350
+ """
351
+
352
+ super().__init__(num_heads, block_size, different_layout_per_head)
353
+
354
+ self.num_random_blocks = num_random_blocks
355
+ self.local_window_blocks = local_window_blocks
356
+ self.global_block_indices = global_block_indices
357
+
358
+ if global_block_end_indices is not None:
359
+ if len(global_block_indices) != len(global_block_end_indices):
360
+ raise ValueError(
361
+ f"""Global block start indices length, {len(global_block_indices)}, must be same as
362
+ global block end indices length, {len(global_block_end_indices)}!"""
363
+ )
364
+ for _, (start_idx, end_idx) in enumerate(
365
+ zip(global_block_indices, global_block_end_indices)
366
+ ):
367
+ if start_idx >= end_idx:
368
+ raise ValueError(
369
+ f"""Global block start index, {start_idx}, must be smaller than global block end
370
+ index, {end_idx}!"""
371
+ )
372
+ self.global_block_end_indices = global_block_end_indices
373
+
374
+ if attention != "unidirectional" and attention != "bidirectional":
375
+ raise NotImplementedError(
376
+ 'only "uni/bi-directional" attentions are supported for now!'
377
+ )
378
+ self.attention = attention
379
+
380
+ if attention != "bidirectional" and horizontal_global_attention:
381
+ raise ValueError(
382
+ 'only "bi-directional" attentions can support horizontal global attention!'
383
+ )
384
+ self.horizontal_global_attention = horizontal_global_attention
385
+
386
+ def set_random_layout(self, h, layout):
387
+ """Sets random attention layout used by the given head in the sparse attention.
388
+ Note) By default, it assumes there will be a unique random block layout for all heads; unless
389
+ `different_layout_per_head` parameter is set in which each head can have a different random
390
+ layout.
391
+ Arguments:
392
+ h: required: an integer determining head index
393
+ layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
394
+ sparsity layout of all head; may not be completely set at this step
395
+ Return:
396
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
397
+ layout of all head in which random layout is set
398
+ """
399
+
400
+ num_blocks = layout.shape[1]
401
+ if num_blocks < self.num_random_blocks:
402
+ raise ValueError(
403
+ f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number
404
+ of blocks in a row, {num_blocks}!"""
405
+ )
406
+ for row in range(0, num_blocks):
407
+ rnd_cols = random.sample(range(0, num_blocks), self.num_random_blocks)
408
+ layout[h, row, rnd_cols] = 1
409
+ return layout
410
+
411
+ def set_local_layout(self, h, layout):
412
+ """Sets local attention layout used by the given head in the sparse attention.
413
+ Arguments:
414
+ h: required: an integer determining head index
415
+ layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
416
+ sparsity layout of all head; may not be completely set at this step
417
+ Return:
418
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
419
+ layout of all head in which local layout is set
420
+ """
421
+
422
+ num_blocks = layout.shape[1]
423
+ start_block_idx = 0
424
+ end_block_idx = 0
425
+ for block_size in self.local_window_blocks:
426
+ end_block_idx += block_size
427
+ end_block_idx = min(end_block_idx, num_blocks)
428
+ for row in range(start_block_idx, end_block_idx):
429
+ for col in range(
430
+ start_block_idx,
431
+ (row + 1 if self.attention == "unidirectional" else end_block_idx),
432
+ ):
433
+ layout[h, row, col] = 1
434
+ start_block_idx += block_size
435
+
436
+ # if there is any remaining not attended part, use the lats local window block size as local
437
+ # window for the remaining applicable local windows
438
+ for i in range(start_block_idx, num_blocks, block_size):
439
+ end_block_idx = min(i + block_size, num_blocks)
440
+ for row in range(i, end_block_idx):
441
+ for col in range(
442
+ i,
443
+ (row + 1 if self.attention == "unidirectional" else end_block_idx),
444
+ ):
445
+ layout[h, row, col] = 1
446
+ return layout
447
+
448
+ def set_global_layout(self, h, layout):
449
+ """Sets global attention layout used by the given head in the sparse attention.
450
+ Arguments:
451
+ h: required: an integer determining head index
452
+ layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
453
+ sparsity layout of all head; may not be completely set at this step
454
+ Return:
455
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
456
+ layout of all head in which global layout is set
457
+ """
458
+
459
+ num_blocks = layout.shape[1]
460
+ if self.global_block_end_indices is None:
461
+ for idx in self.global_block_indices:
462
+ # if global block idx is in the range of the sequence blocks
463
+ if idx < num_blocks:
464
+ # global rows
465
+ if self.horizontal_global_attention:
466
+ layout[h, idx, :] = 1
467
+
468
+ # global columns
469
+ first_row = 0 if self.attention == "bidirectional" else idx
470
+ layout[h, first_row:, idx] = 1
471
+ else:
472
+ for _, (start_idx, end_idx) in enumerate(
473
+ zip(self.global_block_indices, self.global_block_end_indices)
474
+ ):
475
+ # if global block idx is in the range of the sequence blocks
476
+ if start_idx < num_blocks:
477
+ end_idx = min(end_idx, num_blocks)
478
+ # global rows
479
+ if self.horizontal_global_attention:
480
+ layout[h, start_idx:end_idx, :] = 1
481
+
482
+ # global columns
483
+ first_row = 0 if self.attention == "bidirectional" else start_idx
484
+ layout[h, first_row:, start_idx:end_idx] = 1
485
+ return layout
486
+
487
+ def make_layout(self, seq_len):
488
+ """Generates `Variable` sparsity layout used by each head in the sparse attention.
489
+ Arguments:
490
+ seq_len: required: an integer determining number of attention heads of the layer.
491
+ Return:
492
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Variable`
493
+ sparsity layout of all head
494
+ """
495
+
496
+ layout = self.setup_layout(seq_len)
497
+ for h in range(0, self.num_layout_heads):
498
+ layout = self.set_random_layout(h, layout)
499
+ layout = self.set_local_layout(h, layout)
500
+ layout = self.set_global_layout(h, layout)
501
+
502
+ layout = self.check_and_propagate_first_head_layout(layout)
503
+ return layout
504
+
505
+
506
+ class BigBirdSparsityConfig(SparsityConfig):
507
+ """Configuration class to store `BigBird` sparsity configuration.
508
+ For more details about this sparsity config, please see `Big Bird: Transformers for
509
+ Longer Sequences`: https://arxiv.org/pdf/2007.14062.pdf
510
+ This class extends parent class of `SparsityConfig` and customizes it for `BigBird` sparsity.
511
+ """
512
+
513
+ def __init__(
514
+ self,
515
+ num_heads,
516
+ block_size=16,
517
+ different_layout_per_head=False,
518
+ num_random_blocks=1,
519
+ num_sliding_window_blocks=3,
520
+ num_global_blocks=1,
521
+ attention="bidirectional",
522
+ ):
523
+ """Initialize the BigBird Sparsity Pattern Config.
524
+ For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
525
+ Arguments:
526
+ num_heads: required: an integer determining number of attention heads of the layer.
527
+ block_size: optional: an integer determining the block size. Current implementation of
528
+ sparse self-attention is based on blocked sparse matrices. In which this parameter
529
+ defines size of such blocks, `Block X Block`.
530
+ different_layout_per_head: optional: a boolean determining if each head should be assigned
531
+ a different sparsity layout; default is false and this will be satisfied based on
532
+ availability.
533
+ num_random_blocks: optional: an integer determining the number of random blocks in each
534
+ block row.
535
+ num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding
536
+ local attention window.
537
+ num_global_blocks: optional: an integer determining how many consecutive blocks, starting
538
+ from index 0, are considered as global attention. Global block tokens will be attended
539
+ by all other block tokens and will attend to all other block tokens as well.
540
+ attention: optional: a string determining attention type. Attention can be `unidirectional`,
541
+ such as autoregressive models, in which tokens attend only to tokens appear before them
542
+ in the context. Considering that, the upper triangular of attention matrix is empty as
543
+ above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
544
+ any other tokens before or after them. Then, the upper triangular part of the attention
545
+ matrix is mirror of the lower triangular in the above figure.
546
+ """
547
+
548
+ super().__init__(num_heads, block_size, different_layout_per_head)
549
+
550
+ self.num_random_blocks = num_random_blocks
551
+ self.num_sliding_window_blocks = num_sliding_window_blocks
552
+ self.num_global_blocks = num_global_blocks
553
+
554
+ if attention != "unidirectional" and attention != "bidirectional":
555
+ raise NotImplementedError(
556
+ 'only "uni/bi-directional" attentions are supported for now!'
557
+ )
558
+ self.attention = attention
559
+
560
+ def set_random_layout(self, h, layout):
561
+ """Sets random attention layout used by the given head in the sparse attention.
562
+ Note) By default, it assumes there will be a unique random block layout for all heads; unless
563
+ `different_layout_per_head` parameter is set in which each head can have a different random layout.
564
+ Arguments:
565
+ h: required: an integer determining head index
566
+ layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
567
+ sparsity layout of all head; may not be completely set at this step
568
+ Return:
569
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
570
+ layout of all head in which random layout is set
571
+ """
572
+
573
+ num_blocks = layout.shape[1]
574
+ if num_blocks < self.num_random_blocks:
575
+ raise ValueError(
576
+ f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number
577
+ of blocks in a row, {num_blocks}!"""
578
+ )
579
+
580
+ for row in range(0, num_blocks):
581
+ sample_range = (
582
+ range(0, num_blocks)
583
+ if self.attention == "bidirectional"
584
+ else range(0, row + 1)
585
+ )
586
+ rnd_cols = random.sample(sample_range, self.num_random_blocks)
587
+ layout[h, row, rnd_cols] = 1
588
+ return layout
589
+
590
+ def set_sliding_window_layout(self, h, layout):
591
+ """Sets sliding local attention layout used by the given head in the sparse attention.
592
+ Arguments:
593
+ h: required: an integer determining head index
594
+ layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
595
+ sparsity layout of all head; may not be completely set at this step
596
+ Return:
597
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
598
+ layout of all head in which local sliding window layout is set
599
+ """
600
+
601
+ num_blocks = layout.shape[1]
602
+ if num_blocks < self.num_sliding_window_blocks:
603
+ raise ValueError(
604
+ f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than
605
+ overall number of blocks in a row, {num_blocks}!"""
606
+ )
607
+
608
+ w = self.num_sliding_window_blocks // 2
609
+ for row in range(0, num_blocks):
610
+ start = max(0, row - w)
611
+ end = min(row + w + 1, num_blocks)
612
+ layout[h, row, start:end] = 1
613
+ return layout
614
+
615
+ def set_global_layout_itc(self, h, layout):
616
+ """Sets global attention layout used by the given head in the sparse attention.
617
+ Arguments:
618
+ h: required: an integer determining head index
619
+ layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
620
+ sparsity layout of all head; may not be completely set at this step
621
+ Return:
622
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout
623
+ of all head in which global layout is set
624
+ """
625
+
626
+ num_blocks = layout.shape[1]
627
+ if num_blocks < self.num_global_blocks:
628
+ raise ValueError(
629
+ f"""Number of global blocks, {self.num_global_blocks}, must be smaller than overall number
630
+ of blocks in a row, {num_blocks}!"""
631
+ )
632
+
633
+ # global rows
634
+ layout[h, 0 : self.num_global_blocks, :] = 1
635
+
636
+ # global columns
637
+ layout[h, :, 0 : self.num_global_blocks] = 1
638
+
639
+ if self.attention == "unidirectional":
640
+ # zero out anything attending to the future
641
+ layout = torch.tril(layout)
642
+
643
+ return layout
644
+
645
+ def make_layout(self, seq_len):
646
+ """Generates `BigBird` sparsity layout used by each head in the sparse attention.
647
+ Arguments:
648
+ seq_len: required: an integer determining number of attention heads of the layer.
649
+ Return:
650
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BigBird`
651
+ sparsity layout of all head
652
+ """
653
+
654
+ layout = self.setup_layout(seq_len)
655
+ for h in range(0, self.num_layout_heads):
656
+ layout = self.set_random_layout(h, layout)
657
+ layout = self.set_sliding_window_layout(h, layout)
658
+ layout = self.set_global_layout_itc(h, layout)
659
+
660
+ layout = self.check_and_propagate_first_head_layout(layout)
661
+ return layout
662
+
663
+
664
+ class BSLongformerSparsityConfig(SparsityConfig):
665
+ """Configuration class to store edited `Longformer` sparsity configuration.
666
+ Note) this is a block-sparse version of the Longformer which is slightly different than original
667
+ Longformer; which is element-wise sparsity.
668
+ For more details about this sparsity config, please see `Longformer:
669
+ The Long-Document Transformer`: https://arxiv.org/pdf/2004.05150.pdf
670
+ This class extends parent class of `SparsityConfig` and customizes it for `Longformer` sparsity.
671
+ """
672
+
673
+ def __init__(
674
+ self,
675
+ num_heads,
676
+ block_size=16,
677
+ different_layout_per_head=False,
678
+ num_sliding_window_blocks=3,
679
+ global_block_indices=[0],
680
+ global_block_end_indices=None,
681
+ attention="bidirectional",
682
+ ):
683
+ """Initialize the edited `Longformer` Sparsity Pattern Config.
684
+ For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
685
+ Arguments:
686
+ num_heads: required: an integer determining number of attention heads of the layer.
687
+ block_size: optional: an integer determining the block size. Current implementation of sparse
688
+ self-attention is based on blocked sparse matrices. In which this parameter defines size
689
+ of such blocks, `Block X Block`.
690
+ different_layout_per_head: optional: a boolean determining if each head should be assigned a
691
+ different sparsity layout; default is false and this will be satisfied based on
692
+ availability.
693
+ num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding
694
+ local attention window.
695
+ global_block_indices: optional: a list of integers determining which blocks are considered
696
+ as global attention. Given indices, determine the blocks that all other token blocks
697
+ attend to and they attend to all other token blocks. Default value is only index 0.
698
+ Notice that if global_block_end_indices parameter is set, this parameter is used as
699
+ starting index of each global window.
700
+ global_block_end_indices: optional: a list of integers determining end indices of global
701
+ window blocks. By default this is not used. But if it is set, it must have the same size
702
+ of global_block_indices parameter, and combining this two parameters, for each index i,
703
+ blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are
704
+ considered as global attention.
705
+ attention: optional: a string determining attention type. Attention can be `unidirectional`,
706
+ such as autoregressive models, in which tokens attend only to tokens appear before them
707
+ in the context. Considering that, the upper triangular of attention matrix is empty as
708
+ above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
709
+ any other tokens before or after them. Then, the upper triangular part of the attention
710
+ matrix is mirror of the lower triangular in the above figure.
711
+ """
712
+
713
+ super().__init__(num_heads, block_size, different_layout_per_head)
714
+
715
+ self.num_sliding_window_blocks = num_sliding_window_blocks
716
+ self.global_block_indices = global_block_indices
717
+ self.attention = attention
718
+
719
+ if global_block_end_indices is not None:
720
+ if len(global_block_indices) != len(global_block_end_indices):
721
+ raise ValueError(
722
+ f"""Global block start indices length, {len(global_block_indices)}, must be same as
723
+ global block end indices length, {len(global_block_end_indices)}!"""
724
+ )
725
+ for _, (start_idx, end_idx) in enumerate(
726
+ zip(global_block_indices, global_block_end_indices)
727
+ ):
728
+ if start_idx >= end_idx:
729
+ raise ValueError(
730
+ f"""Global block start index, {start_idx}, must be smaller than global block end
731
+ index, {end_idx}!"""
732
+ )
733
+ self.global_block_end_indices = global_block_end_indices
734
+
735
+ def set_sliding_window_layout(self, h, layout):
736
+ """Sets sliding local attention layout used by the given head in the sparse attention.
737
+ Arguments:
738
+ h: required: an integer determining head index
739
+ layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
740
+ sparsity layout of all head; may not be completely set at this step
741
+ Return:
742
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout
743
+ of all head in which local sliding window layout is set
744
+ """
745
+
746
+ num_blocks = layout.shape[1]
747
+ if num_blocks < self.num_sliding_window_blocks:
748
+ raise ValueError(
749
+ f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller
750
+ than overall number of blocks in a row, {num_blocks}!"""
751
+ )
752
+
753
+ w = self.num_sliding_window_blocks // 2
754
+ for row in range(0, num_blocks):
755
+ start = max(0, row - w)
756
+ end = min(row + w + 1, num_blocks)
757
+ layout[h, row, start:end] = 1
758
+ return layout
759
+
760
+ def set_global_layout(self, h, layout):
761
+ """Sets global attention layout used by the given head in the sparse attention.
762
+ Arguments:
763
+ h: required: an integer determining head index
764
+ layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
765
+ sparsity layout of all head; may not be completely set at this step
766
+ Return:
767
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
768
+ layout of all head in which global layout is set
769
+ """
770
+
771
+ num_blocks = layout.shape[1]
772
+ if self.global_block_end_indices is None:
773
+ for idx in self.global_block_indices:
774
+ # if global block idx is in the range of the sequence blocks
775
+ if idx < num_blocks:
776
+ # global rows
777
+ layout[h, idx, :] = 1
778
+
779
+ # global columns
780
+ layout[h, :, idx] = 1
781
+ else:
782
+ for _, (start_idx, end_idx) in enumerate(
783
+ zip(self.global_block_indices, self.global_block_end_indices)
784
+ ):
785
+ # if global block idx is in the range of the sequence blocks
786
+ if start_idx < num_blocks:
787
+ end_idx = min(end_idx, num_blocks)
788
+ # global rows
789
+ layout[h, start_idx:end_idx, :] = 1
790
+
791
+ # global columns
792
+ layout[h, :, start_idx:end_idx] = 1
793
+ if self.attention == "unidirectional":
794
+ layout = torch.tril(layout)
795
+ return layout
796
+
797
+ def make_layout(self, seq_len):
798
+ """Generates edited `Longformer` sparsity layout used by each head in the sparse attention.
799
+ Arguments:
800
+ seq_len: required: an integer determining number of attention heads of the layer.
801
+ Return:
802
+ layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BSLongformer`
803
+ sparsity layout of all head
804
+ """
805
+
806
+ layout = self.setup_layout(seq_len)
807
+ for h in range(0, self.num_layout_heads):
808
+ layout = self.set_sliding_window_layout(h, layout)
809
+ layout = self.set_global_layout(h, layout)
810
+
811
+ layout = self.check_and_propagate_first_head_layout(layout)
812
+ return layout
.venv/lib/python3.11/site-packages/xformers/components/attention/utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+
11
+
12
+ # Reshapes key padding mask from (batch_size, src_len) -> (batch_size * num_heads 1, src_len)
13
+ def reshape_key_padding_mask(
14
+ key_padding_mask: torch.Tensor, batched_dim: int
15
+ ) -> torch.Tensor:
16
+ assert key_padding_mask.ndim == 2
17
+ batch_size, src_len = key_padding_mask.size()
18
+ num_heads = batched_dim // batch_size
19
+ return _reshape_key_padding_mask(key_padding_mask, batch_size, src_len, num_heads)
20
+
21
+
22
+ def _reshape_key_padding_mask(
23
+ key_padding_mask: torch.Tensor, batch_size: int, src_len: int, num_heads: int
24
+ ) -> torch.Tensor:
25
+ assert key_padding_mask.shape == (batch_size, src_len)
26
+ key_padding_mask = (
27
+ key_padding_mask.view(batch_size, 1, 1, src_len)
28
+ .expand(-1, num_heads, -1, -1)
29
+ .reshape(batch_size * num_heads, 1, src_len)
30
+ )
31
+ return key_padding_mask
32
+
33
+
34
+ # Combine the attention mask and key padding mask into a single mask
35
+ # Taken from https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py
36
+ # Additive masking not yet supported
37
+ def maybe_merge_masks(
38
+ att_mask: Optional[torch.Tensor],
39
+ key_padding_mask: Optional[torch.Tensor],
40
+ batch_size: int,
41
+ src_len: int,
42
+ num_heads: int,
43
+ tgt_len: Optional[int] = None,
44
+ ) -> Optional[torch.Tensor]:
45
+ if tgt_len is None:
46
+ tgt_len = src_len
47
+ if key_padding_mask is not None:
48
+ assert key_padding_mask.shape == (batch_size, src_len)
49
+ key_padding_mask = _reshape_key_padding_mask(
50
+ key_padding_mask, batch_size, src_len, num_heads
51
+ )
52
+ if att_mask is None:
53
+ # make sure dimensions of key padding mask are the same as those expected for att_mask
54
+ att_mask = key_padding_mask.expand(-1, tgt_len, -1)
55
+ # Assumption is that False means to mask.
56
+ elif att_mask.dtype == torch.bool:
57
+ att_mask = att_mask.logical_and(key_padding_mask)
58
+ else:
59
+ att_mask = att_mask.masked_fill(~key_padding_mask, float("-inf"))
60
+
61
+ return att_mask
62
+
63
+
64
+ # Assumes that matrix passed in has had softmax applied to it.
65
+ def iterative_pinv(softmax_mat: torch.Tensor, n_iter=6, pinverse_original_init=False):
66
+ """
67
+ Computing the Moore-Penrose inverse.
68
+ Use an iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose inverse via efficient
69
+ matrix-matrix multiplications.
70
+ """
71
+
72
+ i = torch.eye(
73
+ softmax_mat.size(-1), device=softmax_mat.device, dtype=softmax_mat.dtype
74
+ )
75
+ k = softmax_mat
76
+
77
+ # The entries of K are positive and ||K||_{\infty} = 1 due to softmax
78
+ if pinverse_original_init:
79
+ # This original implementation is more conservative to compute coefficient of Z_0.
80
+ v = 1 / torch.max(torch.sum(k, dim=-2)) * k.transpose(-1, -2)
81
+ else:
82
+ # This is the exact coefficient computation, 1 / ||K||_1, of initialization of Z_0, leading to faster
83
+ # convergence.
84
+ v = (
85
+ 1
86
+ / torch.max(torch.sum(k, dim=-2), dim=-1).values[:, None, None]
87
+ * k.transpose(-1, -2)
88
+ )
89
+
90
+ for _ in range(n_iter):
91
+ kv = torch.matmul(k, v)
92
+ v = torch.matmul(
93
+ 0.25 * v,
94
+ 13 * i - torch.matmul(kv, 15 * i - torch.matmul(kv, 7 * i - kv)),
95
+ )
96
+ return v
97
+
98
+
99
+ def bool_mask_to_additive(
100
+ mask: torch.Tensor, dtype: Optional[torch.dtype] = torch.float32
101
+ ) -> torch.Tensor:
102
+ assert (
103
+ mask.dtype == torch.bool
104
+ ), "This util is meant to convert in between bool masks and additive ones"
105
+
106
+ mask_ = torch.zeros_like(mask, dtype=dtype)
107
+ mask_[~mask] = float("-inf")
108
+ return mask_
.venv/lib/python3.11/site-packages/xformers/components/feedforward/__init__.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ #
3
+ # This source code is licensed under the BSD license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Dict, Set, Union
9
+
10
+ from xformers.utils import (
11
+ generate_matching_config,
12
+ get_registry_decorator,
13
+ import_all_modules,
14
+ )
15
+
16
+ from .base import Feedforward, FeedforwardConfig # noqa
17
+
18
+ # CREDITS: Classy Vision registry mechanism
19
+
20
+ FEEDFORWARD_REGISTRY: Dict[str, Any] = {}
21
+ FEEDFORWARD_CLASS_NAMES: Set[str] = set()
22
+
23
+
24
+ def build_feedforward(config: Union[Dict[str, Any], FeedforwardConfig]):
25
+ """Builds a feedforward from a config.
26
+
27
+ This assumes a 'name' key in the config which is used to determine what
28
+ attention class to instantiate. For instance, a config `{"name": "my_feedforward",
29
+ "foo": "bar"}` will find a class that was registered as "my_feedforward"
30
+ (see :func:`register_feedforward`) and call .from_config on it."""
31
+
32
+ if not isinstance(config, FeedforwardConfig):
33
+ config_instance = generate_matching_config(
34
+ config, FEEDFORWARD_REGISTRY[config["name"]].config
35
+ )
36
+ else:
37
+ config_instance = config
38
+
39
+ return FEEDFORWARD_REGISTRY[config_instance.name].constructor.from_config(
40
+ config_instance
41
+ )
42
+
43
+
44
+ """Registers a Feedforward subclass.
45
+
46
+ This decorator allows xFormers to instantiate a subclass of Feedforward
47
+ from a configuration file, even if the class itself is not part of the
48
+ xFormers framework. To use it, apply this decorator to a Feedforward
49
+ subclass, like this:
50
+
51
+ .. code-block:: python
52
+
53
+ @dataclass
54
+ class MyConfig:
55
+ ...
56
+
57
+ @register_feedforward('my_ff', MyConfig)
58
+ class MyFeedforward(Feedforward):
59
+ ...
60
+
61
+ To instantiate a feedforward from a configuration file, see :func:`build_feedforward`."""
62
+ register_feedforward: Callable[
63
+ [str, Any], Callable[[Any], Any]
64
+ ] = get_registry_decorator(
65
+ FEEDFORWARD_REGISTRY, FEEDFORWARD_CLASS_NAMES, Feedforward, FeedforwardConfig
66
+ )
67
+
68
+ from .mlp import MLP # noqa
69
+
70
+ __all__ = [
71
+ "MLP",
72
+ "Feedforward",
73
+ "build_feedforward",
74
+ "register_feedforward",
75
+ ]
76
+
77
+ # automatically import any Python files in the directory
78
+ import_all_modules(str(Path(__file__).parent), "xformers.components.feedforward")