|
|
|
|
|
|
|
|
|
import math |
|
from typing import Optional, Sequence |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
|
|
from mobileclip.modules.common.transformer import ( |
|
PositionalEmbedding, |
|
TransformerEncoder, |
|
get_normalization_layer, |
|
) |
|
from mobileclip.modules.text.repmixer import RepMixerBlock |
|
from mobileclip import logger |
|
|
|
|
|
class TextTransformer(nn.Module): |
|
def __init__(self, cfg: dict, projection_dim: int, *args, **kwargs) -> None: |
|
super().__init__() |
|
|
|
model_dim = cfg["dim"] |
|
no_scale_embedding = cfg.get("no_scale_embedding", False) |
|
no_pos_embedding = cfg.get("no_pos_embedding", False) |
|
embed_dropout = cfg.get("embed_dropout", 0.0) |
|
norm_layer = cfg["norm_layer"] |
|
variant = cfg["model_name"] |
|
self.vocab_size = cfg["vocab_size"] |
|
self.projection_dim = projection_dim |
|
|
|
|
|
self.embedding_layer = nn.Embedding( |
|
embedding_dim=model_dim, num_embeddings=self.vocab_size |
|
) |
|
self.embed_scale = 1.0 if no_scale_embedding else model_dim**-0.5 |
|
|
|
|
|
context_length = cfg["context_length"] |
|
assert ( |
|
context_length is not None |
|
), "Context length can't be None. Please set value accordingly." |
|
|
|
self.positional_embedding = ( |
|
None |
|
if no_pos_embedding |
|
else PositionalEmbedding( |
|
num_embeddings=context_length, embedding_dim=model_dim |
|
) |
|
) |
|
|
|
self.embedding_dropout = nn.Dropout(p=embed_dropout) |
|
|
|
|
|
n_transformer_layers = cfg["n_transformer_layers"] |
|
|
|
|
|
ffn_multipliers = cfg["ffn_multiplier_per_layer"] |
|
if isinstance(ffn_multipliers, (float, int)): |
|
ffn_multipliers = [ffn_multipliers] * n_transformer_layers |
|
|
|
if not isinstance(ffn_multipliers, Sequence): |
|
logger.error( |
|
"{} expects FFN multipliers as a list, whose length is the same as" |
|
" number of transformer layers. Got: {}".format( |
|
self.__class__.__name__, type(ffn_multipliers) |
|
) |
|
) |
|
elif ( |
|
isinstance(ffn_multipliers, Sequence) |
|
and len(ffn_multipliers) != n_transformer_layers |
|
): |
|
logger.error( |
|
"We need FFN multiplier for each transformer layer. Got {} ffn" |
|
" multipliers while number of transformer layers = {}".format( |
|
len(ffn_multipliers), n_transformer_layers |
|
) |
|
) |
|
ffn_dims = [ |
|
int(math.ceil(model_dim * ffn_mult / 16.0) * 16.0) |
|
for ffn_mult in ffn_multipliers |
|
] |
|
|
|
|
|
mha_heads = cfg["n_heads_per_layer"] |
|
if isinstance(mha_heads, int): |
|
mha_heads = [mha_heads] * n_transformer_layers |
|
|
|
if not isinstance(mha_heads, Sequence): |
|
logger.error( |
|
"{} expects MHA heads as a list, whose length is the same as number of " |
|
"transformer layers. Got: {}".format( |
|
self.__class__.__name__, type(mha_heads) |
|
) |
|
) |
|
elif isinstance(mha_heads, Sequence) and len(mha_heads) != n_transformer_layers: |
|
logger.error( |
|
"{} needs MHA heads for each transformer layer. Got {} mha heads while" |
|
" number of transformer layers = {}".format( |
|
self.__class__.__name__, len(mha_heads), n_transformer_layers |
|
) |
|
) |
|
|
|
if variant == "base": |
|
self.transformer = nn.ModuleList( |
|
[ |
|
TransformerEncoder( |
|
embed_dim=model_dim, |
|
num_heads=mha_heads[layer_idx], |
|
ffn_latent_dim=ffn_dims[layer_idx], |
|
transformer_norm_layer=norm_layer, |
|
) |
|
for layer_idx in range(n_transformer_layers) |
|
] |
|
) |
|
elif variant == "mct": |
|
self.transformer = nn.ModuleList([RepMixerBlock(dim=model_dim)]) |
|
self.transformer.extend( |
|
[ |
|
TransformerEncoder( |
|
embed_dim=model_dim, |
|
num_heads=mha_heads[layer_idx], |
|
ffn_latent_dim=ffn_dims[layer_idx], |
|
transformer_norm_layer=norm_layer, |
|
) |
|
for layer_idx in range(n_transformer_layers) |
|
] |
|
) |
|
self.transformer.extend([RepMixerBlock(dim=model_dim)]) |
|
else: |
|
raise ValueError("Unrecognized text encoder variant {}".format(variant)) |
|
|
|
self.final_layer_norm = get_normalization_layer( |
|
num_features=model_dim, norm_type=norm_layer |
|
) |
|
|
|
self.projection_layer = nn.Parameter( |
|
torch.empty(model_dim, self.projection_dim) |
|
) |
|
self.model_dim = model_dim |
|
self.causal_masking = cfg["causal_masking"] |
|
|
|
def forward_embedding(self, text_tokens: Tensor) -> Tensor: |
|
"""Return text embedding for all tokens. |
|
|
|
Args: |
|
text_tokens: a tensor of token indices. Shape: [batch_size, context_length] |
|
|
|
Returns: |
|
A tensor of [batch_size, context_length, hidden_dim]. |
|
""" |
|
|
|
token_emb = self.embedding_layer(text_tokens) |
|
seq_len = token_emb.shape[1] |
|
if self.positional_embedding is not None: |
|
token_emb = token_emb + self.positional_embedding(seq_len).to( |
|
token_emb.dtype |
|
) |
|
token_emb = self.embedding_dropout(token_emb) |
|
return token_emb |
|
|
|
def build_attention_mask(self, context_length: int, batch_size: int) -> Tensor: |
|
"""Build causal attention mask [batch_size, context_length, context_length].""" |
|
|
|
|
|
mask = torch.empty(context_length, context_length) |
|
mask.fill_(float("-inf")) |
|
mask.triu_(1) |
|
mask = mask.unsqueeze(0) |
|
mask = mask.expand(batch_size, -1, -1) |
|
return mask |
|
|
|
def encode_text( |
|
self, |
|
text_tokens: Tensor, |
|
key_padding_mask: Optional[Tensor] = None, |
|
return_all_tokens: bool = False, |
|
*args, |
|
**kwargs |
|
) -> Tensor: |
|
"""Return text token embeddings. |
|
|
|
Args: |
|
text_tokens: a tensor of token indices. Shape: [batch_size, context_length] |
|
key_padding_mask: a tensor of boolean values as the padding mask. |
|
Shape: [batch_size, context_length] |
|
return_all_tokens: a boolean flag to return all tokens, defaults to False |
|
to return only EOT token embedding. |
|
Returns: |
|
A tensor of [batch_size, context_length, hidden_dim] if return_all_tokens is |
|
True, otherwise a tensor of [batch_size, hidden_dim]. |
|
""" |
|
|
|
|
|
token_emb = self.forward_embedding(text_tokens) |
|
|
|
|
|
attn_mask = None |
|
if self.causal_masking: |
|
attn_mask = self.build_attention_mask( |
|
context_length=text_tokens.shape[1], batch_size=text_tokens.shape[0] |
|
) |
|
attn_mask = attn_mask.to(device=token_emb.device, dtype=token_emb.dtype) |
|
key_padding_mask = None |
|
|
|
for layer in self.transformer: |
|
token_emb = layer( |
|
token_emb, |
|
key_padding_mask=key_padding_mask, |
|
attn_mask=attn_mask, |
|
) |
|
|
|
|
|
token_emb = self.final_layer_norm(token_emb) |
|
|
|
if return_all_tokens: |
|
return token_emb |
|
|
|
|
|
token_emb = token_emb[ |
|
torch.arange(text_tokens.shape[0]), text_tokens.argmax(dim=-1) |
|
] |
|
|
|
token_emb = token_emb @ self.projection_layer |
|
return token_emb |
|
|
|
def forward( |
|
self, |
|
text_tokens: Tensor, |
|
key_padding_mask: Optional[Tensor] = None, |
|
return_all_tokens: bool = False, |
|
*args, |
|
**kwargs |
|
) -> Tensor: |
|
|
|
|
|
text_tokens = self.encode_text( |
|
text_tokens=text_tokens, |
|
key_padding_mask=key_padding_mask, |
|
return_all_tokens=return_all_tokens, |
|
*args, |
|
**kwargs |
|
) |
|
return text_tokens |
|
|