Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import LinearSegmentedColormap | |
# Load a small model | |
model_name = "distilgpt2" # Small model suitable for a demo | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
class OutlineLogitsProcessor(LogitsProcessor): | |
""" | |
A logits processor that enforces an outline structure. | |
""" | |
def __init__(self, outline_tokens, tokenizer, boost_factor=10.0): | |
self.outline_tokens = outline_tokens | |
self.tokenizer = tokenizer | |
self.boost_factor = boost_factor | |
self.current_outline_idx = 0 | |
def __call__(self, input_ids, scores): | |
if self.current_outline_idx < len(self.outline_tokens): | |
# Get the next token from the outline | |
target_token_id = self.outline_tokens[self.current_outline_idx] | |
# Boost probability of the target token | |
scores[target_token_id] += self.boost_factor | |
self.current_outline_idx += 1 | |
return scores | |
def generate_text(prompt, use_outline=False, outline_text=""): | |
"""Generate text with or without an outline constraint.""" | |
# Tokenize the prompt | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
logits_processor = None | |
if use_outline and outline_text.strip(): | |
# Tokenize the outline | |
outline_tokens = tokenizer.encode(outline_text)[1:] # Skip the BOS token | |
logits_processor = [OutlineLogitsProcessor(outline_tokens, tokenizer)] | |
# Store token probabilities for visualization | |
all_probs = [] | |
# Function to capture token probabilities | |
def capture_probs(logits): | |
probs = torch.softmax(logits[0, -1, :], dim=-1) | |
all_probs.append(probs.detach().cpu().numpy()) | |
return logits | |
# Generation parameters | |
gen_kwargs = { | |
"max_length": len(input_ids[0]) + 30, | |
"temperature": 0.7, | |
"do_sample": True, | |
"logits_processor": logits_processor, | |
"output_logits": True, # This is needed to capture logits | |
} | |
# Custom generation with probability capture | |
with torch.no_grad(): | |
for _ in range(30): # Generate 30 tokens | |
outputs = model(input_ids) | |
logits = capture_probs(outputs.logits) | |
if logits_processor: | |
for processor in logits_processor: | |
logits = processor(input_ids, logits[0, -1, :]) | |
next_token_probs = torch.softmax(logits, dim=-1) | |
next_token = torch.multinomial(next_token_probs, 1) | |
input_ids = torch.cat([input_ids, next_token], dim=-1) | |
# Stop if EOS token is generated | |
if next_token.item() == tokenizer.eos_token_id: | |
break | |
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
# Get top tokens and their probabilities for visualization | |
top_tokens = [] | |
for probs in all_probs: | |
top_indices = np.argsort(probs)[-5:][::-1] # Top 5 tokens | |
top_tokens.append([(tokenizer.decode([idx]), float(probs[idx])) for idx in top_indices]) | |
return generated_text, top_tokens | |
def create_probability_plot(top_tokens): | |
"""Create a visualization of token probabilities.""" | |
if not top_tokens: | |
return None | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
# Number of tokens and top-k | |
n_tokens = len(top_tokens) | |
top_k = len(top_tokens[0]) | |
# Create a custom colormap that goes from light blue to dark blue | |
colors = [(0.8, 0.9, 1.0), (0.0, 0.4, 0.8)] | |
cmap = LinearSegmentedColormap.from_list("blue_gradient", colors) | |
# Create the heatmap-style visualization | |
data = np.zeros((top_k, n_tokens)) | |
token_labels = [] | |
for i, token_probs in enumerate(top_tokens): | |
# Extract tokens and probabilities | |
tokens = [t[0] for t in token_probs] | |
probs = [t[1] for t in token_probs] | |
# Store probabilities for visualization | |
for j, prob in enumerate(probs): | |
data[j, i] = prob | |
# Store token labels for the first position | |
if i == 0: | |
token_labels = tokens | |
# Plot the heatmap | |
im = ax.imshow(data, aspect='auto', cmap=cmap) | |
# Add colorbar | |
cbar = fig.colorbar(im, ax=ax, label='Probability') | |
# Customize the plot | |
ax.set_yticks(range(top_k)) | |
ax.set_yticklabels(token_labels) | |
ax.set_xlabel('Token Position in Generated Sequence') | |
ax.set_ylabel('Top Tokens') | |
ax.set_title('Token Probabilities During Generation') | |
# Adjust layout and save | |
plt.tight_layout() | |
return fig | |
def interface_fn(prompt, use_outline, outline_text): | |
"""Main function for the Gradio interface.""" | |
generated_text, top_tokens = generate_text(prompt, use_outline, outline_text) | |
# Create visualization of token probabilities | |
prob_plot = create_probability_plot(top_tokens) | |
# Format token probabilities as text for display | |
prob_text = "" | |
for i, tokens in enumerate(top_tokens): | |
prob_text += f"Position {i+1}:\n" | |
for token, prob in tokens: | |
prob_text += f" '{token}': {prob:.4f}\n" | |
prob_text += "\n" | |
return generated_text, prob_plot, prob_text | |
# Create the Gradio interface | |
with gr.Blocks(title="Structured Generation Demo") as demo: | |
gr.Markdown("# Structured Generation Demo") | |
gr.Markdown("This demo shows how outlines can constrain language model generation to include specific tokens.") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Enter a prompt to start generation...", | |
value="The most interesting thing about AI is" | |
) | |
use_outline = gr.Checkbox(label="Use Outline Constraint", value=False) | |
outline_text = gr.Textbox( | |
label="Outline Text (tokens to enforce in order)", | |
placeholder="Enter tokens to enforce in the generation...", | |
value="safety, creativity, and knowledge" | |
) | |
generate_btn = gr.Button("Generate Text") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Generated Text") | |
prob_plot = gr.Plot(label="Token Probabilities") | |
prob_text = gr.Textbox(label="Detailed Token Probabilities", lines=10) | |
generate_btn.click( | |
interface_fn, | |
inputs=[prompt, use_outline, outline_text], | |
outputs=[output_text, prob_plot, prob_text] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |