|  |  | 
					
						
						|  | """ | 
					
						
						|  | modular_graph_and_candidates.py | 
					
						
						|  | ================================ | 
					
						
						|  | Create **one** rich view that combines | 
					
						
						|  | 1.  The *dependency graph* between existing **modular_*.py** implementations in | 
					
						
						|  | π€Β Transformers (blue/π‘) **and** | 
					
						
						|  | 2.  The list of *missing* modular models (fullβred nodes) **plus** similarity | 
					
						
						|  | edges (fullβred links) between highlyβoverlapping modelling files β the | 
					
						
						|  | output of *find_modular_candidates.py* β so you can immediately spot good | 
					
						
						|  | refactor opportunities. | 
					
						
						|  |  | 
					
						
						|  | βββΒ UsageΒ βββ | 
					
						
						|  |  | 
					
						
						|  | ```bash | 
					
						
						|  | python modular_graph_and_candidates.py /path/to/transformers \ | 
					
						
						|  | --multimodal         # keep only models whose modelling code mentions | 
					
						
						|  | # "pixel_values" β₯Β 3 times | 
					
						
						|  | --sim-threshold 0.5  # Jaccard cutoff (default 0.50) | 
					
						
						|  | --out graph.html     # output HTML file name | 
					
						
						|  | ``` | 
					
						
						|  |  | 
					
						
						|  | Colour legend in the generated HTML: | 
					
						
						|  | * π‘Β **base model**Β β has modular shards *imported* by others but no parent | 
					
						
						|  | * π΅Β **derived modular model**Β β has a `modular_*.py` and inherits from β₯β―1 model | 
					
						
						|  | * π΄Β **candidate**Β β no `modular_*.py` yet (and/or very similar to another) | 
					
						
						|  | * red edges = highβJaccard similarity links (potential to factorise) | 
					
						
						|  | """ | 
					
						
						|  | from __future__ import annotations | 
					
						
						|  |  | 
					
						
						|  | import argparse | 
					
						
						|  | import ast | 
					
						
						|  | import json | 
					
						
						|  | import re | 
					
						
						|  | import tokenize | 
					
						
						|  | from collections import Counter, defaultdict | 
					
						
						|  | from itertools import combinations | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from typing import Dict, List, Set, Tuple | 
					
						
						|  | from sentence_transformers import SentenceTransformer, util | 
					
						
						|  | from tqdm import tqdm | 
					
						
						|  | import numpy as np | 
					
						
						|  | import spaces | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | SIM_DEFAULT     = 0.5 | 
					
						
						|  | PIXEL_MIN_HITS  = 0 | 
					
						
						|  | HTML_DEFAULT = "d3_modular_graph.html" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _strip_source(code: str) -> str: | 
					
						
						|  | """Remove docβstrings, comments and import lines to keep only the core code.""" | 
					
						
						|  | code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code) | 
					
						
						|  | code = re.sub(r"#.*", "", code) | 
					
						
						|  | return "\n".join(ln for ln in code.splitlines() | 
					
						
						|  | if not re.match(r"\s*(from|import)\s+", ln)) | 
					
						
						|  |  | 
					
						
						|  | def _tokenise(code: str) -> Set[str]: | 
					
						
						|  | """Extract identifiers using regex - more robust than tokenizer for malformed code.""" | 
					
						
						|  | toks: Set[str] = set() | 
					
						
						|  | for match in re.finditer(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', code): | 
					
						
						|  | toks.add(match.group()) | 
					
						
						|  | return toks | 
					
						
						|  |  | 
					
						
						|  | def build_token_bags(models_root: Path) -> Tuple[Dict[str, List[Set[str]]], Dict[str, int]]: | 
					
						
						|  | """Return tokenβbags of every `modeling_*.py` plus a pixelβvalue counter.""" | 
					
						
						|  | bags: Dict[str, List[Set[str]]] = defaultdict(list) | 
					
						
						|  | pixel_hits: Dict[str, int] = defaultdict(int) | 
					
						
						|  | for mdl_dir in sorted(p for p in models_root.iterdir() if p.is_dir()): | 
					
						
						|  | for py in mdl_dir.rglob("modeling_*.py"): | 
					
						
						|  | try: | 
					
						
						|  | text = py.read_text(encoding="utfβ8") | 
					
						
						|  | pixel_hits[mdl_dir.name] += text.count("pixel_values") | 
					
						
						|  | bags[mdl_dir.name].append(_tokenise(_strip_source(text))) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"β οΈ  Skipped {py}: {e}") | 
					
						
						|  | return bags, pixel_hits | 
					
						
						|  |  | 
					
						
						|  | def _jaccard(a: Set[str], b: Set[str]) -> float: | 
					
						
						|  | return 0.0 if (not a or not b) else len(a & b) / len(a | b) | 
					
						
						|  |  | 
					
						
						|  | def similarity_clusters(bags: Dict[str, List[Set[str]]], thr: float) -> Dict[Tuple[str,str], float]: | 
					
						
						|  | """Return {(modelA, modelB): score} for pairs with Jaccard β₯ *thr*.""" | 
					
						
						|  | largest = {m: max(ts, key=len) for m, ts in bags.items() if ts} | 
					
						
						|  | out: Dict[Tuple[str,str], float] = {} | 
					
						
						|  | for m1, m2 in combinations(sorted(largest.keys()), 2): | 
					
						
						|  | s = _jaccard(largest[m1], largest[m2]) | 
					
						
						|  | if s >= thr: | 
					
						
						|  | out[(m1, m2)] = s | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  | @spaces.GPU | 
					
						
						|  | def embedding_similarity_clusters(models_root: Path, missing: List[str], thr: float) -> Dict[Tuple[str, str], float]: | 
					
						
						|  | model = SentenceTransformer("codesage/codesage-large-v2", device="cuda", trust_remote_code=True) | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | cfg = model[0].auto_model.config | 
					
						
						|  | pos_limit = int(getattr(cfg, "n_positions", getattr(cfg, "max_position_embeddings"))) | 
					
						
						|  | except Exception: | 
					
						
						|  | pos_limit = 1024 | 
					
						
						|  |  | 
					
						
						|  | seq_len = min(pos_limit, 2048) | 
					
						
						|  | model.max_seq_length = seq_len | 
					
						
						|  | model[0].max_seq_length = seq_len | 
					
						
						|  | model[0].tokenizer.model_max_length = seq_len | 
					
						
						|  |  | 
					
						
						|  | texts = {} | 
					
						
						|  | for name in tqdm(missing, desc="Reading modeling files"): | 
					
						
						|  | if any(skip in name.lower() for skip in ["mobilebert", "lxmert"]): | 
					
						
						|  | print(f"Skipping {name} (causes GPU abort)") | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | code = "" | 
					
						
						|  | for py in (models_root / name).rglob("modeling_*.py"): | 
					
						
						|  | try: | 
					
						
						|  | code += _strip_source(py.read_text(encoding="utf-8")) + "\n" | 
					
						
						|  | except Exception: | 
					
						
						|  | continue | 
					
						
						|  | texts[name] = code.strip() or " " | 
					
						
						|  |  | 
					
						
						|  | names = list(texts) | 
					
						
						|  | all_embeddings = [] | 
					
						
						|  |  | 
					
						
						|  | print(f"Encoding embeddings for {len(names)} models...") | 
					
						
						|  | batch_size = 4 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | temp_cache_path = Path("temp_embeddings.npz") | 
					
						
						|  | final_cache_path = Path("embeddings_cache.npz") | 
					
						
						|  | start_idx = 0 | 
					
						
						|  | emb_dim = getattr(model, "get_sentence_embedding_dimension", lambda: 768)() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if final_cache_path.exists(): | 
					
						
						|  | try: | 
					
						
						|  | cached = np.load(final_cache_path, allow_pickle=True) | 
					
						
						|  | cached_names = list(cached["names"]) | 
					
						
						|  | if names == cached_names: | 
					
						
						|  | print(f"β
 Using final embeddings cache ({len(cached_names)} models)") | 
					
						
						|  | return compute_similarities_from_cache(thr) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"β οΈ  Failed to load final cache: {e}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if temp_cache_path.exists(): | 
					
						
						|  | try: | 
					
						
						|  | cached = np.load(temp_cache_path, allow_pickle=True) | 
					
						
						|  | cached_names = list(cached["names"]) | 
					
						
						|  | if names[:len(cached_names)] == cached_names: | 
					
						
						|  | loaded = cached["embeddings"].astype(np.float32) | 
					
						
						|  | all_embeddings.append(loaded) | 
					
						
						|  | start_idx = len(cached_names) | 
					
						
						|  | print(f"π Resuming from temp cache: {start_idx}/{len(names)} models") | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"β οΈ  Failed to load temp cache: {e}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for i in tqdm(range(start_idx, len(names), batch_size), desc="Batches", leave=False): | 
					
						
						|  | batch_names = names[i:i+batch_size] | 
					
						
						|  | batch_texts = [texts[name] for name in batch_names] | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | print(f"Processing batch: {batch_names}") | 
					
						
						|  | emb = model.encode(batch_texts, convert_to_numpy=True, show_progress_bar=False) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"β οΈ  GPU worker error for batch {batch_names}: {type(e).__name__}: {e}") | 
					
						
						|  | emb = np.zeros((len(batch_names), emb_dim), dtype=np.float32) | 
					
						
						|  |  | 
					
						
						|  | all_embeddings.append(emb) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | cur = np.vstack(all_embeddings).astype(np.float32) | 
					
						
						|  | np.savez( | 
					
						
						|  | temp_cache_path, | 
					
						
						|  | embeddings=cur, | 
					
						
						|  | names=np.array(names[:i+len(batch_names)], dtype=object), | 
					
						
						|  | ) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"β οΈ  Failed to write temp cache: {e}") | 
					
						
						|  |  | 
					
						
						|  | if (i - start_idx) % (3 * batch_size) == 0 and torch.cuda.is_available(): | 
					
						
						|  | torch.cuda.empty_cache() | 
					
						
						|  | torch.cuda.synchronize() | 
					
						
						|  | print(f"π§Ή Cleared GPU cache after batch {(i - start_idx)//batch_size + 1}") | 
					
						
						|  |  | 
					
						
						|  | embeddings = np.vstack(all_embeddings).astype(np.float32) | 
					
						
						|  | norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12 | 
					
						
						|  | embeddings = embeddings / norms | 
					
						
						|  |  | 
					
						
						|  | print("Computing pairwise similarities...") | 
					
						
						|  | sims_mat = embeddings @ embeddings.T | 
					
						
						|  |  | 
					
						
						|  | out = {} | 
					
						
						|  | matrix_size = embeddings.shape[0] | 
					
						
						|  | processed_names = names[:matrix_size] | 
					
						
						|  | for i in range(matrix_size): | 
					
						
						|  | for j in range(i + 1, matrix_size): | 
					
						
						|  | s = float(sims_mat[i, j]) | 
					
						
						|  | if s >= thr: | 
					
						
						|  | out[(processed_names[i], processed_names[j])] = s | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | np.savez(final_cache_path, embeddings=embeddings, names=np.array(names, dtype=object)) | 
					
						
						|  | print(f"πΎ Final embeddings saved to {final_cache_path}") | 
					
						
						|  |  | 
					
						
						|  | if temp_cache_path.exists(): | 
					
						
						|  | temp_cache_path.unlink() | 
					
						
						|  | print(f"π§Ή Cleaned up temp cache") | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"β οΈ Failed to save final cache: {e}") | 
					
						
						|  |  | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  | def compute_similarities_from_cache(threshold: float) -> Dict[Tuple[str, str], float]: | 
					
						
						|  | """Compute similarities from cached embeddings without reprocessing.""" | 
					
						
						|  | embeddings_path = Path("embeddings_cache.npz") | 
					
						
						|  |  | 
					
						
						|  | if not embeddings_path.exists(): | 
					
						
						|  | return {} | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | cached = np.load(embeddings_path, allow_pickle=True) | 
					
						
						|  | embeddings = cached["embeddings"].astype(np.float32) | 
					
						
						|  | names = list(cached["names"]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12 | 
					
						
						|  | embeddings = embeddings / norms | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sims_mat = embeddings @ embeddings.T | 
					
						
						|  |  | 
					
						
						|  | out = {} | 
					
						
						|  | for i in range(len(names)): | 
					
						
						|  | for j in range(i + 1, len(names)): | 
					
						
						|  | s = float(sims_mat[i, j]) | 
					
						
						|  | if s >= threshold: | 
					
						
						|  | out[(names[i], names[j])] = s | 
					
						
						|  |  | 
					
						
						|  | print(f"β‘ Computed {len(out)} similarities from cache (threshold: {threshold})") | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"β οΈ  Failed to compute from cache: {e}") | 
					
						
						|  | return {} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def modular_files(models_root: Path) -> List[Path]: | 
					
						
						|  | return [p for p in models_root.rglob("modular_*.py") if p.suffix == ".py"] | 
					
						
						|  |  | 
					
						
						|  | def dependency_graph(modular_files: List[Path], models_root: Path) -> Dict[str, List[Dict[str,str]]]: | 
					
						
						|  | """Return {derived_model: [{source, imported_class}, ...]} | 
					
						
						|  |  | 
					
						
						|  | Only `modeling_*` imports are kept; anything coming from configuration/processing/ | 
					
						
						|  | image* utils is ignored so the visual graph focuses strictly on modelling code. | 
					
						
						|  | Excludes edges to sources whose model name is not a model dir. | 
					
						
						|  | """ | 
					
						
						|  | model_names = {p.name for p in models_root.iterdir() if p.is_dir()} | 
					
						
						|  | deps: Dict[str, List[Dict[str,str]]] = defaultdict(list) | 
					
						
						|  | for fp in modular_files: | 
					
						
						|  | derived = fp.parent.name | 
					
						
						|  | try: | 
					
						
						|  | tree = ast.parse(fp.read_text(encoding="utfβ8"), filename=str(fp)) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"β οΈ  AST parse failed for {fp}: {e}") | 
					
						
						|  | continue | 
					
						
						|  | for node in ast.walk(tree): | 
					
						
						|  | if not isinstance(node, ast.ImportFrom) or not node.module: | 
					
						
						|  | continue | 
					
						
						|  | mod = node.module | 
					
						
						|  |  | 
					
						
						|  | if ("modeling_" not in mod or | 
					
						
						|  | "configuration_" in mod or | 
					
						
						|  | "processing_" in mod or | 
					
						
						|  | "image_processing" in mod or | 
					
						
						|  | "modeling_attn_mask_utils" in mod): | 
					
						
						|  | continue | 
					
						
						|  | parts = re.split(r"[./]", mod) | 
					
						
						|  | src = next((p for p in parts if p not in {"", "models", "transformers"}), "") | 
					
						
						|  | if not src or src == derived or src not in model_names: | 
					
						
						|  | continue | 
					
						
						|  | for alias in node.names: | 
					
						
						|  | deps[derived].append({"source": src, "imported_class": alias.name}) | 
					
						
						|  | return dict(deps) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_missing_models(models_root: Path, multimodal: bool = False) -> Tuple[List[str], Dict[str, List[Set[str]]], Dict[str, int]]: | 
					
						
						|  | """Get list of models missing modular implementations.""" | 
					
						
						|  | bags, pix_hits = build_token_bags(models_root) | 
					
						
						|  | mod_files = modular_files(models_root) | 
					
						
						|  | models_with_modular = {p.parent.name for p in mod_files} | 
					
						
						|  | missing = [m for m in bags if m not in models_with_modular] | 
					
						
						|  |  | 
					
						
						|  | if multimodal: | 
					
						
						|  | missing = [m for m in missing if pix_hits[m] >= PIXEL_MIN_HITS] | 
					
						
						|  |  | 
					
						
						|  | return missing, bags, pix_hits | 
					
						
						|  |  | 
					
						
						|  | def compute_similarities(models_root: Path, missing: List[str], bags: Dict[str, List[Set[str]]], | 
					
						
						|  | threshold: float, sim_method: str) -> Dict[Tuple[str, str], float]: | 
					
						
						|  | """Compute similarities between missing models using specified method.""" | 
					
						
						|  | if sim_method == "jaccard": | 
					
						
						|  | return similarity_clusters({m: bags[m] for m in missing}, threshold) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | embeddings_path = Path("embeddings_cache.npz") | 
					
						
						|  | if embeddings_path.exists(): | 
					
						
						|  | cached_sims = compute_similarities_from_cache(threshold) | 
					
						
						|  | if cached_sims: | 
					
						
						|  | return cached_sims | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return embedding_similarity_clusters(models_root, missing, threshold) | 
					
						
						|  |  | 
					
						
						|  | def build_graph_json( | 
					
						
						|  | transformers_dir: Path, | 
					
						
						|  | threshold: float = SIM_DEFAULT, | 
					
						
						|  | multimodal: bool = False, | 
					
						
						|  | sim_method: str = "jaccard", | 
					
						
						|  | ) -> dict: | 
					
						
						|  | """Return the {nodes, links} dict that D3 needs.""" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | embeddings_cache = Path("embeddings_cache.npz") | 
					
						
						|  | print(f"π Cache file exists: {embeddings_cache.exists()}, sim_method: {sim_method}") | 
					
						
						|  |  | 
					
						
						|  | if sim_method == "embedding" and embeddings_cache.exists(): | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | cached_sims = compute_similarities_from_cache(threshold) | 
					
						
						|  | print(f"π Got {len(cached_sims)} cached similarities") | 
					
						
						|  |  | 
					
						
						|  | if cached_sims: | 
					
						
						|  |  | 
					
						
						|  | cached_data = np.load(embeddings_cache, allow_pickle=True) | 
					
						
						|  | missing = list(cached_data["names"]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | models_root = transformers_dir / "src/transformers/models" | 
					
						
						|  | mod_files = modular_files(models_root) | 
					
						
						|  | deps = dependency_graph(mod_files, models_root) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | nodes = set(missing) | 
					
						
						|  | links = [] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for drv, lst in deps.items(): | 
					
						
						|  | for d in lst: | 
					
						
						|  | links.append({ | 
					
						
						|  | "source": d["source"], | 
					
						
						|  | "target": drv, | 
					
						
						|  | "label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports", | 
					
						
						|  | "cand": False | 
					
						
						|  | }) | 
					
						
						|  | nodes.update({d["source"], drv}) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for (a, b), s in cached_sims.items(): | 
					
						
						|  | links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True}) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | targets = {lk["target"] for lk in links if not lk["cand"]} | 
					
						
						|  | sources = {lk["source"] for lk in links if not lk["cand"]} | 
					
						
						|  |  | 
					
						
						|  | nodelist = [] | 
					
						
						|  | for n in sorted(nodes): | 
					
						
						|  | if n in missing and n not in sources and n not in targets: | 
					
						
						|  | cls = "cand" | 
					
						
						|  | elif n in sources and n not in targets: | 
					
						
						|  | cls = "base" | 
					
						
						|  | else: | 
					
						
						|  | cls = "derived" | 
					
						
						|  | nodelist.append({"id": n, "cls": cls, "sz": 1}) | 
					
						
						|  |  | 
					
						
						|  | print(f"β‘ Built graph from cache: {len(nodelist)} nodes, {len(links)} links") | 
					
						
						|  | return {"nodes": nodelist, "links": links} | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"β οΈ Cache-only build failed: {e}, falling back to full build") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | models_root = transformers_dir / "src/transformers/models" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | missing, bags, pix_hits = get_missing_models(models_root, multimodal) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mod_files = modular_files(models_root) | 
					
						
						|  | deps = dependency_graph(mod_files, models_root) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sims = compute_similarities(models_root, missing, bags, threshold, sim_method) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | nodes: Set[str] = set() | 
					
						
						|  | links: List[dict] = [] | 
					
						
						|  |  | 
					
						
						|  | for drv, lst in deps.items(): | 
					
						
						|  | for d in lst: | 
					
						
						|  | links.append({ | 
					
						
						|  | "source": d["source"], | 
					
						
						|  | "target": drv, | 
					
						
						|  | "label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports", | 
					
						
						|  | "cand": False | 
					
						
						|  | }) | 
					
						
						|  | nodes.update({d["source"], drv}) | 
					
						
						|  |  | 
					
						
						|  | for (a, b), s in sims.items(): | 
					
						
						|  | links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True}) | 
					
						
						|  | nodes.update({a, b}) | 
					
						
						|  |  | 
					
						
						|  | nodes.update(missing) | 
					
						
						|  |  | 
					
						
						|  | deg = Counter() | 
					
						
						|  | for lk in links: | 
					
						
						|  | deg[lk["source"]] += 1 | 
					
						
						|  | deg[lk["target"]] += 1 | 
					
						
						|  | max_deg = max(deg.values() or [1]) | 
					
						
						|  |  | 
					
						
						|  | targets = {lk["target"] for lk in links if not lk["cand"]} | 
					
						
						|  | sources = {lk["source"] for lk in links if not lk["cand"]} | 
					
						
						|  | missing_only = [m for m in missing if m not in sources and m not in targets] | 
					
						
						|  | nodes.update(missing_only) | 
					
						
						|  |  | 
					
						
						|  | nodelist = [] | 
					
						
						|  | for n in sorted(nodes): | 
					
						
						|  | if n in missing_only: | 
					
						
						|  | cls = "cand" | 
					
						
						|  | elif n in sources and n not in targets: | 
					
						
						|  | cls = "base" | 
					
						
						|  | else: | 
					
						
						|  | cls = "derived" | 
					
						
						|  | nodelist.append({"id": n, "cls": cls, "sz": 1 + 2*(deg[n]/max_deg)}) | 
					
						
						|  |  | 
					
						
						|  | graph = {"nodes": nodelist, "links": links} | 
					
						
						|  | return graph | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def generate_html(graph: dict) -> str: | 
					
						
						|  | """Return the full HTML string with inlined CSS/JS + graph JSON.""" | 
					
						
						|  | js = JS.replace("__GRAPH_DATA__", json.dumps(graph, separators=(",", ":"))) | 
					
						
						|  | return HTML.replace("__CSS__", CSS).replace("__JS__", js) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | CSS = """ | 
					
						
						|  | @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap'); | 
					
						
						|  |  | 
					
						
						|  | :root{ | 
					
						
						|  | --bg:#ffffff; | 
					
						
						|  | --text:#222222; | 
					
						
						|  | --muted:#555555; | 
					
						
						|  | --outline:#ffffff; | 
					
						
						|  | } | 
					
						
						|  | @media (prefers-color-scheme: dark){ | 
					
						
						|  | :root{ | 
					
						
						|  | --bg:#0b0d10; | 
					
						
						|  | --text:#e8e8e8; | 
					
						
						|  | --muted:#c8c8c8; | 
					
						
						|  | --outline:#000000; | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | body{ margin:0; font-family:'Inter',Arial,sans-serif; background:var(--bg); overflow:hidden; } | 
					
						
						|  | svg{ width:100vw; height:100vh; } | 
					
						
						|  |  | 
					
						
						|  | .link{ stroke:#999; stroke-opacity:.6; } | 
					
						
						|  | .link.cand{ stroke:#e63946; stroke-width:2.5; } | 
					
						
						|  |  | 
					
						
						|  | .node-label{ | 
					
						
						|  | fill:var(--text); | 
					
						
						|  | pointer-events:none; | 
					
						
						|  | text-anchor:middle; | 
					
						
						|  | font-weight:600; | 
					
						
						|  | paint-order:stroke fill; | 
					
						
						|  | stroke:var(--outline); | 
					
						
						|  | stroke-width:3px; | 
					
						
						|  | } | 
					
						
						|  | .link-label{ | 
					
						
						|  | fill:var(--muted); | 
					
						
						|  | pointer-events:none; | 
					
						
						|  | text-anchor:middle; | 
					
						
						|  | font-size:10px; | 
					
						
						|  | paint-order:stroke fill; | 
					
						
						|  | stroke:var(--bg); | 
					
						
						|  | stroke-width:2px; | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | .node.base image{ width:60px; height:60px; transform:translate(-30px,-30px); } | 
					
						
						|  | .node.derived circle{ fill:#1f77b4; } | 
					
						
						|  | .node.cand circle, .node.cand path{ fill:#e63946; } | 
					
						
						|  |  | 
					
						
						|  | #legend{ | 
					
						
						|  | position:fixed; top:18px; left:18px; | 
					
						
						|  | background:rgba(255,255,255,.92); | 
					
						
						|  | padding:18px 28px; border-radius:10px; border:1.5px solid #bbb; | 
					
						
						|  | font-size:18px; box-shadow:0 2px 8px rgba(0,0,0,.08); | 
					
						
						|  | } | 
					
						
						|  | @media (prefers-color-scheme: dark){ | 
					
						
						|  | #legend{ background:rgba(20,22,25,.92); color:#e8e8e8; border-color:#444; } | 
					
						
						|  | } | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | JS = """ | 
					
						
						|  | function updateVisibility() { | 
					
						
						|  | const show = document.getElementById('toggleRed').checked; | 
					
						
						|  | svg.selectAll('.link.cand').style('display', show ? null : 'none'); | 
					
						
						|  | svg.selectAll('.node.cand').style('display', show ? null : 'none'); | 
					
						
						|  | svg.selectAll('.link-label').filter(d => d.cand).style('display', show ? null : 'none'); | 
					
						
						|  | } | 
					
						
						|  | document.getElementById('toggleRed').addEventListener('change', updateVisibility); | 
					
						
						|  |  | 
					
						
						|  | const HF_LOGO_URI = "./static/hf-logo.svg"; | 
					
						
						|  | const graph = __GRAPH_DATA__; | 
					
						
						|  | const W = innerWidth, H = innerHeight; | 
					
						
						|  | const svg = d3.select('#dependency').call(d3.zoom().on('zoom', e => g.attr('transform', e.transform))); | 
					
						
						|  | const g = svg.append('g'); | 
					
						
						|  |  | 
					
						
						|  | const link = g.selectAll('line') | 
					
						
						|  | .data(graph.links) | 
					
						
						|  | .join('line') | 
					
						
						|  | .attr('class', d => d.cand ? 'link cand' : 'link'); | 
					
						
						|  |  | 
					
						
						|  | const linkLbl = g.selectAll('text.link-label') | 
					
						
						|  | .data(graph.links) | 
					
						
						|  | .join('text') | 
					
						
						|  | .attr('class', 'link-label') | 
					
						
						|  | .text(d => d.label); | 
					
						
						|  |  | 
					
						
						|  | const node = g.selectAll('g.node') | 
					
						
						|  | .data(graph.nodes) | 
					
						
						|  | .join('g') | 
					
						
						|  | .attr('class', d => `node ${d.cls}`) | 
					
						
						|  | .call(d3.drag().on('start', dragStart).on('drag', dragged).on('end', dragEnd)); | 
					
						
						|  |  | 
					
						
						|  | const baseSel = node.filter(d => d.cls === 'base'); | 
					
						
						|  | if (HF_LOGO_URI){ | 
					
						
						|  | baseSel.append('image').attr('href', HF_LOGO_URI); | 
					
						
						|  | }else{ | 
					
						
						|  | baseSel.append('circle').attr('r', d => 22*d.sz).attr('fill', '#ffbe0b'); | 
					
						
						|  | } | 
					
						
						|  | node.filter(d => d.cls !== 'base').append('circle').attr('r', d => 20*d.sz); | 
					
						
						|  |  | 
					
						
						|  | node.append('text').attr('class','node-label').attr('dy','-2.4em').text(d => d.id); | 
					
						
						|  |  | 
					
						
						|  | const sim = d3.forceSimulation(graph.nodes) | 
					
						
						|  | .force('link', d3.forceLink(graph.links).id(d => d.id).distance(520)) | 
					
						
						|  | .force('charge', d3.forceManyBody().strength(-600)) | 
					
						
						|  | .force('center', d3.forceCenter(W / 2, H / 2)) | 
					
						
						|  | .force('collide', d3.forceCollide(d => 50)); | 
					
						
						|  |  | 
					
						
						|  | sim.on('tick', () => { | 
					
						
						|  | link.attr('x1', d=>d.source.x).attr('y1', d=>d.source.y) | 
					
						
						|  | .attr('x2', d=>d.target.x).attr('y2', d=>d.target.y); | 
					
						
						|  | linkLbl.attr('x', d=> (d.source.x+d.target.x)/2) | 
					
						
						|  | .attr('y', d=> (d.source.y+d.target.y)/2); | 
					
						
						|  | node.attr('transform', d=>`translate(${d.x},${d.y})`); | 
					
						
						|  | }); | 
					
						
						|  |  | 
					
						
						|  | function dragStart(e,d){ if(!e.active) sim.alphaTarget(.3).restart(); d.fx=d.x; d.fy=d.y; } | 
					
						
						|  | function dragged(e,d){ d.fx=e.x; d.fy=e.y; } | 
					
						
						|  | function dragEnd(e,d){ if(!e.active) sim.alphaTarget(0); d.fx=d.fy=null; } | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | HTML = """ | 
					
						
						|  | <!DOCTYPE html> | 
					
						
						|  | <html lang='en'><head><meta charset='UTF-8'> | 
					
						
						|  | <title>Transformers modular graph</title> | 
					
						
						|  | <style>__CSS__</style></head><body> | 
					
						
						|  | <div id='legend'> | 
					
						
						|  | π‘ base<br>π΅ modular<br>π΄ candidate<br>red edgeΒ = high embedding similarity<br><br> | 
					
						
						|  | <label><input type="checkbox" id="toggleRed" checked> Show candidates edges and nodes</label> | 
					
						
						|  | </div> | 
					
						
						|  | <svg id='dependency'></svg> | 
					
						
						|  | <script src='https://d3js.org/d3.v7.min.js'></script> | 
					
						
						|  | <script>__JS__</script></body></html> | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def write_html(graph_data: dict, path: Path): | 
					
						
						|  | path.write_text(generate_html(graph_data), encoding="utf-8") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def main(): | 
					
						
						|  | ap = argparse.ArgumentParser(description="Visualise modular dependencies + candidates") | 
					
						
						|  | ap.add_argument("transformers", help="Path to local π€ transformers repo root") | 
					
						
						|  | ap.add_argument("--multimodal", action="store_true", help="filter to models with β₯3 'pixel_values'") | 
					
						
						|  | ap.add_argument("--sim-threshold", type=float, default=SIM_DEFAULT) | 
					
						
						|  | ap.add_argument("--out", default=HTML_DEFAULT) | 
					
						
						|  | ap.add_argument("--sim-method", choices=["jaccard", "embedding"], default="jaccard", | 
					
						
						|  | help="Similarity method: 'jaccard' or 'embedding'") | 
					
						
						|  | args = ap.parse_args() | 
					
						
						|  |  | 
					
						
						|  | graph = build_graph_json( | 
					
						
						|  | transformers_dir=Path(args.transformers).expanduser().resolve(), | 
					
						
						|  | threshold=args.sim_threshold, | 
					
						
						|  | multimodal=args.multimodal, | 
					
						
						|  | sim_method=args.sim_method, | 
					
						
						|  | ) | 
					
						
						|  | write_html(graph, Path(args.out).expanduser()) | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | main() | 
					
						
						|  |  |