File size: 5,863 Bytes
dd850a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import json
from typing import List, Dict, Any, Optional, Tuple
from .utils import clean_label, scale_weight_to_width, scale_weight_to_opacity

class SimpleSVGVisualizer:
    def __init__(self, config):
        self.config = config
        
    def create_visualization_html(
        self,
        input_tokens: List[str],
        output_tokens: List[str],
        attention_matrices: List[Dict],
        threshold: float = 0.05,
        initial_step: int = 0,
        selected_token: Optional[int] = None,
        selected_type: Optional[str] = None
    ) -> str:
        """Create a simple SVG visualization without D3."""
        # Clean labels
        input_labels = [clean_label(token) for token in input_tokens]
        output_labels = [clean_label(token) for token in output_tokens]
        
        # Calculate positions
        width = self.config.PLOT_WIDTH
        height = self.config.PLOT_HEIGHT
        margin = 100
        
        input_x = margin
        output_x = width - margin
        
        # Create SVG elements
        svg_elements = []
        
        # Background
        svg_elements.append(f'<rect width="{width}" height="{height}" fill="white" stroke="#ddd"/>')
        
        # Title
        svg_elements.append(f'<text x="{width/2}" y="30" text-anchor="middle" font-size="16" font-weight="bold">Token Attention Flow</text>')
        
        # Calculate vertical positions
        input_y_positions = []
        output_y_positions = []
        
        if len(input_labels) > 0:
            input_spacing = (height - 2 * margin) / max(1, len(input_labels) - 1)
            input_y_positions = [margin + i * input_spacing for i in range(len(input_labels))]
        
        if len(output_labels) > 0:
            output_spacing = (height - 2 * margin) / max(1, len(output_labels) - 1)
            output_y_positions = [margin + i * output_spacing for i in range(len(output_labels))]
        
        # Draw connections
        for j in range(min(initial_step + 1, len(output_labels))):
            if j < len(attention_matrices):
                for i in range(len(input_labels)):
                    weight = attention_matrices[j]['input_attention'][i].item()
                    
                    # Apply filtering
                    if selected_token is not None:
                        if selected_type == 'input' and i != selected_token:
                            continue
                        elif selected_type == 'output' and j != selected_token:
                            continue
                    
                    if weight > threshold:
                        opacity = scale_weight_to_opacity(weight, threshold)
                        width_val = scale_weight_to_width(weight)
                        
                        svg_elements.append(
                            f'<line x1="{input_x}" y1="{input_y_positions[i]}" '
                            f'x2="{output_x}" y2="{output_y_positions[j]}" '
                            f'stroke="blue" stroke-width="{width_val}" opacity="{opacity}"/>'
                        )
        
        # Draw input nodes
        for i, label in enumerate(input_labels):
            y = input_y_positions[i]
            color = "yellow" if selected_token == i and selected_type == 'input' else self.config.INPUT_COLOR
            
            svg_elements.append(
                f'<circle cx="{input_x}" cy="{y}" r="{self.config.NODE_SIZE/2}" '
                f'fill="{color}" stroke="darkblue" stroke-width="2" '
                f'style="cursor: pointer" '
                f'onclick="handleTokenClick({i}, \'input\')"/>'
            )
            svg_elements.append(
                f'<text x="{input_x - self.config.NODE_SIZE/2 - 10}" y="{y + 5}" '
                f'text-anchor="end" font-size="{self.config.FONT_SIZE}">{label}</text>'
            )
        
        # Draw output nodes
        for j, label in enumerate(output_labels):
            y = output_y_positions[j]
            color = "yellow" if selected_token == j and selected_type == 'output' else (
                self.config.OUTPUT_COLOR if j <= initial_step else "#e6e6e6"
            )
            
            svg_elements.append(
                f'<circle cx="{output_x}" cy="{y}" r="{self.config.NODE_SIZE/2}" '
                f'fill="{color}" stroke="darkred" stroke-width="2" '
                f'style="cursor: pointer" '
                f'onclick="handleTokenClick({j}, \'output\')"/>'
            )
            svg_elements.append(
                f'<text x="{output_x + self.config.NODE_SIZE/2 + 10}" y="{y + 5}" '
                f'text-anchor="start" font-size="{self.config.FONT_SIZE}">{label}</text>'
            )
        
        # Step info
        svg_elements.append(
            f'<text x="{width/2}" y="{height - 20}" text-anchor="middle" font-size="12" fill="darkred">'
            f'Step {initial_step} / {len(output_labels) - 1}: Generating "{output_labels[initial_step] if initial_step < len(output_labels) else ""}"'
            f'</text>'
        )
        
        # Create HTML
        html = f"""
        <div style="width: 100%; overflow-x: auto;">
            <svg width="{width}" height="{height}" style="border: 1px solid #ddd;">
                {''.join(svg_elements)}
            </svg>
        </div>
        
        <script>
        function handleTokenClick(index, type) {{
            console.log('Token clicked:', index, type);
            const hiddenInput = document.querySelector('#clicked-token-d3 textarea');
            if (hiddenInput) {{
                const clickData = JSON.stringify({{index: index, type: type}});
                hiddenInput.value = clickData;
                hiddenInput.dispatchEvent(new Event('input', {{ bubbles: true }}));
            }}
        }}
        </script>
        """
        
        return html