constrained / app.py
davanstrien's picture
davanstrien HF Staff
Update app.py
aaeefcc verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
# Load a small model
model_name = "distilgpt2" # Small model suitable for a demo
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
class OutlineLogitsProcessor(LogitsProcessor):
"""
A logits processor that enforces an outline structure.
"""
def __init__(self, outline_tokens, tokenizer, boost_factor=10.0):
self.outline_tokens = outline_tokens
self.tokenizer = tokenizer
self.boost_factor = boost_factor
self.current_outline_idx = 0
def __call__(self, input_ids, scores):
if self.current_outline_idx < len(self.outline_tokens):
# Get the next token from the outline
target_token_id = self.outline_tokens[self.current_outline_idx]
# Boost probability of the target token
scores[target_token_id] += self.boost_factor
self.current_outline_idx += 1
return scores
def generate_text(prompt, use_outline=False, outline_text=""):
"""Generate text with or without an outline constraint."""
# Tokenize the prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt")
logits_processor = None
if use_outline and outline_text.strip():
# Tokenize the outline
outline_tokens = tokenizer.encode(outline_text)[1:] # Skip the BOS token
logits_processor = [OutlineLogitsProcessor(outline_tokens, tokenizer)]
# Store token probabilities for visualization
all_probs = []
# Function to capture token probabilities
def capture_probs(logits):
probs = torch.softmax(logits[0, -1, :], dim=-1)
all_probs.append(probs.detach().cpu().numpy())
return logits
# Generation parameters
gen_kwargs = {
"max_length": len(input_ids[0]) + 30,
"temperature": 0.7,
"do_sample": True,
"logits_processor": logits_processor,
"output_logits": True, # This is needed to capture logits
}
# Custom generation with probability capture
with torch.no_grad():
for _ in range(30): # Generate 30 tokens
outputs = model(input_ids)
logits = capture_probs(outputs.logits)
if logits_processor:
for processor in logits_processor:
logits = processor(input_ids, logits[0, -1, :])
next_token_probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(next_token_probs, 1)
input_ids = torch.cat([input_ids, next_token], dim=-1)
# Stop if EOS token is generated
if next_token.item() == tokenizer.eos_token_id:
break
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
# Get top tokens and their probabilities for visualization
top_tokens = []
for probs in all_probs:
top_indices = np.argsort(probs)[-5:][::-1] # Top 5 tokens
top_tokens.append([(tokenizer.decode([idx]), float(probs[idx])) for idx in top_indices])
return generated_text, top_tokens
def create_probability_plot(top_tokens):
"""Create a visualization of token probabilities."""
if not top_tokens:
return None
fig, ax = plt.subplots(figsize=(10, 6))
# Number of tokens and top-k
n_tokens = len(top_tokens)
top_k = len(top_tokens[0])
# Create a custom colormap that goes from light blue to dark blue
colors = [(0.8, 0.9, 1.0), (0.0, 0.4, 0.8)]
cmap = LinearSegmentedColormap.from_list("blue_gradient", colors)
# Create the heatmap-style visualization
data = np.zeros((top_k, n_tokens))
token_labels = []
for i, token_probs in enumerate(top_tokens):
# Extract tokens and probabilities
tokens = [t[0] for t in token_probs]
probs = [t[1] for t in token_probs]
# Store probabilities for visualization
for j, prob in enumerate(probs):
data[j, i] = prob
# Store token labels for the first position
if i == 0:
token_labels = tokens
# Plot the heatmap
im = ax.imshow(data, aspect='auto', cmap=cmap)
# Add colorbar
cbar = fig.colorbar(im, ax=ax, label='Probability')
# Customize the plot
ax.set_yticks(range(top_k))
ax.set_yticklabels(token_labels)
ax.set_xlabel('Token Position in Generated Sequence')
ax.set_ylabel('Top Tokens')
ax.set_title('Token Probabilities During Generation')
# Adjust layout and save
plt.tight_layout()
return fig
def interface_fn(prompt, use_outline, outline_text):
"""Main function for the Gradio interface."""
generated_text, top_tokens = generate_text(prompt, use_outline, outline_text)
# Create visualization of token probabilities
prob_plot = create_probability_plot(top_tokens)
# Format token probabilities as text for display
prob_text = ""
for i, tokens in enumerate(top_tokens):
prob_text += f"Position {i+1}:\n"
for token, prob in tokens:
prob_text += f" '{token}': {prob:.4f}\n"
prob_text += "\n"
return generated_text, prob_plot, prob_text
# Create the Gradio interface
with gr.Blocks(title="Structured Generation Demo") as demo:
gr.Markdown("# Structured Generation Demo")
gr.Markdown("This demo shows how outlines can constrain language model generation to include specific tokens.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
placeholder="Enter a prompt to start generation...",
value="The most interesting thing about AI is"
)
use_outline = gr.Checkbox(label="Use Outline Constraint", value=False)
outline_text = gr.Textbox(
label="Outline Text (tokens to enforce in order)",
placeholder="Enter tokens to enforce in the generation...",
value="safety, creativity, and knowledge"
)
generate_btn = gr.Button("Generate Text")
with gr.Column():
output_text = gr.Textbox(label="Generated Text")
prob_plot = gr.Plot(label="Token Probabilities")
prob_text = gr.Textbox(label="Detailed Token Probabilities", lines=10)
generate_btn.click(
interface_fn,
inputs=[prompt, use_outline, outline_text],
outputs=[output_text, prob_plot, prob_text]
)
# Launch the app
if __name__ == "__main__":
demo.launch()