|
import logging |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from diffusers import AutoencoderKL, DDPMScheduler |
|
|
|
from leffa.diffusion_model.unet_ref import ( |
|
UNet2DConditionModel as ReferenceUNet, |
|
) |
|
from leffa.diffusion_model.unet_gen import ( |
|
UNet2DConditionModel as GenerativeUNet, |
|
) |
|
|
|
logger: logging.Logger = logging.getLogger(__name__) |
|
|
|
|
|
class LeffaModel(nn.Module): |
|
def __init__( |
|
self, |
|
pretrained_model_name_or_path: str = "", |
|
pretrained_model: str = "", |
|
new_in_channels: int = 12, |
|
height: int = 1024, |
|
width: int = 768, |
|
): |
|
super().__init__() |
|
|
|
self.height = height |
|
self.width = width |
|
|
|
self.build_models( |
|
pretrained_model_name_or_path, |
|
pretrained_model, |
|
new_in_channels, |
|
) |
|
|
|
def build_models( |
|
self, |
|
pretrained_model_name_or_path: str = "", |
|
pretrained_model: str = "", |
|
new_in_channels: int = 12, |
|
): |
|
diffusion_model_type = "" |
|
if "stable-diffusion-inpainting" in pretrained_model_name_or_path: |
|
diffusion_model_type = "sd15" |
|
elif "stable-diffusion-xl-1.0-inpainting-0.1" in pretrained_model_name_or_path: |
|
diffusion_model_type = "sdxl" |
|
|
|
|
|
self.noise_scheduler = DDPMScheduler.from_pretrained( |
|
pretrained_model_name_or_path, |
|
subfolder="scheduler", |
|
rescale_betas_zero_snr=False if diffusion_model_type == "sd15" else True, |
|
) |
|
|
|
vae_config, vae_kwargs = AutoencoderKL.load_config( |
|
pretrained_model_name_or_path, |
|
subfolder="vae", |
|
return_unused_kwargs=True, |
|
) |
|
self.vae = AutoencoderKL.from_config(vae_config, **vae_kwargs) |
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
|
|
|
unet_config, unet_kwargs = ReferenceUNet.load_config( |
|
pretrained_model_name_or_path, |
|
subfolder="unet", |
|
return_unused_kwargs=True, |
|
) |
|
self.unet_encoder = ReferenceUNet.from_config(unet_config, **unet_kwargs) |
|
self.unet_encoder.config.addition_embed_type = None |
|
|
|
unet_config, unet_kwargs = GenerativeUNet.load_config( |
|
pretrained_model_name_or_path, |
|
subfolder="unet", |
|
return_unused_kwargs=True, |
|
) |
|
self.unet = GenerativeUNet.from_config(unet_config, **unet_kwargs) |
|
self.unet.config.addition_embed_type = None |
|
|
|
unet_conv_in_channel_changed = self.unet.config.in_channels != new_in_channels |
|
if unet_conv_in_channel_changed: |
|
self.unet.conv_in = self.replace_conv_in_layer(self.unet, new_in_channels) |
|
self.unet.config.in_channels = new_in_channels |
|
unet_conv_out_channel_changed = ( |
|
self.unet.config.out_channels != self.vae.config.latent_channels |
|
) |
|
if unet_conv_out_channel_changed: |
|
self.unet.conv_out = self.replace_conv_out_layer( |
|
self.unet, self.vae.config.latent_channels |
|
) |
|
self.unet.config.out_channels = self.vae.config.latent_channels |
|
|
|
unet_encoder_conv_in_channel_changed = ( |
|
self.unet_encoder.config.in_channels != self.vae.config.latent_channels |
|
) |
|
if unet_encoder_conv_in_channel_changed: |
|
self.unet_encoder.conv_in = self.replace_conv_in_layer( |
|
self.unet_encoder, self.vae.config.latent_channels |
|
) |
|
self.unet_encoder.config.in_channels = self.vae.config.latent_channels |
|
unet_encoder_conv_out_channel_changed = ( |
|
self.unet_encoder.config.out_channels != self.vae.config.latent_channels |
|
) |
|
if unet_encoder_conv_out_channel_changed: |
|
self.unet_encoder.conv_out = self.replace_conv_out_layer( |
|
self.unet_encoder, self.vae.config.latent_channels |
|
) |
|
self.unet_encoder.config.out_channels = self.vae.config.latent_channels |
|
|
|
|
|
remove_cross_attention(self.unet) |
|
remove_cross_attention(self.unet_encoder, model_type="unet_encoder") |
|
|
|
|
|
if pretrained_model != "" and pretrained_model is not None: |
|
self.load_state_dict(torch.load(pretrained_model, map_location="cpu")) |
|
logger.info("Load pretrained model from {}".format(pretrained_model)) |
|
|
|
def replace_conv_in_layer(self, unet_model, new_in_channels): |
|
original_conv_in = unet_model.conv_in |
|
|
|
if original_conv_in.in_channels == new_in_channels: |
|
return original_conv_in |
|
|
|
new_conv_in = torch.nn.Conv2d( |
|
in_channels=new_in_channels, |
|
out_channels=original_conv_in.out_channels, |
|
kernel_size=original_conv_in.kernel_size, |
|
padding=1, |
|
) |
|
new_conv_in.weight.data.zero_() |
|
new_conv_in.bias.data = original_conv_in.bias.data.clone() |
|
if original_conv_in.in_channels < new_in_channels: |
|
new_conv_in.weight.data[:, : original_conv_in.in_channels] = ( |
|
original_conv_in.weight.data |
|
) |
|
else: |
|
new_conv_in.weight.data[:, :new_in_channels] = original_conv_in.weight.data[ |
|
:, :new_in_channels |
|
] |
|
return new_conv_in |
|
|
|
def replace_conv_out_layer(self, unet_model, new_out_channels): |
|
original_conv_out = unet_model.conv_out |
|
|
|
if original_conv_out.out_channels == new_out_channels: |
|
return original_conv_out |
|
|
|
new_conv_out = torch.nn.Conv2d( |
|
in_channels=original_conv_out.in_channels, |
|
out_channels=new_out_channels, |
|
kernel_size=original_conv_out.kernel_size, |
|
padding=1, |
|
) |
|
new_conv_out.weight.data.zero_() |
|
new_conv_out.bias.data[: original_conv_out.out_channels] = ( |
|
original_conv_out.bias.data.clone() |
|
) |
|
if original_conv_out.out_channels < new_out_channels: |
|
new_conv_out.weight.data[: original_conv_out.out_channels] = ( |
|
original_conv_out.weight.data |
|
) |
|
else: |
|
new_conv_out.weight.data[:new_out_channels] = original_conv_out.weight.data[ |
|
:new_out_channels |
|
] |
|
return new_conv_out |
|
|
|
def vae_encode(self, pixel_values): |
|
pixel_values = pixel_values.to(device=self.vae.device, dtype=self.vae.dtype) |
|
with torch.no_grad(): |
|
latent = self.vae.encode(pixel_values).latent_dist.sample() |
|
latent = latent * self.vae.config.scaling_factor |
|
return latent |
|
|
|
|
|
class SkipAttnProcessor(torch.nn.Module): |
|
def __init__(self, *args, **kwargs) -> None: |
|
super().__init__() |
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
): |
|
return hidden_states |
|
|
|
|
|
def remove_cross_attention( |
|
unet, |
|
cross_attn_cls=SkipAttnProcessor, |
|
self_attn_cls=None, |
|
cross_attn_dim=None, |
|
**kwargs, |
|
): |
|
if cross_attn_dim is None: |
|
cross_attn_dim = unet.config.cross_attention_dim |
|
attn_procs = {} |
|
for name in unet.attn_processors.keys(): |
|
cross_attention_dim = ( |
|
None if name.endswith("attn1.processor") else cross_attn_dim |
|
) |
|
if name.startswith("mid_block"): |
|
hidden_size = unet.config.block_out_channels[-1] |
|
elif name.startswith("up_blocks"): |
|
block_id = int(name[len("up_blocks.")]) |
|
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] |
|
elif name.startswith("down_blocks"): |
|
block_id = int(name[len("down_blocks.")]) |
|
hidden_size = unet.config.block_out_channels[block_id] |
|
if cross_attention_dim is None: |
|
if self_attn_cls is not None: |
|
attn_procs[name] = self_attn_cls( |
|
hidden_size=hidden_size, |
|
cross_attention_dim=cross_attention_dim, |
|
**kwargs, |
|
) |
|
else: |
|
|
|
attn_procs[name] = AttnProcessor2_0( |
|
hidden_size=hidden_size, |
|
cross_attention_dim=cross_attention_dim, |
|
layer_name=name, |
|
**kwargs, |
|
) |
|
else: |
|
attn_procs[name] = cross_attn_cls( |
|
hidden_size=hidden_size, |
|
cross_attention_dim=cross_attention_dim, |
|
**kwargs, |
|
) |
|
|
|
unet.set_attn_processor(attn_procs) |
|
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) |
|
return adapter_modules |
|
|
|
|
|
|
|
class AttnProcessor2_0(torch.nn.Module): |
|
r""" |
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). |
|
""" |
|
|
|
def __init__( |
|
self, hidden_size=None, cross_attention_dim=None, layer_name=None, **kwargs |
|
): |
|
super().__init__() |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError( |
|
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." |
|
) |
|
self.layer_name = layer_name |
|
self.model_type = kwargs.get("model_type", "none") |
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
*args, |
|
**kwargs, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view( |
|
batch_size, channel, height * width |
|
).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape |
|
if encoder_hidden_states is None |
|
else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask( |
|
attention_mask, sequence_length, batch_size |
|
) |
|
|
|
|
|
attention_mask = attention_mask.view( |
|
batch_size, attn.heads, -1, attention_mask.shape[-1] |
|
) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( |
|
1, 2 |
|
) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states( |
|
encoder_hidden_states |
|
) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape( |
|
batch_size, -1, attn.heads * head_dim |
|
) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape( |
|
batch_size, channel, height, width |
|
) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |