""" prototypes.py — Build and manage the offline prototype library. The prototype library maps each (layer, head) pair to a set of K centroid attention-distribution vectors, learned via K-Means from the profiling corpus. At inference time, these centroids drive the O(n) scoring function without any query lookups. """ from __future__ import annotations import os import pickle import numpy as np from typing import Dict, Optional, List from sklearn.cluster import KMeans def build_prototypes( patterns: List[Dict], n_clusters: int = 4, max_seq_len: int = 512, random_state: int = 42, ) -> Dict: """ Cluster per-head attention patterns into prototype centroids. Args: patterns: Output of ``profile_model()`` — list of dicts mapping ``(layer, head) → np.ndarray`` of shape ``(seq_len,)``. n_clusters: Number of K-Means clusters per head (default 4). max_seq_len: Maximum sequence length to include in clustering. random_state: Random seed for reproducibility. Returns: prototypes: Dict mapping ``(layer, head) → {"centroids": np.ndarray}`` where centroids has shape ``(n_clusters, max_seq_len)``. """ if not patterns: raise ValueError("patterns list is empty. Run profile_model() first.") keys = sorted(patterns[0].keys()) prototypes = {} for (layer, head) in keys: data = np.array([ p[(layer, head)] for p in patterns if (layer, head) in p ]) # shape: (num_docs, max_seq_len) if len(data) == 0: continue k = min(n_clusters, len(data)) kmeans = KMeans(n_clusters=k, random_state=random_state, n_init=10) kmeans.fit(data) prototypes[(layer, head)] = { "centroids": kmeans.cluster_centers_.astype(np.float32), "labels": kmeans.labels_, "inertia": float(kmeans.inertia_), } return prototypes def save_prototypes(prototypes: Dict, path: str) -> None: """Serialize prototypes to disk.""" os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) with open(path, "wb") as f: pickle.dump(prototypes, f) print(f"[ProactiveCache] Prototypes saved to {path}") def load_prototypes(path: str) -> Dict: """Load prototypes from disk.""" if not os.path.exists(path): raise FileNotFoundError( f"Prototype file not found: {path}\n" "Run ProactiveCache.profile(model, ..., save_path='{path}') first." ) with open(path, "rb") as f: prototypes = pickle.load(f) print(f"[ProactiveCache] Loaded {len(prototypes)} prototypes from {path}") return prototypes def prototype_summary(prototypes: Dict) -> str: """Return a human-readable summary of a prototype library.""" num_pairs = len(prototypes) if num_pairs == 0: return "Empty prototype library." layers = sorted(set(layer for (layer, _) in prototypes)) heads_per_layer = sorted(set(head for (_, head) in prototypes)) sample_key = next(iter(prototypes)) n_clusters = prototypes[sample_key]["centroids"].shape[0] seq_len = prototypes[sample_key]["centroids"].shape[1] return ( f"ProactiveCache Prototype Library\n" f" Layers: {len(layers)} ({layers[0]}–{layers[-1]})\n" f" Heads per layer: {len(heads_per_layer)}\n" f" Total (L, H): {num_pairs}\n" f" Clusters/head: {n_clusters}\n" f" Profile seq_len: {seq_len}\n" )