from typing import List, Optional, Callable import torch import torch.nn.functional as F from config import RunConfig from constants import OUT_INDEX, STRUCT_INDEX, STYLE_INDEX from models.stable_diffusion import CrossImageAttentionStableDiffusionPipeline from utils import attention_utils from utils.adain import masked_adain from utils.model_utils import get_stable_diffusion_model from utils.segmentation import Segmentor class AppearanceTransferModel: def __init__(self, config: RunConfig, pipe: Optional[CrossImageAttentionStableDiffusionPipeline] = None): self.config = config self.pipe = get_stable_diffusion_model() if pipe is None else pipe self.register_attention_control() self.segmentor = Segmentor(prompt=config.prompt, object_nouns=[config.object_noun]) self.latents_app, self.latents_struct = None, None self.zs_app, self.zs_struct = None, None self.image_app_mask_32, self.image_app_mask_64 = None, None self.image_struct_mask_32, self.image_struct_mask_64 = None, None self.enable_edit = False self.step = 0 def set_latents(self, latents_app: torch.Tensor, latents_struct: torch.Tensor): self.latents_app = latents_app self.latents_struct = latents_struct def set_noise(self, zs_app: torch.Tensor, zs_struct: torch.Tensor): self.zs_app = zs_app self.zs_struct = zs_struct def set_masks(self, masks: List[torch.Tensor]): self.image_app_mask_32, self.image_struct_mask_32, self.image_app_mask_64, self.image_struct_mask_64 = masks def get_adain_callback(self): def callback(st: int, timestep: int, latents: torch.FloatTensor) -> Callable: self.step = st # Compute the masks using prompt mixing self-segmentation and use the masks for AdaIN operation if self.step == self.config.adain_range.start: masks = self.segmentor.get_object_masks() self.set_masks(masks) # Apply AdaIN operation using the computed masks if self.config.adain_range.start <= self.step < self.config.adain_range.end: latents[0] = masked_adain(latents[0], latents[1], self.image_struct_mask_64, self.image_app_mask_64) return callback def register_attention_control(self): model_self = self class AttentionProcessor: def __init__(self, place_in_unet: str): self.place_in_unet = place_in_unet if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires torch 2.0, to use it, please upgrade torch to 2.0.") def __call__(self, attn, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask=None, temb=None, perform_swap: bool = False): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) is_cross = encoder_hidden_states is not None if not is_cross: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads should_mix = False # Potentially apply our cross image attention operation # To do so, we need to be in a self-attention alyer in the decoder part of the denoising network if perform_swap and not is_cross and "up" in self.place_in_unet and model_self.enable_edit: if attention_utils.should_mix_keys_and_values(model_self, hidden_states): should_mix = True if model_self.step % 5 == 0 and model_self.step < 40: # Inject the structure's keys and values key[OUT_INDEX] = key[STRUCT_INDEX] value[OUT_INDEX] = value[STRUCT_INDEX] else: # Inject the appearance's keys and values key[OUT_INDEX] = key[STYLE_INDEX] value[OUT_INDEX] = value[STYLE_INDEX] query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # Compute the cross attention and apply our contrasting operation hidden_states, attn_weight = attention_utils.compute_scaled_dot_product_attention( query, key, value, edit_map=perform_swap and model_self.enable_edit and should_mix, is_cross=is_cross, contrast_strength=model_self.config.contrast_strength, ) # Update attention map for segmentation if model_self.config.use_masked_adain and model_self.step == model_self.config.adain_range.start - 1: model_self.segmentor.update_attention(attn_weight, is_cross) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query[OUT_INDEX].dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def register_recr(net_, count, place_in_unet): if net_.__class__.__name__ == 'ResnetBlock2D': pass if net_.__class__.__name__ == 'Attention': net_.set_processor(AttentionProcessor(place_in_unet + f"_{count + 1}")) return count + 1 elif hasattr(net_, 'children'): for net__ in net_.children(): count = register_recr(net__, count, place_in_unet) return count cross_att_count = 0 sub_nets = self.pipe.unet.named_children() for net in sub_nets: if "down" in net[0]: cross_att_count += register_recr(net[1], 0, "down") elif "up" in net[0]: cross_att_count += register_recr(net[1], 0, "up") elif "mid" in net[0]: cross_att_count += register_recr(net[1], 0, "mid")