toto10's picture
c1a010c78919f8efef07b18490f90b54ce38605a218a48907aa972358dd0fbc1
9c5ffc0
raw
history blame
1.6 kB
from __future__ import annotations
from contextlib import contextmanager
from modules import img2img, processing, shared
def cn_restore_unet_hook(p, cn_latest_network):
if cn_latest_network is not None:
unet = p.sd_model.model.diffusion_model
cn_latest_network.restore(unet)
class CNHijackRestore:
def __init__(self):
self.process = hasattr(processing, "__controlnet_original_process_images_inner")
self.img2img = hasattr(img2img, "__controlnet_original_process_batch")
def __enter__(self):
if self.process:
self.orig_process = processing.process_images_inner
processing.process_images_inner = getattr(
processing, "__controlnet_original_process_images_inner"
)
if self.img2img:
self.orig_img2img = img2img.process_batch
img2img.process_batch = getattr(
img2img, "__controlnet_original_process_batch"
)
def __exit__(self, *args, **kwargs):
if self.process:
processing.process_images_inner = self.orig_process
if self.img2img:
img2img.process_batch = self.orig_img2img
@contextmanager
def cn_allow_script_control():
orig = False
if "control_net_allow_script_control" in shared.opts.data:
try:
orig = shared.opts.data["control_net_allow_script_control"]
shared.opts.data["control_net_allow_script_control"] = True
yield
finally:
shared.opts.data["control_net_allow_script_control"] = orig
else:
yield