|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""perceiver_mod.py |
|
|
|
Implementation of the PerceiverTF encoder with: |
|
- AliBi positional bias |
|
- Mixtral of Experts (MoE) feedforward layer |
|
|
|
""" |
|
import math |
|
from einops import rearrange |
|
from typing import Optional, Tuple, Union, List, Dict, Literal |
|
|
|
import torch |
|
from torch import nn |
|
from transformers.models.perceiver.modeling_perceiver import PerceiverSelfOutput |
|
from transformers.pytorch_utils import (apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer) |
|
from model.perceiver_helper import MoEModelOutputWithCrossAttentions |
|
from model.perceiver_helper import PerceiverTFPreTrainedModel, PerceiverTFConfig |
|
from model.positional_encoding import AlibiPositionalBias, get_rotary_emb |
|
from model.ops import get_layer_norm |
|
from model.ff_layer import get_ff_layer |
|
|
|
|
|
class PerceiverEmbeddings(nn.Module): |
|
"""Construct the latent embeddings sharable with token embeddings in the decoder.""" |
|
|
|
def __init__(self, config, shared_emb: Optional[nn.Parameter] = None): |
|
super().__init__() |
|
if shared_emb is not None: |
|
self.latents = shared_emb |
|
assert self.latents.shape == (config.num_latents, config.d_latents) |
|
else: |
|
self.latents = nn.Parameter(torch.randn(config.num_latents, config.d_latents)) |
|
|
|
def forward(self, batch_size: int): |
|
return self.latents.expand(batch_size, -1, -1) |
|
|
|
|
|
class PerceiverTFTrainablePE(nn.Module): |
|
"""Construct the trainable absolute positional embeddings.""" |
|
|
|
def __init__(self, position_encoding_type: Literal['trainable', 'tkd', 'td', 'tk', 'kdt'], max_t: int, k: int, |
|
d: int) -> None: |
|
super().__init__() |
|
self.position_encoding_type = position_encoding_type |
|
self.max_t = max_t |
|
self.k = k |
|
self.d = d |
|
|
|
if position_encoding_type in ['trainable', 'tkd']: |
|
self._pos_emb = nn.Parameter(torch.randn(max_t, k, d)) |
|
elif position_encoding_type == 'td': |
|
self._pos_emb = nn.Parameter(torch.randn(max_t, d)) |
|
elif position_encoding_type == 'tk': |
|
self._pos_emb = nn.Parameter(torch.randn(max_t, k)) |
|
elif position_encoding_type == 'kdt': |
|
self._pos_emb = nn.Parameter(torch.randn(k, d)) |
|
self._pos_emb_temporal = nn.Parameter(torch.randn(max_t, d)) |
|
else: |
|
raise ValueError(f'unknown position encoding type {position_encoding_type}') |
|
|
|
def forward(self): |
|
pos_emb_temporal = None |
|
|
|
if self.position_encoding_type in ['trainable', 'tkd']: |
|
pos_emb = self._pos_emb |
|
elif self.position_encoding_type == 'td': |
|
pos_emb = self._pos_emb.unsqueeze(1).expand(-1, self.k, -1) |
|
elif self.position_encoding_type == 'tk': |
|
pos_emb = self._pos_emb.unsqueeze(-1).expand(-1, -1, self.d) |
|
elif self.position_encoding_type == 'kdt': |
|
pos_emb = self._pos_emb.unsqueeze(0).expand(self.max_t, -1, -1) |
|
pos_emb_temporal = self._pos_emb_temporal |
|
|
|
return pos_emb, pos_emb_temporal |
|
|
|
|
|
class PerceiverAlibiSelfAttention(nn.Module): |
|
""" |
|
Multi-headed {cross, self}-attention + Alibi/Rotary positional bias/emb: |
|
- Can be used both in the encoder as well as in the decoder. |
|
- Modified from PerceiverSelfAttention in modeling_perceiver.py to support Alibi positional bias |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
config, |
|
is_cross_attention=False, |
|
qk_channels=None, |
|
v_channels=None, |
|
num_heads=1, |
|
q_dim=None, |
|
kv_dim=None, |
|
rotary_emb=None, |
|
): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
|
|
|
|
if qk_channels is None: |
|
qk_channels = q_dim |
|
|
|
|
|
if v_channels is None: |
|
v_channels = qk_channels |
|
if qk_channels % num_heads != 0: |
|
raise ValueError(f"qk_channels ({qk_channels}) must be divisible by num_heads ({num_heads}).") |
|
if v_channels % num_heads != 0: |
|
raise ValueError(f"v_channels ({v_channels}) must be divisible by num_heads ({num_heads}).") |
|
|
|
self.qk_channels = qk_channels |
|
self.v_channels = v_channels |
|
self.qk_channels_per_head = self.qk_channels // num_heads |
|
self.v_channels_per_head = self.v_channels // num_heads |
|
|
|
|
|
self.layernorm1 = get_layer_norm(q_dim, config.layer_norm_type, config.layer_norm_eps) |
|
if is_cross_attention: |
|
self.layernorm2 = get_layer_norm(kv_dim, config.layer_norm_type, config.layer_norm_eps) |
|
else: |
|
self.layernorm2 = nn.Identity() |
|
|
|
|
|
|
|
|
|
self.query = nn.Linear(q_dim, qk_channels) |
|
self.key = nn.Linear(kv_dim, qk_channels) |
|
self.value = nn.Linear(kv_dim, v_channels) |
|
|
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
|
|
|
|
if config.position_encoding_type == 'alibi': |
|
self.alibi_bias = AlibiPositionalBias(heads=num_heads, total_heads=num_heads, trainable_slope=False) |
|
elif config.position_encoding_type == 'alibit': |
|
self.alibi_bias = AlibiPositionalBias(heads=num_heads, total_heads=num_heads, trainable_slope=True) |
|
else: |
|
self.alibi_bias = None |
|
|
|
if config.position_encoding_type == 'rope': |
|
assert rotary_emb is not None, "rotary_emb must be provided for RoPE." |
|
self.rotary_emb = rotary_emb |
|
else: |
|
self.rotary_emb = None |
|
self.rope_apply_to_keys = config.rope_apply_to_keys |
|
|
|
def transpose_for_scores(self, x, channels_per_head): |
|
new_x_shape = x.size()[:-1] + (self.num_heads, channels_per_head) |
|
x = x.view(*new_x_shape) |
|
return x.permute(0, 2, 1, 3) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs: Optional[torch.FloatTensor] = None, |
|
inputs_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = False, |
|
) -> Tuple[torch.Tensor]: |
|
hidden_states = self.layernorm1(hidden_states) |
|
inputs = self.layernorm2(inputs) |
|
|
|
|
|
|
|
is_cross_attention = inputs is not None |
|
queries = self.query(hidden_states) |
|
|
|
if is_cross_attention: |
|
keys = self.key(inputs) |
|
values = self.value(inputs) |
|
attention_mask = inputs_mask |
|
else: |
|
keys = self.key(hidden_states) |
|
values = self.value(hidden_states) |
|
|
|
|
|
|
|
queries = self.transpose_for_scores(queries, self.qk_channels_per_head) |
|
keys = self.transpose_for_scores(keys, self.qk_channels_per_head) |
|
values = self.transpose_for_scores(values, self.v_channels_per_head) |
|
|
|
|
|
if self.rotary_emb is not None: |
|
queries = self.rotary_emb.apply_rotary_custom(queries) |
|
if self.rope_apply_to_keys is True: |
|
keys = self.rotary_emb.apply_rotary_custom(keys) |
|
|
|
|
|
attention_scores = torch.matmul(queries, keys.transpose(-1, -2)) |
|
|
|
|
|
if self.alibi_bias is not None: |
|
batch_size, num_heads, q_seq_len, k_seq_len = attention_scores.shape |
|
attention_scores += self.alibi_bias(q_seq_len, |
|
k_seq_len) |
|
|
|
_, _, _, q_head_dim = queries.shape |
|
_, _, _, v_head_dim = values.shape |
|
hiddens = self.num_heads * v_head_dim |
|
|
|
attention_scores = attention_scores / math.sqrt(q_head_dim) |
|
|
|
if attention_mask is not None: |
|
|
|
attention_scores = attention_scores + attention_mask |
|
|
|
|
|
attention_probs = nn.Softmax(dim=-1)(attention_scores) |
|
|
|
|
|
|
|
attention_probs = self.dropout(attention_probs) |
|
|
|
|
|
if head_mask is not None: |
|
attention_probs = attention_probs * head_mask |
|
|
|
context_layer = torch.matmul(attention_probs, values) |
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
|
new_context_layer_shape = context_layer.size()[:-2] + (hiddens,) |
|
context_layer = context_layer.view(*new_context_layer_shape) |
|
|
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) |
|
|
|
return outputs |
|
|
|
|
|
class PerceiverAlibiAttention(nn.Module): |
|
""" |
|
Attention module, including a dense block + Alibi |
|
: modified from PerceiverAttention in modeling_perceiver.py to support Alibi positional bias |
|
""" |
|
|
|
def __init__( |
|
self, |
|
config, |
|
is_cross_attention=False, |
|
qk_channels=None, |
|
v_channels=None, |
|
num_heads=1, |
|
q_dim=None, |
|
kv_dim=None, |
|
use_query_residual=True, |
|
rotary_emb=None, |
|
): |
|
super().__init__() |
|
|
|
if is_cross_attention and qk_channels is None: |
|
if config.cross_attention_shape_for_attention == "q": |
|
qk_channels = q_dim |
|
elif config.cross_attention_shape_for_attention == "kv": |
|
qk_channels = kv_dim |
|
else: |
|
raise ValueError(f"Unknown value {config.cross_attention_shape_for_attention} for " |
|
"cross_attention_shape_for_attention.") |
|
else: |
|
if qk_channels is None: |
|
qk_channels = q_dim |
|
if v_channels is None: |
|
v_channels = qk_channels |
|
self.self = PerceiverAlibiSelfAttention(config, |
|
is_cross_attention=is_cross_attention, |
|
qk_channels=qk_channels, |
|
v_channels=v_channels, |
|
num_heads=num_heads, |
|
q_dim=q_dim, |
|
kv_dim=kv_dim, |
|
rotary_emb=rotary_emb) |
|
|
|
output_channels = None |
|
if is_cross_attention: |
|
output_channels = q_dim |
|
else: |
|
if output_channels is None: |
|
output_channels = v_channels |
|
self.output = PerceiverSelfOutput(config, input_channels=self.self.v_channels, output_channels=output_channels) |
|
self.use_query_residual = use_query_residual |
|
self.pruned_heads = set() |
|
|
|
def prune_heads(self, heads): |
|
if len(heads) == 0: |
|
return |
|
heads, index = find_pruneable_heads_and_indices(heads, self.self.num_attention_heads, |
|
self.self.attention_head_size, self.pruned_heads) |
|
|
|
|
|
self.self.query = prune_linear_layer(self.self.query, index) |
|
self.self.key = prune_linear_layer(self.self.key, index) |
|
self.self.value = prune_linear_layer(self.self.value, index) |
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) |
|
|
|
|
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads) |
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads |
|
self.pruned_heads = self.pruned_heads.union(heads) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs: Optional[torch.FloatTensor] = None, |
|
inputs_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = False, |
|
) -> Tuple[torch.Tensor]: |
|
self_outputs = self.self( |
|
hidden_states, |
|
attention_mask, |
|
head_mask, |
|
inputs, |
|
inputs_mask, |
|
output_attentions, |
|
) |
|
|
|
|
|
attention_output = self.output(self_outputs[0]) |
|
|
|
|
|
|
|
|
|
if self.use_query_residual: |
|
attention_output = attention_output + hidden_states |
|
|
|
outputs = (attention_output,) + self_outputs[1:] |
|
return outputs |
|
|
|
|
|
class PerceiverAlibiLayer(nn.Module): |
|
"""Construct a single PerceiverTF layer with: |
|
- Alibi positional bias |
|
- RoPE |
|
- Mixtral of Experts (MoE) feedforward layer |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
config, |
|
is_cross_attention=False, |
|
qk_channels=None, |
|
v_channels=None, |
|
num_heads=1, |
|
q_dim=None, |
|
kv_dim=None, |
|
widening_factor=1, |
|
use_query_residual=True, |
|
rotary_emb=None, |
|
): |
|
super().__init__() |
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward |
|
self.seq_len_dim = 1 |
|
self.attention = PerceiverAlibiAttention(config, |
|
is_cross_attention=is_cross_attention, |
|
qk_channels=qk_channels, |
|
v_channels=v_channels, |
|
num_heads=num_heads, |
|
q_dim=q_dim, |
|
kv_dim=kv_dim, |
|
use_query_residual=use_query_residual, |
|
rotary_emb=rotary_emb) |
|
self.layernorm = get_layer_norm(q_dim, config.layer_norm_type, config.layer_norm_eps) |
|
|
|
self.mlp = get_ff_layer(config, input_size=q_dim, widening_factor=widening_factor) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs: Optional[torch.FloatTensor] = None, |
|
inputs_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = False, |
|
) -> Tuple[torch.Tensor]: |
|
attention_outputs = self.attention( |
|
hidden_states, |
|
attention_mask, |
|
head_mask, |
|
inputs, |
|
inputs_mask, |
|
output_attentions, |
|
) |
|
attention_output = attention_outputs[0] |
|
|
|
outputs = attention_outputs[1:] |
|
"""apply_chunking_to_forward: |
|
This function chunks the input_tensors into smaller input tensor parts of size |
|
chunk_size over the dimension chunk_dim. It then applies a layer forward_fn to |
|
each chunk independently to save memory.If the forward_fn is independent across |
|
the chunk_dim this function will yield the same result as not applying it. |
|
""" |
|
layer_output, router_logits = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward, |
|
self.seq_len_dim, attention_output) |
|
|
|
layer_output = layer_output + attention_output |
|
outputs = (layer_output,) + outputs + (router_logits,) |
|
return outputs |
|
|
|
def feed_forward_chunk(self, attention_output): |
|
layer_output = self.layernorm(attention_output) |
|
layer_output, router_logits = self.mlp(layer_output) |
|
return layer_output, router_logits |
|
|
|
|
|
class PerceiverTFEncoderBlock(nn.Module): |
|
"""Construct a single block of PerceiverTF encoder: |
|
- Spectral Cross Attention (SCA) |
|
- Local latent transformer layers |
|
- Temporal transformer layers |
|
- added Alibi positional bias, RoPE, gMLP and MoE feedforward layer |
|
""" |
|
|
|
def __init__(self, |
|
config: PerceiverTFConfig, |
|
kv_dim: Optional[int] = None, |
|
sca_use_query_residual: bool = True, |
|
rotary_emb_sca: Optional[nn.Module] = None, |
|
rotary_emb_latent: Optional[nn.Module] = None, |
|
rotary_emb_temporal: Optional[nn.Module] = None): |
|
super().__init__() |
|
self.config = config |
|
|
|
|
|
if config.d_latents % config.num_self_attention_heads != 0: |
|
raise ValueError(f"num_z_channels ({config.d_latents}) must be divisible by" |
|
f" num_self_attend_heads ({config.num_self_attention_heads}).") |
|
if config.d_latents % config.num_cross_attention_heads != 0: |
|
raise ValueError(f"num_z_channels ({config.d_latents}) must be divisible by" |
|
f" num_cross_attend_heads ({config.num_cross_attention_heads}).") |
|
|
|
if kv_dim is None: |
|
kv_dim = config.kv_dim |
|
if sca_use_query_residual is None: |
|
sca_use_query_residual = config.sca_use_query_residual |
|
|
|
|
|
self.sca_attention_to_channel = config.attention_to_channel |
|
self.spectral_cross_attention = PerceiverAlibiAttention(config, |
|
is_cross_attention=True, |
|
qk_channels=config.qk_channels, |
|
v_channels=config.v_channels, |
|
num_heads=config.num_cross_attention_heads, |
|
q_dim=config.d_latents, |
|
kv_dim=kv_dim, |
|
use_query_residual=sca_use_query_residual, |
|
rotary_emb=rotary_emb_sca) |
|
|
|
|
|
local_transformer_layers = [] |
|
for _ in range(config.num_local_transformers_per_block): |
|
layer = PerceiverAlibiLayer( |
|
config, |
|
is_cross_attention=False, |
|
qk_channels=config.qk_channels, |
|
v_channels=config.v_channels, |
|
num_heads=config.num_self_attention_heads, |
|
q_dim=config.d_model, |
|
kv_dim=config.d_model, |
|
widening_factor=config.ff_widening_factor, |
|
use_query_residual=config.use_query_residual, |
|
rotary_emb=rotary_emb_latent |
|
) |
|
local_transformer_layers.append(layer) |
|
self.local_transformer = nn.ModuleList(local_transformer_layers) |
|
|
|
|
|
temporal_transformer_layers = [] |
|
for _ in range(config.num_temporal_transformers_per_block): |
|
layer = PerceiverAlibiLayer( |
|
config, |
|
is_cross_attention=False, |
|
qk_channels=config.qk_channels, |
|
v_channels=config.v_channels, |
|
num_heads=config.num_self_attention_heads, |
|
q_dim=config.d_model, |
|
kv_dim=config.d_model, |
|
widening_factor=config.ff_widening_factor, |
|
use_query_residual=config.use_query_residual, |
|
rotary_emb=rotary_emb_temporal |
|
) |
|
temporal_transformer_layers.append(layer) |
|
self.temporal_transformer = nn.ModuleList(temporal_transformer_layers) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
inputs: Optional[torch.FloatTensor] = None, |
|
inputs_mask: Optional[torch.FloatTensor] = None, |
|
local_attention_mask: Optional[torch.FloatTensor] = None, |
|
temporal_attention_mask: Optional[torch.FloatTensor] = None, |
|
local_head_mask: Optional[torch.FloatTensor] = None, |
|
temporal_head_mask: Optional[torch.FloatTensor] = None, |
|
pos_emb_temporal: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = False, |
|
output_hidden_states: Optional[bool] = False, |
|
output_router_logits: Optional[bool] = False, |
|
return_dict: Optional[bool] = True, |
|
) -> Union[Tuple, MoEModelOutputWithCrossAttentions]: |
|
""" |
|
Inputs: |
|
hidden_states: (B, T, K, D) |
|
inputs: (B, T, F, C) |
|
Returns: |
|
hidden_states: (B, T, K, D) |
|
|
|
Args: |
|
hidden_states: |
|
latent_array (B, T, num_latents, d_latents) for SCA. The latent array |
|
with shape (B, K, D) is expanded by t, and positional embeddings are |
|
added to it. |
|
inputs: torch.FloatTensor |
|
The input sequence of shape (B, T, F, C). |
|
inputs_mask: torch.FloatTensor |
|
Only used for SCA. By default, None. |
|
local_attention_mask: |
|
Used for local self-attention. By default, None. |
|
temporal_attention_mask: |
|
Used for temporal self-attention. By default, None. |
|
local_head_mask: |
|
By default, None. |
|
temporal_head_mask: |
|
By default, None. |
|
pos_emb_temporal: |
|
Optioanl. Used for temporal self-attention. By default, None. (max_t, num_latents, d_latents) |
|
output_attentions: bool |
|
Whether to return attentions weights. |
|
output_hidden_states: bool |
|
Whether to return all hidden states. If False, only last hidden |
|
state is returned. |
|
output_router_logits: bool |
|
Whether to return router logits for MoE. If False, only last hidden |
|
state is returned. |
|
return_dict: bool |
|
Whether to return a MoEModelOutputWithCrossAttentions instead of a tuple. |
|
""" |
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attentions = () if output_attentions else None |
|
all_cross_attentions = () if output_attentions else None |
|
all_router_logits = () if output_router_logits else None |
|
|
|
|
|
batch_size, t, num_latents, d_latents = hidden_states.size() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = rearrange(hidden_states, "b t k d -> (b t) k d") |
|
inputs = rearrange(inputs, "b t f c -> (b t) f c") |
|
|
|
|
|
layer_outputs = self.spectral_cross_attention( |
|
hidden_states, |
|
attention_mask=None, |
|
inputs=inputs, |
|
inputs_mask=inputs_mask, |
|
output_attentions=output_attentions, |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_cross_attentions = all_cross_attentions + (layer_outputs[1],) |
|
|
|
|
|
for i, layer_module in enumerate(self.local_transformer): |
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
layer_head_mask = local_head_mask[i] if local_head_mask is not None else None |
|
layer_outputs = layer_module( |
|
hidden_states, |
|
attention_mask=local_attention_mask, |
|
head_mask=layer_head_mask, |
|
output_attentions=output_attentions, |
|
) |
|
hidden_states = layer_outputs[0] |
|
if output_attentions: |
|
all_self_attentions = all_self_attentions + (layer_outputs[1],) |
|
if output_router_logits: |
|
all_router_logits = all_router_logits + (layer_outputs[2],) |
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
hidden_states = rearrange(hidden_states, "(b t) k d -> (b k) t d", b=batch_size) |
|
|
|
|
|
for i, layer_module in enumerate(self.temporal_transformer): |
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
layer_head_mask = temporal_head_mask[i] if temporal_head_mask is not None else None |
|
|
|
if i == 0 and pos_emb_temporal is not None: |
|
|
|
hidden_states = hidden_states + pos_emb_temporal[:t] |
|
|
|
layer_outputs = layer_module( |
|
hidden_states, |
|
attention_mask=temporal_attention_mask, |
|
head_mask=layer_head_mask, |
|
output_attentions=output_attentions, |
|
) |
|
hidden_states = layer_outputs[0] |
|
if output_attentions: |
|
all_self_attentions = all_self_attentions + (layer_outputs[1],) |
|
if output_router_logits: |
|
all_router_logits = all_router_logits + (layer_outputs[2],) |
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
last_hideen_state = hidden_states |
|
|
|
last_hideen_state = rearrange(last_hideen_state, "(b k) t d -> b t k d", b=batch_size) |
|
|
|
|
|
if not return_dict: |
|
return tuple( |
|
v for v in |
|
[last_hideen_state, all_hidden_states, all_self_attentions, all_cross_attentions, all_router_logits] |
|
if v is not None) |
|
return MoEModelOutputWithCrossAttentions( |
|
last_hidden_state=last_hideen_state, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attentions, |
|
cross_attentions=all_cross_attentions, |
|
router_logits=all_router_logits, |
|
) |
|
|
|
|
|
class PerceiverTFEncoder(PerceiverTFPreTrainedModel): |
|
"""PerceiverTFEncoder is an encoder model based on the Perceiver and Spectral Cross Attention (SCA). |
|
|
|
position_encoding_type: str |
|
The type of positional encoding to use. One of the following: |
|
- 'trainable': trainable positional embeddings |
|
- 'alibi': AlibiNet positional embeddings |
|
- 'alibit': AlibiNet positional embeddings with trainable slopes for each head |
|
- 'rope': RoPE (Rotary Positional Encoding) |
|
(experimental w/ 'trainable') |
|
- 'tkd': trainable PE (T,K,D) on latent (default for 'trainable') |
|
- 'td': trainable PE (T,D) on latent |
|
- 'tk': trainable PE (T,K) on latent |
|
- 'kdt': trainable PE (K,D) on latent, and (T,) on temporal transformer |
|
|
|
""" |
|
|
|
def __init__(self, |
|
config: PerceiverTFConfig, |
|
sca_use_query_residual: Optional[bool] = None, |
|
shared_emb: Optional[nn.Embedding] = None): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
if sca_use_query_residual is None: |
|
self.sca_use_query_residual = config.sca_use_query_residual |
|
self.position_encoding_type = config.position_encoding_type |
|
self.sca_attention_to_channel = config.attention_to_channel |
|
|
|
|
|
self.latent_array = PerceiverEmbeddings(config) |
|
|
|
|
|
if self.position_encoding_type == 'rope': |
|
|
|
self.rotary_emb_sca = get_rotary_emb(config.num_cross_attention_heads, config.rope_type_sca, |
|
config.rope_partial_pe, config.rope_trainable) |
|
self.rotary_emb_latent = get_rotary_emb(config.num_cross_attention_heads, config.rope_type_latent, |
|
config.rope_partial_pe, config.rope_trainable) |
|
self.rotary_emb_temporal = get_rotary_emb(config.num_cross_attention_heads, config.rope_type_temporal, |
|
config.rope_partial_pe, config.rope_trainable) |
|
else: |
|
self.rotary_emb_sca = None |
|
self.rotary_emb_latent = None |
|
self.rotary_emb_temporal = None |
|
|
|
if self.position_encoding_type in ['alibi', 'alibit', 'rope', None]: |
|
|
|
|
|
self.pos_emb = None |
|
else: |
|
k, d = self.latent_array.latents.size() |
|
max_t = int(config.num_max_positions) + 10 |
|
self.pos_emb = PerceiverTFTrainablePE(self.position_encoding_type, max_t, k, d) |
|
""" |
|
self.pos_emb() returns: |
|
pos_emb: (max_t, K, D) |
|
pos_emb_temporal: (max_t, K, D) |
|
""" |
|
|
|
|
|
blocks = [] |
|
for _ in range(config.num_blocks): |
|
block = PerceiverTFEncoderBlock( |
|
config, |
|
kv_dim=config.kv_dim, |
|
sca_use_query_residual=sca_use_query_residual, |
|
rotary_emb_sca=self.rotary_emb_sca, |
|
rotary_emb_latent=self.rotary_emb_latent, |
|
rotary_emb_temporal=self.rotary_emb_temporal) |
|
blocks.append(block) |
|
self.blocks = nn.ModuleList(blocks) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.latent_array.latents |
|
|
|
def set_input_embeddings(self, value): |
|
self.latent_array.latents = value |
|
|
|
"""temporary fix for torch.compile issue""" |
|
|
|
def forward(self, **kwargs): |
|
if self.training is True: |
|
return self._forward_compile(**kwargs) |
|
else: |
|
return self._forward_no_compile(**kwargs) |
|
|
|
def _forward_no_compile(self, **kwargs): |
|
return self._forward(**kwargs) |
|
|
|
@torch.compile |
|
def _forward_compile(self, **kwargs): |
|
return self._forward(**kwargs) |
|
|
|
def _forward( |
|
self, |
|
inputs: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
inputs_mask: Optional[torch.FloatTensor] = None, |
|
local_attention_mask: Optional[torch.FloatTensor] = None, |
|
temporal_attention_mask: Optional[torch.FloatTensor] = None, |
|
local_head_mask: Optional[torch.FloatTensor] = None, |
|
temporal_head_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_router_logits: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, MoEModelOutputWithCrossAttentions]: |
|
|
|
|
|
|
|
|
|
if inputs is None and inputs_embeds is not None: |
|
inputs = inputs_embeds |
|
elif inputs is None and inputs_embeds is None: |
|
raise ValueError("You must provide 'inputs' or 'inputs_embeds' argument.") |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = (output_hidden_states |
|
if output_hidden_states is not None else self.config.output_hidden_states) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
batch_size, t, _f, _c = inputs.size() |
|
device = inputs.device |
|
|
|
|
|
if self.sca_attention_to_channel is True: |
|
inputs = rearrange(inputs, "b t f c -> b t c f") |
|
|
|
|
|
|
|
|
|
|
|
|
|
local_head_mask = self.get_head_mask(local_head_mask, |
|
self.config.num_blocks * self.config.num_local_transformers_per_block) |
|
temporal_head_mask = self.get_head_mask( |
|
temporal_head_mask, self.config.num_blocks * self.config.num_temporal_transformers_per_block) |
|
|
|
|
|
|
|
|
|
latent_embeddings = self.latent_array(batch_size=batch_size) |
|
expanded_latent_embeddings = latent_embeddings.unsqueeze(1).expand(-1, t, -1, -1) |
|
|
|
|
|
if self.pos_emb is not None: |
|
pos_emb_latent, pos_emb_temporal = self.pos_emb.forward() |
|
expanded_latent_embeddings = expanded_latent_embeddings + pos_emb_latent[:t] |
|
|
|
else: |
|
pos_emb_temporal = None |
|
|
|
|
|
all_hidden_states = [] |
|
all_attentions = [] |
|
all_cross_attentions = [] |
|
all_router_logits = [] |
|
|
|
hidden_states = expanded_latent_embeddings |
|
|
|
|
|
for i, block in enumerate(self.blocks): |
|
block_output = block(hidden_states=hidden_states, |
|
inputs=inputs, |
|
inputs_mask=inputs_mask, |
|
local_attention_mask=local_attention_mask, |
|
temporal_attention_mask=temporal_attention_mask, |
|
local_head_mask=local_head_mask, |
|
temporal_head_mask=temporal_head_mask, |
|
pos_emb_temporal=pos_emb_temporal if i == 0 else None, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
output_router_logits=output_router_logits, |
|
return_dict=True) |
|
|
|
|
|
hidden_states = block_output.last_hidden_state |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states.append(hidden_states) |
|
if output_attentions: |
|
all_attentions.append(block_output.attentions) |
|
all_cross_attentions.append(block_output.cross_attentions) |
|
if output_router_logits: |
|
all_router_logits.append(block_output.router_logits) |
|
last_hidden_states = hidden_states |
|
|
|
|
|
if not return_dict: |
|
|
|
return (last_hidden_states, tuple(all_hidden_states) if all_hidden_states else None, |
|
tuple(all_attentions) if all_attentions else None, |
|
tuple(all_cross_attentions) if all_cross_attentions else None, |
|
tuple(all_router_logits) if all_router_logits else None) |
|
|
|
return MoEModelOutputWithCrossAttentions( |
|
last_hidden_state=last_hidden_states, |
|
hidden_states=tuple(all_hidden_states) if all_hidden_states else None, |
|
attentions=tuple(all_attentions) if all_attentions else None, |
|
cross_attentions=tuple(all_cross_attentions) if all_cross_attentions else None, |
|
router_logits=tuple(all_router_logits) if all_router_logits else None) |
|
|
|
|
|
def test(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from model.ops import count_parameters |
|
|
|
|
|
b = 2 |
|
t = 10 |
|
f = 128 |
|
c = 128 |
|
k = 24 |
|
d = 128 |
|
conv_feat = torch.randn(b, t, f, c) |
|
|
|
|
|
config = PerceiverTFConfig() |
|
pe_types = ['alibi', 'alibit', 'trainable', 'tkd', 'td', 'tk', 'kdt', None] |
|
config.ff_layer_type = 'moe' |
|
config.moe_num_experts = 4 |
|
config.moe_topk = 2 |
|
|
|
for pe_type in pe_types: |
|
config.position_encoding_type = pe_type |
|
config.num_latents = k |
|
config.d_latents = d |
|
config.kv_dim = c |
|
config.qk_channels = d |
|
config.v_channels = d |
|
encoder = PerceiverTFEncoder(config) |
|
encoder.eval() |
|
assert encoder.latent_array.latents.size() == (k, d) |
|
|
|
enc_hidden_state = encoder.forward(inputs_embeds=conv_feat).last_hidden_state |
|
|
|
n_param = count_parameters(encoder)[1] // 1000 |
|
print(config.position_encoding_type, f'num_param: {n_param}K') |
|
""" |
|
PE type | num. param. |
|
None | 1397K |
|
alibi | 1397K |
|
alibit (train slope) | 1397K |
|
tkd | 2442K |
|
td | 1441K |
|
tk | 1405K |
|
kdt | 1444K |
|
|
|
MLP | 2637K |
|
MoE (4 experts) | 4411K |
|
MoE (6 experts) | 5594K |
|
""" |
|
|