gpt1 / modeling_gpt1.py
Alexandru Gherghescu
Fix inference code
04fbb43 unverified
raw
history blame contribute delete
No virus
8.25 kB
""" PyTorch GPT1 model."""
import math
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
)
from transformers.activations import get_activation
from configuration_gpt1 import GPT1Config
class GPT1MLP(nn.Module):
def __init__(self, config: GPT1Config):
super().__init__()
self.activation_fn = get_activation(config.hidden_act)
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_state):
hidden_state = self.fc1(hidden_state)
hidden_state = self.activation_fn(hidden_state)
hidden_state = self.fc2(hidden_state)
return hidden_state
class GPT1Attention(nn.Module):
def __init__(self, config: GPT1Config):
"""
Multi-head attention layer.
"""
super().__init__()
assert config.hidden_size % config.num_attention_heads == 0
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.attn_dropout = nn.Dropout(p=config.attention_dropout)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
def forward(self, hidden_state, attn_mask):
bs, seq_len, _ = hidden_state.size() # (batch_size, seq_len, dim)
# linearly project the inputs
Q = self.q_proj(hidden_state) # (batch_size, seq_len, n_heads * head_dim)
K = self.k_proj(hidden_state)
V = self.v_proj(hidden_state)
# split into n_heads to compute attention
queries = Q.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_len, head_dim)
keys = K.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
values = V.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# compute attention matmul
keys = keys.transpose(2, 3) # (batch_size, n_heads, head_dim, seq_len)
attn_scores = queries @ keys # (batch_size, n_heads, seq_len, seq_len)
# scale
attn_scores = attn_scores / math.sqrt(self.head_dim)
# mask
if attn_mask is not None:
attn_scores = attn_scores + attn_mask
# softmax (attention probabilities) + dropout
attn_probs = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to(Q.dtype)
attn_probs = self.attn_dropout(attn_probs)
# matmul
attn_output = attn_probs @ values # (batch_size, n_heads, seq_len, head_dim)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bs, seq_len, self.hidden_size) # (batch_size, seq_len, n_heads * head_dim)
# final linear
attn_output = self.o_proj(attn_output)
return attn_output
class GPT1DecoderLayer(nn.Module):
def __init__(self, config: GPT1Config):
super().__init__()
self.attention = GPT1Attention(config)
self.mlp = GPT1MLP(config)
self.attention_norm = nn.LayerNorm(normalized_shape=config.hidden_size,
eps=config.layer_norm_eps)
self.mlp_norm = nn.LayerNorm(normalized_shape=config.hidden_size,
eps=config.layer_norm_eps)
self.res_dropout = nn.Dropout(p=config.resid_pdrop)
def forward(self, hidden_state, attn_mask):
# attention
residual = hidden_state
hidden_state = self.attention(hidden_state, attn_mask)
hidden_state = self.res_dropout(hidden_state)
hidden_state = residual + hidden_state
hidden_state = self.attention_norm(hidden_state)
# feed forward fully connected
residual = hidden_state
hidden_state = self.mlp(hidden_state)
hidden_state = self.res_dropout(hidden_state)
hidden_state = residual + hidden_state
hidden_state = self.mlp_norm(hidden_state)
return hidden_state
class GPT1PreTrainedModel(PreTrainedModel):
config_class = GPT1Config
supports_gradient_checkpointing = False
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class GPT1Model(GPT1PreTrainedModel):
def __init__(self, config: GPT1Config):
super().__init__(config)
# embeddings
self.embs = nn.Embedding(config.vocab_size, config.hidden_size)
self.embs_dropout = nn.Dropout(p=config.embd_pdrop)
# positional encoding (learned)
self.pos_emb = nn.Embedding(config.max_position_embeddings,
config.hidden_size)
self.layers = nn.ModuleList(
[GPT1DecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.post_init()
def get_input_embeddings(self):
return self.embs
def set_input_embeddings(self, value):
self.embs = value
def forward(self, input_ids, *args, **kwargs):
position_ids = torch.arange(input_ids.size(-1),
dtype=torch.long,
device=input_ids.device).unsqueeze_(0)
input_embeds = self.embs(input_ids) # (bs, seq_len, dim)
position_embeds = self.pos_emb(position_ids)
hidden_state = self.embs_dropout(input_embeds) + position_embeds
seq_len = input_ids.size(-1)
attn_mask = torch.full((seq_len, seq_len), fill_value=float('-inf'))
attn_mask = torch.triu(attn_mask, diagonal=1)
causal_mask = attn_mask.to(dtype=input_embeds.dtype,
device=input_embeds.device)
for layer in self.layers:
hidden_state = layer(hidden_state, attn_mask=causal_mask)
return BaseModelOutput(
last_hidden_state=hidden_state
)
class GPT1ForCausalLM(GPT1PreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: GPT1Config):
super().__init__(config)
self.model = GPT1Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
# initialize weigths and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embs
def set_input_embeddings(self, value):
self.model.embs = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_decoder(self):
return self.model
def set_decoder(self, decoder):
self.model = decoder
def forward(self, input_ids, labels=None, *args, **kwargs):
output = self.model(input_ids)
hidden_state = output[0]
logits = self.lm_head(hidden_state).float()
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fn = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
loss = loss_fn(shift_logits, shift_labels)
return CausalLMOutput(
loss=loss,
logits=logits
)
def prepare_inputs_for_generation(self, input_ids, *args, **kwargs):
return { 'input_ids': input_ids }