AWS Trainium & Inferentia documentation

Flux

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Flux

Flux is a series of text-to-image generation models based on diffusion transformers.

We recommend using a inf2.24xlarge instance with tensor parallel size 8 for the model compilation and inference.

Export to Neuron

  • Option 1: CLI
optimum-cli export neuron --model black-forest-labs/FLUX.1-dev --tensor_parallel_size 8 --batch_size 1 --height 1024 --width 1024 --num_images_per_prompt 1 --torch_dtype bfloat16 flux_dev_neuron/
  • Option 2: Python API
from optimum.neuron import NeuronFluxPipeline

if __name__ == "__main__":
    compiler_args = {"auto_cast": "none"}
    input_shapes = {"batch_size": 1, "height": 1024, "width": 1024}

    pipe = NeuronFluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16,
        export=True,
        tensor_parallel_size=8,
        **compiler_args,
        **input_shapes
    )

    # Save locally
    pipe.save_pretrained("flux_dev_neuron_1024_tp8/")

    # Upload to the HuggingFace Hub
    pipe.push_to_hub(
        "flux_dev_neuron_1024_tp8/", repository_id="Jingya/FLUX.1-dev-neuronx-1024x1024-tp8"  # Replace with your HF Hub repo id
    )

Guidance-distilled

  • The guidance-distilled variant takes about 50 sampling steps for good-quality generation.
from optimum.neuron import NeuronFluxPipeline

pipe = NeuronFluxPipeline.from_pretrained("flux_dev_neuron_1024_tp8/")
prompt = "A cat holding a sign that says hello world"
out = pipe(
    prompt,
    guidance_scale=3.5,
    num_inference_steps=50,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
out.save("flux_optimum.png")
Flux dev generated image.

Timestep-distilled

  • max_sequence_length cannot be more than 256.
  • guidance_scale needs to be 0.
  • As this is a timestep-distilled model, it benefits from fewer sampling steps.
optimum-cli export neuron --model black-forest-labs/FLUX.1-schnell --tensor_parallel_size 8 --batch_size 1 --height 1024 --width 1024 --num_images_per_prompt 1 --sequence_length 256 --torch_dtype bfloat16 flux_schnell_neuron_1024_tp8/
import torch
from optimum.neuron import NeuronFluxPipeline

pipe = NeuronFluxPipeline.from_pretrained("flux_schnell_neuron_1024_tp8")
prompt = "A cat holding a sign that says hello world"
out = pipe(prompt, max_sequence_length=256, num_inference_steps=4).images[0]
Flux schnell generated image.

NeuronFluxPipeline

The Flux pipeline for text-to-image generation.

class optimum.neuron.NeuronFluxPipeline

< >

( config: dict[str, typing.Any] configs: dict[str, 'PretrainedConfig'] neuron_configs: dict[str, 'NeuronDefaultConfig'] data_parallel_mode: typing.Literal['none', 'unet', 'transformer', 'all'] scheduler: diffusers.schedulers.scheduling_utils.SchedulerMixin | None vae_decoder: torch.jit._script.ScriptModule | NeuronModelVaeDecoder text_encoder: torch.jit._script.ScriptModule | NeuronModelTextEncoder | None = None text_encoder_2: torch.jit._script.ScriptModule | NeuronModelTextEncoder | None = None unet: torch.jit._script.ScriptModule | NeuronModelUnet | None = None transformer: torch.jit._script.ScriptModule | NeuronModelTransformer | None = None vae_encoder: torch.jit._script.ScriptModule | NeuronModelVaeEncoder | None = None image_encoder: torch.jit._script.ScriptModule | None = None safety_checker: torch.jit._script.ScriptModule | None = None tokenizer: transformers.models.clip.tokenization_clip.CLIPTokenizer | transformers.models.t5.tokenization_t5.T5Tokenizer | None = None tokenizer_2: transformers.models.clip.tokenization_clip.CLIPTokenizer | None = None feature_extractor: transformers.models.clip.feature_extraction_clip.CLIPFeatureExtractor | None = None controlnet: torch.jit._script.ScriptModule | list[torch.jit._script.ScriptModule]| NeuronControlNetModel | NeuronMultiControlNetModel | None = None requires_aesthetics_score: bool = False force_zeros_for_empty_prompt: bool = True add_watermarker: bool | None = None model_save_dir: str | pathlib.Path | tempfile.TemporaryDirectory | None = None model_and_config_save_paths: dict[str, tuple[str, pathlib.Path]] | None = None )

__call__

< >

( *args **kwargs )

Are there any other diffusion features that you want us to support in 🤗Optimum-neuron? Please file an issue to Optimum-neuron Github repo or discuss with us on HuggingFace’s community forum, cheers 🤗 !