davanstrien HF Staff commited on
Commit
83af74a
·
verified ·
1 Parent(s): f59ec74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -10
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  # Load model and tokenizer
6
- model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
7
- tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
8
 
9
  def get_next_token_probs(text):
10
  # Handle empty input
@@ -24,7 +24,7 @@ def get_next_token_probs(text):
24
  next_token_probs = torch.softmax(next_token_logits, dim=0)
25
 
26
  # Get top-5 tokens and their probabilities
27
- topk_probs, topk_indices = torch.topk(next_token_probs, 20)
28
  topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
29
 
30
  # Format the results as strings
@@ -46,13 +46,10 @@ with gr.Blocks(css="footer {display: none}") as demo:
46
  # Input textbox
47
  input_text = gr.Textbox(
48
  label="Text Input",
49
- placeholder="Type text here...",
50
  value="The weather tomorrow will be"
51
  )
52
 
53
- # Predict button
54
- predict_btn = gr.Button("Predict Next Tokens")
55
-
56
  # Simple header for results
57
  gr.Markdown("##### Most likely next tokens:")
58
 
@@ -65,8 +62,8 @@ with gr.Blocks(css="footer {display: none}") as demo:
65
 
66
  token_outputs = [token1, token2, token3, token4, token5]
67
 
68
- # Set up button click event
69
- predict_btn.click(
70
  fn=get_next_token_probs,
71
  inputs=input_text,
72
  outputs=token_outputs
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
 
5
  # Load model and tokenizer
6
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
7
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
8
 
9
  def get_next_token_probs(text):
10
  # Handle empty input
 
24
  next_token_probs = torch.softmax(next_token_logits, dim=0)
25
 
26
  # Get top-5 tokens and their probabilities
27
+ topk_probs, topk_indices = torch.topk(next_token_probs, 5)
28
  topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
29
 
30
  # Format the results as strings
 
46
  # Input textbox
47
  input_text = gr.Textbox(
48
  label="Text Input",
49
+ placeholder="Type here and watch predictions update...",
50
  value="The weather tomorrow will be"
51
  )
52
 
 
 
 
53
  # Simple header for results
54
  gr.Markdown("##### Most likely next tokens:")
55
 
 
62
 
63
  token_outputs = [token1, token2, token3, token4, token5]
64
 
65
+ # Set up the live update
66
+ input_text.change(
67
  fn=get_next_token_probs,
68
  inputs=input_text,
69
  outputs=token_outputs