|
import os |
|
import re |
|
import sys |
|
|
|
from modules import sd_models, shared |
|
from modules.paths import extensions_builtin_dir |
|
|
|
from scripts.animatediff_logger import logger_animatediff as logger |
|
|
|
sys.path.append(f"{extensions_builtin_dir}/Lora") |
|
|
|
class AnimateDiffLora: |
|
original_load_network = None |
|
|
|
def __init__(self, v2: bool): |
|
self.v2 = v2 |
|
|
|
def hack(self): |
|
if not self.v2: |
|
return |
|
|
|
if AnimateDiffLora.original_load_network is not None: |
|
logger.info("AnimateDiff LoRA already hacked") |
|
return |
|
|
|
logger.info("Hacking LoRA module to support motion LoRA") |
|
import network |
|
import networks |
|
AnimateDiffLora.original_load_network = networks.load_network |
|
original_load_network = AnimateDiffLora.original_load_network |
|
|
|
def mm_load_network(name, network_on_disk): |
|
|
|
def convert_mm_name_to_compvis(key): |
|
sd_module_key, _, network_part = re.split(r'(_lora\.)', key) |
|
sd_module_key = sd_module_key.replace("processor.", "").replace("to_out", "to_out.0") |
|
return sd_module_key, 'lora_' + network_part |
|
|
|
net = network.Network(name, network_on_disk) |
|
net.mtime = os.path.getmtime(network_on_disk.filename) |
|
|
|
sd = sd_models.read_state_dict(network_on_disk.filename) |
|
|
|
if 'motion_modules' in list(sd.keys())[0]: |
|
logger.info(f"Loading motion LoRA {name} from {network_on_disk.filename}") |
|
matched_networks = {} |
|
|
|
for key_network, weight in sd.items(): |
|
key, network_part = convert_mm_name_to_compvis(key_network) |
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None) |
|
|
|
assert sd_module is not None, f"Failed to find sd module for key {key}." |
|
|
|
if key not in matched_networks: |
|
matched_networks[key] = network.NetworkWeights( |
|
network_key=key_network, sd_key=key, w={}, sd_module=sd_module) |
|
|
|
matched_networks[key].w[network_part] = weight |
|
|
|
for key, weights in matched_networks.items(): |
|
net_module = networks.module_types[0].create_module(net, weights) |
|
assert net_module is not None, "Failed to create motion module LoRA" |
|
net.modules[key] = net_module |
|
|
|
return net |
|
else: |
|
del sd |
|
return original_load_network(name, network_on_disk) |
|
|
|
networks.load_network = mm_load_network |
|
|
|
|
|
def restore(self): |
|
if not self.v2: |
|
return |
|
|
|
if AnimateDiffLora.original_load_network is None: |
|
logger.info("AnimateDiff LoRA already restored") |
|
return |
|
|
|
logger.info("Restoring hacked LoRA") |
|
import networks |
|
networks.load_network = AnimateDiffLora.original_load_network |
|
AnimateDiffLora.original_load_network = None |
|
|