olm-chat-7b / open_lm /attention.py
henhenhahi111112's picture
Upload folder using huggingface_hub
af6e330 verified
raw
history blame
10.3 kB
from functools import partial
import torch
from torch.nn import functional as F
import xformers.ops as xops
def get_rectangular_causal_mask(shape, q_seq_len, k_seq_len, device, dtype):
"""Create a rectangular causal mask.
This is especially useful when query length < key length, and ensures that the attention tensor comes from a tensor
that initially has dimensions that are a multiple of 8, as required by xformers.
>>> get_rectangular_causal_mask((1, 1), 2, 2, "cpu", torch.float32)
tensor([[[[0., -inf],
[0., 0.]]]])
>>> get_rectangular_causal_mask((1, 1), 3, 5, "cpu", torch.float32)
tensor([[[[0., 0., 0., -inf, -inf],
[0., 0., 0., 0., -inf],
[0., 0., 0., 0., 0.]]]])
>>> get_rectangular_causal_mask((1, 1), 5, 5, "cpu", torch.float32)
tensor([[[[0., -inf, -inf, -inf, -inf],
[0., 0., -inf, -inf, -inf],
[0., 0., 0., -inf, -inf],
[0., 0., 0., 0., -inf],
[0., 0., 0., 0., 0.]]]])
"""
# xformers requires the mask to be built with a shape that is a multiple of 8
next_multiple_8 = (k_seq_len + 7) // 8 * 8 #
mask = torch.ones((q_seq_len, k_seq_len), device=device, dtype=bool)
mask[:, -q_seq_len:] = torch.tril(mask[:, -q_seq_len:], diagonal=0)
output_mask = torch.zeros((*shape, q_seq_len, next_multiple_8), device=device, dtype=dtype)
output_mask[:, :, :, :k_seq_len].masked_fill_(~mask, torch.finfo(dtype).min)
return output_mask[:, :, :, :k_seq_len]
def apply_attention_mask_(bias, attention_mask, queries_dtype):
"""Applies attention mask (e.g., from HuggingFace generate) to an attention bias mask in-place.
Args:
bias (torch.Tensor, shape (batch_size, num_heads, q_seq_len, k_seq_len))
attention_mask (torch.Tensor, shape (batch_size, sequence_len))
queries_dtype: queries.dtype; used to get minimum value for masked indices.
Returns:
bias_with_mask (torch.Tensor, shape (batch_size, num_heads, q_seq_len, k_seq_len))
"""
# Update mask to remove attention based on attention_mask that's passed in.
assert attention_mask.dim() == 2
# From https://github.com/huggingface/transformers/blob/f738ab3b5d30e30c43a4c3d00ca8939f8a4d4427/src/transformers/models/llama/modeling_llama.py#L1089C1-L1091C117
mask_length = attention_mask.shape[-1]
# Set parts of bias that are zero (i.e., where attention is allowed) _and_ attention_mask is False (i.e.,
# where we should not attend) with min_dtype.
padding_mask = bias[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
min_dtype = torch.finfo(queries_dtype).min
bias[..., :mask_length] = bias[..., :mask_length].masked_fill(padding_mask, min_dtype)
# Disable masking for sequence indices where all attention weights are -inf
# We won't use these anyway, and keeping them as -inf leads to nans.
# See https://github.com/huggingface/transformers/blob/f738ab3b5d30e30c43a4c3d00ca8939f8a4d4427/src/transformers/modeling_attn_mask_utils.py#L189
# for details.
bias.mul_(~torch.all(bias == min_dtype, dim=-1, keepdim=True))
def xformers_attn(queries, keys, values, is_causal, attention_mask=None):
# xformers assumes q, k, v are [batch, seq_len, heads, embed_dim]
# We assume that queries match the last part of the key / value sequences
# see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask)
# we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask()
# sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1
# If queries have shape [batch, 1, heads, dim] it means there is only one query in the sequence.
# In this case, there is no notion of causal masking, so we can just set the mask to None.
# This is actually needed to get the desired behavior with seq_len=1.
bias = None
if is_causal and queries.shape[1] == keys.shape[1] and attention_mask is None:
bias = xops.LowerTriangularMask()
elif is_causal and (queries.shape[1] > 1 or attention_mask is not None):
# Build causal mask that assumes queries are in the end of the sequence.
batch, q_seq_len, heads, _ = queries.shape
k_seq_len = keys.shape[1]
bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype)
if attention_mask is not None:
apply_attention_mask_(bias, attention_mask, queries_dtype=queries.dtype)
elif not is_causal and attention_mask is not None:
raise NotImplementedError("attention_mask with is_causal=False is not yet implemented.")
return xops.memory_efficient_attention(queries, keys, values, attn_bias=bias)
def torch_attn(queries, keys, values, is_causal, attention_mask=None):
# Need to call contiguous in torch >=2.1, otherwise later calls to .view() fail.
# Possibly related: https://github.com/pytorch/pytorch/issues/110213 - behavior of scaled_dot_product_attention
# changed between 2.0 and 2.1
if is_causal and keys.shape[1] > queries.shape[1] > 1:
q_seq_len = queries.shape[1]
k_seq_len = keys.shape[1]
# Same as above, we would like to use:
# mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask().materialize((1, 1, q_seq_len, k_seq_len), queries.dtype, queries.device)
mask = get_rectangular_causal_mask((1, 1), q_seq_len, k_seq_len, queries.device, queries.dtype)
if attention_mask is not None:
apply_attention_mask_(mask, attention_mask, queries_dtype=queries.dtype)
return (
F.scaled_dot_product_attention(
queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask
)
.transpose(1, 2)
.contiguous()
)
else:
if attention_mask is None:
bias = None
# If we only have one query, assume we don't need to be in causal mode (can attend to all keys).
if queries.shape == 1:
is_causal = False
else:
if not is_causal:
raise NotImplementedError("attention_mask with is_causal=False is not yet implemented.")
# Build causal mask that assumes queries are in the end of the sequence.
batch, q_seq_len, heads, _ = queries.shape
k_seq_len = keys.shape[1]
bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype)
if attention_mask is not None:
apply_attention_mask_(bias, attention_mask, queries_dtype=queries.dtype)
# We apply causal mask in attention instead of using is_causal=True.
is_causal = False
return (
F.scaled_dot_product_attention(
queries.transpose(1, 2),
keys.transpose(1, 2),
values.transpose(1, 2),
attn_mask=bias,
is_causal=is_causal,
)
.transpose(1, 2)
.contiguous()
)
ATTN_ACTIVATIONS = {
"relu": F.relu,
"relu_squared": lambda x: torch.pow(F.relu(x), 2),
# "gelu": F.gelu, # goes to NaN with bais so comment out for now
"softplus": F.softplus,
"identity": lambda x: x,
"relu6": F.relu6,
"sigmoid": F.sigmoid,
"softmax": partial(F.softmax, dim=-1),
}
ATTN_SEQ_SCALARS = {
"max": lambda x: x,
# "seq": lambda x: torch.arange(x) + 1, # comment out for now more involved
"avg": lambda x: (x - 1) / 2 + 1,
"none": lambda _: 1,
}
def custom_attn(
queries,
keys,
values,
attn_activation,
attn_seq_scalar,
alpha,
is_causal=False,
attention_mask=None,
) -> torch.Tensor:
# naive reference implementation for relu-attention following: https://arxiv.org/pdf/2309.08586.pdf
# code modifies: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
if attention_mask is not None:
raise NotImplementedError("attention_mask not yet implemented for custom_attn.")
batch, q_seq_len, heads, embed_dim = queries.shape
_, k_seq_len, _, _ = keys.shape
attn_bias = torch.zeros(batch, heads, q_seq_len, k_seq_len, device=queries.device, dtype=queries.dtype)
if is_causal and queries.shape[1] > 1:
attn_bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype)
inner_scale = embed_dim**-0.5
attn_weight = torch.einsum("bqhd,bkhd->bhqk", inner_scale * queries, keys)
attn_weight += attn_bias
# scaling by: 1/L^{-\alpha}
outter_scale = ATTN_SEQ_SCALARS[attn_seq_scalar](k_seq_len) ** -alpha
attn_weight = outter_scale * ATTN_ACTIVATIONS[attn_activation](attn_weight)
return torch.einsum("bhqk,bkhd->bqhd", attn_weight, values)
def get_attn_func(
attn_name,
attn_activation=None,
attn_seq_scalar=None,
alpha=None,
):
if attn_name == "auto":
return xformers_attn if torch.cuda.is_available() else torch_attn
elif attn_name == "xformers_attn":
return xformers_attn
elif attn_name == "xformers_attn_variable_length":
# Upon changing the input sequence length, xformers attention changes
# the stride dimension of the output tensor. This makes future calls to
# .view() that collapses last two dimensions fail. One thus needs to
# call .contiguous() on the output tensor. [#188]
return lambda *args, **kwargs: xformers_attn(*args, **kwargs).contiguous()
elif attn_name == "torch_attn":
return torch_attn
elif attn_name == "custom_attn":
assert (
attn_activation is not None and attn_seq_scalar is not None and alpha is not None
), "must provide attn-activation, attn-seq-scalar, attn-seq-scalar-alpha"
return partial(
custom_attn,
attn_activation,
attn_seq_scalar,
alpha,
)
else:
raise ValueError(f"Unsupported attn-name: {attn_name}")