laurencer's picture
Step 6000
261dbc8 verified
raw
history blame
8.69 kB
from typing import Optional
import copy
import math
import torch
from torch import nn, Tensor
from torchtune.modules import (
CausalSelfAttention,
FeedForward,
KVCache,
RMSNorm,
RotaryPositionalEmbeddings,
# TransformerDecoder, replaced with our custom implementation.
TransformerDecoderLayer,
)
from masked_apply import MaskedApply
def initialize_identity_linear(size):
layer = nn.Linear(size, size)
layer.weight.data.copy_(torch.eye(size))
layer.bias.data.copy_(torch.zeros(size))
return layer
def initialize_linear(size):
return nn.Linear(size, size)
def initialize_kaiming_uniform_linear(size):
layer = nn.Linear(size, size)
nn.init.kaiming_uniform_(layer.weight, a=math.sqrt(5))
layer.bias.data.copy_(torch.zeros(size))
return layer
def initialize_zeros_linear(size):
layer = nn.Linear(size, size)
layer.weight.data.copy_(torch.zeros(size))
layer.bias.data.copy_(torch.zeros(size))
return layer
INITIALIZATION_OPTIONS = {
"identity": initialize_identity_linear,
"default": initialize_linear,
"kaiming_uniform": initialize_kaiming_uniform_linear,
"zeros": initialize_zeros_linear,
}
def _get_clones(module: nn.Module, n: int) -> nn.ModuleList:
"""
Return a list of ``n`` identical layers.
Args:
module (nn.Module): module to be cloned
n (int): number of clones
Returns:
nn.ModuleList: list of ``n`` identical layers
"""
# FIXME: copy.deepcopy() is not defined on nn.module
return nn.ModuleList([copy.deepcopy(module) for i in range(n)])
class ColoringTransformerDecoder(nn.Module):
"""
See torchtune.models.llama2.TransformerDecoder for the original implementation.
"""
def __init__(
self,
tok_embeddings: nn.Embedding,
embedding_transform: nn.Module,
layer: TransformerDecoderLayer,
num_layers: int,
norm: nn.Module,
output: nn.Linear,
embedding_norm: nn.Module = None
) -> None:
super().__init__()
self.tok_embeddings = tok_embeddings
self.embedding_transform = embedding_transform
self.embedding_norm = embedding_norm
self.layers = _get_clones(layer, num_layers)
self.norm = norm
self.output = output
def forward(
self,
tokens: Tensor,
mask: Optional[Tensor] = None,
colors: Optional[Tensor] = None,
curr_pos: int = 0
) -> Tensor:
"""
Args:
tokens (Tensor): input tensor with shape [b x s]
mask (Optional[Tensor]): attention mask tensor, defaults to None.
curr_pos (int): current position in the seq, defaults to 0.
Only relevant when incrementally decoding.
Returns:
Tensor: output tensor with shape [b x s x v]
Notation used for tensor shapes:
- b: batch size
- s: sequence length
- v: vocab size
- d: embed dim
"""
# input tensor of shape [b, s]
bsz, seq_len = tokens.shape
# shape: [b, s, d]
h = self.tok_embeddings(tokens)
# Apply normalization before embedding transform to improve
# training stability.
ch = h
if self.embedding_norm is not None:
# TODO: norm does an in-place operation, so we need to clone the input
ch = self.embedding_norm(h.clone())
# Apply the embedding transform (e.g. color layer)
ch = self.embedding_transform(ch, colors)
# Add the output of the color transform to the embeddings
h = h + ch
# TODO: Fix the masking logic to not rely on checking kv_cache
if seq_len > 1 and self.layers[0].attn.kv_cache is not None:
mask = torch.full(
(1, 1, seq_len, seq_len), float("-inf"), device=tokens.device
)
mask = torch.triu(mask, diagonal=curr_pos + 1)
for layer in self.layers:
# shape: [b, s, d]
h = layer(h, mask, curr_pos)
# shape: [b, s, d]
h = self.norm(h)
# shape: [b, s, v]
output = self.output(h).float()
return output
def coloring_llama2_7b(color_layer_initialization: str, norm_before_color_layer: bool = False, max_batch_size: Optional[int] = None) -> ColoringTransformerDecoder:
"""Builder for creating a Llama2 model initialized w/ the default 7b parameter values.
From https://arxiv.org/abs/2307.09288, these default values are:
- vocab_size: 32,000
- embed_dim: 4,096
- num_layers: 32
- num_heads: 32
- num_kv_heads: 32
- max_seq_len: 4,096
- norm_eps: 1e-5
Args:
max_batch_size (Optional[int]): Maximum batch size to be passed to KVCache.
Returns:
A ``TransformerDecoder`` instance of the Llama2 model.
"""
return coloring_llama2(
color_layer_initialization=color_layer_initialization,
vocab_size=32_000,
num_layers=32,
num_heads=32,
num_kv_heads=32,
embed_dim=4096,
max_seq_len=4096,
num_colors=4, # color for default, instruction, input, response
max_batch_size=max_batch_size,
attn_dropout=0.0,
norm_eps=1e-5,
norm_before_color_layer=norm_before_color_layer
)
def _scale_hidden_dim_for_mlp(dim: int, multiple_of: int = 256) -> int:
"""Scale hidden dimension for MLP to keep number of parameters and computation constant.
Args:
dim (int): Input dimension.
multiple_of (int): Round scaled dimension to nearest multiple of `multiple_of` for clean computation.
Returns:
Scaled hidden dimension.
"""
# Scale hidden dimension by (2/3)4d for SwiGLU to keep number of
# parameters and computation constant
hidden_dim = 4 * int(2 * dim / 3)
# Round hidden dimension to nearest multiple of `multiple_of`
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
return hidden_dim
def coloring_llama2(
color_layer_initialization: str,
vocab_size: int,
num_layers: int,
num_heads: int,
num_kv_heads: int,
embed_dim: int,
max_seq_len: int,
num_colors: int,
norm_before_color_layer: bool = False,
attn_dropout: float = 0.0,
max_batch_size: Optional[int] = None,
norm_eps: float = 1e-5,
):
if color_layer_initialization not in INITIALIZATION_OPTIONS:
raise ValueError(f"Invalid color_layer_initialization: {color_layer_initialization}. Expected one of {list(INITIALIZATION_OPTIONS.keys())}.")
color_layer_initializer = INITIALIZATION_OPTIONS[color_layer_initialization]
head_dim = embed_dim // num_heads
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
kv_cache = (
KVCache(
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
n_kv_heads=num_heads,
head_dim=head_dim,
)
if max_batch_size is not None
else None
)
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)
self_attn = CausalSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
pos_embeddings=rope,
kv_cache=kv_cache,
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
)
hidden_dim = _scale_hidden_dim_for_mlp(embed_dim)
mlp = FeedForward(dim=embed_dim, hidden_dim=hidden_dim, linear_class=nn.Linear)
layer = TransformerDecoderLayer(
attn=self_attn,
mlp=mlp,
sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
)
tok_embeddings = nn.Embedding(vocab_size, embed_dim)
output_proj = nn.Linear(embed_dim, vocab_size, bias=False)
embedding_transform = MaskedApply(
[color_layer_initializer(embed_dim) for _ in range(num_colors)],
strict=True
)
embedding_norm = RMSNorm(embed_dim, eps=norm_eps) if norm_before_color_layer else None
return ColoringTransformerDecoder(
tok_embeddings=tok_embeddings,
embedding_transform=embedding_transform,
embedding_norm=embedding_norm,
layer=layer,
num_layers=num_layers,
norm=RMSNorm(embed_dim, eps=norm_eps),
output=output_proj,
)