next-token / app.py
davanstrien's picture
davanstrien HF Staff
Update app.py
a6b48a0 verified
raw
history blame
2.81 kB
import spaces
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Load model and tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
@spaces.GPU
def get_next_token_probs(text, top_k=5):
# Handle empty input
if not text.strip():
return [""] * top_k
# Tokenize input
input_ids = tokenizer.encode(text, return_tensors="pt")
# Get predictions
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
# Get probabilities for next token
next_token_logits = logits[0, -1, :]
next_token_probs = torch.softmax(next_token_logits, dim=0)
# Get top-k tokens and their probabilities
topk_probs, topk_indices = torch.topk(next_token_probs, top_k)
topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
# Format the results as strings
formatted_results = []
for i, (token, prob) in enumerate(zip(topk_tokens, topk_probs)):
# Format probability as percentage with 1 decimal place
prob_percent = f"{prob.item()*100:.1f}%"
# Clean up token display (remove leading space if present)
display_token = token.replace(" ", "␣") # Replace space with visible space symbol
# Format the output string
formatted_results.append(f"{i+1}. \"{display_token}\" ({prob_percent})")
return formatted_results
# Create custom CSS
custom_css = """
.token-box {
margin-top: 10px;
padding: 15px;
border-radius: 8px;
background-color: #f7f7f7;
font-family: monospace;
font-size: 16px;
}
.token-item {
margin: 8px 0;
padding: 8px;
background-color: white;
border-left: 4px solid #2c8ecb;
border-radius: 4px;
}
footer {display: none}
"""
# Create minimal interface
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("### GPT-2 Next Token Predictor")
# Input textbox
input_text = gr.Textbox(
label="Text Input",
placeholder="Type here and watch predictions update...",
value="The weather tomorrow will be"
)
# Container for token displays
with gr.Box(elem_classes=["token-box"]):
gr.Markdown("##### Most likely next tokens:")
token_outputs = [gr.Markdown(elem_classes=["token-item"]) for _ in range(5)]
# Function to update tokens in real-time
def update_tokens(text):
return get_next_token_probs(text)
# Set up the live update
input_text.change(
fn=update_tokens,
inputs=input_text,
outputs=token_outputs
)
# Initialize with default text
demo.load(
fn=update_tokens,
inputs=input_text,
outputs=token_outputs
)
# Launch the app
demo.launch()