|
|
|
""" Checkpoint Cleaning Script |
|
|
|
Takes training checkpoints with GPU tensors, optimizer state, extra dict keys, etc. |
|
and outputs a CPU tensor checkpoint with only the `state_dict` along with SHA256 |
|
calculation for model zoo compatibility. |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) |
|
""" |
|
import torch |
|
import argparse |
|
import os |
|
import hashlib |
|
import shutil |
|
import tempfile |
|
from timm.models import load_state_dict |
|
try: |
|
import safetensors.torch |
|
_has_safetensors = True |
|
except ImportError: |
|
_has_safetensors = False |
|
|
|
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') |
|
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', |
|
help='path to latest checkpoint (default: none)') |
|
parser.add_argument('--output', default='', type=str, metavar='PATH', |
|
help='output path') |
|
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', |
|
help='use ema version of weights if present') |
|
parser.add_argument('--no-hash', dest='no_hash', action='store_true', |
|
help='no hash in output filename') |
|
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true', |
|
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint') |
|
parser.add_argument('--safetensors', action='store_true', |
|
help='Save weights using safetensors instead of the default torch way (pickle).') |
|
|
|
|
|
def main(): |
|
args = parser.parse_args() |
|
|
|
if os.path.exists(args.output): |
|
print("Error: Output filename ({}) already exists.".format(args.output)) |
|
exit(1) |
|
|
|
clean_checkpoint( |
|
args.checkpoint, |
|
args.output, |
|
not args.no_use_ema, |
|
args.no_hash, |
|
args.clean_aux_bn, |
|
safe_serialization=args.safetensors, |
|
) |
|
|
|
|
|
def clean_checkpoint( |
|
checkpoint, |
|
output, |
|
use_ema=True, |
|
no_hash=False, |
|
clean_aux_bn=False, |
|
safe_serialization: bool=False, |
|
): |
|
|
|
if checkpoint and os.path.isfile(checkpoint): |
|
print("=> Loading checkpoint '{}'".format(checkpoint)) |
|
state_dict = load_state_dict(checkpoint, use_ema=use_ema) |
|
new_state_dict = {} |
|
for k, v in state_dict.items(): |
|
if clean_aux_bn and 'aux_bn' in k: |
|
|
|
|
|
continue |
|
name = k[7:] if k.startswith('module.') else k |
|
new_state_dict[name] = v |
|
print("=> Loaded state_dict from '{}'".format(checkpoint)) |
|
|
|
ext = '' |
|
if output: |
|
checkpoint_root, checkpoint_base = os.path.split(output) |
|
checkpoint_base, ext = os.path.splitext(checkpoint_base) |
|
else: |
|
checkpoint_root = '' |
|
checkpoint_base = os.path.split(checkpoint)[1] |
|
checkpoint_base = os.path.splitext(checkpoint_base)[0] |
|
|
|
temp_filename = '__' + checkpoint_base |
|
if safe_serialization: |
|
assert _has_safetensors, "`pip install safetensors` to use .safetensors" |
|
safetensors.torch.save_file(new_state_dict, temp_filename) |
|
else: |
|
torch.save(new_state_dict, temp_filename) |
|
|
|
with open(temp_filename, 'rb') as f: |
|
sha_hash = hashlib.sha256(f.read()).hexdigest() |
|
|
|
if ext: |
|
final_ext = ext |
|
else: |
|
final_ext = ('.safetensors' if safe_serialization else '.pth') |
|
|
|
if no_hash: |
|
final_filename = checkpoint_base + final_ext |
|
else: |
|
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + final_ext |
|
|
|
shutil.move(temp_filename, os.path.join(checkpoint_root, final_filename)) |
|
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) |
|
return final_filename |
|
else: |
|
print("Error: Checkpoint ({}) doesn't exist".format(checkpoint)) |
|
return '' |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|