| """ |
| Compute a Fixed Corpus Basis (FCB) for cross-document and |
| cross-model stable state vector extraction. |
| |
| The FCB is the principal subspace of the key manifold computed |
| from a diverse reference corpus. Unlike per-document SVD, |
| the FCB is document-independent — all documents projected |
| with the same FCB exist in the same coordinate system. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import gc |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| from llama_cpp import Llama |
|
|
| from kvcos.core.blob_parser import parse_state_blob |
| from kvcos.core.state_extractor import MARStateExtractor |
| from scripts.generate_alignment_dataset import DOCUMENTS |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser(description="Compute Fixed Corpus Basis") |
| parser.add_argument("--model", required=True) |
| parser.add_argument("--layer-range", type=int, nargs=2, default=[8, 24]) |
| parser.add_argument("--gate-start", type=int, default=6) |
| parser.add_argument("--rank", type=int, default=122) |
| parser.add_argument("--output", required=True) |
| args = parser.parse_args() |
|
|
| llm = Llama(model_path=args.model, n_ctx=2048, n_gpu_layers=-1, verbose=False) |
| meta = llm.metadata |
| n_kv = int(meta.get("llama.attention.head_count_kv", "8")) |
| head_dim = int(meta.get("llama.embedding_length", "4096")) // int( |
| meta.get("llama.attention.head_count", "32") |
| ) |
| model_name = meta.get("general.name", "unknown") |
|
|
| print(f"Model: {model_name} ({n_kv} KV heads, {head_dim} head_dim)") |
| print(f"Layer range: {args.layer_range}, gate_start: {args.gate_start}") |
| print(f"Collecting key tensors from {len(DOCUMENTS)} documents...") |
|
|
| key_tensors: list[torch.Tensor] = [] |
| for i, doc in enumerate(DOCUMENTS): |
| llm.reset() |
| llm(doc.strip(), max_tokens=1, temperature=0.0) |
| s = llm.save_state() |
| parsed = parse_state_blob( |
| bytes(s.llama_state), n_kv_heads=n_kv, head_dim=head_dim |
| ) |
| key_tensors.append(parsed.keys) |
| if (i + 1) % 10 == 0: |
| print(f" {i + 1}/{len(DOCUMENTS)}") |
| del llm |
| gc.collect() |
|
|
| print("Computing corpus SVD...") |
| basis = MARStateExtractor.compute_corpus_basis( |
| key_tensors=key_tensors, |
| layer_range=tuple(args.layer_range), |
| gate_start=args.gate_start, |
| rank=args.rank, |
| ) |
|
|
| output_path = Path(args.output) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save( |
| { |
| "basis": basis, |
| "model_name": model_name, |
| "layer_range": args.layer_range, |
| "gate_start": args.gate_start, |
| "rank": args.rank, |
| "n_corpus_docs": len(DOCUMENTS), |
| "key_tensors": key_tensors, |
| }, |
| str(output_path), |
| ) |
|
|
| print(f"Basis shape: {basis.shape}") |
| print(f"Saved: {output_path}") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|