madebyollin commited on
Commit
8ab04db
1 Parent(s): 2fa4db2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +23 -2
README.md CHANGED
@@ -7,5 +7,26 @@ inference: false
7
  ---
8
  # SDXL-VAE-FP16-Fix
9
 
10
- SDXL-VAE-FP16-Fix is a version of the SDXL VAE decoder which was modified to work in fp16 precision without generating NaNs.
11
- SDXL-VAE-FP16-Fix is potentially useful for running the SDXL VAE on platforms where bf16 is not available.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  ---
8
  # SDXL-VAE-FP16-Fix
9
 
10
+ SDXL-VAE-FP16-Fix is the [SDXL VAE](https://huggingface.co/stabilityai/sdxl-vae), but modified to run in fp16 precision without generating NaNs.
11
+
12
+ ```python
13
+ from diffusers import DiffusionPipeline, AutoencoderKL
14
+ import torch
15
+
16
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16").to("cuda")
17
+ fixed_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix").half().to("cuda")
18
+
19
+ prompt = "An astronaut riding a green horse"
20
+ latents = pipe(prompt=prompt, output_type="latent").images
21
+
22
+ for vae in (pipe.vae, fixed_vae):
23
+ for dtype in (torch.float32, torch.float16):
24
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16, enabled=(dtype==torch.float16)):
25
+ print(dtype, "sdxl-vae" if vae == pipe.vae else "sdxl-vae-fp16-fix")
26
+ display(pipe.image_processor.postprocess(vae.decode(latents / vae.config.scaling_factor).sample)[0])
27
+ ```
28
+
29
+ | VAE | Decoding in `float32` precision | Decoding in `float16` precision |
30
+ | --------------------- | ------------------------------- | ------------------------------- |
31
+ | SDXL-VAE | ✅ ![](./images/orig-fp32.png) | ⚠️ ![](./images/orig-fp16.png) |
32
+ | SDXL-VAE-FP16-Fix | ✅ ![](./images/fix-fp32.png) | ✅ ![](./images/fix-fp16.png) |