from __future__ import annotations from contextlib import contextmanager from modules import img2img, processing, shared def cn_restore_unet_hook(p, cn_latest_network): if cn_latest_network is not None: unet = p.sd_model.model.diffusion_model cn_latest_network.restore(unet) class CNHijackRestore: def __init__(self): self.process = hasattr(processing, "__controlnet_original_process_images_inner") self.img2img = hasattr(img2img, "__controlnet_original_process_batch") def __enter__(self): if self.process: self.orig_process = processing.process_images_inner processing.process_images_inner = getattr( processing, "__controlnet_original_process_images_inner" ) if self.img2img: self.orig_img2img = img2img.process_batch img2img.process_batch = getattr( img2img, "__controlnet_original_process_batch" ) def __exit__(self, *args, **kwargs): if self.process: processing.process_images_inner = self.orig_process if self.img2img: img2img.process_batch = self.orig_img2img @contextmanager def cn_allow_script_control(): orig = False if "control_net_allow_script_control" in shared.opts.data: try: orig = shared.opts.data["control_net_allow_script_control"] shared.opts.data["control_net_allow_script_control"] = True yield finally: shared.opts.data["control_net_allow_script_control"] = orig else: yield