# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. # Derivated from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py import math from typing import Optional, Tuple import torch import torch.nn as nn from dataclasses import dataclass from pathlib import Path from typing import Optional, Union from sentencepiece import SentencePieceProcessor import torch @dataclass class ItaliaConfig: block_size: int = 4096 vocab_size: int = 50_000 padding_multiple: int = 512 padded_vocab_size: int = 50176 head_size: int = 160 n_layer: int = 34 n_head: int = 32 n_embd: int = 5120 rotary_percentage: float = 0.4 parallel_residual: bool = True bias: bool = True lm_head_bias: bool = True n_query_groups: int = 32 shared_attention_norm: bool = True norm_eps: float = 1e-5 intermediate_size: int = 12800 rope_condense_ratio: int = 1 rope_n_elem: int = 64 rope_base: int = 10000 class Tokenizer: def __init__(self, checkpoint_dir: Union[Path, str]) -> None: checkpoint_dir = Path(checkpoint_dir) if not checkpoint_dir.exists(): raise NotADirectoryError( f"The checkpoint directory does not exist: {str(checkpoint_dir)}" ) self.use_bos = True self.bos_id = None self.eos_id = None if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file(): self.processor = SentencePieceProcessor(model_file=str(vocabulary_path)) self.backend = "sentencepiece" self.bos_id = self.processor.bos_id() self.eos_id = self.processor.eos_id() else: raise FileNotFoundError( f"tokenizer.model not found in {str(checkpoint_dir)}" ) @property def vocab_size(self) -> int: return self.processor.vocab_size() def token_to_id(self, token: str) -> int: return self.processor.piece_to_id(token) def encode( self, string: str, device: Optional[torch.device] = None, max_length: int = -1, ) -> torch.Tensor: tokens = self.processor.encode(string) tokens = [self.bos_id] + tokens if max_length > 0: tokens = tokens[:max_length] return torch.tensor(tokens, dtype=torch.int, device=device) def decode(self, tensor: torch.Tensor) -> str: tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() return self.processor.decode(tokens).strip() class Italia(nn.Module): def __init__(self, config: ItaliaConfig) -> None: super().__init__() assert config.padded_vocab_size is not None self.config = config self.lm_head = nn.Linear( config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias ) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), ln_f=nn.LayerNorm(config.n_embd, eps=config.norm_eps), ) ) self.max_seq_length = self.config.block_size self.mask_cache: Optional[torch.Tensor] = None @property def max_seq_length(self) -> int: return self._max_seq_length @max_seq_length.setter def max_seq_length(self, value: int) -> None: """ When doing inference, the sequences used might be shorter than the model's context length. This allows setting a smaller number to avoid allocating unused memory """ if value > self.config.block_size: raise ValueError( f"Cannot attend to {value}, block size is only {self.config.block_size}" ) self._max_seq_length = value if not hasattr(self, "cos"): cos, sin = self.rope_cache() self.register_buffer("cos", cos, persistent=False) self.register_buffer("sin", sin, persistent=False) elif value != self.cos.size(0): self.cos, self.sin = self.rope_cache(device=self.cos.device) def reset_parameters(self) -> None: self.cos, self.sin = self.rope_cache() def forward( self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None ) -> torch.Tensor: T = idx.size(1) if self.max_seq_length < T: raise ValueError( f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}." ) if input_pos is not None: # use the kv cache cos = self.cos.index_select(0, input_pos) sin = self.sin.index_select(0, input_pos) if self.mask_cache is None: raise TypeError("You need to call `gpt.set_kv_cache()`") mask = self.mask_cache.index_select(2, input_pos) else: cos = self.cos[:T] sin = self.sin[:T] mask = None x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) for block in self.transformer.h: x = block(x, cos, sin, mask, input_pos) x = self.transformer.ln_f(x) return self.lm_head(x) # (b, t, vocab_size) def rope_cache( self, device: Optional[torch.device] = None ) -> Tuple[torch.Tensor, torch.Tensor]: return build_rope_cache( seq_len=self.max_seq_length, n_elem=self.config.rope_n_elem, device=device, condense_ratio=self.config.rope_condense_ratio, base=self.config.rope_base, ) def set_kv_cache( self, batch_size: int, rope_cache_length: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: if rope_cache_length is None: rope_cache_length = self.cos.size(-1) max_seq_length = self.max_seq_length for block in self.transformer.h: block.attn.kv_cache = block.attn.build_kv_cache( batch_size, max_seq_length, rope_cache_length, device, dtype ) if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length: self.mask_cache = build_mask_cache(max_seq_length, device) def clear_kv_cache(self) -> None: self.mask_cache = None for block in self.transformer.h: block.attn.kv_cache = None class Block(nn.Module): def __init__(self, config: ItaliaConfig) -> None: super().__init__() self.norm_1 = nn.LayerNorm(config.n_embd, eps=config.norm_eps) self.attn = CausalSelfAttention(config) self.mlp = MLP(config) self.config = config def forward( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: n_1 = self.norm_1(x) h = self.attn(n_1, cos, sin, mask, input_pos) n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) x = self.mlp(n_2) + h + x return x class CausalSelfAttention(nn.Module): def __init__(self, config: ItaliaConfig) -> None: super().__init__() shape = (config.n_head + 2 * config.n_query_groups) * config.head_size linear_module = nn.Linear self.attn = linear_module(config.n_embd, shape, bias=config.bias) self.proj = linear_module(config.n_embd, config.n_embd, bias=config.bias) self.kv_cache: Optional[KVCache] = None self.config = config def forward( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: B, T, _ = ( x.size() ) # batch size, sequence length, embedding dimensionality (n_embd) qkv = self.attn(x) # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) q_per_kv = self.config.n_head // self.config.n_query_groups total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value qkv = qkv.view( B, T, self.config.n_query_groups, total_qkv, self.config.head_size ) qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) # split batched computation into three q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) if input_pos is not None: if not isinstance(self.kv_cache, KVCache): raise TypeError("You need to call `gpt.set_kv_cache()`") k, v = self.kv_cache(input_pos, k, v) y = self.scaled_dot_product_attention(q, k, v, mask) y = y.reshape( B, T, self.config.n_embd ) # re-assemble all head outputs side by side # output projection return self.proj(y) def scaled_dot_product_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: scale = 1.0 / math.sqrt(self.config.head_size) y = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None ) return y.transpose(1, 2) def build_kv_cache( self, batch_size: int, max_seq_length: int, rope_cache_length: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> "KVCache": heads = 1 if self.config.n_query_groups == 1 else self.config.n_head v_shape = (batch_size, heads, max_seq_length, self.config.head_size) if rope_cache_length is None: if self.config.rotary_percentage != 1.0: raise TypeError( "Please pass the `rope_cache_length=gpt.cos.size(-1)` value" ) k_shape = v_shape else: k_shape = ( batch_size, heads, max_seq_length, rope_cache_length + self.config.head_size - self.config.rope_n_elem, ) return KVCache(k_shape, v_shape, device=device, dtype=dtype) class MLP(nn.Module): def __init__(self, config: ItaliaConfig) -> None: super().__init__() self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) self.config = config def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc(x) x = torch.nn.functional.gelu(x, approximate="tanh") return self.proj(x) def build_rope_cache( seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor]: """Enhanced Transformer with Rotary Position Embedding. Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ transformers/rope/__init__.py. MIT License: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. """ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) # Create position indexes `[0, 1, ..., seq_len - 1]` seq_idx = torch.arange(seq_len, device=device) / condense_ratio # Calculate the product of position index and $\theta_i$ idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) return torch.cos(idx_theta), torch.sin(idx_theta) def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: head_size = x.size(-1) x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) roped = (x * cos) + (rotated * sin) return roped.to(dtype=x.dtype) class KVCache(nn.Module): def __init__( self, k_shape: Tuple[int, int, int, int], v_shape: Tuple[int, int, int, int], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: super().__init__() self.register_buffer( "k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False ) self.register_buffer( "v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False ) def forward( self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # move the buffer to the activation dtype for when AMP is used self.k = self.k.to(k.dtype) self.v = self.v.to(v.dtype) # update the cache k = self.k.index_copy_(2, input_pos, k) v = self.v.index_copy_(2, input_pos, v) return k, v def reset_parameters(self) -> None: torch.nn.init.zeros_(self.k) torch.nn.init.zeros_(self.v) def build_mask_cache( max_seq_length: int, device: Optional[torch.device] = None ) -> torch.Tensor: ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool) return torch.tril(ones).unsqueeze(0).unsqueeze(0)