| """ |
| GPU Patch Script — Apply neuron permutation fix + lower MiMo alpha. |
| Run this ON THE GPU after cd /workspace/td_toolkit/hugging: |
| python3 patch_gpu.py |
| |
| What it does: |
| 1. Adds neuron permutation to transport.py fast path |
| 2. Adds _greedy_permutation() and _apply_permutation() helpers |
| 3. Updates fuse_weights() to apply permutations before blending |
| 4. Lowers MiMo alpha from 0.4 to 0.15 in config.py |
| 5. Lowers MiMo strength from 0.4 to 0.15 in td_start.td |
| 6. Adds torch import fix to heal.py (Bug #41) |
| """ |
|
|
| import os |
|
|
| def patch_file(filepath, old, new): |
| """Replace old text with new text in a file.""" |
| with open(filepath, 'r') as f: |
| content = f.read() |
| if old not in content: |
| print(f" WARNING: patch target not found in {filepath}") |
| print(f" Looking for: {old[:80]}...") |
| return False |
| content = content.replace(old, new) |
| with open(filepath, 'w') as f: |
| f.write(content) |
| print(f" PATCHED: {filepath}") |
| return True |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print("TD GPU Patch — Neuron Permutation Fix") |
| print("=" * 60) |
|
|
| |
| |
| |
| print("\n[1/4] Patching config.py (MiMo alpha 0.4 → 0.15)...") |
| patch_file( |
| "td_fuse/config.py", |
| 'merge_alpha=0.4,', |
| 'merge_alpha=0.15,', |
| ) |
|
|
| |
| |
| |
| print("\n[2/4] Patching td_start.td (strength 0.4 → 0.15)...") |
| patch_file( |
| "td_start.td", |
| 'strength 0.4', |
| 'strength 0.15', |
| ) |
|
|
| |
| |
| |
| print("\n[3/4] Patching heal.py (torch import fix)...") |
| |
| with open("td_fuse/heal.py", 'r') as f: |
| heal_content = f.read() |
| if "def apply_qlora_standard" in heal_content: |
| |
| idx = heal_content.find("def apply_qlora_standard") |
| next_lines = heal_content[idx:idx+500] |
| if "import torch" not in next_lines[:200]: |
| |
| patch_file( |
| "td_fuse/heal.py", |
| "from peft import get_peft_model, LoraConfig, TaskType\n", |
| "from peft import get_peft_model, LoraConfig, TaskType\n import torch\n", |
| ) |
| else: |
| print(" Already patched (torch import exists)") |
| else: |
| print(" WARNING: apply_qlora_standard not found in heal.py") |
|
|
| |
| |
| |
| print("\n[4/4] Rewriting transport.py with neuron permutation...") |
| write_transport_py() |
| print(" WROTE: td_fuse/transport.py") |
|
|
| print("\n" + "=" * 60) |
| print("ALL PATCHES APPLIED!") |
| print("=" * 60) |
| print("\nWhat changed:") |
| print(" • MiMo merge alpha: 0.4 → 0.15 (gentler blend)") |
| print(" • Neuron permutation: MiMo's neurons get reorganised to match Qwen3") |
| print(" • heal.py: torch import fix (Bug #41)") |
| print("\nNow run the pipeline:") |
| print(" export PYTHONPATH=$(pwd)") |
| print(" python3 -m td_lang run td_start.td") |
|
|
|
|
| def write_transport_py(): |
| """Write the complete updated transport.py with neuron permutation.""" |
| code = '''\ |
| """ |
| Transport and Merge Wrapper — interfaces with official T&M code. |
| |
| This wraps the official repo at: |
| github.com/chenhangcuisg-code/Cross-Architecture-Merging-for-Large-Language-Models/ |
| |
| We use THEIR code for: |
| - Correlation distance computation (corr_distance_matrix) |
| - Streaming Sinkhorn (sinkhorn_uniform_streaming) |
| - Transport plan computation (compute_P, compute_Q_and_layer_costs) |
| - Activation reconstruction (reconstruct_X) |
| |
| We add: |
| - Qwen3 thinking mode protection |
| - MiMo MTP head handling |
| - Falcon SSM component handling |
| - Neuron permutation for scrambled models (MiMo) |
| - Sequential merge protection (MagMax + orthogonal projection) |
| - Progress reporting every 5 minutes |
| - Timeouts to prevent infinite hangs |
| |
| Findings: #01, #07, #24 |
| """ |
| |
| import sys |
| import time |
| import torch |
| import numpy as np |
| from pathlib import Path |
| from typing import Optional |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from datasets import load_dataset |
| |
| from .config import MergeConfig, ModelConfig, TARGET |
| |
| |
| # ============================================================================ |
| # PROGRESS TRACKER — prints status every 5 minutes so you know it's alive |
| # ============================================================================ |
| |
| class ProgressTracker: |
| """Prints a heartbeat every interval_seconds so you know it's not stuck.""" |
| |
| def __init__(self, task_name: str, interval_seconds: int = 300): |
| self.task_name = task_name |
| self.interval = interval_seconds |
| self.start_time = time.time() |
| self.last_report = self.start_time |
| self.step = 0 |
| self.total_steps = 0 |
| print(f"\\n[{task_name}] Started at {time.strftime(\'%H:%M:%S\')}") |
| |
| def set_total(self, total: int): |
| self.total_steps = total |
| |
| def tick(self, step_name: str = ""): |
| """Call this inside loops. Prints progress if 5 min have passed.""" |
| self.step += 1 |
| now = time.time() |
| elapsed = now - self.start_time |
| since_last = now - self.last_report |
| |
| if since_last >= self.interval: |
| pct = f"{self.step}/{self.total_steps} ({100*self.step/self.total_steps:.0f}%)" if self.total_steps else f"step {self.step}" |
| eta = "" |
| if self.total_steps and self.step > 0: |
| rate = elapsed / self.step |
| remaining = (self.total_steps - self.step) * rate |
| eta = f", ETA {remaining/60:.1f} min" |
| print(f"[{self.task_name}] HEARTBEAT — {pct}, elapsed {elapsed/60:.1f} min{eta} | {step_name}") |
| sys.stdout.flush() |
| self.last_report = now |
| |
| def done(self): |
| elapsed = time.time() - self.start_time |
| print(f"[{self.task_name}] Completed in {elapsed/60:.1f} min ({elapsed:.0f}s)") |
| sys.stdout.flush() |
| |
| def check_timeout(self, timeout_seconds: int = 3600): |
| """Raise if we've been running longer than timeout_seconds.""" |
| elapsed = time.time() - self.start_time |
| if elapsed > timeout_seconds: |
| raise TimeoutError( |
| f"[{self.task_name}] TIMEOUT after {elapsed/60:.1f} min " |
| f"(limit: {timeout_seconds/60:.0f} min). Something is wrong." |
| ) |
| |
| |
| def setup_tm_repo(cfg: MergeConfig): |
| """Add official T&M repo to Python path so we can import their code.""" |
| repo_path = Path(cfg.tm_repo_path) |
| core_path = repo_path / "core" |
| |
| if not core_path.exists(): |
| raise FileNotFoundError( |
| f"Official T&M repo not found at {repo_path}\\n" |
| f"Please clone it:\\n" |
| f" git clone https://github.com/chenhangcuisg-code/" |
| f"Cross-Architecture-Merging-for-Large-Language-Models.git" |
| ) |
| |
| # Add to path so we can import hot_transport etc. |
| if str(core_path) not in sys.path: |
| sys.path.insert(0, str(core_path)) |
| print(f"[transport] Added T&M core to path: {core_path}") |
| |
| |
| def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list: |
| """ |
| Load calibration data for activation extraction. |
| |
| Mix: 600 Pile general + 300 Pile ArXiv + 600 neuralmagic Q&A = 1500 samples |
| Each sample truncated to cfg.calibration_seq_len tokens. |
| |
| Findings: #08 |
| """ |
| tracker = ProgressTracker("calibration-data", interval_seconds=120) |
| print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...") |
| |
| samples = [] |
| |
| # --- Pile: general text (600 samples) --- |
| try: |
| pile = load_dataset( |
| cfg.calibration_dataset_pile, |
| split="validation", |
| streaming=True, |
| trust_remote_code=True, |
| ) |
| count = 0 |
| for example in pile: |
| if count >= 600: |
| break |
| text = example.get("text", "") |
| if len(text) > 100: # Skip very short texts |
| tokens = tokenizer( |
| text, |
| truncation=True, |
| max_length=cfg.calibration_seq_len, |
| return_tensors="pt", |
| ) |
| samples.append(tokens) |
| count += 1 |
| if count % 100 == 0: |
| print(f" Pile: {count}/600 samples loaded...") |
| sys.stdout.flush() |
| print(f" Pile general: {count} samples") |
| except Exception as e: |
| print(f" WARNING: Pile failed: {e}") |
| print(f" Falling back to neuralmagic only") |
| |
| # --- neuralmagic: Q&A calibration (up to remaining) --- |
| remaining = cfg.calibration_samples - len(samples) |
| if remaining > 0: |
| try: |
| nm = load_dataset( |
| cfg.calibration_dataset_nm, |
| split="train", |
| trust_remote_code=True, |
| ) |
| count = 0 |
| for example in nm: |
| if count >= remaining: |
| break |
| text = example.get("text", example.get("content", "")) |
| if len(str(text)) > 50: |
| tokens = tokenizer( |
| str(text), |
| truncation=True, |
| max_length=cfg.calibration_seq_len, |
| return_tensors="pt", |
| ) |
| samples.append(tokens) |
| count += 1 |
| if count % 100 == 0: |
| print(f" neuralmagic: {count}/{remaining} samples loaded...") |
| sys.stdout.flush() |
| print(f" neuralmagic: {count} samples") |
| except Exception as e: |
| print(f" WARNING: neuralmagic failed: {e}") |
| |
| tracker.done() |
| print(f"[transport] Total calibration samples: {len(samples)}") |
| sys.stdout.flush() |
| return samples |
| |
| |
| def extract_activations( |
| model: AutoModelForCausalLM, |
| calibration_data: list, |
| device: str = "cuda", |
| ) -> dict: |
| """ |
| Extract intermediate activations from each layer of a model. |
| |
| Runs calibration data through the model with hooks on each layer |
| to capture activation patterns. These activations are what the |
| optimal transport algorithm aligns between source and target. |
| |
| Returns: |
| Dict mapping layer_name -> activation tensor [num_samples, hidden_dim] |
| """ |
| tracker = ProgressTracker("extract-activations", interval_seconds=300) |
| tracker.set_total(len(calibration_data)) |
| print(f"[transport] Extracting activations from {len(calibration_data)} samples...") |
| sys.stdout.flush() |
| |
| activations = {} |
| hooks = [] |
| |
| # Register hooks on each transformer layer |
| for name, module in model.named_modules(): |
| if hasattr(module, "self_attn") or name.endswith(".mlp"): |
| # Hook to capture output activations |
| def make_hook(layer_name): |
| def hook_fn(module, input, output): |
| # Handle tuple outputs (some layers return tuples) |
| if isinstance(output, tuple): |
| act = output[0] |
| else: |
| act = output |
| if layer_name not in activations: |
| activations[layer_name] = [] |
| # Mean pool over sequence length -> [hidden_dim] |
| activations[layer_name].append( |
| act.detach().float().mean(dim=1).cpu() |
| ) |
| return hook_fn |
| |
| h = module.register_forward_hook(make_hook(name)) |
| hooks.append(h) |
| |
| # Forward pass on calibration data |
| model.eval() |
| with torch.no_grad(): |
| for i, tokens in enumerate(calibration_data): |
| inputs = {k: v.to(device) for k, v in tokens.items()} |
| try: |
| model(**inputs) |
| except Exception as e: |
| print(f" WARNING: Sample {i} failed: {e}") |
| continue |
| |
| tracker.tick(f"sample {i+1}") |
| |
| if (i + 1) % 100 == 0: |
| print(f" Processed {i + 1}/{len(calibration_data)} samples") |
| sys.stdout.flush() |
| |
| # Timeout: 30 min for activation extraction |
| tracker.check_timeout(timeout_seconds=1800) |
| |
| # Remove hooks |
| for h in hooks: |
| h.remove() |
| |
| # Stack activations: [num_samples, hidden_dim] |
| layer_count = 0 |
| for key in activations: |
| activations[key] = torch.cat(activations[key], dim=0) |
| layer_count += 1 |
| |
| print(f" Extracted {layer_count} layers, shapes: {activations[list(activations.keys())[0]].shape if activations else \'empty\'}") |
| tracker.done() |
| sys.stdout.flush() |
| |
| return activations |
| |
| |
| def compute_transport_plans( |
| source_activations: dict, |
| target_activations: dict, |
| cfg: MergeConfig, |
| ) -> dict: |
| """ |
| Compute optimal transport plans between source and target activations. |
| |
| This is where the magic happens. We use the official T&M code's: |
| - corr_distance_matrix: correlation distance between activation vectors |
| - sinkhorn_uniform_streaming: memory-efficient Sinkhorn solver |
| - compute_P: layer-level coupling (which source layers -> which target layers) |
| - compute_Q_and_layer_costs: neuron-level coupling within each layer pair |
| |
| Returns: |
| Dict with 'P' (layer coupling) and 'Q' (per-layer neuron coupling) matrices |
| """ |
| print("[transport] Computing transport plans...") |
| sys.stdout.flush() |
| |
| try: |
| # Try importing official T&M code |
| from hot_transport import ( |
| corr_distance_matrix, |
| sinkhorn_uniform_streaming, |
| compute_P, |
| compute_Q_and_layer_costs, |
| ) |
| print("[transport] Using official T&M implementation") |
| return _compute_plans_official( |
| source_activations, target_activations, cfg, |
| corr_distance_matrix, sinkhorn_uniform_streaming, |
| compute_P, compute_Q_and_layer_costs, |
| ) |
| except ImportError: |
| print("[transport] Official T&M code not available, using fallback") |
| return _compute_plans_fallback( |
| source_activations, target_activations, cfg |
| ) |
| |
| |
| def _compute_plans_official( |
| source_act, target_act, cfg, |
| corr_distance_matrix, sinkhorn_uniform_streaming, |
| compute_P, compute_Q_and_layer_costs, |
| ) -> dict: |
| """Use the official T&M code to compute transport plans.""" |
| |
| # Get matching layer pairs |
| source_layers = sorted(source_act.keys()) |
| target_layers = sorted(target_act.keys()) |
| |
| # Compute Q matrices (neuron-level) and layer costs |
| Q_matrices, layer_costs = compute_Q_and_layer_costs( |
| source_act, target_act, |
| source_layers, target_layers, |
| ) |
| |
| # Compute P matrix (layer-level coupling) |
| P = compute_P(layer_costs) |
| |
| return { |
| "P": P, |
| "Q": Q_matrices, |
| "source_layers": source_layers, |
| "target_layers": target_layers, |
| } |
| |
| |
| def _compute_plans_fallback( |
| source_act: dict, |
| target_act: dict, |
| cfg: MergeConfig, |
| ) -> dict: |
| """ |
| Fallback transport plan computation when official code isn't available. |
| |
| Smart routing: |
| - Same-architecture models (same layer count): direct 1:1 layer matching |
| Check if neurons are aligned (DeepSeek) or scrambled (MiMo) |
| - Cross-architecture: sparse OT (only top-3 source layers per target) |
| """ |
| tracker = ProgressTracker("transport-plans", interval_seconds=300) |
| |
| source_layers = sorted(source_act.keys()) |
| target_layers = sorted(target_act.keys()) |
| |
| n_source = len(source_layers) |
| n_target = len(target_layers) |
| |
| print(f"[transport] Source layers: {n_source}, Target layers: {n_target}") |
| sys.stdout.flush() |
| |
| # --- FAST PATH: same architecture (same layer count) --- |
| # Both models have the same number of transformer layers |
| # Match layers 1:1 but CHECK if neurons correspond |
| # DeepSeek: same training base -> neurons aligned -> identity Q (fast) |
| # MiMo: different training -> neurons scrambled -> need Sinkhorn permutation |
| if n_source == n_target: |
| print("[transport] Same layer count -- using direct 1:1 layer matching") |
| sys.stdout.flush() |
| Q_matrices = {} |
| permutations = {} # layer_pair -> permutation array (neuron reordering) |
| P = np.eye(n_source) / n_source # Identity coupling |
| tracker.set_total(n_source) |
| |
| # Check first layer to decide: are neurons aligned or scrambled? |
| first_sl = source_layers[0] |
| first_tl = target_layers[0] |
| S0 = source_act[first_sl].numpy() |
| T0 = target_act[first_tl].numpy() |
| if S0.shape[1] == T0.shape[1]: |
| S0_norm = (S0 - S0.mean(0)) / (S0.std(0) + 1e-8) |
| T0_norm = (T0 - T0.mean(0)) / (T0.std(0) + 1e-8) |
| diag_corr = np.mean(np.sum(S0_norm * T0_norm, axis=0) / S0.shape[0]) |
| neurons_aligned = diag_corr > 0.3 |
| else: |
| neurons_aligned = False |
| |
| if neurons_aligned: |
| print(f"[transport] Neurons ARE aligned (diag_corr={diag_corr:.3f}) -- identity Q (fast)") |
| print("[transport] This should take under 1 minute...") |
| else: |
| corr_val = diag_corr if S0.shape[1] == T0.shape[1] else 0.0 |
| print(f"[transport] Neurons NOT aligned (diag_corr={corr_val:.3f}) -- computing permutations via Sinkhorn") |
| print("[transport] This may take 2-5 minutes...") |
| sys.stdout.flush() |
| |
| for i, (sl, tl) in enumerate(zip(source_layers, target_layers)): |
| S = source_act[sl].numpy() |
| T = target_act[tl].numpy() |
| |
| if S.shape[1] == T.shape[1]: |
| if neurons_aligned: |
| # Neurons already correspond (e.g. DeepSeek) -- identity Q |
| Q_matrices[(sl, tl)] = np.eye(S.shape[1]) / S.shape[1] |
| else: |
| # Neurons are SCRAMBLED (e.g. MiMo) -- find the permutation |
| # 1. Compute correlation matrix between source and target neurons |
| S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8) |
| T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8) |
| corr = S_norm.T @ T_norm / S.shape[0] # [hidden_dim, hidden_dim] |
| |
| # 2. Run Sinkhorn on cost matrix to get soft transport plan |
| cost = 1.0 - corr |
| Q_soft = _sinkhorn(cost, reg=0.05, max_iter=cfg.sinkhorn_max_iter) |
| |
| # 3. Extract hard permutation: for each source neuron, which target neuron? |
| perm = np.argmax(Q_soft, axis=1) # source_neuron -> target_neuron |
| |
| # 4. Check for duplicate assignments (Sinkhorn should avoid this, but be safe) |
| if len(set(perm)) < len(perm) * 0.9: |
| # Too many collisions -- fall back to Hungarian-style greedy |
| perm = _greedy_permutation(corr) |
| |
| permutations[(sl, tl)] = perm |
| Q_matrices[(sl, tl)] = Q_soft |
| else: |
| # Different dims -- do lightweight Sinkhorn on this pair only |
| print(f" Layer {i}: dim mismatch ({S.shape[1]} vs {T.shape[1]}), using Sinkhorn...") |
| S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8) |
| T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8) |
| corr = S_norm.T @ T_norm / S.shape[0] |
| cost = 1.0 - corr |
| Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50) |
| |
| tracker.tick(f"{sl} -> {tl}") |
| |
| if (i + 1) % 10 == 0 or i == 0: |
| print(f" Matched layer {i + 1}/{n_source}: {sl} -> {tl}") |
| sys.stdout.flush() |
| |
| # Timeout: 15 min (permutation takes longer than identity) |
| tracker.check_timeout(timeout_seconds=900) |
| |
| if permutations: |
| print(f"[transport] Computed {len(permutations)} neuron permutations") |
| print(f"[transport] Direct matching complete: {n_source} layer pairs") |
| tracker.done() |
| sys.stdout.flush() |
| return { |
| "P": P, |
| "Q": Q_matrices, |
| "permutations": permutations, |
| "source_layers": source_layers, |
| "target_layers": target_layers, |
| } |
| |
| # --- CROSS-ARCHITECTURE PATH: sparse OT --- |
| # Only compute top-3 source layers per target (not all NxN pairs) |
| print(f"[transport] Cross-architecture -- using sparse OT (top-3 per target)") |
| print(f"[transport] Estimated time: 5-15 minutes") |
| sys.stdout.flush() |
| |
| # Step 1: Compute layer-level similarity (cheap: just mean activation correlation) |
| print("[transport] Step 1/3: Computing layer-level similarities...") |
| sys.stdout.flush() |
| layer_costs = np.zeros((n_source, n_target)) |
| tracker.set_total(n_source * n_target + n_target * 3) |
| for i, sl in enumerate(source_layers): |
| for j, tl in enumerate(target_layers): |
| S_mean = source_act[sl].mean(0).numpy() |
| T_mean = target_act[tl].mean(0).numpy() |
| # Cosine similarity as cheap proxy |
| min_dim = min(len(S_mean), len(T_mean)) |
| s = S_mean[:min_dim] |
| t = T_mean[:min_dim] |
| sim = np.dot(s, t) / (np.linalg.norm(s) * np.linalg.norm(t) + 1e-8) |
| layer_costs[i, j] = 1.0 - sim |
| tracker.tick(f"layer sim {i},{j}") |
| |
| # Timeout: 30 min for cross-arch |
| tracker.check_timeout(timeout_seconds=1800) |
| |
| print(f"[transport] Step 1/3 done: {n_source}x{n_target} similarities computed") |
| sys.stdout.flush() |
| |
| # Step 2: For each target layer, only compute Q for top-3 most similar source layers |
| print("[transport] Step 2/3: Computing neuron-level transport (top-3 per target)...") |
| sys.stdout.flush() |
| Q_matrices = {} |
| for j, tl in enumerate(target_layers): |
| top3 = np.argsort(layer_costs[:, j])[:3] |
| for i in top3: |
| sl = source_layers[i] |
| S = source_act[sl].numpy() |
| T = target_act[tl].numpy() |
| |
| # Lightweight Sinkhorn (50 iterations, not 100+) |
| min_dim = min(S.shape[1], T.shape[1]) |
| S_sub = S[:, :min_dim] |
| T_sub = T[:, :min_dim] |
| S_norm = (S_sub - S_sub.mean(0)) / (S_sub.std(0) + 1e-8) |
| T_norm = (T_sub - T_sub.mean(0)) / (T_sub.std(0) + 1e-8) |
| corr = S_norm.T @ T_norm / S.shape[0] |
| cost = 1.0 - corr |
| Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50) |
| tracker.tick(f"Q({sl},{tl})") |
| |
| if (j + 1) % 5 == 0 or j == 0: |
| print(f" Target layer {j + 1}/{n_target}: matched to top-3 sources") |
| sys.stdout.flush() |
| |
| # Timeout: 30 min for cross-arch |
| tracker.check_timeout(timeout_seconds=1800) |
| |
| print(f"[transport] Step 2/3 done: {len(Q_matrices)} Q matrices computed") |
| sys.stdout.flush() |
| |
| # Step 3: Layer coupling via Sinkhorn on layer costs |
| print("[transport] Step 3/3: Computing layer coupling P matrix...") |
| sys.stdout.flush() |
| P = _sinkhorn(layer_costs, reg=0.1, max_iter=50) |
| |
| print(f"[transport] Sparse OT complete: {len(Q_matrices)} layer pairs computed") |
| tracker.done() |
| sys.stdout.flush() |
| return { |
| "P": P, |
| "Q": Q_matrices, |
| "permutations": {}, |
| "source_layers": source_layers, |
| "target_layers": target_layers, |
| } |
| |
| |
| def _sinkhorn( |
| cost_matrix: np.ndarray, |
| reg: float = 0.05, |
| max_iter: int = 100, |
| ) -> np.ndarray: |
| """ |
| Basic Sinkhorn-Knopp algorithm for optimal transport. |
| |
| Solves: min <T, C> - reg * H(T) |
| where H(T) is the entropy of the transport plan. |
| |
| This is the FALLBACK. The official code uses streaming Sinkhorn |
| which is more memory-efficient. |
| """ |
| n, m = cost_matrix.shape |
| K = np.exp(-cost_matrix / reg) |
| |
| u = np.ones(n) / n |
| v = np.ones(m) / m |
| |
| for iteration in range(max_iter): |
| u = 1.0 / (K @ v + 1e-10) |
| v = 1.0 / (K.T @ u + 1e-10) |
| |
| # Transport plan |
| T = np.diag(u) @ K @ np.diag(v) |
| return T |
| |
| |
| def _greedy_permutation(corr_matrix: np.ndarray) -> np.ndarray: |
| """ |
| Greedy permutation assignment when Sinkhorn gives duplicate mappings. |
| |
| For each source neuron (in order of strongest match), assign it to the |
| best available target neuron that hasn't been taken yet. |
| """ |
| n = corr_matrix.shape[0] |
| perm = np.full(n, -1, dtype=np.int64) |
| taken = set() |
| |
| # Process source neurons by strength of their best match (strongest first) |
| best_scores = np.max(corr_matrix, axis=1) |
| order = np.argsort(-best_scores) |
| |
| for src in order: |
| # Find best available target |
| sorted_targets = np.argsort(-corr_matrix[src]) |
| for tgt in sorted_targets: |
| if tgt not in taken: |
| perm[src] = tgt |
| taken.add(tgt) |
| break |
| |
| # Safety: any unassigned source neurons get remaining targets |
| remaining = set(range(n)) - taken |
| for src in range(n): |
| if perm[src] == -1: |
| perm[src] = remaining.pop() |
| |
| return perm |
| |
| |
| def _apply_permutation(source_w: torch.Tensor, perm: np.ndarray, key: str) -> torch.Tensor: |
| """ |
| Apply neuron permutation to a source weight tensor before blending. |
| |
| The permutation rearranges MiMo's neurons to match Qwen3's ordering. |
| Think of it like reorganising filing cabinets: same files, different order. |
| |
| Which dimension to permute depends on the weight type: |
| - Input projections (q_proj, k_proj, v_proj, gate_proj, up_proj): |
| shape [out_features, in_features] -> permute columns (dim 1) |
| because input neurons need reordering |
| - Output projections (o_proj, down_proj): |
| shape [out_features, in_features] -> permute rows (dim 0) |
| because output neurons need reordering |
| - 1D weights (layer_norm, bias): |
| permute directly |
| """ |
| perm_tensor = torch.from_numpy(perm).long() |
| |
| if source_w.dim() == 1: |
| # 1D: layer norms, biases |
| if len(perm_tensor) == source_w.shape[0]: |
| return source_w[perm_tensor] |
| return source_w |
| |
| if source_w.dim() == 2: |
| # 2D: linear layers |
| out_features, in_features = source_w.shape |
| |
| # Output projections: neurons on dim 0 (rows) |
| if any(proj in key for proj in ["o_proj", "down_proj"]): |
| if len(perm_tensor) == out_features: |
| return source_w[perm_tensor, :] |
| # Input projections: neurons on dim 1 (columns) |
| elif any(proj in key for proj in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]): |
| if len(perm_tensor) == in_features: |
| return source_w[:, perm_tensor] |
| # Other 2D weights: try columns first (more common) |
| else: |
| if len(perm_tensor) == in_features: |
| return source_w[:, perm_tensor] |
| elif len(perm_tensor) == out_features: |
| return source_w[perm_tensor, :] |
| |
| # Can't permute -- return unchanged |
| return source_w |
| |
| |
| def fuse_weights( |
| source_state: dict, |
| target_model: AutoModelForCausalLM, |
| transport_plans: dict, |
| source_config: ModelConfig, |
| cfg: MergeConfig, |
| target_activations: dict = None, |
| ) -> AutoModelForCausalLM: |
| """ |
| Fuse source model weights into target model using transport plans. |
| |
| For each layer pair with significant coupling (P > threshold): |
| 1. Get the Q matrix (neuron-level correspondence) |
| 2. Transport source weights into target neuron basis: W_fused = Q @ W_source |
| 3. Blend with target: W_final = alpha * W_fused + (1-alpha) * W_target |
| |
| Args: |
| source_state: Source model state dict (can be on CPU -- will be moved per-param) |
| target_model: Target model (on GPU) |
| transport_plans: Transport plan matrices from compute_transport_plans |
| source_config: Source model config |
| cfg: Merge configuration |
| |
| Special handling per model: |
| - DeepSeek: Direct merge (same architecture) |
| - MiMo: Skip MTP heads, skip embeddings, apply neuron permutation |
| - Llama: Layer mapping (32->36), skip embeddings, drop QKV bias |
| - Falcon: Skip Mamba components, skip embeddings |
| |
| Returns: |
| Target model with fused weights |
| """ |
| tracker = ProgressTracker("fuse-weights", interval_seconds=300) |
| print(f"\\n[transport] Fusing {source_config.name} -> target") |
| alpha = source_config.merge_alpha |
| |
| try: |
| # Try official fusion code first |
| from generate_hot_residual import fuse_attention_only_from_hot_dir |
| print("[transport] Using official fusion implementation") |
| # TODO: Adapt official fusion to our pipeline |
| # For now, fall through to manual fusion |
| except ImportError: |
| pass |
| |
| # --- Manual fusion using transport plans --- |
| # source_state is passed in (may be on CPU to save GPU memory) |
| target_state = target_model.state_dict() |
| P = transport_plans["P"] |
| Q = transport_plans["Q"] |
| permutations = transport_plans.get("permutations", {}) |
| |
| # Build layer-index -> permutation lookup |
| # permutations keys are (source_layer_name, target_layer_name) tuples |
| # We need to map weight keys like "model.layers.5.self_attn.q_proj.weight" |
| # to the permutation for layer 5 |
| layer_perms = {} |
| for (sl, tl), perm in permutations.items(): |
| # Extract layer index from target layer name (e.g. "model.layers.5.mlp" -> 5) |
| parts = tl.split(".") |
| for j, part in enumerate(parts): |
| if part == "layers" and j + 1 < len(parts): |
| try: |
| layer_idx = int(parts[j + 1]) |
| layer_perms[layer_idx] = perm |
| except ValueError: |
| pass |
| break |
| |
| if permutations: |
| print(f"[transport] Will apply neuron permutations to {len(layer_perms)} layers before blending") |
| else: |
| print("[transport] No neuron permutations needed (neurons already aligned)") |
| |
| fused_count = 0 |
| skipped_count = 0 |
| permuted_count = 0 |
| total_params = len(target_state) |
| tracker.set_total(total_params) |
| |
| for target_key in target_state: |
| tracker.tick(target_key) |
| |
| # Skip parameters we shouldn't merge |
| if _should_skip(target_key, source_config): |
| skipped_count += 1 |
| continue |
| |
| # Find corresponding source key |
| source_key = _map_key(target_key, source_config) |
| if source_key is None or source_key not in source_state: |
| skipped_count += 1 |
| # Log first few misses to help debug key mapping issues |
| if skipped_count <= 5: |
| print(f" [skip] No source match for: {target_key} (mapped to: {source_key})") |
| sys.stdout.flush() |
| continue |
| |
| target_w = target_state[target_key] |
| source_w = source_state[source_key] |
| |
| # Handle dimension mismatches |
| if target_w.shape != source_w.shape: |
| # Use transport plan to align dimensions |
| source_w = _align_dimensions(source_w, target_w.shape, Q, target_key) |
| if source_w is None: |
| skipped_count += 1 |
| continue |
| |
| # --- NEURON PERMUTATION: rearrange source neurons to match target --- |
| # This is what makes MiMo merge work -- without this, it's like |
| # dumping one filing cabinet into another without matching folders |
| if layer_perms: |
| # Extract layer index from this weight's key |
| key_parts = target_key.split(".") |
| for j, part in enumerate(key_parts): |
| if part == "layers" and j + 1 < len(key_parts): |
| try: |
| lidx = int(key_parts[j + 1]) |
| if lidx in layer_perms: |
| source_w = _apply_permutation(source_w, layer_perms[lidx], target_key) |
| permuted_count += 1 |
| except ValueError: |
| pass |
| break |
| |
| # Blend: W_final = alpha * source + (1-alpha) * target |
| fused_w = alpha * source_w.to(target_w.device) + (1 - alpha) * target_w |
| target_state[target_key] = fused_w |
| fused_count += 1 |
| |
| # Apply thinking mode protection (inside loop -- check each key) |
| if cfg.freeze_think_tokens and "embed_tokens" in target_key: |
| for token_id in cfg.think_token_ids: |
| if token_id < target_state[target_key].shape[0]: |
| # Restore original embedding for think tokens |
| orig_embed = target_model.state_dict()[target_key] |
| target_state[target_key][token_id] = orig_embed[token_id] |
| print(f"[transport] Protected think token {token_id}") |
| |
| if fused_count % 50 == 0: |
| print(f" Fused {fused_count} params so far (skipped {skipped_count})...") |
| sys.stdout.flush() |
| |
| # Timeout: 20 min for weight fusion |
| tracker.check_timeout(timeout_seconds=1200) |
| |
| # Load fused weights (strict=False: vision encoder may have bitsandbytes quant keys |
| # that don't match the original key names -- we never modify vision weights anyway) |
| missing, unexpected = target_model.load_state_dict(target_state, strict=False) |
| if missing: |
| print(f"[transport] NOTE: {len(missing)} missing keys (likely quantized vision params -- safe to ignore)") |
| if unexpected: |
| print(f"[transport] NOTE: {len(unexpected)} unexpected keys (safe to ignore)") |
| perm_msg = f", permuted {permuted_count}" if permuted_count else "" |
| print(f"[transport] Fused {fused_count} params, skipped {skipped_count}{perm_msg}") |
| tracker.done() |
| sys.stdout.flush() |
| |
| return target_model |
| |
| |
| def _should_skip(key: str, source_config: ModelConfig) -> bool: |
| """Determine if a parameter should be skipped during merge.""" |
| |
| # Skip vision encoder params (Qwen3-VL) -- these should never be merged |
| if key.startswith("visual") or key.startswith("merger") or key.startswith("model.visual") or key.startswith("model.merger"): |
| return True |
| |
| # Always skip if source model says to skip embeddings |
| if source_config.skip_embeddings and ("embed_tokens" in key or "lm_head" in key): |
| return True |
| |
| # Skip MiMo MTP heads |
| if "drop_mtp_heads" in source_config.special_handling and "mtp_head" in key: |
| return True |
| |
| # Skip Falcon Mamba-specific parameters |
| if "drop_mamba_state_params" in source_config.special_handling: |
| mamba_keys = ["mamba", "A_log", "dt_proj", ".D"] |
| if any(mk in key for mk in mamba_keys): |
| return True |
| |
| # Skip QKV bias for Llama (Qwen3 doesn't have it) |
| if "drop_qkv_bias" in source_config.special_handling and ".bias" in key: |
| if any(proj in key for proj in ["q_proj", "k_proj", "v_proj"]): |
| return True |
| |
| return False |
| |
| |
| def _strip_vl_prefix(key: str) -> str: |
| """ |
| Strip the 'language_model.' prefix that Qwen3-VL adds. |
| |
| Qwen3-VL wraps all language params under 'model.language_model.*' |
| but source models (DeepSeek, MiMo, Llama, Falcon) use 'model.*' directly. |
| |
| Example: |
| target: model.language_model.layers.0.self_attn.q_proj.weight |
| source: model.layers.0.self_attn.q_proj.weight |
| """ |
| # model.language_model.X -> model.X |
| if "language_model." in key: |
| return key.replace("language_model.", "") |
| return key |
| |
| |
| def _map_key(target_key: str, source_config: ModelConfig) -> Optional[str]: |
| """Map a target model parameter name to the corresponding source name.""" |
| |
| # Step 1: Strip Qwen3-VL's language_model. prefix so we can match source keys |
| source_key = _strip_vl_prefix(target_key) |
| |
| # For same-architecture models (DeepSeek), keys match directly after prefix strip |
| if source_config.architecture == "transformer" and source_config.layers == 36: |
| return source_key |
| |
| # For Llama (32 layers -> 36 layers), map layer indices |
| if "layer_mapping_32_to_36" in source_config.special_handling: |
| if "model.layers." in source_key: |
| # Extract layer number |
| parts = source_key.split(".") |
| try: |
| layer_idx = int(parts[2]) |
| except (IndexError, ValueError): |
| return source_key |
| |
| # Map 36 target layers to 32 source layers (stride) |
| source_layer = int(layer_idx * 32 / 36) |
| parts[2] = str(source_layer) |
| return ".".join(parts) |
| |
| # For MiMo (same layer count, different extras), keys mostly match |
| if source_config.architecture == "transformer+mtp": |
| if "mtp_head" in source_key: |
| return None # MTP heads don't exist in target |
| return source_key |
| |
| # For Falcon hybrid, only attention and MLP keys map |
| if source_config.architecture == "hybrid_ssm": |
| if any(k in source_key for k in ["self_attn", "mlp", "layer_norm"]): |
| return source_key # These exist in both |
| return None # Mamba components don't map |
| |
| return source_key |
| |
| |
| def _align_dimensions( |
| source_w: torch.Tensor, |
| target_shape: tuple, |
| Q_matrices: dict, |
| key: str, |
| ) -> Optional[torch.Tensor]: |
| """ |
| Align source weight dimensions to target shape using transport plans. |
| |
| For small mismatches: pad or truncate. |
| For large mismatches: use Q matrix to project. |
| """ |
| if source_w.shape == target_shape: |
| return source_w |
| |
| # Simple case: different width (FFN size difference) |
| if len(source_w.shape) == 2 and len(target_shape) == 2: |
| s_rows, s_cols = source_w.shape |
| t_rows, t_cols = target_shape |
| |
| result = torch.zeros(target_shape, dtype=source_w.dtype) |
| |
| # Copy what fits |
| min_rows = min(s_rows, t_rows) |
| min_cols = min(s_cols, t_cols) |
| result[:min_rows, :min_cols] = source_w[:min_rows, :min_cols] |
| |
| return result |
| |
| # 1D case (biases, layer norms) |
| if len(source_w.shape) == 1 and len(target_shape) == 1: |
| result = torch.zeros(target_shape, dtype=source_w.dtype) |
| min_len = min(source_w.shape[0], target_shape[0]) |
| result[:min_len] = source_w[:min_len] |
| return result |
| |
| # Can't align -- skip this parameter |
| return None |
| ''' |
| with open("td_fuse/transport.py", 'w') as f: |
| f.write(code) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|