import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Union, Tuple, List, Callable, Dict from torchvision.utils import save_image from einops import rearrange, repeat class AttentionBase: def __init__(self): self.cur_step = 0 self.num_att_layers = -1 self.cur_att_layer = 0 def before_step(self): pass def after_step(self): pass def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): if self.cur_att_layer == 0: self.before_step() out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) self.cur_att_layer += 1 if self.cur_att_layer == self.num_att_layers: self.cur_att_layer = 0 self.cur_step += 1 # after step self.after_step() return out def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): out = torch.einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads) return out def reset(self): self.cur_step = 0 self.cur_att_layer = 0 class AttentionStore(AttentionBase): def __init__(self, res=[32], min_step=0, max_step=1000): super().__init__() self.res = res self.min_step = min_step self.max_step = max_step self.valid_steps = 0 self.self_attns = [] # store the all attns self.cross_attns = [] self.self_attns_step = [] # store the attns in each step self.cross_attns_step = [] def after_step(self): if self.cur_step > self.min_step and self.cur_step < self.max_step: self.valid_steps += 1 if len(self.self_attns) == 0: self.self_attns = self.self_attns_step self.cross_attns = self.cross_attns_step else: for i in range(len(self.self_attns)): self.self_attns[i] += self.self_attns_step[i] self.cross_attns[i] += self.cross_attns_step[i] self.self_attns_step.clear() self.cross_attns_step.clear() def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): if attn.shape[1] <= 64 ** 2: # avoid OOM if is_cross: self.cross_attns_step.append(attn) else: self.self_attns_step.append(attn) return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) def regiter_attention_editor_diffusers(model, editor: AttentionBase): """ Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt] """ def ca_forward(self, place_in_unet): def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None): """ The attention is similar to the original implementation of LDM CrossAttention class except adding some modifications on the attention """ if encoder_hidden_states is not None: context = encoder_hidden_states if attention_mask is not None: mask = attention_mask to_out = self.to_out if isinstance(to_out, nn.modules.container.ModuleList): to_out = self.to_out[0] else: to_out = self.to_out h = self.heads q = self.to_q(x) is_cross = context is not None context = context if is_cross else x k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale if mask is not None: mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) mask = mask[:, None, :].repeat(h, 1, 1) sim.masked_fill_(~mask, max_neg_value) attn = sim.softmax(dim=-1) # the only difference out = editor( q, k, v, sim, attn, is_cross, place_in_unet, self.heads, scale=self.scale) return to_out(out) return forward def register_editor(net, count, place_in_unet, prefix=''): for name, subnet in net.named_children(): if net.__class__.__name__ == 'Attention': # spatial Transformer layer net.forward = ca_forward(net, place_in_unet) return count + 1 elif hasattr(net, 'children'): count = register_editor(subnet, count, place_in_unet, prefix=prefix + '\t') return count cross_att_count = 0 for net_name, net in model.unet.named_children(): if "down" in net_name: cross_att_count += register_editor(net, 0, "down") elif "mid" in net_name: cross_att_count += register_editor(net, 0, "mid") elif "up" in net_name: cross_att_count += register_editor(net, 0, "up") editor.num_att_layers = cross_att_count editor.model = model model.editor = editor def regiter_attention_editor_ldm(model, editor: AttentionBase): """ Register a attention editor to Stable Diffusion model, refer from [Prompt-to-Prompt] """ def ca_forward(self, place_in_unet): def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None): """ The attention is similar to the original implementation of LDM CrossAttention class except adding some modifications on the attention """ if encoder_hidden_states is not None: context = encoder_hidden_states if attention_mask is not None: mask = attention_mask to_out = self.to_out if isinstance(to_out, nn.modules.container.ModuleList): to_out = self.to_out[0] else: to_out = self.to_out h = self.heads q = self.to_q(x) is_cross = context is not None context = context if is_cross else x k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale if mask is not None: mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) mask = mask[:, None, :].repeat(h, 1, 1) sim.masked_fill_(~mask, max_neg_value) attn = sim.softmax(dim=-1) # the only difference out = editor( q, k, v, sim, attn, is_cross, place_in_unet, self.heads, scale=self.scale) return to_out(out) return forward def register_editor(net, count, place_in_unet): for name, subnet in net.named_children(): if net.__class__.__name__ == 'CrossAttention': # spatial Transformer layer net.forward = ca_forward(net, place_in_unet) return count + 1 elif hasattr(net, 'children'): count = register_editor(subnet, count, place_in_unet) return count cross_att_count = 0 for net_name, net in model.model.diffusion_model.named_children(): if "input" in net_name: cross_att_count += register_editor(net, 0, "input") elif "middle" in net_name: cross_att_count += register_editor(net, 0, "middle") elif "output" in net_name: cross_att_count += register_editor(net, 0, "output") editor.num_att_layers = cross_att_count