custom_gpt2 / app.py
Gbssreejith's picture
Update app.py
57c78bf verified
raw
history blame contribute delete
No virus
19.9 kB
import copy
import torch
import math
import torch.nn as nn
from torch.nn.parameter import Parameter
import random
import numpy as np
from load_weights import load_weight
from sklearn.model_selection import train_test_split
from transformers import GPT2TokenizerFast
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
torch.manual_seed(42)
import nltk
nltk.download('punkt')
from transformers import GPT2Tokenizer
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
import datetime
import time
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
from tqdm import trange
import gradio as gr
import re
def gelu(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class Conv1D(nn.Module):
def __init__(self, nf, nx):
super(Conv1D, self).__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = Parameter(w)
self.bias = Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
return x
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False):
super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx)
def _attn(self, q, k, v):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
nd, ns = w.size(-2), w.size(-1)
b = self.bias[:, :, ns-nd:ns, :ns]
w = w * b - 1e10 * (1 - b)
w = nn.Softmax(dim=-1)(w)
return torch.matmul(w, v)
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
if k:
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
else:
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def forward(self, x, layer_past=None):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
if layer_past is not None:
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
a = self._attn(query, key, value)
a = self.merge_heads(a)
a = self.c_proj(a)
return a, present
class MLP(nn.Module):
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
super(MLP, self).__init__()
nx = config.n_embd
self.c_fc = Conv1D(n_state, nx)
self.c_proj = Conv1D(nx, n_state)
self.act = gelu
def forward(self, x):
h = self.act(self.c_fc(x))
h2 = self.c_proj(h)
return h2
class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False):
super(Block, self).__init__()
nx = config.n_embd
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
def forward(self, x, layer_past=None):
a, present = self.attn(self.ln_1(x), layer_past=layer_past)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
return x, present
class GPT2Model(nn.Module):
def __init__(self, config):
super(GPT2Model, self).__init__()
self.n_layer = config.n_layer
self.n_embd = config.n_embd
self.n_vocab = config.vocab_size
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
block = Block(config.n_ctx, config, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
def set_embeddings_weights(self, model_embeddings_weights):
embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
if (input_ids >= self.n_vocab).any():
raise ValueError(f"Invalid token ID found in input_ids: {input_ids}")
# print(f"input_ids: {input_ids}") # Debugging statement
# print(f"Max input_id: {input_ids.max().item()}") # Debugging statement
# print(f"Min input_id: {input_ids.min().item()}") # Debugging statement
if past is None:
past_length = 0
past = [None] * len(self.h)
else:
past_length = past[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long,
device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1))
position_ids = position_ids.view(-1, position_ids.size(-1))
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
# print(f"inputs_embeds shape: {inputs_embeds.shape}")
# print(f"position_embeds shape: {position_embeds.shape}")
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
token_type_embeds = self.wte(token_type_ids)
else:
token_type_embeds = 0
hidden_states = inputs_embeds + position_embeds + token_type_embeds
presents = []
for block, layer_past in zip(self.h, past):
hidden_states, present = block(hidden_states, layer_past)
presents.append(present)
hidden_states = self.ln_f(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
return hidden_states.view(*output_shape), presents
class GPT2LMHead(nn.Module):
def __init__(self, model_embeddings_weights, config):
super(GPT2LMHead, self).__init__()
self.n_embd = config.n_embd
self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights):
embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, hidden_state):
# Truncated Language modeling logits (we remove the last token)
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
lm_logits = self.decoder(hidden_state)
return lm_logits
import torch.nn.functional as F
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
filter_value: value to replace filtered logits.
"""
assert logits.dim() == 2 # batch size x vocabulary size
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
class GPT2LMHeadModel(nn.Module):
def __init__(self, config):
super(GPT2LMHeadModel, self).__init__()
self.transformer = GPT2Model(config)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
def set_tied(self):
""" Make sure we are sharing the embeddings
"""
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
lm_logits = self.lm_head(hidden_states)
outputs = (lm_logits,presents)
if lm_labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
outputs = (loss,) + outputs
return outputs
import torch.nn.functional as F
def generate(
self, input_ids, max_length, temperature=1.0, top_k=0, top_p=0.9, repetition_penalty=1.0, device='cuda'
):
self.eval()
input_ids = input_ids.to(device)
batch_size = input_ids.shape[0]
past = None
generated = input_ids
with torch.no_grad():
for _ in range(max_length):
outputs = self(input_ids, past=past)
next_token_logits = outputs[0][:, -1, :]
past = outputs[1]
for i in range(batch_size):
for token_id in set(generated[i].tolist()):
next_token_logits[i, token_id] /= repetition_penalty
next_token_logits = next_token_logits / temperature
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token), dim=1)
if (next_token == self.config.eos_token_id).all():
break
input_ids = next_token
return generated
class GPT2Config(object):
def __init__(
self,
vocab_size_or_config_json_file=50257,
n_positions=1024,
n_ctx=1024,
n_embd=768,
n_layer=12,
n_head=12,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
):
self.vocab_size = vocab_size_or_config_json_file
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = GPT2Config()
model = GPT2LMHeadModel(config)
state_dict = torch.load(r'epoch_5.pth', map_location='cpu' if not torch.cuda.is_available() else None)
model = load_weight(model, state_dict)
model.to(device)
print(model)
model.eval()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size x vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
"""
assert logits.dim() == 2, "Expected logits dimension to be 2 (batch size x vocabulary size)"
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(nn.Softmax(dim=-1)(sorted_logits), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Ensure that the dimensions match
if sorted_indices_to_remove.size() != sorted_indices.size():
raise ValueError(f"Size mismatch: {sorted_indices_to_remove.size()} vs {sorted_indices.size()}")
indices_to_remove = sorted_indices[sorted_indices_to_remove]
# Expand dimensions to match logits tensor and use scatter_
for batch_idx in range(logits.size(0)):
logits[batch_idx, indices_to_remove[batch_idx]] = filter_value
return logits
# prompt_text = "What is the classical conceptualisation of oxidation and reduction in redox reactions?"
# prompt = f"\n<|startoftext|>[WP] {prompt_text} \n[RESPONSE]"
# input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
# max_length = 50
# temperature = 0.7
# top_k = 50
# top_p = 0.95
# repetition_penalty = 1.0
# with torch.no_grad():
# for _ in range(max_length):
# outputs = model(input_ids)
# logits = outputs[0]
# next_token_logits = logits[:, -1, :] / temperature
# # Apply repetition penalty
# for i in range(input_ids.size(0)):
# for token_id in set(input_ids[i].tolist()):
# next_token_logits[0, token_id] /= repetition_penalty
# # Filter logits using top-k and/or top-p filtering
# filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
# next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
# input_ids = torch.cat([input_ids, next_token], dim=-1).to(device)
# import re
# # generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
# # wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:]
# print(input_ids[0])
# generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
# wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:]
# print(wp_responses)
# Define the generation function
def generate_text(prompt_text, max_length=50, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.0):
prompt = f"\n[WP] {prompt_text} \n[RESPONSE]"
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
with torch.no_grad():
for _ in range(max_length):
outputs = model(input_ids)
logits = outputs[0]
next_token_logits = logits[:, -1, :] / temperature
# Apply repetition penalty
for i in range(input_ids.size(0)):
for token_id in set(input_ids[i].tolist()):
next_token_logits[0, token_id] /= repetition_penalty
# Filter logits using top-k and/or top-p filtering
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=-1).to(device)
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:]
return wp_responses[1]
# Define example prompts
examples = [
"What is the classical conceptualisation of oxidation and reduction in redox reactions?",
"What is the difference between alkenes and alkynes in terms of reactivity?",
"What is the first law of thermodynamics in thermodynamics?",
"What is a popular type of vending machine in banking services?",
"What has the worldwide campaign against smallpox led to?"
]
# Define the Gradio interface using Blocks
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("<h1 style='text-align: center'>GPT-2 Text Generator</h1>")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=2, placeholder="Enter prompt here...", label="Prompt")
max_length = gr.Slider(minimum=10, maximum=100, step=1, value=50, label="Max Length")
temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Temperature")
top_k = gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Top K")
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.95, label="Top P")
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Repetition Penalty")
generate_button = gr.Button("Generate")
with gr.Column():
output_text = gr.Textbox(lines=20, label="Generated Text")
generate_button.click(
fn=generate_text,
inputs=[prompt, max_length, temperature, top_k, top_p, repetition_penalty],
outputs=output_text
)
gr.Examples(
examples=examples,
inputs=prompt,
)
demo.launch(share=True,debug=True)