| |
| """Patch a Gemma-4 SFT checkpoint that is missing some weights. |
| |
| DeepSpeed ZeRO3 sometimes drops sliding-window-layer K/V weights when saving. |
| This script copies the missing weights from the base model into the SFT |
| checkpoint, producing a complete model.safetensors plus an updated |
| model.safetensors.index.json (if needed) so the model can be loaded by vLLM. |
| |
| Usage: |
| python3 patch_gemma_checkpoint.py \ |
| --base /path/to/Gemma-4-E4B-it \ |
| --sft /path/to/sft/output_dir \ |
| [--out /path/to/output_dir] # default: in-place |
| """ |
| import argparse |
| import json |
| import os |
| import shutil |
| from pathlib import Path |
|
|
| import torch |
| from safetensors import safe_open |
| from safetensors.torch import save_file |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--base", required=True, help="path to base model dir") |
| ap.add_argument("--sft", required=True, help="path to sft model dir") |
| ap.add_argument("--out", default=None, help="output dir (default: --sft)") |
| args = ap.parse_args() |
|
|
| base_dir = Path(args.base) |
| sft_dir = Path(args.sft) |
| out_dir = Path(args.out) if args.out else sft_dir |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| base_st = base_dir / "model.safetensors" |
| sft_st = sft_dir / "model.safetensors" |
| if not base_st.exists(): |
| |
| base_st_index = base_dir / "model.safetensors.index.json" |
| assert base_st_index.exists(), f"missing {base_st} or its index" |
| if not sft_st.exists(): |
| sft_st_index = sft_dir / "model.safetensors.index.json" |
| assert sft_st_index.exists(), f"missing {sft_st} or its index" |
|
|
| |
| print(f"[1/4] loading SFT weights from {sft_dir}...") |
| sft_tensors = {} |
| sft_index_file = sft_dir / "model.safetensors.index.json" |
| if sft_st.exists(): |
| with safe_open(sft_st, framework="pt") as f: |
| for k in f.keys(): |
| sft_tensors[k] = f.get_tensor(k) |
| else: |
| idx = json.loads(sft_index_file.read_text()) |
| for shard in set(idx["weight_map"].values()): |
| with safe_open(sft_dir / shard, framework="pt") as f: |
| for k in f.keys(): |
| sft_tensors[k] = f.get_tensor(k) |
|
|
| |
| print(f"[2/4] scanning base weights from {base_dir}...") |
| base_keys_to_files = {} |
| if base_st.exists(): |
| with safe_open(base_st, framework="pt") as f: |
| for k in f.keys(): |
| base_keys_to_files[k] = base_st |
| else: |
| idx = json.loads((base_dir / "model.safetensors.index.json").read_text()) |
| for k, shard in idx["weight_map"].items(): |
| base_keys_to_files[k] = base_dir / shard |
|
|
| base_keys = set(base_keys_to_files.keys()) |
| sft_keys = set(sft_tensors.keys()) |
| missing = sorted(base_keys - sft_keys) |
| extra = sorted(sft_keys - base_keys) |
|
|
| print(f" base keys: {len(base_keys)}") |
| print(f" sft keys: {len(sft_keys)}") |
| print(f" missing in sft: {len(missing)} (will copy from base)") |
| print(f" extra in sft : {len(extra)} (kept as-is)") |
|
|
| if not missing: |
| print("[OK] nothing to patch; sft is already complete") |
| return |
|
|
| |
| print(f"[3/4] copying {len(missing)} missing weights from base...") |
| by_shard = {} |
| for k in missing: |
| by_shard.setdefault(base_keys_to_files[k], []).append(k) |
|
|
| for shard_path, keys in by_shard.items(): |
| with safe_open(shard_path, framework="pt") as f: |
| for k in keys: |
| t = f.get_tensor(k) |
| if t.dtype != torch.bfloat16: |
| t = t.to(torch.bfloat16) |
| sft_tensors[k] = t |
|
|
| |
| out_path = out_dir / "model.safetensors" |
| print(f"[4/4] writing patched checkpoint -> {out_path}") |
| |
| if out_dir == sft_dir: |
| for stale in [out_dir / "model.safetensors.index.json"]: |
| if stale.exists(): |
| print(f" removing stale {stale}") |
| stale.unlink() |
| save_file(sft_tensors, str(out_path), metadata={"format": "pt"}) |
| print(f"[OK] saved {len(sft_tensors)} tensors to {out_path}") |
| print(f" size: {out_path.stat().st_size / 1e9:.2f} GB") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|