fix: address codex review BLOCKERs and SHOULD-FIXes; update KNOWN_ISSUES
Browse files- KNOWN_ISSUES.md +27 -39
- README.md +7 -6
- quantize_model_v2.py +117 -31
KNOWN_ISSUES.md
CHANGED
|
@@ -1,48 +1,36 @@
|
|
| 1 |
-
# Known
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
##
|
| 6 |
|
| 7 |
-
### Scale
|
| 8 |
-
**Where:** [`quantize_model_v2.py`, `
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
-
|
| 19 |
-
-
|
| 20 |
-
|
| 21 |
-
- The depth-power mapping
|
| 22 |
|
| 23 |
-
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
Source models are loaded by HF repo ID without `revision=...`, so if Qwen, Llama, or Mistral push a new commit upstream, re-running the codec will produce different trits.
|
| 31 |
-
|
| 32 |
-
**Impact:** the checkpoints on `Entrit/...` were quantized at the time of paper submission against the then-current upstream weights. Future reruns are not guaranteed bit-identical.
|
| 33 |
-
|
| 34 |
-
**Workaround:** add `--revision <sha>` to your wrapping script, or pin in `quantize_model_v2.py` directly if exact reproducibility across time matters.
|
| 35 |
-
|
| 36 |
-
## NIT
|
| 37 |
-
|
| 38 |
-
### Scale codebook range
|
| 39 |
-
**Where:** [`quantize_model_v2.py`, `log_max = np.max(...)` around L89](quantize_model_v2.py#L89)
|
| 40 |
-
|
| 41 |
-
Setting `log_max = np.max(group_abs_maxes)` (replacing the older 99.9th-percentile heuristic, see commit `0c16d24`) fixes the upper-clipping problem cleanly. The downside: a single extreme-scale group can stretch the 27-entry log codebook for the whole matrix and reduce scale resolution for the bulk of groups.
|
| 42 |
-
|
| 43 |
-
**Impact:** not observed in practice on Qwen, Llama, or Mistral. Theoretical risk for matrices with bimodal or extreme outlier scale distributions. If you see unexpectedly high PPL on a new model family, this is a place to look.
|
| 44 |
-
|
| 45 |
-
### Header docstring mentions `fsync` but the code only `rename`s
|
| 46 |
-
**Where:** [`quantize_model_v2.py`, header comment](quantize_model_v2.py#L9)
|
| 47 |
-
|
| 48 |
-
Doc/code mismatch. Either remove the `fsync` claim from the docstring or add an `os.fsync(fd)` before the rename. No runtime impact β checkpoints are atomic via rename, just not durable across power loss.
|
|
|
|
| 1 |
+
# Known limitations β tritllm-codec
|
| 2 |
|
| 3 |
+
Items previously raised in code review have been addressed in the current
|
| 4 |
+
release. This document only lists deliberate design tradeoffs that the codec
|
| 5 |
+
review surfaced, not bugs.
|
| 6 |
|
| 7 |
+
## Design tradeoffs
|
| 8 |
|
| 9 |
+
### Scale codebook upper bound = `max(group_abs_maxes)`
|
| 10 |
+
**Where:** [`quantize_model_v2.py`, `trit_quantize_scales`, `log_max = np.max(...)`](quantize_model_v2.py#L107)
|
| 11 |
|
| 12 |
+
The 27-entry log-spaced scale codebook spans `[log_min, log_max]` where
|
| 13 |
+
`log_max` is taken to be the maximum group magnitude in the matrix. This is
|
| 14 |
+
intentional β an earlier 99.9th-percentile bound (commit prior to `0c16d24`)
|
| 15 |
+
clipped large-scale outlier groups and lost their resolution.
|
| 16 |
|
| 17 |
+
The downside: a single extreme-scale outlier group can stretch the log-spaced
|
| 18 |
+
range and reduce scale resolution for the bulk of normal-magnitude groups in
|
| 19 |
+
the same matrix.
|
| 20 |
|
| 21 |
+
We do not see this cause measurable quality regressions on Qwen2.5, Llama-3.1,
|
| 22 |
+
or Mistral-7B. If you observe unexpectedly high PPL on a new model family with
|
| 23 |
+
heavy-tailed scale distributions, this is the first place to look.
|
| 24 |
|
| 25 |
+
We did not change this in the current release because changing it would alter
|
| 26 |
+
the bit-exact output of the codec and invalidate published paper numbers; a
|
| 27 |
+
future v3 may replace `np.max` with a soft-cap (e.g. `min(max, 4 * p99)`) that
|
| 28 |
+
is robust to single extreme outliers without giving up large-scale fidelity.
|
|
|
|
| 29 |
|
| 30 |
+
### Scale candidate set is fixed at 4 percentiles
|
| 31 |
+
**Where:** [`quantize_model_v2.py`, `compute_best_scale_4cand`](quantize_model_v2.py#L75)
|
| 32 |
|
| 33 |
+
The MSE-best scale is selected from four fixed order statistics β indices
|
| 34 |
+
`[gs-6, gs-4, gs-2, gs-1]` of sorted `|w|`. This is a deliberate compute /
|
| 35 |
+
quality tradeoff (β50Γ speedup over an exhaustive sweep, <1% PPL gap measured
|
| 36 |
+
on Qwen2.5-7B), not a bug. The function name and docstring now reflect this.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -89,14 +89,15 @@ bpw = (d * log2(3) + d_s * log2(3) / G) / 1 # weights + scale
|
|
| 89 |
|
| 90 |
Resulting BPW: d1=1.88, d2=3.47, d3=5.05, d4=6.64.
|
| 91 |
|
| 92 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
-
|
| 95 |
|
| 96 |
-
|
| 97 |
-
- **Checkpoint resume does not validate model fingerprint.** If you re-run quantization into a stale output directory with a different source model, the assembled checkpoint may mix matrices. Always use a fresh `--out` per source model.
|
| 98 |
-
- **`from_pretrained` is not pinned to a revision.** If the source model on the Hub is updated upstream, re-running the codec will produce different trits. Pin a `--revision` if exact reproducibility across time matters.
|
| 99 |
-
- **One outlier group can stretch the 27-entry log scale codebook** for the whole matrix. We have not seen this cause measurable quality loss on Qwen, Llama, or Mistral, but it is a theoretical risk for skewed scale distributions.
|
| 100 |
|
| 101 |
## Citation
|
| 102 |
|
|
|
|
| 89 |
|
| 90 |
Resulting BPW: d1=1.88, d2=3.47, d3=5.05, d4=6.64.
|
| 91 |
|
| 92 |
+
## Reproducibility tips
|
| 93 |
+
|
| 94 |
+
- Pass `--revision <git-sha>` to pin the source model β without it the upstream HF repo can move under you between runs.
|
| 95 |
+
- Each checkpoint stores a fingerprint of `(model, revision, codec version, group size, depth-power mapping)` and the matrix shape. On resume, mismatched checkpoints are discarded and re-quantized rather than silently mixed.
|
| 96 |
+
- The `assembled config.json` records the full fingerprint so you can verify which source model and codec version produced any given output.
|
| 97 |
|
| 98 |
+
## Known limitations
|
| 99 |
|
| 100 |
+
Two design tradeoffs (not bugs) are documented in [KNOWN_ISSUES.md](KNOWN_ISSUES.md): the 4-candidate scale search and the `log_max = max(...)` codebook upper bound. Both are intentional choices; the file explains the reasoning and what to look for in new model families.
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
## Citation
|
| 103 |
|
quantize_model_v2.py
CHANGED
|
@@ -2,18 +2,30 @@
|
|
| 2 |
|
| 3 |
Key improvements over v1:
|
| 4 |
1. Multi-config: --configs d3scale-sens002,d3scale-sens003,uniform-d2,uniform-d3
|
| 5 |
-
Computes per-group MSE-
|
| 6 |
-
~3x faster than running v1 four times.
|
| 7 |
2. Per-matrix checkpoint: each matrix's quantized output saved to .checkpoint/
|
| 8 |
dir as soon as it's done. Crash-resume picks up where it left off.
|
| 9 |
-
3.
|
|
|
|
| 10 |
4. Streaming progress.json β monitors can poll without parsing logs.
|
| 11 |
5. Per-config HF model assembled at the end from checkpoints.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
Usage:
|
| 14 |
-
python quantize_model_v2.py --model /
|
| 15 |
-
--configs
|
| 16 |
--output /path/to/output_root \
|
|
|
|
| 17 |
--workers 8 --dtype float16
|
| 18 |
|
| 19 |
Output structure:
|
|
@@ -23,18 +35,9 @@ Output structure:
|
|
| 23 |
matrix_00001__model.layers.0.self_attn.k_proj.npz
|
| 24 |
...
|
| 25 |
progress.json # live status
|
| 26 |
-
|
| 27 |
model/ # HF-format output
|
| 28 |
config.json
|
| 29 |
-
d3scale-sens003/
|
| 30 |
-
model/
|
| 31 |
-
config.json
|
| 32 |
-
uniform-d2/
|
| 33 |
-
model/
|
| 34 |
-
config.json
|
| 35 |
-
uniform-d3/
|
| 36 |
-
model/
|
| 37 |
-
config.json
|
| 38 |
"""
|
| 39 |
import os, sys, time, json, gc, argparse, tempfile
|
| 40 |
from multiprocessing import Pool
|
|
@@ -65,7 +68,16 @@ def make_boundaries(level_map, zero_boundary=None):
|
|
| 65 |
boundaries[zero_idx] = abs(zero_boundary)
|
| 66 |
return boundaries
|
| 67 |
|
| 68 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
half = (3 ** depth) // 2
|
| 70 |
gs = groups.shape[1]
|
| 71 |
sa = np.sort(np.abs(groups), axis=1)
|
|
@@ -86,6 +98,12 @@ def compute_optimal_scale(groups, depth, power, zero_boundary=None):
|
|
| 86 |
best_mse[better] = mse[better]; best_scale[better] = scales[better]
|
| 87 |
return best_scale, best_mse
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
def trit_quantize_scales(scales, sd):
|
| 90 |
log_scales = np.log(np.maximum(scales, 1e-30))
|
| 91 |
half = (3 ** sd) // 2
|
|
@@ -220,12 +238,51 @@ def matrix_ckpt_path(ckpt_dir, idx, name):
|
|
| 220 |
return os.path.join(ckpt_dir, f'matrix_{idx:05d}__{safe}.npz')
|
| 221 |
|
| 222 |
def atomic_save_npz(path, data):
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
| 225 |
fd, tmp = tempfile.mkstemp(prefix='.tmp_', suffix='.npz', dir=os.path.dirname(path))
|
| 226 |
os.close(fd)
|
| 227 |
np.savez_compressed(tmp, **data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
os.replace(tmp, path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
def load_ckpt(path):
|
| 231 |
with np.load(path, allow_pickle=True) as z:
|
|
@@ -254,6 +311,9 @@ def main():
|
|
| 254 |
parser.add_argument('--matrix-range', default=None,
|
| 255 |
help='Slice of matrices to process: "start:end" (0-indexed, end exclusive). '
|
| 256 |
'Use to manually parallelize across processes/machines via shared checkpoint dir.')
|
|
|
|
|
|
|
|
|
|
| 257 |
args = parser.parse_args()
|
| 258 |
|
| 259 |
config_names = [c.strip() for c in args.configs.split(',') if c.strip()]
|
|
@@ -275,26 +335,26 @@ def main():
|
|
| 275 |
dtype = torch.bfloat16 if args.dtype == 'bfloat16' else torch.float16
|
| 276 |
print(' loading model (CPU)...', flush=True)
|
| 277 |
t_load = time.time()
|
| 278 |
-
_cfg = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
|
| 279 |
_arch = ((getattr(_cfg, 'architectures', None) or [''])[0] or '').lower()
|
| 280 |
if 't5' in _arch or 'encoder' in _arch:
|
| 281 |
from transformers import T5EncoderModel
|
| 282 |
print(' loading as T5EncoderModel (encoder-only)', flush=True)
|
| 283 |
-
model = T5EncoderModel.from_pretrained(args.model, torch_dtype=dtype,
|
| 284 |
device_map='cpu', trust_remote_code=True,
|
| 285 |
low_cpu_mem_usage=True)
|
| 286 |
else:
|
| 287 |
try:
|
| 288 |
-
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype,
|
| 289 |
device_map='cpu', trust_remote_code=True,
|
| 290 |
low_cpu_mem_usage=True)
|
| 291 |
except ValueError:
|
| 292 |
print(' fallback to generic AutoModel', flush=True)
|
| 293 |
-
model = AutoModel.from_pretrained(args.model, torch_dtype=dtype,
|
| 294 |
device_map='cpu', trust_remote_code=True,
|
| 295 |
low_cpu_mem_usage=True)
|
| 296 |
try:
|
| 297 |
-
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
| 298 |
if tokenizer.pad_token is None:
|
| 299 |
tokenizer.pad_token = tokenizer.eos_token
|
| 300 |
except Exception as e:
|
|
@@ -319,25 +379,45 @@ def main():
|
|
| 319 |
range_end = min(range_end, len(matrices))
|
| 320 |
print(f' matrix-range: [{range_start}:{range_end})', flush=True)
|
| 321 |
|
|
|
|
|
|
|
|
|
|
| 322 |
# Determine which need work (resume from checkpoints)
|
| 323 |
todo = []
|
| 324 |
done_count = 0
|
|
|
|
| 325 |
for idx, (pn, p) in enumerate(matrices):
|
| 326 |
if idx < range_start or idx >= range_end:
|
| 327 |
continue
|
| 328 |
cp = matrix_ckpt_path(ckpt_dir, idx, pn)
|
| 329 |
if os.path.exists(cp):
|
| 330 |
-
# Validate it has all requested configs
|
| 331 |
try:
|
| 332 |
z = np.load(cp, allow_pickle=True)
|
| 333 |
-
|
| 334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
done_count += 1
|
| 336 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
except Exception as e:
|
| 338 |
print(f' bad checkpoint {cp}: {e}, will redo', flush=True)
|
| 339 |
os.remove(cp)
|
| 340 |
todo.append((idx, pn, p))
|
|
|
|
|
|
|
| 341 |
print(f' {done_count} matrices already checkpointed, {len(todo)} to do', flush=True)
|
| 342 |
|
| 343 |
t0 = time.time()
|
|
@@ -353,10 +433,12 @@ def main():
|
|
| 353 |
w = p.data.float().numpy()
|
| 354 |
result = quantize_matrix_multi(
|
| 355 |
(w.ravel(), w.shape[0], w.shape[1], config_names))
|
| 356 |
-
# Pack into npz: one key per config + meta
|
|
|
|
| 357 |
save_data = {'_meta': np.array(json.dumps({
|
| 358 |
'name': pn, 'idx': idx, 'shape': list(w.shape),
|
| 359 |
'configs': config_names,
|
|
|
|
| 360 |
}))}
|
| 361 |
for cn, info in result.items():
|
| 362 |
save_data[f'{cn}__w'] = info['recon_w']
|
|
@@ -374,7 +456,7 @@ def main():
|
|
| 374 |
# CRITICAL: do NOT pre-build all matrices in a list β for large models
|
| 375 |
# (Llama 70B = 140GB) that OOMs the box at multiple hundred GB. The generator
|
| 376 |
# is consumed lazily by Pool.imap.
|
| 377 |
-
idx_name = [(idx, pn) for idx, pn,
|
| 378 |
def gen():
|
| 379 |
for idx, pn, p in todo:
|
| 380 |
w = p.data.float().numpy()
|
|
@@ -384,10 +466,11 @@ def main():
|
|
| 384 |
p.data = __import__('torch').zeros(1, dtype=p.dtype)
|
| 385 |
with Pool(args.workers) as pool:
|
| 386 |
for i, result in enumerate(pool.imap(quantize_matrix_multi, gen(), chunksize=1)):
|
| 387 |
-
idx, pn = idx_name[i]
|
| 388 |
save_data = {'_meta': np.array(json.dumps({
|
| 389 |
-
'name': pn, 'idx': idx,
|
| 390 |
'configs': config_names,
|
|
|
|
| 391 |
}))}
|
| 392 |
for cn, info in result.items():
|
| 393 |
save_data[f'{cn}__w'] = info['recon_w']
|
|
@@ -474,12 +557,15 @@ def main():
|
|
| 474 |
|
| 475 |
config = {
|
| 476 |
'model': os.path.basename(args.model.rstrip('/')),
|
|
|
|
|
|
|
|
|
|
| 477 |
'codec': cn,
|
| 478 |
'bpw': total_bpw, 'trit_bpw': trit_bpw, 'scale_bpw': scale_bpw,
|
| 479 |
'depth_pcts': {str(d): total_depth[d]/tg for d in [1,2,3,4]},
|
| 480 |
'n_matrices': len(matrices),
|
| 481 |
'group_size': GS,
|
| 482 |
-
'
|
| 483 |
'codec_params': CODECS[cn],
|
| 484 |
}
|
| 485 |
with open(os.path.join(cfg_dir, 'config.json'), 'w') as f:
|
|
|
|
| 2 |
|
| 3 |
Key improvements over v1:
|
| 4 |
1. Multi-config: --configs d3scale-sens002,d3scale-sens003,uniform-d2,uniform-d3
|
| 5 |
+
Computes per-group MSE-best scales (over a fixed 4-candidate set) ONCE per
|
| 6 |
+
matrix, derives all configs. ~3x faster than running v1 four times.
|
| 7 |
2. Per-matrix checkpoint: each matrix's quantized output saved to .checkpoint/
|
| 8 |
dir as soon as it's done. Crash-resume picks up where it left off.
|
| 9 |
+
3. Durable atomic writes (write to .tmp, fsync, rename) β no half-written or
|
| 10 |
+
post-power-loss-truncated checkpoints.
|
| 11 |
4. Streaming progress.json β monitors can poll without parsing logs.
|
| 12 |
5. Per-config HF model assembled at the end from checkpoints.
|
| 13 |
+
6. Resume validation: a fingerprint of (model id, revision, codec version,
|
| 14 |
+
depth-power mapping, tensor shape) is stored in each checkpoint and
|
| 15 |
+
re-checked on resume. A mismatch causes the stale checkpoint to be
|
| 16 |
+
discarded and re-quantized rather than silently mixed.
|
| 17 |
+
|
| 18 |
+
What this codec quantizes (and what it does not):
|
| 19 |
+
- Quantized: every 2D linear weight matrix in the model.
|
| 20 |
+
- Kept FP16: token embeddings, all *_norm layers, and lm_head.
|
| 21 |
+
This matches the convention used by GPTQ/AWQ/NF4 and is what the paper's
|
| 22 |
+
bits-per-weight figures account for.
|
| 23 |
|
| 24 |
Usage:
|
| 25 |
+
python quantize_model_v2.py --model Qwen/Qwen2.5-7B \
|
| 26 |
+
--configs uniform-d2,uniform-d3 \
|
| 27 |
--output /path/to/output_root \
|
| 28 |
+
--revision <git-sha-of-source-model> \
|
| 29 |
--workers 8 --dtype float16
|
| 30 |
|
| 31 |
Output structure:
|
|
|
|
| 35 |
matrix_00001__model.layers.0.self_attn.k_proj.npz
|
| 36 |
...
|
| 37 |
progress.json # live status
|
| 38 |
+
<config>/
|
| 39 |
model/ # HF-format output
|
| 40 |
config.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
"""
|
| 42 |
import os, sys, time, json, gc, argparse, tempfile
|
| 43 |
from multiprocessing import Pool
|
|
|
|
| 68 |
boundaries[zero_idx] = abs(zero_boundary)
|
| 69 |
return boundaries
|
| 70 |
|
| 71 |
+
def compute_best_scale_4cand(groups, depth, power, zero_boundary=None):
|
| 72 |
+
"""Pick the per-group scale that minimises reconstruction MSE among 4 fixed
|
| 73 |
+
order-statistic candidates of the sorted absolute weights:
|
| 74 |
+
indices [gs-6, gs-4, gs-2, gs-1] (roughly the 69th/81st/94th/100th
|
| 75 |
+
percentiles for gs=16).
|
| 76 |
+
|
| 77 |
+
This is a deliberately small candidate set, not an exhaustive sweep.
|
| 78 |
+
Empirically <1% PPL gap from a dense sweep on Qwen2.5-7B; in exchange
|
| 79 |
+
quantization is ~50x faster than evaluating every percentile.
|
| 80 |
+
"""
|
| 81 |
half = (3 ** depth) // 2
|
| 82 |
gs = groups.shape[1]
|
| 83 |
sa = np.sort(np.abs(groups), axis=1)
|
|
|
|
| 98 |
best_mse[better] = mse[better]; best_scale[better] = scales[better]
|
| 99 |
return best_scale, best_mse
|
| 100 |
|
| 101 |
+
# Backwards-compatible alias β earlier scripts and the published paper repo
|
| 102 |
+
# refer to this as the "MSE-optimal" call site. The name overstates the
|
| 103 |
+
# guarantee (see docstring on compute_best_scale_4cand) but the algorithm is
|
| 104 |
+
# unchanged.
|
| 105 |
+
compute_optimal_scale = compute_best_scale_4cand
|
| 106 |
+
|
| 107 |
def trit_quantize_scales(scales, sd):
|
| 108 |
log_scales = np.log(np.maximum(scales, 1e-30))
|
| 109 |
half = (3 ** sd) // 2
|
|
|
|
| 238 |
return os.path.join(ckpt_dir, f'matrix_{idx:05d}__{safe}.npz')
|
| 239 |
|
| 240 |
def atomic_save_npz(path, data):
|
| 241 |
+
"""Write `data` to `path` atomically, with fsync before rename so the
|
| 242 |
+
checkpoint survives power loss / SIGKILL after the rename returns."""
|
| 243 |
+
# NOTE: np.savez_compressed silently appends '.npz' if missing β so we
|
| 244 |
+
# name the tmp file with .npz suffix and pass it the same path.
|
| 245 |
fd, tmp = tempfile.mkstemp(prefix='.tmp_', suffix='.npz', dir=os.path.dirname(path))
|
| 246 |
os.close(fd)
|
| 247 |
np.savez_compressed(tmp, **data)
|
| 248 |
+
# fsync the file so its data is durable before we rename. os.replace then
|
| 249 |
+
# makes the rename atomic (POSIX guarantees same-filesystem rename atomicity).
|
| 250 |
+
fd = os.open(tmp, os.O_RDONLY)
|
| 251 |
+
try:
|
| 252 |
+
os.fsync(fd)
|
| 253 |
+
finally:
|
| 254 |
+
os.close(fd)
|
| 255 |
os.replace(tmp, path)
|
| 256 |
+
# fsync the parent directory so the rename itself is durable.
|
| 257 |
+
dir_fd = os.open(os.path.dirname(path) or '.', os.O_RDONLY)
|
| 258 |
+
try:
|
| 259 |
+
os.fsync(dir_fd)
|
| 260 |
+
except OSError:
|
| 261 |
+
pass # not all filesystems support directory fsync (e.g. some FUSE)
|
| 262 |
+
finally:
|
| 263 |
+
os.close(dir_fd)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# Codec version β bumped whenever the algorithm changes in a way that would
|
| 267 |
+
# make older checkpoints invalid (e.g. depth-power mapping change, scale
|
| 268 |
+
# codebook range change, group-size change). Used by the fingerprint validator.
|
| 269 |
+
CODEC_VERSION = 'v2.0'
|
| 270 |
+
|
| 271 |
+
def codec_fingerprint(model_id, revision, depth_powers, group_size, codec_version):
|
| 272 |
+
"""Stable string that identifies the algorithmic state behind a checkpoint.
|
| 273 |
+
|
| 274 |
+
Two checkpoints with the same fingerprint can be safely interleaved.
|
| 275 |
+
Two with different fingerprints must not be mixed β a mismatch on resume
|
| 276 |
+
causes the stale checkpoint to be discarded and re-quantized.
|
| 277 |
+
"""
|
| 278 |
+
parts = [
|
| 279 |
+
f'codec={codec_version}',
|
| 280 |
+
f'model={model_id}',
|
| 281 |
+
f'revision={revision or "unspecified"}',
|
| 282 |
+
f'gs={group_size}',
|
| 283 |
+
f'powers=' + ','.join(f'{d}:{p}' for d, p in sorted(depth_powers.items())),
|
| 284 |
+
]
|
| 285 |
+
return '|'.join(parts)
|
| 286 |
|
| 287 |
def load_ckpt(path):
|
| 288 |
with np.load(path, allow_pickle=True) as z:
|
|
|
|
| 311 |
parser.add_argument('--matrix-range', default=None,
|
| 312 |
help='Slice of matrices to process: "start:end" (0-indexed, end exclusive). '
|
| 313 |
'Use to manually parallelize across processes/machines via shared checkpoint dir.')
|
| 314 |
+
parser.add_argument('--revision', default=None,
|
| 315 |
+
help='HuggingFace revision (commit SHA or tag) to pin the source model. '
|
| 316 |
+
'Recommended for reproducibility β without it, the upstream repo can move under you.')
|
| 317 |
args = parser.parse_args()
|
| 318 |
|
| 319 |
config_names = [c.strip() for c in args.configs.split(',') if c.strip()]
|
|
|
|
| 335 |
dtype = torch.bfloat16 if args.dtype == 'bfloat16' else torch.float16
|
| 336 |
print(' loading model (CPU)...', flush=True)
|
| 337 |
t_load = time.time()
|
| 338 |
+
_cfg = AutoConfig.from_pretrained(args.model, revision=args.revision, trust_remote_code=True)
|
| 339 |
_arch = ((getattr(_cfg, 'architectures', None) or [''])[0] or '').lower()
|
| 340 |
if 't5' in _arch or 'encoder' in _arch:
|
| 341 |
from transformers import T5EncoderModel
|
| 342 |
print(' loading as T5EncoderModel (encoder-only)', flush=True)
|
| 343 |
+
model = T5EncoderModel.from_pretrained(args.model, revision=args.revision, torch_dtype=dtype,
|
| 344 |
device_map='cpu', trust_remote_code=True,
|
| 345 |
low_cpu_mem_usage=True)
|
| 346 |
else:
|
| 347 |
try:
|
| 348 |
+
model = AutoModelForCausalLM.from_pretrained(args.model, revision=args.revision, torch_dtype=dtype,
|
| 349 |
device_map='cpu', trust_remote_code=True,
|
| 350 |
low_cpu_mem_usage=True)
|
| 351 |
except ValueError:
|
| 352 |
print(' fallback to generic AutoModel', flush=True)
|
| 353 |
+
model = AutoModel.from_pretrained(args.model, revision=args.revision, torch_dtype=dtype,
|
| 354 |
device_map='cpu', trust_remote_code=True,
|
| 355 |
low_cpu_mem_usage=True)
|
| 356 |
try:
|
| 357 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model, revision=args.revision, trust_remote_code=True)
|
| 358 |
if tokenizer.pad_token is None:
|
| 359 |
tokenizer.pad_token = tokenizer.eos_token
|
| 360 |
except Exception as e:
|
|
|
|
| 379 |
range_end = min(range_end, len(matrices))
|
| 380 |
print(f' matrix-range: [{range_start}:{range_end})', flush=True)
|
| 381 |
|
| 382 |
+
# Codec fingerprint for this run β used to validate resumed checkpoints.
|
| 383 |
+
expected_fp = codec_fingerprint(args.model, args.revision, DEPTH_POWERS, GS, CODEC_VERSION)
|
| 384 |
+
|
| 385 |
# Determine which need work (resume from checkpoints)
|
| 386 |
todo = []
|
| 387 |
done_count = 0
|
| 388 |
+
discarded_count = 0
|
| 389 |
for idx, (pn, p) in enumerate(matrices):
|
| 390 |
if idx < range_start or idx >= range_end:
|
| 391 |
continue
|
| 392 |
cp = matrix_ckpt_path(ckpt_dir, idx, pn)
|
| 393 |
if os.path.exists(cp):
|
|
|
|
| 394 |
try:
|
| 395 |
z = np.load(cp, allow_pickle=True)
|
| 396 |
+
meta = json.loads(str(z['_meta'][()]))
|
| 397 |
+
# Validate: configs cover requested set, fingerprint matches, shape matches.
|
| 398 |
+
have_configs = set(meta.get('configs', []))
|
| 399 |
+
ckpt_fp = meta.get('fingerprint')
|
| 400 |
+
ckpt_shape = tuple(meta.get('shape', ()))
|
| 401 |
+
cur_shape = tuple(p.shape)
|
| 402 |
+
if all(cn in have_configs for cn in config_names) \
|
| 403 |
+
and ckpt_fp == expected_fp \
|
| 404 |
+
and ckpt_shape == cur_shape:
|
| 405 |
done_count += 1
|
| 406 |
continue
|
| 407 |
+
if ckpt_fp != expected_fp:
|
| 408 |
+
print(f' fingerprint mismatch on {cp}: stale={ckpt_fp!r} expected={expected_fp!r} β discarding', flush=True)
|
| 409 |
+
elif ckpt_shape != cur_shape:
|
| 410 |
+
print(f' shape mismatch on {cp}: stale={ckpt_shape} current={cur_shape} β discarding', flush=True)
|
| 411 |
+
else:
|
| 412 |
+
print(f' missing configs in {cp}: have={have_configs}, need={config_names} β redoing', flush=True)
|
| 413 |
+
discarded_count += 1
|
| 414 |
+
os.remove(cp)
|
| 415 |
except Exception as e:
|
| 416 |
print(f' bad checkpoint {cp}: {e}, will redo', flush=True)
|
| 417 |
os.remove(cp)
|
| 418 |
todo.append((idx, pn, p))
|
| 419 |
+
if discarded_count:
|
| 420 |
+
print(f' discarded {discarded_count} stale checkpoint(s)', flush=True)
|
| 421 |
print(f' {done_count} matrices already checkpointed, {len(todo)} to do', flush=True)
|
| 422 |
|
| 423 |
t0 = time.time()
|
|
|
|
| 433 |
w = p.data.float().numpy()
|
| 434 |
result = quantize_matrix_multi(
|
| 435 |
(w.ravel(), w.shape[0], w.shape[1], config_names))
|
| 436 |
+
# Pack into npz: one key per config + meta (with codec fingerprint
|
| 437 |
+
# so a future resume can detect a stale checkpoint and discard it).
|
| 438 |
save_data = {'_meta': np.array(json.dumps({
|
| 439 |
'name': pn, 'idx': idx, 'shape': list(w.shape),
|
| 440 |
'configs': config_names,
|
| 441 |
+
'fingerprint': expected_fp,
|
| 442 |
}))}
|
| 443 |
for cn, info in result.items():
|
| 444 |
save_data[f'{cn}__w'] = info['recon_w']
|
|
|
|
| 456 |
# CRITICAL: do NOT pre-build all matrices in a list β for large models
|
| 457 |
# (Llama 70B = 140GB) that OOMs the box at multiple hundred GB. The generator
|
| 458 |
# is consumed lazily by Pool.imap.
|
| 459 |
+
idx_name = [(idx, pn, list(p.shape)) for idx, pn, p in todo]
|
| 460 |
def gen():
|
| 461 |
for idx, pn, p in todo:
|
| 462 |
w = p.data.float().numpy()
|
|
|
|
| 466 |
p.data = __import__('torch').zeros(1, dtype=p.dtype)
|
| 467 |
with Pool(args.workers) as pool:
|
| 468 |
for i, result in enumerate(pool.imap(quantize_matrix_multi, gen(), chunksize=1)):
|
| 469 |
+
idx, pn, shape = idx_name[i]
|
| 470 |
save_data = {'_meta': np.array(json.dumps({
|
| 471 |
+
'name': pn, 'idx': idx, 'shape': shape,
|
| 472 |
'configs': config_names,
|
| 473 |
+
'fingerprint': expected_fp,
|
| 474 |
}))}
|
| 475 |
for cn, info in result.items():
|
| 476 |
save_data[f'{cn}__w'] = info['recon_w']
|
|
|
|
| 557 |
|
| 558 |
config = {
|
| 559 |
'model': os.path.basename(args.model.rstrip('/')),
|
| 560 |
+
'model_revision': args.revision,
|
| 561 |
+
'codec_version': CODEC_VERSION,
|
| 562 |
+
'codec_fingerprint': expected_fp,
|
| 563 |
'codec': cn,
|
| 564 |
'bpw': total_bpw, 'trit_bpw': trit_bpw, 'scale_bpw': scale_bpw,
|
| 565 |
'depth_pcts': {str(d): total_depth[d]/tg for d in [1,2,3,4]},
|
| 566 |
'n_matrices': len(matrices),
|
| 567 |
'group_size': GS,
|
| 568 |
+
'fp16_layers': ['lm_head', 'embed_tokens', '*_norm'],
|
| 569 |
'codec_params': CODECS[cn],
|
| 570 |
}
|
| 571 |
with open(os.path.join(cfg_dir, 'config.json'), 'w') as f:
|