new-nlp-hw3-llama / configure_for_hf.py
Mortie1's picture
Upload MyLLaMa
507f903 verified
from transformers import (
PretrainedConfig,
PreTrainedModel,
AutoConfig,
AutoModel,
AutoModelForCausalLM,
)
from torch import nn
from hf.llama import CustomAttentionLLaMa
class MyLLaMaConfig(PretrainedConfig):
model_type = "LLaMa"
def __init__(
self,
embed_dim: int = 1536,
n_layers: int = 24,
n_heads: int = 24,
n_chckpnt_segments: int = 24,
**kwargs,
):
self.embed_dim = embed_dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_chckpnt_segments = n_chckpnt_segments
super().__init__(**kwargs)
class MyLLaMa(PreTrainedModel):
config_class = MyLLaMaConfig
def __init__(self, config: MyLLaMaConfig):
super().__init__(config)
self.model = CustomAttentionLLaMa(
config.embed_dim,
config.n_layers,
config.n_heads,
dropout=0,
n_chckpnt_segments=config.n_chckpnt_segments,
)
def forward(self, tensor, labels=None):
logits = self.model(tensor)["logits"]
if labels is not None:
loss = nn.functional.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits}
AutoConfig.register("LLaMa", MyLLaMaConfig)
AutoModel.register(MyLLaMaConfig, MyLLaMa)
AutoModelForCausalLM.register(MyLLaMaConfig, MyLLaMa)