diff --git "a/ComfyUI-Advanced-ControlNet/adv_control/control_reference.py" "b/ComfyUI-Advanced-ControlNet/adv_control/control_reference.py" --- "a/ComfyUI-Advanced-ControlNet/adv_control/control_reference.py" +++ "b/ComfyUI-Advanced-ControlNet/adv_control/control_reference.py" @@ -1,1112 +1,1112 @@ -from typing import Callable, Union - -import math -import torch -from torch import Tensor - -import comfy.model_management -import comfy.sample -import comfy.model_patcher -import comfy.utils -from comfy.controlnet import ControlBase -from comfy.model_patcher import ModelPatcher -from comfy.ldm.modules.attention import BasicTransformerBlock -from comfy.ldm.modules.diffusionmodules import openaimodel - -from .logger import logger -from .utils import (AdvancedControlBase, ControlWeights, TimestepKeyframeGroup, TimestepKeyframe, AbstractPreprocWrapper, - broadcast_image_to_extend, ORIG_PREVIOUS_CONTROLNET, CONTROL_INIT_BY_ACN) - - -REF_READ_ATTN_CONTROL_LIST = "ref_read_attn_control_list" -REF_WRITE_ATTN_CONTROL_LIST = "ref_write_attn_control_list" -REF_READ_ADAIN_CONTROL_LIST = "ref_read_adain_control_list" -REF_WRITE_ADAIN_CONTROL_LIST = "ref_write_adain_control_list" - -REF_ATTN_CONTROL_LIST = "ref_attn_control_list" -REF_ADAIN_CONTROL_LIST = "ref_adain_control_list" -REF_CONTROL_LIST_ALL = "ref_control_list_all" -REF_CONTROL_INFO = "ref_control_info" -REF_ATTN_MACHINE_STATE = "ref_attn_machine_state" -REF_ADAIN_MACHINE_STATE = "ref_adain_machine_state" -REF_COND_IDXS = "ref_cond_idxs" -REF_UNCOND_IDXS = "ref_uncond_idxs" - -CONTEXTREF_OPTIONS_CLASS = "contextref_options_class" -CONTEXTREF_CLEAN_FUNC = "contextref_clean_func" -CONTEXTREF_CONTROL_LIST_ALL = "contextref_control_list_all" -CONTEXTREF_MACHINE_STATE = "contextref_machine_state" -CONTEXTREF_TEMP_COND_IDX = "contextref_temp_cond_idx" - -HIGHEST_VERSION_SUPPORT = 1 -RETURNED_CONTEXTREF_VERSION = 1 - - -class RefConst: - OPTS = "refcn_opts" - CREF_MODE = "contextref_mode" - - -class MachineState: - WRITE = "write" - READ = "read" - READ_WRITE = "read_write" - STYLEALIGN = "stylealign" - OFF = "off" - -def is_read(state: str): - return state in [MachineState.READ, MachineState.READ_WRITE] - -def is_write(state: str): - return state in [MachineState.WRITE, MachineState.READ_WRITE] - - -class ReferenceType: - ATTN = "reference_attn" - ADAIN = "reference_adain" - ATTN_ADAIN = "reference_attn+adain" - STYLE_ALIGN = "StyleAlign" - - _LIST = [ATTN, ADAIN, ATTN_ADAIN] - _LIST_ATTN = [ATTN, ATTN_ADAIN] - _LIST_ADAIN = [ADAIN, ATTN_ADAIN] - - @classmethod - def is_attn(cls, ref_type: str): - return ref_type in cls._LIST_ATTN - - @classmethod - def is_adain(cls, ref_type: str): - return ref_type in cls._LIST_ADAIN - - -class ReferenceOptions: - def __init__(self, reference_type: str, - attn_style_fidelity: float, adain_style_fidelity: float, - attn_ref_weight: float, adain_ref_weight: float, - attn_strength: float=1.0, adain_strength: float=1.0, - ref_with_other_cns: bool=False): - self.reference_type = reference_type - # attn - self.original_attn_style_fidelity = attn_style_fidelity - self.attn_style_fidelity = attn_style_fidelity - self.attn_ref_weight = attn_ref_weight - self.attn_strength = attn_strength - # adain - self.original_adain_style_fidelity = adain_style_fidelity - self.adain_style_fidelity = adain_style_fidelity - self.adain_ref_weight = adain_ref_weight - self.adain_strength = adain_strength - # other - self.ref_with_other_cns = ref_with_other_cns - - def clone(self): - return ReferenceOptions(reference_type=self.reference_type, - attn_style_fidelity=self.original_attn_style_fidelity, adain_style_fidelity=self.original_adain_style_fidelity, - attn_ref_weight=self.attn_ref_weight, adain_ref_weight=self.adain_ref_weight, - attn_strength=self.attn_strength, adain_strength=self.adain_strength, - ref_with_other_cns=self.ref_with_other_cns) - - @staticmethod - def create_combo(reference_type: str, style_fidelity: float, ref_weight: float, ref_with_other_cns: bool=False): - return ReferenceOptions(reference_type=reference_type, - attn_style_fidelity=style_fidelity, adain_style_fidelity=style_fidelity, - attn_ref_weight=ref_weight, adain_ref_weight=ref_weight, - ref_with_other_cns=ref_with_other_cns) - - @staticmethod - def create_from_kwargs(attn_style_fidelity=0.0, adain_style_fidelity=0.0, - attn_ref_weight=0.0, adain_ref_weight=0.0, - attn_strength=0.0, adain_strength=0.0, **kwargs): - has_attn = attn_strength > 0.0 - has_adain = adain_strength > 0.0 - if has_attn and has_adain: - reference_type = ReferenceType.ATTN_ADAIN - elif has_adain: - reference_type = ReferenceType.ADAIN - else: - reference_type = ReferenceType.ATTN - return ReferenceOptions(reference_type=reference_type, - attn_style_fidelity=float(attn_style_fidelity), adain_style_fidelity=float(adain_style_fidelity), - attn_ref_weight=float(attn_ref_weight), adain_ref_weight=float(adain_ref_weight), - attn_strength=float(attn_strength), adain_strength=float(adain_strength)) - - -class ReferencePreprocWrapper(AbstractPreprocWrapper): - error_msg = error_msg = "Invalid use of Reference Preprocess output. The output of Reference preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input." - def __init__(self, condhint: Tensor): - super().__init__(condhint) - - -class ReferenceAdvanced(ControlBase, AdvancedControlBase): - CHANNEL_TO_MULT = {320: 1, 640: 2, 1280: 4} - - def __init__(self, ref_opts: ReferenceOptions, timestep_keyframes: TimestepKeyframeGroup): - super().__init__() - AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite(), allow_condhint_latents=True) - # TODO: allow vae_optional to be used instead of preprocessor - #require_vae=True - self._ref_opts = ref_opts - self.order = 0 - self.model_latent_format = None - self.model_sampling_current = None - self.should_apply_attn_effective_strength = False - self.should_apply_adain_effective_strength = False - self.should_apply_effective_masks = False - self.latent_shape = None - # ContextRef stuff - self.is_context_ref = False - self.contextref_cond_idx = -1 - self.contextref_version = RETURNED_CONTEXTREF_VERSION - - @property - def ref_opts(self): - if self._current_timestep_keyframe is not None and self._current_timestep_keyframe.has_control_weights(): - return self._current_timestep_keyframe.control_weights.extras.get(RefConst.OPTS, self._ref_opts) - return self._ref_opts - - def any_attn_strength_to_apply(self): - return self.should_apply_attn_effective_strength or self.should_apply_effective_masks - - def any_adain_strength_to_apply(self): - return self.should_apply_adain_effective_strength or self.should_apply_effective_masks - - def get_effective_strength(self): - effective_strength = self.strength - if self._current_timestep_keyframe is not None: - effective_strength = effective_strength * self._current_timestep_keyframe.strength - return effective_strength - - def get_effective_attn_mask_or_float(self, x: Tensor, channels: int, is_mid: bool): - if not self.should_apply_effective_masks: - return self.get_effective_strength() * self.ref_opts.attn_strength - if is_mid: - div = 8 - else: - div = self.CHANNEL_TO_MULT[channels] - real_mask = torch.ones([self.latent_shape[0], 1, self.latent_shape[2]//div, self.latent_shape[3]//div]).to(dtype=x.dtype, device=x.device) * self.strength * self.ref_opts.attn_strength - self.apply_advanced_strengths_and_masks(x=real_mask, batched_number=self.batched_number) - # mask is now shape [b, 1, h ,w]; need to turn into [b, h*w, 1] - b, c, h, w = real_mask.shape - real_mask = real_mask.permute(0, 2, 3, 1).reshape(b, h*w, c) - return real_mask - - def get_effective_adain_mask_or_float(self, x: Tensor): - if not self.should_apply_effective_masks: - return self.get_effective_strength() * self.ref_opts.adain_strength - b, c, h, w = x.shape - real_mask = torch.ones([b, 1, h, w]).to(dtype=x.dtype, device=x.device) * self.strength * self.ref_opts.adain_strength - self.apply_advanced_strengths_and_masks(x=real_mask, batched_number=self.batched_number) - return real_mask - - def get_contextref_mode_replace(self): - # used by ADE to get mode_replace for current keyframe - if self._current_timestep_keyframe.has_control_weights(): - return self._current_timestep_keyframe.control_weights.extras.get(RefConst.CREF_MODE, None) - return None - - def should_run(self): - running = super().should_run() - if not running: - return running - attn_run = False - adain_run = False - if ReferenceType.is_attn(self.ref_opts.reference_type): - # attn will run as long as neither weight or strength is zero - attn_run = not (math.isclose(self.ref_opts.attn_ref_weight, 0.0) or math.isclose(self.ref_opts.attn_strength, 0.0)) - if ReferenceType.is_adain(self.ref_opts.reference_type): - # adain will run as long as neither weight or strength is zero - adain_run = not (math.isclose(self.ref_opts.adain_ref_weight, 0.0) or math.isclose(self.ref_opts.adain_strength, 0.0)) - return attn_run or adain_run - - def pre_run_advanced(self, model, percent_to_timestep_function): - AdvancedControlBase.pre_run_advanced(self, model, percent_to_timestep_function) - if isinstance(self.cond_hint_original, AbstractPreprocWrapper): - self.cond_hint_original = self.cond_hint_original.condhint - self.model_latent_format = model.latent_format # LatentFormat object, used to process_in latent cond_hint - self.model_sampling_current = model.model_sampling - # SDXL is more sensitive to style_fidelity according to sd-webui-controlnet comments; - # prepare all ref_opts accordingly - all_ref_opts = [self._ref_opts] - for kf in self.timestep_keyframes.keyframes: - if kf.has_control_weights() and RefConst.OPTS in kf.control_weights.extras: - all_ref_opts.append(kf.control_weights.extras[RefConst.OPTS]) - for ropts in all_ref_opts: - if type(model).__name__ == "SDXL": - ropts.attn_style_fidelity = ropts.original_attn_style_fidelity ** 3.0 - ropts.adain_style_fidelity = ropts.original_adain_style_fidelity ** 3.0 - else: - ropts.attn_style_fidelity = ropts.original_attn_style_fidelity - ropts.adain_style_fidelity = ropts.original_adain_style_fidelity - - def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int): - # normal ControlNet stuff - control_prev = None - if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) - - if self.timestep_range is not None: - if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: - return control_prev - - dtype = x_noisy.dtype - # cond_hint_original only matters for RefCN, NOT ContextRef - if self.cond_hint_original is not None: - # prepare cond_hint - it is a latent, NOT an image - #if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] != self.cond_hint.shape[2] or x_noisy.shape[3] != self.cond_hint.shape[3]: - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling - if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length: - self.cond_hint = comfy.utils.common_upscale( - self.cond_hint_original[self.sub_idxs], - x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(x_noisy.device) - else: - self.cond_hint = comfy.utils.common_upscale( - self.cond_hint_original, - x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(x_noisy.device) - if x_noisy.shape[0] != self.cond_hint.shape[0]: - self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number, except_one=False) - # noise cond_hint based on sigma (current step) - self.cond_hint = self.model_latent_format.process_in(self.cond_hint) - self.cond_hint = ref_noise_latents(self.cond_hint, sigma=t, noise=None) - timestep = self.model_sampling_current.timestep(t) - self.should_apply_attn_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.attn_strength, 1.0)) - self.should_apply_adain_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.adain_strength, 1.0)) - # prepare mask - use direct_attn, so the mask dims will match source latents (and be smaller) - self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, direct_attn=True) - self.should_apply_effective_masks = self.latent_keyframes is not None or self.mask_cond_hint is not None or self.tk_mask_cond_hint is not None - self.latent_shape = list(x_noisy.shape) - # done preparing; model patches will take care of everything now. - # return normal controlnet stuff - return control_prev - - def cleanup_advanced(self): - super().cleanup_advanced() - del self.model_latent_format - self.model_latent_format = None - del self.model_sampling_current - self.model_sampling_current = None - self.should_apply_attn_effective_strength = False - self.should_apply_adain_effective_strength = False - self.should_apply_effective_masks = False - - def copy(self): - c = ReferenceAdvanced(self.ref_opts, self.timestep_keyframes) - c.order = self.order - c.is_context_ref = self.is_context_ref - self.copy_to(c) - self.copy_to_advanced(c) - return c - - # avoid deepcopy shenanigans by making deepcopy not do anything to the reference - # TODO: do the bookkeeping to do this in a proper way for all Adv-ControlNets - def __deepcopy__(self, memo): - return self - - -def handle_context_ref_setup(contextref_obj, transformer_options: dict, positive, negative): - transformer_options[CONTEXTREF_MACHINE_STATE] = MachineState.OFF - # verify version is compatible - if contextref_obj.version > HIGHEST_VERSION_SUPPORT: - raise Exception(f"AnimateDiff-Evolved's ContextRef v{contextref_obj.version} is not supported in currently-installed Advanced-ControlNet (only supports ContextRef up to v{HIGHEST_VERSION_SUPPORT}); " + - f"update your Advanced-ControlNet nodes for ContextRef to work.") - # init ReferenceOptions - cref_opt_dict = contextref_obj.tune.create_dict() # ContextRefTune obj from ADE - opts = ReferenceOptions.create_from_kwargs(**cref_opt_dict) - # init TimestepKeyframes - cref_tks_list = contextref_obj.keyframe.create_list_of_dicts() # ContextRefKeyframeGroup obj from ADE - timestep_keyframes = _create_tks_from_dict_list(cref_tks_list) - # create ReferenceAdvanced - cref = ReferenceAdvanced(ref_opts=opts, timestep_keyframes=timestep_keyframes) - cref.strength = contextref_obj.strength # ContextRef obj from ADE - cref.set_cond_hint_mask(contextref_obj.mask) - cref.order = 99 - cref.is_context_ref = True - context_ref_list = [cref] - transformer_options[CONTEXTREF_CONTROL_LIST_ALL] = context_ref_list - transformer_options[CONTEXTREF_OPTIONS_CLASS] = ReferenceOptions - _add_context_ref_to_conds([positive, negative], cref) - return context_ref_list - - -def _create_tks_from_dict_list(dlist: list[dict[str]]) -> TimestepKeyframeGroup: - tks = TimestepKeyframeGroup() - if dlist is None or len(dlist) == 0: - return tks - for d in dlist: - # scheduling - start_percent = d["start_percent"] - guarantee_steps = d["guarantee_steps"] - inherit_missing = d["inherit_missing"] - # values - strength = d["strength"] - mask = d["mask"] - tune = d["tune"] - mode = d["mode"] - weights = None - extras = {} - if tune is not None: - cref_opt_dict = tune.create_dict() # ContextRefTune obj from ADE - opts = ReferenceOptions.create_from_kwargs(**cref_opt_dict) - extras[RefConst.OPTS] = opts - if mode is not None: - extras[RefConst.CREF_MODE] = mode - weights = ControlWeights.default(extras=extras) - # create keyframe - tk = TimestepKeyframe(start_percent=start_percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing, - strength=strength, mask_hint_orig=mask, control_weights=weights) - tks.add(tk) - return tks - - -def _add_context_ref_to_conds(conds: list[list[dict[str]]], context_ref: ReferenceAdvanced): - def _add_context_ref_to_existing_control(control: ControlBase, context_ref: ReferenceAdvanced): - curr_cn = control - while curr_cn is not None: - if type(curr_cn) == ReferenceAdvanced and curr_cn.is_context_ref: - break - if curr_cn.previous_controlnet is not None: - curr_cn = curr_cn.previous_controlnet - continue - orig_previous_controlnet = curr_cn.previous_controlnet - # NOTE: code is already in place to restore any ORIG_PREVIOUS_CONTROLNET props - setattr(curr_cn, ORIG_PREVIOUS_CONTROLNET, orig_previous_controlnet) - curr_cn.previous_controlnet = context_ref - curr_cn = orig_previous_controlnet - - def _add_context_ref(actual_cond: dict[str], context_ref: ReferenceAdvanced): - # if controls already present on cond, add it to the last previous_controlnet - if "control" in actual_cond: - return _add_context_ref_to_existing_control(actual_cond["control"], context_ref) - # otherwise, need to add it to begin with, and should mark that it should be cleaned after - actual_cond["control"] = context_ref - actual_cond[CONTROL_INIT_BY_ACN] = True - - # either add context_ref to end of existing cnet chain, or init 'control' key on actual cond - for cond in conds: - if cond is not None: - for sub_cond in cond: - actual_cond = sub_cond[1] - _add_context_ref(actual_cond, context_ref) - - -def ref_noise_latents(latents: Tensor, sigma: Tensor, noise: Tensor=None): - sigma = sigma.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) - alpha_cumprod = 1 / ((sigma * sigma) + 1) - sqrt_alpha_prod = alpha_cumprod ** 0.5 - sqrt_one_minus_alpha_prod = (1. - alpha_cumprod) ** 0.5 - if noise is None: - # generator = torch.Generator(device="cuda") - # generator.manual_seed(0) - # noise = torch.empty_like(latents).normal_(generator=generator) - # generator = torch.Generator() - # generator.manual_seed(0) - # noise = torch.randn(latents.size(), generator=generator).to(latents.device) - noise = torch.randn_like(latents).to(latents.device) - return sqrt_alpha_prod * latents + sqrt_one_minus_alpha_prod * noise - - -def simple_noise_latents(latents: Tensor, sigma: float, noise: Tensor=None): - if noise is None: - noise = torch.rand_like(latents) - return latents + noise * sigma - - -class BankStylesBasicTransformerBlock: - def __init__(self): - # ref - self.bank = [] - self.style_cfgs = [] - self.cn_idx: list[int] = [] - # contextref - list of lists as each cond/uncond stored separately - self.c_bank: list[list] = [] - self.c_style_cfgs: list[list] = [] - self.c_cn_idx: list[list[int]] = [] - - def get_bank(self, cref_idx, ignore_contextref, cdevice=None): - if ignore_contextref or cref_idx >= len(self.c_bank): - return self.bank - real_c_bank_list = self.c_bank[cref_idx] - if cdevice != None: - real_c_bank_list = real_c_bank_list.copy() - for i in range(len(real_c_bank_list)): - real_c_bank_list[i] = real_c_bank_list[i].to(cdevice) - return self.bank + real_c_bank_list - - def get_avg_style_fidelity(self, cref_idx, ignore_contextref): - if ignore_contextref or cref_idx >= len(self.c_style_cfgs): - return sum(self.style_cfgs) / float(len(self.style_cfgs)) - combined = self.style_cfgs + self.c_style_cfgs[cref_idx] - return sum(combined) / float(len(combined)) - - def get_cn_idxs(self, cref_idx, ignore_contxtref): - if ignore_contxtref or cref_idx >= len(self.c_cn_idx): - return self.cn_idx - return self.cn_idx + self.c_cn_idx[cref_idx] - - def init_cref_for_idx(self, cref_idx: int): - # makes sure cref lists can accommodate cref_idx - if cref_idx < 0: - return - while cref_idx >= len(self.c_bank): - self.c_bank.append([]) - self.c_style_cfgs.append([]) - self.c_cn_idx.append([]) - - def clear_cref_for_idx(self, cref_idx: int): - if cref_idx < 0 or cref_idx >= len(self.c_bank): - return - self.c_bank[cref_idx] = [] - self.c_style_cfgs[cref_idx] = [] - self.c_cn_idx[cref_idx] = [] - - def clean_ref(self): - del self.bank - del self.style_cfgs - del self.cn_idx - self.bank = [] - self.style_cfgs = [] - self.cn_idx = [] - - def clean_contextref(self): - del self.c_bank - del self.c_style_cfgs - del self.c_cn_idx - self.c_bank = [] - self.c_style_cfgs = [] - self.c_cn_idx = [] - - def clean_all(self): - self.clean_ref() - self.clean_contextref() - - -class BankStylesTimestepEmbedSequential: - def __init__(self): - # ref - self.var_bank = [] - self.mean_bank = [] - self.style_cfgs = [] - self.cn_idx: list[int] = [] - # cref - self.c_var_bank: list[list] = [] - self.c_mean_bank: list[list] = [] - self.c_style_cfgs: list[list] = [] - self.c_cn_idx: list[list[int]] = [] - - def get_var_bank(self, cref_idx, ignore_contextref): - if ignore_contextref or cref_idx >= len(self.c_var_bank): - return self.var_bank - return self.var_bank + self.c_var_bank[cref_idx] - - def get_mean_bank(self, cref_idx, ignore_contextref): - if ignore_contextref or cref_idx >= len(self.c_mean_bank): - return self.mean_bank - return self.mean_bank + self.c_mean_bank[cref_idx] - - def get_style_cfgs(self, cref_idx, ignore_contextref): - if ignore_contextref or cref_idx >= len(self.c_style_cfgs): - return self.style_cfgs - return self.style_cfgs + self.c_style_cfgs[cref_idx] - - def get_cn_idxs(self, cref_idx, ignore_contextref): - if ignore_contextref or cref_idx >= len(self.c_cn_idx): - return self.cn_idx - return self.cn_idx + self.c_cn_idx[cref_idx] - - def init_cref_for_idx(self, cref_idx: int): - # makes sure cref lists can accommodate cref_idx - if cref_idx < 0: - return - while cref_idx >= len(self.c_var_bank): - self.c_var_bank.append([]) - self.c_mean_bank.append([]) - self.c_style_cfgs.append([]) - self.c_cn_idx.append([]) - - def clear_cref_for_idx(self, cref_idx: int): - if cref_idx < 0 or cref_idx >= len(self.c_var_bank): - return - self.c_var_bank[cref_idx] = [] - self.c_mean_bank[cref_idx] = [] - self.c_style_cfgs[cref_idx] = [] - self.c_cn_idx[cref_idx] = [] - - def clean_ref(self): - del self.mean_bank - del self.var_bank - del self.style_cfgs - del self.cn_idx - self.mean_bank = [] - self.var_bank = [] - self.style_cfgs = [] - self.cn_idx = [] - - def clean_contextref(self): - del self.c_var_bank - del self.c_mean_bank - del self.c_style_cfgs - del self.c_cn_idx - self.c_var_bank = [] - self.c_mean_bank = [] - self.c_style_cfgs = [] - self.c_cn_idx = [] - - def clean_all(self): - self.clean_ref() - self.clean_contextref() - - -class InjectionBasicTransformerBlockHolder: - def __init__(self, block: BasicTransformerBlock, idx=None): - if hasattr(block, "_forward"): # backward compatibility - self.original_forward = block._forward - else: - self.original_forward = block.forward - self.idx = idx - self.attn_weight = 1.0 - self.is_middle = False - self.bank_styles = BankStylesBasicTransformerBlock() - - def restore(self, block: BasicTransformerBlock): - if hasattr(block, "_forward"): # backward compatibility - block._forward = self.original_forward - else: - block.forward = self.original_forward - - def clean_ref(self): - self.bank_styles.clean_ref() - - def clean_contextref(self): - self.bank_styles.clean_contextref() - - def clean_all(self): - self.bank_styles.clean_all() - - -class InjectionTimestepEmbedSequentialHolder: - def __init__(self, block: openaimodel.TimestepEmbedSequential, idx=None, is_middle=False, is_input=False, is_output=False): - self.original_forward = block.forward - self.idx = idx - self.gn_weight = 1.0 - self.is_middle = is_middle - self.is_input = is_input - self.is_output = is_output - self.bank_styles = BankStylesTimestepEmbedSequential() - - def restore(self, block: openaimodel.TimestepEmbedSequential): - block.forward = self.original_forward - - def clean_ref(self): - self.bank_styles.clean_ref() - - def clean_contextref(self): - self.bank_styles.clean_contextref() - - def clean_all(self): - self.bank_styles.clean_all() - - -class ReferenceInjections: - def __init__(self, attn_modules: list['RefBasicTransformerBlock']=None, gn_modules: list['RefTimestepEmbedSequential']=None): - self.attn_modules = attn_modules if attn_modules else [] - self.gn_modules = gn_modules if gn_modules else [] - self.diffusion_model_orig_forward: Callable = None - - def clean_ref_module_mem(self): - for attn_module in self.attn_modules: - try: - attn_module.injection_holder.clean_ref() - except Exception: - pass - for gn_module in self.gn_modules: - try: - gn_module.injection_holder.clean_ref() - except Exception: - pass - - def clean_contextref_module_mem(self): - for attn_module in self.attn_modules: - try: - attn_module.injection_holder.clean_contextref() - except Exception: - pass - for gn_module in self.gn_modules: - try: - gn_module.injection_holder.clean_contextref() - except Exception: - pass - - def clean_all_module_mem(self): - for attn_module in self.attn_modules: - try: - attn_module.injection_holder.clean_all() - except Exception: - pass - for gn_module in self.gn_modules: - try: - gn_module.injection_holder.clean_all() - except Exception: - pass - - def cleanup(self): - self.clean_all_module_mem() - del self.attn_modules - self.attn_modules = [] - del self.gn_modules - self.gn_modules = [] - self.diffusion_model_orig_forward = None - - -def factory_forward_inject_UNetModel(reference_injections: ReferenceInjections): - def forward_inject_UNetModel(self, x: Tensor, *args, **kwargs): - # get control and transformer_options from kwargs - real_args = list(args) - real_kwargs = list(kwargs.keys()) - control = kwargs.get("control", None) - transformer_options: dict[str] = kwargs.get("transformer_options", {}) - # NOTE: adds support for both ReferenceCN and ContextRef, so need to track them separately - # get ReferenceAdvanced objects - ref_controlnets: list[ReferenceAdvanced] = transformer_options.get(REF_CONTROL_LIST_ALL, []) - context_controlnets: list[ReferenceAdvanced] = transformer_options.get(CONTEXTREF_CONTROL_LIST_ALL, []) - # clean contextref stuff if OFF - if len(context_controlnets) > 0 and transformer_options[CONTEXTREF_MACHINE_STATE] == MachineState.OFF: - reference_injections.clean_contextref_module_mem() - context_controlnets = [] - # discard any controlnets that should not run - ref_controlnets = [z for z in ref_controlnets if z.should_run()] - context_controlnets = [z for z in context_controlnets if z.should_run()] - # if nothing related to reference controlnets, do nothing special - if len(ref_controlnets) == 0 and len(context_controlnets) == 0: - return reference_injections.diffusion_model_orig_forward(x, *args, **kwargs) - try: - # assign cond and uncond idxs - batched_number = len(transformer_options["cond_or_uncond"]) - per_batch = x.shape[0] // batched_number - indiv_conds = [] - for cond_type in transformer_options["cond_or_uncond"]: - indiv_conds.extend([cond_type] * per_batch) - transformer_options[REF_UNCOND_IDXS] = [i for i, z in enumerate(indiv_conds) if z == 1] - transformer_options[REF_COND_IDXS] = [i for i, z in enumerate(indiv_conds) if z == 0] - # check which controlnets do which thing - attn_controlnets = [] - adain_controlnets = [] - for control in ref_controlnets: - if ReferenceType.is_attn(control.ref_opts.reference_type): - attn_controlnets.append(control) - if ReferenceType.is_adain(control.ref_opts.reference_type): - adain_controlnets.append(control) - context_attn_controlnets = [] - context_adain_controlnets = [] - # for ease of access, store current contextref_cond_idx value - if len(context_controlnets) == 0: - transformer_options[CONTEXTREF_TEMP_COND_IDX] = -1 - else: - transformer_options[CONTEXTREF_TEMP_COND_IDX] = context_controlnets[0].contextref_cond_idx - # logger.info(f"{transformer_options[CONTEXTREF_MACHINE_STATE]}: {transformer_options[CONTEXTREF_TEMP_COND_IDX]}") - - for control in context_controlnets: - if ReferenceType.is_attn(control.ref_opts.reference_type): - context_attn_controlnets.append(control) - if ReferenceType.is_adain(control.ref_opts.reference_type): - context_adain_controlnets.append(control) - if len(adain_controlnets) > 0 or len(context_adain_controlnets) > 0: - # ComfyUI uses forward_timestep_embed with the TimestepEmbedSequential passed into it - orig_forward_timestep_embed = openaimodel.forward_timestep_embed - openaimodel.forward_timestep_embed = forward_timestep_embed_ref_inject_factory(orig_forward_timestep_embed) - - # if RefCN to be used, handle running diffusion with ref cond hints - if len(ref_controlnets) > 0: - for control in ref_controlnets: - read_attn_list = [] - write_attn_list = [] - read_adain_list = [] - write_adain_list = [] - - if ReferenceType.is_attn(control.ref_opts.reference_type): - write_attn_list.append(control) - if ReferenceType.is_adain(control.ref_opts.reference_type): - write_adain_list.append(control) - # apply lists - transformer_options[REF_READ_ATTN_CONTROL_LIST] = read_attn_list - transformer_options[REF_WRITE_ATTN_CONTROL_LIST] = write_attn_list - transformer_options[REF_READ_ADAIN_CONTROL_LIST] = read_adain_list - transformer_options[REF_WRITE_ADAIN_CONTROL_LIST] = write_adain_list - - orig_kwargs = kwargs - # disable other controlnets for this run, if specified - if not control.ref_opts.ref_with_other_cns: - kwargs = kwargs.copy() - kwargs["control"] = None - reference_injections.diffusion_model_orig_forward(control.cond_hint.to(dtype=x.dtype).to(device=x.device), *args, **kwargs) - kwargs = orig_kwargs - # prepare running diffusion for real now - read_attn_list = [] - write_attn_list = [] - read_adain_list = [] - write_adain_list = [] - - # add RefCNs to read lists - read_attn_list.extend(attn_controlnets) - read_adain_list.extend(adain_controlnets) - - # do contextref stuff, if needed - if len(context_controlnets) > 0: - # clean contextref stuff if first WRITE - # if context_controlnets[0].contextref_cond_idx == 0 and is_write(transformer_options[CONTEXTREF_MACHINE_STATE]): - # reference_injections.clean_contextref_module_mem() - ### add ContextRef to appropriate lists - # attn - if is_read(transformer_options[CONTEXTREF_MACHINE_STATE]): - read_attn_list.extend(context_attn_controlnets) - if is_write(transformer_options[CONTEXTREF_MACHINE_STATE]): - write_attn_list.extend(context_attn_controlnets) - # adain - if is_read(transformer_options[CONTEXTREF_MACHINE_STATE]): - read_adain_list.extend(context_adain_controlnets) - if is_write(transformer_options[CONTEXTREF_MACHINE_STATE]): - write_adain_list.extend(context_adain_controlnets) - # apply lists, containing both RefCN and ContextRef - transformer_options[REF_READ_ATTN_CONTROL_LIST] = read_attn_list - transformer_options[REF_WRITE_ATTN_CONTROL_LIST] = write_attn_list - transformer_options[REF_READ_ADAIN_CONTROL_LIST] = read_adain_list - transformer_options[REF_WRITE_ADAIN_CONTROL_LIST] = write_adain_list - # run diffusion for real - try: - return reference_injections.diffusion_model_orig_forward(x, *args, **kwargs) - finally: - # increment current cond idx - if len(context_controlnets) > 0: - for cn in context_controlnets: - cn.contextref_cond_idx += 1 - finally: - # make sure ref banks are cleared no matter what happens - otherwise, RIP VRAM - reference_injections.clean_ref_module_mem() - if len(adain_controlnets) > 0 or len(context_adain_controlnets) > 0: - openaimodel.forward_timestep_embed = orig_forward_timestep_embed - - - return forward_inject_UNetModel - - -# dummy class just to help IDE keep track of injected variables -class RefBasicTransformerBlock(BasicTransformerBlock): - injection_holder: InjectionBasicTransformerBlockHolder = None - -def _forward_inject_BasicTransformerBlock(self: RefBasicTransformerBlock, x: Tensor, context: Tensor=None, transformer_options: dict[str]={}): - extra_options = {} - block = transformer_options.get("block", None) - block_index = transformer_options.get("block_index", 0) - transformer_patches = {} - transformer_patches_replace = {} - - for k in transformer_options: - if k == "patches": - transformer_patches = transformer_options[k] - elif k == "patches_replace": - transformer_patches_replace = transformer_options[k] - else: - extra_options[k] = transformer_options[k] - - extra_options["n_heads"] = self.n_heads - extra_options["dim_head"] = self.d_head - - if self.ff_in: - x_skip = x - x = self.ff_in(self.norm_in(x)) - if self.is_res: - x += x_skip - - n: Tensor = self.norm1(x) - if self.disable_self_attn: - context_attn1 = context - else: - context_attn1 = None - value_attn1 = None - - # Reference CN stuff - uc_idx_mask = transformer_options.get(REF_UNCOND_IDXS, []) - #c_idx_mask = transformer_options.get(REF_COND_IDXS, []) - # WRITE mode may have only 1 ReferenceAdvanced for RefCN at a time, other modes will have all ReferenceAdvanced - ref_write_cns: list[ReferenceAdvanced] = transformer_options.get(REF_WRITE_ATTN_CONTROL_LIST, []) - ref_read_cns: list[ReferenceAdvanced] = transformer_options.get(REF_READ_ATTN_CONTROL_LIST, []) - cref_cond_idx: int = transformer_options.get(CONTEXTREF_TEMP_COND_IDX, -1) - ignore_contextref_read = cref_cond_idx < 0 # if writing to bank, should NOT be read in the same execution - - cached_n = None - cref_write_cns: list[ReferenceAdvanced] = [] - # check if any WRITE cns are applicable; Reference CN WRITEs immediately, ContextREF WRITEs after READ completed - # if any refs to WRITE, save n and style_fidelity - for refcn in ref_write_cns: - if refcn.ref_opts.attn_ref_weight > self.injection_holder.attn_weight: - if cached_n is None: - cached_n = n.detach().clone() - # for ContextRef, make sure relevant lists are long enough to cond_idx - # store RefCN and ContextRef stuff separately - if refcn.is_context_ref: - cref_write_cns.append(refcn) - self.injection_holder.bank_styles.init_cref_for_idx(cref_cond_idx) - else: # Reference CN WRITE - self.injection_holder.bank_styles.bank.append(cached_n) - self.injection_holder.bank_styles.style_cfgs.append(refcn.ref_opts.attn_style_fidelity) - self.injection_holder.bank_styles.cn_idx.append(refcn.order) - if len(cref_write_cns) == 0: - del cached_n - - if "attn1_patch" in transformer_patches: - patch = transformer_patches["attn1_patch"] - if context_attn1 is None: - context_attn1 = n - value_attn1 = context_attn1 - for p in patch: - n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options) - - if block is not None: - transformer_block = (block[0], block[1], block_index) - else: - transformer_block = None - attn1_replace_patch = transformer_patches_replace.get("attn1", {}) - block_attn1 = transformer_block - if block_attn1 not in attn1_replace_patch: - block_attn1 = block - - if block_attn1 in attn1_replace_patch: - if context_attn1 is None: - context_attn1 = n - value_attn1 = n - n = self.attn1.to_q(n) - # Reference CN READ - use attn1_replace_patch appropriately - if len(ref_read_cns) > 0 and len(self.injection_holder.bank_styles.get_bank(cref_cond_idx, ignore_contextref_read)) > 0: - bank_styles = self.injection_holder.bank_styles - style_fidelity = bank_styles.get_avg_style_fidelity(cref_cond_idx, ignore_contextref_read) - real_bank = bank_styles.get_bank(cref_cond_idx, ignore_contextref_read, cdevice=n.device).copy() - real_cn_idxs = bank_styles.get_cn_idxs(cref_cond_idx, ignore_contextref_read) - cn_idx = 0 - for idx, order in enumerate(real_cn_idxs): - # make sure matching ref cn is selected - for i in range(cn_idx, len(ref_read_cns)): - if ref_read_cns[i].order == order: - cn_idx = i - break - assert order == ref_read_cns[cn_idx].order - if ref_read_cns[cn_idx].any_attn_strength_to_apply(): - effective_strength = ref_read_cns[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle) - real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength) - n_uc = self.attn1.to_out(attn1_replace_patch[block_attn1]( - n, - self.attn1.to_k(torch.cat([context_attn1] + real_bank, dim=1)), - self.attn1.to_v(torch.cat([value_attn1] + real_bank, dim=1)), - extra_options)) - n_c = n_uc.clone() - if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0): - n_c[uc_idx_mask] = self.attn1.to_out(attn1_replace_patch[block_attn1]( - n[uc_idx_mask], - self.attn1.to_k(context_attn1[uc_idx_mask]), - self.attn1.to_v(value_attn1[uc_idx_mask]), - extra_options)) - n = style_fidelity * n_c + (1.0-style_fidelity) * n_uc - bank_styles.clean_ref() - else: - context_attn1 = self.attn1.to_k(context_attn1) - value_attn1 = self.attn1.to_v(value_attn1) - n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) - n = self.attn1.to_out(n) - else: - # Reference CN READ - no attn1_replace_patch - if len(ref_read_cns) > 0 and len(self.injection_holder.bank_styles.get_bank(cref_cond_idx, ignore_contextref_read)) > 0: - if context_attn1 is None: - context_attn1 = n - bank_styles = self.injection_holder.bank_styles - style_fidelity = bank_styles.get_avg_style_fidelity(cref_cond_idx, ignore_contextref_read) - real_bank = bank_styles.get_bank(cref_cond_idx, ignore_contextref_read, cdevice=n.device).copy() - real_cn_idxs = bank_styles.get_cn_idxs(cref_cond_idx, ignore_contextref_read) - cn_idx = 0 - for idx, order in enumerate(real_cn_idxs): - # make sure matching ref cn is selected - for i in range(cn_idx, len(ref_read_cns)): - if ref_read_cns[i].order == order: - cn_idx = i - break - assert order == ref_read_cns[cn_idx].order - if ref_read_cns[cn_idx].any_attn_strength_to_apply(): - effective_strength = ref_read_cns[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle) - real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength) - n_uc: Tensor = self.attn1( - n, - context=torch.cat([context_attn1] + real_bank, dim=1), - value=torch.cat([value_attn1] + real_bank, dim=1) if value_attn1 is not None else value_attn1) - n_c = n_uc.clone() - if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0): - n_c[uc_idx_mask] = self.attn1( - n[uc_idx_mask], - context=context_attn1[uc_idx_mask], - value=value_attn1[uc_idx_mask] if value_attn1 is not None else value_attn1) - n = style_fidelity * n_c + (1.0-style_fidelity) * n_uc - bank_styles.clean_ref() - else: - n = self.attn1(n, context=context_attn1, value=value_attn1) - - # ContextRef CN WRITE - if len(cref_write_cns) > 0: - # clear so that ContextRef CNs can properly 'replace' previous value at cond_idx - self.injection_holder.bank_styles.clear_cref_for_idx(cref_cond_idx) - for refcn in cref_write_cns: - # add a whole list to match expected type when combining - self.injection_holder.bank_styles.c_bank[cref_cond_idx].append(cached_n.to(comfy.model_management.unet_offload_device())) - self.injection_holder.bank_styles.c_style_cfgs[cref_cond_idx].append(refcn.ref_opts.attn_style_fidelity) - self.injection_holder.bank_styles.c_cn_idx[cref_cond_idx].append(refcn.order) - del cached_n - - if "attn1_output_patch" in transformer_patches: - patch = transformer_patches["attn1_output_patch"] - for p in patch: - n = p(n, extra_options) - - x += n - if "middle_patch" in transformer_patches: - patch = transformer_patches["middle_patch"] - for p in patch: - x = p(x, extra_options) - - if self.attn2 is not None: - n = self.norm2(x) - if self.switch_temporal_ca_to_sa: - context_attn2 = n - else: - context_attn2 = context - value_attn2 = None - if "attn2_patch" in transformer_patches: - patch = transformer_patches["attn2_patch"] - value_attn2 = context_attn2 - for p in patch: - n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) - - attn2_replace_patch = transformer_patches_replace.get("attn2", {}) - block_attn2 = transformer_block - if block_attn2 not in attn2_replace_patch: - block_attn2 = block - - if block_attn2 in attn2_replace_patch: - if value_attn2 is None: - value_attn2 = context_attn2 - n = self.attn2.to_q(n) - context_attn2 = self.attn2.to_k(context_attn2) - value_attn2 = self.attn2.to_v(value_attn2) - n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) - n = self.attn2.to_out(n) - else: - n = self.attn2(n, context=context_attn2, value=value_attn2) - - if "attn2_output_patch" in transformer_patches: - patch = transformer_patches["attn2_output_patch"] - for p in patch: - n = p(n, extra_options) - - x += n - if self.is_res: - x_skip = x - x = self.ff(self.norm3(x)) - if self.is_res: - x += x_skip - - return x - - -class RefTimestepEmbedSequential(openaimodel.TimestepEmbedSequential): - injection_holder: InjectionTimestepEmbedSequentialHolder = None - -def forward_timestep_embed_ref_inject_factory(orig_timestep_embed_inject_factory: Callable): - def forward_timestep_embed_ref_inject(*args, **kwargs): - ts: RefTimestepEmbedSequential = args[0] - if not hasattr(ts, "injection_holder"): - return orig_timestep_embed_inject_factory(*args, **kwargs) - eps = 1e-6 - x: Tensor = orig_timestep_embed_inject_factory(*args, **kwargs) - y: Tensor = None - transformer_options: dict[str] = args[4] - # Reference CN stuff - uc_idx_mask = transformer_options.get(REF_UNCOND_IDXS, []) - #c_idx_mask = transformer_options.get(REF_COND_IDXS, []) - # WRITE mode will only have one ReferenceAdvanced, other modes will have all ReferenceAdvanced - ref_write_cns: list[ReferenceAdvanced] = transformer_options.get(REF_WRITE_ADAIN_CONTROL_LIST, []) - ref_read_cns: list[ReferenceAdvanced] = transformer_options.get(REF_READ_ADAIN_CONTROL_LIST, []) - cref_cond_idx: int = transformer_options.get(CONTEXTREF_TEMP_COND_IDX, -1) - ignore_contextref_read = cref_cond_idx < 0 # if writing to bank, should NOT be read in the same execution - - cached_var = None - cached_mean = None - cref_write_cns: list[ReferenceAdvanced] = [] - # if any refs to WRITE, save var, mean, and style_cfg - for refcn in ref_write_cns: - if refcn.ref_opts.adain_ref_weight > ts.injection_holder.gn_weight: - if cached_var is None: - cached_var, cached_mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) - if refcn.is_context_ref: - cref_write_cns.append(refcn) - ts.injection_holder.bank_styles.init_cref_for_idx(cref_cond_idx) - else: - ts.injection_holder.bank_styles.var_bank.append(cached_var) - ts.injection_holder.bank_styles.mean_bank.append(cached_mean) - ts.injection_holder.bank_styles.style_cfgs.append(refcn.ref_opts.adain_style_fidelity) - ts.injection_holder.bank_styles.cn_idx.append(refcn.order) - if len(cref_write_cns) == 0: - del cached_var - del cached_mean - - # if any refs to READ, do math with saved var, mean, and style_cfg - if len(ref_read_cns) > 0: - if len(ts.injection_holder.bank_styles.get_var_bank(cref_cond_idx, ignore_contextref_read)) > 0: - bank_styles = ts.injection_holder.bank_styles - var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) - std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 - y_uc = torch.zeros_like(x) - cn_idx = 0 - real_style_cfgs = bank_styles.get_style_cfgs(cref_cond_idx, ignore_contextref_read) - real_var_bank = bank_styles.get_var_bank(cref_cond_idx, ignore_contextref_read) - real_mean_bank = bank_styles.get_mean_bank(cref_cond_idx, ignore_contextref_read) - real_cn_idxs = bank_styles.get_cn_idxs(cref_cond_idx, ignore_contextref_read) - for idx, order in enumerate(real_cn_idxs): - # make sure matching ref cn is selected - for i in range(cn_idx, len(ref_read_cns)): - if ref_read_cns[i].order == order: - cn_idx = i - break - assert order == ref_read_cns[cn_idx].order - style_fidelity = real_style_cfgs[idx] - var_acc = real_var_bank[idx] - mean_acc = real_mean_bank[idx] - std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 - sub_y_uc = (((x - mean) / std) * std_acc) + mean_acc - if ref_read_cns[cn_idx].any_adain_strength_to_apply(): - effective_strength = ref_read_cns[cn_idx].get_effective_adain_mask_or_float(x=x) - sub_y_uc = sub_y_uc * effective_strength + x * (1-effective_strength) - y_uc += sub_y_uc - # get average, if more than one - if len(real_cn_idxs) > 1: - y_uc /= len(real_cn_idxs) - y_c = y_uc.clone() - if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0): - y_c[uc_idx_mask] = x.to(y_c.dtype)[uc_idx_mask] - y = style_fidelity * y_c + (1.0 - style_fidelity) * y_uc - ts.injection_holder.bank_styles.clean_ref() - - # ContextRef CN WRITE - if len(cref_write_cns) > 0: - # clear so that ContextRef CNs can properly 'replace' previous value at cond_idx - ts.injection_holder.bank_styles.clear_cref_for_idx(cref_cond_idx) - for refcn in cref_write_cns: - # add a whole list to match expected type when combining - ts.injection_holder.bank_styles.c_var_bank[cref_cond_idx].append(cached_var) - ts.injection_holder.bank_styles.c_mean_bank[cref_cond_idx].append(cached_mean) - ts.injection_holder.bank_styles.c_style_cfgs[cref_cond_idx].append(refcn.ref_opts.adain_style_fidelity) - ts.injection_holder.bank_styles.c_cn_idx[cref_cond_idx].append(refcn.order) - del cached_var - del cached_mean - - if y is None: - y = x - return y.to(x.dtype) - - return forward_timestep_embed_ref_inject +from typing import Callable, Union + +import math +import torch +from torch import Tensor + +import comfy.model_management +import comfy.sample +import comfy.model_patcher +import comfy.utils +from comfy.controlnet import ControlBase +from comfy.model_patcher import ModelPatcher +from comfy.ldm.modules.attention import BasicTransformerBlock +from comfy.ldm.modules.diffusionmodules import openaimodel + +from .logger import logger +from .utils import (AdvancedControlBase, ControlWeights, TimestepKeyframeGroup, TimestepKeyframe, AbstractPreprocWrapper, + broadcast_image_to_extend, ORIG_PREVIOUS_CONTROLNET, CONTROL_INIT_BY_ACN) + + +REF_READ_ATTN_CONTROL_LIST = "ref_read_attn_control_list" +REF_WRITE_ATTN_CONTROL_LIST = "ref_write_attn_control_list" +REF_READ_ADAIN_CONTROL_LIST = "ref_read_adain_control_list" +REF_WRITE_ADAIN_CONTROL_LIST = "ref_write_adain_control_list" + +REF_ATTN_CONTROL_LIST = "ref_attn_control_list" +REF_ADAIN_CONTROL_LIST = "ref_adain_control_list" +REF_CONTROL_LIST_ALL = "ref_control_list_all" +REF_CONTROL_INFO = "ref_control_info" +REF_ATTN_MACHINE_STATE = "ref_attn_machine_state" +REF_ADAIN_MACHINE_STATE = "ref_adain_machine_state" +REF_COND_IDXS = "ref_cond_idxs" +REF_UNCOND_IDXS = "ref_uncond_idxs" + +CONTEXTREF_OPTIONS_CLASS = "contextref_options_class" +CONTEXTREF_CLEAN_FUNC = "contextref_clean_func" +CONTEXTREF_CONTROL_LIST_ALL = "contextref_control_list_all" +CONTEXTREF_MACHINE_STATE = "contextref_machine_state" +CONTEXTREF_TEMP_COND_IDX = "contextref_temp_cond_idx" + +HIGHEST_VERSION_SUPPORT = 1 +RETURNED_CONTEXTREF_VERSION = 1 + + +class RefConst: + OPTS = "refcn_opts" + CREF_MODE = "contextref_mode" + + +class MachineState: + WRITE = "write" + READ = "read" + READ_WRITE = "read_write" + STYLEALIGN = "stylealign" + OFF = "off" + +def is_read(state: str): + return state in [MachineState.READ, MachineState.READ_WRITE] + +def is_write(state: str): + return state in [MachineState.WRITE, MachineState.READ_WRITE] + + +class ReferenceType: + ATTN = "reference_attn" + ADAIN = "reference_adain" + ATTN_ADAIN = "reference_attn+adain" + STYLE_ALIGN = "StyleAlign" + + _LIST = [ATTN, ADAIN, ATTN_ADAIN] + _LIST_ATTN = [ATTN, ATTN_ADAIN] + _LIST_ADAIN = [ADAIN, ATTN_ADAIN] + + @classmethod + def is_attn(cls, ref_type: str): + return ref_type in cls._LIST_ATTN + + @classmethod + def is_adain(cls, ref_type: str): + return ref_type in cls._LIST_ADAIN + + +class ReferenceOptions: + def __init__(self, reference_type: str, + attn_style_fidelity: float, adain_style_fidelity: float, + attn_ref_weight: float, adain_ref_weight: float, + attn_strength: float=1.0, adain_strength: float=1.0, + ref_with_other_cns: bool=False): + self.reference_type = reference_type + # attn + self.original_attn_style_fidelity = attn_style_fidelity + self.attn_style_fidelity = attn_style_fidelity + self.attn_ref_weight = attn_ref_weight + self.attn_strength = attn_strength + # adain + self.original_adain_style_fidelity = adain_style_fidelity + self.adain_style_fidelity = adain_style_fidelity + self.adain_ref_weight = adain_ref_weight + self.adain_strength = adain_strength + # other + self.ref_with_other_cns = ref_with_other_cns + + def clone(self): + return ReferenceOptions(reference_type=self.reference_type, + attn_style_fidelity=self.original_attn_style_fidelity, adain_style_fidelity=self.original_adain_style_fidelity, + attn_ref_weight=self.attn_ref_weight, adain_ref_weight=self.adain_ref_weight, + attn_strength=self.attn_strength, adain_strength=self.adain_strength, + ref_with_other_cns=self.ref_with_other_cns) + + @staticmethod + def create_combo(reference_type: str, style_fidelity: float, ref_weight: float, ref_with_other_cns: bool=False): + return ReferenceOptions(reference_type=reference_type, + attn_style_fidelity=style_fidelity, adain_style_fidelity=style_fidelity, + attn_ref_weight=ref_weight, adain_ref_weight=ref_weight, + ref_with_other_cns=ref_with_other_cns) + + @staticmethod + def create_from_kwargs(attn_style_fidelity=0.0, adain_style_fidelity=0.0, + attn_ref_weight=0.0, adain_ref_weight=0.0, + attn_strength=0.0, adain_strength=0.0, **kwargs): + has_attn = attn_strength > 0.0 + has_adain = adain_strength > 0.0 + if has_attn and has_adain: + reference_type = ReferenceType.ATTN_ADAIN + elif has_adain: + reference_type = ReferenceType.ADAIN + else: + reference_type = ReferenceType.ATTN + return ReferenceOptions(reference_type=reference_type, + attn_style_fidelity=float(attn_style_fidelity), adain_style_fidelity=float(adain_style_fidelity), + attn_ref_weight=float(attn_ref_weight), adain_ref_weight=float(adain_ref_weight), + attn_strength=float(attn_strength), adain_strength=float(adain_strength)) + + +class ReferencePreprocWrapper(AbstractPreprocWrapper): + error_msg = error_msg = "Invalid use of Reference Preprocess output. The output of Reference preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input." + def __init__(self, condhint: Tensor): + super().__init__(condhint) + + +class ReferenceAdvanced(ControlBase, AdvancedControlBase): + CHANNEL_TO_MULT = {320: 1, 640: 2, 1280: 4} + + def __init__(self, ref_opts: ReferenceOptions, timestep_keyframes: TimestepKeyframeGroup): + super().__init__() + AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite(), allow_condhint_latents=True) + # TODO: allow vae_optional to be used instead of preprocessor + #require_vae=True + self._ref_opts = ref_opts + self.order = 0 + self.model_latent_format = None + self.model_sampling_current = None + self.should_apply_attn_effective_strength = False + self.should_apply_adain_effective_strength = False + self.should_apply_effective_masks = False + self.latent_shape = None + # ContextRef stuff + self.is_context_ref = False + self.contextref_cond_idx = -1 + self.contextref_version = RETURNED_CONTEXTREF_VERSION + + @property + def ref_opts(self): + if self._current_timestep_keyframe is not None and self._current_timestep_keyframe.has_control_weights(): + return self._current_timestep_keyframe.control_weights.extras.get(RefConst.OPTS, self._ref_opts) + return self._ref_opts + + def any_attn_strength_to_apply(self): + return self.should_apply_attn_effective_strength or self.should_apply_effective_masks + + def any_adain_strength_to_apply(self): + return self.should_apply_adain_effective_strength or self.should_apply_effective_masks + + def get_effective_strength(self): + effective_strength = self.strength + if self._current_timestep_keyframe is not None: + effective_strength = effective_strength * self._current_timestep_keyframe.strength + return effective_strength + + def get_effective_attn_mask_or_float(self, x: Tensor, channels: int, is_mid: bool): + if not self.should_apply_effective_masks: + return self.get_effective_strength() * self.ref_opts.attn_strength + if is_mid: + div = 8 + else: + div = self.CHANNEL_TO_MULT[channels] + real_mask = torch.ones([self.latent_shape[0], 1, self.latent_shape[2]//div, self.latent_shape[3]//div]).to(dtype=x.dtype, device=x.device) * self.strength * self.ref_opts.attn_strength + self.apply_advanced_strengths_and_masks(x=real_mask, batched_number=self.batched_number) + # mask is now shape [b, 1, h ,w]; need to turn into [b, h*w, 1] + b, c, h, w = real_mask.shape + real_mask = real_mask.permute(0, 2, 3, 1).reshape(b, h*w, c) + return real_mask + + def get_effective_adain_mask_or_float(self, x: Tensor): + if not self.should_apply_effective_masks: + return self.get_effective_strength() * self.ref_opts.adain_strength + b, c, h, w = x.shape + real_mask = torch.ones([b, 1, h, w]).to(dtype=x.dtype, device=x.device) * self.strength * self.ref_opts.adain_strength + self.apply_advanced_strengths_and_masks(x=real_mask, batched_number=self.batched_number) + return real_mask + + def get_contextref_mode_replace(self): + # used by ADE to get mode_replace for current keyframe + if self._current_timestep_keyframe.has_control_weights(): + return self._current_timestep_keyframe.control_weights.extras.get(RefConst.CREF_MODE, None) + return None + + def should_run(self): + running = super().should_run() + if not running: + return running + attn_run = False + adain_run = False + if ReferenceType.is_attn(self.ref_opts.reference_type): + # attn will run as long as neither weight or strength is zero + attn_run = not (math.isclose(self.ref_opts.attn_ref_weight, 0.0) or math.isclose(self.ref_opts.attn_strength, 0.0)) + if ReferenceType.is_adain(self.ref_opts.reference_type): + # adain will run as long as neither weight or strength is zero + adain_run = not (math.isclose(self.ref_opts.adain_ref_weight, 0.0) or math.isclose(self.ref_opts.adain_strength, 0.0)) + return attn_run or adain_run + + def pre_run_advanced(self, model, percent_to_timestep_function): + AdvancedControlBase.pre_run_advanced(self, model, percent_to_timestep_function) + if isinstance(self.cond_hint_original, AbstractPreprocWrapper): + self.cond_hint_original = self.cond_hint_original.condhint + self.model_latent_format = model.latent_format # LatentFormat object, used to process_in latent cond_hint + self.model_sampling_current = model.model_sampling + # SDXL is more sensitive to style_fidelity according to sd-webui-controlnet comments; + # prepare all ref_opts accordingly + all_ref_opts = [self._ref_opts] + for kf in self.timestep_keyframes.keyframes: + if kf.has_control_weights() and RefConst.OPTS in kf.control_weights.extras: + all_ref_opts.append(kf.control_weights.extras[RefConst.OPTS]) + for ropts in all_ref_opts: + if type(model).__name__ == "SDXL": + ropts.attn_style_fidelity = ropts.original_attn_style_fidelity ** 3.0 + ropts.adain_style_fidelity = ropts.original_adain_style_fidelity ** 3.0 + else: + ropts.attn_style_fidelity = ropts.original_attn_style_fidelity + ropts.adain_style_fidelity = ropts.original_adain_style_fidelity + + def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int): + # normal ControlNet stuff + control_prev = None + if self.previous_controlnet is not None: + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + return control_prev + + dtype = x_noisy.dtype + # cond_hint_original only matters for RefCN, NOT ContextRef + if self.cond_hint_original is not None: + # prepare cond_hint - it is a latent, NOT an image + #if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] != self.cond_hint.shape[2] or x_noisy.shape[3] != self.cond_hint.shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling + if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length: + self.cond_hint = comfy.utils.common_upscale( + self.cond_hint_original[self.sub_idxs], + x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(x_noisy.device) + else: + self.cond_hint = comfy.utils.common_upscale( + self.cond_hint_original, + x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(x_noisy.device) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number, except_one=False) + # noise cond_hint based on sigma (current step) + self.cond_hint = self.model_latent_format.process_in(self.cond_hint) + self.cond_hint = ref_noise_latents(self.cond_hint, sigma=t, noise=None) + timestep = self.model_sampling_current.timestep(t) + self.should_apply_attn_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.attn_strength, 1.0)) + self.should_apply_adain_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.adain_strength, 1.0)) + # prepare mask - use direct_attn, so the mask dims will match source latents (and be smaller) + self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, direct_attn=True) + self.should_apply_effective_masks = self.latent_keyframes is not None or self.mask_cond_hint is not None or self.tk_mask_cond_hint is not None + self.latent_shape = list(x_noisy.shape) + # done preparing; model patches will take care of everything now. + # return normal controlnet stuff + return control_prev + + def cleanup_advanced(self): + super().cleanup_advanced() + del self.model_latent_format + self.model_latent_format = None + del self.model_sampling_current + self.model_sampling_current = None + self.should_apply_attn_effective_strength = False + self.should_apply_adain_effective_strength = False + self.should_apply_effective_masks = False + + def copy(self): + c = ReferenceAdvanced(self.ref_opts, self.timestep_keyframes) + c.order = self.order + c.is_context_ref = self.is_context_ref + self.copy_to(c) + self.copy_to_advanced(c) + return c + + # avoid deepcopy shenanigans by making deepcopy not do anything to the reference + # TODO: do the bookkeeping to do this in a proper way for all Adv-ControlNets + def __deepcopy__(self, memo): + return self + + +def handle_context_ref_setup(contextref_obj, transformer_options: dict, positive, negative): + transformer_options[CONTEXTREF_MACHINE_STATE] = MachineState.OFF + # verify version is compatible + if contextref_obj.version > HIGHEST_VERSION_SUPPORT: + raise Exception(f"AnimateDiff-Evolved's ContextRef v{contextref_obj.version} is not supported in currently-installed Advanced-ControlNet (only supports ContextRef up to v{HIGHEST_VERSION_SUPPORT}); " + + f"update your Advanced-ControlNet nodes for ContextRef to work.") + # init ReferenceOptions + cref_opt_dict = contextref_obj.tune.create_dict() # ContextRefTune obj from ADE + opts = ReferenceOptions.create_from_kwargs(**cref_opt_dict) + # init TimestepKeyframes + cref_tks_list = contextref_obj.keyframe.create_list_of_dicts() # ContextRefKeyframeGroup obj from ADE + timestep_keyframes = _create_tks_from_dict_list(cref_tks_list) + # create ReferenceAdvanced + cref = ReferenceAdvanced(ref_opts=opts, timestep_keyframes=timestep_keyframes) + cref.strength = contextref_obj.strength # ContextRef obj from ADE + cref.set_cond_hint_mask(contextref_obj.mask) + cref.order = 99 + cref.is_context_ref = True + context_ref_list = [cref] + transformer_options[CONTEXTREF_CONTROL_LIST_ALL] = context_ref_list + transformer_options[CONTEXTREF_OPTIONS_CLASS] = ReferenceOptions + _add_context_ref_to_conds([positive, negative], cref) + return context_ref_list + + +def _create_tks_from_dict_list(dlist: list[dict[str]]) -> TimestepKeyframeGroup: + tks = TimestepKeyframeGroup() + if dlist is None or len(dlist) == 0: + return tks + for d in dlist: + # scheduling + start_percent = d["start_percent"] + guarantee_steps = d["guarantee_steps"] + inherit_missing = d["inherit_missing"] + # values + strength = d["strength"] + mask = d["mask"] + tune = d["tune"] + mode = d["mode"] + weights = None + extras = {} + if tune is not None: + cref_opt_dict = tune.create_dict() # ContextRefTune obj from ADE + opts = ReferenceOptions.create_from_kwargs(**cref_opt_dict) + extras[RefConst.OPTS] = opts + if mode is not None: + extras[RefConst.CREF_MODE] = mode + weights = ControlWeights.default(extras=extras) + # create keyframe + tk = TimestepKeyframe(start_percent=start_percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing, + strength=strength, mask_hint_orig=mask, control_weights=weights) + tks.add(tk) + return tks + + +def _add_context_ref_to_conds(conds: list[list[dict[str]]], context_ref: ReferenceAdvanced): + def _add_context_ref_to_existing_control(control: ControlBase, context_ref: ReferenceAdvanced): + curr_cn = control + while curr_cn is not None: + if type(curr_cn) == ReferenceAdvanced and curr_cn.is_context_ref: + break + if curr_cn.previous_controlnet is not None: + curr_cn = curr_cn.previous_controlnet + continue + orig_previous_controlnet = curr_cn.previous_controlnet + # NOTE: code is already in place to restore any ORIG_PREVIOUS_CONTROLNET props + setattr(curr_cn, ORIG_PREVIOUS_CONTROLNET, orig_previous_controlnet) + curr_cn.previous_controlnet = context_ref + curr_cn = orig_previous_controlnet + + def _add_context_ref(actual_cond: dict[str], context_ref: ReferenceAdvanced): + # if controls already present on cond, add it to the last previous_controlnet + if "control" in actual_cond: + return _add_context_ref_to_existing_control(actual_cond["control"], context_ref) + # otherwise, need to add it to begin with, and should mark that it should be cleaned after + actual_cond["control"] = context_ref + actual_cond[CONTROL_INIT_BY_ACN] = True + + # either add context_ref to end of existing cnet chain, or init 'control' key on actual cond + for cond in conds: + if cond is not None: + for sub_cond in cond: + actual_cond = sub_cond[1] + _add_context_ref(actual_cond, context_ref) + + +def ref_noise_latents(latents: Tensor, sigma: Tensor, noise: Tensor=None): + sigma = sigma.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + alpha_cumprod = 1 / ((sigma * sigma) + 1) + sqrt_alpha_prod = alpha_cumprod ** 0.5 + sqrt_one_minus_alpha_prod = (1. - alpha_cumprod) ** 0.5 + if noise is None: + # generator = torch.Generator(device="cuda") + # generator.manual_seed(0) + # noise = torch.empty_like(latents).normal_(generator=generator) + # generator = torch.Generator() + # generator.manual_seed(0) + # noise = torch.randn(latents.size(), generator=generator).to(latents.device) + noise = torch.randn_like(latents).to(latents.device) + return sqrt_alpha_prod * latents + sqrt_one_minus_alpha_prod * noise + + +def simple_noise_latents(latents: Tensor, sigma: float, noise: Tensor=None): + if noise is None: + noise = torch.rand_like(latents) + return latents + noise * sigma + + +class BankStylesBasicTransformerBlock: + def __init__(self): + # ref + self.bank = [] + self.style_cfgs = [] + self.cn_idx: list[int] = [] + # contextref - list of lists as each cond/uncond stored separately + self.c_bank: list[list] = [] + self.c_style_cfgs: list[list] = [] + self.c_cn_idx: list[list[int]] = [] + + def get_bank(self, cref_idx, ignore_contextref, cdevice=None): + if ignore_contextref or cref_idx >= len(self.c_bank): + return self.bank + real_c_bank_list = self.c_bank[cref_idx] + if cdevice != None: + real_c_bank_list = real_c_bank_list.copy() + for i in range(len(real_c_bank_list)): + real_c_bank_list[i] = real_c_bank_list[i].to(cdevice) + return self.bank + real_c_bank_list + + def get_avg_style_fidelity(self, cref_idx, ignore_contextref): + if ignore_contextref or cref_idx >= len(self.c_style_cfgs): + return sum(self.style_cfgs) / float(len(self.style_cfgs)) + combined = self.style_cfgs + self.c_style_cfgs[cref_idx] + return sum(combined) / float(len(combined)) + + def get_cn_idxs(self, cref_idx, ignore_contxtref): + if ignore_contxtref or cref_idx >= len(self.c_cn_idx): + return self.cn_idx + return self.cn_idx + self.c_cn_idx[cref_idx] + + def init_cref_for_idx(self, cref_idx: int): + # makes sure cref lists can accommodate cref_idx + if cref_idx < 0: + return + while cref_idx >= len(self.c_bank): + self.c_bank.append([]) + self.c_style_cfgs.append([]) + self.c_cn_idx.append([]) + + def clear_cref_for_idx(self, cref_idx: int): + if cref_idx < 0 or cref_idx >= len(self.c_bank): + return + self.c_bank[cref_idx] = [] + self.c_style_cfgs[cref_idx] = [] + self.c_cn_idx[cref_idx] = [] + + def clean_ref(self): + del self.bank + del self.style_cfgs + del self.cn_idx + self.bank = [] + self.style_cfgs = [] + self.cn_idx = [] + + def clean_contextref(self): + del self.c_bank + del self.c_style_cfgs + del self.c_cn_idx + self.c_bank = [] + self.c_style_cfgs = [] + self.c_cn_idx = [] + + def clean_all(self): + self.clean_ref() + self.clean_contextref() + + +class BankStylesTimestepEmbedSequential: + def __init__(self): + # ref + self.var_bank = [] + self.mean_bank = [] + self.style_cfgs = [] + self.cn_idx: list[int] = [] + # cref + self.c_var_bank: list[list] = [] + self.c_mean_bank: list[list] = [] + self.c_style_cfgs: list[list] = [] + self.c_cn_idx: list[list[int]] = [] + + def get_var_bank(self, cref_idx, ignore_contextref): + if ignore_contextref or cref_idx >= len(self.c_var_bank): + return self.var_bank + return self.var_bank + self.c_var_bank[cref_idx] + + def get_mean_bank(self, cref_idx, ignore_contextref): + if ignore_contextref or cref_idx >= len(self.c_mean_bank): + return self.mean_bank + return self.mean_bank + self.c_mean_bank[cref_idx] + + def get_style_cfgs(self, cref_idx, ignore_contextref): + if ignore_contextref or cref_idx >= len(self.c_style_cfgs): + return self.style_cfgs + return self.style_cfgs + self.c_style_cfgs[cref_idx] + + def get_cn_idxs(self, cref_idx, ignore_contextref): + if ignore_contextref or cref_idx >= len(self.c_cn_idx): + return self.cn_idx + return self.cn_idx + self.c_cn_idx[cref_idx] + + def init_cref_for_idx(self, cref_idx: int): + # makes sure cref lists can accommodate cref_idx + if cref_idx < 0: + return + while cref_idx >= len(self.c_var_bank): + self.c_var_bank.append([]) + self.c_mean_bank.append([]) + self.c_style_cfgs.append([]) + self.c_cn_idx.append([]) + + def clear_cref_for_idx(self, cref_idx: int): + if cref_idx < 0 or cref_idx >= len(self.c_var_bank): + return + self.c_var_bank[cref_idx] = [] + self.c_mean_bank[cref_idx] = [] + self.c_style_cfgs[cref_idx] = [] + self.c_cn_idx[cref_idx] = [] + + def clean_ref(self): + del self.mean_bank + del self.var_bank + del self.style_cfgs + del self.cn_idx + self.mean_bank = [] + self.var_bank = [] + self.style_cfgs = [] + self.cn_idx = [] + + def clean_contextref(self): + del self.c_var_bank + del self.c_mean_bank + del self.c_style_cfgs + del self.c_cn_idx + self.c_var_bank = [] + self.c_mean_bank = [] + self.c_style_cfgs = [] + self.c_cn_idx = [] + + def clean_all(self): + self.clean_ref() + self.clean_contextref() + + +class InjectionBasicTransformerBlockHolder: + def __init__(self, block: BasicTransformerBlock, idx=None): + if hasattr(block, "_forward"): # backward compatibility + self.original_forward = block._forward + else: + self.original_forward = block.forward + self.idx = idx + self.attn_weight = 1.0 + self.is_middle = False + self.bank_styles = BankStylesBasicTransformerBlock() + + def restore(self, block: BasicTransformerBlock): + if hasattr(block, "_forward"): # backward compatibility + block._forward = self.original_forward + else: + block.forward = self.original_forward + + def clean_ref(self): + self.bank_styles.clean_ref() + + def clean_contextref(self): + self.bank_styles.clean_contextref() + + def clean_all(self): + self.bank_styles.clean_all() + + +class InjectionTimestepEmbedSequentialHolder: + def __init__(self, block: openaimodel.TimestepEmbedSequential, idx=None, is_middle=False, is_input=False, is_output=False): + self.original_forward = block.forward + self.idx = idx + self.gn_weight = 1.0 + self.is_middle = is_middle + self.is_input = is_input + self.is_output = is_output + self.bank_styles = BankStylesTimestepEmbedSequential() + + def restore(self, block: openaimodel.TimestepEmbedSequential): + block.forward = self.original_forward + + def clean_ref(self): + self.bank_styles.clean_ref() + + def clean_contextref(self): + self.bank_styles.clean_contextref() + + def clean_all(self): + self.bank_styles.clean_all() + + +class ReferenceInjections: + def __init__(self, attn_modules: list['RefBasicTransformerBlock']=None, gn_modules: list['RefTimestepEmbedSequential']=None): + self.attn_modules = attn_modules if attn_modules else [] + self.gn_modules = gn_modules if gn_modules else [] + self.diffusion_model_orig_forward: Callable = None + + def clean_ref_module_mem(self): + for attn_module in self.attn_modules: + try: + attn_module.injection_holder.clean_ref() + except Exception: + pass + for gn_module in self.gn_modules: + try: + gn_module.injection_holder.clean_ref() + except Exception: + pass + + def clean_contextref_module_mem(self): + for attn_module in self.attn_modules: + try: + attn_module.injection_holder.clean_contextref() + except Exception: + pass + for gn_module in self.gn_modules: + try: + gn_module.injection_holder.clean_contextref() + except Exception: + pass + + def clean_all_module_mem(self): + for attn_module in self.attn_modules: + try: + attn_module.injection_holder.clean_all() + except Exception: + pass + for gn_module in self.gn_modules: + try: + gn_module.injection_holder.clean_all() + except Exception: + pass + + def cleanup(self): + self.clean_all_module_mem() + del self.attn_modules + self.attn_modules = [] + del self.gn_modules + self.gn_modules = [] + self.diffusion_model_orig_forward = None + + +def factory_forward_inject_UNetModel(reference_injections: ReferenceInjections): + def forward_inject_UNetModel(self, x: Tensor, *args, **kwargs): + # get control and transformer_options from kwargs + real_args = list(args) + real_kwargs = list(kwargs.keys()) + control = kwargs.get("control", None) + transformer_options: dict[str] = kwargs.get("transformer_options", {}) + # NOTE: adds support for both ReferenceCN and ContextRef, so need to track them separately + # get ReferenceAdvanced objects + ref_controlnets: list[ReferenceAdvanced] = transformer_options.get(REF_CONTROL_LIST_ALL, []) + context_controlnets: list[ReferenceAdvanced] = transformer_options.get(CONTEXTREF_CONTROL_LIST_ALL, []) + # clean contextref stuff if OFF + if len(context_controlnets) > 0 and transformer_options[CONTEXTREF_MACHINE_STATE] == MachineState.OFF: + reference_injections.clean_contextref_module_mem() + context_controlnets = [] + # discard any controlnets that should not run + ref_controlnets = [z for z in ref_controlnets if z.should_run()] + context_controlnets = [z for z in context_controlnets if z.should_run()] + # if nothing related to reference controlnets, do nothing special + if len(ref_controlnets) == 0 and len(context_controlnets) == 0: + return reference_injections.diffusion_model_orig_forward(x, *args, **kwargs) + try: + # assign cond and uncond idxs + batched_number = len(transformer_options["cond_or_uncond"]) + per_batch = x.shape[0] // batched_number + indiv_conds = [] + for cond_type in transformer_options["cond_or_uncond"]: + indiv_conds.extend([cond_type] * per_batch) + transformer_options[REF_UNCOND_IDXS] = [i for i, z in enumerate(indiv_conds) if z == 1] + transformer_options[REF_COND_IDXS] = [i for i, z in enumerate(indiv_conds) if z == 0] + # check which controlnets do which thing + attn_controlnets = [] + adain_controlnets = [] + for control in ref_controlnets: + if ReferenceType.is_attn(control.ref_opts.reference_type): + attn_controlnets.append(control) + if ReferenceType.is_adain(control.ref_opts.reference_type): + adain_controlnets.append(control) + context_attn_controlnets = [] + context_adain_controlnets = [] + # for ease of access, store current contextref_cond_idx value + if len(context_controlnets) == 0: + transformer_options[CONTEXTREF_TEMP_COND_IDX] = -1 + else: + transformer_options[CONTEXTREF_TEMP_COND_IDX] = context_controlnets[0].contextref_cond_idx + # logger.info(f"{transformer_options[CONTEXTREF_MACHINE_STATE]}: {transformer_options[CONTEXTREF_TEMP_COND_IDX]}") + + for control in context_controlnets: + if ReferenceType.is_attn(control.ref_opts.reference_type): + context_attn_controlnets.append(control) + if ReferenceType.is_adain(control.ref_opts.reference_type): + context_adain_controlnets.append(control) + if len(adain_controlnets) > 0 or len(context_adain_controlnets) > 0: + # ComfyUI uses forward_timestep_embed with the TimestepEmbedSequential passed into it + orig_forward_timestep_embed = openaimodel.forward_timestep_embed + openaimodel.forward_timestep_embed = forward_timestep_embed_ref_inject_factory(orig_forward_timestep_embed) + + # if RefCN to be used, handle running diffusion with ref cond hints + if len(ref_controlnets) > 0: + for control in ref_controlnets: + read_attn_list = [] + write_attn_list = [] + read_adain_list = [] + write_adain_list = [] + + if ReferenceType.is_attn(control.ref_opts.reference_type): + write_attn_list.append(control) + if ReferenceType.is_adain(control.ref_opts.reference_type): + write_adain_list.append(control) + # apply lists + transformer_options[REF_READ_ATTN_CONTROL_LIST] = read_attn_list + transformer_options[REF_WRITE_ATTN_CONTROL_LIST] = write_attn_list + transformer_options[REF_READ_ADAIN_CONTROL_LIST] = read_adain_list + transformer_options[REF_WRITE_ADAIN_CONTROL_LIST] = write_adain_list + + orig_kwargs = kwargs + # disable other controlnets for this run, if specified + if not control.ref_opts.ref_with_other_cns: + kwargs = kwargs.copy() + kwargs["control"] = None + reference_injections.diffusion_model_orig_forward(control.cond_hint.to(dtype=x.dtype).to(device=x.device), *args, **kwargs) + kwargs = orig_kwargs + # prepare running diffusion for real now + read_attn_list = [] + write_attn_list = [] + read_adain_list = [] + write_adain_list = [] + + # add RefCNs to read lists + read_attn_list.extend(attn_controlnets) + read_adain_list.extend(adain_controlnets) + + # do contextref stuff, if needed + if len(context_controlnets) > 0: + # clean contextref stuff if first WRITE + # if context_controlnets[0].contextref_cond_idx == 0 and is_write(transformer_options[CONTEXTREF_MACHINE_STATE]): + # reference_injections.clean_contextref_module_mem() + ### add ContextRef to appropriate lists + # attn + if is_read(transformer_options[CONTEXTREF_MACHINE_STATE]): + read_attn_list.extend(context_attn_controlnets) + if is_write(transformer_options[CONTEXTREF_MACHINE_STATE]): + write_attn_list.extend(context_attn_controlnets) + # adain + if is_read(transformer_options[CONTEXTREF_MACHINE_STATE]): + read_adain_list.extend(context_adain_controlnets) + if is_write(transformer_options[CONTEXTREF_MACHINE_STATE]): + write_adain_list.extend(context_adain_controlnets) + # apply lists, containing both RefCN and ContextRef + transformer_options[REF_READ_ATTN_CONTROL_LIST] = read_attn_list + transformer_options[REF_WRITE_ATTN_CONTROL_LIST] = write_attn_list + transformer_options[REF_READ_ADAIN_CONTROL_LIST] = read_adain_list + transformer_options[REF_WRITE_ADAIN_CONTROL_LIST] = write_adain_list + # run diffusion for real + try: + return reference_injections.diffusion_model_orig_forward(x, *args, **kwargs) + finally: + # increment current cond idx + if len(context_controlnets) > 0: + for cn in context_controlnets: + cn.contextref_cond_idx += 1 + finally: + # make sure ref banks are cleared no matter what happens - otherwise, RIP VRAM + reference_injections.clean_ref_module_mem() + if len(adain_controlnets) > 0 or len(context_adain_controlnets) > 0: + openaimodel.forward_timestep_embed = orig_forward_timestep_embed + + + return forward_inject_UNetModel + + +# dummy class just to help IDE keep track of injected variables +class RefBasicTransformerBlock(BasicTransformerBlock): + injection_holder: InjectionBasicTransformerBlockHolder = None + +def _forward_inject_BasicTransformerBlock(self: RefBasicTransformerBlock, x: Tensor, context: Tensor=None, transformer_options: dict[str]={}): + extra_options = {} + block = transformer_options.get("block", None) + block_index = transformer_options.get("block_index", 0) + transformer_patches = {} + transformer_patches_replace = {} + + for k in transformer_options: + if k == "patches": + transformer_patches = transformer_options[k] + elif k == "patches_replace": + transformer_patches_replace = transformer_options[k] + else: + extra_options[k] = transformer_options[k] + + extra_options["n_heads"] = self.n_heads + extra_options["dim_head"] = self.d_head + + if self.ff_in: + x_skip = x + x = self.ff_in(self.norm_in(x)) + if self.is_res: + x += x_skip + + n: Tensor = self.norm1(x) + if self.disable_self_attn: + context_attn1 = context + else: + context_attn1 = None + value_attn1 = None + + # Reference CN stuff + uc_idx_mask = transformer_options.get(REF_UNCOND_IDXS, []) + #c_idx_mask = transformer_options.get(REF_COND_IDXS, []) + # WRITE mode may have only 1 ReferenceAdvanced for RefCN at a time, other modes will have all ReferenceAdvanced + ref_write_cns: list[ReferenceAdvanced] = transformer_options.get(REF_WRITE_ATTN_CONTROL_LIST, []) + ref_read_cns: list[ReferenceAdvanced] = transformer_options.get(REF_READ_ATTN_CONTROL_LIST, []) + cref_cond_idx: int = transformer_options.get(CONTEXTREF_TEMP_COND_IDX, -1) + ignore_contextref_read = cref_cond_idx < 0 # if writing to bank, should NOT be read in the same execution + + cached_n = None + cref_write_cns: list[ReferenceAdvanced] = [] + # check if any WRITE cns are applicable; Reference CN WRITEs immediately, ContextREF WRITEs after READ completed + # if any refs to WRITE, save n and style_fidelity + for refcn in ref_write_cns: + if refcn.ref_opts.attn_ref_weight > self.injection_holder.attn_weight: + if cached_n is None: + cached_n = n.detach().clone() + # for ContextRef, make sure relevant lists are long enough to cond_idx + # store RefCN and ContextRef stuff separately + if refcn.is_context_ref: + cref_write_cns.append(refcn) + self.injection_holder.bank_styles.init_cref_for_idx(cref_cond_idx) + else: # Reference CN WRITE + self.injection_holder.bank_styles.bank.append(cached_n) + self.injection_holder.bank_styles.style_cfgs.append(refcn.ref_opts.attn_style_fidelity) + self.injection_holder.bank_styles.cn_idx.append(refcn.order) + if len(cref_write_cns) == 0: + del cached_n + + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + if context_attn1 is None: + context_attn1 = n + value_attn1 = context_attn1 + for p in patch: + n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options) + + if block is not None: + transformer_block = (block[0], block[1], block_index) + else: + transformer_block = None + attn1_replace_patch = transformer_patches_replace.get("attn1", {}) + block_attn1 = transformer_block + if block_attn1 not in attn1_replace_patch: + block_attn1 = block + + if block_attn1 in attn1_replace_patch: + if context_attn1 is None: + context_attn1 = n + value_attn1 = n + n = self.attn1.to_q(n) + # Reference CN READ - use attn1_replace_patch appropriately + if len(ref_read_cns) > 0 and len(self.injection_holder.bank_styles.get_bank(cref_cond_idx, ignore_contextref_read)) > 0: + bank_styles = self.injection_holder.bank_styles + style_fidelity = bank_styles.get_avg_style_fidelity(cref_cond_idx, ignore_contextref_read) + real_bank = bank_styles.get_bank(cref_cond_idx, ignore_contextref_read, cdevice=n.device).copy() + real_cn_idxs = bank_styles.get_cn_idxs(cref_cond_idx, ignore_contextref_read) + cn_idx = 0 + for idx, order in enumerate(real_cn_idxs): + # make sure matching ref cn is selected + for i in range(cn_idx, len(ref_read_cns)): + if ref_read_cns[i].order == order: + cn_idx = i + break + assert order == ref_read_cns[cn_idx].order + if ref_read_cns[cn_idx].any_attn_strength_to_apply(): + effective_strength = ref_read_cns[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle) + real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength) + n_uc = self.attn1.to_out(attn1_replace_patch[block_attn1]( + n, + self.attn1.to_k(torch.cat([context_attn1] + real_bank, dim=1)), + self.attn1.to_v(torch.cat([value_attn1] + real_bank, dim=1)), + extra_options)) + n_c = n_uc.clone() + if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0): + n_c[uc_idx_mask] = self.attn1.to_out(attn1_replace_patch[block_attn1]( + n[uc_idx_mask], + self.attn1.to_k(context_attn1[uc_idx_mask]), + self.attn1.to_v(value_attn1[uc_idx_mask]), + extra_options)) + n = style_fidelity * n_c + (1.0-style_fidelity) * n_uc + bank_styles.clean_ref() + else: + context_attn1 = self.attn1.to_k(context_attn1) + value_attn1 = self.attn1.to_v(value_attn1) + n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) + n = self.attn1.to_out(n) + else: + # Reference CN READ - no attn1_replace_patch + if len(ref_read_cns) > 0 and len(self.injection_holder.bank_styles.get_bank(cref_cond_idx, ignore_contextref_read)) > 0: + if context_attn1 is None: + context_attn1 = n + bank_styles = self.injection_holder.bank_styles + style_fidelity = bank_styles.get_avg_style_fidelity(cref_cond_idx, ignore_contextref_read) + real_bank = bank_styles.get_bank(cref_cond_idx, ignore_contextref_read, cdevice=n.device).copy() + real_cn_idxs = bank_styles.get_cn_idxs(cref_cond_idx, ignore_contextref_read) + cn_idx = 0 + for idx, order in enumerate(real_cn_idxs): + # make sure matching ref cn is selected + for i in range(cn_idx, len(ref_read_cns)): + if ref_read_cns[i].order == order: + cn_idx = i + break + assert order == ref_read_cns[cn_idx].order + if ref_read_cns[cn_idx].any_attn_strength_to_apply(): + effective_strength = ref_read_cns[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle) + real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength) + n_uc: Tensor = self.attn1( + n, + context=torch.cat([context_attn1] + real_bank, dim=1), + value=torch.cat([value_attn1] + real_bank, dim=1) if value_attn1 is not None else value_attn1) + n_c = n_uc.clone() + if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0): + n_c[uc_idx_mask] = self.attn1( + n[uc_idx_mask], + context=context_attn1[uc_idx_mask], + value=value_attn1[uc_idx_mask] if value_attn1 is not None else value_attn1) + n = style_fidelity * n_c + (1.0-style_fidelity) * n_uc + bank_styles.clean_ref() + else: + n = self.attn1(n, context=context_attn1, value=value_attn1) + + # ContextRef CN WRITE + if len(cref_write_cns) > 0: + # clear so that ContextRef CNs can properly 'replace' previous value at cond_idx + self.injection_holder.bank_styles.clear_cref_for_idx(cref_cond_idx) + for refcn in cref_write_cns: + # add a whole list to match expected type when combining + self.injection_holder.bank_styles.c_bank[cref_cond_idx].append(cached_n.to(comfy.model_management.unet_offload_device())) + self.injection_holder.bank_styles.c_style_cfgs[cref_cond_idx].append(refcn.ref_opts.attn_style_fidelity) + self.injection_holder.bank_styles.c_cn_idx[cref_cond_idx].append(refcn.order) + del cached_n + + if "attn1_output_patch" in transformer_patches: + patch = transformer_patches["attn1_output_patch"] + for p in patch: + n = p(n, extra_options) + + x += n + if "middle_patch" in transformer_patches: + patch = transformer_patches["middle_patch"] + for p in patch: + x = p(x, extra_options) + + if self.attn2 is not None: + n = self.norm2(x) + if self.switch_temporal_ca_to_sa: + context_attn2 = n + else: + context_attn2 = context + value_attn2 = None + if "attn2_patch" in transformer_patches: + patch = transformer_patches["attn2_patch"] + value_attn2 = context_attn2 + for p in patch: + n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) + + attn2_replace_patch = transformer_patches_replace.get("attn2", {}) + block_attn2 = transformer_block + if block_attn2 not in attn2_replace_patch: + block_attn2 = block + + if block_attn2 in attn2_replace_patch: + if value_attn2 is None: + value_attn2 = context_attn2 + n = self.attn2.to_q(n) + context_attn2 = self.attn2.to_k(context_attn2) + value_attn2 = self.attn2.to_v(value_attn2) + n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) + n = self.attn2.to_out(n) + else: + n = self.attn2(n, context=context_attn2, value=value_attn2) + + if "attn2_output_patch" in transformer_patches: + patch = transformer_patches["attn2_output_patch"] + for p in patch: + n = p(n, extra_options) + + x += n + if self.is_res: + x_skip = x + x = self.ff(self.norm3(x)) + if self.is_res: + x += x_skip + + return x + + +class RefTimestepEmbedSequential(openaimodel.TimestepEmbedSequential): + injection_holder: InjectionTimestepEmbedSequentialHolder = None + +def forward_timestep_embed_ref_inject_factory(orig_timestep_embed_inject_factory: Callable): + def forward_timestep_embed_ref_inject(*args, **kwargs): + ts: RefTimestepEmbedSequential = args[0] + if not hasattr(ts, "injection_holder"): + return orig_timestep_embed_inject_factory(*args, **kwargs) + eps = 1e-6 + x: Tensor = orig_timestep_embed_inject_factory(*args, **kwargs) + y: Tensor = None + transformer_options: dict[str] = args[4] + # Reference CN stuff + uc_idx_mask = transformer_options.get(REF_UNCOND_IDXS, []) + #c_idx_mask = transformer_options.get(REF_COND_IDXS, []) + # WRITE mode will only have one ReferenceAdvanced, other modes will have all ReferenceAdvanced + ref_write_cns: list[ReferenceAdvanced] = transformer_options.get(REF_WRITE_ADAIN_CONTROL_LIST, []) + ref_read_cns: list[ReferenceAdvanced] = transformer_options.get(REF_READ_ADAIN_CONTROL_LIST, []) + cref_cond_idx: int = transformer_options.get(CONTEXTREF_TEMP_COND_IDX, -1) + ignore_contextref_read = cref_cond_idx < 0 # if writing to bank, should NOT be read in the same execution + + cached_var = None + cached_mean = None + cref_write_cns: list[ReferenceAdvanced] = [] + # if any refs to WRITE, save var, mean, and style_cfg + for refcn in ref_write_cns: + if refcn.ref_opts.adain_ref_weight > ts.injection_holder.gn_weight: + if cached_var is None: + cached_var, cached_mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + if refcn.is_context_ref: + cref_write_cns.append(refcn) + ts.injection_holder.bank_styles.init_cref_for_idx(cref_cond_idx) + else: + ts.injection_holder.bank_styles.var_bank.append(cached_var) + ts.injection_holder.bank_styles.mean_bank.append(cached_mean) + ts.injection_holder.bank_styles.style_cfgs.append(refcn.ref_opts.adain_style_fidelity) + ts.injection_holder.bank_styles.cn_idx.append(refcn.order) + if len(cref_write_cns) == 0: + del cached_var + del cached_mean + + # if any refs to READ, do math with saved var, mean, and style_cfg + if len(ref_read_cns) > 0: + if len(ts.injection_holder.bank_styles.get_var_bank(cref_cond_idx, ignore_contextref_read)) > 0: + bank_styles = ts.injection_holder.bank_styles + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + y_uc = torch.zeros_like(x) + cn_idx = 0 + real_style_cfgs = bank_styles.get_style_cfgs(cref_cond_idx, ignore_contextref_read) + real_var_bank = bank_styles.get_var_bank(cref_cond_idx, ignore_contextref_read) + real_mean_bank = bank_styles.get_mean_bank(cref_cond_idx, ignore_contextref_read) + real_cn_idxs = bank_styles.get_cn_idxs(cref_cond_idx, ignore_contextref_read) + for idx, order in enumerate(real_cn_idxs): + # make sure matching ref cn is selected + for i in range(cn_idx, len(ref_read_cns)): + if ref_read_cns[i].order == order: + cn_idx = i + break + assert order == ref_read_cns[cn_idx].order + style_fidelity = real_style_cfgs[idx] + var_acc = real_var_bank[idx] + mean_acc = real_mean_bank[idx] + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + sub_y_uc = (((x - mean) / std) * std_acc) + mean_acc + if ref_read_cns[cn_idx].any_adain_strength_to_apply(): + effective_strength = ref_read_cns[cn_idx].get_effective_adain_mask_or_float(x=x) + sub_y_uc = sub_y_uc * effective_strength + x * (1-effective_strength) + y_uc += sub_y_uc + # get average, if more than one + if len(real_cn_idxs) > 1: + y_uc /= len(real_cn_idxs) + y_c = y_uc.clone() + if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0): + y_c[uc_idx_mask] = x.to(y_c.dtype)[uc_idx_mask] + y = style_fidelity * y_c + (1.0 - style_fidelity) * y_uc + ts.injection_holder.bank_styles.clean_ref() + + # ContextRef CN WRITE + if len(cref_write_cns) > 0: + # clear so that ContextRef CNs can properly 'replace' previous value at cond_idx + ts.injection_holder.bank_styles.clear_cref_for_idx(cref_cond_idx) + for refcn in cref_write_cns: + # add a whole list to match expected type when combining + ts.injection_holder.bank_styles.c_var_bank[cref_cond_idx].append(cached_var) + ts.injection_holder.bank_styles.c_mean_bank[cref_cond_idx].append(cached_mean) + ts.injection_holder.bank_styles.c_style_cfgs[cref_cond_idx].append(refcn.ref_opts.adain_style_fidelity) + ts.injection_holder.bank_styles.c_cn_idx[cref_cond_idx].append(refcn.order) + del cached_var + del cached_mean + + if y is None: + y = x + return y.to(x.dtype) + + return forward_timestep_embed_ref_inject