custom-nki-kernels / src /benchmark.py
Jingya's picture
Jingya HF Staff
Upload src/benchmark.py with huggingface_hub
75c1830 verified
#!/usr/bin/env python3
"""
Benchmark for Flux2-klein 4B and 9B models on AWS Neuron.
Usage:
torchrun --nproc_per_node=4 flux2-klein/benchmark.py \\
--no-random-weights --model-id black-forest-labs/FLUX.2-klein-9B \\
--num-runs 3 --num-steps 4
Results (trn2.3xlarge, 4 NeuronCores, 512×512, 4 steps, bfloat16):
FLUX.2-klein-4B (3.88B params) — eager mode
Run Type step01 step02 step03 step04 total
1 COLD 9.348s 0.844s 0.835s 0.860s 11.888s
2 WARM 0.831s 0.835s 0.838s 0.837s 3.342s
3 WARM 0.830s 0.835s 0.831s 0.834s 3.330s
4 WARM 0.836s 0.831s 0.840s 0.838s 3.345s
Cold first call (XLA compilation): 9.348s
Warm avg/step: 0.835s | 1.198 steps/s | speedup vs cold: 11.2×
FLUX.2-klein-9B (9.08B params) — eager mode
Run Type step01 step02 step03 step04 total
1 COLD 129.651s 1.276s 1.264s 1.270s 133.461s
2 WARM 1.277s 1.264s 1.267s 1.264s 5.071s
3 WARM 1.265s 1.262s 1.270s 1.263s 5.061s
4 WARM 1.258s 1.274s 1.267s 1.266s 5.065s
Cold first call (XLA compilation): 129.651s
Warm avg/step: 1.266s | 0.790 steps/s | speedup vs cold: 102.4×
FLUX.2-klein-9B (9.08B params) — compile mode (torch.compile, Dynamo+NEFF)
Run Type step01 step02 step03 step04 total
1 COLD 264.514s 5.677s 5.675s 5.673s 281.539s
2 WARM 5.676s 5.677s 5.677s 5.673s 22.703s
3 WARM 5.672s 5.676s 5.679s 5.676s 22.702s
4 WARM 5.671s 5.673s 5.673s 5.677s 22.695s
Cold first call (Dynamo+NEFF compilation): 264.514s
Warm avg/step: 5.675s | 0.176 steps/s
Comparison — FLUX.2-klein-9B warm throughput:
eager: 1.284s/step (0.779 steps/s) ← 4.4× faster
compile: 5.675s/step (0.176 steps/s)
Note: compile mode is slower because torch.compile/Dynamo uses the NKI flash
attention decomposition (training=True path) and does not benefit from the
XLA-level fusions that the lazy-XLA path applies automatically.
"""
import argparse
import gc
import logging
import os
import sys
import time
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh
from diffusers import Flux2Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.flux2.pipeline_flux2_klein import (
Flux2KleinPipeline,
compute_empirical_mu,
)
# Import loading/TP helpers from pipeline.py in the same directory
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from pipeline import ( # noqa: E402
apply_tp_flux2_transformer,
apply_tp_text_encoder,
_encode_prompt_tp,
load_text_encoder,
load_transformer,
_snapshot,
)
import torch_neuronx # noqa: F401, E402 — registers neuron backend
from torch_neuronx.neuron_dynamo_backend import set_model_name
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
logger = logging.getLogger(__name__)
DEFAULT_MODEL_ID = "black-forest-labs/FLUX.2-klein-4B"
# ---------------------------------------------------------------------------
# Latent / position ID preparation
# ---------------------------------------------------------------------------
def _prepare_inputs(transformer, height, width, batch_size, text_seq_len, device, seed):
"""
Compute initial latents, latent position IDs, and text position IDs.
Returns (latents_dev, latent_ids_dev, text_ids_dev, latents_cpu) where
latents_cpu is kept to reset latents to their original values each run.
"""
generator = torch.Generator().manual_seed(seed)
vae_scale = 8
lh = 2 * (height // (vae_scale * 2))
lw = 2 * (width // (vae_scale * 2))
seq_len = (lh // 2) * (lw // 2)
latents_cpu = torch.randn(
batch_size, seq_len, transformer.config.in_channels,
dtype=torch.bfloat16, generator=generator,
)
latent_ids_cpu = (
torch.cartesian_prod(
torch.arange(1), torch.arange(lh // 2),
torch.arange(lw // 2), torch.arange(1),
)
.unsqueeze(0).expand(batch_size, -1, -1).contiguous().float()
)
text_ids_cpu = (
torch.cartesian_prod(
torch.arange(1), torch.arange(1),
torch.arange(1), torch.arange(text_seq_len),
)
.unsqueeze(0).expand(batch_size, -1, -1).contiguous().float()
)
return (
latents_cpu.to(device),
latent_ids_cpu.to(device),
text_ids_cpu.to(device),
latents_cpu, # kept on CPU for resetting between runs
)
# ---------------------------------------------------------------------------
# Single denoising run
# ---------------------------------------------------------------------------
def _run_one(
run_idx, num_runs, transformer, scheduler,
prompt_embeds, latents_init_cpu, latent_ids_dev, text_ids_dev,
ts_tensor, num_steps, batch_size, device, rank,
):
"""
Execute one complete denoising loop and return per-step wall-clock times.
Latents are always reset to `latents_init_cpu` at the start so every run
is independent. Scheduler step is on rank 0 CPU; updated latents are
broadcast to all ranks.
Returns:
step_times: list[float] — elapsed seconds for each transformer forward.
"""
is_cold = (run_idx == 0)
label = f"Run {run_idx + 1}/{num_runs} ({'COLD' if is_cold else 'WARM':4s})"
latents_dev = latents_init_cpu.to(device)
# Reset scheduler's internal step counter (avoids IndexError on run 2+)
if rank == 0:
scheduler._step_index = None
step_times = []
if rank == 0:
logger.info(f" --- {label} ---")
dist.barrier()
t_run = time.time()
with torch.no_grad():
for step_idx in range(num_steps):
t_val = ts_tensor[step_idx]
timestep = t_val.expand(batch_size).to(torch.bfloat16).to(device) / 1000.0
dist.barrier()
t0 = time.time()
noise_pred = transformer(
hidden_states=latents_dev,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
img_ids=latent_ids_dev,
txt_ids=text_ids_dev,
guidance=None,
return_dict=False,
)[0]
dist.barrier()
elapsed = time.time() - t0
step_times.append(elapsed)
if rank == 0:
logger.info(
f" step {step_idx + 1:2d}/{num_steps}"
f" t={t_val.item():7.1f}"
f" elapsed={elapsed:.3f}s"
)
if rank == 0:
lat_new = scheduler.step(
noise_pred.to("cpu"), t_val.cpu(), latents_dev.to("cpu"),
return_dict=False,
)[0]
latents_dev.copy_(lat_new.to(device))
dist.broadcast(latents_dev, src=0)
if rank == 0:
total = time.time() - t_run
logger.info(f" run {run_idx + 1} total: {total:.3f}s")
return step_times
# ---------------------------------------------------------------------------
# Summary reporting
# ---------------------------------------------------------------------------
def _print_summary(mode, model_id, height, width, num_steps, num_runs, all_step_times):
"""Print a formatted latency table and key metrics to the log."""
SEP = "=" * 72
HSEP = "-" * 72
cold_label = (
"Dynamo trace + NEFF compilation" if mode == "compile"
else "XLA compilation"
)
logger.info(SEP)
logger.info(f"BENCHMARK RESULTS | {model_id} | mode={mode}")
logger.info(f" {height}x{width} · {num_steps} steps/run · {num_runs} runs")
logger.info(HSEP)
step_hdrs = " ".join(f"step{i + 1:02d}" for i in range(num_steps))
logger.info(f"{'Run':<5} {'Type':<5} {step_hdrs} total")
logger.info(HSEP)
for run_idx, times in enumerate(all_step_times):
rtype = "COLD" if run_idx == 0 else "WARM"
cells = " ".join(f"{t:6.3f}s" for t in times)
logger.info(f"{run_idx + 1:<5} {rtype:<5} {cells} {sum(times):.3f}s")
logger.info(HSEP)
cold_step1 = all_step_times[0][0]
logger.info(f" Cold first call (incl. {cold_label}): {cold_step1:.3f}s")
if num_steps > 1:
cold_rest = all_step_times[0][1:]
avg_cold_rest = sum(cold_rest) / len(cold_rest)
logger.info(
f" Cold run steps 2-{num_steps} avg : {avg_cold_rest:.3f}s/step"
)
if num_runs > 1:
warm_times = [t for times in all_step_times[1:] for t in times]
avg_warm = sum(warm_times) / len(warm_times)
warm_step1_times = [times[0] for times in all_step_times[1:]]
avg_warm_step1 = sum(warm_step1_times) / len(warm_step1_times)
logger.info(
f" Warm runs — first step avg : {avg_warm_step1:.3f}s/step"
)
logger.info(
f" Warm runs — all steps avg : {avg_warm:.3f}s/step"
)
logger.info(
f" Throughput (warm, all steps) : {1.0 / avg_warm:.3f} steps/s"
)
logger.info(
f" Speedup vs cold first call : {cold_step1 / avg_warm:.1f}x"
)
logger.info(SEP)
# ---------------------------------------------------------------------------
# Main benchmark entry point
# ---------------------------------------------------------------------------
def benchmark(
mode, model_id, prompt, height, width, num_steps, batch_size,
num_runs, random_weights, seed, fuse_qkv=False, flash_attn=False,
):
assert mode in ("eager", "compile"), f"--mode must be 'eager' or 'compile', got {mode!r}"
dist.init_process_group(backend="neuron")
world_size = dist.get_world_size()
rank = dist.get_rank()
device = torch.neuron.current_device()
tp_mesh = DeviceMesh("neuron", list(range(world_size)))
if rank == 0:
logger.info(f"{'=' * 72}")
logger.info(f"Flux2-klein benchmark | {model_id} | mode={mode}")
logger.info(
f" {height}x{width} · {num_steps} steps · {num_runs} runs "
f"· batch={batch_size} · random_weights={random_weights}"
)
logger.info(f"{'=' * 72}")
xfmr_cfg = Flux2Transformer2DModel.load_config(model_id, subfolder="transformer")
joint_attention_dim = xfmr_cfg["joint_attention_dim"]
text_seq_len = 512
# ------------------------------------------------------------------
# 1. Text encoder: all ranks load & TP-encode, then free
# ------------------------------------------------------------------
if not random_weights:
t0 = time.time()
text_encoder, tokenizer = load_text_encoder(model_id, random_weights=False)
logger.info(
f"Rank {rank}: text encoder loaded in {time.time() - t0:.1f}s "
f"({sum(p.numel() for p in text_encoder.parameters()) / 1e9:.2f}B params)"
)
text_encoder = apply_tp_text_encoder(text_encoder, tp_mesh)
text_encoder = text_encoder.to(device)
text_encoder.eval()
if mode == "compile":
set_model_name(f"qwen3_text_encoder_rank{rank}")
# Pre-install output-capturing hooks so _output_capturing_hooks_installed=True;
# the maybe_install_capturing_hooks early-return fires before the threading.Lock
# that torch.compile(fullgraph=True) cannot trace. See pipeline.py for full note.
from transformers.utils.output_capturing import install_all_output_capturing_hooks
install_all_output_capturing_hooks(text_encoder)
text_encoder = torch.compile(text_encoder, backend="neuron", fullgraph=True)
logger.info(f"Rank {rank}: text encoder compiled")
gc.collect()
prompt_embeds = _encode_prompt_tp(
text_encoder, tokenizer, prompt, batch_size, device)
if rank == 0:
logger.info(f"Prompt encoded shape={prompt_embeds.shape}")
del text_encoder, tokenizer
gc.collect()
else:
prompt_embeds = torch.zeros(
batch_size, text_seq_len, joint_attention_dim,
dtype=torch.bfloat16, device=device,
)
if rank == 0:
prompt_embeds.copy_(
torch.randn(batch_size, text_seq_len, joint_attention_dim,
dtype=torch.bfloat16).to(device))
dist.broadcast(prompt_embeds, src=0)
# ------------------------------------------------------------------
# 2. Transformer: all ranks load, TP, move to Neuron [+ compile]
# ------------------------------------------------------------------
t0 = time.time()
transformer = load_transformer(model_id, random_weights)
logger.info(
f"Rank {rank}: transformer loaded in {time.time() - t0:.1f}s "
f"({sum(p.numel() for p in transformer.parameters()) / 1e9:.2f}B params)"
)
transformer = apply_tp_flux2_transformer(transformer, tp_mesh,
fuse_qkv=fuse_qkv, flash_attn=flash_attn)
transformer = transformer.to(device)
transformer.eval()
if mode == "compile":
set_model_name(f"flux2_transformer_rank{rank}")
transformer = torch.compile(transformer, backend="neuron", fullgraph=True)
logger.info(f"Rank {rank}: transformer compiled (NEFF will build on first call)")
gc.collect()
# ------------------------------------------------------------------
# 3. Scheduler timesteps (computed once, reused for all runs)
# ------------------------------------------------------------------
vae_scale = 8
lh = 2 * (height // (vae_scale * 2))
lw = 2 * (width // (vae_scale * 2))
image_seq_len = (lh // 2) * (lw // 2)
if rank == 0:
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_steps)
scheduler = FlowMatchEulerDiscreteScheduler()
scheduler.set_timesteps(num_steps, mu=mu)
ts_float = scheduler.timesteps.float()
logger.info(f"Timesteps: {scheduler.timesteps.tolist()}")
else:
scheduler = FlowMatchEulerDiscreteScheduler()
ts_float = torch.zeros(num_steps, dtype=torch.float32)
ts_dev = ts_float.to(device)
dist.broadcast(ts_dev, src=0)
# ------------------------------------------------------------------
# 4. Initial latents and position IDs
# ------------------------------------------------------------------
latents_dev, latent_ids_dev, text_ids_dev, latents_init_cpu = _prepare_inputs(
transformer, height, width, batch_size, text_seq_len, device, seed,
)
# ------------------------------------------------------------------
# 5. Benchmark loop
# Run 1 (COLD): triggers compilation (XLA or Dynamo+NEFF)
# Runs 2+ (WARM): reuse compiled graph
# ------------------------------------------------------------------
dist.barrier()
if rank == 0:
compile_note = " (run 1 triggers Dynamo+NEFF compile)" if mode == "compile" else ""
logger.info(
f"Starting {num_runs} benchmark runs ({num_steps} steps each){compile_note} ..."
)
all_step_times = []
for run_idx in range(num_runs):
step_times = _run_one(
run_idx, num_runs, transformer, scheduler,
prompt_embeds, latents_init_cpu, latent_ids_dev, text_ids_dev,
ts_dev, num_steps, batch_size, device, rank,
)
all_step_times.append(step_times)
# ------------------------------------------------------------------
# 6. Summary (rank 0 only)
# ------------------------------------------------------------------
if rank == 0:
_print_summary(mode, model_id, height, width, num_steps, num_runs, all_step_times)
dist.barrier()
dist.destroy_process_group()
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser(
description="Flux2-klein latency benchmark (4B / 9B) on Neuron",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
p.add_argument("--mode", choices=["eager", "compile"], default="eager",
help="eager: lazy-XLA path. compile: torch.compile Dynamo path.")
p.add_argument("--model-id", default=DEFAULT_MODEL_ID,
help="HuggingFace model ID (4B or 9B variant)")
p.add_argument("--prompt", default="a cat sitting on a Neuron chip, photorealistic")
p.add_argument("--height", type=int, default=512)
p.add_argument("--width", type=int, default=512)
p.add_argument("--num-steps", type=int, default=4,
help="Denoising steps per run")
p.add_argument("--num-runs", type=int, default=4,
help="Total runs: run 1=COLD (compilation), runs 2+=WARM (benchmarked)")
p.add_argument("--batch-size", type=int, default=1)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--random-weights", action="store_true", default=True)
p.add_argument("--no-random-weights", action="store_false", dest="random_weights")
p.add_argument("--fused-qkv", action="store_true", default=False,
help="Use NKI fused QKV kernel for double-stream blocks.")
p.add_argument("--flash-attn", action="store_true", default=False,
help="Use NKI flash attention kernel for all blocks.")
p.add_argument(
"--cache-dir",
default=None,
help=(
"Persistent NEFF cache directory (sets TORCH_NEURONX_NEFF_CACHE_DIR). "
"Applies to both eager and compile modes. "
"NEFFs saved on first run, reloaded on subsequent runs. "
"Example: --cache-dir /home/ubuntu/neff_cache"
),
)
return p.parse_args()
if __name__ == "__main__":
args = parse_args()
# Always set the NEFF cache dir regardless of mode — both eager (lazy-XLA)
# and compile (Dynamo) paths use TORCH_NEURONX_NEFF_CACHE_DIR to persist
# compiled NEFFs across runs. Default /tmp/neff_cache is lost on reboot.
cache_dir = args.cache_dir or os.environ.get("TORCH_NEURONX_NEFF_CACHE_DIR", "/tmp/neff_cache")
os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir
os.makedirs(cache_dir, exist_ok=True)
logger.info(f"NEFF cache dir: {cache_dir}")
benchmark(
mode=args.mode,
model_id=args.model_id,
prompt=args.prompt,
height=args.height,
width=args.width,
num_steps=args.num_steps,
batch_size=args.batch_size,
num_runs=args.num_runs,
random_weights=args.random_weights,
seed=args.seed,
fuse_qkv=args.fused_qkv,
flash_attn=args.flash_attn,
)