Flexstorydiff / xformers /tests /test_attentions.py
FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Tuple
import pytest
import torch
from xformers.components import (
InputProjection,
InputProjectionConfig,
MultiHeadDispatch,
)
# Automatically test all the registered attentions
from xformers.components.attention import (
_DENSITY_THRESHOLD,
ATTENTION_REGISTRY,
build_attention,
)
DEVICES = (
[torch.device("cpu")] if not torch.cuda.is_available() else [torch.device("cuda")]
)
BATCH = 2
SEQ = 128 if torch.cuda.is_available() else 36
MODEL = 128 if torch.cuda.is_available() else 16
GLOBAL_ATTENTION_RATIO = (
_DENSITY_THRESHOLD * 0.9
) # Make sure that we test the sparse implementation, no matter the threshold
assert ATTENTION_REGISTRY.keys(), "Attention layers should have been registered"
_non_order_invariant_attentions = ["visual", "pooling"]
def _get_multihead(
attention_name,
attn_dropout,
res_dropout,
causal,
heads,
device,
skip_output_projection=False,
use_separate_proj_weights=True,
):
test_config = {
"name": attention_name,
"dropout": attn_dropout,
"causal": causal,
"seq_len": SEQ,
"window_size": SEQ // 8 + 1, # local attention
"attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO,
"dim_model": MODEL,
"num_heads": heads,
"dim_head": MODEL / heads,
"num_rules": 2, # Compositional Attention
"r": 0.5, # random attention, ratio of tokens that the attention can attend to
}
if skip_output_projection:
def noop(x):
return x
test_config["out_proj"] = noop
# Add some blocksparse layout to test the corresponding attention
block_size = 16
test_config["layout"] = torch.eye(
SEQ // block_size, SEQ // block_size, dtype=torch.long
)
test_config["block_size"] = block_size
attention = build_attention(test_config)
# build a multi head dispatch to test this attention mechanism
multi_head = MultiHeadDispatch(
seq_len=SEQ,
dim_model=MODEL,
residual_dropout=res_dropout,
num_heads=heads,
attention=attention,
use_separate_proj_weight=use_separate_proj_weights,
).to(device)
return multi_head
@pytest.mark.parametrize("attn_dropout", [0.0, 0.3])
@pytest.mark.parametrize("residual_dropout", [0.0, 0.1])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("heads", [1, 4])
@pytest.mark.parametrize(
"attention_name", ATTENTION_REGISTRY.keys() - _non_order_invariant_attentions
)
@pytest.mark.parametrize("device", DEVICES)
def test_order_invariance(
attention_name: str,
heads: int,
attn_dropout: float,
residual_dropout: float,
causal: bool,
device: torch.device,
):
if (
torch.version.hip
and device == torch.device("cuda")
and attention_name == "local"
):
# Backend calls into Sputnik library which isn't built on ROCm
device = torch.device("cpu")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
multi_head = _get_multihead(
attention_name,
attn_dropout,
residual_dropout,
causal,
heads,
device,
use_separate_proj_weights=False,
)
if (
int(math.sqrt(SEQ)) ** 2 != SEQ
and multi_head.attention.requires_squared_context
):
pytest.skip(f"{attention_name} requires squared sequence lengths")
# Check that we can pass a smaller sequence
seqs = (
[SEQ, SEQ // 2]
if not multi_head.attention.requires_same_k_q_dimensions
else [SEQ]
)
for seq in seqs:
# Check that the attention is invariant to a permutation of K, V
inputs = torch.rand(BATCH, seq, MODEL, device=device)
shuffle = torch.randperm(inputs.shape[1])
inputs_shuffled = inputs[:, shuffle, :].clone()
results = multi_head(inputs, inputs, inputs)
results_shuffled = multi_head(inputs, inputs_shuffled, inputs_shuffled)
torch.allclose(results, results_shuffled)
# Check that the attention is equivariant to a permutation of Q,
# meaning that the result is permuted in the same way
results_shuffled = multi_head(inputs_shuffled, inputs, inputs)
torch.allclose(results[:, shuffle, :], results_shuffled)
# Check that dropout actually drops some values
if attn_dropout > 0:
att_1 = multi_head(inputs, inputs_shuffled, inputs)
att_2 = multi_head(inputs, inputs_shuffled, inputs)
assert (att_1 != att_2).any()
# Test AMP, if available
if device.type == "cuda":
with torch.cuda.amp.autocast(enabled=True):
_ = multi_head(inputs, inputs_shuffled, inputs)
@pytest.mark.parametrize("heads", [1, 4])
@pytest.mark.parametrize("attention_name", ["scaled_dot_product"])
@pytest.mark.parametrize("device", DEVICES)
def test_kqv_ordering(
attention_name: str,
heads: int,
device: torch.device,
):
multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device)
# Check kqv are not flipped
# this will not catch all issues, but would catch a V being misplaced
# make k and q complimentary, so that QKt is all zero and attention is uniform
q = torch.cat(
(
torch.rand((1, MODEL // 2), device=device),
torch.zeros((1, MODEL // 2), device=device),
),
dim=1,
).expand((BATCH, SEQ, MODEL))
k = torch.cat(
(
torch.zeros((1, MODEL // 2), device=device),
torch.rand((1, MODEL // 2), device=device),
),
dim=1,
).expand((BATCH, SEQ, MODEL))
v = torch.rand(BATCH, SEQ, MODEL, device=device)
# Normal call
res = multi_head(query=q, key=k, value=v)
for i in range(BATCH):
assert torch.allclose(res[i, :, :], res[i, 0, :].unsqueeze(-2))
assert not torch.allclose(res[0, :, :], res[1, :, :])
# Flip qkv, and check that we invert the above check properly
res_false = multi_head(query=v, key=k, value=q)
assert torch.allclose(res_false[0, :, :], res_false[1, :, :])
@pytest.mark.parametrize("heads", [1, 4])
@pytest.mark.parametrize("attention_name", ["scaled_dot_product"])
@pytest.mark.parametrize("device", DEVICES)
def test_different_seqlen(
attention_name: str,
heads: int,
device: torch.device,
):
multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device)
# Check kqv are not flipped
# this will not catch all issues, but would catch a V being misplaced
# make k and q complimentary, so that QKt is all zero and attention is uniform
q = torch.cat(
(
torch.rand((1, MODEL // 2), device=device),
torch.zeros((1, MODEL // 2), device=device),
),
dim=1,
).expand((BATCH, SEQ, MODEL))
k = torch.cat(
(
torch.zeros((1, MODEL // 2), device=device),
torch.rand((1, MODEL // 2), device=device),
),
dim=1,
).expand((BATCH, SEQ, MODEL))
v = torch.rand(BATCH, SEQ, MODEL, device=device)
# Normal call
res = multi_head(query=q, key=k, value=v)
# Changing sequence length by dividing by two to simulate differing sequence length
q2 = torch.cat(
(
torch.rand((1, MODEL // 2), device=device),
torch.zeros((1, MODEL // 2), device=device),
),
dim=1,
).expand((BATCH, SEQ // 2, MODEL))
k2 = torch.cat(
(
torch.zeros((1, MODEL // 2), device=device),
torch.rand((1, MODEL // 2), device=device),
),
dim=1,
).expand((BATCH, SEQ // 2, MODEL))
v2 = torch.rand(BATCH, SEQ // 2, MODEL, device=device)
res2 = multi_head(query=q2, key=k2, value=v2)
assert res.shape != res2.shape
@pytest.mark.parametrize("proj_bias", [False, True])
@pytest.mark.parametrize("same_sizes", [False, True])
@pytest.mark.parametrize("same_settings", [False, True])
def test_inproj(proj_bias: bool, same_sizes: bool, same_settings: bool):
test_config = {
"name": "scaled_dot_product",
"dropout": 0.1,
"causal": False,
"seq_len": SEQ,
"window_size": SEQ // 8 + 1,
"num_heads": 1,
"dim_head": MODEL,
}
attention = build_attention(test_config)
# Construct the initial projection, test different options
in_params = InputProjectionConfig(MODEL, MODEL, proj_bias)
if same_settings:
in_proj = InputProjection(in_params, None, None)
out_features = MODEL
else:
out_features = MODEL if same_sizes else MODEL // 2
in_params_flip = InputProjectionConfig(MODEL, out_features, proj_bias)
in_proj = InputProjection(
in_params_flip, # Q proj
in_params_flip, # K proj
in_params, # V proj
)
# build a multi head dispatch to test this attention mechanism
multi_head = MultiHeadDispatch(
seq_len=SEQ,
dim_model=MODEL,
residual_dropout=0.1,
num_heads=1,
attention=attention,
in_proj_container=in_proj,
dim_key=out_features,
dim_value=MODEL,
)
# Check kqv are not flipped
# this will not catch all issues, but would catch a V being misplaced
# make k and q complimentary, so that QKt is all zero and attention is uniform
q = torch.cat(
(
torch.rand((1, MODEL // 2)),
torch.zeros((1, MODEL // 2)),
),
dim=1,
).expand((BATCH, SEQ, MODEL))
k = torch.cat(
(
torch.zeros((1, MODEL // 2)),
torch.rand((1, MODEL // 2)),
),
dim=1,
).expand((BATCH, SEQ, MODEL))
v = torch.rand(BATCH, SEQ, MODEL)
# just check that a FW does not assert out
_ = multi_head(query=q, key=k, value=v)
@pytest.mark.parametrize("heads", [1, 4])
@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys())
@pytest.mark.parametrize("device", DEVICES)
def test_different_kq_dimensions(
attention_name: str,
heads: int,
device: torch.device,
):
multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device)
if multi_head.attention.requires_same_k_q_dimensions:
# pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre.
pytest.skip(f"{attention_name} does not support different k, q dimensions yet.")
seq_q = SEQ // 2
q = torch.rand((BATCH, seq_q, MODEL), device=device)
k = torch.rand((BATCH, SEQ, MODEL), device=device)
v = torch.rand((BATCH, SEQ, MODEL), device=device)
res = multi_head(query=q, key=k, value=v)
assert res.shape == torch.Size([BATCH, seq_q, MODEL])
@pytest.mark.parametrize("heads", [1, 4])
@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys())
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize(
"batch_sizes",
[
(1, BATCH, BATCH),
(BATCH, 1, BATCH),
(BATCH, BATCH, 1),
(1, 1, BATCH),
(BATCH, 1, 1),
(1, BATCH, 1),
],
)
def test_broadcast_batch_dimension(
attention_name: str,
heads: int,
device: torch.device,
batch_sizes: Tuple[int, int, int],
):
Q_BATCH, K_BATCH, V_BATCH = batch_sizes
multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device)
if (
int(math.sqrt(SEQ)) ** 2 != SEQ
and multi_head.attention.requires_squared_context
):
pytest.skip(f"{attention_name} requires squared sequence lengths")
if multi_head.attention.requires_same_k_q_dimensions:
# pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre.
pytest.skip(f"{attention_name} does not support different k, q dimensions yet.")
q = torch.rand((Q_BATCH, SEQ, MODEL), device=device)
k = torch.rand((K_BATCH, SEQ, MODEL), device=device)
v = torch.rand((V_BATCH, SEQ, MODEL), device=device)
res = multi_head(query=q, key=k, value=v)
assert res.shape == torch.Size([BATCH, SEQ, MODEL])
@pytest.mark.parametrize("heads", [1, 4])
@pytest.mark.parametrize("attention_name", ["scaled_dot_product", "favor"])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires a CUDA gpu")
def test_causal(
attention_name: str,
heads: int,
):
"""
Make sure that the causal flag is respected.
The input data is orthogonal by design if causal is respected, but if the attention looks ahead this will fail
"""
torch.random.manual_seed(42)
device = torch.device("cuda")
multi_head = _get_multihead(
attention_name,
0.0,
0.0,
causal=True,
heads=heads,
device=device,
skip_output_projection=True,
)
k = (
torch.tril(torch.ones((SEQ, SEQ), device=device), diagonal=0)
.unsqueeze(0)
.expand(1, -1, -1)
)
q = (
torch.triu(torch.ones((SEQ, SEQ), device=device), diagonal=0)
.unsqueeze(0)
.expand(1, -1, -1)
)
v = (
torch.arange(SEQ, device=device)
.float()
.unsqueeze(0)
.unsqueeze(-1)
.expand(1, -1, SEQ)
)
# Make sure that we don´t project, to keep the embeddings orthogonal
multi_head.attention.requires_input_projection = False
res = multi_head(query=q, key=k, value=v).squeeze(0)
# Consolidate along the embedding, if causal was respected the amplitude should be sorted already
res_sum = torch.sum(res, dim=1).cpu()
assert torch.allclose(torch.sort(res_sum)[1], torch.arange(SEQ)) or torch.allclose(
torch.sort(res_sum, descending=True)[1], torch.arange(SEQ)
), res_sum
@pytest.mark.parametrize("attn_dropout", [0.0, 0.1])
@pytest.mark.parametrize("heads", [2])
@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys())
@pytest.mark.skipif(torch.cuda.is_available(), reason="CUDA gpu not supported yet")
def test_torch_script_ability(
attention_name: str,
heads: int,
attn_dropout: float,
):
if attention_name in {"favor", "global", "local", "random"}:
# pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre.
pytest.skip(f"{attention_name} does not support scripting yet.")
device = torch.device("cpu")
multi_head = _get_multihead(attention_name, attn_dropout, 0.0, False, heads, device)
if (
int(math.sqrt(SEQ)) ** 2 != SEQ
and multi_head.attention.requires_squared_context
):
pytest.skip(f"{attention_name} requires squared sequence lengths")
# input for tracing the function
q = torch.rand((BATCH, SEQ, MODEL), device=device)
k = torch.rand((BATCH, SEQ, MODEL), device=device)
v = torch.rand((BATCH, SEQ, MODEL), device=device)
# to make sure dropout behaves deterministically
torch.random.manual_seed(42)
# tracing the attention module
traced_multi_head = torch.jit.trace(multi_head, (q, k, v))
# create new random inputs for testing the eager model and traced model
q = torch.rand((BATCH, SEQ, MODEL), device=device)
k = torch.rand((BATCH, SEQ, MODEL), device=device)
v = torch.rand((BATCH, SEQ, MODEL), device=device)
# to make sure dropout behaves deterministically need to set the seed again
torch.random.manual_seed(42)
res = multi_head(query=q, key=k, value=v)
# to make sure dropout behaves deterministically need to set the seed again
torch.random.manual_seed(42)
res_traced = traced_multi_head(query=q, key=k, value=v)
assert torch.allclose(res, res_traced)
# TODO: way more unit tests..