|
from functools import partial |
|
|
|
import torch |
|
from torch.nn import functional as F |
|
import xformers.ops as xops |
|
|
|
|
|
def get_rectangular_causal_mask(shape, q_seq_len, k_seq_len, device, dtype): |
|
"""Create a rectangular causal mask. |
|
|
|
This is especially useful when query length < key length, and ensures that the attention tensor comes from a tensor |
|
that initially has dimensions that are a multiple of 8, as required by xformers. |
|
|
|
>>> get_rectangular_causal_mask((1, 1), 2, 2, "cpu", torch.float32) |
|
tensor([[[[0., -inf], |
|
[0., 0.]]]]) |
|
>>> get_rectangular_causal_mask((1, 1), 3, 5, "cpu", torch.float32) |
|
tensor([[[[0., 0., 0., -inf, -inf], |
|
[0., 0., 0., 0., -inf], |
|
[0., 0., 0., 0., 0.]]]]) |
|
>>> get_rectangular_causal_mask((1, 1), 5, 5, "cpu", torch.float32) |
|
tensor([[[[0., -inf, -inf, -inf, -inf], |
|
[0., 0., -inf, -inf, -inf], |
|
[0., 0., 0., -inf, -inf], |
|
[0., 0., 0., 0., -inf], |
|
[0., 0., 0., 0., 0.]]]]) |
|
""" |
|
|
|
next_multiple_8 = (k_seq_len + 7) // 8 * 8 |
|
|
|
mask = torch.ones((q_seq_len, k_seq_len), device=device, dtype=bool) |
|
mask[:, -q_seq_len:] = torch.tril(mask[:, -q_seq_len:], diagonal=0) |
|
|
|
output_mask = torch.zeros((*shape, q_seq_len, next_multiple_8), device=device, dtype=dtype) |
|
output_mask[:, :, :, :k_seq_len].masked_fill_(~mask, torch.finfo(dtype).min) |
|
return output_mask[:, :, :, :k_seq_len] |
|
|
|
|
|
def apply_attention_mask_(bias, attention_mask, queries_dtype): |
|
"""Applies attention mask (e.g., from HuggingFace generate) to an attention bias mask in-place. |
|
|
|
Args: |
|
bias (torch.Tensor, shape (batch_size, num_heads, q_seq_len, k_seq_len)) |
|
attention_mask (torch.Tensor, shape (batch_size, sequence_len)) |
|
queries_dtype: queries.dtype; used to get minimum value for masked indices. |
|
|
|
Returns: |
|
bias_with_mask (torch.Tensor, shape (batch_size, num_heads, q_seq_len, k_seq_len)) |
|
""" |
|
|
|
assert attention_mask.dim() == 2 |
|
|
|
mask_length = attention_mask.shape[-1] |
|
|
|
|
|
padding_mask = bias[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) |
|
min_dtype = torch.finfo(queries_dtype).min |
|
bias[..., :mask_length] = bias[..., :mask_length].masked_fill(padding_mask, min_dtype) |
|
|
|
|
|
|
|
|
|
bias.mul_(~torch.all(bias == min_dtype, dim=-1, keepdim=True)) |
|
|
|
|
|
def xformers_attn(queries, keys, values, is_causal, attention_mask=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bias = None |
|
if is_causal and queries.shape[1] == keys.shape[1] and attention_mask is None: |
|
bias = xops.LowerTriangularMask() |
|
elif is_causal and (queries.shape[1] > 1 or attention_mask is not None): |
|
|
|
batch, q_seq_len, heads, _ = queries.shape |
|
k_seq_len = keys.shape[1] |
|
bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) |
|
if attention_mask is not None: |
|
apply_attention_mask_(bias, attention_mask, queries_dtype=queries.dtype) |
|
elif not is_causal and attention_mask is not None: |
|
raise NotImplementedError("attention_mask with is_causal=False is not yet implemented.") |
|
return xops.memory_efficient_attention(queries, keys, values, attn_bias=bias) |
|
|
|
|
|
def torch_attn(queries, keys, values, is_causal, attention_mask=None): |
|
|
|
|
|
|
|
if is_causal and keys.shape[1] > queries.shape[1] > 1: |
|
q_seq_len = queries.shape[1] |
|
k_seq_len = keys.shape[1] |
|
|
|
|
|
mask = get_rectangular_causal_mask((1, 1), q_seq_len, k_seq_len, queries.device, queries.dtype) |
|
if attention_mask is not None: |
|
apply_attention_mask_(mask, attention_mask, queries_dtype=queries.dtype) |
|
return ( |
|
F.scaled_dot_product_attention( |
|
queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask |
|
) |
|
.transpose(1, 2) |
|
.contiguous() |
|
) |
|
else: |
|
if attention_mask is None: |
|
bias = None |
|
|
|
if queries.shape == 1: |
|
is_causal = False |
|
else: |
|
if not is_causal: |
|
raise NotImplementedError("attention_mask with is_causal=False is not yet implemented.") |
|
|
|
batch, q_seq_len, heads, _ = queries.shape |
|
k_seq_len = keys.shape[1] |
|
bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) |
|
if attention_mask is not None: |
|
apply_attention_mask_(bias, attention_mask, queries_dtype=queries.dtype) |
|
|
|
is_causal = False |
|
return ( |
|
F.scaled_dot_product_attention( |
|
queries.transpose(1, 2), |
|
keys.transpose(1, 2), |
|
values.transpose(1, 2), |
|
attn_mask=bias, |
|
is_causal=is_causal, |
|
) |
|
.transpose(1, 2) |
|
.contiguous() |
|
) |
|
|
|
|
|
ATTN_ACTIVATIONS = { |
|
"relu": F.relu, |
|
"relu_squared": lambda x: torch.pow(F.relu(x), 2), |
|
|
|
"softplus": F.softplus, |
|
"identity": lambda x: x, |
|
"relu6": F.relu6, |
|
"sigmoid": F.sigmoid, |
|
"softmax": partial(F.softmax, dim=-1), |
|
} |
|
|
|
ATTN_SEQ_SCALARS = { |
|
"max": lambda x: x, |
|
|
|
"avg": lambda x: (x - 1) / 2 + 1, |
|
"none": lambda _: 1, |
|
} |
|
|
|
|
|
def custom_attn( |
|
queries, |
|
keys, |
|
values, |
|
attn_activation, |
|
attn_seq_scalar, |
|
alpha, |
|
is_causal=False, |
|
attention_mask=None, |
|
) -> torch.Tensor: |
|
|
|
|
|
if attention_mask is not None: |
|
raise NotImplementedError("attention_mask not yet implemented for custom_attn.") |
|
|
|
batch, q_seq_len, heads, embed_dim = queries.shape |
|
_, k_seq_len, _, _ = keys.shape |
|
|
|
attn_bias = torch.zeros(batch, heads, q_seq_len, k_seq_len, device=queries.device, dtype=queries.dtype) |
|
if is_causal and queries.shape[1] > 1: |
|
attn_bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) |
|
|
|
inner_scale = embed_dim**-0.5 |
|
attn_weight = torch.einsum("bqhd,bkhd->bhqk", inner_scale * queries, keys) |
|
attn_weight += attn_bias |
|
|
|
|
|
outter_scale = ATTN_SEQ_SCALARS[attn_seq_scalar](k_seq_len) ** -alpha |
|
attn_weight = outter_scale * ATTN_ACTIVATIONS[attn_activation](attn_weight) |
|
|
|
return torch.einsum("bhqk,bkhd->bqhd", attn_weight, values) |
|
|
|
|
|
def get_attn_func( |
|
attn_name, |
|
attn_activation=None, |
|
attn_seq_scalar=None, |
|
alpha=None, |
|
): |
|
if attn_name == "auto": |
|
return xformers_attn if torch.cuda.is_available() else torch_attn |
|
elif attn_name == "xformers_attn": |
|
return xformers_attn |
|
elif attn_name == "xformers_attn_variable_length": |
|
|
|
|
|
|
|
|
|
return lambda *args, **kwargs: xformers_attn(*args, **kwargs).contiguous() |
|
elif attn_name == "torch_attn": |
|
return torch_attn |
|
elif attn_name == "custom_attn": |
|
assert ( |
|
attn_activation is not None and attn_seq_scalar is not None and alpha is not None |
|
), "must provide attn-activation, attn-seq-scalar, attn-seq-scalar-alpha" |
|
return partial( |
|
custom_attn, |
|
attn_activation, |
|
attn_seq_scalar, |
|
alpha, |
|
) |
|
else: |
|
raise ValueError(f"Unsupported attn-name: {attn_name}") |
|
|