|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from typing import Any, cast |
|
|
|
from .attention import ParallelAttentionBlock, KVCache |
|
from .phi2_configuration import Phi2Config |
|
|
|
|
|
class Phi2PreTrainedModel(PreTrainedModel): |
|
config_class = Phi2Config |
|
supports_gradient_checkpointing = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: Phi2Config): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
def _init_weights(self, module: nn.Module) -> None: |
|
|
|
if isinstance(module, (nn.Linear,)): |
|
module.weight.data.normal_(mean=0.0, std=self.config.weight_initialization_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.config.weight_initialization_range) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids: torch.LongTensor, |
|
kv_cache: KVCache | None = None, |
|
key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None, |
|
) -> dict[str, Any]: |
|
if not kv_cache: |
|
kv_cache = KVCache( |
|
max_seqlen=self.config.initial_cos_sin_cache_len, |
|
max_batch_size=input_ids.shape[0], |
|
seqlen_offset=0, |
|
batch_size_offset=0, |
|
kv_block_map={}, |
|
lengths_per_sample=None, |
|
) |
|
else: |
|
|
|
kv_cache.seqlen_offset = input_ids.shape[1] - 1 |
|
input_ids = cast(torch.LongTensor, input_ids[:, -1].unsqueeze(-1)) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"kv_cache": kv_cache, |
|
"key_padding_mask": key_padding_mask, |
|
} |
|
|
|
|
|
class Embedding(nn.Module): |
|
"""Token embedding with dropout from Phi2.""" |
|
|
|
def __init__( |
|
self, |
|
vocab_size: int, |
|
d_embedding: int, |
|
embd_pdrop: float, |
|
) -> None: |
|
super().__init__() |
|
self.embeddings = nn.Embedding(vocab_size, d_embedding) |
|
self.dropout = nn.Dropout(embd_pdrop) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
) -> torch.FloatTensor: |
|
x = self.embeddings( |
|
input_ids.view(-1, input_ids.size()[-1]) |
|
) |
|
x = self.dropout(x) |
|
return x |
|
|
|
|
|
class Phi2Model(Phi2PreTrainedModel): |
|
def __init__(self, config: Phi2Config) -> None: |
|
super().__init__(config) |
|
self.rotary_embedding = Embedding( |
|
vocab_size=config.vocab_size, |
|
d_embedding=config.d_embedding, |
|
embd_pdrop=config.embd_pdrop, |
|
) |
|
self.parallel_blocks = nn.ModuleList([ |
|
ParallelAttentionBlock( |
|
resid_pdrop=config.resid_pdrop, |
|
layer_norm_epsilon=config.layer_norm_epsilon, |
|
d_embedding=config.d_embedding, |
|
n_attn_heads=config.n_attn_heads, |
|
block_n=i, |
|
initial_cos_sin_cache_len=config.initial_cos_sin_cache_len, |
|
attn_pdrop=config.attn_pdrop, |
|
use_flash_rotary=config.use_flash_rotary, |
|
use_flash_attn=config.use_flash_attn, |
|
use_fused_dense=config.use_fused_dense, |
|
checkpointing=config.checkpointing, |
|
) |
|
for i in range(config.n_blocks) |
|
]) |
|
self.gradient_checkpointing_disable() |
|
self.post_init() |
|
|
|
""" |
|
def get_input_embeddings(self) -> nn.Embedding: |
|
return self.rotary_embedding.embeddings |
|
|
|
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: |
|
self.rotary_embedding.embeddings = new_embeddings |
|
""" |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
kv_cache: KVCache | None = None, |
|
key_padding_mask: torch.BoolTensor | None = None, |
|
) -> torch.FloatTensor: |
|
x = self.rotary_embedding(input_ids) |
|
for block in self.parallel_blocks: |
|
x = block( |
|
x, |
|
kv_cache=kv_cache, |
|
key_padding_mask=key_padding_mask, |
|
) |
|
return x |
|
|
|
|
|
class Phi2ModelForCausalLM(Phi2PreTrainedModel): |
|
def __init__(self, config: Phi2Config) -> None: |
|
super().__init__(config) |
|
self.pretrained_model = Phi2Model(config) |
|
self.lm_head_layer_norm = nn.LayerNorm(config.d_embedding, eps=config.layer_norm_epsilon) |
|
self.lm_head_linear = nn.Linear(config.d_embedding, config.vocab_size) |
|
self.loss_fn = nn.CrossEntropyLoss() |
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
kv_cache: KVCache | None = None, |
|
key_padding_mask: torch.BoolTensor | None = None, |
|
labels: torch.LongTensor | None = None, |
|
) -> CausalLMOutputWithPast: |
|
x = self.pretrained_model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask) |
|
x = self.lm_head_layer_norm(x) |
|
logits = self.lm_head_linear(x).to(torch.float32) |
|
loss = ( |
|
self.loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1)) |
|
if labels is not None |
|
else None |
|
) |
|
return CausalLMOutputWithPast(loss=loss, logits=logits) |
|
|