#!/usr/bin/env python3 """ Checkpoint Averaging Script This script averages all model weights for checkpoints in specified path that match the specified filter wildcard. All checkpoints must be from the exact same model. For any hope of decent results, the checkpoints should be from the same or child (via resumes) training session. This can be viewed as similar to maintaining running EMA (exponential moving average) of the model weights or performing SWA (stochastic weight averaging), but post-training. Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ import torch import torch.nn as nn import argparse import os import glob import hashlib from collections import OrderedDict, defaultdict import re import copy from safetensors.torch import load_file as safetensors_load_file from safetensors.torch import save_file as safetensors_save_file parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') parser.add_argument('--input', default='', nargs="+", type=str, metavar='PATHS', help='path(s) to base input folder containing checkpoints') parser.add_argument('--output', type=str, default='avgmodel.pt', metavar='PATH', help='output file name of the averaged checkpoint') parser.add_argument('--suffix', default='', type=str, metavar='WILDCARD', help='checkpoint suffix') parser.add_argument('--min', type=int, default=500, help='Minimal iteration of checkpoints to average') parser.add_argument('--max', type=int, default=-1, help='Maximum iteration of checkpoints to average') def main(): args = parser.parse_args() patterns = args.input sel_checkpoint_filenames = [] for pattern in patterns: if args.suffix is not None: if not args.suffix.startswith('*'): pattern += '*' pattern += args.suffix checkpoint_filenames = glob.glob(pattern, recursive=True) if len(checkpoint_filenames) == 0: print("WARNING: No checkpoints matching '{}' and iteration >= {} in '{}'".format( pattern, args.min, args.input)) continue sel_checkpoint_filenames += checkpoint_filenames avg_ckpt = {} avg_counts = {} for i, c in enumerate(sel_checkpoint_filenames): if c.endswith(".safetensors"): checkpoint = safetensors_load_file(c) else: checkpoint = torch.load(c, map_location='cpu') print(c) for k in checkpoint: # Skip ema weights if k.startswith("model_ema."): continue if k not in avg_ckpt: avg_ckpt[k] = checkpoint[k] print(f"Copy {k}") avg_counts[k] = 1 # Another occurrence of a previously seen nn.Module. elif isinstance(checkpoint[k], nn.Module): #print(f"nn.Module: {k}") avg_state_dict = avg_ckpt[k].state_dict() param_state_dict = checkpoint[k] for m_k, m_v in param_state_dict.state_dict().items(): if m_k not in avg_state_dict: avg_state_dict[m_k] = copy.copy(m_v) print(f"Copy {k}.{m_k}") else: avg_state_dict[m_k].data += m_v print(f"Accumulate {k}.{m_k}") avg_ckpt[k].load_state_dict(avg_state_dict) avg_counts[k] = 1 # Another occurrence of a previously seen nn.Parameter. elif isinstance(checkpoint[k], (nn.Parameter, torch.Tensor)): #print(f"nn.Parameter: {k}") avg_ckpt[k].data += checkpoint[k].data avg_counts[k] += 1 else: print(f"NOT copying {type(checkpoint[k])}: {k}") pass for k in avg_ckpt: # safetensors use torch.Tensor instead of nn.Parameter. if isinstance(avg_ckpt[k], (nn.Parameter, torch.Tensor)): print(f"Averaging nn.Parameter: {k}") avg_ckpt[k].data /= avg_counts[k] elif isinstance(avg_ckpt[k], nn.Module): print(f"Averaging nn.Module: {k}") avg_state_dict = avg_ckpt[k].state_dict() for m_k, m_v in avg_state_dict.items(): m_v.data = (m_v.data / avg_counts[k]).to(m_v.data.dtype) avg_ckpt[k].load_state_dict(avg_state_dict) else: print(f"NOT averaging {type(avg_ckpt[k])}: {k}") if args.output.endswith(".safetensors"): safetensors_save_file(avg_ckpt, args.output) else: torch.save(avg_ckpt, args.output) print("=> Saved state_dict to '{}'".format(args.output)) if __name__ == '__main__': main()