Spaces:
Runtime error
Runtime error
# Copyright 2023 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
from re import U | |
import numpy as np | |
from einops import rearrange | |
from .masactrl_utils import AttentionBase | |
from torchvision.utils import save_image | |
import sys | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
import torch.fft as fft | |
from einops import rearrange, repeat | |
from diffusers.utils import deprecate, logging | |
from diffusers.utils.import_utils import is_xformers_available | |
# from masactrl.masactrl import MutualSelfAttentionControl | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
if is_xformers_available(): | |
import xformers | |
import xformers.ops | |
else: | |
xformers = None | |
class AttentionBase: | |
def __init__(self): | |
self.cur_step = 0 | |
self.num_att_layers = -1 | |
self.cur_att_layer = 0 | |
def after_step(self): | |
pass | |
def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): | |
out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) | |
self.cur_att_layer += 1 | |
if self.cur_att_layer == self.num_att_layers: | |
self.cur_att_layer = 0 | |
self.cur_step += 1 | |
# after step | |
self.after_step() | |
return out | |
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): | |
out = torch.einsum('b i j, b j d -> b i d', attn, v) | |
out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads) | |
return out | |
def reset(self): | |
self.cur_step = 0 | |
self.cur_att_layer = 0 | |
class MaskPromptedStyleAttentionControl(AttentionBase): | |
def __init__(self, start_step=4, start_layer=10, style_attn_step=35, layer_idx=None, step_idx=None, total_steps=50, style_guidance=0.1, | |
only_masked_region=False, guidance=0.0, | |
style_mask=None, source_mask=None, de_bug=False): | |
""" | |
MaskPromptedSAC | |
Args: | |
start_step: the step to start mutual self-attention control | |
start_layer: the layer to start mutual self-attention control | |
layer_idx: list of the layers to apply mutual self-attention control | |
step_idx: list the steps to apply mutual self-attention control | |
total_steps: the total number of steps | |
thres: the thereshold for mask thresholding | |
ref_token_idx: the token index list for cross-attention map aggregation | |
cur_token_idx: the token index list for cross-attention map aggregation | |
mask_save_dir: the path to save the mask image | |
""" | |
super().__init__() | |
self.total_steps = total_steps | |
self.total_layers = 16 | |
self.start_step = start_step | |
self.start_layer = start_layer | |
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers)) | |
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps)) | |
print("using MaskPromptStyleAttentionControl") | |
print("MaskedSAC at denoising steps: ", self.step_idx) | |
print("MaskedSAC at U-Net layers: ", self.layer_idx) | |
self.de_bug = de_bug | |
self.style_guidance = style_guidance | |
self.only_masked_region = only_masked_region | |
self.style_attn_step = style_attn_step | |
self.self_attns = [] | |
self.cross_attns = [] | |
self.guidance = guidance | |
self.style_mask = style_mask | |
self.source_mask = source_mask | |
def after_step(self): | |
self.self_attns = [] | |
self.cross_attns = [] | |
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, q_mask,k_mask, **kwargs): | |
B = q.shape[0] // num_heads | |
H = W = int(np.sqrt(q.shape[1])) | |
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) | |
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) | |
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) | |
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale") | |
if q_mask is not None: | |
sim = sim.masked_fill(q_mask.unsqueeze(0)==0, -torch.finfo(sim.dtype).max) | |
if k_mask is not None: | |
sim = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==0, -torch.finfo(sim.dtype).max) | |
attn = sim.softmax(-1) if attn is None else attn | |
if len(attn) == 2 * len(v): | |
v = torch.cat([v] * 2) | |
out = torch.einsum("h i j, h j d -> h i d", attn, v) | |
out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads) | |
return out | |
def attn_batch_fg_bg(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, q_mask,k_mask, **kwargs): | |
B = q.shape[0] // num_heads | |
H = W = int(np.sqrt(q.shape[1])) | |
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) | |
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) | |
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) | |
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale") | |
if q_mask is not None: | |
sim_fg = sim.masked_fill(q_mask.unsqueeze(0)==0, -torch.finfo(sim.dtype).max) | |
sim_bg = sim.masked_fill(q_mask.unsqueeze(0)==1, -torch.finfo(sim.dtype).max) | |
if k_mask is not None: | |
sim_fg = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==0, -torch.finfo(sim.dtype).max) | |
sim_bg = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==1, -torch.finfo(sim.dtype).max) | |
sim = torch.cat([sim_fg, sim_bg]) | |
attn = sim.softmax(-1) | |
if len(attn) == 2 * len(v): | |
v = torch.cat([v] * 2) | |
out = torch.einsum("h i j, h j d -> h i d", attn, v) | |
out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads) | |
return out | |
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): | |
""" | |
Attention forward function | |
""" | |
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: | |
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) | |
B = q.shape[0] // num_heads // 2 | |
H = W = int(np.sqrt(q.shape[1])) | |
if self.style_mask is not None and self.source_mask is not None: | |
#mask = self.aggregate_cross_attn_map(idx=self.cur_token_idx) # (4, H, W) | |
heigh, width = self.style_mask.shape[-2:] | |
mask_style = self.style_mask# (H, W) | |
mask_source = self.source_mask# (H, W) | |
scale = int(np.sqrt(heigh * width / q.shape[1])) | |
# res = int(np.sqrt(q.shape[1])) | |
spatial_mask_source = F.interpolate(mask_source, (heigh//scale, width//scale)).reshape(-1, 1) | |
spatial_mask_style = F.interpolate(mask_style, (heigh//scale, width//scale)).reshape(-1, 1) | |
else: | |
spatial_mask_source=None | |
spatial_mask_style=None | |
if spatial_mask_style is None or spatial_mask_source is None: | |
out_s,out_c,out_t = self.style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs) | |
else: | |
if self.only_masked_region: | |
out_s,out_c,out_t = self.mask_prompted_style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs) | |
else: | |
out_s,out_c,out_t = self.separate_mask_prompted_style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs) | |
out = torch.cat([out_s,out_c,out_t],dim=0) | |
return out | |
def style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs): | |
if self.de_bug: | |
import pdb; pdb.set_trace() | |
qs, qc, qt = q.chunk(3) | |
out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs) | |
out_c = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs) | |
if self.cur_step < self.style_attn_step: | |
out_t = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs) | |
else: | |
out_t = self.attn_batch(qt, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs) | |
if self.style_guidance>=0: | |
out_t = out_c + (out_t - out_c) * self.style_guidance | |
return out_s,out_c,out_t | |
def mask_prompted_style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs): | |
qs, qc, qt = q.chunk(3) | |
out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs) | |
out_c = self.attn_batch(qc, k[num_heads: 2*num_heads], v[num_heads:2*num_heads], sim[num_heads: 2*num_heads], attn[num_heads: 2*num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None, **kwargs) | |
out_c_new = self.attn_batch(qc, k[num_heads: 2*num_heads], v[num_heads:2*num_heads], sim[num_heads: 2*num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None, **kwargs) | |
if self.de_bug: | |
import pdb; pdb.set_trace() | |
if self.cur_step < self.style_attn_step: | |
out_t = out_c #self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs) | |
else: | |
out_t_fg = self.attn_batch(qt, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs) | |
out_c_fg = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs) | |
if self.style_guidance>=0: | |
out_t = out_c_fg + (out_t_fg - out_c_fg) * self.style_guidance | |
out_t = out_t * spatial_mask_source + out_c * (1 - spatial_mask_source) | |
if self.de_bug: | |
import pdb; pdb.set_trace() | |
# print(torch.sum(out_t* (1 - spatial_mask_source) - out_c * (1 - spatial_mask_source))) | |
return out_s,out_c,out_t | |
def separate_mask_prompted_style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs): | |
if self.de_bug: | |
import pdb; pdb.set_trace() | |
# To prevent query confusion, render fg and bg according to mask. | |
qs, qc, qt = q.chunk(3) | |
out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs) | |
if self.cur_step < self.style_attn_step: | |
out_c = self.attn_batch_fg_bg(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs) | |
out_c_fg,out_c_bg = out_c.chunk(2) | |
out_t = out_c_fg * spatial_mask_source + out_c_bg * (1 - spatial_mask_source) | |
else: | |
out_t = self.attn_batch_fg_bg(qt, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs) | |
out_c = self.attn_batch_fg_bg(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs) | |
out_t_fg,out_t_bg = out_t.chunk(2) | |
out_c_fg,out_c_bg = out_c.chunk(2) | |
if self.style_guidance>=0: | |
out_t_fg = out_c_fg + (out_t_fg - out_c_fg) * self.style_guidance | |
out_t_bg = out_c_bg + (out_t_bg - out_c_bg) * self.style_guidance | |
out_t = out_t_fg * spatial_mask_source + out_t_bg * (1 - spatial_mask_source) | |
return out_s,out_t,out_t | |