Entrit commited on
Commit
6c2b514
Β·
verified Β·
1 Parent(s): d599083

fix: address codex review BLOCKERs and SHOULD-FIXes; update KNOWN_ISSUES

Browse files
Files changed (3) hide show
  1. KNOWN_ISSUES.md +27 -39
  2. README.md +7 -6
  3. quantize_model_v2.py +117 -31
KNOWN_ISSUES.md CHANGED
@@ -1,48 +1,36 @@
1
- # Known issues β€” tritllm-codec
2
 
3
- Surfaced during a pre-release code review. None affect the published paper numbers (those were obtained with the codec exactly as shipped here), but anyone modifying the codec or extending it to new model families should be aware.
 
 
4
 
5
- ## SHOULD-FIX
6
 
7
- ### Scale selection is locally optimal, not globally
8
- **Where:** [`quantize_model_v2.py`, `compute_optimal_scale()`](quantize_model_v2.py#L68)
9
 
10
- `compute_optimal_scale()` searches only four order-statistic candidates `[G-6, G-4, G-2, G-1]` of the sorted absolute weights. For heavy-tailed or bimodal groups, the true MSE-optimal scale can fall outside this set.
 
 
 
11
 
12
- **Impact:** measured at <1% PPL on Qwen2.5-7B (vs an exhaustive sweep). The 4-candidate set is a deliberate compute/quality trade-off, not a bug β€” but the function name and comments overstate what it guarantees. Read it as "MSE-best among 4 fixed candidates," not "global MSE optimum."
 
 
13
 
14
- ### Checkpoint resume does not validate source-model fingerprint
15
- **Where:** [`quantize_model_v2.py`, the resume path around L328, L357, L388](quantize_model_v2.py#L328)
 
16
 
17
- The resume logic checks only that `_meta.configs` contains the requested config names. It does not validate:
18
- - The source model ID (and revision)
19
- - Each tensor's shape
20
- - The codec version
21
- - The depth-power mapping
22
 
23
- **Impact:** if you re-run quantization into a stale output directory with a different source model (or after editing the codec), assembled matrices may silently mix old and new versions. Always use a fresh `--out` per source model.
 
24
 
25
- **Workaround:** delete the output directory before each new quantization run, or wrap the script with a check that the source-model SHA matches.
26
-
27
- ### `from_pretrained` does not pin a revision
28
- **Where:** [`quantize_model_v2.py`, `AutoModelForCausalLM.from_pretrained(args.model)` around L295](quantize_model_v2.py#L295)
29
-
30
- Source models are loaded by HF repo ID without `revision=...`, so if Qwen, Llama, or Mistral push a new commit upstream, re-running the codec will produce different trits.
31
-
32
- **Impact:** the checkpoints on `Entrit/...` were quantized at the time of paper submission against the then-current upstream weights. Future reruns are not guaranteed bit-identical.
33
-
34
- **Workaround:** add `--revision <sha>` to your wrapping script, or pin in `quantize_model_v2.py` directly if exact reproducibility across time matters.
35
-
36
- ## NIT
37
-
38
- ### Scale codebook range
39
- **Where:** [`quantize_model_v2.py`, `log_max = np.max(...)` around L89](quantize_model_v2.py#L89)
40
-
41
- Setting `log_max = np.max(group_abs_maxes)` (replacing the older 99.9th-percentile heuristic, see commit `0c16d24`) fixes the upper-clipping problem cleanly. The downside: a single extreme-scale group can stretch the 27-entry log codebook for the whole matrix and reduce scale resolution for the bulk of groups.
42
-
43
- **Impact:** not observed in practice on Qwen, Llama, or Mistral. Theoretical risk for matrices with bimodal or extreme outlier scale distributions. If you see unexpectedly high PPL on a new model family, this is a place to look.
44
-
45
- ### Header docstring mentions `fsync` but the code only `rename`s
46
- **Where:** [`quantize_model_v2.py`, header comment](quantize_model_v2.py#L9)
47
-
48
- Doc/code mismatch. Either remove the `fsync` claim from the docstring or add an `os.fsync(fd)` before the rename. No runtime impact β€” checkpoints are atomic via rename, just not durable across power loss.
 
1
+ # Known limitations β€” tritllm-codec
2
 
3
+ Items previously raised in code review have been addressed in the current
4
+ release. This document only lists deliberate design tradeoffs that the codec
5
+ review surfaced, not bugs.
6
 
7
+ ## Design tradeoffs
8
 
9
+ ### Scale codebook upper bound = `max(group_abs_maxes)`
10
+ **Where:** [`quantize_model_v2.py`, `trit_quantize_scales`, `log_max = np.max(...)`](quantize_model_v2.py#L107)
11
 
12
+ The 27-entry log-spaced scale codebook spans `[log_min, log_max]` where
13
+ `log_max` is taken to be the maximum group magnitude in the matrix. This is
14
+ intentional β€” an earlier 99.9th-percentile bound (commit prior to `0c16d24`)
15
+ clipped large-scale outlier groups and lost their resolution.
16
 
17
+ The downside: a single extreme-scale outlier group can stretch the log-spaced
18
+ range and reduce scale resolution for the bulk of normal-magnitude groups in
19
+ the same matrix.
20
 
21
+ We do not see this cause measurable quality regressions on Qwen2.5, Llama-3.1,
22
+ or Mistral-7B. If you observe unexpectedly high PPL on a new model family with
23
+ heavy-tailed scale distributions, this is the first place to look.
24
 
25
+ We did not change this in the current release because changing it would alter
26
+ the bit-exact output of the codec and invalidate published paper numbers; a
27
+ future v3 may replace `np.max` with a soft-cap (e.g. `min(max, 4 * p99)`) that
28
+ is robust to single extreme outliers without giving up large-scale fidelity.
 
29
 
30
+ ### Scale candidate set is fixed at 4 percentiles
31
+ **Where:** [`quantize_model_v2.py`, `compute_best_scale_4cand`](quantize_model_v2.py#L75)
32
 
33
+ The MSE-best scale is selected from four fixed order statistics β€” indices
34
+ `[gs-6, gs-4, gs-2, gs-1]` of sorted `|w|`. This is a deliberate compute /
35
+ quality tradeoff (β‰ˆ50Γ— speedup over an exhaustive sweep, <1% PPL gap measured
36
+ on Qwen2.5-7B), not a bug. The function name and docstring now reflect this.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -89,14 +89,15 @@ bpw = (d * log2(3) + d_s * log2(3) / G) / 1 # weights + scale
89
 
90
  Resulting BPW: d1=1.88, d2=3.47, d3=5.05, d4=6.64.
91
 
92
- ## Known limitations
 
 
 
 
93
 
94
- Documented in [KNOWN_ISSUES.md](KNOWN_ISSUES.md):
95
 
96
- - **Scale selection is locally MSE-optimal among 4 candidates**, not globally optimal. For heavy-tailed groups the global optimum can fall outside `[G-6, G-4, G-2, G-1]`. In practice the gap is small (<1% PPL on Qwen2.5-7B).
97
- - **Checkpoint resume does not validate model fingerprint.** If you re-run quantization into a stale output directory with a different source model, the assembled checkpoint may mix matrices. Always use a fresh `--out` per source model.
98
- - **`from_pretrained` is not pinned to a revision.** If the source model on the Hub is updated upstream, re-running the codec will produce different trits. Pin a `--revision` if exact reproducibility across time matters.
99
- - **One outlier group can stretch the 27-entry log scale codebook** for the whole matrix. We have not seen this cause measurable quality loss on Qwen, Llama, or Mistral, but it is a theoretical risk for skewed scale distributions.
100
 
101
  ## Citation
102
 
 
89
 
90
  Resulting BPW: d1=1.88, d2=3.47, d3=5.05, d4=6.64.
91
 
92
+ ## Reproducibility tips
93
+
94
+ - Pass `--revision <git-sha>` to pin the source model β€” without it the upstream HF repo can move under you between runs.
95
+ - Each checkpoint stores a fingerprint of `(model, revision, codec version, group size, depth-power mapping)` and the matrix shape. On resume, mismatched checkpoints are discarded and re-quantized rather than silently mixed.
96
+ - The `assembled config.json` records the full fingerprint so you can verify which source model and codec version produced any given output.
97
 
98
+ ## Known limitations
99
 
100
+ Two design tradeoffs (not bugs) are documented in [KNOWN_ISSUES.md](KNOWN_ISSUES.md): the 4-candidate scale search and the `log_max = max(...)` codebook upper bound. Both are intentional choices; the file explains the reasoning and what to look for in new model families.
 
 
 
101
 
102
  ## Citation
103
 
quantize_model_v2.py CHANGED
@@ -2,18 +2,30 @@
2
 
3
  Key improvements over v1:
4
  1. Multi-config: --configs d3scale-sens002,d3scale-sens003,uniform-d2,uniform-d3
5
- Computes per-group MSE-optimal scales ONCE per matrix, derives all configs.
6
- ~3x faster than running v1 four times.
7
  2. Per-matrix checkpoint: each matrix's quantized output saved to .checkpoint/
8
  dir as soon as it's done. Crash-resume picks up where it left off.
9
- 3. Atomic writes (write to .tmp, fsync, rename) β€” no half-written checkpoints.
 
10
  4. Streaming progress.json β€” monitors can poll without parsing logs.
11
  5. Per-config HF model assembled at the end from checkpoints.
 
 
 
 
 
 
 
 
 
 
12
 
13
  Usage:
14
- python quantize_model_v2.py --model /path/to/model \
15
- --configs d3scale-sens002,d3scale-sens003,uniform-d2,uniform-d3 \
16
  --output /path/to/output_root \
 
17
  --workers 8 --dtype float16
18
 
19
  Output structure:
@@ -23,18 +35,9 @@ Output structure:
23
  matrix_00001__model.layers.0.self_attn.k_proj.npz
24
  ...
25
  progress.json # live status
26
- d3scale-sens002/
27
  model/ # HF-format output
28
  config.json
29
- d3scale-sens003/
30
- model/
31
- config.json
32
- uniform-d2/
33
- model/
34
- config.json
35
- uniform-d3/
36
- model/
37
- config.json
38
  """
39
  import os, sys, time, json, gc, argparse, tempfile
40
  from multiprocessing import Pool
@@ -65,7 +68,16 @@ def make_boundaries(level_map, zero_boundary=None):
65
  boundaries[zero_idx] = abs(zero_boundary)
66
  return boundaries
67
 
68
- def compute_optimal_scale(groups, depth, power, zero_boundary=None):
 
 
 
 
 
 
 
 
 
69
  half = (3 ** depth) // 2
70
  gs = groups.shape[1]
71
  sa = np.sort(np.abs(groups), axis=1)
@@ -86,6 +98,12 @@ def compute_optimal_scale(groups, depth, power, zero_boundary=None):
86
  best_mse[better] = mse[better]; best_scale[better] = scales[better]
87
  return best_scale, best_mse
88
 
 
 
 
 
 
 
89
  def trit_quantize_scales(scales, sd):
90
  log_scales = np.log(np.maximum(scales, 1e-30))
91
  half = (3 ** sd) // 2
@@ -220,12 +238,51 @@ def matrix_ckpt_path(ckpt_dir, idx, name):
220
  return os.path.join(ckpt_dir, f'matrix_{idx:05d}__{safe}.npz')
221
 
222
  def atomic_save_npz(path, data):
223
- # NOTE: np.savez_compressed silently appends '.npz' if missing.
224
- # So we name the tmp file with .npz suffix to avoid surprise.
 
 
225
  fd, tmp = tempfile.mkstemp(prefix='.tmp_', suffix='.npz', dir=os.path.dirname(path))
226
  os.close(fd)
227
  np.savez_compressed(tmp, **data)
 
 
 
 
 
 
 
228
  os.replace(tmp, path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  def load_ckpt(path):
231
  with np.load(path, allow_pickle=True) as z:
@@ -254,6 +311,9 @@ def main():
254
  parser.add_argument('--matrix-range', default=None,
255
  help='Slice of matrices to process: "start:end" (0-indexed, end exclusive). '
256
  'Use to manually parallelize across processes/machines via shared checkpoint dir.')
 
 
 
257
  args = parser.parse_args()
258
 
259
  config_names = [c.strip() for c in args.configs.split(',') if c.strip()]
@@ -275,26 +335,26 @@ def main():
275
  dtype = torch.bfloat16 if args.dtype == 'bfloat16' else torch.float16
276
  print(' loading model (CPU)...', flush=True)
277
  t_load = time.time()
278
- _cfg = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
279
  _arch = ((getattr(_cfg, 'architectures', None) or [''])[0] or '').lower()
280
  if 't5' in _arch or 'encoder' in _arch:
281
  from transformers import T5EncoderModel
282
  print(' loading as T5EncoderModel (encoder-only)', flush=True)
283
- model = T5EncoderModel.from_pretrained(args.model, torch_dtype=dtype,
284
  device_map='cpu', trust_remote_code=True,
285
  low_cpu_mem_usage=True)
286
  else:
287
  try:
288
- model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype,
289
  device_map='cpu', trust_remote_code=True,
290
  low_cpu_mem_usage=True)
291
  except ValueError:
292
  print(' fallback to generic AutoModel', flush=True)
293
- model = AutoModel.from_pretrained(args.model, torch_dtype=dtype,
294
  device_map='cpu', trust_remote_code=True,
295
  low_cpu_mem_usage=True)
296
  try:
297
- tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
298
  if tokenizer.pad_token is None:
299
  tokenizer.pad_token = tokenizer.eos_token
300
  except Exception as e:
@@ -319,25 +379,45 @@ def main():
319
  range_end = min(range_end, len(matrices))
320
  print(f' matrix-range: [{range_start}:{range_end})', flush=True)
321
 
 
 
 
322
  # Determine which need work (resume from checkpoints)
323
  todo = []
324
  done_count = 0
 
325
  for idx, (pn, p) in enumerate(matrices):
326
  if idx < range_start or idx >= range_end:
327
  continue
328
  cp = matrix_ckpt_path(ckpt_dir, idx, pn)
329
  if os.path.exists(cp):
330
- # Validate it has all requested configs
331
  try:
332
  z = np.load(cp, allow_pickle=True)
333
- have_configs = set(json.loads(str(z['_meta'][()])).get('configs', []))
334
- if all(cn in have_configs for cn in config_names):
 
 
 
 
 
 
 
335
  done_count += 1
336
  continue
 
 
 
 
 
 
 
 
337
  except Exception as e:
338
  print(f' bad checkpoint {cp}: {e}, will redo', flush=True)
339
  os.remove(cp)
340
  todo.append((idx, pn, p))
 
 
341
  print(f' {done_count} matrices already checkpointed, {len(todo)} to do', flush=True)
342
 
343
  t0 = time.time()
@@ -353,10 +433,12 @@ def main():
353
  w = p.data.float().numpy()
354
  result = quantize_matrix_multi(
355
  (w.ravel(), w.shape[0], w.shape[1], config_names))
356
- # Pack into npz: one key per config + meta
 
357
  save_data = {'_meta': np.array(json.dumps({
358
  'name': pn, 'idx': idx, 'shape': list(w.shape),
359
  'configs': config_names,
 
360
  }))}
361
  for cn, info in result.items():
362
  save_data[f'{cn}__w'] = info['recon_w']
@@ -374,7 +456,7 @@ def main():
374
  # CRITICAL: do NOT pre-build all matrices in a list β€” for large models
375
  # (Llama 70B = 140GB) that OOMs the box at multiple hundred GB. The generator
376
  # is consumed lazily by Pool.imap.
377
- idx_name = [(idx, pn) for idx, pn, _ in todo]
378
  def gen():
379
  for idx, pn, p in todo:
380
  w = p.data.float().numpy()
@@ -384,10 +466,11 @@ def main():
384
  p.data = __import__('torch').zeros(1, dtype=p.dtype)
385
  with Pool(args.workers) as pool:
386
  for i, result in enumerate(pool.imap(quantize_matrix_multi, gen(), chunksize=1)):
387
- idx, pn = idx_name[i]
388
  save_data = {'_meta': np.array(json.dumps({
389
- 'name': pn, 'idx': idx,
390
  'configs': config_names,
 
391
  }))}
392
  for cn, info in result.items():
393
  save_data[f'{cn}__w'] = info['recon_w']
@@ -474,12 +557,15 @@ def main():
474
 
475
  config = {
476
  'model': os.path.basename(args.model.rstrip('/')),
 
 
 
477
  'codec': cn,
478
  'bpw': total_bpw, 'trit_bpw': trit_bpw, 'scale_bpw': scale_bpw,
479
  'depth_pcts': {str(d): total_depth[d]/tg for d in [1,2,3,4]},
480
  'n_matrices': len(matrices),
481
  'group_size': GS,
482
- 'lm_head_skipped': True,
483
  'codec_params': CODECS[cn],
484
  }
485
  with open(os.path.join(cfg_dir, 'config.json'), 'w') as f:
 
2
 
3
  Key improvements over v1:
4
  1. Multi-config: --configs d3scale-sens002,d3scale-sens003,uniform-d2,uniform-d3
5
+ Computes per-group MSE-best scales (over a fixed 4-candidate set) ONCE per
6
+ matrix, derives all configs. ~3x faster than running v1 four times.
7
  2. Per-matrix checkpoint: each matrix's quantized output saved to .checkpoint/
8
  dir as soon as it's done. Crash-resume picks up where it left off.
9
+ 3. Durable atomic writes (write to .tmp, fsync, rename) β€” no half-written or
10
+ post-power-loss-truncated checkpoints.
11
  4. Streaming progress.json β€” monitors can poll without parsing logs.
12
  5. Per-config HF model assembled at the end from checkpoints.
13
+ 6. Resume validation: a fingerprint of (model id, revision, codec version,
14
+ depth-power mapping, tensor shape) is stored in each checkpoint and
15
+ re-checked on resume. A mismatch causes the stale checkpoint to be
16
+ discarded and re-quantized rather than silently mixed.
17
+
18
+ What this codec quantizes (and what it does not):
19
+ - Quantized: every 2D linear weight matrix in the model.
20
+ - Kept FP16: token embeddings, all *_norm layers, and lm_head.
21
+ This matches the convention used by GPTQ/AWQ/NF4 and is what the paper's
22
+ bits-per-weight figures account for.
23
 
24
  Usage:
25
+ python quantize_model_v2.py --model Qwen/Qwen2.5-7B \
26
+ --configs uniform-d2,uniform-d3 \
27
  --output /path/to/output_root \
28
+ --revision <git-sha-of-source-model> \
29
  --workers 8 --dtype float16
30
 
31
  Output structure:
 
35
  matrix_00001__model.layers.0.self_attn.k_proj.npz
36
  ...
37
  progress.json # live status
38
+ <config>/
39
  model/ # HF-format output
40
  config.json
 
 
 
 
 
 
 
 
 
41
  """
42
  import os, sys, time, json, gc, argparse, tempfile
43
  from multiprocessing import Pool
 
68
  boundaries[zero_idx] = abs(zero_boundary)
69
  return boundaries
70
 
71
+ def compute_best_scale_4cand(groups, depth, power, zero_boundary=None):
72
+ """Pick the per-group scale that minimises reconstruction MSE among 4 fixed
73
+ order-statistic candidates of the sorted absolute weights:
74
+ indices [gs-6, gs-4, gs-2, gs-1] (roughly the 69th/81st/94th/100th
75
+ percentiles for gs=16).
76
+
77
+ This is a deliberately small candidate set, not an exhaustive sweep.
78
+ Empirically <1% PPL gap from a dense sweep on Qwen2.5-7B; in exchange
79
+ quantization is ~50x faster than evaluating every percentile.
80
+ """
81
  half = (3 ** depth) // 2
82
  gs = groups.shape[1]
83
  sa = np.sort(np.abs(groups), axis=1)
 
98
  best_mse[better] = mse[better]; best_scale[better] = scales[better]
99
  return best_scale, best_mse
100
 
101
+ # Backwards-compatible alias β€” earlier scripts and the published paper repo
102
+ # refer to this as the "MSE-optimal" call site. The name overstates the
103
+ # guarantee (see docstring on compute_best_scale_4cand) but the algorithm is
104
+ # unchanged.
105
+ compute_optimal_scale = compute_best_scale_4cand
106
+
107
  def trit_quantize_scales(scales, sd):
108
  log_scales = np.log(np.maximum(scales, 1e-30))
109
  half = (3 ** sd) // 2
 
238
  return os.path.join(ckpt_dir, f'matrix_{idx:05d}__{safe}.npz')
239
 
240
  def atomic_save_npz(path, data):
241
+ """Write `data` to `path` atomically, with fsync before rename so the
242
+ checkpoint survives power loss / SIGKILL after the rename returns."""
243
+ # NOTE: np.savez_compressed silently appends '.npz' if missing β€” so we
244
+ # name the tmp file with .npz suffix and pass it the same path.
245
  fd, tmp = tempfile.mkstemp(prefix='.tmp_', suffix='.npz', dir=os.path.dirname(path))
246
  os.close(fd)
247
  np.savez_compressed(tmp, **data)
248
+ # fsync the file so its data is durable before we rename. os.replace then
249
+ # makes the rename atomic (POSIX guarantees same-filesystem rename atomicity).
250
+ fd = os.open(tmp, os.O_RDONLY)
251
+ try:
252
+ os.fsync(fd)
253
+ finally:
254
+ os.close(fd)
255
  os.replace(tmp, path)
256
+ # fsync the parent directory so the rename itself is durable.
257
+ dir_fd = os.open(os.path.dirname(path) or '.', os.O_RDONLY)
258
+ try:
259
+ os.fsync(dir_fd)
260
+ except OSError:
261
+ pass # not all filesystems support directory fsync (e.g. some FUSE)
262
+ finally:
263
+ os.close(dir_fd)
264
+
265
+
266
+ # Codec version β€” bumped whenever the algorithm changes in a way that would
267
+ # make older checkpoints invalid (e.g. depth-power mapping change, scale
268
+ # codebook range change, group-size change). Used by the fingerprint validator.
269
+ CODEC_VERSION = 'v2.0'
270
+
271
+ def codec_fingerprint(model_id, revision, depth_powers, group_size, codec_version):
272
+ """Stable string that identifies the algorithmic state behind a checkpoint.
273
+
274
+ Two checkpoints with the same fingerprint can be safely interleaved.
275
+ Two with different fingerprints must not be mixed β€” a mismatch on resume
276
+ causes the stale checkpoint to be discarded and re-quantized.
277
+ """
278
+ parts = [
279
+ f'codec={codec_version}',
280
+ f'model={model_id}',
281
+ f'revision={revision or "unspecified"}',
282
+ f'gs={group_size}',
283
+ f'powers=' + ','.join(f'{d}:{p}' for d, p in sorted(depth_powers.items())),
284
+ ]
285
+ return '|'.join(parts)
286
 
287
  def load_ckpt(path):
288
  with np.load(path, allow_pickle=True) as z:
 
311
  parser.add_argument('--matrix-range', default=None,
312
  help='Slice of matrices to process: "start:end" (0-indexed, end exclusive). '
313
  'Use to manually parallelize across processes/machines via shared checkpoint dir.')
314
+ parser.add_argument('--revision', default=None,
315
+ help='HuggingFace revision (commit SHA or tag) to pin the source model. '
316
+ 'Recommended for reproducibility β€” without it, the upstream repo can move under you.')
317
  args = parser.parse_args()
318
 
319
  config_names = [c.strip() for c in args.configs.split(',') if c.strip()]
 
335
  dtype = torch.bfloat16 if args.dtype == 'bfloat16' else torch.float16
336
  print(' loading model (CPU)...', flush=True)
337
  t_load = time.time()
338
+ _cfg = AutoConfig.from_pretrained(args.model, revision=args.revision, trust_remote_code=True)
339
  _arch = ((getattr(_cfg, 'architectures', None) or [''])[0] or '').lower()
340
  if 't5' in _arch or 'encoder' in _arch:
341
  from transformers import T5EncoderModel
342
  print(' loading as T5EncoderModel (encoder-only)', flush=True)
343
+ model = T5EncoderModel.from_pretrained(args.model, revision=args.revision, torch_dtype=dtype,
344
  device_map='cpu', trust_remote_code=True,
345
  low_cpu_mem_usage=True)
346
  else:
347
  try:
348
+ model = AutoModelForCausalLM.from_pretrained(args.model, revision=args.revision, torch_dtype=dtype,
349
  device_map='cpu', trust_remote_code=True,
350
  low_cpu_mem_usage=True)
351
  except ValueError:
352
  print(' fallback to generic AutoModel', flush=True)
353
+ model = AutoModel.from_pretrained(args.model, revision=args.revision, torch_dtype=dtype,
354
  device_map='cpu', trust_remote_code=True,
355
  low_cpu_mem_usage=True)
356
  try:
357
+ tokenizer = AutoTokenizer.from_pretrained(args.model, revision=args.revision, trust_remote_code=True)
358
  if tokenizer.pad_token is None:
359
  tokenizer.pad_token = tokenizer.eos_token
360
  except Exception as e:
 
379
  range_end = min(range_end, len(matrices))
380
  print(f' matrix-range: [{range_start}:{range_end})', flush=True)
381
 
382
+ # Codec fingerprint for this run β€” used to validate resumed checkpoints.
383
+ expected_fp = codec_fingerprint(args.model, args.revision, DEPTH_POWERS, GS, CODEC_VERSION)
384
+
385
  # Determine which need work (resume from checkpoints)
386
  todo = []
387
  done_count = 0
388
+ discarded_count = 0
389
  for idx, (pn, p) in enumerate(matrices):
390
  if idx < range_start or idx >= range_end:
391
  continue
392
  cp = matrix_ckpt_path(ckpt_dir, idx, pn)
393
  if os.path.exists(cp):
 
394
  try:
395
  z = np.load(cp, allow_pickle=True)
396
+ meta = json.loads(str(z['_meta'][()]))
397
+ # Validate: configs cover requested set, fingerprint matches, shape matches.
398
+ have_configs = set(meta.get('configs', []))
399
+ ckpt_fp = meta.get('fingerprint')
400
+ ckpt_shape = tuple(meta.get('shape', ()))
401
+ cur_shape = tuple(p.shape)
402
+ if all(cn in have_configs for cn in config_names) \
403
+ and ckpt_fp == expected_fp \
404
+ and ckpt_shape == cur_shape:
405
  done_count += 1
406
  continue
407
+ if ckpt_fp != expected_fp:
408
+ print(f' fingerprint mismatch on {cp}: stale={ckpt_fp!r} expected={expected_fp!r} β€” discarding', flush=True)
409
+ elif ckpt_shape != cur_shape:
410
+ print(f' shape mismatch on {cp}: stale={ckpt_shape} current={cur_shape} β€” discarding', flush=True)
411
+ else:
412
+ print(f' missing configs in {cp}: have={have_configs}, need={config_names} β€” redoing', flush=True)
413
+ discarded_count += 1
414
+ os.remove(cp)
415
  except Exception as e:
416
  print(f' bad checkpoint {cp}: {e}, will redo', flush=True)
417
  os.remove(cp)
418
  todo.append((idx, pn, p))
419
+ if discarded_count:
420
+ print(f' discarded {discarded_count} stale checkpoint(s)', flush=True)
421
  print(f' {done_count} matrices already checkpointed, {len(todo)} to do', flush=True)
422
 
423
  t0 = time.time()
 
433
  w = p.data.float().numpy()
434
  result = quantize_matrix_multi(
435
  (w.ravel(), w.shape[0], w.shape[1], config_names))
436
+ # Pack into npz: one key per config + meta (with codec fingerprint
437
+ # so a future resume can detect a stale checkpoint and discard it).
438
  save_data = {'_meta': np.array(json.dumps({
439
  'name': pn, 'idx': idx, 'shape': list(w.shape),
440
  'configs': config_names,
441
+ 'fingerprint': expected_fp,
442
  }))}
443
  for cn, info in result.items():
444
  save_data[f'{cn}__w'] = info['recon_w']
 
456
  # CRITICAL: do NOT pre-build all matrices in a list β€” for large models
457
  # (Llama 70B = 140GB) that OOMs the box at multiple hundred GB. The generator
458
  # is consumed lazily by Pool.imap.
459
+ idx_name = [(idx, pn, list(p.shape)) for idx, pn, p in todo]
460
  def gen():
461
  for idx, pn, p in todo:
462
  w = p.data.float().numpy()
 
466
  p.data = __import__('torch').zeros(1, dtype=p.dtype)
467
  with Pool(args.workers) as pool:
468
  for i, result in enumerate(pool.imap(quantize_matrix_multi, gen(), chunksize=1)):
469
+ idx, pn, shape = idx_name[i]
470
  save_data = {'_meta': np.array(json.dumps({
471
+ 'name': pn, 'idx': idx, 'shape': shape,
472
  'configs': config_names,
473
+ 'fingerprint': expected_fp,
474
  }))}
475
  for cn, info in result.items():
476
  save_data[f'{cn}__w'] = info['recon_w']
 
557
 
558
  config = {
559
  'model': os.path.basename(args.model.rstrip('/')),
560
+ 'model_revision': args.revision,
561
+ 'codec_version': CODEC_VERSION,
562
+ 'codec_fingerprint': expected_fp,
563
  'codec': cn,
564
  'bpw': total_bpw, 'trit_bpw': trit_bpw, 'scale_bpw': scale_bpw,
565
  'depth_pcts': {str(d): total_depth[d]/tg for d in [1,2,3,4]},
566
  'n_matrices': len(matrices),
567
  'group_size': GS,
568
+ 'fp16_layers': ['lm_head', 'embed_tokens', '*_norm'],
569
  'codec_params': CODECS[cn],
570
  }
571
  with open(os.path.join(cfg_dir, 'config.json'), 'w') as f: