# Based on https://raw.githubusercontent.com/okotaku/diffusers/feature/reference_only_control/examples/community/stable_diffusion_reference.py # Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 from typing import Any, Callable, Dict, List, Optional, Union, Tuple import numpy as np import PIL.Image import torch from diffusers import StableDiffusionPipeline from diffusers.models.attention import BasicTransformerBlock from diffusers.models.unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D, ) from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import PIL_INTERPOLATION, logging import torch.nn.functional as F logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import UniPCMultistepScheduler >>> from diffusers.utils import load_image >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") >>> pipe = StableDiffusionReferencePipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", safety_checker=None, torch_dtype=torch.float16 ).to('cuda:0') >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config) >>> result_img = pipe(ref_image=input_image, prompt="1girl", num_inference_steps=20, reference_attn=True, reference_adain=True).images[0] >>> result_img.show() ``` """ def torch_dfs(model: torch.nn.Module): result = [model] for child in model.children(): result += torch_dfs(child) return result class StableDiffusionReferencePipeline: def prepare_ref_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): if not isinstance(image, torch.Tensor): if isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): images = [] for image_ in image: image_ = image_.convert("RGB") image_ = image_.resize( (width, height), resample=PIL_INTERPOLATION["lanczos"] ) image_ = np.array(image_) image_ = image_[None, :] images.append(image_) image = images image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = (image - 0.5) / 0.5 image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image def prepare_ref_latents( self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance, ): refimage = refimage.to(device=device, dtype=dtype) # encode the mask image into latents space so we can concatenate it to the latents if isinstance(generator, list): ref_image_latents = [ self.vae.encode(refimage[i: i + 1]).latent_dist.sample( generator=generator[i] ) for i in range(batch_size) ] ref_image_latents = torch.cat(ref_image_latents, dim=0) else: ref_image_latents = self.vae.encode(refimage).latent_dist.sample( generator=generator ) ref_image_latents = self.vae.config.scaling_factor * ref_image_latents # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method if ref_image_latents.shape[0] < batch_size: if not batch_size % ref_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) ref_image_latents = ref_image_latents.repeat( batch_size // ref_image_latents.shape[0], 1, 1, 1 ) ref_image_latents = ( torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents ) # aligning device to prevent device errors when concating it with the latent model input ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) return ref_image_latents def check_ref_input(self, reference_attn, reference_adain): assert ( reference_attn or reference_adain ), "`reference_attn` or `reference_adain` must be True." def redefine_ref_model( self, model, reference_attn, reference_adain, model_type="unet" ): def hacked_basic_transformer_inner_forward( self, hidden_states: torch.FloatTensor, attention_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, ): if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: ( norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, ) = self.norm1( hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype, ) else: norm_hidden_states = self.norm1(hidden_states) # 1. Self-Attention cross_attention_kwargs = ( cross_attention_kwargs if cross_attention_kwargs is not None else {} ) if self.only_cross_attention: attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) else: if self.MODE == "write": if self.attention_auto_machine_weight > self.attn_weight: # print("hacked_basic_transformer_inner_forward") scale_ratio = ( (self.ref_mask.shape[2] * self.ref_mask.shape[3]) / norm_hidden_states.shape[1] ) ** 0.5 this_ref_mask = F.interpolate( self.ref_mask.to(norm_hidden_states.device), scale_factor=1 / scale_ratio, ) resize_norm_hidden_states = norm_hidden_states.view( norm_hidden_states.shape[0], this_ref_mask.shape[2], this_ref_mask.shape[3], -1, ).permute(0, 3, 1, 2) ref_scale = 1.0 resize_norm_hidden_states = F.interpolate( resize_norm_hidden_states, scale_factor=ref_scale, mode="bilinear", ) this_ref_mask = F.interpolate( this_ref_mask, scale_factor=ref_scale ) # print("this_ref_mask",this_ref_mask.shape) # this_ref_mask = this_ref_mask.view(1,-1,1) this_ref_mask = this_ref_mask.repeat( resize_norm_hidden_states.shape[0], resize_norm_hidden_states.shape[1], 1, 1, ).bool() masked_norm_hidden_states = ( resize_norm_hidden_states[this_ref_mask] .detach() .clone() .view( resize_norm_hidden_states.shape[0], resize_norm_hidden_states.shape[1], -1, ) ) masked_norm_hidden_states = masked_norm_hidden_states.permute( 0, 2, 1 ) self.bank.append(masked_norm_hidden_states) # self.bank.append(norm_hidden_states.detach().clone()) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) if self.MODE == "read": if self.attention_auto_machine_weight > self.attn_weight: # scale_ratio = ((self.ref_mask.shape[2] * self.ref_mask.shape[3])/norm_hidden_states.shape[1])**0.5 # print(scale_ratio) # this_ref_mask = F.interpolate(self.ref_mask.to(norm_hidden_states.device), scale_factor=1/scale_ratio).view(1,1,-1) # print("resized mask", this_ref_mask.shape, this_ref_mask.max(), this_ref_mask.min(), this_ref_mask.sum()) # ref_hidden_states = torch.cat([norm_hidden_states] + self.bank, dim=1) # if attention_mask is None: # attention_mask = torch.ones( # norm_hidden_states.shape[0], norm_hidden_states.shape[1], ref_hidden_states.shape[1], dtype=norm_hidden_states.dtype, device=norm_hidden_states.device # ) # this_ref_mask = this_ref_mask.repeat(norm_hidden_states.shape[0], norm_hidden_states.shape[1], 1) # this_ref_mask = torch.zeros( # norm_hidden_states.shape[0], norm_hidden_states.shape[1], this_ref_mask.shape[1], dtype=norm_hidden_states.dtype, device=norm_hidden_states.device # ) # print(attention_mask.shape, this_ref_mask.shape) # attention_mask = torch.cat((attention_mask, this_ref_mask), dim=-1) # print("merge", attention_mask.shape) ref_hidden_states = torch.cat( [norm_hidden_states] + self.bank, dim=1 ) attn_output_uc = self.attn1( norm_hidden_states, encoder_hidden_states=ref_hidden_states, # attention_mask=attention_mask, **cross_attention_kwargs, ) attn_output_c = attn_output_uc.clone() if self.do_classifier_free_guidance and self.style_fidelity > 0: attn_output_c[self.uc_mask] = self.attn1( norm_hidden_states[self.uc_mask], encoder_hidden_states=norm_hidden_states[self.uc_mask], **cross_attention_kwargs, ) attn_output = ( self.style_fidelity * attn_output_c + (1.0 - self.style_fidelity) * attn_output_uc ) self.bank.clear() else: attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = attn_output + hidden_states if self.attn2 is not None: norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) # 2. Cross-Attention attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: norm_hidden_states = ( norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ) ff_output = self.ff(norm_hidden_states) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = ff_output + hidden_states return hidden_states def hacked_mid_forward(self, *args, **kwargs): eps = 1e-6 x = self.original_forward(*args, **kwargs) if self.MODE == "write": if self.gn_auto_machine_weight >= self.gn_weight: # mask var mean scale_ratio = self.ref_mask.shape[2] / x.shape[2] this_ref_mask = F.interpolate( self.ref_mask.to(x.device), scale_factor=1 / scale_ratio ) this_ref_mask = this_ref_mask.repeat( x.shape[0], x.shape[1], 1, 1 ).bool() masked_x = ( x[this_ref_mask] .detach() .clone() .view(x.shape[0], x.shape[1], -1, 1) ) var, mean = torch.var_mean( masked_x, dim=(2, 3), keepdim=True, correction=0 ) self.mean_bank.append(mean) self.var_bank.append(var) if self.MODE == "read": if ( self.gn_auto_machine_weight >= self.gn_weight and len(self.mean_bank) > 0 and len(self.var_bank) > 0 ): # print("hacked_mid_forward") scale_ratio = self.inpaint_mask.shape[2] / x.shape[2] this_inpaint_mask = F.interpolate( self.inpaint_mask.to(x.device), scale_factor=1 / scale_ratio ) this_inpaint_mask = this_inpaint_mask.repeat( x.shape[0], x.shape[1], 1, 1 ).bool() masked_x = ( x[this_inpaint_mask] .detach() .clone() .view(x.shape[0], x.shape[1], -1, 1) ) var, mean = torch.var_mean( masked_x, dim=(2, 3), keepdim=True, correction=0 ) std = torch.maximum( var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) var_acc = sum(self.var_bank) / float(len(self.var_bank)) std_acc = ( torch.maximum(var_acc, torch.zeros_like( var_acc) + eps) ** 0.5 ) x_uc = (((masked_x - mean) / std) * std_acc) + mean_acc x_c = x_uc.clone() if self.do_classifier_free_guidance and self.style_fidelity > 0: x_c[self.uc_mask] = masked_x[self.uc_mask] masked_x = self.style_fidelity * x_c + \ (1.0 - self.style_fidelity) * x_uc x[this_inpaint_mask] = masked_x.view(-1) self.mean_bank = [] self.var_bank = [] return x def hack_CrossAttnDownBlock2D_forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): eps = 1e-6 # TODO(Patrick, William) - attention mask is not used output_states = () for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): hidden_states = resnet(hidden_states, temb) if self.MODE == "write": if self.gn_auto_machine_weight >= self.gn_weight: # var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) # mask var mean scale_ratio = self.ref_mask.shape[2] / \ hidden_states.shape[2] this_ref_mask = F.interpolate( self.ref_mask.to(hidden_states.device), scale_factor=1 / scale_ratio, ) this_ref_mask = this_ref_mask.repeat( hidden_states.shape[0], hidden_states.shape[1], 1, 1 ).bool() masked_hidden_states = ( hidden_states[this_ref_mask] .detach() .clone() .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) ) var, mean = torch.var_mean( masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 ) self.mean_bank0.append(mean) self.var_bank0.append(var) if self.MODE == "read": if ( self.gn_auto_machine_weight >= self.gn_weight and len(self.mean_bank0) > 0 and len(self.var_bank0) > 0 ): # print("hacked_CrossAttnDownBlock2D_forward0") scale_ratio = self.inpaint_mask.shape[2] / \ hidden_states.shape[2] this_inpaint_mask = F.interpolate( self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio ) this_inpaint_mask = this_inpaint_mask.repeat( hidden_states.shape[0], hidden_states.shape[1], 1, 1 ).bool() masked_hidden_states = ( hidden_states[this_inpaint_mask] .detach() .clone() .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) ) var, mean = torch.var_mean( masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 ) std = torch.maximum( var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank0[i]) / float( len(self.mean_bank0[i]) ) var_acc = sum( self.var_bank0[i]) / float(len(self.var_bank0[i])) std_acc = ( torch.maximum( var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 ) hidden_states_uc = ( ((masked_hidden_states - mean) / std) * std_acc ) + mean_acc hidden_states_c = hidden_states_uc.clone() if self.do_classifier_free_guidance and self.style_fidelity > 0: hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask] masked_hidden_states = ( self.style_fidelity * hidden_states_c + (1.0 - self.style_fidelity) * hidden_states_uc ) hidden_states[this_inpaint_mask] = masked_hidden_states.view( -1) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, # attention_mask=attention_mask, # encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if self.MODE == "write": if self.gn_auto_machine_weight >= self.gn_weight: # mask var mean scale_ratio = self.ref_mask.shape[2] / \ hidden_states.shape[2] this_ref_mask = F.interpolate( self.ref_mask.to(hidden_states.device), scale_factor=1 / scale_ratio, ) this_ref_mask = this_ref_mask.repeat( hidden_states.shape[0], hidden_states.shape[1], 1, 1 ).bool() masked_hidden_states = ( hidden_states[this_ref_mask] .detach() .clone() .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) ) var, mean = torch.var_mean( masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 ) self.mean_bank.append(mean) self.var_bank.append(var) if self.MODE == "read": if ( self.gn_auto_machine_weight >= self.gn_weight and len(self.mean_bank) > 0 and len(self.var_bank) > 0 ): # print("hack_CrossAttnDownBlock2D_forward") scale_ratio = self.inpaint_mask.shape[2] / \ hidden_states.shape[2] this_inpaint_mask = F.interpolate( self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio ) this_inpaint_mask = this_inpaint_mask.repeat( hidden_states.shape[0], hidden_states.shape[1], 1, 1 ).bool() masked_hidden_states = ( hidden_states[this_inpaint_mask] .detach() .clone() .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) ) var, mean = torch.var_mean( masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 ) std = torch.maximum( var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank[i]) / float( len(self.mean_bank[i]) ) var_acc = sum( self.var_bank[i]) / float(len(self.var_bank[i])) std_acc = ( torch.maximum( var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 ) hidden_states_uc = ( ((masked_hidden_states - mean) / std) * std_acc ) + mean_acc hidden_states_c = hidden_states_uc.clone() if self.do_classifier_free_guidance and self.style_fidelity > 0: hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask] masked_hidden_states = ( self.style_fidelity * hidden_states_c + (1.0 - self.style_fidelity) * hidden_states_uc ) hidden_states[this_inpaint_mask] = masked_hidden_states.view( -1) output_states = output_states + (hidden_states,) if self.MODE == "read": self.mean_bank0 = [] self.var_bank0 = [] self.mean_bank = [] self.var_bank = [] if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states def hacked_DownBlock2D_forward(self, hidden_states, temb=None): eps = 1e-6 output_states = () for i, resnet in enumerate(self.resnets): hidden_states = resnet(hidden_states, temb) if self.MODE == "write": if self.gn_auto_machine_weight >= self.gn_weight: # var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) # mask var mean scale_ratio = self.ref_mask.shape[2] / \ hidden_states.shape[2] this_ref_mask = F.interpolate( self.ref_mask.to(hidden_states.device), scale_factor=1 / scale_ratio, ) this_ref_mask = this_ref_mask.repeat( hidden_states.shape[0], hidden_states.shape[1], 1, 1 ).bool() masked_hidden_states = ( hidden_states[this_ref_mask] .detach() .clone() .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) ) var, mean = torch.var_mean( masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 ) self.mean_bank.append(mean) self.var_bank.append(var) if self.MODE == "read": if ( self.gn_auto_machine_weight >= self.gn_weight and len(self.mean_bank) > 0 and len(self.var_bank) > 0 ): # print("hacked_DownBlock2D_forward") scale_ratio = self.inpaint_mask.shape[2] / \ hidden_states.shape[2] this_inpaint_mask = F.interpolate( self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio ) this_inpaint_mask = this_inpaint_mask.repeat( hidden_states.shape[0], hidden_states.shape[1], 1, 1 ).bool() masked_hidden_states = ( hidden_states[this_inpaint_mask] .detach() .clone() .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) ) var, mean = torch.var_mean( masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 ) std = torch.maximum( var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank[i]) / float( len(self.mean_bank[i]) ) var_acc = sum( self.var_bank[i]) / float(len(self.var_bank[i])) std_acc = ( torch.maximum( var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 ) hidden_states_uc = ( ((masked_hidden_states - mean) / std) * std_acc ) + mean_acc hidden_states_c = hidden_states_uc.clone() if self.do_classifier_free_guidance and self.style_fidelity > 0: hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask] masked_hidden_states = ( self.style_fidelity * hidden_states_c + (1.0 - self.style_fidelity) * hidden_states_uc ) hidden_states[this_inpaint_mask] = masked_hidden_states.view( -1) output_states = output_states + (hidden_states,) if self.MODE == "read": self.mean_bank = [] self.var_bank = [] if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states def hacked_CrossAttnUpBlock2D_forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): eps = 1e-6 # TODO(Patrick, William) - attention mask is not used for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat( [hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) if self.MODE == "write": if self.gn_auto_machine_weight >= self.gn_weight: # var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) # mask var mean scale_ratio = self.ref_mask.shape[2] / \ hidden_states.shape[2] this_ref_mask = F.interpolate( self.ref_mask.to(hidden_states.device), scale_factor=1 / scale_ratio, ) this_ref_mask = this_ref_mask.repeat( hidden_states.shape[0], hidden_states.shape[1], 1, 1 ).bool() masked_hidden_states = ( hidden_states[this_ref_mask] .detach() .clone() .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) ) var, mean = torch.var_mean( masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 ) self.mean_bank0.append(mean) self.var_bank0.append(var) if self.MODE == "read": if ( self.gn_auto_machine_weight >= self.gn_weight and len(self.mean_bank0) > 0 and len(self.var_bank0) > 0 ): # print("hacked_CrossAttnUpBlock2D_forward1") scale_ratio = self.inpaint_mask.shape[2] / \ hidden_states.shape[2] this_inpaint_mask = F.interpolate( self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio ) this_inpaint_mask = this_inpaint_mask.repeat( hidden_states.shape[0], hidden_states.shape[1], 1, 1 ).bool() masked_hidden_states = ( hidden_states[this_inpaint_mask] .detach() .clone() .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) ) var, mean = torch.var_mean( masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 ) std = torch.maximum( var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank0[i]) / float( len(self.mean_bank0[i]) ) var_acc = sum( self.var_bank0[i]) / float(len(self.var_bank0[i])) std_acc = ( torch.maximum( var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 ) hidden_states_uc = ( ((masked_hidden_states - mean) / std) * std_acc ) + mean_acc hidden_states_c = hidden_states_uc.clone() if self.do_classifier_free_guidance and self.style_fidelity > 0: hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask] masked_hidden_states = ( self.style_fidelity * hidden_states_c + (1.0 - self.style_fidelity) * hidden_states_uc ) hidden_states[this_inpaint_mask] = masked_hidden_states.view( -1) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, # attention_mask=attention_mask, # encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if self.MODE == "write": if self.gn_auto_machine_weight >= self.gn_weight: # var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) # mask var mean scale_ratio = self.ref_mask.shape[2] / \ hidden_states.shape[2] this_ref_mask = F.interpolate( self.ref_mask.to(hidden_states.device), scale_factor=1 / scale_ratio, ) this_ref_mask = this_ref_mask.repeat( hidden_states.shape[0], hidden_states.shape[1], 1, 1 ).bool() masked_hidden_states = ( hidden_states[this_ref_mask] .detach() .clone() .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) ) var, mean = torch.var_mean( masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 ) self.mean_bank.append(mean) self.var_bank.append(var) if self.MODE == "read": if ( self.gn_auto_machine_weight >= self.gn_weight and len(self.mean_bank) > 0 and len(self.var_bank) > 0 ): # print("hacked_CrossAttnUpBlock2D_forward") scale_ratio = self.inpaint_mask.shape[2] / \ hidden_states.shape[2] this_inpaint_mask = F.interpolate( self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio ) this_inpaint_mask = this_inpaint_mask.repeat( hidden_states.shape[0], hidden_states.shape[1], 1, 1 ).bool() masked_hidden_states = ( hidden_states[this_inpaint_mask] .detach() .clone() .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) ) var, mean = torch.var_mean( masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 ) std = torch.maximum( var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank[i]) / float( len(self.mean_bank[i]) ) var_acc = sum( self.var_bank[i]) / float(len(self.var_bank[i])) std_acc = ( torch.maximum( var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 ) hidden_states_uc = ( ((masked_hidden_states - mean) / std) * std_acc ) + mean_acc hidden_states_c = hidden_states_uc.clone() if self.do_classifier_free_guidance and self.style_fidelity > 0: hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask] masked_hidden_states = ( self.style_fidelity * hidden_states_c + (1.0 - self.style_fidelity) * hidden_states_uc ) hidden_states[this_inpaint_mask] = masked_hidden_states.view( -1) if self.MODE == "read": self.mean_bank0 = [] self.var_bank0 = [] self.mean_bank = [] self.var_bank = [] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states def hacked_UpBlock2D_forward( self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None ): eps = 1e-6 for i, resnet in enumerate(self.resnets): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat( [hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) if self.MODE == "write": if self.gn_auto_machine_weight >= self.gn_weight: # var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) # mask var mean scale_ratio = self.ref_mask.shape[2] / \ hidden_states.shape[2] this_ref_mask = F.interpolate( self.ref_mask.to(hidden_states.device), scale_factor=1 / scale_ratio, ) this_ref_mask = this_ref_mask.repeat( hidden_states.shape[0], hidden_states.shape[1], 1, 1 ).bool() masked_hidden_states = ( hidden_states[this_ref_mask] .detach() .clone() .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) ) var, mean = torch.var_mean( masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 ) self.mean_bank.append(mean) self.var_bank.append(var) if self.MODE == "read": if ( self.gn_auto_machine_weight >= self.gn_weight and len(self.mean_bank) > 0 and len(self.var_bank) > 0 ): # print("hacked_UpBlock2D_forward") scale_ratio = self.inpaint_mask.shape[2] / \ hidden_states.shape[2] this_inpaint_mask = F.interpolate( self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio ) this_inpaint_mask = this_inpaint_mask.repeat( hidden_states.shape[0], hidden_states.shape[1], 1, 1 ).bool() masked_hidden_states = ( hidden_states[this_inpaint_mask] .detach() .clone() .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) ) var, mean = torch.var_mean( masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 ) std = torch.maximum( var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank[i]) / float( len(self.mean_bank[i]) ) var_acc = sum( self.var_bank[i]) / float(len(self.var_bank[i])) std_acc = ( torch.maximum( var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 ) hidden_states_uc = ( ((masked_hidden_states - mean) / std) * std_acc ) + mean_acc hidden_states_c = hidden_states_uc.clone() if self.do_classifier_free_guidance and self.style_fidelity > 0: hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask] masked_hidden_states = ( self.style_fidelity * hidden_states_c + (1.0 - self.style_fidelity) * hidden_states_uc ) hidden_states[this_inpaint_mask] = masked_hidden_states.view( -1) if self.MODE == "read": self.mean_bank = [] self.var_bank = [] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states if model_type == "unet": if reference_attn: attn_modules = [ module for module in torch_dfs(model) if isinstance(module, BasicTransformerBlock) ] attn_modules = sorted( attn_modules, key=lambda x: -x.norm1.normalized_shape[0] ) for i, module in enumerate(attn_modules): module._original_inner_forward = module.forward module.forward = hacked_basic_transformer_inner_forward.__get__( module, BasicTransformerBlock ) module.bank = [] module.attn_weight = float(i) / float(len(attn_modules)) module.attention_auto_machine_weight = ( self.attention_auto_machine_weight ) module.gn_auto_machine_weight = self.gn_auto_machine_weight module.do_classifier_free_guidance = ( self.do_classifier_free_guidance ) module.do_classifier_free_guidance = ( self.do_classifier_free_guidance ) module.uc_mask = self.uc_mask module.style_fidelity = self.style_fidelity module.ref_mask = self.ref_mask else: attn_modules = None if reference_adain: gn_modules = [model.mid_block] model.mid_block.gn_weight = 0 down_blocks = model.down_blocks for w, module in enumerate(down_blocks): module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) gn_modules.append(module) # print(module.__class__.__name__,module.gn_weight) up_blocks = model.up_blocks for w, module in enumerate(up_blocks): module.gn_weight = float(w) / float(len(up_blocks)) gn_modules.append(module) # print(module.__class__.__name__,module.gn_weight) for i, module in enumerate(gn_modules): if getattr(module, "original_forward", None) is None: module.original_forward = module.forward if i == 0: # mid_block module.forward = hacked_mid_forward.__get__( module, torch.nn.Module ) elif isinstance(module, CrossAttnDownBlock2D): module.forward = hack_CrossAttnDownBlock2D_forward.__get__( module, CrossAttnDownBlock2D ) module.mean_bank0 = [] module.var_bank0 = [] elif isinstance(module, DownBlock2D): module.forward = hacked_DownBlock2D_forward.__get__( module, DownBlock2D ) # elif isinstance(module, CrossAttnUpBlock2D): # module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D) # module.mean_bank0 = [] # module.var_bank0 = [] elif isinstance(module, UpBlock2D): module.forward = hacked_UpBlock2D_forward.__get__( module, UpBlock2D ) module.mean_bank0 = [] module.var_bank0 = [] module.mean_bank = [] module.var_bank = [] module.attention_auto_machine_weight = ( self.attention_auto_machine_weight ) module.gn_auto_machine_weight = self.gn_auto_machine_weight module.do_classifier_free_guidance = ( self.do_classifier_free_guidance ) module.do_classifier_free_guidance = ( self.do_classifier_free_guidance ) module.uc_mask = self.uc_mask module.style_fidelity = self.style_fidelity module.ref_mask = self.ref_mask module.inpaint_mask = self.inpaint_mask else: gn_modules = None elif model_type == "controlnet": model = model.nets[-1] # only hack the inpainting controlnet if reference_attn: attn_modules = [ module for module in torch_dfs(model) if isinstance(module, BasicTransformerBlock) ] attn_modules = sorted( attn_modules, key=lambda x: -x.norm1.normalized_shape[0] ) for i, module in enumerate(attn_modules): module._original_inner_forward = module.forward module.forward = hacked_basic_transformer_inner_forward.__get__( module, BasicTransformerBlock ) module.bank = [] # float(i) / float(len(attn_modules)) module.attn_weight = 0.0 module.attention_auto_machine_weight = ( self.attention_auto_machine_weight ) module.gn_auto_machine_weight = self.gn_auto_machine_weight module.do_classifier_free_guidance = ( self.do_classifier_free_guidance ) module.do_classifier_free_guidance = ( self.do_classifier_free_guidance ) module.uc_mask = self.uc_mask module.style_fidelity = self.style_fidelity module.ref_mask = self.ref_mask else: attn_modules = None gn_modules = None return attn_modules, gn_modules def change_module_mode(self, mode, attn_modules, gn_modules): if attn_modules is not None: for i, module in enumerate(attn_modules): module.MODE = mode if gn_modules is not None: for i, module in enumerate(gn_modules): module.MODE = mode