from lora_diffusion.cli_lora_add import * from lora_diffusion.lora import * from lora_diffusion.to_ckpt_v2 import * def monkeypatch_or_replace_safeloras(models, safeloras): loras = parse_safeloras(safeloras) for name, (lora, ranks, target) in loras.items(): model = getattr(models, name, None) if not model: print(f"No model provided for {name}, contained in Lora") continue monkeypatch_or_replace_lora_extended(model, lora, target, ranks) def parse_safeloras( safeloras, ) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]: """ Converts a loaded safetensor file that contains a set of module Loras into Parameters and other information Output is a dictionary of { "module name": ( [list of weights], [list of ranks], target_replacement_modules ) } """ loras = {} # metadata = safeloras.metadata() metadata = safeloras['metadata'] safeloras_ = safeloras['weights'] get_name = lambda k: k.split(":")[0] keys = list(safeloras_.keys()) keys.sort(key=get_name) for name, module_keys in groupby(keys, get_name): info = metadata.get(name) if not info: raise ValueError( f"Tensor {name} has no metadata - is this a Lora safetensor?" ) # Skip Textual Inversion embeds if info == EMBED_FLAG: continue # Handle Loras # Extract the targets target = json.loads(info) # Build the result lists - Python needs us to preallocate lists to insert into them module_keys = list(module_keys) ranks = [4] * (len(module_keys) // 2) weights = [None] * len(module_keys) for key in module_keys: # Split the model name and index out of the key _, idx, direction = key.split(":") idx = int(idx) # Add the rank ranks[idx] = int(metadata[f"{name}:{idx}:rank"]) # Insert the weight into the list idx = idx * 2 + (1 if direction == "down" else 0) # weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key)) weights[idx] = nn.parameter.Parameter(safeloras_[key]) loras[name] = (weights, ranks, target) return loras def parse_safeloras_embeds( safeloras, ) -> Dict[str, torch.Tensor]: """ Converts a loaded safetensor file that contains Textual Inversion embeds into a dictionary of embed_token: Tensor """ embeds = {} metadata = safeloras['metadata'] safeloras_ = safeloras['weights'] for key in safeloras_.keys(): # Only handle Textual Inversion embeds meta=None if key in metadata: meta = metadata[key] if not meta or meta != EMBED_FLAG: continue embeds[key] = safeloras_[key] return embeds def patch_pipe( pipe, maybe_unet_path, token: Optional[str] = None, r: int = 4, patch_unet=True, patch_text=True, patch_ti=True, idempotent_token=True, unet_target_replace_module=DEFAULT_TARGET_REPLACE, text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, ): safeloras=maybe_unet_path monkeypatch_or_replace_safeloras(pipe, safeloras) tok_dict = parse_safeloras_embeds(safeloras) if patch_ti: apply_learned_embed_in_clip( tok_dict, pipe.text_encoder, pipe.tokenizer, token=token, idempotent=idempotent_token, ) return tok_dict def lora_convert(model_path, as_half): """ Modified version of lora_duffusion.to_ckpt_v2.convert_to_ckpt """ assert model_path is not None, "Must provide a model path!" unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") # Convert the UNet model unet_state_dict = torch.load(unet_path, map_location="cpu") unet_state_dict = convert_unet_state_dict(unet_state_dict) unet_state_dict = { "model.diffusion_model." + k: v for k, v in unet_state_dict.items() } # Convert the VAE model vae_state_dict = torch.load(vae_path, map_location="cpu") vae_state_dict = convert_vae_state_dict(vae_state_dict) vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} # Convert the text encoder model text_enc_dict = torch.load(text_enc_path, map_location="cpu") text_enc_dict = convert_text_enc_state_dict(text_enc_dict) text_enc_dict = { "cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items() } # Put together new checkpoint state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} if as_half: state_dict = {k: v.half() for k, v in state_dict.items()} return state_dict def merge(path_1: str, path_2: str, alpha_1: float = 0.5, ): loaded_pipeline = StableDiffusionPipeline.from_pretrained( path_1, ).to("cpu") tok_dict = patch_pipe(loaded_pipeline, path_2, patch_ti=False) collapse_lora(loaded_pipeline.unet, alpha_1) collapse_lora(loaded_pipeline.text_encoder, alpha_1) monkeypatch_remove_lora(loaded_pipeline.unet) monkeypatch_remove_lora(loaded_pipeline.text_encoder) _tmp_output = "./merge.tmp" loaded_pipeline.save_pretrained(_tmp_output) state_dict = lora_convert(_tmp_output, as_half=True) # remove the tmp_output folder shutil.rmtree(_tmp_output) keys = sorted(tok_dict.keys()) tok_catted = torch.stack([tok_dict[k] for k in keys]) ret = { "string_to_token": {"*": torch.tensor(265)}, "string_to_param": {"*": tok_catted}, "name": "", } return state_dict, ret