Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,800 Bytes
02cc20b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
#!/usr/bin/env python3
""" Checkpoint Averaging Script
This script averages all model weights for checkpoints in specified path that match
the specified filter wildcard. All checkpoints must be from the exact same model.
For any hope of decent results, the checkpoints should be from the same or child
(via resumes) training session. This can be viewed as similar to maintaining running
EMA (exponential moving average) of the model weights or performing SWA (stochastic
weight averaging), but post-training.
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import torch
import torch.nn as nn
import argparse
import os
import glob
import hashlib
from collections import OrderedDict, defaultdict
import re
import copy
from safetensors.torch import load_file as safetensors_load_file
from safetensors.torch import save_file as safetensors_save_file
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', nargs="+", type=str, metavar='PATHS',
help='path(s) to base input folder containing checkpoints')
parser.add_argument('--output', type=str, default='avgmodel.pt', metavar='PATH',
help='output file name of the averaged checkpoint')
parser.add_argument('--suffix', default='', type=str, metavar='WILDCARD',
help='checkpoint suffix')
parser.add_argument('--min', type=int, default=500, help='Minimal iteration of checkpoints to average')
parser.add_argument('--max', type=int, default=-1, help='Maximum iteration of checkpoints to average')
def main():
args = parser.parse_args()
patterns = args.input
sel_checkpoint_filenames = []
for pattern in patterns:
if args.suffix is not None:
if not args.suffix.startswith('*'):
pattern += '*'
pattern += args.suffix
checkpoint_filenames = glob.glob(pattern, recursive=True)
if len(checkpoint_filenames) == 0:
print("WARNING: No checkpoints matching '{}' and iteration >= {} in '{}'".format(
pattern, args.min, args.input))
continue
sel_checkpoint_filenames += checkpoint_filenames
avg_ckpt = {}
avg_counts = {}
for i, c in enumerate(sel_checkpoint_filenames):
if c.endswith(".safetensors"):
checkpoint = safetensors_load_file(c)
else:
checkpoint = torch.load(c, map_location='cpu')
print(c)
for k in checkpoint:
# Skip ema weights
if k.startswith("model_ema."):
continue
if k not in avg_ckpt:
avg_ckpt[k] = checkpoint[k]
print(f"Copy {k}")
avg_counts[k] = 1
# Another occurrence of a previously seen nn.Module.
elif isinstance(checkpoint[k], nn.Module):
#print(f"nn.Module: {k}")
avg_state_dict = avg_ckpt[k].state_dict()
param_state_dict = checkpoint[k]
for m_k, m_v in param_state_dict.state_dict().items():
if m_k not in avg_state_dict:
avg_state_dict[m_k] = copy.copy(m_v)
print(f"Copy {k}.{m_k}")
else:
avg_state_dict[m_k].data += m_v
print(f"Accumulate {k}.{m_k}")
avg_ckpt[k].load_state_dict(avg_state_dict)
avg_counts[k] = 1
# Another occurrence of a previously seen nn.Parameter.
elif isinstance(checkpoint[k], (nn.Parameter, torch.Tensor)):
#print(f"nn.Parameter: {k}")
avg_ckpt[k].data += checkpoint[k].data
avg_counts[k] += 1
else:
print(f"NOT copying {type(checkpoint[k])}: {k}")
pass
for k in avg_ckpt:
# safetensors use torch.Tensor instead of nn.Parameter.
if isinstance(avg_ckpt[k], (nn.Parameter, torch.Tensor)):
print(f"Averaging nn.Parameter: {k}")
avg_ckpt[k].data /= avg_counts[k]
elif isinstance(avg_ckpt[k], nn.Module):
print(f"Averaging nn.Module: {k}")
avg_state_dict = avg_ckpt[k].state_dict()
for m_k, m_v in avg_state_dict.items():
m_v.data = (m_v.data / avg_counts[k]).to(m_v.data.dtype)
avg_ckpt[k].load_state_dict(avg_state_dict)
else:
print(f"NOT averaging {type(avg_ckpt[k])}: {k}")
if args.output.endswith(".safetensors"):
safetensors_save_file(avg_ckpt, args.output)
else:
torch.save(avg_ckpt, args.output)
print("=> Saved state_dict to '{}'".format(args.output))
if __name__ == '__main__':
main()
|