File size: 2,118 Bytes
a0ad699
 
 
 
 
dae69c0
a0ad699
 
 
 
 
 
 
dae69c0
 
a0ad699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5571883
 
a0ad699
dae69c0
 
 
 
 
 
 
 
a0ad699
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#!/usr/bin/env python3

import argparse
from pathlib import Path
import torch
from safetensors.torch import save_file as safetensors_save_file


def main():
    parser = argparse.ArgumentParser(
        description="Strip training checkpoint to inference-only weights (FP32)."
    )
    parser.add_argument("--in", dest="inp", type=str, required=True, help="Path to training checkpoint .pt")
    parser.add_argument("--out", dest="out", type=str, required=True, help="Path to save weights-only .pt or .safetensors")
    # Output format is inferred from --out extension
    args = parser.parse_args()

    in_path = Path(args.inp)
    out_path = Path(args.out)

    assert in_path.is_file(), f"Input file does not exist: {in_path}"
    out_path.parent.mkdir(parents=True, exist_ok=True)

    ckpt = torch.load(str(in_path), map_location="cpu")

    # Primary (project) format: {'step', 'model', 'optim', 'scaler', 'best_f1'}
    if isinstance(ckpt, dict) and "model" in ckpt:
        state_dict = ckpt["model"]
    # Secondary common format: {'state_dict': model.state_dict(), ...}
    elif isinstance(ckpt, dict) and "state_dict" in ckpt:
        state_dict = ckpt["state_dict"]
    else:
        # Fallback: checkpoint is already a pure state_dict
        assert isinstance(ckpt, dict) and all(isinstance(v, torch.Tensor) for v in ckpt.values()), (
            "Checkpoint is not a recognized format: expected keys 'model' or 'state_dict', "
            "or a pure state_dict (name->Tensor)."
        )
        state_dict = ckpt

    #in the future, can cast to bfloat if necessary.
    # state_dict = {k: (v.float() if torch.is_floating_point(v) else v) for k, v in state_dict.items()}

    suffix = out_path.suffix.lower()
    if suffix == ".safetensors":
        safetensors_save_file(state_dict, str(out_path))
        print(f"[strip_checkpoint] Saved safetensors (pure state_dict) to: {out_path}")
    else:
        to_save = {"model": state_dict}
        torch.save(to_save, str(out_path))
        print(f"[strip_checkpoint] Saved dict with only 'model' to: {out_path}")


if __name__ == "__main__":
    main()