File size: 14,303 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import os
import html
import json
import time
import shutil

import torch
import tqdm
import gradio as gr
import safetensors.torch
from modules.merging.merge import merge_models
from modules.merging.merge_utils import TRIPLE_METHODS

from modules import shared, images, sd_models, sd_vae, sd_models_config, devices


def run_pnginfo(image):
    if image is None:
        return '', '', ''
    geninfo, items = images.read_info_from_image(image)
    items = {**{'parameters': geninfo}, **items}
    info = ''
    for key, text in items.items():
        if key != 'UserComment':
            info += f"<div><b>{html.escape(str(key))}</b>: {html.escape(str(text))}</div>"
    return '', geninfo, info


def create_config(ckpt_result, config_source, a, b, c):
    def config(x):
        res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
        return res if res != shared.sd_default_config else None

    if config_source == 0:
        cfg = config(a) or config(b) or config(c)
    elif config_source == 1:
        cfg = config(b)
    elif config_source == 2:
        cfg = config(c)
    else:
        cfg = None
    if cfg is None:
        return
    filename, _ = os.path.splitext(ckpt_result)
    checkpoint_filename = filename + ".yaml"
    shared.log.info("Copying config: {cfg} -> {checkpoint_filename}")
    shutil.copyfile(cfg, checkpoint_filename)


def to_half(tensor, enable):
    if enable and tensor.dtype == torch.float:
        return tensor.half()
    return tensor


def run_modelmerger(id_task, **kwargs):  # pylint: disable=unused-argument
    shared.state.begin('merge')
    t0 = time.time()

    def fail(message):
        shared.state.textinfo = message
        shared.state.end()
        return [*[gr.update() for _ in range(4)], message]

    kwargs["models"] = {
        "model_a": sd_models.get_closet_checkpoint_match(kwargs.get("primary_model_name", None)).filename,
        "model_b": sd_models.get_closet_checkpoint_match(kwargs.get("secondary_model_name", None)).filename,
    }

    if kwargs.get("primary_model_name", None) in [None, 'None']:
        return fail("Failed: Merging requires a primary model.")
    primary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("primary_model_name", None))
    if kwargs.get("secondary_model_name", None) in [None, 'None']:
        return fail("Failed: Merging requires a secondary model.")
    secondary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("secondary_model_name", None))
    if kwargs.get("tertiary_model_name", None) in [None, 'None'] and kwargs.get("merge_mode", None) in TRIPLE_METHODS:
        return fail(f"Failed: Interpolation method ({kwargs.get('merge_mode', None)}) requires a tertiary model.")
    tertiary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("tertiary_model_name", None)) if kwargs.get("merge_mode", None) in TRIPLE_METHODS else None

    del kwargs["primary_model_name"]
    del kwargs["secondary_model_name"]
    if kwargs.get("tertiary_model_name", None) is not None:
        kwargs["models"] |= {"model_c": sd_models.get_closet_checkpoint_match(kwargs.get("tertiary_model_name", None)).filename}
        del kwargs["tertiary_model_name"]

    if kwargs.get("alpha_base", None) and kwargs.get("alpha_in_blocks", None) and kwargs.get("alpha_mid_block", None) and kwargs.get("alpha_out_blocks", None):
        try:
            alpha = [float(x) for x in
                    [kwargs["alpha_base"]] + kwargs["alpha_in_blocks"].split(",") + [kwargs["alpha_mid_block"]] + kwargs["alpha_out_blocks"].split(",")]
            assert len(alpha) == 26 or len(alpha) == 20, "Alpha Block Weights are wrong length (26 or 20 for SDXL)"
            kwargs["alpha"] = alpha
        except KeyError as ke:
            shared.log.warning(f"Merge: Malformed manual block weight: {ke}")
    elif kwargs.get("alpha_preset", None) or kwargs.get("alpha", None):
        kwargs["alpha"] = kwargs.get("alpha_preset", kwargs["alpha"])

    kwargs.pop("alpha_base", None)
    kwargs.pop("alpha_in_blocks", None)
    kwargs.pop("alpha_mid_block", None)
    kwargs.pop("alpha_out_blocks", None)
    kwargs.pop("alpha_preset", None)

    if kwargs.get("beta_base", None) and kwargs.get("beta_in_blocks", None) and kwargs.get("beta_mid_block", None) and kwargs.get("beta_out_blocks", None):
        try:
            beta = [float(x) for x in
                    [kwargs["beta_base"]] + kwargs["beta_in_blocks"].split(",") + [kwargs["beta_mid_block"]] + kwargs["beta_out_blocks"].split(",")]
            assert len(beta) == 26 or len(beta) == 20, "Beta Block Weights are wrong length (26 or 20 for SDXL)"
            kwargs["beta"] = beta
        except KeyError as ke:
            shared.log.warning(f"Merge: Malformed manual block weight: {ke}")
    elif kwargs.get("beta_preset", None) or kwargs.get("beta", None):
        kwargs["beta"] = kwargs.get("beta_preset", kwargs["beta"])

    kwargs.pop("beta_base", None)
    kwargs.pop("beta_in_blocks", None)
    kwargs.pop("beta_mid_block", None)
    kwargs.pop("beta_out_blocks", None)
    kwargs.pop("beta_preset", None)

    if kwargs["device"] == "gpu":
        kwargs["device"] = devices.device
    elif kwargs["device"] == "shuffle":
        kwargs["device"] = torch.device("cpu")
        kwargs["work_device"] = devices.device
    else:
        kwargs["device"] = torch.device("cpu")
    if kwargs.pop("unload", False):
        sd_models.unload_model_weights()

    try:
        theta_0 = merge_models(**kwargs)
    except Exception as e:
        return fail(f"{e}")

    try:
        theta_0 = theta_0.to_dict() #TensorDict -> Dict if necessary
    except Exception:
        pass

    bake_in_vae_filename = sd_vae.vae_dict.get(kwargs.get("bake_in_vae", None), None)
    if bake_in_vae_filename is not None:
        shared.log.info(f"Merge VAE='{bake_in_vae_filename}'")
        shared.state.textinfo = 'Merge VAE'
        vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename)
        for key in vae_dict.keys():
            theta_0_key = 'first_stage_model.' + key
            if theta_0_key in theta_0:
                theta_0[theta_0_key] = to_half(vae_dict[key], kwargs.get("precision", "fp16") == "fp16")
        del vae_dict

    ckpt_dir = shared.opts.ckpt_dir or sd_models.model_path
    filename = kwargs.get("custom_name", "Unnamed_Merge")
    filename += "." + kwargs.get("checkpoint_format", None)
    output_modelname = os.path.join(ckpt_dir, filename)
    shared.state.textinfo = "merge saving"
    metadata = None
    if kwargs.get("save_metadata", False):
        metadata = {"format": "pt", "sd_merge_models": {}}
        merge_recipe = {
            "type": "SDNext",  # indicate this model was merged with webui's built-in merger
            "primary_model_hash": primary_model_info.sha256,
            "secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
            "tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
            "merge_mode": kwargs.get('merge_mode', None),
            "alpha": kwargs.get('alpha', None),
            "beta": kwargs.get('beta', None),
            "precision": kwargs.get('precision', None),
            "custom_name": kwargs.get("custom_name", "Unamed_Merge"),
        }
        metadata["sd_merge_recipe"] = json.dumps(merge_recipe)

        def add_model_metadata(checkpoint_info):
            checkpoint_info.calculate_shorthash()
            metadata["sd_merge_models"][checkpoint_info.sha256] = {
                "name": checkpoint_info.name,
                "legacy_hash": checkpoint_info.hash,
                "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
            }
            metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))

        add_model_metadata(primary_model_info)
        if secondary_model_info:
            add_model_metadata(secondary_model_info)
        if tertiary_model_info:
            add_model_metadata(tertiary_model_info)
        metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])

    _, extension = os.path.splitext(output_modelname)

    if os.path.exists(output_modelname) and not kwargs.get("overwrite", False):
        return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Model alredy exists: {output_modelname}"]
    if extension.lower() == ".safetensors":
        safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
    else:
        torch.save(theta_0, output_modelname)

    t1 = time.time()
    shared.log.info(f"Merge complete: saved='{output_modelname}' time={t1-t0:.2f}")
    sd_models.list_models()
    created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
    if created_model:
        created_model.calculate_shorthash()
    devices.torch_gc(force=True)
    shared.state.end()
    return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Model saved to {output_modelname}"]


def run_modelconvert(model, checkpoint_formats, precision, conv_type, custom_name, unet_conv, text_encoder_conv,
                     vae_conv, others_conv, fix_clip):
    # position_ids in clip is int64. model_ema.num_updates is int32
    dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16}
    dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16}

    def conv_fp16(t: torch.Tensor):
        return t.half() if t.dtype in dtypes_to_fp16 else t

    def conv_bf16(t: torch.Tensor):
        return t.bfloat16() if t.dtype in dtypes_to_bf16 else t

    def conv_full(t):
        return t

    _g_precision_func = {
        "full": conv_full,
        "fp32": conv_full,
        "fp16": conv_fp16,
        "bf16": conv_bf16,
    }

    def check_weight_type(k: str) -> str:
        if k.startswith("model.diffusion_model"):
            return "unet"
        elif k.startswith("first_stage_model"):
            return "vae"
        elif k.startswith("cond_stage_model"):
            return "clip"
        return "other"

    def load_model(path):
        if path.endswith(".safetensors"):
            m = safetensors.torch.load_file(path, device="cpu")
        else:
            m = torch.load(path, map_location="cpu")
        state_dict = m["state_dict"] if "state_dict" in m else m
        return state_dict

    def fix_model(model, fix_clip=False):
        # code from model-toolkit
        nai_keys = {
            'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
            'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
            'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.'
        }
        for k in list(model.keys()):
            for r in nai_keys:
                if type(k) == str and k.startswith(r):
                    new_key = k.replace(r, nai_keys[r])
                    model[new_key] = model[k]
                    del model[k]
                    shared.log.warning(f"Model convert: fixed NovelAI error key: {k}")
                    break
        if fix_clip:
            i = "cond_stage_model.transformer.text_model.embeddings.position_ids"
            if i in model:
                correct = torch.Tensor([list(range(77))]).to(torch.int64)
                now = model[i].to(torch.int64)

                broken = correct.ne(now)
                broken = [i for i in range(77) if broken[0][i]]
                model[i] = correct
                if len(broken) != 0:
                    shared.log.warning(f"Model convert: fixed broken CLiP: {broken}")

        return model

    if model == "":
        return "Error: you must choose a model"
    if len(checkpoint_formats) == 0:
        return "Error: at least choose one model save format"

    extra_opt = {
        "unet": unet_conv,
        "clip": text_encoder_conv,
        "vae": vae_conv,
        "other": others_conv
    }
    shared.state.begin('convert')
    model_info = sd_models.checkpoints_list[model]
    shared.state.textinfo = f"Loading {model_info.filename}..."
    shared.log.info(f"Model convert loading: {model_info.filename}")
    state_dict = load_model(model_info.filename)

    ok = {}  # {"state_dict": {}}

    conv_func = _g_precision_func[precision]

    def _hf(wk: str, t: torch.Tensor):
        if not isinstance(t, torch.Tensor):
            return
        w_t = check_weight_type(wk)
        conv_t = extra_opt[w_t]
        if conv_t == "convert":
            ok[wk] = conv_func(t)
        elif conv_t == "copy":
            ok[wk] = t
        elif conv_t == "delete":
            return

    shared.log.info("Model convert: running")
    if conv_type == "ema-only":
        for k in tqdm.tqdm(state_dict):
            ema_k = "___"
            try:
                ema_k = "model_ema." + k[6:].replace(".", "")
            except Exception:
                pass
            if ema_k in state_dict:
                _hf(k, state_dict[ema_k])
            elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
                _hf(k, state_dict[k])
    elif conv_type == "no-ema":
        for k, v in tqdm.tqdm(state_dict.items()):
            if "model_ema." not in k:
                _hf(k, v)
    else:
        for k, v in tqdm.tqdm(state_dict.items()):
            _hf(k, v)

    ok = fix_model(ok, fix_clip=fix_clip)
    output = ""
    ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
    save_name = f"{model_info.model_name}-{precision}"
    if conv_type != "disabled":
        save_name += f"-{conv_type}"
    if custom_name != "":
        save_name = custom_name
    for fmt in checkpoint_formats:
        ext = ".safetensors" if fmt == "safetensors" else ".ckpt"
        _save_name = save_name + ext
        save_path = os.path.join(ckpt_dir, _save_name)
        shared.log.info(f"Model convert saving: {save_path}")
        if fmt == "safetensors":
            safetensors.torch.save_file(ok, save_path)
        else:
            torch.save({"state_dict": ok}, save_path)
        output += f"Checkpoint saved to {save_path}<br>"
    shared.state.end()
    return output