|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Dict, Optional |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
from diffusers.models.embeddings import ImagePositionalEmbeddings |
|
from diffusers.utils import BaseOutput, deprecate, maybe_allow_in_graph |
|
from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention |
|
from diffusers.models.embeddings import PatchEmbed |
|
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear |
|
from diffusers.models.modeling_utils import ModelMixin |
|
from diffusers.utils.import_utils import is_xformers_available |
|
|
|
from einops import rearrange, repeat |
|
import pdb |
|
import random |
|
|
|
|
|
if is_xformers_available(): |
|
import xformers |
|
import xformers.ops |
|
else: |
|
xformers = None |
|
|
|
def my_repeat(tensor, num_repeats): |
|
""" |
|
Repeat a tensor along a given dimension |
|
""" |
|
if len(tensor.shape) == 3: |
|
return repeat(tensor, "b d c -> (b v) d c", v=num_repeats) |
|
elif len(tensor.shape) == 4: |
|
return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats) |
|
|
|
|
|
@dataclass |
|
class TransformerMV2DModelOutput(BaseOutput): |
|
""" |
|
The output of [`Transformer2DModel`]. |
|
|
|
Args: |
|
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): |
|
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability |
|
distributions for the unnoised latent pixels. |
|
""" |
|
|
|
sample: torch.FloatTensor |
|
|
|
|
|
class TransformerMV2DModel(ModelMixin, ConfigMixin): |
|
""" |
|
A 2D Transformer model for image-like data. |
|
|
|
Parameters: |
|
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. |
|
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. |
|
in_channels (`int`, *optional*): |
|
The number of channels in the input and output (specify if the input is **continuous**). |
|
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. |
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
|
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. |
|
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). |
|
This is fixed during training since it is used to learn a number of position embeddings. |
|
num_vector_embeds (`int`, *optional*): |
|
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). |
|
Includes the class for the masked latent pixel. |
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. |
|
num_embeds_ada_norm ( `int`, *optional*): |
|
The number of diffusion steps used during training. Pass if at least one of the norm_layers is |
|
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are |
|
added to the hidden states. |
|
|
|
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. |
|
attention_bias (`bool`, *optional*): |
|
Configure if the `TransformerBlocks` attention should contain a bias parameter. |
|
""" |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
num_attention_heads: int = 16, |
|
attention_head_dim: int = 88, |
|
in_channels: Optional[int] = None, |
|
out_channels: Optional[int] = None, |
|
num_layers: int = 1, |
|
dropout: float = 0.0, |
|
norm_num_groups: int = 32, |
|
cross_attention_dim: Optional[int] = None, |
|
attention_bias: bool = False, |
|
sample_size: Optional[int] = None, |
|
num_vector_embeds: Optional[int] = None, |
|
patch_size: Optional[int] = None, |
|
activation_fn: str = "geglu", |
|
num_embeds_ada_norm: Optional[int] = None, |
|
use_linear_projection: bool = False, |
|
only_cross_attention: bool = False, |
|
upcast_attention: bool = False, |
|
norm_type: str = "layer_norm", |
|
norm_elementwise_affine: bool = True, |
|
num_views: int = 1, |
|
cd_attention_last: bool=False, |
|
cd_attention_mid: bool=False, |
|
multiview_attention: bool=True, |
|
sparse_mv_attention: bool = False, |
|
mvcd_attention: bool=False |
|
): |
|
super().__init__() |
|
self.use_linear_projection = use_linear_projection |
|
self.num_attention_heads = num_attention_heads |
|
self.attention_head_dim = attention_head_dim |
|
inner_dim = num_attention_heads * attention_head_dim |
|
|
|
|
|
|
|
self.is_input_continuous = (in_channels is not None) and (patch_size is None) |
|
self.is_input_vectorized = num_vector_embeds is not None |
|
self.is_input_patches = in_channels is not None and patch_size is not None |
|
|
|
if norm_type == "layer_norm" and num_embeds_ada_norm is not None: |
|
deprecation_message = ( |
|
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" |
|
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." |
|
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" |
|
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" |
|
" would be very nice if you could open a Pull request for the `transformer/config.json` file" |
|
) |
|
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) |
|
norm_type = "ada_norm" |
|
|
|
if self.is_input_continuous and self.is_input_vectorized: |
|
raise ValueError( |
|
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" |
|
" sure that either `in_channels` or `num_vector_embeds` is None." |
|
) |
|
elif self.is_input_vectorized and self.is_input_patches: |
|
raise ValueError( |
|
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" |
|
" sure that either `num_vector_embeds` or `num_patches` is None." |
|
) |
|
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: |
|
raise ValueError( |
|
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" |
|
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." |
|
) |
|
|
|
|
|
if self.is_input_continuous: |
|
self.in_channels = in_channels |
|
|
|
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) |
|
if use_linear_projection: |
|
self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) |
|
else: |
|
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) |
|
elif self.is_input_vectorized: |
|
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" |
|
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" |
|
|
|
self.height = sample_size |
|
self.width = sample_size |
|
self.num_vector_embeds = num_vector_embeds |
|
self.num_latent_pixels = self.height * self.width |
|
|
|
self.latent_image_embedding = ImagePositionalEmbeddings( |
|
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width |
|
) |
|
elif self.is_input_patches: |
|
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" |
|
|
|
self.height = sample_size |
|
self.width = sample_size |
|
|
|
self.patch_size = patch_size |
|
self.pos_embed = PatchEmbed( |
|
height=sample_size, |
|
width=sample_size, |
|
patch_size=patch_size, |
|
in_channels=in_channels, |
|
embed_dim=inner_dim, |
|
) |
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
[ |
|
BasicMVTransformerBlock( |
|
inner_dim, |
|
num_attention_heads, |
|
attention_head_dim, |
|
dropout=dropout, |
|
cross_attention_dim=cross_attention_dim, |
|
activation_fn=activation_fn, |
|
num_embeds_ada_norm=num_embeds_ada_norm, |
|
attention_bias=attention_bias, |
|
only_cross_attention=only_cross_attention, |
|
upcast_attention=upcast_attention, |
|
norm_type=norm_type, |
|
norm_elementwise_affine=norm_elementwise_affine, |
|
num_views=num_views, |
|
cd_attention_last=cd_attention_last, |
|
cd_attention_mid=cd_attention_mid, |
|
multiview_attention=multiview_attention, |
|
sparse_mv_attention=sparse_mv_attention, |
|
mvcd_attention=mvcd_attention |
|
) |
|
for d in range(num_layers) |
|
] |
|
) |
|
|
|
|
|
self.out_channels = in_channels if out_channels is None else out_channels |
|
if self.is_input_continuous: |
|
|
|
if use_linear_projection: |
|
self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) |
|
else: |
|
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) |
|
elif self.is_input_vectorized: |
|
self.norm_out = nn.LayerNorm(inner_dim) |
|
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) |
|
elif self.is_input_patches: |
|
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) |
|
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) |
|
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
timestep: Optional[torch.LongTensor] = None, |
|
class_labels: Optional[torch.LongTensor] = None, |
|
cross_attention_kwargs: Dict[str, Any] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
): |
|
""" |
|
The [`Transformer2DModel`] forward method. |
|
|
|
Args: |
|
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): |
|
Input `hidden_states`. |
|
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): |
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to |
|
self-attention. |
|
timestep ( `torch.LongTensor`, *optional*): |
|
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. |
|
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): |
|
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in |
|
`AdaLayerZeroNorm`. |
|
encoder_attention_mask ( `torch.Tensor`, *optional*): |
|
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: |
|
|
|
* Mask `(batch, sequence_length)` True = keep, False = discard. |
|
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. |
|
|
|
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format |
|
above. This bias will be added to the cross-attention scores. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain |
|
tuple. |
|
|
|
Returns: |
|
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a |
|
`tuple` where the first element is the sample tensor. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None and attention_mask.ndim == 2: |
|
|
|
|
|
|
|
|
|
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 |
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
|
|
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: |
|
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 |
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
|
|
|
|
|
if self.is_input_continuous: |
|
batch, _, height, width = hidden_states.shape |
|
residual = hidden_states |
|
|
|
hidden_states = self.norm(hidden_states) |
|
if not self.use_linear_projection: |
|
hidden_states = self.proj_in(hidden_states) |
|
inner_dim = hidden_states.shape[1] |
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) |
|
else: |
|
inner_dim = hidden_states.shape[1] |
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) |
|
hidden_states = self.proj_in(hidden_states) |
|
elif self.is_input_vectorized: |
|
hidden_states = self.latent_image_embedding(hidden_states) |
|
elif self.is_input_patches: |
|
hidden_states = self.pos_embed(hidden_states) |
|
|
|
|
|
for block in self.transformer_blocks: |
|
hidden_states = block( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
timestep=timestep, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
class_labels=class_labels, |
|
) |
|
|
|
|
|
if self.is_input_continuous: |
|
if not self.use_linear_projection: |
|
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() |
|
hidden_states = self.proj_out(hidden_states) |
|
else: |
|
hidden_states = self.proj_out(hidden_states) |
|
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() |
|
|
|
output = hidden_states + residual |
|
elif self.is_input_vectorized: |
|
hidden_states = self.norm_out(hidden_states) |
|
logits = self.out(hidden_states) |
|
|
|
logits = logits.permute(0, 2, 1) |
|
|
|
|
|
output = F.log_softmax(logits.double(), dim=1).float() |
|
elif self.is_input_patches: |
|
|
|
conditioning = self.transformer_blocks[0].norm1.emb( |
|
timestep, class_labels, hidden_dtype=hidden_states.dtype |
|
) |
|
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) |
|
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] |
|
hidden_states = self.proj_out_2(hidden_states) |
|
|
|
|
|
height = width = int(hidden_states.shape[1] ** 0.5) |
|
hidden_states = hidden_states.reshape( |
|
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) |
|
) |
|
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) |
|
output = hidden_states.reshape( |
|
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) |
|
) |
|
|
|
if not return_dict: |
|
return (output,) |
|
|
|
return TransformerMV2DModelOutput(sample=output) |
|
|
|
|
|
@maybe_allow_in_graph |
|
class BasicMVTransformerBlock(nn.Module): |
|
r""" |
|
A basic Transformer block. |
|
|
|
Parameters: |
|
dim (`int`): The number of channels in the input and output. |
|
num_attention_heads (`int`): The number of heads to use for multi-head attention. |
|
attention_head_dim (`int`): The number of channels in each head. |
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
|
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. |
|
only_cross_attention (`bool`, *optional*): |
|
Whether to use only cross-attention layers. In this case two cross attention layers are used. |
|
double_self_attention (`bool`, *optional*): |
|
Whether to use two self-attention layers. In this case no cross attention layers are used. |
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. |
|
num_embeds_ada_norm (: |
|
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. |
|
attention_bias (: |
|
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
num_attention_heads: int, |
|
attention_head_dim: int, |
|
dropout=0.0, |
|
cross_attention_dim: Optional[int] = None, |
|
activation_fn: str = "geglu", |
|
num_embeds_ada_norm: Optional[int] = None, |
|
attention_bias: bool = False, |
|
only_cross_attention: bool = False, |
|
double_self_attention: bool = False, |
|
upcast_attention: bool = False, |
|
norm_elementwise_affine: bool = True, |
|
norm_type: str = "layer_norm", |
|
final_dropout: bool = False, |
|
num_views: int = 1, |
|
cd_attention_last: bool = False, |
|
cd_attention_mid: bool = False, |
|
multiview_attention: bool = True, |
|
sparse_mv_attention: bool = False, |
|
mvcd_attention: bool = False |
|
): |
|
super().__init__() |
|
self.only_cross_attention = only_cross_attention |
|
|
|
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" |
|
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" |
|
|
|
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: |
|
raise ValueError( |
|
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" |
|
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." |
|
) |
|
|
|
|
|
|
|
if self.use_ada_layer_norm: |
|
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) |
|
elif self.use_ada_layer_norm_zero: |
|
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) |
|
else: |
|
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) |
|
|
|
self.multiview_attention = multiview_attention |
|
self.sparse_mv_attention = sparse_mv_attention |
|
self.mvcd_attention = mvcd_attention |
|
|
|
self.attn1 = CustomAttention( |
|
query_dim=dim, |
|
heads=num_attention_heads, |
|
dim_head=attention_head_dim, |
|
dropout=dropout, |
|
bias=attention_bias, |
|
cross_attention_dim=cross_attention_dim if only_cross_attention else None, |
|
upcast_attention=upcast_attention, |
|
processor=MVAttnProcessor() |
|
) |
|
|
|
|
|
if cross_attention_dim is not None or double_self_attention: |
|
|
|
|
|
|
|
self.norm2 = ( |
|
AdaLayerNorm(dim, num_embeds_ada_norm) |
|
if self.use_ada_layer_norm |
|
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) |
|
) |
|
self.attn2 = Attention( |
|
query_dim=dim, |
|
cross_attention_dim=cross_attention_dim if not double_self_attention else None, |
|
heads=num_attention_heads, |
|
dim_head=attention_head_dim, |
|
dropout=dropout, |
|
bias=attention_bias, |
|
upcast_attention=upcast_attention, |
|
) |
|
else: |
|
self.norm2 = None |
|
self.attn2 = None |
|
|
|
|
|
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) |
|
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) |
|
|
|
|
|
self._chunk_size = None |
|
self._chunk_dim = 0 |
|
|
|
self.num_views = num_views |
|
|
|
self.cd_attention_last = cd_attention_last |
|
|
|
if self.cd_attention_last: |
|
|
|
self.attn_joint_last = CustomJointAttention( |
|
query_dim=dim, |
|
heads=num_attention_heads, |
|
dim_head=attention_head_dim, |
|
dropout=dropout, |
|
bias=attention_bias, |
|
cross_attention_dim=cross_attention_dim if only_cross_attention else None, |
|
upcast_attention=upcast_attention, |
|
processor=JointAttnProcessor() |
|
) |
|
nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data) |
|
self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) |
|
|
|
|
|
self.cd_attention_mid = cd_attention_mid |
|
|
|
if self.cd_attention_mid: |
|
|
|
|
|
self.attn_joint_mid = CustomJointAttention( |
|
query_dim=dim, |
|
heads=num_attention_heads, |
|
dim_head=attention_head_dim, |
|
dropout=dropout, |
|
bias=attention_bias, |
|
cross_attention_dim=cross_attention_dim if only_cross_attention else None, |
|
upcast_attention=upcast_attention, |
|
processor=JointAttnProcessor() |
|
) |
|
nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data) |
|
self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) |
|
|
|
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): |
|
|
|
self._chunk_size = chunk_size |
|
self._chunk_dim = dim |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
timestep: Optional[torch.LongTensor] = None, |
|
cross_attention_kwargs: Dict[str, Any] = None, |
|
class_labels: Optional[torch.LongTensor] = None, |
|
): |
|
assert attention_mask is None |
|
|
|
|
|
if self.use_ada_layer_norm: |
|
norm_hidden_states = self.norm1(hidden_states, timestep) |
|
elif self.use_ada_layer_norm_zero: |
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( |
|
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype |
|
) |
|
else: |
|
norm_hidden_states = self.norm1(hidden_states) |
|
|
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} |
|
|
|
attn_output = self.attn1( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, |
|
attention_mask=attention_mask, |
|
num_views=self.num_views, |
|
multiview_attention=self.multiview_attention, |
|
sparse_mv_attention=self.sparse_mv_attention, |
|
mvcd_attention=self.mvcd_attention, |
|
**cross_attention_kwargs, |
|
) |
|
|
|
|
|
if self.use_ada_layer_norm_zero: |
|
attn_output = gate_msa.unsqueeze(1) * attn_output |
|
hidden_states = attn_output + hidden_states |
|
|
|
|
|
if self.cd_attention_mid: |
|
norm_hidden_states = ( |
|
self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states) |
|
) |
|
hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states |
|
|
|
|
|
if self.attn2 is not None: |
|
norm_hidden_states = ( |
|
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) |
|
) |
|
|
|
attn_output = self.attn2( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
hidden_states = attn_output + hidden_states |
|
|
|
|
|
norm_hidden_states = self.norm3(hidden_states) |
|
|
|
if self.use_ada_layer_norm_zero: |
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
|
|
if self._chunk_size is not None: |
|
|
|
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: |
|
raise ValueError( |
|
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." |
|
) |
|
|
|
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size |
|
ff_output = torch.cat( |
|
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], |
|
dim=self._chunk_dim, |
|
) |
|
else: |
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
if self.use_ada_layer_norm_zero: |
|
ff_output = gate_mlp.unsqueeze(1) * ff_output |
|
|
|
hidden_states = ff_output + hidden_states |
|
|
|
if self.cd_attention_last: |
|
norm_hidden_states = ( |
|
self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states) |
|
) |
|
hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states |
|
|
|
return hidden_states |
|
|
|
|
|
class CustomAttention(Attention): |
|
def set_use_memory_efficient_attention_xformers( |
|
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs |
|
): |
|
processor = XFormersMVAttnProcessor() |
|
self.set_processor(processor) |
|
|
|
|
|
|
|
class CustomJointAttention(Attention): |
|
def set_use_memory_efficient_attention_xformers( |
|
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs |
|
): |
|
processor = XFormersJointAttnProcessor() |
|
self.set_processor(processor) |
|
|
|
|
|
class MVAttnProcessor: |
|
r""" |
|
Default processor for performing attention-related computations. |
|
""" |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
num_views=1, |
|
multiview_attention=True |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
if multiview_attention: |
|
key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) |
|
value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) |
|
|
|
query = attn.head_to_batch_dim(query).contiguous() |
|
key = attn.head_to_batch_dim(key).contiguous() |
|
value = attn.head_to_batch_dim(value).contiguous() |
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
hidden_states = torch.bmm(attention_probs, value) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
class XFormersMVAttnProcessor: |
|
r""" |
|
Default processor for performing attention-related computations. |
|
""" |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
num_views=1., |
|
multiview_attention=True, |
|
sparse_mv_attention=False, |
|
mvcd_attention=False, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
_, query_tokens, _ = hidden_states.shape |
|
attention_mask = attention_mask.expand(-1, query_tokens, -1) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key_raw = attn.to_k(encoder_hidden_states) |
|
value_raw = attn.to_v(encoder_hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
if multiview_attention: |
|
if not sparse_mv_attention: |
|
key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views) |
|
value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views) |
|
else: |
|
key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) |
|
value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) |
|
key = torch.cat([key_front, key_raw], dim=1) |
|
value = torch.cat([value_front, value_raw], dim=1) |
|
|
|
else: |
|
|
|
key = key_raw |
|
value = value_raw |
|
|
|
query = attn.head_to_batch_dim(query) |
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
|
|
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
|
|
class XFormersJointAttnProcessor: |
|
r""" |
|
Default processor for performing attention-related computations. |
|
""" |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
num_tasks=2 |
|
): |
|
|
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
_, query_tokens, _ = hidden_states.shape |
|
attention_mask = attention_mask.expand(-1, query_tokens, -1) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
assert num_tasks == 2 |
|
|
|
key_0, key_1 = torch.chunk(key, dim=0, chunks=2) |
|
value_0, value_1 = torch.chunk(value, dim=0, chunks=2) |
|
key = torch.cat([key_0, key_1], dim=1) |
|
value = torch.cat([value_0, value_1], dim=1) |
|
key = torch.cat([key]*2, dim=0) |
|
value = torch.cat([value]*2, dim=0) |
|
|
|
|
|
query = attn.head_to_batch_dim(query).contiguous() |
|
key = attn.head_to_batch_dim(key).contiguous() |
|
value = attn.head_to_batch_dim(value).contiguous() |
|
|
|
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
class JointAttnProcessor: |
|
r""" |
|
Default processor for performing attention-related computations. |
|
""" |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
num_tasks=2 |
|
): |
|
|
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
assert num_tasks == 2 |
|
|
|
key_0, key_1 = torch.chunk(key, dim=0, chunks=2) |
|
value_0, value_1 = torch.chunk(value, dim=0, chunks=2) |
|
key = torch.cat([key_0, key_1], dim=1) |
|
value = torch.cat([value_0, value_1], dim=1) |
|
key = torch.cat([key]*2, dim=0) |
|
value = torch.cat([value]*2, dim=0) |
|
|
|
|
|
query = attn.head_to_batch_dim(query).contiguous() |
|
key = attn.head_to_batch_dim(key).contiguous() |
|
value = attn.head_to_batch_dim(value).contiguous() |
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
hidden_states = torch.bmm(attention_probs, value) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |