VQDiffusion
Overview
Vector Quantized Diffusion Model for Text-to-Image Synthesis by Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, Baining Guo
The abstract of the paper is the following:
We present the vector quantized diffusion (VQ-Diffusion) model for text-to-image generation. This method is based on a vector quantized variational autoencoder (VQ-VAE) whose latent space is modeled by a conditional variant of the recently developed Denoising Diffusion Probabilistic Model (DDPM). We find that this latent-space method is well-suited for text-to-image generation tasks because it not only eliminates the unidirectional bias with existing methods but also allows us to incorporate a mask-and-replace diffusion strategy to avoid the accumulation of errors, which is a serious problem with existing methods. Our experiments show that the VQ-Diffusion produces significantly better text-to-image generation results when compared with conventional autoregressive (AR) models with similar numbers of parameters. Compared with previous GAN-based text-to-image methods, our VQ-Diffusion can handle more complex scenes and improve the synthesized image quality by a large margin. Finally, we show that the image generation computation in our method can be made highly efficient by reparameterization. With traditional AR methods, the text-to-image generation time increases linearly with the output image resolution and hence is quite time consuming even for normal size images. The VQ-Diffusion allows us to achieve a better trade-off between quality and speed. Our experiments indicate that the VQ-Diffusion model with the reparameterization is fifteen times faster than traditional AR methods while achieving a better image quality.
The original codebase can be found here.
Available Pipelines:
Pipeline | Tasks | Colab |
---|---|---|
pipeline_vq_diffusion.py | Text-to-Image Generation | - |
VQDiffusionPipeline
class diffusers.VQDiffusionPipeline
< source >( vqvae: VQModel text_encoder: CLIPTextModel tokenizer: CLIPTokenizer transformer: Transformer2DModel scheduler: VQDiffusionScheduler learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings )
Parameters
- vqvae (VQModel) — Vector Quantized Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
-
text_encoder (
CLIPTextModel
) — Frozen text-encoder. VQ Diffusion uses the text portion of CLIP, specifically the clip-vit-base-patch32 variant. -
tokenizer (
CLIPTokenizer
) — Tokenizer of class CLIPTokenizer. - transformer (Transformer2DModel) — Conditional transformer to denoise the encoded image latents.
-
scheduler (VQDiffusionScheduler) —
A scheduler to be used in combination with
transformer
to denoise the encoded image latents.
Pipeline for text-to-image generation using VQ Diffusion
This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
__call__
< source >(
prompt: typing.Union[str, typing.List[str]]
num_inference_steps: int = 100
guidance_scale: float = 5.0
truncation_rate: float = 1.0
num_images_per_prompt: int = 1
generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None
latents: typing.Optional[torch.FloatTensor] = None
output_type: typing.Optional[str] = 'pil'
return_dict: bool = True
callback: typing.Union[typing.Callable[[int, int, torch.FloatTensor], NoneType], NoneType] = None
callback_steps: typing.Optional[int] = 1
)
→
ImagePipelineOutput or tuple
Parameters
-
prompt (
str
orList[str]
) — The prompt or prompts to guide the image generation. -
num_inference_steps (
int
, optional, defaults to 100) — The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. -
guidance_scale (
float
, optional, defaults to 7.5) — Guidance scale as defined in Classifier-Free Diffusion Guidance.guidance_scale
is defined asw
of equation 2. of Imagen Paper. Guidance scale is enabled by settingguidance_scale > 1
. Higher guidance scale encourages to generate images that are closely linked to the textprompt
, usually at the expense of lower image quality. -
truncation_rate (
float
, optional, defaults to 1.0 (equivalent to no truncation)) — Used to “truncate” the predicted classes for x_0 such that the cumulative probability for a pixel is at mosttruncation_rate
. The lowest probabilities that would increase the cumulative probability abovetruncation_rate
are set to zero. -
num_images_per_prompt (
int
, optional, defaults to 1) — The number of images to generate per prompt. -
generator (
torch.Generator
, optional) — One or a list of torch generator(s) to make generation deterministic. -
latents (
torch.FloatTensor
of shape (batch), optional) — Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated of completely masked latent pixels. -
output_type (
str
, optional, defaults to"pil"
) — The output format of the generated image. Choose between PIL:PIL.Image.Image
ornp.array
. -
return_dict (
bool
, optional, defaults toTrue
) — Whether or not to return a ImagePipelineOutput instead of a plain tuple. -
callback (
Callable
, optional) — A function that will be called everycallback_steps
steps during inference. The function will be called with the following arguments:callback(step: int, timestep: int, latents: torch.FloatTensor)
. -
callback_steps (
int
, optional, defaults to 1) — The frequency at which thecallback
function will be called. If not specified, the callback will be called at every step.
Returns
ImagePipelineOutput or tuple
~ pipeline_utils.ImagePipelineOutput
if return_dict
is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
Function invoked when calling the pipeline for generation.
Truncates log_p_x_0 such that for each column vector, the total cumulative probability is truncation_rate
The
lowest probabilities that would increase the cumulative probability above truncation_rate
are set to zero.