modular-inheritance / generate_graph_modular.py
Molbap's picture
Molbap HF Staff
Upload generate_graph_modular.py
070e318 verified
import os
import ast
import json
import argparse
from collections import defaultdict, Counter
import re
def find_modular_files(transformers_path):
"""
Scans for files containing 'modular' in their name within the specified
Hugging Face Transformers repository path. This includes modeling, configuration,
and processing files.
"""
modular_files = []
models_path = os.path.join(transformers_path, 'src', 'transformers', 'models')
for root, _, files in os.walk(models_path):
for file in files:
if 'modular' in file and file.endswith('.py'):
modular_files.append(os.path.join(root, file))
return modular_files
def build_dependency_graph(modular_files):
"""
Builds a dependency graph by parsing the abstract syntax tree (AST) of each
modular file. It identifies imports from other models, configurations, and
processing files within the Transformers library.
"""
dependencies = defaultdict(list)
for file_path in modular_files:
derived_model_name = os.path.basename(os.path.dirname(file_path))
with open(file_path, 'r', encoding='utf-8') as f:
try:
tree = ast.parse(f.read(), filename=file_path)
for node in ast.walk(tree):
if not isinstance(node, ast.ImportFrom) or not node.module:
continue
is_relevant_import = ((
node.module.startswith('transformers.models.') or
'modeling_' in node.module or
'configuration_' in node.module or
'processing_' in node.module or
node.module.startswith('..'))
and (all([x not in node.module for x in ['modeling_attn_mask_utils']]))
)
if is_relevant_import:
path_parts = re.split(r'\.|\.', node.module)
if len(path_parts) > 1:
# Heuristic to find the source model name
source_model_name = ""
for part in path_parts:
if part not in ("", "models", "transformers"):
source_model_name = part
break
if source_model_name and source_model_name != derived_model_name:
for alias in node.names:
dependencies[derived_model_name].append({
'source': source_model_name,
'imported_class': alias.name
})
except Exception as e:
print(f"Could not parse {file_path}: {e}")
return dict(dependencies)
def print_debug_info(dependencies):
"""Prints a human-readable summary of the model dependencies."""
print("--- Model Dependency Debug ---")
if not dependencies:
print("No modular dependencies found.")
return
for derived_model, deps in sorted(dependencies.items()):
print(f"\n🎨 Derived Model: {derived_model}")
source_groups = defaultdict(list)
for dep in deps:
source_groups[dep['source']].append(dep['imported_class'])
for source, imports in sorted(source_groups.items()):
print(f" └── inherits from '{source}' (imports: {', '.join(sorted(imports))})")
print("\n--------------------------")
def generate_d3_visualization(dependencies, output_filename='d3_dependency_graph.html', hf_logo_path='hf-logo.svg'):
"""
Generates a self‑contained, interactive D3.js HTML file for visualizing
the dependency graph. The visualization is zoomable and uses a custom
SVG path for source nodes to resemble the Hugging Face logo.
Minor finetuning over the original version:
– Larger base‐model icons & labels
– Cleaner sans‑serif font (Inter/Arial fallback)
– Transparent page background
– Tighter layout (reduced repulsion & link distance)
– Fixed legend in top‑left corner
"""
# 1️⃣ Assemble graph‑data ------------------------------------------------------------------
nodes = set()
links = []
source_models = set()
derived_models = set(dependencies.keys())
for derived_model, deps in dependencies.items():
nodes.add(derived_model)
for dep in deps:
nodes.add(dep['source'])
source_models.add(dep['source'])
links.append({
"source": dep['source'],
"target": derived_model,
"label": dep['imported_class']
})
base_models = source_models - derived_models
consolidated_links = defaultdict(list)
for link in links:
key = (link['source'], link['target'])
consolidated_links[key].append(link['label'])
final_links = [
{"source": k[0], "target": k[1], "label": f"{len(v)} classes"}
for k, v in consolidated_links.items()
]
degree = Counter()
for link in final_links:
degree[link["source"]] += 1
degree[link["target"]] += 1
max_deg = max(degree.values() or [1]) # prevent div by 0
node_list = []
for name in sorted(nodes):
node_list.append({
"id": name,
"is_base": name in base_models,
"size": 1 + 2 * (degree[name] / max_deg)
})
graph_data = {
"nodes": node_list,
"links": final_links
}
# 2️⃣ Static path for the HF logo outline (unused but kept for reference) ------------------
hf_svg_path = (
"M21.2,6.7c-0.2-0.2-0.5-0.3-0.8-0.3H3.6C3.3,6.4,3,6.5,2.8,6.7s-0.3,0.5-0.3,0.8v10.8c0,0.3,0.1,0.5,0.3,0.8 "
"c0.2,0.2,0.5,0.3,0.8,0.3h16.8c0.3,0,0.5-0.1,0.8-0.3c0.2-0.2,0.3-0.5,0.3-0.8V7.5C21.5,7.2,21.4,6.9,21.2,6.7z "
"M12,17.8L5.9,9.4h3.1 V8.3h6v1.1h3.1L12,17.8z"
)
# 3️⃣ HTML / CSS / JS ---------------------------------------------------------------------
html_template = f"""
<!DOCTYPE html>
<html lang=\"en\">
<head>
<meta charset=\"UTF-8\">
<title>Transformers Modular Model Dependencies</title>
<style>
/* Google‑font – small fallback cost & optional */
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
:root {{
--base‑size: 60px; /* icon radius helper */
}}
body {{
font-family: 'Inter', Arial, sans-serif;
margin: 0;
overflow: hidden;
background-color: transparent; /* requested transparency */
}}
svg {{
width: 100vw;
height: 100vh;
}}
.link {{
stroke: #999;
stroke-opacity: 0.6;
}}
.node-label {{
fill: #333;
pointer-events: none;
text-anchor: middle;
font-weight: 600;
}}
.link-label {{
fill: #555;
font-size: 10px;
pointer-events: none;
text-anchor: middle;
}}
.node.base path {{ fill: #ffbe0b; }}
.node.derived circle {{ fill: #1f77b4; }}
/* Legend styling */
#legend {{
position: fixed;
top: 18px;
left: 18px;
font-size: 20px;
background: rgba(255,255,255,0.92);
padding: 18px 28px;
border-radius: 10px;
border: 1.5px solid #bbb;
font-family: 'Inter', Arial, sans-serif;
box-shadow: 0 2px 8px rgba(0,0,0,0.08);
z-index: 1000;
}}
</style>
</head>
<body>
<div id=\"legend\">🟡 base model (HF icon)<br>🔵 derived modular model<br>Edge label: #classes imported</div>
<svg id=\"dependency-graph\"></svg>
<script src=\"https://d3js.org/d3.v7.min.js\"></script>
<script>
const graphData = {json.dumps(graph_data, indent=4)};
const hfLogoPath = "{hf_svg_path}"; // kept for potential future use
const width = window.innerWidth;
const height = window.innerHeight;
const svg = d3.select('#dependency-graph')
.call(
d3.zoom().on('zoom', (event) => {{
g.attr('transform', event.transform);
}})
);
const g = svg.append('g');
// Forces – tweaked for tighter graph
const simulation = d3.forceSimulation(graphData.nodes)
.force('link', d3.forceLink(graphData.links).id(d => d.id).distance(500))
.force('charge', d3.forceManyBody().strength(-500))
.force('center', d3.forceCenter(width / 2, height / 2))
.force('collide', d3.forceCollide(0.01 * parseFloat(getComputedStyle(document.documentElement).getPropertyValue('--base‑size'))));
// Links
const link = g.append('g')
.selectAll('line')
.data(graphData.links)
.join('line')
.attr('class', 'link')
.attr('stroke-width', 1.5);
// Link‑labels (#classes)
const linkLabel = g.append('g')
.selectAll('text')
.data(graphData.links)
.join('text')
.attr('class', 'link-label')
.text(d => d.label);
// Nodes (base vs derived)
const node = g.append('g')
.selectAll('g')
.data(graphData.nodes)
.join('g')
.attr('class', d => d.is_base ? 'node base' : 'node derived')
.call(d3.drag()
.on('start', dragstarted)
.on('drag', dragged)
.on('end', dragended)
);
// Base‑model icon (HF logo)
node.filter(d => d.is_base)
.append('image')
.attr('xlink:href', '{hf_logo_path}')
.attr('x', -parseFloat(getComputedStyle(document.documentElement).getPropertyValue('--base‑size')) / 2)
.attr('y', -parseFloat(getComputedStyle(document.documentElement).getPropertyValue('--base‑size')) / 2)
.attr('width', parseFloat(getComputedStyle(document.documentElement).getPropertyValue('--base‑size')))
.attr('height', parseFloat(getComputedStyle(document.documentElement).getPropertyValue('--base‑size')));
// Base‑model label (below icon)
node.filter(d => d.is_base)
.append('text')
.attr('class', 'node-label')
.attr('y', d => 30 * d.size + 8) // keep under the icon
.style('font-size', d => `${{26 * d.size}}px`) // scale 26–78 px for size 1-3
.text(d => d.id);
// Derived‑model circle + label w/ background rect
const derived = node.filter(d => !d.is_base);
derived.append('circle')
.attr('r', d => 20 * d.size); // scaled
const labelGroup = derived.append('g').attr('class', 'label-group');
labelGroup.append('rect')
.attr('x', -45)
.attr('y', -18)
.attr('width', 90)
.attr('height', 36)
.attr('rx', 8)
.attr('fill', '#fffbe6')
.attr('stroke', '#ccc');
labelGroup.append('text')
.attr('class', 'node-label')
.attr('dy', '0.35em')
.style('font-size', '18px')
.text(d => d.id);
// Tick
simulation.on('tick', () => {{
link.attr('x1', d => d.source.x)
.attr('y1', d => d.source.y)
.attr('x2', d => d.target.x)
.attr('y2', d => d.target.y);
linkLabel.attr('x', d => (d.source.x + d.target.x) / 2)
.attr('y', d => (d.source.y + d.target.y) / 2);
node.attr('transform', d => `translate(${{d.x}}, ${{d.y}})`);
}});
// Drag helpers
function dragstarted(event, d) {{
if (!event.active) simulation.alphaTarget(0.3).restart();
d.fx = d.x; d.fy = d.y;
}}
function dragged(event, d) {{
d.fx = event.x; d.fy = event.y;
}}
function dragended(event, d) {{
if (!event.active) simulation.alphaTarget(0);
d.fx = null; d.fy = null;
}}
</script>
</body>
</html>
"""
with open(output_filename, 'w', encoding='utf-8') as f:
f.write(html_template)
print(f"✅ D3.js visualization saved to '{output_filename}'. Open this file in your browser.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Visualize modular model dependencies in Transformers using D3.js.")
parser.add_argument("transformers_path", type=str,
help="The local path to the Hugging Face transformers repository.")
args = parser.parse_args()
modular_files = find_modular_files(args.transformers_path)
if not modular_files:
print("No modular files found. Make sure the path to the transformers repository is correct.")
else:
dependencies = build_dependency_graph(modular_files)
print_debug_info(dependencies)
generate_d3_visualization(dependencies)