Spaces:
Running
Running
File size: 2,811 Bytes
5ced46c 823cebb a6b48a0 823cebb 5ced46c a6b48a0 823cebb a6b48a0 823cebb a6b48a0 823cebb a6b48a0 823cebb a6b48a0 823cebb a6b48a0 823cebb a6b48a0 2faad0e a6b48a0 823cebb a6b48a0 2faad0e a6b48a0 2faad0e a6b48a0 2faad0e a6b48a0 2faad0e 823cebb 2faad0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import spaces
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Load model and tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
@spaces.GPU
def get_next_token_probs(text, top_k=5):
# Handle empty input
if not text.strip():
return [""] * top_k
# Tokenize input
input_ids = tokenizer.encode(text, return_tensors="pt")
# Get predictions
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
# Get probabilities for next token
next_token_logits = logits[0, -1, :]
next_token_probs = torch.softmax(next_token_logits, dim=0)
# Get top-k tokens and their probabilities
topk_probs, topk_indices = torch.topk(next_token_probs, top_k)
topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
# Format the results as strings
formatted_results = []
for i, (token, prob) in enumerate(zip(topk_tokens, topk_probs)):
# Format probability as percentage with 1 decimal place
prob_percent = f"{prob.item()*100:.1f}%"
# Clean up token display (remove leading space if present)
display_token = token.replace(" ", "␣") # Replace space with visible space symbol
# Format the output string
formatted_results.append(f"{i+1}. \"{display_token}\" ({prob_percent})")
return formatted_results
# Create custom CSS
custom_css = """
.token-box {
margin-top: 10px;
padding: 15px;
border-radius: 8px;
background-color: #f7f7f7;
font-family: monospace;
font-size: 16px;
}
.token-item {
margin: 8px 0;
padding: 8px;
background-color: white;
border-left: 4px solid #2c8ecb;
border-radius: 4px;
}
footer {display: none}
"""
# Create minimal interface
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("### GPT-2 Next Token Predictor")
# Input textbox
input_text = gr.Textbox(
label="Text Input",
placeholder="Type here and watch predictions update...",
value="The weather tomorrow will be"
)
# Container for token displays
with gr.Box(elem_classes=["token-box"]):
gr.Markdown("##### Most likely next tokens:")
token_outputs = [gr.Markdown(elem_classes=["token-item"]) for _ in range(5)]
# Function to update tokens in real-time
def update_tokens(text):
return get_next_token_probs(text)
# Set up the live update
input_text.change(
fn=update_tokens,
inputs=input_text,
outputs=token_outputs
)
# Initialize with default text
demo.load(
fn=update_tokens,
inputs=input_text,
outputs=token_outputs
)
# Launch the app
demo.launch() |