|
|
import json |
|
|
from typing import Optional, Set |
|
|
|
|
|
import torch.nn as nn |
|
|
from safetensors.torch import save_file as safe_save |
|
|
|
|
|
from lora_diffusion.lora import DEFAULT_TARGET_REPLACE, LoraInjectedLinear, LoraInjectedConv2d |
|
|
|
|
|
|
|
|
def _find_modules_with_ancestor( |
|
|
model, |
|
|
ancestor_class: Optional[Set[str]] = None, |
|
|
search_class=None, |
|
|
exclude_children_of=None, |
|
|
): |
|
|
""" |
|
|
Find all modules of a certain class (or union of classes) that are direct or |
|
|
indirect descendants of other modules of a certain class (or union of classes). |
|
|
Returns all matching modules, along with the parent of those moduless and the |
|
|
names they are referenced by as the full ansestor name. |
|
|
This is a copy of _find_modules_v2 from lora.py. It was copied instead of |
|
|
refactored to keep the implementation in lora.py in sync with cloneofsimo. |
|
|
""" |
|
|
|
|
|
if exclude_children_of is None: |
|
|
exclude_children_of = [ |
|
|
LoraInjectedLinear, |
|
|
LoraInjectedConv2d, |
|
|
] |
|
|
if search_class is None: |
|
|
search_class = [nn.Linear] |
|
|
if ancestor_class is not None: |
|
|
ancestors = ( |
|
|
(name, module) |
|
|
for name, module in model.named_modules() |
|
|
if module.__class__.__name__ in ancestor_class |
|
|
) |
|
|
else: |
|
|
|
|
|
ancestors = [(name, module) for name, module in model.named_modules()] |
|
|
|
|
|
|
|
|
for anc_name, ancestor in ancestors: |
|
|
for fullname, module in ancestor.named_modules(): |
|
|
if any([isinstance(module, _class) for _class in search_class]): |
|
|
|
|
|
*path, name = fullname.split(".") |
|
|
parent = ancestor |
|
|
while path: |
|
|
parent = parent.get_submodule(path.pop(0)) |
|
|
|
|
|
if exclude_children_of and any( |
|
|
[isinstance(parent, _class) |
|
|
for _class in exclude_children_of] |
|
|
): |
|
|
continue |
|
|
|
|
|
yield parent, name, module, fullname, anc_name |
|
|
|
|
|
|
|
|
def get_extra_networks_diffsuers_key(name, child_module, full_child_name): |
|
|
""" |
|
|
Computed the diffusers key that is compatible with the extra networks feature in the webui. |
|
|
""" |
|
|
if child_module.__class__.__name__ == "LoraInjectedLinear" or child_module.__class__.__name__ == "LoraInjectedConv2d": |
|
|
lora_name = f"{name}.{full_child_name}" |
|
|
lora_name = lora_name.replace('.', '_') |
|
|
return lora_name |
|
|
else: |
|
|
print(f"Unsupported module type {child_module.__class__.__name__}") |
|
|
return None |
|
|
|
|
|
|
|
|
def get_extra_networks_ups_down(model, target_replace_module=None): |
|
|
""" |
|
|
Get a list of loras and keys for saving to extra networks. |
|
|
""" |
|
|
if target_replace_module is None: |
|
|
target_replace_module = DEFAULT_TARGET_REPLACE |
|
|
loras = [] |
|
|
|
|
|
for _m, child_name, _child_module, fullname, ancestor_name in _find_modules_with_ancestor( |
|
|
model, |
|
|
target_replace_module, |
|
|
search_class=[LoraInjectedLinear, LoraInjectedConv2d], |
|
|
): |
|
|
key = get_extra_networks_diffsuers_key( |
|
|
ancestor_name, _child_module, fullname) |
|
|
if key is None: |
|
|
print(f"{ancestor_name}, {child_name} did not converst to key.") |
|
|
loras.append((key, _child_module.lora_up, _child_module.lora_down)) |
|
|
|
|
|
if len(loras) == 0: |
|
|
raise ValueError("No lora injected.") |
|
|
|
|
|
return loras |
|
|
|
|
|
|
|
|
def save_extra_networks(modelmap={}, outpath="./lora.safetensors"): |
|
|
""" |
|
|
Saves the Lora from multiple modules in a single safetensor file that is compatible with extra_networks. |
|
|
modelmap is a dictionary of { |
|
|
"module name": (module, target_replace_module) |
|
|
} |
|
|
|
|
|
metadata contains a mapping from the keys to the normal lora keys. |
|
|
""" |
|
|
|
|
|
weights = {} |
|
|
metadata = {} |
|
|
|
|
|
metadata["lora_key_encoding"] = "extra_network_diffusers" |
|
|
for name, (model, target_replace_module) in modelmap.items(): |
|
|
metadata[name] = json.dumps(list(target_replace_module)) |
|
|
prefix = "lora_unet" if name == "unet" else "lora_te" |
|
|
rank = None |
|
|
for i, (_key, _up, _down) in enumerate( |
|
|
get_extra_networks_ups_down(model, target_replace_module) |
|
|
): |
|
|
try: |
|
|
rank = getattr(_down, "out_features") |
|
|
except: |
|
|
rank = getattr(_down, "out_channels") |
|
|
weights[f"{prefix}_{_key}.lora_up.weight"] = _up.weight |
|
|
weights[f"{prefix}_{_key}.lora_down.weight"] = _down.weight |
|
|
metadata[f"{name}:{i}:up"] = f"{prefix}_{_key}.lora_up.weight" |
|
|
metadata[f"{name}:{i}:down"] = f"{prefix}_{_key}.lora_down.weight" |
|
|
metadata[f"{name}:{i}:rank"] = str(rank) |
|
|
if rank: |
|
|
metadata[f"{prefix}_rank"] = f"{rank}" |
|
|
|
|
|
print(f"Saving weights to {outpath}") |
|
|
safe_save(weights, outpath, metadata) |
|
|
|