Diffusers documentation

Scalable Diffusion Models with Transformers (DiT)

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Scalable Diffusion Models with Transformers (DiT)

Overview

Scalable Diffusion Models with Transformers (DiT) by William Peebles and Saining Xie.

The abstract of the paper is the following:

We explore a new class of diffusion models based on the transformer architecture. We train latent diffusion models of images, replacing the commonly-used U-Net backbone with a transformer that operates on latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass complexity as measured by Gflops. We find that DiTs with higher Gflops — through increased transformer depth/width or increased number of input tokens — consistently have lower FID. In addition to possessing good scalability properties, our largest DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512x512 and 256x256 benchmarks, achieving a state-of-the-art FID of 2.27 on the latter.

The original codebase of this paper can be found here: facebookresearch/dit.

Available Pipelines:

Pipeline Tasks Colab
pipeline_dit.py Conditional Image Generation -

Usage example

from diffusers import DiTPipeline, DPMSolverMultistepScheduler
import torch

pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")

# pick words from Imagenet class labels
pipe.labels  # to print all available words

# pick words that exist in ImageNet
words = ["white shark", "umbrella"]

class_ids = pipe.get_label_ids(words)

generator = torch.manual_seed(33)
output = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator)

image = output.images[0]  # label 'white shark'

DiTPipeline

class diffusers.DiTPipeline

< >

( transformer: Transformer2DModel vae: AutoencoderKL scheduler: KarrasDiffusionSchedulers id2label: typing.Union[typing.Dict[int, str], NoneType] = None )

Parameters

  • transformer (Transformer2DModel) — Class conditioned Transformer in Diffusion model to denoise the encoded image latents.
  • vae (AutoencoderKL) — Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
  • scheduler (DDIMScheduler) — A scheduler to be used in combination with dit to denoise the encoded image latents.

This pipeline inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

__call__

< >

( class_labels: typing.List[int] guidance_scale: float = 4.0 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None num_inference_steps: int = 50 output_type: typing.Optional[str] = 'pil' return_dict: bool = True )

Parameters

  • class_labels (List[int]) — List of imagenet class labels for the images to be generated.
  • guidance_scale (float, optional, defaults to 4.0) — Scale of the guidance signal.
  • generator (torch.Generator, optional) — A torch generator to make generation deterministic.
  • num_inference_steps (int, optional, defaults to 250) — The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
  • output_type (str, optional, defaults to "pil") — The output format of the generate image. Choose between PIL: PIL.Image.Image or np.array.
  • return_dict (bool, optional, defaults to True) — Whether or not to return a ImagePipelineOutput instead of a plain tuple.

Function invoked when calling the pipeline for generation.

get_label_ids

< >

( label: typing.Union[str, typing.List[str]] ) → list of int

Parameters

  • label (str or dict of str) — label strings to be mapped to class ids.

Returns

list of int

Class ids to be processed by pipeline.

Map label strings, e.g. from ImageNet, to corresponding class ids.