| import gradio as gr |
| import torch |
| import html as html_lib |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase") |
| model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase") |
| model.eval() |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
|
|
| def get_color(p): |
| hue = min(p * 120, 120) |
| return f"hsl({hue},80%,35%)", f"hsla({hue},80%,50%,0.15)" |
|
|
| def analyze_text(text, top_k): |
| top_k = max(1, int(top_k)) |
| if not text.strip(): |
| return "<p style='color:#999;text-align:center;padding:40px'>Paste some text and click Analyze.</p>" |
|
|
| tokens = tokenizer.encode(text) |
| if len(tokens) > 512: |
| tokens = tokens[:512] |
|
|
| with torch.no_grad(): |
| input_ids = torch.tensor([tokens]).to(device) |
| all_logits = model(input_ids).logits[0].cpu() |
|
|
| css = """<style> |
| .tc{display:flex;flex-wrap:wrap;gap:5px;padding:20px;line-height:2.4;font-family:'Segoe UI',sans-serif} |
| .tw{position:relative;display:inline-block} |
| .tk{padding:4px 7px;border-radius:6px;cursor:pointer;font-size:15px;transition:.2s;border:1px solid transparent;user-select:none} |
| .tw:hover .tk{transform:translateY(-2px);box-shadow:0 4px 14px rgba(0,0,0,.18);border-color:#999} |
| .tt{display:none;position:absolute;bottom:calc(100% + 8px);left:50%;transform:translateX(-50%); |
| background:#1a1a2e;color:#eee;padding:14px;border-radius:12px;font-size:13px;z-index:9999; |
| box-shadow:0 10px 30px rgba(0,0,0,.35);min-width:220px;max-height:350px;overflow-y:auto} |
| .tt::after{content:'';position:absolute;top:100%;left:0;width:100%;height:12px} |
| .tw:hover .tt{display:block} |
| .tw.pinned .tt{display:block} |
| .tw.pinned .tk{transform:translateY(-2px);box-shadow:0 4px 14px rgba(0,0,0,.18);border-color:#999;outline:2px solid #7fdbca} |
| .th{font-weight:700;font-size:14px;color:#7fdbca;border-bottom:1px solid #333;padding-bottom:6px;margin-bottom:6px} |
| .tp{color:#ffd700;margin-bottom:8px} |
| .at{color:#ff79c6;font-size:10px;text-transform:uppercase;letter-spacing:1px;margin-bottom:4px} |
| .aw{display:flex;justify-content:space-between;padding:2px 0;font-size:12px} |
| .aw .w{color:#c3cee3}.aw .p{color:#666;margin-left:14px} |
| .hi{font-weight:700;color:#7fdbca!important} |
| </style> |
| <script> |
| document.addEventListener('click', function(e) { |
| const tk = e.target.closest('.tk'); |
| const tw = tk ? tk.closest('.tw') : null; |
| if (tw) { |
| const wasPinned = tw.classList.contains('pinned'); |
| document.querySelectorAll('.tw.pinned').forEach(el => el.classList.remove('pinned')); |
| if (!wasPinned) tw.classList.add('pinned'); |
| } else if (!e.target.closest('.tt')) { |
| document.querySelectorAll('.tw.pinned').forEach(el => el.classList.remove('pinned')); |
| } |
| }); |
| </script>""" |
|
|
| parts = [css, '<div class="tc">'] |
| for i in range(len(tokens)): |
| tok = html_lib.escape(tokenizer.decode([tokens[i]])) |
| if i == 0: |
| parts.append(f'<div class="tw"><span class="tk" style="background:rgba(128,128,128,.1);color:#888">{tok}</span></div>') |
| continue |
|
|
| probs = torch.softmax(all_logits[i - 1], dim=-1) |
| actual_p = probs[tokens[i]].item() |
| top_p, top_idx = probs.topk(top_k) |
| color, bg = get_color(actual_p) |
|
|
| rank = None |
| alts = "" |
| for j in range(top_k): |
| a_text = html_lib.escape(tokenizer.decode([top_idx[j].item()])) |
| a_p = top_p[j].item() |
| hit = top_idx[j].item() == tokens[i] |
| if hit: rank = j + 1 |
| cls = ' class="w hi"' if hit else ' class="w"' |
| pcls = ' class="p hi"' if hit else ' class="p"' |
| alts += f'<div class="aw"><span{cls}>{a_text}</span><span{pcls}>{a_p:.4f}</span></div>' |
|
|
| rank_s = f"rank #{rank}" if rank else f"rank >{top_k}" |
| tooltip = f'''<div class="tt"> |
| <div class="th">“{tok}”</div> |
| <div class="tp">P = {actual_p:.4f} ({rank_s})</div> |
| <div class="at">Top {top_k} alternatives</div>{alts}</div>''' |
|
|
| parts.append(f'<div class="tw"><span class="tk" style="background:{bg};color:{color}">{tok}</span>{tooltip}</div>') |
|
|
| parts.append('</div>') |
| return ''.join(parts) |
|
|
|
|
| def predict_next(text, num_candidates): |
| num_candidates = max(1, int(num_candidates)) |
| if not text.strip(): |
| return "<p style='color:#999;text-align:center;padding:40px'>Enter text and click Predict Next.</p>" |
|
|
| tokens = tokenizer.encode(text) |
| if len(tokens) > 512: |
| tokens = tokens[:512] |
|
|
| with torch.no_grad(): |
| input_ids = torch.tensor([tokens]).to(device) |
| logits = model(input_ids).logits[0, -1].cpu() |
|
|
| probs = torch.softmax(logits, dim=-1) |
| log_probs = torch.log(probs) |
| top_p, top_idx = probs.topk(num_candidates) |
| top_lp = log_probs[top_idx] |
|
|
| rows = "" |
| for j in range(num_candidates): |
| tok_text = html_lib.escape(tokenizer.decode([top_idx[j].item()])) |
| p = top_p[j].item() |
| lp = top_lp[j].item() |
| bar_width = max(1, int(p * 100)) |
| hue = min(p * 120, 120) |
| rows += f"""<tr> |
| <td style="padding:6px 12px;font-weight:600;color:#e0e0e0;white-space:nowrap">{j+1}</td> |
| <td style="padding:6px 12px;font-family:monospace;font-size:15px;color:#7fdbca;white-space:nowrap">{tok_text}</td> |
| <td style="padding:6px 12px;width:100%"> |
| <div style="background:hsla({hue},80%,50%,0.25);border-radius:4px;height:22px;width:{bar_width}%;min-width:2px;display:flex;align-items:center;padding-left:6px"> |
| <span style="font-size:11px;color:hsl({hue},80%,70%);font-weight:600">{p:.4f}</span> |
| </div> |
| </td> |
| <td style="padding:6px 12px;font-family:monospace;font-size:13px;color:#888;white-space:nowrap">{lp:.4f}</td> |
| </tr>""" |
|
|
| html = f"""<div style="font-family:'Segoe UI',sans-serif;background:#1a1a2e;border-radius:12px;padding:16px;overflow-x:auto"> |
| <div style="color:#ff79c6;font-size:11px;text-transform:uppercase;letter-spacing:1px;margin-bottom:10px"> |
| Top {num_candidates} predicted next tokens</div> |
| <table style="width:100%;border-collapse:collapse"> |
| <thead><tr style="border-bottom:1px solid #333"> |
| <th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">#</th> |
| <th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">TOKEN</th> |
| <th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">PROBABILITY</th> |
| <th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">LOG PROB</th> |
| </tr></thead> |
| <tbody>{rows}</tbody> |
| </table></div>""" |
| return html |
|
|
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# 🔍 Token Probability Explorer & Predictor\nPaste text, **hover** to preview or **click** a token to pin its tooltip open. Click elsewhere to dismiss.") |
|
|
| text_input = gr.Textbox(label="Input Text", placeholder="Paste your text here…", lines=5) |
|
|
| with gr.Row(): |
| top_k_input = gr.Number(label="# Alternatives (Analysis)", value=10, minimum=1, maximum=200, step=1) |
| num_candidates_input = gr.Number(label="# Next Token Candidates", value=10, minimum=1, maximum=200, step=1) |
|
|
| with gr.Row(): |
| btn_analyze = gr.Button("Analyze", variant="primary") |
| btn_predict = gr.Button("Predict Next", variant="secondary") |
|
|
| output_analysis = gr.HTML(label="Analysis Output") |
| output_prediction = gr.HTML(label="Predicted Next Tokens") |
|
|
| btn_analyze.click(fn=analyze_text, inputs=[text_input, top_k_input], outputs=output_analysis) |
| btn_predict.click(fn=predict_next, inputs=[text_input, num_candidates_input], outputs=output_prediction) |
|
|
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| theme=gr.themes.Soft(), |
| css="footer{display:none!important}.main{max-width:960px;margin:auto}" |
| ) |