Spaces:
Running
Running
File size: 3,028 Bytes
5ced46c 823cebb 2faad0e 823cebb 5ced46c 823cebb 2faad0e 823cebb 2faad0e 823cebb 2faad0e 823cebb 2faad0e 823cebb 2faad0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import spaces
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import matplotlib.pyplot as plt
import seaborn as sns
import os
# Load model and tokenizer
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
@spaces.GPU
def get_token_probabilities(text, top_k=10):
# Tokenize the input text
input_ids = tokenizer.encode(text, return_tensors="pt")
# Get the last token's position
last_token_position = input_ids.shape[1] - 1
# Get model predictions
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
# Get probabilities for the next token after the last token
next_token_logits = logits[0, last_token_position, :]
next_token_probs = torch.softmax(next_token_logits, dim=0)
# Get top k most likely tokens
topk_probs, topk_indices = torch.topk(next_token_probs, top_k)
# Convert to numpy for easier handling
topk_probs = topk_probs.numpy()
topk_indices = topk_indices.numpy()
# Decode tokens
topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
# Create a plot
plt.figure(figsize=(10, 6))
sns.barplot(x=topk_probs, y=topk_tokens)
plt.title(f"Top {top_k} token probabilities after: '{text}'")
plt.xlabel("Probability")
plt.ylabel("Tokens")
plt.tight_layout()
# Ensure temp directory exists
os.makedirs("tmp", exist_ok=True)
# Save the plot to a file in the temp directory
plot_path = os.path.join("tmp", "token_probabilities.png")
plt.savefig(plot_path)
plt.close()
return plot_path, dict(zip(topk_tokens, topk_probs.tolist()))
with gr.Blocks() as demo:
gr.Markdown("# GPT-2 Next Token Probability Visualizer")
gr.Markdown("Enter text and see the probabilities of possible next tokens.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Input Text",
placeholder="Type some text here...",
value="Hello, my name is"
)
top_k = gr.Slider(
minimum=5,
maximum=20,
value=10,
step=1,
label="Number of top tokens to show"
)
btn = gr.Button("Generate Probabilities")
with gr.Column():
output_image = gr.Image(label="Probability Distribution")
output_table = gr.JSON(label="Token Probabilities")
btn.click(
fn=get_token_probabilities,
inputs=[input_text, top_k],
outputs=[output_image, output_table]
)
gr.Examples(
examples=[
["Hello, my name is", 10],
["The capital of France is", 10],
["Once upon a time", 10],
["The best way to learn is to", 10]
],
inputs=[input_text, top_k],
)
# Launch the app
demo.launch() |