Artist / models /attn_injection.py
fffiloni's picture
Upload 20 files
e02c605 verified
# -*- coding : utf-8 -*-
# @FileName : attn_injection.py
# @Author : Ruixiang JIANG (Songrise)
# @Time : Mar 20, 2024
# @Github : https://github.com/songrise
# @Description: implement attention dump and attention injection for CPSD
from __future__ import annotations
from dataclasses import dataclass
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
import torch
import torch.nn as nn
from torch.nn import functional as nnf
from diffusers.models import attention_processor
import einops
from diffusers.models import unet_2d_condition, attention, transformer_2d, resnet
from diffusers.models.unets import unet_2d_blocks
# from diffusers.models.unet_2d import CrossAttnUpBlock2D
from typing import Optional, List
T = torch.Tensor
import os
@dataclass(frozen=True)
class StyleAlignedArgs:
share_group_norm: bool = True
share_layer_norm: bool = (True,)
share_attention: bool = True
adain_queries: bool = True
adain_keys: bool = True
adain_values: bool = False
full_attention_share: bool = False
shared_score_scale: float = 1.0
shared_score_shift: float = 0.0
only_self_level: float = 0.0
def expand_first(
feat: T,
scale=1.0,
) -> T:
b = feat.shape[0]
feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
if scale == 1:
feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
else:
feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
return feat_style.reshape(*feat.shape)
def concat_first(feat: T, dim=2, scale=1.0) -> T:
feat_style = expand_first(feat, scale=scale)
return torch.cat((feat, feat_style), dim=dim)
def calc_mean_std(feat, eps: float = 1e-5) -> tuple[T, T]:
feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
feat_mean = feat.mean(dim=-2, keepdims=True)
return feat_mean, feat_std
def adain(feat: T) -> T:
feat_mean, feat_std = calc_mean_std(feat)
feat_style_mean = expand_first(feat_mean)
feat_style_std = expand_first(feat_std)
feat = (feat - feat_mean) / feat_std
feat = feat * feat_style_std + feat_style_mean
return feat
def my_adain(feat: T) -> T:
batch_size = feat.shape[0] // 2
feat_mean, feat_std = calc_mean_std(feat)
feat_uncond_content, feat_cond_content = feat[0], feat[batch_size]
feat_style_mean = torch.stack((feat_mean[1], feat_mean[batch_size + 1])).unsqueeze(
1
)
feat_style_mean = feat_style_mean.expand(2, batch_size, *feat_mean.shape[1:])
feat_style_mean = feat_style_mean.reshape(*feat_mean.shape) # (6, D)
feat_style_std = torch.stack((feat_std[1], feat_std[batch_size + 1])).unsqueeze(1)
feat_style_std = feat_style_std.expand(2, batch_size, *feat_std.shape[1:])
feat_style_std = feat_style_std.reshape(*feat_std.shape)
feat = (feat - feat_mean) / feat_std
feat = feat * feat_style_std + feat_style_mean
feat[0] = feat_uncond_content
feat[batch_size] = feat_cond_content
return feat
class DefaultAttentionProcessor(nn.Module):
def __init__(self):
super().__init__()
# self.processor = attention_processor.AttnProcessor2_0()
self.processor = attention_processor.AttnProcessor() # for torch 1.11.0
def __call__(
self,
attn: attention_processor.Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
**kwargs,
):
return self.processor(
attn, hidden_states, encoder_hidden_states, attention_mask
)
class ArtistAttentionProcessor(DefaultAttentionProcessor):
def __init__(
self,
inject_query: bool = True,
inject_key: bool = True,
inject_value: bool = True,
use_adain: bool = False,
name: str = None,
use_content_to_style_injection=False,
):
super().__init__()
self.inject_query = inject_query
self.inject_key = inject_key
self.inject_value = inject_value
self.share_enabled = True
self.use_adain = use_adain
self.__custom_name = name
self.content_to_style_injection = use_content_to_style_injection
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.Tensor:
#######Code from original attention impl
residual = hidden_states
# args = () if USE_PEFT_BACKEND else (scale,)
args = ()
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
######## inject begins here, here we assume the style image is always the 2nd instance in batch
batch_size = query.shape[0] // 2 # divide 2 since CFG is used
if self.share_enabled and batch_size > 1: # when == 1, no need to inject,
ref_q_uncond, ref_q_cond = query[1, ...].unsqueeze(0), query[
batch_size + 1, ...
].unsqueeze(0)
ref_k_uncond, ref_k_cond = key[1, ...].unsqueeze(0), key[
batch_size + 1, ...
].unsqueeze(0)
ref_v_uncond, ref_v_cond = value[1, ...].unsqueeze(0), value[
batch_size + 1, ...
].unsqueeze(0)
if self.inject_query:
if self.use_adain:
query = my_adain(query)
if self.content_to_style_injection:
content_v_uncond = value[0, ...].unsqueeze(0)
content_v_cond = value[batch_size, ...].unsqueeze(0)
query[1] = content_v_uncond
query[batch_size + 1] = content_v_cond
else:
query[2] = ref_q_uncond
query[batch_size + 2] = ref_q_cond
if self.inject_key:
if self.use_adain:
key = my_adain(key)
else:
key[2] = ref_k_uncond
key[batch_size + 2] = ref_k_cond
if self.inject_value:
if self.use_adain:
value = my_adain(value)
else:
value[2] = ref_v_uncond
value[batch_size + 2] = ref_v_cond
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
# inject here, swap the attention map
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class ArtistResBlockWrapper(nn.Module):
def __init__(
self, block: resnet.ResnetBlock2D, injection_method: str, name: str = None
):
super().__init__()
self.block = block
self.output_scale_factor = self.block.output_scale_factor
self.injection_method = injection_method
self.name = name
def forward(
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor,
scale: float = 1.0,
):
if self.injection_method == "hidden":
feat = self.block(
input_tensor, temb, scale
) # when disentangle, feat should be [recon, uncontrolled style, controlled style]
batch_size = feat.shape[0] // 2
if batch_size == 1:
return feat
# the features of the reconstruction
recon_feat_uncond, recon_feat_cond = feat[0, ...].unsqueeze(0), feat[
batch_size, ...
].unsqueeze(0)
# residual
input_tensor = self.block.conv_shortcut(input_tensor)
input_content_uncond, input_content_cond = input_tensor[0, ...].unsqueeze(
0
), input_tensor[batch_size, ...].unsqueeze(0)
# since feat = (input + h) / scale
recon_feat_uncond, recon_feat_cond = (
recon_feat_uncond * self.output_scale_factor,
recon_feat_cond * self.output_scale_factor,
)
h_content_uncond, h_content_cond = (
recon_feat_uncond - input_content_uncond,
recon_feat_cond - input_content_cond,
)
# only share the h, the residual is not shared
h_shared = torch.cat(
([h_content_uncond] * batch_size) + ([h_content_cond] * batch_size),
dim=0,
)
output_feat_shared = (input_tensor + h_shared) / self.output_scale_factor
# do not inject the feat for the 2nd instance, which is uncontrolled style
output_feat_shared[1] = feat[1]
output_feat_shared[batch_size + 1] = feat[batch_size + 1]
# uncomment to not inject content to controlled style
# output_feat_shared[2] = feat[2]
# output_feat_shared[batch_size + 2] = feat[batch_size + 2]
return output_feat_shared
else:
raise NotImplementedError(f"Unknown injection method {self.injection_method}")
class SharedResBlockWrapper(nn.Module):
def __init__(self, block: resnet.ResnetBlock2D):
super().__init__()
self.block = block
self.output_scale_factor = self.block.output_scale_factor
self.share_enabled = True
def forward(
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor,
scale: float = 1.0,
):
if self.share_enabled:
feat = self.block(input_tensor, temb, scale)
batch_size = feat.shape[0] // 2
if batch_size == 1:
return feat
# the features of the reconstruction
feat_uncond, feat_cond = feat[0, ...].unsqueeze(0), feat[
batch_size, ...
].unsqueeze(0)
# residual
input_tensor = self.block.conv_shortcut(input_tensor)
input_content_uncond, input_content_cond = input_tensor[0, ...].unsqueeze(
0
), input_tensor[batch_size, ...].unsqueeze(0)
# since feat = (input + h) / scale
feat_uncond, feat_cond = (
feat_uncond * self.output_scale_factor,
feat_cond * self.output_scale_factor,
)
h_content_uncond, h_content_cond = (
feat_uncond - input_content_uncond,
feat_cond - input_content_cond,
)
# only share the h, the residual is not shared
h_shared = torch.cat(
([h_content_uncond] * batch_size) + ([h_content_cond] * batch_size),
dim=0,
)
output_shared = (input_tensor + h_shared) / self.output_scale_factor
return output_shared
else:
return self.block(input_tensor, temb, scale)
def register_attention_processors(
pipe,
base_dir: str = None,
disentangle: bool = False,
attn_mode: str = "artist",
resnet_mode: str = "hidden",
share_resblock: bool = True,
share_attn: bool = True,
share_cross_attn: bool = False,
share_attn_layers: Optional[int] = None,
share_resnet_layers: Optional[int] = None,
c2s_layers: Optional[int] = [0, 1],
share_query: bool = True,
share_key: bool = True,
share_value: bool = True,
use_adain: bool = False,
):
unet: unet_2d_condition.UNet2DConditionModel = pipe.unet
if isinstance(pipe, StableDiffusionPipeline):
up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[
1:
] # skip the first block, which is UpBlock2D
elif isinstance(pipe, StableDiffusionXLPipeline):
up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[:-1]
layer_idx_attn = 0
layer_idx_resnet = 0
for block in up_blocks:
# each block should have 3 transformer layer
# transformer_layer : transformer_2d.Transformer2DModel
if share_resblock:
if share_resnet_layers is not None:
resnet_wrappers = []
resnets = block.resnets
for resnet_block in resnets:
if layer_idx_resnet not in share_resnet_layers:
resnet_wrappers.append(
resnet_block
) # use original implementation
else:
if disentangle:
resnet_wrappers.append(
ArtistResBlockWrapper(
resnet_block,
injection_method=resnet_mode,
name=f"layer_{layer_idx_resnet}",
)
)
print(
f"Disentangle resnet {resnet_mode} set for layer {layer_idx_resnet}"
)
else:
resnet_wrappers.append(SharedResBlockWrapper(resnet_block))
print(
f"Share resnet feature set for layer {layer_idx_resnet}"
)
layer_idx_resnet += 1
block.resnets = nn.ModuleList(
resnet_wrappers
) # actually apply the change
if share_attn:
for transformer_layer in block.attentions:
transformer_block: attention.BasicTransformerBlock = (
transformer_layer.transformer_blocks[0]
)
self_attn: attention_processor.Attention = transformer_block.attn1
# cross attn does not inject
cross_attn: attention_processor.Attention = transformer_block.attn2
if attn_mode == "artist":
if (
share_attn_layers is not None
and layer_idx_attn in share_attn_layers
):
if layer_idx_attn in c2s_layers:
content_to_style = True
else:
content_to_style = False
pnp_inject_processor = ArtistAttentionProcessor(
inject_query=share_query,
inject_key=share_key,
inject_value=share_value,
use_adain=use_adain,
name=f"layer_{layer_idx_attn}_self",
use_content_to_style_injection=content_to_style,
)
self_attn.set_processor(pnp_inject_processor)
print(
f"Disentangled Pnp inject processor set for self-attention in layer {layer_idx_attn} with c2s={content_to_style}"
)
if share_cross_attn:
cross_attn_processor = ArtistAttentionProcessor(
inject_query=False,
inject_key=True,
inject_value=True,
use_adain=False,
name=f"layer_{layer_idx_attn}_cross",
)
cross_attn.set_processor(cross_attn_processor)
print(
f"Disentangled Pnp inject processor set for cross-attention in layer {layer_idx_attn}"
)
layer_idx_attn += 1
def unset_attention_processors(
pipe,
unset_share_attn: bool = False,
unset_share_resblock: bool = False,
):
unet: unet_2d_condition.UNet2DConditionMode = pipe.unet
if isinstance(pipe, StableDiffusionPipeline):
up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[
1:
] # skip the first block, which is UpBlock2D
elif isinstance(pipe, StableDiffusionXLPipeline):
up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[:-1]
block_idx = 1
layer_idx = 0
for block in up_blocks:
if unset_share_resblock:
resnet_origs = []
resnets = block.resnets
for resnet_block in resnets:
if isinstance(resnet_block, SharedResBlockWrapper) or isinstance(
resnet_block, ArtistResBlockWrapper
):
resnet_origs.append(resnet_block.block)
else:
resnet_origs.append(resnet_block)
block.resnets = nn.ModuleList(resnet_origs)
if unset_share_attn:
for transformer_layer in block.attentions:
layer_idx += 1
transformer_block: attention.BasicTransformerBlock = (
transformer_layer.transformer_blocks[0]
)
self_attn: attention_processor.Attention = transformer_block.attn1
cross_attn: attention_processor.Attention = transformer_block.attn2
self_attn.set_processor(DefaultAttentionProcessor())
cross_attn.set_processor(DefaultAttentionProcessor())
block_idx += 1
layer_idx = 0