Spaces:
Runtime error
Runtime error
| from typing import Literal, Union, Dict | |
| import os | |
| import shutil | |
| import fire | |
| from diffusers import StableDiffusionPipeline | |
| from safetensors.torch import safe_open, save_file | |
| import torch | |
| from .lora import ( | |
| tune_lora_scale, | |
| patch_pipe, | |
| collapse_lora, | |
| monkeypatch_remove_lora, | |
| ) | |
| from .lora_manager import lora_join | |
| from .to_ckpt_v2 import convert_to_ckpt | |
| def _text_lora_path(path: str) -> str: | |
| assert path.endswith(".pt"), "Only .pt files are supported" | |
| return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) | |
| def add( | |
| path_1: str, | |
| path_2: str, | |
| output_path: str, | |
| alpha_1: float = 0.5, | |
| alpha_2: float = 0.5, | |
| mode: Literal[ | |
| "lpl", | |
| "upl", | |
| "upl-ckpt-v2", | |
| ] = "lpl", | |
| with_text_lora: bool = False, | |
| ): | |
| print("Lora Add, mode " + mode) | |
| if mode == "lpl": | |
| if path_1.endswith(".pt") and path_2.endswith(".pt"): | |
| for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + ( | |
| [(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")] | |
| if with_text_lora | |
| else [] | |
| ): | |
| print("Loading", _path_1, _path_2) | |
| out_list = [] | |
| if opt == "text_encoder": | |
| if not os.path.exists(_path_1): | |
| print(f"No text encoder found in {_path_1}, skipping...") | |
| continue | |
| if not os.path.exists(_path_2): | |
| print(f"No text encoder found in {_path_1}, skipping...") | |
| continue | |
| l1 = torch.load(_path_1) | |
| l2 = torch.load(_path_2) | |
| l1pairs = zip(l1[::2], l1[1::2]) | |
| l2pairs = zip(l2[::2], l2[1::2]) | |
| for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs): | |
| # print("Merging", x1.shape, y1.shape, x2.shape, y2.shape) | |
| x1.data = alpha_1 * x1.data + alpha_2 * x2.data | |
| y1.data = alpha_1 * y1.data + alpha_2 * y2.data | |
| out_list.append(x1) | |
| out_list.append(y1) | |
| if opt == "unet": | |
| print("Saving merged UNET to", output_path) | |
| torch.save(out_list, output_path) | |
| elif opt == "text_encoder": | |
| print("Saving merged text encoder to", _text_lora_path(output_path)) | |
| torch.save( | |
| out_list, | |
| _text_lora_path(output_path), | |
| ) | |
| elif path_1.endswith(".safetensors") and path_2.endswith(".safetensors"): | |
| safeloras_1 = safe_open(path_1, framework="pt", device="cpu") | |
| safeloras_2 = safe_open(path_2, framework="pt", device="cpu") | |
| metadata = dict(safeloras_1.metadata()) | |
| metadata.update(dict(safeloras_2.metadata())) | |
| ret_tensor = {} | |
| for keys in set(list(safeloras_1.keys()) + list(safeloras_2.keys())): | |
| if keys.startswith("text_encoder") or keys.startswith("unet"): | |
| tens1 = safeloras_1.get_tensor(keys) | |
| tens2 = safeloras_2.get_tensor(keys) | |
| tens = alpha_1 * tens1 + alpha_2 * tens2 | |
| ret_tensor[keys] = tens | |
| else: | |
| if keys in safeloras_1.keys(): | |
| tens1 = safeloras_1.get_tensor(keys) | |
| else: | |
| tens1 = safeloras_2.get_tensor(keys) | |
| ret_tensor[keys] = tens1 | |
| save_file(ret_tensor, output_path, metadata) | |
| elif mode == "upl": | |
| print( | |
| f"Merging UNET/CLIP from {path_1} with LoRA from {path_2} to {output_path}. Merging ratio : {alpha_1}." | |
| ) | |
| loaded_pipeline = StableDiffusionPipeline.from_pretrained( | |
| path_1, | |
| ).to("cpu") | |
| patch_pipe(loaded_pipeline, path_2) | |
| collapse_lora(loaded_pipeline.unet, alpha_1) | |
| collapse_lora(loaded_pipeline.text_encoder, alpha_1) | |
| monkeypatch_remove_lora(loaded_pipeline.unet) | |
| monkeypatch_remove_lora(loaded_pipeline.text_encoder) | |
| loaded_pipeline.save_pretrained(output_path) | |
| elif mode == "upl-ckpt-v2": | |
| assert output_path.endswith(".ckpt"), "Only .ckpt files are supported" | |
| name = os.path.basename(output_path)[0:-5] | |
| print( | |
| f"You will be using {name} as the token in A1111 webui. Make sure {name} is unique enough token." | |
| ) | |
| loaded_pipeline = StableDiffusionPipeline.from_pretrained( | |
| path_1, | |
| ).to("cpu") | |
| tok_dict = patch_pipe(loaded_pipeline, path_2, patch_ti=False) | |
| collapse_lora(loaded_pipeline.unet, alpha_1) | |
| collapse_lora(loaded_pipeline.text_encoder, alpha_1) | |
| monkeypatch_remove_lora(loaded_pipeline.unet) | |
| monkeypatch_remove_lora(loaded_pipeline.text_encoder) | |
| _tmp_output = output_path + ".tmp" | |
| loaded_pipeline.save_pretrained(_tmp_output) | |
| convert_to_ckpt(_tmp_output, output_path, as_half=True) | |
| # remove the tmp_output folder | |
| shutil.rmtree(_tmp_output) | |
| keys = sorted(tok_dict.keys()) | |
| tok_catted = torch.stack([tok_dict[k] for k in keys]) | |
| ret = { | |
| "string_to_token": {"*": torch.tensor(265)}, | |
| "string_to_param": {"*": tok_catted}, | |
| "name": name, | |
| } | |
| torch.save(ret, output_path[:-5] + ".pt") | |
| print( | |
| f"Textual embedding saved as {output_path[:-5]}.pt, put it in the embedding folder and use it as {name} in A1111 repo, " | |
| ) | |
| elif mode == "ljl": | |
| print("Using Join mode : alpha will not have an effect here.") | |
| assert path_1.endswith(".safetensors") and path_2.endswith( | |
| ".safetensors" | |
| ), "Only .safetensors files are supported" | |
| safeloras_1 = safe_open(path_1, framework="pt", device="cpu") | |
| safeloras_2 = safe_open(path_2, framework="pt", device="cpu") | |
| total_tensor, total_metadata, _, _ = lora_join([safeloras_1, safeloras_2]) | |
| save_file(total_tensor, output_path, total_metadata) | |
| else: | |
| print("Unknown mode", mode) | |
| raise ValueError(f"Unknown mode {mode}") | |
| def main(): | |
| fire.Fire(add) | |