from typing import Callable, Union import math import torch from torch import Tensor import comfy.model_management import comfy.sample import comfy.model_patcher import comfy.utils from comfy.controlnet import ControlBase from comfy.model_patcher import ModelPatcher from comfy.ldm.modules.attention import BasicTransformerBlock from comfy.ldm.modules.diffusionmodules import openaimodel from .logger import logger from .utils import (AdvancedControlBase, ControlWeights, TimestepKeyframeGroup, TimestepKeyframe, AbstractPreprocWrapper, broadcast_image_to_extend, ORIG_PREVIOUS_CONTROLNET, CONTROL_INIT_BY_ACN) REF_READ_ATTN_CONTROL_LIST = "ref_read_attn_control_list" REF_WRITE_ATTN_CONTROL_LIST = "ref_write_attn_control_list" REF_READ_ADAIN_CONTROL_LIST = "ref_read_adain_control_list" REF_WRITE_ADAIN_CONTROL_LIST = "ref_write_adain_control_list" REF_ATTN_CONTROL_LIST = "ref_attn_control_list" REF_ADAIN_CONTROL_LIST = "ref_adain_control_list" REF_CONTROL_LIST_ALL = "ref_control_list_all" REF_CONTROL_INFO = "ref_control_info" REF_ATTN_MACHINE_STATE = "ref_attn_machine_state" REF_ADAIN_MACHINE_STATE = "ref_adain_machine_state" REF_COND_IDXS = "ref_cond_idxs" REF_UNCOND_IDXS = "ref_uncond_idxs" CONTEXTREF_OPTIONS_CLASS = "contextref_options_class" CONTEXTREF_CLEAN_FUNC = "contextref_clean_func" CONTEXTREF_CONTROL_LIST_ALL = "contextref_control_list_all" CONTEXTREF_MACHINE_STATE = "contextref_machine_state" CONTEXTREF_TEMP_COND_IDX = "contextref_temp_cond_idx" HIGHEST_VERSION_SUPPORT = 1 RETURNED_CONTEXTREF_VERSION = 1 class RefConst: OPTS = "refcn_opts" CREF_MODE = "contextref_mode" class MachineState: WRITE = "write" READ = "read" READ_WRITE = "read_write" STYLEALIGN = "stylealign" OFF = "off" def is_read(state: str): return state in [MachineState.READ, MachineState.READ_WRITE] def is_write(state: str): return state in [MachineState.WRITE, MachineState.READ_WRITE] class ReferenceType: ATTN = "reference_attn" ADAIN = "reference_adain" ATTN_ADAIN = "reference_attn+adain" STYLE_ALIGN = "StyleAlign" _LIST = [ATTN, ADAIN, ATTN_ADAIN] _LIST_ATTN = [ATTN, ATTN_ADAIN] _LIST_ADAIN = [ADAIN, ATTN_ADAIN] @classmethod def is_attn(cls, ref_type: str): return ref_type in cls._LIST_ATTN @classmethod def is_adain(cls, ref_type: str): return ref_type in cls._LIST_ADAIN class ReferenceOptions: def __init__(self, reference_type: str, attn_style_fidelity: float, adain_style_fidelity: float, attn_ref_weight: float, adain_ref_weight: float, attn_strength: float=1.0, adain_strength: float=1.0, ref_with_other_cns: bool=False): self.reference_type = reference_type # attn self.original_attn_style_fidelity = attn_style_fidelity self.attn_style_fidelity = attn_style_fidelity self.attn_ref_weight = attn_ref_weight self.attn_strength = attn_strength # adain self.original_adain_style_fidelity = adain_style_fidelity self.adain_style_fidelity = adain_style_fidelity self.adain_ref_weight = adain_ref_weight self.adain_strength = adain_strength # other self.ref_with_other_cns = ref_with_other_cns def clone(self): return ReferenceOptions(reference_type=self.reference_type, attn_style_fidelity=self.original_attn_style_fidelity, adain_style_fidelity=self.original_adain_style_fidelity, attn_ref_weight=self.attn_ref_weight, adain_ref_weight=self.adain_ref_weight, attn_strength=self.attn_strength, adain_strength=self.adain_strength, ref_with_other_cns=self.ref_with_other_cns) @staticmethod def create_combo(reference_type: str, style_fidelity: float, ref_weight: float, ref_with_other_cns: bool=False): return ReferenceOptions(reference_type=reference_type, attn_style_fidelity=style_fidelity, adain_style_fidelity=style_fidelity, attn_ref_weight=ref_weight, adain_ref_weight=ref_weight, ref_with_other_cns=ref_with_other_cns) @staticmethod def create_from_kwargs(attn_style_fidelity=0.0, adain_style_fidelity=0.0, attn_ref_weight=0.0, adain_ref_weight=0.0, attn_strength=0.0, adain_strength=0.0, **kwargs): has_attn = attn_strength > 0.0 has_adain = adain_strength > 0.0 if has_attn and has_adain: reference_type = ReferenceType.ATTN_ADAIN elif has_adain: reference_type = ReferenceType.ADAIN else: reference_type = ReferenceType.ATTN return ReferenceOptions(reference_type=reference_type, attn_style_fidelity=float(attn_style_fidelity), adain_style_fidelity=float(adain_style_fidelity), attn_ref_weight=float(attn_ref_weight), adain_ref_weight=float(adain_ref_weight), attn_strength=float(attn_strength), adain_strength=float(adain_strength)) class ReferencePreprocWrapper(AbstractPreprocWrapper): error_msg = error_msg = "Invalid use of Reference Preprocess output. The output of Reference preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input." def __init__(self, condhint: Tensor): super().__init__(condhint) class ReferenceAdvanced(ControlBase, AdvancedControlBase): CHANNEL_TO_MULT = {320: 1, 640: 2, 1280: 4} def __init__(self, ref_opts: ReferenceOptions, timestep_keyframes: TimestepKeyframeGroup): super().__init__() AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite(), allow_condhint_latents=True) # TODO: allow vae_optional to be used instead of preprocessor #require_vae=True self._ref_opts = ref_opts self.order = 0 self.model_latent_format = None self.model_sampling_current = None self.should_apply_attn_effective_strength = False self.should_apply_adain_effective_strength = False self.should_apply_effective_masks = False self.latent_shape = None # ContextRef stuff self.is_context_ref = False self.contextref_cond_idx = -1 self.contextref_version = RETURNED_CONTEXTREF_VERSION @property def ref_opts(self): if self._current_timestep_keyframe is not None and self._current_timestep_keyframe.has_control_weights(): return self._current_timestep_keyframe.control_weights.extras.get(RefConst.OPTS, self._ref_opts) return self._ref_opts def any_attn_strength_to_apply(self): return self.should_apply_attn_effective_strength or self.should_apply_effective_masks def any_adain_strength_to_apply(self): return self.should_apply_adain_effective_strength or self.should_apply_effective_masks def get_effective_strength(self): effective_strength = self.strength if self._current_timestep_keyframe is not None: effective_strength = effective_strength * self._current_timestep_keyframe.strength return effective_strength def get_effective_attn_mask_or_float(self, x: Tensor, channels: int, is_mid: bool): if not self.should_apply_effective_masks: return self.get_effective_strength() * self.ref_opts.attn_strength if is_mid: div = 8 else: div = self.CHANNEL_TO_MULT[channels] real_mask = torch.ones([self.latent_shape[0], 1, self.latent_shape[2]//div, self.latent_shape[3]//div]).to(dtype=x.dtype, device=x.device) * self.strength * self.ref_opts.attn_strength self.apply_advanced_strengths_and_masks(x=real_mask, batched_number=self.batched_number) # mask is now shape [b, 1, h ,w]; need to turn into [b, h*w, 1] b, c, h, w = real_mask.shape real_mask = real_mask.permute(0, 2, 3, 1).reshape(b, h*w, c) return real_mask def get_effective_adain_mask_or_float(self, x: Tensor): if not self.should_apply_effective_masks: return self.get_effective_strength() * self.ref_opts.adain_strength b, c, h, w = x.shape real_mask = torch.ones([b, 1, h, w]).to(dtype=x.dtype, device=x.device) * self.strength * self.ref_opts.adain_strength self.apply_advanced_strengths_and_masks(x=real_mask, batched_number=self.batched_number) return real_mask def get_contextref_mode_replace(self): # used by ADE to get mode_replace for current keyframe if self._current_timestep_keyframe.has_control_weights(): return self._current_timestep_keyframe.control_weights.extras.get(RefConst.CREF_MODE, None) return None def should_run(self): running = super().should_run() if not running: return running attn_run = False adain_run = False if ReferenceType.is_attn(self.ref_opts.reference_type): # attn will run as long as neither weight or strength is zero attn_run = not (math.isclose(self.ref_opts.attn_ref_weight, 0.0) or math.isclose(self.ref_opts.attn_strength, 0.0)) if ReferenceType.is_adain(self.ref_opts.reference_type): # adain will run as long as neither weight or strength is zero adain_run = not (math.isclose(self.ref_opts.adain_ref_weight, 0.0) or math.isclose(self.ref_opts.adain_strength, 0.0)) return attn_run or adain_run def pre_run_advanced(self, model, percent_to_timestep_function): AdvancedControlBase.pre_run_advanced(self, model, percent_to_timestep_function) if isinstance(self.cond_hint_original, AbstractPreprocWrapper): self.cond_hint_original = self.cond_hint_original.condhint self.model_latent_format = model.latent_format # LatentFormat object, used to process_in latent cond_hint self.model_sampling_current = model.model_sampling # SDXL is more sensitive to style_fidelity according to sd-webui-controlnet comments; # prepare all ref_opts accordingly all_ref_opts = [self._ref_opts] for kf in self.timestep_keyframes.keyframes: if kf.has_control_weights() and RefConst.OPTS in kf.control_weights.extras: all_ref_opts.append(kf.control_weights.extras[RefConst.OPTS]) for ropts in all_ref_opts: if type(model).__name__ == "SDXL": ropts.attn_style_fidelity = ropts.original_attn_style_fidelity ** 3.0 ropts.adain_style_fidelity = ropts.original_adain_style_fidelity ** 3.0 else: ropts.attn_style_fidelity = ropts.original_attn_style_fidelity ropts.adain_style_fidelity = ropts.original_adain_style_fidelity def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int): # normal ControlNet stuff control_prev = None if self.previous_controlnet is not None: control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) if self.timestep_range is not None: if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: return control_prev dtype = x_noisy.dtype # cond_hint_original only matters for RefCN, NOT ContextRef if self.cond_hint_original is not None: # prepare cond_hint - it is a latent, NOT an image #if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] != self.cond_hint.shape[2] or x_noisy.shape[3] != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length: self.cond_hint = comfy.utils.common_upscale( self.cond_hint_original[self.sub_idxs], x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(x_noisy.device) else: self.cond_hint = comfy.utils.common_upscale( self.cond_hint_original, x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(x_noisy.device) if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number, except_one=False) # noise cond_hint based on sigma (current step) self.cond_hint = self.model_latent_format.process_in(self.cond_hint) self.cond_hint = ref_noise_latents(self.cond_hint, sigma=t, noise=None) timestep = self.model_sampling_current.timestep(t) self.should_apply_attn_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.attn_strength, 1.0)) self.should_apply_adain_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.adain_strength, 1.0)) # prepare mask - use direct_attn, so the mask dims will match source latents (and be smaller) self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, direct_attn=True) self.should_apply_effective_masks = self.latent_keyframes is not None or self.mask_cond_hint is not None or self.tk_mask_cond_hint is not None self.latent_shape = list(x_noisy.shape) # done preparing; model patches will take care of everything now. # return normal controlnet stuff return control_prev def cleanup_advanced(self): super().cleanup_advanced() del self.model_latent_format self.model_latent_format = None del self.model_sampling_current self.model_sampling_current = None self.should_apply_attn_effective_strength = False self.should_apply_adain_effective_strength = False self.should_apply_effective_masks = False def copy(self): c = ReferenceAdvanced(self.ref_opts, self.timestep_keyframes) c.order = self.order c.is_context_ref = self.is_context_ref self.copy_to(c) self.copy_to_advanced(c) return c # avoid deepcopy shenanigans by making deepcopy not do anything to the reference # TODO: do the bookkeeping to do this in a proper way for all Adv-ControlNets def __deepcopy__(self, memo): return self def handle_context_ref_setup(contextref_obj, transformer_options: dict, positive, negative): transformer_options[CONTEXTREF_MACHINE_STATE] = MachineState.OFF # verify version is compatible if contextref_obj.version > HIGHEST_VERSION_SUPPORT: raise Exception(f"AnimateDiff-Evolved's ContextRef v{contextref_obj.version} is not supported in currently-installed Advanced-ControlNet (only supports ContextRef up to v{HIGHEST_VERSION_SUPPORT}); " + f"update your Advanced-ControlNet nodes for ContextRef to work.") # init ReferenceOptions cref_opt_dict = contextref_obj.tune.create_dict() # ContextRefTune obj from ADE opts = ReferenceOptions.create_from_kwargs(**cref_opt_dict) # init TimestepKeyframes cref_tks_list = contextref_obj.keyframe.create_list_of_dicts() # ContextRefKeyframeGroup obj from ADE timestep_keyframes = _create_tks_from_dict_list(cref_tks_list) # create ReferenceAdvanced cref = ReferenceAdvanced(ref_opts=opts, timestep_keyframes=timestep_keyframes) cref.strength = contextref_obj.strength # ContextRef obj from ADE cref.set_cond_hint_mask(contextref_obj.mask) cref.order = 99 cref.is_context_ref = True context_ref_list = [cref] transformer_options[CONTEXTREF_CONTROL_LIST_ALL] = context_ref_list transformer_options[CONTEXTREF_OPTIONS_CLASS] = ReferenceOptions _add_context_ref_to_conds([positive, negative], cref) return context_ref_list def _create_tks_from_dict_list(dlist: list[dict[str]]) -> TimestepKeyframeGroup: tks = TimestepKeyframeGroup() if dlist is None or len(dlist) == 0: return tks for d in dlist: # scheduling start_percent = d["start_percent"] guarantee_steps = d["guarantee_steps"] inherit_missing = d["inherit_missing"] # values strength = d["strength"] mask = d["mask"] tune = d["tune"] mode = d["mode"] weights = None extras = {} if tune is not None: cref_opt_dict = tune.create_dict() # ContextRefTune obj from ADE opts = ReferenceOptions.create_from_kwargs(**cref_opt_dict) extras[RefConst.OPTS] = opts if mode is not None: extras[RefConst.CREF_MODE] = mode weights = ControlWeights.default(extras=extras) # create keyframe tk = TimestepKeyframe(start_percent=start_percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing, strength=strength, mask_hint_orig=mask, control_weights=weights) tks.add(tk) return tks def _add_context_ref_to_conds(conds: list[list[dict[str]]], context_ref: ReferenceAdvanced): def _add_context_ref_to_existing_control(control: ControlBase, context_ref: ReferenceAdvanced): curr_cn = control while curr_cn is not None: if type(curr_cn) == ReferenceAdvanced and curr_cn.is_context_ref: break if curr_cn.previous_controlnet is not None: curr_cn = curr_cn.previous_controlnet continue orig_previous_controlnet = curr_cn.previous_controlnet # NOTE: code is already in place to restore any ORIG_PREVIOUS_CONTROLNET props setattr(curr_cn, ORIG_PREVIOUS_CONTROLNET, orig_previous_controlnet) curr_cn.previous_controlnet = context_ref curr_cn = orig_previous_controlnet def _add_context_ref(actual_cond: dict[str], context_ref: ReferenceAdvanced): # if controls already present on cond, add it to the last previous_controlnet if "control" in actual_cond: return _add_context_ref_to_existing_control(actual_cond["control"], context_ref) # otherwise, need to add it to begin with, and should mark that it should be cleaned after actual_cond["control"] = context_ref actual_cond[CONTROL_INIT_BY_ACN] = True # either add context_ref to end of existing cnet chain, or init 'control' key on actual cond for cond in conds: if cond is not None: for sub_cond in cond: actual_cond = sub_cond[1] _add_context_ref(actual_cond, context_ref) def ref_noise_latents(latents: Tensor, sigma: Tensor, noise: Tensor=None): sigma = sigma.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) alpha_cumprod = 1 / ((sigma * sigma) + 1) sqrt_alpha_prod = alpha_cumprod ** 0.5 sqrt_one_minus_alpha_prod = (1. - alpha_cumprod) ** 0.5 if noise is None: # generator = torch.Generator(device="cuda") # generator.manual_seed(0) # noise = torch.empty_like(latents).normal_(generator=generator) # generator = torch.Generator() # generator.manual_seed(0) # noise = torch.randn(latents.size(), generator=generator).to(latents.device) noise = torch.randn_like(latents).to(latents.device) return sqrt_alpha_prod * latents + sqrt_one_minus_alpha_prod * noise def simple_noise_latents(latents: Tensor, sigma: float, noise: Tensor=None): if noise is None: noise = torch.rand_like(latents) return latents + noise * sigma class BankStylesBasicTransformerBlock: def __init__(self): # ref self.bank = [] self.style_cfgs = [] self.cn_idx: list[int] = [] # contextref - list of lists as each cond/uncond stored separately self.c_bank: list[list] = [] self.c_style_cfgs: list[list] = [] self.c_cn_idx: list[list[int]] = [] def get_bank(self, cref_idx, ignore_contextref, cdevice=None): if ignore_contextref or cref_idx >= len(self.c_bank): return self.bank real_c_bank_list = self.c_bank[cref_idx] if cdevice != None: real_c_bank_list = real_c_bank_list.copy() for i in range(len(real_c_bank_list)): real_c_bank_list[i] = real_c_bank_list[i].to(cdevice) return self.bank + real_c_bank_list def get_avg_style_fidelity(self, cref_idx, ignore_contextref): if ignore_contextref or cref_idx >= len(self.c_style_cfgs): return sum(self.style_cfgs) / float(len(self.style_cfgs)) combined = self.style_cfgs + self.c_style_cfgs[cref_idx] return sum(combined) / float(len(combined)) def get_cn_idxs(self, cref_idx, ignore_contxtref): if ignore_contxtref or cref_idx >= len(self.c_cn_idx): return self.cn_idx return self.cn_idx + self.c_cn_idx[cref_idx] def init_cref_for_idx(self, cref_idx: int): # makes sure cref lists can accommodate cref_idx if cref_idx < 0: return while cref_idx >= len(self.c_bank): self.c_bank.append([]) self.c_style_cfgs.append([]) self.c_cn_idx.append([]) def clear_cref_for_idx(self, cref_idx: int): if cref_idx < 0 or cref_idx >= len(self.c_bank): return self.c_bank[cref_idx] = [] self.c_style_cfgs[cref_idx] = [] self.c_cn_idx[cref_idx] = [] def clean_ref(self): del self.bank del self.style_cfgs del self.cn_idx self.bank = [] self.style_cfgs = [] self.cn_idx = [] def clean_contextref(self): del self.c_bank del self.c_style_cfgs del self.c_cn_idx self.c_bank = [] self.c_style_cfgs = [] self.c_cn_idx = [] def clean_all(self): self.clean_ref() self.clean_contextref() class BankStylesTimestepEmbedSequential: def __init__(self): # ref self.var_bank = [] self.mean_bank = [] self.style_cfgs = [] self.cn_idx: list[int] = [] # cref self.c_var_bank: list[list] = [] self.c_mean_bank: list[list] = [] self.c_style_cfgs: list[list] = [] self.c_cn_idx: list[list[int]] = [] def get_var_bank(self, cref_idx, ignore_contextref): if ignore_contextref or cref_idx >= len(self.c_var_bank): return self.var_bank return self.var_bank + self.c_var_bank[cref_idx] def get_mean_bank(self, cref_idx, ignore_contextref): if ignore_contextref or cref_idx >= len(self.c_mean_bank): return self.mean_bank return self.mean_bank + self.c_mean_bank[cref_idx] def get_style_cfgs(self, cref_idx, ignore_contextref): if ignore_contextref or cref_idx >= len(self.c_style_cfgs): return self.style_cfgs return self.style_cfgs + self.c_style_cfgs[cref_idx] def get_cn_idxs(self, cref_idx, ignore_contextref): if ignore_contextref or cref_idx >= len(self.c_cn_idx): return self.cn_idx return self.cn_idx + self.c_cn_idx[cref_idx] def init_cref_for_idx(self, cref_idx: int): # makes sure cref lists can accommodate cref_idx if cref_idx < 0: return while cref_idx >= len(self.c_var_bank): self.c_var_bank.append([]) self.c_mean_bank.append([]) self.c_style_cfgs.append([]) self.c_cn_idx.append([]) def clear_cref_for_idx(self, cref_idx: int): if cref_idx < 0 or cref_idx >= len(self.c_var_bank): return self.c_var_bank[cref_idx] = [] self.c_mean_bank[cref_idx] = [] self.c_style_cfgs[cref_idx] = [] self.c_cn_idx[cref_idx] = [] def clean_ref(self): del self.mean_bank del self.var_bank del self.style_cfgs del self.cn_idx self.mean_bank = [] self.var_bank = [] self.style_cfgs = [] self.cn_idx = [] def clean_contextref(self): del self.c_var_bank del self.c_mean_bank del self.c_style_cfgs del self.c_cn_idx self.c_var_bank = [] self.c_mean_bank = [] self.c_style_cfgs = [] self.c_cn_idx = [] def clean_all(self): self.clean_ref() self.clean_contextref() class InjectionBasicTransformerBlockHolder: def __init__(self, block: BasicTransformerBlock, idx=None): if hasattr(block, "_forward"): # backward compatibility self.original_forward = block._forward else: self.original_forward = block.forward self.idx = idx self.attn_weight = 1.0 self.is_middle = False self.bank_styles = BankStylesBasicTransformerBlock() def restore(self, block: BasicTransformerBlock): if hasattr(block, "_forward"): # backward compatibility block._forward = self.original_forward else: block.forward = self.original_forward def clean_ref(self): self.bank_styles.clean_ref() def clean_contextref(self): self.bank_styles.clean_contextref() def clean_all(self): self.bank_styles.clean_all() class InjectionTimestepEmbedSequentialHolder: def __init__(self, block: openaimodel.TimestepEmbedSequential, idx=None, is_middle=False, is_input=False, is_output=False): self.original_forward = block.forward self.idx = idx self.gn_weight = 1.0 self.is_middle = is_middle self.is_input = is_input self.is_output = is_output self.bank_styles = BankStylesTimestepEmbedSequential() def restore(self, block: openaimodel.TimestepEmbedSequential): block.forward = self.original_forward def clean_ref(self): self.bank_styles.clean_ref() def clean_contextref(self): self.bank_styles.clean_contextref() def clean_all(self): self.bank_styles.clean_all() class ReferenceInjections: def __init__(self, attn_modules: list['RefBasicTransformerBlock']=None, gn_modules: list['RefTimestepEmbedSequential']=None): self.attn_modules = attn_modules if attn_modules else [] self.gn_modules = gn_modules if gn_modules else [] self.diffusion_model_orig_forward: Callable = None def clean_ref_module_mem(self): for attn_module in self.attn_modules: try: attn_module.injection_holder.clean_ref() except Exception: pass for gn_module in self.gn_modules: try: gn_module.injection_holder.clean_ref() except Exception: pass def clean_contextref_module_mem(self): for attn_module in self.attn_modules: try: attn_module.injection_holder.clean_contextref() except Exception: pass for gn_module in self.gn_modules: try: gn_module.injection_holder.clean_contextref() except Exception: pass def clean_all_module_mem(self): for attn_module in self.attn_modules: try: attn_module.injection_holder.clean_all() except Exception: pass for gn_module in self.gn_modules: try: gn_module.injection_holder.clean_all() except Exception: pass def cleanup(self): self.clean_all_module_mem() del self.attn_modules self.attn_modules = [] del self.gn_modules self.gn_modules = [] self.diffusion_model_orig_forward = None def factory_forward_inject_UNetModel(reference_injections: ReferenceInjections): def forward_inject_UNetModel(self, x: Tensor, *args, **kwargs): # get control and transformer_options from kwargs real_args = list(args) real_kwargs = list(kwargs.keys()) control = kwargs.get("control", None) transformer_options: dict[str] = kwargs.get("transformer_options", {}) # NOTE: adds support for both ReferenceCN and ContextRef, so need to track them separately # get ReferenceAdvanced objects ref_controlnets: list[ReferenceAdvanced] = transformer_options.get(REF_CONTROL_LIST_ALL, []) context_controlnets: list[ReferenceAdvanced] = transformer_options.get(CONTEXTREF_CONTROL_LIST_ALL, []) # clean contextref stuff if OFF if len(context_controlnets) > 0 and transformer_options[CONTEXTREF_MACHINE_STATE] == MachineState.OFF: reference_injections.clean_contextref_module_mem() context_controlnets = [] # discard any controlnets that should not run ref_controlnets = [z for z in ref_controlnets if z.should_run()] context_controlnets = [z for z in context_controlnets if z.should_run()] # if nothing related to reference controlnets, do nothing special if len(ref_controlnets) == 0 and len(context_controlnets) == 0: return reference_injections.diffusion_model_orig_forward(x, *args, **kwargs) try: # assign cond and uncond idxs batched_number = len(transformer_options["cond_or_uncond"]) per_batch = x.shape[0] // batched_number indiv_conds = [] for cond_type in transformer_options["cond_or_uncond"]: indiv_conds.extend([cond_type] * per_batch) transformer_options[REF_UNCOND_IDXS] = [i for i, z in enumerate(indiv_conds) if z == 1] transformer_options[REF_COND_IDXS] = [i for i, z in enumerate(indiv_conds) if z == 0] # check which controlnets do which thing attn_controlnets = [] adain_controlnets = [] for control in ref_controlnets: if ReferenceType.is_attn(control.ref_opts.reference_type): attn_controlnets.append(control) if ReferenceType.is_adain(control.ref_opts.reference_type): adain_controlnets.append(control) context_attn_controlnets = [] context_adain_controlnets = [] # for ease of access, store current contextref_cond_idx value if len(context_controlnets) == 0: transformer_options[CONTEXTREF_TEMP_COND_IDX] = -1 else: transformer_options[CONTEXTREF_TEMP_COND_IDX] = context_controlnets[0].contextref_cond_idx # logger.info(f"{transformer_options[CONTEXTREF_MACHINE_STATE]}: {transformer_options[CONTEXTREF_TEMP_COND_IDX]}") for control in context_controlnets: if ReferenceType.is_attn(control.ref_opts.reference_type): context_attn_controlnets.append(control) if ReferenceType.is_adain(control.ref_opts.reference_type): context_adain_controlnets.append(control) if len(adain_controlnets) > 0 or len(context_adain_controlnets) > 0: # ComfyUI uses forward_timestep_embed with the TimestepEmbedSequential passed into it orig_forward_timestep_embed = openaimodel.forward_timestep_embed openaimodel.forward_timestep_embed = forward_timestep_embed_ref_inject_factory(orig_forward_timestep_embed) # if RefCN to be used, handle running diffusion with ref cond hints if len(ref_controlnets) > 0: for control in ref_controlnets: read_attn_list = [] write_attn_list = [] read_adain_list = [] write_adain_list = [] if ReferenceType.is_attn(control.ref_opts.reference_type): write_attn_list.append(control) if ReferenceType.is_adain(control.ref_opts.reference_type): write_adain_list.append(control) # apply lists transformer_options[REF_READ_ATTN_CONTROL_LIST] = read_attn_list transformer_options[REF_WRITE_ATTN_CONTROL_LIST] = write_attn_list transformer_options[REF_READ_ADAIN_CONTROL_LIST] = read_adain_list transformer_options[REF_WRITE_ADAIN_CONTROL_LIST] = write_adain_list orig_kwargs = kwargs # disable other controlnets for this run, if specified if not control.ref_opts.ref_with_other_cns: kwargs = kwargs.copy() kwargs["control"] = None reference_injections.diffusion_model_orig_forward(control.cond_hint.to(dtype=x.dtype).to(device=x.device), *args, **kwargs) kwargs = orig_kwargs # prepare running diffusion for real now read_attn_list = [] write_attn_list = [] read_adain_list = [] write_adain_list = [] # add RefCNs to read lists read_attn_list.extend(attn_controlnets) read_adain_list.extend(adain_controlnets) # do contextref stuff, if needed if len(context_controlnets) > 0: # clean contextref stuff if first WRITE # if context_controlnets[0].contextref_cond_idx == 0 and is_write(transformer_options[CONTEXTREF_MACHINE_STATE]): # reference_injections.clean_contextref_module_mem() ### add ContextRef to appropriate lists # attn if is_read(transformer_options[CONTEXTREF_MACHINE_STATE]): read_attn_list.extend(context_attn_controlnets) if is_write(transformer_options[CONTEXTREF_MACHINE_STATE]): write_attn_list.extend(context_attn_controlnets) # adain if is_read(transformer_options[CONTEXTREF_MACHINE_STATE]): read_adain_list.extend(context_adain_controlnets) if is_write(transformer_options[CONTEXTREF_MACHINE_STATE]): write_adain_list.extend(context_adain_controlnets) # apply lists, containing both RefCN and ContextRef transformer_options[REF_READ_ATTN_CONTROL_LIST] = read_attn_list transformer_options[REF_WRITE_ATTN_CONTROL_LIST] = write_attn_list transformer_options[REF_READ_ADAIN_CONTROL_LIST] = read_adain_list transformer_options[REF_WRITE_ADAIN_CONTROL_LIST] = write_adain_list # run diffusion for real try: return reference_injections.diffusion_model_orig_forward(x, *args, **kwargs) finally: # increment current cond idx if len(context_controlnets) > 0: for cn in context_controlnets: cn.contextref_cond_idx += 1 finally: # make sure ref banks are cleared no matter what happens - otherwise, RIP VRAM reference_injections.clean_ref_module_mem() if len(adain_controlnets) > 0 or len(context_adain_controlnets) > 0: openaimodel.forward_timestep_embed = orig_forward_timestep_embed return forward_inject_UNetModel # dummy class just to help IDE keep track of injected variables class RefBasicTransformerBlock(BasicTransformerBlock): injection_holder: InjectionBasicTransformerBlockHolder = None def _forward_inject_BasicTransformerBlock(self: RefBasicTransformerBlock, x: Tensor, context: Tensor=None, transformer_options: dict[str]={}): extra_options = {} block = transformer_options.get("block", None) block_index = transformer_options.get("block_index", 0) transformer_patches = {} transformer_patches_replace = {} for k in transformer_options: if k == "patches": transformer_patches = transformer_options[k] elif k == "patches_replace": transformer_patches_replace = transformer_options[k] else: extra_options[k] = transformer_options[k] extra_options["n_heads"] = self.n_heads extra_options["dim_head"] = self.d_head if self.ff_in: x_skip = x x = self.ff_in(self.norm_in(x)) if self.is_res: x += x_skip n: Tensor = self.norm1(x) if self.disable_self_attn: context_attn1 = context else: context_attn1 = None value_attn1 = None # Reference CN stuff uc_idx_mask = transformer_options.get(REF_UNCOND_IDXS, []) #c_idx_mask = transformer_options.get(REF_COND_IDXS, []) # WRITE mode may have only 1 ReferenceAdvanced for RefCN at a time, other modes will have all ReferenceAdvanced ref_write_cns: list[ReferenceAdvanced] = transformer_options.get(REF_WRITE_ATTN_CONTROL_LIST, []) ref_read_cns: list[ReferenceAdvanced] = transformer_options.get(REF_READ_ATTN_CONTROL_LIST, []) cref_cond_idx: int = transformer_options.get(CONTEXTREF_TEMP_COND_IDX, -1) ignore_contextref_read = cref_cond_idx < 0 # if writing to bank, should NOT be read in the same execution cached_n = None cref_write_cns: list[ReferenceAdvanced] = [] # check if any WRITE cns are applicable; Reference CN WRITEs immediately, ContextREF WRITEs after READ completed # if any refs to WRITE, save n and style_fidelity for refcn in ref_write_cns: if refcn.ref_opts.attn_ref_weight > self.injection_holder.attn_weight: if cached_n is None: cached_n = n.detach().clone() # for ContextRef, make sure relevant lists are long enough to cond_idx # store RefCN and ContextRef stuff separately if refcn.is_context_ref: cref_write_cns.append(refcn) self.injection_holder.bank_styles.init_cref_for_idx(cref_cond_idx) else: # Reference CN WRITE self.injection_holder.bank_styles.bank.append(cached_n) self.injection_holder.bank_styles.style_cfgs.append(refcn.ref_opts.attn_style_fidelity) self.injection_holder.bank_styles.cn_idx.append(refcn.order) if len(cref_write_cns) == 0: del cached_n if "attn1_patch" in transformer_patches: patch = transformer_patches["attn1_patch"] if context_attn1 is None: context_attn1 = n value_attn1 = context_attn1 for p in patch: n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options) if block is not None: transformer_block = (block[0], block[1], block_index) else: transformer_block = None attn1_replace_patch = transformer_patches_replace.get("attn1", {}) block_attn1 = transformer_block if block_attn1 not in attn1_replace_patch: block_attn1 = block if block_attn1 in attn1_replace_patch: if context_attn1 is None: context_attn1 = n value_attn1 = n n = self.attn1.to_q(n) # Reference CN READ - use attn1_replace_patch appropriately if len(ref_read_cns) > 0 and len(self.injection_holder.bank_styles.get_bank(cref_cond_idx, ignore_contextref_read)) > 0: bank_styles = self.injection_holder.bank_styles style_fidelity = bank_styles.get_avg_style_fidelity(cref_cond_idx, ignore_contextref_read) real_bank = bank_styles.get_bank(cref_cond_idx, ignore_contextref_read, cdevice=n.device).copy() real_cn_idxs = bank_styles.get_cn_idxs(cref_cond_idx, ignore_contextref_read) cn_idx = 0 for idx, order in enumerate(real_cn_idxs): # make sure matching ref cn is selected for i in range(cn_idx, len(ref_read_cns)): if ref_read_cns[i].order == order: cn_idx = i break assert order == ref_read_cns[cn_idx].order if ref_read_cns[cn_idx].any_attn_strength_to_apply(): effective_strength = ref_read_cns[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle) real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength) n_uc = self.attn1.to_out(attn1_replace_patch[block_attn1]( n, self.attn1.to_k(torch.cat([context_attn1] + real_bank, dim=1)), self.attn1.to_v(torch.cat([value_attn1] + real_bank, dim=1)), extra_options)) n_c = n_uc.clone() if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0): n_c[uc_idx_mask] = self.attn1.to_out(attn1_replace_patch[block_attn1]( n[uc_idx_mask], self.attn1.to_k(context_attn1[uc_idx_mask]), self.attn1.to_v(value_attn1[uc_idx_mask]), extra_options)) n = style_fidelity * n_c + (1.0-style_fidelity) * n_uc bank_styles.clean_ref() else: context_attn1 = self.attn1.to_k(context_attn1) value_attn1 = self.attn1.to_v(value_attn1) n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) n = self.attn1.to_out(n) else: # Reference CN READ - no attn1_replace_patch if len(ref_read_cns) > 0 and len(self.injection_holder.bank_styles.get_bank(cref_cond_idx, ignore_contextref_read)) > 0: if context_attn1 is None: context_attn1 = n bank_styles = self.injection_holder.bank_styles style_fidelity = bank_styles.get_avg_style_fidelity(cref_cond_idx, ignore_contextref_read) real_bank = bank_styles.get_bank(cref_cond_idx, ignore_contextref_read, cdevice=n.device).copy() real_cn_idxs = bank_styles.get_cn_idxs(cref_cond_idx, ignore_contextref_read) cn_idx = 0 for idx, order in enumerate(real_cn_idxs): # make sure matching ref cn is selected for i in range(cn_idx, len(ref_read_cns)): if ref_read_cns[i].order == order: cn_idx = i break assert order == ref_read_cns[cn_idx].order if ref_read_cns[cn_idx].any_attn_strength_to_apply(): effective_strength = ref_read_cns[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle) real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength) n_uc: Tensor = self.attn1( n, context=torch.cat([context_attn1] + real_bank, dim=1), value=torch.cat([value_attn1] + real_bank, dim=1) if value_attn1 is not None else value_attn1) n_c = n_uc.clone() if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0): n_c[uc_idx_mask] = self.attn1( n[uc_idx_mask], context=context_attn1[uc_idx_mask], value=value_attn1[uc_idx_mask] if value_attn1 is not None else value_attn1) n = style_fidelity * n_c + (1.0-style_fidelity) * n_uc bank_styles.clean_ref() else: n = self.attn1(n, context=context_attn1, value=value_attn1) # ContextRef CN WRITE if len(cref_write_cns) > 0: # clear so that ContextRef CNs can properly 'replace' previous value at cond_idx self.injection_holder.bank_styles.clear_cref_for_idx(cref_cond_idx) for refcn in cref_write_cns: # add a whole list to match expected type when combining self.injection_holder.bank_styles.c_bank[cref_cond_idx].append(cached_n.to(comfy.model_management.unet_offload_device())) self.injection_holder.bank_styles.c_style_cfgs[cref_cond_idx].append(refcn.ref_opts.attn_style_fidelity) self.injection_holder.bank_styles.c_cn_idx[cref_cond_idx].append(refcn.order) del cached_n if "attn1_output_patch" in transformer_patches: patch = transformer_patches["attn1_output_patch"] for p in patch: n = p(n, extra_options) x += n if "middle_patch" in transformer_patches: patch = transformer_patches["middle_patch"] for p in patch: x = p(x, extra_options) if self.attn2 is not None: n = self.norm2(x) if self.switch_temporal_ca_to_sa: context_attn2 = n else: context_attn2 = context value_attn2 = None if "attn2_patch" in transformer_patches: patch = transformer_patches["attn2_patch"] value_attn2 = context_attn2 for p in patch: n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) attn2_replace_patch = transformer_patches_replace.get("attn2", {}) block_attn2 = transformer_block if block_attn2 not in attn2_replace_patch: block_attn2 = block if block_attn2 in attn2_replace_patch: if value_attn2 is None: value_attn2 = context_attn2 n = self.attn2.to_q(n) context_attn2 = self.attn2.to_k(context_attn2) value_attn2 = self.attn2.to_v(value_attn2) n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) n = self.attn2.to_out(n) else: n = self.attn2(n, context=context_attn2, value=value_attn2) if "attn2_output_patch" in transformer_patches: patch = transformer_patches["attn2_output_patch"] for p in patch: n = p(n, extra_options) x += n if self.is_res: x_skip = x x = self.ff(self.norm3(x)) if self.is_res: x += x_skip return x class RefTimestepEmbedSequential(openaimodel.TimestepEmbedSequential): injection_holder: InjectionTimestepEmbedSequentialHolder = None def forward_timestep_embed_ref_inject_factory(orig_timestep_embed_inject_factory: Callable): def forward_timestep_embed_ref_inject(*args, **kwargs): ts: RefTimestepEmbedSequential = args[0] if not hasattr(ts, "injection_holder"): return orig_timestep_embed_inject_factory(*args, **kwargs) eps = 1e-6 x: Tensor = orig_timestep_embed_inject_factory(*args, **kwargs) y: Tensor = None transformer_options: dict[str] = args[4] # Reference CN stuff uc_idx_mask = transformer_options.get(REF_UNCOND_IDXS, []) #c_idx_mask = transformer_options.get(REF_COND_IDXS, []) # WRITE mode will only have one ReferenceAdvanced, other modes will have all ReferenceAdvanced ref_write_cns: list[ReferenceAdvanced] = transformer_options.get(REF_WRITE_ADAIN_CONTROL_LIST, []) ref_read_cns: list[ReferenceAdvanced] = transformer_options.get(REF_READ_ADAIN_CONTROL_LIST, []) cref_cond_idx: int = transformer_options.get(CONTEXTREF_TEMP_COND_IDX, -1) ignore_contextref_read = cref_cond_idx < 0 # if writing to bank, should NOT be read in the same execution cached_var = None cached_mean = None cref_write_cns: list[ReferenceAdvanced] = [] # if any refs to WRITE, save var, mean, and style_cfg for refcn in ref_write_cns: if refcn.ref_opts.adain_ref_weight > ts.injection_holder.gn_weight: if cached_var is None: cached_var, cached_mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) if refcn.is_context_ref: cref_write_cns.append(refcn) ts.injection_holder.bank_styles.init_cref_for_idx(cref_cond_idx) else: ts.injection_holder.bank_styles.var_bank.append(cached_var) ts.injection_holder.bank_styles.mean_bank.append(cached_mean) ts.injection_holder.bank_styles.style_cfgs.append(refcn.ref_opts.adain_style_fidelity) ts.injection_holder.bank_styles.cn_idx.append(refcn.order) if len(cref_write_cns) == 0: del cached_var del cached_mean # if any refs to READ, do math with saved var, mean, and style_cfg if len(ref_read_cns) > 0: if len(ts.injection_holder.bank_styles.get_var_bank(cref_cond_idx, ignore_contextref_read)) > 0: bank_styles = ts.injection_holder.bank_styles var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 y_uc = torch.zeros_like(x) cn_idx = 0 real_style_cfgs = bank_styles.get_style_cfgs(cref_cond_idx, ignore_contextref_read) real_var_bank = bank_styles.get_var_bank(cref_cond_idx, ignore_contextref_read) real_mean_bank = bank_styles.get_mean_bank(cref_cond_idx, ignore_contextref_read) real_cn_idxs = bank_styles.get_cn_idxs(cref_cond_idx, ignore_contextref_read) for idx, order in enumerate(real_cn_idxs): # make sure matching ref cn is selected for i in range(cn_idx, len(ref_read_cns)): if ref_read_cns[i].order == order: cn_idx = i break assert order == ref_read_cns[cn_idx].order style_fidelity = real_style_cfgs[idx] var_acc = real_var_bank[idx] mean_acc = real_mean_bank[idx] std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 sub_y_uc = (((x - mean) / std) * std_acc) + mean_acc if ref_read_cns[cn_idx].any_adain_strength_to_apply(): effective_strength = ref_read_cns[cn_idx].get_effective_adain_mask_or_float(x=x) sub_y_uc = sub_y_uc * effective_strength + x * (1-effective_strength) y_uc += sub_y_uc # get average, if more than one if len(real_cn_idxs) > 1: y_uc /= len(real_cn_idxs) y_c = y_uc.clone() if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0): y_c[uc_idx_mask] = x.to(y_c.dtype)[uc_idx_mask] y = style_fidelity * y_c + (1.0 - style_fidelity) * y_uc ts.injection_holder.bank_styles.clean_ref() # ContextRef CN WRITE if len(cref_write_cns) > 0: # clear so that ContextRef CNs can properly 'replace' previous value at cond_idx ts.injection_holder.bank_styles.clear_cref_for_idx(cref_cond_idx) for refcn in cref_write_cns: # add a whole list to match expected type when combining ts.injection_holder.bank_styles.c_var_bank[cref_cond_idx].append(cached_var) ts.injection_holder.bank_styles.c_mean_bank[cref_cond_idx].append(cached_mean) ts.injection_holder.bank_styles.c_style_cfgs[cref_cond_idx].append(refcn.ref_opts.adain_style_fidelity) ts.injection_holder.bank_styles.c_cn_idx[cref_cond_idx].append(refcn.order) del cached_var del cached_mean if y is None: y = x return y.to(x.dtype) return forward_timestep_embed_ref_inject