davanstrien HF Staff commited on
Commit
a6b48a0
·
verified ·
1 Parent(s): 889d49e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -64
app.py CHANGED
@@ -2,92 +2,97 @@ import spaces
2
  import gradio as gr
3
  import torch
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
- import matplotlib.pyplot as plt
6
- import seaborn as sns
7
- import os
8
 
9
  # Load model and tokenizer
10
- model_name = "gpt2"
11
- model = GPT2LMHeadModel.from_pretrained(model_name)
12
- tokenizer = GPT2Tokenizer.from_pretrained(model_name)
13
 
14
  @spaces.GPU
15
- def get_token_probabilities(text, top_k=10):
16
- # Tokenize the input text
17
- input_ids = tokenizer.encode(text, return_tensors="pt")
 
18
 
19
- # Get the last token's position
20
- last_token_position = input_ids.shape[1] - 1
21
 
22
- # Get model predictions
23
  with torch.no_grad():
24
  outputs = model(input_ids)
25
  logits = outputs.logits
26
 
27
- # Get probabilities for the next token after the last token
28
- next_token_logits = logits[0, last_token_position, :]
29
  next_token_probs = torch.softmax(next_token_logits, dim=0)
30
 
31
- # Get top k most likely tokens
32
  topk_probs, topk_indices = torch.topk(next_token_probs, top_k)
33
-
34
- # Convert to numpy for easier handling
35
- topk_probs = topk_probs.numpy()
36
- topk_indices = topk_indices.numpy()
37
-
38
- # Decode tokens
39
  topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
40
 
41
- # Create a plot
42
- plt.figure(figsize=(10, 6))
43
- sns.barplot(x=topk_probs, y=topk_tokens)
44
- plt.title(f"Top {top_k} token probabilities after: '{text}'")
45
- plt.xlabel("Probability")
46
- plt.ylabel("Tokens")
47
- plt.tight_layout()
 
 
48
 
49
- # Ensure temp directory exists
50
- os.makedirs("tmp", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Save the plot to a file in the temp directory
53
- plot_path = os.path.join("tmp", "token_probabilities.png")
54
- plt.savefig(plot_path)
55
- plt.close()
 
 
56
 
57
- return plot_path, dict(zip(topk_tokens, topk_probs.tolist()))
58
-
59
- with gr.Blocks() as demo:
60
- gr.Markdown("# GPT-2 Next Token Probability Visualizer")
61
- gr.Markdown("Enter text and see the probabilities of possible next tokens.")
62
 
63
- with gr.Row():
64
- with gr.Column():
65
- input_text = gr.Textbox(
66
- label="Input Text",
67
- placeholder="Type some text here...",
68
- value="Hello, my name is"
69
- )
70
- top_k = gr.Slider(
71
- minimum=5,
72
- maximum=20,
73
- value=10,
74
- step=1,
75
- label="Number of top tokens to show"
76
- )
77
- btn = gr.Button("Generate Probabilities")
78
-
79
- with gr.Column():
80
- output_image = gr.Image(label="Probability Distribution")
81
- output_table = gr.JSON(label="Token Probabilities")
82
 
83
- btn.click(
84
- fn=get_token_probabilities,
85
- inputs=[input_text, top_k],
86
- outputs=[output_image, output_table]
 
87
  )
88
 
89
- gr.Examples(
90
- inputs=[input_text, top_k],
 
 
 
91
  )
92
 
93
  # Launch the app
 
2
  import gradio as gr
3
  import torch
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
 
 
5
 
6
  # Load model and tokenizer
7
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
8
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
 
9
 
10
  @spaces.GPU
11
+ def get_next_token_probs(text, top_k=5):
12
+ # Handle empty input
13
+ if not text.strip():
14
+ return [""] * top_k
15
 
16
+ # Tokenize input
17
+ input_ids = tokenizer.encode(text, return_tensors="pt")
18
 
19
+ # Get predictions
20
  with torch.no_grad():
21
  outputs = model(input_ids)
22
  logits = outputs.logits
23
 
24
+ # Get probabilities for next token
25
+ next_token_logits = logits[0, -1, :]
26
  next_token_probs = torch.softmax(next_token_logits, dim=0)
27
 
28
+ # Get top-k tokens and their probabilities
29
  topk_probs, topk_indices = torch.topk(next_token_probs, top_k)
 
 
 
 
 
 
30
  topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
31
 
32
+ # Format the results as strings
33
+ formatted_results = []
34
+ for i, (token, prob) in enumerate(zip(topk_tokens, topk_probs)):
35
+ # Format probability as percentage with 1 decimal place
36
+ prob_percent = f"{prob.item()*100:.1f}%"
37
+ # Clean up token display (remove leading space if present)
38
+ display_token = token.replace(" ", "␣") # Replace space with visible space symbol
39
+ # Format the output string
40
+ formatted_results.append(f"{i+1}. \"{display_token}\" ({prob_percent})")
41
 
42
+ return formatted_results
43
+
44
+ # Create custom CSS
45
+ custom_css = """
46
+ .token-box {
47
+ margin-top: 10px;
48
+ padding: 15px;
49
+ border-radius: 8px;
50
+ background-color: #f7f7f7;
51
+ font-family: monospace;
52
+ font-size: 16px;
53
+ }
54
+ .token-item {
55
+ margin: 8px 0;
56
+ padding: 8px;
57
+ background-color: white;
58
+ border-left: 4px solid #2c8ecb;
59
+ border-radius: 4px;
60
+ }
61
+ footer {display: none}
62
+ """
63
+
64
+ # Create minimal interface
65
+ with gr.Blocks(css=custom_css) as demo:
66
+ gr.Markdown("### GPT-2 Next Token Predictor")
67
 
68
+ # Input textbox
69
+ input_text = gr.Textbox(
70
+ label="Text Input",
71
+ placeholder="Type here and watch predictions update...",
72
+ value="The weather tomorrow will be"
73
+ )
74
 
75
+ # Container for token displays
76
+ with gr.Box(elem_classes=["token-box"]):
77
+ gr.Markdown("##### Most likely next tokens:")
78
+ token_outputs = [gr.Markdown(elem_classes=["token-item"]) for _ in range(5)]
 
79
 
80
+ # Function to update tokens in real-time
81
+ def update_tokens(text):
82
+ return get_next_token_probs(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ # Set up the live update
85
+ input_text.change(
86
+ fn=update_tokens,
87
+ inputs=input_text,
88
+ outputs=token_outputs
89
  )
90
 
91
+ # Initialize with default text
92
+ demo.load(
93
+ fn=update_tokens,
94
+ inputs=input_text,
95
+ outputs=token_outputs
96
  )
97
 
98
  # Launch the app