Spaces:
Runtime error
Runtime error
from typing import List, Optional, Callable | |
import torch | |
import torch.nn.functional as F | |
from config import RunConfig | |
from constants import OUT_INDEX, STRUCT_INDEX, STYLE_INDEX | |
from models.stable_diffusion import CrossImageAttentionStableDiffusionPipeline | |
from utils import attention_utils | |
from utils.adain import masked_adain | |
from utils.model_utils import get_stable_diffusion_model | |
from utils.segmentation import Segmentor | |
class AppearanceTransferModel: | |
def __init__(self, config: RunConfig, pipe: Optional[CrossImageAttentionStableDiffusionPipeline] = None): | |
self.config = config | |
self.pipe = get_stable_diffusion_model() if pipe is None else pipe | |
self.register_attention_control() | |
self.segmentor = Segmentor(prompt=config.prompt, object_nouns=[config.object_noun]) | |
self.latents_app, self.latents_struct = None, None | |
self.zs_app, self.zs_struct = None, None | |
self.image_app_mask_32, self.image_app_mask_64 = None, None | |
self.image_struct_mask_32, self.image_struct_mask_64 = None, None | |
self.enable_edit = False | |
self.step = 0 | |
def set_latents(self, latents_app: torch.Tensor, latents_struct: torch.Tensor): | |
self.latents_app = latents_app | |
self.latents_struct = latents_struct | |
def set_noise(self, zs_app: torch.Tensor, zs_struct: torch.Tensor): | |
self.zs_app = zs_app | |
self.zs_struct = zs_struct | |
def set_masks(self, masks: List[torch.Tensor]): | |
self.image_app_mask_32, self.image_struct_mask_32, self.image_app_mask_64, self.image_struct_mask_64 = masks | |
def get_adain_callback(self): | |
def callback(st: int, timestep: int, latents: torch.FloatTensor) -> Callable: | |
self.step = st | |
# Compute the masks using prompt mixing self-segmentation and use the masks for AdaIN operation | |
if self.step == self.config.adain_range.start: | |
masks = self.segmentor.get_object_masks() | |
self.set_masks(masks) | |
# Apply AdaIN operation using the computed masks | |
if self.config.adain_range.start <= self.step < self.config.adain_range.end: | |
latents[0] = masked_adain(latents[0], latents[1], self.image_struct_mask_64, self.image_app_mask_64) | |
return callback | |
def register_attention_control(self): | |
model_self = self | |
class AttentionProcessor: | |
def __init__(self, place_in_unet: str): | |
self.place_in_unet = place_in_unet | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError("AttnProcessor2_0 requires torch 2.0, to use it, please upgrade torch to 2.0.") | |
def __call__(self, | |
attn, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
attention_mask=None, | |
temb=None, | |
perform_swap: bool = False): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
is_cross = encoder_hidden_states is not None | |
if not is_cross: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
should_mix = False | |
# Potentially apply our cross image attention operation | |
# To do so, we need to be in a self-attention alyer in the decoder part of the denoising network | |
if perform_swap and not is_cross and "up" in self.place_in_unet and model_self.enable_edit: | |
if attention_utils.should_mix_keys_and_values(model_self, hidden_states): | |
should_mix = True | |
if model_self.step % 5 == 0 and model_self.step < 40: | |
# Inject the structure's keys and values | |
key[OUT_INDEX] = key[STRUCT_INDEX] | |
value[OUT_INDEX] = value[STRUCT_INDEX] | |
else: | |
# Inject the appearance's keys and values | |
key[OUT_INDEX] = key[STYLE_INDEX] | |
value[OUT_INDEX] = value[STYLE_INDEX] | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# Compute the cross attention and apply our contrasting operation | |
hidden_states, attn_weight = attention_utils.compute_scaled_dot_product_attention( | |
query, key, value, | |
edit_map=perform_swap and model_self.enable_edit and should_mix, | |
is_cross=is_cross, | |
contrast_strength=model_self.config.contrast_strength, | |
) | |
# Update attention map for segmentation | |
if model_self.config.use_masked_adain and model_self.step == model_self.config.adain_range.start - 1: | |
model_self.segmentor.update_attention(attn_weight, is_cross) | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
hidden_states = hidden_states.to(query[OUT_INDEX].dtype) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
def register_recr(net_, count, place_in_unet): | |
if net_.__class__.__name__ == 'ResnetBlock2D': | |
pass | |
if net_.__class__.__name__ == 'Attention': | |
net_.set_processor(AttentionProcessor(place_in_unet + f"_{count + 1}")) | |
return count + 1 | |
elif hasattr(net_, 'children'): | |
for net__ in net_.children(): | |
count = register_recr(net__, count, place_in_unet) | |
return count | |
cross_att_count = 0 | |
sub_nets = self.pipe.unet.named_children() | |
for net in sub_nets: | |
if "down" in net[0]: | |
cross_att_count += register_recr(net[1], 0, "down") | |
elif "up" in net[0]: | |
cross_att_count += register_recr(net[1], 0, "up") | |
elif "mid" in net[0]: | |
cross_att_count += register_recr(net[1], 0, "mid") | |