# coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from einops import repeat from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.utils import is_flash_attn_2_available from transformers.utils import logging from .common import MLP, RMSNorm if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa logger = logging.get_logger(__name__) # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( indices, cu_seqlens, max_seqlen_in_batch, ) # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class PerceiverAttention(nn.Module): def __init__(self, config) -> None: """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.perceiver_config.resampler_n_heads self.head_dim = config.perceiver_config.resampler_head_dim self.num_key_value_heads = config.perceiver_config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver self.attention_dropout = config.perceiver_config.attention_dropout if self.qk_layer_norms: self.q_layer_norm = RMSNorm(self.head_dim) self.k_layer_norm = RMSNorm(self.head_dim) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.is_causal = False def forward( self, latents: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension! :param context: Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample. :param latents: Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to. :return: Tensor of shape [bsz, n_latents, embed_dim] representing attention over latents w/ cross from context. """ bsz, q_len, _ = latents.size() kv_seq_len = q_len + context.size()[1] query_states = self.q_proj(latents).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = ( self.k_proj(torch.cat([context, latents], dim=-2)) .view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim) .transpose(1, 2) ) value_states = ( self.v_proj(torch.cat([context, latents], dim=-2)) .view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim) .transpose(1, 2) ) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None if self.qk_layer_norms: query_states = self.q_layer_norm(query_states) key_states = self.k_layer_norm(key_states) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class PerceiverFlashAttention2(PerceiverAttention): """ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward( self, latents: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = latents.size() kv_seq_len = q_len + context.size()[1] # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn! # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents` query_states = self.q_proj(latents) key_states = self.k_proj(torch.cat([context, latents], dim=-2)) value_states = self.v_proj(torch.cat([context, latents], dim=-2)) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window: slicing_tokens = kv_seq_len - self.config.sliding_window past_key = past_key_value[0] past_value = past_key_value[1] past_key = past_key[:, :, slicing_tokens:, :].contiguous() past_value = past_value[:, :, slicing_tokens:, :].contiguous() if past_key.shape[-2] != self.config.sliding_window - 1: raise ValueError( "past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1," f" head_dim`), got {past_key.shape}" ) past_key_value = (past_key, past_value) if attention_mask is not None: attention_mask = attention_mask[:, slicing_tokens:] attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) dropout_rate = 0.0 if not self.training else self.attention_dropout # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: # Handle the case where the model is quantized if hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( "The input hidden states seems to be silently casted in float32, this might be related to the fact" " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) # Reashape to the expected shape for Flash Attention query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, use_sliding_windows=False, ) attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def _flash_attention_forward( self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, use_sliding_windows=False, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. Args: query_states (`torch.Tensor`): Input query states to be passed to Flash Attention API key_states (`torch.Tensor`): Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`int`, *optional*): Attention dropout softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) use_sliding_windows (`bool`, *optional*): Whether to activate sliding window attention. """ # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens if not use_sliding_windows: attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=self.is_causal, ) else: attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=self.is_causal, window_size=(self.config.sliding_window, self.config.sliding_window), ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: if not use_sliding_windows: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal, ) else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal, window_size=(self.config.sliding_window, self.config.sliding_window), ) return attn_output def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) value_layer = index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) class PerceiverLayer(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.hidden_act = config.perceiver_config.hidden_act self.n_latents = config.perceiver_config.resampler_n_latents self.depth = config.perceiver_config.resampler_depth self.rms_norm_eps = config.rms_norm_eps self.intermediate_size = self.hidden_size * 4 self.input_latents_norm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) self.input_context_norm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) self.self_attn = ( PerceiverAttention(config) if not getattr(config, "_flash_attn_2_enabled", False) else PerceiverFlashAttention2(config) ) self.post_attention_layernorm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) self.mlp = MLP( activation=self.hidden_act, input_size=self.hidden_size, intermediate_size=self.intermediate_size, output_size=self.hidden_size, ) def forward( self, latents: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = latents latents = self.input_latents_norm(latents) context = self.input_context_norm(context) if attention_mask is None: attention_mask = torch.ones( size=(context.size(0), context.size(1)), dtype=torch.bool, device=context.device, ) attention_mask = torch.cat( [ attention_mask, torch.ones( size=(attention_mask.size(0), latents.size(1)), dtype=attention_mask.dtype, device=attention_mask.device, ), ], dim=-1, ) latents, self_attn_weights, present_key_value = self.self_attn( latents=latents, context=context, attention_mask=( _prepare_4d_attention_mask(attention_mask, latents.dtype, tgt_len=self.n_latents) if not self.config._flash_attn_2_enabled else attention_mask ), ) latents = residual + latents residual = latents latents = self.post_attention_layernorm(latents) latents = self.mlp(latents) latents = residual + latents outputs = (latents,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs class PerceiverResampler(nn.Module): def __init__( self, config, ) -> None: """ Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler. Could be e.g., VIT embed_dim, ResNet pool dim, and so on. :param depth: Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3). :param n_heads: Number of heads in each Transformer block (for multi-headed self-attention). :param head_dim: Dimensionality of each head projection in the Transformer block. :param n_latents: Number of latent embeddings to resample ("compress") the input sequence to (usually < 128). """ super().__init__() self.config = config self.hidden_size = config.hidden_size self.hidden_act = config.perceiver_config.hidden_act self.n_latents = config.perceiver_config.resampler_n_latents self.depth = config.perceiver_config.resampler_depth self.rms_norm_eps = config.rms_norm_eps # Create Latents for Perceiver self.latents = nn.Parameter(torch.ones(self.n_latents, self.hidden_size)) # Create Transformer Blocks self.layers = nn.ModuleList([PerceiverLayer(config) for _ in range(self.depth)]) self.norm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) def forward( self, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: latents = repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0]) for perceiver_layer in self.layers: layer_outputs = perceiver_layer( latents, context, attention_mask=attention_mask, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, ) latents = layer_outputs[0] latents = self.norm(latents) return latents