""" PyTorch GPT1 model.""" import math import torch from torch import nn from transformers import PreTrainedModel from transformers.modeling_outputs import ( BaseModelOutput, CausalLMOutput, ) from transformers.activations import get_activation from configuration_gpt1 import GPT1Config class GPT1MLP(nn.Module): def __init__(self, config: GPT1Config): super().__init__() self.activation_fn = get_activation(config.hidden_act) self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_state): hidden_state = self.fc1(hidden_state) hidden_state = self.activation_fn(hidden_state) hidden_state = self.fc2(hidden_state) return hidden_state class GPT1Attention(nn.Module): def __init__(self, config: GPT1Config): """ Multi-head attention layer. """ super().__init__() assert config.hidden_size % config.num_attention_heads == 0 self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.attn_dropout = nn.Dropout(p=config.attention_dropout) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) def forward(self, hidden_state, attn_mask): bs, seq_len, _ = hidden_state.size() # (batch_size, seq_len, dim) # linearly project the inputs Q = self.q_proj(hidden_state) # (batch_size, seq_len, n_heads * head_dim) K = self.k_proj(hidden_state) V = self.v_proj(hidden_state) # split into n_heads to compute attention queries = Q.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_len, head_dim) keys = K.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2) values = V.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # compute attention matmul keys = keys.transpose(2, 3) # (batch_size, n_heads, head_dim, seq_len) attn_scores = queries @ keys # (batch_size, n_heads, seq_len, seq_len) # scale attn_scores = attn_scores / math.sqrt(self.head_dim) # mask if attn_mask is not None: attn_scores = attn_scores + attn_mask # softmax (attention probabilities) + dropout attn_probs = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to(Q.dtype) attn_probs = self.attn_dropout(attn_probs) # matmul attn_output = attn_probs @ values # (batch_size, n_heads, seq_len, head_dim) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bs, seq_len, self.hidden_size) # (batch_size, seq_len, n_heads * head_dim) # final linear attn_output = self.o_proj(attn_output) return attn_output class GPT1DecoderLayer(nn.Module): def __init__(self, config: GPT1Config): super().__init__() self.attention = GPT1Attention(config) self.mlp = GPT1MLP(config) self.attention_norm = nn.LayerNorm(normalized_shape=config.hidden_size, eps=config.layer_norm_eps) self.mlp_norm = nn.LayerNorm(normalized_shape=config.hidden_size, eps=config.layer_norm_eps) self.res_dropout = nn.Dropout(p=config.resid_pdrop) def forward(self, hidden_state, attn_mask): # attention residual = hidden_state hidden_state = self.attention(hidden_state, attn_mask) hidden_state = self.res_dropout(hidden_state) hidden_state = residual + hidden_state hidden_state = self.attention_norm(hidden_state) # feed forward fully connected residual = hidden_state hidden_state = self.mlp(hidden_state) hidden_state = self.res_dropout(hidden_state) hidden_state = residual + hidden_state hidden_state = self.mlp_norm(hidden_state) return hidden_state class GPT1PreTrainedModel(PreTrainedModel): config_class = GPT1Config supports_gradient_checkpointing = False def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class GPT1Model(GPT1PreTrainedModel): def __init__(self, config: GPT1Config): super().__init__(config) # embeddings self.embs = nn.Embedding(config.vocab_size, config.hidden_size) self.embs_dropout = nn.Dropout(p=config.embd_pdrop) # positional encoding (learned) self.pos_emb = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.layers = nn.ModuleList( [GPT1DecoderLayer(config) for _ in range(config.num_hidden_layers)] ) causal_mask = torch.full((1, config.max_position_embeddings, config.max_position_embeddings), fill_value=float('-inf')) self.register_buffer('causal_mask', torch.triu(causal_mask, diagonal=1), persistent=False) self.mask_cache_len = config.max_position_embeddings self.post_init() def get_input_embeddings(self): return self.embs def set_input_embeddings(self, value): self.embs = value def forward(self, input_ids, attention_mask=None, *args, **kwargs): position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device).unsqueeze_(0) input_embeds = self.embs(input_ids) # (bs, seq_len, dim) position_embeds = self.pos_emb(position_ids) hidden_state = self.embs_dropout(input_embeds) + position_embeds if attention_mask is not None and attention_mask.size(1) > self.mask_cache_len: seq_len = attention_mask.size(1) self.mask_cache_len = seq_len causal_mask = torch.full((seq_len, seq_len), fill_value=float('-inf')) self.register_buffer('causal_mask', torch.triu(causal_mask, diagonal=1), persistent=False) causal_mask = self.causal_mask.to(dtype=input_embeds.dtype, device=input_embeds.device) for layer in self.layers: hidden_state = layer(hidden_state, attn_mask=causal_mask) return BaseModelOutput( last_hidden_state=hidden_state ) class GPT1ForCausalLM(GPT1PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: GPT1Config): super().__init__(config) self.model = GPT1Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) # initialize weigths and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embs def set_input_embeddings(self, value): self.model.embs = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def get_decoder(self): return self.model def set_decoder(self, decoder): self.model = decoder def forward(self, input_ids, labels=None, attention_mask=None, *args, **kwargs): output = self.model(input_ids, attention_mask) hidden_state = output[0] logits = self.lm_head(hidden_state).float() loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fn = torch.nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) loss = loss_fn(shift_logits, shift_labels) return CausalLMOutput( loss=loss, logits=logits ) def prepare_inputs_for_generation(self, input_ids, attention_mask, *args, **kwargs): assert attention_mask.size(1) == input_ids.size(1) seq_len = attention_mask.size(1) attn_mask = torch.full((seq_len, seq_len), fill_value=float('-inf')) attn_mask = torch.triu(attn_mask, diagonal=1) return { 'input_ids': input_ids, 'attention_mask': attn_mask }