import torch import torch.nn as nn from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from typing import Any, cast from .attention import ParallelAttentionBlock, KVCache from .phi2_configuration import Phi2Config class Phi2PreTrainedModel(PreTrainedModel): config_class = Phi2Config # not necessary unless you want to register model with auto classes supports_gradient_checkpointing = False # _no_split_modules = ["ParallelAttentionBlock"] def __init__(self, config: Phi2Config): super().__init__(config) self.config = config def _init_weights(self, module: nn.Module) -> None: # initialize weights - will get overwritten by saved weights in from_pretrained() if they exist if isinstance(module, (nn.Linear,)): module.weight.data.normal_(mean=0.0, std=self.config.weight_initialization_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.weight_initialization_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): if module.bias is not None: module.bias.data.zero_() module.weight.data.fill_(1.0) def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, # dim: (batch_size, seq_len) past_key_values: KVCache | None = None, # has to be named this key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None, **kwargs, # has to be here ) -> dict[str, Any]: kv_cache = past_key_values if not kv_cache: kv_cache = KVCache( max_seqlen=self.config.initial_cos_sin_cache_len, max_batch_size=input_ids.shape[0], seqlen_offset=0, batch_size_offset=0, kv_block_map={}, lengths_per_sample=None, ) else: # assume that `kv_cache` has cached all tokens up to the last token in `input_ids` kv_cache.seqlen_offset = input_ids.shape[1] - 1 input_ids = cast(torch.LongTensor, input_ids[:, -1].unsqueeze(-1)) return { # to be passed to forward() "input_ids": input_ids, "kv_cache": kv_cache, "key_padding_mask": key_padding_mask, } class Embedding(nn.Module): """Token embedding with dropout.""" def __init__( self, vocab_size: int, d_embedding: int, embd_pdrop: float, ) -> None: super().__init__() self.embeddings = nn.Embedding(vocab_size, d_embedding) self.dropout = nn.Dropout(embd_pdrop) def forward( self, input_ids: torch.LongTensor, # dim: (batch_size, seq_len) ) -> torch.FloatTensor: x = self.embeddings( # dim: (batch_size, seq_len, d_embedding) input_ids.view(-1, input_ids.size()[-1]) ) x = self.dropout(x) return x class Phi2Model(Phi2PreTrainedModel): def __init__(self, config: Phi2Config) -> None: super().__init__(config) self.embedding = Embedding( vocab_size=config.vocab_size, d_embedding=config.d_embedding, embd_pdrop=config.embd_pdrop, ) self.parallel_blocks = nn.ModuleList([ ParallelAttentionBlock( resid_pdrop=config.resid_pdrop, layer_norm_epsilon=config.layer_norm_epsilon, d_embedding=config.d_embedding, n_attn_heads=config.n_attn_heads, block_n=i, initial_cos_sin_cache_len=config.initial_cos_sin_cache_len, attn_pdrop=config.attn_pdrop, use_flash_rotary=config.use_flash_rotary, use_flash_attn=config.use_flash_attn, use_fused_dense=config.use_fused_dense, checkpointing=config.checkpointing, ) for i in range(config.n_attn_blocks) ]) self.gradient_checkpointing_disable() # https://github.com/cybertronai/gradient-checkpointing - I think this is turned off due to flash attention? self.post_init() # calls self._init_weights() for all modules """ def get_input_embeddings(self) -> nn.Embedding: return self.embedding.embeddings def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: self.embedding.embeddings = new_embeddings """ def forward( self, input_ids: torch.LongTensor, kv_cache: KVCache | None = None, key_padding_mask: torch.BoolTensor | None = None, ) -> torch.FloatTensor: x = self.embedding(input_ids) for block in self.parallel_blocks: x = block( x, kv_cache=kv_cache, key_padding_mask=key_padding_mask, ) return x class Phi2ModelForCausalLM(Phi2PreTrainedModel): def __init__(self, config: Phi2Config) -> None: super().__init__(config) self.model = Phi2Model(config) self.lm_head_layer_norm = nn.LayerNorm(config.d_embedding, eps=config.layer_norm_epsilon) self.lm_head_linear = nn.Linear(config.d_embedding, config.vocab_size) self.loss_fn = nn.CrossEntropyLoss() self.post_init() # calls self._init_weights() for all modules def forward( self, input_ids: torch.LongTensor, kv_cache: KVCache | None = None, key_padding_mask: torch.BoolTensor | None = None, labels: torch.LongTensor | None = None, **kwargs, # has to be here ) -> CausalLMOutputWithPast: x = self.model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask) x = self.lm_head_layer_norm(x) logits = self.lm_head_linear(x).to(torch.float32) loss = ( self.loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1)) if labels is not None else None ) return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=kv_cache)