cross-image-attention / appearance_transfer_model.py
yuvalalaluf's picture
initial commit
82ef366
from typing import List, Optional, Callable
import torch
import torch.nn.functional as F
from config import RunConfig
from constants import OUT_INDEX, STRUCT_INDEX, STYLE_INDEX
from models.stable_diffusion import CrossImageAttentionStableDiffusionPipeline
from utils import attention_utils
from utils.adain import masked_adain
from utils.model_utils import get_stable_diffusion_model
from utils.segmentation import Segmentor
class AppearanceTransferModel:
def __init__(self, config: RunConfig, pipe: Optional[CrossImageAttentionStableDiffusionPipeline] = None):
self.config = config
self.pipe = get_stable_diffusion_model() if pipe is None else pipe
self.register_attention_control()
self.segmentor = Segmentor(prompt=config.prompt, object_nouns=[config.object_noun])
self.latents_app, self.latents_struct = None, None
self.zs_app, self.zs_struct = None, None
self.image_app_mask_32, self.image_app_mask_64 = None, None
self.image_struct_mask_32, self.image_struct_mask_64 = None, None
self.enable_edit = False
self.step = 0
def set_latents(self, latents_app: torch.Tensor, latents_struct: torch.Tensor):
self.latents_app = latents_app
self.latents_struct = latents_struct
def set_noise(self, zs_app: torch.Tensor, zs_struct: torch.Tensor):
self.zs_app = zs_app
self.zs_struct = zs_struct
def set_masks(self, masks: List[torch.Tensor]):
self.image_app_mask_32, self.image_struct_mask_32, self.image_app_mask_64, self.image_struct_mask_64 = masks
def get_adain_callback(self):
def callback(st: int, timestep: int, latents: torch.FloatTensor) -> Callable:
self.step = st
# Compute the masks using prompt mixing self-segmentation and use the masks for AdaIN operation
if self.step == self.config.adain_range.start:
masks = self.segmentor.get_object_masks()
self.set_masks(masks)
# Apply AdaIN operation using the computed masks
if self.config.adain_range.start <= self.step < self.config.adain_range.end:
latents[0] = masked_adain(latents[0], latents[1], self.image_struct_mask_64, self.image_app_mask_64)
return callback
def register_attention_control(self):
model_self = self
class AttentionProcessor:
def __init__(self, place_in_unet: str):
self.place_in_unet = place_in_unet
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires torch 2.0, to use it, please upgrade torch to 2.0.")
def __call__(self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask=None,
temb=None,
perform_swap: bool = False):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
is_cross = encoder_hidden_states is not None
if not is_cross:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
should_mix = False
# Potentially apply our cross image attention operation
# To do so, we need to be in a self-attention alyer in the decoder part of the denoising network
if perform_swap and not is_cross and "up" in self.place_in_unet and model_self.enable_edit:
if attention_utils.should_mix_keys_and_values(model_self, hidden_states):
should_mix = True
if model_self.step % 5 == 0 and model_self.step < 40:
# Inject the structure's keys and values
key[OUT_INDEX] = key[STRUCT_INDEX]
value[OUT_INDEX] = value[STRUCT_INDEX]
else:
# Inject the appearance's keys and values
key[OUT_INDEX] = key[STYLE_INDEX]
value[OUT_INDEX] = value[STYLE_INDEX]
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Compute the cross attention and apply our contrasting operation
hidden_states, attn_weight = attention_utils.compute_scaled_dot_product_attention(
query, key, value,
edit_map=perform_swap and model_self.enable_edit and should_mix,
is_cross=is_cross,
contrast_strength=model_self.config.contrast_strength,
)
# Update attention map for segmentation
if model_self.config.use_masked_adain and model_self.step == model_self.config.adain_range.start - 1:
model_self.segmentor.update_attention(attn_weight, is_cross)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query[OUT_INDEX].dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def register_recr(net_, count, place_in_unet):
if net_.__class__.__name__ == 'ResnetBlock2D':
pass
if net_.__class__.__name__ == 'Attention':
net_.set_processor(AttentionProcessor(place_in_unet + f"_{count + 1}"))
return count + 1
elif hasattr(net_, 'children'):
for net__ in net_.children():
count = register_recr(net__, count, place_in_unet)
return count
cross_att_count = 0
sub_nets = self.pipe.unet.named_children()
for net in sub_nets:
if "down" in net[0]:
cross_att_count += register_recr(net[1], 0, "down")
elif "up" in net[0]:
cross_att_count += register_recr(net[1], 0, "up")
elif "mid" in net[0]:
cross_att_count += register_recr(net[1], 0, "mid")