Diffusers documentation

ParaAttention

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.35.1).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

ParaAttention

大型图像和视频生成模型,如 FLUX.1-devHunyuanVideo,由于其规模,可能对实时应用和部署构成推理挑战。

ParaAttention 是一个实现了上下文并行第一块缓存的库,可以与其他技术(如 torch.compile、fp8 动态量化)结合使用,以加速推理。

本指南将展示如何在 NVIDIA L20 GPU 上对 FLUX.1-dev 和 HunyuanVideo 应用 ParaAttention。 在我们的基线基准测试中,除了 HunyuanVideo 为避免内存不足错误外,未应用任何优化。

我们的基线基准测试显示,FLUX.1-dev 能够在 28 步中生成 1024x1024 分辨率图像,耗时 26.36 秒;HunyuanVideo 能够在 30 步中生成 129 帧 720p 分辨率视频,耗时 3675.71 秒。

对于更快的上下文并行推理,请尝试使用支持 NVLink 的 NVIDIA A100 或 H100 GPU(如果可用),尤其是在 GPU 数量较多时。

第一块缓存

缓存模型中 transformer 块的输出并在后续推理步骤中重用它们,可以降低计算成本并加速推理。

然而,很难决定何时重用缓存以确保生成图像或视频的质量。ParaAttention 直接使用第一个 transformer 块输出的残差差异来近似模型输出之间的差异。当差异足够小时,重用先前推理步骤的残差差异。换句话说,跳过去噪步骤。

这在 FLUX.1-dev 和 HunyuanVideo 推理上实现了 2 倍加速,且质量非常好。

Cache in Diffusion Transformer
AdaCache 的工作原理,第一块缓存是其变体
FLUX-1.dev
HunyuanVideo

要在 FLUX.1-dev 上应用第一块缓存,请调用 apply_cache_on_pipe,如下所示。0.08 是 FLUX 模型的默认残差差异值。

import time
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")

from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe

apply_cache_on_pipe(pipe, residual_diff_thre
shold=0.08)

# 启用内存节省
# pipe.enable_model_cpu_offload()
# pipe.enable_sequential_cpu_offload()

begin = time.time()
image = pipe(
    "A cat holding a sign that says hello world",
    num_inference_steps=28,
).images[0]
end = time.time()
print(f"Time: {end - begin:.2f}s")

print("Saving image to flux.png")
image.save("flux.png")
优化 原始 FBCache rdt=0.06 FBCache rdt=0.08 FBCache rdt=0.10 FBCache rdt=0.12
预览 Original FBCache rdt=0.06 FBCache rdt=0.08 FBCache rdt=0.10 FBCache rdt=0.12
墙时间 (s) 26.36 21.83 17.01 16.00 13.78

First Block Cache 将推理速度降低到 17.01 秒,与基线相比,或快 1.55 倍,同时保持几乎零质量损失。

fp8 量化

fp8 动态量化进一步加速推理并减少内存使用。为了使用 8 位 NVIDIA Tensor Cores,必须对激活和权重进行量化。

使用 float8_weight_onlyfloat8_dynamic_activation_float8_weight 来量化文本编码器和变换器模型。

默认量化方法是逐张量量化,但如果您的 GPU 支持逐行量化,您也可以尝试它以获得更好的准确性。

使用以下命令安装 torchao

pip3 install -U torch torchao

torch.compile 使用 mode="max-autotune-no-cudagraphs"mode="max-autotune" 选择最佳内核以获得性能。如果是第一次调用模型,编译可能会花费很长时间,但一旦模型编译完成,这是值得的。

此示例仅量化变换器模型,但您也可以量化文本编码器以进一步减少内存使用。

动态量化可能会显著改变模型输出的分布,因此您需要将 residual_diff_threshold 设置为更大的值以使其生效。

FLUX-1.dev
HunyuanVideo
import time
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")

from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe

apply_cache_on_pipe(
    pipe,
    residual_diff_threshold=0.12,  # 使用更大的值以使缓存生效
)

from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only

quantize_(pipe.text_encoder, float8_weight_only())
quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())
pipe.transformer = torch.compile(
   pipe.transformer, mode="max-autotune-no-cudagraphs",
)

# 启用内存节省
# pipe.enable_model_cpu_offload()
# pipe.enable_sequential_cpu_offload()

for i in range(2):
    begin = time.time()
    image = pipe(
        "A cat holding a sign that says hello world",
        num_inference_steps=28,
    ).images[0]
    end = time.time()
    if i == 0:
        print(f"预热时间: {end - begin:.2f}s")
    else:
        print(f"时间: {end - begin:.2f}s")

print("保存图像到 flux.png")
image.save("flux.png")

fp8 动态量化和 torch.compile 将推理速度降低至 7.56 秒,相比基线快了 3.48 倍。

上下文并行性

上下文并行性并行化推理并随多个 GPU 扩展。ParaAttention 组合设计允许您将上下文并行性与第一块缓存和动态量化结合使用。

请参考 ParaAttention 仓库获取详细说明和如何使用多个 GPU 扩展推理的示例。

如果推理过程需要持久化和可服务,建议使用 torch.multiprocessing 编写您自己的推理处理器。这可以消除启动进程以及加载和重新编译模型的开销。

FLUX-1.dev
HunyuanVideo

以下代码示例结合了第一块缓存、fp8动态量化、torch.compile和上下文并行,以实现最快的推理速度。

import time
import torch
import torch.distributed as dist
from diffusers import FluxPipeline

dist.init_process_group()

torch.cuda.set_device(dist.get_rank())

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")

from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
from para_attn.parallel_vae.diffusers_adapters import parallelize_vae

mesh = init_context_parallel_mesh(
    pipe.device.type,
    max_ring_dim_size=2,
)
parallelize_pipe(
    pipe,
    mesh=mesh,
)
parallelize_vae(pipe.vae, mesh=mesh._flatten())

from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe

apply_cache_on_pipe(
    pipe,
    residual_diff_threshold=0.12,  # 使用较大的值以使缓存生效
)

from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only

quantize_(pipe.text_encoder, float8_weight_only())
quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())
torch._inductor.config.reorder_for_compute_comm_overlap = True
pipe.transformer = torch.compile(
   pipe.transformer, mode="max-autotune-no-cudagraphs",
)

# 启用内存节省
# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())
# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank())

for i in range(2):
    begin = time.time()
    image = pipe(
        "A cat holding a sign that says hello world",
        num_inference_steps=28,
        output_type="pil" if dist.get_rank() == 0 else "pt",
    ).images[0]
    end = time.time()
    if dist.get_rank() == 0:
        if i == 0:
            print(f"预热时间: {end - begin:.2f}s")
        else:
            print(f"时间: {end - begin:.2f}s")

if dist.get_rank() == 0:
    print("将图像保存到flux.png")
    image.save("flux.png")

dist.destroy_process_group()

保存到run_flux.py并使用torchrun启动。

# 使用--nproc_per_node指定GPU数量
torchrun --nproc_per_node=2 run_flux.py

推理速度降至8.20秒,相比基线快了3.21倍,使用2个NVIDIA L20 GPU。在4个L20上,推理速度为3.90秒,快了6.75倍。

基准测试

FLUX-1.dev
HunyuanVideo
GPU 类型 GPU 数量 优化 墙钟时间 (s) 加速比
NVIDIA L20 1 基线 26.36 1.00x
NVIDIA L20 1 FBCache (rdt=0.08) 17.01 1.55x
NVIDIA L20 1 FP8 DQ 13.40 1.96x
NVIDIA L20 1 FBCache (rdt=0.12) + FP8 DQ 7.56 3.48x
NVIDIA L20 2 FBCache (rdt=0.12) + FP8 DQ + CP 4.92 5.35x
NVIDIA L20 4 FBCache (rdt=0.12) + FP8 DQ + CP 3.90 6.75x
< > Update on GitHub