import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap # Load a small model model_name = "distilgpt2" # Small model suitable for a demo tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) class OutlineLogitsProcessor(LogitsProcessor): """ A logits processor that enforces an outline structure. """ def __init__(self, outline_tokens, tokenizer, boost_factor=10.0): self.outline_tokens = outline_tokens self.tokenizer = tokenizer self.boost_factor = boost_factor self.current_outline_idx = 0 def __call__(self, input_ids, scores): if self.current_outline_idx < len(self.outline_tokens): # Get the next token from the outline target_token_id = self.outline_tokens[self.current_outline_idx] # Boost probability of the target token scores[target_token_id] += self.boost_factor self.current_outline_idx += 1 return scores def generate_text(prompt, use_outline=False, outline_text=""): """Generate text with or without an outline constraint.""" # Tokenize the prompt input_ids = tokenizer.encode(prompt, return_tensors="pt") logits_processor = None if use_outline and outline_text.strip(): # Tokenize the outline outline_tokens = tokenizer.encode(outline_text)[1:] # Skip the BOS token logits_processor = [OutlineLogitsProcessor(outline_tokens, tokenizer)] # Store token probabilities for visualization all_probs = [] # Function to capture token probabilities def capture_probs(logits): probs = torch.softmax(logits[0, -1, :], dim=-1) all_probs.append(probs.detach().cpu().numpy()) return logits # Generation parameters gen_kwargs = { "max_length": len(input_ids[0]) + 30, "temperature": 0.7, "do_sample": True, "logits_processor": logits_processor, "output_logits": True, # This is needed to capture logits } # Custom generation with probability capture with torch.no_grad(): for _ in range(30): # Generate 30 tokens outputs = model(input_ids) logits = capture_probs(outputs.logits) if logits_processor: for processor in logits_processor: logits = processor(input_ids, logits[0, -1, :]) next_token_probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(next_token_probs, 1) input_ids = torch.cat([input_ids, next_token], dim=-1) # Stop if EOS token is generated if next_token.item() == tokenizer.eos_token_id: break generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True) # Get top tokens and their probabilities for visualization top_tokens = [] for probs in all_probs: top_indices = np.argsort(probs)[-5:][::-1] # Top 5 tokens top_tokens.append([(tokenizer.decode([idx]), float(probs[idx])) for idx in top_indices]) return generated_text, top_tokens def create_probability_plot(top_tokens): """Create a visualization of token probabilities.""" if not top_tokens: return None fig, ax = plt.subplots(figsize=(10, 6)) # Number of tokens and top-k n_tokens = len(top_tokens) top_k = len(top_tokens[0]) # Create a custom colormap that goes from light blue to dark blue colors = [(0.8, 0.9, 1.0), (0.0, 0.4, 0.8)] cmap = LinearSegmentedColormap.from_list("blue_gradient", colors) # Create the heatmap-style visualization data = np.zeros((top_k, n_tokens)) token_labels = [] for i, token_probs in enumerate(top_tokens): # Extract tokens and probabilities tokens = [t[0] for t in token_probs] probs = [t[1] for t in token_probs] # Store probabilities for visualization for j, prob in enumerate(probs): data[j, i] = prob # Store token labels for the first position if i == 0: token_labels = tokens # Plot the heatmap im = ax.imshow(data, aspect='auto', cmap=cmap) # Add colorbar cbar = fig.colorbar(im, ax=ax, label='Probability') # Customize the plot ax.set_yticks(range(top_k)) ax.set_yticklabels(token_labels) ax.set_xlabel('Token Position in Generated Sequence') ax.set_ylabel('Top Tokens') ax.set_title('Token Probabilities During Generation') # Adjust layout and save plt.tight_layout() return fig def interface_fn(prompt, use_outline, outline_text): """Main function for the Gradio interface.""" generated_text, top_tokens = generate_text(prompt, use_outline, outline_text) # Create visualization of token probabilities prob_plot = create_probability_plot(top_tokens) # Format token probabilities as text for display prob_text = "" for i, tokens in enumerate(top_tokens): prob_text += f"Position {i+1}:\n" for token, prob in tokens: prob_text += f" '{token}': {prob:.4f}\n" prob_text += "\n" return generated_text, prob_plot, prob_text # Create the Gradio interface with gr.Blocks(title="Structured Generation Demo") as demo: gr.Markdown("# Structured Generation Demo") gr.Markdown("This demo shows how outlines can constrain language model generation to include specific tokens.") with gr.Row(): with gr.Column(): prompt = gr.Textbox( label="Prompt", placeholder="Enter a prompt to start generation...", value="The most interesting thing about AI is" ) use_outline = gr.Checkbox(label="Use Outline Constraint", value=False) outline_text = gr.Textbox( label="Outline Text (tokens to enforce in order)", placeholder="Enter tokens to enforce in the generation...", value="safety, creativity, and knowledge" ) generate_btn = gr.Button("Generate Text") with gr.Column(): output_text = gr.Textbox(label="Generated Text") prob_plot = gr.Plot(label="Token Probabilities") prob_text = gr.Textbox(label="Detailed Token Probabilities", lines=10) generate_btn.click( interface_fn, inputs=[prompt, use_outline, outline_text], outputs=[output_text, prob_plot, prob_text] ) # Launch the app if __name__ == "__main__": demo.launch()