AuraFlowTransformer2DModel
A Transformer model for image-like data from AuraFlow.
AuraFlowTransformer2DModel
class diffusers.AuraFlowTransformer2DModel
< source >( sample_size: int = 64 patch_size: int = 2 in_channels: int = 4 num_mmdit_layers: int = 4 num_single_dit_layers: int = 32 attention_head_dim: int = 256 num_attention_heads: int = 12 joint_attention_dim: int = 2048 caption_projection_dim: int = 3072 out_channels: int = 4 pos_embed_max_size: int = 1024 )
Parameters
- sample_size (
int
) — The width of the latent images. This is fixed during training since it is used to learn a number of position embeddings. - patch_size (
int
) — Patch size to turn the input data into small patches. - in_channels (
int
, optional, defaults to 16) — The number of channels in the input. - num_mmdit_layers (
int
, optional, defaults to 4) — The number of layers of MMDiT Transformer blocks to use. - num_single_dit_layers (
int
, optional, defaults to 4) — The number of layers of Transformer blocks to use. These blocks use concatenated image and text representations. - attention_head_dim (
int
, optional, defaults to 64) — The number of channels in each head. - num_attention_heads (
int
, optional, defaults to 18) — The number of heads to use for multi-head attention. - joint_attention_dim (
int
, optional) — The number ofencoder_hidden_states
dimensions to use. - caption_projection_dim (
int
) — Number of dimensions to use when projecting theencoder_hidden_states
. - out_channels (
int
, defaults to 16) — Number of output channels. - pos_embed_max_size (
int
, defaults to 4096) — Maximum positions to embed from the image latents.
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
This API is 🧪 experimental.
set_attn_processor
< source >( processor: Union )
Parameters
- processor (
dict
ofAttentionProcessor
or onlyAttentionProcessor
) — The instantiated processor class or a dictionary of processor classes that will be set as the processor for allAttention
layers.If
processor
is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.
Sets the attention processor to use to compute attention.
Disables the fused QKV projection if enabled.
This API is 🧪 experimental.