|
from dataclasses import dataclass |
|
from diffusers.configuration_utils import ConfigMixin |
|
from diffusers.models.modeling_utils import ModelMixin |
|
import torch |
|
import torch.nn as nn |
|
from typing import Any, Dict, Optional, Tuple |
|
from pixcell_transformer_2d import PixCellTransformer2DModel |
|
|
|
from diffusers.models.controlnet import zero_module |
|
from diffusers.models.embeddings import PatchEmbed |
|
from diffusers.utils import BaseOutput, is_torch_version |
|
|
|
@dataclass |
|
class PixCellControlNetOutput(BaseOutput): |
|
controlnet_block_samples: Tuple[torch.Tensor] |
|
|
|
class PixCellControlNet(ModelMixin, ConfigMixin): |
|
def __init__( |
|
self, |
|
base_transformer: PixCellTransformer2DModel, |
|
n_blocks: int = None, |
|
): |
|
super().__init__() |
|
|
|
self.n_blocks = n_blocks |
|
|
|
|
|
self.transformer = base_transformer |
|
|
|
|
|
|
|
|
|
|
|
interpolation_scale = ( |
|
self.transformer.config.interpolation_scale |
|
if self.transformer.config.interpolation_scale is not None |
|
else max(self.transformer.config.sample_size // 64, 1) |
|
) |
|
self.cond_pos_embed = zero_module(PatchEmbed( |
|
height=self.transformer.config.sample_size, |
|
width=self.transformer.config.sample_size, |
|
patch_size=self.transformer.config.patch_size, |
|
in_channels=self.transformer.config.in_channels, |
|
embed_dim=self.transformer.inner_dim, |
|
interpolation_scale=interpolation_scale, |
|
)) |
|
|
|
|
|
|
|
if self.n_blocks is not None: |
|
self.transformer.transformer_blocks = self.transformer.transformer_blocks[:self.n_blocks] |
|
|
|
|
|
self.controlnet_blocks = nn.ModuleList([]) |
|
for i in range(len(self.transformer.transformer_blocks)): |
|
controlnet_block = nn.Linear(self.transformer.inner_dim, self.transformer.inner_dim) |
|
controlnet_block = zero_module(controlnet_block) |
|
self.controlnet_blocks.append(controlnet_block) |
|
|
|
if self.n_blocks is not None: |
|
if i+1 == self.n_blocks: |
|
break |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
conditioning: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
timestep: Optional[torch.LongTensor] = None, |
|
conditioning_scale: float = 1.0, |
|
added_cond_kwargs: Dict[str, torch.Tensor] = None, |
|
cross_attention_kwargs: Dict[str, Any] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
): |
|
if self.transformer.use_additional_conditions and added_cond_kwargs is None: |
|
raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None and attention_mask.ndim == 2: |
|
|
|
|
|
|
|
|
|
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 |
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
|
|
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: |
|
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 |
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
|
|
|
|
|
batch_size = hidden_states.shape[0] |
|
height, width = ( |
|
hidden_states.shape[-2] // self.transformer.config.patch_size, |
|
hidden_states.shape[-1] // self.transformer.config.patch_size, |
|
) |
|
hidden_states = self.transformer.pos_embed(hidden_states) |
|
|
|
|
|
hidden_states = hidden_states + self.cond_pos_embed(conditioning) |
|
|
|
timestep, embedded_timestep = self.transformer.adaln_single( |
|
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype |
|
) |
|
|
|
if self.transformer.caption_projection is not None: |
|
|
|
if self.transformer.y_pos_embed is not None: |
|
encoder_hidden_states = self.transformer.y_pos_embed(encoder_hidden_states) |
|
encoder_hidden_states = self.transformer.caption_projection(encoder_hidden_states) |
|
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) |
|
|
|
|
|
block_outputs = () |
|
|
|
for block in self.transformer.transformer_blocks: |
|
if torch.is_grad_enabled() and self.transformer.gradient_checkpointing: |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
timestep, |
|
cross_attention_kwargs, |
|
None, |
|
**ckpt_kwargs, |
|
) |
|
else: |
|
hidden_states = block( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
timestep=timestep, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
class_labels=None, |
|
) |
|
|
|
block_outputs = block_outputs + (hidden_states,) |
|
|
|
|
|
controlnet_outputs = () |
|
for t_output, controlnet_block in zip(block_outputs, self.controlnet_blocks): |
|
b_output = controlnet_block(t_output) |
|
controlnet_outputs = controlnet_outputs + (b_output,) |
|
|
|
controlnet_outputs = [sample * conditioning_scale for sample in controlnet_outputs] |
|
|
|
if not return_dict: |
|
return (controlnet_outputs,) |
|
|
|
return PixCellControlNetOutput(controlnet_block_samples=controlnet_outputs) |
|
|