Diffusers documentation

FluxTransformer2DModel

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

FluxTransformer2DModel

A Transformer model for image-like data from Flux.

FluxTransformer2DModel

class diffusers.FluxTransformer2DModel

< >

( patch_size: int = 1 in_channels: int = 64 num_layers: int = 19 num_single_layers: int = 38 attention_head_dim: int = 128 num_attention_heads: int = 24 joint_attention_dim: int = 4096 pooled_projection_dim: int = 768 guidance_embeds: bool = False axes_dims_rope: List = [16, 56, 56] )

Parameters

  • patch_size (int) — Patch size to turn the input data into small patches.
  • in_channels (int, optional, defaults to 16) — The number of channels in the input.
  • num_layers (int, optional, defaults to 18) — The number of layers of MMDiT blocks to use.
  • num_single_layers (int, optional, defaults to 18) — The number of layers of single DiT blocks to use.
  • attention_head_dim (int, optional, defaults to 64) — The number of channels in each head.
  • num_attention_heads (int, optional, defaults to 18) — The number of heads to use for multi-head attention.
  • joint_attention_dim (int, optional) — The number of encoder_hidden_states dimensions to use.
  • pooled_projection_dim (int) — Number of dimensions to use when projecting the pooled_projections.
  • guidance_embeds (bool, defaults to False) — Whether to use guidance embeddings.

The Transformer model introduced in Flux.

Reference: https://blackforestlabs.ai/announcing-black-forest-labs/

forward

< >

( hidden_states: Tensor encoder_hidden_states: Tensor = None pooled_projections: Tensor = None timestep: LongTensor = None img_ids: Tensor = None txt_ids: Tensor = None guidance: Tensor = None joint_attention_kwargs: Optional = None return_dict: bool = True )

Parameters

  • hidden_states (torch.FloatTensor of shape (batch size, channel, height, width)) — Input hidden_states.
  • encoder_hidden_states (torch.FloatTensor of shape (batch size, sequence_len, embed_dims)) — Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
  • pooled_projections (torch.FloatTensor of shape (batch_size, projection_dim)) — Embeddings projected from the embeddings of input conditions.
  • timestep ( torch.LongTensor) — Used to indicate denoising step. block_controlnet_hidden_states — (list of torch.Tensor): A list of tensors that if specified are added to the residuals of transformer blocks.
  • joint_attention_kwargs (dict, optional) — A kwargs dictionary that if specified is passed along to the AttentionProcessor as defined under self.processor in diffusers.models.attention_processor.
  • return_dict (bool, optional, defaults to True) — Whether or not to return a ~models.transformer_2d.Transformer2DModelOutput instead of a plain tuple.

The FluxTransformer2DModel forward method.

< > Update on GitHub