File size: 2,346 Bytes
823cebb
 
e25871b
823cebb
 
e25871b
 
823cebb
b24cb59
a6b48a0
 
e25871b
823cebb
a6b48a0
 
823cebb
a6b48a0
823cebb
 
 
 
a6b48a0
 
823cebb
 
e25871b
 
823cebb
 
a6b48a0
 
 
 
 
b24cb59
 
a6b48a0
 
823cebb
a6b48a0
 
b24cb59
 
e25871b
2faad0e
a6b48a0
 
 
83af74a
a6b48a0
 
823cebb
b24cb59
 
 
e25871b
 
2faad0e
83af74a
 
b24cb59
a6b48a0
 
2faad0e
 
a6b48a0
 
b24cb59
a6b48a0
 
2faad0e
823cebb
2faad0e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")

def get_next_token_probs(text):
    # Handle empty input
    if not text.strip():
        return ["No input text"] * 20
    
    # Tokenize input
    input_ids = tokenizer.encode(text, return_tensors="pt")
    
    # Get predictions
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
    
    # Get probabilities for next token
    next_token_logits = logits[0, -1, :]
    next_token_probs = torch.softmax(next_token_logits, dim=0)
    
    # Get top-20 tokens and their probabilities
    topk_probs, topk_indices = torch.topk(next_token_probs, 20)
    topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
    
    # Format the results as strings
    formatted_results = []
    for i, (token, prob) in enumerate(zip(topk_tokens, topk_probs)):
        # Format probability as percentage with 1 decimal place
        prob_percent = f"{prob.item()*100:.1f}%"
        # Clean up token display (replace space with visible space symbol)
        display_token = token.replace(" ", "␣")
        # Format the output string
        formatted_results.append(f"{i+1}. \"{display_token}\" ({prob_percent})")
    
    return formatted_results

# Create minimal interface with simpler components
with gr.Blocks(css="footer {display: none}") as demo:
    gr.Markdown("### SmolLM2 Next Token Predictor")
    
    # Input textbox
    input_text = gr.Textbox(
        label="Text Input",
        placeholder="Type here and watch predictions update...",
        value="The weather tomorrow will be"
    )
    
    # Simple header for results
    gr.Markdown("##### Most likely next tokens:")
    
    # Create 20 individual output markdown components
    token_outputs = [gr.Markdown() for _ in range(20)]
    
    # Set up the live update
    input_text.change(
        fn=get_next_token_probs,
        inputs=input_text,
        outputs=token_outputs
    )
    
    # Initialize with default text
    demo.load(
        fn=get_next_token_probs,
        inputs=input_text,
        outputs=token_outputs
    )

# Launch the app
demo.launch()