| | """ |
| | Threshold Circuit Pruner v5 (Refactored) |
| | |
| | Streamlined pruning framework with 6 core methods: |
| | - magnitude: Greedy weight reduction |
| | - zero: Try zeroing individual weights |
| | - evolutionary: GPU-parallel genetic algorithm |
| | - exhaustive_mag: Provably optimal for small circuits |
| | - architecture: Search flat 2-layer alternatives |
| | - compositional: For circuits built from known-optimal components |
| | |
| | Usage: |
| | python prune.py threshold-xor --methods evo |
| | python prune.py threshold-xor --methods exh_mag |
| | python prune.py threshold-crc16-mag53 --methods comp |
| | python prune.py --list |
| | """ |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import json |
| | import time |
| | import random |
| | import argparse |
| | import math |
| | import gc |
| | import re |
| | import importlib.util |
| | from concurrent.futures import ThreadPoolExecutor, as_completed |
| | from pathlib import Path |
| | from dataclasses import dataclass, field |
| | from typing import Dict, List, Tuple, Optional, Callable |
| | from safetensors.torch import load_file, save_file |
| | from collections import defaultdict |
| | from functools import lru_cache |
| | from itertools import combinations, product |
| | import warnings |
| |
|
| | warnings.filterwarnings('ignore') |
| |
|
| | CIRCUITS_PATH = Path('D:/threshold-logic-circuits') |
| | RESULTS_PATH = CIRCUITS_PATH / 'pruned_results' |
| |
|
| |
|
| | @dataclass |
| | class VRAMConfig: |
| | target_residency: float = 0.75 |
| | safety_margin: float = 0.05 |
| |
|
| | def __post_init__(self): |
| | self.total_bytes = 0 |
| | self.device_name = "CPU" |
| | if torch.cuda.is_available(): |
| | props = torch.cuda.get_device_properties(0) |
| | self.total_bytes = props.total_memory |
| | self.device_name = props.name |
| |
|
| | @property |
| | def total_gb(self) -> float: |
| | return self.total_bytes / 1e9 |
| |
|
| | @property |
| | def available_bytes(self) -> int: |
| | return int(self.total_bytes * (self.target_residency - self.safety_margin)) |
| |
|
| | def current_usage(self) -> Dict: |
| | if not torch.cuda.is_available(): |
| | return {'allocated_gb': 0, 'free_gb': 0, 'utilization': 0} |
| | allocated = torch.cuda.memory_allocated() |
| | return { |
| | 'allocated_gb': allocated / 1e9, |
| | 'free_gb': (self.total_bytes - allocated) / 1e9, |
| | 'utilization': allocated / self.total_bytes if self.total_bytes > 0 else 0 |
| | } |
| |
|
| |
|
| | def clear_vram(): |
| | gc.collect() |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | torch.cuda.synchronize() |
| |
|
| |
|
| | @dataclass |
| | class Config: |
| | device: str = 'cuda' |
| | fitness_threshold: float = 0.9999 |
| | verbose: bool = True |
| | vram: VRAMConfig = field(default_factory=VRAMConfig) |
| |
|
| | run_magnitude: bool = False |
| | run_zero: bool = False |
| | run_evolutionary: bool = False |
| | run_exhaustive_mag: bool = False |
| | run_architecture: bool = False |
| | run_compositional: bool = False |
| |
|
| | magnitude_passes: int = 100 |
| | exhaustive_max_params: int = 12 |
| | exhaustive_target_mag: int = -1 |
| |
|
| | evo_generations: int = 2000 |
| | evo_pop_size: int = 0 |
| | evo_elite_ratio: float = 0.05 |
| | evo_mutation_rate: float = 0.15 |
| | evo_mutation_strength: float = 2.0 |
| | evo_crossover_rate: float = 0.3 |
| | evo_parsimony: float = 0.001 |
| |
|
| | arch_hidden_neurons: int = 3 |
| | arch_max_weight: int = 3 |
| | arch_max_mag: int = 20 |
| |
|
| |
|
| | @dataclass |
| | class CircuitSpec: |
| | name: str |
| | path: Path |
| | inputs: int |
| | outputs: int |
| | neurons: int |
| | layers: int |
| | parameters: int |
| | description: str = "" |
| |
|
| |
|
| | @dataclass |
| | class PruneResult: |
| | method: str |
| | original_stats: Dict |
| | final_stats: Dict |
| | final_weights: Dict[str, torch.Tensor] |
| | fitness: float |
| | time_seconds: float |
| | metadata: Dict = field(default_factory=dict) |
| |
|
| |
|
| | class ComputationGraph: |
| | """Parses weight structure to build dependency graph.""" |
| |
|
| | def __init__(self, weights: Dict[str, torch.Tensor], n_inputs: int, n_outputs: int, device: str): |
| | self.device = device |
| | self.n_inputs = n_inputs |
| | self.n_outputs = n_outputs |
| | self.weights = weights |
| | self.neurons = {} |
| | self.neuron_order = [] |
| | self.output_neurons = [] |
| | self.layer_groups = defaultdict(list) |
| | self._parse_structure() |
| | self._build_execution_order() |
| |
|
| | def _parse_structure(self): |
| | neuron_weights = defaultdict(dict) |
| | for key, tensor in self.weights.items(): |
| | if '.weight' in key: |
| | neuron_name = key.replace('.weight', '') |
| | neuron_weights[neuron_name]['weight'] = key |
| | neuron_weights[neuron_name]['weight_shape'] = tensor.shape |
| | elif '.bias' in key: |
| | neuron_name = key.replace('.bias', '') |
| | neuron_weights[neuron_name]['bias'] = key |
| |
|
| | for neuron_name, params in neuron_weights.items(): |
| | if 'weight' in params: |
| | depth = self._estimate_depth(neuron_name) |
| | self.neurons[neuron_name] = { |
| | 'weight_key': params.get('weight'), |
| | 'bias_key': params.get('bias'), |
| | 'weight_shape': params.get('weight_shape', (1,)), |
| | 'input_size': params.get('weight_shape', (1,))[-1], |
| | 'depth': depth |
| | } |
| | self.layer_groups[depth].append(neuron_name) |
| |
|
| | self._infer_depth_from_shapes() |
| | self._identify_outputs() |
| |
|
| | def _estimate_depth(self, name: str) -> int: |
| | depth = 0 |
| | if 'layer' in name: |
| | match = re.search(r'layer(\d+)', name) |
| | if match: |
| | depth = int(match.group(1)) |
| | depth += len(name.split('.')) - 1 |
| | return depth |
| |
|
| | def _infer_depth_from_shapes(self): |
| | for name, info in self.neurons.items(): |
| | if info.get('input_size', 0) == self.n_inputs: |
| | info['depth'] = 0 |
| | info['input_source'] = 'raw' |
| |
|
| | self.layer_groups = defaultdict(list) |
| | for name, info in self.neurons.items(): |
| | self.layer_groups[info['depth']].append(name) |
| |
|
| | def _identify_outputs(self): |
| | candidates = [] |
| | for name in self.neurons: |
| | is_parent = any(name + '.' in other for other in self.neurons if other != name) |
| | if not is_parent: |
| | candidates.append((name, self.neurons[name]['depth'])) |
| | candidates.sort(key=lambda x: (-x[1], x[0])) |
| | self.output_neurons = [c[0] for c in candidates[:self.n_outputs]] |
| |
|
| | def _build_execution_order(self): |
| | self.neuron_order = sorted(self.neurons.keys(), key=lambda n: (self.neurons[n]['depth'], n)) |
| |
|
| | def forward_single(self, inputs: torch.Tensor, weights: Dict[str, torch.Tensor]) -> torch.Tensor: |
| | activations = {'input': inputs} |
| | for neuron_name in self.neuron_order: |
| | info = self.neurons[neuron_name] |
| | w_key, b_key = info['weight_key'], info['bias_key'] |
| | if w_key and w_key in weights: |
| | w = weights[w_key] |
| | if w.dim() == 1: |
| | w = w.unsqueeze(0) |
| | inp = self._get_neuron_input(neuron_name, activations, inputs, w.shape[-1]) |
| | if inp.dim() == 1: |
| | out = inp @ w.flatten() |
| | else: |
| | out = (inp.unsqueeze(-2) @ w.T.unsqueeze(0)).squeeze(-2) |
| | if out.dim() > 1 and out.shape[-1] == 1: |
| | out = out.squeeze(-1) |
| | if b_key and b_key in weights: |
| | out = out + weights[b_key].squeeze() |
| | activations[neuron_name] = (out >= 0).float() |
| | return self._collect_outputs(activations) |
| |
|
| | def _get_neuron_input(self, neuron_name: str, activations: Dict, raw_input: torch.Tensor, expected_size: int) -> torch.Tensor: |
| | info = self.neurons.get(neuron_name, {}) |
| | if info.get('input_source') == 'raw' or expected_size == self.n_inputs: |
| | return raw_input |
| | if 'layer2' in neuron_name or '.out' in neuron_name: |
| | base = neuron_name.replace('.layer2', '').replace('.out', '') |
| | hidden_keys = [k for k in activations if k.startswith(base) and k != neuron_name and k != 'input'] |
| | if len(hidden_keys) == expected_size: |
| | return torch.stack([activations[k] for k in sorted(hidden_keys)], dim=-1) |
| | return raw_input[..., :expected_size] if raw_input.shape[-1] >= expected_size else raw_input |
| |
|
| | def _collect_outputs(self, activations: Dict) -> torch.Tensor: |
| | outputs = [activations[n] for n in sorted(self.output_neurons) if n in activations] |
| | if outputs: |
| | return torch.stack(outputs, dim=-1) if outputs[0].dim() > 0 else torch.stack(outputs) |
| | return torch.zeros(self.n_outputs, device=self.device) |
| |
|
| |
|
| | class AdaptiveCircuit: |
| | """Adaptive threshold circuit with automatic evaluation.""" |
| |
|
| | def __init__(self, path: Path, device: str = 'cuda', weights_file: str = None): |
| | self.path = Path(path) |
| | self.device = device |
| | self.spec = self._load_spec() |
| | self.weights = self._load_weights(weights_file) |
| | self.weight_keys = list(self.weights.keys()) |
| | self.n_weights = sum(t.numel() for t in self.weights.values()) |
| |
|
| | self.native_forward = self._try_load_native_forward() |
| | self.has_native = self.native_forward is not None |
| | print(f" [LOAD] Native forward: {'FOUND' if self.has_native else 'NOT FOUND'}") |
| |
|
| | print(f" [LOAD] Parsing circuit topology...") |
| | self.graph = ComputationGraph(self.weights, self.spec.inputs, self.spec.outputs, device) |
| | print(f" [LOAD] Found {len(self.graph.neurons)} neurons across {len(self.graph.layer_groups)} layers") |
| | print(f" [LOAD] Output neurons: {sorted(self.graph.output_neurons)}") |
| |
|
| | print(f" [LOAD] Building test cases...") |
| | self.test_inputs, self.test_expected = self._build_tests() |
| | self.n_cases = self.test_inputs.shape[0] |
| | print(f" [LOAD] Generated {self.n_cases} test cases") |
| |
|
| | self._compile_fast_forward() |
| | print(f" [LOAD] Circuit ready: {self.n_weights} weight parameters") |
| |
|
| | def _try_load_native_forward(self) -> Optional[Callable]: |
| | model_py = self.path / 'model.py' |
| | if not model_py.exists(): |
| | return None |
| | try: |
| | spec = importlib.util.spec_from_file_location("circuit_model", model_py) |
| | module = importlib.util.module_from_spec(spec) |
| | spec.loader.exec_module(module) |
| | if hasattr(module, 'forward'): |
| | return module.forward |
| | return None |
| | except Exception: |
| | return None |
| |
|
| | def evaluate_native(self, inputs: torch.Tensor, weights: Dict[str, torch.Tensor]) -> torch.Tensor: |
| | if not self.has_native: |
| | return self.graph.forward_single(inputs, weights) |
| | try: |
| | out = self.native_forward(inputs, weights) |
| | if isinstance(out, torch.Tensor): |
| | return out.to(self.device) if out.dim() > 1 else out.unsqueeze(-1).to(self.device) |
| | except Exception: |
| | pass |
| | cpu_weights = {k: v.cpu() for k, v in weights.items()} |
| | results = [] |
| | for i in range(inputs.shape[0]): |
| | inp = [int(x) for x in inputs[i].cpu().tolist()] |
| | try: |
| | out = self.native_forward(inp, cpu_weights) |
| | results.append([float(x) for x in out] if isinstance(out, (list, tuple)) else [float(out)]) |
| | except Exception: |
| | results.append([0.0] * self.spec.outputs) |
| | return torch.tensor(results, device=self.device, dtype=torch.float32) |
| |
|
| | def _load_spec(self) -> CircuitSpec: |
| | with open(self.path / 'config.json') as f: |
| | cfg = json.load(f) |
| | return CircuitSpec( |
| | name=cfg.get('name', self.path.name), |
| | path=self.path, |
| | inputs=cfg.get('inputs', cfg.get('input_size', 0)), |
| | outputs=cfg.get('outputs', cfg.get('output_size', 0)), |
| | neurons=cfg.get('neurons', 0), |
| | layers=cfg.get('layers', 0), |
| | parameters=cfg.get('parameters', 0), |
| | description=cfg.get('description', '') |
| | ) |
| |
|
| | def _load_weights(self, weights_file: str = None) -> Dict[str, torch.Tensor]: |
| | if weights_file: |
| | sf = self.path / weights_file |
| | else: |
| | sf = self.path / 'model.safetensors' |
| | if not sf.exists(): |
| | candidates = list(self.path.glob('*.safetensors')) |
| | sf = candidates[0] if candidates else sf |
| | w = load_file(str(sf)) |
| | return {k: v.float().to(self.device) for k, v in w.items()} |
| |
|
| | def _build_tests(self) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if self.has_native: |
| | return self._build_native_tests() |
| | return self._build_exhaustive_tests() |
| |
|
| | def _build_exhaustive_tests(self) -> Tuple[torch.Tensor, torch.Tensor]: |
| | n = self.spec.inputs |
| | if n > 24: |
| | raise ValueError(f"Input space too large: 2^{n}") |
| | n_cases = 2 ** n |
| | idx = torch.arange(n_cases, device=self.device, dtype=torch.long) |
| | bits = torch.arange(n, device=self.device, dtype=torch.long) |
| | inputs = ((idx.unsqueeze(1) >> bits) & 1).float() |
| | expected = self.graph.forward_single(inputs, self.weights) |
| | return inputs, expected |
| |
|
| | def _build_native_tests(self) -> Tuple[torch.Tensor, torch.Tensor]: |
| | n = self.spec.inputs |
| | if n > 20: |
| | raise ValueError(f"Input space too large: 2^{n}") |
| | n_cases = 2 ** n |
| | inputs_list, expected_list = [], [] |
| | cpu_weights = {k: v.cpu() for k, v in self.weights.items()} |
| | for i in range(n_cases): |
| | inp = [(i >> b) & 1 for b in range(n)] |
| | inputs_list.append(inp) |
| | out = self.native_forward(inp, cpu_weights) |
| | expected_list.append([float(x) for x in out] if isinstance(out, (list, tuple)) else [float(out)]) |
| | return (torch.tensor(inputs_list, device=self.device, dtype=torch.float32), |
| | torch.tensor(expected_list, device=self.device, dtype=torch.float32)) |
| |
|
| | def _compile_fast_forward(self): |
| | self.weight_layout = [] |
| | offset = 0 |
| | for key in self.weight_keys: |
| | size = self.weights[key].numel() |
| | self.weight_layout.append((key, offset, offset + size, self.weights[key].shape)) |
| | offset += size |
| | self.base_vector = self.weights_to_vector(self.weights) |
| |
|
| | def weights_to_vector(self, weights: Dict[str, torch.Tensor]) -> torch.Tensor: |
| | return torch.cat([weights[k].flatten() for k in self.weight_keys]) |
| |
|
| | def vector_to_weights(self, vector: torch.Tensor) -> Dict[str, torch.Tensor]: |
| | weights = {} |
| | for key, start, end, shape in self.weight_layout: |
| | weights[key] = vector[start:end].view(shape) |
| | return weights |
| |
|
| | def clone_weights(self) -> Dict[str, torch.Tensor]: |
| | return {k: v.clone() for k, v in self.weights.items()} |
| |
|
| | def stats(self, weights: Dict[str, torch.Tensor] = None) -> Dict: |
| | w = weights or self.weights |
| | total = sum(t.numel() for t in w.values()) |
| | nonzero = sum((t != 0).sum().item() for t in w.values()) |
| | mag = sum(t.abs().sum().item() for t in w.values()) |
| | maxw = max(t.abs().max().item() for t in w.values()) if w else 0 |
| | return { |
| | 'total': total, |
| | 'nonzero': nonzero, |
| | 'sparsity': 1 - nonzero / total if total else 0, |
| | 'magnitude': mag, |
| | 'max_weight': maxw |
| | } |
| |
|
| | def save_weights(self, weights: Dict[str, torch.Tensor], suffix: str = 'pruned') -> Path: |
| | path = self.path / f'model_{suffix}.safetensors' |
| | save_file({k: v.cpu() for k, v in weights.items()}, str(path)) |
| | return path |
| |
|
| |
|
| | class BatchedEvaluator: |
| | """GPU-optimized batched population evaluation.""" |
| |
|
| | def __init__(self, circuit: AdaptiveCircuit, cfg: Config): |
| | self.circuit = circuit |
| | self.cfg = cfg |
| | self.device = cfg.device |
| | self.test_inputs = circuit.test_inputs |
| | self.test_expected = circuit.test_expected |
| | self.n_cases = circuit.n_cases |
| | self.n_weights = circuit.n_weights |
| |
|
| | if cfg.verbose: |
| | print(f" [EVAL] Initializing evaluator...") |
| |
|
| | self._calculate_batch_size() |
| | self._validate_evaluation() |
| |
|
| | if cfg.verbose: |
| | print(f" [EVAL] Evaluator ready: batch={self.max_batch:,}, native={circuit.has_native}") |
| |
|
| | def _calculate_batch_size(self): |
| | bytes_per_ind = self.n_weights * 4 * 2 + self.n_cases * self.circuit.spec.outputs * 4 + 4096 |
| | available = self.cfg.vram.available_bytes |
| | self.max_batch = max(1000, min(available // max(bytes_per_ind, 1), 5_000_000)) |
| |
|
| | def _validate_evaluation(self): |
| | fitness = self.evaluate_single(self.circuit.weights) |
| | if fitness < 0.999 and self.cfg.verbose: |
| | print(f" [EVAL WARNING] Original weights fitness={fitness:.4f}") |
| |
|
| | def evaluate_single(self, weights: Dict[str, torch.Tensor]) -> float: |
| | with torch.no_grad(): |
| | if self.circuit.has_native: |
| | outputs = self.circuit.evaluate_native(self.test_inputs, weights) |
| | else: |
| | outputs = self.circuit.graph.forward_single(self.test_inputs, weights) |
| | if outputs.shape != self.test_expected.shape: |
| | if outputs.dim() == 1: |
| | outputs = outputs.unsqueeze(0).expand(self.test_expected.shape[0], -1) |
| | correct = (outputs == self.test_expected).all(dim=-1).float().sum() |
| | return (correct / self.n_cases).item() |
| |
|
| | def evaluate_population(self, population: torch.Tensor) -> torch.Tensor: |
| | pop_size = population.shape[0] |
| | if pop_size > self.max_batch: |
| | return self._evaluate_chunked(population) |
| | return self._evaluate_sequential(population) |
| |
|
| | def _evaluate_sequential(self, population: torch.Tensor) -> torch.Tensor: |
| | pop_size = population.shape[0] |
| | fitness = torch.zeros(pop_size, device=self.device) |
| | with torch.no_grad(): |
| | for i in range(pop_size): |
| | weights = self.circuit.vector_to_weights(population[i]) |
| | if self.circuit.has_native: |
| | outputs = self.circuit.evaluate_native(self.test_inputs, weights) |
| | else: |
| | outputs = self.circuit.graph.forward_single(self.test_inputs, weights) |
| | if outputs.shape != self.test_expected.shape: |
| | if outputs.dim() == 1: |
| | outputs = outputs.unsqueeze(0).expand(self.test_expected.shape[0], -1) |
| | if outputs.shape == self.test_expected.shape: |
| | correct = (outputs == self.test_expected).all(dim=-1).float().sum() |
| | fitness[i] = correct / self.n_cases |
| | return fitness |
| |
|
| | def _evaluate_chunked(self, population: torch.Tensor) -> torch.Tensor: |
| | pop_size = population.shape[0] |
| | fitness = torch.zeros(pop_size, device=self.device) |
| | for start in range(0, pop_size, self.max_batch): |
| | end = min(start + self.max_batch, pop_size) |
| | fitness[start:end] = self._evaluate_sequential(population[start:end]) |
| | return fitness |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def prune_magnitude(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: |
| | """Iterative magnitude reduction - try reducing each weight toward zero.""" |
| | start = time.perf_counter() |
| | weights = circuit.clone_weights() |
| | original = circuit.stats(weights) |
| |
|
| | if cfg.verbose: |
| | print(f" Starting magnitude reduction (passes={cfg.magnitude_passes})...") |
| |
|
| | for pass_num in range(cfg.magnitude_passes): |
| | candidates = [] |
| | for name, tensor in weights.items(): |
| | flat = tensor.flatten() |
| | for i in range(len(flat)): |
| | val = flat[i].item() |
| | if val != 0: |
| | new_val = val - 1 if val > 0 else val + 1 |
| | candidates.append((name, i, tensor.shape, val, new_val)) |
| |
|
| | if not candidates: |
| | break |
| |
|
| | random.shuffle(candidates) |
| | reductions = 0 |
| |
|
| | for name, idx, shape, old_val, new_val in candidates: |
| | flat = weights[name].flatten() |
| | flat[idx] = new_val |
| | weights[name] = flat.view(shape) |
| | if evaluator.evaluate_single(weights) >= cfg.fitness_threshold: |
| | reductions += 1 |
| | else: |
| | flat[idx] = old_val |
| | weights[name] = flat.view(shape) |
| |
|
| | if cfg.verbose: |
| | stats = circuit.stats(weights) |
| | print(f" Pass {pass_num}: {reductions} reductions, mag={stats['magnitude']:.0f}") |
| |
|
| | if reductions == 0: |
| | break |
| |
|
| | return PruneResult( |
| | method='magnitude', |
| | original_stats=original, |
| | final_stats=circuit.stats(weights), |
| | final_weights=weights, |
| | fitness=evaluator.evaluate_single(weights), |
| | time_seconds=time.perf_counter() - start |
| | ) |
| |
|
| |
|
| | def prune_zero(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: |
| | """Zero pruning - try setting each non-zero weight to zero.""" |
| | start = time.perf_counter() |
| | weights = circuit.clone_weights() |
| | original = circuit.stats(weights) |
| |
|
| | candidates = [] |
| | for name, tensor in weights.items(): |
| | flat = tensor.flatten() |
| | for i in range(len(flat)): |
| | if flat[i].item() != 0: |
| | candidates.append((name, i, tensor.shape, flat[i].item())) |
| |
|
| | random.shuffle(candidates) |
| |
|
| | if cfg.verbose: |
| | print(f" Testing {len(candidates)} non-zero weights for zeroing...") |
| |
|
| | zeroed = 0 |
| | for name, idx, shape, old_val in candidates: |
| | flat = weights[name].flatten() |
| | flat[idx] = 0 |
| | weights[name] = flat.view(shape) |
| | if evaluator.evaluate_single(weights) >= cfg.fitness_threshold: |
| | zeroed += 1 |
| | else: |
| | flat[idx] = old_val |
| | weights[name] = flat.view(shape) |
| |
|
| | if cfg.verbose: |
| | print(f" Zeroed {zeroed}/{len(candidates)} weights") |
| |
|
| | return PruneResult( |
| | method='zero', |
| | original_stats=original, |
| | final_stats=circuit.stats(weights), |
| | final_weights=weights, |
| | fitness=evaluator.evaluate_single(weights), |
| | time_seconds=time.perf_counter() - start |
| | ) |
| |
|
| |
|
| | def prune_evolutionary(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: |
| | """Evolutionary search with GPU-optimized parallel population evaluation.""" |
| | start = time.perf_counter() |
| | original = circuit.stats() |
| |
|
| | if cfg.evo_pop_size > 0: |
| | pop_size = cfg.evo_pop_size |
| | else: |
| | bytes_per_ind = circuit.n_weights * 4 * 3 |
| | available = cfg.vram.available_bytes - torch.cuda.memory_allocated() if torch.cuda.is_available() else cfg.vram.available_bytes |
| | pop_size = max(10000, min(available // max(bytes_per_ind, 1), evaluator.max_batch, 500000)) |
| |
|
| | elite_size = max(1, int(pop_size * cfg.evo_elite_ratio)) |
| |
|
| | if cfg.verbose: |
| | print(f" [EVO] Population: {pop_size:,}, Elite: {elite_size:,}, Generations: {cfg.evo_generations}") |
| |
|
| | base_vector = circuit.weights_to_vector(circuit.weights) |
| | population = base_vector.unsqueeze(0).expand(pop_size, -1).clone() |
| |
|
| | |
| | n_exact = max(elite_size, pop_size // 10) |
| | for i in range(n_exact, pop_size // 2): |
| | n_muts = max(1, circuit.n_weights // 10) |
| | mut_idx = torch.randperm(circuit.n_weights)[:n_muts] |
| | population[i, mut_idx] += torch.randint(-1, 2, (n_muts,), device=cfg.device, dtype=population.dtype) |
| | for i in range(pop_size // 2, pop_size): |
| | n_muts = max(1, circuit.n_weights // 4) |
| | mut_idx = torch.randperm(circuit.n_weights)[:n_muts] |
| | population[i, mut_idx] += torch.randint(-2, 3, (n_muts,), device=cfg.device, dtype=population.dtype) |
| |
|
| | best_weights = circuit.clone_weights() |
| | best_fitness = evaluator.evaluate_single(best_weights) |
| | best_mag = original['magnitude'] |
| | best_score = best_fitness - cfg.evo_parsimony * best_mag / circuit.n_weights if best_fitness >= cfg.fitness_threshold else -float('inf') |
| | stagnant = 0 |
| | mutation_rate = cfg.evo_mutation_rate |
| |
|
| | for gen in range(cfg.evo_generations): |
| | fitness = evaluator.evaluate_population(population) |
| | magnitudes = population.abs().sum(dim=1) |
| | adjusted = fitness - cfg.evo_parsimony * magnitudes / circuit.n_weights |
| |
|
| | valid_mask = fitness >= cfg.fitness_threshold |
| | n_valid = valid_mask.sum().item() |
| |
|
| | if n_valid > 0: |
| | valid_adjusted = adjusted.clone() |
| | valid_adjusted[~valid_mask] = -float('inf') |
| | best_idx = valid_adjusted.argmax().item() |
| | if adjusted[best_idx] > best_score: |
| | best_score = adjusted[best_idx].item() |
| | best_fitness = fitness[best_idx].item() |
| | best_weights = circuit.vector_to_weights(population[best_idx].clone()) |
| | best_mag = magnitudes[best_idx].item() |
| | stagnant = 0 |
| | else: |
| | stagnant += 1 |
| | else: |
| | stagnant += 1 |
| |
|
| | |
| | if stagnant > 50: |
| | mutation_rate = min(0.5, mutation_rate * 1.1) |
| | elif stagnant == 0: |
| | mutation_rate = max(0.01, mutation_rate * 0.95) |
| |
|
| | if cfg.verbose and (gen % 50 == 0 or gen == cfg.evo_generations - 1): |
| | print(f" Gen {gen:4d} | valid: {n_valid:6,}/{pop_size:,} | mag: {best_mag:.0f} | stag: {stagnant}") |
| |
|
| | |
| | sorted_idx = adjusted.argsort(descending=True) |
| | elite = population[sorted_idx[:elite_size]].clone() |
| | probs = F.softmax(adjusted * 10, dim=0) |
| | parent_idx = torch.multinomial(probs, pop_size - elite_size, replacement=True) |
| | children = population[parent_idx].clone() |
| |
|
| | |
| | if cfg.evo_crossover_rate > 0: |
| | cross_mask = torch.rand(len(children), device=cfg.device) < cfg.evo_crossover_rate |
| | cross_idx = torch.where(cross_mask)[0] |
| | for i in range(0, len(cross_idx) - 1, 2): |
| | p1, p2 = cross_idx[i].item(), cross_idx[i + 1].item() |
| | point = random.randint(1, circuit.n_weights - 1) |
| | temp = children[p1, point:].clone() |
| | children[p1, point:] = children[p2, point:] |
| | children[p2, point:] = temp |
| |
|
| | |
| | mut_mask = torch.rand_like(children) < mutation_rate |
| | mutations = torch.randint(-int(cfg.evo_mutation_strength), int(cfg.evo_mutation_strength) + 1, |
| | children.shape, device=cfg.device, dtype=children.dtype) |
| | children = children + mut_mask * mutations |
| | population = torch.cat([elite, children], dim=0) |
| |
|
| | if stagnant > 200: |
| | if cfg.verbose: |
| | print(f" [EVO] Early stop at generation {gen}") |
| | break |
| |
|
| | if cfg.verbose: |
| | final_stats = circuit.stats(best_weights) |
| | print(f" [EVO] Final: mag={final_stats['magnitude']:.0f} (was {original['magnitude']:.0f})") |
| |
|
| | return PruneResult( |
| | method='evolutionary', |
| | original_stats=original, |
| | final_stats=circuit.stats(best_weights), |
| | final_weights=best_weights, |
| | fitness=best_fitness, |
| | time_seconds=time.perf_counter() - start, |
| | metadata={'generations': gen + 1, 'population_size': pop_size} |
| | ) |
| |
|
| |
|
| | def _partitions(total: int, n: int, max_val: int): |
| | """Generate all ways to partition 'total' into 'n' non-negative integers <= max_val.""" |
| | if n == 0: |
| | if total == 0: |
| | yield [] |
| | return |
| | for i in range(min(total, max_val) + 1): |
| | for rest in _partitions(total - i, n - 1, max_val): |
| | yield [i] + rest |
| |
|
| |
|
| | def _all_signs(abs_vals: list): |
| | """Generate all sign combinations for absolute values.""" |
| | if not abs_vals: |
| | yield [] |
| | return |
| | for rest in _all_signs(abs_vals[1:]): |
| | if abs_vals[0] == 0: |
| | yield [0] + rest |
| | else: |
| | yield [abs_vals[0]] + rest |
| | yield [-abs_vals[0]] + rest |
| |
|
| |
|
| | def _configs_at_magnitude(mag: int, n_params: int): |
| | """Generate all n_params-length configs with given total magnitude.""" |
| | for partition in _partitions(mag, n_params, mag): |
| | for signed in _all_signs(partition): |
| | yield tuple(signed) |
| |
|
| |
|
| | def prune_exhaustive_mag(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: |
| | """Exhaustive search by magnitude - finds provably optimal solutions.""" |
| | start = time.perf_counter() |
| | original = circuit.stats() |
| |
|
| | n_params = original['total'] |
| | max_mag = int(original['magnitude']) |
| | target_mag = cfg.exhaustive_target_mag |
| |
|
| | if n_params > cfg.exhaustive_max_params: |
| | if cfg.verbose: |
| | print(f" [EXHAUSTIVE] Skipping: {n_params} params > max {cfg.exhaustive_max_params}") |
| | return PruneResult( |
| | method='exhaustive_mag', |
| | original_stats=original, |
| | final_stats=original, |
| | final_weights=circuit.clone_weights(), |
| | fitness=evaluator.evaluate_single(circuit.weights), |
| | time_seconds=time.perf_counter() - start, |
| | metadata={'skipped': True} |
| | ) |
| |
|
| | if cfg.verbose: |
| | print(f" [EXHAUSTIVE] Parameters: {n_params}, Original magnitude: {max_mag}") |
| | if target_mag >= 0: |
| | print(f" [EXHAUSTIVE] Target magnitude: {target_mag}") |
| |
|
| | weight_keys = list(circuit.weights.keys()) |
| | weight_shapes = {k: circuit.weights[k].shape for k in weight_keys} |
| | weight_sizes = {k: circuit.weights[k].numel() for k in weight_keys} |
| |
|
| | def vector_to_weights(vec): |
| | weights = {} |
| | idx = 0 |
| | for k in weight_keys: |
| | size = weight_sizes[k] |
| | weights[k] = torch.tensor(vec[idx:idx+size], dtype=torch.float32, device=cfg.device).view(weight_shapes[k]) |
| | idx += size |
| | return weights |
| |
|
| | all_solutions = [] |
| | optimal_mag = None |
| | total_tested = 0 |
| |
|
| | mag_range = [target_mag] if target_mag >= 0 else range(0, max_mag + 1) |
| |
|
| | for mag in mag_range: |
| | configs = list(_configs_at_magnitude(mag, n_params)) |
| | if not configs: |
| | continue |
| |
|
| | if cfg.verbose: |
| | print(f" Magnitude {mag}: {len(configs):,} configurations...", end=" ", flush=True) |
| |
|
| | valid = [] |
| | batch_size = min(100000, len(configs)) |
| |
|
| | for batch_start in range(0, len(configs), batch_size): |
| | batch = configs[batch_start:batch_start + batch_size] |
| | population = torch.tensor(batch, dtype=torch.float32, device=cfg.device) |
| | try: |
| | fitness = evaluator.evaluate_population(population) |
| | except: |
| | fitness = torch.tensor([evaluator.evaluate_single(vector_to_weights(c)) for c in batch], device=cfg.device) |
| | for i, is_valid in enumerate((fitness >= cfg.fitness_threshold).tolist()): |
| | if is_valid: |
| | valid.append(batch[i]) |
| |
|
| | total_tested += len(configs) |
| |
|
| | if valid: |
| | if cfg.verbose: |
| | print(f"FOUND {len(valid)} solutions!") |
| | optimal_mag = mag |
| | all_solutions = valid |
| |
|
| | if cfg.verbose and len(valid) <= 50: |
| | print(f" Solutions:") |
| | for i, sol in enumerate(valid[:20]): |
| | nz = sum(1 for v in sol if v != 0) |
| | print(f" {i+1}: mag={sum(abs(v) for v in sol)}, nz={nz}, {sol}") |
| | break |
| | else: |
| | if cfg.verbose: |
| | print("none") |
| |
|
| | if all_solutions: |
| | best_weights = vector_to_weights(all_solutions[0]) |
| | best_fitness = evaluator.evaluate_single(best_weights) |
| | else: |
| | best_weights = circuit.clone_weights() |
| | best_fitness = evaluator.evaluate_single(best_weights) |
| | optimal_mag = max_mag |
| |
|
| | if cfg.verbose: |
| | print(f" [EXHAUSTIVE] Tested: {total_tested:,}, Optimal: {optimal_mag}, Solutions: {len(all_solutions)}") |
| |
|
| | return PruneResult( |
| | method='exhaustive_mag', |
| | original_stats=original, |
| | final_stats=circuit.stats(best_weights), |
| | final_weights=best_weights, |
| | fitness=best_fitness, |
| | time_seconds=time.perf_counter() - start, |
| | metadata={'optimal_magnitude': optimal_mag, 'solutions_count': len(all_solutions), 'all_solutions': all_solutions[:100]} |
| | ) |
| |
|
| |
|
| | def prune_architecture(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: |
| | """Architecture search - find optimal flat 2-layer architecture.""" |
| | start = time.perf_counter() |
| | original = circuit.stats() |
| |
|
| | n_hidden = cfg.arch_hidden_neurons |
| | n_inputs = circuit.spec.inputs |
| | n_outputs = circuit.spec.outputs |
| | max_weight = cfg.arch_max_weight |
| | max_mag = cfg.arch_max_mag |
| |
|
| | n_params = n_hidden * (n_inputs + 1) + n_outputs * (n_hidden + 1) |
| |
|
| | if cfg.verbose: |
| | print(f" [ARCH] Hidden: {n_hidden}, Params: {n_params}, Max magnitude: {max_mag}") |
| |
|
| | test_inputs = circuit.test_inputs |
| | test_expected = circuit.test_expected |
| |
|
| | def eval_flat(configs: torch.Tensor) -> torch.Tensor: |
| | batch_size = configs.shape[0] |
| | idx = 0 |
| | hidden_w, hidden_b = [], [] |
| | for _ in range(n_hidden): |
| | hidden_w.append(configs[:, idx:idx+n_inputs]) |
| | idx += n_inputs |
| | hidden_b.append(configs[:, idx:idx+1]) |
| | idx += 1 |
| | output_w, output_b = [], [] |
| | for _ in range(n_outputs): |
| | output_w.append(configs[:, idx:idx+n_hidden]) |
| | idx += n_hidden |
| | output_b.append(configs[:, idx:idx+1]) |
| | idx += 1 |
| |
|
| | hidden_acts = [] |
| | for h in range(n_hidden): |
| | act = (hidden_w[h].unsqueeze(1) * test_inputs.unsqueeze(0)).sum(dim=2) + hidden_b[h] |
| | hidden_acts.append((act >= 0).float()) |
| | hidden_stack = torch.stack(hidden_acts, dim=2) |
| |
|
| | outputs = [] |
| | for o in range(n_outputs): |
| | out = (hidden_stack * output_w[o].unsqueeze(1)).sum(dim=2) + output_b[o] |
| | outputs.append((out >= 0).float()) |
| |
|
| | if n_outputs == 1: |
| | predicted = outputs[0] |
| | expected = test_expected.squeeze() |
| | else: |
| | predicted = torch.stack(outputs, dim=2) |
| | expected = test_expected |
| |
|
| | correct = (predicted == expected.unsqueeze(0)).float().mean(dim=1) |
| | if n_outputs > 1: |
| | correct = correct.mean(dim=1) |
| | return correct |
| |
|
| | @lru_cache(maxsize=None) |
| | def partitions(total: int, n_slots: int, max_val: int) -> list: |
| | if n_slots == 0: |
| | return [()] if total == 0 else [] |
| | if n_slots == 1: |
| | return [(total,)] if total <= max_val else [] |
| | result = [] |
| | for v in range(min(total, max_val) + 1): |
| | for rest in partitions(total - v, n_slots - 1, max_val): |
| | result.append((v,) + rest) |
| | return result |
| |
|
| | def signs_for_partition(partition: tuple) -> torch.Tensor: |
| | n = len(partition) |
| | nonzero_idx = [i for i, v in enumerate(partition) if v != 0] |
| | k = len(nonzero_idx) |
| | if k == 0: |
| | return torch.zeros(1, n, device=cfg.device, dtype=torch.float32) |
| | n_patterns = 2 ** k |
| | configs = torch.zeros(n_patterns, n, device=cfg.device, dtype=torch.float32) |
| | for i, idx in enumerate(nonzero_idx): |
| | signs = ((torch.arange(n_patterns, device=cfg.device) >> i) & 1) * 2 - 1 |
| | configs[:, idx] = signs.float() * partition[idx] |
| | return configs |
| |
|
| | all_solutions = [] |
| | optimal_mag = None |
| | total_tested = 0 |
| |
|
| | for target_mag in range(1, max_mag + 1): |
| | all_configs = [] |
| | for partition in partitions(target_mag, n_params, max_weight): |
| | all_configs.append(signs_for_partition(partition)) |
| | if not all_configs: |
| | continue |
| | configs = torch.cat(all_configs, dim=0) |
| |
|
| | if cfg.verbose: |
| | print(f" Magnitude {target_mag}: {configs.shape[0]:,} configs...", end=" ", flush=True) |
| |
|
| | valid = [] |
| | for i in range(0, configs.shape[0], 500000): |
| | batch = configs[i:i+500000] |
| | fitness = eval_flat(batch) |
| | valid.extend(batch[fitness >= cfg.fitness_threshold].cpu().tolist()) |
| |
|
| | total_tested += configs.shape[0] |
| |
|
| | if valid: |
| | if cfg.verbose: |
| | print(f"FOUND {len(valid)} solutions!") |
| | optimal_mag = target_mag |
| | all_solutions = valid |
| | break |
| | else: |
| | if cfg.verbose: |
| | print("none") |
| |
|
| | if cfg.verbose: |
| | print(f" [ARCH] Tested: {total_tested:,}, Optimal: {optimal_mag}, Solutions: {len(all_solutions)}") |
| |
|
| | return PruneResult( |
| | method='architecture', |
| | original_stats=original, |
| | final_stats=original, |
| | final_weights=circuit.clone_weights(), |
| | fitness=evaluator.evaluate_single(circuit.weights), |
| | time_seconds=time.perf_counter() - start, |
| | metadata={'optimal_magnitude': optimal_mag, 'solutions_count': len(all_solutions)} |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | OPTIMAL_COMPONENTS = { |
| | 'xor': { |
| | 'inputs': 2, |
| | 'outputs': 1, |
| | 'neurons': 3, |
| | 'magnitude': 7, |
| | 'solutions': [ |
| | |
| | |
| | {'h1': ([-1, 1], 0), 'h2': ([1, -1], 0), 'out': ([-1, -1], 1)}, |
| | {'h1': ([1, -1], 0), 'h2': ([-1, 1], 0), 'out': ([-1, -1], 1)}, |
| | {'h1': ([-1, 1], 0), 'h2': ([-1, 1], 0), 'out': ([1, -1], 0)}, |
| | {'h1': ([1, -1], 0), 'h2': ([1, -1], 0), 'out': ([-1, 1], 0)}, |
| | {'h1': ([1, 1], -1), 'h2': ([-1, -1], 1), 'out': ([1, 1], -1)}, |
| | {'h1': ([-1, -1], 1), 'h2': ([1, 1], -1), 'out': ([1, 1], -1)}, |
| | ] |
| | }, |
| | 'xor3': { |
| | 'inputs': 3, |
| | 'outputs': 1, |
| | 'neurons': 4, |
| | 'magnitude': 10, |
| | 'solutions': [ |
| | |
| | {'h1': ([0, 0, -1], 0), 'h2': ([-1, 1, -1], 0), 'h3': ([-1, -1, 1], 0), 'out': ([1, -1, -1], 0)}, |
| | {'h1': ([0, 0, 1], -1), 'h2': ([1, -1, 1], -1), 'h3': ([1, 1, -1], -1), 'out': ([-1, 1, 1], 0)}, |
| | {'h1': ([0, -1, 0], 0), 'h2': ([-1, -1, 1], 0), 'h3': ([1, -1, -1], 0), 'out': ([-1, -1, 1], 0)}, |
| | {'h1': ([0, 1, 0], -1), 'h2': ([1, 1, -1], -1), 'h3': ([-1, 1, 1], -1), 'out': ([1, 1, -1], 0)}, |
| | {'h1': ([-1, 0, 0], 0), 'h2': ([-1, 1, -1], 0), 'h3': ([1, -1, -1], 0), 'out': ([-1, 1, -1], 0)}, |
| | {'h1': ([1, 0, 0], -1), 'h2': ([1, -1, 1], -1), 'h3': ([-1, 1, 1], -1), 'out': ([1, -1, 1], 0)}, |
| | ] |
| | }, |
| | 'passthrough': { |
| | 'inputs': 1, |
| | 'outputs': 1, |
| | 'neurons': 1, |
| | 'magnitude': 2, |
| | 'solutions': [ |
| | {'out': ([1], 0)}, |
| | {'out': ([2], -1)}, |
| | ] |
| | } |
| | } |
| |
|
| |
|
| | def prune_compositional(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: |
| | """ |
| | Compositional search for circuits built from known-optimal components. |
| | |
| | Instead of searching 10^15 parameter combinations, recognizes that circuits |
| | like CRC-16 are composed of XOR (6 solutions), XOR3 (18 solutions), and |
| | pass-throughs, giving only 6 × 18 × 18 = 1,944 combinations. |
| | """ |
| | start = time.perf_counter() |
| | original = circuit.stats() |
| |
|
| | if cfg.verbose: |
| | print(f" [COMP] Analyzing circuit structure...") |
| |
|
| | |
| | components = [] |
| | neuron_names = list(circuit.graph.neurons.keys()) |
| |
|
| | |
| | prefixes = set() |
| | for name in neuron_names: |
| | if '.' in name: |
| | prefix = name.rsplit('.', 1)[0] |
| | prefixes.add(prefix) |
| | else: |
| | prefixes.add(name) |
| |
|
| | |
| | for prefix in sorted(prefixes): |
| | related = [n for n in neuron_names if n == prefix or n.startswith(prefix + '.')] |
| | n_neurons = len(related) |
| |
|
| | if n_neurons == 1: |
| | |
| | components.append({'type': 'passthrough', 'prefix': prefix, 'neurons': related}) |
| | elif n_neurons == 3 and any('h1' in n or 'h2' in n for n in related): |
| | |
| | components.append({'type': 'xor', 'prefix': prefix, 'neurons': related}) |
| | elif n_neurons == 4 and any('h1' in n or 'h2' in n or 'h3' in n for n in related): |
| | |
| | components.append({'type': 'xor3', 'prefix': prefix, 'neurons': related}) |
| | else: |
| | |
| | components.append({'type': 'unknown', 'prefix': prefix, 'neurons': related, 'n': n_neurons}) |
| |
|
| | |
| | type_counts = defaultdict(int) |
| | for c in components: |
| | type_counts[c['type']] += 1 |
| |
|
| | if cfg.verbose: |
| | print(f" [COMP] Detected components:") |
| | for ctype, count in sorted(type_counts.items()): |
| | if ctype in OPTIMAL_COMPONENTS: |
| | n_solutions = len(OPTIMAL_COMPONENTS[ctype]['solutions']) |
| | print(f" - {ctype}: {count} instances × {n_solutions} solutions each") |
| | else: |
| | print(f" - {ctype}: {count} instances (no known optimal)") |
| |
|
| | |
| | unknown_components = [c for c in components if c['type'] == 'unknown'] |
| | if unknown_components: |
| | if cfg.verbose: |
| | print(f" [COMP] Cannot use compositional search - unknown components found") |
| | for c in unknown_components[:5]: |
| | print(f" - {c['prefix']}: {c['n']} neurons") |
| | return PruneResult( |
| | method='compositional', |
| | original_stats=original, |
| | final_stats=original, |
| | final_weights=circuit.clone_weights(), |
| | fitness=evaluator.evaluate_single(circuit.weights), |
| | time_seconds=time.perf_counter() - start, |
| | metadata={'status': 'unknown_components', 'unknown': [c['prefix'] for c in unknown_components]} |
| | ) |
| |
|
| | |
| | total_combos = 1 |
| | for c in components: |
| | if c['type'] in OPTIMAL_COMPONENTS: |
| | total_combos *= len(OPTIMAL_COMPONENTS[c['type']]['solutions']) |
| |
|
| | if cfg.verbose: |
| | print(f" [COMP] Total combinations: {total_combos:,}") |
| |
|
| | if total_combos > 10_000_000: |
| | if cfg.verbose: |
| | print(f" [COMP] Too many combinations, using sampling...") |
| | |
| | n_samples = min(1_000_000, total_combos) |
| | valid_solutions = [] |
| |
|
| | for _ in range(n_samples): |
| | weights = circuit.clone_weights() |
| | total_mag = 0 |
| |
|
| | for comp in components: |
| | if comp['type'] not in OPTIMAL_COMPONENTS: |
| | continue |
| | solutions = OPTIMAL_COMPONENTS[comp['type']]['solutions'] |
| | sol = random.choice(solutions) |
| |
|
| | |
| | |
| | total_mag += OPTIMAL_COMPONENTS[comp['type']]['magnitude'] |
| |
|
| | |
| | fitness = evaluator.evaluate_single(weights) |
| | if fitness >= cfg.fitness_threshold: |
| | valid_solutions.append({'magnitude': total_mag, 'fitness': fitness}) |
| |
|
| | if valid_solutions: |
| | best = min(valid_solutions, key=lambda x: x['magnitude']) |
| | if cfg.verbose: |
| | print(f" [COMP] Found {len(valid_solutions)} valid from {n_samples:,} samples") |
| | else: |
| | |
| | valid_solutions = [] |
| | tested = 0 |
| |
|
| | |
| | solution_lists = [] |
| | for comp in components: |
| | if comp['type'] in OPTIMAL_COMPONENTS: |
| | solution_lists.append(list(range(len(OPTIMAL_COMPONENTS[comp['type']]['solutions'])))) |
| | else: |
| | solution_lists.append([0]) |
| |
|
| | if cfg.verbose: |
| | print(f" [COMP] Enumerating {total_combos:,} combinations...") |
| |
|
| | for combo in product(*solution_lists): |
| | tested += 1 |
| | total_mag = 0 |
| |
|
| | for i, comp in enumerate(components): |
| | if comp['type'] in OPTIMAL_COMPONENTS: |
| | total_mag += OPTIMAL_COMPONENTS[comp['type']]['magnitude'] |
| |
|
| | |
| | |
| | valid_solutions.append({'combo': combo, 'magnitude': total_mag}) |
| |
|
| | if cfg.verbose and tested % 10000 == 0: |
| | print(f" Tested {tested:,}/{total_combos:,}...", end='\r') |
| |
|
| | if cfg.verbose: |
| | print(f" Tested {tested:,}/{total_combos:,} - done") |
| |
|
| | |
| | theoretical_mag = sum(OPTIMAL_COMPONENTS[c['type']]['magnitude'] for c in components if c['type'] in OPTIMAL_COMPONENTS) |
| |
|
| | if cfg.verbose: |
| | print(f" [COMP] Theoretical optimal magnitude: {theoretical_mag}") |
| | print(f" [COMP] Original magnitude: {original['magnitude']:.0f}") |
| | if theoretical_mag < original['magnitude']: |
| | print(f" [COMP] Potential reduction: {(1 - theoretical_mag/original['magnitude'])*100:.1f}%") |
| |
|
| | return PruneResult( |
| | method='compositional', |
| | original_stats=original, |
| | final_stats=original, |
| | final_weights=circuit.clone_weights(), |
| | fitness=evaluator.evaluate_single(circuit.weights), |
| | time_seconds=time.perf_counter() - start, |
| | metadata={ |
| | 'components': [(c['type'], c['prefix']) for c in components], |
| | 'total_combinations': total_combos, |
| | 'theoretical_magnitude': theoretical_mag |
| | } |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneResult]: |
| | """Run all enabled pruning methods.""" |
| | print(f"\n{'=' * 70}") |
| | print(f" PRUNING: {circuit.spec.name}") |
| | print(f"{'=' * 70}") |
| |
|
| | usage = cfg.vram.current_usage() |
| | print(f" VRAM: {cfg.vram.total_gb:.1f} GB total, {usage['free_gb']:.1f} GB free") |
| | print(f" Device: {cfg.vram.device_name}") |
| |
|
| | original = circuit.stats() |
| | print(f" Inputs: {circuit.spec.inputs}, Outputs: {circuit.spec.outputs}") |
| | print(f" Neurons: {circuit.spec.neurons}, Layers: {circuit.spec.layers}") |
| | print(f" Parameters: {original['total']}, Non-zero: {original['nonzero']}") |
| | print(f" Magnitude: {original['magnitude']:.0f}") |
| | print(f" Test cases: {circuit.n_cases}") |
| | print(f"{'=' * 70}") |
| |
|
| | evaluator = BatchedEvaluator(circuit, cfg) |
| | initial_fitness = evaluator.evaluate_single(circuit.weights) |
| | print(f"\n Initial fitness: {initial_fitness:.6f}") |
| |
|
| | if initial_fitness < cfg.fitness_threshold: |
| | print(" ERROR: Circuit doesn't pass baseline!") |
| | return {} |
| |
|
| | results = {} |
| | methods = [ |
| | ('magnitude', cfg.run_magnitude, lambda: prune_magnitude(circuit, evaluator, cfg)), |
| | ('zero', cfg.run_zero, lambda: prune_zero(circuit, evaluator, cfg)), |
| | ('evolutionary', cfg.run_evolutionary, lambda: prune_evolutionary(circuit, evaluator, cfg)), |
| | ('exhaustive_mag', cfg.run_exhaustive_mag, lambda: prune_exhaustive_mag(circuit, evaluator, cfg)), |
| | ('architecture', cfg.run_architecture, lambda: prune_architecture(circuit, evaluator, cfg)), |
| | ('compositional', cfg.run_compositional, lambda: prune_compositional(circuit, evaluator, cfg)), |
| | ] |
| |
|
| | enabled = [(n, fn) for n, enabled, fn in methods if enabled] |
| | print(f"\n Running {len(enabled)} pruning methods...") |
| | print(f"{'=' * 70}") |
| |
|
| | for i, (name, fn) in enumerate(enabled): |
| | print(f"\n[{i + 1}/{len(enabled)}] {name.upper()}") |
| | print("-" * 50) |
| | try: |
| | clear_vram() |
| | results[name] = fn() |
| | r = results[name] |
| | print(f" Fitness: {r.fitness:.6f}, Magnitude: {r.final_stats.get('magnitude', 0):.0f}, Time: {r.time_seconds:.1f}s") |
| | except Exception as e: |
| | print(f" ERROR: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | |
| | print(f"\n{'=' * 70}") |
| | print(" SUMMARY") |
| | print(f"{'=' * 70}") |
| | print(f"\n{'Method':<15} {'Fitness':<10} {'Magnitude':<12} {'Reduction':<12} {'Time':<10}") |
| | print("-" * 60) |
| | print(f"{'Original':<15} {'1.0000':<10} {original['magnitude']:<12.0f} {'-':<12} {'-':<10}") |
| |
|
| | best_method, best_mag = None, float('inf') |
| | for name, r in sorted(results.items(), key=lambda x: x[1].final_stats.get('magnitude', float('inf'))): |
| | mag = r.final_stats.get('magnitude', 0) |
| | reduction = f"{(1 - mag/original['magnitude'])*100:.1f}%" if mag < original['magnitude'] else "-" |
| | print(f"{name:<15} {r.fitness:<10.4f} {mag:<12.0f} {reduction:<12} {r.time_seconds:<10.1f}s") |
| | if r.fitness >= cfg.fitness_threshold and mag < best_mag: |
| | best_mag, best_method = mag, name |
| |
|
| | if best_method: |
| | print(f"\n BEST: {best_method} ({(1 - best_mag/original['magnitude'])*100:.1f}% reduction)") |
| |
|
| | return results |
| |
|
| |
|
| | def discover_circuits(base: Path = CIRCUITS_PATH) -> List[CircuitSpec]: |
| | """Find all circuits.""" |
| | circuits = [] |
| | for d in base.iterdir(): |
| | if d.is_dir() and (d / 'config.json').exists() and list(d.glob('*.safetensors')): |
| | try: |
| | with open(d / 'config.json') as f: |
| | cfg = json.load(f) |
| | circuits.append(CircuitSpec( |
| | name=cfg.get('name', d.name), |
| | path=d, |
| | inputs=cfg.get('inputs', 0), |
| | outputs=cfg.get('outputs', 0), |
| | neurons=cfg.get('neurons', 0), |
| | layers=cfg.get('layers', 0), |
| | parameters=cfg.get('parameters', 0) |
| | )) |
| | except: |
| | pass |
| | return sorted(circuits, key=lambda x: (x.inputs, x.neurons)) |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='Threshold Circuit Pruner v5') |
| | parser.add_argument('circuit', nargs='?', help='Circuit name') |
| | parser.add_argument('--weights', type=str, help='Specific .safetensors file') |
| | parser.add_argument('--list', action='store_true') |
| | parser.add_argument('--all', action='store_true') |
| | parser.add_argument('--max-inputs', type=int, default=10) |
| | parser.add_argument('--device', default='cuda') |
| | parser.add_argument('--methods', type=str, help='Comma-separated: mag,zero,evo,exh,arch,comp') |
| | parser.add_argument('--fitness', type=float, default=0.9999) |
| | parser.add_argument('--quiet', action='store_true') |
| | parser.add_argument('--save', action='store_true') |
| | parser.add_argument('--evo-pop', type=int, default=0) |
| | parser.add_argument('--evo-gens', type=int, default=2000) |
| | parser.add_argument('--exhaustive-max-params', type=int, default=12) |
| | parser.add_argument('--target-mag', type=int, default=-1) |
| | parser.add_argument('--arch-hidden', type=int, default=3) |
| | parser.add_argument('--arch-max-weight', type=int, default=3) |
| | parser.add_argument('--arch-max-mag', type=int, default=20) |
| |
|
| | args = parser.parse_args() |
| |
|
| | if args.list: |
| | specs = discover_circuits() |
| | print(f"\nAvailable circuits ({len(specs)}):\n") |
| | for s in specs: |
| | print(f" {s.name:<40} {s.inputs}in/{s.outputs}out {s.neurons}N {s.parameters}P") |
| | return |
| |
|
| | vram_cfg = VRAMConfig() |
| | cfg = Config( |
| | device=args.device, |
| | fitness_threshold=args.fitness, |
| | verbose=not args.quiet, |
| | vram=vram_cfg, |
| | evo_pop_size=args.evo_pop, |
| | evo_generations=args.evo_gens, |
| | exhaustive_max_params=args.exhaustive_max_params, |
| | exhaustive_target_mag=args.target_mag, |
| | arch_hidden_neurons=args.arch_hidden, |
| | arch_max_weight=args.arch_max_weight, |
| | arch_max_mag=args.arch_max_mag |
| | ) |
| |
|
| | if args.methods: |
| | method_map = { |
| | 'mag': 'magnitude', 'magnitude': 'magnitude', |
| | 'zero': 'zero', |
| | 'evo': 'evolutionary', 'evolutionary': 'evolutionary', |
| | 'exh': 'exhaustive_mag', 'exh_mag': 'exhaustive_mag', 'exhaustive': 'exhaustive_mag', |
| | 'arch': 'architecture', 'architecture': 'architecture', |
| | 'comp': 'compositional', 'compositional': 'compositional' |
| | } |
| | for m in args.methods.lower().split(','): |
| | m = m.strip() |
| | if m in method_map: |
| | setattr(cfg, f'run_{method_map[m]}', True) |
| |
|
| | RESULTS_PATH.mkdir(exist_ok=True) |
| |
|
| | if args.all: |
| | specs = [s for s in discover_circuits() if s.inputs <= args.max_inputs] |
| | print(f"\nRunning on {len(specs)} circuits...") |
| | for spec in specs: |
| | try: |
| | circuit = AdaptiveCircuit(spec.path, cfg.device) |
| | run_all_methods(circuit, cfg) |
| | clear_vram() |
| | except Exception as e: |
| | print(f"ERROR on {spec.name}: {e}") |
| | elif args.circuit: |
| | path = CIRCUITS_PATH / args.circuit |
| | if not path.exists(): |
| | path = CIRCUITS_PATH / f'threshold-{args.circuit}' |
| | if not path.exists(): |
| | print(f"Circuit not found: {args.circuit}") |
| | return |
| |
|
| | circuit = AdaptiveCircuit(path, cfg.device, args.weights) |
| | results = run_all_methods(circuit, cfg) |
| |
|
| | if args.save and results: |
| | best = min(results.values(), key=lambda r: r.final_stats.get('magnitude', float('inf'))) |
| | if best.fitness >= cfg.fitness_threshold: |
| | path = circuit.save_weights(best.final_weights, f'pruned_{best.method}') |
| | print(f"\nSaved to: {path}") |
| | else: |
| | parser.print_help() |
| | print("\n\nExamples:") |
| | print(" python prune.py --list") |
| | print(" python prune.py threshold-xor --methods evo") |
| | print(" python prune.py threshold-xor --methods exh --exhaustive-max-params 20") |
| | print(" python prune.py threshold-crc16-mag53 --methods comp") |
| | print(" python prune.py --all --max-inputs 8 --methods mag,zero") |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|