Entrit commited on
Commit
d599083
·
verified ·
1 Parent(s): bf73af4

initial public release: code, README, KNOWN_ISSUES

Browse files
Files changed (3) hide show
  1. KNOWN_ISSUES.md +48 -0
  2. README.md +114 -0
  3. quantize_model_v2.py +492 -0
KNOWN_ISSUES.md ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Known issues — tritllm-codec
2
+
3
+ Surfaced during a pre-release code review. None affect the published paper numbers (those were obtained with the codec exactly as shipped here), but anyone modifying the codec or extending it to new model families should be aware.
4
+
5
+ ## SHOULD-FIX
6
+
7
+ ### Scale selection is locally optimal, not globally
8
+ **Where:** [`quantize_model_v2.py`, `compute_optimal_scale()`](quantize_model_v2.py#L68)
9
+
10
+ `compute_optimal_scale()` searches only four order-statistic candidates `[G-6, G-4, G-2, G-1]` of the sorted absolute weights. For heavy-tailed or bimodal groups, the true MSE-optimal scale can fall outside this set.
11
+
12
+ **Impact:** measured at <1% PPL on Qwen2.5-7B (vs an exhaustive sweep). The 4-candidate set is a deliberate compute/quality trade-off, not a bug — but the function name and comments overstate what it guarantees. Read it as "MSE-best among 4 fixed candidates," not "global MSE optimum."
13
+
14
+ ### Checkpoint resume does not validate source-model fingerprint
15
+ **Where:** [`quantize_model_v2.py`, the resume path around L328, L357, L388](quantize_model_v2.py#L328)
16
+
17
+ The resume logic checks only that `_meta.configs` contains the requested config names. It does not validate:
18
+ - The source model ID (and revision)
19
+ - Each tensor's shape
20
+ - The codec version
21
+ - The depth-power mapping
22
+
23
+ **Impact:** if you re-run quantization into a stale output directory with a different source model (or after editing the codec), assembled matrices may silently mix old and new versions. Always use a fresh `--out` per source model.
24
+
25
+ **Workaround:** delete the output directory before each new quantization run, or wrap the script with a check that the source-model SHA matches.
26
+
27
+ ### `from_pretrained` does not pin a revision
28
+ **Where:** [`quantize_model_v2.py`, `AutoModelForCausalLM.from_pretrained(args.model)` around L295](quantize_model_v2.py#L295)
29
+
30
+ Source models are loaded by HF repo ID without `revision=...`, so if Qwen, Llama, or Mistral push a new commit upstream, re-running the codec will produce different trits.
31
+
32
+ **Impact:** the checkpoints on `Entrit/...` were quantized at the time of paper submission against the then-current upstream weights. Future reruns are not guaranteed bit-identical.
33
+
34
+ **Workaround:** add `--revision <sha>` to your wrapping script, or pin in `quantize_model_v2.py` directly if exact reproducibility across time matters.
35
+
36
+ ## NIT
37
+
38
+ ### Scale codebook range
39
+ **Where:** [`quantize_model_v2.py`, `log_max = np.max(...)` around L89](quantize_model_v2.py#L89)
40
+
41
+ Setting `log_max = np.max(group_abs_maxes)` (replacing the older 99.9th-percentile heuristic, see commit `0c16d24`) fixes the upper-clipping problem cleanly. The downside: a single extreme-scale group can stretch the 27-entry log codebook for the whole matrix and reduce scale resolution for the bulk of groups.
42
+
43
+ **Impact:** not observed in practice on Qwen, Llama, or Mistral. Theoretical risk for matrices with bimodal or extreme outlier scale distributions. If you see unexpectedly high PPL on a new model family, this is a place to look.
44
+
45
+ ### Header docstring mentions `fsync` but the code only `rename`s
46
+ **Where:** [`quantize_model_v2.py`, header comment](quantize_model_v2.py#L9)
47
+
48
+ Doc/code mismatch. Either remove the `fsync` claim from the docstring or add an `os.fsync(fd)` before the rename. No runtime impact — checkpoints are atomic via rename, just not durable across power loss.
README.md ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - quantization
5
+ - ternary
6
+ - llm
7
+ - post-training-quantization
8
+ library_name: transformers
9
+ ---
10
+
11
+ # tritllm-codec
12
+
13
+ Reference implementation of the balanced ternary post-training quantization codec from
14
+ **"Balanced Ternary Post-Training Quantization for Large Language Models"** (Stentzel, 2026).
15
+
16
+ Quantizes FP16 LLM weights to balanced ternary at configurable depth `d ∈ {1, 2, 3, 4}` (3, 9, 27, 81 levels per weight) with no calibration data and no per-model tuning. Output is dequantized FP16 safetensors that load into stock `transformers` and `lm-eval` without a custom loader.
17
+
18
+ ## What gets quantized
19
+
20
+ The codec quantizes all 2D linear weight matrices in a model. **The following are kept in FP16 and not counted in the BPW total:**
21
+
22
+ - `lm_head` (output projection)
23
+ - Token embeddings (`embed_tokens`)
24
+ - All `*_norm` layers (RMSNorm, LayerNorm — these are 1D anyway)
25
+
26
+ This is the standard convention in quantization papers (see GPTQ, AWQ, NF4) and reflects the fact that embedding lookups and the final classifier are not GEMV-bound at inference time. Throughout the paper, "BPW" refers to the average bits-per-weight of the quantized matrices only.
27
+
28
+ ## Install
29
+
30
+ ```bash
31
+ pip install torch transformers safetensors numpy huggingface_hub
32
+ git clone https://huggingface.co/Entrit/tritllm-codec
33
+ cd tritllm-codec
34
+ ```
35
+
36
+ ## Quick start
37
+
38
+ ```bash
39
+ # Quantize Qwen2.5-7B at uniform depth d=2 (3.47 bpw)
40
+ python quantize_model_v2.py \
41
+ --model Qwen/Qwen2.5-7B \
42
+ --configs uniform-d2 \
43
+ --out ./out
44
+
45
+ # Multi-config single pass (computes scales once, derives 6 configs)
46
+ python quantize_model_v2.py \
47
+ --model Qwen/Qwen2.5-7B \
48
+ --configs uniform-d1,uniform-d2,uniform-d3,uniform-d4,d3scale-sens002,d3scale-sens003 \
49
+ --out ./out
50
+ ```
51
+
52
+ The output directory contains one HF-loadable model per config:
53
+
54
+ ```
55
+ out/
56
+ uniform-d2/
57
+ model/
58
+ config.json
59
+ model.safetensors # dequantized FP16
60
+ tokenizer.json
61
+ ...
62
+ ```
63
+
64
+ Load like any HF model:
65
+
66
+ ```python
67
+ from transformers import AutoModelForCausalLM, AutoTokenizer
68
+ m = AutoModelForCausalLM.from_pretrained("./out/uniform-d2/model")
69
+ t = AutoTokenizer.from_pretrained("./out/uniform-d2/model")
70
+ ```
71
+
72
+ ## Settled design (don't change unless reproducing an ablation)
73
+
74
+ | Parameter | Value | Notes |
75
+ |---|---|---|
76
+ | Group size `G` | 16 | Per Section 6.1 of the paper, gs=64 is also viable; gs=16 gives best PPL |
77
+ | Scale depth `d_s` | 3 | 27-entry log-spaced codebook per matrix |
78
+ | Power mapping | d1=1.0, d2=1.5, d3=1.2, d4=1.0 | Tuned once on Qwen2.5-7B, held fixed for all subsequent models |
79
+ | Scale candidates | indices `[G-6, G-4, G-2, G-1]` of sorted `\|w\|` | MSE-minimum over the 4 candidates is selected per group |
80
+ | Scale codebook range | `log_min` = 0.1th percentile of group `\|w\|`-maxes, `log_max` = max | Fixed in commit `0c16d24` (was 99.9th percentile, which clipped) |
81
+ | `lm_head`, embeddings, norms | kept FP16 | See "What gets quantized" above |
82
+
83
+ ## BPW calculation
84
+
85
+ ```
86
+ bpw = (d * log2(3) + d_s * log2(3) / G) / 1 # weights + scales only
87
+ = d * 1.585 + 0.297 # for G=16, d_s=3
88
+ ```
89
+
90
+ Resulting BPW: d1=1.88, d2=3.47, d3=5.05, d4=6.64.
91
+
92
+ ## Known limitations
93
+
94
+ Documented in [KNOWN_ISSUES.md](KNOWN_ISSUES.md):
95
+
96
+ - **Scale selection is locally MSE-optimal among 4 candidates**, not globally optimal. For heavy-tailed groups the global optimum can fall outside `[G-6, G-4, G-2, G-1]`. In practice the gap is small (<1% PPL on Qwen2.5-7B).
97
+ - **Checkpoint resume does not validate model fingerprint.** If you re-run quantization into a stale output directory with a different source model, the assembled checkpoint may mix matrices. Always use a fresh `--out` per source model.
98
+ - **`from_pretrained` is not pinned to a revision.** If the source model on the Hub is updated upstream, re-running the codec will produce different trits. Pin a `--revision` if exact reproducibility across time matters.
99
+ - **One outlier group can stretch the 27-entry log scale codebook** for the whole matrix. We have not seen this cause measurable quality loss on Qwen, Llama, or Mistral, but it is a theoretical risk for skewed scale distributions.
100
+
101
+ ## Citation
102
+
103
+ ```
104
+ @article{stentzel2026ternaryptq,
105
+ title = {Balanced Ternary Post-Training Quantization for Large Language Models},
106
+ author = {Stentzel, Eric},
107
+ year = 2026,
108
+ note = {Entrit Systems}
109
+ }
110
+ ```
111
+
112
+ ## Models quantized with this codec
113
+
114
+ See the [Entrit organization page](https://huggingface.co/Entrit) for prequantized model checkpoints across Qwen2.5 (0.5B–72B), Llama-3.1-8B, and Mistral-7B at depths d=1 through d=4.
quantize_model_v2.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ternary quantizer v2 — multi-config single-pass with per-matrix checkpointing.
2
+
3
+ Key improvements over v1:
4
+ 1. Multi-config: --configs d3scale-sens002,d3scale-sens003,uniform-d2,uniform-d3
5
+ Computes per-group MSE-optimal scales ONCE per matrix, derives all configs.
6
+ ~3x faster than running v1 four times.
7
+ 2. Per-matrix checkpoint: each matrix's quantized output saved to .checkpoint/
8
+ dir as soon as it's done. Crash-resume picks up where it left off.
9
+ 3. Atomic writes (write to .tmp, fsync, rename) — no half-written checkpoints.
10
+ 4. Streaming progress.json — monitors can poll without parsing logs.
11
+ 5. Per-config HF model assembled at the end from checkpoints.
12
+
13
+ Usage:
14
+ python quantize_model_v2.py --model /path/to/model \
15
+ --configs d3scale-sens002,d3scale-sens003,uniform-d2,uniform-d3 \
16
+ --output /path/to/output_root \
17
+ --workers 8 --dtype float16
18
+
19
+ Output structure:
20
+ output_root/
21
+ .checkpoint/
22
+ matrix_00000__model.layers.0.self_attn.q_proj.npz # all configs in one file
23
+ matrix_00001__model.layers.0.self_attn.k_proj.npz
24
+ ...
25
+ progress.json # live status
26
+ d3scale-sens002/
27
+ model/ # HF-format output
28
+ config.json
29
+ d3scale-sens003/
30
+ model/
31
+ config.json
32
+ uniform-d2/
33
+ model/
34
+ config.json
35
+ uniform-d3/
36
+ model/
37
+ config.json
38
+ """
39
+ import os, sys, time, json, gc, argparse, tempfile
40
+ from multiprocessing import Pool
41
+ import numpy as np
42
+
43
+ # ============================================================
44
+ # CODEC CORE (unchanged from v1)
45
+ # ============================================================
46
+ GS = 16
47
+ DEPTH_POWERS = {1: 1.0, 2: 1.5, 3: 1.2, 4: 1.0}
48
+
49
+ def build_levels(half, power):
50
+ int_levels = np.arange(-half, half + 1).astype(np.float64)
51
+ n = int_levels / max(half, 1)
52
+ if power != 1.0:
53
+ return np.sign(n) * np.abs(n) ** power * max(half, 1)
54
+ return int_levels
55
+
56
+ def make_boundaries(level_map, zero_boundary=None):
57
+ """Default = midpoints between levels. If zero_boundary given, override the
58
+ boundaries straddling 0 (used for d1 with custom zero-zone width)."""
59
+ boundaries = (level_map[:-1] + level_map[1:]) / 2
60
+ if zero_boundary is not None:
61
+ zero_idx = int(np.argmin(np.abs(level_map)))
62
+ if zero_idx > 0:
63
+ boundaries[zero_idx - 1] = -abs(zero_boundary)
64
+ if zero_idx < len(level_map) - 1:
65
+ boundaries[zero_idx] = abs(zero_boundary)
66
+ return boundaries
67
+
68
+ def compute_optimal_scale(groups, depth, power, zero_boundary=None):
69
+ half = (3 ** depth) // 2
70
+ gs = groups.shape[1]
71
+ sa = np.sort(np.abs(groups), axis=1)
72
+ cand_idx = np.clip(np.array([gs-6, gs-4, gs-2, gs-1]), 0, gs-1)
73
+ level_map = build_levels(half, power)
74
+ boundaries = make_boundaries(level_map, zero_boundary)
75
+ N = len(groups)
76
+ best_scale = np.zeros(N); best_mse = np.full(N, np.inf)
77
+ for ki in cand_idx:
78
+ scales = np.maximum(sa[:, ki] / max(half, 1), 1e-30)
79
+ normalized = groups / scales[:, None]
80
+ idx = np.searchsorted(boundaries, normalized.ravel())
81
+ idx = np.clip(idx, 0, len(level_map) - 1)
82
+ q = level_map[idx].reshape(N, gs)
83
+ recon = q * scales[:, None]
84
+ mse = np.mean((groups - recon) ** 2, axis=1)
85
+ better = mse < best_mse
86
+ best_mse[better] = mse[better]; best_scale[better] = scales[better]
87
+ return best_scale, best_mse
88
+
89
+ def trit_quantize_scales(scales, sd):
90
+ log_scales = np.log(np.maximum(scales, 1e-30))
91
+ half = (3 ** sd) // 2
92
+ n_levels = 2 * half + 1
93
+ log_min = np.percentile(log_scales, 0.1)
94
+ log_max = np.max(log_scales) # 100th pct — never clip large scales
95
+ if log_max - log_min < 1e-9:
96
+ log_max = log_min + 1e-9
97
+ codebook_log = np.linspace(log_min, log_max, n_levels)
98
+ idx = np.argmin(np.abs(log_scales[:, None] - codebook_log[None, :]), axis=1)
99
+ return np.exp(codebook_log[idx])
100
+
101
+ def quantize_with_scale(groups, scale, depth, power, zero_boundary=None):
102
+ half = (3 ** depth) // 2
103
+ level_map = build_levels(half, power)
104
+ boundaries = make_boundaries(level_map, zero_boundary)
105
+ scale = np.maximum(scale, 1e-30)
106
+ normalized = groups / scale[:, None]
107
+ idx = np.searchsorted(boundaries, normalized.ravel())
108
+ idx = np.clip(idx, 0, len(level_map) - 1)
109
+ q = level_map[idx].reshape(groups.shape)
110
+ return q * scale[:, None]
111
+
112
+ # ============================================================
113
+ # CODEC CONFIGS
114
+ # ============================================================
115
+ CODECS = {
116
+ 'd3scale-sens002': {'mode': 'adaptive', 'scale_depth': 3, 'threshold': 0.002},
117
+ 'd3scale-sens003': {'mode': 'adaptive', 'scale_depth': 3, 'threshold': 0.003},
118
+ # d1 with narrow zero zone (zw=0.25): 3 levels {-1,0,+1}, zero only when |w|<0.25*scale.
119
+ # Old default was zw=0.5 which made 97.5% of weights round to 0 (random-chance MMLU).
120
+ 'uniform-d1': {'mode': 'uniform', 'scale_depth': 3, 'depth': 1, 'zero_boundary': 0.25},
121
+ 'uniform-d2': {'mode': 'uniform', 'scale_depth': 3, 'depth': 2},
122
+ 'uniform-d3': {'mode': 'uniform', 'scale_depth': 3, 'depth': 3},
123
+ 'uniform-d4': {'mode': 'uniform', 'scale_depth': 3, 'depth': 4},
124
+ }
125
+
126
+ # ============================================================
127
+ # MULTI-CONFIG MATRIX QUANTIZATION
128
+ # ============================================================
129
+ def quantize_matrix_multi(args):
130
+ """Quantize one matrix for ALL requested configs in a single pass.
131
+ Returns dict: config_name -> (recon_w, depth_counts, weight_bits, scale_bits, n_groups)
132
+ """
133
+ w_flat, rows, cols, config_names = args
134
+ w = w_flat.reshape(rows, cols)
135
+ pad = (GS - cols % GS) % GS
136
+ if pad > 0:
137
+ w = np.pad(w, ((0, 0), (0, pad)))
138
+ groups = w.reshape(-1, GS).astype(np.float64)
139
+ N = len(groups)
140
+ group_var = np.maximum(np.var(groups, axis=1), 1e-30)
141
+
142
+ # Precompute optimal scale + MSE for every (depth, zero_boundary) combo used.
143
+ # Adaptive uses default boundaries for d2/d3/d4; uniform configs may override (e.g. d1 zw=0.25).
144
+ needed_keys = set() # (depth, zero_boundary)
145
+ for cn in config_names:
146
+ cfg = CODECS[cn]
147
+ if cfg['mode'] == 'adaptive':
148
+ for d in (2, 3, 4):
149
+ needed_keys.add((d, None))
150
+ else:
151
+ needed_keys.add((cfg['depth'], cfg.get('zero_boundary')))
152
+
153
+ scales_per_key = {}
154
+ mse_per_key = {}
155
+ recon_per_key = {}
156
+ for d, zb in sorted(needed_keys, key=lambda x: (x[0], x[1] or 0)):
157
+ power = DEPTH_POWERS[d]
158
+ opt_s, _ = compute_optimal_scale(groups, d, power, zero_boundary=zb)
159
+ use_s = trit_quantize_scales(opt_s, 3)
160
+ r = quantize_with_scale(groups, use_s, d, power, zero_boundary=zb)
161
+ mse = np.mean((groups - r) ** 2, axis=1)
162
+ scales_per_key[(d, zb)] = use_s
163
+ mse_per_key[(d, zb)] = mse
164
+ recon_per_key[(d, zb)] = r
165
+
166
+ out = {}
167
+ for cn in config_names:
168
+ cfg = CODECS[cn]
169
+ if cfg['mode'] == 'uniform':
170
+ d = cfg['depth']
171
+ zb = cfg.get('zero_boundary')
172
+ recon = recon_per_key[(d, zb)]
173
+ depth_counts = {1:0, 2:0, 3:0, 4:0}
174
+ depth_counts[d] = N
175
+ wb = N * GS * d * np.log2(3)
176
+ sb = N * cfg['scale_depth'] * np.log2(3)
177
+ else: # adaptive
178
+ eff_thresh = cfg['threshold'] * 5.5
179
+ recon = np.zeros_like(groups)
180
+ assigned = np.zeros(N, dtype=bool)
181
+ depth_counts = {1:0, 2:0, 3:0, 4:0}
182
+ wb = 0.0; sb = 0.0
183
+ for d in [2, 3, 4]:
184
+ unassigned = ~assigned
185
+ if not np.any(unassigned):
186
+ break
187
+ if d == 4:
188
+ recon[unassigned] = recon_per_key[(4, None)][unassigned]
189
+ n_d = int(np.sum(unassigned))
190
+ depth_counts[d] = n_d
191
+ wb += n_d * GS * d * np.log2(3)
192
+ sb += n_d * cfg['scale_depth'] * np.log2(3)
193
+ break
194
+ mse_d = mse_per_key[(d, None)][unassigned]
195
+ meets = (mse_d / group_var[unassigned]) < eff_thresh
196
+ uidx = np.where(unassigned)[0]
197
+ midx = uidx[meets]
198
+ recon[midx] = recon_per_key[(d, None)][midx]
199
+ assigned[midx] = True
200
+ n_d = int(np.sum(meets))
201
+ depth_counts[d] = n_d
202
+ wb += n_d * GS * d * np.log2(3)
203
+ sb += n_d * cfg['scale_depth'] * np.log2(3)
204
+
205
+ recon_w = recon.reshape(rows, -1)[:, :cols].astype(np.float32)
206
+ out[cn] = {
207
+ 'recon_w': recon_w,
208
+ 'depth_counts': depth_counts,
209
+ 'weight_bits': float(wb),
210
+ 'scale_bits': float(sb),
211
+ 'n_groups': N,
212
+ }
213
+ return out
214
+
215
+ # ============================================================
216
+ # CHECKPOINTING
217
+ # ============================================================
218
+ def matrix_ckpt_path(ckpt_dir, idx, name):
219
+ safe = name.replace('/', '__').replace('.', '_')
220
+ return os.path.join(ckpt_dir, f'matrix_{idx:05d}__{safe}.npz')
221
+
222
+ def atomic_save_npz(path, data):
223
+ # NOTE: np.savez_compressed silently appends '.npz' if missing.
224
+ # So we name the tmp file with .npz suffix to avoid surprise.
225
+ fd, tmp = tempfile.mkstemp(prefix='.tmp_', suffix='.npz', dir=os.path.dirname(path))
226
+ os.close(fd)
227
+ np.savez_compressed(tmp, **data)
228
+ os.replace(tmp, path)
229
+
230
+ def load_ckpt(path):
231
+ with np.load(path, allow_pickle=True) as z:
232
+ return {k: z[k] for k in z.files}
233
+
234
+ def write_progress(out_root, state):
235
+ path = os.path.join(out_root, 'progress.json')
236
+ fd, tmp = tempfile.mkstemp(prefix='.tmp_', dir=out_root)
237
+ with os.fdopen(fd, 'w') as f:
238
+ json.dump(state, f, indent=2)
239
+ os.replace(tmp, path)
240
+
241
+ # ============================================================
242
+ # MAIN
243
+ # ============================================================
244
+ def main():
245
+ parser = argparse.ArgumentParser(description='Multi-config ternary quantizer with checkpointing')
246
+ parser.add_argument('--model', required=True)
247
+ parser.add_argument('--configs', required=True,
248
+ help='Comma-separated codec names: ' + ','.join(CODECS.keys()))
249
+ parser.add_argument('--output', required=True, help='Output root dir')
250
+ parser.add_argument('--workers', type=int, default=1)
251
+ parser.add_argument('--dtype', default='float16', choices=['float16', 'bfloat16'])
252
+ parser.add_argument('--skip-assembly', action='store_true',
253
+ help='Quantize matrices and checkpoint only; skip final HF model assembly.')
254
+ parser.add_argument('--matrix-range', default=None,
255
+ help='Slice of matrices to process: "start:end" (0-indexed, end exclusive). '
256
+ 'Use to manually parallelize across processes/machines via shared checkpoint dir.')
257
+ args = parser.parse_args()
258
+
259
+ config_names = [c.strip() for c in args.configs.split(',') if c.strip()]
260
+ for cn in config_names:
261
+ if cn not in CODECS:
262
+ print(f'ERROR: unknown codec {cn}', file=sys.stderr); sys.exit(2)
263
+
264
+ os.makedirs(args.output, exist_ok=True)
265
+ ckpt_dir = os.path.join(args.output, '.checkpoint')
266
+ os.makedirs(ckpt_dir, exist_ok=True)
267
+
268
+ print(f'=== Quantizing {args.model} ===', flush=True)
269
+ print(f' configs: {config_names}', flush=True)
270
+ print(f' workers: {args.workers}', flush=True)
271
+
272
+ import torch
273
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel
274
+
275
+ dtype = torch.bfloat16 if args.dtype == 'bfloat16' else torch.float16
276
+ print(' loading model (CPU)...', flush=True)
277
+ t_load = time.time()
278
+ _cfg = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
279
+ _arch = ((getattr(_cfg, 'architectures', None) or [''])[0] or '').lower()
280
+ if 't5' in _arch or 'encoder' in _arch:
281
+ from transformers import T5EncoderModel
282
+ print(' loading as T5EncoderModel (encoder-only)', flush=True)
283
+ model = T5EncoderModel.from_pretrained(args.model, torch_dtype=dtype,
284
+ device_map='cpu', trust_remote_code=True,
285
+ low_cpu_mem_usage=True)
286
+ else:
287
+ try:
288
+ model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype,
289
+ device_map='cpu', trust_remote_code=True,
290
+ low_cpu_mem_usage=True)
291
+ except ValueError:
292
+ print(' fallback to generic AutoModel', flush=True)
293
+ model = AutoModel.from_pretrained(args.model, torch_dtype=dtype,
294
+ device_map='cpu', trust_remote_code=True,
295
+ low_cpu_mem_usage=True)
296
+ try:
297
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
298
+ if tokenizer.pad_token is None:
299
+ tokenizer.pad_token = tokenizer.eos_token
300
+ except Exception as e:
301
+ print(f' tokenizer load failed (ok for encoder-only): {e}', flush=True)
302
+ tokenizer = None
303
+ print(f' loaded in {time.time()-t_load:.0f}s', flush=True)
304
+
305
+ # Collect matrices to quantize (skip embeddings, norms, lm_head)
306
+ matrices = []
307
+ for pn, p in model.named_parameters():
308
+ if p.dim() != 2 or 'norm' in pn or 'embed' in pn or 'lm_head' in pn:
309
+ continue
310
+ matrices.append((pn, p))
311
+ print(f' {len(matrices)} matrices to quantize', flush=True)
312
+
313
+ # Apply --matrix-range slice (for parallel sharded processing)
314
+ range_start, range_end = 0, len(matrices)
315
+ if args.matrix_range:
316
+ s, e = args.matrix_range.split(':')
317
+ range_start = int(s) if s else 0
318
+ range_end = int(e) if e else len(matrices)
319
+ range_end = min(range_end, len(matrices))
320
+ print(f' matrix-range: [{range_start}:{range_end})', flush=True)
321
+
322
+ # Determine which need work (resume from checkpoints)
323
+ todo = []
324
+ done_count = 0
325
+ for idx, (pn, p) in enumerate(matrices):
326
+ if idx < range_start or idx >= range_end:
327
+ continue
328
+ cp = matrix_ckpt_path(ckpt_dir, idx, pn)
329
+ if os.path.exists(cp):
330
+ # Validate it has all requested configs
331
+ try:
332
+ z = np.load(cp, allow_pickle=True)
333
+ have_configs = set(json.loads(str(z['_meta'][()])).get('configs', []))
334
+ if all(cn in have_configs for cn in config_names):
335
+ done_count += 1
336
+ continue
337
+ except Exception as e:
338
+ print(f' bad checkpoint {cp}: {e}, will redo', flush=True)
339
+ os.remove(cp)
340
+ todo.append((idx, pn, p))
341
+ print(f' {done_count} matrices already checkpointed, {len(todo)} to do', flush=True)
342
+
343
+ t0 = time.time()
344
+ state = {
345
+ 'model': args.model, 'configs': config_names,
346
+ 'total_matrices': len(matrices),
347
+ 'done_matrices': done_count,
348
+ 'started_at': t0, 'updated_at': t0,
349
+ }
350
+ write_progress(args.output, state)
351
+
352
+ def process_one(idx, pn, p):
353
+ w = p.data.float().numpy()
354
+ result = quantize_matrix_multi(
355
+ (w.ravel(), w.shape[0], w.shape[1], config_names))
356
+ # Pack into npz: one key per config + meta
357
+ save_data = {'_meta': np.array(json.dumps({
358
+ 'name': pn, 'idx': idx, 'shape': list(w.shape),
359
+ 'configs': config_names,
360
+ }))}
361
+ for cn, info in result.items():
362
+ save_data[f'{cn}__w'] = info['recon_w']
363
+ save_data[f'{cn}__stats'] = np.array(json.dumps({
364
+ 'depth_counts': info['depth_counts'],
365
+ 'weight_bits': info['weight_bits'],
366
+ 'scale_bits': info['scale_bits'],
367
+ 'n_groups': info['n_groups'],
368
+ }))
369
+ atomic_save_npz(matrix_ckpt_path(ckpt_dir, idx, pn), save_data)
370
+ return idx
371
+
372
+ if args.workers > 1 and len(todo) > 1:
373
+ # Streaming generator: yield (matrix, config_names) one at a time.
374
+ # CRITICAL: do NOT pre-build all matrices in a list — for large models
375
+ # (Llama 70B = 140GB) that OOMs the box at multiple hundred GB. The generator
376
+ # is consumed lazily by Pool.imap.
377
+ idx_name = [(idx, pn) for idx, pn, _ in todo]
378
+ def gen():
379
+ for idx, pn, p in todo:
380
+ w = p.data.float().numpy()
381
+ yield (w.ravel(), w.shape[0], w.shape[1], config_names)
382
+ # Free the source tensor after we've handed off the numpy view.
383
+ # The Pool worker has its own copy via pickle.
384
+ p.data = __import__('torch').zeros(1, dtype=p.dtype)
385
+ with Pool(args.workers) as pool:
386
+ for i, result in enumerate(pool.imap(quantize_matrix_multi, gen(), chunksize=1)):
387
+ idx, pn = idx_name[i]
388
+ save_data = {'_meta': np.array(json.dumps({
389
+ 'name': pn, 'idx': idx,
390
+ 'configs': config_names,
391
+ }))}
392
+ for cn, info in result.items():
393
+ save_data[f'{cn}__w'] = info['recon_w']
394
+ save_data[f'{cn}__stats'] = np.array(json.dumps({
395
+ 'depth_counts': info['depth_counts'],
396
+ 'weight_bits': info['weight_bits'],
397
+ 'scale_bits': info['scale_bits'],
398
+ 'n_groups': info['n_groups'],
399
+ }))
400
+ atomic_save_npz(matrix_ckpt_path(ckpt_dir, idx, pn), save_data)
401
+ done_count += 1
402
+ state['done_matrices'] = done_count
403
+ state['updated_at'] = time.time()
404
+ state['elapsed_s'] = time.time() - t0
405
+ if (i+1) % 5 == 0 or (i+1) == len(todo):
406
+ write_progress(args.output, state)
407
+ eta = (len(todo) - (i+1)) * (time.time() - t0) / max(i+1, 1)
408
+ print(f' {done_count}/{len(matrices)} ({time.time()-t0:.0f}s, ETA {eta:.0f}s)', flush=True)
409
+ else:
410
+ for i, (idx, pn, p) in enumerate(todo):
411
+ process_one(idx, pn, p)
412
+ done_count += 1
413
+ state['done_matrices'] = done_count
414
+ state['updated_at'] = time.time()
415
+ state['elapsed_s'] = time.time() - t0
416
+ if (i+1) % 5 == 0 or (i+1) == len(todo):
417
+ write_progress(args.output, state)
418
+ eta = (len(todo) - (i+1)) * (time.time() - t0) / max(i+1, 1)
419
+ print(f' {done_count}/{len(matrices)} ({time.time()-t0:.0f}s, ETA {eta:.0f}s)', flush=True)
420
+
421
+ print(f' Quantization complete in {time.time()-t0:.0f}s', flush=True)
422
+
423
+ # If we processed only a slice, don't assemble — leave that for the merge step.
424
+ if args.matrix_range:
425
+ # Verify which checkpoints exist for this slice; print summary
426
+ slice_done = sum(1 for idx, (pn, p) in enumerate(matrices)
427
+ if range_start <= idx < range_end
428
+ and os.path.exists(matrix_ckpt_path(ckpt_dir, idx, pn)))
429
+ print(f' slice [{range_start}:{range_end}): {slice_done} checkpointed', flush=True)
430
+ return
431
+
432
+ if args.skip_assembly:
433
+ print(' --skip-assembly: not building HF model dirs', flush=True)
434
+ return
435
+
436
+ # ============================================================
437
+ # ASSEMBLY: load each config from checkpoints, write HF model
438
+ # ============================================================
439
+ print(' Assembling HF models per config...', flush=True)
440
+ for cn in config_names:
441
+ cfg_dir = os.path.join(args.output, cn)
442
+ os.makedirs(cfg_dir, exist_ok=True)
443
+ model_dir = os.path.join(cfg_dir, 'model')
444
+
445
+ # Aggregate stats
446
+ total_groups = 0
447
+ total_depth = {1:0, 2:0, 3:0, 4:0}
448
+ total_wb = 0.0; total_sb = 0.0
449
+
450
+ # Replace tensors in-place with this config's reconstruction
451
+ name_to_param = {pn: p for pn, p in matrices}
452
+ for idx, (pn, p) in enumerate(matrices):
453
+ cp = matrix_ckpt_path(ckpt_dir, idx, pn)
454
+ z = np.load(cp, allow_pickle=True)
455
+ recon_w = z[f'{cn}__w']
456
+ stats = json.loads(str(z[f'{cn}__stats'][()]))
457
+ p.data = __import__('torch').from_numpy(recon_w).to(p.dtype)
458
+ total_groups += stats['n_groups']
459
+ for d in [1,2,3,4]:
460
+ total_depth[d] += stats['depth_counts'].get(str(d), stats['depth_counts'].get(d, 0))
461
+ total_wb += stats['weight_bits']
462
+ total_sb += stats['scale_bits']
463
+
464
+ tg = max(total_groups, 1)
465
+ trit_bpw = total_wb / (tg * GS)
466
+ scale_bpw = total_sb / (tg * GS)
467
+ total_bpw = trit_bpw + scale_bpw
468
+
469
+ print(f' [{cn}] BPW={total_bpw:.3f} (trit={trit_bpw:.3f}+scale={scale_bpw:.3f})', flush=True)
470
+ print(f' [{cn}] Saving to {model_dir}...', flush=True)
471
+ model.save_pretrained(model_dir, safe_serialization=True)
472
+ if tokenizer is not None:
473
+ tokenizer.save_pretrained(model_dir)
474
+
475
+ config = {
476
+ 'model': os.path.basename(args.model.rstrip('/')),
477
+ 'codec': cn,
478
+ 'bpw': total_bpw, 'trit_bpw': trit_bpw, 'scale_bpw': scale_bpw,
479
+ 'depth_pcts': {str(d): total_depth[d]/tg for d in [1,2,3,4]},
480
+ 'n_matrices': len(matrices),
481
+ 'group_size': GS,
482
+ 'lm_head_skipped': True,
483
+ 'codec_params': CODECS[cn],
484
+ }
485
+ with open(os.path.join(cfg_dir, 'config.json'), 'w') as f:
486
+ json.dump(config, f, indent=2)
487
+ print(f' [{cn}] DONE: {cfg_dir}', flush=True)
488
+
489
+ print(f' ALL CONFIGS COMPLETE in {time.time()-t0:.0f}s total', flush=True)
490
+
491
+ if __name__ == '__main__':
492
+ main()