YourMT3 / amt /src /model /perceiver_mod.py
mimbres's picture
.
a03c9b4
raw
history blame
40.8 kB
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
"""perceiver_mod.py
Implementation of the PerceiverTF encoder with:
- AliBi positional bias
- Mixtral of Experts (MoE) feedforward layer
"""
import math
from einops import rearrange
from typing import Optional, Tuple, Union, List, Dict, Literal
import torch
from torch import nn
from transformers.models.perceiver.modeling_perceiver import PerceiverSelfOutput
from transformers.pytorch_utils import (apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer)
from model.perceiver_helper import MoEModelOutputWithCrossAttentions
from model.perceiver_helper import PerceiverTFPreTrainedModel, PerceiverTFConfig
from model.positional_encoding import AlibiPositionalBias, get_rotary_emb
from model.ops import get_layer_norm
from model.ff_layer import get_ff_layer
class PerceiverEmbeddings(nn.Module):
"""Construct the latent embeddings sharable with token embeddings in the decoder."""
def __init__(self, config, shared_emb: Optional[nn.Parameter] = None):
super().__init__()
if shared_emb is not None:
self.latents = shared_emb
assert self.latents.shape == (config.num_latents, config.d_latents)
else:
self.latents = nn.Parameter(torch.randn(config.num_latents, config.d_latents))
def forward(self, batch_size: int):
return self.latents.expand(batch_size, -1, -1)
class PerceiverTFTrainablePE(nn.Module):
"""Construct the trainable absolute positional embeddings."""
def __init__(self, position_encoding_type: Literal['trainable', 'tkd', 'td', 'tk', 'kdt'], max_t: int, k: int,
d: int) -> None:
super().__init__()
self.position_encoding_type = position_encoding_type
self.max_t = max_t
self.k = k
self.d = d
if position_encoding_type in ['trainable', 'tkd']:
self._pos_emb = nn.Parameter(torch.randn(max_t, k, d))
elif position_encoding_type == 'td':
self._pos_emb = nn.Parameter(torch.randn(max_t, d))
elif position_encoding_type == 'tk':
self._pos_emb = nn.Parameter(torch.randn(max_t, k))
elif position_encoding_type == 'kdt':
self._pos_emb = nn.Parameter(torch.randn(k, d))
self._pos_emb_temporal = nn.Parameter(torch.randn(max_t, d))
else:
raise ValueError(f'unknown position encoding type {position_encoding_type}')
def forward(self):
pos_emb_temporal = None
if self.position_encoding_type in ['trainable', 'tkd']:
pos_emb = self._pos_emb
elif self.position_encoding_type == 'td':
pos_emb = self._pos_emb.unsqueeze(1).expand(-1, self.k, -1)
elif self.position_encoding_type == 'tk':
pos_emb = self._pos_emb.unsqueeze(-1).expand(-1, -1, self.d)
elif self.position_encoding_type == 'kdt':
pos_emb = self._pos_emb.unsqueeze(0).expand(self.max_t, -1, -1)
pos_emb_temporal = self._pos_emb_temporal
return pos_emb, pos_emb_temporal
class PerceiverAlibiSelfAttention(nn.Module):
"""
Multi-headed {cross, self}-attention + Alibi/Rotary positional bias/emb:
- Can be used both in the encoder as well as in the decoder.
- Modified from PerceiverSelfAttention in modeling_perceiver.py to support Alibi positional bias
"""
def __init__(
self,
config,
is_cross_attention=False,
qk_channels=None,
v_channels=None,
num_heads=1,
q_dim=None,
kv_dim=None,
rotary_emb=None,
):
super().__init__()
self.num_heads = num_heads
# Q and K must have the same number of channels.
# Default to preserving Q's input's shape.
if qk_channels is None:
qk_channels = q_dim
# V's num_channels determines the shape of the output of QKV-attention.
# Default to the same number of channels used in the key-query operation.
if v_channels is None:
v_channels = qk_channels
if qk_channels % num_heads != 0:
raise ValueError(f"qk_channels ({qk_channels}) must be divisible by num_heads ({num_heads}).")
if v_channels % num_heads != 0:
raise ValueError(f"v_channels ({v_channels}) must be divisible by num_heads ({num_heads}).")
self.qk_channels = qk_channels
self.v_channels = v_channels
self.qk_channels_per_head = self.qk_channels // num_heads
self.v_channels_per_head = self.v_channels // num_heads
# Layer normalization
self.layernorm1 = get_layer_norm(q_dim, config.layer_norm_type, config.layer_norm_eps)
if is_cross_attention:
self.layernorm2 = get_layer_norm(kv_dim, config.layer_norm_type, config.layer_norm_eps)
else:
self.layernorm2 = nn.Identity()
# self.layernorm1 = nn.LayerNorm(q_dim)
# self.layernorm2 = nn.LayerNorm(kv_dim) if is_cross_attention else nn.Identity()
# Projection matrices
self.query = nn.Linear(q_dim, qk_channels)
self.key = nn.Linear(kv_dim, qk_channels)
self.value = nn.Linear(kv_dim, v_channels)
self.dropout = nn.Dropout(config.dropout_rate)
# (Modified) Alibi positional bias
if config.position_encoding_type == 'alibi':
self.alibi_bias = AlibiPositionalBias(heads=num_heads, total_heads=num_heads, trainable_slope=False)
elif config.position_encoding_type == 'alibit':
self.alibi_bias = AlibiPositionalBias(heads=num_heads, total_heads=num_heads, trainable_slope=True)
else:
self.alibi_bias = None
# (Modified) RoPE
if config.position_encoding_type == 'rope':
assert rotary_emb is not None, "rotary_emb must be provided for RoPE."
self.rotary_emb = rotary_emb
else:
self.rotary_emb = None
self.rope_apply_to_keys = config.rope_apply_to_keys # False by default
def transpose_for_scores(self, x, channels_per_head):
new_x_shape = x.size()[:-1] + (self.num_heads, channels_per_head)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs: Optional[torch.FloatTensor] = None,
inputs_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
hidden_states = self.layernorm1(hidden_states)
inputs = self.layernorm2(inputs)
# Project queries, keys and values to a common feature dimension. If this is instantiated as a cross-attention module,
# the keys and values come from the inputs; the attention mask needs to be such that the inputs's non-relevant tokens are not attended to.
is_cross_attention = inputs is not None
queries = self.query(hidden_states)
if is_cross_attention:
keys = self.key(inputs)
values = self.value(inputs)
attention_mask = inputs_mask
else:
keys = self.key(hidden_states)
values = self.value(hidden_states)
# Reshape channels for multi-head attention.
# We reshape from (batch_size, time, channels) to (batch_size, num_heads, time, channels per head)
queries = self.transpose_for_scores(queries, self.qk_channels_per_head)
keys = self.transpose_for_scores(keys, self.qk_channels_per_head)
values = self.transpose_for_scores(values, self.v_channels_per_head)
# (Modified) RoPE
if self.rotary_emb is not None:
queries = self.rotary_emb.apply_rotary_custom(queries)
if self.rope_apply_to_keys is True:
keys = self.rotary_emb.apply_rotary_custom(keys)
# Take the dot product between the queries and keys to get the raw attention scores.
attention_scores = torch.matmul(queries, keys.transpose(-1, -2))
# (Modified) Alibi positional bias
if self.alibi_bias is not None:
batch_size, num_heads, q_seq_len, k_seq_len = attention_scores.shape
attention_scores += self.alibi_bias(q_seq_len,
k_seq_len) # auto-broadcasting to (b, num_heads, q_seq_len, k_seq_len)
_, _, _, q_head_dim = queries.shape
_, _, _, v_head_dim = values.shape
hiddens = self.num_heads * v_head_dim
attention_scores = attention_scores / math.sqrt(q_head_dim)
if attention_mask is not None:
# Apply the attention mask (precomputed for all layers in PerceiverModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, values)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (hiddens,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class PerceiverAlibiAttention(nn.Module):
"""
Attention module, including a dense block + Alibi
: modified from PerceiverAttention in modeling_perceiver.py to support Alibi positional bias
"""
def __init__(
self,
config,
is_cross_attention=False,
qk_channels=None,
v_channels=None,
num_heads=1,
q_dim=None,
kv_dim=None,
use_query_residual=True,
rotary_emb=None,
):
super().__init__()
# MultiHead attention
if is_cross_attention and qk_channels is None:
if config.cross_attention_shape_for_attention == "q":
qk_channels = q_dim
elif config.cross_attention_shape_for_attention == "kv":
qk_channels = kv_dim
else:
raise ValueError(f"Unknown value {config.cross_attention_shape_for_attention} for "
"cross_attention_shape_for_attention.")
else:
if qk_channels is None:
qk_channels = q_dim
if v_channels is None:
v_channels = qk_channels
self.self = PerceiverAlibiSelfAttention(config,
is_cross_attention=is_cross_attention,
qk_channels=qk_channels,
v_channels=v_channels,
num_heads=num_heads,
q_dim=q_dim,
kv_dim=kv_dim,
rotary_emb=rotary_emb)
# dense block
output_channels = None
if is_cross_attention:
output_channels = q_dim
else:
if output_channels is None:
output_channels = v_channels
self.output = PerceiverSelfOutput(config, input_channels=self.self.v_channels, output_channels=output_channels)
self.use_query_residual = use_query_residual
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(heads, self.self.num_attention_heads,
self.self.attention_head_size, self.pruned_heads)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs: Optional[torch.FloatTensor] = None,
inputs_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
inputs,
inputs_mask,
output_attentions,
)
# Output projection
attention_output = self.output(self_outputs[0])
# Optionally include a residual to the original queries.
# Consider omitting the residual if the semantics of query and output
# are different, e.g. if queries are positions and outputs are pixels.
if self.use_query_residual:
attention_output = attention_output + hidden_states
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class PerceiverAlibiLayer(nn.Module):
"""Construct a single PerceiverTF layer with:
- Alibi positional bias
- RoPE
- Mixtral of Experts (MoE) feedforward layer
"""
def __init__(
self,
config,
is_cross_attention=False,
qk_channels=None,
v_channels=None,
num_heads=1,
q_dim=None,
kv_dim=None,
widening_factor=1,
use_query_residual=True,
rotary_emb=None,
):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = PerceiverAlibiAttention(config,
is_cross_attention=is_cross_attention,
qk_channels=qk_channels,
v_channels=v_channels,
num_heads=num_heads,
q_dim=q_dim,
kv_dim=kv_dim,
use_query_residual=use_query_residual,
rotary_emb=rotary_emb)
self.layernorm = get_layer_norm(q_dim, config.layer_norm_type, config.layer_norm_eps)
# self.layernorm = nn.LayerNorm(q_dim)
self.mlp = get_ff_layer(config, input_size=q_dim, widening_factor=widening_factor)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs: Optional[torch.FloatTensor] = None,
inputs_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
inputs,
inputs_mask,
output_attentions,
)
attention_output = attention_outputs[0]
outputs = attention_outputs[1:] # add attentions if we output attention weights
"""apply_chunking_to_forward:
This function chunks the input_tensors into smaller input tensor parts of size
chunk_size over the dimension chunk_dim. It then applies a layer forward_fn to
each chunk independently to save memory.If the forward_fn is independent across
the chunk_dim this function will yield the same result as not applying it.
"""
layer_output, router_logits = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward,
self.seq_len_dim, attention_output)
layer_output = layer_output + attention_output # residual connection
outputs = (layer_output,) + outputs + (router_logits,) # add router_logits to outputs
return outputs
def feed_forward_chunk(self, attention_output):
layer_output = self.layernorm(attention_output)
layer_output, router_logits = self.mlp(layer_output) # router_logits is returned only when using MoE.
return layer_output, router_logits
class PerceiverTFEncoderBlock(nn.Module):
"""Construct a single block of PerceiverTF encoder:
- Spectral Cross Attention (SCA)
- Local latent transformer layers
- Temporal transformer layers
- added Alibi positional bias, RoPE, gMLP and MoE feedforward layer
"""
def __init__(self,
config: PerceiverTFConfig,
kv_dim: Optional[int] = None,
sca_use_query_residual: bool = True,
rotary_emb_sca: Optional[nn.Module] = None,
rotary_emb_latent: Optional[nn.Module] = None,
rotary_emb_temporal: Optional[nn.Module] = None):
super().__init__()
self.config = config
# Check that we can use multihead-attention with these shapes.
if config.d_latents % config.num_self_attention_heads != 0:
raise ValueError(f"num_z_channels ({config.d_latents}) must be divisible by"
f" num_self_attend_heads ({config.num_self_attention_heads}).")
if config.d_latents % config.num_cross_attention_heads != 0:
raise ValueError(f"num_z_channels ({config.d_latents}) must be divisible by"
f" num_cross_attend_heads ({config.num_cross_attention_heads}).")
if kv_dim is None:
kv_dim = config.kv_dim
if sca_use_query_residual is None:
sca_use_query_residual = config.sca_use_query_residual
# Spectral Cross Attention (SCA) layer.
self.sca_attention_to_channel = config.attention_to_channel
self.spectral_cross_attention = PerceiverAlibiAttention(config,
is_cross_attention=True,
qk_channels=config.qk_channels,
v_channels=config.v_channels,
num_heads=config.num_cross_attention_heads,
q_dim=config.d_latents,
kv_dim=kv_dim,
use_query_residual=sca_use_query_residual,
rotary_emb=rotary_emb_sca) # (Modified) RoPE
# Local latent trasformer layers.
local_transformer_layers = []
for _ in range(config.num_local_transformers_per_block):
layer = PerceiverAlibiLayer(
config,
is_cross_attention=False,
qk_channels=config.qk_channels, # projection dim for q and k.
v_channels=config.v_channels, # projection dim for v.
num_heads=config.num_self_attention_heads,
q_dim=config.d_model,
kv_dim=config.d_model,
widening_factor=config.ff_widening_factor,
use_query_residual=config.use_query_residual,
rotary_emb=rotary_emb_latent # (Modified) RoPE
)
local_transformer_layers.append(layer)
self.local_transformer = nn.ModuleList(local_transformer_layers)
# Temporal transformer layers.
temporal_transformer_layers = []
for _ in range(config.num_temporal_transformers_per_block):
layer = PerceiverAlibiLayer(
config,
is_cross_attention=False,
qk_channels=config.qk_channels, # projection dim for q and k.
v_channels=config.v_channels, # projection dim for v.
num_heads=config.num_self_attention_heads,
q_dim=config.d_model,
kv_dim=config.d_model,
widening_factor=config.ff_widening_factor,
use_query_residual=config.use_query_residual,
rotary_emb=rotary_emb_temporal # (Modified) RoPE
)
temporal_transformer_layers.append(layer)
self.temporal_transformer = nn.ModuleList(temporal_transformer_layers)
def forward(
self,
hidden_states: torch.Tensor,
inputs: Optional[torch.FloatTensor] = None,
inputs_mask: Optional[torch.FloatTensor] = None,
local_attention_mask: Optional[torch.FloatTensor] = None,
temporal_attention_mask: Optional[torch.FloatTensor] = None,
local_head_mask: Optional[torch.FloatTensor] = None,
temporal_head_mask: Optional[torch.FloatTensor] = None,
pos_emb_temporal: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
output_router_logits: Optional[bool] = False, # Only used for MoE.
return_dict: Optional[bool] = True,
) -> Union[Tuple, MoEModelOutputWithCrossAttentions]:
"""
Inputs:
hidden_states: (B, T, K, D)
inputs: (B, T, F, C)
Returns:
hidden_states: (B, T, K, D)
Args:
hidden_states:
latent_array (B, T, num_latents, d_latents) for SCA. The latent array
with shape (B, K, D) is expanded by t, and positional embeddings are
added to it.
inputs: torch.FloatTensor
The input sequence of shape (B, T, F, C).
inputs_mask: torch.FloatTensor
Only used for SCA. By default, None.
local_attention_mask:
Used for local self-attention. By default, None.
temporal_attention_mask:
Used for temporal self-attention. By default, None.
local_head_mask:
By default, None.
temporal_head_mask:
By default, None.
pos_emb_temporal:
Optioanl. Used for temporal self-attention. By default, None. (max_t, num_latents, d_latents)
output_attentions: bool
Whether to return attentions weights.
output_hidden_states: bool
Whether to return all hidden states. If False, only last hidden
state is returned.
output_router_logits: bool
Whether to return router logits for MoE. If False, only last hidden
state is returned.
return_dict: bool
Whether to return a MoEModelOutputWithCrossAttentions instead of a tuple.
"""
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions else None
all_router_logits = () if output_router_logits else None
# Collect dimension info
batch_size, t, num_latents, d_latents = hidden_states.size() # (B, T, K, D)
# if self.sca_attention_to_channel:
# _, _, _, f = inputs.size() # (B, T, C, F)
# assert d_latents == f, "d_latents must be equal to kv_dim, which is input frequency dim."
# else:
# _, _, _, c = inputs.size() # (B, T, F, C)
# assert d_latents == c, "d_latents must be equal to kv_dim, which is input channels."
# Reshape (B, T, _, _) to (B*T, _, _) for SCA and local transformer.
hidden_states = rearrange(hidden_states, "b t k d -> (b t) k d")
inputs = rearrange(inputs, "b t f c -> (b t) f c")
# Apply the SCA between the latents (hidden_states) and inputs:
layer_outputs = self.spectral_cross_attention(
hidden_states,
attention_mask=None, # Input_mask is used instead for cross-attention
inputs=inputs,
inputs_mask=inputs_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] # (B*T, K, D)
if output_attentions:
all_cross_attentions = all_cross_attentions + (layer_outputs[1],)
# Apply the block of local latent transformer layers.
for i, layer_module in enumerate(self.local_transformer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = local_head_mask[i] if local_head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
attention_mask=local_attention_mask,
head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0] # (B*T, K, D)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_router_logits:
all_router_logits = all_router_logits + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# Reshape (B*T, K, D) to (B*K, T, D) for the temporal transformer.
hidden_states = rearrange(hidden_states, "(b t) k d -> (b k) t d", b=batch_size)
# Apply the block of temporal transformer layers.
for i, layer_module in enumerate(self.temporal_transformer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = temporal_head_mask[i] if temporal_head_mask is not None else None
if i == 0 and pos_emb_temporal is not None:
# Add temporal positional embeddings to the hidden_states.
hidden_states = hidden_states + pos_emb_temporal[:t] # pos_emb_temporal: (T, D)
layer_outputs = layer_module(
hidden_states,
attention_mask=temporal_attention_mask,
head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_router_logits:
all_router_logits = all_router_logits + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
last_hideen_state = hidden_states
# Reshape (B*K, T, D) to (B, T, K, D) for the next block.
last_hideen_state = rearrange(last_hideen_state, "(b k) t d -> b t k d", b=batch_size)
# Prepare the outputs.
if not return_dict:
return tuple(
v for v in
[last_hideen_state, all_hidden_states, all_self_attentions, all_cross_attentions, all_router_logits]
if v is not None)
return MoEModelOutputWithCrossAttentions(
last_hidden_state=last_hideen_state,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
router_logits=all_router_logits,
)
class PerceiverTFEncoder(PerceiverTFPreTrainedModel):
"""PerceiverTFEncoder is an encoder model based on the Perceiver and Spectral Cross Attention (SCA).
position_encoding_type: str
The type of positional encoding to use. One of the following:
- 'trainable': trainable positional embeddings
- 'alibi': AlibiNet positional embeddings
- 'alibit': AlibiNet positional embeddings with trainable slopes for each head
- 'rope': RoPE (Rotary Positional Encoding)
(experimental w/ 'trainable')
- 'tkd': trainable PE (T,K,D) on latent (default for 'trainable')
- 'td': trainable PE (T,D) on latent
- 'tk': trainable PE (T,K) on latent
- 'kdt': trainable PE (K,D) on latent, and (T,) on temporal transformer
"""
def __init__(self,
config: PerceiverTFConfig,
sca_use_query_residual: Optional[bool] = None,
shared_emb: Optional[nn.Embedding] = None):
super().__init__(config)
self.config = config
if sca_use_query_residual is None:
self.sca_use_query_residual = config.sca_use_query_residual # True by default
self.position_encoding_type = config.position_encoding_type
self.sca_attention_to_channel = config.attention_to_channel
# Construct a latent array.
self.latent_array = PerceiverEmbeddings(config) # (num_latents, d_latents)
# Positional embeddings for the latent array.
if self.position_encoding_type == 'rope':
# (Modified) RoPE
self.rotary_emb_sca = get_rotary_emb(config.num_cross_attention_heads, config.rope_type_sca,
config.rope_partial_pe, config.rope_trainable)
self.rotary_emb_latent = get_rotary_emb(config.num_cross_attention_heads, config.rope_type_latent,
config.rope_partial_pe, config.rope_trainable)
self.rotary_emb_temporal = get_rotary_emb(config.num_cross_attention_heads, config.rope_type_temporal,
config.rope_partial_pe, config.rope_trainable)
else:
self.rotary_emb_sca = None
self.rotary_emb_latent = None
self.rotary_emb_temporal = None
if self.position_encoding_type in ['alibi', 'alibit', 'rope', None]:
# alibi is imeplemented within PerceiverAlibiSelfAttention, and activated by config.
# RoPE is implemented without using self.pos_emb.
self.pos_emb = None
else:
k, d = self.latent_array.latents.size()
max_t = int(config.num_max_positions) + 10 # 10 is headroom for future task tokens...
self.pos_emb = PerceiverTFTrainablePE(self.position_encoding_type, max_t, k, d)
"""
self.pos_emb() returns:
pos_emb: (max_t, K, D)
pos_emb_temporal: (max_t, K, D)
"""
# Construct the encoder blocks.
blocks = []
for _ in range(config.num_blocks):
block = PerceiverTFEncoderBlock(
config,
kv_dim=config.kv_dim,
sca_use_query_residual=sca_use_query_residual,
rotary_emb_sca=self.rotary_emb_sca, # (Modified) RoPE
rotary_emb_latent=self.rotary_emb_latent,
rotary_emb_temporal=self.rotary_emb_temporal)
blocks.append(block)
self.blocks = nn.ModuleList(blocks)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.latent_array.latents
def set_input_embeddings(self, value):
self.latent_array.latents = value
"""temporary fix for torch.compile issue"""
def forward(self, **kwargs):
if self.training is True:
return self._forward_compile(**kwargs)
else:
return self._forward_no_compile(**kwargs)
def _forward_no_compile(self, **kwargs):
return self._forward(**kwargs)
@torch.compile
def _forward_compile(self, **kwargs):
return self._forward(**kwargs)
def _forward(
self,
inputs: Optional[torch.FloatTensor] = None, # (B, T, F, kv_dim)
inputs_embeds: Optional[torch.FloatTensor] = None, # (B, T, F, kv_dim)
inputs_mask: Optional[torch.FloatTensor] = None, # (B, F) Mask freq. of inputs in SCA.
local_attention_mask: Optional[torch.FloatTensor] = None, # (B, K)
temporal_attention_mask: Optional[torch.FloatTensor] = None, # (B, T)
local_head_mask: Optional[torch.FloatTensor] = None,
temporal_head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MoEModelOutputWithCrossAttentions]:
# Inputs and inputs_embeds are tied, and actually the same. (following T5 convention)
# Inputs are from convoulutional features from audio.
# Don't be confused with latent embeddings, which is `self.latent_array.latents`, and
# used as hidden_state of block.
if inputs is None and inputs_embeds is not None:
inputs = inputs_embeds
elif inputs is None and inputs_embeds is None:
raise ValueError("You must provide 'inputs' or 'inputs_embeds' argument.")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, t, _f, _c = inputs.size()
device = inputs.device
# SCA attention to channels of inputs, instead of frequency bins.
if self.sca_attention_to_channel is True:
inputs = rearrange(inputs, "b t f c -> b t c f")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_blocks x num_heads]
# and head_mask is converted to shape [num_blocks x batch x num_heads x N x N]
local_head_mask = self.get_head_mask(local_head_mask,
self.config.num_blocks * self.config.num_local_transformers_per_block)
temporal_head_mask = self.get_head_mask(
temporal_head_mask, self.config.num_blocks * self.config.num_temporal_transformers_per_block)
# Prepare attention mask: not implemented
# Expand the latent embeddings by t: (B, K, D) --> (B, T, K, D)
latent_embeddings = self.latent_array(batch_size=batch_size) # (B, num_latents, d_latents)
expanded_latent_embeddings = latent_embeddings.unsqueeze(1).expand(-1, t, -1, -1)
# Add positional embeddings to the expanded latent embeddings: (B, T, K, D)
if self.pos_emb is not None:
pos_emb_latent, pos_emb_temporal = self.pos_emb.forward()
expanded_latent_embeddings = expanded_latent_embeddings + pos_emb_latent[:t]
# (max_t, K, D) -> (T, K, D) -> (B, T, K, D) auto-broadcasting
else:
pos_emb_temporal = None
# Lists to store intermediate outputs if required
all_hidden_states = []
all_attentions = []
all_cross_attentions = []
all_router_logits = []
hidden_states = expanded_latent_embeddings
# Forward-pass
for i, block in enumerate(self.blocks):
block_output = block(hidden_states=hidden_states,
inputs=inputs,
inputs_mask=inputs_mask,
local_attention_mask=local_attention_mask,
temporal_attention_mask=temporal_attention_mask,
local_head_mask=local_head_mask,
temporal_head_mask=temporal_head_mask,
pos_emb_temporal=pos_emb_temporal if i == 0 else None,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=True)
# Update the hidden_states for the next block
hidden_states = block_output.last_hidden_state
# Append to lists if required
if output_hidden_states:
all_hidden_states.append(hidden_states)
if output_attentions:
all_attentions.append(block_output.attentions)
all_cross_attentions.append(block_output.cross_attentions)
if output_router_logits:
all_router_logits.append(block_output.router_logits)
last_hidden_states = hidden_states
# Prepare outputs
if not return_dict:
# Convert lists to tuples
return (last_hidden_states, tuple(all_hidden_states) if all_hidden_states else None,
tuple(all_attentions) if all_attentions else None,
tuple(all_cross_attentions) if all_cross_attentions else None,
tuple(all_router_logits) if all_router_logits else None)
return MoEModelOutputWithCrossAttentions(
last_hidden_state=last_hidden_states,
hidden_states=tuple(all_hidden_states) if all_hidden_states else None,
attentions=tuple(all_attentions) if all_attentions else None,
cross_attentions=tuple(all_cross_attentions) if all_cross_attentions else None,
router_logits=tuple(all_router_logits) if all_router_logits else None)
def test():
# In HuggingFace's Perceiver implementation:
# `q_dim` is the latent array dimension d_latents of ((B), num_latents, d_latents).
# `kv_dim`os the actual input dimension D of (B, T, D)
# `qk_channels`, `v_channels`: are projection dimensions for attention, (B, T, C)
# (B, T, D) --> projection --> (B, T, C)
# However, PerceiverTF does not require projection:
# It takes as input a latent tensor (B, num_latents, d_latents) and a conv_feat tensor (T, B, F, C)
# The `spectral-cross-attention` and `local-self-attention-transformer` takes as input (B*T, F, C),
# and C=D=d_latents.
from model.ops import count_parameters
# Test input
b = 2 # batch
t = 10 # time steps (330 for 6s in paper)
f = 128 # freq of conv_feat
c = 128 # channels of conv_feat
k = 24 # num_latents
d = 128 # d_latents
conv_feat = torch.randn(b, t, f, c)
# construct PerceiverTFEncoder
config = PerceiverTFConfig()
pe_types = ['alibi', 'alibit', 'trainable', 'tkd', 'td', 'tk', 'kdt', None]
config.ff_layer_type = 'moe'
config.moe_num_experts = 4
config.moe_topk = 2
for pe_type in pe_types:
config.position_encoding_type = pe_type # 'alibi', 'alibit', 'trainable', 'tkd', 'td', 'tk', 'kdt', None
config.num_latents = k
config.d_latents = d
config.kv_dim = c
config.qk_channels = d
config.v_channels = d
encoder = PerceiverTFEncoder(config)
encoder.eval()
assert encoder.latent_array.latents.size() == (k, d)
# forward
enc_hidden_state = encoder.forward(inputs_embeds=conv_feat).last_hidden_state
# print(enc_hidden_state.shape) # [2, 10, 24, 128] = [B, T, K, D]
n_param = count_parameters(encoder)[1] // 1000
print(config.position_encoding_type, f'num_param: {n_param}K')
"""
PE type | num. param.
None | 1397K
alibi | 1397K
alibit (train slope) | 1397K
tkd | 2442K
td | 1441K
tk | 1405K
kdt | 1444K
MLP | 2637K
MoE (4 experts) | 4411K
MoE (6 experts) | 5594K
"""