File size: 2,209 Bytes
fe6327d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os, sys
sys.path.insert(0, os.getcwd())
import argparse


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "base_model", help="The model you want to merge with loha",
        default='', type=str
    )
    parser.add_argument(
        "lycoris_model", help="the lyco model you want to merge into sd model",
        default='', type=str
    )
    parser.add_argument(
        "output_name", help="the output model",
        default='./out.pt', type=str
    )
    parser.add_argument(
        "--is_v2", help="Your base model is sd v2 or not",
        default=False, action="store_true"
    )
    parser.add_argument(
        "--device", help="Which device you want to use to merge the weight",
        default='cpu', type=str
    )
    parser.add_argument(
        "--dtype", help='dtype to save',
        default='float', type=str
    )
    parser.add_argument(
        "--weight", help='weight for the lyco model to merge',
        default='1.0', type=float
    )
    return parser.parse_args()
ARGS = get_args()


from lycoris_utils import merge
from lycoris.kohya_model_utils import (
    load_models_from_stable_diffusion_checkpoint,
    save_stable_diffusion_checkpoint,
    load_file
)

import torch


def main():
    base = load_models_from_stable_diffusion_checkpoint(ARGS.is_v2, ARGS.base_model)
    if ARGS.lycoris_model.rsplit('.', 1)[-1] == 'safetensors':
        lyco = load_file(ARGS.lycoris_model)
    else:
        lyco = torch.load(ARGS.lycoris_model)
    
    dtype_str = ARGS.dtype.replace('fp', 'float').replace('bf', 'bfloat')
    dtype = {
        'float': torch.float,
        'float16': torch.float16,
        'float32': torch.float32,
        'float64': torch.float64,
        'bfloat': torch.bfloat16,
        'bfloat16': torch.bfloat16,
    }.get(dtype_str, None)
    if dtype is None:
        raise ValueError(f'Cannot Find the dtype "{dtype}"')
    
    merge(
        base,
        lyco,
        ARGS.weight,
        ARGS.device
    )
    
    save_stable_diffusion_checkpoint(
        ARGS.is_v2, ARGS.output_name, 
        base[0], base[2], 
        None, 0, 0, dtype, 
        base[1]
    )


if __name__ == '__main__':
    main()