YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

FLUX.2-klein-4B 30 FPS Optimizer

Optimized inference for black-forest-labs/FLUX.2-klein-4B targeting maximum FPS.

The Reality Check

FLUX.2-klein-4B is already distilled for 4-step inference and runs at roughly ~1 image/sec at 1024×1024 on an RTX 4090. 30fps = 33ms/frame, which is a 30× speedup from stock inference.

Realistic FPS Targets

Resolution Steps Hardware + Optimizations Expected FPS
1024×1024 4 RTX 4090 stock ~1 fps
512×512 4 RTX 4090 + torch.compile ~2-5 fps
256×256 4 RTX 4090 + torch.compile ~8-15 fps
256×256 1 RTX 4090 + torch.compile ~10-20 fps
256×256 4 H100 + compile + FP8 ~15-25 fps
256×256 1 H100 + TensorRT ~20-30 fps

To hit true 30fps, you need all of the following simultaneously:

  • 256×256 resolution (or smaller)
  • 1-step distilled checkpoint (would need custom distillation)
  • TensorRT / ONNX Runtime compiled engine (beyond torch.compile)
  • FlashInfer / SageAttention backend
  • Fast VAE decode optimization
  • Optional: batch processing across multiple GPUs

Quick Start

# Install dependencies
pip install -r requirements.txt

# Benchmark at 256x256 (fastest config)
python flux_klein_30fps.py --benchmark --frames 30 --resolution 256 --steps 4

# Try 1-step for maximum speed (lower quality)
python flux_klein_30fps.py --benchmark --frames 30 --resolution 256 --steps 1

# Single image at 512x512
python flux_klein_30fps.py --resolution 512 --prompt "a cyberpunk street at night"

# With FP8 quantization (H100 / Ada GPUs only)
python flux_klein_30fps.py --quant fp8 --resolution 256 --benchmark --frames 10

Optimizations Applied

1. torch.compile (2-3× speedup)

pipe.transformer = torch.compile(
    pipe.transformer,
    mode="max-autotune",
    fullgraph=False,
    dynamic=False,
)

This fuses transformer blocks, eliminates Python overhead, and enables CUDA graph-like execution. First call is slow (graph capture), subsequent calls are fast.

2. bfloat16 Precision

Halves memory bandwidth and enables faster Tensor Core paths on Ampere/Hopper GPUs.

3. Resolution Reduction

Attention is O(n²) in spatial dimensions. Dropping from 512→256 gives a 4× speedup in the attention bottleneck.

4. Fewer Steps

Klein is already distilled to 4 steps. Pushing to 1 step is possible but quality degrades noticeably.

5. FP8 Quantization (H100/Ada only)

from diffusers.hooks import apply_layerwise_casting
apply_layerwise_casting(
    pipe.transformer,
    storage_dtype=torch.float8_e4m3fn,
    compute_dtype=torch.bfloat16,
    skip_modules_classes=["LayerNorm", "GroupNorm", "RMSNorm"],
)

6. FlashAttention-2 / SDPA

Automatic when available via PyTorch's scaled_dot_product_attention. Install flash-attn for explicit backend.

Model Architecture

  • Params: 4B
  • Layers: 5 double + 20 single transformer layers
  • Attention heads: 24 × 128 dim
  • Text encoder: Qwen3ForCausalLM
  • VAE: Standard FLUX VAE (~160MB)
  • License: Apache 2.0
  • VRAM: ~13GB at 1024×1024 (BF16)

Hardware Requirements

Config Min GPU VRAM
1024×1024, 4 steps RTX 3090 / 4070 13GB
512×512, 4 steps + compile RTX 4090 / A100 10GB
256×256, 4 steps + compile RTX 4090 / A100 8GB
256×256, FP8 H100 / RTX 4090 Ada 6GB

Next Steps for 30fps

If you need true 30fps:

  1. Train a 1-step distillation using consistency distillation on top of FLUX.2-klein
  2. Export to TensorRT using torch_tensorrt.compile() after torch.compile warmup
  3. Use the small decoder FLUX.2-small-decoder if compatible
  4. Batch process across multiple GPUs for video generation
  5. Use SageAttention or FlashInfer for faster attention kernels

References

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support