initial public release: code, README, KNOWN_ISSUES
Browse files- KNOWN_ISSUES.md +48 -0
- README.md +114 -0
- quantize_model_v2.py +492 -0
KNOWN_ISSUES.md
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Known issues — tritllm-codec
|
| 2 |
+
|
| 3 |
+
Surfaced during a pre-release code review. None affect the published paper numbers (those were obtained with the codec exactly as shipped here), but anyone modifying the codec or extending it to new model families should be aware.
|
| 4 |
+
|
| 5 |
+
## SHOULD-FIX
|
| 6 |
+
|
| 7 |
+
### Scale selection is locally optimal, not globally
|
| 8 |
+
**Where:** [`quantize_model_v2.py`, `compute_optimal_scale()`](quantize_model_v2.py#L68)
|
| 9 |
+
|
| 10 |
+
`compute_optimal_scale()` searches only four order-statistic candidates `[G-6, G-4, G-2, G-1]` of the sorted absolute weights. For heavy-tailed or bimodal groups, the true MSE-optimal scale can fall outside this set.
|
| 11 |
+
|
| 12 |
+
**Impact:** measured at <1% PPL on Qwen2.5-7B (vs an exhaustive sweep). The 4-candidate set is a deliberate compute/quality trade-off, not a bug — but the function name and comments overstate what it guarantees. Read it as "MSE-best among 4 fixed candidates," not "global MSE optimum."
|
| 13 |
+
|
| 14 |
+
### Checkpoint resume does not validate source-model fingerprint
|
| 15 |
+
**Where:** [`quantize_model_v2.py`, the resume path around L328, L357, L388](quantize_model_v2.py#L328)
|
| 16 |
+
|
| 17 |
+
The resume logic checks only that `_meta.configs` contains the requested config names. It does not validate:
|
| 18 |
+
- The source model ID (and revision)
|
| 19 |
+
- Each tensor's shape
|
| 20 |
+
- The codec version
|
| 21 |
+
- The depth-power mapping
|
| 22 |
+
|
| 23 |
+
**Impact:** if you re-run quantization into a stale output directory with a different source model (or after editing the codec), assembled matrices may silently mix old and new versions. Always use a fresh `--out` per source model.
|
| 24 |
+
|
| 25 |
+
**Workaround:** delete the output directory before each new quantization run, or wrap the script with a check that the source-model SHA matches.
|
| 26 |
+
|
| 27 |
+
### `from_pretrained` does not pin a revision
|
| 28 |
+
**Where:** [`quantize_model_v2.py`, `AutoModelForCausalLM.from_pretrained(args.model)` around L295](quantize_model_v2.py#L295)
|
| 29 |
+
|
| 30 |
+
Source models are loaded by HF repo ID without `revision=...`, so if Qwen, Llama, or Mistral push a new commit upstream, re-running the codec will produce different trits.
|
| 31 |
+
|
| 32 |
+
**Impact:** the checkpoints on `Entrit/...` were quantized at the time of paper submission against the then-current upstream weights. Future reruns are not guaranteed bit-identical.
|
| 33 |
+
|
| 34 |
+
**Workaround:** add `--revision <sha>` to your wrapping script, or pin in `quantize_model_v2.py` directly if exact reproducibility across time matters.
|
| 35 |
+
|
| 36 |
+
## NIT
|
| 37 |
+
|
| 38 |
+
### Scale codebook range
|
| 39 |
+
**Where:** [`quantize_model_v2.py`, `log_max = np.max(...)` around L89](quantize_model_v2.py#L89)
|
| 40 |
+
|
| 41 |
+
Setting `log_max = np.max(group_abs_maxes)` (replacing the older 99.9th-percentile heuristic, see commit `0c16d24`) fixes the upper-clipping problem cleanly. The downside: a single extreme-scale group can stretch the 27-entry log codebook for the whole matrix and reduce scale resolution for the bulk of groups.
|
| 42 |
+
|
| 43 |
+
**Impact:** not observed in practice on Qwen, Llama, or Mistral. Theoretical risk for matrices with bimodal or extreme outlier scale distributions. If you see unexpectedly high PPL on a new model family, this is a place to look.
|
| 44 |
+
|
| 45 |
+
### Header docstring mentions `fsync` but the code only `rename`s
|
| 46 |
+
**Where:** [`quantize_model_v2.py`, header comment](quantize_model_v2.py#L9)
|
| 47 |
+
|
| 48 |
+
Doc/code mismatch. Either remove the `fsync` claim from the docstring or add an `os.fsync(fd)` before the rename. No runtime impact — checkpoints are atomic via rename, just not durable across power loss.
|
README.md
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- quantization
|
| 5 |
+
- ternary
|
| 6 |
+
- llm
|
| 7 |
+
- post-training-quantization
|
| 8 |
+
library_name: transformers
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# tritllm-codec
|
| 12 |
+
|
| 13 |
+
Reference implementation of the balanced ternary post-training quantization codec from
|
| 14 |
+
**"Balanced Ternary Post-Training Quantization for Large Language Models"** (Stentzel, 2026).
|
| 15 |
+
|
| 16 |
+
Quantizes FP16 LLM weights to balanced ternary at configurable depth `d ∈ {1, 2, 3, 4}` (3, 9, 27, 81 levels per weight) with no calibration data and no per-model tuning. Output is dequantized FP16 safetensors that load into stock `transformers` and `lm-eval` without a custom loader.
|
| 17 |
+
|
| 18 |
+
## What gets quantized
|
| 19 |
+
|
| 20 |
+
The codec quantizes all 2D linear weight matrices in a model. **The following are kept in FP16 and not counted in the BPW total:**
|
| 21 |
+
|
| 22 |
+
- `lm_head` (output projection)
|
| 23 |
+
- Token embeddings (`embed_tokens`)
|
| 24 |
+
- All `*_norm` layers (RMSNorm, LayerNorm — these are 1D anyway)
|
| 25 |
+
|
| 26 |
+
This is the standard convention in quantization papers (see GPTQ, AWQ, NF4) and reflects the fact that embedding lookups and the final classifier are not GEMV-bound at inference time. Throughout the paper, "BPW" refers to the average bits-per-weight of the quantized matrices only.
|
| 27 |
+
|
| 28 |
+
## Install
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
pip install torch transformers safetensors numpy huggingface_hub
|
| 32 |
+
git clone https://huggingface.co/Entrit/tritllm-codec
|
| 33 |
+
cd tritllm-codec
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## Quick start
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
# Quantize Qwen2.5-7B at uniform depth d=2 (3.47 bpw)
|
| 40 |
+
python quantize_model_v2.py \
|
| 41 |
+
--model Qwen/Qwen2.5-7B \
|
| 42 |
+
--configs uniform-d2 \
|
| 43 |
+
--out ./out
|
| 44 |
+
|
| 45 |
+
# Multi-config single pass (computes scales once, derives 6 configs)
|
| 46 |
+
python quantize_model_v2.py \
|
| 47 |
+
--model Qwen/Qwen2.5-7B \
|
| 48 |
+
--configs uniform-d1,uniform-d2,uniform-d3,uniform-d4,d3scale-sens002,d3scale-sens003 \
|
| 49 |
+
--out ./out
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
The output directory contains one HF-loadable model per config:
|
| 53 |
+
|
| 54 |
+
```
|
| 55 |
+
out/
|
| 56 |
+
uniform-d2/
|
| 57 |
+
model/
|
| 58 |
+
config.json
|
| 59 |
+
model.safetensors # dequantized FP16
|
| 60 |
+
tokenizer.json
|
| 61 |
+
...
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
Load like any HF model:
|
| 65 |
+
|
| 66 |
+
```python
|
| 67 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 68 |
+
m = AutoModelForCausalLM.from_pretrained("./out/uniform-d2/model")
|
| 69 |
+
t = AutoTokenizer.from_pretrained("./out/uniform-d2/model")
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
## Settled design (don't change unless reproducing an ablation)
|
| 73 |
+
|
| 74 |
+
| Parameter | Value | Notes |
|
| 75 |
+
|---|---|---|
|
| 76 |
+
| Group size `G` | 16 | Per Section 6.1 of the paper, gs=64 is also viable; gs=16 gives best PPL |
|
| 77 |
+
| Scale depth `d_s` | 3 | 27-entry log-spaced codebook per matrix |
|
| 78 |
+
| Power mapping | d1=1.0, d2=1.5, d3=1.2, d4=1.0 | Tuned once on Qwen2.5-7B, held fixed for all subsequent models |
|
| 79 |
+
| Scale candidates | indices `[G-6, G-4, G-2, G-1]` of sorted `\|w\|` | MSE-minimum over the 4 candidates is selected per group |
|
| 80 |
+
| Scale codebook range | `log_min` = 0.1th percentile of group `\|w\|`-maxes, `log_max` = max | Fixed in commit `0c16d24` (was 99.9th percentile, which clipped) |
|
| 81 |
+
| `lm_head`, embeddings, norms | kept FP16 | See "What gets quantized" above |
|
| 82 |
+
|
| 83 |
+
## BPW calculation
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
bpw = (d * log2(3) + d_s * log2(3) / G) / 1 # weights + scales only
|
| 87 |
+
= d * 1.585 + 0.297 # for G=16, d_s=3
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
Resulting BPW: d1=1.88, d2=3.47, d3=5.05, d4=6.64.
|
| 91 |
+
|
| 92 |
+
## Known limitations
|
| 93 |
+
|
| 94 |
+
Documented in [KNOWN_ISSUES.md](KNOWN_ISSUES.md):
|
| 95 |
+
|
| 96 |
+
- **Scale selection is locally MSE-optimal among 4 candidates**, not globally optimal. For heavy-tailed groups the global optimum can fall outside `[G-6, G-4, G-2, G-1]`. In practice the gap is small (<1% PPL on Qwen2.5-7B).
|
| 97 |
+
- **Checkpoint resume does not validate model fingerprint.** If you re-run quantization into a stale output directory with a different source model, the assembled checkpoint may mix matrices. Always use a fresh `--out` per source model.
|
| 98 |
+
- **`from_pretrained` is not pinned to a revision.** If the source model on the Hub is updated upstream, re-running the codec will produce different trits. Pin a `--revision` if exact reproducibility across time matters.
|
| 99 |
+
- **One outlier group can stretch the 27-entry log scale codebook** for the whole matrix. We have not seen this cause measurable quality loss on Qwen, Llama, or Mistral, but it is a theoretical risk for skewed scale distributions.
|
| 100 |
+
|
| 101 |
+
## Citation
|
| 102 |
+
|
| 103 |
+
```
|
| 104 |
+
@article{stentzel2026ternaryptq,
|
| 105 |
+
title = {Balanced Ternary Post-Training Quantization for Large Language Models},
|
| 106 |
+
author = {Stentzel, Eric},
|
| 107 |
+
year = 2026,
|
| 108 |
+
note = {Entrit Systems}
|
| 109 |
+
}
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
## Models quantized with this codec
|
| 113 |
+
|
| 114 |
+
See the [Entrit organization page](https://huggingface.co/Entrit) for prequantized model checkpoints across Qwen2.5 (0.5B–72B), Llama-3.1-8B, and Mistral-7B at depths d=1 through d=4.
|
quantize_model_v2.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Ternary quantizer v2 — multi-config single-pass with per-matrix checkpointing.
|
| 2 |
+
|
| 3 |
+
Key improvements over v1:
|
| 4 |
+
1. Multi-config: --configs d3scale-sens002,d3scale-sens003,uniform-d2,uniform-d3
|
| 5 |
+
Computes per-group MSE-optimal scales ONCE per matrix, derives all configs.
|
| 6 |
+
~3x faster than running v1 four times.
|
| 7 |
+
2. Per-matrix checkpoint: each matrix's quantized output saved to .checkpoint/
|
| 8 |
+
dir as soon as it's done. Crash-resume picks up where it left off.
|
| 9 |
+
3. Atomic writes (write to .tmp, fsync, rename) — no half-written checkpoints.
|
| 10 |
+
4. Streaming progress.json — monitors can poll without parsing logs.
|
| 11 |
+
5. Per-config HF model assembled at the end from checkpoints.
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python quantize_model_v2.py --model /path/to/model \
|
| 15 |
+
--configs d3scale-sens002,d3scale-sens003,uniform-d2,uniform-d3 \
|
| 16 |
+
--output /path/to/output_root \
|
| 17 |
+
--workers 8 --dtype float16
|
| 18 |
+
|
| 19 |
+
Output structure:
|
| 20 |
+
output_root/
|
| 21 |
+
.checkpoint/
|
| 22 |
+
matrix_00000__model.layers.0.self_attn.q_proj.npz # all configs in one file
|
| 23 |
+
matrix_00001__model.layers.0.self_attn.k_proj.npz
|
| 24 |
+
...
|
| 25 |
+
progress.json # live status
|
| 26 |
+
d3scale-sens002/
|
| 27 |
+
model/ # HF-format output
|
| 28 |
+
config.json
|
| 29 |
+
d3scale-sens003/
|
| 30 |
+
model/
|
| 31 |
+
config.json
|
| 32 |
+
uniform-d2/
|
| 33 |
+
model/
|
| 34 |
+
config.json
|
| 35 |
+
uniform-d3/
|
| 36 |
+
model/
|
| 37 |
+
config.json
|
| 38 |
+
"""
|
| 39 |
+
import os, sys, time, json, gc, argparse, tempfile
|
| 40 |
+
from multiprocessing import Pool
|
| 41 |
+
import numpy as np
|
| 42 |
+
|
| 43 |
+
# ============================================================
|
| 44 |
+
# CODEC CORE (unchanged from v1)
|
| 45 |
+
# ============================================================
|
| 46 |
+
GS = 16
|
| 47 |
+
DEPTH_POWERS = {1: 1.0, 2: 1.5, 3: 1.2, 4: 1.0}
|
| 48 |
+
|
| 49 |
+
def build_levels(half, power):
|
| 50 |
+
int_levels = np.arange(-half, half + 1).astype(np.float64)
|
| 51 |
+
n = int_levels / max(half, 1)
|
| 52 |
+
if power != 1.0:
|
| 53 |
+
return np.sign(n) * np.abs(n) ** power * max(half, 1)
|
| 54 |
+
return int_levels
|
| 55 |
+
|
| 56 |
+
def make_boundaries(level_map, zero_boundary=None):
|
| 57 |
+
"""Default = midpoints between levels. If zero_boundary given, override the
|
| 58 |
+
boundaries straddling 0 (used for d1 with custom zero-zone width)."""
|
| 59 |
+
boundaries = (level_map[:-1] + level_map[1:]) / 2
|
| 60 |
+
if zero_boundary is not None:
|
| 61 |
+
zero_idx = int(np.argmin(np.abs(level_map)))
|
| 62 |
+
if zero_idx > 0:
|
| 63 |
+
boundaries[zero_idx - 1] = -abs(zero_boundary)
|
| 64 |
+
if zero_idx < len(level_map) - 1:
|
| 65 |
+
boundaries[zero_idx] = abs(zero_boundary)
|
| 66 |
+
return boundaries
|
| 67 |
+
|
| 68 |
+
def compute_optimal_scale(groups, depth, power, zero_boundary=None):
|
| 69 |
+
half = (3 ** depth) // 2
|
| 70 |
+
gs = groups.shape[1]
|
| 71 |
+
sa = np.sort(np.abs(groups), axis=1)
|
| 72 |
+
cand_idx = np.clip(np.array([gs-6, gs-4, gs-2, gs-1]), 0, gs-1)
|
| 73 |
+
level_map = build_levels(half, power)
|
| 74 |
+
boundaries = make_boundaries(level_map, zero_boundary)
|
| 75 |
+
N = len(groups)
|
| 76 |
+
best_scale = np.zeros(N); best_mse = np.full(N, np.inf)
|
| 77 |
+
for ki in cand_idx:
|
| 78 |
+
scales = np.maximum(sa[:, ki] / max(half, 1), 1e-30)
|
| 79 |
+
normalized = groups / scales[:, None]
|
| 80 |
+
idx = np.searchsorted(boundaries, normalized.ravel())
|
| 81 |
+
idx = np.clip(idx, 0, len(level_map) - 1)
|
| 82 |
+
q = level_map[idx].reshape(N, gs)
|
| 83 |
+
recon = q * scales[:, None]
|
| 84 |
+
mse = np.mean((groups - recon) ** 2, axis=1)
|
| 85 |
+
better = mse < best_mse
|
| 86 |
+
best_mse[better] = mse[better]; best_scale[better] = scales[better]
|
| 87 |
+
return best_scale, best_mse
|
| 88 |
+
|
| 89 |
+
def trit_quantize_scales(scales, sd):
|
| 90 |
+
log_scales = np.log(np.maximum(scales, 1e-30))
|
| 91 |
+
half = (3 ** sd) // 2
|
| 92 |
+
n_levels = 2 * half + 1
|
| 93 |
+
log_min = np.percentile(log_scales, 0.1)
|
| 94 |
+
log_max = np.max(log_scales) # 100th pct — never clip large scales
|
| 95 |
+
if log_max - log_min < 1e-9:
|
| 96 |
+
log_max = log_min + 1e-9
|
| 97 |
+
codebook_log = np.linspace(log_min, log_max, n_levels)
|
| 98 |
+
idx = np.argmin(np.abs(log_scales[:, None] - codebook_log[None, :]), axis=1)
|
| 99 |
+
return np.exp(codebook_log[idx])
|
| 100 |
+
|
| 101 |
+
def quantize_with_scale(groups, scale, depth, power, zero_boundary=None):
|
| 102 |
+
half = (3 ** depth) // 2
|
| 103 |
+
level_map = build_levels(half, power)
|
| 104 |
+
boundaries = make_boundaries(level_map, zero_boundary)
|
| 105 |
+
scale = np.maximum(scale, 1e-30)
|
| 106 |
+
normalized = groups / scale[:, None]
|
| 107 |
+
idx = np.searchsorted(boundaries, normalized.ravel())
|
| 108 |
+
idx = np.clip(idx, 0, len(level_map) - 1)
|
| 109 |
+
q = level_map[idx].reshape(groups.shape)
|
| 110 |
+
return q * scale[:, None]
|
| 111 |
+
|
| 112 |
+
# ============================================================
|
| 113 |
+
# CODEC CONFIGS
|
| 114 |
+
# ============================================================
|
| 115 |
+
CODECS = {
|
| 116 |
+
'd3scale-sens002': {'mode': 'adaptive', 'scale_depth': 3, 'threshold': 0.002},
|
| 117 |
+
'd3scale-sens003': {'mode': 'adaptive', 'scale_depth': 3, 'threshold': 0.003},
|
| 118 |
+
# d1 with narrow zero zone (zw=0.25): 3 levels {-1,0,+1}, zero only when |w|<0.25*scale.
|
| 119 |
+
# Old default was zw=0.5 which made 97.5% of weights round to 0 (random-chance MMLU).
|
| 120 |
+
'uniform-d1': {'mode': 'uniform', 'scale_depth': 3, 'depth': 1, 'zero_boundary': 0.25},
|
| 121 |
+
'uniform-d2': {'mode': 'uniform', 'scale_depth': 3, 'depth': 2},
|
| 122 |
+
'uniform-d3': {'mode': 'uniform', 'scale_depth': 3, 'depth': 3},
|
| 123 |
+
'uniform-d4': {'mode': 'uniform', 'scale_depth': 3, 'depth': 4},
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
# ============================================================
|
| 127 |
+
# MULTI-CONFIG MATRIX QUANTIZATION
|
| 128 |
+
# ============================================================
|
| 129 |
+
def quantize_matrix_multi(args):
|
| 130 |
+
"""Quantize one matrix for ALL requested configs in a single pass.
|
| 131 |
+
Returns dict: config_name -> (recon_w, depth_counts, weight_bits, scale_bits, n_groups)
|
| 132 |
+
"""
|
| 133 |
+
w_flat, rows, cols, config_names = args
|
| 134 |
+
w = w_flat.reshape(rows, cols)
|
| 135 |
+
pad = (GS - cols % GS) % GS
|
| 136 |
+
if pad > 0:
|
| 137 |
+
w = np.pad(w, ((0, 0), (0, pad)))
|
| 138 |
+
groups = w.reshape(-1, GS).astype(np.float64)
|
| 139 |
+
N = len(groups)
|
| 140 |
+
group_var = np.maximum(np.var(groups, axis=1), 1e-30)
|
| 141 |
+
|
| 142 |
+
# Precompute optimal scale + MSE for every (depth, zero_boundary) combo used.
|
| 143 |
+
# Adaptive uses default boundaries for d2/d3/d4; uniform configs may override (e.g. d1 zw=0.25).
|
| 144 |
+
needed_keys = set() # (depth, zero_boundary)
|
| 145 |
+
for cn in config_names:
|
| 146 |
+
cfg = CODECS[cn]
|
| 147 |
+
if cfg['mode'] == 'adaptive':
|
| 148 |
+
for d in (2, 3, 4):
|
| 149 |
+
needed_keys.add((d, None))
|
| 150 |
+
else:
|
| 151 |
+
needed_keys.add((cfg['depth'], cfg.get('zero_boundary')))
|
| 152 |
+
|
| 153 |
+
scales_per_key = {}
|
| 154 |
+
mse_per_key = {}
|
| 155 |
+
recon_per_key = {}
|
| 156 |
+
for d, zb in sorted(needed_keys, key=lambda x: (x[0], x[1] or 0)):
|
| 157 |
+
power = DEPTH_POWERS[d]
|
| 158 |
+
opt_s, _ = compute_optimal_scale(groups, d, power, zero_boundary=zb)
|
| 159 |
+
use_s = trit_quantize_scales(opt_s, 3)
|
| 160 |
+
r = quantize_with_scale(groups, use_s, d, power, zero_boundary=zb)
|
| 161 |
+
mse = np.mean((groups - r) ** 2, axis=1)
|
| 162 |
+
scales_per_key[(d, zb)] = use_s
|
| 163 |
+
mse_per_key[(d, zb)] = mse
|
| 164 |
+
recon_per_key[(d, zb)] = r
|
| 165 |
+
|
| 166 |
+
out = {}
|
| 167 |
+
for cn in config_names:
|
| 168 |
+
cfg = CODECS[cn]
|
| 169 |
+
if cfg['mode'] == 'uniform':
|
| 170 |
+
d = cfg['depth']
|
| 171 |
+
zb = cfg.get('zero_boundary')
|
| 172 |
+
recon = recon_per_key[(d, zb)]
|
| 173 |
+
depth_counts = {1:0, 2:0, 3:0, 4:0}
|
| 174 |
+
depth_counts[d] = N
|
| 175 |
+
wb = N * GS * d * np.log2(3)
|
| 176 |
+
sb = N * cfg['scale_depth'] * np.log2(3)
|
| 177 |
+
else: # adaptive
|
| 178 |
+
eff_thresh = cfg['threshold'] * 5.5
|
| 179 |
+
recon = np.zeros_like(groups)
|
| 180 |
+
assigned = np.zeros(N, dtype=bool)
|
| 181 |
+
depth_counts = {1:0, 2:0, 3:0, 4:0}
|
| 182 |
+
wb = 0.0; sb = 0.0
|
| 183 |
+
for d in [2, 3, 4]:
|
| 184 |
+
unassigned = ~assigned
|
| 185 |
+
if not np.any(unassigned):
|
| 186 |
+
break
|
| 187 |
+
if d == 4:
|
| 188 |
+
recon[unassigned] = recon_per_key[(4, None)][unassigned]
|
| 189 |
+
n_d = int(np.sum(unassigned))
|
| 190 |
+
depth_counts[d] = n_d
|
| 191 |
+
wb += n_d * GS * d * np.log2(3)
|
| 192 |
+
sb += n_d * cfg['scale_depth'] * np.log2(3)
|
| 193 |
+
break
|
| 194 |
+
mse_d = mse_per_key[(d, None)][unassigned]
|
| 195 |
+
meets = (mse_d / group_var[unassigned]) < eff_thresh
|
| 196 |
+
uidx = np.where(unassigned)[0]
|
| 197 |
+
midx = uidx[meets]
|
| 198 |
+
recon[midx] = recon_per_key[(d, None)][midx]
|
| 199 |
+
assigned[midx] = True
|
| 200 |
+
n_d = int(np.sum(meets))
|
| 201 |
+
depth_counts[d] = n_d
|
| 202 |
+
wb += n_d * GS * d * np.log2(3)
|
| 203 |
+
sb += n_d * cfg['scale_depth'] * np.log2(3)
|
| 204 |
+
|
| 205 |
+
recon_w = recon.reshape(rows, -1)[:, :cols].astype(np.float32)
|
| 206 |
+
out[cn] = {
|
| 207 |
+
'recon_w': recon_w,
|
| 208 |
+
'depth_counts': depth_counts,
|
| 209 |
+
'weight_bits': float(wb),
|
| 210 |
+
'scale_bits': float(sb),
|
| 211 |
+
'n_groups': N,
|
| 212 |
+
}
|
| 213 |
+
return out
|
| 214 |
+
|
| 215 |
+
# ============================================================
|
| 216 |
+
# CHECKPOINTING
|
| 217 |
+
# ============================================================
|
| 218 |
+
def matrix_ckpt_path(ckpt_dir, idx, name):
|
| 219 |
+
safe = name.replace('/', '__').replace('.', '_')
|
| 220 |
+
return os.path.join(ckpt_dir, f'matrix_{idx:05d}__{safe}.npz')
|
| 221 |
+
|
| 222 |
+
def atomic_save_npz(path, data):
|
| 223 |
+
# NOTE: np.savez_compressed silently appends '.npz' if missing.
|
| 224 |
+
# So we name the tmp file with .npz suffix to avoid surprise.
|
| 225 |
+
fd, tmp = tempfile.mkstemp(prefix='.tmp_', suffix='.npz', dir=os.path.dirname(path))
|
| 226 |
+
os.close(fd)
|
| 227 |
+
np.savez_compressed(tmp, **data)
|
| 228 |
+
os.replace(tmp, path)
|
| 229 |
+
|
| 230 |
+
def load_ckpt(path):
|
| 231 |
+
with np.load(path, allow_pickle=True) as z:
|
| 232 |
+
return {k: z[k] for k in z.files}
|
| 233 |
+
|
| 234 |
+
def write_progress(out_root, state):
|
| 235 |
+
path = os.path.join(out_root, 'progress.json')
|
| 236 |
+
fd, tmp = tempfile.mkstemp(prefix='.tmp_', dir=out_root)
|
| 237 |
+
with os.fdopen(fd, 'w') as f:
|
| 238 |
+
json.dump(state, f, indent=2)
|
| 239 |
+
os.replace(tmp, path)
|
| 240 |
+
|
| 241 |
+
# ============================================================
|
| 242 |
+
# MAIN
|
| 243 |
+
# ============================================================
|
| 244 |
+
def main():
|
| 245 |
+
parser = argparse.ArgumentParser(description='Multi-config ternary quantizer with checkpointing')
|
| 246 |
+
parser.add_argument('--model', required=True)
|
| 247 |
+
parser.add_argument('--configs', required=True,
|
| 248 |
+
help='Comma-separated codec names: ' + ','.join(CODECS.keys()))
|
| 249 |
+
parser.add_argument('--output', required=True, help='Output root dir')
|
| 250 |
+
parser.add_argument('--workers', type=int, default=1)
|
| 251 |
+
parser.add_argument('--dtype', default='float16', choices=['float16', 'bfloat16'])
|
| 252 |
+
parser.add_argument('--skip-assembly', action='store_true',
|
| 253 |
+
help='Quantize matrices and checkpoint only; skip final HF model assembly.')
|
| 254 |
+
parser.add_argument('--matrix-range', default=None,
|
| 255 |
+
help='Slice of matrices to process: "start:end" (0-indexed, end exclusive). '
|
| 256 |
+
'Use to manually parallelize across processes/machines via shared checkpoint dir.')
|
| 257 |
+
args = parser.parse_args()
|
| 258 |
+
|
| 259 |
+
config_names = [c.strip() for c in args.configs.split(',') if c.strip()]
|
| 260 |
+
for cn in config_names:
|
| 261 |
+
if cn not in CODECS:
|
| 262 |
+
print(f'ERROR: unknown codec {cn}', file=sys.stderr); sys.exit(2)
|
| 263 |
+
|
| 264 |
+
os.makedirs(args.output, exist_ok=True)
|
| 265 |
+
ckpt_dir = os.path.join(args.output, '.checkpoint')
|
| 266 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 267 |
+
|
| 268 |
+
print(f'=== Quantizing {args.model} ===', flush=True)
|
| 269 |
+
print(f' configs: {config_names}', flush=True)
|
| 270 |
+
print(f' workers: {args.workers}', flush=True)
|
| 271 |
+
|
| 272 |
+
import torch
|
| 273 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel
|
| 274 |
+
|
| 275 |
+
dtype = torch.bfloat16 if args.dtype == 'bfloat16' else torch.float16
|
| 276 |
+
print(' loading model (CPU)...', flush=True)
|
| 277 |
+
t_load = time.time()
|
| 278 |
+
_cfg = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
|
| 279 |
+
_arch = ((getattr(_cfg, 'architectures', None) or [''])[0] or '').lower()
|
| 280 |
+
if 't5' in _arch or 'encoder' in _arch:
|
| 281 |
+
from transformers import T5EncoderModel
|
| 282 |
+
print(' loading as T5EncoderModel (encoder-only)', flush=True)
|
| 283 |
+
model = T5EncoderModel.from_pretrained(args.model, torch_dtype=dtype,
|
| 284 |
+
device_map='cpu', trust_remote_code=True,
|
| 285 |
+
low_cpu_mem_usage=True)
|
| 286 |
+
else:
|
| 287 |
+
try:
|
| 288 |
+
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype,
|
| 289 |
+
device_map='cpu', trust_remote_code=True,
|
| 290 |
+
low_cpu_mem_usage=True)
|
| 291 |
+
except ValueError:
|
| 292 |
+
print(' fallback to generic AutoModel', flush=True)
|
| 293 |
+
model = AutoModel.from_pretrained(args.model, torch_dtype=dtype,
|
| 294 |
+
device_map='cpu', trust_remote_code=True,
|
| 295 |
+
low_cpu_mem_usage=True)
|
| 296 |
+
try:
|
| 297 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
| 298 |
+
if tokenizer.pad_token is None:
|
| 299 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 300 |
+
except Exception as e:
|
| 301 |
+
print(f' tokenizer load failed (ok for encoder-only): {e}', flush=True)
|
| 302 |
+
tokenizer = None
|
| 303 |
+
print(f' loaded in {time.time()-t_load:.0f}s', flush=True)
|
| 304 |
+
|
| 305 |
+
# Collect matrices to quantize (skip embeddings, norms, lm_head)
|
| 306 |
+
matrices = []
|
| 307 |
+
for pn, p in model.named_parameters():
|
| 308 |
+
if p.dim() != 2 or 'norm' in pn or 'embed' in pn or 'lm_head' in pn:
|
| 309 |
+
continue
|
| 310 |
+
matrices.append((pn, p))
|
| 311 |
+
print(f' {len(matrices)} matrices to quantize', flush=True)
|
| 312 |
+
|
| 313 |
+
# Apply --matrix-range slice (for parallel sharded processing)
|
| 314 |
+
range_start, range_end = 0, len(matrices)
|
| 315 |
+
if args.matrix_range:
|
| 316 |
+
s, e = args.matrix_range.split(':')
|
| 317 |
+
range_start = int(s) if s else 0
|
| 318 |
+
range_end = int(e) if e else len(matrices)
|
| 319 |
+
range_end = min(range_end, len(matrices))
|
| 320 |
+
print(f' matrix-range: [{range_start}:{range_end})', flush=True)
|
| 321 |
+
|
| 322 |
+
# Determine which need work (resume from checkpoints)
|
| 323 |
+
todo = []
|
| 324 |
+
done_count = 0
|
| 325 |
+
for idx, (pn, p) in enumerate(matrices):
|
| 326 |
+
if idx < range_start or idx >= range_end:
|
| 327 |
+
continue
|
| 328 |
+
cp = matrix_ckpt_path(ckpt_dir, idx, pn)
|
| 329 |
+
if os.path.exists(cp):
|
| 330 |
+
# Validate it has all requested configs
|
| 331 |
+
try:
|
| 332 |
+
z = np.load(cp, allow_pickle=True)
|
| 333 |
+
have_configs = set(json.loads(str(z['_meta'][()])).get('configs', []))
|
| 334 |
+
if all(cn in have_configs for cn in config_names):
|
| 335 |
+
done_count += 1
|
| 336 |
+
continue
|
| 337 |
+
except Exception as e:
|
| 338 |
+
print(f' bad checkpoint {cp}: {e}, will redo', flush=True)
|
| 339 |
+
os.remove(cp)
|
| 340 |
+
todo.append((idx, pn, p))
|
| 341 |
+
print(f' {done_count} matrices already checkpointed, {len(todo)} to do', flush=True)
|
| 342 |
+
|
| 343 |
+
t0 = time.time()
|
| 344 |
+
state = {
|
| 345 |
+
'model': args.model, 'configs': config_names,
|
| 346 |
+
'total_matrices': len(matrices),
|
| 347 |
+
'done_matrices': done_count,
|
| 348 |
+
'started_at': t0, 'updated_at': t0,
|
| 349 |
+
}
|
| 350 |
+
write_progress(args.output, state)
|
| 351 |
+
|
| 352 |
+
def process_one(idx, pn, p):
|
| 353 |
+
w = p.data.float().numpy()
|
| 354 |
+
result = quantize_matrix_multi(
|
| 355 |
+
(w.ravel(), w.shape[0], w.shape[1], config_names))
|
| 356 |
+
# Pack into npz: one key per config + meta
|
| 357 |
+
save_data = {'_meta': np.array(json.dumps({
|
| 358 |
+
'name': pn, 'idx': idx, 'shape': list(w.shape),
|
| 359 |
+
'configs': config_names,
|
| 360 |
+
}))}
|
| 361 |
+
for cn, info in result.items():
|
| 362 |
+
save_data[f'{cn}__w'] = info['recon_w']
|
| 363 |
+
save_data[f'{cn}__stats'] = np.array(json.dumps({
|
| 364 |
+
'depth_counts': info['depth_counts'],
|
| 365 |
+
'weight_bits': info['weight_bits'],
|
| 366 |
+
'scale_bits': info['scale_bits'],
|
| 367 |
+
'n_groups': info['n_groups'],
|
| 368 |
+
}))
|
| 369 |
+
atomic_save_npz(matrix_ckpt_path(ckpt_dir, idx, pn), save_data)
|
| 370 |
+
return idx
|
| 371 |
+
|
| 372 |
+
if args.workers > 1 and len(todo) > 1:
|
| 373 |
+
# Streaming generator: yield (matrix, config_names) one at a time.
|
| 374 |
+
# CRITICAL: do NOT pre-build all matrices in a list — for large models
|
| 375 |
+
# (Llama 70B = 140GB) that OOMs the box at multiple hundred GB. The generator
|
| 376 |
+
# is consumed lazily by Pool.imap.
|
| 377 |
+
idx_name = [(idx, pn) for idx, pn, _ in todo]
|
| 378 |
+
def gen():
|
| 379 |
+
for idx, pn, p in todo:
|
| 380 |
+
w = p.data.float().numpy()
|
| 381 |
+
yield (w.ravel(), w.shape[0], w.shape[1], config_names)
|
| 382 |
+
# Free the source tensor after we've handed off the numpy view.
|
| 383 |
+
# The Pool worker has its own copy via pickle.
|
| 384 |
+
p.data = __import__('torch').zeros(1, dtype=p.dtype)
|
| 385 |
+
with Pool(args.workers) as pool:
|
| 386 |
+
for i, result in enumerate(pool.imap(quantize_matrix_multi, gen(), chunksize=1)):
|
| 387 |
+
idx, pn = idx_name[i]
|
| 388 |
+
save_data = {'_meta': np.array(json.dumps({
|
| 389 |
+
'name': pn, 'idx': idx,
|
| 390 |
+
'configs': config_names,
|
| 391 |
+
}))}
|
| 392 |
+
for cn, info in result.items():
|
| 393 |
+
save_data[f'{cn}__w'] = info['recon_w']
|
| 394 |
+
save_data[f'{cn}__stats'] = np.array(json.dumps({
|
| 395 |
+
'depth_counts': info['depth_counts'],
|
| 396 |
+
'weight_bits': info['weight_bits'],
|
| 397 |
+
'scale_bits': info['scale_bits'],
|
| 398 |
+
'n_groups': info['n_groups'],
|
| 399 |
+
}))
|
| 400 |
+
atomic_save_npz(matrix_ckpt_path(ckpt_dir, idx, pn), save_data)
|
| 401 |
+
done_count += 1
|
| 402 |
+
state['done_matrices'] = done_count
|
| 403 |
+
state['updated_at'] = time.time()
|
| 404 |
+
state['elapsed_s'] = time.time() - t0
|
| 405 |
+
if (i+1) % 5 == 0 or (i+1) == len(todo):
|
| 406 |
+
write_progress(args.output, state)
|
| 407 |
+
eta = (len(todo) - (i+1)) * (time.time() - t0) / max(i+1, 1)
|
| 408 |
+
print(f' {done_count}/{len(matrices)} ({time.time()-t0:.0f}s, ETA {eta:.0f}s)', flush=True)
|
| 409 |
+
else:
|
| 410 |
+
for i, (idx, pn, p) in enumerate(todo):
|
| 411 |
+
process_one(idx, pn, p)
|
| 412 |
+
done_count += 1
|
| 413 |
+
state['done_matrices'] = done_count
|
| 414 |
+
state['updated_at'] = time.time()
|
| 415 |
+
state['elapsed_s'] = time.time() - t0
|
| 416 |
+
if (i+1) % 5 == 0 or (i+1) == len(todo):
|
| 417 |
+
write_progress(args.output, state)
|
| 418 |
+
eta = (len(todo) - (i+1)) * (time.time() - t0) / max(i+1, 1)
|
| 419 |
+
print(f' {done_count}/{len(matrices)} ({time.time()-t0:.0f}s, ETA {eta:.0f}s)', flush=True)
|
| 420 |
+
|
| 421 |
+
print(f' Quantization complete in {time.time()-t0:.0f}s', flush=True)
|
| 422 |
+
|
| 423 |
+
# If we processed only a slice, don't assemble — leave that for the merge step.
|
| 424 |
+
if args.matrix_range:
|
| 425 |
+
# Verify which checkpoints exist for this slice; print summary
|
| 426 |
+
slice_done = sum(1 for idx, (pn, p) in enumerate(matrices)
|
| 427 |
+
if range_start <= idx < range_end
|
| 428 |
+
and os.path.exists(matrix_ckpt_path(ckpt_dir, idx, pn)))
|
| 429 |
+
print(f' slice [{range_start}:{range_end}): {slice_done} checkpointed', flush=True)
|
| 430 |
+
return
|
| 431 |
+
|
| 432 |
+
if args.skip_assembly:
|
| 433 |
+
print(' --skip-assembly: not building HF model dirs', flush=True)
|
| 434 |
+
return
|
| 435 |
+
|
| 436 |
+
# ============================================================
|
| 437 |
+
# ASSEMBLY: load each config from checkpoints, write HF model
|
| 438 |
+
# ============================================================
|
| 439 |
+
print(' Assembling HF models per config...', flush=True)
|
| 440 |
+
for cn in config_names:
|
| 441 |
+
cfg_dir = os.path.join(args.output, cn)
|
| 442 |
+
os.makedirs(cfg_dir, exist_ok=True)
|
| 443 |
+
model_dir = os.path.join(cfg_dir, 'model')
|
| 444 |
+
|
| 445 |
+
# Aggregate stats
|
| 446 |
+
total_groups = 0
|
| 447 |
+
total_depth = {1:0, 2:0, 3:0, 4:0}
|
| 448 |
+
total_wb = 0.0; total_sb = 0.0
|
| 449 |
+
|
| 450 |
+
# Replace tensors in-place with this config's reconstruction
|
| 451 |
+
name_to_param = {pn: p for pn, p in matrices}
|
| 452 |
+
for idx, (pn, p) in enumerate(matrices):
|
| 453 |
+
cp = matrix_ckpt_path(ckpt_dir, idx, pn)
|
| 454 |
+
z = np.load(cp, allow_pickle=True)
|
| 455 |
+
recon_w = z[f'{cn}__w']
|
| 456 |
+
stats = json.loads(str(z[f'{cn}__stats'][()]))
|
| 457 |
+
p.data = __import__('torch').from_numpy(recon_w).to(p.dtype)
|
| 458 |
+
total_groups += stats['n_groups']
|
| 459 |
+
for d in [1,2,3,4]:
|
| 460 |
+
total_depth[d] += stats['depth_counts'].get(str(d), stats['depth_counts'].get(d, 0))
|
| 461 |
+
total_wb += stats['weight_bits']
|
| 462 |
+
total_sb += stats['scale_bits']
|
| 463 |
+
|
| 464 |
+
tg = max(total_groups, 1)
|
| 465 |
+
trit_bpw = total_wb / (tg * GS)
|
| 466 |
+
scale_bpw = total_sb / (tg * GS)
|
| 467 |
+
total_bpw = trit_bpw + scale_bpw
|
| 468 |
+
|
| 469 |
+
print(f' [{cn}] BPW={total_bpw:.3f} (trit={trit_bpw:.3f}+scale={scale_bpw:.3f})', flush=True)
|
| 470 |
+
print(f' [{cn}] Saving to {model_dir}...', flush=True)
|
| 471 |
+
model.save_pretrained(model_dir, safe_serialization=True)
|
| 472 |
+
if tokenizer is not None:
|
| 473 |
+
tokenizer.save_pretrained(model_dir)
|
| 474 |
+
|
| 475 |
+
config = {
|
| 476 |
+
'model': os.path.basename(args.model.rstrip('/')),
|
| 477 |
+
'codec': cn,
|
| 478 |
+
'bpw': total_bpw, 'trit_bpw': trit_bpw, 'scale_bpw': scale_bpw,
|
| 479 |
+
'depth_pcts': {str(d): total_depth[d]/tg for d in [1,2,3,4]},
|
| 480 |
+
'n_matrices': len(matrices),
|
| 481 |
+
'group_size': GS,
|
| 482 |
+
'lm_head_skipped': True,
|
| 483 |
+
'codec_params': CODECS[cn],
|
| 484 |
+
}
|
| 485 |
+
with open(os.path.join(cfg_dir, 'config.json'), 'w') as f:
|
| 486 |
+
json.dump(config, f, indent=2)
|
| 487 |
+
print(f' [{cn}] DONE: {cfg_dir}', flush=True)
|
| 488 |
+
|
| 489 |
+
print(f' ALL CONFIGS COMPLETE in {time.time()-t0:.0f}s total', flush=True)
|
| 490 |
+
|
| 491 |
+
if __name__ == '__main__':
|
| 492 |
+
main()
|