Spaces:
Runtime error
Runtime error
""" | |
ORIGINAL SOURCE CODE | |
https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 | |
""" | |
import torch | |
import math | |
import copy | |
import numpy as np | |
def gaussian_2d(x=0, y=0, mx=0, my=0, sx=1, sy=1): | |
""" Gaussian weight | |
Args: | |
x(float): sample x | |
x(float): sample x | |
""" | |
return ( | |
1 | |
/ (2 * math.pi * sx * sy) | |
* torch.exp(-((x - mx) ** 2 / (2 * sx ** 2) + (y - my) ** 2 / (2 * sy ** 2))) | |
) | |
def init_attention_edit( | |
unet, tokens=[], rios=[], num_trailing_attn=[], length_prompt=0, noise_scale=[], | |
): | |
""" To initialize overwritting the attention module. | |
""" | |
rios_ = copy.deepcopy(rios) | |
tokens_ = copy.deepcopy(tokens) | |
num_regions = len(rios) | |
num_trailing_attn_ = copy.deepcopy(num_trailing_attn) | |
noise_scale_ = copy.deepcopy(noise_scale) | |
for i in range(num_regions): | |
total_unrelated = length_prompt + 1 + num_trailing_attn[i] | |
tokens_trailing = [] | |
for j in range(length_prompt + 1, total_unrelated): # trailing | |
if j > 76: | |
break | |
tokens_[i].append(j) | |
# tokens_trailing.append(j) | |
tokens_[i] = list(set(tokens_[i])) | |
def new_attention(self, query, key, value): | |
attn_slice = torch.matmul(query, key.transpose(-1, -2)) * self.scale | |
attn_slice = attn_slice.softmax(dim=-1) | |
if self.use_edited_attn and attn_slice.size()[2] == 77: | |
in_tokens = tokens_[:] | |
rios = rios_[:] | |
num_trailing_attn = num_trailing_attn_[:] | |
dim = int(np.sqrt(attn_slice.size()[1])) | |
attn_slice = attn_slice.view(8, dim, dim, 77) | |
# the mask for all interested words in prompt | |
# global_mask = torch.ones_like(attn_slice, dtype=torch.bool) | |
global_mask = torch.zeros_like(attn_slice, dtype=torch.bool) | |
for i in range(num_regions): | |
# region | |
left = int(dim * rios[i][0]) | |
right = int(dim * rios[i][1]) | |
top = int(dim * rios[i][2]) | |
bottom = int(dim * rios[i][3]) | |
# add noise to avoid nullity | |
if self.use_add_noise: | |
tmp = attn_slice[ | |
:, top:bottom, left:right, in_tokens[i] | |
].clone() * (noise_scale_[i] * self.noise_level) | |
# TODO: the uniform distribution on the region | |
w = tmp.shape[2] | |
h = tmp.shape[1] | |
# noise_u = torch.abs(torch.randn(h, w)) | |
x = torch.linspace(0, h, h) | |
y = torch.linspace(0, w, w) | |
x, y = torch.meshgrid(x, y, indexing="ij") | |
noise_g = gaussian_2d( | |
x, | |
y, | |
mx=int(h / 2), | |
my=int(w / 2), | |
sx=float(h) / 2.0, | |
sy=float(w) / 2.0, | |
) | |
noise = noise_g | |
noise = ( | |
noise.unsqueeze(0) | |
.unsqueeze(-1) | |
.repeat(tmp.shape[0], 1, 1, tmp.shape[-1]) | |
.to(attn_slice.device) | |
) | |
attn_slice[:, top:bottom, left:right, in_tokens[i]] = tmp + noise | |
scale = 0 | |
mask = torch.ones_like(attn_slice, dtype=torch.bool) | |
mask[:, :, right:, in_tokens[i]] = False | |
mask[:, :, :left, in_tokens[i]] = False | |
mask[:, :top, :, in_tokens[i]] = False | |
mask[:, bottom:, :, in_tokens[i]] = False | |
global_mask[..., in_tokens[i]] |= mask[..., in_tokens[i]] | |
mask = torch.zeros_like(attn_slice, dtype=torch.bool) | |
mask[:, top:bottom, :, in_tokens[i]] = True | |
mask[:, :, left:right, in_tokens[i]] = True | |
global_mask[..., in_tokens[i]] &= mask[..., in_tokens[i]] | |
zeros_indices = torch.where(global_mask == False) | |
# global_mask = torch.tensor( | |
# global_mask.clone().detach(), dtype=torch.float16 | |
# ) | |
global_mask = global_mask.clone().detach().half() | |
global_mask[zeros_indices] = 0.01 | |
attn_slice *= global_mask | |
attn_slice = attn_slice.view(8, dim * dim, 77) | |
self.use_add_noise = False | |
self.use_edited_attn = False | |
self.attn = attn_slice.clone() | |
hidden_states = torch.matmul(attn_slice, value) | |
# reshape hidden_states, e.g., torch.Size([1, 4096, 320]) | |
hidden_states = self.reshape_batch_dim_to_heads(hidden_states) | |
return hidden_states | |
for name, module in unet.named_modules(): | |
module_name = type(module).__name__ | |
if module_name == "CrossAttention": | |
module.last_attn_slice = None | |
module.use_last_attn_slice = False | |
module.use_last_attn_weights = False | |
module.use_edited_attn = False | |
module.save_last_attn_slice = False | |
module.use_add_noise = False | |
module.noise_level = 0.0 | |
module.attn = None | |
module._attention = new_attention.__get__(module, type(module)) | |