File size: 4,800 Bytes
02cc20b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python3
""" Checkpoint Averaging Script

This script averages all model weights for checkpoints in specified path that match
the specified filter wildcard. All checkpoints must be from the exact same model.

For any hope of decent results, the checkpoints should be from the same or child
(via resumes) training session. This can be viewed as similar to maintaining running
EMA (exponential moving average) of the model weights or performing SWA (stochastic
weight averaging), but post-training.

Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import torch
import torch.nn as nn
import argparse
import os
import glob
import hashlib
from collections import OrderedDict, defaultdict
import re
import copy
from safetensors.torch import load_file as safetensors_load_file
from safetensors.torch import save_file as safetensors_save_file

parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', nargs="+", type=str, metavar='PATHS',
                    help='path(s) to base input folder containing checkpoints')
parser.add_argument('--output', type=str, default='avgmodel.pt', metavar='PATH',
                    help='output file name of the averaged checkpoint')
parser.add_argument('--suffix', default='', type=str, metavar='WILDCARD',
                    help='checkpoint suffix')
parser.add_argument('--min', type=int, default=500, help='Minimal iteration of checkpoints to average')
parser.add_argument('--max', type=int, default=-1,  help='Maximum iteration of checkpoints to average')

def main():
    args = parser.parse_args()
    patterns = args.input
    sel_checkpoint_filenames = []

    for pattern in patterns:
        if args.suffix is not None:
            if not args.suffix.startswith('*'):
                pattern += '*'
            pattern += args.suffix

        checkpoint_filenames = glob.glob(pattern, recursive=True)
        if len(checkpoint_filenames) == 0:
            print("WARNING: No checkpoints matching '{}' and iteration >= {} in '{}'".format(
                    pattern, args.min, args.input))
            continue

        sel_checkpoint_filenames += checkpoint_filenames
    
    avg_ckpt = {}
    avg_counts = {}
    for i, c in enumerate(sel_checkpoint_filenames):
        if c.endswith(".safetensors"):
            checkpoint = safetensors_load_file(c)
        else:
            checkpoint = torch.load(c, map_location='cpu')
        print(c)

        for k in checkpoint:
            # Skip ema weights
            if k.startswith("model_ema."):
                continue
            if k not in avg_ckpt:
                avg_ckpt[k] = checkpoint[k]
                print(f"Copy {k}")
                avg_counts[k] = 1
            # Another occurrence of a previously seen nn.Module.
            elif isinstance(checkpoint[k], nn.Module):
                #print(f"nn.Module: {k}")
                avg_state_dict = avg_ckpt[k].state_dict()
                param_state_dict = checkpoint[k]
                for m_k, m_v in param_state_dict.state_dict().items():
                    if m_k not in avg_state_dict:
                        avg_state_dict[m_k] = copy.copy(m_v)
                        print(f"Copy {k}.{m_k}")
                    else:
                        avg_state_dict[m_k].data += m_v
                        print(f"Accumulate {k}.{m_k}")
                avg_ckpt[k].load_state_dict(avg_state_dict)
                avg_counts[k] = 1
            # Another occurrence of a previously seen nn.Parameter.
            elif isinstance(checkpoint[k], (nn.Parameter, torch.Tensor)):
                #print(f"nn.Parameter: {k}")
                avg_ckpt[k].data += checkpoint[k].data
                avg_counts[k] += 1
            else:
                print(f"NOT copying {type(checkpoint[k])}: {k}")
                pass
    
    for k in avg_ckpt:
        # safetensors use torch.Tensor instead of nn.Parameter.
        if isinstance(avg_ckpt[k], (nn.Parameter, torch.Tensor)):
            print(f"Averaging nn.Parameter: {k}")
            avg_ckpt[k].data /= avg_counts[k]
        elif isinstance(avg_ckpt[k], nn.Module):
            print(f"Averaging nn.Module: {k}")
            avg_state_dict = avg_ckpt[k].state_dict()
            for m_k, m_v in avg_state_dict.items():
                m_v.data = (m_v.data / avg_counts[k]).to(m_v.data.dtype)
            avg_ckpt[k].load_state_dict(avg_state_dict)
        else:
            print(f"NOT averaging {type(avg_ckpt[k])}: {k}")

    if args.output.endswith(".safetensors"):
        safetensors_save_file(avg_ckpt, args.output)
    else:
        torch.save(avg_ckpt, args.output)

    print("=> Saved state_dict to '{}'".format(args.output))


if __name__ == '__main__':
    main()