next-token / app.py
davanstrien's picture
davanstrien HF Staff
Update app.py
e25871b verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
def get_next_token_probs(text):
# Handle empty input
if not text.strip():
return ["No input text"] * 20
# Tokenize input
input_ids = tokenizer.encode(text, return_tensors="pt")
# Get predictions
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
# Get probabilities for next token
next_token_logits = logits[0, -1, :]
next_token_probs = torch.softmax(next_token_logits, dim=0)
# Get top-20 tokens and their probabilities
topk_probs, topk_indices = torch.topk(next_token_probs, 20)
topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
# Format the results as strings
formatted_results = []
for i, (token, prob) in enumerate(zip(topk_tokens, topk_probs)):
# Format probability as percentage with 1 decimal place
prob_percent = f"{prob.item()*100:.1f}%"
# Clean up token display (replace space with visible space symbol)
display_token = token.replace(" ", "␣")
# Format the output string
formatted_results.append(f"{i+1}. \"{display_token}\" ({prob_percent})")
return formatted_results
# Create minimal interface with simpler components
with gr.Blocks(css="footer {display: none}") as demo:
gr.Markdown("### SmolLM2 Next Token Predictor")
# Input textbox
input_text = gr.Textbox(
label="Text Input",
placeholder="Type here and watch predictions update...",
value="The weather tomorrow will be"
)
# Simple header for results
gr.Markdown("##### Most likely next tokens:")
# Create 20 individual output markdown components
token_outputs = [gr.Markdown() for _ in range(20)]
# Set up the live update
input_text.change(
fn=get_next_token_probs,
inputs=input_text,
outputs=token_outputs
)
# Initialize with default text
demo.load(
fn=get_next_token_probs,
inputs=input_text,
outputs=token_outputs
)
# Launch the app
demo.launch()