|
import os |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from typing import Optional, Union, Tuple, List, Callable, Dict |
|
|
|
from torchvision.utils import save_image |
|
from einops import rearrange, repeat |
|
|
|
|
|
class AttentionBase: |
|
def __init__(self): |
|
self.cur_step = 0 |
|
self.num_att_layers = -1 |
|
self.cur_att_layer = 0 |
|
|
|
def after_step(self): |
|
pass |
|
|
|
def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): |
|
out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) |
|
self.cur_att_layer += 1 |
|
if self.cur_att_layer == self.num_att_layers: |
|
self.cur_att_layer = 0 |
|
self.cur_step += 1 |
|
|
|
self.after_step() |
|
return out |
|
|
|
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): |
|
out = torch.einsum('b i j, b j d -> b i d', attn, v) |
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads) |
|
return out |
|
|
|
def reset(self): |
|
self.cur_step = 0 |
|
self.cur_att_layer = 0 |
|
|
|
|
|
class AttentionStore(AttentionBase): |
|
def __init__(self, res=[32], min_step=0, max_step=1000): |
|
super().__init__() |
|
self.res = res |
|
self.min_step = min_step |
|
self.max_step = max_step |
|
self.valid_steps = 0 |
|
|
|
self.self_attns = [] |
|
self.cross_attns = [] |
|
|
|
self.self_attns_step = [] |
|
self.cross_attns_step = [] |
|
|
|
def after_step(self): |
|
if self.cur_step > self.min_step and self.cur_step < self.max_step: |
|
self.valid_steps += 1 |
|
if len(self.self_attns) == 0: |
|
self.self_attns = self.self_attns_step |
|
self.cross_attns = self.cross_attns_step |
|
else: |
|
for i in range(len(self.self_attns)): |
|
self.self_attns[i] += self.self_attns_step[i] |
|
self.cross_attns[i] += self.cross_attns_step[i] |
|
self.self_attns_step.clear() |
|
self.cross_attns_step.clear() |
|
|
|
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): |
|
if attn.shape[1] <= 64 ** 2: |
|
if is_cross: |
|
self.cross_attns_step.append(attn) |
|
else: |
|
self.self_attns_step.append(attn) |
|
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) |
|
|
|
|
|
def regiter_attention_editor_diffusers(model, editor: AttentionBase): |
|
""" |
|
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt] |
|
""" |
|
def ca_forward(self, place_in_unet): |
|
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None): |
|
""" |
|
The attention is similar to the original implementation of LDM CrossAttention class |
|
except adding some modifications on the attention |
|
""" |
|
if encoder_hidden_states is not None: |
|
context = encoder_hidden_states |
|
if attention_mask is not None: |
|
mask = attention_mask |
|
|
|
to_out = self.to_out |
|
if isinstance(to_out, nn.modules.container.ModuleList): |
|
to_out = self.to_out[0] |
|
else: |
|
to_out = self.to_out |
|
|
|
h = self.heads |
|
q = self.to_q(x) |
|
is_cross = context is not None |
|
context = context if is_cross else x |
|
k = self.to_k(context) |
|
v = self.to_v(context) |
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) |
|
|
|
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale |
|
|
|
if mask is not None: |
|
mask = rearrange(mask, 'b ... -> b (...)') |
|
max_neg_value = -torch.finfo(sim.dtype).max |
|
mask = repeat(mask, 'b j -> (b h) () j', h=h) |
|
mask = mask[:, None, :].repeat(h, 1, 1) |
|
sim.masked_fill_(~mask, max_neg_value) |
|
|
|
attn = sim.softmax(dim=-1) |
|
|
|
out = editor( |
|
q, k, v, sim, attn, is_cross, place_in_unet, |
|
self.heads, scale=self.scale) |
|
|
|
return to_out(out) |
|
|
|
return forward |
|
|
|
def register_editor(net, count, place_in_unet): |
|
for name, subnet in net.named_children(): |
|
if net.__class__.__name__ == 'Attention': |
|
net.forward = ca_forward(net, place_in_unet) |
|
return count + 1 |
|
elif hasattr(net, 'children'): |
|
count = register_editor(subnet, count, place_in_unet) |
|
return count |
|
|
|
cross_att_count = 0 |
|
for net_name, net in model.unet.named_children(): |
|
if "down" in net_name: |
|
cross_att_count += register_editor(net, 0, "down") |
|
elif "mid" in net_name: |
|
cross_att_count += register_editor(net, 0, "mid") |
|
elif "up" in net_name: |
|
cross_att_count += register_editor(net, 0, "up") |
|
editor.num_att_layers = cross_att_count |
|
|
|
|
|
def regiter_attention_editor_ldm(model, editor: AttentionBase): |
|
""" |
|
Register a attention editor to Stable Diffusion model, refer from [Prompt-to-Prompt] |
|
""" |
|
def ca_forward(self, place_in_unet): |
|
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None): |
|
""" |
|
The attention is similar to the original implementation of LDM CrossAttention class |
|
except adding some modifications on the attention |
|
""" |
|
if encoder_hidden_states is not None: |
|
context = encoder_hidden_states |
|
if attention_mask is not None: |
|
mask = attention_mask |
|
|
|
to_out = self.to_out |
|
if isinstance(to_out, nn.modules.container.ModuleList): |
|
to_out = self.to_out[0] |
|
else: |
|
to_out = self.to_out |
|
|
|
h = self.heads |
|
q = self.to_q(x) |
|
is_cross = context is not None |
|
context = context if is_cross else x |
|
k = self.to_k(context) |
|
v = self.to_v(context) |
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) |
|
|
|
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale |
|
|
|
if mask is not None: |
|
mask = rearrange(mask, 'b ... -> b (...)') |
|
max_neg_value = -torch.finfo(sim.dtype).max |
|
mask = repeat(mask, 'b j -> (b h) () j', h=h) |
|
mask = mask[:, None, :].repeat(h, 1, 1) |
|
sim.masked_fill_(~mask, max_neg_value) |
|
|
|
attn = sim.softmax(dim=-1) |
|
|
|
out = editor( |
|
q, k, v, sim, attn, is_cross, place_in_unet, |
|
self.heads, scale=self.scale) |
|
|
|
return to_out(out) |
|
|
|
return forward |
|
|
|
def register_editor(net, count, place_in_unet): |
|
for name, subnet in net.named_children(): |
|
if net.__class__.__name__ == 'CrossAttention': |
|
net.forward = ca_forward(net, place_in_unet) |
|
return count + 1 |
|
elif hasattr(net, 'children'): |
|
count = register_editor(subnet, count, place_in_unet) |
|
return count |
|
|
|
cross_att_count = 0 |
|
for net_name, net in model.model.diffusion_model.named_children(): |
|
if "input" in net_name: |
|
cross_att_count += register_editor(net, 0, "input") |
|
elif "middle" in net_name: |
|
cross_att_count += register_editor(net, 0, "middle") |
|
elif "output" in net_name: |
|
cross_att_count += register_editor(net, 0, "output") |
|
editor.num_att_layers = cross_att_count |
|
|