asadshahab commited on
Commit
dd850a7
·
1 Parent(s): 6df9665
.gitignore ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ pip-wheel-metadata/
20
+ share/python-wheels/
21
+ *.egg-info/
22
+ .installed.cfg
23
+ *.egg
24
+ MANIFEST
25
+
26
+ # Virtual environments
27
+ venv/
28
+ env/
29
+ ENV/
30
+ env.bak/
31
+ venv.bak/
32
+ .venv/
33
+
34
+ # IDE and editors
35
+ .vscode/
36
+ .idea/
37
+ *.swp
38
+ *.swo
39
+ *~
40
+ .claude/
41
+
42
+ # OS generated files
43
+ .DS_Store
44
+ .DS_Store?
45
+ ._*
46
+ .Spotlight-V100
47
+ .Trashes
48
+ ehthumbs.db
49
+ Thumbs.db
50
+
51
+ # Logs and databases
52
+ *.log
53
+ *.sqlite3
54
+ *.db
55
+
56
+ # Model cache and downloads
57
+ models/
58
+ .cache/
59
+ huggingface_hub/
60
+ transformers_cache/
61
+
62
+ # Temporary files
63
+ *.tmp
64
+ *.temp
65
+ .tmp/
66
+
67
+ # Environment variables
68
+ .env
69
+ .env.local
70
+ .env.production
71
+ .env.staging
72
+
73
+ # Jupyter Notebook
74
+ .ipynb_checkpoints
75
+
76
+ # pytest
77
+ .pytest_cache/
78
+ .coverage
79
+
80
+ # mypy
81
+ .mypy_cache/
82
+ .dmypy.json
83
+ dmypy.json
84
+
85
+ # Local development
86
+ local_test.py
87
+ test_*.py
88
+ debug.py
89
+
90
+ # Gradio temporary files
91
+ .gradio/
92
+ gradio_cached_examples/
93
+ flagged/
94
+
95
+ # Large files that shouldn't be in git
96
+ *.bin
97
+ *.safetensors
98
+ *.pt
99
+ *.pth
100
+ *.ckpt
101
+ *.h5
102
+
103
+ # Documentation build
104
+ docs/_build/
README.md CHANGED
@@ -1,14 +1,79 @@
1
  ---
2
  title: Token Attention Visualizer
3
- emoji: 📈
4
- colorFrom: purple
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
- short_description: An interactive tool for visualizing attention patterns in La
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Token Attention Visualizer
3
+ emoji: 🔍
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.0.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
 
11
  ---
12
 
13
+ # Token Attention Visualizer
14
+
15
+ An interactive tool for visualizing attention patterns in Large Language Models during text generation.
16
+
17
+ ## Features
18
+
19
+ - 🚀 **Real-time Generation**: Generate text with any Hugging Face model
20
+ - 🔍 **Attention Visualization**: Explore attention patterns with clear visual representations
21
+ - 📊 **Dual Normalization**: Choose between separate or joint attention normalization
22
+ - ⚡ **Smart Caching**: Fast response with intelligent result caching
23
+ - 🎯 **Token Selection**: Use dropdown menus to select and filter token connections
24
+ - 📈 **Step Navigation**: Navigate through generation steps
25
+ - 🎨 **Customizable Threshold**: Filter weak attention connections
26
+
27
+ ## How It Works
28
+
29
+ The visualizer shows how tokens attend to each other during text generation:
30
+ - **Blue lines**: Attention from input tokens to output tokens
31
+ - **Orange curves**: Attention between output tokens
32
+ - **Line thickness**: Represents attention weight strength
33
+
34
+ ## Usage
35
+
36
+ 1. **Load a Model**: Enter a Hugging Face model name (default: HuggingFaceTB/SmolLM-135M-Instruct)
37
+ 2. **Enter Prompt**: Type your input text
38
+ 3. **Configure Settings**: Adjust max tokens, temperature, and normalization
39
+ 4. **Generate**: Click to generate text and visualize attention
40
+ 5. **Explore**: Use dropdown menus to select tokens and view their attention patterns
41
+
42
+ ## Technical Details
43
+
44
+ - Built with Gradio for the interface
45
+ - Visualization system with dropdown-based token selection
46
+ - Supports any Hugging Face causal language model
47
+ - Optimized for smaller models like SmolLM for efficient deployment
48
+ - Implements efficient attention processing and caching
49
+
50
+ ## Local Development
51
+
52
+ ```bash
53
+ # Clone the repository
54
+ git clone <repo-url>
55
+ cd token-attention-viz
56
+
57
+ # Install dependencies
58
+ pip install -r requirements.txt
59
+
60
+ # Run the app
61
+ python app.py
62
+ ```
63
+
64
+ ## Deployment
65
+
66
+ This app is designed for easy deployment on Hugging Face Spaces. Simply:
67
+ 1. Create a new Space
68
+ 2. Upload the project files
69
+ 3. The app will automatically start
70
+
71
+ ## Requirements
72
+
73
+ - Python 3.8+
74
+ - 4GB+ RAM (SmolLM models are lightweight)
75
+ - GPU acceleration optional (works well on CPU)
76
+
77
+ ## License
78
+
79
+ Apache 2.0
api/__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import sys
4
+ import os
5
+ import json
6
+ from pathlib import Path
7
+
8
+ # Add project root to path
9
+ sys.path.insert(0, str(Path(__file__).parent))
10
+
11
+ from core.model_handler import ModelHandler
12
+ from core.attention import AttentionProcessor
13
+ from core.cache import AttentionCache
14
+ from config import Config
15
+ from visualization.d3_viz import create_d3_visualization
16
+
17
+ class TokenVisualizerApp:
18
+ def __init__(self):
19
+ self.config = Config()
20
+ self.model_handler = ModelHandler(config=self.config)
21
+ self.cache = AttentionCache(max_size=self.config.CACHE_SIZE)
22
+ self.current_data = None
23
+ self.model_loaded = False
24
+
25
+
26
+ def load_model(self, model_name: str = None) -> str:
27
+ """Load the model and return status message."""
28
+ if not model_name:
29
+ model_name = self.config.DEFAULT_MODEL
30
+
31
+ success, message = self.model_handler.load_model(model_name)
32
+ self.model_loaded = success
33
+
34
+ if success:
35
+ model_info = self.model_handler.get_model_info()
36
+ return f"✅ Model loaded: {model_name}\n📊 Parameters: {model_info['num_parameters']:,}\n🖥️ Device: {model_info['device']}"
37
+ else:
38
+ return f"❌ Failed to load model: {message}"
39
+
40
+ def generate_and_visualize(
41
+ self,
42
+ prompt: str,
43
+ max_tokens: int,
44
+ threshold: float,
45
+ temperature: float,
46
+ normalization: str,
47
+ progress=gr.Progress()
48
+ ):
49
+ """Main generation function (no visualization)."""
50
+ if not self.model_loaded:
51
+ return None, "Please load a model first!", None
52
+
53
+ if not prompt.strip():
54
+ return None, "Please enter a prompt!", None
55
+
56
+ progress(0.2, desc="Checking cache...")
57
+
58
+ # Check cache
59
+ cache_key = self.cache.get_key(
60
+ prompt, max_tokens,
61
+ self.model_handler.model_name,
62
+ temperature
63
+ )
64
+ cached = self.cache.get(cache_key)
65
+
66
+ if cached:
67
+ progress(0.5, desc="Using cached data...")
68
+ self.current_data = cached
69
+ else:
70
+ progress(0.3, desc="Generating text...")
71
+
72
+ # Generate new
73
+ attention_data, output_tokens, input_tokens, generated_text = \
74
+ self.model_handler.generate_with_attention(
75
+ prompt, max_tokens, temperature
76
+ )
77
+
78
+ if attention_data is None:
79
+ return None, f"Generation failed: {generated_text}", None
80
+
81
+ progress(0.6, desc="Processing attention...")
82
+
83
+ # Process attention based on normalization method
84
+ if normalization == "separate":
85
+ attention_matrices = AttentionProcessor.process_attention_separate(
86
+ attention_data, input_tokens, output_tokens
87
+ )
88
+ else:
89
+ attention_matrices = AttentionProcessor.process_attention_joint(
90
+ attention_data, input_tokens, output_tokens
91
+ )
92
+
93
+ self.current_data = {
94
+ 'input_tokens': input_tokens,
95
+ 'output_tokens': output_tokens,
96
+ 'attention_matrices': attention_matrices,
97
+ 'generated_text': generated_text,
98
+ 'attention_data': attention_data # Keep raw for step updates
99
+ }
100
+
101
+ # Cache it
102
+ self.cache.set(cache_key, self.current_data)
103
+
104
+ progress(1.0, desc="Complete!")
105
+
106
+ # Create info text
107
+ info_text = f"📝 Generated: {self.current_data['generated_text']}\n"
108
+ info_text += f"🔤 Input tokens: {len(self.current_data['input_tokens'])}\n"
109
+ info_text += f"🔤 Output tokens: {len(self.current_data['output_tokens'])}"
110
+
111
+ return (
112
+ info_text,
113
+ )
114
+
115
+ def update_step(self, step_idx: int, threshold: float):
116
+ """No-op placeholder after removing visualization."""
117
+ return None
118
+
119
+ def update_threshold(self, threshold: float, normalization: str):
120
+ """No-op placeholder after removing visualization."""
121
+ return None
122
+
123
+ def filter_token_connections(self, token_idx: int, token_type: str, threshold: float):
124
+ """Removed visualization; keep placeholder."""
125
+ return None
126
+
127
+ def reset_view(self, threshold: float):
128
+ """Removed visualization; keep placeholder."""
129
+ return None
130
+
131
+ def on_d3_token_click(self, click_data: str, threshold: float):
132
+ """Removed visualization; keep placeholder for compatibility."""
133
+ return None, gr.update()
134
+
135
+ def on_input_token_select(self, token_label: str, threshold: float):
136
+ """Removed visualization; keep placeholder for compatibility."""
137
+ return None
138
+
139
+ def prepare_d3_data(self, step_idx: int, threshold: float = 0.01, filter_token: str = None):
140
+ """
141
+ Convert attention data to D3.js-friendly JSON format.
142
+
143
+ Args:
144
+ step_idx: Generation step to visualize (0-based)
145
+ threshold: Minimum attention weight to include
146
+ filter_token: Token to filter by (format: "[IN] token" or "[OUT] token" or "All tokens")
147
+
148
+ Returns:
149
+ dict: JSON structure with nodes and links for D3.js
150
+ """
151
+ if not self.current_data:
152
+ return {"nodes": [], "links": []}
153
+
154
+ input_tokens = self.current_data['input_tokens']
155
+ output_tokens = self.current_data['output_tokens']
156
+ attention_matrices = self.current_data['attention_matrices']
157
+
158
+ # Ensure step_idx is within bounds
159
+ if step_idx >= len(attention_matrices):
160
+ step_idx = len(attention_matrices) - 1
161
+
162
+ attention_matrix = attention_matrices[step_idx]
163
+
164
+ # Create nodes
165
+ nodes = []
166
+
167
+ # Add input nodes
168
+ for i, token in enumerate(input_tokens):
169
+ nodes.append({
170
+ "id": f"input_{i}",
171
+ "token": token,
172
+ "type": "input",
173
+ "index": i
174
+ })
175
+
176
+ # Add output nodes (up to current step)
177
+ for i in range(step_idx + 1):
178
+ if i < len(output_tokens):
179
+ nodes.append({
180
+ "id": f"output_{i}",
181
+ "token": output_tokens[i],
182
+ "type": "output",
183
+ "index": i
184
+ })
185
+
186
+ # Parse filter token
187
+ filter_type = None
188
+ filter_idx = None
189
+ if filter_token and filter_token != "All tokens":
190
+ if filter_token.startswith("[IN] "):
191
+ filter_type = "input"
192
+ filter_token_text = filter_token[5:] # Remove "[IN] " prefix
193
+ filter_idx = next((i for i, token in enumerate(input_tokens) if token == filter_token_text), None)
194
+ elif filter_token.startswith("[OUT] "):
195
+ filter_type = "output"
196
+ filter_token_text = filter_token[6:] # Remove "[OUT] " prefix
197
+ filter_idx = next((i for i, token in enumerate(output_tokens) if token == filter_token_text), None)
198
+
199
+ # Create links from attention matrices - show ALL steps up to current step
200
+ links = []
201
+
202
+ # Show connections for all steps up to and including step_idx
203
+ for current_step in range(step_idx + 1):
204
+ if current_step < len(attention_matrices):
205
+ step_attention = attention_matrices[current_step]
206
+
207
+ # Links from input tokens to this output token
208
+ input_attention = step_attention['input_attention']
209
+ if input_attention is not None:
210
+ for input_idx in range(len(input_tokens)):
211
+ if input_idx < len(input_attention): # Check bounds
212
+ weight = float(input_attention[input_idx])
213
+ if weight >= threshold:
214
+ # Apply filtering
215
+ show_link = True
216
+ if filter_type == "input" and filter_idx is not None:
217
+ # Only show connections involving the selected input token
218
+ show_link = (input_idx == filter_idx)
219
+ elif filter_type == "output" and filter_idx is not None:
220
+ # Only show connections involving the selected output token
221
+ show_link = (current_step == filter_idx)
222
+
223
+ if show_link:
224
+ links.append({
225
+ "source": f"input_{input_idx}",
226
+ "target": f"output_{current_step}",
227
+ "weight": weight,
228
+ "type": "input_to_output"
229
+ })
230
+
231
+ # Links from previous output tokens to this output token
232
+ output_attention = step_attention['output_attention']
233
+ if output_attention is not None and current_step > 0:
234
+ for prev_output_idx in range(current_step):
235
+ if prev_output_idx < len(output_attention): # Check bounds
236
+ weight = float(output_attention[prev_output_idx])
237
+ if weight >= threshold:
238
+ # Apply filtering
239
+ show_link = True
240
+ if filter_type == "input" and filter_idx is not None:
241
+ # Don't show output-to-output connections when filtering by input
242
+ show_link = False
243
+ elif filter_type == "output" and filter_idx is not None:
244
+ # Only show connections involving the selected output token
245
+ show_link = (prev_output_idx == filter_idx or current_step == filter_idx)
246
+
247
+ if show_link:
248
+ links.append({
249
+ "source": f"output_{prev_output_idx}",
250
+ "target": f"output_{current_step}",
251
+ "weight": weight,
252
+ "type": "output_to_output"
253
+ })
254
+
255
+ return {
256
+ "nodes": nodes,
257
+ "links": links,
258
+ "step": step_idx,
259
+ "total_steps": len(attention_matrices),
260
+ "input_count": len(input_tokens),
261
+ "output_count": step_idx + 1
262
+ }
263
+
264
+ def create_d3_visualization_html(self, step_idx: int = 0, threshold: float = 0.01, filter_token: str = None):
265
+ """
266
+ Create D3.js visualization HTML for the current data.
267
+
268
+ Args:
269
+ step_idx: Generation step to visualize (0-based)
270
+ threshold: Minimum attention weight to include
271
+ filter_token: Token to filter by (format: "[IN] token" or "[OUT] token")
272
+
273
+ Returns:
274
+ str: HTML string for D3.js visualization
275
+ """
276
+ if not self.current_data:
277
+ return "<div>No data available. Generate text first!</div>"
278
+
279
+ d3_data = self.prepare_d3_data(step_idx, threshold, filter_token)
280
+
281
+ viz_html = create_d3_visualization(d3_data)
282
+ return viz_html
283
+
284
+ def get_token_choices(self):
285
+ """
286
+ Get list of token choices for dropdown.
287
+
288
+ Returns:
289
+ list: List of token strings for dropdown options
290
+ """
291
+ if not self.current_data:
292
+ return []
293
+
294
+ input_tokens = self.current_data['input_tokens']
295
+ output_tokens = self.current_data['output_tokens']
296
+
297
+ # Create choices with prefixes to distinguish input/output
298
+ choices = ["All tokens"]
299
+ choices.extend([f"[IN] {token}" for token in input_tokens])
300
+ choices.extend([f"[OUT] {token}" for token in output_tokens])
301
+
302
+ return choices
303
+
304
+
305
+ def create_gradio_interface():
306
+ """Create the Gradio interface."""
307
+ app = TokenVisualizerApp()
308
+
309
+ with gr.Blocks(
310
+ title="Token Attention Visualizer",
311
+ css="""
312
+ /* Default/Light mode styles */
313
+ .main-header {
314
+ text-align: center;
315
+ padding: 2rem 0 3rem 0;
316
+ background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%);
317
+ border-radius: 1rem;
318
+ margin-bottom: 2rem;
319
+ border: 1px solid #e2e8f0;
320
+ }
321
+
322
+ .main-title {
323
+ font-size: 2.5rem;
324
+ font-weight: 700;
325
+ color: #1e293b;
326
+ margin-bottom: 0.5rem;
327
+ background: linear-gradient(135deg, #1e293b 0%, #3b82f6 100%);
328
+ -webkit-background-clip: text;
329
+ -webkit-text-fill-color: transparent;
330
+ background-clip: text;
331
+ }
332
+
333
+ .main-subtitle {
334
+ font-size: 1.125rem;
335
+ color: #64748b;
336
+ font-weight: 400;
337
+ }
338
+
339
+ .section-title {
340
+ font-size: 1.25rem;
341
+ font-weight: 600;
342
+ color: #1e293b;
343
+ margin-bottom: 1.5rem;
344
+ padding-bottom: 0.5rem;
345
+ border-bottom: 2px solid #e2e8f0;
346
+ }
347
+
348
+ /* Explicit light mode overrides */
349
+ .light .main-header,
350
+ [data-theme="light"] .main-header {
351
+ background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%);
352
+ border: 1px solid #e2e8f0;
353
+ }
354
+
355
+ .light .main-title,
356
+ [data-theme="light"] .main-title {
357
+ color: #1e293b;
358
+ background: linear-gradient(135deg, #1e293b 0%, #3b82f6 100%);
359
+ -webkit-background-clip: text;
360
+ -webkit-text-fill-color: transparent;
361
+ background-clip: text;
362
+ }
363
+
364
+ .light .main-subtitle,
365
+ [data-theme="light"] .main-subtitle {
366
+ color: #64748b;
367
+ }
368
+
369
+ .light .section-title,
370
+ [data-theme="light"] .section-title {
371
+ color: #1e293b;
372
+ border-bottom: 2px solid #e2e8f0;
373
+ }
374
+
375
+ /* Dark mode styles with higher specificity */
376
+ .dark .main-header,
377
+ [data-theme="dark"] .main-header {
378
+ background: linear-gradient(135deg, #1e293b 0%, #334155 100%) !important;
379
+ border: 1px solid #475569 !important;
380
+ }
381
+
382
+ .dark .main-title,
383
+ [data-theme="dark"] .main-title {
384
+ color: #f1f5f9 !important;
385
+ background: linear-gradient(135deg, #f1f5f9 0%, #60a5fa 100%) !important;
386
+ -webkit-background-clip: text !important;
387
+ -webkit-text-fill-color: transparent !important;
388
+ background-clip: text !important;
389
+ }
390
+
391
+ .dark .main-subtitle,
392
+ [data-theme="dark"] .main-subtitle {
393
+ color: #cbd5e1 !important;
394
+ }
395
+
396
+ .dark .section-title,
397
+ [data-theme="dark"] .section-title {
398
+ color: #f1f5f9 !important;
399
+ border-bottom: 2px solid #475569 !important;
400
+ }
401
+
402
+ /* System dark mode - only apply when no explicit theme is set */
403
+ @media (prefers-color-scheme: dark) {
404
+ :root:not([data-theme="light"]) .main-header {
405
+ background: linear-gradient(135deg, #1e293b 0%, #334155 100%);
406
+ border: 1px solid #475569;
407
+ }
408
+
409
+ :root:not([data-theme="light"]) .main-title {
410
+ color: #f1f5f9;
411
+ background: linear-gradient(135deg, #f1f5f9 0%, #60a5fa 100%);
412
+ -webkit-background-clip: text;
413
+ -webkit-text-fill-color: transparent;
414
+ background-clip: text;
415
+ }
416
+
417
+ :root:not([data-theme="light"]) .main-subtitle {
418
+ color: #cbd5e1;
419
+ }
420
+
421
+ :root:not([data-theme="light"]) .section-title {
422
+ color: #f1f5f9;
423
+ border-bottom: 2px solid #475569;
424
+ }
425
+ }
426
+
427
+ .load-model-btn {
428
+ background: linear-gradient(135deg, #f97316 0%, #ea580c 100%) !important;
429
+ color: white !important;
430
+ border: none !important;
431
+ font-weight: 600 !important;
432
+ padding: 0.75rem 2rem !important;
433
+ border-radius: 0.5rem !important;
434
+ box-shadow: 0 4px 6px -1px rgba(249, 115, 22, 0.25) !important;
435
+ transition: all 0.2s ease !important;
436
+ }
437
+
438
+ .load-model-btn:hover {
439
+ background: linear-gradient(135deg, #ea580c 0%, #dc2626 100%) !important;
440
+ transform: translateY(-1px) !important;
441
+ box-shadow: 0 6px 8px -1px rgba(249, 115, 22, 0.35) !important;
442
+ }
443
+ """
444
+ ) as demo:
445
+ gr.HTML("""
446
+ <div class="main-header">
447
+ <h1 class="main-title">Token Attention Visualizer</h1>
448
+ <p class="main-subtitle">Interactive visualization of attention patterns in Large Language Models</p>
449
+ </div>
450
+ """)
451
+
452
+ with gr.Row():
453
+ # Left Panel - Controls
454
+ with gr.Column(scale=1):
455
+ gr.HTML('<h2 class="section-title">Model & Generation</h2>')
456
+
457
+ # Model loading
458
+ model_input = gr.Textbox(
459
+ label="Model Name",
460
+ value=app.config.DEFAULT_MODEL,
461
+ placeholder="Enter Hugging Face model name..."
462
+ )
463
+ load_model_btn = gr.Button("Load Model", variant="primary", elem_classes=["load-model-btn"])
464
+
465
+ model_status = gr.Textbox(
466
+ label="Model Status",
467
+ value="No model loaded",
468
+ interactive=False,
469
+ lines=2
470
+ )
471
+
472
+ # Generation controls
473
+ prompt_input = gr.Textbox(
474
+ label="Prompt",
475
+ value=app.config.DEFAULT_PROMPT,
476
+ lines=3,
477
+ placeholder="Enter your prompt here..."
478
+ )
479
+
480
+ max_tokens_input = gr.Slider(
481
+ minimum=1,
482
+ maximum=50,
483
+ value=app.config.DEFAULT_MAX_TOKENS,
484
+ step=1,
485
+ label="Max Tokens"
486
+ )
487
+
488
+ temperature_input = gr.Slider(
489
+ minimum=0.0,
490
+ maximum=2.0,
491
+ value=app.config.DEFAULT_TEMPERATURE,
492
+ step=0.1,
493
+ label="Temperature"
494
+ )
495
+
496
+ generate_btn = gr.Button("Generate", variant="primary", size="lg")
497
+
498
+ generated_info = gr.Textbox(
499
+ label="Generation Info",
500
+ interactive=False,
501
+ lines=4
502
+ )
503
+
504
+ gr.HTML('<h2 class="section-title">Visualization Controls</h2>')
505
+
506
+ step_slider = gr.Slider(
507
+ minimum=0,
508
+ maximum=10,
509
+ value=0,
510
+ step=1,
511
+ label="Generation Step",
512
+ info="Navigate through generation steps"
513
+ )
514
+
515
+ threshold_slider = gr.Slider(
516
+ minimum=0.001,
517
+ maximum=0.5,
518
+ value=0.01,
519
+ step=0.001,
520
+ label="Attention Threshold",
521
+ info="Filter weak connections"
522
+ )
523
+
524
+ token_dropdown = gr.Dropdown(
525
+ choices=["All tokens"],
526
+ value="All tokens",
527
+ label="Filter by Token",
528
+ info="Select a token to highlight"
529
+ )
530
+
531
+ # Right Panel - Visualization
532
+ with gr.Column(scale=2):
533
+ gr.HTML('<h2 class="section-title">Attention Visualization</h2>')
534
+
535
+ d3_visualization = gr.HTML(
536
+ value="""<div style='height: 700px; display: flex; align-items: center; justify-content: center; font-size: 16px;'>
537
+ <div style='text-align: center;'>
538
+ <div style='font-size: 3rem; margin-bottom: 16px; opacity: 0.5;'>⚪</div>
539
+ <div style='font-weight: 500; margin-bottom: 8px;'>Ready to visualize</div>
540
+ <div>Generate text to see attention patterns</div>
541
+ </div>
542
+ </div>"""
543
+ )
544
+
545
+ # (Visualization output and overlay removed)
546
+
547
+ # Instructions
548
+ with gr.Accordion("📖 How to Use", open=False):
549
+ gr.Markdown(
550
+ """
551
+ ### Instructions:
552
+ 1. **Load a model** from Hugging Face (default: Llama-3.2-1B)
553
+ 2. **Enter a prompt** and configure generation settings
554
+ 3. **Click Generate** to create text and visualize attention
555
+ 4. **Interact with the visualization:**
556
+ - Use the **step slider** to navigate through generation steps
557
+ - Adjust the **threshold** to filter weak connections
558
+ - Click on **tokens** in the plot to filter their connections
559
+ - Click **Reset View** to show all connections
560
+
561
+ ### Understanding the Visualization:
562
+ - **Blue lines**: Attention from input to output tokens
563
+ - **Orange curves**: Attention between output tokens
564
+ - **Line thickness**: Represents attention weight strength
565
+ - **Node colors**: Blue = input tokens, Coral = generated tokens
566
+ """
567
+ )
568
+
569
+ # Event handlers
570
+ load_model_btn.click(
571
+ fn=app.load_model,
572
+ inputs=[model_input],
573
+ outputs=[model_status]
574
+ )
575
+
576
+ def _generate(prompt, max_tokens, threshold, temperature):
577
+ info, = app.generate_and_visualize(
578
+ prompt, max_tokens, threshold, temperature, "separate" # Always use separate normalization
579
+ )
580
+
581
+ # Update visualization and dropdown choices
582
+ max_steps = len(app.current_data['attention_matrices']) - 1 if app.current_data else 0
583
+ viz_html = app.create_d3_visualization_html(step_idx=max_steps, threshold=0.01) # Start with last step
584
+ token_choices = app.get_token_choices()
585
+
586
+ return info, viz_html, gr.update(choices=token_choices, value="All tokens"), gr.update(maximum=max_steps, value=max_steps)
587
+
588
+ generate_btn.click(
589
+ fn=_generate,
590
+ inputs=[
591
+ prompt_input,
592
+ max_tokens_input,
593
+ gr.State(app.config.DEFAULT_THRESHOLD), # keep threshold in call but unused
594
+ temperature_input
595
+ ],
596
+ outputs=[generated_info, d3_visualization, token_dropdown, step_slider]
597
+ )
598
+
599
+ # Event handlers for visualization controls
600
+ def _update_visualization(step_idx, threshold, filter_token="All tokens"):
601
+ """Update visualization when step or threshold changes."""
602
+ viz_html = app.create_d3_visualization_html(step_idx=int(step_idx), threshold=threshold, filter_token=filter_token)
603
+ return viz_html
604
+
605
+ def _filter_by_token(selected_token, step_idx, threshold):
606
+ """Update visualization when token filter changes."""
607
+ viz_html = app.create_d3_visualization_html(step_idx=int(step_idx), threshold=threshold, filter_token=selected_token)
608
+ return viz_html
609
+
610
+ # Connect visualization controls
611
+ step_slider.change(
612
+ fn=_update_visualization,
613
+ inputs=[step_slider, threshold_slider, token_dropdown],
614
+ outputs=[d3_visualization]
615
+ )
616
+
617
+ threshold_slider.change(
618
+ fn=_update_visualization,
619
+ inputs=[step_slider, threshold_slider, token_dropdown],
620
+ outputs=[d3_visualization]
621
+ )
622
+
623
+ token_dropdown.change(
624
+ fn=_filter_by_token,
625
+ inputs=[token_dropdown, step_slider, threshold_slider],
626
+ outputs=[d3_visualization]
627
+ )
628
+
629
+
630
+
631
+ # Load default model on startup
632
+ demo.load(
633
+ fn=app.load_model,
634
+ inputs=[gr.State(app.config.DEFAULT_MODEL)],
635
+ outputs=[model_status]
636
+ )
637
+
638
+ return demo
639
+
640
+ if __name__ == "__main__":
641
+ # Check if CUDA is available
642
+ if torch.cuda.is_available():
643
+ print(f"✅ CUDA available: {torch.cuda.get_device_name(0)}")
644
+ else:
645
+ print("⚠️ CUDA not available, using CPU")
646
+
647
+ # Create and launch the app
648
+ demo = create_gradio_interface()
649
+ """ demo.launch(
650
+ share=False, # Set to True for public URL
651
+ server_name="0.0.0.0", # Allow external connections
652
+ server_port=7860, # Default Gradio port
653
+ inbrowser=False # Don't auto-open browser
654
+ ) """
655
+
656
+ demo.launch()
claude.md ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Claude Code Instructions - Token Attention Visualizer
2
+
3
+ ## Project Overview
4
+ You are helping to build a Token Attention Visualizer - a web-based tool that visualizes attention weights in Large Language Models (LLMs) during text generation. The tool shows how input tokens influence the generation of output tokens through interactive visualizations.
5
+
6
+ ## Core Functionality
7
+ 1. Accept a text prompt and generate tokens using a Llama model
8
+ 2. Extract and process attention matrices from the model
9
+ 3. Create an interactive visualization showing token relationships
10
+ 4. Allow users to click tokens to filter connections
11
+ 5. Provide step-by-step navigation through the generation process
12
+
13
+ ## Tech Stack
14
+ - **Backend**: FastAPI
15
+ - **Frontend**: Gradio (for easy Hugging Face Spaces deployment)
16
+ - **Visualization**: Plotly (interactive graphs)
17
+ - **ML**: Transformers, PyTorch
18
+ - **Models**: Llama models (1B-3B range)
19
+
20
+ ## Project Structure
21
+ ```
22
+ token-attention-viz/
23
+ ├── app.py # Main Gradio application
24
+ ├── api/
25
+ │ ├── __init__.py
26
+ │ ├── server.py # FastAPI endpoints (optional)
27
+ │ └── models.py # Pydantic models
28
+ ├── core/
29
+ │ ├── __init__.py
30
+ │ ├── model_handler.py # Model loading and generation
31
+ │ ├── attention.py # Attention processing
32
+ │ └── cache.py # Caching logic
33
+ ├── visualization/
34
+ │ ├── __init__.py
35
+ │ ├── plotly_viz.py # Plotly visualization
36
+ │ └── utils.py # Token cleaning utilities
37
+ ├── requirements.txt
38
+ └── config.py # Configuration settings
39
+ ```
40
+
41
+ ## Implementation Guidelines
42
+
43
+ ### Critical Code to Preserve from Original Implementation
44
+
45
+ 1. **Model Loading Logic**:
46
+ - Device and dtype detection based on GPU capability
47
+ - Pad token handling for models without it
48
+ - Error handling for model loading
49
+
50
+ 2. **Attention Extraction** :
51
+ - BOS token removal from visualization
52
+ - EOS token handling
53
+ - Attention matrix extraction with proper indexing
54
+
55
+ 3. **Token Cleaning Function**:
56
+ ```python
57
+ def clean_label(token):
58
+ label = str(token)
59
+ label = label.replace('Ġ', ' ')
60
+ label = label.replace('▁', ' ')
61
+ label = label.replace('Ċ', '\\n')
62
+ label = label.replace('</s>', '[EOS]')
63
+ label = label.replace('<unk>', '[UNK]')
64
+ label = label.replace('<|begin_of_text|>', '[BOS]')
65
+ label = label.replace('<|end_of_text|>', '[EOS]')
66
+ label = re.sub(r'<0x[0-9A-Fa-f]{2}>', '', label)
67
+ return label.strip() if label.strip() else "[EMPTY]"
68
+ ```
69
+
70
+ 4. **Attention Processing with Separate Normalization**:
71
+ - Layer averaging across heads and layers
72
+ - Separate normalization for input and output attention
73
+ - Epsilon handling (1e-8) to avoid division by zero
74
+
75
+ 5. **Interactive Features**:
76
+ - Token click handling to show specific connections
77
+ - Reset selection functionality
78
+ - Step-by-step navigation
79
+ - "All Connections" view
80
+
81
+ ### Key Implementation Details
82
+
83
+ #### Model Handler (`core/model_handler.py`)
84
+ - Use `unsloth/Llama-3.2-1B-Instruct` as default model
85
+ - Implement proper device detection (CUDA if available)
86
+ - Use bfloat16 for GPUs with compute capability >= 8.0
87
+ - Generate with `output_attentions=True` and `return_dict_in_generate=True`
88
+
89
+ #### Attention Processing (`core/attention.py`)
90
+ - Extract attention for each generation step
91
+ - Average across all layers and heads
92
+ - Apply separate normalization (input and output attention normalized independently)
93
+ - Handle edge cases (first token has no output-to-output attention)
94
+
95
+ #### Visualization (`visualization/plotly_viz.py`)
96
+ - **Layout**:
97
+ - Input tokens on left (x=0.1)
98
+ - Output tokens on right (x=0.9)
99
+ - Use linspace for y-coordinates
100
+ - **Connections**:
101
+ - Blue lines for input→output attention
102
+ - Orange curved lines for output→output attention
103
+ - Line thickness proportional to attention weight
104
+ - Only show connections above threshold
105
+ - **Interactivity**:
106
+ - Click on any token to filter connections
107
+ - Highlight selected token in yellow
108
+ - Show previously generated tokens in pink
109
+ - Current generating token in coral
110
+
111
+ #### Gradio Interface (`app.py`)
112
+ - **Input Controls**:
113
+ - Text area for prompt
114
+ - Slider for max tokens (1-50)
115
+ - Slider for attention threshold (0.0-0.2, step 0.001)
116
+ - **Visualization Controls**:
117
+ - Step slider for navigation
118
+ - Reset Selection button
119
+ - Show All Connections button
120
+ - **Display**:
121
+ - Generated text output
122
+ - Interactive Plotly graph
123
+
124
+ ### Performance Optimizations
125
+
126
+ 1. **Caching**:
127
+ - Cache generated attention matrices by prompt+max_tokens hash
128
+ - LRU cache with configurable size (default 10)
129
+ - Store processed attention, not raw tensors
130
+
131
+ 2. **Lazy Updates**:
132
+ - Only update changed traces when stepping through
133
+ - Don't recreate entire plot on threshold change
134
+ - Use Plotly's batch_update for multiple changes
135
+
136
+ 3. **Memory Management**:
137
+ - Clear raw attention tensors after processing
138
+ - Convert to CPU tensors for storage
139
+ - Use float32 instead of original dtype for visualization
140
+
141
+ ### Configuration (`config.py`)
142
+ ```python
143
+ DEFAULT_MODEL = "unsloth/Llama-3.2-1B-Instruct"
144
+ DEFAULT_PROMPT = "The old wizard walked through the forest"
145
+ DEFAULT_MAX_TOKENS = 20
146
+ DEFAULT_THRESHOLD = 0.05
147
+ MIN_LINE_WIDTH = 0.5
148
+ MAX_LINE_WIDTH = 3.0
149
+ PLOT_WIDTH = 1000
150
+ PLOT_HEIGHT = 600
151
+ ```
152
+
153
+ ### Deployment Preparation
154
+
155
+ For Hugging Face Spaces deployment:
156
+ 1. Create proper `requirements.txt` with pinned versions
157
+ 2. Add `README.md` with Spaces metadata
158
+ 3. Ensure model downloads work in Spaces environment
159
+ 4. Set appropriate memory/GPU requirements
160
+
161
+ ## Testing Instructions
162
+
163
+ 1. **Basic Functionality**:
164
+ - Test with default prompt
165
+ - Verify attention matrices are extracted correctly
166
+ - Check visualization renders properly
167
+
168
+ 2. **Interactive Features**:
169
+ - Click on input tokens - should show only their connections to outputs
170
+ - Click on output tokens - should show incoming connections
171
+ - Reset button should clear selection
172
+ - Step slider should navigate through generation
173
+
174
+ 3. **Edge Cases**:
175
+ - Empty prompt
176
+ - Single token generation
177
+ - Very long prompts (>100 tokens)
178
+ - High/low threshold values
179
+
180
+ ## Development Workflow
181
+
182
+ 1. Start by implementing the model handler and verify generation works
183
+ 2. Add attention extraction and processing
184
+ 3. Create basic visualization without interactivity
185
+ 4. Add interactive features one by one
186
+ 5. Implement caching
187
+ 6. Create Gradio interface
188
+ 7. Test and optimize performance
189
+ 8. Prepare for deployment
190
+
191
+ ## Important Notes
192
+
193
+ - Preserve the token cleaning logic exactly as it handles special tokens
194
+ - Keep the BOS token removal logic for cleaner visualization
195
+ - Maintain separate normalization (not joint) for attention weights
196
+ - Ensure CUDA memory is properly managed to avoid OOM errors
197
+ - Test with different model sizes based on available GPU memory
198
+
199
+ ## Common Issues and Solutions
200
+
201
+ 1. **CUDA OOM**: Reduce batch size or use smaller model
202
+ 2. **Slow Generation**: Enable GPU, use smaller model, or implement streaming
203
+ 3. **Visualization Lag**: Reduce number of traces, implement virtualization
204
+ 4. **Cache Misses**: Normalize prompt formatting before hashing
205
+
206
+ When implementing, prioritize functionality over optimization initially. Get the core visualization working first, then add caching and performance improvements.
config.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ @dataclass
5
+ class Config:
6
+ # Model settings
7
+ DEFAULT_MODEL: str = "HuggingFaceTB/SmolLM-135M-Instruct"
8
+ DEVICE: str = "cpu" # Force CPU usage
9
+
10
+ # Generation settings
11
+ DEFAULT_MAX_TOKENS: int = 20
12
+ DEFAULT_PROMPT: str = "The old wizard walked through the forest when he"
13
+ DEFAULT_TEMPERATURE: float = 0.7
14
+ DEFAULT_TOP_P: float = 0.95
15
+
16
+ # Visualization settings
17
+ DEFAULT_THRESHOLD: float = 0.05
18
+ MIN_LINE_WIDTH: float = 0.5
19
+ MAX_LINE_WIDTH: float = 3.0
20
+
21
+ # Colors
22
+ INPUT_COLOR: str = "skyblue"
23
+ OUTPUT_COLOR: str = "coral"
24
+ CONNECTION_COLOR: str = "rgba(128, 128, 128, 0.3)"
25
+
26
+ # Cache settings
27
+ CACHE_SIZE: int = 10 # Number of generations to cache
28
+
29
+ # UI settings
30
+ PLOT_WIDTH: int = 1000
31
+ PLOT_HEIGHT: int = 600
32
+
33
+ # Node settings
34
+ NODE_SIZE: int = 15
35
+ NODE_LINE_WIDTH: float = 2
36
+
37
+ # Font settings
38
+ FONT_SIZE: int = 10
39
+ FONT_FAMILY: str = "Arial, sans-serif"
core/__init__.py ADDED
File without changes
core/attention.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import List, Dict, Any, Optional
4
+
5
+ class AttentionProcessor:
6
+ @staticmethod
7
+ def process_attention_separate(
8
+ attention_data: Dict[str, Any],
9
+ input_tokens: List[str],
10
+ output_tokens: List[str]
11
+ ) -> List[Dict[str, torch.Tensor]]:
12
+ """
13
+ Process attention with separate normalization for input and output.
14
+ This preserves the relative importance within each group.
15
+ """
16
+ attentions = attention_data['attentions']
17
+ input_len_for_attention = attention_data['input_len_for_attention']
18
+ output_len = attention_data['output_len']
19
+
20
+ if not attentions:
21
+ return [{'input_attention': torch.zeros(input_len_for_attention),
22
+ 'output_attention': None} for _ in range(output_len)]
23
+
24
+ attention_matrices = []
25
+ num_steps = len(attentions)
26
+
27
+ if num_steps == 0:
28
+ print("Warning: No attention steps found in output.")
29
+ return [{'input_attention': torch.zeros(input_len_for_attention),
30
+ 'output_attention': None} for _ in range(output_len)]
31
+
32
+ steps_to_process = min(num_steps, output_len)
33
+
34
+ for i in range(steps_to_process):
35
+ step_attentions = attentions[i]
36
+ input_attention_layers = []
37
+ output_attention_layers = []
38
+
39
+ for layer_idx, layer_attn in enumerate(step_attentions):
40
+ try:
41
+ # Extract attention to input tokens (skip BOS token at position 0)
42
+ input_indices = slice(1, 1 + input_len_for_attention)
43
+ if layer_attn.shape[3] >= input_indices.stop:
44
+ # Get attention from current token (position 0 in generation) to input
45
+ input_attn = layer_attn[0, :, 0, input_indices]
46
+ input_attention_layers.append(input_attn)
47
+
48
+ # Extract attention to previous output tokens
49
+ if i > 0:
50
+ output_indices = slice(1 + input_len_for_attention, 1 + input_len_for_attention + i)
51
+ if layer_attn.shape[3] >= output_indices.stop:
52
+ output_attn = layer_attn[0, :, 0, output_indices]
53
+ output_attention_layers.append(output_attn)
54
+ else:
55
+ output_attention_layers.append(
56
+ torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
57
+ )
58
+ else:
59
+ input_attention_layers.append(
60
+ torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device)
61
+ )
62
+ if i > 0:
63
+ output_attention_layers.append(
64
+ torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
65
+ )
66
+
67
+ except Exception as e:
68
+ print(f"Error processing attention at step {i}, layer {layer_idx}: {e}")
69
+ input_attention_layers.append(
70
+ torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device)
71
+ )
72
+ if i > 0:
73
+ output_attention_layers.append(
74
+ torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
75
+ )
76
+
77
+ # Average across layers and heads
78
+ if input_attention_layers:
79
+ avg_input_attn = torch.mean(torch.stack(input_attention_layers).float(), dim=[0, 1])
80
+ else:
81
+ avg_input_attn = torch.zeros(input_len_for_attention)
82
+
83
+ avg_output_attn = None
84
+ if i > 0 and output_attention_layers:
85
+ avg_output_attn = torch.mean(torch.stack(output_attention_layers).float(), dim=[0, 1])
86
+ elif i > 0:
87
+ avg_output_attn = torch.zeros(i)
88
+
89
+ # Normalize separately with epsilon for numerical stability
90
+ epsilon = 1e-8
91
+ input_sum = avg_input_attn.sum() + epsilon
92
+ normalized_input_attn = avg_input_attn / input_sum
93
+
94
+ normalized_output_attn = None
95
+ if i > 0 and avg_output_attn is not None:
96
+ output_sum = avg_output_attn.sum() + epsilon
97
+ normalized_output_attn = avg_output_attn / output_sum
98
+
99
+ attention_matrices.append({
100
+ 'input_attention': normalized_input_attn.cpu(),
101
+ 'output_attention': normalized_output_attn.cpu() if normalized_output_attn is not None else None,
102
+ 'raw_input_attention': avg_input_attn.cpu(), # Keep raw for analysis
103
+ 'raw_output_attention': avg_output_attn.cpu() if avg_output_attn is not None else None
104
+ })
105
+
106
+ # Fill remaining steps with zeros if needed
107
+ while len(attention_matrices) < output_len:
108
+ attention_matrices.append({
109
+ 'input_attention': torch.zeros(input_len_for_attention),
110
+ 'output_attention': None,
111
+ 'raw_input_attention': torch.zeros(input_len_for_attention),
112
+ 'raw_output_attention': None
113
+ })
114
+
115
+ return attention_matrices
116
+
117
+ @staticmethod
118
+ def process_attention_joint(
119
+ attention_data: Dict[str, Any],
120
+ input_tokens: List[str],
121
+ output_tokens: List[str]
122
+ ) -> List[Dict[str, torch.Tensor]]:
123
+ """
124
+ Process attention with joint normalization across input and output.
125
+ This preserves the relative importance across all tokens.
126
+ """
127
+ attentions = attention_data['attentions']
128
+ input_len_for_attention = attention_data['input_len_for_attention']
129
+ output_len = attention_data['output_len']
130
+
131
+ if not attentions:
132
+ return [{'input_attention': torch.zeros(input_len_for_attention),
133
+ 'output_attention': None} for _ in range(output_len)]
134
+
135
+ attention_matrices = []
136
+ num_steps = len(attentions)
137
+
138
+ if num_steps == 0:
139
+ print("Warning: No attention steps found in output.")
140
+ return [{'input_attention': torch.zeros(input_len_for_attention),
141
+ 'output_attention': None} for _ in range(output_len)]
142
+
143
+ steps_to_process = min(num_steps, output_len)
144
+
145
+ for i in range(steps_to_process):
146
+ step_attentions = attentions[i]
147
+ input_attention_layers = []
148
+ output_attention_layers = []
149
+
150
+ for layer_idx, layer_attn in enumerate(step_attentions):
151
+ try:
152
+ # Extract attention to input tokens
153
+ input_indices = slice(1, 1 + input_len_for_attention)
154
+ if layer_attn.shape[3] >= input_indices.stop:
155
+ input_attn = layer_attn[0, :, 0, input_indices]
156
+ input_attention_layers.append(input_attn)
157
+
158
+ # Extract attention to previous output tokens
159
+ if i > 0:
160
+ output_indices = slice(1 + input_len_for_attention, 1 + input_len_for_attention + i)
161
+ if layer_attn.shape[3] >= output_indices.stop:
162
+ output_attn = layer_attn[0, :, 0, output_indices]
163
+ output_attention_layers.append(output_attn)
164
+ else:
165
+ output_attention_layers.append(
166
+ torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
167
+ )
168
+ else:
169
+ input_attention_layers.append(
170
+ torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device)
171
+ )
172
+ if i > 0:
173
+ output_attention_layers.append(
174
+ torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
175
+ )
176
+
177
+ except Exception as e:
178
+ print(f"Error processing attention at step {i}, layer {layer_idx}: {e}")
179
+ input_attention_layers.append(
180
+ torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device)
181
+ )
182
+ if i > 0:
183
+ output_attention_layers.append(
184
+ torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
185
+ )
186
+
187
+ # Average across layers and heads
188
+ if input_attention_layers:
189
+ avg_input_attn = torch.mean(torch.stack(input_attention_layers).float(), dim=[0, 1])
190
+ else:
191
+ avg_input_attn = torch.zeros(input_len_for_attention)
192
+
193
+ avg_output_attn = None
194
+ if i > 0 and output_attention_layers:
195
+ avg_output_attn = torch.mean(torch.stack(output_attention_layers).float(), dim=[0, 1])
196
+ elif i > 0:
197
+ avg_output_attn = torch.zeros(i)
198
+
199
+ # Joint normalization
200
+ epsilon = 1e-8
201
+ if i > 0 and avg_output_attn is not None:
202
+ # Concatenate and normalize together
203
+ combined_attn = torch.cat([avg_input_attn, avg_output_attn])
204
+ sum_attn = combined_attn.sum() + epsilon
205
+ normalized_combined = combined_attn / sum_attn
206
+ normalized_input_attn = normalized_combined[:input_len_for_attention]
207
+ normalized_output_attn = normalized_combined[input_len_for_attention:]
208
+ else:
209
+ # Only input attention available
210
+ sum_attn = avg_input_attn.sum() + epsilon
211
+ normalized_input_attn = avg_input_attn / sum_attn
212
+ normalized_output_attn = None
213
+
214
+ attention_matrices.append({
215
+ 'input_attention': normalized_input_attn.cpu(),
216
+ 'output_attention': normalized_output_attn.cpu() if normalized_output_attn is not None else None
217
+ })
218
+
219
+ # Fill remaining steps with zeros if needed
220
+ while len(attention_matrices) < output_len:
221
+ attention_matrices.append({
222
+ 'input_attention': torch.zeros(input_len_for_attention),
223
+ 'output_attention': None
224
+ })
225
+
226
+ return attention_matrices
227
+
228
+ @staticmethod
229
+ def extract_attention_for_step(
230
+ attention_data: Dict[str, Any],
231
+ step: int,
232
+ input_len: int
233
+ ) -> Dict[str, torch.Tensor]:
234
+ """
235
+ Extract attention weights for a specific generation step.
236
+ Optimized to only process the needed step.
237
+ """
238
+ attentions = attention_data['attentions']
239
+
240
+ if step >= len(attentions):
241
+ return {
242
+ 'input_attention': torch.zeros(input_len),
243
+ 'output_attention': None
244
+ }
245
+
246
+ step_attentions = attentions[step]
247
+ input_attention_layers = []
248
+ output_attention_layers = []
249
+
250
+ for layer_attn in step_attentions:
251
+ # Extract input attention
252
+ input_indices = slice(1, 1 + input_len)
253
+ if layer_attn.shape[3] >= input_indices.stop:
254
+ input_attn = layer_attn[0, :, 0, input_indices]
255
+ input_attention_layers.append(input_attn)
256
+
257
+ # Extract output attention if there are previous outputs
258
+ if step > 0:
259
+ output_indices = slice(1 + input_len, 1 + input_len + step)
260
+ if layer_attn.shape[3] >= output_indices.stop:
261
+ output_attn = layer_attn[0, :, 0, output_indices]
262
+ output_attention_layers.append(output_attn)
263
+
264
+ # Average and normalize
265
+ if input_attention_layers:
266
+ avg_input = torch.mean(torch.stack(input_attention_layers).float(), dim=[0, 1])
267
+ normalized_input = avg_input / (avg_input.sum() + 1e-8)
268
+ else:
269
+ normalized_input = torch.zeros(input_len)
270
+
271
+ normalized_output = None
272
+ if step > 0 and output_attention_layers:
273
+ avg_output = torch.mean(torch.stack(output_attention_layers).float(), dim=[0, 1])
274
+ normalized_output = avg_output / (avg_output.sum() + 1e-8)
275
+
276
+ return {
277
+ 'input_attention': normalized_input.cpu(),
278
+ 'output_attention': normalized_output.cpu() if normalized_output is not None else None
279
+ }
core/cache.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, Optional
2
+ import hashlib
3
+ import json
4
+ import torch
5
+ import pickle
6
+ import io
7
+
8
+ class AttentionCache:
9
+ def __init__(self, max_size: int = 10):
10
+ self.cache = {}
11
+ self.access_order = []
12
+ self.max_size = max_size
13
+
14
+ def get_key(self, prompt: str, max_tokens: int, model: str, temperature: float = 0.7) -> str:
15
+ """Generate cache key from parameters"""
16
+ data = f"{prompt}_{max_tokens}_{model}_{temperature}"
17
+ return hashlib.md5(data.encode()).hexdigest()
18
+
19
+ def get(self, key: str) -> Optional[Dict[str, Any]]:
20
+ """Retrieve cached data"""
21
+ if key in self.cache:
22
+ # Move to end (LRU)
23
+ self.access_order.remove(key)
24
+ self.access_order.append(key)
25
+ return self._deserialize(self.cache[key])
26
+ return None
27
+
28
+ def set(self, key: str, data: Dict[str, Any]):
29
+ """Store data in cache"""
30
+ if len(self.cache) >= self.max_size:
31
+ # Remove least recently used
32
+ oldest = self.access_order.pop(0)
33
+ del self.cache[oldest]
34
+
35
+ self.cache[key] = self._serialize(data)
36
+ self.access_order.append(key)
37
+
38
+ def _serialize(self, data: Dict[str, Any]) -> bytes:
39
+ """Serialize data for caching, handling torch tensors"""
40
+ serialized = {}
41
+ for key, value in data.items():
42
+ if isinstance(value, list) and len(value) > 0:
43
+ # Check if it's a list of dicts with tensors (attention matrices)
44
+ if isinstance(value[0], dict) and any(isinstance(v, torch.Tensor) for v in value[0].values()):
45
+ # Convert tensors to CPU and serialize
46
+ serialized_list = []
47
+ for item in value:
48
+ serialized_item = {}
49
+ for k, v in item.items():
50
+ if isinstance(v, torch.Tensor):
51
+ serialized_item[k] = v.cpu().numpy()
52
+ else:
53
+ serialized_item[k] = v
54
+ serialized_list.append(serialized_item)
55
+ serialized[key] = serialized_list
56
+ else:
57
+ serialized[key] = value
58
+ else:
59
+ serialized[key] = value
60
+
61
+ buffer = io.BytesIO()
62
+ pickle.dump(serialized, buffer)
63
+ return buffer.getvalue()
64
+
65
+ def _deserialize(self, data: bytes) -> Dict[str, Any]:
66
+ """Deserialize data from cache, restoring torch tensors"""
67
+ buffer = io.BytesIO(data)
68
+ deserialized = pickle.load(buffer)
69
+
70
+ # Convert numpy arrays back to tensors where needed
71
+ for key, value in deserialized.items():
72
+ if isinstance(value, list) and len(value) > 0:
73
+ if isinstance(value[0], dict):
74
+ # Check if it contains numpy arrays (was tensors)
75
+ import numpy as np
76
+ for item in value:
77
+ for k, v in item.items():
78
+ if isinstance(v, np.ndarray):
79
+ item[k] = torch.from_numpy(v)
80
+
81
+ return deserialized
82
+
83
+ def clear(self):
84
+ """Clear the entire cache"""
85
+ self.cache.clear()
86
+ self.access_order.clear()
87
+
88
+ def size(self) -> int:
89
+ """Get current cache size"""
90
+ return len(self.cache)
core/model_handler.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from typing import Tuple, Optional, List, Dict, Any
4
+ import warnings
5
+
6
+ warnings.filterwarnings("ignore", category=UserWarning, module='transformers.generation')
7
+
8
+ class ModelHandler:
9
+ def __init__(self, model_name: str = None, config=None):
10
+ self.model = None
11
+ self.tokenizer = None
12
+ self.device = None
13
+ self.model_name = model_name
14
+ self.config = config
15
+
16
+ def load_model(self, model_name: str = None) -> Tuple[bool, str]:
17
+ """Load model with optimized settings"""
18
+ if model_name:
19
+ self.model_name = model_name
20
+
21
+ if not self.model_name:
22
+ return False, "No model name provided"
23
+
24
+ try:
25
+ print(f"Loading model: {self.model_name}...")
26
+
27
+ # Load tokenizer
28
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
29
+
30
+ # Determine device and dtype
31
+ if self.config and hasattr(self.config, 'DEVICE'):
32
+ self.device = self.config.DEVICE
33
+ # If config specifies CPU, force it even if CUDA is available
34
+ if self.device == "cpu":
35
+ print("Forcing CPU usage as specified in config")
36
+ elif self.device == "cuda" and not torch.cuda.is_available():
37
+ print("CUDA requested but not available, falling back to CPU")
38
+ self.device = "cpu"
39
+ else:
40
+ # Fallback to auto-detection if no config provided
41
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
42
+
43
+ # Use bfloat16 for Ampere GPUs (compute capability >= 8.0), otherwise float32
44
+ if self.device == "cuda" and torch.cuda.is_available():
45
+ capability = torch.cuda.get_device_capability()
46
+ if capability[0] >= 8:
47
+ dtype = torch.bfloat16
48
+ else:
49
+ dtype = torch.float32
50
+ else:
51
+ dtype = torch.float32
52
+
53
+ # Load model
54
+ try:
55
+ self.model = AutoModelForCausalLM.from_pretrained(
56
+ self.model_name,
57
+ torch_dtype=dtype,
58
+ attn_implementation="eager" # Force eager attention for attention extraction
59
+ ).to(self.device)
60
+ print(f"Model loaded on {self.device} with dtype {dtype} (eager attention)")
61
+ except Exception as e:
62
+ print(f"Error loading model with specific dtype: {e}")
63
+ print("Attempting to load without specific dtype...")
64
+ try:
65
+ self.model = AutoModelForCausalLM.from_pretrained(
66
+ self.model_name,
67
+ attn_implementation="eager"
68
+ ).to(self.device)
69
+ print(f"Model loaded on {self.device} (default dtype, eager attention)")
70
+ except Exception as e2:
71
+ print(f"Error with eager attention: {e2}")
72
+ print("Loading with default settings...")
73
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_name).to(self.device)
74
+ print(f"Model loaded on {self.device} (default settings)")
75
+
76
+ # Handle pad token
77
+ if self.tokenizer.pad_token is None:
78
+ if self.tokenizer.eos_token:
79
+ print("Setting pad_token to eos_token")
80
+ self.tokenizer.pad_token = self.tokenizer.eos_token
81
+ if hasattr(self.model.config, 'pad_token_id') and self.model.config.pad_token_id is None:
82
+ self.model.config.pad_token_id = self.tokenizer.eos_token_id
83
+ else:
84
+ print("Warning: No eos_token found to set as pad_token.")
85
+
86
+ return True, f"Model loaded successfully on {self.device}"
87
+
88
+ except Exception as e:
89
+ return False, f"Error loading model: {str(e)}"
90
+
91
+ def generate_with_attention(
92
+ self,
93
+ prompt: str,
94
+ max_tokens: int = 30,
95
+ temperature: float = 0.7,
96
+ top_p: float = 0.95
97
+ ) -> Tuple[Optional[List], List[str], List[str], str]:
98
+ """
99
+ Generate text and capture attention weights
100
+ Returns: (attention_matrices, output_tokens, input_tokens, generated_text)
101
+ """
102
+ if not self.model or not self.tokenizer:
103
+ return None, [], [], "Model not loaded"
104
+
105
+ # Encode input
106
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
107
+ input_len_raw = input_ids.shape[1]
108
+
109
+ print(f"Generating with input length: {input_len_raw}, max_new_tokens: {max_tokens}")
110
+
111
+ # Generate with attention
112
+ with torch.no_grad():
113
+ attention_mask = torch.ones_like(input_ids)
114
+ gen_kwargs = {
115
+ "attention_mask": attention_mask,
116
+ "max_new_tokens": max_tokens,
117
+ "output_attentions": True,
118
+ "return_dict_in_generate": True,
119
+ "temperature": temperature,
120
+ "top_p": top_p,
121
+ "do_sample": temperature > 0
122
+ }
123
+
124
+ if self.tokenizer.pad_token_id is not None:
125
+ gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
126
+
127
+ try:
128
+ output = self.model.generate(input_ids, **gen_kwargs)
129
+ except Exception as e:
130
+ print(f"Error during generation: {e}")
131
+ return None, [], [], f"Error during generation: {str(e)}"
132
+
133
+ # Extract generated tokens
134
+ full_sequence = output.sequences[0]
135
+ if full_sequence.shape[0] > input_len_raw:
136
+ generated_ids = full_sequence[input_len_raw:]
137
+ else:
138
+ generated_ids = torch.tensor([], dtype=torch.long, device=self.device)
139
+
140
+ # Convert to tokens
141
+ output_tokens = self.tokenizer.convert_ids_to_tokens(generated_ids, skip_special_tokens=False)
142
+ input_tokens_raw = self.tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
143
+
144
+ # Handle BOS token removal from visualization
145
+ input_tokens = input_tokens_raw
146
+ input_len_for_attention = input_len_raw
147
+ bos_token = self.tokenizer.bos_token or '<|begin_of_text|>'
148
+
149
+ if input_tokens_raw and input_tokens_raw[0] == bos_token:
150
+ input_tokens = input_tokens_raw[1:]
151
+ input_len_for_attention = input_len_raw - 1
152
+
153
+ # Handle EOS token removal
154
+ eos_token = self.tokenizer.eos_token or '<|end_of_text|>'
155
+ if output_tokens and output_tokens[-1] == eos_token:
156
+ output_tokens = output_tokens[:-1]
157
+ generated_ids = generated_ids[:-1]
158
+
159
+ # Decode generated text
160
+ generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
161
+
162
+ # Extract attention weights
163
+ attentions = getattr(output, 'attentions', None)
164
+ if attentions is None:
165
+ print("Warning: 'attentions' not found in model output. Cannot visualize attention.")
166
+ return None, output_tokens, input_tokens, generated_text
167
+
168
+ # Return raw attention, tokens, and metadata
169
+ return {
170
+ 'attentions': attentions,
171
+ 'input_len_for_attention': input_len_for_attention,
172
+ 'output_len': len(output_tokens)
173
+ }, output_tokens, input_tokens, generated_text
174
+
175
+ def get_model_info(self) -> Dict[str, Any]:
176
+ """Get information about the loaded model"""
177
+ if not self.model:
178
+ return {"loaded": False}
179
+
180
+ return {
181
+ "loaded": True,
182
+ "model_name": self.model_name,
183
+ "device": str(self.device),
184
+ "num_parameters": sum(p.numel() for p in self.model.parameters()),
185
+ "dtype": str(next(self.model.parameters()).dtype),
186
+ "vocab_size": self.tokenizer.vocab_size if self.tokenizer else 0
187
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.30.0
3
+ gradio>=4.0.0
4
+ plotly>=5.14.0
5
+ numpy>=1.24.0
6
+ accelerate>=0.20.0
7
+ sentencepiece>=0.1.99
8
+ protobuf>=3.20.0
visualization/__init__.py ADDED
File without changes
visualization/d3_viz.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ D3.js visualization module for interactive token attention visualization.
3
+ """
4
+
5
+ def create_d3_visualization(data):
6
+ """
7
+ Generate a complete, self-contained HTML string with embedded D3.js visualization.
8
+
9
+ Args:
10
+ data (dict): JSON structure with nodes and links from prepare_d3_data()
11
+
12
+ Returns:
13
+ str: Complete HTML string with embedded D3.js, CSS, and JavaScript
14
+ """
15
+
16
+ # Get nodes by type
17
+ input_nodes = [node for node in data.get('nodes', []) if node.get('type') == 'input']
18
+ output_nodes = [node for node in data.get('nodes', []) if node.get('type') == 'output']
19
+ links = data.get('links', [])
20
+
21
+ # SVG dimensions
22
+ width = 800
23
+ height = max(400, max(len(input_nodes), len(output_nodes)) * 50 + 100)
24
+
25
+ # Calculate positions
26
+ input_x = 100
27
+ output_x = width - 100
28
+
29
+ # Position nodes vertically
30
+ def get_y_pos(index, total):
31
+ if total <= 1:
32
+ return height // 2
33
+ return 80 + (index * (height - 160)) / (total - 1)
34
+
35
+ # Start building SVG
36
+ svg_html = f"""
37
+ <div style='display: flex; flex-direction: column; align-items: center; border: 1px solid #ddd; padding: 20px; margin: 10px; background: white; border-radius: 8px;'>
38
+ <div style='text-align: center; margin-bottom: 15px;'>
39
+ <h3 style='margin: 0; color: #333;'>Token Attention Visualization</h3>
40
+ <p style='margin: 5px 0; color: #666;'>Step {data.get('step', 0) + 1} | {len(input_nodes)} input → {len(output_nodes)} output | {len(links)} connections</p>
41
+ </div>
42
+
43
+ <svg width="{width}" height="{height}" style='border: 1px solid #eee; background: #fafafa; display: block;'>
44
+ <!-- Background grid -->
45
+ <defs>
46
+ <pattern id="grid" width="20" height="20" patternUnits="userSpaceOnUse">
47
+ <path d="M 20 0 L 0 0 0 20" fill="none" stroke="#f0f0f0" stroke-width="1"/>
48
+ </pattern>
49
+ </defs>
50
+ <rect width="100%" height="100%" fill="url(#grid)" />
51
+
52
+ <!-- Column headers -->
53
+ <text x="{input_x}" y="30" text-anchor="middle" font-size="16" font-weight="bold" fill="#4285f4">Input Tokens</text>
54
+ <text x="{output_x}" y="30" text-anchor="middle" font-size="16" font-weight="bold" fill="#ea4335">Output Tokens</text>
55
+ """
56
+
57
+ # Draw connections first (so they appear behind nodes)
58
+ for link in links:
59
+ # Find source and target nodes
60
+ source_node = next((n for n in input_nodes + output_nodes if n['id'] == link['source']), None)
61
+ target_node = next((n for n in input_nodes + output_nodes if n['id'] == link['target']), None)
62
+
63
+ if source_node and target_node:
64
+ # Get positions
65
+ if source_node['type'] == 'input':
66
+ source_idx = next((i for i, n in enumerate(input_nodes) if n['id'] == source_node['id']), 0)
67
+ source_y = get_y_pos(source_idx, len(input_nodes))
68
+ source_x_pos = input_x + 20 # Offset from center of node
69
+ else:
70
+ source_idx = next((i for i, n in enumerate(output_nodes) if n['id'] == source_node['id']), 0)
71
+ source_y = get_y_pos(source_idx, len(output_nodes))
72
+ source_x_pos = output_x - 20
73
+
74
+ if target_node['type'] == 'input':
75
+ target_idx = next((i for i, n in enumerate(input_nodes) if n['id'] == target_node['id']), 0)
76
+ target_y = get_y_pos(target_idx, len(input_nodes))
77
+ target_x_pos = input_x - 20
78
+ else:
79
+ target_idx = next((i for i, n in enumerate(output_nodes) if n['id'] == target_node['id']), 0)
80
+ target_y = get_y_pos(target_idx, len(output_nodes))
81
+ target_x_pos = output_x - 20
82
+
83
+ # Line properties based on weight
84
+ stroke_width = max(1, min(8, link['weight'] * 20))
85
+ opacity = max(0.3, min(1.0, link['weight'] * 2))
86
+ color = "#4285f4" if link['type'] == 'input_to_output' else "#ea4335"
87
+
88
+ # Create straight line
89
+ svg_html += f'''
90
+ <line x1="{source_x_pos}" y1="{source_y}" x2="{target_x_pos}" y2="{target_y}"
91
+ stroke="{color}" stroke-width="{stroke_width}" opacity="{opacity}"/>
92
+ '''
93
+
94
+ # Draw input nodes
95
+ for i, node in enumerate(input_nodes):
96
+ y = get_y_pos(i, len(input_nodes))
97
+ token_text = node['token']
98
+
99
+ # Clean token text - remove special prefix characters
100
+ if token_text.startswith('Ġ'):
101
+ token_text = token_text[1:] # Remove Ġ prefix
102
+ if token_text.startswith('▁'):
103
+ token_text = token_text[1:] # Remove ▁ prefix (SentencePiece)
104
+ if token_text.startswith('##'):
105
+ token_text = token_text[2:] # Remove ## prefix (BERT subwords)
106
+
107
+ if len(token_text) > 15:
108
+ token_text = token_text[:13] + "..."
109
+
110
+ svg_html += f'''
111
+ <g>
112
+ <circle cx="{input_x}" cy="{y}" r="12" fill="#4285f4" stroke="#1a73e8" stroke-width="2" opacity="0.9"/>
113
+ <text x="{input_x - 20}" y="{y + 4}" text-anchor="end" font-size="12" fill="#333" font-weight="bold">{token_text}</text>
114
+ </g>
115
+ '''
116
+
117
+ # Draw output nodes
118
+ for i, node in enumerate(output_nodes):
119
+ y = get_y_pos(i, len(output_nodes))
120
+ token_text = node['token']
121
+
122
+ # Clean token text - remove special prefix characters
123
+ if token_text.startswith('Ġ'):
124
+ token_text = token_text[1:] # Remove Ġ prefix
125
+ if token_text.startswith('▁'):
126
+ token_text = token_text[1:] # Remove ▁ prefix (SentencePiece)
127
+ if token_text.startswith('##'):
128
+ token_text = token_text[2:] # Remove ## prefix (BERT subwords)
129
+
130
+ if len(token_text) > 15:
131
+ token_text = token_text[:13] + "..."
132
+
133
+ svg_html += f'''
134
+ <g>
135
+ <circle cx="{output_x}" cy="{y}" r="12" fill="#ea4335" stroke="#d33b2c" stroke-width="2" opacity="0.9"/>
136
+ <text x="{output_x + 20}" y="{y + 4}" text-anchor="start" font-size="12" fill="#333" font-weight="bold">{token_text}</text>
137
+ </g>
138
+ '''
139
+
140
+ # Close SVG and add legend
141
+ svg_html += '''
142
+ </svg>
143
+
144
+ <div style='margin-top: 20px; padding: 16px; background: #f8fafc; border: 1px solid #e2e8f0; border-radius: 8px;'>
145
+ <div style='display: flex; justify-content: center; align-items: center; gap: 32px; font-size: 12px; color: #64748b; font-family: Inter, sans-serif;'>
146
+ <div style='display: flex; align-items: center; gap: 8px;'>
147
+ <div style='width: 16px; height: 2px; background: #4285f4; border-radius: 1px;'></div>
148
+ <span style='color: #1e293b; font-weight: 500;'>Input → Output</span>
149
+ </div>
150
+ <div style='display: flex; align-items: center; gap: 8px;'>
151
+ <div style='display: flex; gap: 2px;'>
152
+ <div style='width: 8px; height: 1px; background: #64748b;'></div>
153
+ <div style='width: 8px; height: 2px; background: #64748b;'></div>
154
+ <div style='width: 8px; height: 3px; background: #64748b;'></div>
155
+ </div>
156
+ <span style='color: #1e293b; font-weight: 500;'>Line thickness = weight</span>
157
+ </div>
158
+ </div>
159
+ </div>
160
+ </div>
161
+ '''
162
+
163
+ return svg_html
164
+
165
+ def create_d3_visualization_old(data):
166
+ """
167
+ OLD VERSION - Generate a complete, self-contained HTML string with embedded D3.js visualization.
168
+
169
+ Args:
170
+ data (dict): JSON structure with nodes and links from prepare_d3_data()
171
+
172
+ Returns:
173
+ str: Complete HTML string with embedded D3.js, CSS, and JavaScript
174
+ """
175
+
176
+ html_template = f"""
177
+ <!DOCTYPE html>
178
+ <html>
179
+ <head>
180
+ <meta charset="utf-8">
181
+ <style>
182
+ .visualization-container {{
183
+ width: 100%;
184
+ height: 600px;
185
+ border: 1px solid #ddd;
186
+ border-radius: 8px;
187
+ background: #fafafa;
188
+ position: relative;
189
+ overflow: hidden;
190
+ }}
191
+
192
+ .node {{
193
+ cursor: pointer;
194
+ stroke-width: 2px;
195
+ }}
196
+
197
+ .node.input {{
198
+ fill: #4285f4;
199
+ stroke: #1a73e8;
200
+ }}
201
+
202
+ .node.output {{
203
+ fill: #ea4335;
204
+ stroke: #d33b2c;
205
+ }}
206
+
207
+ .node.highlighted {{
208
+ stroke-width: 4px;
209
+ stroke: #ff6d00;
210
+ }}
211
+
212
+ .node.dimmed {{
213
+ opacity: 0.3;
214
+ }}
215
+
216
+ .link {{
217
+ stroke: #666;
218
+ stroke-opacity: 0.6;
219
+ fill: none;
220
+ }}
221
+
222
+ .link.input-to-output {{
223
+ stroke: #4285f4;
224
+ }}
225
+
226
+ .link.output-to-output {{
227
+ stroke: #ea4335;
228
+ }}
229
+
230
+ .link.highlighted {{
231
+ stroke-opacity: 1;
232
+ stroke-width: 3px;
233
+ }}
234
+
235
+ .link.dimmed {{
236
+ stroke-opacity: 0.1;
237
+ }}
238
+
239
+ .token-label {{
240
+ font-family: 'Courier New', monospace;
241
+ font-size: 12px;
242
+ text-anchor: middle;
243
+ dominant-baseline: central;
244
+ fill: white;
245
+ font-weight: bold;
246
+ pointer-events: none;
247
+ }}
248
+
249
+ .reset-btn {{
250
+ position: absolute;
251
+ top: 10px;
252
+ right: 10px;
253
+ padding: 8px 16px;
254
+ background: #4285f4;
255
+ color: white;
256
+ border: none;
257
+ border-radius: 4px;
258
+ cursor: pointer;
259
+ font-size: 12px;
260
+ z-index: 100;
261
+ }}
262
+
263
+ .reset-btn:hover {{
264
+ background: #1a73e8;
265
+ }}
266
+
267
+ .info-panel {{
268
+ position: absolute;
269
+ bottom: 10px;
270
+ left: 10px;
271
+ background: rgba(255, 255, 255, 0.9);
272
+ padding: 8px 12px;
273
+ border-radius: 4px;
274
+ font-size: 11px;
275
+ font-family: Arial, sans-serif;
276
+ border: 1px solid #ddd;
277
+ }}
278
+ </style>
279
+ </head>
280
+ <body>
281
+ <div class="visualization-container" id="viz-container">
282
+ <button class="reset-btn" onclick="resetView()">Reset View</button>
283
+ <div class="info-panel">
284
+ <div>Step: {data.get('step', 0) + 1} / {data.get('total_steps', 1)}</div>
285
+ <div>Nodes: {len(data.get('nodes', []))} | Links: {len(data.get('links', []))}</div>
286
+ <div>Click nodes to filter connections</div>
287
+ </div>
288
+ <svg id="visualization"></svg>
289
+ </div>
290
+
291
+ <script>
292
+ // Simple visualization without D3 first - just to test
293
+ const data = {repr(data)};
294
+
295
+ // Create simple HTML visualization
296
+ const container = document.getElementById("viz-container");
297
+ let html = "<div style='padding: 20px;'>";
298
+ html += "<h3>Debug Info</h3>";
299
+ html += "<p>Nodes: " + data.nodes.length + "</p>";
300
+ html += "<p>Links: " + data.links.length + "</p>";
301
+
302
+ // Simple SVG without D3
303
+ html += "<svg width='800' height='400' style='border: 1px solid #ccc; background: white;'>";
304
+
305
+ // Draw input nodes (left side)
306
+ const inputNodes = data.nodes.filter(n => n.type === "input");
307
+ const outputNodes = data.nodes.filter(n => n.type === "output");
308
+
309
+ inputNodes.forEach((node, i) => {{
310
+ const y = 50 + i * 40;
311
+ html += `<circle cx="50" cy="${{y}}" r="15" fill="#4285f4" stroke="#1a73e8" stroke-width="2"/>`;
312
+ html += `<text x="80" y="${{y + 5}}" font-size="12" fill="black">${{node.token}}</text>`;
313
+ }});
314
+
315
+ // Draw output nodes (right side)
316
+ outputNodes.forEach((node, i) => {{
317
+ const y = 50 + i * 40;
318
+ html += `<circle cx="700" cy="${{y}}" r="15" fill="#ea4335" stroke="#d33b2c" stroke-width="2"/>`;
319
+ html += `<text x="620" y="${{y + 5}}" font-size="12" fill="black" text-anchor="end">${{node.token}}</text>`;
320
+ }});
321
+
322
+ // Draw links
323
+ data.links.forEach(link => {{
324
+ const sourceNode = data.nodes.find(n => n.id === link.source);
325
+ const targetNode = data.nodes.find(n => n.id === link.target);
326
+ if (sourceNode && targetNode) {{
327
+ const sourceIdx = sourceNode.type === "input" ?
328
+ inputNodes.findIndex(n => n.id === sourceNode.id) :
329
+ outputNodes.findIndex(n => n.id === sourceNode.id);
330
+ const targetIdx = targetNode.type === "input" ?
331
+ inputNodes.findIndex(n => n.id === targetNode.id) :
332
+ outputNodes.findIndex(n => n.id === targetNode.id);
333
+
334
+ const sourceX = sourceNode.type === "input" ? 65 : 685;
335
+ const targetX = targetNode.type === "input" ? 65 : 685;
336
+ const sourceY = 50 + sourceIdx * 40;
337
+ const targetY = 50 + targetIdx * 40;
338
+
339
+ const strokeWidth = Math.max(1, link.weight * 10);
340
+ const color = link.type === "input_to_output" ? "#4285f4" : "#ea4335";
341
+
342
+ html += `<line x1="${{sourceX}}" y1="${{sourceY}}" x2="${{targetX}}" y2="${{targetY}}" stroke="${{color}}" stroke-width="${{strokeWidth}}" opacity="0.6"/>`;
343
+ }}
344
+ }});
345
+
346
+ html += "</svg>";
347
+ html += "</div>";
348
+
349
+ container.innerHTML = html;
350
+
351
+ </script>
352
+ </body>
353
+ </html>
354
+ """
355
+
356
+ return html_template
visualization/plotly_viz.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.graph_objects as go
2
+ import numpy as np
3
+ from typing import List, Dict, Any, Optional, Tuple, Callable
4
+ from .utils import (
5
+ clean_label, scale_weight_to_width, scale_weight_to_opacity,
6
+ get_node_positions, create_spline_path, format_attention_text,
7
+ get_color_for_weight, truncate_token_label
8
+ )
9
+
10
+ class AttentionVisualizer:
11
+ def __init__(self, config):
12
+ self.config = config
13
+ self.current_state = {
14
+ 'selected_token': None,
15
+ 'selected_type': None,
16
+ 'current_step': 0,
17
+ 'show_all': True
18
+ }
19
+ self.traces_info = {
20
+ 'input_to_output': [],
21
+ 'output_to_output': [],
22
+ 'input_nodes_idx': None,
23
+ 'output_nodes_idx': None
24
+ }
25
+
26
+ def create_interactive_plot(
27
+ self,
28
+ input_tokens: List[str],
29
+ output_tokens: List[str],
30
+ attention_matrices: List[Dict],
31
+ threshold: float = 0.05,
32
+ initial_step: int = 0,
33
+ normalization: str = "separate"
34
+ ) -> go.Figure:
35
+ """
36
+ Create the main interactive visualization.
37
+ """
38
+ # Clean labels
39
+ input_labels = [clean_label(token) for token in input_tokens]
40
+ output_labels = [clean_label(token) for token in output_tokens]
41
+
42
+ num_input = len(input_labels)
43
+ num_output = len(output_labels)
44
+ num_steps = len(attention_matrices)
45
+
46
+ if num_input == 0 or num_output == 0 or num_steps == 0:
47
+ return self._create_empty_figure("No data to visualize")
48
+
49
+ # Get node positions
50
+ input_x, input_y, output_x, output_y = get_node_positions(num_input, num_output)
51
+
52
+ # Create connection traces
53
+ traces = []
54
+ self.traces_info = {
55
+ 'input_to_output': [],
56
+ 'output_to_output': [],
57
+ 'input_nodes_idx': None,
58
+ 'output_nodes_idx': None
59
+ }
60
+
61
+ # Input to output connections
62
+ for j in range(num_output):
63
+ for i in range(num_input):
64
+ weight = 0
65
+ if j < len(attention_matrices):
66
+ weight = attention_matrices[j]['input_attention'][i].item()
67
+
68
+ opacity = scale_weight_to_opacity(weight, threshold=threshold)
69
+ width = scale_weight_to_width(weight) if opacity > 0 else 0.5
70
+
71
+ trace = go.Scatter(
72
+ x=[input_x[i], output_x[j]],
73
+ y=[input_y[i], output_y[j]],
74
+ mode="lines",
75
+ line=dict(
76
+ color=get_color_for_weight(weight, "blue"),
77
+ width=width
78
+ ),
79
+ opacity=opacity,
80
+ showlegend=False,
81
+ hoverinfo='text',
82
+ text=format_attention_text(input_labels[i], output_labels[j], weight),
83
+ hoverlabel=dict(bgcolor="lightskyblue", bordercolor="darkblue"),
84
+ name=f"in_to_out_{i}_{j}",
85
+ customdata=[(i, j)],
86
+ hovertemplate="Input→Output %{customdata[0]}→%{customdata[1]}<extra></extra>"
87
+ )
88
+ traces.append(trace)
89
+ self.traces_info['input_to_output'].append({
90
+ 'input_idx': i,
91
+ 'output_idx': j,
92
+ 'trace_idx': len(traces) - 1
93
+ })
94
+
95
+ # Output to output connections
96
+ for j in range(1, num_output):
97
+ for i in range(j):
98
+ weight = 0
99
+ if j < len(attention_matrices) and attention_matrices[j]['output_attention'] is not None:
100
+ if i < len(attention_matrices[j]['output_attention']):
101
+ weight = attention_matrices[j]['output_attention'][i].item()
102
+
103
+ opacity = scale_weight_to_opacity(weight, threshold=threshold)
104
+ width = scale_weight_to_width(weight) if opacity > 0 else 0.5
105
+
106
+ # Create spline path for curved connection
107
+ path_x, path_y = create_spline_path(
108
+ output_x[i], output_y[i],
109
+ output_x[j], output_y[j],
110
+ control_offset=0.15
111
+ )
112
+
113
+ trace = go.Scatter(
114
+ x=path_x,
115
+ y=path_y,
116
+ mode="lines",
117
+ line=dict(
118
+ color=get_color_for_weight(weight, "orange"),
119
+ width=width,
120
+ shape='spline'
121
+ ),
122
+ opacity=opacity,
123
+ showlegend=False,
124
+ hoverinfo='text',
125
+ text=format_attention_text(output_labels[i], output_labels[j], weight),
126
+ hoverlabel=dict(bgcolor="moccasin", bordercolor="darkorange"),
127
+ name=f"out_to_out_{i}_{j}"
128
+ )
129
+ traces.append(trace)
130
+ self.traces_info['output_to_output'].append({
131
+ 'from_idx': i,
132
+ 'to_idx': j,
133
+ 'trace_idx': len(traces) - 1
134
+ })
135
+
136
+ # Input nodes
137
+ input_trace = go.Scatter(
138
+ x=input_x,
139
+ y=input_y,
140
+ mode="markers+text",
141
+ marker=dict(
142
+ size=self.config.NODE_SIZE,
143
+ color=self.config.INPUT_COLOR,
144
+ line=dict(width=self.config.NODE_LINE_WIDTH, color="darkblue")
145
+ ),
146
+ selected=dict(
147
+ marker=dict(
148
+ size=self.config.NODE_SIZE + 6,
149
+ color="rgba(0, 0, 200, 0.9)"
150
+ )
151
+ ),
152
+ unselected=dict(
153
+ marker=dict(
154
+ opacity=0.65
155
+ )
156
+ ),
157
+ text=[truncate_token_label(label) for label in input_labels],
158
+ textfont=dict(size=self.config.FONT_SIZE, family=self.config.FONT_FAMILY),
159
+ textposition="middle left",
160
+ name="Input Tokens",
161
+ hovertemplate="Input: %{text}<br>Click to filter connections<extra></extra>",
162
+ customdata=[(i, 'input') for i in range(num_input)]
163
+ )
164
+ traces.append(input_trace)
165
+ self.traces_info['input_nodes_idx'] = len(traces) - 1
166
+
167
+ # Output nodes
168
+ output_colors = []
169
+ for j in range(num_output):
170
+ if j <= initial_step:
171
+ output_colors.append(self.config.OUTPUT_COLOR)
172
+ else:
173
+ output_colors.append("rgba(230, 230, 230, 0.8)")
174
+
175
+ output_trace = go.Scatter(
176
+ x=output_x,
177
+ y=output_y,
178
+ mode="markers+text",
179
+ marker=dict(
180
+ size=self.config.NODE_SIZE,
181
+ color=output_colors,
182
+ line=dict(width=self.config.NODE_LINE_WIDTH, color="darkred")
183
+ ),
184
+ selected=dict(
185
+ marker=dict(
186
+ size=self.config.NODE_SIZE + 6,
187
+ color="rgba(200, 80, 0, 0.9)"
188
+ )
189
+ ),
190
+ unselected=dict(
191
+ marker=dict(
192
+ opacity=0.65
193
+ )
194
+ ),
195
+ text=[truncate_token_label(label) for label in output_labels],
196
+ textfont=dict(size=self.config.FONT_SIZE, family=self.config.FONT_FAMILY),
197
+ textposition="middle right",
198
+ name="Output Tokens",
199
+ hovertemplate="Output: %{text}<br>Click to filter connections<extra></extra>",
200
+ customdata=[(i, 'output') for i in range(num_output)]
201
+ )
202
+ traces.append(output_trace)
203
+ self.traces_info['output_nodes_idx'] = len(traces) - 1
204
+
205
+ # Create figure
206
+ fig = go.Figure(data=traces)
207
+
208
+ # Update layout
209
+ title = f"Token Attention Flow ({normalization.capitalize()} Normalization)"
210
+ fig.update_layout(
211
+ title=title,
212
+ xaxis=dict(
213
+ range=[-0.1, 1.1],
214
+ showgrid=False,
215
+ zeroline=False,
216
+ showticklabels=False,
217
+ fixedrange=True
218
+ ),
219
+ yaxis=dict(
220
+ range=[0, 1],
221
+ showgrid=False,
222
+ zeroline=False,
223
+ showticklabels=False,
224
+ fixedrange=True
225
+ ),
226
+ hovermode="closest",
227
+ clickmode="event+select",
228
+ dragmode="select",
229
+ width=self.config.PLOT_WIDTH,
230
+ height=max(self.config.PLOT_HEIGHT, num_input * 30, num_output * 30),
231
+ plot_bgcolor="white",
232
+ margin=dict(l=150, r=200, t=80, b=80),
233
+ hoverdistance=20,
234
+ hoverlabel=dict(font_size=12, font_family=self.config.FONT_FAMILY),
235
+ showlegend=True,
236
+ legend=dict(
237
+ yanchor="top",
238
+ y=0.99,
239
+ xanchor="left",
240
+ x=1.02
241
+ ),
242
+ # Preserve UI state on updates
243
+ uirevision="constant"
244
+ )
245
+
246
+ # Add legend traces
247
+ fig.add_trace(go.Scatter(
248
+ x=[None], y=[None],
249
+ mode='lines',
250
+ line=dict(color='rgba(0, 0, 255, 0.6)', width=2),
251
+ name='Input→Output'
252
+ ))
253
+ fig.add_trace(go.Scatter(
254
+ x=[None], y=[None],
255
+ mode='lines',
256
+ line=dict(color='rgba(255, 165, 0, 0.6)', width=2),
257
+ name='Output→Output'
258
+ ))
259
+
260
+ # Add annotations
261
+ fig.add_annotation(
262
+ x=0.5, y=0.02,
263
+ text=f"Step {initial_step} / {num_steps-1}: Generating '{output_labels[initial_step] if initial_step < len(output_labels) else ''}'",
264
+ showarrow=False,
265
+ font=dict(size=12, color="darkred"),
266
+ xref="paper", yref="paper"
267
+ )
268
+
269
+ fig.add_annotation(
270
+ x=0.01, y=0.98,
271
+ text="💡 Click tokens to filter connections | Use step slider to navigate generation",
272
+ showarrow=False,
273
+ font=dict(size=10, color="gray"),
274
+ align="left",
275
+ xref="paper", yref="paper"
276
+ )
277
+
278
+ self.current_state['current_step'] = initial_step
279
+
280
+ return fig
281
+
282
+ def update_for_step(
283
+ self,
284
+ fig: go.Figure,
285
+ step: int,
286
+ attention_matrices: List[Dict],
287
+ output_tokens: List[str],
288
+ threshold: float = 0.05
289
+ ) -> go.Figure:
290
+ """
291
+ Update visualization for a specific generation step.
292
+ """
293
+ if step >= len(attention_matrices):
294
+ return fig
295
+
296
+ output_labels = [clean_label(token) for token in output_tokens]
297
+
298
+ with fig.batch_update():
299
+ # Update input-to-output connections for current step
300
+ for conn_info in self.traces_info['input_to_output']:
301
+ if conn_info['output_idx'] == step:
302
+ weight = attention_matrices[step]['input_attention'][conn_info['input_idx']].item()
303
+ opacity = scale_weight_to_opacity(weight, threshold=threshold)
304
+ width = scale_weight_to_width(weight) if opacity > 0 else 0.5
305
+
306
+ trace_idx = conn_info['trace_idx']
307
+ fig.data[trace_idx].opacity = opacity
308
+ fig.data[trace_idx].line.width = width
309
+ fig.data[trace_idx].line.color = get_color_for_weight(weight, "blue")
310
+ elif conn_info['output_idx'] > step:
311
+ # Hide future connections
312
+ fig.data[conn_info['trace_idx']].opacity = 0
313
+
314
+ # Update output-to-output connections
315
+ for conn_info in self.traces_info['output_to_output']:
316
+ if conn_info['to_idx'] == step and attention_matrices[step]['output_attention'] is not None:
317
+ if conn_info['from_idx'] < len(attention_matrices[step]['output_attention']):
318
+ weight = attention_matrices[step]['output_attention'][conn_info['from_idx']].item()
319
+ opacity = scale_weight_to_opacity(weight, threshold=threshold)
320
+ width = scale_weight_to_width(weight) if opacity > 0 else 0.5
321
+
322
+ trace_idx = conn_info['trace_idx']
323
+ fig.data[trace_idx].opacity = opacity
324
+ fig.data[trace_idx].line.width = width
325
+ fig.data[trace_idx].line.color = get_color_for_weight(weight, "orange")
326
+ elif conn_info['to_idx'] > step:
327
+ # Hide future connections
328
+ fig.data[conn_info['trace_idx']].opacity = 0
329
+
330
+ # Update output node colors
331
+ output_colors = []
332
+ for j in range(len(output_tokens)):
333
+ if j <= step:
334
+ output_colors.append(self.config.OUTPUT_COLOR)
335
+ else:
336
+ output_colors.append("rgba(230, 230, 230, 0.8)")
337
+
338
+ if self.traces_info['output_nodes_idx'] is not None:
339
+ fig.data[self.traces_info['output_nodes_idx']].marker.color = output_colors
340
+
341
+ # Update step annotation
342
+ fig.layout.annotations[0].text = f"Step {step} / {len(attention_matrices)-1}: Generating '{output_labels[step] if step < len(output_labels) else ''}'"
343
+
344
+ self.current_state['current_step'] = step
345
+ return fig
346
+
347
+ def filter_by_token(
348
+ self,
349
+ fig: go.Figure,
350
+ token_idx: int,
351
+ token_type: str,
352
+ attention_matrices: List[Dict],
353
+ threshold: float = 0.05
354
+ ) -> go.Figure:
355
+ """
356
+ Filter connections to show only those related to selected token.
357
+ """
358
+ with fig.batch_update():
359
+ current_step = self.current_state['current_step']
360
+
361
+ if token_type == 'input':
362
+ # Show only connections from this input token
363
+ for conn_info in self.traces_info['input_to_output']:
364
+ if conn_info['input_idx'] == token_idx and conn_info['output_idx'] <= current_step:
365
+ weight = attention_matrices[conn_info['output_idx']]['input_attention'][token_idx].item()
366
+ opacity = scale_weight_to_opacity(weight, threshold=threshold)
367
+ fig.data[conn_info['trace_idx']].opacity = opacity if opacity > 0 else 0
368
+ else:
369
+ fig.data[conn_info['trace_idx']].opacity = 0
370
+
371
+ # Hide all output-to-output connections
372
+ for conn_info in self.traces_info['output_to_output']:
373
+ fig.data[conn_info['trace_idx']].opacity = 0
374
+
375
+ elif token_type == 'output':
376
+ # Show connections to this output token
377
+ for conn_info in self.traces_info['input_to_output']:
378
+ if conn_info['output_idx'] == token_idx and token_idx <= current_step:
379
+ weight = attention_matrices[token_idx]['input_attention'][conn_info['input_idx']].item()
380
+ opacity = scale_weight_to_opacity(weight, threshold=threshold)
381
+ fig.data[conn_info['trace_idx']].opacity = opacity if opacity > 0 else 0
382
+ else:
383
+ fig.data[conn_info['trace_idx']].opacity = 0
384
+
385
+ # Show connections from/to this output token
386
+ for conn_info in self.traces_info['output_to_output']:
387
+ show = False
388
+ if conn_info['to_idx'] == token_idx and token_idx <= current_step:
389
+ if attention_matrices[token_idx]['output_attention'] is not None:
390
+ if conn_info['from_idx'] < len(attention_matrices[token_idx]['output_attention']):
391
+ weight = attention_matrices[token_idx]['output_attention'][conn_info['from_idx']].item()
392
+ opacity = scale_weight_to_opacity(weight, threshold=threshold)
393
+ fig.data[conn_info['trace_idx']].opacity = opacity if opacity > 0 else 0
394
+ show = True
395
+ elif conn_info['from_idx'] == token_idx and conn_info['to_idx'] <= current_step:
396
+ if attention_matrices[conn_info['to_idx']]['output_attention'] is not None:
397
+ if token_idx < len(attention_matrices[conn_info['to_idx']]['output_attention']):
398
+ weight = attention_matrices[conn_info['to_idx']]['output_attention'][token_idx].item()
399
+ opacity = scale_weight_to_opacity(weight, threshold=threshold)
400
+ fig.data[conn_info['trace_idx']].opacity = opacity if opacity > 0 else 0
401
+ show = True
402
+
403
+ if not show:
404
+ fig.data[conn_info['trace_idx']].opacity = 0
405
+
406
+ self.current_state['selected_token'] = token_idx
407
+ self.current_state['selected_type'] = token_type
408
+ self.current_state['show_all'] = False
409
+
410
+ return fig
411
+
412
+ def show_all_connections(
413
+ self,
414
+ fig: go.Figure,
415
+ attention_matrices: List[Dict],
416
+ threshold: float = 0.05
417
+ ) -> go.Figure:
418
+ """
419
+ Reset to show all connections for current step.
420
+ """
421
+ self.current_state['selected_token'] = None
422
+ self.current_state['selected_type'] = None
423
+ self.current_state['show_all'] = True
424
+
425
+ return self.update_for_step(
426
+ fig,
427
+ self.current_state['current_step'],
428
+ attention_matrices,
429
+ [clean_label(t) for t in attention_matrices],
430
+ threshold
431
+ )
432
+
433
+ def _create_empty_figure(self, message: str) -> go.Figure:
434
+ """Create an empty figure with a message."""
435
+ fig = go.Figure()
436
+ fig.update_layout(
437
+ title=message,
438
+ xaxis={'visible': False},
439
+ yaxis={'visible': False},
440
+ width=self.config.PLOT_WIDTH,
441
+ height=self.config.PLOT_HEIGHT
442
+ )
443
+ return fig
visualization/simple_svg_viz.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List, Dict, Any, Optional, Tuple
3
+ from .utils import clean_label, scale_weight_to_width, scale_weight_to_opacity
4
+
5
+ class SimpleSVGVisualizer:
6
+ def __init__(self, config):
7
+ self.config = config
8
+
9
+ def create_visualization_html(
10
+ self,
11
+ input_tokens: List[str],
12
+ output_tokens: List[str],
13
+ attention_matrices: List[Dict],
14
+ threshold: float = 0.05,
15
+ initial_step: int = 0,
16
+ selected_token: Optional[int] = None,
17
+ selected_type: Optional[str] = None
18
+ ) -> str:
19
+ """Create a simple SVG visualization without D3."""
20
+ # Clean labels
21
+ input_labels = [clean_label(token) for token in input_tokens]
22
+ output_labels = [clean_label(token) for token in output_tokens]
23
+
24
+ # Calculate positions
25
+ width = self.config.PLOT_WIDTH
26
+ height = self.config.PLOT_HEIGHT
27
+ margin = 100
28
+
29
+ input_x = margin
30
+ output_x = width - margin
31
+
32
+ # Create SVG elements
33
+ svg_elements = []
34
+
35
+ # Background
36
+ svg_elements.append(f'<rect width="{width}" height="{height}" fill="white" stroke="#ddd"/>')
37
+
38
+ # Title
39
+ svg_elements.append(f'<text x="{width/2}" y="30" text-anchor="middle" font-size="16" font-weight="bold">Token Attention Flow</text>')
40
+
41
+ # Calculate vertical positions
42
+ input_y_positions = []
43
+ output_y_positions = []
44
+
45
+ if len(input_labels) > 0:
46
+ input_spacing = (height - 2 * margin) / max(1, len(input_labels) - 1)
47
+ input_y_positions = [margin + i * input_spacing for i in range(len(input_labels))]
48
+
49
+ if len(output_labels) > 0:
50
+ output_spacing = (height - 2 * margin) / max(1, len(output_labels) - 1)
51
+ output_y_positions = [margin + i * output_spacing for i in range(len(output_labels))]
52
+
53
+ # Draw connections
54
+ for j in range(min(initial_step + 1, len(output_labels))):
55
+ if j < len(attention_matrices):
56
+ for i in range(len(input_labels)):
57
+ weight = attention_matrices[j]['input_attention'][i].item()
58
+
59
+ # Apply filtering
60
+ if selected_token is not None:
61
+ if selected_type == 'input' and i != selected_token:
62
+ continue
63
+ elif selected_type == 'output' and j != selected_token:
64
+ continue
65
+
66
+ if weight > threshold:
67
+ opacity = scale_weight_to_opacity(weight, threshold)
68
+ width_val = scale_weight_to_width(weight)
69
+
70
+ svg_elements.append(
71
+ f'<line x1="{input_x}" y1="{input_y_positions[i]}" '
72
+ f'x2="{output_x}" y2="{output_y_positions[j]}" '
73
+ f'stroke="blue" stroke-width="{width_val}" opacity="{opacity}"/>'
74
+ )
75
+
76
+ # Draw input nodes
77
+ for i, label in enumerate(input_labels):
78
+ y = input_y_positions[i]
79
+ color = "yellow" if selected_token == i and selected_type == 'input' else self.config.INPUT_COLOR
80
+
81
+ svg_elements.append(
82
+ f'<circle cx="{input_x}" cy="{y}" r="{self.config.NODE_SIZE/2}" '
83
+ f'fill="{color}" stroke="darkblue" stroke-width="2" '
84
+ f'style="cursor: pointer" '
85
+ f'onclick="handleTokenClick({i}, \'input\')"/>'
86
+ )
87
+ svg_elements.append(
88
+ f'<text x="{input_x - self.config.NODE_SIZE/2 - 10}" y="{y + 5}" '
89
+ f'text-anchor="end" font-size="{self.config.FONT_SIZE}">{label}</text>'
90
+ )
91
+
92
+ # Draw output nodes
93
+ for j, label in enumerate(output_labels):
94
+ y = output_y_positions[j]
95
+ color = "yellow" if selected_token == j and selected_type == 'output' else (
96
+ self.config.OUTPUT_COLOR if j <= initial_step else "#e6e6e6"
97
+ )
98
+
99
+ svg_elements.append(
100
+ f'<circle cx="{output_x}" cy="{y}" r="{self.config.NODE_SIZE/2}" '
101
+ f'fill="{color}" stroke="darkred" stroke-width="2" '
102
+ f'style="cursor: pointer" '
103
+ f'onclick="handleTokenClick({j}, \'output\')"/>'
104
+ )
105
+ svg_elements.append(
106
+ f'<text x="{output_x + self.config.NODE_SIZE/2 + 10}" y="{y + 5}" '
107
+ f'text-anchor="start" font-size="{self.config.FONT_SIZE}">{label}</text>'
108
+ )
109
+
110
+ # Step info
111
+ svg_elements.append(
112
+ f'<text x="{width/2}" y="{height - 20}" text-anchor="middle" font-size="12" fill="darkred">'
113
+ f'Step {initial_step} / {len(output_labels) - 1}: Generating "{output_labels[initial_step] if initial_step < len(output_labels) else ""}"'
114
+ f'</text>'
115
+ )
116
+
117
+ # Create HTML
118
+ html = f"""
119
+ <div style="width: 100%; overflow-x: auto;">
120
+ <svg width="{width}" height="{height}" style="border: 1px solid #ddd;">
121
+ {''.join(svg_elements)}
122
+ </svg>
123
+ </div>
124
+
125
+ <script>
126
+ function handleTokenClick(index, type) {{
127
+ console.log('Token clicked:', index, type);
128
+ const hiddenInput = document.querySelector('#clicked-token-d3 textarea');
129
+ if (hiddenInput) {{
130
+ const clickData = JSON.stringify({{index: index, type: type}});
131
+ hiddenInput.value = clickData;
132
+ hiddenInput.dispatchEvent(new Event('input', {{ bubbles: true }}));
133
+ }}
134
+ }}
135
+ </script>
136
+ """
137
+
138
+ return html
visualization/utils.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Tuple, Optional
3
+ import numpy as np
4
+
5
+ def clean_label(token: str) -> str:
6
+ """
7
+ Cleans token labels for visualization.
8
+ Handles various tokenizer-specific formatting.
9
+ """
10
+ label = str(token)
11
+
12
+ # Handle common tokenizer prefixes
13
+ label = label.replace('Ġ', ' ') # GPT-2 style space
14
+ label = label.replace('▁', ' ') # SentencePiece style space
15
+ label = label.replace('Ċ', '\\n') # Newline
16
+
17
+ # Handle special tokens
18
+ label = label.replace('</s>', '[EOS]')
19
+ label = label.replace('<s>', '[BOS]')
20
+ label = label.replace('<unk>', '[UNK]')
21
+ label = label.replace('<pad>', '[PAD]')
22
+ label = label.replace('<|begin_of_text|>', '[BOS]')
23
+ label = label.replace('<|end_of_text|>', '[EOS]')
24
+ label = label.replace('<|endoftext|>', '[EOS]')
25
+
26
+ # Remove byte-level encoding markers
27
+ label = re.sub(r'<0x[0-9A-Fa-f]{2}>', '', label)
28
+
29
+ # Clean up whitespace
30
+ label = label.strip()
31
+
32
+ # Return cleaned label or placeholder
33
+ return label if label else "[EMPTY]"
34
+
35
+ def scale_weight_to_width(
36
+ weight: float,
37
+ min_width: float = 0.5,
38
+ max_width: float = 3.0,
39
+ scale_factor: float = 5.0
40
+ ) -> float:
41
+ """
42
+ Scale attention weight to line width for visualization.
43
+
44
+ Args:
45
+ weight: Attention weight (0-1)
46
+ min_width: Minimum line width
47
+ max_width: Maximum line width
48
+ scale_factor: Scaling factor for weight
49
+
50
+ Returns:
51
+ Scaled line width
52
+ """
53
+ scaled = min(1.0, weight * scale_factor)
54
+ return min_width + (max_width - min_width) * scaled
55
+
56
+ def scale_weight_to_opacity(
57
+ weight: float,
58
+ min_opacity: float = 0.1,
59
+ max_opacity: float = 1.0,
60
+ threshold: float = 0.0
61
+ ) -> float:
62
+ """
63
+ Scale attention weight to opacity for visualization.
64
+
65
+ Args:
66
+ weight: Attention weight (0-1)
67
+ min_opacity: Minimum opacity
68
+ max_opacity: Maximum opacity
69
+ threshold: Threshold below which opacity is 0
70
+
71
+ Returns:
72
+ Scaled opacity
73
+ """
74
+ if weight < threshold:
75
+ return 0.0
76
+
77
+ # Linear scaling above threshold
78
+ normalized = (weight - threshold) / (1.0 - threshold) if threshold < 1.0 else weight
79
+ return min_opacity + (max_opacity - min_opacity) * normalized
80
+
81
+ def get_node_positions(
82
+ num_input: int,
83
+ num_output: int,
84
+ spacing: str = 'linear'
85
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
86
+ """
87
+ Calculate node positions for visualization.
88
+
89
+ Args:
90
+ num_input: Number of input tokens
91
+ num_output: Number of output tokens
92
+ spacing: Spacing strategy ('linear', 'equal')
93
+
94
+ Returns:
95
+ Tuple of (input_x, input_y, output_x, output_y)
96
+ """
97
+ # Y positions (vertical)
98
+ if spacing == 'linear':
99
+ input_y = np.linspace(0.1, 0.9, num_input) if num_input > 1 else np.array([0.5])
100
+ output_y = np.linspace(0.1, 0.9, num_output) if num_output > 1 else np.array([0.5])
101
+ else: # equal spacing
102
+ total_height = 0.8
103
+ input_spacing = total_height / (num_input + 1)
104
+ output_spacing = total_height / (num_output + 1)
105
+ input_y = np.array([0.1 + (i + 1) * input_spacing for i in range(num_input)])
106
+ output_y = np.array([0.1 + (i + 1) * output_spacing for i in range(num_output)])
107
+
108
+ # X positions (horizontal)
109
+ input_x = np.full(num_input, 0.1)
110
+ output_x = np.full(num_output, 0.9)
111
+
112
+ return input_x, input_y, output_x, output_y
113
+
114
+ def create_spline_path(
115
+ start_x: float,
116
+ start_y: float,
117
+ end_x: float,
118
+ end_y: float,
119
+ control_offset: float = 0.15
120
+ ) -> Tuple[List[float], List[float]]:
121
+ """
122
+ Create a spline path for output-to-output connections.
123
+
124
+ Args:
125
+ start_x, start_y: Starting position
126
+ end_x, end_y: Ending position
127
+ control_offset: Offset for control points
128
+
129
+ Returns:
130
+ Tuple of (x_path, y_path) for spline
131
+ """
132
+ # Create control points for smooth curve
133
+ path_x = [
134
+ start_x,
135
+ start_x + control_offset,
136
+ end_x + control_offset,
137
+ end_x
138
+ ]
139
+ path_y = [
140
+ start_y,
141
+ start_y,
142
+ end_y,
143
+ end_y
144
+ ]
145
+
146
+ return path_x, path_y
147
+
148
+ def format_attention_text(
149
+ from_token: str,
150
+ to_token: str,
151
+ weight: float,
152
+ connection_type: str = "attention"
153
+ ) -> str:
154
+ """
155
+ Format hover text for attention connections.
156
+
157
+ Args:
158
+ from_token: Source token
159
+ to_token: Target token
160
+ weight: Attention weight
161
+ connection_type: Type of connection
162
+
163
+ Returns:
164
+ Formatted hover text
165
+ """
166
+ return (
167
+ f"{from_token} → {to_token}<br>"
168
+ f"{connection_type.capitalize()} Weight: {weight:.4f}"
169
+ )
170
+
171
+ def get_color_for_weight(
172
+ weight: float,
173
+ base_color: str = "blue",
174
+ use_gradient: bool = True
175
+ ) -> str:
176
+ """
177
+ Get color for attention weight visualization.
178
+
179
+ Args:
180
+ weight: Attention weight (0-1)
181
+ base_color: Base color name
182
+ use_gradient: Whether to use gradient based on weight
183
+
184
+ Returns:
185
+ Color string for plotly
186
+ """
187
+ if not use_gradient:
188
+ if base_color == "blue":
189
+ return "rgba(0, 0, 255, 0.6)"
190
+ elif base_color == "orange":
191
+ return "rgba(255, 165, 0, 0.6)"
192
+ else:
193
+ return "rgba(128, 128, 128, 0.6)"
194
+
195
+ # Create gradient based on weight
196
+ if base_color == "blue":
197
+ # Light blue to dark blue
198
+ intensity = int(255 - weight * 155) # 255 to 100
199
+ return f"rgba(0, {intensity}, 255, {0.3 + weight * 0.4})"
200
+ elif base_color == "orange":
201
+ # Light orange to dark orange
202
+ intensity = int(255 - weight * 100) # 255 to 155
203
+ return f"rgba(255, {intensity}, 0, {0.3 + weight * 0.4})"
204
+ else:
205
+ # Gray scale
206
+ intensity = int(200 - weight * 100) # 200 to 100
207
+ return f"rgba({intensity}, {intensity}, {intensity}, {0.3 + weight * 0.4})"
208
+
209
+ def truncate_token_label(token: str, max_length: int = 15) -> str:
210
+ """
211
+ Truncate long token labels for display.
212
+
213
+ Args:
214
+ token: Token string
215
+ max_length: Maximum length
216
+
217
+ Returns:
218
+ Truncated token with ellipsis if needed
219
+ """
220
+ cleaned = clean_label(token)
221
+ if len(cleaned) > max_length:
222
+ return cleaned[:max_length-3] + "..."
223
+ return cleaned