Würstchen
Wuerstchen: An Efficient Architecture for Large-Scale Text-to-Image Diffusion Models is by Pablo Pernias, Dominic Rampas, Mats L. Richter and Christopher Pal and Marc Aubreville.
The abstract from the paper is:
We introduce Würstchen, a novel architecture for text-to-image synthesis that combines competitive performance with unprecedented cost-effectiveness for large-scale text-to-image diffusion models. A key contribution of our work is to develop a latent diffusion technique in which we learn a detailed but extremely compact semantic image representation used to guide the diffusion process. This highly compressed representation of an image provides much more detailed guidance compared to latent representations of language and this significantly reduces the computational requirements to achieve state-of-the-art results. Our approach also improves the quality of text-conditioned image generation based on our user preference study. The training requirements of our approach consists of 24,602 A100-GPU hours - compared to Stable Diffusion 2.1’s 200,000 GPU hours. Our approach also requires less training data to achieve these results. Furthermore, our compact latent representations allows us to perform inference over twice as fast, slashing the usual costs and carbon footprint of a state-of-the-art (SOTA) diffusion model significantly, without compromising the end performance. In a broader comparison against SOTA models our approach is substantially more efficient and compares favorably in terms of image quality. We believe that this work motivates more emphasis on the prioritization of both performance and computational accessibility.
Würstchen Overview
Würstchen is a diffusion model, whose text-conditional model works in a highly compressed latent space of images. Why is this important? Compressing data can reduce computational costs for both training and inference by magnitudes. Training on 1024x1024 images is way more expensive than training on 32x32. Usually, other works make use of a relatively small compression, in the range of 4x - 8x spatial compression. Würstchen takes this to an extreme. Through its novel design, we achieve a 42x spatial compression. This was unseen before because common methods fail to faithfully reconstruct detailed images after 16x spatial compression. Würstchen employs a two-stage compression, what we call Stage A and Stage B. Stage A is a VQGAN, and Stage B is a Diffusion Autoencoder (more details can be found in the paper). A third model, Stage C, is learned in that highly compressed latent space. This training requires fractions of the compute used for current top-performing models, while also allowing cheaper and faster inference.
Würstchen v2 comes to Diffusers
After the initial paper release, we have improved numerous things in the architecture, training and sampling, making Würstchen competitive to current state-of-the-art models in many ways. We are excited to release this new version together with Diffusers. Here is a list of the improvements.
- Higher resolution (1024x1024 up to 2048x2048)
- Faster inference
- Multi Aspect Resolution Sampling
- Better quality
We are releasing 3 checkpoints for the text-conditional image generation model (Stage C). Those are:
- v2-base
- v2-aesthetic
- (default) v2-interpolated (50% interpolation between v2-base and v2-aesthetic)
We recommend using v2-interpolated, as it has a nice touch of both photorealism and aesthetics. Use v2-base for finetunings as it does not have a style bias and use v2-aesthetic for very artistic generations. A comparison can be seen here:
Text-to-Image Generation
For the sake of usability, Würstchen can be used with a single pipeline. This pipeline can be used as follows:
import torch
from diffusers import AutoPipelineForText2Image
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
pipe = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen", torch_dtype=torch.float16).to("cuda")
caption = "Anthropomorphic cat dressed as a fire fighter"
images = pipe(
caption,
width=1024,
height=1536,
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
prior_guidance_scale=4.0,
num_images_per_prompt=2,
).images
For explanation purposes, we can also initialize the two main pipelines of Würstchen individually. Würstchen consists of 3 stages: Stage C, Stage B, Stage A. They all have different jobs and work only together. When generating text-conditional images, Stage C will first generate the latents in a very compressed latent space. This is what happens in the prior_pipeline
. Afterwards, the generated latents will be passed to Stage B, which decompresses the latents into a bigger latent space of a VQGAN. These latents can then be decoded by Stage A, which is a VQGAN, into the pixel-space. Stage B & Stage A are both encapsulated in the decoder_pipeline
. For more details, take a look at the paper.
import torch
from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
device = "cuda"
dtype = torch.float16
num_images_per_prompt = 2
prior_pipeline = WuerstchenPriorPipeline.from_pretrained(
"warp-ai/wuerstchen-prior", torch_dtype=dtype
).to(device)
decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained(
"warp-ai/wuerstchen", torch_dtype=dtype
).to(device)
caption = "Anthropomorphic cat dressed as a fire fighter"
negative_prompt = ""
prior_output = prior_pipeline(
prompt=caption,
height=1024,
width=1536,
timesteps=DEFAULT_STAGE_C_TIMESTEPS,
negative_prompt=negative_prompt,
guidance_scale=4.0,
num_images_per_prompt=num_images_per_prompt,
)
decoder_output = decoder_pipeline(
image_embeddings=prior_output.image_embeddings,
prompt=caption,
negative_prompt=negative_prompt,
guidance_scale=0.0,
output_type="pil",
).images[0]
decoder_output
Speed-Up Inference
You can make use of torch.compile
function and gain a speed-up of about 2-3x:
prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="reduce-overhead", fullgraph=True)
Limitations
- Due to the high compression employed by Würstchen, generations can lack a good amount of detail. To our human eye, this is especially noticeable in faces, hands etc.
- Images can only be generated in 128-pixel steps, e.g. the next higher resolution after 1024x1024 is 1152x1152
- The model lacks the ability to render correct text in images
- The model often does not achieve photorealism
- Difficult compositional prompts are hard for the model
The original codebase, as well as experimental ideas, can be found at dome272/Wuerstchen.
WuerstchenCombinedPipeline
class diffusers.WuerstchenCombinedPipeline
< source >( tokenizer: CLIPTokenizer text_encoder: CLIPTextModel decoder: WuerstchenDiffNeXt scheduler: DDPMWuerstchenScheduler vqgan: PaellaVQModel prior_tokenizer: CLIPTokenizer prior_text_encoder: CLIPTextModel prior_prior: WuerstchenPrior prior_scheduler: DDPMWuerstchenScheduler )
Parameters
- tokenizer (
CLIPTokenizer
) — The decoder tokenizer to be used for text inputs. - text_encoder (
CLIPTextModel
) — The decoder text encoder to be used for text inputs. - decoder (
WuerstchenDiffNeXt
) — The decoder model to be used for decoder image generation pipeline. - scheduler (
DDPMWuerstchenScheduler
) — The scheduler to be used for decoder image generation pipeline. - vqgan (
PaellaVQModel
) — The VQGAN model to be used for decoder image generation pipeline. - prior_tokenizer (
CLIPTokenizer
) — The prior tokenizer to be used for text inputs. - prior_text_encoder (
CLIPTextModel
) — The prior text encoder to be used for text inputs. - prior_prior (
WuerstchenPrior
) — The prior model to be used for prior pipeline. - prior_scheduler (
DDPMWuerstchenScheduler
) — The scheduler to be used for prior pipeline.
Combined Pipeline for text-to-image generation using Wuerstchen
This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
__call__
< source >( prompt: Union = None height: int = 512 width: int = 512 prior_num_inference_steps: int = 60 prior_timesteps: Optional = None prior_guidance_scale: float = 4.0 num_inference_steps: int = 12 decoder_timesteps: Optional = None decoder_guidance_scale: float = 0.0 negative_prompt: Union = None prompt_embeds: Optional = None negative_prompt_embeds: Optional = None num_images_per_prompt: int = 1 generator: Union = None latents: Optional = None output_type: Optional = 'pil' return_dict: bool = True prior_callback_on_step_end: Optional = None prior_callback_on_step_end_tensor_inputs: List = ['latents'] callback_on_step_end: Optional = None callback_on_step_end_tensor_inputs: List = ['latents'] **kwargs )
Parameters
- prompt (
str
orList[str]
) — The prompt or prompts to guide the image generation for the prior and decoder. - negative_prompt (
str
orList[str]
, optional) — The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored ifguidance_scale
is less than1
). - prompt_embeds (
torch.FloatTensor
, optional) — Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, text embeddings will be generated fromprompt
input argument. - negative_prompt_embeds (
torch.FloatTensor
, optional) — Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, negative_prompt_embeds will be generated fromnegative_prompt
input argument. - num_images_per_prompt (
int
, optional, defaults to 1) — The number of images to generate per prompt. - height (
int
, optional, defaults to 512) — The height in pixels of the generated image. - width (
int
, optional, defaults to 512) — The width in pixels of the generated image. - prior_guidance_scale (
float
, optional, defaults to 4.0) — Guidance scale as defined in Classifier-Free Diffusion Guidance.prior_guidance_scale
is defined asw
of equation 2. of Imagen Paper. Guidance scale is enabled by settingprior_guidance_scale > 1
. Higher guidance scale encourages to generate images that are closely linked to the textprompt
, usually at the expense of lower image quality. - prior_num_inference_steps (
Union[int, Dict[float, int]]
, optional, defaults to 60) — The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. For more specific timestep spacing, you can pass customizedprior_timesteps
- num_inference_steps (
int
, optional, defaults to 12) — The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. For more specific timestep spacing, you can pass customizedtimesteps
- prior_timesteps (
List[float]
, optional) — Custom timesteps to use for the denoising process for the prior. If not defined, equal spacedprior_num_inference_steps
timesteps are used. Must be in descending order. - decoder_timesteps (
List[float]
, optional) — Custom timesteps to use for the denoising process for the decoder. If not defined, equal spacednum_inference_steps
timesteps are used. Must be in descending order. - decoder_guidance_scale (
float
, optional, defaults to 0.0) — Guidance scale as defined in Classifier-Free Diffusion Guidance.guidance_scale
is defined asw
of equation 2. of Imagen Paper. Guidance scale is enabled by settingguidance_scale > 1
. Higher guidance scale encourages to generate images that are closely linked to the textprompt
, usually at the expense of lower image quality. - generator (
torch.Generator
orList[torch.Generator]
, optional) — One or a list of torch generator(s) to make generation deterministic. - latents (
torch.FloatTensor
, optional) — Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied randomgenerator
. - output_type (
str
, optional, defaults to"pil"
) — The output format of the generate image. Choose between:"pil"
(PIL.Image.Image
),"np"
(np.array
) or"pt"
(torch.Tensor
). - return_dict (
bool
, optional, defaults toTrue
) — Whether or not to return a ImagePipelineOutput instead of a plain tuple. - prior_callback_on_step_end (
Callable
, optional) — A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments:prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)
. - prior_callback_on_step_end_tensor_inputs (
List
, optional) — The list of tensor inputs for theprior_callback_on_step_end
function. The tensors specified in the list will be passed ascallback_kwargs
argument. You will only be able to include variables listed in the._callback_tensor_inputs
attribute of your pipeline class. - callback_on_step_end (
Callable
, optional) — A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments:callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)
.callback_kwargs
will include a list of all tensors as specified bycallback_on_step_end_tensor_inputs
. - callback_on_step_end_tensor_inputs (
List
, optional) — The list of tensor inputs for thecallback_on_step_end
function. The tensors specified in the list will be passed ascallback_kwargs
argument. You will only be able to include variables listed in the._callback_tensor_inputs
attribute of your pipeline class.
Function invoked when calling the pipeline for generation.
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to enable_sequential_cpu_offload
, this method moves one whole model at a time to the GPU when its forward
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
enable_sequential_cpu_offload
, but performance is much better due to the iterative execution of the unet
.
Offloads all models (unet
, text_encoder
, vae
, and safety checker
state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a torch.device('meta')
and loaded on a
GPU only when their specific submodule’s forward
method is called. Offloading happens on a submodule basis.
Memory savings are higher than using enable_model_cpu_offload
, but performance is lower.
WuerstchenPriorPipeline
class diffusers.WuerstchenPriorPipeline
< source >( tokenizer: CLIPTokenizer text_encoder: CLIPTextModel prior: WuerstchenPrior scheduler: DDPMWuerstchenScheduler latent_mean: float = 42.0 latent_std: float = 1.0 resolution_multiple: float = 42.67 )
Parameters
- prior (
Prior
) — The canonical unCLIP prior to approximate the image embedding from the text embedding. - text_encoder (
CLIPTextModelWithProjection
) — Frozen text-encoder. - tokenizer (
CLIPTokenizer
) — Tokenizer of class CLIPTokenizer. - scheduler (
DDPMWuerstchenScheduler
) — A scheduler to be used in combination withprior
to generate image embedding. - latent_mean (‘float’, optional, defaults to 42.0) — Mean value for latent diffusers.
- latent_std (‘float’, optional, defaults to 1.0) — Standard value for latent diffusers.
- resolution_multiple (‘float’, optional, defaults to 42.67) — Default resolution for multiple images generated.
Pipeline for generating image prior for Wuerstchen.
This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- load_lora_weights() for loading LoRA weights
- save_lora_weights() for saving LoRA weights
__call__
< source >( prompt: Union = None height: int = 1024 width: int = 1024 num_inference_steps: int = 60 timesteps: List = None guidance_scale: float = 8.0 negative_prompt: Union = None prompt_embeds: Optional = None negative_prompt_embeds: Optional = None num_images_per_prompt: Optional = 1 generator: Union = None latents: Optional = None output_type: Optional = 'pt' return_dict: bool = True callback_on_step_end: Optional = None callback_on_step_end_tensor_inputs: List = ['latents'] **kwargs )
Parameters
- prompt (
str
orList[str]
) — The prompt or prompts to guide the image generation. - height (
int
, optional, defaults to 1024) — The height in pixels of the generated image. - width (
int
, optional, defaults to 1024) — The width in pixels of the generated image. - num_inference_steps (
int
, optional, defaults to 60) — The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - timesteps (
List[int]
, optional) — Custom timesteps to use for the denoising process. If not defined, equal spacednum_inference_steps
timesteps are used. Must be in descending order. - guidance_scale (
float
, optional, defaults to 8.0) — Guidance scale as defined in Classifier-Free Diffusion Guidance.decoder_guidance_scale
is defined asw
of equation 2. of Imagen Paper. Guidance scale is enabled by settingdecoder_guidance_scale > 1
. Higher guidance scale encourages to generate images that are closely linked to the textprompt
, usually at the expense of lower image quality. - negative_prompt (
str
orList[str]
, optional) — The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored ifdecoder_guidance_scale
is less than1
). - prompt_embeds (
torch.FloatTensor
, optional) — Pre-generated text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, text embeddings will be generated fromprompt
input argument. - negative_prompt_embeds (
torch.FloatTensor
, optional) — Pre-generated negative text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, negative_prompt_embeds will be generated fromnegative_prompt
input argument. - num_images_per_prompt (
int
, optional, defaults to 1) — The number of images to generate per prompt. - generator (
torch.Generator
orList[torch.Generator]
, optional) — One or a list of torch generator(s) to make generation deterministic. - latents (
torch.FloatTensor
, optional) — Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied randomgenerator
. - output_type (
str
, optional, defaults to"pil"
) — The output format of the generate image. Choose between:"pil"
(PIL.Image.Image
),"np"
(np.array
) or"pt"
(torch.Tensor
). - return_dict (
bool
, optional, defaults toTrue
) — Whether or not to return a ImagePipelineOutput instead of a plain tuple. - callback_on_step_end (
Callable
, optional) — A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments:callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)
.callback_kwargs
will include a list of all tensors as specified bycallback_on_step_end_tensor_inputs
. - callback_on_step_end_tensor_inputs (
List
, optional) — The list of tensor inputs for thecallback_on_step_end
function. The tensors specified in the list will be passed ascallback_kwargs
argument. You will only be able to include variables listed in the._callback_tensor_inputs
attribute of your pipeline class.
Function invoked when calling the pipeline for generation.
Examples:
>>> import torch
>>> from diffusers import WuerstchenPriorPipeline
>>> prior_pipe = WuerstchenPriorPipeline.from_pretrained(
... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16
... ).to("cuda")
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> prior_output = pipe(prompt)
WuerstchenPriorPipelineOutput
class diffusers.pipelines.wuerstchen.pipeline_wuerstchen_prior.WuerstchenPriorPipelineOutput
< source >( image_embeddings: Union )
Output class for WuerstchenPriorPipeline.
WuerstchenDecoderPipeline
class diffusers.WuerstchenDecoderPipeline
< source >( tokenizer: CLIPTokenizer text_encoder: CLIPTextModel decoder: WuerstchenDiffNeXt scheduler: DDPMWuerstchenScheduler vqgan: PaellaVQModel latent_dim_scale: float = 10.67 )
Parameters
- tokenizer (
CLIPTokenizer
) — The CLIP tokenizer. - text_encoder (
CLIPTextModel
) — The CLIP text encoder. - decoder (
WuerstchenDiffNeXt
) — The WuerstchenDiffNeXt unet decoder. - vqgan (
PaellaVQModel
) — The VQGAN model. - scheduler (
DDPMWuerstchenScheduler
) — A scheduler to be used in combination withprior
to generate image embedding. - latent_dim_scale (float,
optional
, defaults to 10.67) — Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are height=24 and width=24, the VQ latent shape needs to be height=int(2410.67)=256 and width=int(2410.67)=256 in order to match the training conditions.
Pipeline for generating images from the Wuerstchen model.
This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
__call__
< source >( image_embeddings: Union prompt: Union = None num_inference_steps: int = 12 timesteps: Optional = None guidance_scale: float = 0.0 negative_prompt: Union = None num_images_per_prompt: int = 1 generator: Union = None latents: Optional = None output_type: Optional = 'pil' return_dict: bool = True callback_on_step_end: Optional = None callback_on_step_end_tensor_inputs: List = ['latents'] **kwargs )
Parameters
- image_embedding (
torch.FloatTensor
orList[torch.FloatTensor]
) — Image Embeddings either extracted from an image or generated by a Prior Model. - prompt (
str
orList[str]
) — The prompt or prompts to guide the image generation. - num_inference_steps (
int
, optional, defaults to 12) — The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - timesteps (
List[int]
, optional) — Custom timesteps to use for the denoising process. If not defined, equal spacednum_inference_steps
timesteps are used. Must be in descending order. - guidance_scale (
float
, optional, defaults to 0.0) — Guidance scale as defined in Classifier-Free Diffusion Guidance.decoder_guidance_scale
is defined asw
of equation 2. of Imagen Paper. Guidance scale is enabled by settingdecoder_guidance_scale > 1
. Higher guidance scale encourages to generate images that are closely linked to the textprompt
, usually at the expense of lower image quality. - negative_prompt (
str
orList[str]
, optional) — The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored ifdecoder_guidance_scale
is less than1
). - num_images_per_prompt (
int
, optional, defaults to 1) — The number of images to generate per prompt. - generator (
torch.Generator
orList[torch.Generator]
, optional) — One or a list of torch generator(s) to make generation deterministic. - latents (
torch.FloatTensor
, optional) — Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied randomgenerator
. - output_type (
str
, optional, defaults to"pil"
) — The output format of the generate image. Choose between:"pil"
(PIL.Image.Image
),"np"
(np.array
) or"pt"
(torch.Tensor
). - return_dict (
bool
, optional, defaults toTrue
) — Whether or not to return a ImagePipelineOutput instead of a plain tuple. - callback_on_step_end (
Callable
, optional) — A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments:callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)
.callback_kwargs
will include a list of all tensors as specified bycallback_on_step_end_tensor_inputs
. - callback_on_step_end_tensor_inputs (
List
, optional) — The list of tensor inputs for thecallback_on_step_end
function. The tensors specified in the list will be passed ascallback_kwargs
argument. You will only be able to include variables listed in the._callback_tensor_inputs
attribute of your pipeline class.
Function invoked when calling the pipeline for generation.
Examples:
>>> import torch
>>> from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline
>>> prior_pipe = WuerstchenPriorPipeline.from_pretrained(
... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16
... ).to("cuda")
>>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain("warp-ai/wuerstchen", torch_dtype=torch.float16).to(
... "cuda"
... )
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> prior_output = pipe(prompt)
>>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt)
Citation
@misc{pernias2023wuerstchen,
title={Wuerstchen: An Efficient Architecture for Large-Scale Text-to-Image Diffusion Models},
author={Pablo Pernias and Dominic Rampas and Mats L. Richter and Christopher J. Pal and Marc Aubreville},
year={2023},
eprint={2306.00637},
archivePrefix={arXiv},
primaryClass={cs.CV}
}