eggarsway's picture
add DD
b0369c2
"""
ORIGINAL SOURCE CODE
https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
"""
import torch
import math
import copy
import numpy as np
def gaussian_2d(x=0, y=0, mx=0, my=0, sx=1, sy=1):
""" Gaussian weight
Args:
x(float): sample x
x(float): sample x
"""
return (
1
/ (2 * math.pi * sx * sy)
* torch.exp(-((x - mx) ** 2 / (2 * sx ** 2) + (y - my) ** 2 / (2 * sy ** 2)))
)
def init_attention_edit(
unet, tokens=[], rios=[], num_trailing_attn=[], length_prompt=0, noise_scale=[],
):
""" To initialize overwritting the attention module.
"""
rios_ = copy.deepcopy(rios)
tokens_ = copy.deepcopy(tokens)
num_regions = len(rios)
num_trailing_attn_ = copy.deepcopy(num_trailing_attn)
noise_scale_ = copy.deepcopy(noise_scale)
for i in range(num_regions):
total_unrelated = length_prompt + 1 + num_trailing_attn[i]
tokens_trailing = []
for j in range(length_prompt + 1, total_unrelated): # trailing
if j > 76:
break
tokens_[i].append(j)
# tokens_trailing.append(j)
tokens_[i] = list(set(tokens_[i]))
def new_attention(self, query, key, value):
attn_slice = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attn_slice = attn_slice.softmax(dim=-1)
if self.use_edited_attn and attn_slice.size()[2] == 77:
in_tokens = tokens_[:]
rios = rios_[:]
num_trailing_attn = num_trailing_attn_[:]
dim = int(np.sqrt(attn_slice.size()[1]))
attn_slice = attn_slice.view(8, dim, dim, 77)
# the mask for all interested words in prompt
# global_mask = torch.ones_like(attn_slice, dtype=torch.bool)
global_mask = torch.zeros_like(attn_slice, dtype=torch.bool)
for i in range(num_regions):
# region
left = int(dim * rios[i][0])
right = int(dim * rios[i][1])
top = int(dim * rios[i][2])
bottom = int(dim * rios[i][3])
# add noise to avoid nullity
if self.use_add_noise:
tmp = attn_slice[
:, top:bottom, left:right, in_tokens[i]
].clone() * (noise_scale_[i] * self.noise_level)
# TODO: the uniform distribution on the region
w = tmp.shape[2]
h = tmp.shape[1]
# noise_u = torch.abs(torch.randn(h, w))
x = torch.linspace(0, h, h)
y = torch.linspace(0, w, w)
x, y = torch.meshgrid(x, y, indexing="ij")
noise_g = gaussian_2d(
x,
y,
mx=int(h / 2),
my=int(w / 2),
sx=float(h) / 2.0,
sy=float(w) / 2.0,
)
noise = noise_g
noise = (
noise.unsqueeze(0)
.unsqueeze(-1)
.repeat(tmp.shape[0], 1, 1, tmp.shape[-1])
.to(attn_slice.device)
)
attn_slice[:, top:bottom, left:right, in_tokens[i]] = tmp + noise
scale = 0
mask = torch.ones_like(attn_slice, dtype=torch.bool)
mask[:, :, right:, in_tokens[i]] = False
mask[:, :, :left, in_tokens[i]] = False
mask[:, :top, :, in_tokens[i]] = False
mask[:, bottom:, :, in_tokens[i]] = False
global_mask[..., in_tokens[i]] |= mask[..., in_tokens[i]]
mask = torch.zeros_like(attn_slice, dtype=torch.bool)
mask[:, top:bottom, :, in_tokens[i]] = True
mask[:, :, left:right, in_tokens[i]] = True
global_mask[..., in_tokens[i]] &= mask[..., in_tokens[i]]
zeros_indices = torch.where(global_mask == False)
# global_mask = torch.tensor(
# global_mask.clone().detach(), dtype=torch.float16
# )
global_mask = global_mask.clone().detach().half()
global_mask[zeros_indices] = 0.01
attn_slice *= global_mask
attn_slice = attn_slice.view(8, dim * dim, 77)
self.use_add_noise = False
self.use_edited_attn = False
self.attn = attn_slice.clone()
hidden_states = torch.matmul(attn_slice, value)
# reshape hidden_states, e.g., torch.Size([1, 4096, 320])
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.last_attn_slice = None
module.use_last_attn_slice = False
module.use_last_attn_weights = False
module.use_edited_attn = False
module.save_last_attn_slice = False
module.use_add_noise = False
module.noise_level = 0.0
module.attn = None
module._attention = new_attention.__get__(module, type(module))