File size: 2,811 Bytes
5ced46c
823cebb
 
 
 
 
a6b48a0
 
823cebb
5ced46c
a6b48a0
 
 
 
823cebb
a6b48a0
 
823cebb
a6b48a0
823cebb
 
 
 
a6b48a0
 
823cebb
 
a6b48a0
823cebb
 
 
a6b48a0
 
 
 
 
 
 
 
 
823cebb
a6b48a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2faad0e
a6b48a0
 
 
 
 
 
823cebb
a6b48a0
 
 
 
2faad0e
a6b48a0
 
 
2faad0e
a6b48a0
 
 
 
 
2faad0e
 
a6b48a0
 
 
 
 
2faad0e
823cebb
2faad0e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import spaces
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load model and tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

@spaces.GPU
def get_next_token_probs(text, top_k=5):
    # Handle empty input
    if not text.strip():
        return [""] * top_k
    
    # Tokenize input
    input_ids = tokenizer.encode(text, return_tensors="pt")
    
    # Get predictions
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
    
    # Get probabilities for next token
    next_token_logits = logits[0, -1, :]
    next_token_probs = torch.softmax(next_token_logits, dim=0)
    
    # Get top-k tokens and their probabilities
    topk_probs, topk_indices = torch.topk(next_token_probs, top_k)
    topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
    
    # Format the results as strings
    formatted_results = []
    for i, (token, prob) in enumerate(zip(topk_tokens, topk_probs)):
        # Format probability as percentage with 1 decimal place
        prob_percent = f"{prob.item()*100:.1f}%"
        # Clean up token display (remove leading space if present)
        display_token = token.replace(" ", "␣")  # Replace space with visible space symbol
        # Format the output string
        formatted_results.append(f"{i+1}. \"{display_token}\" ({prob_percent})")
    
    return formatted_results

# Create custom CSS
custom_css = """
.token-box {
    margin-top: 10px;
    padding: 15px;
    border-radius: 8px;
    background-color: #f7f7f7;
    font-family: monospace;
    font-size: 16px;
}
.token-item {
    margin: 8px 0;
    padding: 8px;
    background-color: white;
    border-left: 4px solid #2c8ecb;
    border-radius: 4px;
}
footer {display: none}
"""

# Create minimal interface
with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("### GPT-2 Next Token Predictor")
    
    # Input textbox
    input_text = gr.Textbox(
        label="Text Input",
        placeholder="Type here and watch predictions update...",
        value="The weather tomorrow will be"
    )
    
    # Container for token displays
    with gr.Box(elem_classes=["token-box"]):
        gr.Markdown("##### Most likely next tokens:")
        token_outputs = [gr.Markdown(elem_classes=["token-item"]) for _ in range(5)]
    
    # Function to update tokens in real-time
    def update_tokens(text):
        return get_next_token_probs(text)
    
    # Set up the live update
    input_text.change(
        fn=update_tokens,
        inputs=input_text,
        outputs=token_outputs
    )
    
    # Initialize with default text
    demo.load(
        fn=update_tokens,
        inputs=input_text,
        outputs=token_outputs
    )

# Launch the app
demo.launch()