Diffusers documentation

Torch2.0 support in Diffusers

You are viewing v0.13.0 version. A newer version v0.27.2 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Torch2.0 support in Diffusers

Starting from version 0.13.0, Diffusers supports the latest optimization from the upcoming PyTorch 2.0 release. These include:

  1. Support for native flash and memory-efficient attention without any extra dependencies.
  2. torch.compile support for compiling individual models for extra performance boost.

Installation

To benefit from the native efficient attention and `torch.compile`, we will need to install the nightly version of PyTorch as the stable version is yet to be released. The first step is to install CUDA11.7 or CUDA11.8, as torch2.0 does not support the previous versions. Once CUDA is installed, torch nightly can be installed using:
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu117

Using efficient attention and torch.compile.

  1. Efficient Attention

    Efficient attention is implemented via the torch.nn.functional.scaled_dot_product_attention function, which automatically enables flash/memory efficient attention, depending on the input and the GPU type. This is the same as the memory_efficient_attention from xFormers but built natively into PyTorch.

    Efficient attention will be enabled by default in Diffusers if torch2.0 is installed and if torch.nn.functional.scaled_dot_product_attention is available. To use it, you can install torch2.0 as suggested above and use the pipeline. For example:

    import torch
    from diffusers import StableDiffusionPipeline
    
    pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
    pipe = pipe.to("cuda")
    
    prompt = "a photo of an astronaut riding a horse on mars"
    image = pipe(prompt).images[0]

    If you want to enable it explicitly (which is not required), you can do so as shown below.

    import torch
    from diffusers import StableDiffusionPipeline
    from diffusers.models.cross_attention import AttnProcessor2_0
    
    pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
    pipe.unet.set_attn_processor(AttnProcessor2_0())
    
    prompt = "a photo of an astronaut riding a horse on mars"
    image = pipe(prompt).images[0]

    This should be as fast and memory efficient as xFormers.

  1. torch.compile

    To get an additional speedup, we can use the new torch.compile feature. To do so, we wrap our unet with torch.compile. For more information and different options, refer to the torch compile docs.

    import torch
    from diffusers import StableDiffusionPipeline
    
    pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(
        "cuda"
    )
    pipe.unet = torch.compile(pipe.unet)
    
    batch_size = 10
    prompt = "A photo of an astronaut riding a horse on marse."
    images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images

    Depending on the type of GPU it can give between 2-9% speed-up over efficient attention. But note that as of now the speed-up is mostly noticeable on the more recent GPU architectures, such as in the A100.

    Note that compilation will also take some time to complete, so it is best suited for situations where you need to prepare your pipeline once and then perform the same type of inference operations multiple times.

Benchmark

We conducted a simple benchmark on different GPUs to compare vanilla attention, xFormers, torch.nn.functional.scaled_dot_product_attention and torch.compile+torch.nn.functional.scaled_dot_product_attention. For the benchmark we used the the stable-diffusion-v1-4 model with 50 steps. xFormers benchmark is done using the torch==1.13.1 version. The table below summarizes the result that we got. The Speed over xformers columns denotes the speed-up gained over xFormers using the torch.compile+torch.nn.functional.scaled_dot_product_attention.

FP16 benchmark

The table below shows the benchmark results for inference using fp16. As we can see, torch.nn.functional.scaled_dot_product_attention is as fast as xFormers (sometimes slightly faster/slower) on all the GPUs we tested. And using torch.compile gives further speed-up up to 10% over xFormers, but it’s mostly noticeable on the A100 GPU.

The time reported is in seconds.

GPU Batch Size Vanilla Attention xFormers PyTorch2.0 SDPA SDPA + torch.compile Speed over xformers (%)
A100 10 12.02 8.7 8.79 7.89 9.31
A100 16 18.95 13.57 13.67 12.25 9.73
A100 32 (1) OOM 26.56 26.68 24.08 9.34
A100 64(2) 52.51 53.03 47.81 8.95
A10 4 13.94 9.81 10.01 9.35 4.69
A10 8 27.09 19 19.53 18.33 3.53
A10 10 33.69 23.53 24.19 22.52 4.29
A10 16 OOM 37.55 38.31 36.81 1.97
A10 32 (1) 77.19 78.43 76.64 0.71
A10 64 (1) 173.59 158.99 155.14 10.63
T4 4 38.81 30.09 29.74 27.55 8.44
T4 8 OOM 55.71 55.99 53.85 3.34
T4 10 OOM 68.96 69.86 65.35 5.23
T4 16 OOM 111.47 113.26 106.93 4.07
V100 4 9.84 8.16 8.09 7.65 6.25
V100 8 OOM 15.62 15.44 14.59 6.59
V100 10 OOM 19.52 19.28 18.18 6.86
V100 16 OOM 30.29 29.84 28.22 6.83
3090 4 10.04 7.82 7.89 7.47 4.48
3090 8 19.27 14.97 15.04 14.22 5.01
3090 10 24.08 18.7 18.7 17.69 5.40
3090 16 OOM 29.06 29.06 28.2 2.96
3090 32 (1) 58.05 58 54.88 5.46
3090 64 (1) 126.54 126.03 117.33 7.28
3090 Ti 4 9.07 7.14 7.15 6.81 4.62
3090 Ti 8 17.51 13.65 13.72 12.99 4.84
3090 Ti 10 (2) 21.79 16.85 16.93 16.02 4.93
3090 Ti 16 OOM 26.1 26.28 25.46 2.45
3090 Ti 32 (1) 51.78 52.04 49.15 5.08
3090 Ti 64 (1) 112.02 112.33 103.91 7.24

FP32 benchmark

The table below shows the benchmark results for inference using fp32. As we can see, torch.nn.functional.scaled_dot_product_attention is as fast as xFormers (sometimes slightly faster/slower) on all the GPUs we tested. Using torch.compile with efficient attention gives up to 18% performance improvement over xFormers in Ampere cards, and up to 20% over vanilla attention.

GPU Batch Size Vanilla Attention xFormers PyTorch2.0 SDPA SDPA + torch.compile Speed over xformers (%) Speed over vanilla (%)
A100 4 16.56 12.42 12.2 11.84 4.67 28.50
A100 10 OOM 29.93 29.44 28.5 4.78
A100 16 47.08 46.27 44.8 4.84
A100 32 92.89 91.34 88.35 4.89
A100 64 185.3 182.71 176.48 4.76
A10 1 10.59 8.81 7.51 7.35 16.57 30.59
A10 4 34.77 27.63 22.77 22.07 20.12 36.53
A10 8 56.19 43.53 43.86 21.94
A10 16 116.49 88.56 86.64 25.62
A10 32 221.95 175.74 168.18 24.23
A10 48 333.23 264.84 20.52
T4 1 28.2 24.49 23.93 23.56 3.80 16.45
T4 2 52.77 45.7 45.88 45.06 1.40 14.61
T4 4 OOM 85.72 85.78 84.48 1.45
T4 8 149.64 150.75 148.4 0.83
V100 1 7.4 6.84 6.8 6.66 2.63 10.00
V100 2 13.85 12.81 12.66 12.35 3.59 10.83
V100 4 OOM 25.73 25.31 24.78 3.69
V100 8 43.95 43.37 42.25 3.87
V100 16 84.99 84.73 82.55 2.87
3090 1 7.09 6.78 6.11 6.03 11.06 14.95
3090 4 22.69 21.45 18.67 18.09 15.66 20.27
3090 8 (2) 42.59 36.75 35.59 16.44
3090 16 85.35 72.37 70.25 17.69
3090 32 (1) 162.05 138.99 134.53 16.98
3090 48 241.91 207.75 14.12
3090 Ti 1 6.45 6.19 5.64 5.49 11.31 14.88
3090 Ti 4 20.32 19.31 16.9 16.37 15.23 19.44
3090 Ti 8 (2) 37.93 33.05 31.99 15.66
3090 Ti 16 75.37 65.25 64.32 14.66
3090 Ti 32 (1) 142.55 124.44 120.74 15.30
3090 Ti 48 213.19 186.55 12.50
4090 1 5.54 4.99 4.51
4090 4 13.67 11.4 10.3
4090 8 (2) 19.79 17.13
4090 16 38.62 33.14
4090 32 (1) 76.57 65.96
4090 48 114.44 98.78

(1) Batch Size >= 32 requires enable_vae_slicing() because of https://github.com/pytorch/pytorch/issues/81665 This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and batch size of 64

For more details about how this benchmark was run, please refer to this PR.