| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from typing import Any, Dict, Optional, Union, Tuple |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| | from diffusers.configuration_utils import ConfigMixin, register_to_config |
| | from diffusers.utils import is_torch_version, logging |
| | from diffusers.models.attention import BasicTransformerBlock |
| | from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0 |
| | from diffusers.models.embeddings import PatchEmbed |
| | from diffusers.models.modeling_utils import ModelMixin |
| | from diffusers.models.normalization import AdaLayerNormSingle |
| | from diffusers.models.activations import deprecate, FP32SiLU |
| |
|
| | from diffusers.models.controlnet import zero_module |
| | from diffusers.models.embeddings import PatchEmbed |
| | from dataclasses import dataclass |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | |
| | def pixcell_get_2d_sincos_pos_embed( |
| | embed_dim, |
| | grid_size, |
| | cls_token=False, |
| | extra_tokens=0, |
| | interpolation_scale=1.0, |
| | base_size=16, |
| | device: Optional[torch.device] = None, |
| | phase=0, |
| | output_type: str = "np", |
| | ): |
| | """ |
| | Creates 2D sinusoidal positional embeddings. |
| | |
| | Args: |
| | embed_dim (`int`): |
| | The embedding dimension. |
| | grid_size (`int`): |
| | The size of the grid height and width. |
| | cls_token (`bool`, defaults to `False`): |
| | Whether or not to add a classification token. |
| | extra_tokens (`int`, defaults to `0`): |
| | The number of extra tokens to add. |
| | interpolation_scale (`float`, defaults to `1.0`): |
| | The scale of the interpolation. |
| | |
| | Returns: |
| | pos_embed (`torch.Tensor`): |
| | Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size, |
| | embed_dim]` if using cls_token |
| | """ |
| | if output_type == "np": |
| | deprecation_message = ( |
| | "`get_2d_sincos_pos_embed` uses `torch` and supports `device`." |
| | " `from_numpy` is no longer required." |
| | " Pass `output_type='pt' to use the new version now." |
| | ) |
| | deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) |
| | raise ValueError("Not supported") |
| | if isinstance(grid_size, int): |
| | grid_size = (grid_size, grid_size) |
| |
|
| | grid_h = ( |
| | torch.arange(grid_size[0], device=device, dtype=torch.float32) |
| | / (grid_size[0] / base_size) |
| | / interpolation_scale |
| | ) |
| | grid_w = ( |
| | torch.arange(grid_size[1], device=device, dtype=torch.float32) |
| | / (grid_size[1] / base_size) |
| | / interpolation_scale |
| | ) |
| | grid = torch.meshgrid(grid_w, grid_h, indexing="xy") |
| | grid = torch.stack(grid, dim=0) |
| |
|
| | grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) |
| | pos_embed = pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=phase, output_type=output_type) |
| | if cls_token and extra_tokens > 0: |
| | pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0) |
| | return pos_embed |
| |
|
| |
|
| | def pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=0, output_type="np"): |
| | r""" |
| | This function generates 2D sinusoidal positional embeddings from a grid. |
| | |
| | Args: |
| | embed_dim (`int`): The embedding dimension. |
| | grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`. |
| | |
| | Returns: |
| | `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` |
| | """ |
| | if output_type == "np": |
| | deprecation_message = ( |
| | "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`." |
| | " `from_numpy` is no longer required." |
| | " Pass `output_type='pt' to use the new version now." |
| | ) |
| | deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) |
| | raise ValueError("Not supported") |
| | if embed_dim % 2 != 0: |
| | raise ValueError("embed_dim must be divisible by 2") |
| |
|
| | |
| | emb_h = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], phase=phase, output_type=output_type) |
| | emb_w = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], phase=phase, output_type=output_type) |
| |
|
| | emb = torch.concat([emb_h, emb_w], dim=1) |
| | return emb |
| |
|
| |
|
| | def pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim, pos, phase=0, output_type="np"): |
| | """ |
| | This function generates 1D positional embeddings from a grid. |
| | |
| | Args: |
| | embed_dim (`int`): The embedding dimension `D` |
| | pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)` |
| | |
| | Returns: |
| | `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`. |
| | """ |
| | if output_type == "np": |
| | deprecation_message = ( |
| | "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`." |
| | " `from_numpy` is no longer required." |
| | " Pass `output_type='pt' to use the new version now." |
| | ) |
| | deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False) |
| | raise ValueError("Not supported") |
| | if embed_dim % 2 != 0: |
| | raise ValueError("embed_dim must be divisible by 2") |
| |
|
| | omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64) |
| | omega /= embed_dim / 2.0 |
| | omega = 1.0 / 10000**omega |
| |
|
| | pos = pos.reshape(-1) + phase |
| | out = torch.outer(pos, omega) |
| |
|
| | emb_sin = torch.sin(out) |
| | emb_cos = torch.cos(out) |
| |
|
| | emb = torch.concat([emb_sin, emb_cos], dim=1) |
| | return emb |
| |
|
| |
|
| | class PixcellUNIProjection(nn.Module): |
| | """ |
| | Projects UNI embeddings. Also handles dropout for classifier-free guidance. |
| | |
| | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py |
| | """ |
| |
|
| | def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", num_tokens=1): |
| | super().__init__() |
| | if out_features is None: |
| | out_features = hidden_size |
| | self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) |
| | if act_fn == "gelu_tanh": |
| | self.act_1 = nn.GELU(approximate="tanh") |
| | elif act_fn == "silu": |
| | self.act_1 = nn.SiLU() |
| | elif act_fn == "silu_fp32": |
| | self.act_1 = FP32SiLU() |
| | else: |
| | raise ValueError(f"Unknown activation function: {act_fn}") |
| | self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) |
| |
|
| | self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features ** 0.5)) |
| |
|
| | def forward(self, caption): |
| | hidden_states = self.linear_1(caption) |
| | hidden_states = self.act_1(hidden_states) |
| | hidden_states = self.linear_2(hidden_states) |
| | return hidden_states |
| |
|
| | class UNIPosEmbed(nn.Module): |
| | """ |
| | Adds positional embeddings to the UNI conditions. |
| | |
| | Args: |
| | height (`int`, defaults to `224`): The height of the image. |
| | width (`int`, defaults to `224`): The width of the image. |
| | patch_size (`int`, defaults to `16`): The size of the patches. |
| | in_channels (`int`, defaults to `3`): The number of input channels. |
| | embed_dim (`int`, defaults to `768`): The output dimension of the embedding. |
| | layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization. |
| | flatten (`bool`, defaults to `True`): Whether or not to flatten the output. |
| | bias (`bool`, defaults to `True`): Whether or not to use bias. |
| | interpolation_scale (`float`, defaults to `1`): The scale of the interpolation. |
| | pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding. |
| | pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | height=1, |
| | width=1, |
| | base_size=16, |
| | embed_dim=768, |
| | interpolation_scale=1, |
| | pos_embed_type="sincos", |
| | ): |
| | super().__init__() |
| |
|
| | num_embeds = height*width |
| | grid_size = int(num_embeds ** 0.5) |
| |
|
| | if pos_embed_type == "sincos": |
| | y_pos_embed = pixcell_get_2d_sincos_pos_embed( |
| | embed_dim, |
| | grid_size, |
| | base_size=base_size, |
| | interpolation_scale=interpolation_scale, |
| | output_type="pt", |
| | phase = base_size // num_embeds |
| | ) |
| | self.register_buffer("y_pos_embed", y_pos_embed.float().unsqueeze(0)) |
| | else: |
| | raise ValueError("`pos_embed_type` not supported") |
| |
|
| | def forward(self, uni_embeds): |
| | return (uni_embeds + self.y_pos_embed).to(uni_embeds.dtype) |
| |
|
| | from diffusers.utils import BaseOutput, is_torch_version |
| | @dataclass |
| | class PixCellControlNetOutput(BaseOutput): |
| | controlnet_block_samples: Tuple[torch.Tensor] |
| |
|
| | class PixCellControlNet(ModelMixin, ConfigMixin): |
| | r""" |
| | A 2D Transformer ControlNet model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426, |
| | https://arxiv.org/abs/2403.04692). Modified for the pathology domain. |
| | |
| | Parameters: |
| | num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. |
| | attention_head_dim (int, optional, defaults to 72): The number of channels in each head. |
| | in_channels (int, defaults to 4): The number of channels in the input. |
| | out_channels (int, optional): |
| | The number of channels in the output. Specify this parameter if the output channel number differs from the |
| | input. |
| | num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use. |
| | dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks. |
| | norm_num_groups (int, optional, defaults to 32): |
| | Number of groups for group normalization within Transformer blocks. |
| | cross_attention_dim (int, optional): |
| | The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension. |
| | attention_bias (bool, optional, defaults to True): |
| | Configure if the Transformer blocks' attention should contain a bias parameter. |
| | sample_size (int, defaults to 128): |
| | The width of the latent images. This parameter is fixed during training. |
| | patch_size (int, defaults to 2): |
| | Size of the patches the model processes, relevant for architectures working on non-sequential data. |
| | activation_fn (str, optional, defaults to "gelu-approximate"): |
| | Activation function to use in feed-forward networks within Transformer blocks. |
| | num_embeds_ada_norm (int, optional, defaults to 1000): |
| | Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during |
| | inference. |
| | upcast_attention (bool, optional, defaults to False): |
| | If true, upcasts the attention mechanism dimensions for potentially improved performance. |
| | norm_type (str, optional, defaults to "ada_norm_zero"): |
| | Specifies the type of normalization used, can be 'ada_norm_zero'. |
| | norm_elementwise_affine (bool, optional, defaults to False): |
| | If true, enables element-wise affine parameters in the normalization layers. |
| | norm_eps (float, optional, defaults to 1e-6): |
| | A small constant added to the denominator in normalization layers to prevent division by zero. |
| | interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings. |
| | use_additional_conditions (bool, optional): If we're using additional conditions as inputs. |
| | attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used. |
| | caption_channels (int, optional, defaults to None): |
| | Number of channels to use for projecting the caption embeddings. |
| | use_linear_projection (bool, optional, defaults to False): |
| | Deprecated argument. Will be removed in a future version. |
| | num_vector_embeds (bool, optional, defaults to False): |
| | Deprecated argument. Will be removed in a future version. |
| | """ |
| |
|
| | _supports_gradient_checkpointing = True |
| | _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] |
| |
|
| | @register_to_config |
| | def __init__( |
| | self, |
| | num_attention_heads: int = 16, |
| | attention_head_dim: int = 72, |
| | in_channels: int = 4, |
| | out_channels: Optional[int] = 8, |
| | num_layers: int = 28, |
| | dropout: float = 0.0, |
| | norm_num_groups: int = 32, |
| | cross_attention_dim: Optional[int] = 1152, |
| | attention_bias: bool = True, |
| | sample_size: int = 128, |
| | patch_size: int = 2, |
| | activation_fn: str = "gelu-approximate", |
| | num_embeds_ada_norm: Optional[int] = 1000, |
| | upcast_attention: bool = False, |
| | norm_type: str = "ada_norm_single", |
| | norm_elementwise_affine: bool = False, |
| | norm_eps: float = 1e-6, |
| | interpolation_scale: Optional[int] = None, |
| | use_additional_conditions: Optional[bool] = None, |
| | caption_channels: Optional[int] = None, |
| | caption_num_tokens: int = 1, |
| | attention_type: Optional[str] = "default", |
| | n_controlnet_blocks: Optional[int] = 28, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | if norm_type != "ada_norm_single": |
| | raise NotImplementedError( |
| | f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." |
| | ) |
| | elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None: |
| | raise ValueError( |
| | f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." |
| | ) |
| |
|
| | |
| | self.attention_head_dim = attention_head_dim |
| | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim |
| | self.out_channels = in_channels if out_channels is None else out_channels |
| | if use_additional_conditions is None: |
| | if sample_size == 128: |
| | use_additional_conditions = True |
| | else: |
| | use_additional_conditions = False |
| | self.use_additional_conditions = use_additional_conditions |
| |
|
| | self.gradient_checkpointing = False |
| |
|
| | |
| | self.height = self.config.sample_size |
| | self.width = self.config.sample_size |
| |
|
| | interpolation_scale = ( |
| | self.config.interpolation_scale |
| | if self.config.interpolation_scale is not None |
| | else max(self.config.sample_size // 64, 1) |
| | ) |
| | self.pos_embed = PatchEmbed( |
| | height=self.config.sample_size, |
| | width=self.config.sample_size, |
| | patch_size=self.config.patch_size, |
| | in_channels=self.config.in_channels, |
| | embed_dim=self.inner_dim, |
| | interpolation_scale=interpolation_scale, |
| | ) |
| |
|
| | self.transformer_blocks = nn.ModuleList( |
| | [ |
| | BasicTransformerBlock( |
| | self.inner_dim, |
| | self.config.num_attention_heads, |
| | self.config.attention_head_dim, |
| | dropout=self.config.dropout, |
| | cross_attention_dim=self.config.cross_attention_dim, |
| | activation_fn=self.config.activation_fn, |
| | num_embeds_ada_norm=self.config.num_embeds_ada_norm, |
| | attention_bias=self.config.attention_bias, |
| | upcast_attention=self.config.upcast_attention, |
| | norm_type=norm_type, |
| | norm_elementwise_affine=self.config.norm_elementwise_affine, |
| | norm_eps=self.config.norm_eps, |
| | attention_type=self.config.attention_type, |
| | ) |
| | for _ in range(self.config.num_layers) |
| | ] |
| | ) |
| |
|
| | |
| | if self.config.caption_num_tokens == 1: |
| | self.y_pos_embed = None |
| | else: |
| | |
| | self.uni_height = int(self.config.caption_num_tokens ** 0.5) |
| | self.uni_width = int(self.config.caption_num_tokens ** 0.5) |
| |
|
| | self.y_pos_embed = UNIPosEmbed( |
| | height=self.uni_height, |
| | width=self.uni_width, |
| | base_size=self.config.sample_size // self.config.patch_size, |
| | embed_dim=self.config.caption_channels, |
| | interpolation_scale=2, |
| | pos_embed_type="sincos", |
| | ) |
| |
|
| | |
| | self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) |
| | self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) |
| | self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels) |
| |
|
| | self.adaln_single = AdaLayerNormSingle( |
| | self.inner_dim, use_additional_conditions=self.use_additional_conditions |
| | ) |
| | self.caption_projection = None |
| | if self.config.caption_channels is not None: |
| | self.caption_projection = PixcellUNIProjection( |
| | in_features=self.config.caption_channels, hidden_size=self.inner_dim, num_tokens=self.config.caption_num_tokens, |
| | ) |
| |
|
| |
|
| | |
| | |
| | self.cond_pos_embed = zero_module(PatchEmbed( |
| | height=self.config.sample_size, |
| | width=self.config.sample_size, |
| | patch_size=self.config.patch_size, |
| | in_channels=self.config.in_channels, |
| | embed_dim=self.inner_dim, |
| | interpolation_scale=interpolation_scale, |
| | )) |
| | |
| | self.n_controlnet_blocks = n_controlnet_blocks |
| | if self.n_controlnet_blocks is not None: |
| | self.transformer_blocks = self.transformer_blocks[:self.n_controlnet_blocks] |
| |
|
| | |
| | self.controlnet_blocks = nn.ModuleList([]) |
| | for i in range(len(self.transformer_blocks)): |
| | controlnet_block = nn.Linear(self.inner_dim, self.inner_dim) |
| | controlnet_block = zero_module(controlnet_block) |
| | self.controlnet_blocks.append(controlnet_block) |
| |
|
| | if self.n_controlnet_blocks is not None: |
| | if i+1 == self.n_controlnet_blocks: |
| | break |
| | |
| |
|
| |
|
| | def _set_gradient_checkpointing(self, module, value=False): |
| | if hasattr(module, "gradient_checkpointing"): |
| | module.gradient_checkpointing = value |
| |
|
| | @property |
| | |
| | def attn_processors(self) -> Dict[str, AttentionProcessor]: |
| | r""" |
| | Returns: |
| | `dict` of attention processors: A dictionary containing all attention processors used in the model with |
| | indexed by its weight name. |
| | """ |
| | |
| | processors = {} |
| |
|
| | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
| | if hasattr(module, "get_processor"): |
| | processors[f"{name}.processor"] = module.get_processor() |
| |
|
| | for sub_name, child in module.named_children(): |
| | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
| |
|
| | return processors |
| |
|
| | for name, module in self.named_children(): |
| | fn_recursive_add_processors(name, module, processors) |
| |
|
| | return processors |
| |
|
| | |
| | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
| | r""" |
| | Sets the attention processor to use to compute attention. |
| | |
| | Parameters: |
| | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
| | The instantiated processor class or a dictionary of processor classes that will be set as the processor |
| | for **all** `Attention` layers. |
| | |
| | If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
| | processor. This is strongly recommended when setting trainable attention processors. |
| | |
| | """ |
| | count = len(self.attn_processors.keys()) |
| |
|
| | if isinstance(processor, dict) and len(processor) != count: |
| | raise ValueError( |
| | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
| | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
| | ) |
| |
|
| | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
| | if hasattr(module, "set_processor"): |
| | if not isinstance(processor, dict): |
| | module.set_processor(processor) |
| | else: |
| | module.set_processor(processor.pop(f"{name}.processor")) |
| |
|
| | for sub_name, child in module.named_children(): |
| | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
| |
|
| | for name, module in self.named_children(): |
| | fn_recursive_attn_processor(name, module, processor) |
| |
|
| | def set_default_attn_processor(self): |
| | """ |
| | Disables custom attention processors and sets the default attention implementation. |
| | |
| | Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model. |
| | """ |
| | self.set_attn_processor(AttnProcessor()) |
| |
|
| | |
| | def fuse_qkv_projections(self): |
| | """ |
| | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) |
| | are fused. For cross-attention modules, key and value projection matrices are fused. |
| | |
| | <Tip warning={true}> |
| | |
| | This API is 🧪 experimental. |
| | |
| | </Tip> |
| | """ |
| | self.original_attn_processors = None |
| |
|
| | for _, attn_processor in self.attn_processors.items(): |
| | if "Added" in str(attn_processor.__class__.__name__): |
| | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") |
| |
|
| | self.original_attn_processors = self.attn_processors |
| |
|
| | for module in self.modules(): |
| | if isinstance(module, Attention): |
| | module.fuse_projections(fuse=True) |
| |
|
| | self.set_attn_processor(FusedAttnProcessor2_0()) |
| |
|
| | |
| | def unfuse_qkv_projections(self): |
| | """Disables the fused QKV projection if enabled. |
| | |
| | <Tip warning={true}> |
| | |
| | This API is 🧪 experimental. |
| | |
| | </Tip> |
| | |
| | """ |
| | if self.original_attn_processors is not None: |
| | self.set_attn_processor(self.original_attn_processors) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | conditioning: torch.Tensor, |
| | encoder_hidden_states: Optional[torch.Tensor] = None, |
| | timestep: Optional[torch.LongTensor] = None, |
| | conditioning_scale: float = 1.0, |
| | added_cond_kwargs: Dict[str, torch.Tensor] = None, |
| | cross_attention_kwargs: Dict[str, Any] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | encoder_attention_mask: Optional[torch.Tensor] = None, |
| | return_dict: bool = True, |
| | ): |
| | if self.use_additional_conditions and added_cond_kwargs is None: |
| | raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if attention_mask is not None and attention_mask.ndim == 2: |
| | |
| | |
| | |
| | |
| | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 |
| | attention_mask = attention_mask.unsqueeze(1) |
| |
|
| | |
| | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: |
| | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 |
| | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
| |
|
| | |
| | batch_size = hidden_states.shape[0] |
| | height, width = ( |
| | hidden_states.shape[-2] // self.config.patch_size, |
| | hidden_states.shape[-1] // self.config.patch_size, |
| | ) |
| | hidden_states = self.pos_embed(hidden_states) |
| |
|
| | |
| | hidden_states = hidden_states + self.cond_pos_embed(conditioning) |
| |
|
| | timestep, embedded_timestep = self.adaln_single( |
| | timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype |
| | ) |
| |
|
| | if self.caption_projection is not None: |
| | |
| | if self.y_pos_embed is not None: |
| | encoder_hidden_states = self.y_pos_embed(encoder_hidden_states) |
| | encoder_hidden_states = self.caption_projection(encoder_hidden_states) |
| | encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) |
| |
|
| | |
| | block_outputs = () |
| |
|
| | for block in self.transformer_blocks: |
| | if torch.is_grad_enabled() and self.gradient_checkpointing: |
| |
|
| | def create_custom_forward(module, return_dict=None): |
| | def custom_forward(*inputs): |
| | if return_dict is not None: |
| | return module(*inputs, return_dict=return_dict) |
| | else: |
| | return module(*inputs) |
| |
|
| | return custom_forward |
| |
|
| | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
| | hidden_states = torch.utils.checkpoint.checkpoint( |
| | create_custom_forward(block), |
| | hidden_states, |
| | attention_mask, |
| | encoder_hidden_states, |
| | encoder_attention_mask, |
| | timestep, |
| | cross_attention_kwargs, |
| | None, |
| | **ckpt_kwargs, |
| | ) |
| | else: |
| | hidden_states = block( |
| | hidden_states, |
| | attention_mask=attention_mask, |
| | encoder_hidden_states=encoder_hidden_states, |
| | encoder_attention_mask=encoder_attention_mask, |
| | timestep=timestep, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | class_labels=None, |
| | ) |
| | |
| | block_outputs = block_outputs + (hidden_states,) |
| |
|
| | |
| | controlnet_outputs = () |
| | for t_output, controlnet_block in zip(block_outputs, self.controlnet_blocks): |
| | b_output = controlnet_block(t_output) |
| | controlnet_outputs = controlnet_outputs + (b_output,) |
| |
|
| | controlnet_outputs = [sample * conditioning_scale for sample in controlnet_outputs] |
| |
|
| | if not return_dict: |
| | return (controlnet_outputs,) |
| |
|
| | return PixCellControlNetOutput(controlnet_block_samples=controlnet_outputs) |
| |
|
| |
|