File size: 3,957 Bytes
fe6327d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os, sys
sys.path.insert(0, os.getcwd())
import argparse


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "base_model", help="The model which use it to train the dreambooth model",
        default='', type=str
    )
    parser.add_argument(
        "db_model", help="the dreambooth model you want to extract the locon",
        default='', type=str
    )
    parser.add_argument(
        "output_name", help="the output model",
        default='./out.pt', type=str
    )
    parser.add_argument(
        "--is_v2", help="Your base/db model is sd v2 or not",
        default=False, action="store_true"
    )
    parser.add_argument(
        "--device", help="Which device you want to use to extract the locon",
        default='cpu', type=str
    )
    parser.add_argument(
        "--mode", 
        help=(
            'extraction mode, can be "fixed", "threshold", "ratio", "quantile". '
            'If not "fixed", network_dim and conv_dim will be ignored'
        ),
        default='fixed', type=str
    )
    parser.add_argument(
        "--safetensors", help='use safetensors to save locon model',
        default=False, action="store_true"
    )
    parser.add_argument(
        "--linear_dim", help="network dim for linear layer in fixed mode",
        default=1, type=int
    )
    parser.add_argument(
        "--conv_dim", help="network dim for conv layer in fixed mode",
        default=1, type=int
    )
    parser.add_argument(
        "--linear_threshold", help="singular value threshold for linear layer in threshold mode",
        default=0., type=float
    )
    parser.add_argument(
        "--conv_threshold", help="singular value threshold for conv layer in threshold mode",
        default=0., type=float
    )
    parser.add_argument(
        "--linear_ratio", help="singular ratio for linear layer in ratio mode",
        default=0., type=float
    )
    parser.add_argument(
        "--conv_ratio", help="singular ratio for conv layer in ratio mode",
        default=0., type=float
    )
    parser.add_argument(
        "--linear_quantile", help="singular value quantile for linear layer quantile mode",
        default=1., type=float
    )
    parser.add_argument(
        "--conv_quantile", help="singular value quantile for conv layer quantile mode",
        default=1., type=float
    )
    parser.add_argument(
        "--use_sparse_bias", help="enable sparse bias",
        default=False, action="store_true"
    )
    parser.add_argument(
        "--sparsity", help="sparsity for sparse bias",
        default=0.98, type=float
    )
    parser.add_argument(
        "--disable_cp", help="don't use cp decomposition",
        default=False, action="store_true"
    )
    return parser.parse_args()
ARGS = get_args()


from lycoris.utils import extract_diff
from lycoris.kohya_model_utils import load_models_from_stable_diffusion_checkpoint

import torch
from safetensors.torch import save_file


def main():
    args = ARGS
    base = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.base_model)
    db = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.db_model)
    
    linear_mode_param = {
        'fixed': args.linear_dim,
        'threshold': args.linear_threshold,
        'ratio': args.linear_ratio,
        'quantile': args.linear_quantile,
    }[args.mode]
    conv_mode_param = {
        'fixed': args.conv_dim,
        'threshold': args.conv_threshold,
        'ratio': args.conv_ratio,
        'quantile': args.conv_quantile,
    }[args.mode]
    
    state_dict = extract_diff(
        base, db,
        args.mode,
        linear_mode_param, conv_mode_param,
        args.device, 
        args.use_sparse_bias, args.sparsity,
        not args.disable_cp
    )
    
    if args.safetensors:
        save_file(state_dict, args.output_name)
    else:
        torch.save(state_dict, args.output_name)


if __name__ == '__main__':
    main()