simplified_phi2 / phi2_model.py
BucketOfFish's picture
Passing KV cache through iterations
c07c430
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import Any, cast
from .attention import ParallelAttentionBlock, KVCache
from .phi2_configuration import Phi2Config
class Phi2PreTrainedModel(PreTrainedModel):
config_class = Phi2Config # not necessary unless you want to register model with auto classes
supports_gradient_checkpointing = False
# _no_split_modules = ["ParallelAttentionBlock"]
def __init__(self, config: Phi2Config):
super().__init__(config)
self.config = config
def _init_weights(self, module: nn.Module) -> None:
# initialize weights - will get overwritten by saved weights in from_pretrained() if they exist
if isinstance(module, (nn.Linear,)):
module.weight.data.normal_(mean=0.0, std=self.config.weight_initialization_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.weight_initialization_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
if module.bias is not None:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
past_key_values: KVCache | None = None, # has to be named this
key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
**kwargs, # has to be here
) -> dict[str, Any]:
kv_cache = past_key_values
if not kv_cache:
kv_cache = KVCache(
max_seqlen=self.config.initial_cos_sin_cache_len,
max_batch_size=input_ids.shape[0],
seqlen_offset=0,
batch_size_offset=0,
kv_block_map={},
lengths_per_sample=None,
)
else:
# assume that `kv_cache` has cached all tokens up to the last token in `input_ids`
kv_cache.seqlen_offset = input_ids.shape[1] - 1
input_ids = cast(torch.LongTensor, input_ids[:, -1].unsqueeze(-1))
return { # to be passed to forward()
"input_ids": input_ids,
"kv_cache": kv_cache,
"key_padding_mask": key_padding_mask,
}
class Embedding(nn.Module):
"""Token embedding with dropout."""
def __init__(
self,
vocab_size: int,
d_embedding: int,
embd_pdrop: float,
) -> None:
super().__init__()
self.embeddings = nn.Embedding(vocab_size, d_embedding)
self.dropout = nn.Dropout(embd_pdrop)
def forward(
self,
input_ids: torch.LongTensor, # dim: (batch_size, seq_len)
) -> torch.FloatTensor:
x = self.embeddings( # dim: (batch_size, seq_len, d_embedding)
input_ids.view(-1, input_ids.size()[-1])
)
x = self.dropout(x)
return x
class Phi2Model(Phi2PreTrainedModel):
def __init__(self, config: Phi2Config) -> None:
super().__init__(config)
self.embedding = Embedding(
vocab_size=config.vocab_size,
d_embedding=config.d_embedding,
embd_pdrop=config.embd_pdrop,
)
self.parallel_blocks = nn.ModuleList([
ParallelAttentionBlock(
resid_pdrop=config.resid_pdrop,
layer_norm_epsilon=config.layer_norm_epsilon,
d_embedding=config.d_embedding,
n_attn_heads=config.n_attn_heads,
block_n=i,
initial_cos_sin_cache_len=config.initial_cos_sin_cache_len,
attn_pdrop=config.attn_pdrop,
use_flash_rotary=config.use_flash_rotary,
use_flash_attn=config.use_flash_attn,
use_fused_dense=config.use_fused_dense,
checkpointing=config.checkpointing,
)
for i in range(config.n_attn_blocks)
])
self.gradient_checkpointing_disable() # https://github.com/cybertronai/gradient-checkpointing - I think this is turned off due to flash attention?
self.post_init() # calls self._init_weights() for all modules
"""
def get_input_embeddings(self) -> nn.Embedding:
return self.embedding.embeddings
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
self.embedding.embeddings = new_embeddings
"""
def forward(
self,
input_ids: torch.LongTensor,
kv_cache: KVCache | None = None,
key_padding_mask: torch.BoolTensor | None = None,
) -> torch.FloatTensor:
x = self.embedding(input_ids)
for block in self.parallel_blocks:
x = block(
x,
kv_cache=kv_cache,
key_padding_mask=key_padding_mask,
)
return x
class Phi2ModelForCausalLM(Phi2PreTrainedModel):
def __init__(self, config: Phi2Config) -> None:
super().__init__(config)
self.model = Phi2Model(config)
self.lm_head_layer_norm = nn.LayerNorm(config.d_embedding, eps=config.layer_norm_epsilon)
self.lm_head_linear = nn.Linear(config.d_embedding, config.vocab_size)
self.loss_fn = nn.CrossEntropyLoss()
self.post_init() # calls self._init_weights() for all modules
def forward(
self,
input_ids: torch.LongTensor,
kv_cache: KVCache | None = None,
key_padding_mask: torch.BoolTensor | None = None,
labels: torch.LongTensor | None = None,
**kwargs, # has to be here
) -> CausalLMOutputWithPast:
x = self.model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask)
x = self.lm_head_layer_norm(x)
logits = self.lm_head_linear(x).to(torch.float32)
loss = (
self.loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
if labels is not None
else None
)
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=kv_cache)