Spaces:
Running
Running
import spaces | |
import gradio as gr | |
import torch | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
# Load model and tokenizer | |
model = GPT2LMHeadModel.from_pretrained("gpt2") | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
def get_next_token_probs(text, top_k=5): | |
# Handle empty input | |
if not text.strip(): | |
return [""] * top_k | |
# Tokenize input | |
input_ids = tokenizer.encode(text, return_tensors="pt") | |
# Get predictions | |
with torch.no_grad(): | |
outputs = model(input_ids) | |
logits = outputs.logits | |
# Get probabilities for next token | |
next_token_logits = logits[0, -1, :] | |
next_token_probs = torch.softmax(next_token_logits, dim=0) | |
# Get top-k tokens and their probabilities | |
topk_probs, topk_indices = torch.topk(next_token_probs, top_k) | |
topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices] | |
# Format the results as strings | |
formatted_results = [] | |
for i, (token, prob) in enumerate(zip(topk_tokens, topk_probs)): | |
# Format probability as percentage with 1 decimal place | |
prob_percent = f"{prob.item()*100:.1f}%" | |
# Clean up token display (remove leading space if present) | |
display_token = token.replace(" ", "␣") # Replace space with visible space symbol | |
# Format the output string | |
formatted_results.append(f"{i+1}. \"{display_token}\" ({prob_percent})") | |
return formatted_results | |
# Create custom CSS | |
custom_css = """ | |
.token-box { | |
margin-top: 10px; | |
padding: 15px; | |
border-radius: 8px; | |
background-color: #f7f7f7; | |
font-family: monospace; | |
font-size: 16px; | |
} | |
.token-item { | |
margin: 8px 0; | |
padding: 8px; | |
background-color: white; | |
border-left: 4px solid #2c8ecb; | |
border-radius: 4px; | |
} | |
footer {display: none} | |
""" | |
# Create minimal interface | |
with gr.Blocks(css=custom_css) as demo: | |
gr.Markdown("### GPT-2 Next Token Predictor") | |
# Input textbox | |
input_text = gr.Textbox( | |
label="Text Input", | |
placeholder="Type here and watch predictions update...", | |
value="The weather tomorrow will be" | |
) | |
# Container for token displays | |
with gr.Box(elem_classes=["token-box"]): | |
gr.Markdown("##### Most likely next tokens:") | |
token_outputs = [gr.Markdown(elem_classes=["token-item"]) for _ in range(5)] | |
# Function to update tokens in real-time | |
def update_tokens(text): | |
return get_next_token_probs(text) | |
# Set up the live update | |
input_text.change( | |
fn=update_tokens, | |
inputs=input_text, | |
outputs=token_outputs | |
) | |
# Initialize with default text | |
demo.load( | |
fn=update_tokens, | |
inputs=input_text, | |
outputs=token_outputs | |
) | |
# Launch the app | |
demo.launch() |