from typing import Optional import torch from torch import nn, Tensor import copy from torchtune.modules import ( CausalSelfAttention, FeedForward, KVCache, RMSNorm, RotaryPositionalEmbeddings, # TransformerDecoder, replaced with our custom implementation. TransformerDecoderLayer, ) from masked_apply import MaskedApply def _get_clones(module: nn.Module, n: int) -> nn.ModuleList: """ Return a list of ``n`` identical layers. Args: module (nn.Module): module to be cloned n (int): number of clones Returns: nn.ModuleList: list of ``n`` identical layers """ # FIXME: copy.deepcopy() is not defined on nn.module return nn.ModuleList([copy.deepcopy(module) for i in range(n)]) class ColoringTransformerDecoder(nn.Module): """ See torchtune.models.llama2.TransformerDecoder for the original implementation. """ def __init__( self, tok_embeddings: nn.Embedding, embedding_transform: nn.Module, layer: TransformerDecoderLayer, num_layers: int, norm: nn.Module, output: nn.Linear, ) -> None: super().__init__() self.tok_embeddings = tok_embeddings self.embedding_transform = embedding_transform self.layers = _get_clones(layer, num_layers) self.norm = norm self.output = output def forward( self, tokens: Tensor, mask: Optional[Tensor] = None, colors: Optional[Tensor] = None, curr_pos: int = 0 ) -> Tensor: """ Args: tokens (Tensor): input tensor with shape [b x s] mask (Optional[Tensor]): attention mask tensor, defaults to None. curr_pos (int): current position in the seq, defaults to 0. Only relevant when incrementally decoding. Returns: Tensor: output tensor with shape [b x s x v] Notation used for tensor shapes: - b: batch size - s: sequence length - v: vocab size - d: embed dim """ # input tensor of shape [b, s] bsz, seq_len = tokens.shape # shape: [b, s, d] h = self.tok_embeddings(tokens) h = self.embedding_transform(h, colors) # TODO: Fix the masking logic to not rely on checking kv_cache if seq_len > 1 and self.layers[0].attn.kv_cache is not None: mask = torch.full( (1, 1, seq_len, seq_len), float("-inf"), device=tokens.device ) mask = torch.triu(mask, diagonal=curr_pos + 1) for layer in self.layers: # shape: [b, s, d] h = layer(h, mask, curr_pos) # shape: [b, s, d] h = self.norm(h) # shape: [b, s, v] output = self.output(h).float() return output def colouring_llama2_7b(max_batch_size: Optional[int] = None) -> ColoringTransformerDecoder: """Builder for creating a Llama2 model initialized w/ the default 7b parameter values. From https://arxiv.org/abs/2307.09288, these default values are: - vocab_size: 32,000 - embed_dim: 4,096 - num_layers: 32 - num_heads: 32 - num_kv_heads: 32 - max_seq_len: 4,096 - norm_eps: 1e-5 Args: max_batch_size (Optional[int]): Maximum batch size to be passed to KVCache. Returns: A ``TransformerDecoder`` instance of the Llama2 model. """ return colouring_llama2( vocab_size=32_000, num_layers=32, num_heads=32, num_kv_heads=32, embed_dim=4096, max_seq_len=4096, num_colors=4, # color for default, instruction, input, response max_batch_size=max_batch_size, attn_dropout=0.0, norm_eps=1e-5, ) def _scale_hidden_dim_for_mlp(dim: int, multiple_of: int = 256) -> int: """Scale hidden dimension for MLP to keep number of parameters and computation constant. Args: dim (int): Input dimension. multiple_of (int): Round scaled dimension to nearest multiple of `multiple_of` for clean computation. Returns: Scaled hidden dimension. """ # Scale hidden dimension by (2/3)4d for SwiGLU to keep number of # parameters and computation constant hidden_dim = 4 * int(2 * dim / 3) # Round hidden dimension to nearest multiple of `multiple_of` hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) return hidden_dim def colouring_llama2( vocab_size: int, num_layers: int, num_heads: int, num_kv_heads: int, embed_dim: int, max_seq_len: int, num_colors: int, attn_dropout: float = 0.0, max_batch_size: Optional[int] = None, norm_eps: float = 1e-5, ): head_dim = embed_dim // num_heads num_kv_heads = num_kv_heads if num_kv_heads else num_heads kv_cache = ( KVCache( max_batch_size=max_batch_size, max_seq_len=max_seq_len, n_kv_heads=num_heads, head_dim=head_dim, ) if max_batch_size is not None else None ) rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) self_attn = CausalSelfAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), output_proj=nn.Linear(embed_dim, embed_dim, bias=False), pos_embeddings=rope, kv_cache=kv_cache, max_seq_len=max_seq_len, attn_dropout=attn_dropout, ) hidden_dim = _scale_hidden_dim_for_mlp(embed_dim) mlp = FeedForward(dim=embed_dim, hidden_dim=hidden_dim, linear_class=nn.Linear) layer = TransformerDecoderLayer( attn=self_attn, mlp=mlp, sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), ) tok_embeddings = nn.Embedding(vocab_size, embed_dim) output_proj = nn.Linear(embed_dim, vocab_size, bias=False) return ColoringTransformerDecoder( tok_embeddings=tok_embeddings, embedding_transform=MaskedApply([nn.Linear(embed_dim, embed_dim) for _ in range(num_colors)]), layer=layer, num_layers=num_layers, norm=RMSNorm(embed_dim, eps=norm_eps), output=output_proj, )