Spaces:
Runtime error
Runtime error
# Copyright 2023 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import annotations | |
from dataclasses import dataclass | |
from diffusers import StableDiffusionXLPipeline | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as nnf | |
from diffusers.models import attention_processor | |
import einops | |
T = torch.Tensor | |
class StyleAlignedArgs: | |
share_group_norm: bool = True | |
share_layer_norm: bool = True, | |
share_attention: bool = True | |
adain_queries: bool = True | |
adain_keys: bool = True | |
adain_values: bool = False | |
full_attention_share: bool = False | |
keys_scale: float = 1. | |
only_self_level: float = 0. | |
def expand_first(feat: T, scale=1., ) -> T: | |
b = feat.shape[0] | |
feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1) | |
if scale == 1: | |
feat_style = feat_style.expand(2, b // 2, *feat.shape[1:]) | |
else: | |
feat_style = feat_style.repeat(1, b // 2, 1, 1, 1) | |
feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1) | |
return feat_style.reshape(*feat.shape) | |
def concat_first(feat: T, dim=2, scale=1.) -> T: | |
feat_style = expand_first(feat, scale=scale) | |
return torch.cat((feat, feat_style), dim=dim) | |
def calc_mean_std(feat, eps: float = 1e-5) -> tuple[T, T]: | |
feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt() | |
feat_mean = feat.mean(dim=-2, keepdims=True) | |
return feat_mean, feat_std | |
def adain(feat: T) -> T: | |
feat_mean, feat_std = calc_mean_std(feat) | |
feat_style_mean = expand_first(feat_mean) | |
feat_style_std = expand_first(feat_std) | |
feat = (feat - feat_mean) / feat_std | |
feat = feat * feat_style_std + feat_style_mean | |
return feat | |
class DefaultAttentionProcessor(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.processor = attention_processor.AttnProcessor2_0() | |
def __call__(self, attn: attention_processor.Attention, hidden_states, encoder_hidden_states=None, | |
attention_mask=None, **kwargs): | |
return self.processor(attn, hidden_states, encoder_hidden_states, attention_mask) | |
class SharedAttentionProcessor(DefaultAttentionProcessor): | |
def shared_call( | |
self, | |
attn: attention_processor.Attention, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
**kwargs | |
): | |
residual = hidden_states | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
key = attn.to_k(hidden_states) | |
value = attn.to_v(hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# if self.step >= self.start_inject: | |
if self.adain_queries: | |
query = adain(query) | |
if self.adain_keys: | |
key = adain(key) | |
if self.adain_values: | |
value = adain(value) | |
if self.share_attention: | |
key = concat_first(key, -2, scale=self.keys_scale) | |
value = concat_first(value, -2) | |
hidden_states = nnf.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
else: | |
hidden_states = nnf.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
# hidden_states = adain(hidden_states) | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
hidden_states = hidden_states.to(query.dtype) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
def __call__(self, attn: attention_processor.Attention, hidden_states, encoder_hidden_states=None, | |
attention_mask=None, **kwargs): | |
if self.full_attention_share: | |
b, n, d = hidden_states.shape | |
hidden_states = einops.rearrange(hidden_states, '(k b) n d -> k (b n) d', k=2) | |
hidden_states = super().__call__(attn, hidden_states, encoder_hidden_states=encoder_hidden_states, | |
attention_mask=attention_mask, **kwargs) | |
hidden_states = einops.rearrange(hidden_states, 'k (b n) d -> (k b) n d', n=n) | |
else: | |
hidden_states = self.shared_call(attn, hidden_states, hidden_states, attention_mask, **kwargs) | |
return hidden_states | |
def __init__(self, style_aligned_args: StyleAlignedArgs): | |
super().__init__() | |
self.share_attention = style_aligned_args.share_attention | |
self.adain_queries = style_aligned_args.adain_queries | |
self.adain_keys = style_aligned_args.adain_keys | |
self.adain_values = style_aligned_args.adain_values | |
self.full_attention_share = style_aligned_args.full_attention_share | |
self.keys_scale = style_aligned_args.keys_scale | |
def _get_switch_vec(total_num_layers, level): | |
if level == 0: | |
return torch.zeros(total_num_layers, dtype=torch.bool) | |
if level == 1: | |
return torch.ones(total_num_layers, dtype=torch.bool) | |
to_flip = level > .5 | |
if to_flip: | |
level = 1 - level | |
num_switch = int(level * total_num_layers) | |
vec = torch.arange(total_num_layers) | |
vec = vec % (total_num_layers // num_switch) | |
vec = vec == 0 | |
if to_flip: | |
vec = ~vec | |
return vec | |
def init_attention_processors(pipeline: StableDiffusionXLPipeline, style_aligned_args: StyleAlignedArgs | None = None): | |
attn_procs = {} | |
unet = pipeline.unet | |
number_of_self, number_of_cross = 0, 0 | |
num_self_layers = len([name for name in unet.attn_processors.keys() if 'attn1' in name]) | |
if style_aligned_args is None: | |
only_self_vec = _get_switch_vec(num_self_layers, 1) | |
else: | |
only_self_vec = _get_switch_vec(num_self_layers, style_aligned_args.only_self_level) | |
for i, name in enumerate(unet.attn_processors.keys()): | |
is_self_attention = 'attn1' in name | |
if is_self_attention: | |
number_of_self += 1 | |
if style_aligned_args is None or only_self_vec[i // 2]: | |
attn_procs[name] = DefaultAttentionProcessor() | |
else: | |
attn_procs[name] = SharedAttentionProcessor(style_aligned_args) | |
else: | |
number_of_cross += 1 | |
attn_procs[name] = DefaultAttentionProcessor() | |
unet.set_attn_processor(attn_procs) | |
def register_shared_norm(pipeline: StableDiffusionXLPipeline, | |
share_group_norm: bool = True, | |
share_layer_norm: bool = True, ): | |
def register_norm_forward(norm_layer: nn.GroupNorm | nn.LayerNorm) -> nn.GroupNorm | nn.LayerNorm: | |
if not hasattr(norm_layer, 'orig_forward'): | |
setattr(norm_layer, 'orig_forward', norm_layer.forward) | |
orig_forward = norm_layer.orig_forward | |
def forward_(hidden_states: T) -> T: | |
n = hidden_states.shape[-2] | |
hidden_states = concat_first(hidden_states, dim=-2) | |
hidden_states = orig_forward(hidden_states) | |
return hidden_states[..., :n, :] | |
norm_layer.forward = forward_ | |
return norm_layer | |
def get_norm_layers(pipeline_, norm_layers_: dict[str, list[nn.GroupNorm | nn.LayerNorm]]): | |
if isinstance(pipeline_, nn.LayerNorm) and share_layer_norm: | |
norm_layers_['layer'].append(pipeline_) | |
if isinstance(pipeline_, nn.GroupNorm) and share_group_norm: | |
norm_layers_['group'].append(pipeline_) | |
else: | |
for layer in pipeline_.children(): | |
get_norm_layers(layer, norm_layers_) | |
norm_layers = {'group': [], 'layer': []} | |
get_norm_layers(pipeline.unet, norm_layers) | |
return [register_norm_forward(layer) for layer in norm_layers['group']] + [register_norm_forward(layer) for layer in | |
norm_layers['layer']] | |
class Handler: | |
def register(self, style_aligned_args: StyleAlignedArgs, ): | |
self.norm_layers = register_shared_norm(self.pipeline, style_aligned_args.share_group_norm, | |
style_aligned_args.share_layer_norm) | |
init_attention_processors(self.pipeline, style_aligned_args) | |
def remove(self): | |
for layer in self.norm_layers: | |
layer.forward = layer.orig_forward | |
self.norm_layers = [] | |
init_attention_processors(self.pipeline, None) | |
def __init__(self, pipeline: StableDiffusionXLPipeline): | |
self.pipeline = pipeline | |
self.norm_layers = [] | |