dhairyashil's picture
add repo link
b135ec8
"""
Gradio web app for Shakespeare-style text generation using the trained GPT model.
This app provides an interactive interface for users to generate Shakespeare-style text
with customizable parameters.
"""
import os
import torch
import gradio as gr
from model import GPT, GPTConfig
import tiktoken
torch.set_default_device('cpu')
class ShakespeareTextGenerator:
def __init__(self, model_path='compressed_model_cpu_compatible.pt'):
"""Initialize the text generator with the trained model"""
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load checkpoint
checkpoint = torch.load(model_path, map_location=self.device)
# Initialize model with saved config
self.config = GPTConfig(**checkpoint['config'])
self.model = GPT(self.config)
# Load state dict and convert to correct dtype if needed
if checkpoint['dtype'] == 'float16' and self.device == 'cuda':
self.model.half()
elif checkpoint['dtype'] == 'float32':
self.model.float()
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
self.model.eval()
# Initialize tokenizer with special token handling
self.tokenizer = tiktoken.get_encoding('gpt2')
self.end_token = self.tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]
def generate(self,
prompt,
max_length=100,
temperature=0.7,
top_k=50,
top_p=0.9,
num_return_sequences=1):
"""
Generate Shakespeare-style text based on the prompt
"""
# Encode the prompt with special token handling
input_ids = torch.tensor(
self.tokenizer.encode(prompt, allowed_special=set())
).unsqueeze(0).to(self.device)
generated_sequences = []
with torch.no_grad():
for _ in range(num_return_sequences):
# Initialize sequence with input_ids
cur_ids = input_ids.clone()
for _ in range(max_length):
# Get model's logits for next token
outputs, _ = self.model(cur_ids)
next_token_logits = outputs[:, -1, :] / temperature
# Apply top-k filtering
if top_k > 0:
values, _ = torch.topk(next_token_logits, top_k)
min_value = values[:, -1].unsqueeze(-1).expand_as(next_token_logits)
next_token_logits = torch.where(
next_token_logits < min_value,
torch.ones_like(next_token_logits) * float('-inf'),
next_token_logits
)
# Apply top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf'))
# Sample next token
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append to sequence
cur_ids = torch.cat([cur_ids, next_token], dim=1)
# Stop if we predict the end of text token
if next_token.item() == self.end_token:
break
# Decode the generated sequence
generated_text = self.tokenizer.decode(cur_ids[0].tolist())
generated_sequences.append(generated_text)
return generated_sequences
# Initialize the generator
generator = ShakespeareTextGenerator()
def generate_text(prompt, max_length, temperature, top_k, top_p, num_sequences):
"""Gradio interface function"""
try:
sequences = generator.generate(
prompt=prompt,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
num_return_sequences=num_sequences
)
return "\n\n---\n\n".join(sequences)
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(
lines=3,
label="Prompt",
placeholder="Enter your prompt here...",
value="To be, or not to be,"
),
gr.Slider(
minimum=10,
maximum=500,
value=100,
step=10,
label="Maximum Length"
),
gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature (randomness)"
),
gr.Slider(
minimum=0,
maximum=100,
value=50,
step=5,
label="Top-k"
),
gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-p (nucleus sampling)"
),
gr.Slider(
minimum=1,
maximum=5,
value=1,
step=1,
label="Number of Sequences"
)
],
outputs=gr.Textbox(
lines=10,
label="Generated Text"
),
title="Shakespeare-Style Text Generator",
description="""Generate Shakespeare-style text using a fine-tuned GPT model. Training repository: [https://github.com/dhairyag/ShakespeareGPT-Forge](https://github.com/dhairyag/ShakespeareGPT-Forge)
Adjust the parameters to control the generation:
- Temperature: Higher values make the output more random
- Top-k: Limits the vocabulary to the k most likely tokens
- Top-p: Limits the cumulative probability of tokens considered
- Number of Sequences: Generate multiple variations""",
examples=[
["To be, or not to be,", 100, 0.7, 50, 0.9, 1],
["O Romeo, Romeo,", 150, 0.8, 40, 0.85, 2],
["All the world's a stage,", 200, 0.6, 60, 0.95, 1]
]
)
# Launch the app
if __name__ == "__main__":
iface.launch()