Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/xformers/components/__init__.py +86 -0
- .venv/lib/python3.11/site-packages/xformers/components/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/__pycache__/activations.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/__pycache__/input_projection.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/__pycache__/multi_head_dispatch.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/__pycache__/patch_embedding.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/__pycache__/residual.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/__pycache__/reversible.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/__pycache__/simplicial_embedding.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/activations.py +76 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/attention_patterns.py +295 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/core.py +248 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/favor.py +173 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/fourier_mix.py +35 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/lambda_layer.py +78 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/local.py +120 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/nystrom.py +295 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/random.py +126 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/scaled_dot_product.py +134 -0
- .venv/lib/python3.11/site-packages/xformers/components/attention/visual.py +96 -0
- .venv/lib/python3.11/site-packages/xformers/components/input_projection.py +102 -0
- .venv/lib/python3.11/site-packages/xformers/components/multi_head_dispatch.py +271 -0
- .venv/lib/python3.11/site-packages/xformers/components/patch_embedding.py +83 -0
- .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__init__.py +87 -0
- .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/param.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/rotary.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/sine.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/vocab.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/base.py +38 -0
- .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/param.py +54 -0
- .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/rotary.py +91 -0
- .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/sine.py +46 -0
- .venv/lib/python3.11/site-packages/xformers/components/positional_embedding/vocab.py +65 -0
- .venv/lib/python3.11/site-packages/xformers/components/residual.py +192 -0
- .venv/lib/python3.11/site-packages/xformers/components/reversible.py +160 -0
- .venv/lib/python3.11/site-packages/xformers/components/simplicial_embedding.py +67 -0
- .venv/lib/python3.11/site-packages/xformers/ops/__init__.py +130 -0
- .venv/lib/python3.11/site-packages/xformers/ops/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/ops/_triton/k_index_select_cat.py +184 -0
- .venv/lib/python3.11/site-packages/xformers/ops/_triton/k_scaled_index_add.py +365 -0
- .venv/lib/python3.11/site-packages/xformers/ops/_triton/rmsnorm_kernels.py +163 -0
- .venv/lib/python3.11/site-packages/xformers/ops/_triton/rope_padded_kernels.py +226 -0
- .venv/lib/python3.11/site-packages/xformers/ops/_triton/tiled_matmul_kernels.py +430 -0
- .venv/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py +893 -0
- .venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/attn_bias.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck_decoder.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/xformers/components/__init__.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import warnings
|
| 8 |
+
from dataclasses import fields
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict, Union
|
| 11 |
+
|
| 12 |
+
from xformers.utils import import_all_modules
|
| 13 |
+
|
| 14 |
+
from .activations import Activation, build_activation # noqa
|
| 15 |
+
from .attention import Attention, build_attention # noqa
|
| 16 |
+
from .input_projection import InputProjection, InputProjectionConfig # noqa
|
| 17 |
+
from .multi_head_dispatch import MultiHeadDispatch # noqa
|
| 18 |
+
from .multi_head_dispatch import MultiHeadDispatchConfig
|
| 19 |
+
from .patch_embedding import PatchEmbeddingConfig # noqa
|
| 20 |
+
from .patch_embedding import build_patch_embedding # noqa
|
| 21 |
+
from .residual import NormalizationType # noqa
|
| 22 |
+
from .residual import PostNorm # noqa
|
| 23 |
+
from .residual import PreNorm # noqa
|
| 24 |
+
from .residual import RequiresWrappedInputs # noqa
|
| 25 |
+
from .residual import Residual # noqa
|
| 26 |
+
from .residual import ResidualNormStyle # noqa
|
| 27 |
+
|
| 28 |
+
warnings.warn(
|
| 29 |
+
"xformers.components is deprecated and is not maintained anymore. "
|
| 30 |
+
"It might be removed in a future version of xFormers ",
|
| 31 |
+
FutureWarning,
|
| 32 |
+
stacklevel=2,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# automatically import any Python files in the directory
|
| 37 |
+
import_all_modules(str(Path(__file__).parent), "xformers.components")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def build_multi_head_attention(
|
| 41 |
+
multi_head_config: Union[MultiHeadDispatchConfig, Dict[str, Any]],
|
| 42 |
+
):
|
| 43 |
+
"""Builds a multihead attention from a config.
|
| 44 |
+
|
| 45 |
+
This assumes a 'name' key in the config which is used to determine what
|
| 46 |
+
attention class to instantiate. For instance, a config `{"name": "my_attention",
|
| 47 |
+
"foo": "bar"}` will find a class that was registered as "my_attention"
|
| 48 |
+
(see :func:`register_attention`) and call .from_config on it."""
|
| 49 |
+
|
| 50 |
+
if not isinstance(multi_head_config, MultiHeadDispatchConfig):
|
| 51 |
+
# Extract the required fields
|
| 52 |
+
field_names = list(map(lambda x: x.name, fields(MultiHeadDispatchConfig)))
|
| 53 |
+
|
| 54 |
+
# The missing fields get Noned
|
| 55 |
+
for k in field_names:
|
| 56 |
+
if k not in multi_head_config.keys():
|
| 57 |
+
multi_head_config[k] = None
|
| 58 |
+
|
| 59 |
+
# Could be that the attention needs to be instantiated
|
| 60 |
+
if not isinstance(multi_head_config["attention"], Attention):
|
| 61 |
+
# Convenience: fill in possible missing fields
|
| 62 |
+
if "num_heads" not in multi_head_config["attention"]:
|
| 63 |
+
multi_head_config["attention"]["num_heads"] = multi_head_config[
|
| 64 |
+
"num_heads"
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
if "dim_model" not in multi_head_config["attention"]:
|
| 68 |
+
multi_head_config["attention"]["dim_model"] = multi_head_config[
|
| 69 |
+
"dim_model"
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
if (
|
| 73 |
+
"dim_features" not in multi_head_config["attention"]
|
| 74 |
+
or multi_head_config["attention"]["dim_features"] is None
|
| 75 |
+
):
|
| 76 |
+
multi_head_config["attention"]["dim_features"] = (
|
| 77 |
+
multi_head_config["dim_model"] // multi_head_config["num_heads"]
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
multi_head_config["attention"] = build_attention(
|
| 81 |
+
multi_head_config["attention"]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
multi_head_config = MultiHeadDispatchConfig(**multi_head_config)
|
| 85 |
+
|
| 86 |
+
return MultiHeadDispatch.from_config(multi_head_config)
|
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (3.57 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/activations.cpython-311.pyc
ADDED
|
Binary file (4.59 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/input_projection.cpython-311.pyc
ADDED
|
Binary file (3.97 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/multi_head_dispatch.cpython-311.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/patch_embedding.cpython-311.pyc
ADDED
|
Binary file (4.53 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/residual.cpython-311.pyc
ADDED
|
Binary file (9.56 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/reversible.cpython-311.pyc
ADDED
|
Binary file (9.78 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/__pycache__/simplicial_embedding.cpython-311.pyc
ADDED
|
Binary file (3.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/activations.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from enum import Enum
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
from xformers._deprecation_warning import deprecated_function
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Activation(str, Enum):
|
| 17 |
+
SquaredReLU = "squared_relu"
|
| 18 |
+
GeLU = "gelu"
|
| 19 |
+
LeakyReLU = "leaky_relu"
|
| 20 |
+
ReLU = "relu"
|
| 21 |
+
SmeLU = "smelu"
|
| 22 |
+
StarReLU = "star_relu"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# For unit testing / parity comparisons, probably not the fastest way
|
| 26 |
+
class SquaredReLU(nn.Module):
|
| 27 |
+
def __init__(self) -> None:
|
| 28 |
+
super().__init__()
|
| 29 |
+
deprecated_function(self)
|
| 30 |
+
|
| 31 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
x_ = torch.nn.functional.relu(x)
|
| 33 |
+
return x_ * x_
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class StarReLU(nn.Module):
|
| 37 |
+
def __init__(self) -> None:
|
| 38 |
+
super().__init__()
|
| 39 |
+
deprecated_function(self)
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
x_ = torch.nn.functional.relu(x)
|
| 43 |
+
return 0.8944 * x_ * x_ - 0.4472
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SmeLU(nn.Module):
|
| 47 |
+
def __init__(self, beta: float = 2.0) -> None:
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.beta = beta
|
| 50 |
+
deprecated_function(self)
|
| 51 |
+
|
| 52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
relu = torch.where(
|
| 54 |
+
x >= self.beta,
|
| 55 |
+
x,
|
| 56 |
+
torch.tensor([0.0], device=x.device, dtype=x.dtype),
|
| 57 |
+
)
|
| 58 |
+
return torch.where(
|
| 59 |
+
torch.abs(x) <= self.beta,
|
| 60 |
+
((x + self.beta) ** 2).type_as(x) / (4.0 * self.beta),
|
| 61 |
+
relu,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def build_activation(activation: Optional[Activation]):
|
| 66 |
+
if not activation:
|
| 67 |
+
return nn.Identity()
|
| 68 |
+
|
| 69 |
+
return {
|
| 70 |
+
Activation.ReLU: nn.ReLU,
|
| 71 |
+
Activation.GeLU: nn.GELU,
|
| 72 |
+
Activation.LeakyReLU: nn.LeakyReLU,
|
| 73 |
+
Activation.SquaredReLU: SquaredReLU,
|
| 74 |
+
Activation.StarReLU: StarReLU,
|
| 75 |
+
Activation.SmeLU: SmeLU,
|
| 76 |
+
}[activation]()
|
.venv/lib/python3.11/site-packages/xformers/components/attention/attention_patterns.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from xformers.components.attention.sparsity_config import (
|
| 14 |
+
BigBirdSparsityConfig,
|
| 15 |
+
BSLongformerSparsityConfig,
|
| 16 |
+
FixedSparsityConfig,
|
| 17 |
+
VariableSparsityConfig,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# generic nd cases
|
| 22 |
+
def _generate_nd_grid(*sizes):
|
| 23 |
+
coords = [torch.arange(s) for s in sizes]
|
| 24 |
+
return torch.meshgrid(*coords)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def local_nd_distance(*sizes, p=2.0, weights=None):
|
| 28 |
+
if weights is None:
|
| 29 |
+
weights = (1,) * len(sizes)
|
| 30 |
+
assert len(sizes) == len(weights)
|
| 31 |
+
grid = _generate_nd_grid(*sizes)
|
| 32 |
+
grid = [i.flatten() * w for i, w in zip(grid, weights)]
|
| 33 |
+
grid = torch.stack(grid, dim=1).float()
|
| 34 |
+
d = torch.cdist(grid, grid, p=p)
|
| 35 |
+
return d
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def local_nd_gaussian_distribution(*sizes, sigma=1):
|
| 39 |
+
d = local_nd_distance(*sizes, p=2.0) ** 2
|
| 40 |
+
d = torch.exp(-0.5 * sigma ** (-2.0) * d)
|
| 41 |
+
return d
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def local_nd_pattern(*sizes, distance, p=2.0):
|
| 45 |
+
d = local_nd_distance(*sizes, p=p)
|
| 46 |
+
return d < distance
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def axial_nd_pattern(*sizes):
|
| 50 |
+
# axial is a special case with p=0 and distance=2
|
| 51 |
+
d = local_nd_distance(*sizes, p=0)
|
| 52 |
+
return d < 2
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def random_pattern_from_probability_matrix(dist_matrix, nnz):
|
| 56 |
+
att = torch.zeros_like(dist_matrix, dtype=torch.bool)
|
| 57 |
+
# PyTorch multinomial wrongly doesn't support sampling when number of categories
|
| 58 |
+
# is > 2^24, arguing that it's because it's the max representable consecutive element
|
| 59 |
+
# in fp32 and that the kernels use float32. This is actually not true, and the kernels
|
| 60 |
+
# should work fine if double tensor is passed on CPU. This is a bug that was introduced
|
| 61 |
+
# in https://github.com/pytorch/pytorch/commit/bf04c2ca2f591d98ce57816f0ef0cd20a21bbf66
|
| 62 |
+
# when unifying the checks between CPU and CUDA. For now, just fall-back to numpy
|
| 63 |
+
if dist_matrix.numel() > 2**24:
|
| 64 |
+
dist_matrix = dist_matrix.double()
|
| 65 |
+
dist_matrix /= dist_matrix.sum()
|
| 66 |
+
idxs = np.random.choice(
|
| 67 |
+
dist_matrix.numel(), nnz, p=dist_matrix.flatten(), replace=False
|
| 68 |
+
)
|
| 69 |
+
idxs = torch.as_tensor(idxs)
|
| 70 |
+
else:
|
| 71 |
+
idxs = torch.multinomial(dist_matrix.flatten(), nnz, replacement=False)
|
| 72 |
+
att.view(-1)[idxs] = True
|
| 73 |
+
return att
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def global_token_pattern(attention_query_mask: torch.Tensor) -> torch.Tensor:
|
| 77 |
+
assert attention_query_mask.ndim == 1
|
| 78 |
+
assert attention_query_mask.dtype == torch.bool
|
| 79 |
+
attention_query_mask = attention_query_mask[None, :]
|
| 80 |
+
mask = attention_query_mask | attention_query_mask.transpose(1, 0)
|
| 81 |
+
return mask
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def random_pattern(attn_size: int, sparsity: float) -> torch.Tensor:
|
| 85 |
+
assert 0 < sparsity < 1
|
| 86 |
+
mask = torch.rand(attn_size, attn_size) > sparsity
|
| 87 |
+
return mask
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# 1d-specific cases
|
| 91 |
+
def local_1d_pattern(attn_size: int, window_size: int) -> torch.Tensor:
|
| 92 |
+
assert (
|
| 93 |
+
window_size % 2 == 1
|
| 94 |
+
), "The window size is assumed to be odd (counts self-attention + 2 wings)"
|
| 95 |
+
h_win_size = window_size // 2 + 1
|
| 96 |
+
return local_nd_pattern(attn_size, distance=h_win_size, p=1.0)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def causal_1d_pattern(attn_size: int) -> torch.Tensor:
|
| 100 |
+
mask = torch.tril(torch.ones(attn_size, attn_size, dtype=torch.bool))
|
| 101 |
+
return mask
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# 2d-specific cases
|
| 105 |
+
def horizontal_axial_2d_distance(H, W, p=2.0):
|
| 106 |
+
d = local_nd_distance(H, W, p=p, weights=(1, 0))
|
| 107 |
+
return d
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def vertical_axial_2d_distance(H, W, p=2.0):
|
| 111 |
+
d = local_nd_distance(H, W, p=p, weights=(0, 1))
|
| 112 |
+
return d
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def local_2d_distance(H, W, p=2.0):
|
| 116 |
+
return local_nd_distance(H, W, p=p)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def local_2d_gausian_distribution(H, W, sigma=1):
|
| 120 |
+
return local_nd_gaussian_distribution(H, W, sigma=sigma)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def local_2d_pattern(H, W, distance, p=2.0):
|
| 124 |
+
return local_nd_pattern(H, W, distance=distance, p=p)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def axial_2d_pattern(H, W):
|
| 128 |
+
return axial_nd_pattern(H, W)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def swin_attention_pattern(H, W, window_size, shift_size=0):
|
| 132 |
+
assert H % window_size == 0
|
| 133 |
+
assert W % window_size == 0
|
| 134 |
+
assert 0 <= shift_size < window_size, "shift_size must in 0-window_size"
|
| 135 |
+
|
| 136 |
+
# input grid
|
| 137 |
+
i, j = _generate_nd_grid(H, W)
|
| 138 |
+
i, j = i + 0.5, j + 0.5
|
| 139 |
+
|
| 140 |
+
# anchors grid
|
| 141 |
+
# if shift is present, add extra element to the grid
|
| 142 |
+
# to account for the uneven partitioning
|
| 143 |
+
extra = int(shift_size % window_size != 0)
|
| 144 |
+
grid_h = H // window_size + extra
|
| 145 |
+
grid_w = W // window_size + extra
|
| 146 |
+
|
| 147 |
+
ii, jj = _generate_nd_grid(grid_h, grid_w)
|
| 148 |
+
# convert shift to be compatible with the paper representation
|
| 149 |
+
s = (-shift_size) % window_size
|
| 150 |
+
offset = window_size / 2 - s
|
| 151 |
+
ii = ii * window_size + offset
|
| 152 |
+
jj = jj * window_size + offset
|
| 153 |
+
|
| 154 |
+
input_coords = torch.stack([i.flatten(), j.flatten()], 1).float()
|
| 155 |
+
anchors_coords = torch.stack([ii.flatten(), jj.flatten()], 1).float()
|
| 156 |
+
|
| 157 |
+
anchor_id = torch.cdist(input_coords, anchors_coords, p=2).argmin(1)
|
| 158 |
+
mask = anchor_id[:, None] == anchor_id[None, :]
|
| 159 |
+
return mask
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def dilated_2d_pattern(H, W, k=2):
|
| 163 |
+
"""
|
| 164 |
+
Returns a 2d pattern that samples 1 every k elements in the attention mask.
|
| 165 |
+
Can be seen as a form of downsampling, where every pixel attends to a downsampled
|
| 166 |
+
version of the input.
|
| 167 |
+
"""
|
| 168 |
+
d_h = local_nd_distance(H, W, p=1, weights=(1, 0))
|
| 169 |
+
d_w = local_nd_distance(H, W, p=1, weights=(0, 1))
|
| 170 |
+
d = (d_h.floor() % k == 0) & (d_w.floor() % k == 0)
|
| 171 |
+
return d
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# Block sparse utils
|
| 175 |
+
def block_sparsify_tensor(x, mask, block_size):
|
| 176 |
+
"""
|
| 177 |
+
Block sparsify a tensor, given a mask and block size
|
| 178 |
+
"""
|
| 179 |
+
ret = torch.empty(
|
| 180 |
+
(x.size(0), mask.sum(), block_size, block_size), dtype=x.dtype, device=x.device
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
| 184 |
+
ret[:, idx, :, :] = x[
|
| 185 |
+
:,
|
| 186 |
+
h,
|
| 187 |
+
i * block_size : (i + 1) * block_size,
|
| 188 |
+
j * block_size : (j + 1) * block_size,
|
| 189 |
+
]
|
| 190 |
+
return ret
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def pattern_to_layout(mask: torch.Tensor, block_size: int) -> torch.Tensor:
|
| 194 |
+
r"""
|
| 195 |
+
Given a mask pattern and blocksize, return the corresponding layout
|
| 196 |
+
which makes sure that all the positives in the mask are covered
|
| 197 |
+
"""
|
| 198 |
+
assert mask.ndim >= 2, "We're expecting [Heads, Seq, Seq] or [Seq, Seq]"
|
| 199 |
+
_should_squeeze = False
|
| 200 |
+
|
| 201 |
+
if mask.ndim == 2:
|
| 202 |
+
mask = mask.unsqueeze(0)
|
| 203 |
+
_should_squeeze = True
|
| 204 |
+
|
| 205 |
+
assert (
|
| 206 |
+
mask.shape[1] % block_size == 0 and mask.shape[2] % block_size == 0
|
| 207 |
+
), "We're only handling masks divisible by block_size"
|
| 208 |
+
|
| 209 |
+
# Now mark the mask
|
| 210 |
+
layout = torch.nn.functional.max_pool2d(
|
| 211 |
+
mask.to(torch.float), kernel_size=block_size, stride=block_size
|
| 212 |
+
)
|
| 213 |
+
layout = layout.to(torch.long)
|
| 214 |
+
|
| 215 |
+
if _should_squeeze:
|
| 216 |
+
layout.squeeze_(0)
|
| 217 |
+
|
| 218 |
+
return layout
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def alibi_pattern(threshold: float, mask_shape: torch.Size) -> torch.Tensor:
|
| 222 |
+
r"""
|
| 223 |
+
Use the additive bias computation from ALiBi_ to generate a mask.
|
| 224 |
+
Note that this mask can in turn be used to generate a blocksparse attention computation layout
|
| 225 |
+
|
| 226 |
+
.. note: mask_shape is expected to hold the [heads, seq, seq] dimensions
|
| 227 |
+
|
| 228 |
+
.. _ALiBi: https://arxiv.org/pdf/2108.12409.pdf
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
# CREDITS: code snippet from Ofir Press, one of the authors
|
| 232 |
+
|
| 233 |
+
def get_slopes(n: int):
|
| 234 |
+
def get_slopes_power_of_2(n: int) -> List[float]:
|
| 235 |
+
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
| 236 |
+
ratio = start
|
| 237 |
+
return [start * ratio**i for i in range(n)]
|
| 238 |
+
|
| 239 |
+
# In the paper, we only train models that have 2^a heads for some a. This function has
|
| 240 |
+
# some good properties that only occur when the input is a power of 2. To maintain that even
|
| 241 |
+
# when the number of heads is not a power of 2, we use this workaround.
|
| 242 |
+
if math.log2(n).is_integer():
|
| 243 |
+
return get_slopes_power_of_2(n)
|
| 244 |
+
else:
|
| 245 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
| 246 |
+
return (
|
| 247 |
+
get_slopes_power_of_2(closest_power_of_2)
|
| 248 |
+
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
maxpos = mask_shape[1]
|
| 252 |
+
attn_heads = mask_shape[0]
|
| 253 |
+
slopes = torch.Tensor(get_slopes(attn_heads))
|
| 254 |
+
|
| 255 |
+
# In the next line, the part after the * is what constructs the diagonal matrix
|
| 256 |
+
# (right matrix in Figure 3 in the paper).
|
| 257 |
+
# If you run it you'll see that it doesn't exactly print out the same matrix as we have in Figure 3,
|
| 258 |
+
# but one where all rows are identical.
|
| 259 |
+
# This works because the softmax operation is invariant to translation,
|
| 260 |
+
# and our bias functions are always linear.
|
| 261 |
+
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(maxpos).unsqueeze(
|
| 262 |
+
0
|
| 263 |
+
).unsqueeze(0).expand(attn_heads, -1, -1)
|
| 264 |
+
alibi = alibi.view(attn_heads, 1, maxpos)
|
| 265 |
+
|
| 266 |
+
# Now threshold arbitrarily, report the mask
|
| 267 |
+
return alibi < threshold
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def quick_fixed_layout(num_heads: int, block_size: int, seq_len: int):
|
| 271 |
+
config = FixedSparsityConfig(num_heads=num_heads, block_size=block_size)
|
| 272 |
+
return config.make_layout(seq_len)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def quick_variable_layout(num_heads: int, block_size: int, seq_len: int):
|
| 276 |
+
config = VariableSparsityConfig(num_heads=num_heads, block_size=block_size)
|
| 277 |
+
return config.make_layout(seq_len)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def quick_bigbird_layout(num_heads: int, block_size: int, seq_len: int):
|
| 281 |
+
config = BigBirdSparsityConfig(num_heads=num_heads, block_size=block_size)
|
| 282 |
+
return config.make_layout(seq_len)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def quick_bslongformer_layout(num_heads: int, block_size: int, seq_len: int):
|
| 286 |
+
config = BSLongformerSparsityConfig(num_heads=num_heads, block_size=block_size)
|
| 287 |
+
return config.make_layout(seq_len)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def layout_to_pattern(layout: torch.Tensor, block_size: int):
|
| 291 |
+
r"""
|
| 292 |
+
create a pattern of shape [heads, seq, seq] out of a blocksparse
|
| 293 |
+
layout of shape [heads, seq/block_size, seq/block_size]
|
| 294 |
+
"""
|
| 295 |
+
return torch.kron(layout, torch.ones(block_size, block_size))
|
.venv/lib/python3.11/site-packages/xformers/components/attention/core.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import math
|
| 9 |
+
from contextlib import nullcontext
|
| 10 |
+
from typing import Optional, Union
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from xformers import _has_cpp_library
|
| 15 |
+
from xformers.components.attention.attention_mask import AttentionMask
|
| 16 |
+
|
| 17 |
+
if _has_cpp_library:
|
| 18 |
+
from ._sputnik_sparse import SparseCS
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger("xformers")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _create_random_sparsity(matrix, sparsity, divisible_by=4):
|
| 24 |
+
assert matrix.ndim == 3
|
| 25 |
+
keep = torch.rand_like(matrix[0], dtype=torch.float32) > sparsity
|
| 26 |
+
nonzero = torch.nonzero(keep)
|
| 27 |
+
nnz = nonzero.shape[0]
|
| 28 |
+
# NOTE: need to make it a multiple of 4 for sputnik
|
| 29 |
+
nonzero = nonzero[: (nnz - nnz % divisible_by)]
|
| 30 |
+
i, j = nonzero.unbind(1)
|
| 31 |
+
output = torch.zeros_like(matrix)
|
| 32 |
+
bdim = torch.arange(matrix.shape[0], device=matrix.device)[:, None]
|
| 33 |
+
output[bdim, i, j] = matrix[bdim, i, j]
|
| 34 |
+
return output
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _broadcast_batch(mask, batch_size):
|
| 38 |
+
if mask.ndim == 3:
|
| 39 |
+
return mask
|
| 40 |
+
assert mask.ndim == 2
|
| 41 |
+
|
| 42 |
+
mask = mask.coalesce()
|
| 43 |
+
values = mask.values()
|
| 44 |
+
indices = mask.indices()
|
| 45 |
+
nnz = len(values)
|
| 46 |
+
# strategy: repeat the indices and append the extra batch dimension to the indices
|
| 47 |
+
indices = indices.repeat(1, batch_size)
|
| 48 |
+
# now create the batch indices
|
| 49 |
+
batch_indices = torch.arange(batch_size, device=indices.device)
|
| 50 |
+
batch_indices = batch_indices[:, None].expand(batch_size, nnz).flatten()
|
| 51 |
+
|
| 52 |
+
# put them together
|
| 53 |
+
indices = torch.cat([batch_indices[None, :], indices], dim=0)
|
| 54 |
+
|
| 55 |
+
# now repeat the values
|
| 56 |
+
values = values.repeat(batch_size)
|
| 57 |
+
|
| 58 |
+
size = (batch_size,) + mask.shape
|
| 59 |
+
|
| 60 |
+
return torch.sparse_coo_tensor(indices, values, size)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _matmul_with_mask(
|
| 64 |
+
a: torch.Tensor,
|
| 65 |
+
b: torch.Tensor,
|
| 66 |
+
mask: Optional[Union[torch.Tensor, "SparseCS"]],
|
| 67 |
+
) -> torch.Tensor:
|
| 68 |
+
if mask is None:
|
| 69 |
+
return a @ b
|
| 70 |
+
|
| 71 |
+
if _has_cpp_library and mask.dtype == torch.bool:
|
| 72 |
+
if isinstance(mask, SparseCS):
|
| 73 |
+
return mask.matmul_with_mask(a, b)
|
| 74 |
+
if mask.is_sparse:
|
| 75 |
+
# perform broadcasting if needed
|
| 76 |
+
mask = _broadcast_batch(mask, a.shape[0])
|
| 77 |
+
|
| 78 |
+
# coalesced is not implemented for bool tensors, so need to cast
|
| 79 |
+
mask = mask.to(dtype=a.dtype) # type: ignore # mypy is missing the catch above
|
| 80 |
+
|
| 81 |
+
return torch.ops.xformers.matmul_with_mask(a, b, mask)
|
| 82 |
+
|
| 83 |
+
# Non optimized codepath
|
| 84 |
+
if _has_cpp_library:
|
| 85 |
+
assert not isinstance(mask, SparseCS)
|
| 86 |
+
|
| 87 |
+
att = a @ b
|
| 88 |
+
if mask.dtype == torch.bool:
|
| 89 |
+
assert not isinstance(mask, SparseCS)
|
| 90 |
+
if mask.ndim == 2:
|
| 91 |
+
mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
|
| 92 |
+
# mask is presumed false == ignore
|
| 93 |
+
att[~mask] = float("-inf")
|
| 94 |
+
else:
|
| 95 |
+
# mask is presumed additive
|
| 96 |
+
# repeat if batch sizes don't match
|
| 97 |
+
if (
|
| 98 |
+
not isinstance(mask, SparseCS)
|
| 99 |
+
and mask.ndim == 3
|
| 100 |
+
and mask.shape[0] != att.shape[0]
|
| 101 |
+
and (att.shape[0] % mask.shape[0]) == 0
|
| 102 |
+
):
|
| 103 |
+
repeat_factor = att.shape[0] // mask.shape[0]
|
| 104 |
+
mask = mask.repeat([repeat_factor, 1, 1])
|
| 105 |
+
logger.info("Mismatched batch dimensions for mask, repeating mask.")
|
| 106 |
+
att += mask
|
| 107 |
+
return att
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _softmax(a: torch.Tensor, causal: bool = False) -> torch.Tensor:
|
| 111 |
+
if _has_cpp_library and isinstance(a, SparseCS):
|
| 112 |
+
return a.softmax()
|
| 113 |
+
|
| 114 |
+
if a.is_sparse:
|
| 115 |
+
return torch.sparse.softmax(a, dim=a.ndim - 1)
|
| 116 |
+
|
| 117 |
+
return torch.softmax(a, dim=a.ndim - 1)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if _has_cpp_library:
|
| 121 |
+
|
| 122 |
+
class SparseBMM(torch.autograd.Function):
|
| 123 |
+
@staticmethod
|
| 124 |
+
def forward(ctx, a, b):
|
| 125 |
+
a = a.coalesce()
|
| 126 |
+
r = torch.bmm(a, b)
|
| 127 |
+
ctx.save_for_backward(a, b)
|
| 128 |
+
return r
|
| 129 |
+
|
| 130 |
+
@staticmethod
|
| 131 |
+
def backward(ctx, grad):
|
| 132 |
+
a, b = ctx.saved_tensors
|
| 133 |
+
|
| 134 |
+
# gradients w.r.t. a
|
| 135 |
+
ga = None
|
| 136 |
+
if ctx.needs_input_grad[0]:
|
| 137 |
+
ga = torch.ops.xformers.matmul_with_mask(grad, b.transpose(-2, -1), a)
|
| 138 |
+
|
| 139 |
+
# gradients w.r.t. b
|
| 140 |
+
gb = None
|
| 141 |
+
if ctx.needs_input_grad[1]:
|
| 142 |
+
gb = a.transpose(1, 2).bmm(grad)
|
| 143 |
+
|
| 144 |
+
return ga, gb
|
| 145 |
+
|
| 146 |
+
def _sparse_bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 147 |
+
"""
|
| 148 |
+
Batch matrix multiply between a sparse matrix and a dense matrix
|
| 149 |
+
"""
|
| 150 |
+
assert a.ndim == b.ndim == 3
|
| 151 |
+
assert a.shape[0] == b.shape[0]
|
| 152 |
+
assert a.shape[2] == b.shape[1]
|
| 153 |
+
return SparseBMM.apply(a, b)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 157 |
+
if _has_cpp_library:
|
| 158 |
+
if isinstance(a, SparseCS):
|
| 159 |
+
return a.spmm(b)
|
| 160 |
+
if a.is_sparse:
|
| 161 |
+
return _sparse_bmm(a, b)
|
| 162 |
+
return a @ b
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _apply_dropout(att, dropout):
|
| 166 |
+
if dropout is None:
|
| 167 |
+
return att
|
| 168 |
+
|
| 169 |
+
# Dropout chokes on sparse tensors
|
| 170 |
+
if _has_cpp_library:
|
| 171 |
+
if isinstance(att, SparseCS):
|
| 172 |
+
values = att.values.clone()
|
| 173 |
+
values = dropout(values)
|
| 174 |
+
att = SparseCS.wrap(
|
| 175 |
+
att.shape,
|
| 176 |
+
values,
|
| 177 |
+
att.row_indices,
|
| 178 |
+
att.row_offsets,
|
| 179 |
+
att.column_indices,
|
| 180 |
+
att._transp_info,
|
| 181 |
+
)
|
| 182 |
+
elif att.is_sparse:
|
| 183 |
+
att = att.coalesce()
|
| 184 |
+
values = att.values().clone() # protect against in-place dropout
|
| 185 |
+
values = dropout(values)
|
| 186 |
+
att = torch.sparse_coo_tensor(att.indices(), values, att.shape)
|
| 187 |
+
else:
|
| 188 |
+
# Simple dense case
|
| 189 |
+
att = dropout(att)
|
| 190 |
+
|
| 191 |
+
return att
|
| 192 |
+
|
| 193 |
+
# Non optimized vanilla dropout
|
| 194 |
+
att = dropout(att)
|
| 195 |
+
return att
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def scaled_query_key_softmax(
|
| 199 |
+
q: torch.Tensor,
|
| 200 |
+
k: torch.Tensor,
|
| 201 |
+
att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
|
| 202 |
+
) -> torch.Tensor:
|
| 203 |
+
# TODO assume we have (N, S, hs) instead of (B, nh, S, hs), with N = B x nh
|
| 204 |
+
# this is needed due to limitations in sparse_bmm for now
|
| 205 |
+
|
| 206 |
+
# Self-attend: (N, S, hs) x (N, hs, S) -> (N, S, S)
|
| 207 |
+
q = q / math.sqrt(k.size(-1))
|
| 208 |
+
|
| 209 |
+
# Matmul with mask
|
| 210 |
+
if att_mask is not None and isinstance(att_mask, AttentionMask):
|
| 211 |
+
# Additive mask
|
| 212 |
+
mask: Optional[Union[SparseCS, torch.Tensor]] = att_mask.values
|
| 213 |
+
else:
|
| 214 |
+
mask = att_mask
|
| 215 |
+
|
| 216 |
+
att = _matmul_with_mask(q, k.transpose(-2, -1), mask)
|
| 217 |
+
|
| 218 |
+
# Softmax to get the attention probabilities
|
| 219 |
+
is_causal = isinstance(att_mask, AttentionMask) and att_mask.is_causal
|
| 220 |
+
att = _softmax(att, causal=is_causal)
|
| 221 |
+
return att
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def scaled_dot_product_attention(
|
| 225 |
+
q: torch.Tensor,
|
| 226 |
+
k: torch.Tensor,
|
| 227 |
+
v: torch.Tensor,
|
| 228 |
+
att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
|
| 229 |
+
dropout: Optional[torch.nn.Module] = None,
|
| 230 |
+
) -> torch.Tensor:
|
| 231 |
+
autocast_disabled = (
|
| 232 |
+
_has_cpp_library
|
| 233 |
+
and isinstance(att_mask, SparseCS)
|
| 234 |
+
or (att_mask is not None and att_mask.is_sparse)
|
| 235 |
+
)
|
| 236 |
+
with torch.amp.autocast("cuda", enabled=False) if autocast_disabled else nullcontext(): # type: ignore
|
| 237 |
+
if autocast_disabled:
|
| 238 |
+
q, k, v = q.float(), k.float(), v.float()
|
| 239 |
+
|
| 240 |
+
att = scaled_query_key_softmax(q, k, att_mask=att_mask)
|
| 241 |
+
|
| 242 |
+
# Optional dropout, could be part of the masking in the future
|
| 243 |
+
att = _apply_dropout(att, dropout)
|
| 244 |
+
|
| 245 |
+
# Get to the predicted values, for all heads
|
| 246 |
+
# y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)
|
| 247 |
+
y = bmm(att, v)
|
| 248 |
+
return y
|
.venv/lib/python3.11/site-packages/xformers/components/attention/favor.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import math
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.amp import autocast
|
| 14 |
+
|
| 15 |
+
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
| 16 |
+
from xformers.components.attention.feature_maps import (
|
| 17 |
+
FeatureMap,
|
| 18 |
+
FeatureMapType,
|
| 19 |
+
SMHyperbolic,
|
| 20 |
+
SMOrf,
|
| 21 |
+
SMReg,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("xformers")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class FavorAttentionConfig(AttentionConfig):
|
| 29 |
+
causal: Optional[bool]
|
| 30 |
+
dim_features: Optional[int] = None # The dimensions of the random features
|
| 31 |
+
dim_head: Optional[
|
| 32 |
+
int
|
| 33 |
+
] = None # The embedding dimension of the inputs. Only useful to get a dim_features estimate
|
| 34 |
+
iter_before_redraw: Optional[
|
| 35 |
+
int
|
| 36 |
+
] = None # The number of iterations before the random features are re-drawn from scratch
|
| 37 |
+
feature_map: Optional[FeatureMapType] = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@register_attention("favor", FavorAttentionConfig)
|
| 41 |
+
class FavorAttention(Attention):
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
causal: bool = False,
|
| 45 |
+
dropout: float = 0.0,
|
| 46 |
+
dim_features: Optional[int] = None,
|
| 47 |
+
dim_head: Optional[int] = None,
|
| 48 |
+
iter_before_redraw: Optional[int] = None,
|
| 49 |
+
feature_map_type: FeatureMapType = FeatureMapType.SMReg,
|
| 50 |
+
normalize_inputs: bool = False,
|
| 51 |
+
*_,
|
| 52 |
+
**__,
|
| 53 |
+
):
|
| 54 |
+
r"""
|
| 55 |
+
Kernelized attention, as proposed in Performers_
|
| 56 |
+
("Rethinking attention with performers." K. Choromanski et al. (2020).).
|
| 57 |
+
|
| 58 |
+
FAVOR stands for "Fast Attention Via positive Orthogonal Random features"
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
dropout (float): the probability of an output to be randomly dropped at training time
|
| 62 |
+
dim_features (int): the dimension of the random features space
|
| 63 |
+
iter_before_redraw (int): the number of steps (forward calls) before a redraw of the features
|
| 64 |
+
feature_map_type (FeatureMapType): the type of feature map being used,
|
| 65 |
+
for instance orthogonal random features.
|
| 66 |
+
|
| 67 |
+
.. _Performers: https://arxiv.org/pdf/2009.14794v1.pdf
|
| 68 |
+
"""
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.causal = causal
|
| 72 |
+
self.iter_before_redraw = (
|
| 73 |
+
(2 * iter_before_redraw)
|
| 74 |
+
if iter_before_redraw is not None
|
| 75 |
+
else iter_before_redraw
|
| 76 |
+
) # This will be used for both key and query
|
| 77 |
+
self.normalize_inputs = normalize_inputs
|
| 78 |
+
self.feature_map_type = feature_map_type
|
| 79 |
+
self.attn_drop = nn.Dropout(dropout, inplace=True)
|
| 80 |
+
|
| 81 |
+
# Setup dimension-dependent variables
|
| 82 |
+
# Reasonable dimension default
|
| 83 |
+
if dim_features is None:
|
| 84 |
+
assert dim_head is not None, "dim_features or dim_head needs to be passed"
|
| 85 |
+
self.dim_features = math.ceil(dim_head * (1 + math.log2(dim_head)))
|
| 86 |
+
self.dim_features = 2 * (
|
| 87 |
+
self.dim_features // 2
|
| 88 |
+
) # needs to be even for some variants
|
| 89 |
+
logger.info(
|
| 90 |
+
f"FAVOR: Automatically setting the random mapping dimension to {self.dim_features} from {dim_head}"
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
self.dim_features = dim_features
|
| 94 |
+
|
| 95 |
+
feature_map_constructor = {
|
| 96 |
+
FeatureMapType.SMHyp: SMHyperbolic,
|
| 97 |
+
FeatureMapType.SMReg: SMReg,
|
| 98 |
+
FeatureMapType.SMOrf: SMOrf,
|
| 99 |
+
}[self.feature_map_type]
|
| 100 |
+
|
| 101 |
+
feature_settings = {
|
| 102 |
+
"dim_features": self.dim_features,
|
| 103 |
+
"iter_before_redraw": self.iter_before_redraw,
|
| 104 |
+
"normalize_inputs": self.normalize_inputs,
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
self.feature_map: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore
|
| 108 |
+
|
| 109 |
+
# Properties specific to this attention mechanism
|
| 110 |
+
self.supports_attention_mask = False
|
| 111 |
+
self.supports_key_padding_mask = False
|
| 112 |
+
|
| 113 |
+
@staticmethod
|
| 114 |
+
def _maybe_promote(x: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
# Only promote fp16 buffers, bfloat16 would be fine for instance
|
| 116 |
+
return x.float() if x.dtype == torch.float16 else x
|
| 117 |
+
|
| 118 |
+
@staticmethod
|
| 119 |
+
def _causal_attention(
|
| 120 |
+
k_prime: torch.Tensor, q_prime: torch.Tensor, v: torch.Tensor
|
| 121 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 122 |
+
# Algorithm 1 in the paper
|
| 123 |
+
ref_v = torch.ones_like(v.unsqueeze(2)) # BATCH x SEQ x 1 x EMB
|
| 124 |
+
Gps = k_prime.unsqueeze(3) * v.unsqueeze(2)
|
| 125 |
+
Grenorm = k_prime.unsqueeze(3) * ref_v
|
| 126 |
+
|
| 127 |
+
# Consolidate against the feature dimension
|
| 128 |
+
att_raw = torch.einsum("bcfe,bcf->bce", Gps, q_prime)
|
| 129 |
+
att_norm = torch.einsum("bcfe,bcf->bce", Grenorm, q_prime)
|
| 130 |
+
|
| 131 |
+
# Cumulative sum over the sequence
|
| 132 |
+
att_raw = att_raw.cumsum(2)
|
| 133 |
+
att_norm = att_norm.cumsum(2)
|
| 134 |
+
|
| 135 |
+
return att_raw, att_norm
|
| 136 |
+
|
| 137 |
+
def forward(
|
| 138 |
+
self,
|
| 139 |
+
q: torch.Tensor,
|
| 140 |
+
k: torch.Tensor,
|
| 141 |
+
v: torch.Tensor,
|
| 142 |
+
*_,
|
| 143 |
+
**__,
|
| 144 |
+
):
|
| 145 |
+
|
| 146 |
+
# Project key and queries onto the feature map space
|
| 147 |
+
k_prime = self.feature_map(k)
|
| 148 |
+
q_prime = self.feature_map(q)
|
| 149 |
+
|
| 150 |
+
with autocast("cuda", enabled=False):
|
| 151 |
+
# The softmax kernel approximation for Favor will easily overflow
|
| 152 |
+
# Force the computations here to stay in fp32 for numerical stability
|
| 153 |
+
# Note that the dimensions are vastly reduced when compared to scaled_dot_product
|
| 154 |
+
k_prime = self._maybe_promote(k_prime)
|
| 155 |
+
q_prime = self._maybe_promote(q_prime)
|
| 156 |
+
v = self._maybe_promote(v)
|
| 157 |
+
|
| 158 |
+
if not self.causal:
|
| 159 |
+
att_normalization = q_prime @ (
|
| 160 |
+
k_prime.transpose(-2, -1) @ torch.ones_like(v)
|
| 161 |
+
)
|
| 162 |
+
att_raw = q_prime @ (k_prime.transpose(-2, -1) @ v)
|
| 163 |
+
else:
|
| 164 |
+
# Actually compute attention
|
| 165 |
+
att_raw, att_normalization = self._causal_attention(k_prime, q_prime, v)
|
| 166 |
+
|
| 167 |
+
# Normalize
|
| 168 |
+
att = att_raw / att_normalization
|
| 169 |
+
|
| 170 |
+
if self.attn_drop is not None:
|
| 171 |
+
att = self.attn_drop(att)
|
| 172 |
+
|
| 173 |
+
return att
|
.venv/lib/python3.11/site-packages/xformers/components/attention/fourier_mix.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.amp import autocast
|
| 8 |
+
|
| 9 |
+
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@register_attention("fourier_mix", AttentionConfig)
|
| 13 |
+
class FourierMix(Attention):
|
| 14 |
+
def __init__(self, dropout: float, *_, **__):
|
| 15 |
+
"""
|
| 16 |
+
FFT-based pseudo-attention mechanism, from
|
| 17 |
+
"
|
| 18 |
+
"FNet: Mixing Tokens with Fourier Transforms"
|
| 19 |
+
Lee-Thorp et al., 2021, https://arxiv.org/pdf/2105.03824.pdf
|
| 20 |
+
"""
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
|
| 23 |
+
|
| 24 |
+
# Properties specific to this attention mechanism
|
| 25 |
+
self.supports_attention_mask = False
|
| 26 |
+
self.requires_input_projection = False
|
| 27 |
+
|
| 28 |
+
def forward(self, q: torch.Tensor, *_, **__):
|
| 29 |
+
# Guard against autocast / fp16, not supported by torch.fft.fft2
|
| 30 |
+
with autocast("cuda", enabled=False):
|
| 31 |
+
att = torch.fft.fft2(q).real
|
| 32 |
+
|
| 33 |
+
att = self.attn_drop(att)
|
| 34 |
+
|
| 35 |
+
return att
|
.venv/lib/python3.11/site-packages/xformers/components/attention/lambda_layer.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def calc_rel_pos(n: int):
|
| 15 |
+
# Adapted from LucidRains
|
| 16 |
+
# https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py
|
| 17 |
+
rel_pos = torch.arange(n)[None, :] - torch.arange(n)[:, None] # [n, n]
|
| 18 |
+
rel_pos += n - 1 # shift value range from [-n+1, n-1] to [0, 2n-2]
|
| 19 |
+
return rel_pos
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class LambdaLayerConfig(AttentionConfig):
|
| 24 |
+
seq_len: int # dimension of the input sequence
|
| 25 |
+
dim_head: int
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@register_attention("lambda", LambdaLayerConfig)
|
| 29 |
+
class LambdaLayer(Attention):
|
| 30 |
+
def __init__(self, dropout: float, seq_len: int, dim_head: int, *_, **__):
|
| 31 |
+
"""
|
| 32 |
+
Attention approximation using Lambda layers, from
|
| 33 |
+
"Lambda networks: modeling long-range interactions without attention.", Bello, I. (2021).
|
| 34 |
+
"""
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
# Possible extensions:
|
| 38 |
+
# - support different dimensions for key and queries
|
| 39 |
+
# - support varying dimensions in between inputs and outputs
|
| 40 |
+
# - support u hyperparam
|
| 41 |
+
|
| 42 |
+
self.rel_pos_emb = torch.nn.Parameter(
|
| 43 |
+
torch.randn(2 * seq_len - 1, int(dim_head))
|
| 44 |
+
)
|
| 45 |
+
self.rel_pos = calc_rel_pos(seq_len)
|
| 46 |
+
self.attn_drop = torch.nn.Dropout(dropout, inplace=True)
|
| 47 |
+
|
| 48 |
+
# Properties specific to this attention mechanism
|
| 49 |
+
self.requires_same_k_q_dimensions = True
|
| 50 |
+
self.supports_attention_mask = False
|
| 51 |
+
self.supports_key_padding_mask = False
|
| 52 |
+
|
| 53 |
+
def forward(
|
| 54 |
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
|
| 55 |
+
):
|
| 56 |
+
"""..NOTE: We're reusing the einsum notation suggested by the paper, changed in that
|
| 57 |
+
heads are folded in the batch dimension"""
|
| 58 |
+
|
| 59 |
+
content_lambda = torch.einsum("bnk,bnv->bkv", torch.softmax(k, dim=-1), v)
|
| 60 |
+
content_output = torch.einsum("bnk,bkv->bnv", q, content_lambda)
|
| 61 |
+
|
| 62 |
+
rel_pos_emb = self.rel_pos_emb[self.rel_pos]
|
| 63 |
+
|
| 64 |
+
# Handle real sequence length being possibly smaller
|
| 65 |
+
seq_len = q.shape[1]
|
| 66 |
+
rel_pos_emb = rel_pos_emb[:seq_len, :seq_len, :]
|
| 67 |
+
|
| 68 |
+
# Compute the position lambda for every possible combination in one go, then compute the
|
| 69 |
+
# position related contribution
|
| 70 |
+
position_lambdas = torch.einsum(
|
| 71 |
+
"mnk,bnv->bnkv", rel_pos_emb, v
|
| 72 |
+
) # one lambda per position
|
| 73 |
+
position_output = (q.unsqueeze(2) @ position_lambdas).squeeze()
|
| 74 |
+
att = content_output + position_output
|
| 75 |
+
|
| 76 |
+
att = self.attn_drop(att)
|
| 77 |
+
|
| 78 |
+
return att
|
.venv/lib/python3.11/site-packages/xformers/components/attention/local.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Optional, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from xformers.components.attention import (
|
| 14 |
+
Attention,
|
| 15 |
+
AttentionConfig,
|
| 16 |
+
AttentionMask,
|
| 17 |
+
maybe_sparsify,
|
| 18 |
+
register_attention,
|
| 19 |
+
sparsify,
|
| 20 |
+
)
|
| 21 |
+
from xformers.components.attention.attention_patterns import (
|
| 22 |
+
causal_1d_pattern,
|
| 23 |
+
local_1d_pattern,
|
| 24 |
+
)
|
| 25 |
+
from xformers.components.attention.core import scaled_dot_product_attention
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class LocalAttentionConfig(AttentionConfig):
|
| 30 |
+
causal: Optional[bool] = None
|
| 31 |
+
window_size: Optional[int] = None
|
| 32 |
+
force_sparsity: Optional[bool] = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@register_attention("local", LocalAttentionConfig)
|
| 36 |
+
class LocalAttention(Attention):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
dropout: float = 0.0,
|
| 40 |
+
causal: bool = False,
|
| 41 |
+
window_size: int = 5,
|
| 42 |
+
force_sparsity: bool = False,
|
| 43 |
+
*args,
|
| 44 |
+
**kwargs,
|
| 45 |
+
):
|
| 46 |
+
|
| 47 |
+
r"""
|
| 48 |
+
An implementation of a sliding window attention, as proposed in RoutingTransformer_, LongFormer_ or BigBird_
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
dropout (float): the probability of an output to be randomly dropped at training time
|
| 53 |
+
causal (bool): apply a causal mask, in that the attention cannot be applied to the future
|
| 54 |
+
window_size (int): the overall window size for local attention.
|
| 55 |
+
Odd number is expected if the mask is not causal, as the window size will be evenly
|
| 56 |
+
distributed on both sides of each query
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
.. _RoutingTransformer: https://arxiv.org/pdf/2003.05997.pdf
|
| 60 |
+
|
| 61 |
+
.. _BigBird: https://arxiv.org/pdf/2007.14062.pdf
|
| 62 |
+
|
| 63 |
+
.. _Longformer: https://arxiv.org/pdf/2004.05150.pdf
|
| 64 |
+
|
| 65 |
+
"""
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
self.attn_drop = nn.Dropout(dropout, inplace=False)
|
| 69 |
+
self.causal = causal
|
| 70 |
+
self.force_sparsity = force_sparsity
|
| 71 |
+
|
| 72 |
+
if not self.causal:
|
| 73 |
+
assert (
|
| 74 |
+
window_size % 2 == 1
|
| 75 |
+
), "The window size is assumed to be odd (counts self-attention + 2 wings)"
|
| 76 |
+
|
| 77 |
+
self.window_size = window_size
|
| 78 |
+
self.attention_mask: Optional[torch.Tensor] = None
|
| 79 |
+
self.requires_same_k_q_dimensions = True
|
| 80 |
+
|
| 81 |
+
# Properties specific to this attention mechanism
|
| 82 |
+
self.supports_attention_mask = True
|
| 83 |
+
self.supports_key_padding_mask = False
|
| 84 |
+
|
| 85 |
+
def _get_local_mask(self, shape: torch.Size) -> torch.Tensor:
|
| 86 |
+
window_size = self.window_size * 2 + 1 if self.causal else self.window_size
|
| 87 |
+
mask = local_1d_pattern(shape[1], window_size)
|
| 88 |
+
|
| 89 |
+
if self.causal:
|
| 90 |
+
mask &= causal_1d_pattern(shape[1])
|
| 91 |
+
|
| 92 |
+
mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask)
|
| 93 |
+
|
| 94 |
+
return mask
|
| 95 |
+
|
| 96 |
+
def forward(
|
| 97 |
+
self,
|
| 98 |
+
q: torch.Tensor,
|
| 99 |
+
k: torch.Tensor,
|
| 100 |
+
v: torch.Tensor,
|
| 101 |
+
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
|
| 102 |
+
*args,
|
| 103 |
+
**kwargs,
|
| 104 |
+
):
|
| 105 |
+
# Local window attention masking
|
| 106 |
+
if self.attention_mask is None or self.attention_mask.shape[1] != q.shape[1]:
|
| 107 |
+
self.attention_mask = self._get_local_mask(q.shape).to(q.device)
|
| 108 |
+
|
| 109 |
+
# Take into account the optional user mask
|
| 110 |
+
if att_mask is None:
|
| 111 |
+
mask = self.attention_mask
|
| 112 |
+
else:
|
| 113 |
+
if isinstance(att_mask, AttentionMask):
|
| 114 |
+
# Needed because & op not defined for SparseCS with AttentionMask
|
| 115 |
+
att_mask = att_mask.to_bool()
|
| 116 |
+
mask = self.attention_mask & att_mask
|
| 117 |
+
|
| 118 |
+
return scaled_dot_product_attention(
|
| 119 |
+
q=q, k=k, v=v, att_mask=mask, dropout=self.attn_drop
|
| 120 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/components/attention/nystrom.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
| 15 |
+
from xformers.components.attention.core import (
|
| 16 |
+
scaled_dot_product_attention,
|
| 17 |
+
scaled_query_key_softmax,
|
| 18 |
+
)
|
| 19 |
+
from xformers.components.attention.utils import (
|
| 20 |
+
bool_mask_to_additive,
|
| 21 |
+
iterative_pinv,
|
| 22 |
+
reshape_key_padding_mask,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger("xformers")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class NystromSelfAttentionConfig(AttentionConfig):
|
| 30 |
+
"""
|
| 31 |
+
num_heads Number of heads.
|
| 32 |
+
num_landmarks Number of landmarks to use for softmax approximation. 64 often sufficient for a good
|
| 33 |
+
approximation according to https://arxiv.org/pdf/2102.03902.pdf.
|
| 34 |
+
causal Apply a causal mask, in that the attention cannot be applied to the future.
|
| 35 |
+
use_razavi_pinverse If true, use iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose
|
| 36 |
+
inverse, otherwise use standard torch inverse.
|
| 37 |
+
pinverse_original_init True if using original initialization when calculating Moore-Penrose pseudo inverse using
|
| 38 |
+
method from (Razavi et al. 2014).
|
| 39 |
+
False if using exact coefficient computation (leads to faster convergence).
|
| 40 |
+
inv_iterations Number of iterations for calculating the Moore-Penrose pseudo inverse.
|
| 41 |
+
v_skip_connection A module that will take V as input and will be added as a skip connection to the
|
| 42 |
+
softmax approximation. A skip connection is added in the paper to help with training.
|
| 43 |
+
conv_kernel_size Kernel size for convolution optionally added to help in training.
|
| 44 |
+
If v_skip_connection is not specified, this will be used to define the default
|
| 45 |
+
depth wise convolution used as a skip connection.
|
| 46 |
+
If both conv_kernel_size and v_skip_connection are None, no skip connection will
|
| 47 |
+
be added.
|
| 48 |
+
landmark_pooling Which module to use when computing landmarks. Default is AdaptiveAvgPool2d.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
num_heads: int
|
| 52 |
+
num_landmarks: Optional[int]
|
| 53 |
+
landmark_pooling: Optional[nn.Module]
|
| 54 |
+
causal: Optional[bool]
|
| 55 |
+
pinverse_original_init: Optional[bool]
|
| 56 |
+
inv_iterations: Optional[int]
|
| 57 |
+
v_skip_connection: Optional[nn.Module]
|
| 58 |
+
conv_kernel_size: Optional[int]
|
| 59 |
+
use_razavi_pinverse: Optional[bool]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class AvgPool(nn.Module):
|
| 63 |
+
def __init__(self, n: int):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.n = n
|
| 66 |
+
|
| 67 |
+
def forward(self, x: torch.Tensor):
|
| 68 |
+
# Average independently for every segment in the sequence dimension
|
| 69 |
+
seq_len = x.shape[1]
|
| 70 |
+
head_dim = x.shape[2]
|
| 71 |
+
segments = seq_len // self.n
|
| 72 |
+
assert segments > 0, "num_landmarks should be smaller than the sequence length"
|
| 73 |
+
|
| 74 |
+
# Dimensions are a match
|
| 75 |
+
if seq_len % self.n == 0:
|
| 76 |
+
return x.reshape(
|
| 77 |
+
-1,
|
| 78 |
+
self.n,
|
| 79 |
+
segments,
|
| 80 |
+
head_dim,
|
| 81 |
+
).mean(dim=-2)
|
| 82 |
+
|
| 83 |
+
# Handle the last segment boundary being off
|
| 84 |
+
n_round = self.n - seq_len % self.n
|
| 85 |
+
|
| 86 |
+
x_avg_round = (
|
| 87 |
+
x[:, : n_round * segments, :]
|
| 88 |
+
.reshape(-1, n_round, segments, head_dim)
|
| 89 |
+
.mean(dim=-2)
|
| 90 |
+
)
|
| 91 |
+
x_avg_off = (
|
| 92 |
+
x[:, n_round * segments :, :]
|
| 93 |
+
.reshape(-1, self.n - n_round, segments + 1, head_dim)
|
| 94 |
+
.mean(dim=-2)
|
| 95 |
+
)
|
| 96 |
+
return torch.cat((x_avg_round, x_avg_off), dim=-2)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@register_attention("nystrom", NystromSelfAttentionConfig)
|
| 100 |
+
class NystromAttention(Attention):
|
| 101 |
+
# TODO: update defaults for use_razavi_pinverse and inv_iterations
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
dropout: float,
|
| 105 |
+
num_heads: int,
|
| 106 |
+
num_landmarks: int = 64,
|
| 107 |
+
landmark_pooling: Optional[nn.Module] = None,
|
| 108 |
+
causal: bool = False,
|
| 109 |
+
use_razavi_pinverse: bool = True,
|
| 110 |
+
pinverse_original_init: bool = False,
|
| 111 |
+
inv_iterations: int = 6, # recommended default in paper was 6.
|
| 112 |
+
v_skip_connection: Optional[nn.Module] = None,
|
| 113 |
+
conv_kernel_size: Optional[int] = None,
|
| 114 |
+
*args,
|
| 115 |
+
**kwargs,
|
| 116 |
+
):
|
| 117 |
+
"""
|
| 118 |
+
Nystrom attention mechanism, from Nystromformer_.
|
| 119 |
+
::
|
| 120 |
+
|
| 121 |
+
"A Nystrom-based Algorithm for Approximating Self-Attention."
|
| 122 |
+
Xiong, Y., Zeng, Z., Chakraborty, R., Tan, M., Fung, G., Li, Y., Singh, V. (2021)
|
| 123 |
+
|
| 124 |
+
Reference codebase: https://github.com/mlpen/Nystromformer
|
| 125 |
+
|
| 126 |
+
.. _Nystromformer: https://arxiv.org/pdf/2102.03902.pdf
|
| 127 |
+
|
| 128 |
+
"""
|
| 129 |
+
super().__init__()
|
| 130 |
+
# merged key padding mask and attention mask is not accepted
|
| 131 |
+
self.requires_separate_masks = True
|
| 132 |
+
self.num_landmarks = num_landmarks
|
| 133 |
+
# TODO: should be able to not have to pass in num_heads
|
| 134 |
+
self.num_heads = num_heads
|
| 135 |
+
self.use_razavi_pinverse = use_razavi_pinverse
|
| 136 |
+
self.pinverse_original_init = pinverse_original_init
|
| 137 |
+
self.inv_iterations = inv_iterations
|
| 138 |
+
self.attn_drop = nn.Dropout(dropout)
|
| 139 |
+
self.skip_connection = v_skip_connection
|
| 140 |
+
self.causal = causal
|
| 141 |
+
|
| 142 |
+
if self.skip_connection is None and conv_kernel_size is not None:
|
| 143 |
+
self.skip_connection = nn.Conv2d(
|
| 144 |
+
in_channels=self.num_heads,
|
| 145 |
+
out_channels=self.num_heads,
|
| 146 |
+
kernel_size=(conv_kernel_size, 1),
|
| 147 |
+
padding=(conv_kernel_size // 2, 0),
|
| 148 |
+
bias=False,
|
| 149 |
+
groups=self.num_heads,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
if landmark_pooling is not None:
|
| 153 |
+
self.landmark_pooling = landmark_pooling
|
| 154 |
+
else:
|
| 155 |
+
self.landmark_pooling = AvgPool(n=self.num_landmarks)
|
| 156 |
+
|
| 157 |
+
# Optional lower triangular masks for causal attention
|
| 158 |
+
self.causal_mask_1: Optional[torch.Tensor] = None
|
| 159 |
+
self.causal_mask_2: Optional[torch.Tensor] = None
|
| 160 |
+
self.causal_mask_3: Optional[torch.Tensor] = None
|
| 161 |
+
|
| 162 |
+
# This attention does not support attention masks
|
| 163 |
+
self.supports_attention_mask = False
|
| 164 |
+
self.supports_key_padding_mask = True
|
| 165 |
+
|
| 166 |
+
def forward(
|
| 167 |
+
self,
|
| 168 |
+
q: torch.Tensor,
|
| 169 |
+
k: torch.Tensor,
|
| 170 |
+
v: torch.Tensor,
|
| 171 |
+
key_padding_mask: Optional[torch.Tensor] = None,
|
| 172 |
+
*args,
|
| 173 |
+
**kwargs,
|
| 174 |
+
):
|
| 175 |
+
r"""
|
| 176 |
+
key_padding_mask Only a key padding mask is accepted here. The size must be (batch size, sequence length) or
|
| 177 |
+
(batch size * num_heads, 1, sequence length). If dimensions are not correct, the mask will
|
| 178 |
+
be ignored. An additive mask is expected, meaning float values using "-inf" to mask values
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
batched_dim = k.size(0)
|
| 182 |
+
seq_len = k.size(-2)
|
| 183 |
+
tt = {"dtype": q.dtype, "device": q.device}
|
| 184 |
+
|
| 185 |
+
if key_padding_mask is not None:
|
| 186 |
+
if key_padding_mask.dtype == torch.bool:
|
| 187 |
+
logger.warning(
|
| 188 |
+
"Bool mask found, but an additive mask is expected. Converting but this is slow"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
key_padding_mask = bool_mask_to_additive(key_padding_mask)
|
| 192 |
+
|
| 193 |
+
if key_padding_mask.ndim == 2:
|
| 194 |
+
key_padding_mask = reshape_key_padding_mask(
|
| 195 |
+
key_padding_mask, batched_dim
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
zeros = torch.zeros_like(key_padding_mask)
|
| 199 |
+
ones = torch.ones_like(key_padding_mask)
|
| 200 |
+
is_masked = torch.isinf(-key_padding_mask)
|
| 201 |
+
|
| 202 |
+
# _mask takes 1 if the token is not padded, otherwise 0.
|
| 203 |
+
_mask = torch.where(is_masked, zeros, ones)
|
| 204 |
+
_mask = _mask.transpose(2, 1)
|
| 205 |
+
assert _mask.shape == (batched_dim, q.shape[1], 1)
|
| 206 |
+
|
| 207 |
+
# Mask q and k before pooling
|
| 208 |
+
# https://github.com/mlpen/Nystromformer/blob/main/code/attention_nystrom.py#L31
|
| 209 |
+
q = q * _mask
|
| 210 |
+
k = k * _mask
|
| 211 |
+
|
| 212 |
+
assert key_padding_mask.size() == (batched_dim, 1, seq_len), (
|
| 213 |
+
f"key_padding_mask has invalid dimensions {key_padding_mask.size()}."
|
| 214 |
+
f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})."
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if self.num_landmarks >= seq_len:
|
| 218 |
+
mask: Optional[torch.Tensor] = None
|
| 219 |
+
|
| 220 |
+
if self.causal:
|
| 221 |
+
mask = self._triu_mask(batched_dim, seq_len, seq_len, **tt)
|
| 222 |
+
|
| 223 |
+
if key_padding_mask is not None:
|
| 224 |
+
mask = key_padding_mask if mask is None else mask + key_padding_mask
|
| 225 |
+
|
| 226 |
+
x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask)
|
| 227 |
+
|
| 228 |
+
else:
|
| 229 |
+
q_landmarks = self.landmark_pooling(q)
|
| 230 |
+
k_landmarks = self.landmark_pooling(k)
|
| 231 |
+
|
| 232 |
+
if self.causal and (
|
| 233 |
+
self.causal_mask_1 is None
|
| 234 |
+
or (batched_dim, seq_len, self.num_landmarks)
|
| 235 |
+
!= self.causal_mask_1.size()
|
| 236 |
+
):
|
| 237 |
+
self.causal_mask_1 = self._triu_mask(
|
| 238 |
+
batched_dim, seq_len, self.num_landmarks, **tt
|
| 239 |
+
)
|
| 240 |
+
self.causal_mask_2 = self._triu_mask(
|
| 241 |
+
batched_dim, self.num_landmarks, self.num_landmarks, **tt
|
| 242 |
+
)
|
| 243 |
+
self.causal_mask_3 = self._triu_mask(
|
| 244 |
+
batched_dim, self.num_landmarks, seq_len, **tt
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
mask_3: Optional[torch.Tensor] = self.causal_mask_3
|
| 248 |
+
if key_padding_mask is not None:
|
| 249 |
+
mask_3 = (
|
| 250 |
+
key_padding_mask if mask_3 is None else mask_3 + key_padding_mask
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
kernel_1 = scaled_query_key_softmax(q=q, k=k_landmarks, att_mask=None)
|
| 254 |
+
kernel_2 = scaled_query_key_softmax(
|
| 255 |
+
q=q_landmarks, k=k_landmarks, att_mask=None
|
| 256 |
+
)
|
| 257 |
+
kernel_3 = scaled_dot_product_attention(
|
| 258 |
+
q=q_landmarks, k=k, v=v, att_mask=mask_3
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
kernel_2_inv = (
|
| 262 |
+
iterative_pinv(
|
| 263 |
+
kernel_2, self.inv_iterations, self.pinverse_original_init
|
| 264 |
+
)
|
| 265 |
+
if self.use_razavi_pinverse
|
| 266 |
+
else torch.linalg.pinv(kernel_2)
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
x = torch.matmul(
|
| 270 |
+
torch.matmul(
|
| 271 |
+
kernel_1,
|
| 272 |
+
kernel_2_inv,
|
| 273 |
+
),
|
| 274 |
+
kernel_3,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if self.skip_connection:
|
| 278 |
+
# Assumption here is that v is 3D.
|
| 279 |
+
v_conv = self.skip_connection(
|
| 280 |
+
v.reshape(-1, self.num_heads, v.size(-2), v.size(-1))
|
| 281 |
+
)
|
| 282 |
+
x += v_conv.reshape(-1, v_conv.size(-2), v_conv.size(-1))
|
| 283 |
+
x = self.attn_drop(x)
|
| 284 |
+
return x
|
| 285 |
+
|
| 286 |
+
def _triu_mask(self, dim_1: int, dim_2: int, dim_3: int, **kwargs) -> torch.Tensor:
|
| 287 |
+
device = kwargs["device"]
|
| 288 |
+
dtype = kwargs["dtype"]
|
| 289 |
+
|
| 290 |
+
return torch.triu(
|
| 291 |
+
torch.ones(dim_2, dim_3, dtype=dtype, device=device) * float("-inf"),
|
| 292 |
+
diagonal=1,
|
| 293 |
+
).expand(
|
| 294 |
+
dim_1, -1, -1
|
| 295 |
+
) # micro optim, save memory on the batch dimension
|
.venv/lib/python3.11/site-packages/xformers/components/attention/random.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Optional, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from xformers.components.attention import (
|
| 14 |
+
Attention,
|
| 15 |
+
AttentionConfig,
|
| 16 |
+
AttentionMask,
|
| 17 |
+
maybe_sparsify,
|
| 18 |
+
register_attention,
|
| 19 |
+
sparsify,
|
| 20 |
+
)
|
| 21 |
+
from xformers.components.attention.attention_patterns import (
|
| 22 |
+
causal_1d_pattern,
|
| 23 |
+
random_pattern,
|
| 24 |
+
)
|
| 25 |
+
from xformers.components.attention.core import scaled_dot_product_attention
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class RandomAttentionConfig(AttentionConfig):
|
| 30 |
+
r: Optional[
|
| 31 |
+
float
|
| 32 |
+
] # the ratio of keys that the query can attend to. 1.0 means dense attention
|
| 33 |
+
constant_masking: Optional[
|
| 34 |
+
bool
|
| 35 |
+
] # whether the randomness is per query or defined at construction time
|
| 36 |
+
force_sparsity: Optional[bool] # use sparsity in any case (potentially slower)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@register_attention("random", RandomAttentionConfig)
|
| 40 |
+
class RandomAttention(Attention):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
dropout: float,
|
| 44 |
+
causal: bool = False,
|
| 45 |
+
r: float = 0.01,
|
| 46 |
+
constant_masking: bool = True,
|
| 47 |
+
force_sparsity: bool = False,
|
| 48 |
+
*args,
|
| 49 |
+
**kwargs,
|
| 50 |
+
):
|
| 51 |
+
"""
|
| 52 |
+
"Random" attention, as proposed for instance in BigBird_.
|
| 53 |
+
Random means in that case that each query can attend to a random set of keys.
|
| 54 |
+
This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
r (float): the ratio in [0,1] of keys that the query can attend to
|
| 58 |
+
constant_masking (bool): if true, keep the same random set for all queries.
|
| 59 |
+
|
| 60 |
+
.. _BigBird: https://arxiv.org/pdf/2007.14062.pdf
|
| 61 |
+
|
| 62 |
+
"""
|
| 63 |
+
super().__init__()
|
| 64 |
+
|
| 65 |
+
self.attn_drop = nn.Dropout(dropout, inplace=False)
|
| 66 |
+
self.causal = causal
|
| 67 |
+
self.r = r
|
| 68 |
+
self.rand_attention_mask: Optional[torch.Tensor] = None
|
| 69 |
+
self.constant_masking = constant_masking
|
| 70 |
+
self.force_sparsity = force_sparsity
|
| 71 |
+
|
| 72 |
+
# Properties specific to this attention mechanism
|
| 73 |
+
self.supports_attention_mask = True
|
| 74 |
+
self.supports_key_padding_mask = False
|
| 75 |
+
|
| 76 |
+
self.requires_same_k_q_dimensions = True
|
| 77 |
+
|
| 78 |
+
def _get_rand_mask(self, shape: torch.Size) -> torch.Tensor:
|
| 79 |
+
sparsity = 1 - self.r
|
| 80 |
+
mask = random_pattern(shape[1], sparsity=sparsity)
|
| 81 |
+
|
| 82 |
+
if self.causal:
|
| 83 |
+
mask &= causal_1d_pattern(shape[1])
|
| 84 |
+
|
| 85 |
+
mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask)
|
| 86 |
+
|
| 87 |
+
return mask
|
| 88 |
+
|
| 89 |
+
def forward(
|
| 90 |
+
self,
|
| 91 |
+
q: torch.Tensor,
|
| 92 |
+
k: torch.Tensor,
|
| 93 |
+
v: torch.Tensor,
|
| 94 |
+
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
|
| 95 |
+
*args,
|
| 96 |
+
**kwargs,
|
| 97 |
+
):
|
| 98 |
+
# Rand masking
|
| 99 |
+
if not self.constant_masking or self.rand_attention_mask is None:
|
| 100 |
+
self.rand_attention_mask = self._get_rand_mask(q.shape).to(q.device)
|
| 101 |
+
|
| 102 |
+
# Mask-aware attention
|
| 103 |
+
if att_mask is not None:
|
| 104 |
+
if att_mask.dtype == torch.bool and isinstance(
|
| 105 |
+
self.rand_attention_mask, AttentionMask
|
| 106 |
+
):
|
| 107 |
+
mask = self.rand_attention_mask + AttentionMask.from_bool(att_mask)
|
| 108 |
+
else:
|
| 109 |
+
if isinstance(att_mask, AttentionMask):
|
| 110 |
+
# Needed because & op not defined for SparseCS with AttentionMask
|
| 111 |
+
att_mask = att_mask.to_bool()
|
| 112 |
+
mask = self.rand_attention_mask & att_mask
|
| 113 |
+
else:
|
| 114 |
+
mask = self.rand_attention_mask
|
| 115 |
+
|
| 116 |
+
# Handle q/k/v which would not fit the mask
|
| 117 |
+
seq_len = q.shape[-2]
|
| 118 |
+
q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v))
|
| 119 |
+
|
| 120 |
+
# Normal attention with the random mask
|
| 121 |
+
att = scaled_dot_product_attention(
|
| 122 |
+
q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Take into account an hypothetical padding
|
| 126 |
+
return att[:, :seq_len, :]
|
.venv/lib/python3.11/site-packages/xformers/components/attention/scaled_dot_product.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Optional, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
from xformers.components.attention import (
|
| 14 |
+
Attention,
|
| 15 |
+
AttentionConfig,
|
| 16 |
+
AttentionMask,
|
| 17 |
+
register_attention,
|
| 18 |
+
)
|
| 19 |
+
from xformers.components.attention.core import scaled_dot_product_attention
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger("xformers")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class ScaledDotProductConfig(AttentionConfig):
|
| 26 |
+
causal: Optional[bool]
|
| 27 |
+
seq_len: Optional[int]
|
| 28 |
+
to_seq_len: Optional[int]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@register_attention("scaled_dot_product", ScaledDotProductConfig)
|
| 32 |
+
class ScaledDotProduct(Attention):
|
| 33 |
+
r"""
|
| 34 |
+
Implementing the Scaled Dot-Product attention proposed in
|
| 35 |
+
`Attention is all you need`_, Vaswani et al.
|
| 36 |
+
|
| 37 |
+
.. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
mask: Optional[AttentionMask]
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
dropout: float = 0.0,
|
| 45 |
+
causal: bool = False,
|
| 46 |
+
seq_len: Optional[int] = None,
|
| 47 |
+
to_seq_len: Optional[int] = None,
|
| 48 |
+
*args,
|
| 49 |
+
**kwargs,
|
| 50 |
+
):
|
| 51 |
+
super().__init__()
|
| 52 |
+
|
| 53 |
+
self.attn_drop = nn.Dropout(dropout, inplace=False)
|
| 54 |
+
self.causal = causal
|
| 55 |
+
self.seq_len = seq_len
|
| 56 |
+
|
| 57 |
+
if causal and seq_len is not None:
|
| 58 |
+
self.mask = AttentionMask.make_causal(seq_len, to_seq_len)
|
| 59 |
+
else:
|
| 60 |
+
self.mask = None
|
| 61 |
+
|
| 62 |
+
# Properties specific to this attention mechanism
|
| 63 |
+
self.supports_attention_mask = True
|
| 64 |
+
self.supports_key_padding_mask = False
|
| 65 |
+
|
| 66 |
+
def forward(
|
| 67 |
+
self,
|
| 68 |
+
q: torch.Tensor,
|
| 69 |
+
k: torch.Tensor,
|
| 70 |
+
v: torch.Tensor,
|
| 71 |
+
att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None,
|
| 72 |
+
*args,
|
| 73 |
+
**kwargs,
|
| 74 |
+
) -> torch.Tensor:
|
| 75 |
+
r"""
|
| 76 |
+
att_mask A 2D or 3D mask which ignores attention at certain positions.
|
| 77 |
+
|
| 78 |
+
- If the mask is boolean, a value of True will keep the value,
|
| 79 |
+
while a value of False will mask the value.
|
| 80 |
+
|
| 81 |
+
Key padding masks (dimension: batch x sequence length) and attention masks
|
| 82 |
+
(dimension: sequence length x sequence length OR batch x sequence length x sequence length)
|
| 83 |
+
can be combined and passed in here. Method maybe_merge_masks provided in the utils can be
|
| 84 |
+
used for that merging.
|
| 85 |
+
|
| 86 |
+
- If the mask has the float type, then an additive mask is expected (masked values are -inf)
|
| 87 |
+
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
# Convenience, create an attention mask if a tensor was passed
|
| 91 |
+
if att_mask is not None and isinstance(att_mask, torch.Tensor):
|
| 92 |
+
# By default we don't know of the causality, and a check would be expensive
|
| 93 |
+
att_mask = (
|
| 94 |
+
AttentionMask.from_bool(att_mask)
|
| 95 |
+
if att_mask.dtype == torch.bool
|
| 96 |
+
else AttentionMask(att_mask, is_causal=False)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Handle a possibly deferred causal mask handling
|
| 100 |
+
mask = self.mask
|
| 101 |
+
if self.causal and self.mask is None:
|
| 102 |
+
mask = AttentionMask.make_causal(
|
| 103 |
+
seq_len=q.shape[-2],
|
| 104 |
+
to_seq_len=q.shape[-2],
|
| 105 |
+
device=q.device,
|
| 106 |
+
dtype=q.dtype,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Merge the optional causal mask and the user-provided mask
|
| 110 |
+
if mask is not None:
|
| 111 |
+
mask = mask.to(dtype=q.dtype, device=q.device)
|
| 112 |
+
|
| 113 |
+
att_mask = att_mask + mask if att_mask is not None else mask
|
| 114 |
+
|
| 115 |
+
# Try to handle a case where the sequence is smaller than the mask
|
| 116 |
+
if (
|
| 117 |
+
att_mask is not None
|
| 118 |
+
and q.shape[-2] == k.shape[-2]
|
| 119 |
+
and q.shape[-2] < att_mask.shape[1]
|
| 120 |
+
):
|
| 121 |
+
if isinstance(att_mask, AttentionMask):
|
| 122 |
+
att_mask = att_mask.make_crop(seq_len=q.shape[-2])
|
| 123 |
+
else:
|
| 124 |
+
logger.error(
|
| 125 |
+
"Mismatching sparse attention mask and sequence length."
|
| 126 |
+
+ " Please pad the inputs or adjust the attention mask"
|
| 127 |
+
)
|
| 128 |
+
raise NotImplementedError
|
| 129 |
+
|
| 130 |
+
# Attend: (B x nh, S, hs) x (B x nh, hs, S) -> (B x nh, S, S)
|
| 131 |
+
y = scaled_dot_product_attention(
|
| 132 |
+
q=q, k=k, v=v, att_mask=att_mask, dropout=self.attn_drop
|
| 133 |
+
)
|
| 134 |
+
return y
|
.venv/lib/python3.11/site-packages/xformers/components/attention/visual.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class VisualAttentionConfig(AttentionConfig):
|
| 18 |
+
dim_model: int # dimension of the input sequence
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class LKA(nn.Module):
|
| 22 |
+
def __init__(self, dim: int):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
|
| 25 |
+
self.conv_spatial = nn.Conv2d(
|
| 26 |
+
dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3
|
| 27 |
+
)
|
| 28 |
+
self.conv1 = nn.Conv2d(dim, dim, 1)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: torch.Tensor):
|
| 31 |
+
u = x.clone()
|
| 32 |
+
attn = self.conv0(x)
|
| 33 |
+
attn = self.conv_spatial(attn)
|
| 34 |
+
attn = self.conv1(attn)
|
| 35 |
+
|
| 36 |
+
return u * attn
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@register_attention("visual", VisualAttentionConfig)
|
| 40 |
+
class Visual(Attention):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
dim_model: int,
|
| 44 |
+
*_,
|
| 45 |
+
**__,
|
| 46 |
+
):
|
| 47 |
+
"""
|
| 48 |
+
Large kernel attention mechanism, as proposed in `Visual Attention Network`_, Guo et al (2022).
|
| 49 |
+
The original notation is tentatively kept as is. See https://github.com/Visual-Attention-Network
|
| 50 |
+
for the reference implementation
|
| 51 |
+
|
| 52 |
+
.. Note: compared to the paper, this block contains the LKA (Large Kernel Attention)
|
| 53 |
+
and the prior and posterior transformations (Conv2d and activation)
|
| 54 |
+
|
| 55 |
+
.. _`Visual Attention Network` : https://arxiv.org/pdf/2202.09741.pdf
|
| 56 |
+
"""
|
| 57 |
+
super().__init__()
|
| 58 |
+
|
| 59 |
+
self.block = nn.Sequential(
|
| 60 |
+
nn.Conv2d(dim_model, dim_model, 1),
|
| 61 |
+
nn.GELU(),
|
| 62 |
+
LKA(dim_model),
|
| 63 |
+
nn.Conv2d(dim_model, dim_model, 1),
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# MHA related flags:
|
| 67 |
+
self.requires_same_k_q_dimensions = (
|
| 68 |
+
True # This mechanism only really supports self attention
|
| 69 |
+
)
|
| 70 |
+
self.supports_attention_mask = False
|
| 71 |
+
self.requires_skip_multi_head = (
|
| 72 |
+
True # This mechanism skips the multihead attention altogether
|
| 73 |
+
)
|
| 74 |
+
self.requires_squared_context = (
|
| 75 |
+
True # Recovering the 2D structure from context assumes squared content
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.requires_input_projection = (
|
| 79 |
+
False # This mechanism does not require that the MHA projects inputs
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def forward(self, q: torch.Tensor, *_, **__):
|
| 83 |
+
# Expose the 2D token structure
|
| 84 |
+
B, HW, C = q.shape
|
| 85 |
+
H = int(math.sqrt(HW))
|
| 86 |
+
assert H * H == HW
|
| 87 |
+
|
| 88 |
+
x = q.transpose(-2, -1).reshape(B, C, H, H)
|
| 89 |
+
|
| 90 |
+
# Large kernel attention
|
| 91 |
+
residual = x.clone()
|
| 92 |
+
x = self.block(x)
|
| 93 |
+
x = x + residual
|
| 94 |
+
|
| 95 |
+
# Get back to B HW C
|
| 96 |
+
return x.flatten(2, 3).transpose(-2, -1)
|
.venv/lib/python3.11/site-packages/xformers/components/input_projection.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# CREDITS: Inspired by https://github.com/pytorch/text/blob/master/torchtext/nn/modules/multiheadattention.py
|
| 7 |
+
# and the MultiHeadAttention implementation from PyTorch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn
|
| 16 |
+
|
| 17 |
+
from xformers._deprecation_warning import deprecated_function
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger("xformers")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class InputProjectionConfig:
|
| 24 |
+
in_features: int
|
| 25 |
+
out_features: int
|
| 26 |
+
bias: bool
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class InputProjection(nn.Module):
|
| 30 |
+
"""
|
| 31 |
+
Handle all the input projections in one go, opportunistically fuse some operations.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
query_proj_params: InputProjectionConfig,
|
| 37 |
+
key_proj_params: Optional[InputProjectionConfig],
|
| 38 |
+
value_proj_params: Optional[InputProjectionConfig],
|
| 39 |
+
use_separate_proj_weight: bool = True,
|
| 40 |
+
):
|
| 41 |
+
|
| 42 |
+
super().__init__()
|
| 43 |
+
deprecated_function(self)
|
| 44 |
+
|
| 45 |
+
self.out_features = query_proj_params.out_features
|
| 46 |
+
|
| 47 |
+
# Each input gets a separate projection
|
| 48 |
+
self.q_proj = nn.Linear(
|
| 49 |
+
query_proj_params.in_features,
|
| 50 |
+
query_proj_params.out_features,
|
| 51 |
+
query_proj_params.bias,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
if key_proj_params is not None:
|
| 55 |
+
self.k_proj = nn.Linear(
|
| 56 |
+
key_proj_params.in_features,
|
| 57 |
+
key_proj_params.out_features,
|
| 58 |
+
key_proj_params.bias,
|
| 59 |
+
)
|
| 60 |
+
else:
|
| 61 |
+
logger.info(
|
| 62 |
+
"No Key projection parameters were passed, assuming that the weights"
|
| 63 |
+
+ " are shared with the query projection"
|
| 64 |
+
)
|
| 65 |
+
self.k_proj = self.q_proj
|
| 66 |
+
|
| 67 |
+
if value_proj_params is not None:
|
| 68 |
+
self.v_proj = nn.Linear(
|
| 69 |
+
value_proj_params.in_features,
|
| 70 |
+
value_proj_params.out_features,
|
| 71 |
+
value_proj_params.bias,
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
logger.info(
|
| 75 |
+
"No Value projection parameters were passed, assuming that the weights"
|
| 76 |
+
+ " are shared with the query projection"
|
| 77 |
+
)
|
| 78 |
+
self.v_proj = self.q_proj
|
| 79 |
+
|
| 80 |
+
if not use_separate_proj_weight:
|
| 81 |
+
# Compute optimization used at times, share the parameters in between Q/K/V
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
self.k_proj.weight = self.q_proj.weight
|
| 84 |
+
self.v_proj.weight = self.q_proj.weight
|
| 85 |
+
|
| 86 |
+
def forward(
|
| 87 |
+
self,
|
| 88 |
+
query: torch.Tensor,
|
| 89 |
+
key: torch.Tensor,
|
| 90 |
+
value: torch.Tensor,
|
| 91 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 92 |
+
# One projection per input tensor
|
| 93 |
+
|
| 94 |
+
# NOTE: Would it make sense to catch self attention + shared weights, to skip a projection step ?
|
| 95 |
+
|
| 96 |
+
q, k, v = map(
|
| 97 |
+
lambda fn, x: fn(x),
|
| 98 |
+
[self.q_proj, self.k_proj, self.v_proj],
|
| 99 |
+
[query, key, value],
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
return q, k, v
|
.venv/lib/python3.11/site-packages/xformers/components/multi_head_dispatch.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from dataclasses import asdict, dataclass
|
| 9 |
+
from typing import Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.nn.init import constant_
|
| 14 |
+
|
| 15 |
+
from xformers._deprecation_warning import deprecated_function
|
| 16 |
+
from xformers.components.attention import Attention
|
| 17 |
+
from xformers.components.input_projection import InputProjection, InputProjectionConfig
|
| 18 |
+
from xformers.components.positional_embedding import RotaryEmbedding
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger("xformers")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class MultiHeadDispatchConfig:
|
| 25 |
+
dim_model: int
|
| 26 |
+
num_heads: int
|
| 27 |
+
attention: Attention
|
| 28 |
+
bias: bool
|
| 29 |
+
residual_dropout: float
|
| 30 |
+
dim_key: Optional[int]
|
| 31 |
+
dim_value: Optional[int]
|
| 32 |
+
in_proj_container: Optional[InputProjection]
|
| 33 |
+
use_separate_proj_weight: Optional[bool]
|
| 34 |
+
use_rotary_embeddings: Optional[bool]
|
| 35 |
+
out_proj: Optional[nn.Module]
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, item):
|
| 38 |
+
return getattr(self, item)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Move head forward and fold into batch dim. dimensions become (B * nh, S, hs)
|
| 42 |
+
def _fold_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int):
|
| 43 |
+
return t.view(B, S, H, Hs).transpose(1, 2).flatten(start_dim=0, end_dim=1)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Move head forward and fold into batch dim. dimensions become (B, nh, S, hs)
|
| 47 |
+
def _split_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int):
|
| 48 |
+
return t.view(B, S, H, Hs).transpose(1, 2)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class MultiHeadDispatch(nn.Module):
|
| 52 |
+
"""
|
| 53 |
+
A multi-head masked self-attention dispatch mechanism, with a projection at the end,
|
| 54 |
+
following the architecture proposed in `Attention is all you need`_, Vaswani et al.
|
| 55 |
+
|
| 56 |
+
The actual attention mechanism can vary, as well as the projections.
|
| 57 |
+
This can be used to wrap the proposed attention mechanisms and make them multi-head aware,
|
| 58 |
+
but it is optional.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
dim_model: The model/embedding dimension
|
| 62 |
+
num_heads: The number of heads being used
|
| 63 |
+
attention: The attention mechanism (needs to be registered to the xformers library)
|
| 64 |
+
bias: Whether to use bias for the projections : (Q, K, V, Output)
|
| 65 |
+
residual_dropout: Amount of dropout on the residual path
|
| 66 |
+
use_separate_proj_weight: Use different weights for the Q, K, V projections
|
| 67 |
+
dim_key: Optionally use a different dimension for the key
|
| 68 |
+
dim_value: Optionally use a different dimension for the value
|
| 69 |
+
in_proj_container: Optionally provide the input projection module
|
| 70 |
+
use_rotary_embeddings: Use rotary embeddings
|
| 71 |
+
out_proj: Optionally provide the output projection module
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
.. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
dim_model: int,
|
| 80 |
+
num_heads: int,
|
| 81 |
+
attention: Attention,
|
| 82 |
+
bias: Tuple[bool, bool, bool, bool] = (True, True, True, True),
|
| 83 |
+
residual_dropout: float = 0.0,
|
| 84 |
+
use_separate_proj_weight: bool = True,
|
| 85 |
+
dim_key: Optional[int] = None,
|
| 86 |
+
dim_value: Optional[int] = None,
|
| 87 |
+
in_proj_container: Optional[InputProjection] = None,
|
| 88 |
+
use_rotary_embeddings: Optional[bool] = False,
|
| 89 |
+
out_proj: Optional[nn.Module] = None,
|
| 90 |
+
*args,
|
| 91 |
+
**kwargs,
|
| 92 |
+
):
|
| 93 |
+
super().__init__()
|
| 94 |
+
deprecated_function(self)
|
| 95 |
+
|
| 96 |
+
if isinstance(bias, bool):
|
| 97 |
+
logger.warning(
|
| 98 |
+
"Single bias value provided for the MHA projections."
|
| 99 |
+
+ f" Assuming the same parameter ({bias}) is to be used everywhere"
|
| 100 |
+
)
|
| 101 |
+
bias = (bias, bias, bias, bias)
|
| 102 |
+
|
| 103 |
+
assert (
|
| 104 |
+
dim_model % num_heads == 0
|
| 105 |
+
) # static preset for now, each head works on 1/d the embeddings, could be relaxed
|
| 106 |
+
assert num_heads > 0
|
| 107 |
+
|
| 108 |
+
# Popular default is that all latent dimensions are the same
|
| 109 |
+
dim_key, dim_value = map(lambda x: x if x else dim_model, (dim_key, dim_value))
|
| 110 |
+
|
| 111 |
+
self.num_heads = num_heads
|
| 112 |
+
self.dim_key_head = dim_key // num_heads
|
| 113 |
+
self.dim_value_head = dim_value // num_heads
|
| 114 |
+
self.dim_model = dim_model
|
| 115 |
+
self.attention = attention
|
| 116 |
+
|
| 117 |
+
# key, query, value projections for all heads
|
| 118 |
+
# critical options are
|
| 119 |
+
# - are we sharing weights ?
|
| 120 |
+
# - are we adding biases ?
|
| 121 |
+
if attention.requires_input_projection:
|
| 122 |
+
self.in_proj_container = (
|
| 123 |
+
in_proj_container
|
| 124 |
+
if in_proj_container is not None
|
| 125 |
+
else InputProjection(
|
| 126 |
+
query_proj_params=InputProjectionConfig(
|
| 127 |
+
dim_model, dim_key, bias=bias[0]
|
| 128 |
+
),
|
| 129 |
+
key_proj_params=InputProjectionConfig(
|
| 130 |
+
dim_model, dim_key, bias=bias[1]
|
| 131 |
+
),
|
| 132 |
+
value_proj_params=InputProjectionConfig(
|
| 133 |
+
dim_model, dim_value, bias=bias[2]
|
| 134 |
+
),
|
| 135 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Optional rotary embeddings
|
| 140 |
+
self.rotary_embeddings = (
|
| 141 |
+
RotaryEmbedding(self.dim_key_head) if use_rotary_embeddings else None
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Regularization
|
| 145 |
+
self.resid_drop = nn.Dropout(residual_dropout, inplace=False)
|
| 146 |
+
|
| 147 |
+
# Output projection
|
| 148 |
+
self.proj = (
|
| 149 |
+
out_proj if out_proj else nn.Linear(dim_model, dim_model, bias=bias[3])
|
| 150 |
+
)
|
| 151 |
+
if isinstance(self.proj, nn.Linear) and self.proj.bias is not None:
|
| 152 |
+
constant_(self.proj.bias, 0.0)
|
| 153 |
+
|
| 154 |
+
def forward(
|
| 155 |
+
self,
|
| 156 |
+
query: torch.Tensor,
|
| 157 |
+
key: Optional[torch.Tensor] = None,
|
| 158 |
+
value: Optional[torch.Tensor] = None,
|
| 159 |
+
att_mask: Optional[torch.Tensor] = None,
|
| 160 |
+
key_padding_mask: Optional[torch.Tensor] = None,
|
| 161 |
+
) -> torch.Tensor:
|
| 162 |
+
"""
|
| 163 |
+
Expected input dimensions are [batch size, sequence length, embed dim]
|
| 164 |
+
Output dimensions are [batch size, sequence length, embed dim]
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
if key is None:
|
| 168 |
+
key = query
|
| 169 |
+
if value is None:
|
| 170 |
+
value = query
|
| 171 |
+
|
| 172 |
+
if query.shape[0] != key.shape[0] or query.shape[0] != value.shape[0]:
|
| 173 |
+
max_batch = max((query.shape[0], key.shape[0], value.shape[0]))
|
| 174 |
+
query, key, value = map(
|
| 175 |
+
lambda x: x.expand(max_batch, -1, -1), [query, key, value]
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
B, S_Q, _ = query.size() # Batch x Sequence x Embedding (latent)
|
| 179 |
+
_, S_K, _ = key.size() # K, Q's sequence length could differ
|
| 180 |
+
|
| 181 |
+
# Catch different query and key length but a causal attention
|
| 182 |
+
if S_Q != S_K:
|
| 183 |
+
assert (
|
| 184 |
+
not self.attention.requires_same_k_q_dimensions
|
| 185 |
+
), "This attention mechanism requires query and key to have the same sequence (context) lengths"
|
| 186 |
+
|
| 187 |
+
if hasattr(self.attention, "causal"):
|
| 188 |
+
assert not self.attention.causal, (
|
| 189 |
+
"Causal attention is not supported when key and query have different sequence lengths.\n"
|
| 190 |
+
+ "In that case causality is ill-determined. Please pad your sequences accordingly"
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
kw_mask_args = {}
|
| 194 |
+
if att_mask is not None:
|
| 195 |
+
assert (
|
| 196 |
+
self.attention.supports_attention_mask
|
| 197 |
+
), "This attention does not support attention masks"
|
| 198 |
+
kw_mask_args["att_mask"] = att_mask
|
| 199 |
+
|
| 200 |
+
if key_padding_mask is not None:
|
| 201 |
+
assert (
|
| 202 |
+
self.attention.supports_key_padding_mask
|
| 203 |
+
), "This attention does not support key padding masks"
|
| 204 |
+
kw_mask_args["key_padding_mask"] = key_padding_mask
|
| 205 |
+
|
| 206 |
+
if self.attention.requires_skip_multi_head:
|
| 207 |
+
return self.attention(query, key, value, **kw_mask_args)
|
| 208 |
+
|
| 209 |
+
# Calculate query, key, values for all heads in batch
|
| 210 |
+
if self.attention.requires_input_projection:
|
| 211 |
+
q, k, v = self.in_proj_container(query=query, key=key, value=value)
|
| 212 |
+
else:
|
| 213 |
+
k, q, v = key, query, value
|
| 214 |
+
|
| 215 |
+
# Check the dimensions properly
|
| 216 |
+
def check(t, name):
|
| 217 |
+
assert (
|
| 218 |
+
t.shape[2] % self.num_heads == 0
|
| 219 |
+
), f"the {name} embeddings need to be divisible by the number of heads"
|
| 220 |
+
|
| 221 |
+
check(q, "projected query")
|
| 222 |
+
check(v, "projected value")
|
| 223 |
+
check(k, "projected key")
|
| 224 |
+
|
| 225 |
+
# Optional: rotary embedding, add relative positioning information
|
| 226 |
+
if self.rotary_embeddings:
|
| 227 |
+
# rotary requires the head dimension
|
| 228 |
+
q = _split_heads(q, B, S_Q, self.num_heads, self.dim_key_head)
|
| 229 |
+
k = _split_heads(k, B, S_K, self.num_heads, self.dim_key_head)
|
| 230 |
+
v = _split_heads(v, B, S_K, self.num_heads, self.dim_value_head)
|
| 231 |
+
|
| 232 |
+
q, k = self.rotary_embeddings(q=q, k=k)
|
| 233 |
+
|
| 234 |
+
if not self.attention.requires_head_dimension:
|
| 235 |
+
q, k, v = q.flatten(0, 1), k.flatten(0, 1), v.flatten(0, 1)
|
| 236 |
+
|
| 237 |
+
else:
|
| 238 |
+
# Reshape k/q/v to either expose the heads, or fold the head dimension into the batch
|
| 239 |
+
reshape_fn = (
|
| 240 |
+
_split_heads if self.attention.requires_head_dimension else _fold_heads
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
q = reshape_fn(q, B, S_Q, self.num_heads, self.dim_key_head)
|
| 244 |
+
k = reshape_fn(k, B, S_K, self.num_heads, self.dim_key_head)
|
| 245 |
+
v = reshape_fn(v, B, S_K, self.num_heads, self.dim_value_head)
|
| 246 |
+
|
| 247 |
+
# Self-attend
|
| 248 |
+
y = self.attention(q, k, v, **kw_mask_args)
|
| 249 |
+
|
| 250 |
+
# Re-assemble all head outputs side by side
|
| 251 |
+
y = (
|
| 252 |
+
y.view(B, self.num_heads, S_Q, self.dim_value_head)
|
| 253 |
+
.transpose(1, 2)
|
| 254 |
+
.flatten(start_dim=2, end_dim=3)
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Output projection, dropout and good to go
|
| 258 |
+
y = self.resid_drop(self.proj(y))
|
| 259 |
+
|
| 260 |
+
# Return the same sequence size as the input
|
| 261 |
+
return y
|
| 262 |
+
|
| 263 |
+
@classmethod
|
| 264 |
+
def from_config(cls, config: MultiHeadDispatchConfig):
|
| 265 |
+
# Generate the class inputs from the config
|
| 266 |
+
fields = asdict(config)
|
| 267 |
+
|
| 268 |
+
# Skip all Nones so that default values are used
|
| 269 |
+
fields = {k: v for k, v in fields.items() if v is not None}
|
| 270 |
+
|
| 271 |
+
return cls(**fields)
|
.venv/lib/python3.11/site-packages/xformers/components/patch_embedding.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from enum import Enum
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from xformers._deprecation_warning import deprecated_function
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PoolType(str, Enum):
|
| 16 |
+
Conv2D = "CONV_2D"
|
| 17 |
+
# ...
|
| 18 |
+
# TODO: Support more cases ?
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class PatchEmbeddingConfig:
|
| 23 |
+
"""
|
| 24 |
+
The configuration for the patch embedding layer, which takes the raw token passed in
|
| 25 |
+
and returns a pooled representation along a given embedding dimension.
|
| 26 |
+
|
| 27 |
+
This typically trades the spatial (context length) representation with the embedding size
|
| 28 |
+
|
| 29 |
+
This is canonicaly used by ViT, but other papers (like MetaFormer or other hierarchical transformers)
|
| 30 |
+
propose a more general use case for this
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
in_channels: int
|
| 34 |
+
out_channels: int
|
| 35 |
+
kernel_size: int
|
| 36 |
+
stride: int
|
| 37 |
+
padding: int = 0
|
| 38 |
+
pool_type: PoolType = PoolType.Conv2D
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ConditionalReshape(torch.nn.Module):
|
| 42 |
+
def __init__(self) -> None:
|
| 43 |
+
super().__init__()
|
| 44 |
+
deprecated_function(self)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
if x.ndim == 3:
|
| 48 |
+
B, HW, C = x.shape
|
| 49 |
+
# NOTE: We're assuming a square sample here
|
| 50 |
+
H = int(math.sqrt(HW))
|
| 51 |
+
assert H * H == HW, f"{H, HW}"
|
| 52 |
+
x = x.transpose(1, 2).reshape(B, C, H, H)
|
| 53 |
+
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class PatchToSequence(torch.nn.Module):
|
| 58 |
+
def __init__(self) -> None:
|
| 59 |
+
super().__init__()
|
| 60 |
+
deprecated_function(self)
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
return x.flatten(2, 3).transpose(1, 2).contiguous() # B HW C
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def build_patch_embedding(config: PatchEmbeddingConfig):
|
| 67 |
+
if not isinstance(config, PatchEmbeddingConfig):
|
| 68 |
+
config = PatchEmbeddingConfig(**config)
|
| 69 |
+
|
| 70 |
+
if config.pool_type == PoolType.Conv2D:
|
| 71 |
+
pool = torch.nn.Conv2d(
|
| 72 |
+
config.in_channels,
|
| 73 |
+
config.out_channels,
|
| 74 |
+
kernel_size=config.kernel_size,
|
| 75 |
+
stride=config.stride,
|
| 76 |
+
padding=config.padding,
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
raise NotImplementedError
|
| 80 |
+
|
| 81 |
+
# The patch embedding supposes that the input really is 2D in essence
|
| 82 |
+
# If this block is in the middle of a stack, we need to reshape
|
| 83 |
+
return torch.nn.Sequential(ConditionalReshape(), pool, PatchToSequence())
|
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__init__.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Callable, Dict, Set, Union
|
| 9 |
+
|
| 10 |
+
from xformers.utils import (
|
| 11 |
+
generate_matching_config,
|
| 12 |
+
get_registry_decorator,
|
| 13 |
+
import_all_modules,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from .base import PositionEmbedding, PositionEmbeddingConfig # noqa
|
| 17 |
+
|
| 18 |
+
# CREDITS: Classy Vision registry mechanism
|
| 19 |
+
|
| 20 |
+
POSITION_EMBEDDING_REGISTRY: Dict[str, Any] = {}
|
| 21 |
+
POSITION_EMBEDDING_CLASS_NAMES: Set[str] = set()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def build_positional_embedding(config: Union[Dict[str, Any], PositionEmbeddingConfig]):
|
| 25 |
+
"""Builds a position encoding from a config.
|
| 26 |
+
|
| 27 |
+
This assumes a 'name' key in the config which is used to determine what
|
| 28 |
+
attention class to instantiate. For instance, a config `{"name": "my_position_encoding",
|
| 29 |
+
"foo": "bar"}` will find a class that was registered as "my_position_encoding"
|
| 30 |
+
(see :func:`register_positional_embedding`) and call .from_config on it."""
|
| 31 |
+
|
| 32 |
+
if not isinstance(config, PositionEmbeddingConfig):
|
| 33 |
+
config_instance = generate_matching_config(
|
| 34 |
+
config, POSITION_EMBEDDING_REGISTRY[config["name"]].config
|
| 35 |
+
)
|
| 36 |
+
else:
|
| 37 |
+
config_instance = config
|
| 38 |
+
|
| 39 |
+
return POSITION_EMBEDDING_REGISTRY[config_instance.name].constructor.from_config(
|
| 40 |
+
config_instance
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
"""Registers a PositionEncoding subclass.
|
| 45 |
+
|
| 46 |
+
This decorator allows xFormers to instantiate a subclass of PositionEncoding
|
| 47 |
+
from a configuration file, even if the class itself is not part of the
|
| 48 |
+
xFormers framework. To use it, apply this decorator to a `PositionEncoding`
|
| 49 |
+
subclass, like this:
|
| 50 |
+
|
| 51 |
+
.. code-block:: python
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class MyConfig:
|
| 55 |
+
...
|
| 56 |
+
|
| 57 |
+
@register_positional_embedding('my_encoding', MyConfig)
|
| 58 |
+
class MyEncoding(PositionEncoding):
|
| 59 |
+
...
|
| 60 |
+
|
| 61 |
+
To instantiate a position encoding from a configuration file, see :func:`build_positional_embedding`."""
|
| 62 |
+
register_positional_embedding: Callable[
|
| 63 |
+
[str, Any], Callable[[Any], Any]
|
| 64 |
+
] = get_registry_decorator(
|
| 65 |
+
POSITION_EMBEDDING_REGISTRY,
|
| 66 |
+
POSITION_EMBEDDING_CLASS_NAMES,
|
| 67 |
+
PositionEmbedding,
|
| 68 |
+
PositionEmbeddingConfig,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
from .rotary import RotaryEmbedding # noqa
|
| 73 |
+
from .sine import SinePositionalEmbedding # type: ignore # noqa
|
| 74 |
+
from .vocab import VocabEmbedding # noqa
|
| 75 |
+
|
| 76 |
+
__all__ = [
|
| 77 |
+
"RotaryEmbedding",
|
| 78 |
+
"SinePositionalEmbedding",
|
| 79 |
+
"VocabEmbedding",
|
| 80 |
+
"build_positional_embedding",
|
| 81 |
+
"register_positional_embedding",
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
# automatically import any Python files in the directory
|
| 85 |
+
import_all_modules(
|
| 86 |
+
str(Path(__file__).parent), "xformers.components.positional_embedding"
|
| 87 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.53 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (2.38 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/param.cpython-311.pyc
ADDED
|
Binary file (2.87 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/rotary.cpython-311.pyc
ADDED
|
Binary file (4.85 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/sine.cpython-311.pyc
ADDED
|
Binary file (2.67 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/__pycache__/vocab.cpython-311.pyc
ADDED
|
Binary file (3.52 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/base.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from abc import ABCMeta, abstractmethod
|
| 8 |
+
from dataclasses import asdict, dataclass
|
| 9 |
+
from typing import Type, TypeVar
|
| 10 |
+
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from xformers._deprecation_warning import deprecated_function
|
| 14 |
+
|
| 15 |
+
Self = TypeVar("Self", bound="PositionEmbedding")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class PositionEmbeddingConfig:
|
| 20 |
+
name: str
|
| 21 |
+
dim_model: int
|
| 22 |
+
seq_len: int
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PositionEmbedding(nn.Module, metaclass=ABCMeta):
|
| 26 |
+
@abstractmethod
|
| 27 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 28 |
+
super().__init__()
|
| 29 |
+
deprecated_function(self)
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def from_config(cls: Type[Self], config: PositionEmbeddingConfig) -> Self:
|
| 33 |
+
# Generate the class inputs from the config
|
| 34 |
+
fields = asdict(config)
|
| 35 |
+
|
| 36 |
+
# Skip all Nones so that default values are used
|
| 37 |
+
fields = {k: v for k, v in fields.items() if v is not None}
|
| 38 |
+
return cls(**fields)
|
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/param.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from xformers.components.positional_embedding import (
|
| 12 |
+
PositionEmbedding,
|
| 13 |
+
PositionEmbeddingConfig,
|
| 14 |
+
register_positional_embedding,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class LearnablePositionalEmbeddingConfig(PositionEmbeddingConfig):
|
| 20 |
+
name: str
|
| 21 |
+
seq_len: int
|
| 22 |
+
dim_model: int
|
| 23 |
+
add_class_token: bool
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@register_positional_embedding("learnable", LearnablePositionalEmbeddingConfig)
|
| 27 |
+
class LearnablePositionalEmbedding(PositionEmbedding):
|
| 28 |
+
def __init__(
|
| 29 |
+
self, seq_len: int, dim_model: int, add_class_token: bool = False, *_, **__
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
# 0.02 is BERT initialization
|
| 34 |
+
self.pos_emb = torch.nn.Parameter(
|
| 35 |
+
torch.randn(1, seq_len + int(add_class_token), dim_model) * 0.02
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
self.class_token = (
|
| 39 |
+
torch.nn.Parameter(torch.zeros(dim_model)) if add_class_token else None
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 43 |
+
if self.class_token is not None:
|
| 44 |
+
# Prepend class token
|
| 45 |
+
clf_token = (
|
| 46 |
+
torch.ones(x.shape[0], 1, self.pos_emb.shape[-1], device=x.device)
|
| 47 |
+
* self.class_token
|
| 48 |
+
)
|
| 49 |
+
x = torch.cat([clf_token, x], dim=1)
|
| 50 |
+
|
| 51 |
+
if x.ndim == 2:
|
| 52 |
+
x = x.unsqueeze(-1)
|
| 53 |
+
|
| 54 |
+
return x + self.pos_emb
|
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/rotary.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox
|
| 8 |
+
# NOTE: Almost the same right now, moving parts to Triton is the next step
|
| 9 |
+
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def rotate_half(x):
|
| 16 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 17 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@torch.jit.script
|
| 21 |
+
def apply_rotary_pos_emb(x, cos, sin):
|
| 22 |
+
# NOTE: This could probably be moved to Triton
|
| 23 |
+
|
| 24 |
+
# Handle a possible sequence length mismatch in between q and k
|
| 25 |
+
cos = cos[:, :, : x.shape[-2], :]
|
| 26 |
+
sin = sin[:, :, : x.shape[-2], :]
|
| 27 |
+
|
| 28 |
+
return (x * cos) + (rotate_half(x) * sin)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class RotaryEmbedding(torch.nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
The rotary position embeddings from RoFormer_ (Su et. al).
|
| 34 |
+
A crucial insight from the method is that the query and keys are
|
| 35 |
+
transformed by rotation matrices which depend on the relative positions.
|
| 36 |
+
|
| 37 |
+
Other implementations are available in the Rotary Transformer repo_ and in
|
| 38 |
+
GPT-NeoX_, GPT-NeoX was an inspiration
|
| 39 |
+
|
| 40 |
+
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
| 41 |
+
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
| 42 |
+
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
.. warning: Please note that this embedding is not registered on purpose, as it is transformative
|
| 46 |
+
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, dim_model: int, *_, **__):
|
| 50 |
+
super().__init__()
|
| 51 |
+
# Generate and save the inverse frequency buffer (non trainable)
|
| 52 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
|
| 53 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 54 |
+
|
| 55 |
+
self._seq_len_cached = None
|
| 56 |
+
self._cos_cached = None
|
| 57 |
+
self._sin_cached = None
|
| 58 |
+
|
| 59 |
+
def _update_cos_sin_tables(self, x, seq_dimension=1):
|
| 60 |
+
seq_len = x.shape[seq_dimension]
|
| 61 |
+
|
| 62 |
+
# Reset the tables if the sequence length has changed,
|
| 63 |
+
# or if we're on a new device (possibly due to tracing for instance)
|
| 64 |
+
if (
|
| 65 |
+
seq_len != self._seq_len_cached
|
| 66 |
+
or self._cos_cached.device != x.device
|
| 67 |
+
or self._cos_cached.dtype != x.dtype
|
| 68 |
+
):
|
| 69 |
+
self._seq_len_cached = seq_len
|
| 70 |
+
t = torch.arange(
|
| 71 |
+
x.shape[seq_dimension], device=x.device, dtype=torch.float32
|
| 72 |
+
)
|
| 73 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
|
| 74 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
| 75 |
+
|
| 76 |
+
self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
|
| 77 |
+
self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
|
| 78 |
+
|
| 79 |
+
return self._cos_cached, self._sin_cached
|
| 80 |
+
|
| 81 |
+
def forward(
|
| 82 |
+
self, q: torch.Tensor, k: torch.Tensor
|
| 83 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 84 |
+
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
|
| 85 |
+
k, seq_dimension=-2
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return (
|
| 89 |
+
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
|
| 90 |
+
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
|
| 91 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/sine.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Silence Mypy errors in this file.
|
| 8 |
+
# type: ignore
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from xformers.components.positional_embedding import (
|
| 15 |
+
PositionEmbedding,
|
| 16 |
+
PositionEmbeddingConfig,
|
| 17 |
+
register_positional_embedding,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@register_positional_embedding("sine", PositionEmbeddingConfig)
|
| 22 |
+
class SinePositionalEmbedding(PositionEmbedding):
|
| 23 |
+
def __init__(self, dim_model: int, *args, **kwargs):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.dim_model = dim_model
|
| 26 |
+
|
| 27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
seq_len = x.shape[1]
|
| 29 |
+
pos = (
|
| 30 |
+
torch.arange(0, seq_len, device=x.device, dtype=torch.float32)
|
| 31 |
+
.unsqueeze(1)
|
| 32 |
+
.repeat(1, self.dim_model)
|
| 33 |
+
)
|
| 34 |
+
dim = (
|
| 35 |
+
torch.arange(0, self.dim_model, device=x.device, dtype=torch.float32)
|
| 36 |
+
.unsqueeze(0)
|
| 37 |
+
.repeat(seq_len, 1)
|
| 38 |
+
)
|
| 39 |
+
div = torch.exp(-math.log(10000) * (2 * (dim // 2) / self.dim_model))
|
| 40 |
+
pos *= div
|
| 41 |
+
pos[:, 0::2] = torch.sin(pos[:, 0::2])
|
| 42 |
+
pos[:, 1::2] = torch.cos(pos[:, 1::2])
|
| 43 |
+
|
| 44 |
+
output = x.unsqueeze(-1) if x.ndim == 2 else x
|
| 45 |
+
|
| 46 |
+
return output + pos.unsqueeze(0)
|
.venv/lib/python3.11/site-packages/xformers/components/positional_embedding/vocab.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from xformers.components.positional_embedding import (
|
| 14 |
+
PositionEmbedding,
|
| 15 |
+
PositionEmbeddingConfig,
|
| 16 |
+
register_positional_embedding,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class VocabEmbeddingConfig(PositionEmbeddingConfig):
|
| 22 |
+
vocab_size: int
|
| 23 |
+
dropout: float
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@register_positional_embedding("vocab", VocabEmbeddingConfig)
|
| 27 |
+
class VocabEmbedding(PositionEmbedding):
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
dim_model: int,
|
| 31 |
+
seq_len: int,
|
| 32 |
+
vocab_size: int,
|
| 33 |
+
dropout: float = 0.0,
|
| 34 |
+
*args,
|
| 35 |
+
**kwargs
|
| 36 |
+
):
|
| 37 |
+
super().__init__()
|
| 38 |
+
|
| 39 |
+
self.vocab_size = vocab_size
|
| 40 |
+
self.dim_model = dim_model
|
| 41 |
+
|
| 42 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
| 43 |
+
self.position_embeddings = nn.Embedding(seq_len, self.dim_model)
|
| 44 |
+
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
|
| 45 |
+
|
| 46 |
+
self.position_ids: Optional[torch.Tensor] = None
|
| 47 |
+
|
| 48 |
+
self.init_weights()
|
| 49 |
+
|
| 50 |
+
def init_weights(self, gain: float = 1.0):
|
| 51 |
+
torch.nn.init.normal_(self.position_embeddings.weight, std=0.02 * gain)
|
| 52 |
+
torch.nn.init.normal_(self.word_embeddings.weight, std=0.02 * gain)
|
| 53 |
+
|
| 54 |
+
def forward(self, x: torch.Tensor):
|
| 55 |
+
position_ids = torch.arange(x.shape[1], dtype=torch.long, device=x.device)[
|
| 56 |
+
None, :
|
| 57 |
+
].repeat(x.shape[0], 1)
|
| 58 |
+
|
| 59 |
+
X_token = self.word_embeddings(x)
|
| 60 |
+
X_pos = self.position_embeddings(position_ids)
|
| 61 |
+
|
| 62 |
+
X = X_token + X_pos
|
| 63 |
+
X = self.dropout(X)
|
| 64 |
+
|
| 65 |
+
return X
|
.venv/lib/python3.11/site-packages/xformers/components/residual.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from collections import namedtuple
|
| 8 |
+
from enum import Enum
|
| 9 |
+
from typing import List, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
from xformers._deprecation_warning import deprecated_function
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ResidualNormStyle(str, Enum):
|
| 18 |
+
"""Support different residual path and norm styles.
|
| 19 |
+
See "On Layer Normalization in the Transformer Architecture",
|
| 20 |
+
Xiong et al., https://arxiv.org/pdf/2002.04745v1.pdf
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
Pre = "pre"
|
| 24 |
+
Post = "post"
|
| 25 |
+
DeepNorm = "deepnorm"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class NormalizationType(str, Enum):
|
| 29 |
+
LayerNorm = "layernorm"
|
| 30 |
+
Skip = "skip"
|
| 31 |
+
# TODO: BatchNorm = "batchnorm"
|
| 32 |
+
# TODO: GroupNorm = "groupnorm"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_normalization_layer(normalization_type: NormalizationType):
|
| 36 |
+
class Skip(nn.Module):
|
| 37 |
+
def __init__(self, *_, **__) -> None:
|
| 38 |
+
super().__init__()
|
| 39 |
+
deprecated_function(self)
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor, **_):
|
| 42 |
+
return x
|
| 43 |
+
|
| 44 |
+
return {
|
| 45 |
+
NormalizationType.LayerNorm: nn.LayerNorm,
|
| 46 |
+
NormalizationType.Skip: Skip,
|
| 47 |
+
}[normalization_type]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class RequiresWrappedInputs:
|
| 51 |
+
"""Used to mark, through inheritance,
|
| 52 |
+
the fact that this class will require inputs to be passed as a single list"""
|
| 53 |
+
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# CREDITS: the following is inspired by FastAI's Transformer implementation
|
| 58 |
+
class Residual(nn.Module, RequiresWrappedInputs):
|
| 59 |
+
"""
|
| 60 |
+
Object-oriented handling of the residual path
|
| 61 |
+
|
| 62 |
+
This supports scaling of the residual path, as proposed by DeepNet_
|
| 63 |
+
.. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf
|
| 64 |
+
|
| 65 |
+
.. Note: the wrapped layers must accept all the inputs as a single list
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, layer: nn.Module, scale: Optional[float] = None):
|
| 69 |
+
super().__init__()
|
| 70 |
+
deprecated_function(self)
|
| 71 |
+
self.layer = layer
|
| 72 |
+
self.scale = scale
|
| 73 |
+
|
| 74 |
+
# PreNorm and PostNorm require all the tensors to be passed as a list
|
| 75 |
+
self.wrap_inputs = isinstance(layer, RequiresWrappedInputs)
|
| 76 |
+
|
| 77 |
+
def forward(self, inputs: List[torch.Tensor], **kwargs):
|
| 78 |
+
if self.scale is not None:
|
| 79 |
+
residue = inputs[0] * self.scale
|
| 80 |
+
else:
|
| 81 |
+
residue = inputs[0]
|
| 82 |
+
|
| 83 |
+
if self.wrap_inputs:
|
| 84 |
+
return residue + self.layer(inputs=inputs, **kwargs)
|
| 85 |
+
|
| 86 |
+
else:
|
| 87 |
+
return residue + self.layer(*inputs, **kwargs)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class PreNorm(nn.Module, RequiresWrappedInputs):
|
| 91 |
+
"""Adds a normalization before computing attention
|
| 92 |
+
|
| 93 |
+
..Note: If a list of inputs is passed, all of them get normalized"""
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
d_norm: int,
|
| 98 |
+
sublayer: nn.Module,
|
| 99 |
+
normalization: NormalizationType,
|
| 100 |
+
use_triton: bool = True,
|
| 101 |
+
):
|
| 102 |
+
|
| 103 |
+
super().__init__()
|
| 104 |
+
deprecated_function(self)
|
| 105 |
+
self.norm = get_normalization_layer(normalization)(d_norm)
|
| 106 |
+
|
| 107 |
+
self.sublayer = sublayer
|
| 108 |
+
self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs)
|
| 109 |
+
|
| 110 |
+
def forward(self, inputs: List[torch.Tensor], **kwargs):
|
| 111 |
+
assert len(inputs) > 0
|
| 112 |
+
|
| 113 |
+
# Perf improvement: if the inputs are all the same, only norm once
|
| 114 |
+
ids = [id(x) for x in inputs]
|
| 115 |
+
if ids.count(ids[0]) == len(ids):
|
| 116 |
+
# The same tensor is passed multiple times
|
| 117 |
+
x_norm = self.norm(inputs[0])
|
| 118 |
+
inputs_normed = [x_norm for _ in inputs]
|
| 119 |
+
else:
|
| 120 |
+
# The inputs differ, norm them all
|
| 121 |
+
inputs_normed = [self.norm(x_) for x_ in inputs]
|
| 122 |
+
|
| 123 |
+
if self.wrap_inputs:
|
| 124 |
+
return self.sublayer(inputs=inputs_normed, **kwargs)
|
| 125 |
+
else:
|
| 126 |
+
return self.sublayer(*inputs_normed, **kwargs)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class PostNorm(nn.Module, RequiresWrappedInputs):
|
| 130 |
+
"""Adds LayerNorm after computing attention"""
|
| 131 |
+
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
d_norm: int,
|
| 135 |
+
sublayer: nn.Module,
|
| 136 |
+
normalization: NormalizationType,
|
| 137 |
+
use_triton: bool = True,
|
| 138 |
+
):
|
| 139 |
+
super().__init__()
|
| 140 |
+
deprecated_function(self)
|
| 141 |
+
self.norm = get_normalization_layer(normalization)(d_norm)
|
| 142 |
+
|
| 143 |
+
self.sublayer = sublayer
|
| 144 |
+
self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs)
|
| 145 |
+
|
| 146 |
+
def forward(self, inputs: List[torch.Tensor], **kwargs):
|
| 147 |
+
if self.wrap_inputs:
|
| 148 |
+
x = self.sublayer(inputs=inputs, **kwargs)
|
| 149 |
+
else:
|
| 150 |
+
x = self.sublayer(*inputs, **kwargs)
|
| 151 |
+
return self.norm(x)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"])
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_deepnorm_coefficients(
|
| 158 |
+
encoder_layers: int, decoder_layers: int
|
| 159 |
+
) -> Tuple[Optional[DeepNormCoefficients], Optional[DeepNormCoefficients]]:
|
| 160 |
+
"""
|
| 161 |
+
See DeepNet_.
|
| 162 |
+
|
| 163 |
+
Returns alpha and beta depending on the number of encoder and decoder layers,
|
| 164 |
+
first tuple is for the encoder and second for the decoder
|
| 165 |
+
|
| 166 |
+
.. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
N = encoder_layers
|
| 170 |
+
M = decoder_layers
|
| 171 |
+
|
| 172 |
+
if decoder_layers == 0:
|
| 173 |
+
# Encoder only
|
| 174 |
+
return (
|
| 175 |
+
DeepNormCoefficients(alpha=(2 * N) ** 0.25, beta=(8 * N) ** -0.25),
|
| 176 |
+
None,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
elif encoder_layers == 0:
|
| 180 |
+
# Decoder only
|
| 181 |
+
return None, DeepNormCoefficients(alpha=(2 * M) ** 0.25, beta=(8 * M) ** -0.25)
|
| 182 |
+
else:
|
| 183 |
+
# Encoder/decoder
|
| 184 |
+
encoder_coeffs = DeepNormCoefficients(
|
| 185 |
+
alpha=0.81 * ((N**4) * M) ** 0.0625, beta=0.87 * ((N**4) * M) ** -0.0625
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
decoder_coeffs = DeepNormCoefficients(
|
| 189 |
+
alpha=(3 * M) ** 0.25, beta=(12 * M) ** -0.25
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
return (encoder_coeffs, decoder_coeffs)
|
.venv/lib/python3.11/site-packages/xformers/components/reversible.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from typing import List
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.autograd.function import Function
|
| 12 |
+
from torch.utils.checkpoint import get_device_states, set_device_states
|
| 13 |
+
|
| 14 |
+
from xformers._deprecation_warning import deprecated_function
|
| 15 |
+
from xformers.components import RequiresWrappedInputs
|
| 16 |
+
|
| 17 |
+
# CREDITS: Code adapted from
|
| 18 |
+
# https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py
|
| 19 |
+
# https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py,
|
| 20 |
+
# https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# pyre-fixme[13]: `cpu_state` is not initialized in the constructor.
|
| 24 |
+
class Deterministic(nn.Module):
|
| 25 |
+
def __init__(self, net: nn.Module):
|
| 26 |
+
super().__init__()
|
| 27 |
+
deprecated_function(self)
|
| 28 |
+
self.net = net
|
| 29 |
+
self.cpu_state: torch.Tensor = torch.get_rng_state()
|
| 30 |
+
self.cuda_in_fwd: bool = False
|
| 31 |
+
self.gpu_devices: List[int] = []
|
| 32 |
+
self.gpu_states: List[torch.Tensor] = []
|
| 33 |
+
self.wrap_inputs = isinstance(net, RequiresWrappedInputs)
|
| 34 |
+
|
| 35 |
+
def record_rng(self, *args):
|
| 36 |
+
self.cpu_state = torch.get_rng_state()
|
| 37 |
+
if torch.cuda._initialized:
|
| 38 |
+
self.cuda_in_fwd = True
|
| 39 |
+
self.gpu_devices, self.gpu_states = get_device_states(*args)
|
| 40 |
+
|
| 41 |
+
def forward(self, *args, record_rng: bool = False, set_rng: bool = False, **kwargs):
|
| 42 |
+
if record_rng:
|
| 43 |
+
self.record_rng(*args)
|
| 44 |
+
|
| 45 |
+
if not set_rng:
|
| 46 |
+
# Normal FW run
|
| 47 |
+
if self.wrap_inputs:
|
| 48 |
+
return self.net(inputs=args, **kwargs)
|
| 49 |
+
else:
|
| 50 |
+
return self.net(*args, **kwargs)
|
| 51 |
+
|
| 52 |
+
else: # pragma: no cover # this is called in the backward pass, not picked up
|
| 53 |
+
# This is analogous to checkpointing, reset the original random state
|
| 54 |
+
rng_devices: List[int] = []
|
| 55 |
+
if self.cuda_in_fwd:
|
| 56 |
+
rng_devices = self.gpu_devices
|
| 57 |
+
|
| 58 |
+
with torch.random.fork_rng(devices=rng_devices, enabled=True):
|
| 59 |
+
torch.set_rng_state(self.cpu_state)
|
| 60 |
+
if self.cuda_in_fwd:
|
| 61 |
+
set_device_states(self.gpu_devices, self.gpu_states)
|
| 62 |
+
|
| 63 |
+
if self.wrap_inputs:
|
| 64 |
+
return self.net(inputs=args, **kwargs)
|
| 65 |
+
else:
|
| 66 |
+
return self.net(*args, **kwargs)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ReversibleBlock(nn.Module):
|
| 70 |
+
def __init__(self, f: nn.Module, g: nn.Module, split_dim: int = -1):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.f = Deterministic(f)
|
| 73 |
+
self.g = Deterministic(g)
|
| 74 |
+
self.split_dim = split_dim
|
| 75 |
+
|
| 76 |
+
def forward(self, x: torch.Tensor, f_args={}, g_args={}):
|
| 77 |
+
x1, x2 = torch.chunk(x, 2, dim=-1)
|
| 78 |
+
y1, y2 = None, None
|
| 79 |
+
|
| 80 |
+
with torch.no_grad():
|
| 81 |
+
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
|
| 82 |
+
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
|
| 83 |
+
|
| 84 |
+
return torch.cat([y1, y2], dim=self.split_dim)
|
| 85 |
+
|
| 86 |
+
def backward_pass(
|
| 87 |
+
self, y: torch.Tensor, dy: torch.Tensor, f_args={}, g_args={}
|
| 88 |
+
): # pragma: no cover # this is covered, but called directly from C++
|
| 89 |
+
y1, y2 = torch.chunk(y, 2, dim=self.split_dim)
|
| 90 |
+
del y
|
| 91 |
+
|
| 92 |
+
dy1, dy2 = torch.chunk(dy, 2, dim=self.split_dim)
|
| 93 |
+
del dy
|
| 94 |
+
|
| 95 |
+
with torch.enable_grad():
|
| 96 |
+
y1.requires_grad = True
|
| 97 |
+
gy1 = self.g(y1, set_rng=True, **g_args)
|
| 98 |
+
torch.autograd.backward(gy1, dy2)
|
| 99 |
+
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
x2 = y2 - gy1
|
| 102 |
+
del y2, gy1
|
| 103 |
+
|
| 104 |
+
dx1 = dy1 + y1.grad
|
| 105 |
+
del dy1
|
| 106 |
+
y1.grad = None
|
| 107 |
+
|
| 108 |
+
with torch.enable_grad():
|
| 109 |
+
x2.requires_grad = True
|
| 110 |
+
fx2 = self.f(x2, set_rng=True, **f_args)
|
| 111 |
+
torch.autograd.backward(fx2, dx1)
|
| 112 |
+
|
| 113 |
+
with torch.no_grad():
|
| 114 |
+
x1 = y1 - fx2
|
| 115 |
+
del y1, fx2
|
| 116 |
+
|
| 117 |
+
dx2 = dy2 + x2.grad
|
| 118 |
+
del dy2
|
| 119 |
+
x2.grad = None
|
| 120 |
+
|
| 121 |
+
x = torch.cat([x1, x2.detach()], dim=self.split_dim)
|
| 122 |
+
dx = torch.cat([dx1, dx2], dim=self.split_dim)
|
| 123 |
+
|
| 124 |
+
return x, dx
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class _ReversibleFunction(Function):
|
| 128 |
+
@staticmethod
|
| 129 |
+
def forward(ctx, x, blocks, kwargs):
|
| 130 |
+
ctx.kwargs = kwargs
|
| 131 |
+
for block in blocks:
|
| 132 |
+
x = block(x, **kwargs)
|
| 133 |
+
ctx.y = x.detach()
|
| 134 |
+
ctx.blocks = blocks
|
| 135 |
+
return x
|
| 136 |
+
|
| 137 |
+
@staticmethod
|
| 138 |
+
def backward(
|
| 139 |
+
ctx, dy
|
| 140 |
+
): # pragma: no cover # this is covered, but called directly from C++
|
| 141 |
+
y = ctx.y
|
| 142 |
+
kwargs = ctx.kwargs
|
| 143 |
+
for block in ctx.blocks[::-1]:
|
| 144 |
+
y, dy = block.backward_pass(y, dy, **kwargs)
|
| 145 |
+
return dy, None, None
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class ReversibleSequence(nn.Module):
|
| 149 |
+
def __init__(self, blocks: nn.ModuleList):
|
| 150 |
+
super().__init__()
|
| 151 |
+
deprecated_function(self)
|
| 152 |
+
|
| 153 |
+
# pyre-fixme[23]: Unable to unpack `torch.nn.Module` into 2 values.
|
| 154 |
+
self.blocks = nn.ModuleList([ReversibleBlock(f, g) for f, g in blocks])
|
| 155 |
+
|
| 156 |
+
def forward(self, x, arg_route=(True, False), **kwargs):
|
| 157 |
+
f_args, g_args = map(lambda route: kwargs if route else {}, arg_route)
|
| 158 |
+
block_kwargs = {"f_args": f_args, "g_args": g_args}
|
| 159 |
+
|
| 160 |
+
return _ReversibleFunction.apply(x, self.blocks, block_kwargs)
|
.venv/lib/python3.11/site-packages/xformers/components/simplicial_embedding.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from dataclasses import asdict, dataclass
|
| 7 |
+
from typing import Optional, Type, TypeVar
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from xformers._deprecation_warning import deprecated_function
|
| 12 |
+
|
| 13 |
+
Self = TypeVar("Self", bound="SimplicialEmbedding")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class SimplicialEmbeddingConfig:
|
| 18 |
+
L: int
|
| 19 |
+
temperature: float
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SimplicialEmbedding(torch.nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
An implementation of the "Simplicial Embeddings"_, as proposed by Lavoie et. al
|
| 25 |
+
|
| 26 |
+
Arguments:
|
| 27 |
+
- L: the number of embedding chunks
|
| 28 |
+
- temperature: optional scaling parameter for the softmax operation.
|
| 29 |
+
A small (<1.) temperature will lead to a sparse representation (up to one-hot),
|
| 30 |
+
while a large (>1.) temperature will make the vector more uniform
|
| 31 |
+
|
| 32 |
+
_"Simplicial Embeddings": https://arxiv.org/pdf/2204.00616.pdf
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, L: int, temperature: Optional[float] = None) -> None:
|
| 36 |
+
super().__init__()
|
| 37 |
+
deprecated_function(self)
|
| 38 |
+
self.L = L
|
| 39 |
+
self.temperature = temperature
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
assert (
|
| 43 |
+
x.shape[-1] % self.L == 0
|
| 44 |
+
), f"The embedding dimension {x.shape[-1]} is not divisible by the chosen L parameter {self.L}"
|
| 45 |
+
|
| 46 |
+
# Separate the input tensor into V chunks
|
| 47 |
+
B, C, E = x.shape
|
| 48 |
+
V = E // self.L
|
| 49 |
+
|
| 50 |
+
Vs = x.reshape(B, C, self.L, V)
|
| 51 |
+
|
| 52 |
+
# Softmax normalize them, with the proposed temperature
|
| 53 |
+
# This is done over the last dimension, so only within Vs
|
| 54 |
+
if self.temperature is not None:
|
| 55 |
+
Vs /= self.temperature
|
| 56 |
+
|
| 57 |
+
Vs = torch.nn.functional.softmax(Vs, dim=-1)
|
| 58 |
+
|
| 59 |
+
# Concatenate back and return
|
| 60 |
+
return Vs.reshape(B, C, E)
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def from_config(cls: Type[Self], config: SimplicialEmbeddingConfig) -> Self:
|
| 64 |
+
# Generate the class inputs from the config
|
| 65 |
+
fields = asdict(config)
|
| 66 |
+
|
| 67 |
+
return cls(**fields)
|
.venv/lib/python3.11/site-packages/xformers/ops/__init__.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from .fmha import (
|
| 9 |
+
AttentionBias,
|
| 10 |
+
AttentionOp,
|
| 11 |
+
AttentionOpBase,
|
| 12 |
+
LowerTriangularMask,
|
| 13 |
+
MemoryEfficientAttentionCkOp,
|
| 14 |
+
MemoryEfficientAttentionCutlassFwdFlashBwOp,
|
| 15 |
+
MemoryEfficientAttentionCutlassOp,
|
| 16 |
+
MemoryEfficientAttentionFlashAttentionOp,
|
| 17 |
+
MemoryEfficientAttentionSplitKCkOp,
|
| 18 |
+
memory_efficient_attention,
|
| 19 |
+
memory_efficient_attention_backward,
|
| 20 |
+
memory_efficient_attention_forward,
|
| 21 |
+
memory_efficient_attention_forward_requires_grad,
|
| 22 |
+
)
|
| 23 |
+
from .indexing import index_select_cat, scaled_index_add
|
| 24 |
+
from .ipc import init_ipc
|
| 25 |
+
from .modpar_layers import ColumnParallelLinear, RowParallelLinear
|
| 26 |
+
from .rmsnorm import RMSNorm
|
| 27 |
+
from .rope_padded import rope_padded
|
| 28 |
+
from .seqpar import sequence_parallel_leading_matmul, sequence_parallel_trailing_matmul
|
| 29 |
+
from .sequence_parallel_fused_ops import (
|
| 30 |
+
fused_allgather_and_anything,
|
| 31 |
+
fused_allgather_and_linear,
|
| 32 |
+
fused_anything_and_reducescatter,
|
| 33 |
+
fused_linear_and_reducescatter,
|
| 34 |
+
)
|
| 35 |
+
from .sp24 import Sparse24Tensor, sparsify24, sparsify24_like
|
| 36 |
+
from .swiglu_op import (
|
| 37 |
+
SwiGLU,
|
| 38 |
+
SwiGLUEagerOp,
|
| 39 |
+
SwiGLUFusedOp,
|
| 40 |
+
SwiGLUOp,
|
| 41 |
+
SwiGLUOpDispatch,
|
| 42 |
+
SwiGLUPackedFusedOp,
|
| 43 |
+
swiglu,
|
| 44 |
+
)
|
| 45 |
+
from .tiled_matmul import tiled_matmul
|
| 46 |
+
from .unbind import get_stack_strides, stack_or_none, unbind
|
| 47 |
+
|
| 48 |
+
# BW compatibility
|
| 49 |
+
AttentionMask = AttentionBias
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def masked_matmul(a, b, mask=None):
|
| 53 |
+
if torch.overrides.has_torch_function((a, b, mask)):
|
| 54 |
+
return torch.overrides.handle_torch_function(
|
| 55 |
+
masked_matmul, (a, b, mask), a, b, mask
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
att = a @ b
|
| 59 |
+
|
| 60 |
+
if mask is None:
|
| 61 |
+
return att
|
| 62 |
+
|
| 63 |
+
if mask.dtype == torch.bool:
|
| 64 |
+
if mask.ndim == 2:
|
| 65 |
+
mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
|
| 66 |
+
# mask is presumed false == ignore
|
| 67 |
+
att[~mask] = float("-inf")
|
| 68 |
+
else:
|
| 69 |
+
# mask is presumed additive
|
| 70 |
+
att += mask
|
| 71 |
+
return att
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
__all__ = [
|
| 75 |
+
# fmha
|
| 76 |
+
"AttentionBias",
|
| 77 |
+
"AttentionMask",
|
| 78 |
+
"AttentionOp",
|
| 79 |
+
"AttentionOpBase",
|
| 80 |
+
"LowerTriangularMask",
|
| 81 |
+
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
|
| 82 |
+
"MemoryEfficientAttentionCutlassOp",
|
| 83 |
+
"MemoryEfficientAttentionFlashAttentionOp",
|
| 84 |
+
"MemoryEfficientAttentionCkOp",
|
| 85 |
+
"MemoryEfficientAttentionSplitKCkOp",
|
| 86 |
+
"memory_efficient_attention",
|
| 87 |
+
"memory_efficient_attention_backward",
|
| 88 |
+
"memory_efficient_attention_forward",
|
| 89 |
+
"memory_efficient_attention_forward_requires_grad",
|
| 90 |
+
# indexing
|
| 91 |
+
"index_select_cat",
|
| 92 |
+
"scaled_index_add",
|
| 93 |
+
# ipc
|
| 94 |
+
"init_ipc",
|
| 95 |
+
# modpar_layers
|
| 96 |
+
"ColumnParallelLinear",
|
| 97 |
+
"RowParallelLinear",
|
| 98 |
+
# rmsnorm
|
| 99 |
+
"RMSNorm",
|
| 100 |
+
# rope_padded
|
| 101 |
+
"rope_padded",
|
| 102 |
+
# seqpar
|
| 103 |
+
"sequence_parallel_leading_matmul",
|
| 104 |
+
"sequence_parallel_trailing_matmul",
|
| 105 |
+
# sequence_parallel_fused_ops
|
| 106 |
+
"fused_allgather_and_anything",
|
| 107 |
+
"fused_allgather_and_linear",
|
| 108 |
+
"fused_anything_and_reducescatter",
|
| 109 |
+
"fused_linear_and_reducescatter",
|
| 110 |
+
# swiglu_op
|
| 111 |
+
"SwiGLU",
|
| 112 |
+
"SwiGLUEagerOp",
|
| 113 |
+
"SwiGLUFusedOp",
|
| 114 |
+
"SwiGLUOp",
|
| 115 |
+
"SwiGLUOpDispatch",
|
| 116 |
+
"SwiGLUPackedFusedOp",
|
| 117 |
+
"swiglu",
|
| 118 |
+
# tiled_matmul
|
| 119 |
+
"tiled_matmul",
|
| 120 |
+
# unbind
|
| 121 |
+
"get_stack_strides",
|
| 122 |
+
"stack_or_none",
|
| 123 |
+
"unbind",
|
| 124 |
+
# sp24
|
| 125 |
+
"sparsify24",
|
| 126 |
+
"sparsify24_like",
|
| 127 |
+
"Sparse24Tensor",
|
| 128 |
+
# .
|
| 129 |
+
"masked_matmul",
|
| 130 |
+
]
|
.venv/lib/python3.11/site-packages/xformers/ops/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (3.34 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/ops/_triton/k_index_select_cat.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@triton.jit
|
| 12 |
+
def index_select_cat_fwd_kernel(
|
| 13 |
+
output_ptr, # *Pointer* to output tensor.
|
| 14 |
+
source_ptr, # *Pointer* to source tensor.
|
| 15 |
+
index_ptr, # *Pointer* to index tensor.
|
| 16 |
+
num_indices,
|
| 17 |
+
num_cols,
|
| 18 |
+
stride0, # Stride information of source tensor.
|
| 19 |
+
stride1,
|
| 20 |
+
BLOCK_SIZE_INDEX: tl.constexpr, # Number of indices each program should process.
|
| 21 |
+
BLOCK_SIZE_COL: tl.constexpr, # Number of cols each program should process.
|
| 22 |
+
):
|
| 23 |
+
pid0 = tl.program_id(axis=0) # We use 2D launch grid
|
| 24 |
+
pid1 = tl.program_id(axis=1)
|
| 25 |
+
|
| 26 |
+
indices = pid0 * BLOCK_SIZE_INDEX + tl.arange(0, BLOCK_SIZE_INDEX)
|
| 27 |
+
rows = tl.load(index_ptr + indices, mask=(indices < num_indices))
|
| 28 |
+
cols = pid1 * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)
|
| 29 |
+
|
| 30 |
+
source_offsets = source_ptr + rows[:, None] * stride0 + cols[None, :] * stride1
|
| 31 |
+
mask = (indices[:, None] < num_indices) & (cols[None, :] < num_cols)
|
| 32 |
+
output = tl.load(source_offsets, mask=mask)
|
| 33 |
+
|
| 34 |
+
output_offsets = output_ptr + indices[:, None] * stride0 + cols[None, :] * stride1
|
| 35 |
+
tl.store(output_offsets, output, mask=mask)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def index_select_cat_fwd(
|
| 39 |
+
output: torch.Tensor,
|
| 40 |
+
source: torch.Tensor,
|
| 41 |
+
index: torch.Tensor,
|
| 42 |
+
):
|
| 43 |
+
if not (source.is_cuda and index.is_cuda):
|
| 44 |
+
raise ValueError("The index tensor and the source tensor must be of type CUDA!")
|
| 45 |
+
|
| 46 |
+
if not source.ndim == 2:
|
| 47 |
+
raise ValueError(f"Expected 2-dimensional tensor, got {source.ndim}.")
|
| 48 |
+
if not index.ndim == 1:
|
| 49 |
+
raise ValueError(f"Expected 1-dimensional tensor, got {index.ndim}.")
|
| 50 |
+
|
| 51 |
+
num_rows, num_cols = source.shape
|
| 52 |
+
num_indices = index.shape[0]
|
| 53 |
+
|
| 54 |
+
if not num_indices < num_rows:
|
| 55 |
+
raise ValueError(
|
| 56 |
+
"The number of indices cannot exceed the number of rows in the source matrix."
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
stride0, stride1 = source.stride(0), source.stride(1)
|
| 60 |
+
|
| 61 |
+
def grid(meta):
|
| 62 |
+
return (
|
| 63 |
+
triton.cdiv(num_indices, meta["BLOCK_SIZE_INDEX"]),
|
| 64 |
+
triton.cdiv(num_cols, meta["BLOCK_SIZE_COL"]),
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
index_select_cat_fwd_kernel[grid](
|
| 68 |
+
output,
|
| 69 |
+
source,
|
| 70 |
+
index,
|
| 71 |
+
num_indices,
|
| 72 |
+
num_cols,
|
| 73 |
+
stride0,
|
| 74 |
+
stride1,
|
| 75 |
+
BLOCK_SIZE_INDEX=1,
|
| 76 |
+
BLOCK_SIZE_COL=512,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return output
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@triton.jit
|
| 83 |
+
def index_select_cat_bwd_kernel(
|
| 84 |
+
grad_source_ptr, # *Pointer* to grad_source tensor.
|
| 85 |
+
index_ptr, # *Pointer* to index tensor.
|
| 86 |
+
grad_output_ptr, # *Pointer* to grad_output tensor.
|
| 87 |
+
num_rows,
|
| 88 |
+
num_indices,
|
| 89 |
+
num_cols,
|
| 90 |
+
stride0, # Stride information of input and source tensor.
|
| 91 |
+
stride1,
|
| 92 |
+
BLOCK_SIZE_INDEX: tl.constexpr, # Number of indices each program should process.
|
| 93 |
+
BLOCK_SIZE_COL: tl.constexpr, # Number of cols each program should process.
|
| 94 |
+
):
|
| 95 |
+
pid0 = tl.program_id(axis=0) # We use 3D launch grid
|
| 96 |
+
pid1 = tl.program_id(axis=1)
|
| 97 |
+
|
| 98 |
+
cols = pid1 * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)
|
| 99 |
+
|
| 100 |
+
# load grad_output
|
| 101 |
+
grad_output_indices = pid0 * BLOCK_SIZE_INDEX + tl.arange(0, BLOCK_SIZE_INDEX)
|
| 102 |
+
grad_output_offsets = (
|
| 103 |
+
grad_output_ptr
|
| 104 |
+
+ grad_output_indices[:, None] * stride0
|
| 105 |
+
+ cols[None, :] * stride1
|
| 106 |
+
)
|
| 107 |
+
grad_output_mask = (grad_output_indices[:, None] < num_indices) & (
|
| 108 |
+
cols[None, :] < num_cols
|
| 109 |
+
)
|
| 110 |
+
grad_output = tl.load(grad_output_offsets, mask=grad_output_mask).to(tl.float32)
|
| 111 |
+
|
| 112 |
+
# select indices from grad_source
|
| 113 |
+
grad_source_indices = tl.load(
|
| 114 |
+
index_ptr + grad_output_indices, mask=(grad_output_indices < num_indices)
|
| 115 |
+
)
|
| 116 |
+
grad_source_offsets = (
|
| 117 |
+
grad_source_ptr
|
| 118 |
+
+ grad_source_indices[:, None] * stride0
|
| 119 |
+
+ cols[None, :] * stride1
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# compute scaled index add and save
|
| 123 |
+
tl.store(grad_source_offsets, grad_output, mask=grad_output_mask)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def index_select_cat_bwd(
|
| 127 |
+
grad_source: torch.Tensor,
|
| 128 |
+
index: torch.Tensor,
|
| 129 |
+
grad_output: torch.Tensor,
|
| 130 |
+
):
|
| 131 |
+
if not (grad_source.is_cuda and grad_output.is_cuda):
|
| 132 |
+
raise ValueError("The grad_source and grad_output tensor must be of type CUDA!")
|
| 133 |
+
|
| 134 |
+
if not (grad_source.ndim == 2 and grad_output.ndim == 2):
|
| 135 |
+
raise ValueError(
|
| 136 |
+
f"The grad_source and grad_output must be three-dimensional "
|
| 137 |
+
f"(got {grad_source.ndim} and {grad_output.ndim})!"
|
| 138 |
+
)
|
| 139 |
+
if not grad_source.shape[1] == grad_output.shape[1]:
|
| 140 |
+
raise ValueError(
|
| 141 |
+
f"The number of elements along dimension 1 of grad_source and grad_output must be the same "
|
| 142 |
+
f"(got {grad_source.shape[1]} and {grad_output.shape[1]})"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
num_rows, num_cols = grad_source.shape
|
| 146 |
+
num_indices, num_cols = grad_output.shape
|
| 147 |
+
if not num_rows >= num_indices:
|
| 148 |
+
raise ValueError(
|
| 149 |
+
f"The number of elements along dimension 0 of grad_source must be larger than that of grad_output "
|
| 150 |
+
f"(got {num_rows} and {num_indices})!"
|
| 151 |
+
)
|
| 152 |
+
if not index.shape[0] == num_indices:
|
| 153 |
+
raise ValueError(
|
| 154 |
+
f"The number of indices and the number of elements along dimension 0 of grad_output must match "
|
| 155 |
+
f"(got {index.shape[0]} and {num_indices})!"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
stride0, stride1 = grad_source.stride(0), grad_source.stride(1)
|
| 159 |
+
if not (grad_output.stride(0) == stride0 and grad_output.stride(1) == stride1):
|
| 160 |
+
raise ValueError(
|
| 161 |
+
f"The strides of the grad_source and grad_output tensors must match "
|
| 162 |
+
f"(got {stride0} vs. {grad_output.stride(0)}, {stride1} vs. {grad_output.stride(1)})!"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def grid(meta):
|
| 166 |
+
return (
|
| 167 |
+
triton.cdiv(num_indices, meta["BLOCK_SIZE_INDEX"]),
|
| 168 |
+
triton.cdiv(num_cols, meta["BLOCK_SIZE_COL"]),
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
index_select_cat_bwd_kernel[grid](
|
| 172 |
+
grad_source,
|
| 173 |
+
index,
|
| 174 |
+
grad_output,
|
| 175 |
+
num_rows,
|
| 176 |
+
num_indices,
|
| 177 |
+
num_cols,
|
| 178 |
+
grad_source.stride(0),
|
| 179 |
+
grad_source.stride(1),
|
| 180 |
+
BLOCK_SIZE_INDEX=1,
|
| 181 |
+
BLOCK_SIZE_COL=512,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
return
|
.venv/lib/python3.11/site-packages/xformers/ops/_triton/k_scaled_index_add.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import triton
|
| 10 |
+
import triton.language as tl
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@triton.jit
|
| 14 |
+
def scaled_index_add_fwd_kernel(
|
| 15 |
+
input_ptr, # *Pointer* to input tensor.
|
| 16 |
+
index_ptr, # *Pointer* to index tensor.
|
| 17 |
+
source_ptr, # *Pointer* to source tensor.
|
| 18 |
+
scaling_ptr, # *Pointer* to the scaling tensor.
|
| 19 |
+
alpha,
|
| 20 |
+
num_inp_indices,
|
| 21 |
+
num_src_indices,
|
| 22 |
+
num_rows,
|
| 23 |
+
num_cols,
|
| 24 |
+
stride0, # Stride information of input and source tensor.
|
| 25 |
+
stride1,
|
| 26 |
+
stride2,
|
| 27 |
+
BLOCK_SIZE_INDEX: tl.constexpr, # Number of indices each program should process.
|
| 28 |
+
BLOCK_SIZE_ROW: tl.constexpr, # Number of rows each program should process.
|
| 29 |
+
BLOCK_SIZE_COL: tl.constexpr, # Number of cols each program should process.
|
| 30 |
+
HAS_SCALING: tl.constexpr, # Boolean indicating if the scaling factor is present.
|
| 31 |
+
):
|
| 32 |
+
pid0 = tl.program_id(axis=0) # We use 3D launch grid
|
| 33 |
+
pid1 = tl.program_id(axis=1)
|
| 34 |
+
pid2 = tl.program_id(axis=2)
|
| 35 |
+
|
| 36 |
+
rows = pid1 * BLOCK_SIZE_ROW + tl.arange(0, BLOCK_SIZE_ROW)
|
| 37 |
+
cols = pid2 * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)
|
| 38 |
+
|
| 39 |
+
# load source
|
| 40 |
+
source_indices = pid0 * BLOCK_SIZE_INDEX + tl.arange(0, BLOCK_SIZE_INDEX)
|
| 41 |
+
source_offsets = (
|
| 42 |
+
source_ptr
|
| 43 |
+
+ source_indices[:, None, None] * stride0
|
| 44 |
+
+ rows[None, :, None] * stride1
|
| 45 |
+
+ cols[None, None, :] * stride2
|
| 46 |
+
)
|
| 47 |
+
source_mask = (
|
| 48 |
+
(source_indices[:, None, None] < num_src_indices)
|
| 49 |
+
& (rows[None, :, None] < num_rows)
|
| 50 |
+
& (cols[None, None, :] < num_cols)
|
| 51 |
+
)
|
| 52 |
+
source = tl.load(source_offsets, mask=source_mask).to(tl.float32)
|
| 53 |
+
|
| 54 |
+
# load input
|
| 55 |
+
input_indices = tl.load(
|
| 56 |
+
index_ptr + source_indices, mask=(source_indices < num_src_indices)
|
| 57 |
+
)
|
| 58 |
+
input_offsets = (
|
| 59 |
+
input_ptr
|
| 60 |
+
+ input_indices[:, None, None] * stride0
|
| 61 |
+
+ rows[None, :, None] * stride1
|
| 62 |
+
+ cols[None, None, :] * stride2
|
| 63 |
+
)
|
| 64 |
+
x = tl.load(input_offsets, mask=source_mask).to(tl.float32)
|
| 65 |
+
|
| 66 |
+
# compute scaled index add and save
|
| 67 |
+
if HAS_SCALING:
|
| 68 |
+
scaling = tl.load(
|
| 69 |
+
scaling_ptr + cols[None, None, :] * stride2,
|
| 70 |
+
mask=(cols[None, None, :] < num_cols),
|
| 71 |
+
).to(tl.float32)
|
| 72 |
+
tl.store(input_offsets, x + alpha * scaling * source, mask=source_mask)
|
| 73 |
+
else:
|
| 74 |
+
tl.store(input_offsets, x + alpha * source, mask=source_mask)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def scaled_index_add_fwd(
|
| 78 |
+
x: torch.Tensor,
|
| 79 |
+
index: torch.Tensor,
|
| 80 |
+
source: torch.Tensor,
|
| 81 |
+
scaling: Optional[torch.Tensor],
|
| 82 |
+
alpha: float,
|
| 83 |
+
):
|
| 84 |
+
if not (x.is_cuda and index.is_cuda and source.is_cuda):
|
| 85 |
+
raise ValueError(
|
| 86 |
+
"The input tensor, the index tensor and the source tensor must be of type CUDA!"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
if not (x.ndim == 3 and source.ndim == 3):
|
| 90 |
+
raise ValueError(
|
| 91 |
+
f"The input and source must be three-dimensional (got {x.ndim} and {source.ndim})!"
|
| 92 |
+
)
|
| 93 |
+
if not x.shape[1] == source.shape[1]:
|
| 94 |
+
raise ValueError(
|
| 95 |
+
f"The number of elements along dimension 1 of the input and source must be the same "
|
| 96 |
+
f"(got {x.shape[1], } and {source.shape[1], })!"
|
| 97 |
+
)
|
| 98 |
+
if not x.shape[2] == source.shape[2]:
|
| 99 |
+
raise ValueError(
|
| 100 |
+
f"The number of elements along dimension 2 of the input and source must be the same "
|
| 101 |
+
f"(got {x.shape[2], } and {source.shape[2], })!"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
num_inp_indices, num_rows, num_cols = x.shape
|
| 105 |
+
num_src_indices, num_rows, num_cols = source.shape
|
| 106 |
+
if not num_inp_indices >= num_src_indices:
|
| 107 |
+
raise ValueError(
|
| 108 |
+
f"The number of elements along dimension 0 of the input must be larger than that of source "
|
| 109 |
+
f"(got {num_inp_indices} and {num_src_indices})!"
|
| 110 |
+
)
|
| 111 |
+
if not index.shape[0] == num_src_indices:
|
| 112 |
+
raise ValueError(
|
| 113 |
+
f"The number of indices and source tensors must match (got {len(index)} and {len(source)})!"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
stride0, stride1, stride2 = x.stride(0), x.stride(1), x.stride(2)
|
| 117 |
+
if not (
|
| 118 |
+
source.stride(0) == stride0
|
| 119 |
+
and source.stride(1) == stride1
|
| 120 |
+
and source.stride(2) == stride2
|
| 121 |
+
):
|
| 122 |
+
raise ValueError(
|
| 123 |
+
f"The strides of the source and input tensors must match (got {source.stride(0)} vs. {stride0}, "
|
| 124 |
+
f"{source.stride(1)} vs. {stride1}, {source.stride(2)} vs. {stride2})!"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if scaling is None:
|
| 128 |
+
HAS_SCALING = False
|
| 129 |
+
else:
|
| 130 |
+
HAS_SCALING = True
|
| 131 |
+
if not scaling.is_cuda:
|
| 132 |
+
raise ValueError("The scaling tensor must be of type CUDA!")
|
| 133 |
+
if not (scaling.ndim == 1 and scaling.numel() == num_cols):
|
| 134 |
+
raise ValueError(
|
| 135 |
+
f"The scaling tensor must be a 1-dimensional tensor (got {scaling.ndim}) and its size "
|
| 136 |
+
f"must be equal to the size of dimension 2 of source (got {scaling.numel()} vs. {num_cols})."
|
| 137 |
+
)
|
| 138 |
+
if not scaling.stride(0) == stride2:
|
| 139 |
+
raise ValueError(
|
| 140 |
+
f"The stride of scaling must match the stride2 of input (got {scaling.stride(0)} vs. {stride2})"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if not index.ndim == 1:
|
| 144 |
+
raise ValueError(f"The index must be one-dimensional (got {index.ndim})!")
|
| 145 |
+
|
| 146 |
+
def grid(meta):
|
| 147 |
+
return (
|
| 148 |
+
triton.cdiv(num_src_indices, meta["BLOCK_SIZE_INDEX"]),
|
| 149 |
+
triton.cdiv(num_rows, meta["BLOCK_SIZE_ROW"]),
|
| 150 |
+
triton.cdiv(num_cols, meta["BLOCK_SIZE_COL"]),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
scaled_index_add_fwd_kernel[grid](
|
| 154 |
+
x,
|
| 155 |
+
index,
|
| 156 |
+
source,
|
| 157 |
+
scaling,
|
| 158 |
+
alpha,
|
| 159 |
+
num_inp_indices,
|
| 160 |
+
num_src_indices,
|
| 161 |
+
num_rows,
|
| 162 |
+
num_cols,
|
| 163 |
+
x.stride(0),
|
| 164 |
+
x.stride(1),
|
| 165 |
+
x.stride(2),
|
| 166 |
+
BLOCK_SIZE_INDEX=1,
|
| 167 |
+
BLOCK_SIZE_ROW=1,
|
| 168 |
+
BLOCK_SIZE_COL=512,
|
| 169 |
+
HAS_SCALING=HAS_SCALING,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
return
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@triton.jit
|
| 176 |
+
def scaled_index_add_bwd_kernel(
|
| 177 |
+
grad_output_ptr, # *Pointer* to input tensor.
|
| 178 |
+
grad_source_ptr, # *Pointer* to index tensor.
|
| 179 |
+
grad_scaling_ptr, # *Pointer* to source tensor.
|
| 180 |
+
source_ptr, # *Pointer* to the source tensor.
|
| 181 |
+
scaling_ptr, # *Pointer* to the scaling tensor.
|
| 182 |
+
index_ptr,
|
| 183 |
+
alpha,
|
| 184 |
+
num_inp_indices,
|
| 185 |
+
num_src_indices,
|
| 186 |
+
num_rows,
|
| 187 |
+
num_cols,
|
| 188 |
+
stride0, # Stride information of input and source tensor.
|
| 189 |
+
stride1,
|
| 190 |
+
stride2,
|
| 191 |
+
BLOCK_SIZE_INDEX: tl.constexpr, # Number of indices each program should process.
|
| 192 |
+
BLOCK_SIZE_ROW: tl.constexpr, # Number of rows each program should process.
|
| 193 |
+
BLOCK_SIZE_COL: tl.constexpr, # Number of cols each program should process.
|
| 194 |
+
HAS_SCALING: tl.constexpr, # Boolean indicating if the scaling factor is present.
|
| 195 |
+
):
|
| 196 |
+
pid0 = tl.program_id(axis=0) # We use 3D launch grid
|
| 197 |
+
pid1 = tl.program_id(axis=1)
|
| 198 |
+
pid2 = tl.program_id(axis=2)
|
| 199 |
+
|
| 200 |
+
rows = pid1 * BLOCK_SIZE_ROW + tl.arange(0, BLOCK_SIZE_ROW)
|
| 201 |
+
cols = pid2 * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)
|
| 202 |
+
|
| 203 |
+
# load source
|
| 204 |
+
source_indices = pid0 * BLOCK_SIZE_INDEX + tl.arange(0, BLOCK_SIZE_INDEX)
|
| 205 |
+
source_offsets = (
|
| 206 |
+
source_ptr
|
| 207 |
+
+ source_indices[:, None, None] * stride0
|
| 208 |
+
+ rows[None, :, None] * stride1
|
| 209 |
+
+ cols[None, None, :] * stride2
|
| 210 |
+
)
|
| 211 |
+
source_mask = (
|
| 212 |
+
(source_indices[:, None, None] < num_src_indices)
|
| 213 |
+
& (rows[None, :, None] < num_rows)
|
| 214 |
+
& (cols[None, None, :] < num_cols)
|
| 215 |
+
)
|
| 216 |
+
source = tl.load(source_offsets, mask=source_mask).to(tl.float32)
|
| 217 |
+
|
| 218 |
+
# load grad_output
|
| 219 |
+
grad_output_indices = tl.load(
|
| 220 |
+
index_ptr + source_indices, mask=(source_indices < num_src_indices)
|
| 221 |
+
)
|
| 222 |
+
grad_output_offsets = (
|
| 223 |
+
grad_output_ptr
|
| 224 |
+
+ grad_output_indices * stride0
|
| 225 |
+
+ rows[None, :, None] * stride1
|
| 226 |
+
+ cols[None, None, :] * stride2
|
| 227 |
+
)
|
| 228 |
+
grad_output = tl.load(grad_output_offsets, mask=source_mask).to(tl.float32)
|
| 229 |
+
|
| 230 |
+
# compute gradient
|
| 231 |
+
grad_source_offsets = (
|
| 232 |
+
grad_source_ptr
|
| 233 |
+
+ source_indices[:, None, None] * stride0
|
| 234 |
+
+ rows[None, :, None] * stride1
|
| 235 |
+
+ cols[None, None, :] * stride2
|
| 236 |
+
)
|
| 237 |
+
if HAS_SCALING:
|
| 238 |
+
scaling = tl.load(
|
| 239 |
+
scaling_ptr + cols[None, None, :] * stride2,
|
| 240 |
+
mask=(cols[None, None, :] < num_cols),
|
| 241 |
+
).to(tl.float32)
|
| 242 |
+
|
| 243 |
+
tl.store(grad_source_offsets, alpha * grad_output * scaling, mask=source_mask)
|
| 244 |
+
|
| 245 |
+
grad_scaling_offsets = (
|
| 246 |
+
grad_scaling_ptr
|
| 247 |
+
+ source_indices[:, None, None] * stride0
|
| 248 |
+
+ rows[None, :, None] * stride1
|
| 249 |
+
+ cols[None, None, :] * stride2
|
| 250 |
+
)
|
| 251 |
+
tl.store(grad_scaling_offsets, alpha * grad_output * source, mask=source_mask)
|
| 252 |
+
else:
|
| 253 |
+
tl.store(grad_source_offsets, alpha * grad_output, mask=source_mask)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def scaled_index_add_bwd(
|
| 257 |
+
grad_output: torch.Tensor,
|
| 258 |
+
grad_source: torch.Tensor,
|
| 259 |
+
grad_scaling: Optional[torch.Tensor],
|
| 260 |
+
source: torch.Tensor,
|
| 261 |
+
scaling: Optional[torch.Tensor],
|
| 262 |
+
index: torch.Tensor,
|
| 263 |
+
alpha: float,
|
| 264 |
+
):
|
| 265 |
+
if not (grad_output.is_cuda and grad_source.is_cuda):
|
| 266 |
+
raise ValueError(
|
| 267 |
+
"The grad_output tensor and grad_source tensor must be of type CUDA!"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
if not (grad_output.ndim == 3 and source.ndim == 3):
|
| 271 |
+
raise ValueError(
|
| 272 |
+
f"The input and source must be three-dimensional (got {grad_output.ndim} and {source.ndim})!"
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
if not grad_output.shape[1] == source.shape[1]:
|
| 276 |
+
raise ValueError(
|
| 277 |
+
f"The number of elements along dimension 1 of the input and source must be the same "
|
| 278 |
+
f"(got {grad_output.shape[1], } and {source.shape[1], })!"
|
| 279 |
+
)
|
| 280 |
+
if not grad_output.shape[2] == source.shape[2]:
|
| 281 |
+
raise ValueError(
|
| 282 |
+
f"The number of elements along dimension 2 of the input and source must be the same "
|
| 283 |
+
f"(got {grad_output.shape[2], } and {source.shape[2], })!"
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
num_inp_indices, num_rows, num_cols = grad_output.shape
|
| 287 |
+
num_src_indices, num_rows, num_cols = source.shape
|
| 288 |
+
if not num_inp_indices >= num_src_indices:
|
| 289 |
+
raise ValueError(
|
| 290 |
+
f"The number of elements along dimension 0 of the input must be larger than that of source "
|
| 291 |
+
f"(got {num_inp_indices} and {num_src_indices})!"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
stride0, stride1, stride2 = source.stride(0), source.stride(1), source.stride(2)
|
| 295 |
+
if not (
|
| 296 |
+
grad_output.stride(0) == stride0
|
| 297 |
+
and grad_output.stride(1) == stride1
|
| 298 |
+
and grad_output.stride(2) == stride2
|
| 299 |
+
):
|
| 300 |
+
raise ValueError(
|
| 301 |
+
f"The strides of grad_output and source must match "
|
| 302 |
+
f"(got {grad_output.stride(0)} vs {stride0}, {grad_output.stride(1)} vs {stride1}, "
|
| 303 |
+
f"{grad_output.stride(2)} vs {stride2})!"
|
| 304 |
+
)
|
| 305 |
+
if not (
|
| 306 |
+
grad_source.stride(0) == stride0
|
| 307 |
+
and grad_source.stride(1) == stride1
|
| 308 |
+
and grad_source.stride(2) == stride2
|
| 309 |
+
):
|
| 310 |
+
raise ValueError(
|
| 311 |
+
f"The strides of grad_source and source must match "
|
| 312 |
+
f"(got {grad_source.stride(0)} vs {stride0}, {grad_source.stride(1)} vs {stride1}, "
|
| 313 |
+
f"{grad_source.stride(2)} vs {stride2})!"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
if scaling is not None and grad_scaling is not None:
|
| 317 |
+
HAS_SCALING = True
|
| 318 |
+
if not grad_scaling.is_cuda:
|
| 319 |
+
raise ValueError("The scaling tensor must be of type CUDA!")
|
| 320 |
+
if not (
|
| 321 |
+
grad_scaling.stride(0) == stride0
|
| 322 |
+
and grad_scaling.stride(1) == stride1
|
| 323 |
+
and grad_scaling.stride(2) == stride2
|
| 324 |
+
):
|
| 325 |
+
raise ValueError(
|
| 326 |
+
f"The strides of grad_scaling and source must match "
|
| 327 |
+
f"(got {grad_scaling.stride(0)} vs {stride0}, {grad_scaling.stride(1)} vs {stride1}, "
|
| 328 |
+
f"{grad_scaling.stride(2)} vs {stride2})!"
|
| 329 |
+
)
|
| 330 |
+
if not scaling.stride(0) == stride2:
|
| 331 |
+
raise ValueError(
|
| 332 |
+
f"The stride of scaling must match stride2 of source (got {scaling.stride(0)} vs. {stride2})!"
|
| 333 |
+
)
|
| 334 |
+
else:
|
| 335 |
+
HAS_SCALING = False
|
| 336 |
+
|
| 337 |
+
def grid(meta):
|
| 338 |
+
return (
|
| 339 |
+
triton.cdiv(num_src_indices, meta["BLOCK_SIZE_INDEX"]),
|
| 340 |
+
triton.cdiv(num_rows, meta["BLOCK_SIZE_ROW"]),
|
| 341 |
+
triton.cdiv(num_cols, meta["BLOCK_SIZE_COL"]),
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
scaled_index_add_bwd_kernel[grid](
|
| 345 |
+
grad_output,
|
| 346 |
+
grad_source,
|
| 347 |
+
grad_scaling,
|
| 348 |
+
source,
|
| 349 |
+
scaling,
|
| 350 |
+
index,
|
| 351 |
+
alpha,
|
| 352 |
+
num_inp_indices,
|
| 353 |
+
num_src_indices,
|
| 354 |
+
num_rows,
|
| 355 |
+
num_cols,
|
| 356 |
+
stride0,
|
| 357 |
+
stride1,
|
| 358 |
+
stride2,
|
| 359 |
+
BLOCK_SIZE_INDEX=1,
|
| 360 |
+
BLOCK_SIZE_ROW=1,
|
| 361 |
+
BLOCK_SIZE_COL=512,
|
| 362 |
+
HAS_SCALING=HAS_SCALING,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
return
|
.venv/lib/python3.11/site-packages/xformers/ops/_triton/rmsnorm_kernels.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
import torch
|
| 6 |
+
import triton
|
| 7 |
+
import triton.language as tl
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from triton.language.extra.cuda.libdevice import rsqrt
|
| 11 |
+
except ImportError:
|
| 12 |
+
try:
|
| 13 |
+
from triton.language.math import rsqrt
|
| 14 |
+
except ImportError:
|
| 15 |
+
from triton.language.libdevice import rsqrt
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@triton.jit
|
| 19 |
+
def _rms_norm_kernel(
|
| 20 |
+
x_ptr,
|
| 21 |
+
h1_ptr,
|
| 22 |
+
w_ptr,
|
| 23 |
+
eps,
|
| 24 |
+
stride,
|
| 25 |
+
N_COLS: tl.constexpr,
|
| 26 |
+
BLOCK_SIZE: tl.constexpr,
|
| 27 |
+
INCLUDE_WEIGHT: tl.constexpr,
|
| 28 |
+
):
|
| 29 |
+
row = tl.program_id(0).to(tl.int64)
|
| 30 |
+
x_ptr += row * stride
|
| 31 |
+
h1_ptr += row * stride
|
| 32 |
+
|
| 33 |
+
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
| 34 |
+
for offset in range(0, N_COLS, BLOCK_SIZE):
|
| 35 |
+
cols = offset + tl.arange(0, BLOCK_SIZE)
|
| 36 |
+
a = tl.load(
|
| 37 |
+
x_ptr + cols, mask=cols < N_COLS, other=0.0, eviction_policy="evict_last"
|
| 38 |
+
).to(tl.float32)
|
| 39 |
+
_mean += a * a
|
| 40 |
+
rstd = rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps)
|
| 41 |
+
for offset in range(0, N_COLS, BLOCK_SIZE):
|
| 42 |
+
cols = offset + tl.arange(0, BLOCK_SIZE)
|
| 43 |
+
mask = cols < N_COLS
|
| 44 |
+
a = tl.load(
|
| 45 |
+
x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
|
| 46 |
+
).to(tl.float32)
|
| 47 |
+
if INCLUDE_WEIGHT:
|
| 48 |
+
w = tl.load(w_ptr + cols, mask=mask)
|
| 49 |
+
tl.store(h1_ptr + cols, a * rstd * w, mask=mask)
|
| 50 |
+
else:
|
| 51 |
+
tl.store(h1_ptr + cols, a * rstd, mask=mask)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@triton.jit
|
| 55 |
+
def _rms_norm_add_kernel(
|
| 56 |
+
x_ptr,
|
| 57 |
+
y_ptr,
|
| 58 |
+
h1_ptr,
|
| 59 |
+
w_ptr,
|
| 60 |
+
eps,
|
| 61 |
+
stride,
|
| 62 |
+
N_COLS: tl.constexpr,
|
| 63 |
+
BLOCK_SIZE: tl.constexpr,
|
| 64 |
+
INCLUDE_WEIGHT: tl.constexpr,
|
| 65 |
+
):
|
| 66 |
+
row = tl.program_id(0)
|
| 67 |
+
x_ptr += row * stride
|
| 68 |
+
y_ptr += row * stride
|
| 69 |
+
h1_ptr += row * stride
|
| 70 |
+
|
| 71 |
+
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
| 72 |
+
for offset in range(0, N_COLS, BLOCK_SIZE):
|
| 73 |
+
cols = offset + tl.arange(0, BLOCK_SIZE)
|
| 74 |
+
mask = cols < N_COLS
|
| 75 |
+
ax = tl.load(
|
| 76 |
+
x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_last"
|
| 77 |
+
).to(tl.float32)
|
| 78 |
+
ay = tl.load(
|
| 79 |
+
y_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
|
| 80 |
+
).to(tl.float32)
|
| 81 |
+
a = ax + ay
|
| 82 |
+
tl.store(x_ptr + cols, a, mask=mask)
|
| 83 |
+
_mean += a * a
|
| 84 |
+
rstd = rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps)
|
| 85 |
+
for offset in range(0, N_COLS, BLOCK_SIZE):
|
| 86 |
+
cols = offset + tl.arange(0, BLOCK_SIZE)
|
| 87 |
+
mask = cols < N_COLS
|
| 88 |
+
a = tl.load(
|
| 89 |
+
x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
|
| 90 |
+
).to(tl.float32)
|
| 91 |
+
if INCLUDE_WEIGHT:
|
| 92 |
+
w = tl.load(w_ptr + cols, mask=mask)
|
| 93 |
+
tl.store(h1_ptr + cols, a * rstd * w, mask=mask)
|
| 94 |
+
else:
|
| 95 |
+
tl.store(h1_ptr + cols, a * rstd, mask=mask)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _rms_norm_forward(x, attn_norm_weights, eps):
|
| 99 |
+
if not x.is_contiguous():
|
| 100 |
+
raise ValueError("data must be contiguous")
|
| 101 |
+
if attn_norm_weights is not None:
|
| 102 |
+
if not attn_norm_weights.is_contiguous():
|
| 103 |
+
raise ValueError("weights must be contiguous")
|
| 104 |
+
out = torch.empty_like(x)
|
| 105 |
+
x_arg = x.reshape(-1, x.shape[-1])
|
| 106 |
+
M, N = x_arg.shape
|
| 107 |
+
# Less than 64KB per feature: enqueue fused kernel
|
| 108 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 109 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 110 |
+
BLOCK_SIZE = max(BLOCK_SIZE, 128)
|
| 111 |
+
BLOCK_SIZE = min(BLOCK_SIZE, 8192)
|
| 112 |
+
# heuristics for number of warps
|
| 113 |
+
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
| 114 |
+
with torch.cuda.device(x.device):
|
| 115 |
+
_rms_norm_kernel[(M,)](
|
| 116 |
+
x_arg,
|
| 117 |
+
out,
|
| 118 |
+
attn_norm_weights,
|
| 119 |
+
eps,
|
| 120 |
+
x_arg.stride(0),
|
| 121 |
+
N,
|
| 122 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 123 |
+
num_warps=num_warps,
|
| 124 |
+
INCLUDE_WEIGHT=attn_norm_weights is not None,
|
| 125 |
+
)
|
| 126 |
+
return out
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _rms_norm_add_forward(x, y, attn_norm_weights, eps):
|
| 130 |
+
# x, y contiguous of same shape [..., n]
|
| 131 |
+
# output of same shape, normed over the last dim.
|
| 132 |
+
if not x.is_contiguous():
|
| 133 |
+
raise ValueError("x must be contiguous")
|
| 134 |
+
if not y.is_contiguous():
|
| 135 |
+
raise ValueError("y must be contiguous")
|
| 136 |
+
if attn_norm_weights is not None:
|
| 137 |
+
if not attn_norm_weights.is_contiguous():
|
| 138 |
+
raise ValueError("weights must be contiguous")
|
| 139 |
+
out = torch.empty_like(x)
|
| 140 |
+
x_arg = x.reshape(-1, x.shape[-1])
|
| 141 |
+
y_arg = y.reshape(-1, x.shape[-1])
|
| 142 |
+
M, N = x_arg.shape
|
| 143 |
+
# Less than 64KB per feature: enqueue fused kernel
|
| 144 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
| 145 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
| 146 |
+
BLOCK_SIZE = max(BLOCK_SIZE, 128)
|
| 147 |
+
BLOCK_SIZE = min(BLOCK_SIZE, 8192)
|
| 148 |
+
# heuristics for number of warps
|
| 149 |
+
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
| 150 |
+
with torch.cuda.device(x.device):
|
| 151 |
+
_rms_norm_add_kernel[(M,)](
|
| 152 |
+
x_arg,
|
| 153 |
+
y_arg,
|
| 154 |
+
out,
|
| 155 |
+
attn_norm_weights,
|
| 156 |
+
eps,
|
| 157 |
+
x_arg.stride(0),
|
| 158 |
+
N,
|
| 159 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 160 |
+
num_warps=num_warps,
|
| 161 |
+
INCLUDE_WEIGHT=attn_norm_weights is not None,
|
| 162 |
+
)
|
| 163 |
+
return out
|
.venv/lib/python3.11/site-packages/xformers/ops/_triton/rope_padded_kernels.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
import triton # type: ignore
|
| 6 |
+
import triton.language as tl # type: ignore
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from triton.language.extra.cuda.libdevice import pow
|
| 10 |
+
except ImportError:
|
| 11 |
+
try:
|
| 12 |
+
from triton.language.math import pow
|
| 13 |
+
except ImportError:
|
| 14 |
+
from triton.language.libdevice import pow
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@triton.jit
|
| 18 |
+
def _rope_padded_kernel(
|
| 19 |
+
xq,
|
| 20 |
+
xk,
|
| 21 |
+
xv,
|
| 22 |
+
out_q,
|
| 23 |
+
cache_k,
|
| 24 |
+
cache_v,
|
| 25 |
+
seqstartq,
|
| 26 |
+
seqstartk,
|
| 27 |
+
seqlenk,
|
| 28 |
+
theta,
|
| 29 |
+
linear_scale,
|
| 30 |
+
use_dynamic_scaling: tl.constexpr,
|
| 31 |
+
dynamic_old_context_len: tl.constexpr,
|
| 32 |
+
dynamic_scale_factor: tl.constexpr,
|
| 33 |
+
dynamic_low_freq_factor: tl.constexpr,
|
| 34 |
+
dynamic_high_freq_factor: tl.constexpr,
|
| 35 |
+
first_seqpos,
|
| 36 |
+
seqpos,
|
| 37 |
+
k_start: tl.constexpr,
|
| 38 |
+
v_start: tl.constexpr,
|
| 39 |
+
n_groups,
|
| 40 |
+
dim: tl.constexpr, # dimension of each head
|
| 41 |
+
stride_xqM,
|
| 42 |
+
stride_xqG,
|
| 43 |
+
stride_xqH,
|
| 44 |
+
stride_xkM,
|
| 45 |
+
stride_xkG,
|
| 46 |
+
stride_xkH,
|
| 47 |
+
stride_xvM,
|
| 48 |
+
stride_xvG,
|
| 49 |
+
stride_xvH,
|
| 50 |
+
stride_cachekM,
|
| 51 |
+
stride_cachekG,
|
| 52 |
+
stride_cachekH,
|
| 53 |
+
stride_cachevM,
|
| 54 |
+
stride_cachevG,
|
| 55 |
+
stride_cachevH,
|
| 56 |
+
stride_seqstartq,
|
| 57 |
+
stride_seqstartk,
|
| 58 |
+
stride_seqlenk,
|
| 59 |
+
stride_outqM,
|
| 60 |
+
stride_outqG,
|
| 61 |
+
stride_outqH,
|
| 62 |
+
stride_seqpos,
|
| 63 |
+
internal_dtype: tl.constexpr,
|
| 64 |
+
# If True, seqstartq and seqstartk are not used but rather we
|
| 65 |
+
# assume that every batch element has the same number of
|
| 66 |
+
# queries (i.e. num_queries := tl.num_programs(1) )
|
| 67 |
+
# and the same cache space cache_padding_length.
|
| 68 |
+
# Always False when called below.
|
| 69 |
+
const_batch_strides: tl.constexpr,
|
| 70 |
+
# If const_batch_strides==True, the common cache length for each batch element.
|
| 71 |
+
# (Only the first seqlenk[i] elements are actually in use, and only the last
|
| 72 |
+
# num_queries of those are actually written to.)
|
| 73 |
+
cache_padding_length,
|
| 74 |
+
# offset added to all values in seqlenk before using them.
|
| 75 |
+
# Always 0 when called below.
|
| 76 |
+
seqlenk_shift: tl.constexpr,
|
| 77 |
+
BLOCK_SIZE: tl.constexpr,
|
| 78 |
+
adjacents: tl.constexpr,
|
| 79 |
+
):
|
| 80 |
+
"""
|
| 81 |
+
Each letter in this diagram is a whole row of length dim.
|
| 82 |
+
|
| 83 |
+
INPUT xq xk xv
|
| 84 |
+
|
| 85 |
+
head_dim ─►
|
| 86 |
+
|
| 87 |
+
batch qqqqqq kk vv
|
| 88 |
+
│ qqqqqq kk vv
|
| 89 |
+
▼ qqqqqq kk vv
|
| 90 |
+
|
| 91 |
+
head_idx: (goes across all heads of all 3 inputs)
|
| 92 |
+
▲ ▲ ▲ ▲ ▲ ▲
|
| 93 |
+
│ │ │ │ │ │
|
| 94 |
+
│ │
|
| 95 |
+
0 k_start │v_start │n_total_heads
|
| 96 |
+
│ │
|
| 97 |
+
│ │
|
| 98 |
+
k_start v_start
|
| 99 |
+
|
| 100 |
+
Output is to out_q (same shape as xq), an xk-shaped part
|
| 101 |
+
of cache_k and an xv-shaped part of cache_v
|
| 102 |
+
"""
|
| 103 |
+
query_pos_in_batch_elt = tl.program_id(0)
|
| 104 |
+
batch_elt = tl.program_id(1)
|
| 105 |
+
group_head_idx = tl.program_id(2)
|
| 106 |
+
group_idx = group_head_idx % n_groups
|
| 107 |
+
head_idx = group_head_idx // n_groups
|
| 108 |
+
|
| 109 |
+
if internal_dtype == "f32":
|
| 110 |
+
theta = theta.to(tl.float32)
|
| 111 |
+
elif internal_dtype == "f64":
|
| 112 |
+
theta = theta.to(tl.float64)
|
| 113 |
+
|
| 114 |
+
if const_batch_strides:
|
| 115 |
+
query_pos = query_pos_in_batch_elt + tl.num_programs(1) * batch_elt
|
| 116 |
+
end_query_pos = tl.num_programs(1) * (batch_elt + 1)
|
| 117 |
+
else:
|
| 118 |
+
query_pos = query_pos_in_batch_elt + tl.load(
|
| 119 |
+
seqstartq + batch_elt * stride_seqstartq
|
| 120 |
+
)
|
| 121 |
+
end_query_pos = tl.load(seqstartq + (batch_elt + 1) * stride_seqstartq)
|
| 122 |
+
if query_pos >= end_query_pos:
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
is_q = head_idx < k_start
|
| 126 |
+
is_v = head_idx >= v_start
|
| 127 |
+
|
| 128 |
+
xq += query_pos * stride_xqM + head_idx * stride_xqH + group_idx * stride_xqG
|
| 129 |
+
out_q += (
|
| 130 |
+
query_pos * stride_outqM + head_idx * stride_outqH + group_idx * stride_outqG
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
if const_batch_strides:
|
| 134 |
+
cache_start = cache_padding_length * batch_elt
|
| 135 |
+
else:
|
| 136 |
+
cache_start = tl.load(seqstartk + batch_elt * stride_seqstartk)
|
| 137 |
+
end_of_batch_elt_cache = (
|
| 138 |
+
cache_start + tl.load(seqlenk + batch_elt * stride_seqlenk) + seqlenk_shift
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
cache_pos = end_of_batch_elt_cache - (end_query_pos - query_pos)
|
| 142 |
+
if seqpos is not None:
|
| 143 |
+
seq_pos = tl.load(seqpos + query_pos * stride_seqpos)
|
| 144 |
+
else:
|
| 145 |
+
seq_pos = cache_pos - cache_start
|
| 146 |
+
if first_seqpos is not None:
|
| 147 |
+
seq_pos += tl.load(first_seqpos + batch_elt * stride_seqpos)
|
| 148 |
+
cache_k += (
|
| 149 |
+
(head_idx - k_start) * stride_cachekH
|
| 150 |
+
+ cache_pos * stride_cachekM
|
| 151 |
+
+ group_idx * stride_cachekG
|
| 152 |
+
)
|
| 153 |
+
xk += (
|
| 154 |
+
query_pos * stride_xkM
|
| 155 |
+
+ (head_idx - k_start) * stride_xkH
|
| 156 |
+
+ group_idx * stride_xkG
|
| 157 |
+
)
|
| 158 |
+
in_qk = tl.where(is_q, xq, xk)
|
| 159 |
+
out_qk = tl.where(is_q, out_q, cache_k)
|
| 160 |
+
|
| 161 |
+
cache_v += (
|
| 162 |
+
(head_idx - v_start) * stride_cachevH
|
| 163 |
+
+ cache_pos * stride_cachevM
|
| 164 |
+
+ group_idx * stride_cachevG
|
| 165 |
+
)
|
| 166 |
+
xv += (
|
| 167 |
+
query_pos * stride_xvM
|
| 168 |
+
+ (head_idx - v_start) * stride_xvH
|
| 169 |
+
+ group_idx * stride_xvG
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
out = tl.where(is_v, cache_v, out_qk)
|
| 173 |
+
x_in = tl.where(is_v, xv, in_qk)
|
| 174 |
+
|
| 175 |
+
for offset in range(0, dim // 2, BLOCK_SIZE // 2):
|
| 176 |
+
c = tl.arange(0, BLOCK_SIZE // 2)
|
| 177 |
+
powers = (offset + c) * 2.0
|
| 178 |
+
if adjacents:
|
| 179 |
+
cols_re = (offset + c) * 2
|
| 180 |
+
cols_im = cols_re + 1
|
| 181 |
+
else:
|
| 182 |
+
cols_re = offset + c
|
| 183 |
+
cols_im = cols_re + dim // 2
|
| 184 |
+
|
| 185 |
+
mask = cols_im < dim
|
| 186 |
+
|
| 187 |
+
re_x = tl.load(x_in + cols_re, mask=mask)
|
| 188 |
+
im_x = tl.load(x_in + cols_im, mask=mask)
|
| 189 |
+
# freqs = seq_pos / (theta ** (powers / dim))
|
| 190 |
+
freqs = pow(theta, powers / (-dim))
|
| 191 |
+
|
| 192 |
+
if use_dynamic_scaling:
|
| 193 |
+
lo_freq_wavelen = dynamic_old_context_len / dynamic_low_freq_factor
|
| 194 |
+
hi_freq_wavelen = dynamic_old_context_len / dynamic_high_freq_factor
|
| 195 |
+
|
| 196 |
+
wavelens = 6.28318530718 / freqs # 2*pi
|
| 197 |
+
is_low_freq = wavelens > lo_freq_wavelen
|
| 198 |
+
freqs = tl.where(is_low_freq, freqs / dynamic_scale_factor, freqs)
|
| 199 |
+
|
| 200 |
+
is_mid_freq = hi_freq_wavelen <= wavelens and wavelens <= lo_freq_wavelen
|
| 201 |
+
|
| 202 |
+
smooth = (dynamic_old_context_len / wavelens - dynamic_low_freq_factor) / (
|
| 203 |
+
dynamic_high_freq_factor - dynamic_low_freq_factor
|
| 204 |
+
)
|
| 205 |
+
freqs = tl.where(
|
| 206 |
+
is_mid_freq,
|
| 207 |
+
(1 - smooth) * freqs / dynamic_scale_factor + smooth * freqs,
|
| 208 |
+
freqs,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
freqs = seq_pos * freqs / linear_scale
|
| 212 |
+
sines = tl.sin(freqs)
|
| 213 |
+
cosines = tl.cos(freqs)
|
| 214 |
+
re_out = re_x * cosines - im_x * sines
|
| 215 |
+
im_out = im_x * cosines + re_x * sines
|
| 216 |
+
|
| 217 |
+
re_out_ = tl.where(is_v, re_x, re_out)
|
| 218 |
+
im_out_ = tl.where(is_v, im_x, im_out)
|
| 219 |
+
if internal_dtype == "f64":
|
| 220 |
+
if re_x.dtype == tl.bfloat16:
|
| 221 |
+
# triton 2.0.0 crashes if you try to convert
|
| 222 |
+
# float64 directly to bfloat16, so make an intermediate step.
|
| 223 |
+
re_out_ = re_out_.to(tl.float32)
|
| 224 |
+
im_out_ = im_out_.to(tl.float32)
|
| 225 |
+
tl.store(out + cols_re, re_out_, mask=mask)
|
| 226 |
+
tl.store(out + cols_im, im_out_, mask=mask)
|
.venv/lib/python3.11/site-packages/xformers/ops/_triton/tiled_matmul_kernels.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import itertools
|
| 8 |
+
from typing import List, Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import triton
|
| 12 |
+
import triton.language as tl
|
| 13 |
+
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def init_to_zero(*names):
|
| 17 |
+
def result(nargs):
|
| 18 |
+
for name in names:
|
| 19 |
+
nargs[name].zero_()
|
| 20 |
+
|
| 21 |
+
return result
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def gen_config(
|
| 25 |
+
block_m: int,
|
| 26 |
+
block_n: int,
|
| 27 |
+
block_k: int,
|
| 28 |
+
stages: int,
|
| 29 |
+
warps: int,
|
| 30 |
+
split_k: int = 1,
|
| 31 |
+
group_m: int = 8,
|
| 32 |
+
) -> triton.Config:
|
| 33 |
+
"""A more compact way to define a triton.Config, so it fits on one line"""
|
| 34 |
+
|
| 35 |
+
return triton.Config(
|
| 36 |
+
{
|
| 37 |
+
"BLOCK_M": block_m,
|
| 38 |
+
"BLOCK_N": block_n,
|
| 39 |
+
"BLOCK_K": block_k,
|
| 40 |
+
"SPLIT_K": split_k,
|
| 41 |
+
"GROUP_M": group_m,
|
| 42 |
+
},
|
| 43 |
+
num_stages=stages,
|
| 44 |
+
num_warps=warps,
|
| 45 |
+
pre_hook=init_to_zero(*[f"C{i+1}{j+1}" for i in range(3) for j in range(3)])
|
| 46 |
+
if split_k > 1
|
| 47 |
+
else init_to_zero(),
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
BASIC_MATMUL_CONFIGS = [
|
| 52 |
+
gen_config(block_m=128, block_n=256, block_k=32, stages=3, warps=8),
|
| 53 |
+
gen_config(block_m=256, block_n=128, block_k=32, stages=3, warps=8),
|
| 54 |
+
gen_config(block_m=256, block_n=64, block_k=32, stages=4, warps=4),
|
| 55 |
+
gen_config(block_m=64, block_n=256, block_k=32, stages=4, warps=4),
|
| 56 |
+
gen_config(block_m=128, block_n=128, block_k=32, stages=4, warps=4),
|
| 57 |
+
gen_config(block_m=128, block_n=64, block_k=32, stages=4, warps=4),
|
| 58 |
+
gen_config(block_m=64, block_n=128, block_k=32, stages=4, warps=4),
|
| 59 |
+
gen_config(block_m=128, block_n=32, block_k=32, stages=4, warps=4),
|
| 60 |
+
gen_config(block_m=64, block_n=32, block_k=32, stages=5, warps=2),
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
INT8_MATMUL_CONFIGS = [
|
| 65 |
+
gen_config(block_m=128, block_n=256, block_k=128, stages=3, warps=8),
|
| 66 |
+
gen_config(block_m=256, block_n=128, block_k=128, stages=3, warps=8),
|
| 67 |
+
gen_config(block_m=256, block_n=64, block_k=128, stages=4, warps=4),
|
| 68 |
+
gen_config(block_m=64, block_n=256, block_k=128, stages=4, warps=4),
|
| 69 |
+
gen_config(block_m=128, block_n=128, block_k=128, stages=4, warps=4),
|
| 70 |
+
gen_config(block_m=128, block_n=64, block_k=64, stages=4, warps=4),
|
| 71 |
+
gen_config(block_m=64, block_n=128, block_k=64, stages=4, warps=4),
|
| 72 |
+
gen_config(block_m=128, block_n=32, block_k=64, stages=4, warps=4),
|
| 73 |
+
gen_config(block_m=64, block_n=32, block_k=64, stages=5, warps=2),
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
IO_BOUND_MATMUL_CONFIGS_STAGES = [2, 3, 4, 5, 6]
|
| 78 |
+
IO_BOUND_MATMUL_CONFIGS_BLOCK_M = [16, 32]
|
| 79 |
+
IO_BOUND_MATMUL_CONFIGS_BLOCK_K = [32, 64]
|
| 80 |
+
IO_BOUND_MATMUL_CONFIGS_BLOCK_N = [32, 64, 128, 256]
|
| 81 |
+
IO_BOUND_MATMUL_CONFIGS_SPLIT_K = [1, 2, 4, 8, 16]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
IO_BOUND_MATMUL_CONFIGS = [
|
| 85 |
+
gen_config(
|
| 86 |
+
block_m=block_m,
|
| 87 |
+
block_n=block_n,
|
| 88 |
+
block_k=block_k,
|
| 89 |
+
stages=stages,
|
| 90 |
+
warps=2 if block_n <= 64 else 4,
|
| 91 |
+
split_k=split_k,
|
| 92 |
+
)
|
| 93 |
+
for stages, block_m, block_k, block_n, split_k in itertools.product(
|
| 94 |
+
IO_BOUND_MATMUL_CONFIGS_STAGES,
|
| 95 |
+
IO_BOUND_MATMUL_CONFIGS_BLOCK_M,
|
| 96 |
+
IO_BOUND_MATMUL_CONFIGS_BLOCK_K,
|
| 97 |
+
IO_BOUND_MATMUL_CONFIGS_BLOCK_N,
|
| 98 |
+
IO_BOUND_MATMUL_CONFIGS_SPLIT_K,
|
| 99 |
+
)
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
TRITON_CONFIGS = BASIC_MATMUL_CONFIGS + INT8_MATMUL_CONFIGS + IO_BOUND_MATMUL_CONFIGS
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def our_estimate_matmul_time(
|
| 107 |
+
A11, B11, C11, M1, M2, M3, N1, N2, N3, K1, K2, K3, **kwargs
|
| 108 |
+
):
|
| 109 |
+
"""Call into Triton's upstream cost model, with the right args
|
| 110 |
+
|
| 111 |
+
The upstream function expects arguments to have certain names. Since we
|
| 112 |
+
renamed a few of them in our implementation, we rename them back.
|
| 113 |
+
|
| 114 |
+
At the time of writing (July 2023) the arguments that Triton expects are:
|
| 115 |
+
M, N, K, A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages.
|
| 116 |
+
|
| 117 |
+
"""
|
| 118 |
+
return estimate_matmul_time(
|
| 119 |
+
M=M1 + M2 + M3, N=N1 + N2 + N3, K=K1 + K2 + K3, A=A11, B=B11, C=C11, **kwargs
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def our_early_config_prune(config, named_args, **kwargs):
|
| 124 |
+
new_named_args = named_args.copy()
|
| 125 |
+
new_named_args["M"] = named_args["M1"] + named_args["M2"] + named_args["M3"]
|
| 126 |
+
new_named_args["N"] = named_args["N1"] + named_args["N2"] + named_args["N3"]
|
| 127 |
+
new_named_args["K"] = named_args["K1"] + named_args["K2"] + named_args["K3"]
|
| 128 |
+
new_named_args["A"] = named_args["A11"]
|
| 129 |
+
new_named_args["B"] = named_args["B11"]
|
| 130 |
+
new_named_args["C"] = named_args["C11"]
|
| 131 |
+
return early_config_prune(config, new_named_args, **kwargs)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@triton.autotune(
|
| 135 |
+
configs=TRITON_CONFIGS,
|
| 136 |
+
key=["M1", "M2", "M3", "N1", "N2", "N3", "K1", "K2", "K3"],
|
| 137 |
+
prune_configs_by={
|
| 138 |
+
"early_config_prune": our_early_config_prune,
|
| 139 |
+
"perf_model": our_estimate_matmul_time,
|
| 140 |
+
"top_k": 10,
|
| 141 |
+
},
|
| 142 |
+
)
|
| 143 |
+
@triton.heuristics(
|
| 144 |
+
{
|
| 145 |
+
"EVEN_K": lambda args: all(
|
| 146 |
+
k % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
|
| 147 |
+
for k in [args["K1"], args["K2"], args["K3"]]
|
| 148 |
+
),
|
| 149 |
+
}
|
| 150 |
+
)
|
| 151 |
+
@triton.jit()
|
| 152 |
+
def _xformers_tiled_matmul_kernel(
|
| 153 |
+
A11,
|
| 154 |
+
A12,
|
| 155 |
+
A13,
|
| 156 |
+
A21,
|
| 157 |
+
A22,
|
| 158 |
+
A23,
|
| 159 |
+
A31,
|
| 160 |
+
A32,
|
| 161 |
+
A33,
|
| 162 |
+
B11,
|
| 163 |
+
B12,
|
| 164 |
+
B13,
|
| 165 |
+
B21,
|
| 166 |
+
B22,
|
| 167 |
+
B23,
|
| 168 |
+
B31,
|
| 169 |
+
B32,
|
| 170 |
+
B33,
|
| 171 |
+
C11,
|
| 172 |
+
C12,
|
| 173 |
+
C13,
|
| 174 |
+
C21,
|
| 175 |
+
C22,
|
| 176 |
+
C23,
|
| 177 |
+
C31,
|
| 178 |
+
C32,
|
| 179 |
+
C33,
|
| 180 |
+
M1,
|
| 181 |
+
M2,
|
| 182 |
+
M3,
|
| 183 |
+
N1,
|
| 184 |
+
N2,
|
| 185 |
+
N3,
|
| 186 |
+
K1,
|
| 187 |
+
K2,
|
| 188 |
+
K3,
|
| 189 |
+
stride_am1,
|
| 190 |
+
stride_am2,
|
| 191 |
+
stride_am3,
|
| 192 |
+
stride_ak1,
|
| 193 |
+
stride_ak2,
|
| 194 |
+
stride_ak3,
|
| 195 |
+
stride_bk1,
|
| 196 |
+
stride_bk2,
|
| 197 |
+
stride_bk3,
|
| 198 |
+
stride_bn1,
|
| 199 |
+
stride_bn2,
|
| 200 |
+
stride_bn3,
|
| 201 |
+
stride_cm1,
|
| 202 |
+
stride_cm2,
|
| 203 |
+
stride_cm3,
|
| 204 |
+
stride_cn1,
|
| 205 |
+
stride_cn2,
|
| 206 |
+
stride_cn3,
|
| 207 |
+
BLOCK_M: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL
|
| 208 |
+
BLOCK_N: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL
|
| 209 |
+
BLOCK_K: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL
|
| 210 |
+
GROUP_M: tl.constexpr,
|
| 211 |
+
SPLIT_K: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL
|
| 212 |
+
EVEN_K: tl.constexpr,
|
| 213 |
+
ACC_TYPE: tl.constexpr,
|
| 214 |
+
):
|
| 215 |
+
# matrix multiplication
|
| 216 |
+
pid = tl.program_id(0)
|
| 217 |
+
pid_k = tl.program_id(1)
|
| 218 |
+
grid_m1 = tl.cdiv(M1, BLOCK_M)
|
| 219 |
+
grid_m2 = tl.cdiv(M2, BLOCK_M)
|
| 220 |
+
grid_m3 = tl.cdiv(M3, BLOCK_M)
|
| 221 |
+
grid_n1 = tl.cdiv(N1, BLOCK_N)
|
| 222 |
+
grid_n2 = tl.cdiv(N2, BLOCK_N)
|
| 223 |
+
grid_n3 = tl.cdiv(N3, BLOCK_N)
|
| 224 |
+
grid_m = grid_m1 + grid_m2 + grid_m3
|
| 225 |
+
grid_n = grid_n1 + grid_n2 + grid_n3
|
| 226 |
+
|
| 227 |
+
# re-order program ID for better L2 performance
|
| 228 |
+
width = GROUP_M * grid_n
|
| 229 |
+
group_id = pid // width
|
| 230 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 231 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 232 |
+
pid_n = (pid % width) // (group_size)
|
| 233 |
+
|
| 234 |
+
# We use tl.where to circumvent a regression in alignment auto-detection:
|
| 235 |
+
# https://github.com/openai/triton/issues/1784
|
| 236 |
+
|
| 237 |
+
A1 = tl.where(pid_m < grid_m1, A11, tl.where(pid_m < grid_m1 + grid_m2, A21, A31))
|
| 238 |
+
A2 = tl.where(pid_m < grid_m1, A12, tl.where(pid_m < grid_m1 + grid_m2, A22, A32))
|
| 239 |
+
A3 = tl.where(pid_m < grid_m1, A13, tl.where(pid_m < grid_m1 + grid_m2, A23, A33))
|
| 240 |
+
B1 = tl.where(pid_n < grid_n1, B11, tl.where(pid_n < grid_n1 + grid_n2, B12, B13))
|
| 241 |
+
B2 = tl.where(pid_n < grid_n1, B21, tl.where(pid_n < grid_n1 + grid_n2, B22, B23))
|
| 242 |
+
B3 = tl.where(pid_n < grid_n1, B31, tl.where(pid_n < grid_n1 + grid_n2, B32, B33))
|
| 243 |
+
C = tl.where(
|
| 244 |
+
pid_m < grid_m1,
|
| 245 |
+
tl.where(pid_n < grid_n1, C11, tl.where(pid_n < grid_n1 + grid_n2, C12, C13)),
|
| 246 |
+
tl.where(
|
| 247 |
+
pid_m < grid_m1 + grid_m2,
|
| 248 |
+
tl.where(
|
| 249 |
+
pid_n < grid_n1, C21, tl.where(pid_n < grid_n1 + grid_n2, C22, C23)
|
| 250 |
+
),
|
| 251 |
+
tl.where(
|
| 252 |
+
pid_n < grid_n1, C31, tl.where(pid_n < grid_n1 + grid_n2, C32, C33)
|
| 253 |
+
),
|
| 254 |
+
),
|
| 255 |
+
)
|
| 256 |
+
M = tl.where(pid_m < grid_m1, M1, tl.where(pid_m < grid_m1 + grid_m2, M2, M3))
|
| 257 |
+
N = tl.where(pid_n < grid_n1, N1, tl.where(pid_n < grid_n1 + grid_n2, N2, N3))
|
| 258 |
+
stride_ak = tl.where(
|
| 259 |
+
pid_m < grid_m1,
|
| 260 |
+
stride_ak1,
|
| 261 |
+
tl.where(pid_m < grid_m1 + grid_m2, stride_ak2, stride_ak3),
|
| 262 |
+
)
|
| 263 |
+
stride_bk = tl.where(
|
| 264 |
+
pid_n < grid_n1,
|
| 265 |
+
stride_bk1,
|
| 266 |
+
tl.where(pid_n < grid_n1 + grid_n2, stride_bk2, stride_bk3),
|
| 267 |
+
)
|
| 268 |
+
stride_cn = tl.where(
|
| 269 |
+
pid_m < grid_m1,
|
| 270 |
+
stride_cn1,
|
| 271 |
+
tl.where(pid_m < grid_m1 + grid_m2, stride_cn2, stride_cn3),
|
| 272 |
+
)
|
| 273 |
+
stride_cm = tl.where(
|
| 274 |
+
pid_n < grid_n1,
|
| 275 |
+
stride_cm1,
|
| 276 |
+
tl.where(pid_n < grid_n1 + grid_n2, stride_cm2, stride_cm3),
|
| 277 |
+
)
|
| 278 |
+
pid_m = tl.where(
|
| 279 |
+
pid_m < grid_m1,
|
| 280 |
+
pid_m,
|
| 281 |
+
tl.where(pid_m < grid_m1 + grid_m2, pid_m - grid_m1, pid_m - grid_m1 - grid_m2),
|
| 282 |
+
)
|
| 283 |
+
pid_n = tl.where(
|
| 284 |
+
pid_n < grid_n1,
|
| 285 |
+
pid_n,
|
| 286 |
+
tl.where(pid_n < grid_n1 + grid_n2, pid_n - grid_n1, pid_n - grid_n1 - grid_n2),
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# do matrix multiplication
|
| 290 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 291 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 292 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 293 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 294 |
+
# pointers
|
| 295 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 296 |
+
grid_k1 = tl.cdiv(K1, BLOCK_K)
|
| 297 |
+
grid_k2 = tl.cdiv(K2, BLOCK_K)
|
| 298 |
+
grid_k3 = tl.cdiv(K3, BLOCK_K)
|
| 299 |
+
for tile in range(pid_k, grid_k1 + grid_k2 + grid_k3, SPLIT_K):
|
| 300 |
+
A = tl.where(tile < grid_k1, A1, tl.where(tile < grid_k1 + grid_k2, A2, A3))
|
| 301 |
+
B = tl.where(tile < grid_k1, B1, tl.where(tile < grid_k1 + grid_k2, B2, B3))
|
| 302 |
+
K = tl.where(tile < grid_k1, K1, tl.where(tile < grid_k1 + grid_k2, K2, K3))
|
| 303 |
+
stride_am = tl.where(
|
| 304 |
+
tile < grid_k1,
|
| 305 |
+
stride_am1,
|
| 306 |
+
tl.where(tile < grid_k1 + grid_k2, stride_am2, stride_am3),
|
| 307 |
+
)
|
| 308 |
+
stride_bn = tl.where(
|
| 309 |
+
tile < grid_k1,
|
| 310 |
+
stride_bn1,
|
| 311 |
+
tl.where(tile < grid_k1 + grid_k2, stride_bn2, stride_bn3),
|
| 312 |
+
)
|
| 313 |
+
my_tile = tl.where(
|
| 314 |
+
tile < grid_k1,
|
| 315 |
+
tile,
|
| 316 |
+
tl.where(
|
| 317 |
+
tile < grid_k1 + grid_k2, tile - grid_k1, tile - grid_k1 - grid_k2
|
| 318 |
+
),
|
| 319 |
+
)
|
| 320 |
+
rk = my_tile * BLOCK_K + tl.arange(0, BLOCK_K)
|
| 321 |
+
Ain = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
| 322 |
+
Bin = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
| 323 |
+
if EVEN_K:
|
| 324 |
+
a = tl.load(Ain)
|
| 325 |
+
b = tl.load(Bin)
|
| 326 |
+
else:
|
| 327 |
+
a = tl.load(Ain, mask=rk[None, :] < K, other=0.0)
|
| 328 |
+
b = tl.load(Bin, mask=rk[:, None] < K, other=0.0)
|
| 329 |
+
acc += tl.dot(a, b, allow_tf32=False)
|
| 330 |
+
acc = acc.to(C.dtype.element_ty)
|
| 331 |
+
# rematerialize rm and rn to save registers
|
| 332 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 333 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 334 |
+
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
| 335 |
+
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
| 336 |
+
# handles write-back with reduction-splitting
|
| 337 |
+
if SPLIT_K == 1:
|
| 338 |
+
tl.store(C, acc, mask=mask)
|
| 339 |
+
else:
|
| 340 |
+
tl.atomic_add(C, acc, mask=mask)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def _check_row_or_column(row_or_col_type, row_or_col_idx, tensor_name, dim_name, vals):
|
| 344 |
+
assert len(vals) > 0
|
| 345 |
+
for pos, val in enumerate(vals[1:]):
|
| 346 |
+
assert val == vals[0], (
|
| 347 |
+
f"the tensors on {row_or_col_type} {row_or_col_idx} of the {tensor_name} "
|
| 348 |
+
f"must all have the same stride along the {dim_name} dimension, got "
|
| 349 |
+
f"{vals[0]} at position 0 and {val} at position {pos + 1}"
|
| 350 |
+
)
|
| 351 |
+
return vals[0]
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def _get_strides(
|
| 355 |
+
ts: List[List[torch.Tensor]], tensor_name, dim_0_name, dim_1_name
|
| 356 |
+
) -> Tuple[List[int], List[int]]:
|
| 357 |
+
strides_0 = [
|
| 358 |
+
_check_row_or_column(
|
| 359 |
+
"column", idx, tensor_name, dim_0_name, [y.stride(0) for y in x]
|
| 360 |
+
)
|
| 361 |
+
for idx, x in enumerate(zip(*ts))
|
| 362 |
+
]
|
| 363 |
+
strides_1 = [
|
| 364 |
+
_check_row_or_column(
|
| 365 |
+
"row", idx, tensor_name, dim_1_name, [y.stride(1) for y in x]
|
| 366 |
+
)
|
| 367 |
+
for idx, x in enumerate(ts)
|
| 368 |
+
]
|
| 369 |
+
assert all(s == 1 for s in strides_0) or all(s == 1 for s in strides_1)
|
| 370 |
+
while len(strides_0) < 3:
|
| 371 |
+
strides_0.append(1 if strides_0[0] == 1 else 0)
|
| 372 |
+
while len(strides_1) < 3:
|
| 373 |
+
strides_1.append(1 if strides_1[0] == 1 else 0)
|
| 374 |
+
return strides_0, strides_1
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def _launch_triton_matmul(
|
| 378 |
+
a: List[List[torch.Tensor]],
|
| 379 |
+
b: List[List[torch.Tensor]],
|
| 380 |
+
c: List[List[torch.Tensor]],
|
| 381 |
+
ms: List[int],
|
| 382 |
+
ns: List[int],
|
| 383 |
+
ks: List[int],
|
| 384 |
+
) -> None:
|
| 385 |
+
strides_am, strides_ak = _get_strides(a, "first operand", "m", "k")
|
| 386 |
+
strides_bk, strides_bn = _get_strides(b, "second operand", "k", "n")
|
| 387 |
+
strides_cm, strides_cn = _get_strides(c, "output", "m", "n")
|
| 388 |
+
|
| 389 |
+
# accumulator types
|
| 390 |
+
ACC_TYPE = (
|
| 391 |
+
tl.float32
|
| 392 |
+
if c[0][0].dtype in [torch.float16, torch.bfloat16, torch.float32]
|
| 393 |
+
else tl.int32
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# launch kernel
|
| 397 |
+
def grid(META):
|
| 398 |
+
return (
|
| 399 |
+
sum(triton.cdiv(m, META["BLOCK_M"]) for m in ms)
|
| 400 |
+
* sum(triton.cdiv(n, META["BLOCK_N"]) for n in ns),
|
| 401 |
+
META["SPLIT_K"],
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
_xformers_tiled_matmul_kernel[grid](
|
| 405 |
+
*[
|
| 406 |
+
a[min(i, len(a) - 1)][min(j, len(a[0]) - 1)]
|
| 407 |
+
for i in range(3)
|
| 408 |
+
for j in range(3)
|
| 409 |
+
],
|
| 410 |
+
*[
|
| 411 |
+
b[min(i, len(b) - 1)][min(j, len(b[0]) - 1)]
|
| 412 |
+
for i in range(3)
|
| 413 |
+
for j in range(3)
|
| 414 |
+
],
|
| 415 |
+
*[
|
| 416 |
+
c[min(i, len(c) - 1)][min(j, len(c[0]) - 1)]
|
| 417 |
+
for i in range(3)
|
| 418 |
+
for j in range(3)
|
| 419 |
+
],
|
| 420 |
+
*[ms[i] if len(ms) > i else 0 for i in range(3)],
|
| 421 |
+
*[ns[i] if len(ns) > i else 0 for i in range(3)],
|
| 422 |
+
*[ks[i] if len(ks) > i else 0 for i in range(3)],
|
| 423 |
+
*strides_am,
|
| 424 |
+
*strides_ak,
|
| 425 |
+
*strides_bk,
|
| 426 |
+
*strides_bn,
|
| 427 |
+
*strides_cm,
|
| 428 |
+
*strides_cn,
|
| 429 |
+
ACC_TYPE=ACC_TYPE,
|
| 430 |
+
)
|
.venv/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py
ADDED
|
@@ -0,0 +1,893 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the BSD license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Any, List, Optional, Sequence, Tuple, Type, Union, cast
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from . import (
|
| 11 |
+
attn_bias,
|
| 12 |
+
ck,
|
| 13 |
+
ck_decoder,
|
| 14 |
+
ck_splitk,
|
| 15 |
+
cutlass,
|
| 16 |
+
flash,
|
| 17 |
+
flash3,
|
| 18 |
+
triton_splitk,
|
| 19 |
+
)
|
| 20 |
+
from .attn_bias import (
|
| 21 |
+
VARLEN_BIASES,
|
| 22 |
+
AttentionBias,
|
| 23 |
+
BlockDiagonalMask,
|
| 24 |
+
LowerTriangularMask,
|
| 25 |
+
)
|
| 26 |
+
from .common import (
|
| 27 |
+
AttentionBwOpBase,
|
| 28 |
+
AttentionFwOpBase,
|
| 29 |
+
AttentionOp,
|
| 30 |
+
AttentionOpBase,
|
| 31 |
+
Context,
|
| 32 |
+
Gradients,
|
| 33 |
+
Inputs,
|
| 34 |
+
bmk2bmhk,
|
| 35 |
+
)
|
| 36 |
+
from .dispatch import (
|
| 37 |
+
_dispatch_bw,
|
| 38 |
+
_dispatch_fw,
|
| 39 |
+
_ensure_op_supports_or_raise,
|
| 40 |
+
_get_use_fa3,
|
| 41 |
+
_set_use_fa3,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp)
|
| 45 |
+
MemoryEfficientAttentionCutlassFwdFlashBwOp = (cutlass.FwOp, flash.BwOp)
|
| 46 |
+
MemoryEfficientAttentionFlashAttentionOp = (flash.FwOp, flash.BwOp)
|
| 47 |
+
MemoryEfficientAttentionCkOp = (ck.FwOp, ck.BwOp)
|
| 48 |
+
MemoryEfficientAttentionCkDecoderOp = (ck_decoder.FwOp, ck.BwOp)
|
| 49 |
+
MemoryEfficientAttentionSplitKCkOp = (ck_splitk.FwOp, ck.BwOp)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _deserialize_bias(attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]) -> Any:
|
| 53 |
+
if attn_bias_tensor is None:
|
| 54 |
+
return attn_bias_ctx
|
| 55 |
+
return attn_bias_tensor
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Note: `torch.compile` only allows custom autograd functions
|
| 59 |
+
# to accept a subset of types. Therefore we serialize `op` objects
|
| 60 |
+
# to `str` before entering the function, and unserialize them inside.
|
| 61 |
+
# See also: https://github.com/pytorch/pytorch/issues/118395
|
| 62 |
+
_OPS_LOOKUP = {
|
| 63 |
+
flash.FwOp.NAME: flash.FwOp,
|
| 64 |
+
flash.BwOp.NAME: flash.BwOp,
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _serialize_op(op):
|
| 69 |
+
if op is not None and op.NAME in _OPS_LOOKUP:
|
| 70 |
+
return op.NAME
|
| 71 |
+
return op
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _unserialize_op(op):
|
| 75 |
+
if isinstance(op, str):
|
| 76 |
+
return _OPS_LOOKUP[op]
|
| 77 |
+
return op
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class _fMHA(torch.autograd.Function):
|
| 81 |
+
@staticmethod
|
| 82 |
+
# type: ignore
|
| 83 |
+
def forward(ctx, op_fw, op_bw, *args: Any) -> Any:
|
| 84 |
+
inp = Inputs(*args)
|
| 85 |
+
|
| 86 |
+
op_fw = _unserialize_op(op_fw)
|
| 87 |
+
op_bw = _unserialize_op(op_bw)
|
| 88 |
+
|
| 89 |
+
out, op_ctx = _memory_efficient_attention_forward_requires_grad(
|
| 90 |
+
inp=inp, op=op_fw
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Saving attn_bias is a bit complicated, as the
|
| 94 |
+
# torch part should go in `save_for_backward`
|
| 95 |
+
if isinstance(inp.attn_bias, torch.Tensor):
|
| 96 |
+
attn_bias_tensor = inp.attn_bias
|
| 97 |
+
attn_bias_ctx = None
|
| 98 |
+
else:
|
| 99 |
+
attn_bias_tensor = None
|
| 100 |
+
attn_bias_ctx = inp.attn_bias
|
| 101 |
+
|
| 102 |
+
ctx.save_for_backward(
|
| 103 |
+
inp.query,
|
| 104 |
+
inp.key,
|
| 105 |
+
inp.value,
|
| 106 |
+
op_ctx.out,
|
| 107 |
+
op_ctx.lse,
|
| 108 |
+
)
|
| 109 |
+
ctx.rng_state = op_ctx.rng_state
|
| 110 |
+
ctx.attn_bias_tensor = attn_bias_tensor
|
| 111 |
+
if op_ctx.op_bw is not None:
|
| 112 |
+
if op_bw is not None and op_bw is not op_ctx.op_bw:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
f"Specified op_bw={op_bw.NAME}, but forward op "
|
| 115 |
+
f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
|
| 116 |
+
)
|
| 117 |
+
op_bw = op_ctx.op_bw
|
| 118 |
+
if (
|
| 119 |
+
op_bw is not None
|
| 120 |
+
and isinstance(inp.attn_bias, VARLEN_BIASES)
|
| 121 |
+
and inp.attn_bias.q_seqinfo.seqstart.shape[0] > 2
|
| 122 |
+
and op_bw.VARLEN_LSE_PACKED != op_fw.VARLEN_LSE_PACKED
|
| 123 |
+
):
|
| 124 |
+
raise ValueError(
|
| 125 |
+
f"Specified op_bw={op_bw.NAME} is not compatible with the "
|
| 126 |
+
f"op_fw={op_fw.NAME}, because they use different format of logsumexp. "
|
| 127 |
+
f"NOTE: This is new with xFormers 0.0.28"
|
| 128 |
+
)
|
| 129 |
+
if op_bw is None and (
|
| 130 |
+
inp.query.requires_grad or inp.key.requires_grad or inp.value.requires_grad
|
| 131 |
+
):
|
| 132 |
+
varlen_lse_packed = _detect_lse_packed_or_raise(op_ctx.lse, inp)
|
| 133 |
+
if varlen_lse_packed is not None and op_fw is not None:
|
| 134 |
+
assert (
|
| 135 |
+
op_fw.VARLEN_LSE_PACKED == varlen_lse_packed
|
| 136 |
+
), f"{op_fw.NAME}: wrong value for `VARLEN_LSE_PACKED` ?"
|
| 137 |
+
# NOTE: We need to check tensor strides to decide which operator we run in the BW pass.
|
| 138 |
+
# Unfortunately, PyTorch only allows to call this function during the FW pass, so
|
| 139 |
+
# we decide the operator to use now.
|
| 140 |
+
op_bw = _dispatch_bw(inp, varlen_lse_packed=varlen_lse_packed)
|
| 141 |
+
ctx.op_fw = op_fw
|
| 142 |
+
ctx.op_bw = op_bw
|
| 143 |
+
ctx.p = inp.p
|
| 144 |
+
# This allows to create gradients from a single storage,
|
| 145 |
+
# to avoid a "cat" in the BW pass.
|
| 146 |
+
# The heuristic is approximative, but:
|
| 147 |
+
# (1) It's not a big issue to create a shared storage
|
| 148 |
+
# (2) The heuristic needs to pass `torch.compile`
|
| 149 |
+
# (this is also why we run it in the FW pass, the BW pass is stricter)
|
| 150 |
+
ctx.qkv_share_storage = (
|
| 151 |
+
inp.query.shape[0] == inp.key.shape[0]
|
| 152 |
+
and inp.query.shape[-1] == inp.value.shape[-1]
|
| 153 |
+
and inp.query.stride(-2)
|
| 154 |
+
== (inp.key.shape[-1] + inp.query.shape[-1] + inp.value.shape[-1])
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
ctx.scale = inp.scale
|
| 158 |
+
ctx.attn_bias_ctx = attn_bias_ctx
|
| 159 |
+
ctx.n_args = len(args)
|
| 160 |
+
return out, op_ctx.lse
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
@torch.autograd.function.once_differentiable
|
| 164 |
+
def backward(ctx, grad, grad_lse):
|
| 165 |
+
# Re-create context
|
| 166 |
+
query, key, value, out, lse = ctx.saved_tensors
|
| 167 |
+
attn_bias_tensor = ctx.attn_bias_tensor
|
| 168 |
+
rng_state = ctx.rng_state
|
| 169 |
+
inp = Inputs(
|
| 170 |
+
query=query,
|
| 171 |
+
key=key,
|
| 172 |
+
value=value,
|
| 173 |
+
attn_bias=_deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
|
| 174 |
+
p=ctx.p,
|
| 175 |
+
scale=ctx.scale,
|
| 176 |
+
)
|
| 177 |
+
op_ctx = Context(
|
| 178 |
+
lse=lse,
|
| 179 |
+
out=out,
|
| 180 |
+
rng_state=rng_state,
|
| 181 |
+
)
|
| 182 |
+
grads = _memory_efficient_attention_backward(
|
| 183 |
+
ctx=op_ctx,
|
| 184 |
+
inp=inp,
|
| 185 |
+
grad=grad,
|
| 186 |
+
op=ctx.op_bw,
|
| 187 |
+
_skip_op_checks=True,
|
| 188 |
+
)
|
| 189 |
+
return (None, None, grads.dq, grads.dk, grads.dv, grads.db) + (None,) * (
|
| 190 |
+
ctx.n_args - 2
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def memory_efficient_attention(
|
| 195 |
+
query: torch.Tensor,
|
| 196 |
+
key: torch.Tensor,
|
| 197 |
+
value: torch.Tensor,
|
| 198 |
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
| 199 |
+
p: float = 0.0,
|
| 200 |
+
scale: Optional[float] = None,
|
| 201 |
+
*,
|
| 202 |
+
op: Optional[AttentionOp] = None,
|
| 203 |
+
output_dtype: Optional[torch.dtype] = None,
|
| 204 |
+
) -> torch.Tensor:
|
| 205 |
+
"""Implements the memory-efficient attention mechanism following
|
| 206 |
+
`"Self-Attention Does Not Need O(n^2) Memory" <http://arxiv.org/abs/2112.05682>`_.
|
| 207 |
+
|
| 208 |
+
:Inputs shape:
|
| 209 |
+
|
| 210 |
+
- Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M \
|
| 211 |
+
the sequence length, H the number of heads, and K the embeding size per head
|
| 212 |
+
|
| 213 |
+
- If inputs have dimension 3, it is assumed that the dimensions are ``[B, M, K]`` and ``H=1``
|
| 214 |
+
|
| 215 |
+
- Inputs can also be of dimension 5 with GQA - see note below
|
| 216 |
+
|
| 217 |
+
- Inputs can be non-contiguous - we only require the last dimension's stride to be 1
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
:Equivalent pytorch code:
|
| 221 |
+
|
| 222 |
+
.. code-block:: python
|
| 223 |
+
|
| 224 |
+
scale = 1.0 / query.shape[-1] ** 0.5
|
| 225 |
+
query = query * scale
|
| 226 |
+
query = query.transpose(1, 2)
|
| 227 |
+
key = key.transpose(1, 2)
|
| 228 |
+
value = value.transpose(1, 2)
|
| 229 |
+
attn = query @ key.transpose(-2, -1)
|
| 230 |
+
if attn_bias is not None:
|
| 231 |
+
attn = attn + attn_bias
|
| 232 |
+
attn = attn.softmax(-1)
|
| 233 |
+
attn = F.dropout(attn, p)
|
| 234 |
+
attn = attn @ value
|
| 235 |
+
return attn.transpose(1, 2)
|
| 236 |
+
|
| 237 |
+
:Examples:
|
| 238 |
+
|
| 239 |
+
.. code-block:: python
|
| 240 |
+
|
| 241 |
+
import xformers.ops as xops
|
| 242 |
+
|
| 243 |
+
# Compute regular attention
|
| 244 |
+
y = xops.memory_efficient_attention(q, k, v)
|
| 245 |
+
|
| 246 |
+
# With a dropout of 0.2
|
| 247 |
+
y = xops.memory_efficient_attention(q, k, v, p=0.2)
|
| 248 |
+
|
| 249 |
+
# Causal attention
|
| 250 |
+
y = xops.memory_efficient_attention(
|
| 251 |
+
q, k, v,
|
| 252 |
+
attn_bias=xops.LowerTriangularMask()
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
:Supported hardware:
|
| 256 |
+
|
| 257 |
+
NVIDIA GPUs with compute capability above 6.0 (P100+), datatype ``f16``, ``bf16`` and ``f32``.
|
| 258 |
+
|
| 259 |
+
:EXPERIMENTAL: Using with Multi Query Attention (MQA) and Grouped Query Attention (GQA):
|
| 260 |
+
|
| 261 |
+
MQA/GQA is an experimental feature supported only for the forward pass.
|
| 262 |
+
If you have 16 heads in query, and 2 in key/value, you can provide 5-dim tensors
|
| 263 |
+
in the ``[B, M, G, H, K]`` format, where ``G`` is the number of head groups (here 2), and
|
| 264 |
+
``H`` is the number of heads per group (8 in the example).
|
| 265 |
+
|
| 266 |
+
Please note that xFormers will not automatically broadcast the inputs, so you will need
|
| 267 |
+
to broadcast it manually before calling `memory_efficient_attention`.
|
| 268 |
+
|
| 269 |
+
:GQA/MQA example:
|
| 270 |
+
|
| 271 |
+
.. code-block:: python
|
| 272 |
+
|
| 273 |
+
import torch
|
| 274 |
+
import xformers.ops as xops
|
| 275 |
+
|
| 276 |
+
B, M, K = 3, 32, 128
|
| 277 |
+
kwargs = dict(device="cuda", dtype=torch.float16)
|
| 278 |
+
q = torch.randn([B, M, 8, K], **kwargs)
|
| 279 |
+
k = torch.randn([B, M, 2, K], **kwargs)
|
| 280 |
+
v = torch.randn([B, M, 2, K], **kwargs)
|
| 281 |
+
out_gqa = xops.memory_efficient_attention(
|
| 282 |
+
q.reshape([B, M, 2, 4, K]),
|
| 283 |
+
k.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
|
| 284 |
+
v.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
Raises:
|
| 288 |
+
NotImplementedError: if there is no operator available to compute the MHA
|
| 289 |
+
ValueError: if inputs are invalid
|
| 290 |
+
|
| 291 |
+
:parameter query: Tensor of shape ``[B, Mq, H, K]``
|
| 292 |
+
:parameter key: Tensor of shape ``[B, Mkv, H, K]``
|
| 293 |
+
:parameter value: Tensor of shape ``[B, Mkv, H, Kv]``
|
| 294 |
+
:parameter attn_bias: Bias to apply to the attention matrix - defaults to no masking. \
|
| 295 |
+
For common biases implemented efficiently in xFormers, see :attr:`xformers.ops.fmha.attn_bias.AttentionBias`. \
|
| 296 |
+
This can also be a :attr:`torch.Tensor` for an arbitrary mask (slower).
|
| 297 |
+
:parameter p: Dropout probability. Disabled if set to ``0.0``
|
| 298 |
+
:parameter scale: Scaling factor for ``Q @ K.transpose()``. If set to ``None``, the default \
|
| 299 |
+
scale (q.shape[-1]**-0.5) will be used.
|
| 300 |
+
:parameter op: The operators to use - see :attr:`xformers.ops.AttentionOpBase`. \
|
| 301 |
+
If set to ``None`` (recommended), xFormers \
|
| 302 |
+
will dispatch to the best available operator, depending on the inputs \
|
| 303 |
+
and options.
|
| 304 |
+
:return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
|
| 305 |
+
"""
|
| 306 |
+
return _memory_efficient_attention(
|
| 307 |
+
Inputs(
|
| 308 |
+
query=query,
|
| 309 |
+
key=key,
|
| 310 |
+
value=value,
|
| 311 |
+
p=p,
|
| 312 |
+
attn_bias=attn_bias,
|
| 313 |
+
scale=scale,
|
| 314 |
+
output_dtype=output_dtype,
|
| 315 |
+
),
|
| 316 |
+
op=op,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
torch.library.define(
|
| 321 |
+
"xformer::memory_efficient_attention_forward",
|
| 322 |
+
"(Tensor q, Tensor k, Tensor v, Tensor? b = None, float? p = 0.0, float? scale = None) -> Tensor",
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
@torch.library.impl("xformer::memory_efficient_attention_forward", "Meta")
|
| 327 |
+
def memory_efficient_attention_forward_meta(q, k, v):
|
| 328 |
+
return q.new_empty(q.shape)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# torch.compile has issue when tracing through op dispatch and ensure_op_support
|
| 332 |
+
# so provide a wrapper to register it as a custom torch library op.
|
| 333 |
+
@torch.library.impl("xformer::memory_efficient_attention_forward", "CUDA")
|
| 334 |
+
def memory_efficient_attention_forward_torch_wrapper(
|
| 335 |
+
query: torch.Tensor,
|
| 336 |
+
key: torch.Tensor,
|
| 337 |
+
value: torch.Tensor,
|
| 338 |
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
| 339 |
+
p: float = 0.0,
|
| 340 |
+
scale: Optional[float] = None,
|
| 341 |
+
) -> torch.Tensor:
|
| 342 |
+
"""
|
| 343 |
+
This provides a torch-compilable wrapper op to
|
| 344 |
+
memory_efficient_attention_forward in certain special cases.
|
| 345 |
+
|
| 346 |
+
Note that the following are not supported
|
| 347 |
+
- `op` input (?)
|
| 348 |
+
- certain attn_bias types (?)
|
| 349 |
+
- output_dtype
|
| 350 |
+
- K != Kv
|
| 351 |
+
"""
|
| 352 |
+
return memory_efficient_attention_forward(
|
| 353 |
+
query,
|
| 354 |
+
key,
|
| 355 |
+
value,
|
| 356 |
+
attn_bias,
|
| 357 |
+
p,
|
| 358 |
+
scale,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def memory_efficient_attention_forward(
|
| 363 |
+
query: torch.Tensor,
|
| 364 |
+
key: torch.Tensor,
|
| 365 |
+
value: torch.Tensor,
|
| 366 |
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
| 367 |
+
p: float = 0.0,
|
| 368 |
+
scale: Optional[float] = None,
|
| 369 |
+
*,
|
| 370 |
+
op: Optional[Type[AttentionFwOpBase]] = None,
|
| 371 |
+
output_dtype: Optional[torch.dtype] = None,
|
| 372 |
+
) -> torch.Tensor:
|
| 373 |
+
"""
|
| 374 |
+
Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`.
|
| 375 |
+
"""
|
| 376 |
+
return _memory_efficient_attention_forward(
|
| 377 |
+
Inputs(
|
| 378 |
+
query=query,
|
| 379 |
+
key=key,
|
| 380 |
+
value=value,
|
| 381 |
+
p=p,
|
| 382 |
+
attn_bias=attn_bias,
|
| 383 |
+
scale=scale,
|
| 384 |
+
output_dtype=output_dtype,
|
| 385 |
+
),
|
| 386 |
+
op=op,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def memory_efficient_attention_forward_requires_grad(
|
| 391 |
+
query: torch.Tensor,
|
| 392 |
+
key: torch.Tensor,
|
| 393 |
+
value: torch.Tensor,
|
| 394 |
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
| 395 |
+
p: float = 0.0,
|
| 396 |
+
scale: Optional[float] = None,
|
| 397 |
+
*,
|
| 398 |
+
op: Optional[Type[AttentionFwOpBase]] = None,
|
| 399 |
+
output_dtype: Optional[torch.dtype] = None,
|
| 400 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 401 |
+
"""
|
| 402 |
+
Returns a tuple (output, lse), where `lse` can be used to compute the backward pass later.
|
| 403 |
+
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments
|
| 404 |
+
See :attr:`xformers.ops.memory_efficient_attention_backward` for running the backward pass
|
| 405 |
+
"""
|
| 406 |
+
if p != 0.0:
|
| 407 |
+
raise NotImplementedError(
|
| 408 |
+
"dropout is not supported on the non-autograd API."
|
| 409 |
+
" If you want to use dropout, please call `memory_efficient_attention` directly"
|
| 410 |
+
)
|
| 411 |
+
out, ctx = _memory_efficient_attention_forward_requires_grad(
|
| 412 |
+
Inputs(
|
| 413 |
+
query=query,
|
| 414 |
+
key=key,
|
| 415 |
+
value=value,
|
| 416 |
+
p=p,
|
| 417 |
+
attn_bias=attn_bias,
|
| 418 |
+
scale=scale,
|
| 419 |
+
output_dtype=output_dtype,
|
| 420 |
+
),
|
| 421 |
+
op=op,
|
| 422 |
+
)
|
| 423 |
+
return out, ctx.lse
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def memory_efficient_attention_backward(
|
| 427 |
+
grad: torch.Tensor,
|
| 428 |
+
output: torch.Tensor,
|
| 429 |
+
lse: torch.Tensor,
|
| 430 |
+
query: torch.Tensor,
|
| 431 |
+
key: torch.Tensor,
|
| 432 |
+
value: torch.Tensor,
|
| 433 |
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
| 434 |
+
p: float = 0.0,
|
| 435 |
+
scale: Optional[float] = None,
|
| 436 |
+
*,
|
| 437 |
+
op: Optional[Type[AttentionBwOpBase]] = None,
|
| 438 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 439 |
+
"""
|
| 440 |
+
Computes the gradient of the attention.
|
| 441 |
+
Returns a tuple (dq, dk, dv)
|
| 442 |
+
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments.
|
| 443 |
+
`lse` is the tensor returned by
|
| 444 |
+
:attr:`xformers.ops.memory_efficient_attention_forward_requires_grad`
|
| 445 |
+
"""
|
| 446 |
+
if p != 0.0:
|
| 447 |
+
raise NotImplementedError(
|
| 448 |
+
"dropout is not supported on the non-autograd API."
|
| 449 |
+
" If you want to use dropout, please call `memory_efficient_attention` directly"
|
| 450 |
+
)
|
| 451 |
+
gradients = _memory_efficient_attention_backward(
|
| 452 |
+
Context(out=output, lse=lse),
|
| 453 |
+
Inputs(
|
| 454 |
+
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale
|
| 455 |
+
),
|
| 456 |
+
grad,
|
| 457 |
+
op=op,
|
| 458 |
+
)
|
| 459 |
+
return (gradients.dq, gradients.dk, gradients.dv)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def _memory_efficient_attention(
|
| 463 |
+
inp: Inputs, op: Optional[AttentionOp] = None
|
| 464 |
+
) -> torch.Tensor:
|
| 465 |
+
# fast-path that doesn't require computing the logsumexp for backward computation
|
| 466 |
+
if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
|
| 467 |
+
return _memory_efficient_attention_forward(
|
| 468 |
+
inp, op=op[0] if op is not None else None
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
output_shape = inp.normalize_bmhk()
|
| 472 |
+
|
| 473 |
+
op_fw = _serialize_op(op[0] if op is not None else None)
|
| 474 |
+
op_bw = _serialize_op(op[1] if op is not None else None)
|
| 475 |
+
return _fMHA.apply(
|
| 476 |
+
op_fw, op_bw, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale
|
| 477 |
+
)[0].reshape(output_shape)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def _memory_efficient_attention_forward(
|
| 481 |
+
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
|
| 482 |
+
) -> torch.Tensor:
|
| 483 |
+
inp.validate_inputs()
|
| 484 |
+
output_shape = inp.normalize_bmhk()
|
| 485 |
+
if op is None:
|
| 486 |
+
op = _dispatch_fw(inp, False)
|
| 487 |
+
else:
|
| 488 |
+
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
|
| 489 |
+
|
| 490 |
+
out, *_ = op.apply(inp, needs_gradient=False)
|
| 491 |
+
return out.reshape(output_shape)
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def _memory_efficient_attention_forward_requires_grad(
|
| 495 |
+
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
|
| 496 |
+
) -> Tuple[torch.Tensor, Context]:
|
| 497 |
+
inp.validate_inputs()
|
| 498 |
+
output_shape = inp.normalize_bmhk()
|
| 499 |
+
if op is None:
|
| 500 |
+
op = _dispatch_fw(inp, True)
|
| 501 |
+
else:
|
| 502 |
+
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
|
| 503 |
+
out = op.apply(inp, needs_gradient=True)
|
| 504 |
+
assert out[1] is not None
|
| 505 |
+
return (out[0].reshape(output_shape), out[1])
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def _detect_lse_packed_or_raise(lse: torch.Tensor, inp: Inputs) -> Optional[bool]:
|
| 509 |
+
"""
|
| 510 |
+
Detects the LSE format if we're in a varlen case.
|
| 511 |
+
Returns `None` if the format is not relevant (eg not varlen)
|
| 512 |
+
Raises an exception if the `lse` has the wrong shape
|
| 513 |
+
"""
|
| 514 |
+
shape_mismatch_err = (
|
| 515 |
+
"Input tensors have incompatible shapes.\n"
|
| 516 |
+
f" lse.shape : {lse.shape}\n"
|
| 517 |
+
f" query.shape : {inp.query.shape}\n"
|
| 518 |
+
f" attn_bias : {type(inp.attn_bias)}"
|
| 519 |
+
)
|
| 520 |
+
# 1. Check ndim & head dimensions
|
| 521 |
+
# In any case, LSE should be [*, *GH]
|
| 522 |
+
if lse.ndim != (inp.query.ndim - 1) or lse.shape[1:-1] != inp.query.shape[2:-1]:
|
| 523 |
+
raise ValueError(shape_mismatch_err)
|
| 524 |
+
lse_bm = [lse.shape[0], lse.shape[-1]]
|
| 525 |
+
lse_packed_shape = [inp.query.shape[0], inp.query.shape[1]]
|
| 526 |
+
lse_packed = lse_bm[0] == lse_packed_shape[0] and lse_bm >= lse_packed_shape
|
| 527 |
+
# 2. Check correctness for varlen biases with query.shape = [1, M, *GH, K]
|
| 528 |
+
# Either [1, *GH, M] (packed)
|
| 529 |
+
# Or [num_seq, *GH, Mq] .. with `Mq >= max_q` (padded)
|
| 530 |
+
if isinstance(inp.attn_bias, VARLEN_BIASES):
|
| 531 |
+
si = inp.attn_bias.q_seqinfo
|
| 532 |
+
lse_padded_shape = [si.seqstart.shape[0] - 1, si.max_seqlen]
|
| 533 |
+
lse_padded = lse_bm[0] == lse_padded_shape[0] and lse_bm >= lse_padded_shape
|
| 534 |
+
if lse_packed and lse_padded:
|
| 535 |
+
return None
|
| 536 |
+
elif lse_packed:
|
| 537 |
+
return True
|
| 538 |
+
elif lse_padded:
|
| 539 |
+
return False
|
| 540 |
+
raise ValueError(shape_mismatch_err)
|
| 541 |
+
# 3. For non-varlen, shape must be [B, *GH] with query.shape=[B, M, *GH, K]
|
| 542 |
+
if not lse_packed:
|
| 543 |
+
raise ValueError(shape_mismatch_err)
|
| 544 |
+
return None
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def _memory_efficient_attention_backward(
|
| 548 |
+
ctx: Context,
|
| 549 |
+
inp: Inputs,
|
| 550 |
+
grad: torch.Tensor,
|
| 551 |
+
op: Optional[Type[AttentionBwOpBase]],
|
| 552 |
+
*,
|
| 553 |
+
_skip_op_checks: bool = False,
|
| 554 |
+
) -> Gradients:
|
| 555 |
+
"""Warning: grad/ctx.out is potentially in BMK format"""
|
| 556 |
+
inp.validate_inputs()
|
| 557 |
+
if grad.ndim != inp.query.ndim or grad.ndim != ctx.out.ndim:
|
| 558 |
+
raise ValueError(
|
| 559 |
+
"All tensors should be either in BMK (ndim=3) or BMHK (ndim=4) format. \n"
|
| 560 |
+
f"grad.shape : {grad.shape} \n"
|
| 561 |
+
f"out.shape : {ctx.out.shape} \n"
|
| 562 |
+
f"query.shape: {inp.query.shape}"
|
| 563 |
+
)
|
| 564 |
+
shape_dq, shape_dk, shape_dv = tuple(
|
| 565 |
+
x.shape for x in (inp.query, inp.key, inp.value)
|
| 566 |
+
)
|
| 567 |
+
inp.normalize_bmhk()
|
| 568 |
+
varlen_lse_packed = _detect_lse_packed_or_raise(ctx.lse, inp)
|
| 569 |
+
grad = bmk2bmhk(grad, 1)
|
| 570 |
+
ctx.out = bmk2bmhk(ctx.out, 1)
|
| 571 |
+
|
| 572 |
+
if op is None:
|
| 573 |
+
op = _dispatch_bw(inp, varlen_lse_packed=varlen_lse_packed)
|
| 574 |
+
elif not _skip_op_checks:
|
| 575 |
+
_ensure_op_supports_or_raise(
|
| 576 |
+
ValueError, "memory_efficient_attention_backward", op, inp
|
| 577 |
+
)
|
| 578 |
+
if varlen_lse_packed is not None and varlen_lse_packed != op.VARLEN_LSE_PACKED:
|
| 579 |
+
raise ValueError(
|
| 580 |
+
f"Wrong LSE format for {op.NAME} in variable seqlen case. "
|
| 581 |
+
f"Double-check that the BW operator {op.NAME} is compatible "
|
| 582 |
+
f"with the operator used in the FW pass."
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
grads = op.apply(ctx, inp, grad)
|
| 586 |
+
grads.dq = grads.dq.reshape(shape_dq)
|
| 587 |
+
grads.dk = grads.dk.reshape(shape_dk)
|
| 588 |
+
grads.dv = grads.dv.reshape(shape_dv)
|
| 589 |
+
return grads
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
def memory_efficient_attention_partial(
|
| 593 |
+
query: torch.Tensor,
|
| 594 |
+
key: torch.Tensor,
|
| 595 |
+
value: torch.Tensor,
|
| 596 |
+
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
| 597 |
+
p: float = 0.0,
|
| 598 |
+
scale: Optional[float] = None,
|
| 599 |
+
*,
|
| 600 |
+
op: Optional[Union[AttentionOp, Type[AttentionFwOpBase]]] = None,
|
| 601 |
+
output_dtype: Optional[torch.dtype] = None,
|
| 602 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 603 |
+
"""
|
| 604 |
+
Returns a tuple (output, lse), where `output` is the attention in the style of
|
| 605 |
+
memory_efficient_attention, and `lse` is extra data, a log-sum-exp.
|
| 606 |
+
The outputs of calls to this with the same query and separate keys and values
|
| 607 |
+
can be merged with merge_attentions to obtain the attention of the queries
|
| 608 |
+
against the disjoint union of the keys and values.
|
| 609 |
+
|
| 610 |
+
Warning: The backward pass of this function is quite restricted. In particular
|
| 611 |
+
we assume that in the forward pass the outputs were only used in merge_attention
|
| 612 |
+
calculations, and that LSEs weren't used anywhere except in merge attentions.
|
| 613 |
+
"""
|
| 614 |
+
if p != 0.0:
|
| 615 |
+
raise NotImplementedError("dropout is not supported.")
|
| 616 |
+
fwop: Optional[Type[AttentionFwOpBase]] = op[0] if isinstance(op, tuple) else op
|
| 617 |
+
inp = Inputs(
|
| 618 |
+
query=query,
|
| 619 |
+
key=key,
|
| 620 |
+
value=value,
|
| 621 |
+
p=p,
|
| 622 |
+
attn_bias=attn_bias,
|
| 623 |
+
scale=scale,
|
| 624 |
+
output_dtype=output_dtype,
|
| 625 |
+
is_partial=True,
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
is_grad = torch.is_grad_enabled() and any(
|
| 629 |
+
x.requires_grad for x in [query, key, value]
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
if not is_grad:
|
| 633 |
+
out, ctx = _memory_efficient_attention_forward_requires_grad(
|
| 634 |
+
inp,
|
| 635 |
+
op=fwop,
|
| 636 |
+
)
|
| 637 |
+
return out, ctx.lse
|
| 638 |
+
|
| 639 |
+
if query.ndim == 5:
|
| 640 |
+
raise ValueError("gradients not supported for 5D tensors")
|
| 641 |
+
if isinstance(op, tuple):
|
| 642 |
+
op_fw = _serialize_op(op[0])
|
| 643 |
+
op_bw = _serialize_op(op[1])
|
| 644 |
+
elif op is None:
|
| 645 |
+
op_fw = op_bw = None
|
| 646 |
+
else:
|
| 647 |
+
op_fw = _serialize_op(op)
|
| 648 |
+
op_bw = None
|
| 649 |
+
return _fMHA.apply(
|
| 650 |
+
op_fw,
|
| 651 |
+
op_bw,
|
| 652 |
+
inp.query,
|
| 653 |
+
inp.key,
|
| 654 |
+
inp.value,
|
| 655 |
+
inp.attn_bias,
|
| 656 |
+
inp.p,
|
| 657 |
+
inp.scale,
|
| 658 |
+
inp.output_dtype,
|
| 659 |
+
inp.is_partial,
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def merge_attentions(
|
| 664 |
+
attn_split: Union[torch.Tensor, Sequence[torch.Tensor]],
|
| 665 |
+
lse_split: Union[torch.Tensor, Sequence[torch.Tensor]],
|
| 666 |
+
write_lse: bool = True,
|
| 667 |
+
output_dtype: Optional[torch.dtype] = None,
|
| 668 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 669 |
+
"""
|
| 670 |
+
Combine attention output computed on different parts of K/V for the same
|
| 671 |
+
query to get attention on the whole K/V. See https://arxiv.org/abs/2402.05099
|
| 672 |
+
The result is equal to
|
| 673 |
+
Out_full = (Out1 * exp(LSE1) + Out2 * exp(LSE2) + ...) / (exp(LSE1) + exp(LSE2) + ...)
|
| 674 |
+
LSE_full = log(exp(LSE1) + exp(LSE2) + ...)
|
| 675 |
+
|
| 676 |
+
Args:
|
| 677 |
+
attn_split: attention outputs for chunks,
|
| 678 |
+
either as a list of tensors of shapes [B, M, G, H, Kq] or [B, M, H, Kq]
|
| 679 |
+
or as a single tensor of shape [num_chunks, B, M, G, H, Kq]
|
| 680 |
+
or [num_chunks, B, M, H, Kq]
|
| 681 |
+
lse_split: LSE for chunks,
|
| 682 |
+
either as a list of tensors of shapes [B, G, H, M] or [B, H, M]
|
| 683 |
+
or as a single tensor of shape [num_chunks, B, G, H, M] or [num_chunks, B, H, M]
|
| 684 |
+
write_lse: whether to output LSE
|
| 685 |
+
output_dtype: dtype of attn_out
|
| 686 |
+
|
| 687 |
+
Returns:
|
| 688 |
+
attn_out: [B, M, G, H, Kq] or [B, M, H, Kq]
|
| 689 |
+
lse_out: [B, G, H, M] or [B, H, M] if write_lse
|
| 690 |
+
or None otherwise
|
| 691 |
+
"""
|
| 692 |
+
|
| 693 |
+
attn_is_concat = isinstance(attn_split, torch.Tensor)
|
| 694 |
+
lse_is_concat = isinstance(lse_split, torch.Tensor)
|
| 695 |
+
|
| 696 |
+
attn_requires_grad = (
|
| 697 |
+
attn_split.requires_grad # type: ignore
|
| 698 |
+
if attn_is_concat
|
| 699 |
+
else any(x.requires_grad for x in attn_split)
|
| 700 |
+
)
|
| 701 |
+
lse_requires_grad = (
|
| 702 |
+
lse_split.requires_grad # type: ignore
|
| 703 |
+
if lse_is_concat
|
| 704 |
+
else any(x.requires_grad for x in lse_split)
|
| 705 |
+
)
|
| 706 |
+
requires_grad = torch.is_grad_enabled() and (
|
| 707 |
+
attn_requires_grad or lse_requires_grad
|
| 708 |
+
)
|
| 709 |
+
if requires_grad and not write_lse:
|
| 710 |
+
raise ValueError("write_lse should be true if inputs require gradients.")
|
| 711 |
+
|
| 712 |
+
concat_path = attn_is_concat and lse_is_concat and not requires_grad
|
| 713 |
+
if concat_path:
|
| 714 |
+
attn_split = cast(torch.Tensor, attn_split)
|
| 715 |
+
lse_split = cast(torch.Tensor, lse_split)
|
| 716 |
+
if attn_split.ndim != lse_split.ndim + 1:
|
| 717 |
+
raise ValueError(
|
| 718 |
+
f"Incompatible input shapes: {attn_split.shape=}, {lse_split.shape=}"
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
is_bmhk = attn_split.ndim == 5
|
| 722 |
+
if is_bmhk:
|
| 723 |
+
attn_split = attn_split.unsqueeze(3)
|
| 724 |
+
lse_split = lse_split.unsqueeze(2)
|
| 725 |
+
|
| 726 |
+
num_chunks, B, M, G, H, Kq = attn_split.shape
|
| 727 |
+
num_chunks1, B1, G1, H1, M1 = lse_split.shape
|
| 728 |
+
if B != B1 or G != G1 or H != H1 or num_chunks != num_chunks1 or M != M:
|
| 729 |
+
raise ValueError(
|
| 730 |
+
f"Incompatible input shapes: {attn_split.shape=} {lse_split.shape=} "
|
| 731 |
+
f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {num_chunks}/{num_chunks1}, {M}/{M}"
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
attn_split = attn_split.permute(1, 3, 4, 0, 2, 5)
|
| 735 |
+
lse_split = lse_split.permute(1, 2, 3, 0, 4)
|
| 736 |
+
|
| 737 |
+
device = attn_split.device
|
| 738 |
+
attn_dtype = attn_split.dtype
|
| 739 |
+
lse_dtype = lse_split.dtype
|
| 740 |
+
else:
|
| 741 |
+
if attn_is_concat:
|
| 742 |
+
attn_split = attn_split.unbind(0) # type: ignore
|
| 743 |
+
if lse_is_concat:
|
| 744 |
+
lse_split = lse_split.unbind(0) # type: ignore
|
| 745 |
+
num_chunks = len(attn_split)
|
| 746 |
+
if len(lse_split) != num_chunks:
|
| 747 |
+
raise ValueError(
|
| 748 |
+
f"Incompatible number of LSE and attention chunks: {len(attn_split)=}, {len(lse_split)=}"
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
attn_unsqueezed = []
|
| 752 |
+
lse_unsqueezed = []
|
| 753 |
+
is_bmhk = False
|
| 754 |
+
for i in range(num_chunks):
|
| 755 |
+
if attn_split[i].ndim != lse_split[i].ndim + 1:
|
| 756 |
+
raise ValueError(
|
| 757 |
+
f"Incompatible input shapes for chunk {i}: {attn_split[i].shape=}, {lse_split[i].shape=}"
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
is_bmhk = attn_split[i].ndim == 4
|
| 761 |
+
if is_bmhk:
|
| 762 |
+
attn_unsqueezed.append(attn_split[i].unsqueeze(2))
|
| 763 |
+
lse_unsqueezed.append(lse_split[i].unsqueeze(1))
|
| 764 |
+
else:
|
| 765 |
+
attn_unsqueezed.append(attn_split[i])
|
| 766 |
+
lse_unsqueezed.append(lse_split[i])
|
| 767 |
+
attn_split, lse_split = attn_unsqueezed, lse_unsqueezed
|
| 768 |
+
|
| 769 |
+
B, M, G, H, Kq = attn_split[0].shape
|
| 770 |
+
B1, G1, H1, M1 = lse_split[0].shape
|
| 771 |
+
if B != B1 or G != G1 or H != H1 or M != M:
|
| 772 |
+
raise ValueError(
|
| 773 |
+
f"Incompatible input shapes: {attn_split[0].shape=}, {lse_split[0].shape=} "
|
| 774 |
+
f"{B}/{B1}, {G}/{G1}, {H}/{H1}, {M}/{M}"
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
for i in range(num_chunks):
|
| 778 |
+
if attn_split[i].shape != (B, M, G, H, Kq):
|
| 779 |
+
raise ValueError(
|
| 780 |
+
f"Incompatible input shapes for attention chunk {i}: "
|
| 781 |
+
f"{attn_split[i].shape=}, {(B, M, G, H, Kq)=}"
|
| 782 |
+
)
|
| 783 |
+
if lse_split[i].shape != (B, G, H, M):
|
| 784 |
+
raise ValueError(
|
| 785 |
+
f"Incompatible input shapes for LSE chunk {i}: "
|
| 786 |
+
f"{lse_split[i].shape=}, {(B, G, H, M)=}"
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
attn_split[i] = attn_split[i].permute(0, 2, 3, 1, 4) # to (B, G, H, M, Kq)
|
| 790 |
+
|
| 791 |
+
device = attn_split[0].device
|
| 792 |
+
attn_dtype = attn_split[0].dtype
|
| 793 |
+
lse_dtype = lse_split[0].dtype
|
| 794 |
+
|
| 795 |
+
attn_out = torch.empty(
|
| 796 |
+
B,
|
| 797 |
+
M,
|
| 798 |
+
G,
|
| 799 |
+
H,
|
| 800 |
+
Kq,
|
| 801 |
+
device=device,
|
| 802 |
+
dtype=output_dtype or attn_dtype,
|
| 803 |
+
requires_grad=requires_grad,
|
| 804 |
+
)
|
| 805 |
+
if write_lse:
|
| 806 |
+
lse_out = torch.empty(
|
| 807 |
+
B, G, H, M, device=device, dtype=lse_dtype, requires_grad=requires_grad
|
| 808 |
+
)
|
| 809 |
+
else:
|
| 810 |
+
lse_out = None
|
| 811 |
+
|
| 812 |
+
if concat_path:
|
| 813 |
+
triton_splitk.merge_attentions(attn_out, lse_out, attn_split, lse_split) # type: ignore
|
| 814 |
+
else:
|
| 815 |
+
attn_out, lse_out = _MergeAttentions.apply(attn_out, lse_out, *attn_split, *lse_split) # type: ignore
|
| 816 |
+
|
| 817 |
+
if is_bmhk:
|
| 818 |
+
attn_out = attn_out[:, :, 0]
|
| 819 |
+
if lse_out is not None:
|
| 820 |
+
lse_out = lse_out[:, 0]
|
| 821 |
+
|
| 822 |
+
return attn_out, lse_out
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
class _MergeAttentions(torch.autograd.Function):
|
| 826 |
+
@staticmethod
|
| 827 |
+
# type: ignore
|
| 828 |
+
def forward(
|
| 829 |
+
ctx, attn_out: torch.Tensor, lse_out: torch.Tensor, *inputs: torch.Tensor
|
| 830 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 831 |
+
num_chunks = len(inputs) // 2
|
| 832 |
+
attn_split, lse_split = inputs[:num_chunks], inputs[num_chunks:]
|
| 833 |
+
|
| 834 |
+
triton_splitk.merge_attentions_varargs(attn_out, lse_out, attn_split, lse_split)
|
| 835 |
+
|
| 836 |
+
ctx.save_for_backward(
|
| 837 |
+
attn_out,
|
| 838 |
+
lse_out,
|
| 839 |
+
*inputs,
|
| 840 |
+
)
|
| 841 |
+
return attn_out, lse_out
|
| 842 |
+
|
| 843 |
+
@staticmethod
|
| 844 |
+
# type: ignore
|
| 845 |
+
def backward(
|
| 846 |
+
ctx, grad_attn: torch.Tensor, grad_lse: torch.Tensor
|
| 847 |
+
) -> Tuple[Optional[torch.Tensor], ...]:
|
| 848 |
+
out, lse, *inputs = ctx.saved_tensors
|
| 849 |
+
num_chunks = len(inputs) // 2
|
| 850 |
+
attn_split, lse_split = inputs[:num_chunks], inputs[num_chunks:]
|
| 851 |
+
dattn, dlse = triton_splitk.merge_attentions_varargs_backward(
|
| 852 |
+
attn_split,
|
| 853 |
+
lse_split,
|
| 854 |
+
out,
|
| 855 |
+
lse,
|
| 856 |
+
grad_attn,
|
| 857 |
+
grad_lse,
|
| 858 |
+
)
|
| 859 |
+
ret = [None, None] + dattn + dlse
|
| 860 |
+
return tuple(ret)
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
ALL_FW_OPS: List[Type[AttentionFwOpBase]] = [
|
| 864 |
+
cutlass.FwOp if torch.version.cuda else ck.FwOp,
|
| 865 |
+
flash.FwOp,
|
| 866 |
+
flash3.FwOp,
|
| 867 |
+
triton_splitk.FwOp,
|
| 868 |
+
]
|
| 869 |
+
|
| 870 |
+
ALL_BW_OPS: List[Type[AttentionBwOpBase]] = [
|
| 871 |
+
cutlass.BwOp if torch.version.cuda else ck.BwOp,
|
| 872 |
+
flash.BwOp,
|
| 873 |
+
flash3.BwOp,
|
| 874 |
+
]
|
| 875 |
+
|
| 876 |
+
__all__ = [
|
| 877 |
+
"AttentionBias",
|
| 878 |
+
"AttentionOp",
|
| 879 |
+
"AttentionOpBase",
|
| 880 |
+
"LowerTriangularMask",
|
| 881 |
+
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
|
| 882 |
+
"MemoryEfficientAttentionCutlassOp",
|
| 883 |
+
"MemoryEfficientAttentionFlashAttentionOp",
|
| 884 |
+
"memory_efficient_attention",
|
| 885 |
+
"MemoryEfficientAttentionCkOp",
|
| 886 |
+
"MemoryEfficientAttentionCkDecoderOp",
|
| 887 |
+
"ALL_FW_OPS",
|
| 888 |
+
"ALL_BW_OPS",
|
| 889 |
+
"attn_bias",
|
| 890 |
+
"_get_use_fa3",
|
| 891 |
+
"_set_use_fa3",
|
| 892 |
+
"BlockDiagonalMask",
|
| 893 |
+
]
|
.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (36.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/attn_bias.cpython-311.pyc
ADDED
|
Binary file (84.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck.cpython-311.pyc
ADDED
|
Binary file (19.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/xformers/ops/fmha/__pycache__/ck_decoder.cpython-311.pyc
ADDED
|
Binary file (6.87 kB). View file
|
|
|