HandsomeSB's picture
default prompt
9794109
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
import gradio as gr
set_seed(67)
device = "cpu"
# Initialize models and tokenizer
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")
draft_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct", torch_dtype=torch.bfloat16).to(device)
verify_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct", torch_dtype=torch.bfloat16).to(device)
def draft(input_ids, gamma, confidence_threshold, eos_token, past_kv):
generated = input_ids.clone()
draft_probs = []
for _ in range(gamma):
with torch.no_grad():
outputs = draft_model(
generated if past_kv is None else generated[:, -1:],
past_key_values=past_kv,
use_cache=True
)
logits = outputs.logits[:, -1, :]
past_kv = outputs.past_key_values
probs = torch.softmax(logits, dim=-1)
confidence = probs.max().item()
if confidence < confidence_threshold and len(draft_probs) > 0:
break
next_token = torch.argmax(probs, dim=-1, keepdim=True)
draft_probs.append(probs)
generated = torch.cat([generated, next_token], dim=-1)
if next_token.item() == eos_token:
break
return generated, draft_probs, past_kv
def verify(drafted, drafted_probs, eos_token, past_kv):
draft_len = len(drafted_probs)
with torch.no_grad():
if past_kv is None:
target_outputs = verify_model(drafted, use_cache=True)
target_logits = target_outputs.logits[:, -draft_len - 1:-1, :]
else:
target_outputs = verify_model(
drafted[:, -(draft_len + 1):],
past_key_values=past_kv,
use_cache=True
)
target_logits = target_outputs.logits[:, :-1, :]
past_kv = target_outputs.past_key_values
target_probs = torch.softmax(target_logits, dim=-1)
accepted_tokens = []
num_accepted = 0
for i in range(draft_len):
q = drafted_probs[i]
p = target_probs[:, i, :]
token = drafted[:, i - draft_len]
x = token[0].item()
q_x = q[0, x].item()
p_x = p[0, x].item()
if q_x <= p_x:
accepted_tokens.append(x)
num_accepted += 1
else:
r = torch.rand(1).item()
acceptance_rate = p_x / q_x
if r < acceptance_rate:
accepted_tokens.append(x)
num_accepted += 1
else:
adjusted = torch.clamp(p - q, min=0)
adjusted = adjusted / adjusted.sum()
new_token = torch.multinomial(adjusted, num_samples=1)[0].item()
accepted_tokens.append(new_token)
break
if accepted_tokens[-1] == eos_token:
break
return accepted_tokens, num_accepted, past_kv
def generate_visual(prompt, max_tokens=50, gamma=15, confidence_threshold=0.5):
# Prepare input
messages = [{"role": "user", "content": prompt}]
formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
eos_token = tokenizer.eos_token_id
im_end_token = tokenizer.convert_tokens_to_ids("<|im_end|>")
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
result = inputs["input_ids"].clone()
draft_kv = None
verify_kv = None
total_drafted = 0
total_accepted = 0
steps = []
# Track the clean output tokens (only accepted/resampled)
clean_output_tokens = []
all_tokens = []
# Metadata for ALL tokens: 'accepted', 'rejected', or 'resampled'
all_token_metadata = []
def build_html():
html = "<div style='font-family: monospace;'>"
# Clean final output box
html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2px solid white; border-radius: 5px;'>"
html += f"<b>Final Output (Clean):</b><br/>"
if clean_output_tokens:
clean_text = tokenizer.decode(clean_output_tokens)
html += clean_text
html += "</div>"
# Detailed output box
html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2px solid white; border-radius: 5px;'>"
html += f"<b>Detailed Output (All Tokens):</b><br/>"
if all_tokens:
for i, token_id in enumerate(all_tokens):
token_text = tokenizer.decode([token_id])
token_display = token_text.replace("<", "&lt;").replace(">", "&gt;")
if i < len(all_token_metadata):
if all_token_metadata[i] == 'accepted':
html += f"<span style='background: #66CC66; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token_display}</span>"
elif all_token_metadata[i] == 'resampled':
html += f"<span style='background: #5AADCC; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token_display}</span>"
elif all_token_metadata[i] == 'rejected':
html += f"<span style='background: #FF8B9A; padding: 2px 4px; margin: 1px; text-decoration: line-through; border-radius: 3px;'>{token_display}</span>"
else:
html += token_display
html += "</div>"
# Acceptance rate
if total_drafted > 0:
html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2px solid white; border-radius: 5px;'>"
html += f"<b>Acceptance Rate:</b> {total_accepted}/{total_drafted} = {total_accepted/total_drafted*100:.1f}%"
html += "</div>"
# Decoding steps
html += "<div style='margin-bottom: 10px;'><b>Decoding Steps:</b></div>"
for i, step in enumerate(steps):
html += f"<div style='margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px;'>"
html += f"<b>Step {i+1}:</b> "
for j, token in enumerate(step["drafted"]):
token_display = token.replace("<", "&lt;").replace(">", "&gt;")
if j < step["accepted"]:
html += f"<span style='background: #66CC66; padding: 2px 4px; margin: 2px; border-radius: 3px;'>{token_display}</span>"
else:
html += f"<span style='background: #FF8B9A; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'>{token_display}</span>"
if step["resampled"]:
resampled_display = step["resampled"].replace("<", "&lt;").replace(">", "&gt;")
html += f" β†’ <span style='background: #5AADCC; padding: 2px 4px; border-radius: 3px;'>{resampled_display}</span>"
html += "</div>"
html += "</div>"
return html
while result.shape[-1] - inputs["input_ids"].shape[-1] < max_tokens:
# Draft
drafted, drafted_probs, draft_kv = draft(result, gamma, confidence_threshold, eos_token, draft_kv)
drafted_token_ids = drafted[0, -len(drafted_probs):].tolist()
drafted_tokens = [tokenizer.decode([t]) for t in drafted_token_ids]
clean_output_tokens.extend(drafted_token_ids)
all_tokens.extend(drafted_token_ids)
all_token_metadata.extend(['accepted'] * len(drafted_token_ids))
temp_step = {
"drafted": drafted_tokens,
"accepted": len(drafted_tokens),
"resampled": None
}
steps.append(temp_step)
total_drafted += len(drafted_probs)
yield build_html()
# Verify
accepted_tokens, num_accepted, verify_kv = verify(drafted, drafted_probs, eos_token, verify_kv)
total_accepted += num_accepted
clean_output_tokens = clean_output_tokens[:-len(drafted_token_ids)]
all_token_metadata = all_token_metadata[:-len(drafted_token_ids)]
for i, token_id in enumerate(drafted_token_ids):
if i < num_accepted:
all_token_metadata.append('accepted')
else:
all_token_metadata.append('rejected')
clean_output_tokens.extend(accepted_tokens)
if num_accepted < len(accepted_tokens):
all_tokens.append(accepted_tokens[-1])
all_token_metadata.append('resampled')
steps[-1] = {
"drafted": drafted_tokens,
"accepted": num_accepted,
"resampled": tokenizer.decode([accepted_tokens[-1]]) if num_accepted < len(accepted_tokens) else None
}
yield build_html()
valid_len = result.shape[-1] + num_accepted
result = torch.cat([result, torch.tensor([accepted_tokens], device=device)], dim=-1)
if draft_kv is not None:
draft_kv.crop(max_length=valid_len)
if verify_kv is not None:
verify_kv.crop(max_length=valid_len)
if eos_token in accepted_tokens or im_end_token in accepted_tokens:
break
yield build_html()
demo = gr.Interface(
fn=generate_visual,
inputs=[
gr.Textbox(label="Prompt", value="What is the capital of France?", lines=3),
gr.Slider(minimum=10, maximum=100, value=50, step=10, label="Max Tokens"),
gr.Slider(minimum=1, maximum=30, value=15, step=1, label="Gamma (draft lookahead)"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold")
],
outputs=gr.HTML(label="Speculative Decoding Visualization"),
title="πŸš€ Speculative Decoding Demo",
description="""
**Speculative Decoding Visualization** using SmolLM2 models
- **Draft Model**: HuggingFaceTB/SmolLM2-135M-Instruct (fast)
- **Verify Model**: HuggingFaceTB/SmolLM2-1.7B-Instruct (accurate)
**Color Legend:**
- 🟒 Green = Accepted tokens from draft model
- πŸ”΄ Red = Rejected tokens (with strikethrough)
- πŸ”΅ Blue = Resampled tokens from verify model
**Watch the tokens stream in real-time!** Draft tokens appear immediately, then get accepted or rejected by the verify model.
""",
examples=[
["What is the capital of France?", 80, 15, 0.5],
["Complete the python function \n def fibonacci(n):", 50, 15, 0.5],
["Explain the concept of attention in transformers", 60, 10, 0.6]
]
)
if __name__ == "__main__":
demo.launch()