Commit
·
dd850a7
1
Parent(s):
6df9665
initial
Browse files- .gitignore +104 -0
- README.md +72 -7
- api/__init__.py +0 -0
- app.py +656 -0
- claude.md +206 -0
- config.py +39 -0
- core/__init__.py +0 -0
- core/attention.py +279 -0
- core/cache.py +90 -0
- core/model_handler.py +187 -0
- requirements.txt +8 -0
- visualization/__init__.py +0 -0
- visualization/d3_viz.py +356 -0
- visualization/plotly_viz.py +443 -0
- visualization/simple_svg_viz.py +138 -0
- visualization/utils.py +223 -0
.gitignore
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
downloads/
|
11 |
+
eggs/
|
12 |
+
.eggs/
|
13 |
+
lib/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
wheels/
|
19 |
+
pip-wheel-metadata/
|
20 |
+
share/python-wheels/
|
21 |
+
*.egg-info/
|
22 |
+
.installed.cfg
|
23 |
+
*.egg
|
24 |
+
MANIFEST
|
25 |
+
|
26 |
+
# Virtual environments
|
27 |
+
venv/
|
28 |
+
env/
|
29 |
+
ENV/
|
30 |
+
env.bak/
|
31 |
+
venv.bak/
|
32 |
+
.venv/
|
33 |
+
|
34 |
+
# IDE and editors
|
35 |
+
.vscode/
|
36 |
+
.idea/
|
37 |
+
*.swp
|
38 |
+
*.swo
|
39 |
+
*~
|
40 |
+
.claude/
|
41 |
+
|
42 |
+
# OS generated files
|
43 |
+
.DS_Store
|
44 |
+
.DS_Store?
|
45 |
+
._*
|
46 |
+
.Spotlight-V100
|
47 |
+
.Trashes
|
48 |
+
ehthumbs.db
|
49 |
+
Thumbs.db
|
50 |
+
|
51 |
+
# Logs and databases
|
52 |
+
*.log
|
53 |
+
*.sqlite3
|
54 |
+
*.db
|
55 |
+
|
56 |
+
# Model cache and downloads
|
57 |
+
models/
|
58 |
+
.cache/
|
59 |
+
huggingface_hub/
|
60 |
+
transformers_cache/
|
61 |
+
|
62 |
+
# Temporary files
|
63 |
+
*.tmp
|
64 |
+
*.temp
|
65 |
+
.tmp/
|
66 |
+
|
67 |
+
# Environment variables
|
68 |
+
.env
|
69 |
+
.env.local
|
70 |
+
.env.production
|
71 |
+
.env.staging
|
72 |
+
|
73 |
+
# Jupyter Notebook
|
74 |
+
.ipynb_checkpoints
|
75 |
+
|
76 |
+
# pytest
|
77 |
+
.pytest_cache/
|
78 |
+
.coverage
|
79 |
+
|
80 |
+
# mypy
|
81 |
+
.mypy_cache/
|
82 |
+
.dmypy.json
|
83 |
+
dmypy.json
|
84 |
+
|
85 |
+
# Local development
|
86 |
+
local_test.py
|
87 |
+
test_*.py
|
88 |
+
debug.py
|
89 |
+
|
90 |
+
# Gradio temporary files
|
91 |
+
.gradio/
|
92 |
+
gradio_cached_examples/
|
93 |
+
flagged/
|
94 |
+
|
95 |
+
# Large files that shouldn't be in git
|
96 |
+
*.bin
|
97 |
+
*.safetensors
|
98 |
+
*.pt
|
99 |
+
*.pth
|
100 |
+
*.ckpt
|
101 |
+
*.h5
|
102 |
+
|
103 |
+
# Documentation build
|
104 |
+
docs/_build/
|
README.md
CHANGED
@@ -1,14 +1,79 @@
|
|
1 |
---
|
2 |
title: Token Attention Visualizer
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license:
|
11 |
-
short_description: An interactive tool for visualizing attention patterns in La
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: Token Attention Visualizer
|
3 |
+
emoji: 🔍
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.0.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: apache-2.0
|
|
|
11 |
---
|
12 |
|
13 |
+
# Token Attention Visualizer
|
14 |
+
|
15 |
+
An interactive tool for visualizing attention patterns in Large Language Models during text generation.
|
16 |
+
|
17 |
+
## Features
|
18 |
+
|
19 |
+
- 🚀 **Real-time Generation**: Generate text with any Hugging Face model
|
20 |
+
- 🔍 **Attention Visualization**: Explore attention patterns with clear visual representations
|
21 |
+
- 📊 **Dual Normalization**: Choose between separate or joint attention normalization
|
22 |
+
- ⚡ **Smart Caching**: Fast response with intelligent result caching
|
23 |
+
- 🎯 **Token Selection**: Use dropdown menus to select and filter token connections
|
24 |
+
- 📈 **Step Navigation**: Navigate through generation steps
|
25 |
+
- 🎨 **Customizable Threshold**: Filter weak attention connections
|
26 |
+
|
27 |
+
## How It Works
|
28 |
+
|
29 |
+
The visualizer shows how tokens attend to each other during text generation:
|
30 |
+
- **Blue lines**: Attention from input tokens to output tokens
|
31 |
+
- **Orange curves**: Attention between output tokens
|
32 |
+
- **Line thickness**: Represents attention weight strength
|
33 |
+
|
34 |
+
## Usage
|
35 |
+
|
36 |
+
1. **Load a Model**: Enter a Hugging Face model name (default: HuggingFaceTB/SmolLM-135M-Instruct)
|
37 |
+
2. **Enter Prompt**: Type your input text
|
38 |
+
3. **Configure Settings**: Adjust max tokens, temperature, and normalization
|
39 |
+
4. **Generate**: Click to generate text and visualize attention
|
40 |
+
5. **Explore**: Use dropdown menus to select tokens and view their attention patterns
|
41 |
+
|
42 |
+
## Technical Details
|
43 |
+
|
44 |
+
- Built with Gradio for the interface
|
45 |
+
- Visualization system with dropdown-based token selection
|
46 |
+
- Supports any Hugging Face causal language model
|
47 |
+
- Optimized for smaller models like SmolLM for efficient deployment
|
48 |
+
- Implements efficient attention processing and caching
|
49 |
+
|
50 |
+
## Local Development
|
51 |
+
|
52 |
+
```bash
|
53 |
+
# Clone the repository
|
54 |
+
git clone <repo-url>
|
55 |
+
cd token-attention-viz
|
56 |
+
|
57 |
+
# Install dependencies
|
58 |
+
pip install -r requirements.txt
|
59 |
+
|
60 |
+
# Run the app
|
61 |
+
python app.py
|
62 |
+
```
|
63 |
+
|
64 |
+
## Deployment
|
65 |
+
|
66 |
+
This app is designed for easy deployment on Hugging Face Spaces. Simply:
|
67 |
+
1. Create a new Space
|
68 |
+
2. Upload the project files
|
69 |
+
3. The app will automatically start
|
70 |
+
|
71 |
+
## Requirements
|
72 |
+
|
73 |
+
- Python 3.8+
|
74 |
+
- 4GB+ RAM (SmolLM models are lightweight)
|
75 |
+
- GPU acceleration optional (works well on CPU)
|
76 |
+
|
77 |
+
## License
|
78 |
+
|
79 |
+
Apache 2.0
|
api/__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,656 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
# Add project root to path
|
9 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
10 |
+
|
11 |
+
from core.model_handler import ModelHandler
|
12 |
+
from core.attention import AttentionProcessor
|
13 |
+
from core.cache import AttentionCache
|
14 |
+
from config import Config
|
15 |
+
from visualization.d3_viz import create_d3_visualization
|
16 |
+
|
17 |
+
class TokenVisualizerApp:
|
18 |
+
def __init__(self):
|
19 |
+
self.config = Config()
|
20 |
+
self.model_handler = ModelHandler(config=self.config)
|
21 |
+
self.cache = AttentionCache(max_size=self.config.CACHE_SIZE)
|
22 |
+
self.current_data = None
|
23 |
+
self.model_loaded = False
|
24 |
+
|
25 |
+
|
26 |
+
def load_model(self, model_name: str = None) -> str:
|
27 |
+
"""Load the model and return status message."""
|
28 |
+
if not model_name:
|
29 |
+
model_name = self.config.DEFAULT_MODEL
|
30 |
+
|
31 |
+
success, message = self.model_handler.load_model(model_name)
|
32 |
+
self.model_loaded = success
|
33 |
+
|
34 |
+
if success:
|
35 |
+
model_info = self.model_handler.get_model_info()
|
36 |
+
return f"✅ Model loaded: {model_name}\n📊 Parameters: {model_info['num_parameters']:,}\n🖥️ Device: {model_info['device']}"
|
37 |
+
else:
|
38 |
+
return f"❌ Failed to load model: {message}"
|
39 |
+
|
40 |
+
def generate_and_visualize(
|
41 |
+
self,
|
42 |
+
prompt: str,
|
43 |
+
max_tokens: int,
|
44 |
+
threshold: float,
|
45 |
+
temperature: float,
|
46 |
+
normalization: str,
|
47 |
+
progress=gr.Progress()
|
48 |
+
):
|
49 |
+
"""Main generation function (no visualization)."""
|
50 |
+
if not self.model_loaded:
|
51 |
+
return None, "Please load a model first!", None
|
52 |
+
|
53 |
+
if not prompt.strip():
|
54 |
+
return None, "Please enter a prompt!", None
|
55 |
+
|
56 |
+
progress(0.2, desc="Checking cache...")
|
57 |
+
|
58 |
+
# Check cache
|
59 |
+
cache_key = self.cache.get_key(
|
60 |
+
prompt, max_tokens,
|
61 |
+
self.model_handler.model_name,
|
62 |
+
temperature
|
63 |
+
)
|
64 |
+
cached = self.cache.get(cache_key)
|
65 |
+
|
66 |
+
if cached:
|
67 |
+
progress(0.5, desc="Using cached data...")
|
68 |
+
self.current_data = cached
|
69 |
+
else:
|
70 |
+
progress(0.3, desc="Generating text...")
|
71 |
+
|
72 |
+
# Generate new
|
73 |
+
attention_data, output_tokens, input_tokens, generated_text = \
|
74 |
+
self.model_handler.generate_with_attention(
|
75 |
+
prompt, max_tokens, temperature
|
76 |
+
)
|
77 |
+
|
78 |
+
if attention_data is None:
|
79 |
+
return None, f"Generation failed: {generated_text}", None
|
80 |
+
|
81 |
+
progress(0.6, desc="Processing attention...")
|
82 |
+
|
83 |
+
# Process attention based on normalization method
|
84 |
+
if normalization == "separate":
|
85 |
+
attention_matrices = AttentionProcessor.process_attention_separate(
|
86 |
+
attention_data, input_tokens, output_tokens
|
87 |
+
)
|
88 |
+
else:
|
89 |
+
attention_matrices = AttentionProcessor.process_attention_joint(
|
90 |
+
attention_data, input_tokens, output_tokens
|
91 |
+
)
|
92 |
+
|
93 |
+
self.current_data = {
|
94 |
+
'input_tokens': input_tokens,
|
95 |
+
'output_tokens': output_tokens,
|
96 |
+
'attention_matrices': attention_matrices,
|
97 |
+
'generated_text': generated_text,
|
98 |
+
'attention_data': attention_data # Keep raw for step updates
|
99 |
+
}
|
100 |
+
|
101 |
+
# Cache it
|
102 |
+
self.cache.set(cache_key, self.current_data)
|
103 |
+
|
104 |
+
progress(1.0, desc="Complete!")
|
105 |
+
|
106 |
+
# Create info text
|
107 |
+
info_text = f"📝 Generated: {self.current_data['generated_text']}\n"
|
108 |
+
info_text += f"🔤 Input tokens: {len(self.current_data['input_tokens'])}\n"
|
109 |
+
info_text += f"🔤 Output tokens: {len(self.current_data['output_tokens'])}"
|
110 |
+
|
111 |
+
return (
|
112 |
+
info_text,
|
113 |
+
)
|
114 |
+
|
115 |
+
def update_step(self, step_idx: int, threshold: float):
|
116 |
+
"""No-op placeholder after removing visualization."""
|
117 |
+
return None
|
118 |
+
|
119 |
+
def update_threshold(self, threshold: float, normalization: str):
|
120 |
+
"""No-op placeholder after removing visualization."""
|
121 |
+
return None
|
122 |
+
|
123 |
+
def filter_token_connections(self, token_idx: int, token_type: str, threshold: float):
|
124 |
+
"""Removed visualization; keep placeholder."""
|
125 |
+
return None
|
126 |
+
|
127 |
+
def reset_view(self, threshold: float):
|
128 |
+
"""Removed visualization; keep placeholder."""
|
129 |
+
return None
|
130 |
+
|
131 |
+
def on_d3_token_click(self, click_data: str, threshold: float):
|
132 |
+
"""Removed visualization; keep placeholder for compatibility."""
|
133 |
+
return None, gr.update()
|
134 |
+
|
135 |
+
def on_input_token_select(self, token_label: str, threshold: float):
|
136 |
+
"""Removed visualization; keep placeholder for compatibility."""
|
137 |
+
return None
|
138 |
+
|
139 |
+
def prepare_d3_data(self, step_idx: int, threshold: float = 0.01, filter_token: str = None):
|
140 |
+
"""
|
141 |
+
Convert attention data to D3.js-friendly JSON format.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
step_idx: Generation step to visualize (0-based)
|
145 |
+
threshold: Minimum attention weight to include
|
146 |
+
filter_token: Token to filter by (format: "[IN] token" or "[OUT] token" or "All tokens")
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
dict: JSON structure with nodes and links for D3.js
|
150 |
+
"""
|
151 |
+
if not self.current_data:
|
152 |
+
return {"nodes": [], "links": []}
|
153 |
+
|
154 |
+
input_tokens = self.current_data['input_tokens']
|
155 |
+
output_tokens = self.current_data['output_tokens']
|
156 |
+
attention_matrices = self.current_data['attention_matrices']
|
157 |
+
|
158 |
+
# Ensure step_idx is within bounds
|
159 |
+
if step_idx >= len(attention_matrices):
|
160 |
+
step_idx = len(attention_matrices) - 1
|
161 |
+
|
162 |
+
attention_matrix = attention_matrices[step_idx]
|
163 |
+
|
164 |
+
# Create nodes
|
165 |
+
nodes = []
|
166 |
+
|
167 |
+
# Add input nodes
|
168 |
+
for i, token in enumerate(input_tokens):
|
169 |
+
nodes.append({
|
170 |
+
"id": f"input_{i}",
|
171 |
+
"token": token,
|
172 |
+
"type": "input",
|
173 |
+
"index": i
|
174 |
+
})
|
175 |
+
|
176 |
+
# Add output nodes (up to current step)
|
177 |
+
for i in range(step_idx + 1):
|
178 |
+
if i < len(output_tokens):
|
179 |
+
nodes.append({
|
180 |
+
"id": f"output_{i}",
|
181 |
+
"token": output_tokens[i],
|
182 |
+
"type": "output",
|
183 |
+
"index": i
|
184 |
+
})
|
185 |
+
|
186 |
+
# Parse filter token
|
187 |
+
filter_type = None
|
188 |
+
filter_idx = None
|
189 |
+
if filter_token and filter_token != "All tokens":
|
190 |
+
if filter_token.startswith("[IN] "):
|
191 |
+
filter_type = "input"
|
192 |
+
filter_token_text = filter_token[5:] # Remove "[IN] " prefix
|
193 |
+
filter_idx = next((i for i, token in enumerate(input_tokens) if token == filter_token_text), None)
|
194 |
+
elif filter_token.startswith("[OUT] "):
|
195 |
+
filter_type = "output"
|
196 |
+
filter_token_text = filter_token[6:] # Remove "[OUT] " prefix
|
197 |
+
filter_idx = next((i for i, token in enumerate(output_tokens) if token == filter_token_text), None)
|
198 |
+
|
199 |
+
# Create links from attention matrices - show ALL steps up to current step
|
200 |
+
links = []
|
201 |
+
|
202 |
+
# Show connections for all steps up to and including step_idx
|
203 |
+
for current_step in range(step_idx + 1):
|
204 |
+
if current_step < len(attention_matrices):
|
205 |
+
step_attention = attention_matrices[current_step]
|
206 |
+
|
207 |
+
# Links from input tokens to this output token
|
208 |
+
input_attention = step_attention['input_attention']
|
209 |
+
if input_attention is not None:
|
210 |
+
for input_idx in range(len(input_tokens)):
|
211 |
+
if input_idx < len(input_attention): # Check bounds
|
212 |
+
weight = float(input_attention[input_idx])
|
213 |
+
if weight >= threshold:
|
214 |
+
# Apply filtering
|
215 |
+
show_link = True
|
216 |
+
if filter_type == "input" and filter_idx is not None:
|
217 |
+
# Only show connections involving the selected input token
|
218 |
+
show_link = (input_idx == filter_idx)
|
219 |
+
elif filter_type == "output" and filter_idx is not None:
|
220 |
+
# Only show connections involving the selected output token
|
221 |
+
show_link = (current_step == filter_idx)
|
222 |
+
|
223 |
+
if show_link:
|
224 |
+
links.append({
|
225 |
+
"source": f"input_{input_idx}",
|
226 |
+
"target": f"output_{current_step}",
|
227 |
+
"weight": weight,
|
228 |
+
"type": "input_to_output"
|
229 |
+
})
|
230 |
+
|
231 |
+
# Links from previous output tokens to this output token
|
232 |
+
output_attention = step_attention['output_attention']
|
233 |
+
if output_attention is not None and current_step > 0:
|
234 |
+
for prev_output_idx in range(current_step):
|
235 |
+
if prev_output_idx < len(output_attention): # Check bounds
|
236 |
+
weight = float(output_attention[prev_output_idx])
|
237 |
+
if weight >= threshold:
|
238 |
+
# Apply filtering
|
239 |
+
show_link = True
|
240 |
+
if filter_type == "input" and filter_idx is not None:
|
241 |
+
# Don't show output-to-output connections when filtering by input
|
242 |
+
show_link = False
|
243 |
+
elif filter_type == "output" and filter_idx is not None:
|
244 |
+
# Only show connections involving the selected output token
|
245 |
+
show_link = (prev_output_idx == filter_idx or current_step == filter_idx)
|
246 |
+
|
247 |
+
if show_link:
|
248 |
+
links.append({
|
249 |
+
"source": f"output_{prev_output_idx}",
|
250 |
+
"target": f"output_{current_step}",
|
251 |
+
"weight": weight,
|
252 |
+
"type": "output_to_output"
|
253 |
+
})
|
254 |
+
|
255 |
+
return {
|
256 |
+
"nodes": nodes,
|
257 |
+
"links": links,
|
258 |
+
"step": step_idx,
|
259 |
+
"total_steps": len(attention_matrices),
|
260 |
+
"input_count": len(input_tokens),
|
261 |
+
"output_count": step_idx + 1
|
262 |
+
}
|
263 |
+
|
264 |
+
def create_d3_visualization_html(self, step_idx: int = 0, threshold: float = 0.01, filter_token: str = None):
|
265 |
+
"""
|
266 |
+
Create D3.js visualization HTML for the current data.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
step_idx: Generation step to visualize (0-based)
|
270 |
+
threshold: Minimum attention weight to include
|
271 |
+
filter_token: Token to filter by (format: "[IN] token" or "[OUT] token")
|
272 |
+
|
273 |
+
Returns:
|
274 |
+
str: HTML string for D3.js visualization
|
275 |
+
"""
|
276 |
+
if not self.current_data:
|
277 |
+
return "<div>No data available. Generate text first!</div>"
|
278 |
+
|
279 |
+
d3_data = self.prepare_d3_data(step_idx, threshold, filter_token)
|
280 |
+
|
281 |
+
viz_html = create_d3_visualization(d3_data)
|
282 |
+
return viz_html
|
283 |
+
|
284 |
+
def get_token_choices(self):
|
285 |
+
"""
|
286 |
+
Get list of token choices for dropdown.
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
list: List of token strings for dropdown options
|
290 |
+
"""
|
291 |
+
if not self.current_data:
|
292 |
+
return []
|
293 |
+
|
294 |
+
input_tokens = self.current_data['input_tokens']
|
295 |
+
output_tokens = self.current_data['output_tokens']
|
296 |
+
|
297 |
+
# Create choices with prefixes to distinguish input/output
|
298 |
+
choices = ["All tokens"]
|
299 |
+
choices.extend([f"[IN] {token}" for token in input_tokens])
|
300 |
+
choices.extend([f"[OUT] {token}" for token in output_tokens])
|
301 |
+
|
302 |
+
return choices
|
303 |
+
|
304 |
+
|
305 |
+
def create_gradio_interface():
|
306 |
+
"""Create the Gradio interface."""
|
307 |
+
app = TokenVisualizerApp()
|
308 |
+
|
309 |
+
with gr.Blocks(
|
310 |
+
title="Token Attention Visualizer",
|
311 |
+
css="""
|
312 |
+
/* Default/Light mode styles */
|
313 |
+
.main-header {
|
314 |
+
text-align: center;
|
315 |
+
padding: 2rem 0 3rem 0;
|
316 |
+
background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%);
|
317 |
+
border-radius: 1rem;
|
318 |
+
margin-bottom: 2rem;
|
319 |
+
border: 1px solid #e2e8f0;
|
320 |
+
}
|
321 |
+
|
322 |
+
.main-title {
|
323 |
+
font-size: 2.5rem;
|
324 |
+
font-weight: 700;
|
325 |
+
color: #1e293b;
|
326 |
+
margin-bottom: 0.5rem;
|
327 |
+
background: linear-gradient(135deg, #1e293b 0%, #3b82f6 100%);
|
328 |
+
-webkit-background-clip: text;
|
329 |
+
-webkit-text-fill-color: transparent;
|
330 |
+
background-clip: text;
|
331 |
+
}
|
332 |
+
|
333 |
+
.main-subtitle {
|
334 |
+
font-size: 1.125rem;
|
335 |
+
color: #64748b;
|
336 |
+
font-weight: 400;
|
337 |
+
}
|
338 |
+
|
339 |
+
.section-title {
|
340 |
+
font-size: 1.25rem;
|
341 |
+
font-weight: 600;
|
342 |
+
color: #1e293b;
|
343 |
+
margin-bottom: 1.5rem;
|
344 |
+
padding-bottom: 0.5rem;
|
345 |
+
border-bottom: 2px solid #e2e8f0;
|
346 |
+
}
|
347 |
+
|
348 |
+
/* Explicit light mode overrides */
|
349 |
+
.light .main-header,
|
350 |
+
[data-theme="light"] .main-header {
|
351 |
+
background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%);
|
352 |
+
border: 1px solid #e2e8f0;
|
353 |
+
}
|
354 |
+
|
355 |
+
.light .main-title,
|
356 |
+
[data-theme="light"] .main-title {
|
357 |
+
color: #1e293b;
|
358 |
+
background: linear-gradient(135deg, #1e293b 0%, #3b82f6 100%);
|
359 |
+
-webkit-background-clip: text;
|
360 |
+
-webkit-text-fill-color: transparent;
|
361 |
+
background-clip: text;
|
362 |
+
}
|
363 |
+
|
364 |
+
.light .main-subtitle,
|
365 |
+
[data-theme="light"] .main-subtitle {
|
366 |
+
color: #64748b;
|
367 |
+
}
|
368 |
+
|
369 |
+
.light .section-title,
|
370 |
+
[data-theme="light"] .section-title {
|
371 |
+
color: #1e293b;
|
372 |
+
border-bottom: 2px solid #e2e8f0;
|
373 |
+
}
|
374 |
+
|
375 |
+
/* Dark mode styles with higher specificity */
|
376 |
+
.dark .main-header,
|
377 |
+
[data-theme="dark"] .main-header {
|
378 |
+
background: linear-gradient(135deg, #1e293b 0%, #334155 100%) !important;
|
379 |
+
border: 1px solid #475569 !important;
|
380 |
+
}
|
381 |
+
|
382 |
+
.dark .main-title,
|
383 |
+
[data-theme="dark"] .main-title {
|
384 |
+
color: #f1f5f9 !important;
|
385 |
+
background: linear-gradient(135deg, #f1f5f9 0%, #60a5fa 100%) !important;
|
386 |
+
-webkit-background-clip: text !important;
|
387 |
+
-webkit-text-fill-color: transparent !important;
|
388 |
+
background-clip: text !important;
|
389 |
+
}
|
390 |
+
|
391 |
+
.dark .main-subtitle,
|
392 |
+
[data-theme="dark"] .main-subtitle {
|
393 |
+
color: #cbd5e1 !important;
|
394 |
+
}
|
395 |
+
|
396 |
+
.dark .section-title,
|
397 |
+
[data-theme="dark"] .section-title {
|
398 |
+
color: #f1f5f9 !important;
|
399 |
+
border-bottom: 2px solid #475569 !important;
|
400 |
+
}
|
401 |
+
|
402 |
+
/* System dark mode - only apply when no explicit theme is set */
|
403 |
+
@media (prefers-color-scheme: dark) {
|
404 |
+
:root:not([data-theme="light"]) .main-header {
|
405 |
+
background: linear-gradient(135deg, #1e293b 0%, #334155 100%);
|
406 |
+
border: 1px solid #475569;
|
407 |
+
}
|
408 |
+
|
409 |
+
:root:not([data-theme="light"]) .main-title {
|
410 |
+
color: #f1f5f9;
|
411 |
+
background: linear-gradient(135deg, #f1f5f9 0%, #60a5fa 100%);
|
412 |
+
-webkit-background-clip: text;
|
413 |
+
-webkit-text-fill-color: transparent;
|
414 |
+
background-clip: text;
|
415 |
+
}
|
416 |
+
|
417 |
+
:root:not([data-theme="light"]) .main-subtitle {
|
418 |
+
color: #cbd5e1;
|
419 |
+
}
|
420 |
+
|
421 |
+
:root:not([data-theme="light"]) .section-title {
|
422 |
+
color: #f1f5f9;
|
423 |
+
border-bottom: 2px solid #475569;
|
424 |
+
}
|
425 |
+
}
|
426 |
+
|
427 |
+
.load-model-btn {
|
428 |
+
background: linear-gradient(135deg, #f97316 0%, #ea580c 100%) !important;
|
429 |
+
color: white !important;
|
430 |
+
border: none !important;
|
431 |
+
font-weight: 600 !important;
|
432 |
+
padding: 0.75rem 2rem !important;
|
433 |
+
border-radius: 0.5rem !important;
|
434 |
+
box-shadow: 0 4px 6px -1px rgba(249, 115, 22, 0.25) !important;
|
435 |
+
transition: all 0.2s ease !important;
|
436 |
+
}
|
437 |
+
|
438 |
+
.load-model-btn:hover {
|
439 |
+
background: linear-gradient(135deg, #ea580c 0%, #dc2626 100%) !important;
|
440 |
+
transform: translateY(-1px) !important;
|
441 |
+
box-shadow: 0 6px 8px -1px rgba(249, 115, 22, 0.35) !important;
|
442 |
+
}
|
443 |
+
"""
|
444 |
+
) as demo:
|
445 |
+
gr.HTML("""
|
446 |
+
<div class="main-header">
|
447 |
+
<h1 class="main-title">Token Attention Visualizer</h1>
|
448 |
+
<p class="main-subtitle">Interactive visualization of attention patterns in Large Language Models</p>
|
449 |
+
</div>
|
450 |
+
""")
|
451 |
+
|
452 |
+
with gr.Row():
|
453 |
+
# Left Panel - Controls
|
454 |
+
with gr.Column(scale=1):
|
455 |
+
gr.HTML('<h2 class="section-title">Model & Generation</h2>')
|
456 |
+
|
457 |
+
# Model loading
|
458 |
+
model_input = gr.Textbox(
|
459 |
+
label="Model Name",
|
460 |
+
value=app.config.DEFAULT_MODEL,
|
461 |
+
placeholder="Enter Hugging Face model name..."
|
462 |
+
)
|
463 |
+
load_model_btn = gr.Button("Load Model", variant="primary", elem_classes=["load-model-btn"])
|
464 |
+
|
465 |
+
model_status = gr.Textbox(
|
466 |
+
label="Model Status",
|
467 |
+
value="No model loaded",
|
468 |
+
interactive=False,
|
469 |
+
lines=2
|
470 |
+
)
|
471 |
+
|
472 |
+
# Generation controls
|
473 |
+
prompt_input = gr.Textbox(
|
474 |
+
label="Prompt",
|
475 |
+
value=app.config.DEFAULT_PROMPT,
|
476 |
+
lines=3,
|
477 |
+
placeholder="Enter your prompt here..."
|
478 |
+
)
|
479 |
+
|
480 |
+
max_tokens_input = gr.Slider(
|
481 |
+
minimum=1,
|
482 |
+
maximum=50,
|
483 |
+
value=app.config.DEFAULT_MAX_TOKENS,
|
484 |
+
step=1,
|
485 |
+
label="Max Tokens"
|
486 |
+
)
|
487 |
+
|
488 |
+
temperature_input = gr.Slider(
|
489 |
+
minimum=0.0,
|
490 |
+
maximum=2.0,
|
491 |
+
value=app.config.DEFAULT_TEMPERATURE,
|
492 |
+
step=0.1,
|
493 |
+
label="Temperature"
|
494 |
+
)
|
495 |
+
|
496 |
+
generate_btn = gr.Button("Generate", variant="primary", size="lg")
|
497 |
+
|
498 |
+
generated_info = gr.Textbox(
|
499 |
+
label="Generation Info",
|
500 |
+
interactive=False,
|
501 |
+
lines=4
|
502 |
+
)
|
503 |
+
|
504 |
+
gr.HTML('<h2 class="section-title">Visualization Controls</h2>')
|
505 |
+
|
506 |
+
step_slider = gr.Slider(
|
507 |
+
minimum=0,
|
508 |
+
maximum=10,
|
509 |
+
value=0,
|
510 |
+
step=1,
|
511 |
+
label="Generation Step",
|
512 |
+
info="Navigate through generation steps"
|
513 |
+
)
|
514 |
+
|
515 |
+
threshold_slider = gr.Slider(
|
516 |
+
minimum=0.001,
|
517 |
+
maximum=0.5,
|
518 |
+
value=0.01,
|
519 |
+
step=0.001,
|
520 |
+
label="Attention Threshold",
|
521 |
+
info="Filter weak connections"
|
522 |
+
)
|
523 |
+
|
524 |
+
token_dropdown = gr.Dropdown(
|
525 |
+
choices=["All tokens"],
|
526 |
+
value="All tokens",
|
527 |
+
label="Filter by Token",
|
528 |
+
info="Select a token to highlight"
|
529 |
+
)
|
530 |
+
|
531 |
+
# Right Panel - Visualization
|
532 |
+
with gr.Column(scale=2):
|
533 |
+
gr.HTML('<h2 class="section-title">Attention Visualization</h2>')
|
534 |
+
|
535 |
+
d3_visualization = gr.HTML(
|
536 |
+
value="""<div style='height: 700px; display: flex; align-items: center; justify-content: center; font-size: 16px;'>
|
537 |
+
<div style='text-align: center;'>
|
538 |
+
<div style='font-size: 3rem; margin-bottom: 16px; opacity: 0.5;'>⚪</div>
|
539 |
+
<div style='font-weight: 500; margin-bottom: 8px;'>Ready to visualize</div>
|
540 |
+
<div>Generate text to see attention patterns</div>
|
541 |
+
</div>
|
542 |
+
</div>"""
|
543 |
+
)
|
544 |
+
|
545 |
+
# (Visualization output and overlay removed)
|
546 |
+
|
547 |
+
# Instructions
|
548 |
+
with gr.Accordion("📖 How to Use", open=False):
|
549 |
+
gr.Markdown(
|
550 |
+
"""
|
551 |
+
### Instructions:
|
552 |
+
1. **Load a model** from Hugging Face (default: Llama-3.2-1B)
|
553 |
+
2. **Enter a prompt** and configure generation settings
|
554 |
+
3. **Click Generate** to create text and visualize attention
|
555 |
+
4. **Interact with the visualization:**
|
556 |
+
- Use the **step slider** to navigate through generation steps
|
557 |
+
- Adjust the **threshold** to filter weak connections
|
558 |
+
- Click on **tokens** in the plot to filter their connections
|
559 |
+
- Click **Reset View** to show all connections
|
560 |
+
|
561 |
+
### Understanding the Visualization:
|
562 |
+
- **Blue lines**: Attention from input to output tokens
|
563 |
+
- **Orange curves**: Attention between output tokens
|
564 |
+
- **Line thickness**: Represents attention weight strength
|
565 |
+
- **Node colors**: Blue = input tokens, Coral = generated tokens
|
566 |
+
"""
|
567 |
+
)
|
568 |
+
|
569 |
+
# Event handlers
|
570 |
+
load_model_btn.click(
|
571 |
+
fn=app.load_model,
|
572 |
+
inputs=[model_input],
|
573 |
+
outputs=[model_status]
|
574 |
+
)
|
575 |
+
|
576 |
+
def _generate(prompt, max_tokens, threshold, temperature):
|
577 |
+
info, = app.generate_and_visualize(
|
578 |
+
prompt, max_tokens, threshold, temperature, "separate" # Always use separate normalization
|
579 |
+
)
|
580 |
+
|
581 |
+
# Update visualization and dropdown choices
|
582 |
+
max_steps = len(app.current_data['attention_matrices']) - 1 if app.current_data else 0
|
583 |
+
viz_html = app.create_d3_visualization_html(step_idx=max_steps, threshold=0.01) # Start with last step
|
584 |
+
token_choices = app.get_token_choices()
|
585 |
+
|
586 |
+
return info, viz_html, gr.update(choices=token_choices, value="All tokens"), gr.update(maximum=max_steps, value=max_steps)
|
587 |
+
|
588 |
+
generate_btn.click(
|
589 |
+
fn=_generate,
|
590 |
+
inputs=[
|
591 |
+
prompt_input,
|
592 |
+
max_tokens_input,
|
593 |
+
gr.State(app.config.DEFAULT_THRESHOLD), # keep threshold in call but unused
|
594 |
+
temperature_input
|
595 |
+
],
|
596 |
+
outputs=[generated_info, d3_visualization, token_dropdown, step_slider]
|
597 |
+
)
|
598 |
+
|
599 |
+
# Event handlers for visualization controls
|
600 |
+
def _update_visualization(step_idx, threshold, filter_token="All tokens"):
|
601 |
+
"""Update visualization when step or threshold changes."""
|
602 |
+
viz_html = app.create_d3_visualization_html(step_idx=int(step_idx), threshold=threshold, filter_token=filter_token)
|
603 |
+
return viz_html
|
604 |
+
|
605 |
+
def _filter_by_token(selected_token, step_idx, threshold):
|
606 |
+
"""Update visualization when token filter changes."""
|
607 |
+
viz_html = app.create_d3_visualization_html(step_idx=int(step_idx), threshold=threshold, filter_token=selected_token)
|
608 |
+
return viz_html
|
609 |
+
|
610 |
+
# Connect visualization controls
|
611 |
+
step_slider.change(
|
612 |
+
fn=_update_visualization,
|
613 |
+
inputs=[step_slider, threshold_slider, token_dropdown],
|
614 |
+
outputs=[d3_visualization]
|
615 |
+
)
|
616 |
+
|
617 |
+
threshold_slider.change(
|
618 |
+
fn=_update_visualization,
|
619 |
+
inputs=[step_slider, threshold_slider, token_dropdown],
|
620 |
+
outputs=[d3_visualization]
|
621 |
+
)
|
622 |
+
|
623 |
+
token_dropdown.change(
|
624 |
+
fn=_filter_by_token,
|
625 |
+
inputs=[token_dropdown, step_slider, threshold_slider],
|
626 |
+
outputs=[d3_visualization]
|
627 |
+
)
|
628 |
+
|
629 |
+
|
630 |
+
|
631 |
+
# Load default model on startup
|
632 |
+
demo.load(
|
633 |
+
fn=app.load_model,
|
634 |
+
inputs=[gr.State(app.config.DEFAULT_MODEL)],
|
635 |
+
outputs=[model_status]
|
636 |
+
)
|
637 |
+
|
638 |
+
return demo
|
639 |
+
|
640 |
+
if __name__ == "__main__":
|
641 |
+
# Check if CUDA is available
|
642 |
+
if torch.cuda.is_available():
|
643 |
+
print(f"✅ CUDA available: {torch.cuda.get_device_name(0)}")
|
644 |
+
else:
|
645 |
+
print("⚠️ CUDA not available, using CPU")
|
646 |
+
|
647 |
+
# Create and launch the app
|
648 |
+
demo = create_gradio_interface()
|
649 |
+
""" demo.launch(
|
650 |
+
share=False, # Set to True for public URL
|
651 |
+
server_name="0.0.0.0", # Allow external connections
|
652 |
+
server_port=7860, # Default Gradio port
|
653 |
+
inbrowser=False # Don't auto-open browser
|
654 |
+
) """
|
655 |
+
|
656 |
+
demo.launch()
|
claude.md
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Claude Code Instructions - Token Attention Visualizer
|
2 |
+
|
3 |
+
## Project Overview
|
4 |
+
You are helping to build a Token Attention Visualizer - a web-based tool that visualizes attention weights in Large Language Models (LLMs) during text generation. The tool shows how input tokens influence the generation of output tokens through interactive visualizations.
|
5 |
+
|
6 |
+
## Core Functionality
|
7 |
+
1. Accept a text prompt and generate tokens using a Llama model
|
8 |
+
2. Extract and process attention matrices from the model
|
9 |
+
3. Create an interactive visualization showing token relationships
|
10 |
+
4. Allow users to click tokens to filter connections
|
11 |
+
5. Provide step-by-step navigation through the generation process
|
12 |
+
|
13 |
+
## Tech Stack
|
14 |
+
- **Backend**: FastAPI
|
15 |
+
- **Frontend**: Gradio (for easy Hugging Face Spaces deployment)
|
16 |
+
- **Visualization**: Plotly (interactive graphs)
|
17 |
+
- **ML**: Transformers, PyTorch
|
18 |
+
- **Models**: Llama models (1B-3B range)
|
19 |
+
|
20 |
+
## Project Structure
|
21 |
+
```
|
22 |
+
token-attention-viz/
|
23 |
+
├── app.py # Main Gradio application
|
24 |
+
├── api/
|
25 |
+
│ ├── __init__.py
|
26 |
+
│ ├── server.py # FastAPI endpoints (optional)
|
27 |
+
│ └── models.py # Pydantic models
|
28 |
+
├── core/
|
29 |
+
│ ├── __init__.py
|
30 |
+
│ ├── model_handler.py # Model loading and generation
|
31 |
+
│ ├── attention.py # Attention processing
|
32 |
+
│ └── cache.py # Caching logic
|
33 |
+
├── visualization/
|
34 |
+
│ ├── __init__.py
|
35 |
+
│ ├── plotly_viz.py # Plotly visualization
|
36 |
+
│ └── utils.py # Token cleaning utilities
|
37 |
+
├── requirements.txt
|
38 |
+
└── config.py # Configuration settings
|
39 |
+
```
|
40 |
+
|
41 |
+
## Implementation Guidelines
|
42 |
+
|
43 |
+
### Critical Code to Preserve from Original Implementation
|
44 |
+
|
45 |
+
1. **Model Loading Logic**:
|
46 |
+
- Device and dtype detection based on GPU capability
|
47 |
+
- Pad token handling for models without it
|
48 |
+
- Error handling for model loading
|
49 |
+
|
50 |
+
2. **Attention Extraction** :
|
51 |
+
- BOS token removal from visualization
|
52 |
+
- EOS token handling
|
53 |
+
- Attention matrix extraction with proper indexing
|
54 |
+
|
55 |
+
3. **Token Cleaning Function**:
|
56 |
+
```python
|
57 |
+
def clean_label(token):
|
58 |
+
label = str(token)
|
59 |
+
label = label.replace('Ġ', ' ')
|
60 |
+
label = label.replace('▁', ' ')
|
61 |
+
label = label.replace('Ċ', '\\n')
|
62 |
+
label = label.replace('</s>', '[EOS]')
|
63 |
+
label = label.replace('<unk>', '[UNK]')
|
64 |
+
label = label.replace('<|begin_of_text|>', '[BOS]')
|
65 |
+
label = label.replace('<|end_of_text|>', '[EOS]')
|
66 |
+
label = re.sub(r'<0x[0-9A-Fa-f]{2}>', '', label)
|
67 |
+
return label.strip() if label.strip() else "[EMPTY]"
|
68 |
+
```
|
69 |
+
|
70 |
+
4. **Attention Processing with Separate Normalization**:
|
71 |
+
- Layer averaging across heads and layers
|
72 |
+
- Separate normalization for input and output attention
|
73 |
+
- Epsilon handling (1e-8) to avoid division by zero
|
74 |
+
|
75 |
+
5. **Interactive Features**:
|
76 |
+
- Token click handling to show specific connections
|
77 |
+
- Reset selection functionality
|
78 |
+
- Step-by-step navigation
|
79 |
+
- "All Connections" view
|
80 |
+
|
81 |
+
### Key Implementation Details
|
82 |
+
|
83 |
+
#### Model Handler (`core/model_handler.py`)
|
84 |
+
- Use `unsloth/Llama-3.2-1B-Instruct` as default model
|
85 |
+
- Implement proper device detection (CUDA if available)
|
86 |
+
- Use bfloat16 for GPUs with compute capability >= 8.0
|
87 |
+
- Generate with `output_attentions=True` and `return_dict_in_generate=True`
|
88 |
+
|
89 |
+
#### Attention Processing (`core/attention.py`)
|
90 |
+
- Extract attention for each generation step
|
91 |
+
- Average across all layers and heads
|
92 |
+
- Apply separate normalization (input and output attention normalized independently)
|
93 |
+
- Handle edge cases (first token has no output-to-output attention)
|
94 |
+
|
95 |
+
#### Visualization (`visualization/plotly_viz.py`)
|
96 |
+
- **Layout**:
|
97 |
+
- Input tokens on left (x=0.1)
|
98 |
+
- Output tokens on right (x=0.9)
|
99 |
+
- Use linspace for y-coordinates
|
100 |
+
- **Connections**:
|
101 |
+
- Blue lines for input→output attention
|
102 |
+
- Orange curved lines for output→output attention
|
103 |
+
- Line thickness proportional to attention weight
|
104 |
+
- Only show connections above threshold
|
105 |
+
- **Interactivity**:
|
106 |
+
- Click on any token to filter connections
|
107 |
+
- Highlight selected token in yellow
|
108 |
+
- Show previously generated tokens in pink
|
109 |
+
- Current generating token in coral
|
110 |
+
|
111 |
+
#### Gradio Interface (`app.py`)
|
112 |
+
- **Input Controls**:
|
113 |
+
- Text area for prompt
|
114 |
+
- Slider for max tokens (1-50)
|
115 |
+
- Slider for attention threshold (0.0-0.2, step 0.001)
|
116 |
+
- **Visualization Controls**:
|
117 |
+
- Step slider for navigation
|
118 |
+
- Reset Selection button
|
119 |
+
- Show All Connections button
|
120 |
+
- **Display**:
|
121 |
+
- Generated text output
|
122 |
+
- Interactive Plotly graph
|
123 |
+
|
124 |
+
### Performance Optimizations
|
125 |
+
|
126 |
+
1. **Caching**:
|
127 |
+
- Cache generated attention matrices by prompt+max_tokens hash
|
128 |
+
- LRU cache with configurable size (default 10)
|
129 |
+
- Store processed attention, not raw tensors
|
130 |
+
|
131 |
+
2. **Lazy Updates**:
|
132 |
+
- Only update changed traces when stepping through
|
133 |
+
- Don't recreate entire plot on threshold change
|
134 |
+
- Use Plotly's batch_update for multiple changes
|
135 |
+
|
136 |
+
3. **Memory Management**:
|
137 |
+
- Clear raw attention tensors after processing
|
138 |
+
- Convert to CPU tensors for storage
|
139 |
+
- Use float32 instead of original dtype for visualization
|
140 |
+
|
141 |
+
### Configuration (`config.py`)
|
142 |
+
```python
|
143 |
+
DEFAULT_MODEL = "unsloth/Llama-3.2-1B-Instruct"
|
144 |
+
DEFAULT_PROMPT = "The old wizard walked through the forest"
|
145 |
+
DEFAULT_MAX_TOKENS = 20
|
146 |
+
DEFAULT_THRESHOLD = 0.05
|
147 |
+
MIN_LINE_WIDTH = 0.5
|
148 |
+
MAX_LINE_WIDTH = 3.0
|
149 |
+
PLOT_WIDTH = 1000
|
150 |
+
PLOT_HEIGHT = 600
|
151 |
+
```
|
152 |
+
|
153 |
+
### Deployment Preparation
|
154 |
+
|
155 |
+
For Hugging Face Spaces deployment:
|
156 |
+
1. Create proper `requirements.txt` with pinned versions
|
157 |
+
2. Add `README.md` with Spaces metadata
|
158 |
+
3. Ensure model downloads work in Spaces environment
|
159 |
+
4. Set appropriate memory/GPU requirements
|
160 |
+
|
161 |
+
## Testing Instructions
|
162 |
+
|
163 |
+
1. **Basic Functionality**:
|
164 |
+
- Test with default prompt
|
165 |
+
- Verify attention matrices are extracted correctly
|
166 |
+
- Check visualization renders properly
|
167 |
+
|
168 |
+
2. **Interactive Features**:
|
169 |
+
- Click on input tokens - should show only their connections to outputs
|
170 |
+
- Click on output tokens - should show incoming connections
|
171 |
+
- Reset button should clear selection
|
172 |
+
- Step slider should navigate through generation
|
173 |
+
|
174 |
+
3. **Edge Cases**:
|
175 |
+
- Empty prompt
|
176 |
+
- Single token generation
|
177 |
+
- Very long prompts (>100 tokens)
|
178 |
+
- High/low threshold values
|
179 |
+
|
180 |
+
## Development Workflow
|
181 |
+
|
182 |
+
1. Start by implementing the model handler and verify generation works
|
183 |
+
2. Add attention extraction and processing
|
184 |
+
3. Create basic visualization without interactivity
|
185 |
+
4. Add interactive features one by one
|
186 |
+
5. Implement caching
|
187 |
+
6. Create Gradio interface
|
188 |
+
7. Test and optimize performance
|
189 |
+
8. Prepare for deployment
|
190 |
+
|
191 |
+
## Important Notes
|
192 |
+
|
193 |
+
- Preserve the token cleaning logic exactly as it handles special tokens
|
194 |
+
- Keep the BOS token removal logic for cleaner visualization
|
195 |
+
- Maintain separate normalization (not joint) for attention weights
|
196 |
+
- Ensure CUDA memory is properly managed to avoid OOM errors
|
197 |
+
- Test with different model sizes based on available GPU memory
|
198 |
+
|
199 |
+
## Common Issues and Solutions
|
200 |
+
|
201 |
+
1. **CUDA OOM**: Reduce batch size or use smaller model
|
202 |
+
2. **Slow Generation**: Enable GPU, use smaller model, or implement streaming
|
203 |
+
3. **Visualization Lag**: Reduce number of traces, implement virtualization
|
204 |
+
4. **Cache Misses**: Normalize prompt formatting before hashing
|
205 |
+
|
206 |
+
When implementing, prioritize functionality over optimization initially. Get the core visualization working first, then add caching and performance improvements.
|
config.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class Config:
|
6 |
+
# Model settings
|
7 |
+
DEFAULT_MODEL: str = "HuggingFaceTB/SmolLM-135M-Instruct"
|
8 |
+
DEVICE: str = "cpu" # Force CPU usage
|
9 |
+
|
10 |
+
# Generation settings
|
11 |
+
DEFAULT_MAX_TOKENS: int = 20
|
12 |
+
DEFAULT_PROMPT: str = "The old wizard walked through the forest when he"
|
13 |
+
DEFAULT_TEMPERATURE: float = 0.7
|
14 |
+
DEFAULT_TOP_P: float = 0.95
|
15 |
+
|
16 |
+
# Visualization settings
|
17 |
+
DEFAULT_THRESHOLD: float = 0.05
|
18 |
+
MIN_LINE_WIDTH: float = 0.5
|
19 |
+
MAX_LINE_WIDTH: float = 3.0
|
20 |
+
|
21 |
+
# Colors
|
22 |
+
INPUT_COLOR: str = "skyblue"
|
23 |
+
OUTPUT_COLOR: str = "coral"
|
24 |
+
CONNECTION_COLOR: str = "rgba(128, 128, 128, 0.3)"
|
25 |
+
|
26 |
+
# Cache settings
|
27 |
+
CACHE_SIZE: int = 10 # Number of generations to cache
|
28 |
+
|
29 |
+
# UI settings
|
30 |
+
PLOT_WIDTH: int = 1000
|
31 |
+
PLOT_HEIGHT: int = 600
|
32 |
+
|
33 |
+
# Node settings
|
34 |
+
NODE_SIZE: int = 15
|
35 |
+
NODE_LINE_WIDTH: float = 2
|
36 |
+
|
37 |
+
# Font settings
|
38 |
+
FONT_SIZE: int = 10
|
39 |
+
FONT_FAMILY: str = "Arial, sans-serif"
|
core/__init__.py
ADDED
File without changes
|
core/attention.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from typing import List, Dict, Any, Optional
|
4 |
+
|
5 |
+
class AttentionProcessor:
|
6 |
+
@staticmethod
|
7 |
+
def process_attention_separate(
|
8 |
+
attention_data: Dict[str, Any],
|
9 |
+
input_tokens: List[str],
|
10 |
+
output_tokens: List[str]
|
11 |
+
) -> List[Dict[str, torch.Tensor]]:
|
12 |
+
"""
|
13 |
+
Process attention with separate normalization for input and output.
|
14 |
+
This preserves the relative importance within each group.
|
15 |
+
"""
|
16 |
+
attentions = attention_data['attentions']
|
17 |
+
input_len_for_attention = attention_data['input_len_for_attention']
|
18 |
+
output_len = attention_data['output_len']
|
19 |
+
|
20 |
+
if not attentions:
|
21 |
+
return [{'input_attention': torch.zeros(input_len_for_attention),
|
22 |
+
'output_attention': None} for _ in range(output_len)]
|
23 |
+
|
24 |
+
attention_matrices = []
|
25 |
+
num_steps = len(attentions)
|
26 |
+
|
27 |
+
if num_steps == 0:
|
28 |
+
print("Warning: No attention steps found in output.")
|
29 |
+
return [{'input_attention': torch.zeros(input_len_for_attention),
|
30 |
+
'output_attention': None} for _ in range(output_len)]
|
31 |
+
|
32 |
+
steps_to_process = min(num_steps, output_len)
|
33 |
+
|
34 |
+
for i in range(steps_to_process):
|
35 |
+
step_attentions = attentions[i]
|
36 |
+
input_attention_layers = []
|
37 |
+
output_attention_layers = []
|
38 |
+
|
39 |
+
for layer_idx, layer_attn in enumerate(step_attentions):
|
40 |
+
try:
|
41 |
+
# Extract attention to input tokens (skip BOS token at position 0)
|
42 |
+
input_indices = slice(1, 1 + input_len_for_attention)
|
43 |
+
if layer_attn.shape[3] >= input_indices.stop:
|
44 |
+
# Get attention from current token (position 0 in generation) to input
|
45 |
+
input_attn = layer_attn[0, :, 0, input_indices]
|
46 |
+
input_attention_layers.append(input_attn)
|
47 |
+
|
48 |
+
# Extract attention to previous output tokens
|
49 |
+
if i > 0:
|
50 |
+
output_indices = slice(1 + input_len_for_attention, 1 + input_len_for_attention + i)
|
51 |
+
if layer_attn.shape[3] >= output_indices.stop:
|
52 |
+
output_attn = layer_attn[0, :, 0, output_indices]
|
53 |
+
output_attention_layers.append(output_attn)
|
54 |
+
else:
|
55 |
+
output_attention_layers.append(
|
56 |
+
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
input_attention_layers.append(
|
60 |
+
torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device)
|
61 |
+
)
|
62 |
+
if i > 0:
|
63 |
+
output_attention_layers.append(
|
64 |
+
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
|
65 |
+
)
|
66 |
+
|
67 |
+
except Exception as e:
|
68 |
+
print(f"Error processing attention at step {i}, layer {layer_idx}: {e}")
|
69 |
+
input_attention_layers.append(
|
70 |
+
torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device)
|
71 |
+
)
|
72 |
+
if i > 0:
|
73 |
+
output_attention_layers.append(
|
74 |
+
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
|
75 |
+
)
|
76 |
+
|
77 |
+
# Average across layers and heads
|
78 |
+
if input_attention_layers:
|
79 |
+
avg_input_attn = torch.mean(torch.stack(input_attention_layers).float(), dim=[0, 1])
|
80 |
+
else:
|
81 |
+
avg_input_attn = torch.zeros(input_len_for_attention)
|
82 |
+
|
83 |
+
avg_output_attn = None
|
84 |
+
if i > 0 and output_attention_layers:
|
85 |
+
avg_output_attn = torch.mean(torch.stack(output_attention_layers).float(), dim=[0, 1])
|
86 |
+
elif i > 0:
|
87 |
+
avg_output_attn = torch.zeros(i)
|
88 |
+
|
89 |
+
# Normalize separately with epsilon for numerical stability
|
90 |
+
epsilon = 1e-8
|
91 |
+
input_sum = avg_input_attn.sum() + epsilon
|
92 |
+
normalized_input_attn = avg_input_attn / input_sum
|
93 |
+
|
94 |
+
normalized_output_attn = None
|
95 |
+
if i > 0 and avg_output_attn is not None:
|
96 |
+
output_sum = avg_output_attn.sum() + epsilon
|
97 |
+
normalized_output_attn = avg_output_attn / output_sum
|
98 |
+
|
99 |
+
attention_matrices.append({
|
100 |
+
'input_attention': normalized_input_attn.cpu(),
|
101 |
+
'output_attention': normalized_output_attn.cpu() if normalized_output_attn is not None else None,
|
102 |
+
'raw_input_attention': avg_input_attn.cpu(), # Keep raw for analysis
|
103 |
+
'raw_output_attention': avg_output_attn.cpu() if avg_output_attn is not None else None
|
104 |
+
})
|
105 |
+
|
106 |
+
# Fill remaining steps with zeros if needed
|
107 |
+
while len(attention_matrices) < output_len:
|
108 |
+
attention_matrices.append({
|
109 |
+
'input_attention': torch.zeros(input_len_for_attention),
|
110 |
+
'output_attention': None,
|
111 |
+
'raw_input_attention': torch.zeros(input_len_for_attention),
|
112 |
+
'raw_output_attention': None
|
113 |
+
})
|
114 |
+
|
115 |
+
return attention_matrices
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def process_attention_joint(
|
119 |
+
attention_data: Dict[str, Any],
|
120 |
+
input_tokens: List[str],
|
121 |
+
output_tokens: List[str]
|
122 |
+
) -> List[Dict[str, torch.Tensor]]:
|
123 |
+
"""
|
124 |
+
Process attention with joint normalization across input and output.
|
125 |
+
This preserves the relative importance across all tokens.
|
126 |
+
"""
|
127 |
+
attentions = attention_data['attentions']
|
128 |
+
input_len_for_attention = attention_data['input_len_for_attention']
|
129 |
+
output_len = attention_data['output_len']
|
130 |
+
|
131 |
+
if not attentions:
|
132 |
+
return [{'input_attention': torch.zeros(input_len_for_attention),
|
133 |
+
'output_attention': None} for _ in range(output_len)]
|
134 |
+
|
135 |
+
attention_matrices = []
|
136 |
+
num_steps = len(attentions)
|
137 |
+
|
138 |
+
if num_steps == 0:
|
139 |
+
print("Warning: No attention steps found in output.")
|
140 |
+
return [{'input_attention': torch.zeros(input_len_for_attention),
|
141 |
+
'output_attention': None} for _ in range(output_len)]
|
142 |
+
|
143 |
+
steps_to_process = min(num_steps, output_len)
|
144 |
+
|
145 |
+
for i in range(steps_to_process):
|
146 |
+
step_attentions = attentions[i]
|
147 |
+
input_attention_layers = []
|
148 |
+
output_attention_layers = []
|
149 |
+
|
150 |
+
for layer_idx, layer_attn in enumerate(step_attentions):
|
151 |
+
try:
|
152 |
+
# Extract attention to input tokens
|
153 |
+
input_indices = slice(1, 1 + input_len_for_attention)
|
154 |
+
if layer_attn.shape[3] >= input_indices.stop:
|
155 |
+
input_attn = layer_attn[0, :, 0, input_indices]
|
156 |
+
input_attention_layers.append(input_attn)
|
157 |
+
|
158 |
+
# Extract attention to previous output tokens
|
159 |
+
if i > 0:
|
160 |
+
output_indices = slice(1 + input_len_for_attention, 1 + input_len_for_attention + i)
|
161 |
+
if layer_attn.shape[3] >= output_indices.stop:
|
162 |
+
output_attn = layer_attn[0, :, 0, output_indices]
|
163 |
+
output_attention_layers.append(output_attn)
|
164 |
+
else:
|
165 |
+
output_attention_layers.append(
|
166 |
+
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
input_attention_layers.append(
|
170 |
+
torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device)
|
171 |
+
)
|
172 |
+
if i > 0:
|
173 |
+
output_attention_layers.append(
|
174 |
+
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
|
175 |
+
)
|
176 |
+
|
177 |
+
except Exception as e:
|
178 |
+
print(f"Error processing attention at step {i}, layer {layer_idx}: {e}")
|
179 |
+
input_attention_layers.append(
|
180 |
+
torch.zeros((layer_attn.shape[1], input_len_for_attention), device=layer_attn.device)
|
181 |
+
)
|
182 |
+
if i > 0:
|
183 |
+
output_attention_layers.append(
|
184 |
+
torch.zeros((layer_attn.shape[1], i), device=layer_attn.device)
|
185 |
+
)
|
186 |
+
|
187 |
+
# Average across layers and heads
|
188 |
+
if input_attention_layers:
|
189 |
+
avg_input_attn = torch.mean(torch.stack(input_attention_layers).float(), dim=[0, 1])
|
190 |
+
else:
|
191 |
+
avg_input_attn = torch.zeros(input_len_for_attention)
|
192 |
+
|
193 |
+
avg_output_attn = None
|
194 |
+
if i > 0 and output_attention_layers:
|
195 |
+
avg_output_attn = torch.mean(torch.stack(output_attention_layers).float(), dim=[0, 1])
|
196 |
+
elif i > 0:
|
197 |
+
avg_output_attn = torch.zeros(i)
|
198 |
+
|
199 |
+
# Joint normalization
|
200 |
+
epsilon = 1e-8
|
201 |
+
if i > 0 and avg_output_attn is not None:
|
202 |
+
# Concatenate and normalize together
|
203 |
+
combined_attn = torch.cat([avg_input_attn, avg_output_attn])
|
204 |
+
sum_attn = combined_attn.sum() + epsilon
|
205 |
+
normalized_combined = combined_attn / sum_attn
|
206 |
+
normalized_input_attn = normalized_combined[:input_len_for_attention]
|
207 |
+
normalized_output_attn = normalized_combined[input_len_for_attention:]
|
208 |
+
else:
|
209 |
+
# Only input attention available
|
210 |
+
sum_attn = avg_input_attn.sum() + epsilon
|
211 |
+
normalized_input_attn = avg_input_attn / sum_attn
|
212 |
+
normalized_output_attn = None
|
213 |
+
|
214 |
+
attention_matrices.append({
|
215 |
+
'input_attention': normalized_input_attn.cpu(),
|
216 |
+
'output_attention': normalized_output_attn.cpu() if normalized_output_attn is not None else None
|
217 |
+
})
|
218 |
+
|
219 |
+
# Fill remaining steps with zeros if needed
|
220 |
+
while len(attention_matrices) < output_len:
|
221 |
+
attention_matrices.append({
|
222 |
+
'input_attention': torch.zeros(input_len_for_attention),
|
223 |
+
'output_attention': None
|
224 |
+
})
|
225 |
+
|
226 |
+
return attention_matrices
|
227 |
+
|
228 |
+
@staticmethod
|
229 |
+
def extract_attention_for_step(
|
230 |
+
attention_data: Dict[str, Any],
|
231 |
+
step: int,
|
232 |
+
input_len: int
|
233 |
+
) -> Dict[str, torch.Tensor]:
|
234 |
+
"""
|
235 |
+
Extract attention weights for a specific generation step.
|
236 |
+
Optimized to only process the needed step.
|
237 |
+
"""
|
238 |
+
attentions = attention_data['attentions']
|
239 |
+
|
240 |
+
if step >= len(attentions):
|
241 |
+
return {
|
242 |
+
'input_attention': torch.zeros(input_len),
|
243 |
+
'output_attention': None
|
244 |
+
}
|
245 |
+
|
246 |
+
step_attentions = attentions[step]
|
247 |
+
input_attention_layers = []
|
248 |
+
output_attention_layers = []
|
249 |
+
|
250 |
+
for layer_attn in step_attentions:
|
251 |
+
# Extract input attention
|
252 |
+
input_indices = slice(1, 1 + input_len)
|
253 |
+
if layer_attn.shape[3] >= input_indices.stop:
|
254 |
+
input_attn = layer_attn[0, :, 0, input_indices]
|
255 |
+
input_attention_layers.append(input_attn)
|
256 |
+
|
257 |
+
# Extract output attention if there are previous outputs
|
258 |
+
if step > 0:
|
259 |
+
output_indices = slice(1 + input_len, 1 + input_len + step)
|
260 |
+
if layer_attn.shape[3] >= output_indices.stop:
|
261 |
+
output_attn = layer_attn[0, :, 0, output_indices]
|
262 |
+
output_attention_layers.append(output_attn)
|
263 |
+
|
264 |
+
# Average and normalize
|
265 |
+
if input_attention_layers:
|
266 |
+
avg_input = torch.mean(torch.stack(input_attention_layers).float(), dim=[0, 1])
|
267 |
+
normalized_input = avg_input / (avg_input.sum() + 1e-8)
|
268 |
+
else:
|
269 |
+
normalized_input = torch.zeros(input_len)
|
270 |
+
|
271 |
+
normalized_output = None
|
272 |
+
if step > 0 and output_attention_layers:
|
273 |
+
avg_output = torch.mean(torch.stack(output_attention_layers).float(), dim=[0, 1])
|
274 |
+
normalized_output = avg_output / (avg_output.sum() + 1e-8)
|
275 |
+
|
276 |
+
return {
|
277 |
+
'input_attention': normalized_input.cpu(),
|
278 |
+
'output_attention': normalized_output.cpu() if normalized_output is not None else None
|
279 |
+
}
|
core/cache.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any, Optional
|
2 |
+
import hashlib
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import pickle
|
6 |
+
import io
|
7 |
+
|
8 |
+
class AttentionCache:
|
9 |
+
def __init__(self, max_size: int = 10):
|
10 |
+
self.cache = {}
|
11 |
+
self.access_order = []
|
12 |
+
self.max_size = max_size
|
13 |
+
|
14 |
+
def get_key(self, prompt: str, max_tokens: int, model: str, temperature: float = 0.7) -> str:
|
15 |
+
"""Generate cache key from parameters"""
|
16 |
+
data = f"{prompt}_{max_tokens}_{model}_{temperature}"
|
17 |
+
return hashlib.md5(data.encode()).hexdigest()
|
18 |
+
|
19 |
+
def get(self, key: str) -> Optional[Dict[str, Any]]:
|
20 |
+
"""Retrieve cached data"""
|
21 |
+
if key in self.cache:
|
22 |
+
# Move to end (LRU)
|
23 |
+
self.access_order.remove(key)
|
24 |
+
self.access_order.append(key)
|
25 |
+
return self._deserialize(self.cache[key])
|
26 |
+
return None
|
27 |
+
|
28 |
+
def set(self, key: str, data: Dict[str, Any]):
|
29 |
+
"""Store data in cache"""
|
30 |
+
if len(self.cache) >= self.max_size:
|
31 |
+
# Remove least recently used
|
32 |
+
oldest = self.access_order.pop(0)
|
33 |
+
del self.cache[oldest]
|
34 |
+
|
35 |
+
self.cache[key] = self._serialize(data)
|
36 |
+
self.access_order.append(key)
|
37 |
+
|
38 |
+
def _serialize(self, data: Dict[str, Any]) -> bytes:
|
39 |
+
"""Serialize data for caching, handling torch tensors"""
|
40 |
+
serialized = {}
|
41 |
+
for key, value in data.items():
|
42 |
+
if isinstance(value, list) and len(value) > 0:
|
43 |
+
# Check if it's a list of dicts with tensors (attention matrices)
|
44 |
+
if isinstance(value[0], dict) and any(isinstance(v, torch.Tensor) for v in value[0].values()):
|
45 |
+
# Convert tensors to CPU and serialize
|
46 |
+
serialized_list = []
|
47 |
+
for item in value:
|
48 |
+
serialized_item = {}
|
49 |
+
for k, v in item.items():
|
50 |
+
if isinstance(v, torch.Tensor):
|
51 |
+
serialized_item[k] = v.cpu().numpy()
|
52 |
+
else:
|
53 |
+
serialized_item[k] = v
|
54 |
+
serialized_list.append(serialized_item)
|
55 |
+
serialized[key] = serialized_list
|
56 |
+
else:
|
57 |
+
serialized[key] = value
|
58 |
+
else:
|
59 |
+
serialized[key] = value
|
60 |
+
|
61 |
+
buffer = io.BytesIO()
|
62 |
+
pickle.dump(serialized, buffer)
|
63 |
+
return buffer.getvalue()
|
64 |
+
|
65 |
+
def _deserialize(self, data: bytes) -> Dict[str, Any]:
|
66 |
+
"""Deserialize data from cache, restoring torch tensors"""
|
67 |
+
buffer = io.BytesIO(data)
|
68 |
+
deserialized = pickle.load(buffer)
|
69 |
+
|
70 |
+
# Convert numpy arrays back to tensors where needed
|
71 |
+
for key, value in deserialized.items():
|
72 |
+
if isinstance(value, list) and len(value) > 0:
|
73 |
+
if isinstance(value[0], dict):
|
74 |
+
# Check if it contains numpy arrays (was tensors)
|
75 |
+
import numpy as np
|
76 |
+
for item in value:
|
77 |
+
for k, v in item.items():
|
78 |
+
if isinstance(v, np.ndarray):
|
79 |
+
item[k] = torch.from_numpy(v)
|
80 |
+
|
81 |
+
return deserialized
|
82 |
+
|
83 |
+
def clear(self):
|
84 |
+
"""Clear the entire cache"""
|
85 |
+
self.cache.clear()
|
86 |
+
self.access_order.clear()
|
87 |
+
|
88 |
+
def size(self) -> int:
|
89 |
+
"""Get current cache size"""
|
90 |
+
return len(self.cache)
|
core/model_handler.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
from typing import Tuple, Optional, List, Dict, Any
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
warnings.filterwarnings("ignore", category=UserWarning, module='transformers.generation')
|
7 |
+
|
8 |
+
class ModelHandler:
|
9 |
+
def __init__(self, model_name: str = None, config=None):
|
10 |
+
self.model = None
|
11 |
+
self.tokenizer = None
|
12 |
+
self.device = None
|
13 |
+
self.model_name = model_name
|
14 |
+
self.config = config
|
15 |
+
|
16 |
+
def load_model(self, model_name: str = None) -> Tuple[bool, str]:
|
17 |
+
"""Load model with optimized settings"""
|
18 |
+
if model_name:
|
19 |
+
self.model_name = model_name
|
20 |
+
|
21 |
+
if not self.model_name:
|
22 |
+
return False, "No model name provided"
|
23 |
+
|
24 |
+
try:
|
25 |
+
print(f"Loading model: {self.model_name}...")
|
26 |
+
|
27 |
+
# Load tokenizer
|
28 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
29 |
+
|
30 |
+
# Determine device and dtype
|
31 |
+
if self.config and hasattr(self.config, 'DEVICE'):
|
32 |
+
self.device = self.config.DEVICE
|
33 |
+
# If config specifies CPU, force it even if CUDA is available
|
34 |
+
if self.device == "cpu":
|
35 |
+
print("Forcing CPU usage as specified in config")
|
36 |
+
elif self.device == "cuda" and not torch.cuda.is_available():
|
37 |
+
print("CUDA requested but not available, falling back to CPU")
|
38 |
+
self.device = "cpu"
|
39 |
+
else:
|
40 |
+
# Fallback to auto-detection if no config provided
|
41 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
42 |
+
|
43 |
+
# Use bfloat16 for Ampere GPUs (compute capability >= 8.0), otherwise float32
|
44 |
+
if self.device == "cuda" and torch.cuda.is_available():
|
45 |
+
capability = torch.cuda.get_device_capability()
|
46 |
+
if capability[0] >= 8:
|
47 |
+
dtype = torch.bfloat16
|
48 |
+
else:
|
49 |
+
dtype = torch.float32
|
50 |
+
else:
|
51 |
+
dtype = torch.float32
|
52 |
+
|
53 |
+
# Load model
|
54 |
+
try:
|
55 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
56 |
+
self.model_name,
|
57 |
+
torch_dtype=dtype,
|
58 |
+
attn_implementation="eager" # Force eager attention for attention extraction
|
59 |
+
).to(self.device)
|
60 |
+
print(f"Model loaded on {self.device} with dtype {dtype} (eager attention)")
|
61 |
+
except Exception as e:
|
62 |
+
print(f"Error loading model with specific dtype: {e}")
|
63 |
+
print("Attempting to load without specific dtype...")
|
64 |
+
try:
|
65 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
66 |
+
self.model_name,
|
67 |
+
attn_implementation="eager"
|
68 |
+
).to(self.device)
|
69 |
+
print(f"Model loaded on {self.device} (default dtype, eager attention)")
|
70 |
+
except Exception as e2:
|
71 |
+
print(f"Error with eager attention: {e2}")
|
72 |
+
print("Loading with default settings...")
|
73 |
+
self.model = AutoModelForCausalLM.from_pretrained(self.model_name).to(self.device)
|
74 |
+
print(f"Model loaded on {self.device} (default settings)")
|
75 |
+
|
76 |
+
# Handle pad token
|
77 |
+
if self.tokenizer.pad_token is None:
|
78 |
+
if self.tokenizer.eos_token:
|
79 |
+
print("Setting pad_token to eos_token")
|
80 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
81 |
+
if hasattr(self.model.config, 'pad_token_id') and self.model.config.pad_token_id is None:
|
82 |
+
self.model.config.pad_token_id = self.tokenizer.eos_token_id
|
83 |
+
else:
|
84 |
+
print("Warning: No eos_token found to set as pad_token.")
|
85 |
+
|
86 |
+
return True, f"Model loaded successfully on {self.device}"
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
return False, f"Error loading model: {str(e)}"
|
90 |
+
|
91 |
+
def generate_with_attention(
|
92 |
+
self,
|
93 |
+
prompt: str,
|
94 |
+
max_tokens: int = 30,
|
95 |
+
temperature: float = 0.7,
|
96 |
+
top_p: float = 0.95
|
97 |
+
) -> Tuple[Optional[List], List[str], List[str], str]:
|
98 |
+
"""
|
99 |
+
Generate text and capture attention weights
|
100 |
+
Returns: (attention_matrices, output_tokens, input_tokens, generated_text)
|
101 |
+
"""
|
102 |
+
if not self.model or not self.tokenizer:
|
103 |
+
return None, [], [], "Model not loaded"
|
104 |
+
|
105 |
+
# Encode input
|
106 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
107 |
+
input_len_raw = input_ids.shape[1]
|
108 |
+
|
109 |
+
print(f"Generating with input length: {input_len_raw}, max_new_tokens: {max_tokens}")
|
110 |
+
|
111 |
+
# Generate with attention
|
112 |
+
with torch.no_grad():
|
113 |
+
attention_mask = torch.ones_like(input_ids)
|
114 |
+
gen_kwargs = {
|
115 |
+
"attention_mask": attention_mask,
|
116 |
+
"max_new_tokens": max_tokens,
|
117 |
+
"output_attentions": True,
|
118 |
+
"return_dict_in_generate": True,
|
119 |
+
"temperature": temperature,
|
120 |
+
"top_p": top_p,
|
121 |
+
"do_sample": temperature > 0
|
122 |
+
}
|
123 |
+
|
124 |
+
if self.tokenizer.pad_token_id is not None:
|
125 |
+
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
|
126 |
+
|
127 |
+
try:
|
128 |
+
output = self.model.generate(input_ids, **gen_kwargs)
|
129 |
+
except Exception as e:
|
130 |
+
print(f"Error during generation: {e}")
|
131 |
+
return None, [], [], f"Error during generation: {str(e)}"
|
132 |
+
|
133 |
+
# Extract generated tokens
|
134 |
+
full_sequence = output.sequences[0]
|
135 |
+
if full_sequence.shape[0] > input_len_raw:
|
136 |
+
generated_ids = full_sequence[input_len_raw:]
|
137 |
+
else:
|
138 |
+
generated_ids = torch.tensor([], dtype=torch.long, device=self.device)
|
139 |
+
|
140 |
+
# Convert to tokens
|
141 |
+
output_tokens = self.tokenizer.convert_ids_to_tokens(generated_ids, skip_special_tokens=False)
|
142 |
+
input_tokens_raw = self.tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
|
143 |
+
|
144 |
+
# Handle BOS token removal from visualization
|
145 |
+
input_tokens = input_tokens_raw
|
146 |
+
input_len_for_attention = input_len_raw
|
147 |
+
bos_token = self.tokenizer.bos_token or '<|begin_of_text|>'
|
148 |
+
|
149 |
+
if input_tokens_raw and input_tokens_raw[0] == bos_token:
|
150 |
+
input_tokens = input_tokens_raw[1:]
|
151 |
+
input_len_for_attention = input_len_raw - 1
|
152 |
+
|
153 |
+
# Handle EOS token removal
|
154 |
+
eos_token = self.tokenizer.eos_token or '<|end_of_text|>'
|
155 |
+
if output_tokens and output_tokens[-1] == eos_token:
|
156 |
+
output_tokens = output_tokens[:-1]
|
157 |
+
generated_ids = generated_ids[:-1]
|
158 |
+
|
159 |
+
# Decode generated text
|
160 |
+
generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
161 |
+
|
162 |
+
# Extract attention weights
|
163 |
+
attentions = getattr(output, 'attentions', None)
|
164 |
+
if attentions is None:
|
165 |
+
print("Warning: 'attentions' not found in model output. Cannot visualize attention.")
|
166 |
+
return None, output_tokens, input_tokens, generated_text
|
167 |
+
|
168 |
+
# Return raw attention, tokens, and metadata
|
169 |
+
return {
|
170 |
+
'attentions': attentions,
|
171 |
+
'input_len_for_attention': input_len_for_attention,
|
172 |
+
'output_len': len(output_tokens)
|
173 |
+
}, output_tokens, input_tokens, generated_text
|
174 |
+
|
175 |
+
def get_model_info(self) -> Dict[str, Any]:
|
176 |
+
"""Get information about the loaded model"""
|
177 |
+
if not self.model:
|
178 |
+
return {"loaded": False}
|
179 |
+
|
180 |
+
return {
|
181 |
+
"loaded": True,
|
182 |
+
"model_name": self.model_name,
|
183 |
+
"device": str(self.device),
|
184 |
+
"num_parameters": sum(p.numel() for p in self.model.parameters()),
|
185 |
+
"dtype": str(next(self.model.parameters()).dtype),
|
186 |
+
"vocab_size": self.tokenizer.vocab_size if self.tokenizer else 0
|
187 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
transformers>=4.30.0
|
3 |
+
gradio>=4.0.0
|
4 |
+
plotly>=5.14.0
|
5 |
+
numpy>=1.24.0
|
6 |
+
accelerate>=0.20.0
|
7 |
+
sentencepiece>=0.1.99
|
8 |
+
protobuf>=3.20.0
|
visualization/__init__.py
ADDED
File without changes
|
visualization/d3_viz.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
D3.js visualization module for interactive token attention visualization.
|
3 |
+
"""
|
4 |
+
|
5 |
+
def create_d3_visualization(data):
|
6 |
+
"""
|
7 |
+
Generate a complete, self-contained HTML string with embedded D3.js visualization.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
data (dict): JSON structure with nodes and links from prepare_d3_data()
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
str: Complete HTML string with embedded D3.js, CSS, and JavaScript
|
14 |
+
"""
|
15 |
+
|
16 |
+
# Get nodes by type
|
17 |
+
input_nodes = [node for node in data.get('nodes', []) if node.get('type') == 'input']
|
18 |
+
output_nodes = [node for node in data.get('nodes', []) if node.get('type') == 'output']
|
19 |
+
links = data.get('links', [])
|
20 |
+
|
21 |
+
# SVG dimensions
|
22 |
+
width = 800
|
23 |
+
height = max(400, max(len(input_nodes), len(output_nodes)) * 50 + 100)
|
24 |
+
|
25 |
+
# Calculate positions
|
26 |
+
input_x = 100
|
27 |
+
output_x = width - 100
|
28 |
+
|
29 |
+
# Position nodes vertically
|
30 |
+
def get_y_pos(index, total):
|
31 |
+
if total <= 1:
|
32 |
+
return height // 2
|
33 |
+
return 80 + (index * (height - 160)) / (total - 1)
|
34 |
+
|
35 |
+
# Start building SVG
|
36 |
+
svg_html = f"""
|
37 |
+
<div style='display: flex; flex-direction: column; align-items: center; border: 1px solid #ddd; padding: 20px; margin: 10px; background: white; border-radius: 8px;'>
|
38 |
+
<div style='text-align: center; margin-bottom: 15px;'>
|
39 |
+
<h3 style='margin: 0; color: #333;'>Token Attention Visualization</h3>
|
40 |
+
<p style='margin: 5px 0; color: #666;'>Step {data.get('step', 0) + 1} | {len(input_nodes)} input → {len(output_nodes)} output | {len(links)} connections</p>
|
41 |
+
</div>
|
42 |
+
|
43 |
+
<svg width="{width}" height="{height}" style='border: 1px solid #eee; background: #fafafa; display: block;'>
|
44 |
+
<!-- Background grid -->
|
45 |
+
<defs>
|
46 |
+
<pattern id="grid" width="20" height="20" patternUnits="userSpaceOnUse">
|
47 |
+
<path d="M 20 0 L 0 0 0 20" fill="none" stroke="#f0f0f0" stroke-width="1"/>
|
48 |
+
</pattern>
|
49 |
+
</defs>
|
50 |
+
<rect width="100%" height="100%" fill="url(#grid)" />
|
51 |
+
|
52 |
+
<!-- Column headers -->
|
53 |
+
<text x="{input_x}" y="30" text-anchor="middle" font-size="16" font-weight="bold" fill="#4285f4">Input Tokens</text>
|
54 |
+
<text x="{output_x}" y="30" text-anchor="middle" font-size="16" font-weight="bold" fill="#ea4335">Output Tokens</text>
|
55 |
+
"""
|
56 |
+
|
57 |
+
# Draw connections first (so they appear behind nodes)
|
58 |
+
for link in links:
|
59 |
+
# Find source and target nodes
|
60 |
+
source_node = next((n for n in input_nodes + output_nodes if n['id'] == link['source']), None)
|
61 |
+
target_node = next((n for n in input_nodes + output_nodes if n['id'] == link['target']), None)
|
62 |
+
|
63 |
+
if source_node and target_node:
|
64 |
+
# Get positions
|
65 |
+
if source_node['type'] == 'input':
|
66 |
+
source_idx = next((i for i, n in enumerate(input_nodes) if n['id'] == source_node['id']), 0)
|
67 |
+
source_y = get_y_pos(source_idx, len(input_nodes))
|
68 |
+
source_x_pos = input_x + 20 # Offset from center of node
|
69 |
+
else:
|
70 |
+
source_idx = next((i for i, n in enumerate(output_nodes) if n['id'] == source_node['id']), 0)
|
71 |
+
source_y = get_y_pos(source_idx, len(output_nodes))
|
72 |
+
source_x_pos = output_x - 20
|
73 |
+
|
74 |
+
if target_node['type'] == 'input':
|
75 |
+
target_idx = next((i for i, n in enumerate(input_nodes) if n['id'] == target_node['id']), 0)
|
76 |
+
target_y = get_y_pos(target_idx, len(input_nodes))
|
77 |
+
target_x_pos = input_x - 20
|
78 |
+
else:
|
79 |
+
target_idx = next((i for i, n in enumerate(output_nodes) if n['id'] == target_node['id']), 0)
|
80 |
+
target_y = get_y_pos(target_idx, len(output_nodes))
|
81 |
+
target_x_pos = output_x - 20
|
82 |
+
|
83 |
+
# Line properties based on weight
|
84 |
+
stroke_width = max(1, min(8, link['weight'] * 20))
|
85 |
+
opacity = max(0.3, min(1.0, link['weight'] * 2))
|
86 |
+
color = "#4285f4" if link['type'] == 'input_to_output' else "#ea4335"
|
87 |
+
|
88 |
+
# Create straight line
|
89 |
+
svg_html += f'''
|
90 |
+
<line x1="{source_x_pos}" y1="{source_y}" x2="{target_x_pos}" y2="{target_y}"
|
91 |
+
stroke="{color}" stroke-width="{stroke_width}" opacity="{opacity}"/>
|
92 |
+
'''
|
93 |
+
|
94 |
+
# Draw input nodes
|
95 |
+
for i, node in enumerate(input_nodes):
|
96 |
+
y = get_y_pos(i, len(input_nodes))
|
97 |
+
token_text = node['token']
|
98 |
+
|
99 |
+
# Clean token text - remove special prefix characters
|
100 |
+
if token_text.startswith('Ġ'):
|
101 |
+
token_text = token_text[1:] # Remove Ġ prefix
|
102 |
+
if token_text.startswith('▁'):
|
103 |
+
token_text = token_text[1:] # Remove ▁ prefix (SentencePiece)
|
104 |
+
if token_text.startswith('##'):
|
105 |
+
token_text = token_text[2:] # Remove ## prefix (BERT subwords)
|
106 |
+
|
107 |
+
if len(token_text) > 15:
|
108 |
+
token_text = token_text[:13] + "..."
|
109 |
+
|
110 |
+
svg_html += f'''
|
111 |
+
<g>
|
112 |
+
<circle cx="{input_x}" cy="{y}" r="12" fill="#4285f4" stroke="#1a73e8" stroke-width="2" opacity="0.9"/>
|
113 |
+
<text x="{input_x - 20}" y="{y + 4}" text-anchor="end" font-size="12" fill="#333" font-weight="bold">{token_text}</text>
|
114 |
+
</g>
|
115 |
+
'''
|
116 |
+
|
117 |
+
# Draw output nodes
|
118 |
+
for i, node in enumerate(output_nodes):
|
119 |
+
y = get_y_pos(i, len(output_nodes))
|
120 |
+
token_text = node['token']
|
121 |
+
|
122 |
+
# Clean token text - remove special prefix characters
|
123 |
+
if token_text.startswith('Ġ'):
|
124 |
+
token_text = token_text[1:] # Remove Ġ prefix
|
125 |
+
if token_text.startswith('▁'):
|
126 |
+
token_text = token_text[1:] # Remove ▁ prefix (SentencePiece)
|
127 |
+
if token_text.startswith('##'):
|
128 |
+
token_text = token_text[2:] # Remove ## prefix (BERT subwords)
|
129 |
+
|
130 |
+
if len(token_text) > 15:
|
131 |
+
token_text = token_text[:13] + "..."
|
132 |
+
|
133 |
+
svg_html += f'''
|
134 |
+
<g>
|
135 |
+
<circle cx="{output_x}" cy="{y}" r="12" fill="#ea4335" stroke="#d33b2c" stroke-width="2" opacity="0.9"/>
|
136 |
+
<text x="{output_x + 20}" y="{y + 4}" text-anchor="start" font-size="12" fill="#333" font-weight="bold">{token_text}</text>
|
137 |
+
</g>
|
138 |
+
'''
|
139 |
+
|
140 |
+
# Close SVG and add legend
|
141 |
+
svg_html += '''
|
142 |
+
</svg>
|
143 |
+
|
144 |
+
<div style='margin-top: 20px; padding: 16px; background: #f8fafc; border: 1px solid #e2e8f0; border-radius: 8px;'>
|
145 |
+
<div style='display: flex; justify-content: center; align-items: center; gap: 32px; font-size: 12px; color: #64748b; font-family: Inter, sans-serif;'>
|
146 |
+
<div style='display: flex; align-items: center; gap: 8px;'>
|
147 |
+
<div style='width: 16px; height: 2px; background: #4285f4; border-radius: 1px;'></div>
|
148 |
+
<span style='color: #1e293b; font-weight: 500;'>Input → Output</span>
|
149 |
+
</div>
|
150 |
+
<div style='display: flex; align-items: center; gap: 8px;'>
|
151 |
+
<div style='display: flex; gap: 2px;'>
|
152 |
+
<div style='width: 8px; height: 1px; background: #64748b;'></div>
|
153 |
+
<div style='width: 8px; height: 2px; background: #64748b;'></div>
|
154 |
+
<div style='width: 8px; height: 3px; background: #64748b;'></div>
|
155 |
+
</div>
|
156 |
+
<span style='color: #1e293b; font-weight: 500;'>Line thickness = weight</span>
|
157 |
+
</div>
|
158 |
+
</div>
|
159 |
+
</div>
|
160 |
+
</div>
|
161 |
+
'''
|
162 |
+
|
163 |
+
return svg_html
|
164 |
+
|
165 |
+
def create_d3_visualization_old(data):
|
166 |
+
"""
|
167 |
+
OLD VERSION - Generate a complete, self-contained HTML string with embedded D3.js visualization.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
data (dict): JSON structure with nodes and links from prepare_d3_data()
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
str: Complete HTML string with embedded D3.js, CSS, and JavaScript
|
174 |
+
"""
|
175 |
+
|
176 |
+
html_template = f"""
|
177 |
+
<!DOCTYPE html>
|
178 |
+
<html>
|
179 |
+
<head>
|
180 |
+
<meta charset="utf-8">
|
181 |
+
<style>
|
182 |
+
.visualization-container {{
|
183 |
+
width: 100%;
|
184 |
+
height: 600px;
|
185 |
+
border: 1px solid #ddd;
|
186 |
+
border-radius: 8px;
|
187 |
+
background: #fafafa;
|
188 |
+
position: relative;
|
189 |
+
overflow: hidden;
|
190 |
+
}}
|
191 |
+
|
192 |
+
.node {{
|
193 |
+
cursor: pointer;
|
194 |
+
stroke-width: 2px;
|
195 |
+
}}
|
196 |
+
|
197 |
+
.node.input {{
|
198 |
+
fill: #4285f4;
|
199 |
+
stroke: #1a73e8;
|
200 |
+
}}
|
201 |
+
|
202 |
+
.node.output {{
|
203 |
+
fill: #ea4335;
|
204 |
+
stroke: #d33b2c;
|
205 |
+
}}
|
206 |
+
|
207 |
+
.node.highlighted {{
|
208 |
+
stroke-width: 4px;
|
209 |
+
stroke: #ff6d00;
|
210 |
+
}}
|
211 |
+
|
212 |
+
.node.dimmed {{
|
213 |
+
opacity: 0.3;
|
214 |
+
}}
|
215 |
+
|
216 |
+
.link {{
|
217 |
+
stroke: #666;
|
218 |
+
stroke-opacity: 0.6;
|
219 |
+
fill: none;
|
220 |
+
}}
|
221 |
+
|
222 |
+
.link.input-to-output {{
|
223 |
+
stroke: #4285f4;
|
224 |
+
}}
|
225 |
+
|
226 |
+
.link.output-to-output {{
|
227 |
+
stroke: #ea4335;
|
228 |
+
}}
|
229 |
+
|
230 |
+
.link.highlighted {{
|
231 |
+
stroke-opacity: 1;
|
232 |
+
stroke-width: 3px;
|
233 |
+
}}
|
234 |
+
|
235 |
+
.link.dimmed {{
|
236 |
+
stroke-opacity: 0.1;
|
237 |
+
}}
|
238 |
+
|
239 |
+
.token-label {{
|
240 |
+
font-family: 'Courier New', monospace;
|
241 |
+
font-size: 12px;
|
242 |
+
text-anchor: middle;
|
243 |
+
dominant-baseline: central;
|
244 |
+
fill: white;
|
245 |
+
font-weight: bold;
|
246 |
+
pointer-events: none;
|
247 |
+
}}
|
248 |
+
|
249 |
+
.reset-btn {{
|
250 |
+
position: absolute;
|
251 |
+
top: 10px;
|
252 |
+
right: 10px;
|
253 |
+
padding: 8px 16px;
|
254 |
+
background: #4285f4;
|
255 |
+
color: white;
|
256 |
+
border: none;
|
257 |
+
border-radius: 4px;
|
258 |
+
cursor: pointer;
|
259 |
+
font-size: 12px;
|
260 |
+
z-index: 100;
|
261 |
+
}}
|
262 |
+
|
263 |
+
.reset-btn:hover {{
|
264 |
+
background: #1a73e8;
|
265 |
+
}}
|
266 |
+
|
267 |
+
.info-panel {{
|
268 |
+
position: absolute;
|
269 |
+
bottom: 10px;
|
270 |
+
left: 10px;
|
271 |
+
background: rgba(255, 255, 255, 0.9);
|
272 |
+
padding: 8px 12px;
|
273 |
+
border-radius: 4px;
|
274 |
+
font-size: 11px;
|
275 |
+
font-family: Arial, sans-serif;
|
276 |
+
border: 1px solid #ddd;
|
277 |
+
}}
|
278 |
+
</style>
|
279 |
+
</head>
|
280 |
+
<body>
|
281 |
+
<div class="visualization-container" id="viz-container">
|
282 |
+
<button class="reset-btn" onclick="resetView()">Reset View</button>
|
283 |
+
<div class="info-panel">
|
284 |
+
<div>Step: {data.get('step', 0) + 1} / {data.get('total_steps', 1)}</div>
|
285 |
+
<div>Nodes: {len(data.get('nodes', []))} | Links: {len(data.get('links', []))}</div>
|
286 |
+
<div>Click nodes to filter connections</div>
|
287 |
+
</div>
|
288 |
+
<svg id="visualization"></svg>
|
289 |
+
</div>
|
290 |
+
|
291 |
+
<script>
|
292 |
+
// Simple visualization without D3 first - just to test
|
293 |
+
const data = {repr(data)};
|
294 |
+
|
295 |
+
// Create simple HTML visualization
|
296 |
+
const container = document.getElementById("viz-container");
|
297 |
+
let html = "<div style='padding: 20px;'>";
|
298 |
+
html += "<h3>Debug Info</h3>";
|
299 |
+
html += "<p>Nodes: " + data.nodes.length + "</p>";
|
300 |
+
html += "<p>Links: " + data.links.length + "</p>";
|
301 |
+
|
302 |
+
// Simple SVG without D3
|
303 |
+
html += "<svg width='800' height='400' style='border: 1px solid #ccc; background: white;'>";
|
304 |
+
|
305 |
+
// Draw input nodes (left side)
|
306 |
+
const inputNodes = data.nodes.filter(n => n.type === "input");
|
307 |
+
const outputNodes = data.nodes.filter(n => n.type === "output");
|
308 |
+
|
309 |
+
inputNodes.forEach((node, i) => {{
|
310 |
+
const y = 50 + i * 40;
|
311 |
+
html += `<circle cx="50" cy="${{y}}" r="15" fill="#4285f4" stroke="#1a73e8" stroke-width="2"/>`;
|
312 |
+
html += `<text x="80" y="${{y + 5}}" font-size="12" fill="black">${{node.token}}</text>`;
|
313 |
+
}});
|
314 |
+
|
315 |
+
// Draw output nodes (right side)
|
316 |
+
outputNodes.forEach((node, i) => {{
|
317 |
+
const y = 50 + i * 40;
|
318 |
+
html += `<circle cx="700" cy="${{y}}" r="15" fill="#ea4335" stroke="#d33b2c" stroke-width="2"/>`;
|
319 |
+
html += `<text x="620" y="${{y + 5}}" font-size="12" fill="black" text-anchor="end">${{node.token}}</text>`;
|
320 |
+
}});
|
321 |
+
|
322 |
+
// Draw links
|
323 |
+
data.links.forEach(link => {{
|
324 |
+
const sourceNode = data.nodes.find(n => n.id === link.source);
|
325 |
+
const targetNode = data.nodes.find(n => n.id === link.target);
|
326 |
+
if (sourceNode && targetNode) {{
|
327 |
+
const sourceIdx = sourceNode.type === "input" ?
|
328 |
+
inputNodes.findIndex(n => n.id === sourceNode.id) :
|
329 |
+
outputNodes.findIndex(n => n.id === sourceNode.id);
|
330 |
+
const targetIdx = targetNode.type === "input" ?
|
331 |
+
inputNodes.findIndex(n => n.id === targetNode.id) :
|
332 |
+
outputNodes.findIndex(n => n.id === targetNode.id);
|
333 |
+
|
334 |
+
const sourceX = sourceNode.type === "input" ? 65 : 685;
|
335 |
+
const targetX = targetNode.type === "input" ? 65 : 685;
|
336 |
+
const sourceY = 50 + sourceIdx * 40;
|
337 |
+
const targetY = 50 + targetIdx * 40;
|
338 |
+
|
339 |
+
const strokeWidth = Math.max(1, link.weight * 10);
|
340 |
+
const color = link.type === "input_to_output" ? "#4285f4" : "#ea4335";
|
341 |
+
|
342 |
+
html += `<line x1="${{sourceX}}" y1="${{sourceY}}" x2="${{targetX}}" y2="${{targetY}}" stroke="${{color}}" stroke-width="${{strokeWidth}}" opacity="0.6"/>`;
|
343 |
+
}}
|
344 |
+
}});
|
345 |
+
|
346 |
+
html += "</svg>";
|
347 |
+
html += "</div>";
|
348 |
+
|
349 |
+
container.innerHTML = html;
|
350 |
+
|
351 |
+
</script>
|
352 |
+
</body>
|
353 |
+
</html>
|
354 |
+
"""
|
355 |
+
|
356 |
+
return html_template
|
visualization/plotly_viz.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import plotly.graph_objects as go
|
2 |
+
import numpy as np
|
3 |
+
from typing import List, Dict, Any, Optional, Tuple, Callable
|
4 |
+
from .utils import (
|
5 |
+
clean_label, scale_weight_to_width, scale_weight_to_opacity,
|
6 |
+
get_node_positions, create_spline_path, format_attention_text,
|
7 |
+
get_color_for_weight, truncate_token_label
|
8 |
+
)
|
9 |
+
|
10 |
+
class AttentionVisualizer:
|
11 |
+
def __init__(self, config):
|
12 |
+
self.config = config
|
13 |
+
self.current_state = {
|
14 |
+
'selected_token': None,
|
15 |
+
'selected_type': None,
|
16 |
+
'current_step': 0,
|
17 |
+
'show_all': True
|
18 |
+
}
|
19 |
+
self.traces_info = {
|
20 |
+
'input_to_output': [],
|
21 |
+
'output_to_output': [],
|
22 |
+
'input_nodes_idx': None,
|
23 |
+
'output_nodes_idx': None
|
24 |
+
}
|
25 |
+
|
26 |
+
def create_interactive_plot(
|
27 |
+
self,
|
28 |
+
input_tokens: List[str],
|
29 |
+
output_tokens: List[str],
|
30 |
+
attention_matrices: List[Dict],
|
31 |
+
threshold: float = 0.05,
|
32 |
+
initial_step: int = 0,
|
33 |
+
normalization: str = "separate"
|
34 |
+
) -> go.Figure:
|
35 |
+
"""
|
36 |
+
Create the main interactive visualization.
|
37 |
+
"""
|
38 |
+
# Clean labels
|
39 |
+
input_labels = [clean_label(token) for token in input_tokens]
|
40 |
+
output_labels = [clean_label(token) for token in output_tokens]
|
41 |
+
|
42 |
+
num_input = len(input_labels)
|
43 |
+
num_output = len(output_labels)
|
44 |
+
num_steps = len(attention_matrices)
|
45 |
+
|
46 |
+
if num_input == 0 or num_output == 0 or num_steps == 0:
|
47 |
+
return self._create_empty_figure("No data to visualize")
|
48 |
+
|
49 |
+
# Get node positions
|
50 |
+
input_x, input_y, output_x, output_y = get_node_positions(num_input, num_output)
|
51 |
+
|
52 |
+
# Create connection traces
|
53 |
+
traces = []
|
54 |
+
self.traces_info = {
|
55 |
+
'input_to_output': [],
|
56 |
+
'output_to_output': [],
|
57 |
+
'input_nodes_idx': None,
|
58 |
+
'output_nodes_idx': None
|
59 |
+
}
|
60 |
+
|
61 |
+
# Input to output connections
|
62 |
+
for j in range(num_output):
|
63 |
+
for i in range(num_input):
|
64 |
+
weight = 0
|
65 |
+
if j < len(attention_matrices):
|
66 |
+
weight = attention_matrices[j]['input_attention'][i].item()
|
67 |
+
|
68 |
+
opacity = scale_weight_to_opacity(weight, threshold=threshold)
|
69 |
+
width = scale_weight_to_width(weight) if opacity > 0 else 0.5
|
70 |
+
|
71 |
+
trace = go.Scatter(
|
72 |
+
x=[input_x[i], output_x[j]],
|
73 |
+
y=[input_y[i], output_y[j]],
|
74 |
+
mode="lines",
|
75 |
+
line=dict(
|
76 |
+
color=get_color_for_weight(weight, "blue"),
|
77 |
+
width=width
|
78 |
+
),
|
79 |
+
opacity=opacity,
|
80 |
+
showlegend=False,
|
81 |
+
hoverinfo='text',
|
82 |
+
text=format_attention_text(input_labels[i], output_labels[j], weight),
|
83 |
+
hoverlabel=dict(bgcolor="lightskyblue", bordercolor="darkblue"),
|
84 |
+
name=f"in_to_out_{i}_{j}",
|
85 |
+
customdata=[(i, j)],
|
86 |
+
hovertemplate="Input→Output %{customdata[0]}→%{customdata[1]}<extra></extra>"
|
87 |
+
)
|
88 |
+
traces.append(trace)
|
89 |
+
self.traces_info['input_to_output'].append({
|
90 |
+
'input_idx': i,
|
91 |
+
'output_idx': j,
|
92 |
+
'trace_idx': len(traces) - 1
|
93 |
+
})
|
94 |
+
|
95 |
+
# Output to output connections
|
96 |
+
for j in range(1, num_output):
|
97 |
+
for i in range(j):
|
98 |
+
weight = 0
|
99 |
+
if j < len(attention_matrices) and attention_matrices[j]['output_attention'] is not None:
|
100 |
+
if i < len(attention_matrices[j]['output_attention']):
|
101 |
+
weight = attention_matrices[j]['output_attention'][i].item()
|
102 |
+
|
103 |
+
opacity = scale_weight_to_opacity(weight, threshold=threshold)
|
104 |
+
width = scale_weight_to_width(weight) if opacity > 0 else 0.5
|
105 |
+
|
106 |
+
# Create spline path for curved connection
|
107 |
+
path_x, path_y = create_spline_path(
|
108 |
+
output_x[i], output_y[i],
|
109 |
+
output_x[j], output_y[j],
|
110 |
+
control_offset=0.15
|
111 |
+
)
|
112 |
+
|
113 |
+
trace = go.Scatter(
|
114 |
+
x=path_x,
|
115 |
+
y=path_y,
|
116 |
+
mode="lines",
|
117 |
+
line=dict(
|
118 |
+
color=get_color_for_weight(weight, "orange"),
|
119 |
+
width=width,
|
120 |
+
shape='spline'
|
121 |
+
),
|
122 |
+
opacity=opacity,
|
123 |
+
showlegend=False,
|
124 |
+
hoverinfo='text',
|
125 |
+
text=format_attention_text(output_labels[i], output_labels[j], weight),
|
126 |
+
hoverlabel=dict(bgcolor="moccasin", bordercolor="darkorange"),
|
127 |
+
name=f"out_to_out_{i}_{j}"
|
128 |
+
)
|
129 |
+
traces.append(trace)
|
130 |
+
self.traces_info['output_to_output'].append({
|
131 |
+
'from_idx': i,
|
132 |
+
'to_idx': j,
|
133 |
+
'trace_idx': len(traces) - 1
|
134 |
+
})
|
135 |
+
|
136 |
+
# Input nodes
|
137 |
+
input_trace = go.Scatter(
|
138 |
+
x=input_x,
|
139 |
+
y=input_y,
|
140 |
+
mode="markers+text",
|
141 |
+
marker=dict(
|
142 |
+
size=self.config.NODE_SIZE,
|
143 |
+
color=self.config.INPUT_COLOR,
|
144 |
+
line=dict(width=self.config.NODE_LINE_WIDTH, color="darkblue")
|
145 |
+
),
|
146 |
+
selected=dict(
|
147 |
+
marker=dict(
|
148 |
+
size=self.config.NODE_SIZE + 6,
|
149 |
+
color="rgba(0, 0, 200, 0.9)"
|
150 |
+
)
|
151 |
+
),
|
152 |
+
unselected=dict(
|
153 |
+
marker=dict(
|
154 |
+
opacity=0.65
|
155 |
+
)
|
156 |
+
),
|
157 |
+
text=[truncate_token_label(label) for label in input_labels],
|
158 |
+
textfont=dict(size=self.config.FONT_SIZE, family=self.config.FONT_FAMILY),
|
159 |
+
textposition="middle left",
|
160 |
+
name="Input Tokens",
|
161 |
+
hovertemplate="Input: %{text}<br>Click to filter connections<extra></extra>",
|
162 |
+
customdata=[(i, 'input') for i in range(num_input)]
|
163 |
+
)
|
164 |
+
traces.append(input_trace)
|
165 |
+
self.traces_info['input_nodes_idx'] = len(traces) - 1
|
166 |
+
|
167 |
+
# Output nodes
|
168 |
+
output_colors = []
|
169 |
+
for j in range(num_output):
|
170 |
+
if j <= initial_step:
|
171 |
+
output_colors.append(self.config.OUTPUT_COLOR)
|
172 |
+
else:
|
173 |
+
output_colors.append("rgba(230, 230, 230, 0.8)")
|
174 |
+
|
175 |
+
output_trace = go.Scatter(
|
176 |
+
x=output_x,
|
177 |
+
y=output_y,
|
178 |
+
mode="markers+text",
|
179 |
+
marker=dict(
|
180 |
+
size=self.config.NODE_SIZE,
|
181 |
+
color=output_colors,
|
182 |
+
line=dict(width=self.config.NODE_LINE_WIDTH, color="darkred")
|
183 |
+
),
|
184 |
+
selected=dict(
|
185 |
+
marker=dict(
|
186 |
+
size=self.config.NODE_SIZE + 6,
|
187 |
+
color="rgba(200, 80, 0, 0.9)"
|
188 |
+
)
|
189 |
+
),
|
190 |
+
unselected=dict(
|
191 |
+
marker=dict(
|
192 |
+
opacity=0.65
|
193 |
+
)
|
194 |
+
),
|
195 |
+
text=[truncate_token_label(label) for label in output_labels],
|
196 |
+
textfont=dict(size=self.config.FONT_SIZE, family=self.config.FONT_FAMILY),
|
197 |
+
textposition="middle right",
|
198 |
+
name="Output Tokens",
|
199 |
+
hovertemplate="Output: %{text}<br>Click to filter connections<extra></extra>",
|
200 |
+
customdata=[(i, 'output') for i in range(num_output)]
|
201 |
+
)
|
202 |
+
traces.append(output_trace)
|
203 |
+
self.traces_info['output_nodes_idx'] = len(traces) - 1
|
204 |
+
|
205 |
+
# Create figure
|
206 |
+
fig = go.Figure(data=traces)
|
207 |
+
|
208 |
+
# Update layout
|
209 |
+
title = f"Token Attention Flow ({normalization.capitalize()} Normalization)"
|
210 |
+
fig.update_layout(
|
211 |
+
title=title,
|
212 |
+
xaxis=dict(
|
213 |
+
range=[-0.1, 1.1],
|
214 |
+
showgrid=False,
|
215 |
+
zeroline=False,
|
216 |
+
showticklabels=False,
|
217 |
+
fixedrange=True
|
218 |
+
),
|
219 |
+
yaxis=dict(
|
220 |
+
range=[0, 1],
|
221 |
+
showgrid=False,
|
222 |
+
zeroline=False,
|
223 |
+
showticklabels=False,
|
224 |
+
fixedrange=True
|
225 |
+
),
|
226 |
+
hovermode="closest",
|
227 |
+
clickmode="event+select",
|
228 |
+
dragmode="select",
|
229 |
+
width=self.config.PLOT_WIDTH,
|
230 |
+
height=max(self.config.PLOT_HEIGHT, num_input * 30, num_output * 30),
|
231 |
+
plot_bgcolor="white",
|
232 |
+
margin=dict(l=150, r=200, t=80, b=80),
|
233 |
+
hoverdistance=20,
|
234 |
+
hoverlabel=dict(font_size=12, font_family=self.config.FONT_FAMILY),
|
235 |
+
showlegend=True,
|
236 |
+
legend=dict(
|
237 |
+
yanchor="top",
|
238 |
+
y=0.99,
|
239 |
+
xanchor="left",
|
240 |
+
x=1.02
|
241 |
+
),
|
242 |
+
# Preserve UI state on updates
|
243 |
+
uirevision="constant"
|
244 |
+
)
|
245 |
+
|
246 |
+
# Add legend traces
|
247 |
+
fig.add_trace(go.Scatter(
|
248 |
+
x=[None], y=[None],
|
249 |
+
mode='lines',
|
250 |
+
line=dict(color='rgba(0, 0, 255, 0.6)', width=2),
|
251 |
+
name='Input→Output'
|
252 |
+
))
|
253 |
+
fig.add_trace(go.Scatter(
|
254 |
+
x=[None], y=[None],
|
255 |
+
mode='lines',
|
256 |
+
line=dict(color='rgba(255, 165, 0, 0.6)', width=2),
|
257 |
+
name='Output→Output'
|
258 |
+
))
|
259 |
+
|
260 |
+
# Add annotations
|
261 |
+
fig.add_annotation(
|
262 |
+
x=0.5, y=0.02,
|
263 |
+
text=f"Step {initial_step} / {num_steps-1}: Generating '{output_labels[initial_step] if initial_step < len(output_labels) else ''}'",
|
264 |
+
showarrow=False,
|
265 |
+
font=dict(size=12, color="darkred"),
|
266 |
+
xref="paper", yref="paper"
|
267 |
+
)
|
268 |
+
|
269 |
+
fig.add_annotation(
|
270 |
+
x=0.01, y=0.98,
|
271 |
+
text="💡 Click tokens to filter connections | Use step slider to navigate generation",
|
272 |
+
showarrow=False,
|
273 |
+
font=dict(size=10, color="gray"),
|
274 |
+
align="left",
|
275 |
+
xref="paper", yref="paper"
|
276 |
+
)
|
277 |
+
|
278 |
+
self.current_state['current_step'] = initial_step
|
279 |
+
|
280 |
+
return fig
|
281 |
+
|
282 |
+
def update_for_step(
|
283 |
+
self,
|
284 |
+
fig: go.Figure,
|
285 |
+
step: int,
|
286 |
+
attention_matrices: List[Dict],
|
287 |
+
output_tokens: List[str],
|
288 |
+
threshold: float = 0.05
|
289 |
+
) -> go.Figure:
|
290 |
+
"""
|
291 |
+
Update visualization for a specific generation step.
|
292 |
+
"""
|
293 |
+
if step >= len(attention_matrices):
|
294 |
+
return fig
|
295 |
+
|
296 |
+
output_labels = [clean_label(token) for token in output_tokens]
|
297 |
+
|
298 |
+
with fig.batch_update():
|
299 |
+
# Update input-to-output connections for current step
|
300 |
+
for conn_info in self.traces_info['input_to_output']:
|
301 |
+
if conn_info['output_idx'] == step:
|
302 |
+
weight = attention_matrices[step]['input_attention'][conn_info['input_idx']].item()
|
303 |
+
opacity = scale_weight_to_opacity(weight, threshold=threshold)
|
304 |
+
width = scale_weight_to_width(weight) if opacity > 0 else 0.5
|
305 |
+
|
306 |
+
trace_idx = conn_info['trace_idx']
|
307 |
+
fig.data[trace_idx].opacity = opacity
|
308 |
+
fig.data[trace_idx].line.width = width
|
309 |
+
fig.data[trace_idx].line.color = get_color_for_weight(weight, "blue")
|
310 |
+
elif conn_info['output_idx'] > step:
|
311 |
+
# Hide future connections
|
312 |
+
fig.data[conn_info['trace_idx']].opacity = 0
|
313 |
+
|
314 |
+
# Update output-to-output connections
|
315 |
+
for conn_info in self.traces_info['output_to_output']:
|
316 |
+
if conn_info['to_idx'] == step and attention_matrices[step]['output_attention'] is not None:
|
317 |
+
if conn_info['from_idx'] < len(attention_matrices[step]['output_attention']):
|
318 |
+
weight = attention_matrices[step]['output_attention'][conn_info['from_idx']].item()
|
319 |
+
opacity = scale_weight_to_opacity(weight, threshold=threshold)
|
320 |
+
width = scale_weight_to_width(weight) if opacity > 0 else 0.5
|
321 |
+
|
322 |
+
trace_idx = conn_info['trace_idx']
|
323 |
+
fig.data[trace_idx].opacity = opacity
|
324 |
+
fig.data[trace_idx].line.width = width
|
325 |
+
fig.data[trace_idx].line.color = get_color_for_weight(weight, "orange")
|
326 |
+
elif conn_info['to_idx'] > step:
|
327 |
+
# Hide future connections
|
328 |
+
fig.data[conn_info['trace_idx']].opacity = 0
|
329 |
+
|
330 |
+
# Update output node colors
|
331 |
+
output_colors = []
|
332 |
+
for j in range(len(output_tokens)):
|
333 |
+
if j <= step:
|
334 |
+
output_colors.append(self.config.OUTPUT_COLOR)
|
335 |
+
else:
|
336 |
+
output_colors.append("rgba(230, 230, 230, 0.8)")
|
337 |
+
|
338 |
+
if self.traces_info['output_nodes_idx'] is not None:
|
339 |
+
fig.data[self.traces_info['output_nodes_idx']].marker.color = output_colors
|
340 |
+
|
341 |
+
# Update step annotation
|
342 |
+
fig.layout.annotations[0].text = f"Step {step} / {len(attention_matrices)-1}: Generating '{output_labels[step] if step < len(output_labels) else ''}'"
|
343 |
+
|
344 |
+
self.current_state['current_step'] = step
|
345 |
+
return fig
|
346 |
+
|
347 |
+
def filter_by_token(
|
348 |
+
self,
|
349 |
+
fig: go.Figure,
|
350 |
+
token_idx: int,
|
351 |
+
token_type: str,
|
352 |
+
attention_matrices: List[Dict],
|
353 |
+
threshold: float = 0.05
|
354 |
+
) -> go.Figure:
|
355 |
+
"""
|
356 |
+
Filter connections to show only those related to selected token.
|
357 |
+
"""
|
358 |
+
with fig.batch_update():
|
359 |
+
current_step = self.current_state['current_step']
|
360 |
+
|
361 |
+
if token_type == 'input':
|
362 |
+
# Show only connections from this input token
|
363 |
+
for conn_info in self.traces_info['input_to_output']:
|
364 |
+
if conn_info['input_idx'] == token_idx and conn_info['output_idx'] <= current_step:
|
365 |
+
weight = attention_matrices[conn_info['output_idx']]['input_attention'][token_idx].item()
|
366 |
+
opacity = scale_weight_to_opacity(weight, threshold=threshold)
|
367 |
+
fig.data[conn_info['trace_idx']].opacity = opacity if opacity > 0 else 0
|
368 |
+
else:
|
369 |
+
fig.data[conn_info['trace_idx']].opacity = 0
|
370 |
+
|
371 |
+
# Hide all output-to-output connections
|
372 |
+
for conn_info in self.traces_info['output_to_output']:
|
373 |
+
fig.data[conn_info['trace_idx']].opacity = 0
|
374 |
+
|
375 |
+
elif token_type == 'output':
|
376 |
+
# Show connections to this output token
|
377 |
+
for conn_info in self.traces_info['input_to_output']:
|
378 |
+
if conn_info['output_idx'] == token_idx and token_idx <= current_step:
|
379 |
+
weight = attention_matrices[token_idx]['input_attention'][conn_info['input_idx']].item()
|
380 |
+
opacity = scale_weight_to_opacity(weight, threshold=threshold)
|
381 |
+
fig.data[conn_info['trace_idx']].opacity = opacity if opacity > 0 else 0
|
382 |
+
else:
|
383 |
+
fig.data[conn_info['trace_idx']].opacity = 0
|
384 |
+
|
385 |
+
# Show connections from/to this output token
|
386 |
+
for conn_info in self.traces_info['output_to_output']:
|
387 |
+
show = False
|
388 |
+
if conn_info['to_idx'] == token_idx and token_idx <= current_step:
|
389 |
+
if attention_matrices[token_idx]['output_attention'] is not None:
|
390 |
+
if conn_info['from_idx'] < len(attention_matrices[token_idx]['output_attention']):
|
391 |
+
weight = attention_matrices[token_idx]['output_attention'][conn_info['from_idx']].item()
|
392 |
+
opacity = scale_weight_to_opacity(weight, threshold=threshold)
|
393 |
+
fig.data[conn_info['trace_idx']].opacity = opacity if opacity > 0 else 0
|
394 |
+
show = True
|
395 |
+
elif conn_info['from_idx'] == token_idx and conn_info['to_idx'] <= current_step:
|
396 |
+
if attention_matrices[conn_info['to_idx']]['output_attention'] is not None:
|
397 |
+
if token_idx < len(attention_matrices[conn_info['to_idx']]['output_attention']):
|
398 |
+
weight = attention_matrices[conn_info['to_idx']]['output_attention'][token_idx].item()
|
399 |
+
opacity = scale_weight_to_opacity(weight, threshold=threshold)
|
400 |
+
fig.data[conn_info['trace_idx']].opacity = opacity if opacity > 0 else 0
|
401 |
+
show = True
|
402 |
+
|
403 |
+
if not show:
|
404 |
+
fig.data[conn_info['trace_idx']].opacity = 0
|
405 |
+
|
406 |
+
self.current_state['selected_token'] = token_idx
|
407 |
+
self.current_state['selected_type'] = token_type
|
408 |
+
self.current_state['show_all'] = False
|
409 |
+
|
410 |
+
return fig
|
411 |
+
|
412 |
+
def show_all_connections(
|
413 |
+
self,
|
414 |
+
fig: go.Figure,
|
415 |
+
attention_matrices: List[Dict],
|
416 |
+
threshold: float = 0.05
|
417 |
+
) -> go.Figure:
|
418 |
+
"""
|
419 |
+
Reset to show all connections for current step.
|
420 |
+
"""
|
421 |
+
self.current_state['selected_token'] = None
|
422 |
+
self.current_state['selected_type'] = None
|
423 |
+
self.current_state['show_all'] = True
|
424 |
+
|
425 |
+
return self.update_for_step(
|
426 |
+
fig,
|
427 |
+
self.current_state['current_step'],
|
428 |
+
attention_matrices,
|
429 |
+
[clean_label(t) for t in attention_matrices],
|
430 |
+
threshold
|
431 |
+
)
|
432 |
+
|
433 |
+
def _create_empty_figure(self, message: str) -> go.Figure:
|
434 |
+
"""Create an empty figure with a message."""
|
435 |
+
fig = go.Figure()
|
436 |
+
fig.update_layout(
|
437 |
+
title=message,
|
438 |
+
xaxis={'visible': False},
|
439 |
+
yaxis={'visible': False},
|
440 |
+
width=self.config.PLOT_WIDTH,
|
441 |
+
height=self.config.PLOT_HEIGHT
|
442 |
+
)
|
443 |
+
return fig
|
visualization/simple_svg_viz.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import List, Dict, Any, Optional, Tuple
|
3 |
+
from .utils import clean_label, scale_weight_to_width, scale_weight_to_opacity
|
4 |
+
|
5 |
+
class SimpleSVGVisualizer:
|
6 |
+
def __init__(self, config):
|
7 |
+
self.config = config
|
8 |
+
|
9 |
+
def create_visualization_html(
|
10 |
+
self,
|
11 |
+
input_tokens: List[str],
|
12 |
+
output_tokens: List[str],
|
13 |
+
attention_matrices: List[Dict],
|
14 |
+
threshold: float = 0.05,
|
15 |
+
initial_step: int = 0,
|
16 |
+
selected_token: Optional[int] = None,
|
17 |
+
selected_type: Optional[str] = None
|
18 |
+
) -> str:
|
19 |
+
"""Create a simple SVG visualization without D3."""
|
20 |
+
# Clean labels
|
21 |
+
input_labels = [clean_label(token) for token in input_tokens]
|
22 |
+
output_labels = [clean_label(token) for token in output_tokens]
|
23 |
+
|
24 |
+
# Calculate positions
|
25 |
+
width = self.config.PLOT_WIDTH
|
26 |
+
height = self.config.PLOT_HEIGHT
|
27 |
+
margin = 100
|
28 |
+
|
29 |
+
input_x = margin
|
30 |
+
output_x = width - margin
|
31 |
+
|
32 |
+
# Create SVG elements
|
33 |
+
svg_elements = []
|
34 |
+
|
35 |
+
# Background
|
36 |
+
svg_elements.append(f'<rect width="{width}" height="{height}" fill="white" stroke="#ddd"/>')
|
37 |
+
|
38 |
+
# Title
|
39 |
+
svg_elements.append(f'<text x="{width/2}" y="30" text-anchor="middle" font-size="16" font-weight="bold">Token Attention Flow</text>')
|
40 |
+
|
41 |
+
# Calculate vertical positions
|
42 |
+
input_y_positions = []
|
43 |
+
output_y_positions = []
|
44 |
+
|
45 |
+
if len(input_labels) > 0:
|
46 |
+
input_spacing = (height - 2 * margin) / max(1, len(input_labels) - 1)
|
47 |
+
input_y_positions = [margin + i * input_spacing for i in range(len(input_labels))]
|
48 |
+
|
49 |
+
if len(output_labels) > 0:
|
50 |
+
output_spacing = (height - 2 * margin) / max(1, len(output_labels) - 1)
|
51 |
+
output_y_positions = [margin + i * output_spacing for i in range(len(output_labels))]
|
52 |
+
|
53 |
+
# Draw connections
|
54 |
+
for j in range(min(initial_step + 1, len(output_labels))):
|
55 |
+
if j < len(attention_matrices):
|
56 |
+
for i in range(len(input_labels)):
|
57 |
+
weight = attention_matrices[j]['input_attention'][i].item()
|
58 |
+
|
59 |
+
# Apply filtering
|
60 |
+
if selected_token is not None:
|
61 |
+
if selected_type == 'input' and i != selected_token:
|
62 |
+
continue
|
63 |
+
elif selected_type == 'output' and j != selected_token:
|
64 |
+
continue
|
65 |
+
|
66 |
+
if weight > threshold:
|
67 |
+
opacity = scale_weight_to_opacity(weight, threshold)
|
68 |
+
width_val = scale_weight_to_width(weight)
|
69 |
+
|
70 |
+
svg_elements.append(
|
71 |
+
f'<line x1="{input_x}" y1="{input_y_positions[i]}" '
|
72 |
+
f'x2="{output_x}" y2="{output_y_positions[j]}" '
|
73 |
+
f'stroke="blue" stroke-width="{width_val}" opacity="{opacity}"/>'
|
74 |
+
)
|
75 |
+
|
76 |
+
# Draw input nodes
|
77 |
+
for i, label in enumerate(input_labels):
|
78 |
+
y = input_y_positions[i]
|
79 |
+
color = "yellow" if selected_token == i and selected_type == 'input' else self.config.INPUT_COLOR
|
80 |
+
|
81 |
+
svg_elements.append(
|
82 |
+
f'<circle cx="{input_x}" cy="{y}" r="{self.config.NODE_SIZE/2}" '
|
83 |
+
f'fill="{color}" stroke="darkblue" stroke-width="2" '
|
84 |
+
f'style="cursor: pointer" '
|
85 |
+
f'onclick="handleTokenClick({i}, \'input\')"/>'
|
86 |
+
)
|
87 |
+
svg_elements.append(
|
88 |
+
f'<text x="{input_x - self.config.NODE_SIZE/2 - 10}" y="{y + 5}" '
|
89 |
+
f'text-anchor="end" font-size="{self.config.FONT_SIZE}">{label}</text>'
|
90 |
+
)
|
91 |
+
|
92 |
+
# Draw output nodes
|
93 |
+
for j, label in enumerate(output_labels):
|
94 |
+
y = output_y_positions[j]
|
95 |
+
color = "yellow" if selected_token == j and selected_type == 'output' else (
|
96 |
+
self.config.OUTPUT_COLOR if j <= initial_step else "#e6e6e6"
|
97 |
+
)
|
98 |
+
|
99 |
+
svg_elements.append(
|
100 |
+
f'<circle cx="{output_x}" cy="{y}" r="{self.config.NODE_SIZE/2}" '
|
101 |
+
f'fill="{color}" stroke="darkred" stroke-width="2" '
|
102 |
+
f'style="cursor: pointer" '
|
103 |
+
f'onclick="handleTokenClick({j}, \'output\')"/>'
|
104 |
+
)
|
105 |
+
svg_elements.append(
|
106 |
+
f'<text x="{output_x + self.config.NODE_SIZE/2 + 10}" y="{y + 5}" '
|
107 |
+
f'text-anchor="start" font-size="{self.config.FONT_SIZE}">{label}</text>'
|
108 |
+
)
|
109 |
+
|
110 |
+
# Step info
|
111 |
+
svg_elements.append(
|
112 |
+
f'<text x="{width/2}" y="{height - 20}" text-anchor="middle" font-size="12" fill="darkred">'
|
113 |
+
f'Step {initial_step} / {len(output_labels) - 1}: Generating "{output_labels[initial_step] if initial_step < len(output_labels) else ""}"'
|
114 |
+
f'</text>'
|
115 |
+
)
|
116 |
+
|
117 |
+
# Create HTML
|
118 |
+
html = f"""
|
119 |
+
<div style="width: 100%; overflow-x: auto;">
|
120 |
+
<svg width="{width}" height="{height}" style="border: 1px solid #ddd;">
|
121 |
+
{''.join(svg_elements)}
|
122 |
+
</svg>
|
123 |
+
</div>
|
124 |
+
|
125 |
+
<script>
|
126 |
+
function handleTokenClick(index, type) {{
|
127 |
+
console.log('Token clicked:', index, type);
|
128 |
+
const hiddenInput = document.querySelector('#clicked-token-d3 textarea');
|
129 |
+
if (hiddenInput) {{
|
130 |
+
const clickData = JSON.stringify({{index: index, type: type}});
|
131 |
+
hiddenInput.value = clickData;
|
132 |
+
hiddenInput.dispatchEvent(new Event('input', {{ bubbles: true }}));
|
133 |
+
}}
|
134 |
+
}}
|
135 |
+
</script>
|
136 |
+
"""
|
137 |
+
|
138 |
+
return html
|
visualization/utils.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import List, Tuple, Optional
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def clean_label(token: str) -> str:
|
6 |
+
"""
|
7 |
+
Cleans token labels for visualization.
|
8 |
+
Handles various tokenizer-specific formatting.
|
9 |
+
"""
|
10 |
+
label = str(token)
|
11 |
+
|
12 |
+
# Handle common tokenizer prefixes
|
13 |
+
label = label.replace('Ġ', ' ') # GPT-2 style space
|
14 |
+
label = label.replace('▁', ' ') # SentencePiece style space
|
15 |
+
label = label.replace('Ċ', '\\n') # Newline
|
16 |
+
|
17 |
+
# Handle special tokens
|
18 |
+
label = label.replace('</s>', '[EOS]')
|
19 |
+
label = label.replace('<s>', '[BOS]')
|
20 |
+
label = label.replace('<unk>', '[UNK]')
|
21 |
+
label = label.replace('<pad>', '[PAD]')
|
22 |
+
label = label.replace('<|begin_of_text|>', '[BOS]')
|
23 |
+
label = label.replace('<|end_of_text|>', '[EOS]')
|
24 |
+
label = label.replace('<|endoftext|>', '[EOS]')
|
25 |
+
|
26 |
+
# Remove byte-level encoding markers
|
27 |
+
label = re.sub(r'<0x[0-9A-Fa-f]{2}>', '', label)
|
28 |
+
|
29 |
+
# Clean up whitespace
|
30 |
+
label = label.strip()
|
31 |
+
|
32 |
+
# Return cleaned label or placeholder
|
33 |
+
return label if label else "[EMPTY]"
|
34 |
+
|
35 |
+
def scale_weight_to_width(
|
36 |
+
weight: float,
|
37 |
+
min_width: float = 0.5,
|
38 |
+
max_width: float = 3.0,
|
39 |
+
scale_factor: float = 5.0
|
40 |
+
) -> float:
|
41 |
+
"""
|
42 |
+
Scale attention weight to line width for visualization.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
weight: Attention weight (0-1)
|
46 |
+
min_width: Minimum line width
|
47 |
+
max_width: Maximum line width
|
48 |
+
scale_factor: Scaling factor for weight
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
Scaled line width
|
52 |
+
"""
|
53 |
+
scaled = min(1.0, weight * scale_factor)
|
54 |
+
return min_width + (max_width - min_width) * scaled
|
55 |
+
|
56 |
+
def scale_weight_to_opacity(
|
57 |
+
weight: float,
|
58 |
+
min_opacity: float = 0.1,
|
59 |
+
max_opacity: float = 1.0,
|
60 |
+
threshold: float = 0.0
|
61 |
+
) -> float:
|
62 |
+
"""
|
63 |
+
Scale attention weight to opacity for visualization.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
weight: Attention weight (0-1)
|
67 |
+
min_opacity: Minimum opacity
|
68 |
+
max_opacity: Maximum opacity
|
69 |
+
threshold: Threshold below which opacity is 0
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
Scaled opacity
|
73 |
+
"""
|
74 |
+
if weight < threshold:
|
75 |
+
return 0.0
|
76 |
+
|
77 |
+
# Linear scaling above threshold
|
78 |
+
normalized = (weight - threshold) / (1.0 - threshold) if threshold < 1.0 else weight
|
79 |
+
return min_opacity + (max_opacity - min_opacity) * normalized
|
80 |
+
|
81 |
+
def get_node_positions(
|
82 |
+
num_input: int,
|
83 |
+
num_output: int,
|
84 |
+
spacing: str = 'linear'
|
85 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
86 |
+
"""
|
87 |
+
Calculate node positions for visualization.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
num_input: Number of input tokens
|
91 |
+
num_output: Number of output tokens
|
92 |
+
spacing: Spacing strategy ('linear', 'equal')
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
Tuple of (input_x, input_y, output_x, output_y)
|
96 |
+
"""
|
97 |
+
# Y positions (vertical)
|
98 |
+
if spacing == 'linear':
|
99 |
+
input_y = np.linspace(0.1, 0.9, num_input) if num_input > 1 else np.array([0.5])
|
100 |
+
output_y = np.linspace(0.1, 0.9, num_output) if num_output > 1 else np.array([0.5])
|
101 |
+
else: # equal spacing
|
102 |
+
total_height = 0.8
|
103 |
+
input_spacing = total_height / (num_input + 1)
|
104 |
+
output_spacing = total_height / (num_output + 1)
|
105 |
+
input_y = np.array([0.1 + (i + 1) * input_spacing for i in range(num_input)])
|
106 |
+
output_y = np.array([0.1 + (i + 1) * output_spacing for i in range(num_output)])
|
107 |
+
|
108 |
+
# X positions (horizontal)
|
109 |
+
input_x = np.full(num_input, 0.1)
|
110 |
+
output_x = np.full(num_output, 0.9)
|
111 |
+
|
112 |
+
return input_x, input_y, output_x, output_y
|
113 |
+
|
114 |
+
def create_spline_path(
|
115 |
+
start_x: float,
|
116 |
+
start_y: float,
|
117 |
+
end_x: float,
|
118 |
+
end_y: float,
|
119 |
+
control_offset: float = 0.15
|
120 |
+
) -> Tuple[List[float], List[float]]:
|
121 |
+
"""
|
122 |
+
Create a spline path for output-to-output connections.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
start_x, start_y: Starting position
|
126 |
+
end_x, end_y: Ending position
|
127 |
+
control_offset: Offset for control points
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
Tuple of (x_path, y_path) for spline
|
131 |
+
"""
|
132 |
+
# Create control points for smooth curve
|
133 |
+
path_x = [
|
134 |
+
start_x,
|
135 |
+
start_x + control_offset,
|
136 |
+
end_x + control_offset,
|
137 |
+
end_x
|
138 |
+
]
|
139 |
+
path_y = [
|
140 |
+
start_y,
|
141 |
+
start_y,
|
142 |
+
end_y,
|
143 |
+
end_y
|
144 |
+
]
|
145 |
+
|
146 |
+
return path_x, path_y
|
147 |
+
|
148 |
+
def format_attention_text(
|
149 |
+
from_token: str,
|
150 |
+
to_token: str,
|
151 |
+
weight: float,
|
152 |
+
connection_type: str = "attention"
|
153 |
+
) -> str:
|
154 |
+
"""
|
155 |
+
Format hover text for attention connections.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
from_token: Source token
|
159 |
+
to_token: Target token
|
160 |
+
weight: Attention weight
|
161 |
+
connection_type: Type of connection
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
Formatted hover text
|
165 |
+
"""
|
166 |
+
return (
|
167 |
+
f"{from_token} → {to_token}<br>"
|
168 |
+
f"{connection_type.capitalize()} Weight: {weight:.4f}"
|
169 |
+
)
|
170 |
+
|
171 |
+
def get_color_for_weight(
|
172 |
+
weight: float,
|
173 |
+
base_color: str = "blue",
|
174 |
+
use_gradient: bool = True
|
175 |
+
) -> str:
|
176 |
+
"""
|
177 |
+
Get color for attention weight visualization.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
weight: Attention weight (0-1)
|
181 |
+
base_color: Base color name
|
182 |
+
use_gradient: Whether to use gradient based on weight
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
Color string for plotly
|
186 |
+
"""
|
187 |
+
if not use_gradient:
|
188 |
+
if base_color == "blue":
|
189 |
+
return "rgba(0, 0, 255, 0.6)"
|
190 |
+
elif base_color == "orange":
|
191 |
+
return "rgba(255, 165, 0, 0.6)"
|
192 |
+
else:
|
193 |
+
return "rgba(128, 128, 128, 0.6)"
|
194 |
+
|
195 |
+
# Create gradient based on weight
|
196 |
+
if base_color == "blue":
|
197 |
+
# Light blue to dark blue
|
198 |
+
intensity = int(255 - weight * 155) # 255 to 100
|
199 |
+
return f"rgba(0, {intensity}, 255, {0.3 + weight * 0.4})"
|
200 |
+
elif base_color == "orange":
|
201 |
+
# Light orange to dark orange
|
202 |
+
intensity = int(255 - weight * 100) # 255 to 155
|
203 |
+
return f"rgba(255, {intensity}, 0, {0.3 + weight * 0.4})"
|
204 |
+
else:
|
205 |
+
# Gray scale
|
206 |
+
intensity = int(200 - weight * 100) # 200 to 100
|
207 |
+
return f"rgba({intensity}, {intensity}, {intensity}, {0.3 + weight * 0.4})"
|
208 |
+
|
209 |
+
def truncate_token_label(token: str, max_length: int = 15) -> str:
|
210 |
+
"""
|
211 |
+
Truncate long token labels for display.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
token: Token string
|
215 |
+
max_length: Maximum length
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
Truncated token with ellipsis if needed
|
219 |
+
"""
|
220 |
+
cleaned = clean_label(token)
|
221 |
+
if len(cleaned) > max_length:
|
222 |
+
return cleaned[:max_length-3] + "..."
|
223 |
+
return cleaned
|