Spaces:
Running
on
Zero
Running
on
Zero
| # -*- coding : utf-8 -*- | |
| # @FileName : attn_injection.py | |
| # @Author : Ruixiang JIANG (Songrise) | |
| # @Time : Mar 20, 2024 | |
| # @Github : https://github.com/songrise | |
| # @Description: implement attention dump and attention injection for CPSD | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as nnf | |
| from diffusers.models import attention_processor | |
| import einops | |
| from diffusers.models import unet_2d_condition, attention, transformer_2d, resnet | |
| from diffusers.models.unets import unet_2d_blocks | |
| # from diffusers.models.unet_2d import CrossAttnUpBlock2D | |
| from typing import Optional, List | |
| T = torch.Tensor | |
| import os | |
| class StyleAlignedArgs: | |
| share_group_norm: bool = True | |
| share_layer_norm: bool = (True,) | |
| share_attention: bool = True | |
| adain_queries: bool = True | |
| adain_keys: bool = True | |
| adain_values: bool = False | |
| full_attention_share: bool = False | |
| shared_score_scale: float = 1.0 | |
| shared_score_shift: float = 0.0 | |
| only_self_level: float = 0.0 | |
| def expand_first( | |
| feat: T, | |
| scale=1.0, | |
| ) -> T: | |
| b = feat.shape[0] | |
| feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1) | |
| if scale == 1: | |
| feat_style = feat_style.expand(2, b // 2, *feat.shape[1:]) | |
| else: | |
| feat_style = feat_style.repeat(1, b // 2, 1, 1, 1) | |
| feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1) | |
| return feat_style.reshape(*feat.shape) | |
| def concat_first(feat: T, dim=2, scale=1.0) -> T: | |
| feat_style = expand_first(feat, scale=scale) | |
| return torch.cat((feat, feat_style), dim=dim) | |
| def calc_mean_std(feat, eps: float = 1e-5) -> tuple[T, T]: | |
| feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt() | |
| feat_mean = feat.mean(dim=-2, keepdims=True) | |
| return feat_mean, feat_std | |
| def adain(feat: T) -> T: | |
| feat_mean, feat_std = calc_mean_std(feat) | |
| feat_style_mean = expand_first(feat_mean) | |
| feat_style_std = expand_first(feat_std) | |
| feat = (feat - feat_mean) / feat_std | |
| feat = feat * feat_style_std + feat_style_mean | |
| return feat | |
| def my_adain(feat: T) -> T: | |
| batch_size = feat.shape[0] // 2 | |
| feat_mean, feat_std = calc_mean_std(feat) | |
| feat_uncond_content, feat_cond_content = feat[0], feat[batch_size] | |
| feat_style_mean = torch.stack((feat_mean[1], feat_mean[batch_size + 1])).unsqueeze( | |
| 1 | |
| ) | |
| feat_style_mean = feat_style_mean.expand(2, batch_size, *feat_mean.shape[1:]) | |
| feat_style_mean = feat_style_mean.reshape(*feat_mean.shape) # (6, D) | |
| feat_style_std = torch.stack((feat_std[1], feat_std[batch_size + 1])).unsqueeze(1) | |
| feat_style_std = feat_style_std.expand(2, batch_size, *feat_std.shape[1:]) | |
| feat_style_std = feat_style_std.reshape(*feat_std.shape) | |
| feat = (feat - feat_mean) / feat_std | |
| feat = feat * feat_style_std + feat_style_mean | |
| feat[0] = feat_uncond_content | |
| feat[batch_size] = feat_cond_content | |
| return feat | |
| class DefaultAttentionProcessor(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # self.processor = attention_processor.AttnProcessor2_0() | |
| self.processor = attention_processor.AttnProcessor() # for torch 1.11.0 | |
| def __call__( | |
| self, | |
| attn: attention_processor.Attention, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| **kwargs, | |
| ): | |
| return self.processor( | |
| attn, hidden_states, encoder_hidden_states, attention_mask | |
| ) | |
| class ArtistAttentionProcessor(DefaultAttentionProcessor): | |
| def __init__( | |
| self, | |
| inject_query: bool = True, | |
| inject_key: bool = True, | |
| inject_value: bool = True, | |
| use_adain: bool = False, | |
| name: str = None, | |
| use_content_to_style_injection=False, | |
| ): | |
| super().__init__() | |
| self.inject_query = inject_query | |
| self.inject_key = inject_key | |
| self.inject_value = inject_value | |
| self.share_enabled = True | |
| self.use_adain = use_adain | |
| self.__custom_name = name | |
| self.content_to_style_injection = use_content_to_style_injection | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.FloatTensor, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| temb: Optional[torch.FloatTensor] = None, | |
| scale: float = 1.0, | |
| ) -> torch.Tensor: | |
| #######Code from original attention impl | |
| residual = hidden_states | |
| # args = () if USE_PEFT_BACKEND else (scale,) | |
| args = () | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width | |
| ).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
| 1, 2 | |
| ) | |
| query = attn.to_q(hidden_states, *args) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states | |
| ) | |
| key = attn.to_k(encoder_hidden_states, *args) | |
| value = attn.to_v(encoder_hidden_states, *args) | |
| ######## inject begins here, here we assume the style image is always the 2nd instance in batch | |
| batch_size = query.shape[0] // 2 # divide 2 since CFG is used | |
| if self.share_enabled and batch_size > 1: # when == 1, no need to inject, | |
| ref_q_uncond, ref_q_cond = query[1, ...].unsqueeze(0), query[ | |
| batch_size + 1, ... | |
| ].unsqueeze(0) | |
| ref_k_uncond, ref_k_cond = key[1, ...].unsqueeze(0), key[ | |
| batch_size + 1, ... | |
| ].unsqueeze(0) | |
| ref_v_uncond, ref_v_cond = value[1, ...].unsqueeze(0), value[ | |
| batch_size + 1, ... | |
| ].unsqueeze(0) | |
| if self.inject_query: | |
| if self.use_adain: | |
| query = my_adain(query) | |
| if self.content_to_style_injection: | |
| content_v_uncond = value[0, ...].unsqueeze(0) | |
| content_v_cond = value[batch_size, ...].unsqueeze(0) | |
| query[1] = content_v_uncond | |
| query[batch_size + 1] = content_v_cond | |
| else: | |
| query[2] = ref_q_uncond | |
| query[batch_size + 2] = ref_q_cond | |
| if self.inject_key: | |
| if self.use_adain: | |
| key = my_adain(key) | |
| else: | |
| key[2] = ref_k_uncond | |
| key[batch_size + 2] = ref_k_cond | |
| if self.inject_value: | |
| if self.use_adain: | |
| value = my_adain(value) | |
| else: | |
| value[2] = ref_v_uncond | |
| value[batch_size + 2] = ref_v_cond | |
| query = attn.head_to_batch_dim(query) | |
| key = attn.head_to_batch_dim(key) | |
| value = attn.head_to_batch_dim(value) | |
| attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
| # inject here, swap the attention map | |
| hidden_states = torch.bmm(attention_probs, value) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states, *args) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class ArtistResBlockWrapper(nn.Module): | |
| def __init__( | |
| self, block: resnet.ResnetBlock2D, injection_method: str, name: str = None | |
| ): | |
| super().__init__() | |
| self.block = block | |
| self.output_scale_factor = self.block.output_scale_factor | |
| self.injection_method = injection_method | |
| self.name = name | |
| def forward( | |
| self, | |
| input_tensor: torch.FloatTensor, | |
| temb: torch.FloatTensor, | |
| scale: float = 1.0, | |
| ): | |
| if self.injection_method == "hidden": | |
| feat = self.block( | |
| input_tensor, temb, scale | |
| ) # when disentangle, feat should be [recon, uncontrolled style, controlled style] | |
| batch_size = feat.shape[0] // 2 | |
| if batch_size == 1: | |
| return feat | |
| # the features of the reconstruction | |
| recon_feat_uncond, recon_feat_cond = feat[0, ...].unsqueeze(0), feat[ | |
| batch_size, ... | |
| ].unsqueeze(0) | |
| # residual | |
| input_tensor = self.block.conv_shortcut(input_tensor) | |
| input_content_uncond, input_content_cond = input_tensor[0, ...].unsqueeze( | |
| 0 | |
| ), input_tensor[batch_size, ...].unsqueeze(0) | |
| # since feat = (input + h) / scale | |
| recon_feat_uncond, recon_feat_cond = ( | |
| recon_feat_uncond * self.output_scale_factor, | |
| recon_feat_cond * self.output_scale_factor, | |
| ) | |
| h_content_uncond, h_content_cond = ( | |
| recon_feat_uncond - input_content_uncond, | |
| recon_feat_cond - input_content_cond, | |
| ) | |
| # only share the h, the residual is not shared | |
| h_shared = torch.cat( | |
| ([h_content_uncond] * batch_size) + ([h_content_cond] * batch_size), | |
| dim=0, | |
| ) | |
| output_feat_shared = (input_tensor + h_shared) / self.output_scale_factor | |
| # do not inject the feat for the 2nd instance, which is uncontrolled style | |
| output_feat_shared[1] = feat[1] | |
| output_feat_shared[batch_size + 1] = feat[batch_size + 1] | |
| # uncomment to not inject content to controlled style | |
| # output_feat_shared[2] = feat[2] | |
| # output_feat_shared[batch_size + 2] = feat[batch_size + 2] | |
| return output_feat_shared | |
| else: | |
| raise NotImplementedError(f"Unknown injection method {self.injection_method}") | |
| class SharedResBlockWrapper(nn.Module): | |
| def __init__(self, block: resnet.ResnetBlock2D): | |
| super().__init__() | |
| self.block = block | |
| self.output_scale_factor = self.block.output_scale_factor | |
| self.share_enabled = True | |
| def forward( | |
| self, | |
| input_tensor: torch.FloatTensor, | |
| temb: torch.FloatTensor, | |
| scale: float = 1.0, | |
| ): | |
| if self.share_enabled: | |
| feat = self.block(input_tensor, temb, scale) | |
| batch_size = feat.shape[0] // 2 | |
| if batch_size == 1: | |
| return feat | |
| # the features of the reconstruction | |
| feat_uncond, feat_cond = feat[0, ...].unsqueeze(0), feat[ | |
| batch_size, ... | |
| ].unsqueeze(0) | |
| # residual | |
| input_tensor = self.block.conv_shortcut(input_tensor) | |
| input_content_uncond, input_content_cond = input_tensor[0, ...].unsqueeze( | |
| 0 | |
| ), input_tensor[batch_size, ...].unsqueeze(0) | |
| # since feat = (input + h) / scale | |
| feat_uncond, feat_cond = ( | |
| feat_uncond * self.output_scale_factor, | |
| feat_cond * self.output_scale_factor, | |
| ) | |
| h_content_uncond, h_content_cond = ( | |
| feat_uncond - input_content_uncond, | |
| feat_cond - input_content_cond, | |
| ) | |
| # only share the h, the residual is not shared | |
| h_shared = torch.cat( | |
| ([h_content_uncond] * batch_size) + ([h_content_cond] * batch_size), | |
| dim=0, | |
| ) | |
| output_shared = (input_tensor + h_shared) / self.output_scale_factor | |
| return output_shared | |
| else: | |
| return self.block(input_tensor, temb, scale) | |
| def register_attention_processors( | |
| pipe, | |
| base_dir: str = None, | |
| disentangle: bool = False, | |
| attn_mode: str = "artist", | |
| resnet_mode: str = "hidden", | |
| share_resblock: bool = True, | |
| share_attn: bool = True, | |
| share_cross_attn: bool = False, | |
| share_attn_layers: Optional[int] = None, | |
| share_resnet_layers: Optional[int] = None, | |
| c2s_layers: Optional[int] = [0, 1], | |
| share_query: bool = True, | |
| share_key: bool = True, | |
| share_value: bool = True, | |
| use_adain: bool = False, | |
| ): | |
| unet: unet_2d_condition.UNet2DConditionModel = pipe.unet | |
| if isinstance(pipe, StableDiffusionPipeline): | |
| up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[ | |
| 1: | |
| ] # skip the first block, which is UpBlock2D | |
| elif isinstance(pipe, StableDiffusionXLPipeline): | |
| up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[:-1] | |
| layer_idx_attn = 0 | |
| layer_idx_resnet = 0 | |
| for block in up_blocks: | |
| # each block should have 3 transformer layer | |
| # transformer_layer : transformer_2d.Transformer2DModel | |
| if share_resblock: | |
| if share_resnet_layers is not None: | |
| resnet_wrappers = [] | |
| resnets = block.resnets | |
| for resnet_block in resnets: | |
| if layer_idx_resnet not in share_resnet_layers: | |
| resnet_wrappers.append( | |
| resnet_block | |
| ) # use original implementation | |
| else: | |
| if disentangle: | |
| resnet_wrappers.append( | |
| ArtistResBlockWrapper( | |
| resnet_block, | |
| injection_method=resnet_mode, | |
| name=f"layer_{layer_idx_resnet}", | |
| ) | |
| ) | |
| print( | |
| f"Disentangle resnet {resnet_mode} set for layer {layer_idx_resnet}" | |
| ) | |
| else: | |
| resnet_wrappers.append(SharedResBlockWrapper(resnet_block)) | |
| print( | |
| f"Share resnet feature set for layer {layer_idx_resnet}" | |
| ) | |
| layer_idx_resnet += 1 | |
| block.resnets = nn.ModuleList( | |
| resnet_wrappers | |
| ) # actually apply the change | |
| if share_attn: | |
| for transformer_layer in block.attentions: | |
| transformer_block: attention.BasicTransformerBlock = ( | |
| transformer_layer.transformer_blocks[0] | |
| ) | |
| self_attn: attention_processor.Attention = transformer_block.attn1 | |
| # cross attn does not inject | |
| cross_attn: attention_processor.Attention = transformer_block.attn2 | |
| if attn_mode == "artist": | |
| if ( | |
| share_attn_layers is not None | |
| and layer_idx_attn in share_attn_layers | |
| ): | |
| if layer_idx_attn in c2s_layers: | |
| content_to_style = True | |
| else: | |
| content_to_style = False | |
| pnp_inject_processor = ArtistAttentionProcessor( | |
| inject_query=share_query, | |
| inject_key=share_key, | |
| inject_value=share_value, | |
| use_adain=use_adain, | |
| name=f"layer_{layer_idx_attn}_self", | |
| use_content_to_style_injection=content_to_style, | |
| ) | |
| self_attn.set_processor(pnp_inject_processor) | |
| print( | |
| f"Disentangled Pnp inject processor set for self-attention in layer {layer_idx_attn} with c2s={content_to_style}" | |
| ) | |
| if share_cross_attn: | |
| cross_attn_processor = ArtistAttentionProcessor( | |
| inject_query=False, | |
| inject_key=True, | |
| inject_value=True, | |
| use_adain=False, | |
| name=f"layer_{layer_idx_attn}_cross", | |
| ) | |
| cross_attn.set_processor(cross_attn_processor) | |
| print( | |
| f"Disentangled Pnp inject processor set for cross-attention in layer {layer_idx_attn}" | |
| ) | |
| layer_idx_attn += 1 | |
| def unset_attention_processors( | |
| pipe, | |
| unset_share_attn: bool = False, | |
| unset_share_resblock: bool = False, | |
| ): | |
| unet: unet_2d_condition.UNet2DConditionMode = pipe.unet | |
| if isinstance(pipe, StableDiffusionPipeline): | |
| up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[ | |
| 1: | |
| ] # skip the first block, which is UpBlock2D | |
| elif isinstance(pipe, StableDiffusionXLPipeline): | |
| up_blocks: List[unet_2d_blocks.CrossAttnUpBlock2D] = unet.up_blocks[:-1] | |
| block_idx = 1 | |
| layer_idx = 0 | |
| for block in up_blocks: | |
| if unset_share_resblock: | |
| resnet_origs = [] | |
| resnets = block.resnets | |
| for resnet_block in resnets: | |
| if isinstance(resnet_block, SharedResBlockWrapper) or isinstance( | |
| resnet_block, ArtistResBlockWrapper | |
| ): | |
| resnet_origs.append(resnet_block.block) | |
| else: | |
| resnet_origs.append(resnet_block) | |
| block.resnets = nn.ModuleList(resnet_origs) | |
| if unset_share_attn: | |
| for transformer_layer in block.attentions: | |
| layer_idx += 1 | |
| transformer_block: attention.BasicTransformerBlock = ( | |
| transformer_layer.transformer_blocks[0] | |
| ) | |
| self_attn: attention_processor.Attention = transformer_block.attn1 | |
| cross_attn: attention_processor.Attention = transformer_block.attn2 | |
| self_attn.set_processor(DefaultAttentionProcessor()) | |
| cross_attn.set_processor(DefaultAttentionProcessor()) | |
| block_idx += 1 | |
| layer_idx = 0 | |