| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutput |
| from transformers.generation import GenerationMixin |
|
|
|
|
| from .configuration_my_gpt import MyGPTConfig |
| from .untrained_model import GPTModel |
|
|
| import os |
|
|
|
|
|
|
| class MyGPTForCausalLM(PreTrainedModel, GenerationMixin): |
| config_class = MyGPTConfig |
| main_input_name = "input_ids" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
|
|
| |
| self.model = GPTModel({ |
| "vocab_size": config.vocab_size, |
| "context_length": config.max_position_embeddings, |
| "emb_dim": config.hidden_size, |
| "n_heads": config.num_attention_heads, |
| "n_layers": config.num_hidden_layers, |
| "drop_rate": config.drop_rate, |
| "qkv_bias": config.qkv_bias |
| }) |
|
|
| self.post_init() |
|
|
| def forward(self, input_ids, labels=None, **kwargs): |
| logits = self.model(input_ids) |
| |
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1) |
| ) |
| |
| return CausalLMOutput( |
| loss=loss, |
| logits=logits, |
| ) |
|
|
|
|