|
from __future__ import annotations |
|
|
|
import logging |
|
|
|
from typing import Any, Dict, Tuple, Union, Optional |
|
from einops import rearrange, repeat |
|
from torch import nn |
|
import torch |
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
from diffusers.models.modeling_utils import ModelMixin, load_state_dict |
|
|
|
from ..data.data_util import align_repeat_tensor_single_dim |
|
|
|
from .unet_3d_condition import UNet3DConditionModel |
|
from .referencenet import ReferenceNet2D |
|
from ip_adapter.ip_adapter import ImageProjModel |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SuperUNet3DConditionModel(nn.Module): |
|
"""封装了各种子模型的超模型,与 diffusers 的 pipeline 很像,只不过这里是模型定义。 |
|
主要作用 |
|
1. 将支持controlnet、referencenet等功能的计算封装起来,简洁些; |
|
2. 便于 accelerator 的分布式训练; |
|
|
|
wrap the sub-models, such as unet, referencenet, controlnet, vae, text_encoder, tokenizer, text_emb_extractor, clip_vision_extractor, ip_adapter_image_proj |
|
1. support controlnet, referencenet, etc. |
|
2. support accelerator distributed training |
|
""" |
|
|
|
_supports_gradient_checkpointing = True |
|
print_idx = 0 |
|
|
|
|
|
def __init__( |
|
self, |
|
unet: nn.Module, |
|
referencenet: nn.Module = None, |
|
controlnet: nn.Module = None, |
|
vae: nn.Module = None, |
|
text_encoder: nn.Module = None, |
|
tokenizer: nn.Module = None, |
|
text_emb_extractor: nn.Module = None, |
|
clip_vision_extractor: nn.Module = None, |
|
ip_adapter_image_proj: nn.Module = None, |
|
) -> None: |
|
"""_summary_ |
|
|
|
Args: |
|
unet (nn.Module): _description_ |
|
referencenet (nn.Module, optional): _description_. Defaults to None. |
|
controlnet (nn.Module, optional): _description_. Defaults to None. |
|
vae (nn.Module, optional): _description_. Defaults to None. |
|
text_encoder (nn.Module, optional): _description_. Defaults to None. |
|
tokenizer (nn.Module, optional): _description_. Defaults to None. |
|
text_emb_extractor (nn.Module, optional): wrap text_encoder and tokenizer for str2emb. Defaults to None. |
|
clip_vision_extractor (nn.Module, optional): _description_. Defaults to None. |
|
""" |
|
super().__init__() |
|
self.unet = unet |
|
self.referencenet = referencenet |
|
self.controlnet = controlnet |
|
self.vae = vae |
|
self.text_encoder = text_encoder |
|
self.tokenizer = tokenizer |
|
self.text_emb_extractor = text_emb_extractor |
|
self.clip_vision_extractor = clip_vision_extractor |
|
self.ip_adapter_image_proj = ip_adapter_image_proj |
|
|
|
def forward( |
|
self, |
|
unet_params: Dict, |
|
encoder_hidden_states: torch.Tensor, |
|
referencenet_params: Dict = None, |
|
controlnet_params: Dict = None, |
|
controlnet_scale: float = 1.0, |
|
vision_clip_emb: Union[torch.Tensor, None] = None, |
|
prompt_only_use_image_prompt: bool = False, |
|
): |
|
"""_summary_ |
|
|
|
Args: |
|
unet_params (Dict): _description_ |
|
encoder_hidden_states (torch.Tensor): b t n d |
|
referencenet_params (Dict, optional): _description_. Defaults to None. |
|
controlnet_params (Dict, optional): _description_. Defaults to None. |
|
controlnet_scale (float, optional): _description_. Defaults to 1.0. |
|
vision_clip_emb (Union[torch.Tensor, None], optional): b t d. Defaults to None. |
|
prompt_only_use_image_prompt (bool, optional): _description_. Defaults to False. |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
batch_size = unet_params["sample"].shape[0] |
|
time_size = unet_params["sample"].shape[2] |
|
|
|
|
|
if vision_clip_emb is not None: |
|
|
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"vision_clip_emb, before ip_adapter_image_proj, shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}" |
|
) |
|
if vision_clip_emb.ndim == 3: |
|
vision_clip_emb = rearrange(vision_clip_emb, "b t d-> b t 1 d") |
|
if self.ip_adapter_image_proj is not None: |
|
vision_clip_emb = rearrange(vision_clip_emb, "b t n d ->(b t) n d") |
|
vision_clip_emb = self.ip_adapter_image_proj(vision_clip_emb) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"vision_clip_emb, after ip_adapter_image_proj shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}" |
|
) |
|
if vision_clip_emb.ndim == 2: |
|
vision_clip_emb = rearrange(vision_clip_emb, "b d-> b 1 d") |
|
vision_clip_emb = rearrange( |
|
vision_clip_emb, "(b t) n d -> b t n d", b=batch_size |
|
) |
|
vision_clip_emb = align_repeat_tensor_single_dim( |
|
vision_clip_emb, target_length=time_size, dim=1 |
|
) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"vision_clip_emb, after reshape shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}" |
|
) |
|
|
|
if vision_clip_emb is None and encoder_hidden_states is not None: |
|
vision_clip_emb = encoder_hidden_states |
|
if vision_clip_emb is not None and encoder_hidden_states is None: |
|
encoder_hidden_states = vision_clip_emb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.referencenet is not None and referencenet_params is not None: |
|
referencenet_encoder_hidden_states = align_repeat_tensor_single_dim( |
|
vision_clip_emb, |
|
target_length=referencenet_params["num_frames"], |
|
dim=1, |
|
) |
|
referencenet_params["encoder_hidden_states"] = rearrange( |
|
referencenet_encoder_hidden_states, "b t n d->(b t) n d" |
|
) |
|
referencenet_out = self.referencenet(**referencenet_params) |
|
( |
|
down_block_refer_embs, |
|
mid_block_refer_emb, |
|
refer_self_attn_emb, |
|
) = referencenet_out |
|
if down_block_refer_embs is not None: |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"len(down_block_refer_embs)={len(down_block_refer_embs)}" |
|
) |
|
for i, down_emb in enumerate(down_block_refer_embs): |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"down_emb, {i}, {down_emb.shape}, mean={down_emb.mean()}" |
|
) |
|
else: |
|
if self.print_idx == 0: |
|
logger.debug(f"down_block_refer_embs is None") |
|
if mid_block_refer_emb is not None: |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"mid_block_refer_emb, {mid_block_refer_emb.shape}, mean={mid_block_refer_emb.mean()}" |
|
) |
|
else: |
|
if self.print_idx == 0: |
|
logger.debug(f"mid_block_refer_emb is None") |
|
if refer_self_attn_emb is not None: |
|
if self.print_idx == 0: |
|
logger.debug(f"refer_self_attn_emb, num={len(refer_self_attn_emb)}") |
|
for i, self_attn_emb in enumerate(refer_self_attn_emb): |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"referencenet, self_attn_emb, {i}th, shape={self_attn_emb.shape}, mean={self_attn_emb.mean()}" |
|
) |
|
else: |
|
if self.print_idx == 0: |
|
logger.debug(f"refer_self_attn_emb is None") |
|
else: |
|
down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb = ( |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
if self.controlnet is not None and controlnet_params is not None: |
|
controlnet_encoder_hidden_states = align_repeat_tensor_single_dim( |
|
encoder_hidden_states, |
|
target_length=unet_params["sample"].shape[2], |
|
dim=1, |
|
) |
|
controlnet_params["encoder_hidden_states"] = rearrange( |
|
controlnet_encoder_hidden_states, " b t n d -> (b t) n d" |
|
) |
|
( |
|
down_block_additional_residuals, |
|
mid_block_additional_residual, |
|
) = self.controlnet(**controlnet_params) |
|
if controlnet_scale != 1.0: |
|
down_block_additional_residuals = [ |
|
x * controlnet_scale for x in down_block_additional_residuals |
|
] |
|
mid_block_additional_residual = ( |
|
mid_block_additional_residual * controlnet_scale |
|
) |
|
for i, down_block_additional_residual in enumerate( |
|
down_block_additional_residuals |
|
): |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"{i}, down_block_additional_residual mean={torch.mean(down_block_additional_residual)}" |
|
) |
|
|
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"mid_block_additional_residual mean={torch.mean(mid_block_additional_residual)}" |
|
) |
|
else: |
|
down_block_additional_residuals = None |
|
mid_block_additional_residual = None |
|
|
|
if prompt_only_use_image_prompt and vision_clip_emb is not None: |
|
encoder_hidden_states = vision_clip_emb |
|
|
|
|
|
out = self.unet( |
|
**unet_params, |
|
down_block_refer_embs=down_block_refer_embs, |
|
mid_block_refer_emb=mid_block_refer_emb, |
|
refer_self_attn_emb=refer_self_attn_emb, |
|
down_block_additional_residuals=down_block_additional_residuals, |
|
mid_block_additional_residual=mid_block_additional_residual, |
|
encoder_hidden_states=encoder_hidden_states, |
|
vision_clip_emb=vision_clip_emb, |
|
) |
|
self.print_idx += 1 |
|
return out |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if isinstance(module, (UNet3DConditionModel, ReferenceNet2D)): |
|
module.gradient_checkpointing = value |
|
|