gpt1 / modeling_gpt1.py
Alexandru Gherghescu
Fix modeling_gpt1.py
fe8246f unverified
raw
history blame
9.36 kB
""" PyTorch GPT1 model."""
import math
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
)
from transformers.activations import get_activation
from configuration_gpt1 import GPT1Config
class GPT1MLP(nn.Module):
def __init__(self, config: GPT1Config):
super().__init__()
self.activation_fn = get_activation(config.hidden_act)
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_state):
hidden_state = self.fc1(hidden_state)
hidden_state = self.activation_fn(hidden_state)
hidden_state = self.fc2(hidden_state)
return hidden_state
class GPT1Attention(nn.Module):
def __init__(self, config: GPT1Config):
"""
Multi-head attention layer.
"""
super().__init__()
assert config.hidden_size % config.num_attention_heads == 0
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.attn_dropout = nn.Dropout(p=config.attention_dropout)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
def forward(self, hidden_state, attn_mask):
bs, seq_len, _ = hidden_state.size() # (batch_size, seq_len, dim)
# linearly project the inputs
Q = self.q_proj(hidden_state) # (batch_size, seq_len, n_heads * head_dim)
K = self.k_proj(hidden_state)
V = self.v_proj(hidden_state)
# split into n_heads to compute attention
queries = Q.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_len, head_dim)
keys = K.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
values = V.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# compute attention matmul
keys = keys.transpose(2, 3) # (batch_size, n_heads, head_dim, seq_len)
attn_scores = queries @ keys # (batch_size, n_heads, seq_len, seq_len)
# scale
attn_scores = attn_scores / math.sqrt(self.head_dim)
# mask
if attn_mask is not None:
attn_scores = attn_scores + attn_mask
# softmax (attention probabilities) + dropout
attn_probs = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to(Q.dtype)
attn_probs = self.attn_dropout(attn_probs)
# matmul
attn_output = attn_probs @ values # (batch_size, n_heads, seq_len, head_dim)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bs, seq_len, self.hidden_size) # (batch_size, seq_len, n_heads * head_dim)
# final linear
attn_output = self.o_proj(attn_output)
return attn_output
class GPT1DecoderLayer(nn.Module):
def __init__(self, config: GPT1Config):
super().__init__()
self.attention = GPT1Attention(config)
self.mlp = GPT1MLP(config)
self.attention_norm = nn.LayerNorm(normalized_shape=config.hidden_size,
eps=config.layer_norm_eps)
self.mlp_norm = nn.LayerNorm(normalized_shape=config.hidden_size,
eps=config.layer_norm_eps)
self.res_dropout = nn.Dropout(p=config.resid_pdrop)
def forward(self, hidden_state, attn_mask):
# attention
residual = hidden_state
hidden_state = self.attention(hidden_state, attn_mask)
hidden_state = self.res_dropout(hidden_state)
hidden_state = residual + hidden_state
hidden_state = self.attention_norm(hidden_state)
# feed forward fully connected
residual = hidden_state
hidden_state = self.mlp(hidden_state)
hidden_state = self.res_dropout(hidden_state)
hidden_state = residual + hidden_state
hidden_state = self.mlp_norm(hidden_state)
return hidden_state
class GPT1PreTrainedModel(PreTrainedModel):
config_class = GPT1Config
supports_gradient_checkpointing = False
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class GPT1Model(GPT1PreTrainedModel):
def __init__(self, config: GPT1Config):
super().__init__(config)
# embeddings
self.embs = nn.Embedding(config.vocab_size, config.hidden_size)
self.embs_dropout = nn.Dropout(p=config.embd_pdrop)
# positional encoding (learned)
self.pos_emb = nn.Embedding(config.max_position_embeddings,
config.hidden_size)
self.layers = nn.ModuleList(
[GPT1DecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
causal_mask = torch.full((1, config.max_position_embeddings, config.max_position_embeddings),
fill_value=float('-inf'))
self.register_buffer('causal_mask',
torch.triu(causal_mask, diagonal=1),
persistent=False)
self.mask_cache_len = config.max_position_embeddings
self.post_init()
def get_input_embeddings(self):
return self.embs
def set_input_embeddings(self, value):
self.embs = value
def forward(self, input_ids, attention_mask=None, *args, **kwargs):
position_ids = torch.arange(input_ids.size(-1),
dtype=torch.long,
device=input_ids.device).unsqueeze_(0)
input_embeds = self.embs(input_ids) # (bs, seq_len, dim)
position_embeds = self.pos_emb(position_ids)
hidden_state = self.embs_dropout(input_embeds) + position_embeds
if attention_mask is not None and attention_mask.size(1) > self.mask_cache_len:
seq_len = attention_mask.size(1)
self.mask_cache_len = seq_len
causal_mask = torch.full((seq_len, seq_len),
fill_value=float('-inf'))
self.register_buffer('causal_mask',
torch.triu(causal_mask, diagonal=1),
persistent=False)
causal_mask = self.causal_mask.to(dtype=input_embeds.dtype,
device=input_embeds.device)
for layer in self.layers:
hidden_state = layer(hidden_state, attn_mask=causal_mask)
return BaseModelOutput(
last_hidden_state=hidden_state
)
class GPT1ForCausalLM(GPT1PreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: GPT1Config):
super().__init__(config)
self.model = GPT1Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
# initialize weigths and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embs
def set_input_embeddings(self, value):
self.model.embs = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_decoder(self):
return self.model
def set_decoder(self, decoder):
self.model = decoder
def forward(self, input_ids, labels=None, attention_mask=None,
*args, **kwargs):
output = self.model(input_ids, attention_mask)
hidden_state = output[0]
logits = self.lm_head(hidden_state).float()
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fn = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
loss = loss_fn(shift_logits, shift_labels)
return CausalLMOutput(
loss=loss,
logits=logits
)
def prepare_inputs_for_generation(self, input_ids, attention_mask,
*args, **kwargs):
assert attention_mask.size(1) == input_ids.size(1)
seq_len = attention_mask.size(1)
attn_mask = torch.full((seq_len, seq_len), fill_value=float('-inf'))
attn_mask = torch.triu(attn_mask, diagonal=1)
return {
'input_ids': input_ids,
'attention_mask': attn_mask
}