Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import typing | |
| import torch | |
| from safetensors import safe_open | |
| import lora as comfy_lora | |
| import comfy.utils as comfy_utils | |
| import comfy.model_patcher | |
| import folder_paths | |
| def _get_model_state_dict(model: typing.Any) -> dict: | |
| if hasattr(model, "model_state_dict"): | |
| try: | |
| return model.model_state_dict() | |
| except TypeError: | |
| return model.model_state_dict(None) | |
| return model.state_dict() | |
| def build_newbie_lora_key_map(model) -> dict: | |
| sd = _get_model_state_dict(model) | |
| key_map = {} | |
| for full_key in sd.keys(): | |
| if not full_key.endswith(".weight"): | |
| continue | |
| base = full_key[:-len(".weight")] | |
| variants = set() | |
| variants.add(base) | |
| variants.add("base_model.model." + base) | |
| variants.add("transformer." + base) | |
| short = None | |
| if base.startswith("diffusion_model."): | |
| short = base[len("diffusion_model."):] | |
| variants.add(short) | |
| variants.add("base_model.model." + short) | |
| variants.add("transformer." + short) | |
| variants.add("unet.base_model.model." + short) | |
| lyco_names = ["lycoris_" + base.replace(".", "_")] | |
| if short is not None: | |
| lyco_names.append("lycoris_" + short.replace(".", "_")) | |
| for name in lyco_names: | |
| variants.add(name) | |
| for v in variants: | |
| if v not in key_map: | |
| key_map[v] = full_key | |
| return key_map | |
| def load_newbie_lora_state_dict(lora_name: str) -> tuple: | |
| if not lora_name: | |
| raise ValueError("LoRA name is empty.") | |
| lora_path = folder_paths.get_full_path("loras", lora_name) | |
| if lora_path is None: | |
| raise FileNotFoundError(f"LoRA '{lora_name}' not found in models/loras folder.") | |
| if os.path.isdir(lora_path): | |
| raise ValueError(f"'{lora_path}' is a directory. Please select a LoRA file instead of a folder.") | |
| metadata = {} | |
| if lora_path.endswith('.safetensors'): | |
| with safe_open(lora_path, framework="pt", device="cpu") as f: | |
| metadata = f.metadata() or {} | |
| sd = comfy_utils.load_torch_file(lora_path) | |
| if not isinstance(sd, dict): | |
| raise ValueError(f"Loaded LoRA '{lora_name}' does not contain a valid state dict.") | |
| return sd, metadata | |
| def apply_newbie_lora_to_model( | |
| model, | |
| lora_name: str, | |
| strength: float, | |
| ) -> comfy.model_patcher.ModelPatcher: | |
| if strength == 0.0: | |
| return model | |
| if not isinstance(model, comfy.model_patcher.ModelPatcher): | |
| model = comfy.model_patcher.ModelPatcher(model) | |
| lora_sd, metadata = load_newbie_lora_state_dict(lora_name) | |
| scale = 1.0 | |
| if metadata: | |
| lora_rank = float(metadata.get("lora_rank", 0)) | |
| lora_alpha = float(metadata.get("lora_alpha", lora_rank)) | |
| if lora_rank > 0: | |
| scale = lora_alpha / lora_rank | |
| final_strength = strength * scale | |
| to_load = build_newbie_lora_key_map(model.model) | |
| patches = comfy_lora.load_lora(lora_sd, to_load, log_missing=True) | |
| if not patches: | |
| print(f"Warning: No valid patches found in LoRA '{lora_name}'.") | |
| return model | |
| patched_model = model.clone() | |
| patched_model.add_patches(patches, strength_patch=float(final_strength), strength_model=1.0) | |
| return patched_model |