Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/xformers/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/__pycache__/_cpp_lib.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/__pycache__/_deprecation_warning.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/__pycache__/attn_bias_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/__pycache__/checkpoint.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/__pycache__/info.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/__pycache__/test.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/__pycache__/version.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_mem_eff_attention.py +373 -0
- .venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_sp24.py +178 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__init__.py +124 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/_sputnik_sparse.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/attention_mask.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/attention_patterns.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/compositional.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/core.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/favor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/fourier_mix.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/global_tokens.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/lambda_layer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/linformer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/local.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/nystrom.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/ortho.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/pooling.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/random.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/scaled_dot_product.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/sparsity_config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/visual.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/_sputnik_sparse.py +121 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/attention_mask.py +143 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/base.py +95 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/compositional.py +341 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__init__.py +26 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/softmax.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/base.py +61 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/softmax.py +288 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/global_tokens.py +122 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/linformer.py +74 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/ortho.py +324 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/pooling.py +82 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/sparsity_config.py +812 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/utils.py +108 -0
- .venv/lib/python3.11/site-packages/xformers/components/feedforward/__init__.py +78 -0
.venv/lib/python3.11/site-packages/xformers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/__pycache__/_cpp_lib.cpython-311.pyc
ADDED
|
Binary file (8.38 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/__pycache__/_deprecation_warning.cpython-311.pyc
ADDED
|
Binary file (661 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/__pycache__/attn_bias_utils.cpython-311.pyc
ADDED
|
Binary file (22.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/__pycache__/checkpoint.cpython-311.pyc
ADDED
|
Binary file (27.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/__pycache__/info.cpython-311.pyc
ADDED
|
Binary file (4.55 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/__pycache__/test.cpython-311.pyc
ADDED
|
Binary file (177 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (8.17 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/__pycache__/version.cpython-311.pyc
ADDED
|
Binary file (207 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_mem_eff_attention.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import itertools
|
| 8 |
+
import random
|
| 9 |
+
from functools import partial
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch.utils import benchmark
|
| 13 |
+
|
| 14 |
+
import xformers.ops
|
| 15 |
+
import xformers.ops.fmha as fmha
|
| 16 |
+
from xformers.attn_bias_utils import create_attn_bias, ref_attention
|
| 17 |
+
from xformers.benchmarks.utils import benchmark_main_helper, create_argparser
|
| 18 |
+
|
| 19 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
| 20 |
+
|
| 21 |
+
min_run_time = 0.5
|
| 22 |
+
device = torch.device("cuda")
|
| 23 |
+
|
| 24 |
+
NUM_THREADS = [1] if device.type == "cuda" else [1, 40]
|
| 25 |
+
VISION_SHAPES = [
|
| 26 |
+
# ViT
|
| 27 |
+
(384, 197, 1, 88),
|
| 28 |
+
(384, 197, 1, 80),
|
| 29 |
+
(384, 197, 1, 64),
|
| 30 |
+
(1024, 197, 1, 88),
|
| 31 |
+
(1024, 197, 1, 80),
|
| 32 |
+
(1024, 197, 1, 64),
|
| 33 |
+
# ViT-Huge
|
| 34 |
+
(32 * 16, 197, 1, 80),
|
| 35 |
+
(32, 197, 16, 80),
|
| 36 |
+
(32, 197, 16, 64),
|
| 37 |
+
(32, 197, 16, 128),
|
| 38 |
+
# ViT-Giant
|
| 39 |
+
(16 * 16, 197, 1, 88),
|
| 40 |
+
(16, 197, 16, 88),
|
| 41 |
+
(16, 197, 16, 64),
|
| 42 |
+
(16, 197, 16, 128),
|
| 43 |
+
# FB models
|
| 44 |
+
(1024, 82, 8, 64),
|
| 45 |
+
(150, 256, 16, 64),
|
| 46 |
+
(64, 256, 12, 64),
|
| 47 |
+
# Stable diffusion (https://github.com/huggingface/diffusers/pull/532)
|
| 48 |
+
(1, 4096, 16, 40), # 512x512
|
| 49 |
+
(1, 16384, 16, 40), # 1024x1024
|
| 50 |
+
(1, 4096, 16, 80),
|
| 51 |
+
(1, 16384, 16, 80),
|
| 52 |
+
# + bs4
|
| 53 |
+
(4, 4096, 16, 40),
|
| 54 |
+
(4, 16384, 16, 40),
|
| 55 |
+
(4, 4096, 16, 80),
|
| 56 |
+
(4, 16384, 16, 80),
|
| 57 |
+
# ParlAI model
|
| 58 |
+
(256, 4096, 16, 64),
|
| 59 |
+
# Zetta B M H K
|
| 60 |
+
(8, 2048, 20, 128),
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
LLM_SHAPES = [
|
| 64 |
+
# LLaMa 70b - mp=8/16
|
| 65 |
+
*sorted(itertools.product([1, 2], [2048, 4096, 8192], [4, 8], [128])),
|
| 66 |
+
*sorted(
|
| 67 |
+
itertools.product([16], [128, 512, 1024], [16], [16, 32, 64, 128, 160, 256])
|
| 68 |
+
),
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
OPS = [
|
| 73 |
+
(xformers.ops.fmha.cutlass.FwOp, xformers.ops.fmha.cutlass.BwOp),
|
| 74 |
+
(xformers.ops.fmha.flash.FwOp, xformers.ops.fmha.flash.BwOp),
|
| 75 |
+
(xformers.ops.fmha.flash3.FwOp, xformers.ops.fmha.flash3.BwOp),
|
| 76 |
+
(xformers.ops.fmha.ck.FwOp, xformers.ops.fmha.ck.BwOp),
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def product_dict(**kwargs):
|
| 81 |
+
keys = kwargs.keys()
|
| 82 |
+
vals = kwargs.values()
|
| 83 |
+
for instance in itertools.product(*vals):
|
| 84 |
+
yield dict(zip(keys, instance))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
VISION_CASES, LLM_CASES = [
|
| 88 |
+
list(
|
| 89 |
+
product_dict(
|
| 90 |
+
shape_q=SHAPES,
|
| 91 |
+
num_threads=NUM_THREADS,
|
| 92 |
+
dropout_p=[0.0],
|
| 93 |
+
attn_bias_cfg=[(type(None), False)],
|
| 94 |
+
dtype=[torch.half],
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
for SHAPES in (VISION_SHAPES, LLM_SHAPES)
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
# Add more cases with some variations
|
| 101 |
+
for c in VISION_CASES.copy():
|
| 102 |
+
c = c.copy()
|
| 103 |
+
c.update(
|
| 104 |
+
random.Random(str(c["shape_q"])).choice(
|
| 105 |
+
[
|
| 106 |
+
{"dropout_p": 0.3},
|
| 107 |
+
{"attn_bias_cfg": (torch.Tensor, False)},
|
| 108 |
+
{"attn_bias_cfg": (torch.Tensor, True)},
|
| 109 |
+
{"dtype": torch.bfloat16},
|
| 110 |
+
{"dtype": torch.float},
|
| 111 |
+
]
|
| 112 |
+
)
|
| 113 |
+
)
|
| 114 |
+
VISION_CASES.append(c)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
LLM_CASE_UPDATES = [
|
| 118 |
+
{"attn_bias_cfg": (torch.Tensor, True)},
|
| 119 |
+
{"attn_bias_cfg": (xformers.ops.LowerTriangularMask, False)},
|
| 120 |
+
*[
|
| 121 |
+
{
|
| 122 |
+
"attn_bias_cfg": (
|
| 123 |
+
xformers.ops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
| 124 |
+
False,
|
| 125 |
+
),
|
| 126 |
+
"Hkv": Hkv,
|
| 127 |
+
"dtype": torch.bfloat16,
|
| 128 |
+
}
|
| 129 |
+
for Hkv in [1, 2]
|
| 130 |
+
],
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
for c in LLM_CASES.copy():
|
| 134 |
+
for update in LLM_CASE_UPDATES:
|
| 135 |
+
c = c.copy()
|
| 136 |
+
c.update(update)
|
| 137 |
+
LLM_CASES.append(c)
|
| 138 |
+
|
| 139 |
+
CASES = VISION_CASES + LLM_CASES
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def create_tensors(shape_q, Hkv, dtype, requires_grad=False, packed=True):
|
| 143 |
+
stacked_shape = list(shape_q) # B, M, H, K
|
| 144 |
+
Hq = shape_q[2]
|
| 145 |
+
stacked_dim = 2 if packed else 0
|
| 146 |
+
stacked_shape.insert(stacked_dim, 3)
|
| 147 |
+
qkv = torch.rand(
|
| 148 |
+
stacked_shape, device=device, dtype=dtype, requires_grad=requires_grad
|
| 149 |
+
)
|
| 150 |
+
q = torch.rand(shape_q, device=device, dtype=dtype, requires_grad=requires_grad)
|
| 151 |
+
shape_kv = (shape_q[0], shape_q[1], Hkv, shape_q[3])
|
| 152 |
+
k = (
|
| 153 |
+
torch.rand(shape_kv, device=device, dtype=dtype, requires_grad=requires_grad)
|
| 154 |
+
.reshape(shape_q[0], shape_q[1], 1, Hkv, shape_q[3])
|
| 155 |
+
.expand(shape_q[0], shape_q[1], Hq // Hkv, Hkv, shape_q[3])
|
| 156 |
+
.reshape(shape_q)
|
| 157 |
+
)
|
| 158 |
+
v = (
|
| 159 |
+
torch.rand(shape_kv, device=device, dtype=dtype, requires_grad=requires_grad)
|
| 160 |
+
.reshape(shape_q[0], shape_q[1], 1, Hkv, shape_q[3])
|
| 161 |
+
.expand(shape_q[0], shape_q[1], Hq // Hkv, Hkv, shape_q[3])
|
| 162 |
+
.reshape(shape_q)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
return qkv, q, k, v
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def mem_eff_attention_fw(
|
| 169 |
+
shape_q,
|
| 170 |
+
num_threads: int,
|
| 171 |
+
attn_bias_cfg,
|
| 172 |
+
dropout_p,
|
| 173 |
+
dtype,
|
| 174 |
+
packed=True,
|
| 175 |
+
Hkv=None,
|
| 176 |
+
):
|
| 177 |
+
B, M, Hq, K = shape_q
|
| 178 |
+
Hkv = Hkv or Hq
|
| 179 |
+
_, q, k, v = create_tensors(
|
| 180 |
+
shape_q,
|
| 181 |
+
Hkv,
|
| 182 |
+
dtype,
|
| 183 |
+
requires_grad=False,
|
| 184 |
+
packed=packed,
|
| 185 |
+
)
|
| 186 |
+
attn_bias_type, attn_bias_requires_grad = attn_bias_cfg
|
| 187 |
+
if attn_bias_requires_grad:
|
| 188 |
+
return
|
| 189 |
+
|
| 190 |
+
dtype_str = {
|
| 191 |
+
torch.bfloat16: "b16",
|
| 192 |
+
torch.half: "f16",
|
| 193 |
+
torch.float: "f32",
|
| 194 |
+
}[dtype]
|
| 195 |
+
sub_label = (
|
| 196 |
+
f"{dtype_str} {B}-{M}-{Hq}-{Hkv}-{K}, p={dropout_p}, "
|
| 197 |
+
f"BiasT={attn_bias_type.__name__}"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
has_run = False
|
| 201 |
+
for fw_op, bw_op in OPS:
|
| 202 |
+
bias = create_attn_bias(
|
| 203 |
+
attn_bias_type,
|
| 204 |
+
batch_size=B,
|
| 205 |
+
num_heads=Hq,
|
| 206 |
+
num_heads_groups=Hq // Hkv,
|
| 207 |
+
q_len=M,
|
| 208 |
+
kv_len=M,
|
| 209 |
+
dtype=dtype,
|
| 210 |
+
device=device,
|
| 211 |
+
requires_grad=attn_bias_requires_grad,
|
| 212 |
+
fmt="BMHK",
|
| 213 |
+
op=fw_op,
|
| 214 |
+
)
|
| 215 |
+
inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p)
|
| 216 |
+
if isinstance(
|
| 217 |
+
bias,
|
| 218 |
+
(
|
| 219 |
+
fmha.attn_bias.BlockDiagonalMask,
|
| 220 |
+
fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
| 221 |
+
),
|
| 222 |
+
):
|
| 223 |
+
q, k, v = [x.reshape([1, -1, *x.shape[2:]]) for x in [q, k, v]]
|
| 224 |
+
if not fw_op.supports(inp):
|
| 225 |
+
continue
|
| 226 |
+
|
| 227 |
+
yield benchmark.Timer(
|
| 228 |
+
stmt="fn(q, k, v, attn_bias, p)",
|
| 229 |
+
globals={
|
| 230 |
+
"q": q,
|
| 231 |
+
"k": k,
|
| 232 |
+
"v": v,
|
| 233 |
+
"attn_bias": inp.attn_bias,
|
| 234 |
+
"p": dropout_p,
|
| 235 |
+
"fn": partial(
|
| 236 |
+
xformers.ops.memory_efficient_attention, op=(fw_op, bw_op)
|
| 237 |
+
),
|
| 238 |
+
},
|
| 239 |
+
label=f"attention (attn_bias={attn_bias_type})",
|
| 240 |
+
description=fw_op.NAME,
|
| 241 |
+
sub_label=sub_label,
|
| 242 |
+
num_threads=num_threads,
|
| 243 |
+
)
|
| 244 |
+
has_run = True
|
| 245 |
+
|
| 246 |
+
if not has_run:
|
| 247 |
+
return
|
| 248 |
+
|
| 249 |
+
yield benchmark.Timer(
|
| 250 |
+
stmt="fn(q, k, v, attn_bias, p)",
|
| 251 |
+
globals={
|
| 252 |
+
"q": q,
|
| 253 |
+
"k": k,
|
| 254 |
+
"v": v,
|
| 255 |
+
"attn_bias": inp.attn_bias,
|
| 256 |
+
"p": dropout_p,
|
| 257 |
+
"fn": ref_attention,
|
| 258 |
+
},
|
| 259 |
+
label=f"attention (attn_bias={attn_bias_type})",
|
| 260 |
+
description="eager",
|
| 261 |
+
sub_label=sub_label,
|
| 262 |
+
num_threads=num_threads,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def mem_eff_attention_bw(
|
| 267 |
+
shape_q, num_threads: int, attn_bias_cfg, dropout_p, dtype, Hkv=None
|
| 268 |
+
):
|
| 269 |
+
B, M, Hq, K = shape_q
|
| 270 |
+
Hkv = Hkv or Hq
|
| 271 |
+
_, q, k, v = create_tensors(
|
| 272 |
+
shape_q,
|
| 273 |
+
Hkv,
|
| 274 |
+
dtype,
|
| 275 |
+
requires_grad=True,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
attn_bias_type, attn_bias_requires_grad = attn_bias_cfg
|
| 279 |
+
|
| 280 |
+
dtype_str = {
|
| 281 |
+
torch.bfloat16: "b16",
|
| 282 |
+
torch.half: "f16",
|
| 283 |
+
torch.float: "f32",
|
| 284 |
+
}[dtype]
|
| 285 |
+
sub_label = (
|
| 286 |
+
f"{dtype_str} {B}-{M}-{Hq}-{Hkv}-{K}, p={dropout_p}, "
|
| 287 |
+
f"BiasT={attn_bias_type.__name__}, BiasGrad={attn_bias_requires_grad}"
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
has_run = False
|
| 291 |
+
for fw_op, bw_op in OPS:
|
| 292 |
+
bias = create_attn_bias(
|
| 293 |
+
attn_bias_type,
|
| 294 |
+
batch_size=B,
|
| 295 |
+
num_heads=Hq,
|
| 296 |
+
num_heads_groups=Hq // Hkv,
|
| 297 |
+
q_len=M,
|
| 298 |
+
kv_len=M,
|
| 299 |
+
dtype=dtype,
|
| 300 |
+
device=device,
|
| 301 |
+
requires_grad=attn_bias_requires_grad,
|
| 302 |
+
fmt="BMHK",
|
| 303 |
+
op=bw_op,
|
| 304 |
+
)
|
| 305 |
+
inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p)
|
| 306 |
+
|
| 307 |
+
if not fw_op.supports(inp) or not bw_op.supports(inp):
|
| 308 |
+
continue
|
| 309 |
+
has_run = True
|
| 310 |
+
out = xformers.ops.memory_efficient_attention(
|
| 311 |
+
inp.query, inp.key, inp.value, inp.attn_bias, inp.p, op=(fw_op, bw_op)
|
| 312 |
+
)
|
| 313 |
+
grad_benchmark = torch.ones_like(q)
|
| 314 |
+
|
| 315 |
+
yield benchmark.Timer(
|
| 316 |
+
stmt="out.backward(grad, retain_graph=True)",
|
| 317 |
+
globals={
|
| 318 |
+
"out": out,
|
| 319 |
+
"grad": grad_benchmark,
|
| 320 |
+
},
|
| 321 |
+
label=f"attention backward (attn_bias={attn_bias_type})",
|
| 322 |
+
description=bw_op.NAME,
|
| 323 |
+
sub_label=sub_label,
|
| 324 |
+
num_threads=num_threads,
|
| 325 |
+
)
|
| 326 |
+
del out
|
| 327 |
+
|
| 328 |
+
if not has_run:
|
| 329 |
+
return
|
| 330 |
+
yield benchmark.Timer(
|
| 331 |
+
stmt="out.backward(grad, retain_graph=True)",
|
| 332 |
+
globals={
|
| 333 |
+
"out": ref_attention(q, k, v, inp.attn_bias, dropout_p),
|
| 334 |
+
"grad": grad_benchmark,
|
| 335 |
+
},
|
| 336 |
+
label=f"attention backward (attn_bias={attn_bias_type})",
|
| 337 |
+
description="vanilla",
|
| 338 |
+
sub_label=sub_label,
|
| 339 |
+
num_threads=num_threads,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def main():
|
| 344 |
+
arg_parser = create_argparser()
|
| 345 |
+
arg_parser.add_argument(
|
| 346 |
+
"--omit-forward",
|
| 347 |
+
action="store_true",
|
| 348 |
+
help="Do not run forward benchmarks",
|
| 349 |
+
)
|
| 350 |
+
arg_parser.add_argument(
|
| 351 |
+
"--omit-backward",
|
| 352 |
+
action="store_true",
|
| 353 |
+
help="Do not run backward benchmarks",
|
| 354 |
+
)
|
| 355 |
+
args = arg_parser.parse_args()
|
| 356 |
+
if not args.omit_forward:
|
| 357 |
+
benchmark_main_helper(
|
| 358 |
+
mem_eff_attention_fw,
|
| 359 |
+
CASES,
|
| 360 |
+
arg_parser=arg_parser,
|
| 361 |
+
min_run_time=min_run_time,
|
| 362 |
+
)
|
| 363 |
+
if not args.omit_backward:
|
| 364 |
+
benchmark_main_helper(
|
| 365 |
+
mem_eff_attention_bw,
|
| 366 |
+
CASES,
|
| 367 |
+
arg_parser=arg_parser,
|
| 368 |
+
min_run_time=min_run_time,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
if __name__ == "__main__":
|
| 373 |
+
main()
|
.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_sp24.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from typing import Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch import nn
|
| 12 |
+
from utils import DTYPE2STR, benchmark_main_helper2, product_dict
|
| 13 |
+
|
| 14 |
+
import xformers.ops as xops
|
| 15 |
+
|
| 16 |
+
min_run_time = 0.5
|
| 17 |
+
device = torch.device("cuda")
|
| 18 |
+
|
| 19 |
+
CASES = list(
|
| 20 |
+
product_dict(
|
| 21 |
+
B_in_hidden_out_ft=[
|
| 22 |
+
(2048 * 8, 2048, 2048 * 3, 2048),
|
| 23 |
+
(2048, 5120, 5120 * 3, 5120), # 13b
|
| 24 |
+
(1024, 8192, 8192 * 3, 8192), # 30b
|
| 25 |
+
(2048, 8192, 8192 * 3, 8192), # 30b
|
| 26 |
+
(2048 * 2, 8192, 8192 * 3, 8192), # 30b
|
| 27 |
+
# DINO ViT-L: lg + sm crops (patch16)
|
| 28 |
+
(64 * 2 * (14 * 14 + 1) + 64 * 8 * (6 * 6 + 1), 1024, 1024 * 4, 1024),
|
| 29 |
+
# DINO ViT-g: lg + sm crops (patch16)
|
| 30 |
+
(
|
| 31 |
+
12 * 2 * (16 * 16 + 1 + 11) + 12 * 8 * (7 * 7 + 1 + 11),
|
| 32 |
+
1536,
|
| 33 |
+
1536 * 4,
|
| 34 |
+
1536,
|
| 35 |
+
),
|
| 36 |
+
],
|
| 37 |
+
dtype=[torch.half],
|
| 38 |
+
bias=[False],
|
| 39 |
+
)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Mlp(nn.Module):
|
| 44 |
+
LINEAR_CLS = nn.Linear
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self, B_in_hidden_out_ft: Tuple[int, int, int, int], dtype, bias: bool, bw: bool
|
| 48 |
+
) -> None:
|
| 49 |
+
B, in_ft, hid_ft, out_ft = B_in_hidden_out_ft
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.label = "mlp"
|
| 52 |
+
self.sub_label = (
|
| 53 |
+
f"{DTYPE2STR[dtype]} ({B},{in_ft},{hid_ft},{out_ft}){' b' if bias else ''}"
|
| 54 |
+
)
|
| 55 |
+
self.fc1 = self.LINEAR_CLS(in_ft, hid_ft, bias=bias)
|
| 56 |
+
self.act = nn.GELU()
|
| 57 |
+
self.fc2 = self.LINEAR_CLS(hid_ft, out_ft, bias=bias)
|
| 58 |
+
self.grad = torch.randn([B, out_ft], device="cuda", dtype=dtype)
|
| 59 |
+
self.input = torch.randn(
|
| 60 |
+
[B, in_ft], device="cuda", dtype=dtype, requires_grad=True
|
| 61 |
+
)
|
| 62 |
+
self.out = self.input
|
| 63 |
+
self.to("cuda").to(dtype)
|
| 64 |
+
|
| 65 |
+
def fw(self):
|
| 66 |
+
x = self.input
|
| 67 |
+
x = self.fc1(x)
|
| 68 |
+
x = self.act(x)
|
| 69 |
+
x = self.fc2(x)
|
| 70 |
+
self.out = x
|
| 71 |
+
|
| 72 |
+
def bw(self):
|
| 73 |
+
self.out.backward(self.grad, retain_graph=True)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class MlpDenseMask(Mlp):
|
| 77 |
+
def fw(self):
|
| 78 |
+
x = self.input
|
| 79 |
+
x = self.fc1(x)
|
| 80 |
+
|
| 81 |
+
mask = torch.ops.xformers.sparse24_largest_mask_2d(x)
|
| 82 |
+
x = mask * x
|
| 83 |
+
|
| 84 |
+
x = self.act(x)
|
| 85 |
+
x = self.fc2(x)
|
| 86 |
+
self.out = x
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class MlpAct24(Mlp):
|
| 90 |
+
def fw(self):
|
| 91 |
+
x = self.input
|
| 92 |
+
x = self.fc1(x)
|
| 93 |
+
|
| 94 |
+
x = xops.sparsify24(x)
|
| 95 |
+
|
| 96 |
+
x = self.act(x)
|
| 97 |
+
x = self.fc2(x)
|
| 98 |
+
self.out = x
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class LinearW24(torch.nn.Linear):
|
| 102 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 103 |
+
w_sparse = xops.sparsify24(
|
| 104 |
+
self.weight,
|
| 105 |
+
gradient="24dense",
|
| 106 |
+
backend="cusparselt",
|
| 107 |
+
)
|
| 108 |
+
return F.linear(input, w_sparse, self.bias)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class MlpW24(Mlp):
|
| 112 |
+
LINEAR_CLS = LinearW24
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class MicrobenchmarkBase:
|
| 116 |
+
def __init__(
|
| 117 |
+
self, B_in_hidden_out_ft: Tuple[int, int, int, int], dtype, bias: bool, bw: bool
|
| 118 |
+
) -> None:
|
| 119 |
+
B, in_ft, hid_ft, out_ft = B_in_hidden_out_ft
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.label = "mlp"
|
| 122 |
+
self.sub_label = (
|
| 123 |
+
f"{DTYPE2STR[dtype]} ({B},{in_ft},{hid_ft},{out_ft}){' b' if bias else ''}"
|
| 124 |
+
)
|
| 125 |
+
self.input = torch.randn(
|
| 126 |
+
[B, in_ft], device="cuda", dtype=dtype, requires_grad=True
|
| 127 |
+
)
|
| 128 |
+
self.input_colMajor = self.input.t().contiguous().t()
|
| 129 |
+
self.input_sp = xops.sparsify24(self.input)
|
| 130 |
+
|
| 131 |
+
def bw(self) -> None:
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class MicrobenchmarkSparsify24(MicrobenchmarkBase):
|
| 136 |
+
def fw(self) -> torch.Tensor:
|
| 137 |
+
xops.sparsify24(self.input)
|
| 138 |
+
return self.input
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class MicrobenchmarkSp24ApplyDense(MicrobenchmarkBase):
|
| 142 |
+
def fw(self) -> torch.Tensor:
|
| 143 |
+
xops.sparsify24_like(self.input, pattern=self.input_sp, out_dense=True)
|
| 144 |
+
return self.input
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class MicrobenchmarkSp24ApplyDenseT(MicrobenchmarkBase):
|
| 148 |
+
def fw(self) -> torch.Tensor:
|
| 149 |
+
xops.sparsify24_like(self.input_colMajor, pattern=self.input_sp, out_dense=True)
|
| 150 |
+
return self.input
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class MicrobenchmarkInputClone(MicrobenchmarkBase):
|
| 154 |
+
def fw(self) -> torch.Tensor:
|
| 155 |
+
self.input.clone()
|
| 156 |
+
return self.input
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
functions = {
|
| 160 |
+
"act24": MlpAct24,
|
| 161 |
+
"dense": Mlp,
|
| 162 |
+
"w24": MlpW24,
|
| 163 |
+
"s24_inp_sparsify24": MicrobenchmarkSparsify24,
|
| 164 |
+
"s24_inp_apply_dense": MicrobenchmarkSp24ApplyDense,
|
| 165 |
+
"s24_inp_apply_dense_t": MicrobenchmarkSp24ApplyDenseT,
|
| 166 |
+
"s24_inp_clone": MicrobenchmarkInputClone,
|
| 167 |
+
}
|
| 168 |
+
benchmark_main_helper2(
|
| 169 |
+
"sp24_fw", fw=True, cases=CASES, functions=functions, min_run_time=min_run_time
|
| 170 |
+
)
|
| 171 |
+
benchmark_main_helper2(
|
| 172 |
+
"sp24_fwbw",
|
| 173 |
+
fw=True,
|
| 174 |
+
bw=True,
|
| 175 |
+
cases=CASES,
|
| 176 |
+
functions=functions,
|
| 177 |
+
min_run_time=min_run_time,
|
| 178 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__init__.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Callable, Dict, Set, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from xformers.utils import (
|
| 13 |
+
generate_matching_config,
|
| 14 |
+
get_registry_decorator,
|
| 15 |
+
import_all_modules,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
from ._sputnik_sparse import SparseCS
|
| 19 |
+
from .attention_mask import AttentionMask
|
| 20 |
+
from .base import Attention, AttentionConfig # noqa
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger("xformers")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# CREDITS: Classy Vision registry mechanism
|
| 26 |
+
|
| 27 |
+
ATTENTION_REGISTRY: Dict[str, Any] = {}
|
| 28 |
+
ATTENTION_CLASS_NAMES: Set[str] = set()
|
| 29 |
+
|
| 30 |
+
# Arbitrary threshold for now,
|
| 31 |
+
# in between dense and sparse matrix algorithms for the attention mechanism
|
| 32 |
+
_DENSITY_THRESHOLD = 0.30 # noqa # from the sputnik paper, vs.
|
| 33 |
+
_USE_SPUTNIK = True
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_attention(config: Union[Dict[str, Any], AttentionConfig]):
|
| 37 |
+
"""Builds an attention from a config.
|
| 38 |
+
|
| 39 |
+
This assumes a 'name' key in the config which is used to determine what
|
| 40 |
+
attention class to instantiate. For instance, a config `{"name": "my_attention",
|
| 41 |
+
"foo": "bar"}` will find a class that was registered as "my_attention"
|
| 42 |
+
(see :func:`register_attention`) and call .from_config on it."""
|
| 43 |
+
|
| 44 |
+
if not isinstance(config, AttentionConfig):
|
| 45 |
+
try:
|
| 46 |
+
config_instance = generate_matching_config(
|
| 47 |
+
config, ATTENTION_REGISTRY[config["name"]].config
|
| 48 |
+
)
|
| 49 |
+
except KeyError as e:
|
| 50 |
+
name = config["name"]
|
| 51 |
+
logger.warning(f"{name} not available among {ATTENTION_REGISTRY.keys()}")
|
| 52 |
+
raise e
|
| 53 |
+
else:
|
| 54 |
+
config_instance = config
|
| 55 |
+
|
| 56 |
+
return ATTENTION_REGISTRY[config_instance.name].constructor.from_config(
|
| 57 |
+
config_instance
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
"""Registers an Attention subclass.
|
| 62 |
+
|
| 63 |
+
This decorator allows xFormers to instantiate a subclass of Attention
|
| 64 |
+
from a configuration file, even if the class itself is not part of the
|
| 65 |
+
xFormers library. To use it, apply this decorator to an Attention
|
| 66 |
+
subclass, like this:
|
| 67 |
+
|
| 68 |
+
.. code-block:: python
|
| 69 |
+
|
| 70 |
+
@dataclass
|
| 71 |
+
class MyConfig:
|
| 72 |
+
...
|
| 73 |
+
|
| 74 |
+
@register_attention('my_attention', MyConfig)
|
| 75 |
+
class MyAttention(Attention):
|
| 76 |
+
...
|
| 77 |
+
|
| 78 |
+
To instantiate an attention from a configuration file, see :func:`build_attention`."""
|
| 79 |
+
register_attention: Callable[[str, Any], Callable[[Any], Any]] = get_registry_decorator(
|
| 80 |
+
ATTENTION_REGISTRY, ATTENTION_CLASS_NAMES, Attention, AttentionConfig
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def maybe_sparsify(matrix) -> Any:
|
| 85 |
+
# Sparsify if that makes sense
|
| 86 |
+
if torch.count_nonzero(matrix).item() / matrix.numel() > _DENSITY_THRESHOLD:
|
| 87 |
+
# If not sparse, then AttentionMask is the reference type
|
| 88 |
+
return AttentionMask.from_bool(matrix)
|
| 89 |
+
|
| 90 |
+
return sparsify(matrix)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def sparsify(matrix):
|
| 94 |
+
if _USE_SPUTNIK:
|
| 95 |
+
return SparseCS(matrix)
|
| 96 |
+
return matrix.to_sparse()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
from .favor import FavorAttention # noqa
|
| 100 |
+
from .global_tokens import GlobalAttention # noqa
|
| 101 |
+
from .linformer import LinformerAttention # noqa
|
| 102 |
+
from .local import LocalAttention # noqa
|
| 103 |
+
from .nystrom import NystromAttention # noqa
|
| 104 |
+
from .ortho import OrthoFormerAttention # noqa
|
| 105 |
+
from .random import RandomAttention # noqa
|
| 106 |
+
from .scaled_dot_product import ScaledDotProduct # noqa
|
| 107 |
+
|
| 108 |
+
__all__ = [
|
| 109 |
+
"ScaledDotProduct",
|
| 110 |
+
"LocalAttention",
|
| 111 |
+
"LinformerAttention",
|
| 112 |
+
"NystromAttention",
|
| 113 |
+
"RandomAttention",
|
| 114 |
+
"OrthoFormerAttention",
|
| 115 |
+
"GlobalAttention",
|
| 116 |
+
"FavorAttention",
|
| 117 |
+
"Attention",
|
| 118 |
+
"AttentionMask",
|
| 119 |
+
"build_attention",
|
| 120 |
+
"register_attention",
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
# automatically import any Python files in the directory
|
| 124 |
+
import_all_modules(str(Path(__file__).parent), "xformers.components.attention")
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (4.18 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/_sputnik_sparse.cpython-311.pyc
ADDED
|
Binary file (7.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/attention_mask.cpython-311.pyc
ADDED
|
Binary file (7.48 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/attention_patterns.cpython-311.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (4.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/compositional.cpython-311.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/core.cpython-311.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/favor.cpython-311.pyc
ADDED
|
Binary file (7.42 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/fourier_mix.cpython-311.pyc
ADDED
|
Binary file (2.14 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/global_tokens.cpython-311.pyc
ADDED
|
Binary file (5.47 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/lambda_layer.cpython-311.pyc
ADDED
|
Binary file (3.85 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/linformer.cpython-311.pyc
ADDED
|
Binary file (4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/local.cpython-311.pyc
ADDED
|
Binary file (5.22 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/nystrom.cpython-311.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/ortho.cpython-311.pyc
ADDED
|
Binary file (15.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/pooling.cpython-311.pyc
ADDED
|
Binary file (3.26 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/random.cpython-311.pyc
ADDED
|
Binary file (5.43 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/scaled_dot_product.cpython-311.pyc
ADDED
|
Binary file (5.46 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/sparsity_config.cpython-311.pyc
ADDED
|
Binary file (41.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (4.46 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/__pycache__/visual.cpython-311.pyc
ADDED
|
Binary file (4.68 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/_sputnik_sparse.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from xformers.ops import masked_matmul
|
| 10 |
+
from xformers.sparse import SparseCSRTensor
|
| 11 |
+
|
| 12 |
+
# TODO: this is here for BC
|
| 13 |
+
from xformers.sparse.utils import _csr_to_coo, _dense_to_sparse # noqa: F401
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SparseCS:
|
| 17 |
+
def __init__(self, matrix, device=None):
|
| 18 |
+
if device is None:
|
| 19 |
+
device = torch.device("cpu")
|
| 20 |
+
if matrix.ndim == 2:
|
| 21 |
+
matrix = matrix[None]
|
| 22 |
+
assert matrix.ndim == 3
|
| 23 |
+
self._mat = SparseCSRTensor.from_dense(matrix).to(device)
|
| 24 |
+
|
| 25 |
+
@property
|
| 26 |
+
def device(self):
|
| 27 |
+
return self._mat.device
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def ndim(self):
|
| 31 |
+
return self._mat.ndim
|
| 32 |
+
|
| 33 |
+
@property
|
| 34 |
+
def dtype(self):
|
| 35 |
+
return self._mat.dtype
|
| 36 |
+
|
| 37 |
+
@property
|
| 38 |
+
def is_sparse(self):
|
| 39 |
+
return True
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def shape(self):
|
| 43 |
+
return self._mat.shape[1:]
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def values(self):
|
| 47 |
+
return self._mat.values()
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def row_indices(self):
|
| 51 |
+
return self._mat._csr_row_indices
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def column_indices(self):
|
| 55 |
+
return self._mat._csr_column_indices
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def row_offsets(self):
|
| 59 |
+
return self._mat._csr_row_offsets
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def _transp_info(self):
|
| 63 |
+
return self._mat._csr_transp_info
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def wrap(
|
| 67 |
+
cls, shape, values, row_indices, row_offsets, column_indices, _transp_info
|
| 68 |
+
):
|
| 69 |
+
matrix = cls.__new__(cls)
|
| 70 |
+
_shape = (values.shape[0],) + shape
|
| 71 |
+
csr_matrix = SparseCSRTensor._wrap(
|
| 72 |
+
_shape, values, row_indices, row_offsets, column_indices, _transp_info
|
| 73 |
+
)
|
| 74 |
+
matrix._mat = csr_matrix
|
| 75 |
+
return matrix
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
def _wrap(cls, csr_matrix):
|
| 79 |
+
assert isinstance(csr_matrix, SparseCSRTensor)
|
| 80 |
+
matrix = cls.__new__(cls)
|
| 81 |
+
matrix._mat = csr_matrix
|
| 82 |
+
return matrix
|
| 83 |
+
|
| 84 |
+
def __mul__(self, other):
|
| 85 |
+
assert isinstance(other, (int, float))
|
| 86 |
+
return type(self)._wrap(self._mat * other)
|
| 87 |
+
|
| 88 |
+
def __add__(self, other):
|
| 89 |
+
assert isinstance(other, type(self))
|
| 90 |
+
return type(self)._wrap(self._mat + other._mat)
|
| 91 |
+
|
| 92 |
+
def matmul_with_mask(self, a, b):
|
| 93 |
+
return type(self)._wrap(masked_matmul(a, b, self._mat))
|
| 94 |
+
|
| 95 |
+
def softmax(self):
|
| 96 |
+
out = torch.nn.functional.softmax(self._mat, -1)
|
| 97 |
+
return type(self)._wrap(out)
|
| 98 |
+
|
| 99 |
+
def spmm(self, b):
|
| 100 |
+
out = torch.bmm(self._mat, b)
|
| 101 |
+
return out
|
| 102 |
+
|
| 103 |
+
def transpose(self):
|
| 104 |
+
out = torch.transpose(self._mat, -2, -1)
|
| 105 |
+
return type(self)._wrap(out)
|
| 106 |
+
|
| 107 |
+
def to(self, device):
|
| 108 |
+
assert isinstance(device, torch.device)
|
| 109 |
+
out = self._mat.to(device)
|
| 110 |
+
return type(self)._wrap(out)
|
| 111 |
+
|
| 112 |
+
def to_dense(self):
|
| 113 |
+
return self._mat.to_dense()
|
| 114 |
+
|
| 115 |
+
def logical_and(self, other: torch.Tensor):
|
| 116 |
+
assert not isinstance(other, SparseCS)
|
| 117 |
+
out = torch.logical_and(self._mat, other)
|
| 118 |
+
return type(self)._wrap(out)
|
| 119 |
+
|
| 120 |
+
def __and__(self, other):
|
| 121 |
+
return self.logical_and(other)
|
.venv/lib/python3.11/site-packages/xformers/components/attention/attention_mask.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from typing import Optional, Type, TypeVar
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
Self = TypeVar("Self", bound="AttentionMask")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AttentionMask:
|
| 15 |
+
"""
|
| 16 |
+
Holds an attention mask, along with a couple of helpers and attributes.
|
| 17 |
+
|
| 18 |
+
.. note: this is an additive mask, meaning that coefficients which should be computed hold the '0.' value,
|
| 19 |
+
and coefficients which should be skipped hold the '-inf' value. Any other value is possible if the purpose
|
| 20 |
+
is to bias the attention computation for instance
|
| 21 |
+
|
| 22 |
+
.. note: the attention mask dimensions are expected to be `[batch, to_sequence, from_sequence]`,
|
| 23 |
+
`[to_sequence, from_sequence]`, or anything broadcastable in between
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, additive_mask: torch.Tensor, is_causal: bool = False):
|
| 27 |
+
assert additive_mask.is_floating_point(), additive_mask.dtype
|
| 28 |
+
assert not additive_mask.requires_grad
|
| 29 |
+
|
| 30 |
+
if additive_mask.ndim == 2:
|
| 31 |
+
additive_mask = additive_mask.unsqueeze(0)
|
| 32 |
+
|
| 33 |
+
self.values = additive_mask
|
| 34 |
+
self.is_causal = is_causal
|
| 35 |
+
self.seq_len = additive_mask.shape[1]
|
| 36 |
+
self.to_seq_len = additive_mask.shape[0]
|
| 37 |
+
|
| 38 |
+
def to_bool(self) -> torch.Tensor:
|
| 39 |
+
"""
|
| 40 |
+
.. warning: we assume here that True implies that the value should be computed
|
| 41 |
+
"""
|
| 42 |
+
return self.values != float("-inf")
|
| 43 |
+
|
| 44 |
+
@classmethod
|
| 45 |
+
def from_bool(cls: Type[Self], x: torch.Tensor) -> Self:
|
| 46 |
+
"""
|
| 47 |
+
Create an AttentionMask given a boolean pattern.
|
| 48 |
+
.. warning: we assume here that True implies that the value should be computed
|
| 49 |
+
"""
|
| 50 |
+
assert x.dtype == torch.bool
|
| 51 |
+
|
| 52 |
+
additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device)
|
| 53 |
+
additive_mask.masked_fill_(x, 0.0)
|
| 54 |
+
additive_mask.masked_fill_(~x, float("-inf"))
|
| 55 |
+
|
| 56 |
+
return cls(additive_mask)
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def from_multiplicative(cls: Type[Self], x: torch.Tensor) -> Self:
|
| 60 |
+
"""
|
| 61 |
+
Create an AttentionMask given a multiplicative attention mask.
|
| 62 |
+
"""
|
| 63 |
+
assert not x.dtype == torch.bool
|
| 64 |
+
|
| 65 |
+
additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device)
|
| 66 |
+
x = x.bool()
|
| 67 |
+
|
| 68 |
+
additive_mask.masked_fill_(x, 0.0)
|
| 69 |
+
additive_mask.masked_fill_(~x, float("-inf"))
|
| 70 |
+
|
| 71 |
+
return cls(additive_mask)
|
| 72 |
+
|
| 73 |
+
@classmethod
|
| 74 |
+
def make_causal(
|
| 75 |
+
cls: Type[Self],
|
| 76 |
+
seq_len: int,
|
| 77 |
+
to_seq_len: Optional[int] = None,
|
| 78 |
+
device: Optional[torch.device] = None,
|
| 79 |
+
dtype: Optional[torch.dtype] = None,
|
| 80 |
+
) -> Self:
|
| 81 |
+
if not to_seq_len:
|
| 82 |
+
to_seq_len = seq_len
|
| 83 |
+
|
| 84 |
+
additive_mask = torch.triu(
|
| 85 |
+
torch.ones(seq_len, to_seq_len, device=device, dtype=dtype) * float("-inf"),
|
| 86 |
+
diagonal=1,
|
| 87 |
+
)
|
| 88 |
+
return cls(additive_mask=additive_mask, is_causal=True)
|
| 89 |
+
|
| 90 |
+
def make_crop(
|
| 91 |
+
self, seq_len: int, to_seq_len: Optional[int] = None
|
| 92 |
+
) -> "AttentionMask":
|
| 93 |
+
"""
|
| 94 |
+
Return a cropped attention mask, whose underlying tensor is a view of this one
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
if not to_seq_len:
|
| 98 |
+
to_seq_len = seq_len
|
| 99 |
+
|
| 100 |
+
return AttentionMask(
|
| 101 |
+
self.values[:, :seq_len, :to_seq_len], is_causal=self.is_causal
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def __repr__(self):
|
| 105 |
+
return f"AttentionMask - causal {self.is_causal} - mask " + str(self.values)
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def device(self):
|
| 109 |
+
return self.values.device
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def is_sparse(self):
|
| 113 |
+
return False
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def ndim(self):
|
| 117 |
+
return len(self.values.shape)
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def dtype(self):
|
| 121 |
+
return self.values.dtype
|
| 122 |
+
|
| 123 |
+
@property
|
| 124 |
+
def shape(self):
|
| 125 |
+
return self.values.shape
|
| 126 |
+
|
| 127 |
+
def __add__(self, other):
|
| 128 |
+
return AttentionMask(self.values + other.values, is_causal=False)
|
| 129 |
+
|
| 130 |
+
def to(
|
| 131 |
+
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
|
| 132 |
+
) -> "AttentionMask":
|
| 133 |
+
assert device is None or isinstance(device, torch.device)
|
| 134 |
+
assert dtype is None or isinstance(dtype, torch.dtype)
|
| 135 |
+
assert device is not None or dtype is not None
|
| 136 |
+
|
| 137 |
+
# Noop if we don't need to create another instance
|
| 138 |
+
if ((device and device == self.device) or not device) and (
|
| 139 |
+
(dtype and dtype == self.dtype) or not dtype
|
| 140 |
+
):
|
| 141 |
+
return self
|
| 142 |
+
|
| 143 |
+
return AttentionMask(self.values.to(device=device, dtype=dtype), self.is_causal)
|
.venv/lib/python3.11/site-packages/xformers/components/attention/base.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from abc import ABCMeta, abstractmethod
|
| 8 |
+
from dataclasses import asdict, dataclass
|
| 9 |
+
from typing import Optional, Type, TypeVar
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
from xformers._deprecation_warning import deprecated_function
|
| 15 |
+
from xformers.components.attention import AttentionMask
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class AttentionConfig:
|
| 20 |
+
"""Parameters required for all Attentions.
|
| 21 |
+
Can accept and store extra parameters.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
name: str # the registered name for this attention mechanism
|
| 25 |
+
dropout: float # dropout probability
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
Self = TypeVar("Self", bound="Attention")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Define the common interface, every attention block needs to derive from it
|
| 32 |
+
class Attention(nn.Module, metaclass=ABCMeta):
|
| 33 |
+
r"""The base Attention mechanism, which is typically a sub-part of the multi-head attention"""
|
| 34 |
+
|
| 35 |
+
_causal_mask: Optional[AttentionMask] = None
|
| 36 |
+
|
| 37 |
+
@abstractmethod
|
| 38 |
+
def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
|
| 39 |
+
super().__init__()
|
| 40 |
+
deprecated_function(self)
|
| 41 |
+
|
| 42 |
+
# Requires the inputs to be projected
|
| 43 |
+
self.requires_input_projection = True
|
| 44 |
+
|
| 45 |
+
# Whether the head dimension needs to be present (if not it can be folded into the batch dimension)
|
| 46 |
+
self.requires_head_dimension = False
|
| 47 |
+
|
| 48 |
+
# key padding mask and attention mask must be passed in as separate arguments instead of a merged attention mask
|
| 49 |
+
self.requires_separate_masks = False
|
| 50 |
+
|
| 51 |
+
# Requires that K and Q have the same sequence length
|
| 52 |
+
self.requires_same_k_q_dimensions = False
|
| 53 |
+
|
| 54 |
+
# Whether the attention owns the single head/multihead mechanism
|
| 55 |
+
# so that the MHA wrapper should skip it
|
| 56 |
+
self.requires_skip_multi_head = False
|
| 57 |
+
|
| 58 |
+
# This attention requires a context length which is squared, often due to 2D pooling
|
| 59 |
+
self.requires_squared_context = False
|
| 60 |
+
|
| 61 |
+
# Whether this attention mechanism supports attention masks
|
| 62 |
+
self.supports_attention_mask = True
|
| 63 |
+
self.supports_key_padding_mask = False
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def from_config(cls: Type[Self], config: AttentionConfig) -> Self:
|
| 67 |
+
# Generate the class inputs from the config
|
| 68 |
+
fields = asdict(config)
|
| 69 |
+
|
| 70 |
+
# Skip all Nones so that default values are used
|
| 71 |
+
fields = {k: v for k, v in fields.items() if v is not None}
|
| 72 |
+
|
| 73 |
+
return cls(**fields)
|
| 74 |
+
|
| 75 |
+
@abstractmethod
|
| 76 |
+
def forward(
|
| 77 |
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
|
| 78 |
+
) -> torch.Tensor:
|
| 79 |
+
raise NotImplementedError
|
| 80 |
+
|
| 81 |
+
@staticmethod
|
| 82 |
+
def _maybe_pad_sequence(x: torch.Tensor, mask: torch.Tensor):
|
| 83 |
+
"""
|
| 84 |
+
If the sequence is shorter than the mask, return a padded view
|
| 85 |
+
"""
|
| 86 |
+
if x.shape[-2] != mask.shape[-1]:
|
| 87 |
+
assert x.shape[-2] < mask.shape[-1], (
|
| 88 |
+
"Sequence is bigger than the provided mask, cannot infer what to do with it."
|
| 89 |
+
" Please update your attention mask"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
pad_size = (0, 0, 0, mask.shape[-1] - x.shape[-2], 0, 0)
|
| 93 |
+
return torch.nn.functional.pad(x, pad_size, mode="constant", value=0.0)
|
| 94 |
+
|
| 95 |
+
return x
|
.venv/lib/python3.11/site-packages/xformers/components/attention/compositional.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Credits: this is heavily inspired by the official implementation, present in
|
| 8 |
+
# https://github.com/sarthmit/Compositional-Attention
|
| 9 |
+
# Original author: Sarthak Mittal
|
| 10 |
+
|
| 11 |
+
# This is a simplified version, for the sake of clarity, and because some features could be exposed later
|
| 12 |
+
# via the library directly.
|
| 13 |
+
# In particular, code paths for TPUs, quantization and gumbel softmax have been removed
|
| 14 |
+
# We're also following the same dimension ordering as in the rest of the xformers library
|
| 15 |
+
# which is to say [Batch, Sequence, Embedding] wherever possible
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from torch import Tensor, nn
|
| 24 |
+
|
| 25 |
+
from xformers.components.attention import (
|
| 26 |
+
Attention,
|
| 27 |
+
AttentionConfig,
|
| 28 |
+
AttentionMask,
|
| 29 |
+
register_attention,
|
| 30 |
+
)
|
| 31 |
+
from xformers.components.attention.core import _softmax
|
| 32 |
+
from xformers.components.input_projection import InputProjection, InputProjectionConfig
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _either_or(a: Optional[int], b: int) -> int:
|
| 36 |
+
return a if a is not None else b
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class CompositionalAttentionConfig(AttentionConfig):
|
| 41 |
+
dim_model: int
|
| 42 |
+
num_heads: int
|
| 43 |
+
dim_attn: Optional[int] = None
|
| 44 |
+
num_rules: Optional[int] = None
|
| 45 |
+
dim_key: Optional[int] = None
|
| 46 |
+
dim_value: Optional[int] = None
|
| 47 |
+
dim_selection: Optional[int] = None
|
| 48 |
+
dropout: float
|
| 49 |
+
qk_rule: bool = False
|
| 50 |
+
nonlinear: bool = False
|
| 51 |
+
q_compose: bool = False
|
| 52 |
+
bias: bool = True
|
| 53 |
+
causal: Optional[bool] = False
|
| 54 |
+
in_proj_container: Optional[InputProjection] = None
|
| 55 |
+
use_separate_proj_weight: Optional[bool] = False
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@register_attention("compositional", CompositionalAttentionConfig)
|
| 59 |
+
class CompositionalAttention(Attention):
|
| 60 |
+
"""Compositional Attention, as proposed in
|
| 61 |
+
"Compositional Attention: Disentangling search and retrieval"_, S. Mittal et al.
|
| 62 |
+
|
| 63 |
+
A key insight from this proposal is that the attention mechanism can be conceived as two steps:
|
| 64 |
+
a search and a retrieval operation. When queried, the model can search for the most relevant information
|
| 65 |
+
(Softmax(QKt)), then retrieve information given the Value.
|
| 66 |
+
|
| 67 |
+
Contrary to the original attention proposal, which does not consider interactions in between heads,
|
| 68 |
+
the compositional attention will consider all possible interactions and softmax over that dimension,
|
| 69 |
+
so that the information retrieved covers the most relevant dimensions. The number of heads and rules to
|
| 70 |
+
use is thus typically smaller than for a comparable traditional Transformer, and asking for the same number of heads
|
| 71 |
+
may not fit in memory.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
dim_model: dimension of the incoming latent space
|
| 75 |
+
num_heads: number of heads *for the search operation*
|
| 76 |
+
dim_attn: dimension (embedding) of the attention
|
| 77 |
+
num_rules: number of rules to consider *for the retrieval operation*
|
| 78 |
+
dim_selection: dimension of the scoring/selection space for the retrievals
|
| 79 |
+
dim_key, dim_value: dimensions of K and V, if different from Q
|
| 80 |
+
dropout: attention dropout probability
|
| 81 |
+
qk_rule: QK product will drive the retrieval process
|
| 82 |
+
nonlinear: use a non linear method to score the retrievals
|
| 83 |
+
bias: use bias in the initial projection step
|
| 84 |
+
causal: causal computations (attend to the past only)
|
| 85 |
+
|
| 86 |
+
_"Compositional Attention: Disentangling search and retrieval": https://arxiv.org/pdf/2110.09419v1.pdf
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
dim_model: int,
|
| 92 |
+
num_heads: int,
|
| 93 |
+
dim_attn: Optional[int] = None,
|
| 94 |
+
num_rules: Optional[int] = None,
|
| 95 |
+
dim_selection: Optional[int] = None,
|
| 96 |
+
dim_key: Optional[int] = None,
|
| 97 |
+
dim_value: Optional[int] = None,
|
| 98 |
+
dropout=0.0,
|
| 99 |
+
qk_rule=False,
|
| 100 |
+
nonlinear=False,
|
| 101 |
+
q_compose=False,
|
| 102 |
+
in_proj_container: Optional[InputProjection] = None,
|
| 103 |
+
use_separate_proj_weight: Optional[bool] = False,
|
| 104 |
+
bias=True,
|
| 105 |
+
causal=False,
|
| 106 |
+
*_,
|
| 107 |
+
**__,
|
| 108 |
+
):
|
| 109 |
+
super().__init__()
|
| 110 |
+
|
| 111 |
+
# Define the inherited flags
|
| 112 |
+
self.requires_skip_multi_head = (
|
| 113 |
+
True # This attention owns the multi-head mechanism
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Handle defaults / undefined values
|
| 117 |
+
self.dim_model = dim_model
|
| 118 |
+
num_rules = _either_or(num_rules, num_heads)
|
| 119 |
+
dim_selection = _either_or(dim_selection, dim_model // num_heads)
|
| 120 |
+
|
| 121 |
+
# All the initial definition plumbing
|
| 122 |
+
dim_attn = _either_or(dim_attn, dim_model)
|
| 123 |
+
dim_key = _either_or(dim_key, dim_model)
|
| 124 |
+
dim_value = _either_or(dim_value, dim_model)
|
| 125 |
+
|
| 126 |
+
self.in_proj_container = (
|
| 127 |
+
in_proj_container
|
| 128 |
+
if in_proj_container is not None
|
| 129 |
+
else InputProjection(
|
| 130 |
+
query_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias),
|
| 131 |
+
key_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias)
|
| 132 |
+
if use_separate_proj_weight
|
| 133 |
+
else None,
|
| 134 |
+
value_proj_params=InputProjectionConfig(dim_model, dim_value, bias=bias)
|
| 135 |
+
if use_separate_proj_weight
|
| 136 |
+
else None,
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
self.num_heads = num_heads
|
| 141 |
+
self.num_rules = num_rules
|
| 142 |
+
self.qk_rule = qk_rule
|
| 143 |
+
self.dim_selection = dim_selection
|
| 144 |
+
self.nonlinear = nonlinear
|
| 145 |
+
self.q_compose = q_compose
|
| 146 |
+
|
| 147 |
+
self.dropout_module = nn.Dropout(dropout)
|
| 148 |
+
self.dim_head = dim_model // num_heads
|
| 149 |
+
self.value_dim = dim_attn // num_rules
|
| 150 |
+
|
| 151 |
+
assert (
|
| 152 |
+
self.value_dim * num_rules == dim_attn
|
| 153 |
+
), "value_dim must be divisible by num_rules"
|
| 154 |
+
|
| 155 |
+
self.scaling = self.dim_head**-0.5
|
| 156 |
+
self.scaling_values = self.dim_selection**-0.5
|
| 157 |
+
|
| 158 |
+
self.out_proj = nn.Linear(self.num_heads * self.value_dim, dim_model, bias=bias)
|
| 159 |
+
|
| 160 |
+
if self.qk_rule:
|
| 161 |
+
self.value_k = nn.Linear(self.value_dim, self.dim_selection, bias=bias)
|
| 162 |
+
if self.q_compose:
|
| 163 |
+
self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias)
|
| 164 |
+
else:
|
| 165 |
+
self.value_q = nn.Linear(
|
| 166 |
+
dim_model, self.dim_selection * self.num_heads, bias=bias
|
| 167 |
+
)
|
| 168 |
+
else:
|
| 169 |
+
if self.q_compose:
|
| 170 |
+
self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias)
|
| 171 |
+
else:
|
| 172 |
+
self.value_q = nn.Linear(
|
| 173 |
+
dim_model, self.dim_selection * self.num_heads, bias=bias
|
| 174 |
+
)
|
| 175 |
+
if self.nonlinear:
|
| 176 |
+
self.score_network: nn.Module = nn.Sequential(
|
| 177 |
+
nn.Linear(
|
| 178 |
+
self.dim_selection + self.value_dim,
|
| 179 |
+
self.dim_selection,
|
| 180 |
+
bias=bias,
|
| 181 |
+
),
|
| 182 |
+
nn.ReLU(),
|
| 183 |
+
nn.Linear(self.dim_selection, 1, bias=bias),
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
self.score_network = nn.Linear(
|
| 187 |
+
self.dim_selection + self.value_dim, 1, bias=bias
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
self.causal = causal
|
| 191 |
+
|
| 192 |
+
# Properties specific to this attention mechanism
|
| 193 |
+
self.supports_attention_mask = True
|
| 194 |
+
self.supports_key_padding_mask = False
|
| 195 |
+
|
| 196 |
+
self._reset_parameters()
|
| 197 |
+
|
| 198 |
+
def _reset_parameters(self):
|
| 199 |
+
# NOTE: in_proj_container is already initialized
|
| 200 |
+
|
| 201 |
+
if self.qk_rule:
|
| 202 |
+
nn.init.xavier_uniform_(self.value_k.weight, gain=1 / math.sqrt(2))
|
| 203 |
+
nn.init.xavier_uniform_(self.value_q.weight, gain=1 / math.sqrt(2))
|
| 204 |
+
else:
|
| 205 |
+
nn.init.xavier_uniform_(self.value_q.weight)
|
| 206 |
+
if self.nonlinear:
|
| 207 |
+
nn.init.xavier_uniform_(self.score_network[0].weight)
|
| 208 |
+
nn.init.xavier_uniform_(self.score_network[2].weight)
|
| 209 |
+
else:
|
| 210 |
+
nn.init.xavier_uniform_(self.score_network.weight)
|
| 211 |
+
|
| 212 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 213 |
+
if self.out_proj.bias is not None:
|
| 214 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
| 215 |
+
|
| 216 |
+
def forward(
|
| 217 |
+
self,
|
| 218 |
+
q: Tensor,
|
| 219 |
+
k: Tensor,
|
| 220 |
+
v: Tensor,
|
| 221 |
+
att_mask: Optional[Tensor] = None,
|
| 222 |
+
*args,
|
| 223 |
+
**kwargs,
|
| 224 |
+
) -> Tensor:
|
| 225 |
+
"""
|
| 226 |
+
Input shape: Time x Batch x Channel
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
att_mask (ByteTensor, optional): typically used to
|
| 230 |
+
implement causal attention, where the mask prevents the
|
| 231 |
+
attention from looking forward in time (default: None).
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
B, Sq, E = q.shape
|
| 235 |
+
_, Sk, _ = k.shape
|
| 236 |
+
|
| 237 |
+
assert E == self.dim_model
|
| 238 |
+
|
| 239 |
+
# First define projected query/key/values
|
| 240 |
+
# We keep the projected and original tensors in flight,
|
| 241 |
+
# depending on the options the original values could be reused
|
| 242 |
+
q_unprojected = q
|
| 243 |
+
q, k, v = self.in_proj_container(query=q, key=k, value=v)
|
| 244 |
+
q *= self.scaling
|
| 245 |
+
|
| 246 |
+
# Init causal mask if needed, now that we know the context length
|
| 247 |
+
if self.causal and (
|
| 248 |
+
self._causal_mask is None or self._causal_mask.shape[0] != Sk
|
| 249 |
+
):
|
| 250 |
+
self._causal_mask = AttentionMask.make_causal(Sq, Sq, device=q.device)
|
| 251 |
+
|
| 252 |
+
# Convenience, create an attention mask if a tensor was passed
|
| 253 |
+
# This sanitizes different mask types being passed, from now on it's additive
|
| 254 |
+
if isinstance(att_mask, torch.Tensor):
|
| 255 |
+
# By default we don't know of the causality, and a check would be expensive
|
| 256 |
+
att_mask_additive: Optional[AttentionMask] = (
|
| 257 |
+
AttentionMask.from_bool(att_mask)
|
| 258 |
+
if att_mask.dtype == torch.bool
|
| 259 |
+
else AttentionMask(att_mask, is_causal=False)
|
| 260 |
+
)
|
| 261 |
+
else:
|
| 262 |
+
att_mask_additive = None
|
| 263 |
+
|
| 264 |
+
# Handle the attention and key padding masks
|
| 265 |
+
if self._causal_mask is not None:
|
| 266 |
+
# Optionally add the causal mask
|
| 267 |
+
if att_mask_additive is not None:
|
| 268 |
+
att_mask_additive += self._causal_mask
|
| 269 |
+
else:
|
| 270 |
+
att_mask_additive = self._causal_mask
|
| 271 |
+
|
| 272 |
+
# Flatten the heads or the rules
|
| 273 |
+
q = (
|
| 274 |
+
q.view(B, Sq, self.num_heads, self.dim_head)
|
| 275 |
+
.movedim(2, 1)
|
| 276 |
+
.flatten(0, 1) # [B * num_heads, Sq, dim_head]
|
| 277 |
+
)
|
| 278 |
+
k = (
|
| 279 |
+
k.view(B, Sk, self.num_heads, self.dim_head).movedim(2, 1).flatten(0, 1)
|
| 280 |
+
) # [B * num_heads, Sk, dim_head]
|
| 281 |
+
v = v.view(B, -1, self.num_rules, self.value_dim).movedim(2, 1).flatten(0, 1)
|
| 282 |
+
|
| 283 |
+
# Compute the search: Softmax(QKt)
|
| 284 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2)) # [B * self.num_heads, Sq, Sk]
|
| 285 |
+
|
| 286 |
+
if att_mask_additive is not None:
|
| 287 |
+
attn_weights += att_mask_additive.values
|
| 288 |
+
|
| 289 |
+
attn_weights = _softmax(attn_weights, causal=self.causal)
|
| 290 |
+
|
| 291 |
+
attn_weights = attn_weights.view(B, self.num_heads, Sq, Sk)
|
| 292 |
+
attn_probs = self.dropout_module(attn_weights)
|
| 293 |
+
|
| 294 |
+
# Now compute the information retrieval
|
| 295 |
+
# keep all the heads in flight, we'll score the different possibilities
|
| 296 |
+
# - compute all the possible retrievals
|
| 297 |
+
v = v.view(B, 1, self.num_rules, Sk, self.value_dim)
|
| 298 |
+
attn_probs = attn_probs.unsqueeze(2)
|
| 299 |
+
attn = torch.matmul(attn_probs, v).view(
|
| 300 |
+
B, self.num_heads, self.num_rules, Sq, self.value_dim
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
attn = attn.movedim(3, 1) # [B, Sq, H, Rules, Values]
|
| 304 |
+
|
| 305 |
+
# - search the most appropriate retrieval among all the values
|
| 306 |
+
if self.q_compose:
|
| 307 |
+
v_q = self.value_q(q.transpose(0, 1)).view(
|
| 308 |
+
B, Sq, self.num_heads, 1, self.dim_selection
|
| 309 |
+
)
|
| 310 |
+
else:
|
| 311 |
+
v_q = self.value_q(q_unprojected).view(
|
| 312 |
+
B, Sq, self.num_heads, 1, self.dim_selection
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
if self.qk_rule:
|
| 316 |
+
v_q *= self.scaling_values
|
| 317 |
+
v_k = (
|
| 318 |
+
self.value_k(attn)
|
| 319 |
+
.view(B, Sq, self.num_heads, self.num_rules, self.dim_selection)
|
| 320 |
+
.transpose(4, 3)
|
| 321 |
+
.contiguous()
|
| 322 |
+
)
|
| 323 |
+
v_score = torch.matmul(v_q, v_k).view(
|
| 324 |
+
B, Sq, self.num_heads, self.num_rules, 1
|
| 325 |
+
)
|
| 326 |
+
else:
|
| 327 |
+
v_q = v_q.expand(-1, -1, -1, self.num_rules, -1)
|
| 328 |
+
v_in = torch.cat([attn, v_q], dim=-1)
|
| 329 |
+
v_score = self.score_network(v_in).view(
|
| 330 |
+
B, Sq, self.num_heads, self.num_rules, 1
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
v_score = F.softmax(v_score, dim=3)
|
| 334 |
+
|
| 335 |
+
# - extracted values are the original attention (inc. all the values) weighted by value score
|
| 336 |
+
attn = (attn * v_score).sum(dim=3).view(B, Sq, self.num_heads * self.value_dim)
|
| 337 |
+
|
| 338 |
+
# Final attention projection, same as other mechanisms
|
| 339 |
+
attn = self.out_proj(attn)
|
| 340 |
+
|
| 341 |
+
return attn
|
.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from enum import Enum
|
| 8 |
+
|
| 9 |
+
from .base import FeatureMap, FeatureMapConfig
|
| 10 |
+
from .softmax import NormDistribution, SMHyperbolic, SMOrf, SMReg
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FeatureMapType(str, Enum):
|
| 14 |
+
SMOrf = "sm_orf"
|
| 15 |
+
SMHyp = "sm_hyp"
|
| 16 |
+
SMReg = "sm_reg" # regularized softmax kernel
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"SMOrf",
|
| 21 |
+
"SMReg",
|
| 22 |
+
"SMHyperbolic",
|
| 23 |
+
"NormDistribution",
|
| 24 |
+
"FeatureMapConfig",
|
| 25 |
+
"FeatureMap",
|
| 26 |
+
]
|
.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (865 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (3.02 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/__pycache__/softmax.cpython-311.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/base.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from abc import abstractmethod
|
| 8 |
+
from dataclasses import asdict, dataclass
|
| 9 |
+
from typing import Optional, Type, TypeVar
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
Feature maps allow for a given query or key to be encoded in a different space.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
Self = TypeVar("Self", bound="FeatureMap")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class FeatureMapConfig:
|
| 22 |
+
name: str
|
| 23 |
+
dim_features: int
|
| 24 |
+
iter_before_redraw: Optional[int]
|
| 25 |
+
normalize_inputs: Optional[bool]
|
| 26 |
+
epsilon: Optional[float]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class FeatureMap(torch.nn.Module):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
dim_features: int,
|
| 33 |
+
iter_before_redraw: Optional[int] = None,
|
| 34 |
+
normalize_inputs: bool = False,
|
| 35 |
+
epsilon: float = 1e-6,
|
| 36 |
+
):
|
| 37 |
+
super().__init__()
|
| 38 |
+
|
| 39 |
+
self.dim_features = dim_features
|
| 40 |
+
self.dim_feature_map = dim_features
|
| 41 |
+
|
| 42 |
+
self.iter_before_redraw = iter_before_redraw
|
| 43 |
+
self.features: Optional[torch.Tensor] = None
|
| 44 |
+
self.epsilon = epsilon
|
| 45 |
+
self.normalize_inputs = normalize_inputs
|
| 46 |
+
|
| 47 |
+
self._iter_counter = 0
|
| 48 |
+
|
| 49 |
+
@abstractmethod
|
| 50 |
+
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
|
| 51 |
+
raise NotImplementedError()
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def from_config(cls: Type[Self], config: FeatureMapConfig) -> Self:
|
| 55 |
+
# Generate the class inputs from the config
|
| 56 |
+
fields = asdict(config)
|
| 57 |
+
|
| 58 |
+
# Skip all Nones so that default values are used
|
| 59 |
+
fields = {k: v for k, v in fields.items() if v is not None}
|
| 60 |
+
|
| 61 |
+
return cls(**fields)
|
.venv/lib/python3.11/site-packages/xformers/components/attention/feature_maps/softmax.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from enum import Enum, auto
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch.autograd.profiler import record_function
|
| 13 |
+
|
| 14 |
+
from .base import FeatureMap
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
A set of feature maps which approximate the softmax kernel, as per the Performers_ paper.
|
| 18 |
+
|
| 19 |
+
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
| 20 |
+
https://arxiv.org/pdf/2009.14794v1.pdf
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class NormDistribution(Enum):
|
| 25 |
+
Xi = auto()
|
| 26 |
+
Uniform = auto()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SoftMaxPositiveEstimators(FeatureMap):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
dim_features: int,
|
| 33 |
+
iter_before_redraw: Optional[int],
|
| 34 |
+
normalize_inputs: bool = False,
|
| 35 |
+
epsilon: float = 1e-6,
|
| 36 |
+
softmax_temp: float = -1,
|
| 37 |
+
):
|
| 38 |
+
super().__init__(dim_features, iter_before_redraw, normalize_inputs, epsilon)
|
| 39 |
+
self.softmax_temp = softmax_temp
|
| 40 |
+
|
| 41 |
+
# Handle the scaling from all kernels by √m.
|
| 42 |
+
# This normalizes for all the feature maps involved
|
| 43 |
+
self.h_scale = math.log(math.sqrt(self.dim_features))
|
| 44 |
+
|
| 45 |
+
def pre_scale(self, x: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
with record_function("feature_map::pre_scale"):
|
| 47 |
+
# Re-draw counting logic
|
| 48 |
+
if (
|
| 49 |
+
(
|
| 50 |
+
self.iter_before_redraw is not None
|
| 51 |
+
and self._iter_counter > self.iter_before_redraw
|
| 52 |
+
)
|
| 53 |
+
or self.features is None
|
| 54 |
+
or self.features.device != x.device
|
| 55 |
+
):
|
| 56 |
+
# The feature map is actually using half the dimension, we'll concatenate + and - features
|
| 57 |
+
self._iter_counter = 1
|
| 58 |
+
self.features = self._get_feature_map(
|
| 59 |
+
x.shape[-1], self.dim_feature_map, x.device
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
features = self.features
|
| 63 |
+
assert features is not None
|
| 64 |
+
|
| 65 |
+
if features.dtype != x.dtype:
|
| 66 |
+
self.features = features.to(x.dtype)
|
| 67 |
+
|
| 68 |
+
self._iter_counter += 1
|
| 69 |
+
|
| 70 |
+
# Normalization / softmax
|
| 71 |
+
if self.softmax_temp < 0:
|
| 72 |
+
# A = exp(QK.t/√d), so each input will be scaled by √√d
|
| 73 |
+
self.softmax_temp = x.shape[-1] ** -0.25
|
| 74 |
+
|
| 75 |
+
x_scaled = x * self.softmax_temp
|
| 76 |
+
|
| 77 |
+
# Compute the scaling factors in logspace, applied from within the exponential
|
| 78 |
+
# - dimnish possible exponential overflow
|
| 79 |
+
# - remove a multiply across the batch, replace by an addition
|
| 80 |
+
norm_x_2 = torch.einsum("...d,...d->...", x_scaled, x_scaled).unsqueeze(-1)
|
| 81 |
+
self.offset = -0.5 * norm_x_2 - self.h_scale + self.epsilon
|
| 82 |
+
|
| 83 |
+
if self.normalize_inputs:
|
| 84 |
+
# L0 normalize the exponential term, can be useful for numerical stability
|
| 85 |
+
# This ensures that features +- offset is below 1
|
| 86 |
+
self.offset -= norm_x_2.max(1, keepdim=True)[0]
|
| 87 |
+
|
| 88 |
+
# Return the scaled inputs, the rest depends on the kernel being used
|
| 89 |
+
return x_scaled
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
@torch.no_grad()
|
| 93 |
+
def _get_random_ortho_matrix(
|
| 94 |
+
blocks: int,
|
| 95 |
+
dim: int,
|
| 96 |
+
device: torch.device,
|
| 97 |
+
norm_distribution: NormDistribution = NormDistribution.Uniform,
|
| 98 |
+
) -> torch.Tensor:
|
| 99 |
+
r"""
|
| 100 |
+
Generate a random matrix whose rows are exactly orthonormal
|
| 101 |
+
|
| 102 |
+
"How to generate random matrices from the classical compact groups", Mezzadri, 2007
|
| 103 |
+
https://arxiv.org/pdf/math-ph/0609050v2.pdf
|
| 104 |
+
|
| 105 |
+
.. note: the typical qr decomposition does not give uniform results, qr decomposition is not
|
| 106 |
+
unique and the qr decomposition routines are biased towards numerical stability. See the above
|
| 107 |
+
paper for more information.
|
| 108 |
+
|
| 109 |
+
.. note: this does not follow the original implementation from the Performers authors.
|
| 110 |
+
see docs/assets/kde plots to visualize the impact of using the R signs to correct Q
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
H = torch.randn((blocks, dim, dim), device=device, requires_grad=False)
|
| 114 |
+
|
| 115 |
+
# Randomly scale the norms of the features, Xi distributed
|
| 116 |
+
if norm_distribution == NormDistribution.Xi:
|
| 117 |
+
# NOTE: This averages to sqrt(d)
|
| 118 |
+
norms = torch.sqrt(torch.einsum("...d,...d->...", H, H))
|
| 119 |
+
|
| 120 |
+
Q, R = torch.linalg.qr(H)
|
| 121 |
+
Q = torch.diag_embed(torch.sign(torch.diagonal(R, dim1=1, dim2=2))) @ Q
|
| 122 |
+
|
| 123 |
+
# Normalize if need be. Uniform NormDistribution does nothing, Q is already orthonormal
|
| 124 |
+
if norm_distribution == NormDistribution.Xi:
|
| 125 |
+
return torch.diag_embed(norms) @ Q
|
| 126 |
+
|
| 127 |
+
return Q
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class SMOrf(SoftMaxPositiveEstimators):
|
| 131 |
+
"""
|
| 132 |
+
"Positive random orthogonal features" softmax estimator,
|
| 133 |
+
SM_ort^m+, as proposed in the Performers_ paper, Lemma 1.
|
| 134 |
+
|
| 135 |
+
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
| 136 |
+
https://arxiv.org/pdf/2009.14794v1.pdf
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
@torch.no_grad()
|
| 140 |
+
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
|
| 141 |
+
"""
|
| 142 |
+
Generate the projection matrix onto the random features
|
| 143 |
+
|
| 144 |
+
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
|
| 145 |
+
and not uniformally random.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
# Get per block random unitary matrices.
|
| 149 |
+
# We need enough of them to project the whole input dimension, regardless of the
|
| 150 |
+
# requested dimension of the features
|
| 151 |
+
features = self._get_random_ortho_matrix(
|
| 152 |
+
math.ceil(dim_input / dim_features),
|
| 153 |
+
dim_features,
|
| 154 |
+
norm_distribution=NormDistribution.Xi,
|
| 155 |
+
device=device,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
return features.flatten(0, 1)[:dim_input]
|
| 159 |
+
|
| 160 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 161 |
+
# Softmax-dimension related scaling, shared for all kernels
|
| 162 |
+
x_scaled = super().pre_scale(x)
|
| 163 |
+
assert self.features is not None
|
| 164 |
+
|
| 165 |
+
# Project onto the random feature map.
|
| 166 |
+
x_scaled = x_scaled @ self.features
|
| 167 |
+
return torch.exp(x_scaled + self.offset)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class SMHyperbolic(SoftMaxPositiveEstimators):
|
| 171 |
+
"""
|
| 172 |
+
"Positive random features hyperbolic" estimator, SMHyp+,
|
| 173 |
+
as proposed in the Performers_ paper, Lemma 1.
|
| 174 |
+
|
| 175 |
+
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
| 176 |
+
https://arxiv.org/pdf/2009.14794v1.pdf
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def __init__(
|
| 180 |
+
self,
|
| 181 |
+
dim_features: int,
|
| 182 |
+
iter_before_redraw: Optional[int],
|
| 183 |
+
normalize_inputs: bool = False,
|
| 184 |
+
epsilon: float = 1e-6,
|
| 185 |
+
softmax_temp: float = -1,
|
| 186 |
+
):
|
| 187 |
+
super().__init__(
|
| 188 |
+
dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
assert (
|
| 192 |
+
dim_features % 2 == 0
|
| 193 |
+
), "The feature dimension needs to be even with this kernel"
|
| 194 |
+
self.dim_feature_map = self.dim_features // 2
|
| 195 |
+
|
| 196 |
+
@torch.no_grad()
|
| 197 |
+
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
|
| 198 |
+
"""
|
| 199 |
+
Generate the projection matrix onto the random features
|
| 200 |
+
|
| 201 |
+
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
|
| 202 |
+
and not uniformally random.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
# Get per block random unitary matrices.
|
| 206 |
+
# We need enough of them to project the whole input dimension, regardless of the
|
| 207 |
+
# requested dimension of the features
|
| 208 |
+
features = self._get_random_ortho_matrix(
|
| 209 |
+
math.ceil(dim_input / dim_features),
|
| 210 |
+
dim_features,
|
| 211 |
+
norm_distribution=NormDistribution.Xi,
|
| 212 |
+
device=device,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
return features.flatten(0, 1)[:dim_input]
|
| 216 |
+
|
| 217 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 218 |
+
# Softmax-dimension related scaling, shared for all kernels
|
| 219 |
+
x_scaled = super().pre_scale(x)
|
| 220 |
+
|
| 221 |
+
# Project onto the random feature map, concatenate both + and - results
|
| 222 |
+
# This follows Lemma 1 in the original Performers Paper to best approximate a
|
| 223 |
+
# softmax kernel (cosh representation)
|
| 224 |
+
x_scaled = x_scaled @ self.features
|
| 225 |
+
return torch.cat(
|
| 226 |
+
[torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
|
| 227 |
+
dim=-1,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class SMReg(SoftMaxPositiveEstimators):
|
| 232 |
+
"""
|
| 233 |
+
"Regularized softmax kernel" estimator, SMREG+, as proposed in the Performers_ paper.
|
| 234 |
+
|
| 235 |
+
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
| 236 |
+
https://arxiv.org/pdf/2009.14794v1.pdf
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
dim_features: int,
|
| 242 |
+
iter_before_redraw: Optional[int],
|
| 243 |
+
normalize_inputs: bool = False,
|
| 244 |
+
epsilon: float = 1e-6,
|
| 245 |
+
softmax_temp: float = -1,
|
| 246 |
+
):
|
| 247 |
+
super().__init__(
|
| 248 |
+
dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
assert (
|
| 252 |
+
dim_features % 2 == 0
|
| 253 |
+
), "The feature dimension needs to be even with this kernel"
|
| 254 |
+
self.dim_feature_map = self.dim_features // 2
|
| 255 |
+
|
| 256 |
+
@torch.no_grad()
|
| 257 |
+
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
|
| 258 |
+
"""
|
| 259 |
+
Generate the projection matrix onto the random features
|
| 260 |
+
|
| 261 |
+
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
|
| 262 |
+
and not uniformally random.
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
# Get per block random unitary matrices.
|
| 266 |
+
# We need enough of them to project the whole input dimension, regardless of the
|
| 267 |
+
# requested dimension of the features
|
| 268 |
+
features = self._get_random_ortho_matrix(
|
| 269 |
+
math.ceil(dim_input / dim_features),
|
| 270 |
+
dim_features,
|
| 271 |
+
norm_distribution=NormDistribution.Uniform,
|
| 272 |
+
device=device,
|
| 273 |
+
).flatten(0, 1)
|
| 274 |
+
norms = math.sqrt(dim_input) * torch.ones(features.shape[0], device=device)
|
| 275 |
+
return (torch.diag(norms) @ features)[:dim_input]
|
| 276 |
+
|
| 277 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 278 |
+
# Softmax-dimension related scaling, shared for all kernels
|
| 279 |
+
x_scaled = super().pre_scale(x)
|
| 280 |
+
|
| 281 |
+
# Project onto the random feature map, concatenate both + and - results
|
| 282 |
+
# This follows Lemma 1 in the original Performers Paper to best approximate a
|
| 283 |
+
# softmax kernel (cosh representation + sample regularization)
|
| 284 |
+
x_scaled = x_scaled @ self.features
|
| 285 |
+
return torch.cat(
|
| 286 |
+
[torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
|
| 287 |
+
dim=-1,
|
| 288 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/components/attention/global_tokens.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Optional, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from xformers.components.attention import (
|
| 14 |
+
Attention,
|
| 15 |
+
AttentionConfig,
|
| 16 |
+
AttentionMask,
|
| 17 |
+
maybe_sparsify,
|
| 18 |
+
register_attention,
|
| 19 |
+
sparsify,
|
| 20 |
+
)
|
| 21 |
+
from xformers.components.attention.attention_patterns import (
|
| 22 |
+
causal_1d_pattern,
|
| 23 |
+
global_token_pattern,
|
| 24 |
+
)
|
| 25 |
+
from xformers.components.attention.core import scaled_dot_product_attention
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class GlobalAttentionConfig(AttentionConfig):
|
| 30 |
+
attention_query_mask: torch.Tensor # Mark the queries which have global attention
|
| 31 |
+
causal: Optional[bool]
|
| 32 |
+
force_sparsity: Optional[bool]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@register_attention("global", GlobalAttentionConfig)
|
| 36 |
+
class GlobalAttention(Attention):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
dropout: float,
|
| 40 |
+
attention_query_mask: torch.Tensor,
|
| 41 |
+
causal: bool = False,
|
| 42 |
+
force_sparsity: bool = False,
|
| 43 |
+
*_,
|
| 44 |
+
**__,
|
| 45 |
+
):
|
| 46 |
+
r"""
|
| 47 |
+
Global attention, as proposed for instance in BigBird_ or Longformer_.
|
| 48 |
+
|
| 49 |
+
Global means in that case that the queries positively labelled in the ```attention_query_mask``` can attend
|
| 50 |
+
to all the other queries. The queries negatively labelled in the ```attention_query_mask``` cannot attend to
|
| 51 |
+
any other query.
|
| 52 |
+
|
| 53 |
+
This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
dropout (float): probability of an element to be zeroed
|
| 57 |
+
attention_query_mask (torch.Tensor): if true, this query can attend to all the others
|
| 58 |
+
|
| 59 |
+
"""
|
| 60 |
+
super().__init__()
|
| 61 |
+
|
| 62 |
+
assert attention_query_mask.dtype == torch.bool, "A boolean mask is expected"
|
| 63 |
+
assert (
|
| 64 |
+
attention_query_mask.shape[1] == 1
|
| 65 |
+
and attention_query_mask.shape[0] > attention_query_mask.shape[1]
|
| 66 |
+
), "A N x 1 query mask is expected"
|
| 67 |
+
|
| 68 |
+
self.attn_drop = nn.Dropout(dropout, inplace=False)
|
| 69 |
+
self.attention_mask = global_token_pattern(attention_query_mask[:, 0])
|
| 70 |
+
self.force_sparsity = force_sparsity
|
| 71 |
+
|
| 72 |
+
if causal:
|
| 73 |
+
self.attention_mask &= causal_1d_pattern(attention_query_mask.shape[1])
|
| 74 |
+
|
| 75 |
+
self.attention_mask = (
|
| 76 |
+
sparsify(self.attention_mask)
|
| 77 |
+
if self.force_sparsity
|
| 78 |
+
else maybe_sparsify(self.attention_mask)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Properties specific to this attention mechanism
|
| 82 |
+
self.requires_same_k_q_dimensions = True
|
| 83 |
+
self.supports_attention_mask = False
|
| 84 |
+
self.supports_key_padding_mask = False
|
| 85 |
+
|
| 86 |
+
def forward(
|
| 87 |
+
self,
|
| 88 |
+
q: torch.Tensor,
|
| 89 |
+
k: torch.Tensor,
|
| 90 |
+
v: torch.Tensor,
|
| 91 |
+
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
|
| 92 |
+
*_,
|
| 93 |
+
**__,
|
| 94 |
+
):
|
| 95 |
+
# Make sure that the mask is on the right device
|
| 96 |
+
if self.attention_mask.device != q.device:
|
| 97 |
+
self.attention_mask = self.attention_mask.to(q.device)
|
| 98 |
+
|
| 99 |
+
# Mask-aware attention
|
| 100 |
+
if att_mask is not None:
|
| 101 |
+
if att_mask.dtype == torch.bool and isinstance(
|
| 102 |
+
self.attention_mask, AttentionMask
|
| 103 |
+
):
|
| 104 |
+
if not isinstance(att_mask, AttentionMask):
|
| 105 |
+
att_mask = AttentionMask.from_bool(att_mask)
|
| 106 |
+
mask = self.attention_mask + att_mask
|
| 107 |
+
else:
|
| 108 |
+
mask = self.attention_mask & att_mask
|
| 109 |
+
else:
|
| 110 |
+
mask = self.attention_mask
|
| 111 |
+
|
| 112 |
+
# Handle q/k/v which would not fit the mask
|
| 113 |
+
seq_len = q.shape[-2]
|
| 114 |
+
q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v))
|
| 115 |
+
|
| 116 |
+
# Normal attention with the global tokens mask
|
| 117 |
+
att = scaled_dot_product_attention(
|
| 118 |
+
q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Take into account an hypothetical padding
|
| 122 |
+
return att[:, :seq_len, :]
|
.venv/lib/python3.11/site-packages/xformers/components/attention/linformer.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
| 14 |
+
from xformers.components.attention.core import scaled_dot_product_attention
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class LinformerSelfAttentionConfig(AttentionConfig):
|
| 19 |
+
seq_len: int # dimension of the input sequence
|
| 20 |
+
k: Optional[int] # dimension of the internal space
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@register_attention("linformer", LinformerSelfAttentionConfig)
|
| 24 |
+
class LinformerAttention(Attention):
|
| 25 |
+
def __init__(
|
| 26 |
+
self, dropout: float, seq_len: int, k: Optional[int] = None, *args, **kwargs
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Linformer attention mechanism,
|
| 30 |
+
from `Linformer: Self-Attention with Linear Complexity`_, Wang et al (2020).
|
| 31 |
+
The original notation is kept as is.
|
| 32 |
+
|
| 33 |
+
.. _`Linformer: Self-Attention with Linear Complexity` : https://arxiv.org/abs/2006.04768v2
|
| 34 |
+
"""
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
if k is None:
|
| 38 |
+
k = seq_len // 4
|
| 39 |
+
|
| 40 |
+
self.k = k
|
| 41 |
+
self.E = nn.Linear(seq_len, k, bias=False)
|
| 42 |
+
self.F = nn.Linear(seq_len, k, bias=False)
|
| 43 |
+
self.attn_drop = nn.Dropout(dropout, inplace=False)
|
| 44 |
+
self.seq_len = seq_len
|
| 45 |
+
|
| 46 |
+
# MHA related flags:
|
| 47 |
+
# kq need to have the same dimension
|
| 48 |
+
self.requires_same_k_q_dimensions = True
|
| 49 |
+
|
| 50 |
+
# This attention does not support attention masks
|
| 51 |
+
self.supports_attention_mask = False
|
| 52 |
+
|
| 53 |
+
def forward(
|
| 54 |
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
|
| 55 |
+
):
|
| 56 |
+
# Handle a smaller dimension than expected
|
| 57 |
+
padding = 0
|
| 58 |
+
if q.shape[1] < self.seq_len:
|
| 59 |
+
padding = self.seq_len - q.shape[1]
|
| 60 |
+
pad_dims = (0, 0, 0, padding)
|
| 61 |
+
q = torch.nn.functional.pad(q, pad_dims)
|
| 62 |
+
k = torch.nn.functional.pad(k, pad_dims)
|
| 63 |
+
v = torch.nn.functional.pad(v, pad_dims)
|
| 64 |
+
|
| 65 |
+
k_projected = self.E(k.transpose(-2, -1)).transpose(-2, -1)
|
| 66 |
+
v_projected = self.F(v.transpose(-2, -1)).transpose(-2, -1)
|
| 67 |
+
|
| 68 |
+
y = scaled_dot_product_attention(
|
| 69 |
+
q=q, k=k_projected, v=v_projected, att_mask=None, dropout=self.attn_drop
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
y = self.attn_drop(y)
|
| 73 |
+
|
| 74 |
+
return y[:, :-padding, :] if padding > 0 else y
|
.venv/lib/python3.11/site-packages/xformers/components/attention/ortho.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from enum import Enum
|
| 10 |
+
from typing import Optional, Union
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.autograd.profiler as profiler
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as Fn
|
| 16 |
+
|
| 17 |
+
from xformers.components.attention import (
|
| 18 |
+
Attention,
|
| 19 |
+
AttentionConfig,
|
| 20 |
+
AttentionMask,
|
| 21 |
+
register_attention,
|
| 22 |
+
)
|
| 23 |
+
from xformers.components.attention.core import (
|
| 24 |
+
scaled_dot_product_attention,
|
| 25 |
+
scaled_query_key_softmax,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger("xformers")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class LandmarkSelection(str, Enum):
|
| 32 |
+
Orthogonal = "orthogonal"
|
| 33 |
+
KMeans = "kmeans"
|
| 34 |
+
KMeans_Spherical = "kmeans_spherical"
|
| 35 |
+
Random = "random"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class OrthoformerAttentionConfig(AttentionConfig):
|
| 40 |
+
"""
|
| 41 |
+
num_landmarks Number of landmarks to use for softmax approximation.
|
| 42 |
+
subsample_fraction Percentage of q_samples matrix to sample per iteration
|
| 43 |
+
landmark_selection Landmark selection strategy
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
num_landmarks: Optional[int]
|
| 47 |
+
subsample_fraction: Optional[float]
|
| 48 |
+
landmark_selection: Optional[LandmarkSelection]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@register_attention("orthoformer", OrthoformerAttentionConfig)
|
| 52 |
+
class OrthoFormerAttention(Attention):
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
dropout: float,
|
| 56 |
+
num_landmarks: int = 32,
|
| 57 |
+
subsample_fraction: float = 1.0,
|
| 58 |
+
landmark_selection: LandmarkSelection = LandmarkSelection.Orthogonal,
|
| 59 |
+
*args,
|
| 60 |
+
**kwargs,
|
| 61 |
+
):
|
| 62 |
+
"""
|
| 63 |
+
Orthoformer_ attention mechanism.
|
| 64 |
+
::
|
| 65 |
+
|
| 66 |
+
"Keeping Your Eye on the Ball: Trajectory Attention in Video Transformers"
|
| 67 |
+
Patrick, M., Campbell, D., Asano, Y., Misra, I., Metze, F., Feichtenhofer,
|
| 68 |
+
C., Vedaldi, A., Henriques, J. (2021)
|
| 69 |
+
|
| 70 |
+
Reference codebase: https://github.com/facebookresearch/Motionformer
|
| 71 |
+
|
| 72 |
+
.. _Orthoformer: https://arxiv.org/abs/2106.05392
|
| 73 |
+
|
| 74 |
+
"""
|
| 75 |
+
super().__init__()
|
| 76 |
+
|
| 77 |
+
self.num_landmarks = num_landmarks
|
| 78 |
+
self.attn_drop = nn.Dropout(dropout)
|
| 79 |
+
self.subsample_fraction = subsample_fraction
|
| 80 |
+
self.landmark_selection = landmark_selection
|
| 81 |
+
|
| 82 |
+
# Properties specific to this attention mechanism
|
| 83 |
+
self.supports_attention_mask = True
|
| 84 |
+
self.supports_key_padding_mask = False
|
| 85 |
+
|
| 86 |
+
def forward(
|
| 87 |
+
self,
|
| 88 |
+
q: torch.Tensor,
|
| 89 |
+
k: torch.Tensor,
|
| 90 |
+
v: torch.Tensor,
|
| 91 |
+
att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None,
|
| 92 |
+
*args,
|
| 93 |
+
**kwargs,
|
| 94 |
+
):
|
| 95 |
+
N = k.shape[1]
|
| 96 |
+
|
| 97 |
+
if self.num_landmarks == N:
|
| 98 |
+
# Default attention
|
| 99 |
+
x = scaled_dot_product_attention(q, k, v, att_mask)
|
| 100 |
+
else:
|
| 101 |
+
with torch.no_grad(), profiler.record_function("select landmarks"):
|
| 102 |
+
if self.landmark_selection == LandmarkSelection.Orthogonal:
|
| 103 |
+
landmarks = self._compute_orthogonal_landmarks(q)
|
| 104 |
+
elif self.landmark_selection == LandmarkSelection.Random:
|
| 105 |
+
half_L = self.num_landmarks // 2
|
| 106 |
+
landmarks_q = q[:, torch.randint(q.size(1), (half_L,)), :]
|
| 107 |
+
landmarks_k = k[:, torch.randint(k.size(1), (half_L,)), :]
|
| 108 |
+
landmarks = torch.cat((landmarks_q, landmarks_k), dim=-2)
|
| 109 |
+
elif self.landmark_selection == LandmarkSelection.KMeans:
|
| 110 |
+
landmarks = self._cluster_landmarks(q)
|
| 111 |
+
elif self.landmark_selection == LandmarkSelection.KMeans_Spherical:
|
| 112 |
+
landmarks = self._cluster_landmarks(q, spherical=True)
|
| 113 |
+
|
| 114 |
+
if att_mask is not None:
|
| 115 |
+
logger.warning(
|
| 116 |
+
"Orthoformer: attention mask passed alongside with using landmarks to reduce dimensions. \
|
| 117 |
+
The two are typically not compatible"
|
| 118 |
+
)
|
| 119 |
+
# FIXME: Should we still accept a mask in that case ?
|
| 120 |
+
att_mask = None
|
| 121 |
+
|
| 122 |
+
# pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems
|
| 123 |
+
# like it could be uninitialized.
|
| 124 |
+
kernel_1 = scaled_query_key_softmax(q, landmarks, att_mask)
|
| 125 |
+
# pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems
|
| 126 |
+
# like it could be uninitialized.
|
| 127 |
+
kernel_2 = scaled_query_key_softmax(landmarks, k, att_mask)
|
| 128 |
+
x = torch.matmul(kernel_1, torch.matmul(kernel_2, v))
|
| 129 |
+
x = self.attn_drop(x)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
def _cluster_landmarks(
|
| 133 |
+
self,
|
| 134 |
+
q: torch.Tensor,
|
| 135 |
+
spherical: bool = False,
|
| 136 |
+
num_iters: int = 6,
|
| 137 |
+
) -> torch.Tensor:
|
| 138 |
+
"""
|
| 139 |
+
Construct set of landmarks by recursively selecting new landmarks
|
| 140 |
+
that are maximally orthogonal to the existing set.
|
| 141 |
+
Returns near orthogonal landmarks with shape (B, M, D).
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
num_landmarks = min(self.num_landmarks, q.shape[1])
|
| 145 |
+
|
| 146 |
+
if self.subsample_fraction < 1.0:
|
| 147 |
+
num_samples = max(
|
| 148 |
+
int(self.subsample_fraction * q.size(-2)), num_landmarks
|
| 149 |
+
) # Need at least M/2 samples of queries and keys
|
| 150 |
+
q_samples = q[:, torch.randint(q.size(-2), (num_samples,)), :] # (B, N, D)
|
| 151 |
+
else:
|
| 152 |
+
q_samples = q # (B, N, D)
|
| 153 |
+
|
| 154 |
+
if spherical:
|
| 155 |
+
q_samples_normalized = Fn.normalize(
|
| 156 |
+
q_samples, p=2, dim=-1
|
| 157 |
+
) # may need to change default eps to eps=1e-8 for mixed precision compatibility
|
| 158 |
+
landmarks = self._kmeans_spherical(
|
| 159 |
+
q_samples_normalized, num_landmarks, num_iters
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
landmarks = self._kmeans(q_samples, num_landmarks, num_iters)
|
| 163 |
+
return landmarks # (B, M, D)
|
| 164 |
+
|
| 165 |
+
def _kmeans(self, x: torch.Tensor, K: int, num_iters: int = 10):
|
| 166 |
+
"""
|
| 167 |
+
Arguments:
|
| 168 |
+
x: (B, N, D)
|
| 169 |
+
K: number of clusters
|
| 170 |
+
num_iters: the number of kmeans updates
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
B, N, D = x.size()
|
| 174 |
+
assert K <= N, f"{K} > {N}"
|
| 175 |
+
|
| 176 |
+
c = x[
|
| 177 |
+
:, torch.randperm(N, device=x.device)[:K], :
|
| 178 |
+
].clone() # initialisation for the centroids
|
| 179 |
+
|
| 180 |
+
with profiler.record_function("kmeans"):
|
| 181 |
+
x_i = x.view(B, N, 1, D)
|
| 182 |
+
c_j = c.view(B, 1, K, D)
|
| 183 |
+
counts = c.new_zeros(B, K)
|
| 184 |
+
ones = x.new_ones((B, N))
|
| 185 |
+
|
| 186 |
+
for _ in range(num_iters):
|
| 187 |
+
# E step: assign points to the nearest cluster
|
| 188 |
+
D_ij = ((x_i - c_j) ** 2).sum(-1) # (B, N, K) squared distances
|
| 189 |
+
cl = D_ij.argmin(
|
| 190 |
+
dim=-1, keepdim=True
|
| 191 |
+
).long() # (B, N, 1) index of point to nearest cluster
|
| 192 |
+
|
| 193 |
+
# M step: update the centroids
|
| 194 |
+
c.zero_()
|
| 195 |
+
c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster
|
| 196 |
+
counts.fill_(1e-6) # avoid div0
|
| 197 |
+
counts.scatter_add_(
|
| 198 |
+
-1, cl.squeeze(-1), ones
|
| 199 |
+
) # number of points per cluster
|
| 200 |
+
c.divide_(counts.unsqueeze(-1)) # compute the average
|
| 201 |
+
|
| 202 |
+
return c
|
| 203 |
+
|
| 204 |
+
def _kmeans_spherical(self, x: torch.Tensor, K: int, num_iters=10):
|
| 205 |
+
"""
|
| 206 |
+
Arguments:
|
| 207 |
+
x: (B, N, D)
|
| 208 |
+
"""
|
| 209 |
+
B, N, D = x.size()
|
| 210 |
+
assert K <= N, f"{K} > {N}"
|
| 211 |
+
|
| 212 |
+
# initialisation for the centroids
|
| 213 |
+
c = x[:, torch.randperm(N, device=x.device)[:K], :].clone()
|
| 214 |
+
|
| 215 |
+
with profiler.record_function("kmeans_spherical"):
|
| 216 |
+
counts = c.new_zeros(B, K)
|
| 217 |
+
ones = x.new_ones((B, N))
|
| 218 |
+
|
| 219 |
+
for _ in range(num_iters):
|
| 220 |
+
# E step: assign points to the nearest cluster
|
| 221 |
+
D_ij = torch.matmul(
|
| 222 |
+
x, c.transpose(-2, -1)
|
| 223 |
+
) # (B, N, K) cosine similarity
|
| 224 |
+
cl = D_ij.argmax(
|
| 225 |
+
dim=-1, keepdim=True
|
| 226 |
+
).long() # (B, N, 1) index of point to nearest cluster
|
| 227 |
+
|
| 228 |
+
# M step: update the centroids
|
| 229 |
+
c.zero_()
|
| 230 |
+
c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster
|
| 231 |
+
counts.fill_(1e-6) # avoid div0
|
| 232 |
+
counts.scatter_add_(
|
| 233 |
+
-1, cl.squeeze(-1), ones
|
| 234 |
+
) # number of points per cluster
|
| 235 |
+
c.divide_(counts.unsqueeze(-1)) # compute the average
|
| 236 |
+
c = Fn.normalize(c, p=2, dim=-1) # renormalise
|
| 237 |
+
return c
|
| 238 |
+
|
| 239 |
+
def _compute_orthogonal_landmarks(self, q: torch.Tensor) -> torch.Tensor:
|
| 240 |
+
"""
|
| 241 |
+
Construct set of landmarks by recursively selecting new landmarks
|
| 242 |
+
that are maximally orthogonal to the existing set.
|
| 243 |
+
Returns near orthogonal landmarks with shape (B, M, D).
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
if self.subsample_fraction < 1.0:
|
| 247 |
+
# Need at least M samples of queries
|
| 248 |
+
num_samples = max(
|
| 249 |
+
int(self.subsample_fraction * q.size(-2)), self.num_landmarks
|
| 250 |
+
)
|
| 251 |
+
q_samples = q[
|
| 252 |
+
:, torch.randint(q.size(-2), (num_samples,), device=q.device), :
|
| 253 |
+
]
|
| 254 |
+
else:
|
| 255 |
+
# (B, N, D)
|
| 256 |
+
q_samples = q
|
| 257 |
+
|
| 258 |
+
# may need to change default eps to eps=1e-8 for mixed precision compatibility
|
| 259 |
+
q_samples_normalized = Fn.normalize(q_samples, p=2, dim=-1)
|
| 260 |
+
B, N, D = q_samples_normalized.shape
|
| 261 |
+
|
| 262 |
+
selected_mask = torch.zeros((B, N, 1), device=q_samples_normalized.device)
|
| 263 |
+
landmark_mask = torch.ones(
|
| 264 |
+
(B, 1, 1), dtype=selected_mask.dtype, device=q_samples_normalized.device
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Get initial random landmark
|
| 268 |
+
random_idx = torch.randint(
|
| 269 |
+
q_samples_normalized.size(-2), (B, 1, 1), device=q_samples_normalized.device
|
| 270 |
+
)
|
| 271 |
+
selected_mask.scatter_(-2, random_idx, landmark_mask)
|
| 272 |
+
|
| 273 |
+
# Selected landmarks
|
| 274 |
+
selected_landmarks = torch.empty(
|
| 275 |
+
(B, self.num_landmarks, D),
|
| 276 |
+
device=q_samples_normalized.device,
|
| 277 |
+
dtype=q_samples_normalized.dtype,
|
| 278 |
+
)
|
| 279 |
+
selected_landmarks[:, 0, :] = q_samples_normalized[
|
| 280 |
+
torch.arange(q_samples_normalized.size(0)), random_idx.view(-1), :
|
| 281 |
+
].view(B, D)
|
| 282 |
+
|
| 283 |
+
# Store computed cosine similarities
|
| 284 |
+
cos_sims = torch.empty(
|
| 285 |
+
(B, N, self.num_landmarks),
|
| 286 |
+
device=q_samples_normalized.device,
|
| 287 |
+
dtype=q_samples_normalized.dtype,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
for M in range(1, self.num_landmarks):
|
| 291 |
+
with profiler.record_function("find new landmark"):
|
| 292 |
+
# Calculate absolute cosine similarity between selected and unselected landmarks
|
| 293 |
+
# (B, N, D) * (B, D) -> (B, N)
|
| 294 |
+
cos_sims[:, :, M - 1] = torch.einsum(
|
| 295 |
+
"b n d, b d -> b n",
|
| 296 |
+
q_samples_normalized,
|
| 297 |
+
selected_landmarks[:, M - 1, :],
|
| 298 |
+
).abs()
|
| 299 |
+
|
| 300 |
+
# (B, N, M) cosine similarities of current set of landmarks wrt all queries and keys
|
| 301 |
+
cos_sim_set = cos_sims[:, :, :M]
|
| 302 |
+
|
| 303 |
+
# Get orthogonal landmark: landmark with smallest absolute cosine similarity:
|
| 304 |
+
# set cosine similarity for already selected landmarks to > 1
|
| 305 |
+
cos_sim_set.view(-1, M)[selected_mask.flatten().bool(), :] = 10
|
| 306 |
+
|
| 307 |
+
# (B,) - want max for non
|
| 308 |
+
selected_landmark_idx = cos_sim_set.amax(-1).argmin(-1)
|
| 309 |
+
|
| 310 |
+
# Add most orthogonal landmark to selected landmarks:
|
| 311 |
+
selected_landmarks[:, M, :] = q_samples_normalized[
|
| 312 |
+
torch.arange(q_samples_normalized.size(0)), selected_landmark_idx, :
|
| 313 |
+
].view(B, D)
|
| 314 |
+
|
| 315 |
+
# Removed selected indices from non-selected mask:
|
| 316 |
+
selected_mask.scatter_(
|
| 317 |
+
-2, selected_landmark_idx.unsqueeze(-1).unsqueeze(-1), landmark_mask
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# (B, M, D)
|
| 321 |
+
landmarks = torch.masked_select(q_samples, selected_mask.bool()).reshape(
|
| 322 |
+
B, -1, D
|
| 323 |
+
)
|
| 324 |
+
return landmarks # (B, M, D)
|
.venv/lib/python3.11/site-packages/xformers/components/attention/pooling.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class PoolingAttentionConfig(AttentionConfig):
|
| 19 |
+
pool_size: int # dimension of the input sequence
|
| 20 |
+
stride: Optional[int] # dimension of the internal space
|
| 21 |
+
padding: Optional[int]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@register_attention("pooling", PoolingAttentionConfig)
|
| 25 |
+
class Pooling(Attention):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
pool_size: int = 3,
|
| 29 |
+
stride: int = 1,
|
| 30 |
+
padding: Optional[int] = None,
|
| 31 |
+
*_,
|
| 32 |
+
**__,
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Pooling token mixing mechanism, as proposed in
|
| 36 |
+
`Metaformer is actually what you need for vision`_, Yu et al (2021).
|
| 37 |
+
|
| 38 |
+
The original notation is kept as is.
|
| 39 |
+
|
| 40 |
+
.. _`Metaformer is actually what you need for vision` : https://arxiv.org/pdf/2111.11418v1.pdf
|
| 41 |
+
"""
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
padding = padding if padding is not None else pool_size // 2
|
| 45 |
+
self.pool = nn.AvgPool2d(
|
| 46 |
+
pool_size,
|
| 47 |
+
stride=stride,
|
| 48 |
+
padding=pool_size // 2,
|
| 49 |
+
count_include_pad=False,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# MHA related flags:
|
| 53 |
+
# kq need to have the same dimension
|
| 54 |
+
self.requires_same_k_q_dimensions = False
|
| 55 |
+
|
| 56 |
+
# This attention does not support attention masks
|
| 57 |
+
self.supports_attention_mask = False
|
| 58 |
+
|
| 59 |
+
# This "attention" (token mixing) skips the multihead attention altogether
|
| 60 |
+
self.requires_skip_multi_head = True
|
| 61 |
+
self.requires_input_projection = False
|
| 62 |
+
|
| 63 |
+
# This operator does not really handle q,k,v
|
| 64 |
+
self.requires_same_k_q_dimensions = True
|
| 65 |
+
|
| 66 |
+
# This attention requires the 2d structure out of the context,
|
| 67 |
+
# implictly assumed to be a squared length
|
| 68 |
+
self.requires_squared_context = True
|
| 69 |
+
|
| 70 |
+
def forward(self, q: torch.Tensor, *_, **__):
|
| 71 |
+
# Expose the 2D token structure
|
| 72 |
+
B, HW, C = q.shape
|
| 73 |
+
H = int(math.sqrt(HW))
|
| 74 |
+
assert H * H == HW
|
| 75 |
+
|
| 76 |
+
q = q.transpose(-2, -1).reshape(B, C, H, H)
|
| 77 |
+
|
| 78 |
+
# 2D pool
|
| 79 |
+
x_pool = self.pool(q) - q # compensate for the residual path
|
| 80 |
+
|
| 81 |
+
# Get back to B HW C
|
| 82 |
+
return x_pool.flatten(2, 3).transpose(-2, -1)
|
.venv/lib/python3.11/site-packages/xformers/components/attention/sparsity_config.py
ADDED
|
@@ -0,0 +1,812 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
"""
|
| 6 |
+
The code has been adopted from DeepSpeed
|
| 7 |
+
(https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/sparse_attention/sparsity_config.py)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import random
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SparsityConfig:
|
| 16 |
+
"""Abstract Configuration class to store `sparsity configuration of a self attention layer`.
|
| 17 |
+
It contains shared property of different block-sparse sparsity patterns. However, each class
|
| 18 |
+
needs to extend it based on required property and functionality.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, num_heads, block_size=16, different_layout_per_head=False):
|
| 22 |
+
"""Initialize the Sparsity Pattern Config.
|
| 23 |
+
Arguments:
|
| 24 |
+
num_heads: required: an integer determining number of attention heads of the layer.
|
| 25 |
+
block_size: optional: an integer determining the block size. Current implementation of
|
| 26 |
+
sparse self-attention is based on blocked sparse matrices. In which this parameter
|
| 27 |
+
defines size of such blocks, `Block X Block`.
|
| 28 |
+
different_layout_per_head: optional: a boolean determining if each head should be
|
| 29 |
+
assigned a different sparsity layout; default is false and this will be satisfied
|
| 30 |
+
based on availability.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
self.num_heads = num_heads
|
| 34 |
+
self.block_size = block_size
|
| 35 |
+
self.different_layout_per_head = different_layout_per_head
|
| 36 |
+
self.num_layout_heads = num_heads if different_layout_per_head else 1
|
| 37 |
+
|
| 38 |
+
def setup_layout(self, seq_len):
|
| 39 |
+
"""Create layout tensor for the given sequence length
|
| 40 |
+
Arguments:
|
| 41 |
+
seq_len: required: an integer determining number of attention heads of the layer.
|
| 42 |
+
Return:
|
| 43 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) for sparsity layout
|
| 44 |
+
of all head; initialized with zero
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
if seq_len % self.block_size != 0:
|
| 48 |
+
raise ValueError(
|
| 49 |
+
f"Sequence Length, {seq_len}, needs to be dividable by Block size {self.block_size}!"
|
| 50 |
+
)
|
| 51 |
+
num_blocks = seq_len // self.block_size
|
| 52 |
+
# TODO Currently we allocate layout per head; needs to be updated if heads share a single layout.
|
| 53 |
+
layout = torch.zeros(
|
| 54 |
+
(self.num_heads, num_blocks, num_blocks), dtype=torch.int64
|
| 55 |
+
)
|
| 56 |
+
return layout
|
| 57 |
+
|
| 58 |
+
def check_and_propagate_first_head_layout(self, layout):
|
| 59 |
+
"""If all heads require same sparsity layout, it propagate first head layout to all heads
|
| 60 |
+
Arguments:
|
| 61 |
+
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
| 62 |
+
sparsity layout of all head; may not be completely set at this step
|
| 63 |
+
Return:
|
| 64 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
| 65 |
+
layout of all head
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
if not self.different_layout_per_head:
|
| 69 |
+
layout[1 : self.num_heads, :, :] = layout[0, :, :]
|
| 70 |
+
return layout
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class DenseSparsityConfig(SparsityConfig):
|
| 74 |
+
"""Configuration class to store `Dense` configuration.
|
| 75 |
+
In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and
|
| 76 |
+
comprehension.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, num_heads, block_size=16, different_layout_per_head=False):
|
| 80 |
+
"""Initialize the Dense Sparsity Pattern Config.
|
| 81 |
+
In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison
|
| 82 |
+
and comprehension.
|
| 83 |
+
Arguments:
|
| 84 |
+
num_heads: required: an integer determining number of attention heads of the layer.
|
| 85 |
+
block_size: optional: an integer determining the block size. Current implementation of
|
| 86 |
+
sparse self-attention is based on blocked sparse matrices. In which this parameter
|
| 87 |
+
defines size of such blocks, `Block X Block`.
|
| 88 |
+
different_layout_per_head: optional: this is just for the sake of consistency with
|
| 89 |
+
other sparsity formats; can ignore it for DenseSparsityConfig
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
super().__init__(num_heads, block_size, different_layout_per_head)
|
| 93 |
+
|
| 94 |
+
def make_layout(self, seq_len):
|
| 95 |
+
"""Set 1 to all blocks of the layout meanins the pattern is dense; not sparse.
|
| 96 |
+
Arguments:
|
| 97 |
+
seq_len: required: an integer determining the underling sequence length;
|
| 98 |
+
must be <= max sequence length
|
| 99 |
+
Return:
|
| 100 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
| 101 |
+
layout of all head; for dense everything is 1
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
layout = self.setup_layout(seq_len)
|
| 105 |
+
layout[:, :, :] = 1
|
| 106 |
+
return layout
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class FixedSparsityConfig(SparsityConfig):
|
| 110 |
+
"""Configuration class to store `Fixed` sparsity configuration.
|
| 111 |
+
For more details about this sparsity config, please see `Generative Modeling with
|
| 112 |
+
Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
|
| 113 |
+
This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
num_heads,
|
| 119 |
+
block_size=16,
|
| 120 |
+
different_layout_per_head=False,
|
| 121 |
+
num_local_blocks=4,
|
| 122 |
+
num_global_blocks=1,
|
| 123 |
+
attention="bidirectional",
|
| 124 |
+
horizontal_global_attention=False,
|
| 125 |
+
num_different_global_patterns=1,
|
| 126 |
+
):
|
| 127 |
+
"""Initialize `Fixed` Sparsity Pattern Config.
|
| 128 |
+
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
|
| 129 |
+
Arguments:
|
| 130 |
+
num_heads: required: an integer determining number of attention heads of the layer.
|
| 131 |
+
block_size: optional: an integer determining the block size. Current implementation of
|
| 132 |
+
sparse self-attention is based on blocked sparse matrices. In which this parameter
|
| 133 |
+
defines size of such blocks, `Block X Block`.
|
| 134 |
+
different_layout_per_head: optional: a boolean determining if each head should be
|
| 135 |
+
assigned a different sparsity layout; default is false and this will be satisfied
|
| 136 |
+
based on availability.
|
| 137 |
+
num_local_blocks: optional: an integer determining the number of blocks in local attention
|
| 138 |
+
window.
|
| 139 |
+
num_global_blocks: optional: an integer determining how many consecutive blocks in a local
|
| 140 |
+
window is used as the representative of the window for global attention.
|
| 141 |
+
attention: optional: a string determining attention type. Attention can be `unidirectional`,
|
| 142 |
+
such as autoregressive models, in which tokens attend only to tokens appear before them
|
| 143 |
+
in the context. Considering that, the upper triangular of attention matrix is empty as
|
| 144 |
+
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
|
| 145 |
+
any other tokens before or after them. Then, the upper triangular part of the attention
|
| 146 |
+
matrix is mirror of the lower triangular in the above figure.
|
| 147 |
+
horizontal_global_attention: optional: a boolean determining if blocks that are global
|
| 148 |
+
representative of a local window, also attend to all other blocks. This is valid only if
|
| 149 |
+
attention type is `bidirectional`. Looking at the attention matrix, that means global
|
| 150 |
+
attention not only includes the vertical blocks, but also horizontal blocks.
|
| 151 |
+
num_different_global_patterns: optional: an integer determining number of different global
|
| 152 |
+
attentions layouts. While global attention can be fixed by which block/s are representative
|
| 153 |
+
of any local window, since there are multi-heads, each head can use a different global representative.
|
| 154 |
+
For example, with 4 blocks local window and global attention size of 1 block, we can have 4 different
|
| 155 |
+
versions in which the first, Second, third, or forth block of each local window can be global
|
| 156 |
+
representative of that window. This parameter determines how many of such patterns we want.
|
| 157 |
+
Of course, there is a limitation based on num_local_blocks and num_global_blocks.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
super().__init__(num_heads, block_size, different_layout_per_head)
|
| 161 |
+
|
| 162 |
+
self.num_local_blocks = num_local_blocks
|
| 163 |
+
|
| 164 |
+
if num_local_blocks % num_global_blocks != 0:
|
| 165 |
+
raise ValueError(
|
| 166 |
+
f"""Number of blocks in a local window, {num_local_blocks},
|
| 167 |
+
must be dividable by number of global blocks, {num_global_blocks}!"""
|
| 168 |
+
)
|
| 169 |
+
self.num_global_blocks = num_global_blocks
|
| 170 |
+
|
| 171 |
+
if attention != "unidirectional" and attention != "bidirectional":
|
| 172 |
+
raise NotImplementedError(
|
| 173 |
+
'only "uni/bi-directional" attentions are supported for now!'
|
| 174 |
+
)
|
| 175 |
+
self.attention = attention
|
| 176 |
+
|
| 177 |
+
if attention != "bidirectional" and horizontal_global_attention:
|
| 178 |
+
raise ValueError(
|
| 179 |
+
'only "bi-directional" attentions can support horizontal global attention!'
|
| 180 |
+
)
|
| 181 |
+
self.horizontal_global_attention = horizontal_global_attention
|
| 182 |
+
|
| 183 |
+
if num_different_global_patterns > 1 and not different_layout_per_head:
|
| 184 |
+
raise ValueError(
|
| 185 |
+
"""Number of different layouts cannot be more than one when you have set a single layout
|
| 186 |
+
for all heads! Set different_layout_per_head to True."""
|
| 187 |
+
)
|
| 188 |
+
if num_different_global_patterns > (num_local_blocks // num_global_blocks):
|
| 189 |
+
raise ValueError(
|
| 190 |
+
f"""Number of layout versions (num_different_global_patterns), {num_different_global_patterns},
|
| 191 |
+
cannot be larger than number of local window blocks divided by number of global blocks,
|
| 192 |
+
{num_local_blocks} / {num_global_blocks} = {num_local_blocks//num_global_blocks}!"""
|
| 193 |
+
)
|
| 194 |
+
self.num_different_global_patterns = num_different_global_patterns
|
| 195 |
+
|
| 196 |
+
def set_local_layout(self, h, layout):
|
| 197 |
+
"""Sets local attention layout used by the given head in the sparse attention.
|
| 198 |
+
Arguments:
|
| 199 |
+
h: required: an integer determining head index
|
| 200 |
+
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
| 201 |
+
sparsity layout of all head; may not be completely set at this step
|
| 202 |
+
Return:
|
| 203 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
| 204 |
+
layout of all head in which local layout is set
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
num_blocks = layout.shape[1]
|
| 208 |
+
for i in range(0, num_blocks, self.num_local_blocks):
|
| 209 |
+
end = min(i + self.num_local_blocks, num_blocks)
|
| 210 |
+
for row in range(i, end):
|
| 211 |
+
for col in range(
|
| 212 |
+
i, (row + 1 if self.attention == "unidirectional" else end)
|
| 213 |
+
):
|
| 214 |
+
layout[h, row, col] = 1
|
| 215 |
+
return layout
|
| 216 |
+
|
| 217 |
+
def set_global_layout(self, h, layout):
|
| 218 |
+
"""Sets global attention layout used by the given head in the sparse attention.
|
| 219 |
+
Currently we set global blocks starting from the last block of a local window to the first one.
|
| 220 |
+
That means if a local window consists of 4 blocks and global attention size is one block, we use
|
| 221 |
+
block #4 in each local window as global. If we have different layout per head, then other heads
|
| 222 |
+
will get #3, #2, and #1. And if we have more heads (and different layout has set) than num of global
|
| 223 |
+
attentions, multiple head may have same global attentions.
|
| 224 |
+
Note) if horizontal_global_attention is set, global blocks will be set both horizontally and
|
| 225 |
+
vertically.
|
| 226 |
+
Arguments:
|
| 227 |
+
h: required: an integer determining head index
|
| 228 |
+
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
| 229 |
+
sparsity layout of all head; may not be completely set at this step
|
| 230 |
+
Return:
|
| 231 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
| 232 |
+
layout of all head in which global layout is set
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
num_blocks = layout.shape[1]
|
| 236 |
+
first_global_block_idx = (
|
| 237 |
+
self.num_local_blocks
|
| 238 |
+
- (1 + h % self.num_different_global_patterns) * self.num_global_blocks
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# set all global blocks except the last one if (in last local window)
|
| 242 |
+
end = num_blocks - (num_blocks % self.num_local_blocks)
|
| 243 |
+
for i in range(first_global_block_idx, end, self.num_local_blocks):
|
| 244 |
+
|
| 245 |
+
# vertical global attention
|
| 246 |
+
first_row = 0 if self.attention == "bidirectional" else i
|
| 247 |
+
# (((i // self.num_local_blocks) + 1) * self.num_local_blocks)
|
| 248 |
+
# if (first_row < num_blocks):
|
| 249 |
+
layout[h, first_row:, i : i + self.num_global_blocks] = 1
|
| 250 |
+
|
| 251 |
+
# horizontal global attention; only in bidirectional attention
|
| 252 |
+
if self.horizontal_global_attention:
|
| 253 |
+
layout[h, i : i + self.num_global_blocks, :] = 1
|
| 254 |
+
|
| 255 |
+
# set last global blocks; handle possible short last local window
|
| 256 |
+
if end < num_blocks:
|
| 257 |
+
start = min(
|
| 258 |
+
end + first_global_block_idx, num_blocks - self.num_global_blocks
|
| 259 |
+
)
|
| 260 |
+
end = start + self.num_global_blocks
|
| 261 |
+
|
| 262 |
+
# vertical global attention
|
| 263 |
+
first_row = 0 if self.attention == "bidirectional" else start
|
| 264 |
+
# (((start // self.num_local_blocks) + 1) * self.num_local_blocks)
|
| 265 |
+
# if (first_row < num_blocks):
|
| 266 |
+
layout[h, first_row:, start:end] = 1
|
| 267 |
+
|
| 268 |
+
# horizontal global attention
|
| 269 |
+
if self.horizontal_global_attention:
|
| 270 |
+
layout[h, start:end, :] = 1
|
| 271 |
+
return layout
|
| 272 |
+
|
| 273 |
+
def make_layout(self, seq_len):
|
| 274 |
+
"""Generates `Fixed` sparsity layout used by each head in the sparse attention.
|
| 275 |
+
Arguments:
|
| 276 |
+
seq_len: required: an integer determining number of attention heads of the layer.
|
| 277 |
+
Return:
|
| 278 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Fixed`
|
| 279 |
+
sparsity layout of all head
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
layout = self.setup_layout(seq_len)
|
| 283 |
+
for h in range(0, self.num_layout_heads):
|
| 284 |
+
layout = self.set_local_layout(h, layout)
|
| 285 |
+
layout = self.set_global_layout(h, layout)
|
| 286 |
+
|
| 287 |
+
layout = self.check_and_propagate_first_head_layout(layout)
|
| 288 |
+
return layout
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class VariableSparsityConfig(SparsityConfig):
|
| 292 |
+
"""Configuration class to store `Variable` sparsity configuration.
|
| 293 |
+
This layout is an extension of FixedSparsityConfig in which:
|
| 294 |
+
- user can set random layout; default value is zero means no random block
|
| 295 |
+
- user can provide a list of local block sizes
|
| 296 |
+
- user can provide a list of global block indices.
|
| 297 |
+
For more details about `Fixed` sparsity config, please see `Generative Modeling with
|
| 298 |
+
Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
|
| 299 |
+
This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
def __init__(
|
| 303 |
+
self,
|
| 304 |
+
num_heads,
|
| 305 |
+
block_size=16,
|
| 306 |
+
different_layout_per_head=False,
|
| 307 |
+
num_random_blocks=0,
|
| 308 |
+
local_window_blocks=[4],
|
| 309 |
+
global_block_indices=[0],
|
| 310 |
+
global_block_end_indices=None,
|
| 311 |
+
attention="bidirectional",
|
| 312 |
+
horizontal_global_attention=False,
|
| 313 |
+
):
|
| 314 |
+
"""Initialize `Variable` Sparsity Pattern Config.
|
| 315 |
+
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
|
| 316 |
+
Arguments:
|
| 317 |
+
num_heads: required: an integer determining number of attention heads of the layer.
|
| 318 |
+
block_size: optional: an integer determining the block size. Current implementation of sparse
|
| 319 |
+
self-attention is based on blocked sparse matrices. In which this parameter defines
|
| 320 |
+
size of such blocks, `Block X Block`.
|
| 321 |
+
different_layout_per_head: optional: a boolean determining if each head should be assigned a
|
| 322 |
+
different sparsity layout; default is false and this will be satisfied based on
|
| 323 |
+
availability. Currently this sparsity config can only assign single layout to all heads;
|
| 324 |
+
needs to be extended for different layout per head.
|
| 325 |
+
num_random_blocks: optional: an integer determining the number of random blocks in each block row.
|
| 326 |
+
local_window_blocks: optional: a list of integers determining the number of blocks in each
|
| 327 |
+
local attention window. It assumes first number determines # of blocks in the first local
|
| 328 |
+
window, second the second window, ..., and the last number determines the number of blocks
|
| 329 |
+
in the remaining local windows.
|
| 330 |
+
global_block_indices: optional: a list of integers determining which blocks are considered
|
| 331 |
+
as global attention. Given indices, determine the blocks that all other token blocks
|
| 332 |
+
attend to and they attend to all other token blocks. Default value is only index 0.
|
| 333 |
+
Notice that if global_block_end_indices parameter is set, this parameter is used as
|
| 334 |
+
starting index of each global window.
|
| 335 |
+
global_block_end_indices: optional: a list of integers determining end indices of global
|
| 336 |
+
window blocks. By default this is not used. But if it is set, it must have the same size
|
| 337 |
+
of global_block_indices parameter, and combining this two parameters, for each index i,
|
| 338 |
+
blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are
|
| 339 |
+
considered as global attention.
|
| 340 |
+
attention: optional: a string determining attention type. Attention can be `unidirectional`,
|
| 341 |
+
such as autoregressive models, in which tokens attend only to tokens appear before them
|
| 342 |
+
in the context. Considering that, the upper triangular of attention matrix is empty as
|
| 343 |
+
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
|
| 344 |
+
any other tokens before or after them. Then, the upper triangular part of the attention
|
| 345 |
+
matrix is mirror of the lower triangular in the above figure.
|
| 346 |
+
horizontal_global_attention: optional: a boolean determining if blocks that are global
|
| 347 |
+
representative of a local window, also attend to all other blocks. This is valid only if
|
| 348 |
+
attention type is `bidirectional`. Looking at the attention matrix, that means global
|
| 349 |
+
attention not only includes the vertical blocks, but also horizontal blocks.
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
super().__init__(num_heads, block_size, different_layout_per_head)
|
| 353 |
+
|
| 354 |
+
self.num_random_blocks = num_random_blocks
|
| 355 |
+
self.local_window_blocks = local_window_blocks
|
| 356 |
+
self.global_block_indices = global_block_indices
|
| 357 |
+
|
| 358 |
+
if global_block_end_indices is not None:
|
| 359 |
+
if len(global_block_indices) != len(global_block_end_indices):
|
| 360 |
+
raise ValueError(
|
| 361 |
+
f"""Global block start indices length, {len(global_block_indices)}, must be same as
|
| 362 |
+
global block end indices length, {len(global_block_end_indices)}!"""
|
| 363 |
+
)
|
| 364 |
+
for _, (start_idx, end_idx) in enumerate(
|
| 365 |
+
zip(global_block_indices, global_block_end_indices)
|
| 366 |
+
):
|
| 367 |
+
if start_idx >= end_idx:
|
| 368 |
+
raise ValueError(
|
| 369 |
+
f"""Global block start index, {start_idx}, must be smaller than global block end
|
| 370 |
+
index, {end_idx}!"""
|
| 371 |
+
)
|
| 372 |
+
self.global_block_end_indices = global_block_end_indices
|
| 373 |
+
|
| 374 |
+
if attention != "unidirectional" and attention != "bidirectional":
|
| 375 |
+
raise NotImplementedError(
|
| 376 |
+
'only "uni/bi-directional" attentions are supported for now!'
|
| 377 |
+
)
|
| 378 |
+
self.attention = attention
|
| 379 |
+
|
| 380 |
+
if attention != "bidirectional" and horizontal_global_attention:
|
| 381 |
+
raise ValueError(
|
| 382 |
+
'only "bi-directional" attentions can support horizontal global attention!'
|
| 383 |
+
)
|
| 384 |
+
self.horizontal_global_attention = horizontal_global_attention
|
| 385 |
+
|
| 386 |
+
def set_random_layout(self, h, layout):
|
| 387 |
+
"""Sets random attention layout used by the given head in the sparse attention.
|
| 388 |
+
Note) By default, it assumes there will be a unique random block layout for all heads; unless
|
| 389 |
+
`different_layout_per_head` parameter is set in which each head can have a different random
|
| 390 |
+
layout.
|
| 391 |
+
Arguments:
|
| 392 |
+
h: required: an integer determining head index
|
| 393 |
+
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
| 394 |
+
sparsity layout of all head; may not be completely set at this step
|
| 395 |
+
Return:
|
| 396 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
| 397 |
+
layout of all head in which random layout is set
|
| 398 |
+
"""
|
| 399 |
+
|
| 400 |
+
num_blocks = layout.shape[1]
|
| 401 |
+
if num_blocks < self.num_random_blocks:
|
| 402 |
+
raise ValueError(
|
| 403 |
+
f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number
|
| 404 |
+
of blocks in a row, {num_blocks}!"""
|
| 405 |
+
)
|
| 406 |
+
for row in range(0, num_blocks):
|
| 407 |
+
rnd_cols = random.sample(range(0, num_blocks), self.num_random_blocks)
|
| 408 |
+
layout[h, row, rnd_cols] = 1
|
| 409 |
+
return layout
|
| 410 |
+
|
| 411 |
+
def set_local_layout(self, h, layout):
|
| 412 |
+
"""Sets local attention layout used by the given head in the sparse attention.
|
| 413 |
+
Arguments:
|
| 414 |
+
h: required: an integer determining head index
|
| 415 |
+
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
| 416 |
+
sparsity layout of all head; may not be completely set at this step
|
| 417 |
+
Return:
|
| 418 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
| 419 |
+
layout of all head in which local layout is set
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
num_blocks = layout.shape[1]
|
| 423 |
+
start_block_idx = 0
|
| 424 |
+
end_block_idx = 0
|
| 425 |
+
for block_size in self.local_window_blocks:
|
| 426 |
+
end_block_idx += block_size
|
| 427 |
+
end_block_idx = min(end_block_idx, num_blocks)
|
| 428 |
+
for row in range(start_block_idx, end_block_idx):
|
| 429 |
+
for col in range(
|
| 430 |
+
start_block_idx,
|
| 431 |
+
(row + 1 if self.attention == "unidirectional" else end_block_idx),
|
| 432 |
+
):
|
| 433 |
+
layout[h, row, col] = 1
|
| 434 |
+
start_block_idx += block_size
|
| 435 |
+
|
| 436 |
+
# if there is any remaining not attended part, use the lats local window block size as local
|
| 437 |
+
# window for the remaining applicable local windows
|
| 438 |
+
for i in range(start_block_idx, num_blocks, block_size):
|
| 439 |
+
end_block_idx = min(i + block_size, num_blocks)
|
| 440 |
+
for row in range(i, end_block_idx):
|
| 441 |
+
for col in range(
|
| 442 |
+
i,
|
| 443 |
+
(row + 1 if self.attention == "unidirectional" else end_block_idx),
|
| 444 |
+
):
|
| 445 |
+
layout[h, row, col] = 1
|
| 446 |
+
return layout
|
| 447 |
+
|
| 448 |
+
def set_global_layout(self, h, layout):
|
| 449 |
+
"""Sets global attention layout used by the given head in the sparse attention.
|
| 450 |
+
Arguments:
|
| 451 |
+
h: required: an integer determining head index
|
| 452 |
+
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
| 453 |
+
sparsity layout of all head; may not be completely set at this step
|
| 454 |
+
Return:
|
| 455 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
| 456 |
+
layout of all head in which global layout is set
|
| 457 |
+
"""
|
| 458 |
+
|
| 459 |
+
num_blocks = layout.shape[1]
|
| 460 |
+
if self.global_block_end_indices is None:
|
| 461 |
+
for idx in self.global_block_indices:
|
| 462 |
+
# if global block idx is in the range of the sequence blocks
|
| 463 |
+
if idx < num_blocks:
|
| 464 |
+
# global rows
|
| 465 |
+
if self.horizontal_global_attention:
|
| 466 |
+
layout[h, idx, :] = 1
|
| 467 |
+
|
| 468 |
+
# global columns
|
| 469 |
+
first_row = 0 if self.attention == "bidirectional" else idx
|
| 470 |
+
layout[h, first_row:, idx] = 1
|
| 471 |
+
else:
|
| 472 |
+
for _, (start_idx, end_idx) in enumerate(
|
| 473 |
+
zip(self.global_block_indices, self.global_block_end_indices)
|
| 474 |
+
):
|
| 475 |
+
# if global block idx is in the range of the sequence blocks
|
| 476 |
+
if start_idx < num_blocks:
|
| 477 |
+
end_idx = min(end_idx, num_blocks)
|
| 478 |
+
# global rows
|
| 479 |
+
if self.horizontal_global_attention:
|
| 480 |
+
layout[h, start_idx:end_idx, :] = 1
|
| 481 |
+
|
| 482 |
+
# global columns
|
| 483 |
+
first_row = 0 if self.attention == "bidirectional" else start_idx
|
| 484 |
+
layout[h, first_row:, start_idx:end_idx] = 1
|
| 485 |
+
return layout
|
| 486 |
+
|
| 487 |
+
def make_layout(self, seq_len):
|
| 488 |
+
"""Generates `Variable` sparsity layout used by each head in the sparse attention.
|
| 489 |
+
Arguments:
|
| 490 |
+
seq_len: required: an integer determining number of attention heads of the layer.
|
| 491 |
+
Return:
|
| 492 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Variable`
|
| 493 |
+
sparsity layout of all head
|
| 494 |
+
"""
|
| 495 |
+
|
| 496 |
+
layout = self.setup_layout(seq_len)
|
| 497 |
+
for h in range(0, self.num_layout_heads):
|
| 498 |
+
layout = self.set_random_layout(h, layout)
|
| 499 |
+
layout = self.set_local_layout(h, layout)
|
| 500 |
+
layout = self.set_global_layout(h, layout)
|
| 501 |
+
|
| 502 |
+
layout = self.check_and_propagate_first_head_layout(layout)
|
| 503 |
+
return layout
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
class BigBirdSparsityConfig(SparsityConfig):
|
| 507 |
+
"""Configuration class to store `BigBird` sparsity configuration.
|
| 508 |
+
For more details about this sparsity config, please see `Big Bird: Transformers for
|
| 509 |
+
Longer Sequences`: https://arxiv.org/pdf/2007.14062.pdf
|
| 510 |
+
This class extends parent class of `SparsityConfig` and customizes it for `BigBird` sparsity.
|
| 511 |
+
"""
|
| 512 |
+
|
| 513 |
+
def __init__(
|
| 514 |
+
self,
|
| 515 |
+
num_heads,
|
| 516 |
+
block_size=16,
|
| 517 |
+
different_layout_per_head=False,
|
| 518 |
+
num_random_blocks=1,
|
| 519 |
+
num_sliding_window_blocks=3,
|
| 520 |
+
num_global_blocks=1,
|
| 521 |
+
attention="bidirectional",
|
| 522 |
+
):
|
| 523 |
+
"""Initialize the BigBird Sparsity Pattern Config.
|
| 524 |
+
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
|
| 525 |
+
Arguments:
|
| 526 |
+
num_heads: required: an integer determining number of attention heads of the layer.
|
| 527 |
+
block_size: optional: an integer determining the block size. Current implementation of
|
| 528 |
+
sparse self-attention is based on blocked sparse matrices. In which this parameter
|
| 529 |
+
defines size of such blocks, `Block X Block`.
|
| 530 |
+
different_layout_per_head: optional: a boolean determining if each head should be assigned
|
| 531 |
+
a different sparsity layout; default is false and this will be satisfied based on
|
| 532 |
+
availability.
|
| 533 |
+
num_random_blocks: optional: an integer determining the number of random blocks in each
|
| 534 |
+
block row.
|
| 535 |
+
num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding
|
| 536 |
+
local attention window.
|
| 537 |
+
num_global_blocks: optional: an integer determining how many consecutive blocks, starting
|
| 538 |
+
from index 0, are considered as global attention. Global block tokens will be attended
|
| 539 |
+
by all other block tokens and will attend to all other block tokens as well.
|
| 540 |
+
attention: optional: a string determining attention type. Attention can be `unidirectional`,
|
| 541 |
+
such as autoregressive models, in which tokens attend only to tokens appear before them
|
| 542 |
+
in the context. Considering that, the upper triangular of attention matrix is empty as
|
| 543 |
+
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
|
| 544 |
+
any other tokens before or after them. Then, the upper triangular part of the attention
|
| 545 |
+
matrix is mirror of the lower triangular in the above figure.
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
super().__init__(num_heads, block_size, different_layout_per_head)
|
| 549 |
+
|
| 550 |
+
self.num_random_blocks = num_random_blocks
|
| 551 |
+
self.num_sliding_window_blocks = num_sliding_window_blocks
|
| 552 |
+
self.num_global_blocks = num_global_blocks
|
| 553 |
+
|
| 554 |
+
if attention != "unidirectional" and attention != "bidirectional":
|
| 555 |
+
raise NotImplementedError(
|
| 556 |
+
'only "uni/bi-directional" attentions are supported for now!'
|
| 557 |
+
)
|
| 558 |
+
self.attention = attention
|
| 559 |
+
|
| 560 |
+
def set_random_layout(self, h, layout):
|
| 561 |
+
"""Sets random attention layout used by the given head in the sparse attention.
|
| 562 |
+
Note) By default, it assumes there will be a unique random block layout for all heads; unless
|
| 563 |
+
`different_layout_per_head` parameter is set in which each head can have a different random layout.
|
| 564 |
+
Arguments:
|
| 565 |
+
h: required: an integer determining head index
|
| 566 |
+
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
| 567 |
+
sparsity layout of all head; may not be completely set at this step
|
| 568 |
+
Return:
|
| 569 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
| 570 |
+
layout of all head in which random layout is set
|
| 571 |
+
"""
|
| 572 |
+
|
| 573 |
+
num_blocks = layout.shape[1]
|
| 574 |
+
if num_blocks < self.num_random_blocks:
|
| 575 |
+
raise ValueError(
|
| 576 |
+
f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number
|
| 577 |
+
of blocks in a row, {num_blocks}!"""
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
for row in range(0, num_blocks):
|
| 581 |
+
sample_range = (
|
| 582 |
+
range(0, num_blocks)
|
| 583 |
+
if self.attention == "bidirectional"
|
| 584 |
+
else range(0, row + 1)
|
| 585 |
+
)
|
| 586 |
+
rnd_cols = random.sample(sample_range, self.num_random_blocks)
|
| 587 |
+
layout[h, row, rnd_cols] = 1
|
| 588 |
+
return layout
|
| 589 |
+
|
| 590 |
+
def set_sliding_window_layout(self, h, layout):
|
| 591 |
+
"""Sets sliding local attention layout used by the given head in the sparse attention.
|
| 592 |
+
Arguments:
|
| 593 |
+
h: required: an integer determining head index
|
| 594 |
+
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
| 595 |
+
sparsity layout of all head; may not be completely set at this step
|
| 596 |
+
Return:
|
| 597 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
| 598 |
+
layout of all head in which local sliding window layout is set
|
| 599 |
+
"""
|
| 600 |
+
|
| 601 |
+
num_blocks = layout.shape[1]
|
| 602 |
+
if num_blocks < self.num_sliding_window_blocks:
|
| 603 |
+
raise ValueError(
|
| 604 |
+
f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than
|
| 605 |
+
overall number of blocks in a row, {num_blocks}!"""
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
w = self.num_sliding_window_blocks // 2
|
| 609 |
+
for row in range(0, num_blocks):
|
| 610 |
+
start = max(0, row - w)
|
| 611 |
+
end = min(row + w + 1, num_blocks)
|
| 612 |
+
layout[h, row, start:end] = 1
|
| 613 |
+
return layout
|
| 614 |
+
|
| 615 |
+
def set_global_layout_itc(self, h, layout):
|
| 616 |
+
"""Sets global attention layout used by the given head in the sparse attention.
|
| 617 |
+
Arguments:
|
| 618 |
+
h: required: an integer determining head index
|
| 619 |
+
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
| 620 |
+
sparsity layout of all head; may not be completely set at this step
|
| 621 |
+
Return:
|
| 622 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout
|
| 623 |
+
of all head in which global layout is set
|
| 624 |
+
"""
|
| 625 |
+
|
| 626 |
+
num_blocks = layout.shape[1]
|
| 627 |
+
if num_blocks < self.num_global_blocks:
|
| 628 |
+
raise ValueError(
|
| 629 |
+
f"""Number of global blocks, {self.num_global_blocks}, must be smaller than overall number
|
| 630 |
+
of blocks in a row, {num_blocks}!"""
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
# global rows
|
| 634 |
+
layout[h, 0 : self.num_global_blocks, :] = 1
|
| 635 |
+
|
| 636 |
+
# global columns
|
| 637 |
+
layout[h, :, 0 : self.num_global_blocks] = 1
|
| 638 |
+
|
| 639 |
+
if self.attention == "unidirectional":
|
| 640 |
+
# zero out anything attending to the future
|
| 641 |
+
layout = torch.tril(layout)
|
| 642 |
+
|
| 643 |
+
return layout
|
| 644 |
+
|
| 645 |
+
def make_layout(self, seq_len):
|
| 646 |
+
"""Generates `BigBird` sparsity layout used by each head in the sparse attention.
|
| 647 |
+
Arguments:
|
| 648 |
+
seq_len: required: an integer determining number of attention heads of the layer.
|
| 649 |
+
Return:
|
| 650 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BigBird`
|
| 651 |
+
sparsity layout of all head
|
| 652 |
+
"""
|
| 653 |
+
|
| 654 |
+
layout = self.setup_layout(seq_len)
|
| 655 |
+
for h in range(0, self.num_layout_heads):
|
| 656 |
+
layout = self.set_random_layout(h, layout)
|
| 657 |
+
layout = self.set_sliding_window_layout(h, layout)
|
| 658 |
+
layout = self.set_global_layout_itc(h, layout)
|
| 659 |
+
|
| 660 |
+
layout = self.check_and_propagate_first_head_layout(layout)
|
| 661 |
+
return layout
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
class BSLongformerSparsityConfig(SparsityConfig):
|
| 665 |
+
"""Configuration class to store edited `Longformer` sparsity configuration.
|
| 666 |
+
Note) this is a block-sparse version of the Longformer which is slightly different than original
|
| 667 |
+
Longformer; which is element-wise sparsity.
|
| 668 |
+
For more details about this sparsity config, please see `Longformer:
|
| 669 |
+
The Long-Document Transformer`: https://arxiv.org/pdf/2004.05150.pdf
|
| 670 |
+
This class extends parent class of `SparsityConfig` and customizes it for `Longformer` sparsity.
|
| 671 |
+
"""
|
| 672 |
+
|
| 673 |
+
def __init__(
|
| 674 |
+
self,
|
| 675 |
+
num_heads,
|
| 676 |
+
block_size=16,
|
| 677 |
+
different_layout_per_head=False,
|
| 678 |
+
num_sliding_window_blocks=3,
|
| 679 |
+
global_block_indices=[0],
|
| 680 |
+
global_block_end_indices=None,
|
| 681 |
+
attention="bidirectional",
|
| 682 |
+
):
|
| 683 |
+
"""Initialize the edited `Longformer` Sparsity Pattern Config.
|
| 684 |
+
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
|
| 685 |
+
Arguments:
|
| 686 |
+
num_heads: required: an integer determining number of attention heads of the layer.
|
| 687 |
+
block_size: optional: an integer determining the block size. Current implementation of sparse
|
| 688 |
+
self-attention is based on blocked sparse matrices. In which this parameter defines size
|
| 689 |
+
of such blocks, `Block X Block`.
|
| 690 |
+
different_layout_per_head: optional: a boolean determining if each head should be assigned a
|
| 691 |
+
different sparsity layout; default is false and this will be satisfied based on
|
| 692 |
+
availability.
|
| 693 |
+
num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding
|
| 694 |
+
local attention window.
|
| 695 |
+
global_block_indices: optional: a list of integers determining which blocks are considered
|
| 696 |
+
as global attention. Given indices, determine the blocks that all other token blocks
|
| 697 |
+
attend to and they attend to all other token blocks. Default value is only index 0.
|
| 698 |
+
Notice that if global_block_end_indices parameter is set, this parameter is used as
|
| 699 |
+
starting index of each global window.
|
| 700 |
+
global_block_end_indices: optional: a list of integers determining end indices of global
|
| 701 |
+
window blocks. By default this is not used. But if it is set, it must have the same size
|
| 702 |
+
of global_block_indices parameter, and combining this two parameters, for each index i,
|
| 703 |
+
blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are
|
| 704 |
+
considered as global attention.
|
| 705 |
+
attention: optional: a string determining attention type. Attention can be `unidirectional`,
|
| 706 |
+
such as autoregressive models, in which tokens attend only to tokens appear before them
|
| 707 |
+
in the context. Considering that, the upper triangular of attention matrix is empty as
|
| 708 |
+
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
|
| 709 |
+
any other tokens before or after them. Then, the upper triangular part of the attention
|
| 710 |
+
matrix is mirror of the lower triangular in the above figure.
|
| 711 |
+
"""
|
| 712 |
+
|
| 713 |
+
super().__init__(num_heads, block_size, different_layout_per_head)
|
| 714 |
+
|
| 715 |
+
self.num_sliding_window_blocks = num_sliding_window_blocks
|
| 716 |
+
self.global_block_indices = global_block_indices
|
| 717 |
+
self.attention = attention
|
| 718 |
+
|
| 719 |
+
if global_block_end_indices is not None:
|
| 720 |
+
if len(global_block_indices) != len(global_block_end_indices):
|
| 721 |
+
raise ValueError(
|
| 722 |
+
f"""Global block start indices length, {len(global_block_indices)}, must be same as
|
| 723 |
+
global block end indices length, {len(global_block_end_indices)}!"""
|
| 724 |
+
)
|
| 725 |
+
for _, (start_idx, end_idx) in enumerate(
|
| 726 |
+
zip(global_block_indices, global_block_end_indices)
|
| 727 |
+
):
|
| 728 |
+
if start_idx >= end_idx:
|
| 729 |
+
raise ValueError(
|
| 730 |
+
f"""Global block start index, {start_idx}, must be smaller than global block end
|
| 731 |
+
index, {end_idx}!"""
|
| 732 |
+
)
|
| 733 |
+
self.global_block_end_indices = global_block_end_indices
|
| 734 |
+
|
| 735 |
+
def set_sliding_window_layout(self, h, layout):
|
| 736 |
+
"""Sets sliding local attention layout used by the given head in the sparse attention.
|
| 737 |
+
Arguments:
|
| 738 |
+
h: required: an integer determining head index
|
| 739 |
+
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
| 740 |
+
sparsity layout of all head; may not be completely set at this step
|
| 741 |
+
Return:
|
| 742 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout
|
| 743 |
+
of all head in which local sliding window layout is set
|
| 744 |
+
"""
|
| 745 |
+
|
| 746 |
+
num_blocks = layout.shape[1]
|
| 747 |
+
if num_blocks < self.num_sliding_window_blocks:
|
| 748 |
+
raise ValueError(
|
| 749 |
+
f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller
|
| 750 |
+
than overall number of blocks in a row, {num_blocks}!"""
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
w = self.num_sliding_window_blocks // 2
|
| 754 |
+
for row in range(0, num_blocks):
|
| 755 |
+
start = max(0, row - w)
|
| 756 |
+
end = min(row + w + 1, num_blocks)
|
| 757 |
+
layout[h, row, start:end] = 1
|
| 758 |
+
return layout
|
| 759 |
+
|
| 760 |
+
def set_global_layout(self, h, layout):
|
| 761 |
+
"""Sets global attention layout used by the given head in the sparse attention.
|
| 762 |
+
Arguments:
|
| 763 |
+
h: required: an integer determining head index
|
| 764 |
+
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
| 765 |
+
sparsity layout of all head; may not be completely set at this step
|
| 766 |
+
Return:
|
| 767 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
| 768 |
+
layout of all head in which global layout is set
|
| 769 |
+
"""
|
| 770 |
+
|
| 771 |
+
num_blocks = layout.shape[1]
|
| 772 |
+
if self.global_block_end_indices is None:
|
| 773 |
+
for idx in self.global_block_indices:
|
| 774 |
+
# if global block idx is in the range of the sequence blocks
|
| 775 |
+
if idx < num_blocks:
|
| 776 |
+
# global rows
|
| 777 |
+
layout[h, idx, :] = 1
|
| 778 |
+
|
| 779 |
+
# global columns
|
| 780 |
+
layout[h, :, idx] = 1
|
| 781 |
+
else:
|
| 782 |
+
for _, (start_idx, end_idx) in enumerate(
|
| 783 |
+
zip(self.global_block_indices, self.global_block_end_indices)
|
| 784 |
+
):
|
| 785 |
+
# if global block idx is in the range of the sequence blocks
|
| 786 |
+
if start_idx < num_blocks:
|
| 787 |
+
end_idx = min(end_idx, num_blocks)
|
| 788 |
+
# global rows
|
| 789 |
+
layout[h, start_idx:end_idx, :] = 1
|
| 790 |
+
|
| 791 |
+
# global columns
|
| 792 |
+
layout[h, :, start_idx:end_idx] = 1
|
| 793 |
+
if self.attention == "unidirectional":
|
| 794 |
+
layout = torch.tril(layout)
|
| 795 |
+
return layout
|
| 796 |
+
|
| 797 |
+
def make_layout(self, seq_len):
|
| 798 |
+
"""Generates edited `Longformer` sparsity layout used by each head in the sparse attention.
|
| 799 |
+
Arguments:
|
| 800 |
+
seq_len: required: an integer determining number of attention heads of the layer.
|
| 801 |
+
Return:
|
| 802 |
+
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BSLongformer`
|
| 803 |
+
sparsity layout of all head
|
| 804 |
+
"""
|
| 805 |
+
|
| 806 |
+
layout = self.setup_layout(seq_len)
|
| 807 |
+
for h in range(0, self.num_layout_heads):
|
| 808 |
+
layout = self.set_sliding_window_layout(h, layout)
|
| 809 |
+
layout = self.set_global_layout(h, layout)
|
| 810 |
+
|
| 811 |
+
layout = self.check_and_propagate_first_head_layout(layout)
|
| 812 |
+
return layout
|
.venv/lib/python3.11/site-packages/xformers/components/attention/utils.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Reshapes key padding mask from (batch_size, src_len) -> (batch_size * num_heads 1, src_len)
|
| 13 |
+
def reshape_key_padding_mask(
|
| 14 |
+
key_padding_mask: torch.Tensor, batched_dim: int
|
| 15 |
+
) -> torch.Tensor:
|
| 16 |
+
assert key_padding_mask.ndim == 2
|
| 17 |
+
batch_size, src_len = key_padding_mask.size()
|
| 18 |
+
num_heads = batched_dim // batch_size
|
| 19 |
+
return _reshape_key_padding_mask(key_padding_mask, batch_size, src_len, num_heads)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _reshape_key_padding_mask(
|
| 23 |
+
key_padding_mask: torch.Tensor, batch_size: int, src_len: int, num_heads: int
|
| 24 |
+
) -> torch.Tensor:
|
| 25 |
+
assert key_padding_mask.shape == (batch_size, src_len)
|
| 26 |
+
key_padding_mask = (
|
| 27 |
+
key_padding_mask.view(batch_size, 1, 1, src_len)
|
| 28 |
+
.expand(-1, num_heads, -1, -1)
|
| 29 |
+
.reshape(batch_size * num_heads, 1, src_len)
|
| 30 |
+
)
|
| 31 |
+
return key_padding_mask
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Combine the attention mask and key padding mask into a single mask
|
| 35 |
+
# Taken from https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py
|
| 36 |
+
# Additive masking not yet supported
|
| 37 |
+
def maybe_merge_masks(
|
| 38 |
+
att_mask: Optional[torch.Tensor],
|
| 39 |
+
key_padding_mask: Optional[torch.Tensor],
|
| 40 |
+
batch_size: int,
|
| 41 |
+
src_len: int,
|
| 42 |
+
num_heads: int,
|
| 43 |
+
tgt_len: Optional[int] = None,
|
| 44 |
+
) -> Optional[torch.Tensor]:
|
| 45 |
+
if tgt_len is None:
|
| 46 |
+
tgt_len = src_len
|
| 47 |
+
if key_padding_mask is not None:
|
| 48 |
+
assert key_padding_mask.shape == (batch_size, src_len)
|
| 49 |
+
key_padding_mask = _reshape_key_padding_mask(
|
| 50 |
+
key_padding_mask, batch_size, src_len, num_heads
|
| 51 |
+
)
|
| 52 |
+
if att_mask is None:
|
| 53 |
+
# make sure dimensions of key padding mask are the same as those expected for att_mask
|
| 54 |
+
att_mask = key_padding_mask.expand(-1, tgt_len, -1)
|
| 55 |
+
# Assumption is that False means to mask.
|
| 56 |
+
elif att_mask.dtype == torch.bool:
|
| 57 |
+
att_mask = att_mask.logical_and(key_padding_mask)
|
| 58 |
+
else:
|
| 59 |
+
att_mask = att_mask.masked_fill(~key_padding_mask, float("-inf"))
|
| 60 |
+
|
| 61 |
+
return att_mask
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Assumes that matrix passed in has had softmax applied to it.
|
| 65 |
+
def iterative_pinv(softmax_mat: torch.Tensor, n_iter=6, pinverse_original_init=False):
|
| 66 |
+
"""
|
| 67 |
+
Computing the Moore-Penrose inverse.
|
| 68 |
+
Use an iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose inverse via efficient
|
| 69 |
+
matrix-matrix multiplications.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
i = torch.eye(
|
| 73 |
+
softmax_mat.size(-1), device=softmax_mat.device, dtype=softmax_mat.dtype
|
| 74 |
+
)
|
| 75 |
+
k = softmax_mat
|
| 76 |
+
|
| 77 |
+
# The entries of K are positive and ||K||_{\infty} = 1 due to softmax
|
| 78 |
+
if pinverse_original_init:
|
| 79 |
+
# This original implementation is more conservative to compute coefficient of Z_0.
|
| 80 |
+
v = 1 / torch.max(torch.sum(k, dim=-2)) * k.transpose(-1, -2)
|
| 81 |
+
else:
|
| 82 |
+
# This is the exact coefficient computation, 1 / ||K||_1, of initialization of Z_0, leading to faster
|
| 83 |
+
# convergence.
|
| 84 |
+
v = (
|
| 85 |
+
1
|
| 86 |
+
/ torch.max(torch.sum(k, dim=-2), dim=-1).values[:, None, None]
|
| 87 |
+
* k.transpose(-1, -2)
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
for _ in range(n_iter):
|
| 91 |
+
kv = torch.matmul(k, v)
|
| 92 |
+
v = torch.matmul(
|
| 93 |
+
0.25 * v,
|
| 94 |
+
13 * i - torch.matmul(kv, 15 * i - torch.matmul(kv, 7 * i - kv)),
|
| 95 |
+
)
|
| 96 |
+
return v
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def bool_mask_to_additive(
|
| 100 |
+
mask: torch.Tensor, dtype: Optional[torch.dtype] = torch.float32
|
| 101 |
+
) -> torch.Tensor:
|
| 102 |
+
assert (
|
| 103 |
+
mask.dtype == torch.bool
|
| 104 |
+
), "This util is meant to convert in between bool masks and additive ones"
|
| 105 |
+
|
| 106 |
+
mask_ = torch.zeros_like(mask, dtype=dtype)
|
| 107 |
+
mask_[~mask] = float("-inf")
|
| 108 |
+
return mask_
|
.venv/lib/python3.11/site-packages/xformers/components/feedforward/__init__.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Callable, Dict, Set, Union
|
| 9 |
+
|
| 10 |
+
from xformers.utils import (
|
| 11 |
+
generate_matching_config,
|
| 12 |
+
get_registry_decorator,
|
| 13 |
+
import_all_modules,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from .base import Feedforward, FeedforwardConfig # noqa
|
| 17 |
+
|
| 18 |
+
# CREDITS: Classy Vision registry mechanism
|
| 19 |
+
|
| 20 |
+
FEEDFORWARD_REGISTRY: Dict[str, Any] = {}
|
| 21 |
+
FEEDFORWARD_CLASS_NAMES: Set[str] = set()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def build_feedforward(config: Union[Dict[str, Any], FeedforwardConfig]):
|
| 25 |
+
"""Builds a feedforward from a config.
|
| 26 |
+
|
| 27 |
+
This assumes a 'name' key in the config which is used to determine what
|
| 28 |
+
attention class to instantiate. For instance, a config `{"name": "my_feedforward",
|
| 29 |
+
"foo": "bar"}` will find a class that was registered as "my_feedforward"
|
| 30 |
+
(see :func:`register_feedforward`) and call .from_config on it."""
|
| 31 |
+
|
| 32 |
+
if not isinstance(config, FeedforwardConfig):
|
| 33 |
+
config_instance = generate_matching_config(
|
| 34 |
+
config, FEEDFORWARD_REGISTRY[config["name"]].config
|
| 35 |
+
)
|
| 36 |
+
else:
|
| 37 |
+
config_instance = config
|
| 38 |
+
|
| 39 |
+
return FEEDFORWARD_REGISTRY[config_instance.name].constructor.from_config(
|
| 40 |
+
config_instance
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
"""Registers a Feedforward subclass.
|
| 45 |
+
|
| 46 |
+
This decorator allows xFormers to instantiate a subclass of Feedforward
|
| 47 |
+
from a configuration file, even if the class itself is not part of the
|
| 48 |
+
xFormers framework. To use it, apply this decorator to a Feedforward
|
| 49 |
+
subclass, like this:
|
| 50 |
+
|
| 51 |
+
.. code-block:: python
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class MyConfig:
|
| 55 |
+
...
|
| 56 |
+
|
| 57 |
+
@register_feedforward('my_ff', MyConfig)
|
| 58 |
+
class MyFeedforward(Feedforward):
|
| 59 |
+
...
|
| 60 |
+
|
| 61 |
+
To instantiate a feedforward from a configuration file, see :func:`build_feedforward`."""
|
| 62 |
+
register_feedforward: Callable[
|
| 63 |
+
[str, Any], Callable[[Any], Any]
|
| 64 |
+
] = get_registry_decorator(
|
| 65 |
+
FEEDFORWARD_REGISTRY, FEEDFORWARD_CLASS_NAMES, Feedforward, FeedforwardConfig
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
from .mlp import MLP # noqa
|
| 69 |
+
|
| 70 |
+
__all__ = [
|
| 71 |
+
"MLP",
|
| 72 |
+
"Feedforward",
|
| 73 |
+
"build_feedforward",
|
| 74 |
+
"register_feedforward",
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
# automatically import any Python files in the directory
|
| 78 |
+
import_all_modules(str(Path(__file__).parent), "xformers.components.feedforward")
|