""" D3.js visualization module for interactive token attention visualization. """ def create_d3_visualization(data): """ Generate a complete, self-contained HTML string with embedded D3.js visualization. Args: data (dict): JSON structure with nodes and links from prepare_d3_data() Returns: str: Complete HTML string with embedded D3.js, CSS, and JavaScript """ # Get nodes by type input_nodes = [node for node in data.get('nodes', []) if node.get('type') == 'input'] output_nodes = [node for node in data.get('nodes', []) if node.get('type') == 'output'] links = data.get('links', []) # SVG dimensions width = 800 height = max(400, max(len(input_nodes), len(output_nodes)) * 50 + 100) # Calculate positions input_x = 100 output_x = width - 100 # Position nodes vertically def get_y_pos(index, total): if total <= 1: return height // 2 return 80 + (index * (height - 160)) / (total - 1) # Start building SVG svg_html = f"""

Token Attention Visualization

Step {data.get('step', 0) + 1} | {len(input_nodes)} input → {len(output_nodes)} output | {len(links)} connections

Input Tokens Output Tokens """ # Draw connections first (so they appear behind nodes) for link in links: # Find source and target nodes source_node = next((n for n in input_nodes + output_nodes if n['id'] == link['source']), None) target_node = next((n for n in input_nodes + output_nodes if n['id'] == link['target']), None) if source_node and target_node: # Get positions if source_node['type'] == 'input': source_idx = next((i for i, n in enumerate(input_nodes) if n['id'] == source_node['id']), 0) source_y = get_y_pos(source_idx, len(input_nodes)) source_x_pos = input_x + 20 # Offset from center of node else: source_idx = next((i for i, n in enumerate(output_nodes) if n['id'] == source_node['id']), 0) source_y = get_y_pos(source_idx, len(output_nodes)) source_x_pos = output_x - 20 if target_node['type'] == 'input': target_idx = next((i for i, n in enumerate(input_nodes) if n['id'] == target_node['id']), 0) target_y = get_y_pos(target_idx, len(input_nodes)) target_x_pos = input_x - 20 else: target_idx = next((i for i, n in enumerate(output_nodes) if n['id'] == target_node['id']), 0) target_y = get_y_pos(target_idx, len(output_nodes)) target_x_pos = output_x - 20 # Line properties based on weight stroke_width = max(1, min(8, link['weight'] * 20)) opacity = max(0.3, min(1.0, link['weight'] * 2)) color = "#4285f4" if link['type'] == 'input_to_output' else "#ea4335" # Create straight line svg_html += f''' ''' # Draw input nodes for i, node in enumerate(input_nodes): y = get_y_pos(i, len(input_nodes)) token_text = node['token'] # Clean token text - remove special prefix characters if token_text.startswith('Ġ'): token_text = token_text[1:] # Remove Ġ prefix if token_text.startswith('▁'): token_text = token_text[1:] # Remove ▁ prefix (SentencePiece) if token_text.startswith('##'): token_text = token_text[2:] # Remove ## prefix (BERT subwords) if len(token_text) > 15: token_text = token_text[:13] + "..." svg_html += f''' {token_text} ''' # Draw output nodes for i, node in enumerate(output_nodes): y = get_y_pos(i, len(output_nodes)) token_text = node['token'] # Clean token text - remove special prefix characters if token_text.startswith('Ġ'): token_text = token_text[1:] # Remove Ġ prefix if token_text.startswith('▁'): token_text = token_text[1:] # Remove ▁ prefix (SentencePiece) if token_text.startswith('##'): token_text = token_text[2:] # Remove ## prefix (BERT subwords) if len(token_text) > 15: token_text = token_text[:13] + "..." svg_html += f''' {token_text} ''' # Close SVG and add legend svg_html += '''
Input → Output
Line thickness = weight
''' return svg_html def create_d3_visualization_old(data): """ OLD VERSION - Generate a complete, self-contained HTML string with embedded D3.js visualization. Args: data (dict): JSON structure with nodes and links from prepare_d3_data() Returns: str: Complete HTML string with embedded D3.js, CSS, and JavaScript """ html_template = f"""
Step: {data.get('step', 0) + 1} / {data.get('total_steps', 1)}
Nodes: {len(data.get('nodes', []))} | Links: {len(data.get('links', []))}
Click nodes to filter connections
""" return html_template