|
from typing import Optional |
|
import copy |
|
import math |
|
|
|
import torch |
|
from torch import nn, Tensor |
|
|
|
from torchtune.modules import ( |
|
CausalSelfAttention, |
|
FeedForward, |
|
KVCache, |
|
RMSNorm, |
|
RotaryPositionalEmbeddings, |
|
|
|
TransformerDecoderLayer, |
|
) |
|
|
|
from masked_apply import MaskedApply |
|
|
|
|
|
def initialize_identity_linear(size): |
|
layer = nn.Linear(size, size) |
|
layer.weight.data.copy_(torch.eye(size)) |
|
layer.bias.data.copy_(torch.zeros(size)) |
|
return layer |
|
|
|
def initialize_linear(size): |
|
return nn.Linear(size, size) |
|
|
|
def initialize_kaiming_uniform_linear(size): |
|
layer = nn.Linear(size, size) |
|
nn.init.kaiming_uniform_(layer.weight, a=math.sqrt(5)) |
|
layer.bias.data.copy_(torch.zeros(size)) |
|
return layer |
|
|
|
def initialize_zeros_linear(size): |
|
layer = nn.Linear(size, size) |
|
layer.weight.data.copy_(torch.zeros(size)) |
|
layer.bias.data.copy_(torch.zeros(size)) |
|
return layer |
|
|
|
INITIALIZATION_OPTIONS = { |
|
"identity": initialize_identity_linear, |
|
"default": initialize_linear, |
|
"kaiming_uniform": initialize_kaiming_uniform_linear, |
|
"zeros": initialize_zeros_linear, |
|
} |
|
|
|
def _get_clones(module: nn.Module, n: int) -> nn.ModuleList: |
|
""" |
|
Return a list of ``n`` identical layers. |
|
|
|
Args: |
|
module (nn.Module): module to be cloned |
|
n (int): number of clones |
|
|
|
Returns: |
|
nn.ModuleList: list of ``n`` identical layers |
|
""" |
|
|
|
return nn.ModuleList([copy.deepcopy(module) for i in range(n)]) |
|
|
|
|
|
class ColoringTransformerDecoder(nn.Module): |
|
""" |
|
See torchtune.models.llama2.TransformerDecoder for the original implementation. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
tok_embeddings: nn.Embedding, |
|
embedding_transform: nn.Module, |
|
layer: TransformerDecoderLayer, |
|
num_layers: int, |
|
norm: nn.Module, |
|
output: nn.Linear, |
|
embedding_norm: nn.Module = None |
|
) -> None: |
|
super().__init__() |
|
self.tok_embeddings = tok_embeddings |
|
self.embedding_transform = embedding_transform |
|
self.embedding_norm = embedding_norm |
|
self.layers = _get_clones(layer, num_layers) |
|
self.norm = norm |
|
self.output = output |
|
|
|
def forward( |
|
self, |
|
tokens: Tensor, |
|
mask: Optional[Tensor] = None, |
|
colors: Optional[Tensor] = None, |
|
curr_pos: int = 0 |
|
) -> Tensor: |
|
""" |
|
Args: |
|
tokens (Tensor): input tensor with shape [b x s] |
|
mask (Optional[Tensor]): attention mask tensor, defaults to None. |
|
curr_pos (int): current position in the seq, defaults to 0. |
|
Only relevant when incrementally decoding. |
|
|
|
Returns: |
|
Tensor: output tensor with shape [b x s x v] |
|
|
|
Notation used for tensor shapes: |
|
- b: batch size |
|
- s: sequence length |
|
- v: vocab size |
|
- d: embed dim |
|
""" |
|
|
|
bsz, seq_len = tokens.shape |
|
|
|
|
|
h = self.tok_embeddings(tokens) |
|
|
|
|
|
|
|
ch = h |
|
if self.embedding_norm is not None: |
|
|
|
ch = self.embedding_norm(h.clone()) |
|
|
|
|
|
ch = self.embedding_transform(ch, colors) |
|
|
|
|
|
h = h + ch |
|
|
|
|
|
if seq_len > 1 and self.layers[0].attn.kv_cache is not None: |
|
mask = torch.full( |
|
(1, 1, seq_len, seq_len), float("-inf"), device=tokens.device |
|
) |
|
mask = torch.triu(mask, diagonal=curr_pos + 1) |
|
|
|
for layer in self.layers: |
|
|
|
h = layer(h, mask, curr_pos) |
|
|
|
|
|
h = self.norm(h) |
|
|
|
|
|
output = self.output(h).float() |
|
return output |
|
|
|
|
|
def coloring_llama2_7b(color_layer_initialization: str, norm_before_color_layer: bool = False, max_batch_size: Optional[int] = None) -> ColoringTransformerDecoder: |
|
"""Builder for creating a Llama2 model initialized w/ the default 7b parameter values. |
|
From https://arxiv.org/abs/2307.09288, these default values are: |
|
- vocab_size: 32,000 |
|
- embed_dim: 4,096 |
|
- num_layers: 32 |
|
- num_heads: 32 |
|
- num_kv_heads: 32 |
|
- max_seq_len: 4,096 |
|
- norm_eps: 1e-5 |
|
|
|
Args: |
|
max_batch_size (Optional[int]): Maximum batch size to be passed to KVCache. |
|
|
|
Returns: |
|
A ``TransformerDecoder`` instance of the Llama2 model. |
|
""" |
|
return coloring_llama2( |
|
color_layer_initialization=color_layer_initialization, |
|
vocab_size=32_000, |
|
num_layers=32, |
|
num_heads=32, |
|
num_kv_heads=32, |
|
embed_dim=4096, |
|
max_seq_len=4096, |
|
num_colors=4, |
|
max_batch_size=max_batch_size, |
|
attn_dropout=0.0, |
|
norm_eps=1e-5, |
|
norm_before_color_layer=norm_before_color_layer |
|
) |
|
|
|
def _scale_hidden_dim_for_mlp(dim: int, multiple_of: int = 256) -> int: |
|
"""Scale hidden dimension for MLP to keep number of parameters and computation constant. |
|
|
|
Args: |
|
dim (int): Input dimension. |
|
multiple_of (int): Round scaled dimension to nearest multiple of `multiple_of` for clean computation. |
|
|
|
Returns: |
|
Scaled hidden dimension. |
|
""" |
|
|
|
|
|
hidden_dim = 4 * int(2 * dim / 3) |
|
|
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
|
return hidden_dim |
|
|
|
|
|
def coloring_llama2( |
|
color_layer_initialization: str, |
|
vocab_size: int, |
|
num_layers: int, |
|
num_heads: int, |
|
num_kv_heads: int, |
|
embed_dim: int, |
|
max_seq_len: int, |
|
num_colors: int, |
|
norm_before_color_layer: bool = False, |
|
attn_dropout: float = 0.0, |
|
max_batch_size: Optional[int] = None, |
|
norm_eps: float = 1e-5, |
|
): |
|
if color_layer_initialization not in INITIALIZATION_OPTIONS: |
|
raise ValueError(f"Invalid color_layer_initialization: {color_layer_initialization}. Expected one of {list(INITIALIZATION_OPTIONS.keys())}.") |
|
color_layer_initializer = INITIALIZATION_OPTIONS[color_layer_initialization] |
|
|
|
head_dim = embed_dim // num_heads |
|
num_kv_heads = num_kv_heads if num_kv_heads else num_heads |
|
kv_cache = ( |
|
KVCache( |
|
max_batch_size=max_batch_size, |
|
max_seq_len=max_seq_len, |
|
n_kv_heads=num_heads, |
|
head_dim=head_dim, |
|
) |
|
if max_batch_size is not None |
|
else None |
|
) |
|
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) |
|
self_attn = CausalSelfAttention( |
|
embed_dim=embed_dim, |
|
num_heads=num_heads, |
|
num_kv_heads=num_kv_heads, |
|
head_dim=head_dim, |
|
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), |
|
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), |
|
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), |
|
output_proj=nn.Linear(embed_dim, embed_dim, bias=False), |
|
pos_embeddings=rope, |
|
kv_cache=kv_cache, |
|
max_seq_len=max_seq_len, |
|
attn_dropout=attn_dropout, |
|
) |
|
hidden_dim = _scale_hidden_dim_for_mlp(embed_dim) |
|
mlp = FeedForward(dim=embed_dim, hidden_dim=hidden_dim, linear_class=nn.Linear) |
|
layer = TransformerDecoderLayer( |
|
attn=self_attn, |
|
mlp=mlp, |
|
sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), |
|
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), |
|
) |
|
tok_embeddings = nn.Embedding(vocab_size, embed_dim) |
|
output_proj = nn.Linear(embed_dim, vocab_size, bias=False) |
|
embedding_transform = MaskedApply( |
|
[color_layer_initializer(embed_dim) for _ in range(num_colors)], |
|
strict=True |
|
) |
|
embedding_norm = RMSNorm(embed_dim, eps=norm_eps) if norm_before_color_layer else None |
|
|
|
return ColoringTransformerDecoder( |
|
tok_embeddings=tok_embeddings, |
|
embedding_transform=embedding_transform, |
|
embedding_norm=embedding_norm, |
|
layer=layer, |
|
num_layers=num_layers, |
|
norm=RMSNorm(embed_dim, eps=norm_eps), |
|
output=output_proj, |
|
) |
|
|