davanstrien HF Staff commited on
Commit
823cebb
·
verified ·
1 Parent(s): bcec63c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+
8
+ # Load model and tokenizer
9
+ model_name = "gpt2"
10
+ model = GPT2LMHeadModel.from_pretrained(model_name)
11
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
12
+
13
+ def get_token_probabilities(text, top_k=10):
14
+ # Tokenize the input text
15
+ input_ids = tokenizer.encode(text, return_tensors="pt")
16
+
17
+ # Get the last token's position
18
+ last_token_position = input_ids.shape[1] - 1
19
+
20
+ # Get model predictions
21
+ with torch.no_grad():
22
+ outputs = model(input_ids)
23
+ logits = outputs.logits
24
+
25
+ # Get probabilities for the next token after the last token
26
+ next_token_logits = logits[0, last_token_position, :]
27
+ next_token_probs = torch.softmax(next_token_logits, dim=0)
28
+
29
+ # Get top k most likely tokens
30
+ topk_probs, topk_indices = torch.topk(next_token_probs, top_k)
31
+
32
+ # Convert to numpy for easier handling
33
+ topk_probs = topk_probs.numpy()
34
+ topk_indices = topk_indices.numpy()
35
+
36
+ # Decode tokens
37
+ topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
38
+
39
+ # Create a plot
40
+ plt.figure(figsize=(10, 6))
41
+ sns.barplot(x=topk_probs, y=topk_tokens)
42
+ plt.title(f"Top {top_k} token probabilities after: '{text}'")
43
+ plt.xlabel("Probability")
44
+ plt.ylabel("Tokens")
45
+ plt.tight_layout()
46
+
47
+ # Save the plot to a file
48
+ plt.savefig("token_probabilities.png")
49
+ plt.close()
50
+
51
+ return "token_probabilities.png", dict(zip(topk_tokens, topk_probs.tolist()))
52
+
53
+ def interface():
54
+ with gr.Blocks() as demo:
55
+ gr.Markdown("# GPT-2 Next Token Probability Visualizer")
56
+ gr.Markdown("Enter text and see the probabilities of possible next tokens.")
57
+
58
+ with gr.Row():
59
+ with gr.Column():
60
+ input_text = gr.Textbox(
61
+ label="Input Text",
62
+ placeholder="Type some text here...",
63
+ value="Hello, my name is"
64
+ )
65
+ top_k = gr.Slider(
66
+ minimum=5,
67
+ maximum=20,
68
+ value=10,
69
+ step=1,
70
+ label="Number of top tokens to show"
71
+ )
72
+ btn = gr.Button("Generate Probabilities")
73
+
74
+ with gr.Column():
75
+ output_image = gr.Image(label="Probability Distribution")
76
+ output_table = gr.JSON(label="Token Probabilities")
77
+
78
+ btn.click(
79
+ fn=get_token_probabilities,
80
+ inputs=[input_text, top_k],
81
+ outputs=[output_image, output_table]
82
+ )
83
+
84
+ gr.Examples(
85
+ examples=[
86
+ ["Hello, my name is", 10],
87
+ ["The capital of France is", 10],
88
+ ["Once upon a time", 10],
89
+ ["The best way to learn is to", 10]
90
+ ],
91
+ inputs=[input_text, top_k],
92
+ )
93
+
94
+ return demo
95
+
96
+ if __name__ == "__main__":
97
+ demo = interface()
98
+ demo.launch()