EditAnything / utils /stable_diffusion_reference.py
shgao's picture
update new demo
0c7479d
# Based on https://raw.githubusercontent.com/okotaku/diffusers/feature/reference_only_control/examples/community/stable_diffusion_reference.py
# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
import numpy as np
import PIL.Image
import torch
from diffusers import StableDiffusionPipeline
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
UpBlock2D,
)
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import PIL_INTERPOLATION, logging
import torch.nn.functional as F
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import UniPCMultistepScheduler
>>> from diffusers.utils import load_image
>>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
>>> pipe = StableDiffusionReferencePipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
safety_checker=None,
torch_dtype=torch.float16
).to('cuda:0')
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
>>> result_img = pipe(ref_image=input_image,
prompt="1girl",
num_inference_steps=20,
reference_attn=True,
reference_adain=True).images[0]
>>> result_img.show()
```
"""
def torch_dfs(model: torch.nn.Module):
result = [model]
for child in model.children():
result += torch_dfs(child)
return result
class StableDiffusionReferencePipeline:
def prepare_ref_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize(
(width, height), resample=PIL_INTERPOLATION["lanczos"]
)
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = (image - 0.5) / 0.5
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image
def prepare_ref_latents(
self,
refimage,
batch_size,
dtype,
device,
generator,
do_classifier_free_guidance,
):
refimage = refimage.to(device=device, dtype=dtype)
# encode the mask image into latents space so we can concatenate it to the latents
if isinstance(generator, list):
ref_image_latents = [
self.vae.encode(refimage[i: i + 1]).latent_dist.sample(
generator=generator[i]
)
for i in range(batch_size)
]
ref_image_latents = torch.cat(ref_image_latents, dim=0)
else:
ref_image_latents = self.vae.encode(refimage).latent_dist.sample(
generator=generator
)
ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
# duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
if ref_image_latents.shape[0] < batch_size:
if not batch_size % ref_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
ref_image_latents = ref_image_latents.repeat(
batch_size // ref_image_latents.shape[0], 1, 1, 1
)
ref_image_latents = (
torch.cat([ref_image_latents] * 2)
if do_classifier_free_guidance
else ref_image_latents
)
# aligning device to prevent device errors when concating it with the latent model input
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
return ref_image_latents
def check_ref_input(self, reference_attn, reference_adain):
assert (
reference_attn or reference_adain
), "`reference_attn` or `reference_adain` must be True."
def redefine_ref_model(
self, model, reference_attn, reference_adain, model_type="unet"
):
def hacked_basic_transformer_inner_forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
):
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
(
norm_hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
) = self.norm1(
hidden_states,
timestep,
class_labels,
hidden_dtype=hidden_states.dtype,
)
else:
norm_hidden_states = self.norm1(hidden_states)
# 1. Self-Attention
cross_attention_kwargs = (
cross_attention_kwargs if cross_attention_kwargs is not None else {}
)
if self.only_cross_attention:
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states
if self.only_cross_attention
else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
else:
if self.MODE == "write":
if self.attention_auto_machine_weight > self.attn_weight:
# print("hacked_basic_transformer_inner_forward")
scale_ratio = (
(self.ref_mask.shape[2] * self.ref_mask.shape[3])
/ norm_hidden_states.shape[1]
) ** 0.5
this_ref_mask = F.interpolate(
self.ref_mask.to(norm_hidden_states.device),
scale_factor=1 / scale_ratio,
)
resize_norm_hidden_states = norm_hidden_states.view(
norm_hidden_states.shape[0],
this_ref_mask.shape[2],
this_ref_mask.shape[3],
-1,
).permute(0, 3, 1, 2)
ref_scale = 1.0
resize_norm_hidden_states = F.interpolate(
resize_norm_hidden_states,
scale_factor=ref_scale,
mode="bilinear",
)
this_ref_mask = F.interpolate(
this_ref_mask, scale_factor=ref_scale
)
# print("this_ref_mask",this_ref_mask.shape)
# this_ref_mask = this_ref_mask.view(1,-1,1)
this_ref_mask = this_ref_mask.repeat(
resize_norm_hidden_states.shape[0],
resize_norm_hidden_states.shape[1],
1,
1,
).bool()
masked_norm_hidden_states = (
resize_norm_hidden_states[this_ref_mask]
.detach()
.clone()
.view(
resize_norm_hidden_states.shape[0],
resize_norm_hidden_states.shape[1],
-1,
)
)
masked_norm_hidden_states = masked_norm_hidden_states.permute(
0, 2, 1
)
self.bank.append(masked_norm_hidden_states)
# self.bank.append(norm_hidden_states.detach().clone())
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states
if self.only_cross_attention
else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.MODE == "read":
if self.attention_auto_machine_weight > self.attn_weight:
# scale_ratio = ((self.ref_mask.shape[2] * self.ref_mask.shape[3])/norm_hidden_states.shape[1])**0.5
# print(scale_ratio)
# this_ref_mask = F.interpolate(self.ref_mask.to(norm_hidden_states.device), scale_factor=1/scale_ratio).view(1,1,-1)
# print("resized mask", this_ref_mask.shape, this_ref_mask.max(), this_ref_mask.min(), this_ref_mask.sum())
# ref_hidden_states = torch.cat([norm_hidden_states] + self.bank, dim=1)
# if attention_mask is None:
# attention_mask = torch.ones(
# norm_hidden_states.shape[0], norm_hidden_states.shape[1], ref_hidden_states.shape[1], dtype=norm_hidden_states.dtype, device=norm_hidden_states.device
# )
# this_ref_mask = this_ref_mask.repeat(norm_hidden_states.shape[0], norm_hidden_states.shape[1], 1)
# this_ref_mask = torch.zeros(
# norm_hidden_states.shape[0], norm_hidden_states.shape[1], this_ref_mask.shape[1], dtype=norm_hidden_states.dtype, device=norm_hidden_states.device
# )
# print(attention_mask.shape, this_ref_mask.shape)
# attention_mask = torch.cat((attention_mask, this_ref_mask), dim=-1)
# print("merge", attention_mask.shape)
ref_hidden_states = torch.cat(
[norm_hidden_states] + self.bank, dim=1
)
attn_output_uc = self.attn1(
norm_hidden_states,
encoder_hidden_states=ref_hidden_states,
# attention_mask=attention_mask,
**cross_attention_kwargs,
)
attn_output_c = attn_output_uc.clone()
if self.do_classifier_free_guidance and self.style_fidelity > 0:
attn_output_c[self.uc_mask] = self.attn1(
norm_hidden_states[self.uc_mask],
encoder_hidden_states=norm_hidden_states[self.uc_mask],
**cross_attention_kwargs,
)
attn_output = (
self.style_fidelity * attn_output_c
+ (1.0 - self.style_fidelity) * attn_output_uc
)
self.bank.clear()
else:
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states
if self.only_cross_attention
else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm2(hidden_states)
)
# 2. Cross-Attention
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = (
norm_hidden_states *
(1 + scale_mlp[:, None]) + shift_mlp[:, None]
)
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states
def hacked_mid_forward(self, *args, **kwargs):
eps = 1e-6
x = self.original_forward(*args, **kwargs)
if self.MODE == "write":
if self.gn_auto_machine_weight >= self.gn_weight:
# mask var mean
scale_ratio = self.ref_mask.shape[2] / x.shape[2]
this_ref_mask = F.interpolate(
self.ref_mask.to(x.device), scale_factor=1 / scale_ratio
)
this_ref_mask = this_ref_mask.repeat(
x.shape[0], x.shape[1], 1, 1
).bool()
masked_x = (
x[this_ref_mask]
.detach()
.clone()
.view(x.shape[0], x.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_x, dim=(2, 3), keepdim=True, correction=0
)
self.mean_bank.append(mean)
self.var_bank.append(var)
if self.MODE == "read":
if (
self.gn_auto_machine_weight >= self.gn_weight
and len(self.mean_bank) > 0
and len(self.var_bank) > 0
):
# print("hacked_mid_forward")
scale_ratio = self.inpaint_mask.shape[2] / x.shape[2]
this_inpaint_mask = F.interpolate(
self.inpaint_mask.to(x.device), scale_factor=1 / scale_ratio
)
this_inpaint_mask = this_inpaint_mask.repeat(
x.shape[0], x.shape[1], 1, 1
).bool()
masked_x = (
x[this_inpaint_mask]
.detach()
.clone()
.view(x.shape[0], x.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_x, dim=(2, 3), keepdim=True, correction=0
)
std = torch.maximum(
var, torch.zeros_like(var) + eps) ** 0.5
mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
var_acc = sum(self.var_bank) / float(len(self.var_bank))
std_acc = (
torch.maximum(var_acc, torch.zeros_like(
var_acc) + eps) ** 0.5
)
x_uc = (((masked_x - mean) / std) * std_acc) + mean_acc
x_c = x_uc.clone()
if self.do_classifier_free_guidance and self.style_fidelity > 0:
x_c[self.uc_mask] = masked_x[self.uc_mask]
masked_x = self.style_fidelity * x_c + \
(1.0 - self.style_fidelity) * x_uc
x[this_inpaint_mask] = masked_x.view(-1)
self.mean_bank = []
self.var_bank = []
return x
def hack_CrossAttnDownBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
eps = 1e-6
# TODO(Patrick, William) - attention mask is not used
output_states = ()
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
hidden_states = resnet(hidden_states, temb)
if self.MODE == "write":
if self.gn_auto_machine_weight >= self.gn_weight:
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
# mask var mean
scale_ratio = self.ref_mask.shape[2] / \
hidden_states.shape[2]
this_ref_mask = F.interpolate(
self.ref_mask.to(hidden_states.device),
scale_factor=1 / scale_ratio,
)
this_ref_mask = this_ref_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_ref_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
self.mean_bank0.append(mean)
self.var_bank0.append(var)
if self.MODE == "read":
if (
self.gn_auto_machine_weight >= self.gn_weight
and len(self.mean_bank0) > 0
and len(self.var_bank0) > 0
):
# print("hacked_CrossAttnDownBlock2D_forward0")
scale_ratio = self.inpaint_mask.shape[2] / \
hidden_states.shape[2]
this_inpaint_mask = F.interpolate(
self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio
)
this_inpaint_mask = this_inpaint_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_inpaint_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
std = torch.maximum(
var, torch.zeros_like(var) + eps) ** 0.5
mean_acc = sum(self.mean_bank0[i]) / float(
len(self.mean_bank0[i])
)
var_acc = sum(
self.var_bank0[i]) / float(len(self.var_bank0[i]))
std_acc = (
torch.maximum(
var_acc, torch.zeros_like(var_acc) + eps)
** 0.5
)
hidden_states_uc = (
((masked_hidden_states - mean) / std) * std_acc
) + mean_acc
hidden_states_c = hidden_states_uc.clone()
if self.do_classifier_free_guidance and self.style_fidelity > 0:
hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask]
masked_hidden_states = (
self.style_fidelity * hidden_states_c
+ (1.0 - self.style_fidelity) * hidden_states_uc
)
hidden_states[this_inpaint_mask] = masked_hidden_states.view(
-1)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
# attention_mask=attention_mask,
# encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.MODE == "write":
if self.gn_auto_machine_weight >= self.gn_weight:
# mask var mean
scale_ratio = self.ref_mask.shape[2] / \
hidden_states.shape[2]
this_ref_mask = F.interpolate(
self.ref_mask.to(hidden_states.device),
scale_factor=1 / scale_ratio,
)
this_ref_mask = this_ref_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_ref_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
self.mean_bank.append(mean)
self.var_bank.append(var)
if self.MODE == "read":
if (
self.gn_auto_machine_weight >= self.gn_weight
and len(self.mean_bank) > 0
and len(self.var_bank) > 0
):
# print("hack_CrossAttnDownBlock2D_forward")
scale_ratio = self.inpaint_mask.shape[2] / \
hidden_states.shape[2]
this_inpaint_mask = F.interpolate(
self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio
)
this_inpaint_mask = this_inpaint_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_inpaint_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
std = torch.maximum(
var, torch.zeros_like(var) + eps) ** 0.5
mean_acc = sum(self.mean_bank[i]) / float(
len(self.mean_bank[i])
)
var_acc = sum(
self.var_bank[i]) / float(len(self.var_bank[i]))
std_acc = (
torch.maximum(
var_acc, torch.zeros_like(var_acc) + eps)
** 0.5
)
hidden_states_uc = (
((masked_hidden_states - mean) / std) * std_acc
) + mean_acc
hidden_states_c = hidden_states_uc.clone()
if self.do_classifier_free_guidance and self.style_fidelity > 0:
hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask]
masked_hidden_states = (
self.style_fidelity * hidden_states_c
+ (1.0 - self.style_fidelity) * hidden_states_uc
)
hidden_states[this_inpaint_mask] = masked_hidden_states.view(
-1)
output_states = output_states + (hidden_states,)
if self.MODE == "read":
self.mean_bank0 = []
self.var_bank0 = []
self.mean_bank = []
self.var_bank = []
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
eps = 1e-6
output_states = ()
for i, resnet in enumerate(self.resnets):
hidden_states = resnet(hidden_states, temb)
if self.MODE == "write":
if self.gn_auto_machine_weight >= self.gn_weight:
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
# mask var mean
scale_ratio = self.ref_mask.shape[2] / \
hidden_states.shape[2]
this_ref_mask = F.interpolate(
self.ref_mask.to(hidden_states.device),
scale_factor=1 / scale_ratio,
)
this_ref_mask = this_ref_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_ref_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
self.mean_bank.append(mean)
self.var_bank.append(var)
if self.MODE == "read":
if (
self.gn_auto_machine_weight >= self.gn_weight
and len(self.mean_bank) > 0
and len(self.var_bank) > 0
):
# print("hacked_DownBlock2D_forward")
scale_ratio = self.inpaint_mask.shape[2] / \
hidden_states.shape[2]
this_inpaint_mask = F.interpolate(
self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio
)
this_inpaint_mask = this_inpaint_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_inpaint_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
std = torch.maximum(
var, torch.zeros_like(var) + eps) ** 0.5
mean_acc = sum(self.mean_bank[i]) / float(
len(self.mean_bank[i])
)
var_acc = sum(
self.var_bank[i]) / float(len(self.var_bank[i]))
std_acc = (
torch.maximum(
var_acc, torch.zeros_like(var_acc) + eps)
** 0.5
)
hidden_states_uc = (
((masked_hidden_states - mean) / std) * std_acc
) + mean_acc
hidden_states_c = hidden_states_uc.clone()
if self.do_classifier_free_guidance and self.style_fidelity > 0:
hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask]
masked_hidden_states = (
self.style_fidelity * hidden_states_c
+ (1.0 - self.style_fidelity) * hidden_states_uc
)
hidden_states[this_inpaint_mask] = masked_hidden_states.view(
-1)
output_states = output_states + (hidden_states,)
if self.MODE == "read":
self.mean_bank = []
self.var_bank = []
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
def hacked_CrossAttnUpBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
eps = 1e-6
# TODO(Patrick, William) - attention mask is not used
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat(
[hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if self.MODE == "write":
if self.gn_auto_machine_weight >= self.gn_weight:
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
# mask var mean
scale_ratio = self.ref_mask.shape[2] / \
hidden_states.shape[2]
this_ref_mask = F.interpolate(
self.ref_mask.to(hidden_states.device),
scale_factor=1 / scale_ratio,
)
this_ref_mask = this_ref_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_ref_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
self.mean_bank0.append(mean)
self.var_bank0.append(var)
if self.MODE == "read":
if (
self.gn_auto_machine_weight >= self.gn_weight
and len(self.mean_bank0) > 0
and len(self.var_bank0) > 0
):
# print("hacked_CrossAttnUpBlock2D_forward1")
scale_ratio = self.inpaint_mask.shape[2] / \
hidden_states.shape[2]
this_inpaint_mask = F.interpolate(
self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio
)
this_inpaint_mask = this_inpaint_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_inpaint_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
std = torch.maximum(
var, torch.zeros_like(var) + eps) ** 0.5
mean_acc = sum(self.mean_bank0[i]) / float(
len(self.mean_bank0[i])
)
var_acc = sum(
self.var_bank0[i]) / float(len(self.var_bank0[i]))
std_acc = (
torch.maximum(
var_acc, torch.zeros_like(var_acc) + eps)
** 0.5
)
hidden_states_uc = (
((masked_hidden_states - mean) / std) * std_acc
) + mean_acc
hidden_states_c = hidden_states_uc.clone()
if self.do_classifier_free_guidance and self.style_fidelity > 0:
hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask]
masked_hidden_states = (
self.style_fidelity * hidden_states_c
+ (1.0 - self.style_fidelity) * hidden_states_uc
)
hidden_states[this_inpaint_mask] = masked_hidden_states.view(
-1)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
# attention_mask=attention_mask,
# encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.MODE == "write":
if self.gn_auto_machine_weight >= self.gn_weight:
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
# mask var mean
scale_ratio = self.ref_mask.shape[2] / \
hidden_states.shape[2]
this_ref_mask = F.interpolate(
self.ref_mask.to(hidden_states.device),
scale_factor=1 / scale_ratio,
)
this_ref_mask = this_ref_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_ref_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
self.mean_bank.append(mean)
self.var_bank.append(var)
if self.MODE == "read":
if (
self.gn_auto_machine_weight >= self.gn_weight
and len(self.mean_bank) > 0
and len(self.var_bank) > 0
):
# print("hacked_CrossAttnUpBlock2D_forward")
scale_ratio = self.inpaint_mask.shape[2] / \
hidden_states.shape[2]
this_inpaint_mask = F.interpolate(
self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio
)
this_inpaint_mask = this_inpaint_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_inpaint_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
std = torch.maximum(
var, torch.zeros_like(var) + eps) ** 0.5
mean_acc = sum(self.mean_bank[i]) / float(
len(self.mean_bank[i])
)
var_acc = sum(
self.var_bank[i]) / float(len(self.var_bank[i]))
std_acc = (
torch.maximum(
var_acc, torch.zeros_like(var_acc) + eps)
** 0.5
)
hidden_states_uc = (
((masked_hidden_states - mean) / std) * std_acc
) + mean_acc
hidden_states_c = hidden_states_uc.clone()
if self.do_classifier_free_guidance and self.style_fidelity > 0:
hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask]
masked_hidden_states = (
self.style_fidelity * hidden_states_c
+ (1.0 - self.style_fidelity) * hidden_states_uc
)
hidden_states[this_inpaint_mask] = masked_hidden_states.view(
-1)
if self.MODE == "read":
self.mean_bank0 = []
self.var_bank0 = []
self.mean_bank = []
self.var_bank = []
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
def hacked_UpBlock2D_forward(
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None
):
eps = 1e-6
for i, resnet in enumerate(self.resnets):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat(
[hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if self.MODE == "write":
if self.gn_auto_machine_weight >= self.gn_weight:
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
# mask var mean
scale_ratio = self.ref_mask.shape[2] / \
hidden_states.shape[2]
this_ref_mask = F.interpolate(
self.ref_mask.to(hidden_states.device),
scale_factor=1 / scale_ratio,
)
this_ref_mask = this_ref_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_ref_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
self.mean_bank.append(mean)
self.var_bank.append(var)
if self.MODE == "read":
if (
self.gn_auto_machine_weight >= self.gn_weight
and len(self.mean_bank) > 0
and len(self.var_bank) > 0
):
# print("hacked_UpBlock2D_forward")
scale_ratio = self.inpaint_mask.shape[2] / \
hidden_states.shape[2]
this_inpaint_mask = F.interpolate(
self.inpaint_mask.to(hidden_states.device), scale_factor=1 / scale_ratio
)
this_inpaint_mask = this_inpaint_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_inpaint_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
std = torch.maximum(
var, torch.zeros_like(var) + eps) ** 0.5
mean_acc = sum(self.mean_bank[i]) / float(
len(self.mean_bank[i])
)
var_acc = sum(
self.var_bank[i]) / float(len(self.var_bank[i]))
std_acc = (
torch.maximum(
var_acc, torch.zeros_like(var_acc) + eps)
** 0.5
)
hidden_states_uc = (
((masked_hidden_states - mean) / std) * std_acc
) + mean_acc
hidden_states_c = hidden_states_uc.clone()
if self.do_classifier_free_guidance and self.style_fidelity > 0:
hidden_states_c[self.uc_mask] = masked_hidden_states[self.uc_mask]
masked_hidden_states = (
self.style_fidelity * hidden_states_c
+ (1.0 - self.style_fidelity) * hidden_states_uc
)
hidden_states[this_inpaint_mask] = masked_hidden_states.view(
-1)
if self.MODE == "read":
self.mean_bank = []
self.var_bank = []
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
if model_type == "unet":
if reference_attn:
attn_modules = [
module
for module in torch_dfs(model)
if isinstance(module, BasicTransformerBlock)
]
attn_modules = sorted(
attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
)
for i, module in enumerate(attn_modules):
module._original_inner_forward = module.forward
module.forward = hacked_basic_transformer_inner_forward.__get__(
module, BasicTransformerBlock
)
module.bank = []
module.attn_weight = float(i) / float(len(attn_modules))
module.attention_auto_machine_weight = (
self.attention_auto_machine_weight
)
module.gn_auto_machine_weight = self.gn_auto_machine_weight
module.do_classifier_free_guidance = (
self.do_classifier_free_guidance
)
module.do_classifier_free_guidance = (
self.do_classifier_free_guidance
)
module.uc_mask = self.uc_mask
module.style_fidelity = self.style_fidelity
module.ref_mask = self.ref_mask
else:
attn_modules = None
if reference_adain:
gn_modules = [model.mid_block]
model.mid_block.gn_weight = 0
down_blocks = model.down_blocks
for w, module in enumerate(down_blocks):
module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
gn_modules.append(module)
# print(module.__class__.__name__,module.gn_weight)
up_blocks = model.up_blocks
for w, module in enumerate(up_blocks):
module.gn_weight = float(w) / float(len(up_blocks))
gn_modules.append(module)
# print(module.__class__.__name__,module.gn_weight)
for i, module in enumerate(gn_modules):
if getattr(module, "original_forward", None) is None:
module.original_forward = module.forward
if i == 0:
# mid_block
module.forward = hacked_mid_forward.__get__(
module, torch.nn.Module
)
elif isinstance(module, CrossAttnDownBlock2D):
module.forward = hack_CrossAttnDownBlock2D_forward.__get__(
module, CrossAttnDownBlock2D
)
module.mean_bank0 = []
module.var_bank0 = []
elif isinstance(module, DownBlock2D):
module.forward = hacked_DownBlock2D_forward.__get__(
module, DownBlock2D
)
# elif isinstance(module, CrossAttnUpBlock2D):
# module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
# module.mean_bank0 = []
# module.var_bank0 = []
elif isinstance(module, UpBlock2D):
module.forward = hacked_UpBlock2D_forward.__get__(
module, UpBlock2D
)
module.mean_bank0 = []
module.var_bank0 = []
module.mean_bank = []
module.var_bank = []
module.attention_auto_machine_weight = (
self.attention_auto_machine_weight
)
module.gn_auto_machine_weight = self.gn_auto_machine_weight
module.do_classifier_free_guidance = (
self.do_classifier_free_guidance
)
module.do_classifier_free_guidance = (
self.do_classifier_free_guidance
)
module.uc_mask = self.uc_mask
module.style_fidelity = self.style_fidelity
module.ref_mask = self.ref_mask
module.inpaint_mask = self.inpaint_mask
else:
gn_modules = None
elif model_type == "controlnet":
model = model.nets[-1] # only hack the inpainting controlnet
if reference_attn:
attn_modules = [
module
for module in torch_dfs(model)
if isinstance(module, BasicTransformerBlock)
]
attn_modules = sorted(
attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
)
for i, module in enumerate(attn_modules):
module._original_inner_forward = module.forward
module.forward = hacked_basic_transformer_inner_forward.__get__(
module, BasicTransformerBlock
)
module.bank = []
# float(i) / float(len(attn_modules))
module.attn_weight = 0.0
module.attention_auto_machine_weight = (
self.attention_auto_machine_weight
)
module.gn_auto_machine_weight = self.gn_auto_machine_weight
module.do_classifier_free_guidance = (
self.do_classifier_free_guidance
)
module.do_classifier_free_guidance = (
self.do_classifier_free_guidance
)
module.uc_mask = self.uc_mask
module.style_fidelity = self.style_fidelity
module.ref_mask = self.ref_mask
else:
attn_modules = None
gn_modules = None
return attn_modules, gn_modules
def change_module_mode(self, mode, attn_modules, gn_modules):
if attn_modules is not None:
for i, module in enumerate(attn_modules):
module.MODE = mode
if gn_modules is not None:
for i, module in enumerate(gn_modules):
module.MODE = mode