Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
from typing import Tuple | |
import pytest | |
import torch | |
from xformers.components import ( | |
InputProjection, | |
InputProjectionConfig, | |
MultiHeadDispatch, | |
) | |
# Automatically test all the registered attentions | |
from xformers.components.attention import ( | |
_DENSITY_THRESHOLD, | |
ATTENTION_REGISTRY, | |
build_attention, | |
) | |
DEVICES = ( | |
[torch.device("cpu")] if not torch.cuda.is_available() else [torch.device("cuda")] | |
) | |
BATCH = 2 | |
SEQ = 128 if torch.cuda.is_available() else 36 | |
MODEL = 128 if torch.cuda.is_available() else 16 | |
GLOBAL_ATTENTION_RATIO = ( | |
_DENSITY_THRESHOLD * 0.9 | |
) # Make sure that we test the sparse implementation, no matter the threshold | |
assert ATTENTION_REGISTRY.keys(), "Attention layers should have been registered" | |
_non_order_invariant_attentions = ["visual", "pooling"] | |
def _get_multihead( | |
attention_name, | |
attn_dropout, | |
res_dropout, | |
causal, | |
heads, | |
device, | |
skip_output_projection=False, | |
use_separate_proj_weights=True, | |
): | |
test_config = { | |
"name": attention_name, | |
"dropout": attn_dropout, | |
"causal": causal, | |
"seq_len": SEQ, | |
"window_size": SEQ // 8 + 1, # local attention | |
"attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO, | |
"dim_model": MODEL, | |
"num_heads": heads, | |
"dim_head": MODEL / heads, | |
"num_rules": 2, # Compositional Attention | |
"r": 0.5, # random attention, ratio of tokens that the attention can attend to | |
} | |
if skip_output_projection: | |
def noop(x): | |
return x | |
test_config["out_proj"] = noop | |
# Add some blocksparse layout to test the corresponding attention | |
block_size = 16 | |
test_config["layout"] = torch.eye( | |
SEQ // block_size, SEQ // block_size, dtype=torch.long | |
) | |
test_config["block_size"] = block_size | |
attention = build_attention(test_config) | |
# build a multi head dispatch to test this attention mechanism | |
multi_head = MultiHeadDispatch( | |
seq_len=SEQ, | |
dim_model=MODEL, | |
residual_dropout=res_dropout, | |
num_heads=heads, | |
attention=attention, | |
use_separate_proj_weight=use_separate_proj_weights, | |
).to(device) | |
return multi_head | |
def test_order_invariance( | |
attention_name: str, | |
heads: int, | |
attn_dropout: float, | |
residual_dropout: float, | |
causal: bool, | |
device: torch.device, | |
): | |
if ( | |
torch.version.hip | |
and device == torch.device("cuda") | |
and attention_name == "local" | |
): | |
# Backend calls into Sputnik library which isn't built on ROCm | |
device = torch.device("cpu") | |
torch.manual_seed(42) | |
torch.cuda.manual_seed_all(42) | |
multi_head = _get_multihead( | |
attention_name, | |
attn_dropout, | |
residual_dropout, | |
causal, | |
heads, | |
device, | |
use_separate_proj_weights=False, | |
) | |
if ( | |
int(math.sqrt(SEQ)) ** 2 != SEQ | |
and multi_head.attention.requires_squared_context | |
): | |
pytest.skip(f"{attention_name} requires squared sequence lengths") | |
# Check that we can pass a smaller sequence | |
seqs = ( | |
[SEQ, SEQ // 2] | |
if not multi_head.attention.requires_same_k_q_dimensions | |
else [SEQ] | |
) | |
for seq in seqs: | |
# Check that the attention is invariant to a permutation of K, V | |
inputs = torch.rand(BATCH, seq, MODEL, device=device) | |
shuffle = torch.randperm(inputs.shape[1]) | |
inputs_shuffled = inputs[:, shuffle, :].clone() | |
results = multi_head(inputs, inputs, inputs) | |
results_shuffled = multi_head(inputs, inputs_shuffled, inputs_shuffled) | |
torch.allclose(results, results_shuffled) | |
# Check that the attention is equivariant to a permutation of Q, | |
# meaning that the result is permuted in the same way | |
results_shuffled = multi_head(inputs_shuffled, inputs, inputs) | |
torch.allclose(results[:, shuffle, :], results_shuffled) | |
# Check that dropout actually drops some values | |
if attn_dropout > 0: | |
att_1 = multi_head(inputs, inputs_shuffled, inputs) | |
att_2 = multi_head(inputs, inputs_shuffled, inputs) | |
assert (att_1 != att_2).any() | |
# Test AMP, if available | |
if device.type == "cuda": | |
with torch.cuda.amp.autocast(enabled=True): | |
_ = multi_head(inputs, inputs_shuffled, inputs) | |
def test_kqv_ordering( | |
attention_name: str, | |
heads: int, | |
device: torch.device, | |
): | |
multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) | |
# Check kqv are not flipped | |
# this will not catch all issues, but would catch a V being misplaced | |
# make k and q complimentary, so that QKt is all zero and attention is uniform | |
q = torch.cat( | |
( | |
torch.rand((1, MODEL // 2), device=device), | |
torch.zeros((1, MODEL // 2), device=device), | |
), | |
dim=1, | |
).expand((BATCH, SEQ, MODEL)) | |
k = torch.cat( | |
( | |
torch.zeros((1, MODEL // 2), device=device), | |
torch.rand((1, MODEL // 2), device=device), | |
), | |
dim=1, | |
).expand((BATCH, SEQ, MODEL)) | |
v = torch.rand(BATCH, SEQ, MODEL, device=device) | |
# Normal call | |
res = multi_head(query=q, key=k, value=v) | |
for i in range(BATCH): | |
assert torch.allclose(res[i, :, :], res[i, 0, :].unsqueeze(-2)) | |
assert not torch.allclose(res[0, :, :], res[1, :, :]) | |
# Flip qkv, and check that we invert the above check properly | |
res_false = multi_head(query=v, key=k, value=q) | |
assert torch.allclose(res_false[0, :, :], res_false[1, :, :]) | |
def test_different_seqlen( | |
attention_name: str, | |
heads: int, | |
device: torch.device, | |
): | |
multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) | |
# Check kqv are not flipped | |
# this will not catch all issues, but would catch a V being misplaced | |
# make k and q complimentary, so that QKt is all zero and attention is uniform | |
q = torch.cat( | |
( | |
torch.rand((1, MODEL // 2), device=device), | |
torch.zeros((1, MODEL // 2), device=device), | |
), | |
dim=1, | |
).expand((BATCH, SEQ, MODEL)) | |
k = torch.cat( | |
( | |
torch.zeros((1, MODEL // 2), device=device), | |
torch.rand((1, MODEL // 2), device=device), | |
), | |
dim=1, | |
).expand((BATCH, SEQ, MODEL)) | |
v = torch.rand(BATCH, SEQ, MODEL, device=device) | |
# Normal call | |
res = multi_head(query=q, key=k, value=v) | |
# Changing sequence length by dividing by two to simulate differing sequence length | |
q2 = torch.cat( | |
( | |
torch.rand((1, MODEL // 2), device=device), | |
torch.zeros((1, MODEL // 2), device=device), | |
), | |
dim=1, | |
).expand((BATCH, SEQ // 2, MODEL)) | |
k2 = torch.cat( | |
( | |
torch.zeros((1, MODEL // 2), device=device), | |
torch.rand((1, MODEL // 2), device=device), | |
), | |
dim=1, | |
).expand((BATCH, SEQ // 2, MODEL)) | |
v2 = torch.rand(BATCH, SEQ // 2, MODEL, device=device) | |
res2 = multi_head(query=q2, key=k2, value=v2) | |
assert res.shape != res2.shape | |
def test_inproj(proj_bias: bool, same_sizes: bool, same_settings: bool): | |
test_config = { | |
"name": "scaled_dot_product", | |
"dropout": 0.1, | |
"causal": False, | |
"seq_len": SEQ, | |
"window_size": SEQ // 8 + 1, | |
"num_heads": 1, | |
"dim_head": MODEL, | |
} | |
attention = build_attention(test_config) | |
# Construct the initial projection, test different options | |
in_params = InputProjectionConfig(MODEL, MODEL, proj_bias) | |
if same_settings: | |
in_proj = InputProjection(in_params, None, None) | |
out_features = MODEL | |
else: | |
out_features = MODEL if same_sizes else MODEL // 2 | |
in_params_flip = InputProjectionConfig(MODEL, out_features, proj_bias) | |
in_proj = InputProjection( | |
in_params_flip, # Q proj | |
in_params_flip, # K proj | |
in_params, # V proj | |
) | |
# build a multi head dispatch to test this attention mechanism | |
multi_head = MultiHeadDispatch( | |
seq_len=SEQ, | |
dim_model=MODEL, | |
residual_dropout=0.1, | |
num_heads=1, | |
attention=attention, | |
in_proj_container=in_proj, | |
dim_key=out_features, | |
dim_value=MODEL, | |
) | |
# Check kqv are not flipped | |
# this will not catch all issues, but would catch a V being misplaced | |
# make k and q complimentary, so that QKt is all zero and attention is uniform | |
q = torch.cat( | |
( | |
torch.rand((1, MODEL // 2)), | |
torch.zeros((1, MODEL // 2)), | |
), | |
dim=1, | |
).expand((BATCH, SEQ, MODEL)) | |
k = torch.cat( | |
( | |
torch.zeros((1, MODEL // 2)), | |
torch.rand((1, MODEL // 2)), | |
), | |
dim=1, | |
).expand((BATCH, SEQ, MODEL)) | |
v = torch.rand(BATCH, SEQ, MODEL) | |
# just check that a FW does not assert out | |
_ = multi_head(query=q, key=k, value=v) | |
def test_different_kq_dimensions( | |
attention_name: str, | |
heads: int, | |
device: torch.device, | |
): | |
multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) | |
if multi_head.attention.requires_same_k_q_dimensions: | |
# pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre. | |
pytest.skip(f"{attention_name} does not support different k, q dimensions yet.") | |
seq_q = SEQ // 2 | |
q = torch.rand((BATCH, seq_q, MODEL), device=device) | |
k = torch.rand((BATCH, SEQ, MODEL), device=device) | |
v = torch.rand((BATCH, SEQ, MODEL), device=device) | |
res = multi_head(query=q, key=k, value=v) | |
assert res.shape == torch.Size([BATCH, seq_q, MODEL]) | |
def test_broadcast_batch_dimension( | |
attention_name: str, | |
heads: int, | |
device: torch.device, | |
batch_sizes: Tuple[int, int, int], | |
): | |
Q_BATCH, K_BATCH, V_BATCH = batch_sizes | |
multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) | |
if ( | |
int(math.sqrt(SEQ)) ** 2 != SEQ | |
and multi_head.attention.requires_squared_context | |
): | |
pytest.skip(f"{attention_name} requires squared sequence lengths") | |
if multi_head.attention.requires_same_k_q_dimensions: | |
# pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre. | |
pytest.skip(f"{attention_name} does not support different k, q dimensions yet.") | |
q = torch.rand((Q_BATCH, SEQ, MODEL), device=device) | |
k = torch.rand((K_BATCH, SEQ, MODEL), device=device) | |
v = torch.rand((V_BATCH, SEQ, MODEL), device=device) | |
res = multi_head(query=q, key=k, value=v) | |
assert res.shape == torch.Size([BATCH, SEQ, MODEL]) | |
def test_causal( | |
attention_name: str, | |
heads: int, | |
): | |
""" | |
Make sure that the causal flag is respected. | |
The input data is orthogonal by design if causal is respected, but if the attention looks ahead this will fail | |
""" | |
torch.random.manual_seed(42) | |
device = torch.device("cuda") | |
multi_head = _get_multihead( | |
attention_name, | |
0.0, | |
0.0, | |
causal=True, | |
heads=heads, | |
device=device, | |
skip_output_projection=True, | |
) | |
k = ( | |
torch.tril(torch.ones((SEQ, SEQ), device=device), diagonal=0) | |
.unsqueeze(0) | |
.expand(1, -1, -1) | |
) | |
q = ( | |
torch.triu(torch.ones((SEQ, SEQ), device=device), diagonal=0) | |
.unsqueeze(0) | |
.expand(1, -1, -1) | |
) | |
v = ( | |
torch.arange(SEQ, device=device) | |
.float() | |
.unsqueeze(0) | |
.unsqueeze(-1) | |
.expand(1, -1, SEQ) | |
) | |
# Make sure that we don´t project, to keep the embeddings orthogonal | |
multi_head.attention.requires_input_projection = False | |
res = multi_head(query=q, key=k, value=v).squeeze(0) | |
# Consolidate along the embedding, if causal was respected the amplitude should be sorted already | |
res_sum = torch.sum(res, dim=1).cpu() | |
assert torch.allclose(torch.sort(res_sum)[1], torch.arange(SEQ)) or torch.allclose( | |
torch.sort(res_sum, descending=True)[1], torch.arange(SEQ) | |
), res_sum | |
def test_torch_script_ability( | |
attention_name: str, | |
heads: int, | |
attn_dropout: float, | |
): | |
if attention_name in {"favor", "global", "local", "random"}: | |
# pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre. | |
pytest.skip(f"{attention_name} does not support scripting yet.") | |
device = torch.device("cpu") | |
multi_head = _get_multihead(attention_name, attn_dropout, 0.0, False, heads, device) | |
if ( | |
int(math.sqrt(SEQ)) ** 2 != SEQ | |
and multi_head.attention.requires_squared_context | |
): | |
pytest.skip(f"{attention_name} requires squared sequence lengths") | |
# input for tracing the function | |
q = torch.rand((BATCH, SEQ, MODEL), device=device) | |
k = torch.rand((BATCH, SEQ, MODEL), device=device) | |
v = torch.rand((BATCH, SEQ, MODEL), device=device) | |
# to make sure dropout behaves deterministically | |
torch.random.manual_seed(42) | |
# tracing the attention module | |
traced_multi_head = torch.jit.trace(multi_head, (q, k, v)) | |
# create new random inputs for testing the eager model and traced model | |
q = torch.rand((BATCH, SEQ, MODEL), device=device) | |
k = torch.rand((BATCH, SEQ, MODEL), device=device) | |
v = torch.rand((BATCH, SEQ, MODEL), device=device) | |
# to make sure dropout behaves deterministically need to set the seed again | |
torch.random.manual_seed(42) | |
res = multi_head(query=q, key=k, value=v) | |
# to make sure dropout behaves deterministically need to set the seed again | |
torch.random.manual_seed(42) | |
res_traced = traced_multi_head(query=q, key=k, value=v) | |
assert torch.allclose(res, res_traced) | |
# TODO: way more unit tests.. | |