|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import repeat |
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
|
from transformers.utils import is_flash_attn_2_available |
|
from transformers.utils import logging |
|
|
|
from .common import MLP, RMSNorm |
|
|
|
|
|
if is_flash_attn_2_available(): |
|
from flash_attn import flash_attn_func, flash_attn_varlen_func |
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
def _get_unpad_data(attention_mask): |
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
|
max_seqlen_in_batch = seqlens_in_batch.max().item() |
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) |
|
return ( |
|
indices, |
|
cu_seqlens, |
|
max_seqlen_in_batch, |
|
) |
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
class PerceiverAttention(nn.Module): |
|
def __init__(self, config) -> None: |
|
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" |
|
super().__init__() |
|
|
|
self.config = config |
|
|
|
self.hidden_size = config.hidden_size |
|
self.num_heads = config.perceiver_config.resampler_n_heads |
|
self.head_dim = config.perceiver_config.resampler_head_dim |
|
self.num_key_value_heads = config.perceiver_config.num_key_value_heads |
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver |
|
self.attention_dropout = config.perceiver_config.attention_dropout |
|
|
|
if self.qk_layer_norms: |
|
self.q_layer_norm = RMSNorm(self.head_dim) |
|
self.k_layer_norm = RMSNorm(self.head_dim) |
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
|
self.is_causal = False |
|
|
|
def forward( |
|
self, |
|
latents: torch.Tensor, |
|
context: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
""" |
|
Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension! |
|
:param context: Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample. |
|
:param latents: Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to. |
|
:return: Tensor of shape [bsz, n_latents, embed_dim] representing attention over latents w/ cross from context. |
|
""" |
|
bsz, q_len, _ = latents.size() |
|
kv_seq_len = q_len + context.size()[1] |
|
|
|
query_states = self.q_proj(latents).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = ( |
|
self.k_proj(torch.cat([context, latents], dim=-2)) |
|
.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim) |
|
.transpose(1, 2) |
|
) |
|
value_states = ( |
|
self.v_proj(torch.cat([context, latents], dim=-2)) |
|
.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim) |
|
.transpose(1, 2) |
|
) |
|
|
|
kv_seq_len = key_states.shape[-2] |
|
if past_key_value is not None: |
|
kv_seq_len += past_key_value[0].shape[-2] |
|
|
|
|
|
|
|
if past_key_value is not None: |
|
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
|
|
past_key_value = (key_states, value_states) if use_cache else None |
|
|
|
if self.qk_layer_norms: |
|
query_states = self.q_layer_norm(query_states) |
|
key_states = self.k_layer_norm(key_states) |
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" |
|
f" {attn_weights.size()}" |
|
) |
|
|
|
if attention_mask is not None: |
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
|
) |
|
|
|
attn_weights = attn_weights + attention_mask |
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) |
|
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
|
|
class PerceiverFlashAttention2(PerceiverAttention): |
|
""" |
|
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays |
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of |
|
flash attention and deal with padding tokens in case the input contains any of them. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward( |
|
self, |
|
latents: torch.Tensor, |
|
context: torch.Tensor, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
bsz, q_len, _ = latents.size() |
|
kv_seq_len = q_len + context.size()[1] |
|
|
|
|
|
|
|
query_states = self.q_proj(latents) |
|
key_states = self.k_proj(torch.cat([context, latents], dim=-2)) |
|
value_states = self.v_proj(torch.cat([context, latents], dim=-2)) |
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
kv_seq_len = key_states.shape[-2] |
|
if past_key_value is not None: |
|
kv_seq_len += past_key_value[0].shape[-2] |
|
|
|
if past_key_value is not None: |
|
|
|
if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window: |
|
slicing_tokens = kv_seq_len - self.config.sliding_window |
|
|
|
past_key = past_key_value[0] |
|
past_value = past_key_value[1] |
|
|
|
past_key = past_key[:, :, slicing_tokens:, :].contiguous() |
|
past_value = past_value[:, :, slicing_tokens:, :].contiguous() |
|
|
|
if past_key.shape[-2] != self.config.sliding_window - 1: |
|
raise ValueError( |
|
"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1," |
|
f" head_dim`), got {past_key.shape}" |
|
) |
|
|
|
past_key_value = (past_key, past_value) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attention_mask[:, slicing_tokens:] |
|
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) |
|
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
|
|
past_key_value = (key_states, value_states) if use_cache else None |
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
dropout_rate = 0.0 if not self.training else self.attention_dropout |
|
|
|
|
|
|
|
|
|
input_dtype = query_states.dtype |
|
if input_dtype == torch.float32: |
|
|
|
if hasattr(self.config, "_pre_quantization_dtype"): |
|
target_dtype = self.config._pre_quantization_dtype |
|
else: |
|
target_dtype = self.q_proj.weight.dtype |
|
|
|
logger.warning_once( |
|
"The input hidden states seems to be silently casted in float32, this might be related to the fact" |
|
" you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
|
f" {target_dtype}." |
|
) |
|
|
|
query_states = query_states.to(target_dtype) |
|
key_states = key_states.to(target_dtype) |
|
value_states = value_states.to(target_dtype) |
|
|
|
|
|
query_states = query_states.transpose(1, 2) |
|
key_states = key_states.transpose(1, 2) |
|
value_states = value_states.transpose(1, 2) |
|
|
|
attn_output = self._flash_attention_forward( |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
q_len, |
|
dropout=dropout_rate, |
|
use_sliding_windows=False, |
|
) |
|
|
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() |
|
attn_output = self.o_proj(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
def _flash_attention_forward( |
|
self, |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
query_length, |
|
dropout=0.0, |
|
softmax_scale=None, |
|
use_sliding_windows=False, |
|
): |
|
""" |
|
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token |
|
first unpad the input, then computes the attention scores and pad the final attention scores. |
|
|
|
Args: |
|
query_states (`torch.Tensor`): |
|
Input query states to be passed to Flash Attention API |
|
key_states (`torch.Tensor`): |
|
Input key states to be passed to Flash Attention API |
|
value_states (`torch.Tensor`): |
|
Input value states to be passed to Flash Attention API |
|
attention_mask (`torch.Tensor`): |
|
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the |
|
position of padding tokens and 1 for the position of non-padding tokens. |
|
dropout (`int`, *optional*): |
|
Attention dropout |
|
softmax_scale (`float`, *optional*): |
|
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) |
|
use_sliding_windows (`bool`, *optional*): |
|
Whether to activate sliding window attention. |
|
""" |
|
|
|
if attention_mask is not None: |
|
batch_size = query_states.shape[0] |
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( |
|
query_states, key_states, value_states, attention_mask, query_length |
|
) |
|
|
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
|
|
|
if not use_sliding_windows: |
|
attn_output_unpad = flash_attn_varlen_func( |
|
query_states, |
|
key_states, |
|
value_states, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_q=max_seqlen_in_batch_q, |
|
max_seqlen_k=max_seqlen_in_batch_k, |
|
dropout_p=dropout, |
|
softmax_scale=softmax_scale, |
|
causal=self.is_causal, |
|
) |
|
else: |
|
attn_output_unpad = flash_attn_varlen_func( |
|
query_states, |
|
key_states, |
|
value_states, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_q=max_seqlen_in_batch_q, |
|
max_seqlen_k=max_seqlen_in_batch_k, |
|
dropout_p=dropout, |
|
softmax_scale=softmax_scale, |
|
causal=self.is_causal, |
|
window_size=(self.config.sliding_window, self.config.sliding_window), |
|
) |
|
|
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) |
|
else: |
|
if not use_sliding_windows: |
|
attn_output = flash_attn_func( |
|
query_states, |
|
key_states, |
|
value_states, |
|
dropout, |
|
softmax_scale=softmax_scale, |
|
causal=self.is_causal, |
|
) |
|
else: |
|
attn_output = flash_attn_func( |
|
query_states, |
|
key_states, |
|
value_states, |
|
dropout, |
|
softmax_scale=softmax_scale, |
|
causal=self.is_causal, |
|
window_size=(self.config.sliding_window, self.config.sliding_window), |
|
) |
|
|
|
return attn_output |
|
|
|
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): |
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) |
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape |
|
|
|
key_layer = index_first_axis( |
|
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k |
|
) |
|
value_layer = index_first_axis( |
|
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k |
|
) |
|
if query_length == kv_seq_len: |
|
query_layer = index_first_axis( |
|
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k |
|
) |
|
cu_seqlens_q = cu_seqlens_k |
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k |
|
indices_q = indices_k |
|
elif query_length == 1: |
|
max_seqlen_in_batch_q = 1 |
|
cu_seqlens_q = torch.arange( |
|
batch_size + 1, dtype=torch.int32, device=query_layer.device |
|
) |
|
indices_q = cu_seqlens_q[:-1] |
|
query_layer = query_layer.squeeze(1) |
|
else: |
|
|
|
attention_mask = attention_mask[:, -query_length:] |
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) |
|
|
|
return ( |
|
query_layer, |
|
key_layer, |
|
value_layer, |
|
indices_q, |
|
(cu_seqlens_q, cu_seqlens_k), |
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
|
) |
|
|
|
|
|
class PerceiverLayer(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = config.hidden_size |
|
self.hidden_act = config.perceiver_config.hidden_act |
|
self.n_latents = config.perceiver_config.resampler_n_latents |
|
self.depth = config.perceiver_config.resampler_depth |
|
self.rms_norm_eps = config.rms_norm_eps |
|
self.intermediate_size = self.hidden_size * 4 |
|
|
|
self.input_latents_norm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) |
|
self.input_context_norm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) |
|
self.self_attn = ( |
|
PerceiverAttention(config) |
|
if not getattr(config, "_flash_attn_2_enabled", False) |
|
else PerceiverFlashAttention2(config) |
|
) |
|
self.post_attention_layernorm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) |
|
self.mlp = MLP( |
|
activation=self.hidden_act, |
|
input_size=self.hidden_size, |
|
intermediate_size=self.intermediate_size, |
|
output_size=self.hidden_size, |
|
) |
|
|
|
def forward( |
|
self, |
|
latents: torch.Tensor, |
|
context: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
**kwargs, |
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
""" |
|
Args: |
|
latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size |
|
`(batch, sequence_length)` where padding elements are indicated by 0. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
|
(see `past_key_values`). |
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
|
""" |
|
residual = latents |
|
|
|
latents = self.input_latents_norm(latents) |
|
context = self.input_context_norm(context) |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones( |
|
size=(context.size(0), context.size(1)), |
|
dtype=torch.bool, |
|
device=context.device, |
|
) |
|
attention_mask = torch.cat( |
|
[ |
|
attention_mask, |
|
torch.ones( |
|
size=(attention_mask.size(0), latents.size(1)), |
|
dtype=attention_mask.dtype, |
|
device=attention_mask.device, |
|
), |
|
], |
|
dim=-1, |
|
) |
|
latents, self_attn_weights, present_key_value = self.self_attn( |
|
latents=latents, |
|
context=context, |
|
attention_mask=( |
|
_prepare_4d_attention_mask(attention_mask, latents.dtype, tgt_len=self.n_latents) |
|
if not self.config._flash_attn_2_enabled |
|
else attention_mask |
|
), |
|
) |
|
latents = residual + latents |
|
residual = latents |
|
|
|
latents = self.post_attention_layernorm(latents) |
|
latents = self.mlp(latents) |
|
latents = residual + latents |
|
|
|
outputs = (latents,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
return outputs |
|
|
|
|
|
class PerceiverResampler(nn.Module): |
|
def __init__( |
|
self, |
|
config, |
|
) -> None: |
|
""" |
|
Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or |
|
MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then |
|
returns a Tensor of shape [bsz, n_latents, embed_dim]. |
|
:param embed_dim: Dimensionality of embeddings being fed to the Perceiver Resampler (also dimensionality of |
|
latent embeddings *returned* by the Perceiver Resampler. Could be e.g., VIT embed_dim, ResNet |
|
pool dim, and so on. |
|
:param depth: Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3). |
|
:param n_heads: Number of heads in each Transformer block (for multi-headed self-attention). |
|
:param head_dim: Dimensionality of each head projection in the Transformer block. |
|
:param n_latents: Number of latent embeddings to resample ("compress") the input sequence to (usually < 128). |
|
""" |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = config.hidden_size |
|
self.hidden_act = config.perceiver_config.hidden_act |
|
self.n_latents = config.perceiver_config.resampler_n_latents |
|
self.depth = config.perceiver_config.resampler_depth |
|
self.rms_norm_eps = config.rms_norm_eps |
|
|
|
|
|
self.latents = nn.Parameter(torch.ones(self.n_latents, self.hidden_size)) |
|
|
|
|
|
self.layers = nn.ModuleList([PerceiverLayer(config) for _ in range(self.depth)]) |
|
self.norm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) |
|
|
|
def forward( |
|
self, |
|
context: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
latents = repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0]) |
|
|
|
for perceiver_layer in self.layers: |
|
layer_outputs = perceiver_layer( |
|
latents, |
|
context, |
|
attention_mask=attention_mask, |
|
position_ids=None, |
|
past_key_value=None, |
|
output_attentions=False, |
|
use_cache=False, |
|
) |
|
|
|
latents = layer_outputs[0] |
|
|
|
latents = self.norm(latents) |
|
|
|
return latents |
|
|