DragDiffusion / utils /attn_utils.py
peter850421's picture
Upload folder using huggingface_hub
e1ebf71
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
class AttentionBase:
def __init__(self):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
def after_step(self):
pass
def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
# after step
self.after_step()
return out
def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = rearrange(out, 'b h n d -> b n (h d)')
return out
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
class MutualSelfAttentionControl(AttentionBase):
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, guidance_scale=7.5):
"""
Mutual self-attention control for Stable-Diffusion model
Args:
start_step: the step to start mutual self-attention control
start_layer: the layer to start mutual self-attention control
layer_idx: list of the layers to apply mutual self-attention control
step_idx: list the steps to apply mutual self-attention control
total_steps: the total number of steps
"""
super().__init__()
self.total_steps = total_steps
self.start_step = start_step
self.start_layer = start_layer
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16))
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
# store the guidance scale to decide whether there are unconditional branch
self.guidance_scale = guidance_scale
print("step_idx: ", self.step_idx)
print("layer_idx: ", self.layer_idx)
def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
"""
Attention forward function
"""
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
return super().forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
if self.guidance_scale > 1.0:
qu, qc = q[0:2], q[2:4]
ku, kc = k[0:2], k[2:4]
vu, vc = v[0:2], v[2:4]
# merge queries of source and target branch into one so we can use torch API
qu = torch.cat([qu[0:1], qu[1:2]], dim=2)
qc = torch.cat([qc[0:1], qc[1:2]], dim=2)
out_u = F.scaled_dot_product_attention(qu, ku[0:1], vu[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
out_u = torch.cat(out_u.chunk(2, dim=2), dim=0) # split the queries into source and target batch
out_u = rearrange(out_u, 'b h n d -> b n (h d)')
out_c = F.scaled_dot_product_attention(qc, kc[0:1], vc[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
out_c = torch.cat(out_c.chunk(2, dim=2), dim=0) # split the queries into source and target batch
out_c = rearrange(out_c, 'b h n d -> b n (h d)')
out = torch.cat([out_u, out_c], dim=0)
else:
q = torch.cat([q[0:1], q[1:2]], dim=2)
out = F.scaled_dot_product_attention(q, k[0:1], v[0:1], attn_mask=None, dropout_p=0.0, is_causal=False)
out = torch.cat(out.chunk(2, dim=2), dim=0) # split the queries into source and target batch
out = rearrange(out, 'b h n d -> b n (h d)')
return out
# forward function for default attention processor
# modified from __call__ function of AttnProcessor in diffusers
def override_attn_proc_forward(attn, editor, place_in_unet):
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
"""
The attention is similar to the original implementation of LDM CrossAttention class
except adding some modifications on the attention
"""
if encoder_hidden_states is not None:
context = encoder_hidden_states
if attention_mask is not None:
mask = attention_mask
to_out = attn.to_out
if isinstance(to_out, nn.modules.container.ModuleList):
to_out = attn.to_out[0]
else:
to_out = attn.to_out
h = attn.heads
q = attn.to_q(x)
is_cross = context is not None
context = context if is_cross else x
k = attn.to_k(context)
v = attn.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
# the only difference
out = editor(
q, k, v, is_cross, place_in_unet,
attn.heads, scale=attn.scale)
return to_out(out)
return forward
# forward function for lora attention processor
# modified from __call__ function of LoRAAttnProcessor2_0 in diffusers v0.17.1
def override_lora_attn_proc_forward(attn, editor, place_in_unet):
def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, lora_scale=1.0):
residual = hidden_states
input_ndim = hidden_states.ndim
is_cross = encoder_hidden_states is not None
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states)
query, key, value = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=attn.heads), (query, key, value))
# the only difference
hidden_states = editor(
query, key, value, is_cross, place_in_unet,
attn.heads, scale=attn.scale)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
return forward
def register_attention_editor_diffusers(model, editor: AttentionBase, attn_processor='attn_proc'):
"""
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
"""
def register_editor(net, count, place_in_unet):
for name, subnet in net.named_children():
if net.__class__.__name__ == 'Attention': # spatial Transformer layer
if attn_processor == 'attn_proc':
net.forward = override_attn_proc_forward(net, editor, place_in_unet)
elif attn_processor == 'lora_attn_proc':
net.forward = override_lora_attn_proc_forward(net, editor, place_in_unet)
else:
raise NotImplementedError("not implemented")
return count + 1
elif hasattr(net, 'children'):
count = register_editor(subnet, count, place_in_unet)
return count
cross_att_count = 0
for net_name, net in model.unet.named_children():
if "down" in net_name:
cross_att_count += register_editor(net, 0, "down")
elif "mid" in net_name:
cross_att_count += register_editor(net, 0, "mid")
elif "up" in net_name:
cross_att_count += register_editor(net, 0, "up")
editor.num_att_layers = cross_att_count