File size: 6,923 Bytes
c2b40ae
 
aaeefcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2b40ae
aaeefcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2b40ae
aaeefcc
c2b40ae
aaeefcc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# Load a small model
model_name = "distilgpt2"  # Small model suitable for a demo
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

class OutlineLogitsProcessor(LogitsProcessor):
    """
    A logits processor that enforces an outline structure.
    """
    def __init__(self, outline_tokens, tokenizer, boost_factor=10.0):
        self.outline_tokens = outline_tokens
        self.tokenizer = tokenizer
        self.boost_factor = boost_factor
        self.current_outline_idx = 0
        
    def __call__(self, input_ids, scores):
        if self.current_outline_idx < len(self.outline_tokens):
            # Get the next token from the outline
            target_token_id = self.outline_tokens[self.current_outline_idx]
            
            # Boost probability of the target token
            scores[target_token_id] += self.boost_factor
            self.current_outline_idx += 1
            
        return scores

def generate_text(prompt, use_outline=False, outline_text=""):
    """Generate text with or without an outline constraint."""
    
    # Tokenize the prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    
    logits_processor = None
    if use_outline and outline_text.strip():
        # Tokenize the outline
        outline_tokens = tokenizer.encode(outline_text)[1:]  # Skip the BOS token
        logits_processor = [OutlineLogitsProcessor(outline_tokens, tokenizer)]
    
    # Store token probabilities for visualization
    all_probs = []
    
    # Function to capture token probabilities
    def capture_probs(logits):
        probs = torch.softmax(logits[0, -1, :], dim=-1)
        all_probs.append(probs.detach().cpu().numpy())
        return logits
    
    # Generation parameters
    gen_kwargs = {
        "max_length": len(input_ids[0]) + 30,
        "temperature": 0.7,
        "do_sample": True,
        "logits_processor": logits_processor,
        "output_logits": True,  # This is needed to capture logits
    }
    
    # Custom generation with probability capture
    with torch.no_grad():
        for _ in range(30):  # Generate 30 tokens
            outputs = model(input_ids)
            logits = capture_probs(outputs.logits)
            
            if logits_processor:
                for processor in logits_processor:
                    logits = processor(input_ids, logits[0, -1, :])
            
            next_token_probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(next_token_probs, 1)
            input_ids = torch.cat([input_ids, next_token], dim=-1)
            
            # Stop if EOS token is generated
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    
    # Get top tokens and their probabilities for visualization
    top_tokens = []
    for probs in all_probs:
        top_indices = np.argsort(probs)[-5:][::-1]  # Top 5 tokens
        top_tokens.append([(tokenizer.decode([idx]), float(probs[idx])) for idx in top_indices])
    
    return generated_text, top_tokens

def create_probability_plot(top_tokens):
    """Create a visualization of token probabilities."""
    if not top_tokens:
        return None
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Number of tokens and top-k
    n_tokens = len(top_tokens)
    top_k = len(top_tokens[0])
    
    # Create a custom colormap that goes from light blue to dark blue
    colors = [(0.8, 0.9, 1.0), (0.0, 0.4, 0.8)]
    cmap = LinearSegmentedColormap.from_list("blue_gradient", colors)
    
    # Create the heatmap-style visualization
    data = np.zeros((top_k, n_tokens))
    token_labels = []
    
    for i, token_probs in enumerate(top_tokens):
        # Extract tokens and probabilities
        tokens = [t[0] for t in token_probs]
        probs = [t[1] for t in token_probs]
        
        # Store probabilities for visualization
        for j, prob in enumerate(probs):
            data[j, i] = prob
        
        # Store token labels for the first position
        if i == 0:
            token_labels = tokens
    
    # Plot the heatmap
    im = ax.imshow(data, aspect='auto', cmap=cmap)
    
    # Add colorbar
    cbar = fig.colorbar(im, ax=ax, label='Probability')
    
    # Customize the plot
    ax.set_yticks(range(top_k))
    ax.set_yticklabels(token_labels)
    ax.set_xlabel('Token Position in Generated Sequence')
    ax.set_ylabel('Top Tokens')
    ax.set_title('Token Probabilities During Generation')
    
    # Adjust layout and save
    plt.tight_layout()
    return fig

def interface_fn(prompt, use_outline, outline_text):
    """Main function for the Gradio interface."""
    generated_text, top_tokens = generate_text(prompt, use_outline, outline_text)
    
    # Create visualization of token probabilities
    prob_plot = create_probability_plot(top_tokens)
    
    # Format token probabilities as text for display
    prob_text = ""
    for i, tokens in enumerate(top_tokens):
        prob_text += f"Position {i+1}:\n"
        for token, prob in tokens:
            prob_text += f"  '{token}': {prob:.4f}\n"
        prob_text += "\n"
    
    return generated_text, prob_plot, prob_text

# Create the Gradio interface
with gr.Blocks(title="Structured Generation Demo") as demo:
    gr.Markdown("# Structured Generation Demo")
    gr.Markdown("This demo shows how outlines can constrain language model generation to include specific tokens.")
    
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(
                label="Prompt",
                placeholder="Enter a prompt to start generation...",
                value="The most interesting thing about AI is"
            )
            
            use_outline = gr.Checkbox(label="Use Outline Constraint", value=False)
            
            outline_text = gr.Textbox(
                label="Outline Text (tokens to enforce in order)",
                placeholder="Enter tokens to enforce in the generation...",
                value="safety, creativity, and knowledge"
            )
            
            generate_btn = gr.Button("Generate Text")
        
        with gr.Column():
            output_text = gr.Textbox(label="Generated Text")
            prob_plot = gr.Plot(label="Token Probabilities")
            prob_text = gr.Textbox(label="Detailed Token Probabilities", lines=10)
    
    generate_btn.click(
        interface_fn,
        inputs=[prompt, use_outline, outline_text],
        outputs=[output_text, prob_plot, prob_text]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()