"""
D3.js visualization module for interactive token attention visualization.
"""
def create_d3_visualization(data):
"""
Generate a complete, self-contained HTML string with embedded D3.js visualization.
Args:
data (dict): JSON structure with nodes and links from prepare_d3_data()
Returns:
str: Complete HTML string with embedded D3.js, CSS, and JavaScript
"""
# Get nodes by type
input_nodes = [node for node in data.get('nodes', []) if node.get('type') == 'input']
output_nodes = [node for node in data.get('nodes', []) if node.get('type') == 'output']
links = data.get('links', [])
# SVG dimensions
width = 800
height = max(400, max(len(input_nodes), len(output_nodes)) * 50 + 100)
# Calculate positions
input_x = 100
output_x = width - 100
# Position nodes vertically
def get_y_pos(index, total):
if total <= 1:
return height // 2
return 80 + (index * (height - 160)) / (total - 1)
# Start building SVG
svg_html = f"""
Token Attention Visualization
Step {data.get('step', 0) + 1} | {len(input_nodes)} input → {len(output_nodes)} output | {len(links)} connections
Input Tokens
Output Tokens
"""
# Draw connections first (so they appear behind nodes)
for link in links:
# Find source and target nodes
source_node = next((n for n in input_nodes + output_nodes if n['id'] == link['source']), None)
target_node = next((n for n in input_nodes + output_nodes if n['id'] == link['target']), None)
if source_node and target_node:
# Get positions
if source_node['type'] == 'input':
source_idx = next((i for i, n in enumerate(input_nodes) if n['id'] == source_node['id']), 0)
source_y = get_y_pos(source_idx, len(input_nodes))
source_x_pos = input_x + 20 # Offset from center of node
else:
source_idx = next((i for i, n in enumerate(output_nodes) if n['id'] == source_node['id']), 0)
source_y = get_y_pos(source_idx, len(output_nodes))
source_x_pos = output_x - 20
if target_node['type'] == 'input':
target_idx = next((i for i, n in enumerate(input_nodes) if n['id'] == target_node['id']), 0)
target_y = get_y_pos(target_idx, len(input_nodes))
target_x_pos = input_x - 20
else:
target_idx = next((i for i, n in enumerate(output_nodes) if n['id'] == target_node['id']), 0)
target_y = get_y_pos(target_idx, len(output_nodes))
target_x_pos = output_x - 20
# Line properties based on weight
stroke_width = max(1, min(8, link['weight'] * 20))
opacity = max(0.3, min(1.0, link['weight'] * 2))
color = "#4285f4" if link['type'] == 'input_to_output' else "#ea4335"
# Create straight line
svg_html += f'''
'''
# Draw input nodes
for i, node in enumerate(input_nodes):
y = get_y_pos(i, len(input_nodes))
token_text = node['token']
# Clean token text - remove special prefix characters
if token_text.startswith('Ġ'):
token_text = token_text[1:] # Remove Ġ prefix
if token_text.startswith('▁'):
token_text = token_text[1:] # Remove ▁ prefix (SentencePiece)
if token_text.startswith('##'):
token_text = token_text[2:] # Remove ## prefix (BERT subwords)
if len(token_text) > 15:
token_text = token_text[:13] + "..."
svg_html += f'''
{token_text}
'''
# Draw output nodes
for i, node in enumerate(output_nodes):
y = get_y_pos(i, len(output_nodes))
token_text = node['token']
# Clean token text - remove special prefix characters
if token_text.startswith('Ġ'):
token_text = token_text[1:] # Remove Ġ prefix
if token_text.startswith('▁'):
token_text = token_text[1:] # Remove ▁ prefix (SentencePiece)
if token_text.startswith('##'):
token_text = token_text[2:] # Remove ## prefix (BERT subwords)
if len(token_text) > 15:
token_text = token_text[:13] + "..."
svg_html += f'''
{token_text}
'''
# Close SVG and add legend
svg_html += '''
'''
return svg_html
def create_d3_visualization_old(data):
"""
OLD VERSION - Generate a complete, self-contained HTML string with embedded D3.js visualization.
Args:
data (dict): JSON structure with nodes and links from prepare_d3_data()
Returns:
str: Complete HTML string with embedded D3.js, CSS, and JavaScript
"""
html_template = f"""
Reset View
Step: {data.get('step', 0) + 1} / {data.get('total_steps', 1)}
Nodes: {len(data.get('nodes', []))} | Links: {len(data.get('links', []))}
Click nodes to filter connections
"""
return html_template