from functools import partial import torch from torch.nn import functional as F import xformers.ops as xops def get_rectangular_causal_mask(shape, q_seq_len, k_seq_len, device, dtype): """Create a rectangular causal mask. This is especially useful when query length < key length, and ensures that the attention tensor comes from a tensor that initially has dimensions that are a multiple of 8, as required by xformers. >>> get_rectangular_causal_mask((1, 1), 2, 2, "cpu", torch.float32) tensor([[[[0., -inf], [0., 0.]]]]) >>> get_rectangular_causal_mask((1, 1), 3, 5, "cpu", torch.float32) tensor([[[[0., 0., 0., -inf, -inf], [0., 0., 0., 0., -inf], [0., 0., 0., 0., 0.]]]]) >>> get_rectangular_causal_mask((1, 1), 5, 5, "cpu", torch.float32) tensor([[[[0., -inf, -inf, -inf, -inf], [0., 0., -inf, -inf, -inf], [0., 0., 0., -inf, -inf], [0., 0., 0., 0., -inf], [0., 0., 0., 0., 0.]]]]) """ # xformers requires the mask to be built with a shape that is a multiple of 8 next_multiple_8 = (k_seq_len + 7) // 8 * 8 # mask = torch.ones((q_seq_len, k_seq_len), device=device, dtype=bool) mask[:, -q_seq_len:] = torch.tril(mask[:, -q_seq_len:], diagonal=0) output_mask = torch.zeros((*shape, q_seq_len, next_multiple_8), device=device, dtype=dtype) output_mask[:, :, :, :k_seq_len].masked_fill_(~mask, torch.finfo(dtype).min) return output_mask[:, :, :, :k_seq_len] def apply_attention_mask_(bias, attention_mask, queries_dtype): """Applies attention mask (e.g., from HuggingFace generate) to an attention bias mask in-place. Args: bias (torch.Tensor, shape (batch_size, num_heads, q_seq_len, k_seq_len)) attention_mask (torch.Tensor, shape (batch_size, sequence_len)) queries_dtype: queries.dtype; used to get minimum value for masked indices. Returns: bias_with_mask (torch.Tensor, shape (batch_size, num_heads, q_seq_len, k_seq_len)) """ # Update mask to remove attention based on attention_mask that's passed in. assert attention_mask.dim() == 2 # From mask_length = attention_mask.shape[-1] # Set parts of bias that are zero (i.e., where attention is allowed) _and_ attention_mask is False (i.e., # where we should not attend) with min_dtype. padding_mask = bias[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) min_dtype = torch.finfo(queries_dtype).min bias[..., :mask_length] = bias[..., :mask_length].masked_fill(padding_mask, min_dtype) # Disable masking for sequence indices where all attention weights are -inf # We won't use these anyway, and keeping them as -inf leads to nans. # See # for details. bias.mul_(~torch.all(bias == min_dtype, dim=-1, keepdim=True)) def xformers_attn(queries, keys, values, is_causal, attention_mask=None): # xformers assumes q, k, v are [batch, seq_len, heads, embed_dim] # We assume that queries match the last part of the key / value sequences # see ( # we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask() # sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1 # If queries have shape [batch, 1, heads, dim] it means there is only one query in the sequence. # In this case, there is no notion of causal masking, so we can just set the mask to None. # This is actually needed to get the desired behavior with seq_len=1. bias = None if is_causal and queries.shape[1] == keys.shape[1] and attention_mask is None: bias = xops.LowerTriangularMask() elif is_causal and (queries.shape[1] > 1 or attention_mask is not None): # Build causal mask that assumes queries are in the end of the sequence. batch, q_seq_len, heads, _ = queries.shape k_seq_len = keys.shape[1] bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) if attention_mask is not None: apply_attention_mask_(bias, attention_mask, queries_dtype=queries.dtype) elif not is_causal and attention_mask is not None: raise NotImplementedError("attention_mask with is_causal=False is not yet implemented.") return xops.memory_efficient_attention(queries, keys, values, attn_bias=bias) def torch_attn(queries, keys, values, is_causal, attention_mask=None): # Need to call contiguous in torch >=2.1, otherwise later calls to .view() fail. # Possibly related: - behavior of scaled_dot_product_attention # changed between 2.0 and 2.1 if is_causal and keys.shape[1] > queries.shape[1] > 1: q_seq_len = queries.shape[1] k_seq_len = keys.shape[1] # Same as above, we would like to use: # mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask().materialize((1, 1, q_seq_len, k_seq_len), queries.dtype, queries.device) mask = get_rectangular_causal_mask((1, 1), q_seq_len, k_seq_len, queries.device, queries.dtype) if attention_mask is not None: apply_attention_mask_(mask, attention_mask, queries_dtype=queries.dtype) return ( F.scaled_dot_product_attention( queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask ) .transpose(1, 2) .contiguous() ) else: if attention_mask is None: bias = None # If we only have one query, assume we don't need to be in causal mode (can attend to all keys). if queries.shape == 1: is_causal = False else: if not is_causal: raise NotImplementedError("attention_mask with is_causal=False is not yet implemented.") # Build causal mask that assumes queries are in the end of the sequence. batch, q_seq_len, heads, _ = queries.shape k_seq_len = keys.shape[1] bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) if attention_mask is not None: apply_attention_mask_(bias, attention_mask, queries_dtype=queries.dtype) # We apply causal mask in attention instead of using is_causal=True. is_causal = False return ( F.scaled_dot_product_attention( queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=bias, is_causal=is_causal, ) .transpose(1, 2) .contiguous() ) ATTN_ACTIVATIONS = { "relu": F.relu, "relu_squared": lambda x: torch.pow(F.relu(x), 2), # "gelu": F.gelu, # goes to NaN with bais so comment out for now "softplus": F.softplus, "identity": lambda x: x, "relu6": F.relu6, "sigmoid": F.sigmoid, "softmax": partial(F.softmax, dim=-1), } ATTN_SEQ_SCALARS = { "max": lambda x: x, # "seq": lambda x: torch.arange(x) + 1, # comment out for now more involved "avg": lambda x: (x - 1) / 2 + 1, "none": lambda _: 1, } def custom_attn( queries, keys, values, attn_activation, attn_seq_scalar, alpha, is_causal=False, attention_mask=None, ) -> torch.Tensor: # naive reference implementation for relu-attention following: # code modifies: if attention_mask is not None: raise NotImplementedError("attention_mask not yet implemented for custom_attn.") batch, q_seq_len, heads, embed_dim = queries.shape _, k_seq_len, _, _ = keys.shape attn_bias = torch.zeros(batch, heads, q_seq_len, k_seq_len, device=queries.device, dtype=queries.dtype) if is_causal and queries.shape[1] > 1: attn_bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) inner_scale = embed_dim**-0.5 attn_weight = torch.einsum("bqhd,bkhd->bhqk", inner_scale * queries, keys) attn_weight += attn_bias # scaling by: 1/L^{-\alpha} outter_scale = ATTN_SEQ_SCALARS[attn_seq_scalar](k_seq_len) ** -alpha attn_weight = outter_scale * ATTN_ACTIVATIONS[attn_activation](attn_weight) return torch.einsum("bhqk,bkhd->bqhd", attn_weight, values) def get_attn_func( attn_name, attn_activation=None, attn_seq_scalar=None, alpha=None, ): if attn_name == "auto": return xformers_attn if torch.cuda.is_available() else torch_attn elif attn_name == "xformers_attn": return xformers_attn elif attn_name == "xformers_attn_variable_length": # Upon changing the input sequence length, xformers attention changes # the stride dimension of the output tensor. This makes future calls to # .view() that collapses last two dimensions fail. One thus needs to # call .contiguous() on the output tensor. [#188] return lambda *args, **kwargs: xformers_attn(*args, **kwargs).contiguous() elif attn_name == "torch_attn": return torch_attn elif attn_name == "custom_attn": assert ( attn_activation is not None and attn_seq_scalar is not None and alpha is not None ), "must provide attn-activation, attn-seq-scalar, attn-seq-scalar-alpha" return partial( custom_attn, attn_activation, attn_seq_scalar, alpha, ) else: raise ValueError(f"Unsupported attn-name: {attn_name}")