File size: 1,070 Bytes
16de183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from collections import OrderedDict


def extract(ckpt):
    model = ckpt["model"]
    opt = OrderedDict()
    opt["weight"] = {key: value for key, value in model.items() if "enc_q" not in key}
    return opt


def model_fusion(model_name, pth_path_1, pth_path_2):
    ckpt1 = torch.load(pth_path_1, map_location="cpu")
    ckpt2 = torch.load(pth_path_2, map_location="cpu")
    if "model" in ckpt1:
        ckpt1 = extract(ckpt1)
    else:
        ckpt1 = ckpt1["weight"]
    if "model" in ckpt2:
        ckpt2 = extract(ckpt2)
    else:
        ckpt2 = ckpt2["weight"]
    if sorted(ckpt1.keys()) != sorted(ckpt2.keys()):
        return "Fail to merge the models. The model architectures are not the same."
    opt = OrderedDict(
        weight={
            key: 1 * value.float() + (1 - 1) * ckpt2[key].float()
            for key, value in ckpt1.items()
        }
    )
    opt["info"] = f"Model fusion of {pth_path_1} and {pth_path_2}"
    torch.save(opt, f"logs/{model_name}.pth")
    print(f"Model fusion of {pth_path_1} and {pth_path_2} is done.")