File size: 5,723 Bytes
663259c
 
 
e52e889
44356d2
e52e889
55ecd3b
 
 
 
 
 
e52e889
 
 
 
55ecd3b
e52e889
 
 
55ecd3b
e52e889
 
 
 
 
 
 
 
44356d2
e52e889
55ecd3b
c1e8b26
e52e889
44356d2
e52e889
4e8adf7
e52e889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e8adf7
e52e889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
039990d
e52e889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
---
license: mit
---
<img src="https://cdn-uploads.huggingface.co/production/uploads/634cb5eefb80cc6bcaf63c3e/i-DYpDHw8Pwiy7QBKZVR5.jpeg" width=1500>

## Würstchen - Overview
Würstchen is a diffusion model, whose text-conditional model works in a highly compressed latent space of images. Why is this important? Compressing data can reduce
computational costs for both training and inference by magnitudes. Training on 1024x1024 images is way more expensive than training on 32x32. Usually, other works make 
use of a relatively small compression, in the range of 4x - 8x spatial compression. Würstchen takes this to an extreme. Through its novel design, we achieve a 42x spatial
compression. This was unseen before because common methods fail to faithfully reconstruct detailed images after 16x spatial compression. Würstchen employs a 
two-stage compression, what we call Stage A and Stage B. Stage A is a VQGAN, and Stage B is a Diffusion Autoencoder (more details can be found in the [paper](https://arxiv.org/abs/2306.00637)).
A third model, Stage C, is learned in that highly compressed latent space. This training requires fractions of the compute used for current top-performing models, allowing
also cheaper and faster inference. 

## Würstchen - Prior
The Prior is what we refer to as "Stage C". It is the text-conditional model, operating in the small latent space that Stage A and Stage B encode images into. During 
inference, its job is to generate the image latents given text. These image latents are then sent to Stages A & B to decode the latents into pixel space. 

### Prior - Model - Interpolated
The interpolated model is our current best Prior (Stage C) checkpoint. It is an interpolation between our [base model](https://huggingface.co/warp-ai/wuerstchen-prior-model-base) and the [finetuned model](https://huggingface.co/warp-ai/wuerstchen-prior-model-finetuned).
We created this interpolation because the finetuned model became too artistic and often only generates artistic images. The base model, however, usually is very photorealistic.
As a result, we combined both by interpolating their weights by 50%, so the middle between the base and finetuned model (`0.5 * base_weights + 0.5 * finetuned_weights`).
You can also interpolate the [base model](https://huggingface.co/warp-ai/wuerstchen-prior-model-base) and the [finetuned model](https://huggingface.co/warp-ai/wuerstchen-prior-model-finetuned)
as you want and maybe find an interpolation that fits your needs better than this checkpoint.

### Image Sizes
Würstchen was trained on image resolutions between 1024x1024 & 1536x1536. We sometimes also observe good outputs at resolutions like 1024x2048. Feel free to try it out.
We also observed that the Prior (Stage C) adapts extremely fast to new resolutions. So finetuning it at 2048x2048 should be computationally cheap.
<img src="https://cdn-uploads.huggingface.co/production/uploads/634cb5eefb80cc6bcaf63c3e/IfVsUDcP15OY-5wyLYKnQ.jpeg" width=1000>

## How to run
This pipeline should be run together with https://huggingface.co/warp-ai/wuerstchen:

```py
import torch
from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS

device = "cuda"
dtype = torch.float16
num_images_per_prompt = 2

prior_pipeline = WuerstchenPriorPipeline.from_pretrained(
    "warp-ai/wuerstchen-prior", torch_dtype=dtype
).to(device)
decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained(
    "warp-ai/wuerstchen", torch_dtype=dtype
).to(device)

caption = "Anthropomorphic cat dressed as a fire fighter"
negative_prompt = ""

prior_output = prior_pipeline(
    prompt=caption,
    height=1024,
    width=1536,
    timesteps=DEFAULT_STAGE_C_TIMESTEPS,
    negative_prompt=negative_prompt,
	guidance_scale=4.0,
    num_images_per_prompt=num_images_per_prompt,
)
decoder_output = decoder_pipeline(
    image_embeddings=prior_output.image_embeddings,
    prompt=caption,
    negative_prompt=negative_prompt,
    guidance_scale=0.0,
    output_type="pil",
).images
```

## Model Details
- **Developed by:** Pablo Pernias, Dominic Rampas
- **Model type:** Diffusion-based text-to-image generation model
- **Language(s):** English
- **License:** MIT
- **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a Diffusion model in the style of Stage C from the [Würstchen paper](https://arxiv.org/abs/2306.00637) that uses a fixed, pretrained text encoder ([CLIP ViT-bigG/14](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
- **Resources for more information:** [GitHub Repository](https://github.com/dome272/Wuerstchen), [Paper](https://arxiv.org/abs/2306.00637).
- **Cite as:**

      @misc{pernias2023wuerstchen,
            title={Wuerstchen: Efficient Pretraining of Text-to-Image Models}, 
            author={Pablo Pernias and Dominic Rampas and Marc Aubreville},
            year={2023},
            eprint={2306.00637},
            archivePrefix={arXiv},
            primaryClass={cs.CV}
      }

## Environmental Impact

**Würstchen v2** **Estimated Emissions**
Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.

- **Hardware Type:** A100 PCIe 40GB
- **Hours used:** 24602
- **Cloud Provider:** AWS
- **Compute Region:** US-east
- **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 2275.68 kg CO2 eq.