import plotly.graph_objects as go
import numpy as np
from typing import List, Dict, Any, Optional, Tuple, Callable
from .utils import (
clean_label, scale_weight_to_width, scale_weight_to_opacity,
get_node_positions, create_spline_path, format_attention_text,
get_color_for_weight, truncate_token_label
)
class AttentionVisualizer:
def __init__(self, config):
self.config = config
self.current_state = {
'selected_token': None,
'selected_type': None,
'current_step': 0,
'show_all': True
}
self.traces_info = {
'input_to_output': [],
'output_to_output': [],
'input_nodes_idx': None,
'output_nodes_idx': None
}
def create_interactive_plot(
self,
input_tokens: List[str],
output_tokens: List[str],
attention_matrices: List[Dict],
threshold: float = 0.05,
initial_step: int = 0,
normalization: str = "separate"
) -> go.Figure:
"""
Create the main interactive visualization.
"""
# Clean labels
input_labels = [clean_label(token) for token in input_tokens]
output_labels = [clean_label(token) for token in output_tokens]
num_input = len(input_labels)
num_output = len(output_labels)
num_steps = len(attention_matrices)
if num_input == 0 or num_output == 0 or num_steps == 0:
return self._create_empty_figure("No data to visualize")
# Get node positions
input_x, input_y, output_x, output_y = get_node_positions(num_input, num_output)
# Create connection traces
traces = []
self.traces_info = {
'input_to_output': [],
'output_to_output': [],
'input_nodes_idx': None,
'output_nodes_idx': None
}
# Input to output connections
for j in range(num_output):
for i in range(num_input):
weight = 0
if j < len(attention_matrices):
weight = attention_matrices[j]['input_attention'][i].item()
opacity = scale_weight_to_opacity(weight, threshold=threshold)
width = scale_weight_to_width(weight) if opacity > 0 else 0.5
trace = go.Scatter(
x=[input_x[i], output_x[j]],
y=[input_y[i], output_y[j]],
mode="lines",
line=dict(
color=get_color_for_weight(weight, "blue"),
width=width
),
opacity=opacity,
showlegend=False,
hoverinfo='text',
text=format_attention_text(input_labels[i], output_labels[j], weight),
hoverlabel=dict(bgcolor="lightskyblue", bordercolor="darkblue"),
name=f"in_to_out_{i}_{j}",
customdata=[(i, j)],
hovertemplate="Input→Output %{customdata[0]}→%{customdata[1]}"
)
traces.append(trace)
self.traces_info['input_to_output'].append({
'input_idx': i,
'output_idx': j,
'trace_idx': len(traces) - 1
})
# Output to output connections
for j in range(1, num_output):
for i in range(j):
weight = 0
if j < len(attention_matrices) and attention_matrices[j]['output_attention'] is not None:
if i < len(attention_matrices[j]['output_attention']):
weight = attention_matrices[j]['output_attention'][i].item()
opacity = scale_weight_to_opacity(weight, threshold=threshold)
width = scale_weight_to_width(weight) if opacity > 0 else 0.5
# Create spline path for curved connection
path_x, path_y = create_spline_path(
output_x[i], output_y[i],
output_x[j], output_y[j],
control_offset=0.15
)
trace = go.Scatter(
x=path_x,
y=path_y,
mode="lines",
line=dict(
color=get_color_for_weight(weight, "orange"),
width=width,
shape='spline'
),
opacity=opacity,
showlegend=False,
hoverinfo='text',
text=format_attention_text(output_labels[i], output_labels[j], weight),
hoverlabel=dict(bgcolor="moccasin", bordercolor="darkorange"),
name=f"out_to_out_{i}_{j}"
)
traces.append(trace)
self.traces_info['output_to_output'].append({
'from_idx': i,
'to_idx': j,
'trace_idx': len(traces) - 1
})
# Input nodes
input_trace = go.Scatter(
x=input_x,
y=input_y,
mode="markers+text",
marker=dict(
size=self.config.NODE_SIZE,
color=self.config.INPUT_COLOR,
line=dict(width=self.config.NODE_LINE_WIDTH, color="darkblue")
),
selected=dict(
marker=dict(
size=self.config.NODE_SIZE + 6,
color="rgba(0, 0, 200, 0.9)"
)
),
unselected=dict(
marker=dict(
opacity=0.65
)
),
text=[truncate_token_label(label) for label in input_labels],
textfont=dict(size=self.config.FONT_SIZE, family=self.config.FONT_FAMILY),
textposition="middle left",
name="Input Tokens",
hovertemplate="Input: %{text}
Click to filter connections",
customdata=[(i, 'input') for i in range(num_input)]
)
traces.append(input_trace)
self.traces_info['input_nodes_idx'] = len(traces) - 1
# Output nodes
output_colors = []
for j in range(num_output):
if j <= initial_step:
output_colors.append(self.config.OUTPUT_COLOR)
else:
output_colors.append("rgba(230, 230, 230, 0.8)")
output_trace = go.Scatter(
x=output_x,
y=output_y,
mode="markers+text",
marker=dict(
size=self.config.NODE_SIZE,
color=output_colors,
line=dict(width=self.config.NODE_LINE_WIDTH, color="darkred")
),
selected=dict(
marker=dict(
size=self.config.NODE_SIZE + 6,
color="rgba(200, 80, 0, 0.9)"
)
),
unselected=dict(
marker=dict(
opacity=0.65
)
),
text=[truncate_token_label(label) for label in output_labels],
textfont=dict(size=self.config.FONT_SIZE, family=self.config.FONT_FAMILY),
textposition="middle right",
name="Output Tokens",
hovertemplate="Output: %{text}
Click to filter connections",
customdata=[(i, 'output') for i in range(num_output)]
)
traces.append(output_trace)
self.traces_info['output_nodes_idx'] = len(traces) - 1
# Create figure
fig = go.Figure(data=traces)
# Update layout
title = f"Token Attention Flow ({normalization.capitalize()} Normalization)"
fig.update_layout(
title=title,
xaxis=dict(
range=[-0.1, 1.1],
showgrid=False,
zeroline=False,
showticklabels=False,
fixedrange=True
),
yaxis=dict(
range=[0, 1],
showgrid=False,
zeroline=False,
showticklabels=False,
fixedrange=True
),
hovermode="closest",
clickmode="event+select",
dragmode="select",
width=self.config.PLOT_WIDTH,
height=max(self.config.PLOT_HEIGHT, num_input * 30, num_output * 30),
plot_bgcolor="white",
margin=dict(l=150, r=200, t=80, b=80),
hoverdistance=20,
hoverlabel=dict(font_size=12, font_family=self.config.FONT_FAMILY),
showlegend=True,
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=1.02
),
# Preserve UI state on updates
uirevision="constant"
)
# Add legend traces
fig.add_trace(go.Scatter(
x=[None], y=[None],
mode='lines',
line=dict(color='rgba(0, 0, 255, 0.6)', width=2),
name='Input→Output'
))
fig.add_trace(go.Scatter(
x=[None], y=[None],
mode='lines',
line=dict(color='rgba(255, 165, 0, 0.6)', width=2),
name='Output→Output'
))
# Add annotations
fig.add_annotation(
x=0.5, y=0.02,
text=f"Step {initial_step} / {num_steps-1}: Generating '{output_labels[initial_step] if initial_step < len(output_labels) else ''}'",
showarrow=False,
font=dict(size=12, color="darkred"),
xref="paper", yref="paper"
)
fig.add_annotation(
x=0.01, y=0.98,
text="💡 Click tokens to filter connections | Use step slider to navigate generation",
showarrow=False,
font=dict(size=10, color="gray"),
align="left",
xref="paper", yref="paper"
)
self.current_state['current_step'] = initial_step
return fig
def update_for_step(
self,
fig: go.Figure,
step: int,
attention_matrices: List[Dict],
output_tokens: List[str],
threshold: float = 0.05
) -> go.Figure:
"""
Update visualization for a specific generation step.
"""
if step >= len(attention_matrices):
return fig
output_labels = [clean_label(token) for token in output_tokens]
with fig.batch_update():
# Update input-to-output connections for current step
for conn_info in self.traces_info['input_to_output']:
if conn_info['output_idx'] == step:
weight = attention_matrices[step]['input_attention'][conn_info['input_idx']].item()
opacity = scale_weight_to_opacity(weight, threshold=threshold)
width = scale_weight_to_width(weight) if opacity > 0 else 0.5
trace_idx = conn_info['trace_idx']
fig.data[trace_idx].opacity = opacity
fig.data[trace_idx].line.width = width
fig.data[trace_idx].line.color = get_color_for_weight(weight, "blue")
elif conn_info['output_idx'] > step:
# Hide future connections
fig.data[conn_info['trace_idx']].opacity = 0
# Update output-to-output connections
for conn_info in self.traces_info['output_to_output']:
if conn_info['to_idx'] == step and attention_matrices[step]['output_attention'] is not None:
if conn_info['from_idx'] < len(attention_matrices[step]['output_attention']):
weight = attention_matrices[step]['output_attention'][conn_info['from_idx']].item()
opacity = scale_weight_to_opacity(weight, threshold=threshold)
width = scale_weight_to_width(weight) if opacity > 0 else 0.5
trace_idx = conn_info['trace_idx']
fig.data[trace_idx].opacity = opacity
fig.data[trace_idx].line.width = width
fig.data[trace_idx].line.color = get_color_for_weight(weight, "orange")
elif conn_info['to_idx'] > step:
# Hide future connections
fig.data[conn_info['trace_idx']].opacity = 0
# Update output node colors
output_colors = []
for j in range(len(output_tokens)):
if j <= step:
output_colors.append(self.config.OUTPUT_COLOR)
else:
output_colors.append("rgba(230, 230, 230, 0.8)")
if self.traces_info['output_nodes_idx'] is not None:
fig.data[self.traces_info['output_nodes_idx']].marker.color = output_colors
# Update step annotation
fig.layout.annotations[0].text = f"Step {step} / {len(attention_matrices)-1}: Generating '{output_labels[step] if step < len(output_labels) else ''}'"
self.current_state['current_step'] = step
return fig
def filter_by_token(
self,
fig: go.Figure,
token_idx: int,
token_type: str,
attention_matrices: List[Dict],
threshold: float = 0.05
) -> go.Figure:
"""
Filter connections to show only those related to selected token.
"""
with fig.batch_update():
current_step = self.current_state['current_step']
if token_type == 'input':
# Show only connections from this input token
for conn_info in self.traces_info['input_to_output']:
if conn_info['input_idx'] == token_idx and conn_info['output_idx'] <= current_step:
weight = attention_matrices[conn_info['output_idx']]['input_attention'][token_idx].item()
opacity = scale_weight_to_opacity(weight, threshold=threshold)
fig.data[conn_info['trace_idx']].opacity = opacity if opacity > 0 else 0
else:
fig.data[conn_info['trace_idx']].opacity = 0
# Hide all output-to-output connections
for conn_info in self.traces_info['output_to_output']:
fig.data[conn_info['trace_idx']].opacity = 0
elif token_type == 'output':
# Show connections to this output token
for conn_info in self.traces_info['input_to_output']:
if conn_info['output_idx'] == token_idx and token_idx <= current_step:
weight = attention_matrices[token_idx]['input_attention'][conn_info['input_idx']].item()
opacity = scale_weight_to_opacity(weight, threshold=threshold)
fig.data[conn_info['trace_idx']].opacity = opacity if opacity > 0 else 0
else:
fig.data[conn_info['trace_idx']].opacity = 0
# Show connections from/to this output token
for conn_info in self.traces_info['output_to_output']:
show = False
if conn_info['to_idx'] == token_idx and token_idx <= current_step:
if attention_matrices[token_idx]['output_attention'] is not None:
if conn_info['from_idx'] < len(attention_matrices[token_idx]['output_attention']):
weight = attention_matrices[token_idx]['output_attention'][conn_info['from_idx']].item()
opacity = scale_weight_to_opacity(weight, threshold=threshold)
fig.data[conn_info['trace_idx']].opacity = opacity if opacity > 0 else 0
show = True
elif conn_info['from_idx'] == token_idx and conn_info['to_idx'] <= current_step:
if attention_matrices[conn_info['to_idx']]['output_attention'] is not None:
if token_idx < len(attention_matrices[conn_info['to_idx']]['output_attention']):
weight = attention_matrices[conn_info['to_idx']]['output_attention'][token_idx].item()
opacity = scale_weight_to_opacity(weight, threshold=threshold)
fig.data[conn_info['trace_idx']].opacity = opacity if opacity > 0 else 0
show = True
if not show:
fig.data[conn_info['trace_idx']].opacity = 0
self.current_state['selected_token'] = token_idx
self.current_state['selected_type'] = token_type
self.current_state['show_all'] = False
return fig
def show_all_connections(
self,
fig: go.Figure,
attention_matrices: List[Dict],
threshold: float = 0.05
) -> go.Figure:
"""
Reset to show all connections for current step.
"""
self.current_state['selected_token'] = None
self.current_state['selected_type'] = None
self.current_state['show_all'] = True
return self.update_for_step(
fig,
self.current_state['current_step'],
attention_matrices,
[clean_label(t) for t in attention_matrices],
threshold
)
def _create_empty_figure(self, message: str) -> go.Figure:
"""Create an empty figure with a message."""
fig = go.Figure()
fig.update_layout(
title=message,
xaxis={'visible': False},
yaxis={'visible': False},
width=self.config.PLOT_WIDTH,
height=self.config.PLOT_HEIGHT
)
return fig