davanstrien HF Staff commited on
Commit
aaeefcc
·
verified ·
1 Parent(s): 7685871

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -135
app.py CHANGED
@@ -1,140 +1,194 @@
1
- """
2
- app.py – Gradio demo for structured (constrained) generation with Outlines
3
- -----------------------------------------------------------------------
4
- Deploy this file (plus a requirements.txt) to a **Gradio** Space on
5
- Hugging Face. The UI is intentionally minimal so you can embed the Space
6
- in an `<iframe>` on a slide.
7
-
8
- **requirements.txt** (put this in the same repo):
9
- ```
10
- gradio>=4.28.0
11
- transformers>=4.40.0
12
- outlines>=0.0.36
13
- torch
14
- ```
15
-
16
- After pushing both files, Spaces will build the image automatically. The
17
- Space URL (e.g. `https://username-spacename.hf.space`) can be embedded
18
- with:
19
- ```html
20
- <iframe src="https://username-spacename.hf.space" width="640" height="480"></iframe>
21
- ```
22
- """
23
-
24
  import gradio as gr
25
  import torch
26
- from transformers import AutoModelForCausalLM, AutoTokenizer
27
- import outlines # structured‑generation library
28
-
29
- MODEL_NAME = "distilgpt2" # small & free to download (≈ 300 MB)
30
-
31
- # Load model / tokenizer once at start‑up
32
- print("Loading model first launch may take ~20 s on CPU…")
33
-
34
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
35
- model = AutoModelForCausalLM.from_pretrained(
36
- MODEL_NAME,
37
- torch_dtype=torch.float32,
38
- ).eval()
39
-
40
- # ---------------------------------------------------------------------------
41
- # 1️⃣ Helper: baseline generation (no constraints)
42
- # ---------------------------------------------------------------------------
43
-
44
- def generate_baseline(prompt: str, max_tokens: int = 64, temperature: float = 0.7):
45
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
- model.to(device)
47
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
48
- output_ids = model.generate(
49
- **inputs,
50
- max_new_tokens=max_tokens,
51
- temperature=temperature,
52
- top_k=50,
53
- )
54
- return tokenizer.decode(output_ids[0], skip_special_tokens=True)
55
-
56
- # ---------------------------------------------------------------------------
57
- # 2️⃣ Helper: constrained generation with Outlines
58
- # ---------------------------------------------------------------------------
59
- # For demo purposes we request that *at some point* the token string “OpenAI”
60
- # appears in the output. Any regex that the `re` module accepts will work.
61
- # You can expose it as an additional textbox if you want users to edit it.
62
-
63
- PATTERN = r".*OpenAI.*"
64
-
65
- # Build a generator bound to the regex once; it rewires the model’s logits
66
- # so forbidden tokens get probability −∞ (effectively zero).
67
-
68
- generator = outlines.generate.regex(model, PATTERN)
69
-
70
-
71
- def generate_constrained(prompt: str, max_tokens: int = 64, temperature: float = 0.7):
72
- return generator(prompt, max_tokens=max_tokens, temperature=temperature)
73
-
74
- # ---------------------------------------------------------------------------
75
- # 3️⃣ Helper: show top‑10 next‑token probabilities *before* and *after*
76
- # applying the regex constraint, to make the effect visible.
77
- # ---------------------------------------------------------------------------
78
-
79
- def _topk_probs(logits: torch.Tensor, k: int = 10):
80
- """Return {token: prob} for the k most likely tokens."""
81
- probs = torch.softmax(logits, dim=-1)
82
- topk = torch.topk(probs, k)
83
- tokens = [tokenizer.decode(idx) for idx in topk.indices[0]]
84
- return {t.replace("\n", "\\n"): float(p) for t, p in zip(tokens, topk.values[0])}
85
-
86
-
87
- def compare(prompt: str):
88
- # Baseline text
89
- baseline_text = generate_baseline(prompt)
90
-
91
- # Constrained text
92
- constrained_text = generate_constrained(prompt)
93
-
94
- # Get logits for next token after the *prompt* (not after full generation)
95
  with torch.no_grad():
96
- inputs = tokenizer(prompt, return_tensors="pt")
97
- base_logits = model(**inputs).logits[:, -1, :]
98
-
99
- # Apply Outlines’ regex sampler to obtain constrained logits
100
- regex_sampler = outlines.samplers.RegexSampler(PATTERN)
101
- constrained_logits = regex_sampler(base_logits.clone(), inputs.input_ids)
102
-
103
- baseline_topk = _topk_probs(base_logits)
104
- constrained_topk = _topk_probs(constrained_logits)
105
-
106
- return baseline_text, constrained_text, baseline_topk, constrained_topk
107
-
108
- # ---------------------------------------------------------------------------
109
- # 4️⃣ Gradio UI – minimal so it fits nicely inside slides
110
- # ---------------------------------------------------------------------------
111
-
112
- def build_interface():
113
- with gr.Blocks() as demo:
114
- gr.Markdown("## Structured Generation Demo (Outlines)")
115
-
116
- prompt = gr.Textbox(lines=3, label="Prompt", placeholder="e.g. A short story about innovative AI")
117
- generate_btn = gr.Button("Generate")
118
-
119
- with gr.Row():
120
- baseline_out = gr.Textbox(label="Baseline output (unconstrained)")
121
- constrained_out = gr.Textbox(label="Constrained output (must contain 'OpenAI')")
122
-
123
- with gr.Row():
124
- baseline_probs = gr.JSON(label="Top‑10 next‑token probs (baseline)")
125
- constrained_probs = gr.JSON(label="Top‑10 next‑token probs (constrained)")
126
-
127
- generate_btn.click(compare, inputs=prompt, outputs=[
128
- baseline_out,
129
- constrained_out,
130
- baseline_probs,
131
- constrained_probs,
132
- ])
133
-
134
- return demo
135
-
136
-
137
- demo = build_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
 
139
  if __name__ == "__main__":
140
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from matplotlib.colors import LinearSegmentedColormap
7
+
8
+ # Load a small model
9
+ model_name = "distilgpt2" # Small model suitable for a demo
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForCausalLM.from_pretrained(model_name)
12
+
13
+ class OutlineLogitsProcessor(LogitsProcessor):
14
+ """
15
+ A logits processor that enforces an outline structure.
16
+ """
17
+ def __init__(self, outline_tokens, tokenizer, boost_factor=10.0):
18
+ self.outline_tokens = outline_tokens
19
+ self.tokenizer = tokenizer
20
+ self.boost_factor = boost_factor
21
+ self.current_outline_idx = 0
22
+
23
+ def __call__(self, input_ids, scores):
24
+ if self.current_outline_idx < len(self.outline_tokens):
25
+ # Get the next token from the outline
26
+ target_token_id = self.outline_tokens[self.current_outline_idx]
27
+
28
+ # Boost probability of the target token
29
+ scores[target_token_id] += self.boost_factor
30
+ self.current_outline_idx += 1
31
+
32
+ return scores
33
+
34
+ def generate_text(prompt, use_outline=False, outline_text=""):
35
+ """Generate text with or without an outline constraint."""
36
+
37
+ # Tokenize the prompt
38
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
39
+
40
+ logits_processor = None
41
+ if use_outline and outline_text.strip():
42
+ # Tokenize the outline
43
+ outline_tokens = tokenizer.encode(outline_text)[1:] # Skip the BOS token
44
+ logits_processor = [OutlineLogitsProcessor(outline_tokens, tokenizer)]
45
+
46
+ # Store token probabilities for visualization
47
+ all_probs = []
48
+
49
+ # Function to capture token probabilities
50
+ def capture_probs(logits):
51
+ probs = torch.softmax(logits[0, -1, :], dim=-1)
52
+ all_probs.append(probs.detach().cpu().numpy())
53
+ return logits
54
+
55
+ # Generation parameters
56
+ gen_kwargs = {
57
+ "max_length": len(input_ids[0]) + 30,
58
+ "temperature": 0.7,
59
+ "do_sample": True,
60
+ "logits_processor": logits_processor,
61
+ "output_logits": True, # This is needed to capture logits
62
+ }
63
+
64
+ # Custom generation with probability capture
 
 
 
 
 
 
 
65
  with torch.no_grad():
66
+ for _ in range(30): # Generate 30 tokens
67
+ outputs = model(input_ids)
68
+ logits = capture_probs(outputs.logits)
69
+
70
+ if logits_processor:
71
+ for processor in logits_processor:
72
+ logits = processor(input_ids, logits[0, -1, :])
73
+
74
+ next_token_probs = torch.softmax(logits, dim=-1)
75
+ next_token = torch.multinomial(next_token_probs, 1)
76
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
77
+
78
+ # Stop if EOS token is generated
79
+ if next_token.item() == tokenizer.eos_token_id:
80
+ break
81
+
82
+ generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
83
+
84
+ # Get top tokens and their probabilities for visualization
85
+ top_tokens = []
86
+ for probs in all_probs:
87
+ top_indices = np.argsort(probs)[-5:][::-1] # Top 5 tokens
88
+ top_tokens.append([(tokenizer.decode([idx]), float(probs[idx])) for idx in top_indices])
89
+
90
+ return generated_text, top_tokens
91
+
92
+ def create_probability_plot(top_tokens):
93
+ """Create a visualization of token probabilities."""
94
+ if not top_tokens:
95
+ return None
96
+
97
+ fig, ax = plt.subplots(figsize=(10, 6))
98
+
99
+ # Number of tokens and top-k
100
+ n_tokens = len(top_tokens)
101
+ top_k = len(top_tokens[0])
102
+
103
+ # Create a custom colormap that goes from light blue to dark blue
104
+ colors = [(0.8, 0.9, 1.0), (0.0, 0.4, 0.8)]
105
+ cmap = LinearSegmentedColormap.from_list("blue_gradient", colors)
106
+
107
+ # Create the heatmap-style visualization
108
+ data = np.zeros((top_k, n_tokens))
109
+ token_labels = []
110
+
111
+ for i, token_probs in enumerate(top_tokens):
112
+ # Extract tokens and probabilities
113
+ tokens = [t[0] for t in token_probs]
114
+ probs = [t[1] for t in token_probs]
115
+
116
+ # Store probabilities for visualization
117
+ for j, prob in enumerate(probs):
118
+ data[j, i] = prob
119
+
120
+ # Store token labels for the first position
121
+ if i == 0:
122
+ token_labels = tokens
123
+
124
+ # Plot the heatmap
125
+ im = ax.imshow(data, aspect='auto', cmap=cmap)
126
+
127
+ # Add colorbar
128
+ cbar = fig.colorbar(im, ax=ax, label='Probability')
129
+
130
+ # Customize the plot
131
+ ax.set_yticks(range(top_k))
132
+ ax.set_yticklabels(token_labels)
133
+ ax.set_xlabel('Token Position in Generated Sequence')
134
+ ax.set_ylabel('Top Tokens')
135
+ ax.set_title('Token Probabilities During Generation')
136
+
137
+ # Adjust layout and save
138
+ plt.tight_layout()
139
+ return fig
140
+
141
+ def interface_fn(prompt, use_outline, outline_text):
142
+ """Main function for the Gradio interface."""
143
+ generated_text, top_tokens = generate_text(prompt, use_outline, outline_text)
144
+
145
+ # Create visualization of token probabilities
146
+ prob_plot = create_probability_plot(top_tokens)
147
+
148
+ # Format token probabilities as text for display
149
+ prob_text = ""
150
+ for i, tokens in enumerate(top_tokens):
151
+ prob_text += f"Position {i+1}:\n"
152
+ for token, prob in tokens:
153
+ prob_text += f" '{token}': {prob:.4f}\n"
154
+ prob_text += "\n"
155
+
156
+ return generated_text, prob_plot, prob_text
157
+
158
+ # Create the Gradio interface
159
+ with gr.Blocks(title="Structured Generation Demo") as demo:
160
+ gr.Markdown("# Structured Generation Demo")
161
+ gr.Markdown("This demo shows how outlines can constrain language model generation to include specific tokens.")
162
+
163
+ with gr.Row():
164
+ with gr.Column():
165
+ prompt = gr.Textbox(
166
+ label="Prompt",
167
+ placeholder="Enter a prompt to start generation...",
168
+ value="The most interesting thing about AI is"
169
+ )
170
+
171
+ use_outline = gr.Checkbox(label="Use Outline Constraint", value=False)
172
+
173
+ outline_text = gr.Textbox(
174
+ label="Outline Text (tokens to enforce in order)",
175
+ placeholder="Enter tokens to enforce in the generation...",
176
+ value="safety, creativity, and knowledge"
177
+ )
178
+
179
+ generate_btn = gr.Button("Generate Text")
180
+
181
+ with gr.Column():
182
+ output_text = gr.Textbox(label="Generated Text")
183
+ prob_plot = gr.Plot(label="Token Probabilities")
184
+ prob_text = gr.Textbox(label="Detailed Token Probabilities", lines=10)
185
+
186
+ generate_btn.click(
187
+ interface_fn,
188
+ inputs=[prompt, use_outline, outline_text],
189
+ outputs=[output_text, prob_plot, prob_text]
190
+ )
191
 
192
+ # Launch the app
193
  if __name__ == "__main__":
194
+ demo.launch()