File size: 2,362 Bytes
02cc20b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from safetensors.torch import load_file as safetensors_load_file
from safetensors.torch import save_file as safetensors_save_file
import sys, argparse

def load_ckpt(ckpt_filepath):
    print(f"Loading model from {ckpt_filepath}")
    if ckpt_filepath.endswith(".safetensors"):
        state_dict = safetensors_load_file(ckpt_filepath, device="cpu")
        ckpt = None
    else:
        ckpt = torch.load(ckpt_filepath, map_location="cpu")
        state_dict = ckpt["state_dict"]
    return ckpt, state_dict

def save_ckpt(ckpt, ckpt_state_dict, ckpt_filepath):
    if ckpt_filepath.endswith(".safetensors"):
        safetensors_save_file(ckpt_state_dict, ckpt_filepath)
    else:
        if ckpt is not None:
            torch.save(ckpt, ckpt_filepath)
        else:
            torch.save(ckpt_state_dict, ckpt_filepath)

    print(f"Saved to {ckpt_filepath}")

def load_two_models(base_ckpt_filepath, te_ckpt_filepath):
    base_ckpt, base_state_dict = load_ckpt(base_ckpt_filepath)
    _, te_state_dict = load_ckpt(te_ckpt_filepath)
    # Other fields in sd_ckpt are also needed when saving the checkpoint. 
    # So return the whole sd_ckpt.
    return base_ckpt, base_state_dict, te_state_dict

parser = argparse.ArgumentParser()
parser.add_argument("--base_ckpt", type=str, required=True, help="Path to the base checkpoint")
parser.add_argument("--te_ckpt", type=str, required=True, help="Path to the checkpoint providing text encoder")
parser.add_argument("--out_ckpt", type=str, required=True, help="Path to the output checkpoint")
args = parser.parse_args()

base_ckpt, base_state_dict, te_state_dict = load_two_models(args.base_ckpt, args.te_ckpt)
# base_state_dict = sd_ckpt["state_dict"]

repl_count = 0

for k in base_state_dict:
    if k.startswith("cond_stage_model."):
        if k not in te_state_dict:
            print(f"!!!! '{k}' not in TE checkpoint")
            continue
        if base_state_dict[k].shape != te_state_dict[k].shape:
            print(f"!!!! '{k}' shape mismatch: {base_state_dict[k].shape} vs {te_state_dict[k].shape} !!!!")
            continue
        print(k)
        base_state_dict[k] = te_state_dict[k]
        repl_count += 1

if repl_count > 0:
    print(f"{repl_count} parameters replaced")
    save_ckpt(base_ckpt, base_state_dict, args.out_ckpt)
else:
    print("ERROR: No parameter replaced")