LatteTransformer3DModel
A Diffusion Transformer model for 3D data from Latte.
LatteTransformer3DModel
class diffusers.LatteTransformer3DModel
< source >( num_attention_heads: int = 16 attention_head_dim: int = 88 in_channels: typing.Optional[int] = None out_channels: typing.Optional[int] = None num_layers: int = 1 dropout: float = 0.0 cross_attention_dim: typing.Optional[int] = None attention_bias: bool = False sample_size: int = 64 patch_size: typing.Optional[int] = None activation_fn: str = 'geglu' num_embeds_ada_norm: typing.Optional[int] = None norm_type: str = 'layer_norm' norm_elementwise_affine: bool = True norm_eps: float = 1e-05 caption_channels: int = None video_length: int = 16 )
forward
< source >( hidden_states: Tensor timestep: typing.Optional[torch.LongTensor] = None encoder_hidden_states: typing.Optional[torch.Tensor] = None encoder_attention_mask: typing.Optional[torch.Tensor] = None enable_temporal_attentions: bool = True return_dict: bool = True )
Parameters
- hidden_states shape
(batch size, channel, num_frame, height, width)
— Inputhidden_states
. - timestep (
torch.LongTensor
, optional) — Used to indicate denoising step. Optional timestep to be applied as an embedding inAdaLayerNorm
. - encoder_hidden_states (
torch.FloatTensor
of shape(batch size, sequence len, embed dims)
, optional) — Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. - encoder_attention_mask (
torch.Tensor
, optional) — Cross-attention mask applied toencoder_hidden_states
. Two formats supported:- Mask
(batcheight, sequence_length)
True = keep, False = discard. - Bias
(batcheight, 1, sequence_length)
0 = keep, -10000 = discard.
If
ndim == 2
: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. - Mask
- enable_temporal_attentions —
(
bool
, optional, defaults toTrue
): Whether to enable temporal attentions. - return_dict (
bool
, optional, defaults toTrue
) — Whether or not to return a~models.unet_2d_condition.UNet2DConditionOutput
instead of a plain tuple.
The LatteTransformer3DModel forward method.