# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from dataclasses import dataclass from diffusers import StableDiffusionXLPipeline import torch import torch.nn as nn from torch.nn import functional as nnf from diffusers.models import attention_processor import einops T = torch.Tensor @dataclass(frozen=True) class StyleAlignedArgs: share_group_norm: bool = True share_layer_norm: bool = True, share_attention: bool = True adain_queries: bool = True adain_keys: bool = True adain_values: bool = False full_attention_share: bool = False keys_scale: float = 1. only_self_level: float = 0. def expand_first(feat: T, scale=1., ) -> T: b = feat.shape[0] feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1) if scale == 1: feat_style = feat_style.expand(2, b // 2, *feat.shape[1:]) else: feat_style = feat_style.repeat(1, b // 2, 1, 1, 1) feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1) return feat_style.reshape(*feat.shape) def concat_first(feat: T, dim=2, scale=1.) -> T: feat_style = expand_first(feat, scale=scale) return torch.cat((feat, feat_style), dim=dim) def calc_mean_std(feat, eps: float = 1e-5) -> tuple[T, T]: feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt() feat_mean = feat.mean(dim=-2, keepdims=True) return feat_mean, feat_std def adain(feat: T) -> T: feat_mean, feat_std = calc_mean_std(feat) feat_style_mean = expand_first(feat_mean) feat_style_std = expand_first(feat_std) feat = (feat - feat_mean) / feat_std feat = feat * feat_style_std + feat_style_mean return feat class DefaultAttentionProcessor(nn.Module): def __init__(self): super().__init__() self.processor = attention_processor.AttnProcessor2_0() def __call__(self, attn: attention_processor.Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs): return self.processor(attn, hidden_states, encoder_hidden_states, attention_mask) class SharedAttentionProcessor(DefaultAttentionProcessor): def shared_call( self, attn: attention_processor.Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs ): residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # if self.step >= self.start_inject: if self.adain_queries: query = adain(query) if self.adain_keys: key = adain(key) if self.adain_values: value = adain(value) if self.share_attention: key = concat_first(key, -2, scale=self.keys_scale) value = concat_first(value, -2) hidden_states = nnf.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) else: hidden_states = nnf.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) # hidden_states = adain(hidden_states) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def __call__(self, attn: attention_processor.Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs): if self.full_attention_share: b, n, d = hidden_states.shape hidden_states = einops.rearrange(hidden_states, '(k b) n d -> k (b n) d', k=2) hidden_states = super().__call__(attn, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **kwargs) hidden_states = einops.rearrange(hidden_states, 'k (b n) d -> (k b) n d', n=n) else: hidden_states = self.shared_call(attn, hidden_states, hidden_states, attention_mask, **kwargs) return hidden_states def __init__(self, style_aligned_args: StyleAlignedArgs): super().__init__() self.share_attention = style_aligned_args.share_attention self.adain_queries = style_aligned_args.adain_queries self.adain_keys = style_aligned_args.adain_keys self.adain_values = style_aligned_args.adain_values self.full_attention_share = style_aligned_args.full_attention_share self.keys_scale = style_aligned_args.keys_scale def _get_switch_vec(total_num_layers, level): if level == 0: return torch.zeros(total_num_layers, dtype=torch.bool) if level == 1: return torch.ones(total_num_layers, dtype=torch.bool) to_flip = level > .5 if to_flip: level = 1 - level num_switch = int(level * total_num_layers) vec = torch.arange(total_num_layers) vec = vec % (total_num_layers // num_switch) vec = vec == 0 if to_flip: vec = ~vec return vec def init_attention_processors(pipeline: StableDiffusionXLPipeline, style_aligned_args: StyleAlignedArgs | None = None): attn_procs = {} unet = pipeline.unet number_of_self, number_of_cross = 0, 0 num_self_layers = len([name for name in unet.attn_processors.keys() if 'attn1' in name]) if style_aligned_args is None: only_self_vec = _get_switch_vec(num_self_layers, 1) else: only_self_vec = _get_switch_vec(num_self_layers, style_aligned_args.only_self_level) for i, name in enumerate(unet.attn_processors.keys()): is_self_attention = 'attn1' in name if is_self_attention: number_of_self += 1 if style_aligned_args is None or only_self_vec[i // 2]: attn_procs[name] = DefaultAttentionProcessor() else: attn_procs[name] = SharedAttentionProcessor(style_aligned_args) else: number_of_cross += 1 attn_procs[name] = DefaultAttentionProcessor() unet.set_attn_processor(attn_procs) def register_shared_norm(pipeline: StableDiffusionXLPipeline, share_group_norm: bool = True, share_layer_norm: bool = True, ): def register_norm_forward(norm_layer: nn.GroupNorm | nn.LayerNorm) -> nn.GroupNorm | nn.LayerNorm: if not hasattr(norm_layer, 'orig_forward'): setattr(norm_layer, 'orig_forward', norm_layer.forward) orig_forward = norm_layer.orig_forward def forward_(hidden_states: T) -> T: n = hidden_states.shape[-2] hidden_states = concat_first(hidden_states, dim=-2) hidden_states = orig_forward(hidden_states) return hidden_states[..., :n, :] norm_layer.forward = forward_ return norm_layer def get_norm_layers(pipeline_, norm_layers_: dict[str, list[nn.GroupNorm | nn.LayerNorm]]): if isinstance(pipeline_, nn.LayerNorm) and share_layer_norm: norm_layers_['layer'].append(pipeline_) if isinstance(pipeline_, nn.GroupNorm) and share_group_norm: norm_layers_['group'].append(pipeline_) else: for layer in pipeline_.children(): get_norm_layers(layer, norm_layers_) norm_layers = {'group': [], 'layer': []} get_norm_layers(pipeline.unet, norm_layers) return [register_norm_forward(layer) for layer in norm_layers['group']] + [register_norm_forward(layer) for layer in norm_layers['layer']] class Handler: def register(self, style_aligned_args: StyleAlignedArgs, ): self.norm_layers = register_shared_norm(self.pipeline, style_aligned_args.share_group_norm, style_aligned_args.share_layer_norm) init_attention_processors(self.pipeline, style_aligned_args) def remove(self): for layer in self.norm_layers: layer.forward = layer.orig_forward self.norm_layers = [] init_attention_processors(self.pipeline, None) def __init__(self, pipeline: StableDiffusionXLPipeline): self.pipeline = pipeline self.norm_layers = []