Spaces:
Sleeping
Sleeping
File size: 6,923 Bytes
c2b40ae aaeefcc c2b40ae aaeefcc c2b40ae aaeefcc c2b40ae aaeefcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
# Load a small model
model_name = "distilgpt2" # Small model suitable for a demo
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
class OutlineLogitsProcessor(LogitsProcessor):
"""
A logits processor that enforces an outline structure.
"""
def __init__(self, outline_tokens, tokenizer, boost_factor=10.0):
self.outline_tokens = outline_tokens
self.tokenizer = tokenizer
self.boost_factor = boost_factor
self.current_outline_idx = 0
def __call__(self, input_ids, scores):
if self.current_outline_idx < len(self.outline_tokens):
# Get the next token from the outline
target_token_id = self.outline_tokens[self.current_outline_idx]
# Boost probability of the target token
scores[target_token_id] += self.boost_factor
self.current_outline_idx += 1
return scores
def generate_text(prompt, use_outline=False, outline_text=""):
"""Generate text with or without an outline constraint."""
# Tokenize the prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt")
logits_processor = None
if use_outline and outline_text.strip():
# Tokenize the outline
outline_tokens = tokenizer.encode(outline_text)[1:] # Skip the BOS token
logits_processor = [OutlineLogitsProcessor(outline_tokens, tokenizer)]
# Store token probabilities for visualization
all_probs = []
# Function to capture token probabilities
def capture_probs(logits):
probs = torch.softmax(logits[0, -1, :], dim=-1)
all_probs.append(probs.detach().cpu().numpy())
return logits
# Generation parameters
gen_kwargs = {
"max_length": len(input_ids[0]) + 30,
"temperature": 0.7,
"do_sample": True,
"logits_processor": logits_processor,
"output_logits": True, # This is needed to capture logits
}
# Custom generation with probability capture
with torch.no_grad():
for _ in range(30): # Generate 30 tokens
outputs = model(input_ids)
logits = capture_probs(outputs.logits)
if logits_processor:
for processor in logits_processor:
logits = processor(input_ids, logits[0, -1, :])
next_token_probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(next_token_probs, 1)
input_ids = torch.cat([input_ids, next_token], dim=-1)
# Stop if EOS token is generated
if next_token.item() == tokenizer.eos_token_id:
break
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
# Get top tokens and their probabilities for visualization
top_tokens = []
for probs in all_probs:
top_indices = np.argsort(probs)[-5:][::-1] # Top 5 tokens
top_tokens.append([(tokenizer.decode([idx]), float(probs[idx])) for idx in top_indices])
return generated_text, top_tokens
def create_probability_plot(top_tokens):
"""Create a visualization of token probabilities."""
if not top_tokens:
return None
fig, ax = plt.subplots(figsize=(10, 6))
# Number of tokens and top-k
n_tokens = len(top_tokens)
top_k = len(top_tokens[0])
# Create a custom colormap that goes from light blue to dark blue
colors = [(0.8, 0.9, 1.0), (0.0, 0.4, 0.8)]
cmap = LinearSegmentedColormap.from_list("blue_gradient", colors)
# Create the heatmap-style visualization
data = np.zeros((top_k, n_tokens))
token_labels = []
for i, token_probs in enumerate(top_tokens):
# Extract tokens and probabilities
tokens = [t[0] for t in token_probs]
probs = [t[1] for t in token_probs]
# Store probabilities for visualization
for j, prob in enumerate(probs):
data[j, i] = prob
# Store token labels for the first position
if i == 0:
token_labels = tokens
# Plot the heatmap
im = ax.imshow(data, aspect='auto', cmap=cmap)
# Add colorbar
cbar = fig.colorbar(im, ax=ax, label='Probability')
# Customize the plot
ax.set_yticks(range(top_k))
ax.set_yticklabels(token_labels)
ax.set_xlabel('Token Position in Generated Sequence')
ax.set_ylabel('Top Tokens')
ax.set_title('Token Probabilities During Generation')
# Adjust layout and save
plt.tight_layout()
return fig
def interface_fn(prompt, use_outline, outline_text):
"""Main function for the Gradio interface."""
generated_text, top_tokens = generate_text(prompt, use_outline, outline_text)
# Create visualization of token probabilities
prob_plot = create_probability_plot(top_tokens)
# Format token probabilities as text for display
prob_text = ""
for i, tokens in enumerate(top_tokens):
prob_text += f"Position {i+1}:\n"
for token, prob in tokens:
prob_text += f" '{token}': {prob:.4f}\n"
prob_text += "\n"
return generated_text, prob_plot, prob_text
# Create the Gradio interface
with gr.Blocks(title="Structured Generation Demo") as demo:
gr.Markdown("# Structured Generation Demo")
gr.Markdown("This demo shows how outlines can constrain language model generation to include specific tokens.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
placeholder="Enter a prompt to start generation...",
value="The most interesting thing about AI is"
)
use_outline = gr.Checkbox(label="Use Outline Constraint", value=False)
outline_text = gr.Textbox(
label="Outline Text (tokens to enforce in order)",
placeholder="Enter tokens to enforce in the generation...",
value="safety, creativity, and knowledge"
)
generate_btn = gr.Button("Generate Text")
with gr.Column():
output_text = gr.Textbox(label="Generated Text")
prob_plot = gr.Plot(label="Token Probabilities")
prob_text = gr.Textbox(label="Detailed Token Probabilities", lines=10)
generate_btn.click(
interface_fn,
inputs=[prompt, use_outline, outline_text],
outputs=[output_text, prob_plot, prob_text]
)
# Launch the app
if __name__ == "__main__":
demo.launch() |