Enhance UI styling and layout for Next-Token Predictor with gradient backgrounds, improved button designs, and updated tooltip functionality
35d29d6
| import gradio as gr | |
| import json | |
| import os | |
| import time | |
| import torch | |
| from typing import List, Dict, Tuple | |
| from dotenv import load_dotenv | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Configuration | |
| MODEL_ID = "Qwen/Qwen3-0.6B" | |
| HF_TOKEN = os.getenv('HF_NEXT_TOKEN_PREDICTOR_TOKEN', '') | |
| # Initialize model and tokenizer (local inference like the working app) | |
| print("Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID) | |
| def show_token(token: str) -> str: | |
| """Format token for display""" | |
| if token == "\n": | |
| return "⏎" | |
| elif token.strip() == "": | |
| return f"␣{'' if len(token) == 1 else '×' + str(len(token))}" | |
| return token | |
| def predict_next_token(text: str, top_k: int = 10, temperature: float = 1.0, top_p: float = 0.9) -> Tuple[List[Dict], str]: | |
| """Predict next tokens using local model with temperature and top-p filtering""" | |
| if not text.strip(): | |
| return [], "Please enter some text to predict from" | |
| start_time = time.time() | |
| try: | |
| # Use local model inference | |
| tokens = tokenizer(text, return_tensors="pt", padding=False) | |
| out = model.generate( | |
| **tokens, | |
| max_new_tokens=1, | |
| output_scores=True, | |
| return_dict_in_generate=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=False, | |
| ) | |
| # Get raw logits and apply temperature scaling | |
| logits = out.scores[0] | |
| scaled_logits = logits / temperature | |
| scores = torch.softmax(scaled_logits, dim=-1) | |
| # Apply top-p filtering (nucleus sampling) | |
| sorted_probs, sorted_indices = torch.sort(scores, descending=True) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
| # Find the cutoff point for top-p | |
| cutoff_index = torch.where(cumulative_probs >= top_p)[1] | |
| if len(cutoff_index) > 0: | |
| cutoff = cutoff_index[0].item() + 1 | |
| top_p_indices = sorted_indices[0, :cutoff] | |
| top_p_probs = sorted_probs[0, :cutoff] | |
| else: | |
| # Fallback if top_p is very low | |
| top_p_indices = sorted_indices[0, :min(50, len(sorted_indices[0]))] | |
| top_p_probs = sorted_probs[0, :min(50, len(sorted_probs[0]))] | |
| # Apply top-k to the top-p filtered results | |
| final_k = min(top_k, len(top_p_indices)) | |
| final_indices = top_p_indices[:final_k] | |
| final_probs = top_p_probs[:final_k] | |
| # Convert to tokens | |
| token_ids = [int(idx) for idx in final_indices] | |
| probs = [float(prob) for prob in final_probs] | |
| tokens_text = [tokenizer.decode([tid]) for tid in token_ids] | |
| # Create token data structure | |
| tokens_data = [] | |
| for i in range(len(token_ids)): | |
| tokens_data.append({ | |
| "token": tokens_text[i], | |
| "prob": probs[i] | |
| }) | |
| prediction_time = int((time.time() - start_time) * 1000) | |
| return tokens_data, f"Prediction time: {prediction_time}ms" | |
| except Exception as e: | |
| return [], f"❌ Error: {str(e)}" | |
| def create_clickable_token_display(tokens_data: List[Dict]) -> str: | |
| """Create HTML display with clickable tokens - simplified without JavaScript""" | |
| html = """ | |
| <div id="token-predictions" style="font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace; background: #0e162b; border: 1px solid #1c2945; border-radius: 14px; padding: 12px;"> | |
| """ | |
| for i, token_data in enumerate(tokens_data): | |
| token_display = show_token(token_data['token']) | |
| percentage = f"{token_data['prob'] * 100:.2f}%" | |
| html += f""" | |
| <div class="token-prediction" data-token="{token_data['token']}" | |
| style="display: grid; grid-template-columns: 1fr auto; gap: 8px; align-items: center; padding: 8px 10px; margin: 4px 0; border-radius: 10px; background: #0f1930; border: 1px solid #22365e; cursor: pointer; transition: background 0.2s;" | |
| onmouseover="this.style.background='#1a2b4a'" | |
| onmouseout="this.style.background='#0f1930'"> | |
| <div style="color: #e6f1ff; font-size: 14px;">{token_display}</div> | |
| <div style="color: #9ab0d0; font-size: 12px;">{percentage}</div> | |
| </div> | |
| """ | |
| html += """ | |
| </div> | |
| """ | |
| return html | |
| # Custom CSS to match the Token Visualizer gradient color scheme | |
| custom_css = """ | |
| /* Main container with gradient background like Token Visualizer */ | |
| .gradio-container { | |
| background: linear-gradient(135deg, #0f172a 0%, #1e293b 50%, #0f172a 100%) !important; | |
| color: #e2e8f0 !important; | |
| min-height: 100vh !important; | |
| } | |
| /* Blocks with subtle transparency and borders */ | |
| .block { | |
| background: rgba(15, 23, 42, 0.8) !important; | |
| border: 1px solid rgba(148, 163, 184, 0.2) !important; | |
| border-radius: 12px !important; | |
| backdrop-filter: blur(8px) !important; | |
| } | |
| /* Tab styling */ | |
| .tab-item { | |
| background: rgba(30, 41, 59, 0.6) !important; | |
| border: 1px solid rgba(148, 163, 184, 0.2) !important; | |
| border-radius: 8px !important; | |
| } | |
| /* Token buttons with modern gradient and hover effects */ | |
| .token-button { | |
| background: linear-gradient(135deg, #1e293b 0%, #334155 100%) !important; | |
| border: 1px solid rgba(148, 163, 184, 0.3) !important; | |
| color: #e2e8f0 !important; | |
| border-radius: 8px !important; | |
| margin: 0px !important; | |
| padding: 3px 8px !important; | |
| font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace !important; | |
| transition: all 0.3s ease !important; | |
| font-size: 12px !important; | |
| box-shadow: 0 1px 3px rgba(0, 0, 0, 0.3) !important; | |
| } | |
| .token-button:hover { | |
| background: linear-gradient(135deg, #334155 0%, #475569 100%) !important; | |
| border-color: rgba(148, 163, 184, 0.5) !important; | |
| transform: translateY(-1px) !important; | |
| box-shadow: 0 4px 8px rgba(0, 0, 0, 0.4) !important; | |
| } | |
| /* Input fields styling */ | |
| .gr-textbox { | |
| background: rgba(30, 41, 59, 0.6) !important; | |
| border: 1px solid rgba(148, 163, 184, 0.3) !important; | |
| border-radius: 8px !important; | |
| color: #e2e8f0 !important; | |
| } | |
| /* Slider styling */ | |
| .gr-slider { | |
| background: rgba(30, 41, 59, 0.6) !important; | |
| } | |
| .gr-slider input[type="range"] { | |
| background: linear-gradient(to right, #3b82f6, #06b6d4) !important; | |
| } | |
| /* Labels and text */ | |
| .gr-label { | |
| color: #cbd5e1 !important; | |
| font-weight: 500 !important; | |
| } | |
| .gr-info { | |
| color: #94a3b8 !important; | |
| background: rgba(30, 41, 59, 0.4) !important; | |
| border-radius: 6px !important; | |
| padding: 4px 8px !important; | |
| border: 1px solid rgba(148, 163, 184, 0.2) !important; | |
| } | |
| /* Remove Gradio's default spacing between buttons */ | |
| .token-button + .token-button { | |
| margin-top: 0px !important; | |
| } | |
| /* Remove gaps in the column containing buttons */ | |
| div:has(> .token-button) { | |
| gap: 0px !important; | |
| } | |
| /* Target Gradio's automatic spacing */ | |
| .block > div > div { | |
| gap: 0px !important; | |
| } | |
| /* Hide spinner arrows on number inputs */ | |
| input[type="number"]::-webkit-outer-spin-button, | |
| input[type="number"]::-webkit-inner-spin-button { | |
| -webkit-appearance: none !important; | |
| margin: 0 !important; | |
| } | |
| input[type="number"] { | |
| -moz-appearance: textfield !important; | |
| } | |
| /* Header styling to match Token Visualizer */ | |
| h1, h2, h3, h4 { | |
| background: linear-gradient(135deg, #e2e8f0 0%, #94a3b8 100%) !important; | |
| -webkit-background-clip: text !important; | |
| -webkit-text-fill-color: transparent !important; | |
| background-clip: text !important; | |
| } | |
| /* Interactive tooltip icons in labels */ | |
| .gr-label { | |
| position: relative !important; | |
| } | |
| /* Add tooltip functionality with JavaScript */ | |
| .gr-label:has-text("ⓘ"):hover::after { | |
| content: attr(data-tooltip); | |
| position: absolute; | |
| background: rgba(15, 23, 42, 0.95); | |
| color: #e2e8f0; | |
| padding: 8px 12px; | |
| border-radius: 6px; | |
| font-size: 12px; | |
| max-width: 250px; | |
| border: 1px solid rgba(148, 163, 184, 0.3); | |
| z-index: 1000; | |
| top: 100%; | |
| left: 0; | |
| margin-top: 5px; | |
| white-space: pre-wrap; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=custom_css, title="Next-Token Predictor") as app: | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 32px 20px; background: linear-gradient(135deg, rgba(15, 23, 42, 0.9) 0%, rgba(30, 41, 59, 0.9) 50%, rgba(15, 23, 42, 0.9) 100%); border-bottom: 1px solid rgba(148, 163, 184, 0.2); backdrop-filter: blur(8px);"> | |
| <h1 style="background: linear-gradient(135deg, #e2e8f0 0%, #94a3b8 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; margin: 0; font-size: 28px; font-weight: 600;">Next-Token Predictor</h1> | |
| <p style="color: #94a3b8; margin: 12px 0 0 0; font-size: 16px; opacity: 0.9;">Explore how AI predicts the next word! Click on predictions to append them.</p> | |
| </div> | |
| """) | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Enter your prompt:", | |
| placeholder="Type anything... predictions update automatically!", | |
| value="Twinkle, twinkle, little ", | |
| lines=3, | |
| info="💡 Try: 'The weather today is', 'I think that', 'Once upon a time'" | |
| ) | |
| # Next Token Predictions directly below input | |
| with gr.Column(): | |
| gr.HTML("<h4 style='background: linear-gradient(135deg, #e2e8f0 0%, #94a3b8 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; margin: 0; font-weight: 600;'>Next Token Predictions</h4>") | |
| # Create buttons for each possible token (we'll show/hide as needed) | |
| token_buttons = [] | |
| for i in range(15): # Support up to 15 tokens | |
| btn = gr.Button( | |
| value="", | |
| visible=False, | |
| elem_classes=["token-button"], | |
| size="sm" | |
| ) | |
| token_buttons.append(btn) | |
| # Parameter controls below predictions | |
| with gr.Row(): | |
| with gr.Column(): | |
| top_k = gr.Slider( | |
| minimum=5, | |
| maximum=15, | |
| value=10, | |
| step=1, | |
| label="Top-K", | |
| info="Number of most likely words to consider", | |
| show_label=True, | |
| interactive=True | |
| ) | |
| with gr.Column(): | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Temperature", | |
| info="Controls randomness in predictions", | |
| show_label=True, | |
| interactive=True | |
| ) | |
| with gr.Column(): | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top-P", | |
| info="Probability threshold for word selection", | |
| show_label=True, | |
| interactive=True | |
| ) | |
| timing_info = gr.HTML(value="<div style='color: #94a3b8; font-size: 12px; opacity: 0.8;'>✨ Predictions update as you type!</div>") | |
| # Store current tokens data | |
| current_tokens = gr.State([]) | |
| def update_predictions_and_buttons(text, k, temp, p): | |
| tokens_data, timing = predict_next_token(text, int(k), float(temp), float(p)) | |
| # Update button states | |
| button_updates = [] | |
| for i in range(15): | |
| if i < len(tokens_data): | |
| token = tokens_data[i]['token'] | |
| prob = tokens_data[i]['prob'] | |
| display_token = show_token(token) | |
| button_label = f"{display_token} ({prob*100:.1f}%)" | |
| button_updates.append(gr.Button(value=button_label, visible=True)) | |
| else: | |
| button_updates.append(gr.Button(visible=False)) | |
| return [timing, tokens_data] + button_updates | |
| def append_token_to_input(current_text, tokens_data, button_index): | |
| if tokens_data and 0 <= button_index < len(tokens_data): | |
| token = tokens_data[button_index]['token'] | |
| return current_text + token | |
| return current_text | |
| # Auto-predict on any input change | |
| outputs = [timing_info, current_tokens] + token_buttons | |
| for component in [text_input, top_k, temperature, top_p]: | |
| component.change( | |
| update_predictions_and_buttons, | |
| inputs=[text_input, top_k, temperature, top_p], | |
| outputs=outputs | |
| ) | |
| # Set up click handlers for each token button | |
| for i, btn in enumerate(token_buttons): | |
| btn.click( | |
| lambda text, tokens, idx=i: append_token_to_input(text, tokens, idx), | |
| inputs=[text_input, current_tokens], | |
| outputs=[text_input] | |
| ) | |
| # Load initial predictions on app start | |
| app.load( | |
| lambda: update_predictions_and_buttons("Twinkle, twinkle, little ", 10, 1.0, 0.9), | |
| outputs=outputs | |
| ) | |
| if __name__ == "__main__": | |
| app.launch(share=False) |