import torch import torch.nn as nn from modules import devices, lowvram, shared, scripts cond_cast_unet = getattr(devices, 'cond_cast_unet', lambda x: x) from ldm.modules.diffusionmodules.util import timestep_embedding from ldm.modules.diffusionmodules.openaimodel import UNetModel class TorchHijackForUnet: """ This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 """ def __getattr__(self, item): if item == 'cat': return if hasattr(torch, item): return getattr(torch, item) raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) def cat(self, tensors, *args, **kwargs): if len(tensors) == 2: a, b = tensors if a.shape[-2:] != b.shape[-2:]: a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") tensors = (a, b) return, *args, **kwargs) th = TorchHijackForUnet() class ControlParams: def __init__( self, control_model, hint_cond, guess_mode, weight, guidance_stopped, start_guidance_percent, stop_guidance_percent, advanced_weighting, is_adapter, is_extra_cond ): self.control_model = control_model self.hint_cond = hint_cond self.guess_mode = guess_mode self.weight = weight self.guidance_stopped = guidance_stopped self.start_guidance_percent = start_guidance_percent self.stop_guidance_percent = stop_guidance_percent self.advanced_weighting = advanced_weighting self.is_adapter = is_adapter self.is_extra_cond = is_extra_cond class UnetHook(nn.Module): def __init__(self, lowvram=False) -> None: super().__init__() self.lowvram = lowvram self.batch_cond_available = True self.only_mid_control ="control_net_only_mid_control", False) def hook(self, model): outer = self def guidance_schedule_handler(x): for param in self.control_params: current_sampling_percent = (x.sampling_step / x.total_sampling_steps) param.guidance_stopped = current_sampling_percent < param.start_guidance_percent or current_sampling_percent > param.stop_guidance_percent def cfg_based_adder(base, x, require_autocast, is_adapter=False): if isinstance(x, float): return base + x if require_autocast: zeros = torch.zeros_like(base) zeros[:, :x.shape[1], ...] = x x = zeros # assume the input format is [cond, uncond] and they have same shape # see if base.shape[0] % 2 == 0 and (self.guess_mode or"control_net_cfg_based_guidance", False)): if self.is_vanilla_samplers: uncond, cond = base.chunk(2) if x.shape[0] % 2 == 0: _, x_cond = x.chunk(2) return[uncond, cond + x_cond], dim=0) if is_adapter: return[uncond, cond + x], dim=0) else: cond, uncond = base.chunk(2) if x.shape[0] % 2 == 0: x_cond, _ = x.chunk(2) return[cond + x_cond, uncond], dim=0) if is_adapter: return[cond + x, uncond], dim=0) return base + x def forward(self, x, timesteps=None, context=None, **kwargs): total_control = [0.0] * 13 total_adapter = [0.0] * 4 total_extra_cond = torch.zeros([0, context.shape[-1]]).to(devices.get_device_for("controlnet")) only_mid_control = outer.only_mid_control require_inpaint_hijack = False # handle external cond first for param in outer.control_params: if param.guidance_stopped or not param.is_extra_cond: continue if outer.lowvram:"controlnet")) control = param.control_model(x=x, hint=param.hint_cond, timesteps=timesteps, context=context) total_extra_cond =[total_extra_cond, control.clone().squeeze(0) * param.weight]) # check if it's non-batch-cond mode (lowvram, edit model etc) if context.shape[0] % 2 != 0 and outer.batch_cond_available: outer.batch_cond_available = False if len(total_extra_cond) > 0 or outer.guess_mode or"control_net_cfg_based_guidance", False): print("Warning: StyleAdapter and cfg/guess mode may not works due to non-batch-cond inference") # concat styleadapter to cond, pad uncond to same length if len(total_extra_cond) > 0 and outer.batch_cond_available: total_extra_cond = torch.repeat_interleave(total_extra_cond.unsqueeze(0), context.shape[0] // 2, dim=0) if outer.is_vanilla_samplers: uncond, cond = context.chunk(2) cond =[cond, total_extra_cond], dim=1) uncond =[uncond, uncond[:, -total_extra_cond.shape[1]:, :]], dim=1) context =[uncond, cond], dim=0) else: cond, uncond = context.chunk(2) cond =[cond, total_extra_cond], dim=1) uncond =[uncond, uncond[:, -total_extra_cond.shape[1]:, :]], dim=1) context =[cond, uncond], dim=0) # handle unet injection stuff for param in outer.control_params: if param.guidance_stopped or param.is_extra_cond: continue if outer.lowvram:"controlnet")) # hires stuffs # note that this method may not works if hr_scale < 1.1 if abs(x.shape[-1] - param.hint_cond.shape[-1] // 8) > 8: only_mid_control ="control_net_only_midctrl_hires", True) # If you want to completely disable control net, uncomment this. # return self._original_forward(x, timesteps=timesteps, context=context, **kwargs) # inpaint model workaround x_in = x control_model = param.control_model.control_model if not param.is_adapter and x.shape[1] != control_model.input_blocks[0][0].in_channels and x.shape[1] == 9: # inpaint_model: 4 data + 4 downscaled image + 1 mask x_in = x[:, :4, ...] require_inpaint_hijack = True assert param.hint_cond is not None, f"Controlnet is enabled but no input image is given" control = param.control_model(x=x_in, hint=param.hint_cond, timesteps=timesteps, context=context) control_scales = ([param.weight] * 13) if outer.lowvram:"cpu") if param.guess_mode: if param.is_adapter: # see control_scales = param.weight * [0.25, 0.62, 0.825, 1.0] else: control_scales = [param.weight * (0.825 ** float(12 - i)) for i in range(13)] if param.advanced_weighting is not None: control_scales = param.advanced_weighting control = [c * scale for c, scale in zip(control, control_scales)] for idx, item in enumerate(control): target = total_adapter if param.is_adapter else total_control target[idx] += item control = total_control assert timesteps is not None, ValueError(f"insufficient timestep: {timesteps}") hs = [] with th.no_grad(): t_emb = cond_cast_unet(timestep_embedding(timesteps, self.model_channels, repeat_only=False)) emb = self.time_embed(t_emb) h = x.type(self.dtype) for i, module in enumerate(self.input_blocks): h = module(h, emb, context) # t2i-adatper, same as if ((i+1)%3 == 0) and len(total_adapter): h = cfg_based_adder(h, total_adapter.pop(0), require_inpaint_hijack, is_adapter=True) hs.append(h) h = self.middle_block(h, emb, context) control_in = control.pop() h = cfg_based_adder(h, control_in, require_inpaint_hijack) for i, module in enumerate(self.output_blocks): if only_mid_control: hs_input = hs.pop() h =[h, hs_input], dim=1) else: hs_input, control_input = hs.pop(), control.pop() h =[h, cfg_based_adder(hs_input, control_input, require_inpaint_hijack)], dim=1) h = module(h, emb, context) h = h.type(x.dtype) return self.out(h) def forward2(*args, **kwargs): # webui will handle other compoments try: if shared.cmd_opts.lowvram: lowvram.send_everything_to_cpu() return forward(*args, **kwargs) finally: if self.lowvram: ["cpu") for param in self.control_params] model._original_forward = model.forward model.forward = forward2.__get__(model, UNetModel) scripts.script_callbacks.on_cfg_denoiser(guidance_schedule_handler) def notify(self, params, is_vanilla_samplers): # lint: list[ControlParams] self.is_vanilla_samplers = is_vanilla_samplers self.control_params = params self.guess_mode = any([param.guess_mode for param in params]) def restore(self, model): scripts.script_callbacks.remove_current_script_callbacks() if hasattr(self, "control_params"): del self.control_params if not hasattr(model, "_original_forward"): # no such handle, ignore return model.forward = model._original_forward del model._original_forward