|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
from dataclasses import dataclass |
|
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as nnf |
|
from diffusers.models import attention_processor |
|
import einops |
|
from diffusers.models import unet_2d_condition, attention, transformer_2d, resnet |
|
from diffusers.models.unets import unet_2d_blocks |
|
|
|
|
|
from typing import Optional, List |
|
|
|
T = torch.Tensor |
|
import os |
|
|
|
|
|
@dataclass(frozen=True) |
|
class StyleAlignedArgs: |
|
share_group_norm: bool = True |
|
share_layer_norm: bool = (True,) |
|
share_attention: bool = True |
|
adain_queries: bool = True |
|
adain_keys: bool = True |
|
adain_values: bool = False |
|
full_attention_share: bool = False |
|
shared_score_scale: float = 1.0 |
|
shared_score_shift: float = 0.0 |
|
only_self_level: float = 0.0 |
|
|
|
|
|
def expand_first( |
|
feat: T, |
|
scale=1.0, |
|
) -> T: |
|
b = feat.shape[0] |
|
feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1) |
|
if scale == 1: |
|
feat_style = feat_style.expand(2, b // 2, *feat.shape[1:]) |
|
else: |
|
feat_style = feat_style.repeat(1, b // 2, 1, 1, 1) |
|
feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1) |
|
return feat_style.reshape(*feat.shape) |
|
|
|
|
|
def concat_first(feat: T, dim=2, scale=1.0) -> T: |
|
feat_style = expand_first(feat, scale=scale) |
|
return torch.cat((feat, feat_style), dim=dim) |
|
|
|
|
|
def calc_mean_std(feat, eps: float = 1e-5) -> tuple[T, T]: |
|
feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt() |
|
feat_mean = feat.mean(dim=-2, keepdims=True) |
|
return feat_mean, feat_std |
|
|
|
|
|
def adain(feat: T) -> T: |
|
feat_mean, feat_std = calc_mean_std(feat) |
|
feat_style_mean = expand_first(feat_mean) |
|
feat_style_std = expand_first(feat_std) |
|
feat = (feat - feat_mean) / feat_std |
|
feat = feat * feat_style_std + feat_style_mean |
|
return feat |
|
|
|
|
|
def my_adain(feat: T) -> T: |
|
batch_size = feat.shape[0] // 2 |
|
feat_mean, feat_std = calc_mean_std(feat) |
|
feat_uncond_content, feat_cond_content = feat[0], feat[batch_size] |
|
|
|
feat_style_mean = torch.stack((feat_mean[1], feat_mean[batch_size + 1])).unsqueeze( |
|
1 |
|
) |
|
feat_style_mean = feat_style_mean.expand(2, batch_size, *feat_mean.shape[1:]) |
|
feat_style_mean = feat_style_mean.reshape(*feat_mean.shape) |
|
|
|
feat_style_std = torch.stack((feat_std[1], feat_std[batch_size + 1])).unsqueeze(1) |
|
feat_style_std = feat_style_std.expand(2, batch_size, *feat_std.shape[1:]) |
|
feat_style_std = feat_style_std.reshape(*feat_std.shape) |
|
|
|
feat = (feat - feat_mean) / feat_std |
|
feat = feat * feat_style_std + feat_style_mean |
|
feat[0] = feat_uncond_content |
|
feat[batch_size] = feat_cond_content |
|
return feat |
|
|
|
|
|
class DefaultAttentionProcessor(nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.processor = attention_processor.AttnProcessor() |
|
|
|
def __call__( |
|
self, |
|
attn: attention_processor.Attention, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
**kwargs, |
|
): |
|
return self.processor( |
|
attn, hidden_states, encoder_hidden_states, attention_mask |
|
) |
|
|
|
|
|
class ArtistAttentionProcessor(DefaultAttentionProcessor): |
|
def __init__( |
|
self, |
|
inject_query: bool = True, |
|
inject_key: bool = True, |
|
inject_value: bool = True, |
|
use_adain: bool = False, |
|
name: str = None, |
|
use_content_to_style_injection=False, |
|
): |
|
super().__init__() |
|
|
|
self.inject_query = inject_query |
|
self.inject_key = inject_key |
|
self.inject_value = inject_value |
|
self.share_enabled = True |
|
self.use_adain = use_adain |
|
|
|
self.__custom_name = name |
|
self.content_to_style_injection = use_content_to_style_injection |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
temb: Optional[torch.FloatTensor] = None, |
|
scale: float = 1.0, |
|
) -> torch.Tensor: |
|
|
|
residual = hidden_states |
|
|
|
|
|
args = () |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view( |
|
batch_size, channel, height * width |
|
).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape |
|
if encoder_hidden_states is None |
|
else encoder_hidden_states.shape |
|
) |
|
attention_mask = attn.prepare_attention_mask( |
|
attention_mask, sequence_length, batch_size |
|
) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( |
|
1, 2 |
|
) |
|
|
|
query = attn.to_q(hidden_states, *args) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states( |
|
encoder_hidden_states |
|
) |
|
|
|
key = attn.to_k(encoder_hidden_states, *args) |
|
value = attn.to_v(encoder_hidden_states, *args) |
|
|
|
batch_size = query.shape[0] // 2 |
|
if self.share_enabled and batch_size > 1: |
|
ref_q_uncond, ref_q_cond = query[1, ...].unsqueeze(0), query[ |
|
batch_size + 1, ... |
|
].unsqueeze(0) |
|
ref_k_uncond, ref_k_cond = key[1, ...].unsqueeze(0), key[ |
|
batch_size + 1, ... |
|
].unsqueeze(0) |
|
|
|
ref_v_uncond, ref_v_cond = value[1, ...].unsqueeze(0), value[ |
|
batch_size + 1, ... |
|
].unsqueeze(0) |
|
if self.inject_query: |
|
if self.use_adain: |
|
query = my_adain(query) |
|
|
|
if self.content_to_style_injection: |
|
content_v_uncond = value[0, ...].unsqueeze(0) |
|
content_v_cond = value[batch_size, ...].unsqueeze(0) |
|
query[1] = content_v_uncond |
|
query[batch_size + 1] = content_v_cond |
|
else: |
|
query[2] = ref_q_uncond |
|
query[batch_size + 2] = ref_q_cond |
|
if self.inject_key: |
|
if self.use_adain: |
|
key = my_adain(key) |
|
else: |
|
key[2] = ref_k_uncond |
|
key[batch_size + 2] = ref_k_cond |
|
|
|
if self.inject_value: |
|
if self.use_adain: |
|
value = my_adain(value) |
|
else: |
|
value[2] = ref_v_uncond |
|
value[batch_size + 2] = ref_v_cond |
|
|
|
query = attn.head_to_batch_dim(query) |
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
|
|
|
|
hidden_states = torch.bmm(attention_probs, value) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states, *args) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape( |
|
batch_size, channel, height, width |
|
) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
class ArtistResBlockWrapper(nn.Module): |
|
|
|
def __init__( |
|
self, block: resnet.ResnetBlock2D, injection_method: str, name: str = None |
|
): |
|
super().__init__() |
|
self.block = block |
|
self.output_scale_factor = self.block.output_scale_factor |
|
self.injection_method = injection_method |
|
self.name = name |
|
|
|
def forward( |
|
self, |
|
input_tensor: torch.FloatTensor, |
|
temb: torch.FloatTensor, |
|
scale: float = 1.0, |
|
): |
|
if self.injection_method == "hidden": |
|
feat = self.block( |
|
input_tensor, temb, scale |
|
) |
|
batch_size = feat.shape[0] // 2 |
|
if batch_size == 1: |
|
return feat |
|
|
|
|
|
recon_feat_uncond, recon_feat_cond = feat[0, ...].unsqueeze(0), feat[ |
|
batch_size, ... |
|
].unsqueeze(0) |
|
|
|
input_tensor = self.block.conv_shortcut(input_tensor) |
|
input_content_uncond, input_content_cond = input_tensor[0, ...].unsqueeze( |
|
0 |
|
), input_tensor[batch_size, ...].unsqueeze(0) |
|
|
|
recon_feat_uncond, recon_feat_cond = ( |
|
recon_feat_uncond * self.output_scale_factor, |
|
recon_feat_cond * self.output_scale_factor, |
|
) |
|
h_content_uncond, h_content_cond = ( |
|
recon_feat_uncond - input_content_uncond, |
|
recon_feat_cond - input_content_cond, |
|
) |
|
|
|
h_shared = torch.cat( |
|
([h_content_uncond] * batch_size) + ([h_content_cond] * batch_size), |
|
dim=0, |
|
) |
|
|
|
output_feat_shared = (input_tensor + h_shared) / self.output_scale_factor |
|
|
|
output_feat_shared[1] = feat[1] |
|
output_feat_shared[batch_size + 1] = feat[batch_size + 1] |
|
|
|
|
|
|
|
return output_feat_shared |
|
else: |
|
raise NotImplementedError(f"Unknown injection method {self.injection_method}") |
|
|
|
|
|
class SharedResBlockWrapper(nn.Module): |
|
def __init__(self, block: resnet.ResnetBlock2D): |
|
super().__init__() |
|
self.block = block |
|
self.output_scale_factor = self.block.output_scale_factor |
|
self.share_enabled = True |
|
|
|
def forward( |
|
self, |
|
input_tensor: torch.FloatTensor, |
|
temb: torch.FloatTensor, |
|
scale: float = 1.0, |
|
): |
|
if self.share_enabled: |
|
feat = self.block(input_tensor, temb, scale) |
|
batch_size = feat.shape[0] // 2 |
|
if batch_size == 1: |
|
return feat |
|
|
|
|
|
feat_uncond, feat_cond = feat[0, ...].unsqueeze(0), feat[ |
|
batch_size, ... |
|
].unsqueeze(0) |
|
|
|
input_tensor = self.block.conv_shortcut(input_tensor) |
|
input_content_uncond, input_content_cond = input_tensor[0, ...].unsqueeze( |
|
0 |
|
), input_tensor[batch_size, ...].unsqueeze(0) |
|
|
|
feat_uncond, feat_cond = ( |
|
feat_uncond * self.output_scale_factor, |
|
feat_cond * self.output_scale_factor, |
|
) |
|
h_content_uncond, h_content_cond = ( |
|
feat_uncond - input_content_uncond, |
|
feat_cond - input_content_cond, |
|
) |
|
|
|
h_shared = torch.cat( |
|
([h_content_uncond] * batch_size) + ([h_content_cond] * batch_size), |
|
dim=0, |
|
) |
|
output_shared = (input_tensor + h_shared) / self.output_scale_factor |
|
return output_shared |
|
else: |
|
return self.block(input_tensor, temb, scale) |
|
|
|
|
|
|
|
|
|
def register_attention_processors( |
|
pipe, |
|
base_dir: str = None, |
|
disentangle: bool = False, |
|
attn_mode: str = "artist", |
|
resnet_mode: str = "hidden", |
|
share_resblock: bool = True, |
|
share_attn: bool = True, |
|
share_cross_attn: bool = False, |
|
share_attn_layers: Optional[int] = None, |
|
share_resnet_layers: Optional[int] = None, |
|
c2s_layers: Optional[int] = [0, 1], |
|
share_query: bool = True, |
|
share_key: bool = True, |
|
share_value: bool = True, |
|
use_adain: bool = False, |
|
): |
|
unet: unet_2d_condition.UNet2DConditionModel = pipe.unet |
|
if isinstance(pipe, StableDiffusionPipeline): |
|
up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[ |
|
1: |
|
] |
|
elif isinstance(pipe, StableDiffusionXLPipeline): |
|
up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[:-1] |
|
layer_idx_attn = 0 |
|
layer_idx_resnet = 0 |
|
for block in up_blocks: |
|
|
|
|
|
if share_resblock: |
|
if share_resnet_layers is not None: |
|
resnet_wrappers = [] |
|
resnets = block.resnets |
|
for resnet_block in resnets: |
|
if layer_idx_resnet not in share_resnet_layers: |
|
resnet_wrappers.append( |
|
resnet_block |
|
) |
|
else: |
|
if disentangle: |
|
resnet_wrappers.append( |
|
ArtistResBlockWrapper( |
|
resnet_block, |
|
injection_method=resnet_mode, |
|
name=f"layer_{layer_idx_resnet}", |
|
) |
|
) |
|
print( |
|
f"Disentangle resnet {resnet_mode} set for layer {layer_idx_resnet}" |
|
) |
|
else: |
|
resnet_wrappers.append(SharedResBlockWrapper(resnet_block)) |
|
print( |
|
f"Share resnet feature set for layer {layer_idx_resnet}" |
|
) |
|
|
|
layer_idx_resnet += 1 |
|
block.resnets = nn.ModuleList( |
|
resnet_wrappers |
|
) |
|
if share_attn: |
|
for transformer_layer in block.attentions: |
|
transformer_block: attention.BasicTransformerBlock = ( |
|
transformer_layer.transformer_blocks[0] |
|
) |
|
self_attn: attention_processor.Attention = transformer_block.attn1 |
|
|
|
cross_attn: attention_processor.Attention = transformer_block.attn2 |
|
|
|
if attn_mode == "artist": |
|
if ( |
|
share_attn_layers is not None |
|
and layer_idx_attn in share_attn_layers |
|
): |
|
if layer_idx_attn in c2s_layers: |
|
content_to_style = True |
|
else: |
|
content_to_style = False |
|
pnp_inject_processor = ArtistAttentionProcessor( |
|
inject_query=share_query, |
|
inject_key=share_key, |
|
inject_value=share_value, |
|
use_adain=use_adain, |
|
name=f"layer_{layer_idx_attn}_self", |
|
use_content_to_style_injection=content_to_style, |
|
) |
|
self_attn.set_processor(pnp_inject_processor) |
|
print( |
|
f"Disentangled Pnp inject processor set for self-attention in layer {layer_idx_attn} with c2s={content_to_style}" |
|
) |
|
if share_cross_attn: |
|
cross_attn_processor = ArtistAttentionProcessor( |
|
inject_query=False, |
|
inject_key=True, |
|
inject_value=True, |
|
use_adain=False, |
|
name=f"layer_{layer_idx_attn}_cross", |
|
) |
|
cross_attn.set_processor(cross_attn_processor) |
|
print( |
|
f"Disentangled Pnp inject processor set for cross-attention in layer {layer_idx_attn}" |
|
) |
|
layer_idx_attn += 1 |
|
|
|
|
|
def unset_attention_processors( |
|
pipe, |
|
unset_share_attn: bool = False, |
|
unset_share_resblock: bool = False, |
|
): |
|
unet: unet_2d_condition.UNet2DConditionMode = pipe.unet |
|
if isinstance(pipe, StableDiffusionPipeline): |
|
up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[ |
|
1: |
|
] |
|
elif isinstance(pipe, StableDiffusionXLPipeline): |
|
up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[:-1] |
|
block_idx = 1 |
|
layer_idx = 0 |
|
for block in up_blocks: |
|
if unset_share_resblock: |
|
resnet_origs = [] |
|
resnets = block.resnets |
|
for resnet_block in resnets: |
|
if isinstance(resnet_block, SharedResBlockWrapper) or isinstance( |
|
resnet_block, ArtistResBlockWrapper |
|
): |
|
resnet_origs.append(resnet_block.block) |
|
else: |
|
resnet_origs.append(resnet_block) |
|
block.resnets = nn.ModuleList(resnet_origs) |
|
if unset_share_attn: |
|
for transformer_layer in block.attentions: |
|
layer_idx += 1 |
|
transformer_block: attention.BasicTransformerBlock = ( |
|
transformer_layer.transformer_blocks[0] |
|
) |
|
self_attn: attention_processor.Attention = transformer_block.attn1 |
|
cross_attn: attention_processor.Attention = transformer_block.attn2 |
|
self_attn.set_processor(DefaultAttentionProcessor()) |
|
cross_attn.set_processor(DefaultAttentionProcessor()) |
|
block_idx += 1 |
|
layer_idx = 0 |
|
|