import torch from torch import nn from transformers import PreTrainedModel from configuration_gpt1 import GPT1Config class GPT1PreTrainedModel(PreTrainedModel): config_class = GPT1Config supports_gradient_checkpointing = False def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class GPT1Model(GPT1PreTrainedModel): def __init__(self, config: GPT1Config): super().__init__(config) # embeddings self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.embeddings_dropout = nn.Dropout(config.embd_pdrop) # positional encoding (learned) self.pos_emb = nn.Embedding(config.max_position_embeddings, config.hidden_size) dec_layers = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads, dim_feedforward=config.intermediate_size, dropout=config.attention_dropout, activation=config.hidden_act, layer_norm_eps=config.layer_norm_eps, batch_first=True) self.layers = nn.TransformerEncoder(dec_layers, config.num_hidden_layers) self.post_init() def forward( self, input_ids, attention_mask ): position_ids = torch.arange(input_ids.size()[-1], dtype=torch.long, device=input_ids.device) input_embeds = self.embeddings(input_ids) # (bs, seq_len, dim) position_embeds = self.pos_emb(position_ids) hidden_state = input_embeds + position_embeds hidden_state = self.embeddings_dropout(hidden_state) _, seq_len, _ = hidden_state.shape # (bs, seq_len, dim) attention_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(hidden_state.device) output = self.layers(hidden_state, attention_mask, is_causal=True) return output class GPT1ModelForCausalLM(GPT1PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: GPT1Config): super().__init__(config) self.model = GPT1Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) self.post_init() def forward( self, input_ids, attention_mask, labels ): output = self.model(input_ids, attention_mask) logits = self.lm_head(output).float() loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = torch.nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) loss = loss_fct(shift_logits, shift_labels) return { "loss": loss, "logits": logits } return { "logits": logits }