FluxTransformer2DModel
A Transformer model for image-like data from Flux.
FluxTransformer2DModel
class diffusers.FluxTransformer2DModel
< source >( patch_size: int = 1 in_channels: int = 64 num_layers: int = 19 num_single_layers: int = 38 attention_head_dim: int = 128 num_attention_heads: int = 24 joint_attention_dim: int = 4096 pooled_projection_dim: int = 768 guidance_embeds: bool = False axes_dims_rope: Tuple = (16, 56, 56) )
Parameters
- patch_size (
int
) — Patch size to turn the input data into small patches. - in_channels (
int
, optional, defaults to 16) — The number of channels in the input. - num_layers (
int
, optional, defaults to 18) — The number of layers of MMDiT blocks to use. - num_single_layers (
int
, optional, defaults to 18) — The number of layers of single DiT blocks to use. - attention_head_dim (
int
, optional, defaults to 64) — The number of channels in each head. - num_attention_heads (
int
, optional, defaults to 18) — The number of heads to use for multi-head attention. - joint_attention_dim (
int
, optional) — The number ofencoder_hidden_states
dimensions to use. - pooled_projection_dim (
int
) — Number of dimensions to use when projecting thepooled_projections
. - guidance_embeds (
bool
, defaults to False) — Whether to use guidance embeddings.
The Transformer model introduced in Flux.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
forward
< source >( hidden_states: Tensor encoder_hidden_states: Tensor = None pooled_projections: Tensor = None timestep: LongTensor = None img_ids: Tensor = None txt_ids: Tensor = None guidance: Tensor = None joint_attention_kwargs: Optional = None controlnet_block_samples = None controlnet_single_block_samples = None return_dict: bool = True controlnet_blocks_repeat: bool = False )
Parameters
- hidden_states (
torch.FloatTensor
of shape(batch size, channel, height, width)
) — Inputhidden_states
. - encoder_hidden_states (
torch.FloatTensor
of shape(batch size, sequence_len, embed_dims)
) — Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (
torch.FloatTensor
of shape(batch_size, projection_dim)
) — Embeddings projected from the embeddings of input conditions. - timestep (
torch.LongTensor
) — Used to indicate denoising step. block_controlnet_hidden_states — (list
oftorch.Tensor
): A list of tensors that if specified are added to the residuals of transformer blocks. - joint_attention_kwargs (
dict
, optional) — A kwargs dictionary that if specified is passed along to theAttentionProcessor
as defined underself.processor
in diffusers.models.attention_processor. - return_dict (
bool
, optional, defaults toTrue
) — Whether or not to return a~models.transformer_2d.Transformer2DModelOutput
instead of a plain tuple.
The FluxTransformer2DModel forward method.
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
This API is 🧪 experimental.
set_attn_processor
< source >( processor: Union )
Parameters
- processor (
dict
ofAttentionProcessor
or onlyAttentionProcessor
) — The instantiated processor class or a dictionary of processor classes that will be set as the processor for allAttention
layers.If
processor
is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.
Sets the attention processor to use to compute attention.
Disables the fused QKV projection if enabled.
This API is 🧪 experimental.