File size: 18,585 Bytes
43b7e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 |
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# DeepFloyd IF
## Overview
DeepFloyd IF is a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and language understanding.
The model is a modular composed of a frozen text encoder and three cascaded pixel diffusion modules:
- Stage 1: a base model that generates 64x64 px image based on text prompt,
- Stage 2: a 64x64 px => 256x256 px super-resolution model, and
- Stage 3: a 256x256 px => 1024x1024 px super-resolution model
Stage 1 and Stage 2 utilize a frozen text encoder based on the T5 transformer to extract text embeddings, which are then fed into a UNet architecture enhanced with cross-attention and attention pooling.
Stage 3 is [Stability AI's x4 Upscaling model](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler).
The result is a highly efficient model that outperforms current state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset.
Our work underscores the potential of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for text-to-image synthesis.
## Usage
Before you can use IF, you need to accept its usage conditions. To do so:
1. Make sure to have a [Hugging Face account](https://huggingface.co/join) and be logged in.
2. Accept the license on the model card of [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0). Accepting the license on the stage I model card will auto accept for the other IF models.
3. Make sure to login locally. Install `huggingface_hub`:
```sh
pip install huggingface_hub --upgrade
```
run the login function in a Python shell:
```py
from huggingface_hub import login
login()
```
and enter your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens).
Next we install `diffusers` and dependencies:
```sh
pip install -q diffusers accelerate transformers
```
The following sections give more in-detail examples of how to use IF. Specifically:
- [Text-to-Image Generation](#text-to-image-generation)
- [Image-to-Image Generation](#text-guided-image-to-image-generation)
- [Inpainting](#text-guided-inpainting-generation)
- [Reusing model weights](#converting-between-different-pipelines)
- [Speed optimization](#optimizing-for-speed)
- [Memory optimization](#optimizing-for-memory)
**Available checkpoints**
- *Stage-1*
- [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0)
- [DeepFloyd/IF-I-L-v1.0](https://huggingface.co/DeepFloyd/IF-I-L-v1.0)
- [DeepFloyd/IF-I-M-v1.0](https://huggingface.co/DeepFloyd/IF-I-M-v1.0)
- *Stage-2*
- [DeepFloyd/IF-II-L-v1.0](https://huggingface.co/DeepFloyd/IF-II-L-v1.0)
- [DeepFloyd/IF-II-M-v1.0](https://huggingface.co/DeepFloyd/IF-II-M-v1.0)
- *Stage-3*
- [stabilityai/stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler)
**Google Colab**
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb)
### Text-to-Image Generation
By default diffusers makes use of [model cpu offloading](../../optimization/memory#model-offloading) to run the whole IF pipeline with as little as 14 GB of VRAM.
```python
from diffusers import DiffusionPipeline
from diffusers.utils import pt_to_pil, make_image_grid
import torch
# stage 1
stage_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
stage_1.enable_model_cpu_offload()
# stage 2
stage_2 = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
)
stage_2.enable_model_cpu_offload()
# stage 3
safety_modules = {
"feature_extractor": stage_1.feature_extractor,
"safety_checker": stage_1.safety_checker,
"watermarker": stage_1.watermarker,
}
stage_3 = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
)
stage_3.enable_model_cpu_offload()
prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
generator = torch.manual_seed(1)
# text embeds
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
# stage 1
stage_1_output = stage_1(
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt"
).images
#pt_to_pil(stage_1_output)[0].save("./if_stage_I.png")
# stage 2
stage_2_output = stage_2(
image=stage_1_output,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
generator=generator,
output_type="pt",
).images
#pt_to_pil(stage_2_output)[0].save("./if_stage_II.png")
# stage 3
stage_3_output = stage_3(prompt=prompt, image=stage_2_output, noise_level=100, generator=generator).images
#stage_3_output[0].save("./if_stage_III.png")
make_image_grid([pt_to_pil(stage_1_output)[0], pt_to_pil(stage_2_output)[0], stage_3_output[0]], rows=1, rows=3)
```
### Text Guided Image-to-Image Generation
The same IF model weights can be used for text-guided image-to-image translation or image variation.
In this case just make sure to load the weights using the [`IFImg2ImgPipeline`] and [`IFImg2ImgSuperResolutionPipeline`] pipelines.
**Note**: You can also directly move the weights of the text-to-image pipelines to the image-to-image pipelines
without loading them twice by making use of the [`~DiffusionPipeline.components`] argument as explained [here](#converting-between-different-pipelines).
```python
from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline
from diffusers.utils import pt_to_pil, load_image, make_image_grid
import torch
# download image
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
original_image = load_image(url)
original_image = original_image.resize((768, 512))
# stage 1
stage_1 = IFImg2ImgPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
stage_1.enable_model_cpu_offload()
# stage 2
stage_2 = IFImg2ImgSuperResolutionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
)
stage_2.enable_model_cpu_offload()
# stage 3
safety_modules = {
"feature_extractor": stage_1.feature_extractor,
"safety_checker": stage_1.safety_checker,
"watermarker": stage_1.watermarker,
}
stage_3 = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
)
stage_3.enable_model_cpu_offload()
prompt = "A fantasy landscape in style minecraft"
generator = torch.manual_seed(1)
# text embeds
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
# stage 1
stage_1_output = stage_1(
image=original_image,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
generator=generator,
output_type="pt",
).images
#pt_to_pil(stage_1_output)[0].save("./if_stage_I.png")
# stage 2
stage_2_output = stage_2(
image=stage_1_output,
original_image=original_image,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
generator=generator,
output_type="pt",
).images
#pt_to_pil(stage_2_output)[0].save("./if_stage_II.png")
# stage 3
stage_3_output = stage_3(prompt=prompt, image=stage_2_output, generator=generator, noise_level=100).images
#stage_3_output[0].save("./if_stage_III.png")
make_image_grid([original_image, pt_to_pil(stage_1_output)[0], pt_to_pil(stage_2_output)[0], stage_3_output[0]], rows=1, rows=4)
```
### Text Guided Inpainting Generation
The same IF model weights can be used for text-guided image-to-image translation or image variation.
In this case just make sure to load the weights using the [`IFInpaintingPipeline`] and [`IFInpaintingSuperResolutionPipeline`] pipelines.
**Note**: You can also directly move the weights of the text-to-image pipelines to the image-to-image pipelines
without loading them twice by making use of the [`~DiffusionPipeline.components()`] function as explained [here](#converting-between-different-pipelines).
```python
from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline
from diffusers.utils import pt_to_pil, load_image, make_image_grid
import torch
# download image
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png"
original_image = load_image(url)
# download mask
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png"
mask_image = load_image(url)
# stage 1
stage_1 = IFInpaintingPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
stage_1.enable_model_cpu_offload()
# stage 2
stage_2 = IFInpaintingSuperResolutionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
)
stage_2.enable_model_cpu_offload()
# stage 3
safety_modules = {
"feature_extractor": stage_1.feature_extractor,
"safety_checker": stage_1.safety_checker,
"watermarker": stage_1.watermarker,
}
stage_3 = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
)
stage_3.enable_model_cpu_offload()
prompt = "blue sunglasses"
generator = torch.manual_seed(1)
# text embeds
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
# stage 1
stage_1_output = stage_1(
image=original_image,
mask_image=mask_image,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
generator=generator,
output_type="pt",
).images
#pt_to_pil(stage_1_output)[0].save("./if_stage_I.png")
# stage 2
stage_2_output = stage_2(
image=stage_1_output,
original_image=original_image,
mask_image=mask_image,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
generator=generator,
output_type="pt",
).images
#pt_to_pil(stage_1_output)[0].save("./if_stage_II.png")
# stage 3
stage_3_output = stage_3(prompt=prompt, image=stage_2_output, generator=generator, noise_level=100).images
#stage_3_output[0].save("./if_stage_III.png")
make_image_grid([original_image, mask_image, pt_to_pil(stage_1_output)[0], pt_to_pil(stage_2_output)[0], stage_3_output[0]], rows=1, rows=5)
```
### Converting between different pipelines
In addition to being loaded with `from_pretrained`, Pipelines can also be loaded directly from each other.
```python
from diffusers import IFPipeline, IFSuperResolutionPipeline
pipe_1 = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0")
pipe_2 = IFSuperResolutionPipeline.from_pretrained("DeepFloyd/IF-II-L-v1.0")
from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline
pipe_1 = IFImg2ImgPipeline(**pipe_1.components)
pipe_2 = IFImg2ImgSuperResolutionPipeline(**pipe_2.components)
from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline
pipe_1 = IFInpaintingPipeline(**pipe_1.components)
pipe_2 = IFInpaintingSuperResolutionPipeline(**pipe_2.components)
```
### Optimizing for speed
The simplest optimization to run IF faster is to move all model components to the GPU.
```py
pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
pipe.to("cuda")
```
You can also run the diffusion process for a shorter number of timesteps.
This can either be done with the `num_inference_steps` argument:
```py
pipe("<prompt>", num_inference_steps=30)
```
Or with the `timesteps` argument:
```py
from diffusers.pipelines.deepfloyd_if import fast27_timesteps
pipe("<prompt>", timesteps=fast27_timesteps)
```
When doing image variation or inpainting, you can also decrease the number of timesteps
with the strength argument. The strength argument is the amount of noise to add to the input image which also determines how many steps to run in the denoising process.
A smaller number will vary the image less but run faster.
```py
pipe = IFImg2ImgPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
pipe.to("cuda")
image = pipe(image=image, prompt="<prompt>", strength=0.3).images
```
You can also use [`torch.compile`](../../optimization/torch2.0). Note that we have not exhaustively tested `torch.compile`
with IF and it might not give expected results.
```py
from diffusers import DiffusionPipeline
import torch
pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
pipe.to("cuda")
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
```
### Optimizing for memory
When optimizing for GPU memory, we can use the standard diffusers CPU offloading APIs.
Either the model based CPU offloading,
```py
pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
```
or the more aggressive layer based CPU offloading.
```py
pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
pipe.enable_sequential_cpu_offload()
```
Additionally, T5 can be loaded in 8bit precision
```py
from transformers import T5EncoderModel
text_encoder = T5EncoderModel.from_pretrained(
"DeepFloyd/IF-I-XL-v1.0", subfolder="text_encoder", device_map="auto", load_in_8bit=True, variant="8bit"
)
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-I-XL-v1.0",
text_encoder=text_encoder, # pass the previously instantiated 8bit text encoder
unet=None,
device_map="auto",
)
prompt_embeds, negative_embeds = pipe.encode_prompt("<prompt>")
```
For CPU RAM constrained machines like Google Colab free tier where we can't load all model components to the CPU at once, we can manually only load the pipeline with
the text encoder or UNet when the respective model components are needed.
```py
from diffusers import IFPipeline, IFSuperResolutionPipeline
import torch
import gc
from transformers import T5EncoderModel
from diffusers.utils import pt_to_pil, make_image_grid
text_encoder = T5EncoderModel.from_pretrained(
"DeepFloyd/IF-I-XL-v1.0", subfolder="text_encoder", device_map="auto", load_in_8bit=True, variant="8bit"
)
# text to image
pipe = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-I-XL-v1.0",
text_encoder=text_encoder, # pass the previously instantiated 8bit text encoder
unet=None,
device_map="auto",
)
prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
# Remove the pipeline so we can re-load the pipeline with the unet
del text_encoder
del pipe
gc.collect()
torch.cuda.empty_cache()
pipe = IFPipeline.from_pretrained(
"DeepFloyd/IF-I-XL-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16, device_map="auto"
)
generator = torch.Generator().manual_seed(0)
stage_1_output = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
output_type="pt",
generator=generator,
).images
#pt_to_pil(stage_1_output)[0].save("./if_stage_I.png")
# Remove the pipeline so we can load the super-resolution pipeline
del pipe
gc.collect()
torch.cuda.empty_cache()
# First super resolution
pipe = IFSuperResolutionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16, device_map="auto"
)
generator = torch.Generator().manual_seed(0)
stage_2_output = pipe(
image=stage_1_output,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
output_type="pt",
generator=generator,
).images
#pt_to_pil(stage_2_output)[0].save("./if_stage_II.png")
make_image_grid([pt_to_pil(stage_1_output)[0], pt_to_pil(stage_2_output)[0]], rows=1, rows=2)
```
## Available Pipelines:
| Pipeline | Tasks | Colab
|---|---|:---:|
| [pipeline_if.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py) | *Text-to-Image Generation* | - |
| [pipeline_if_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py) | *Text-to-Image Generation* | - |
| [pipeline_if_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py) | *Image-to-Image Generation* | - |
| [pipeline_if_img2img_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py) | *Image-to-Image Generation* | - |
| [pipeline_if_inpainting.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py) | *Image-to-Image Generation* | - |
| [pipeline_if_inpainting_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py) | *Image-to-Image Generation* | - |
## IFPipeline
[[autodoc]] IFPipeline
- all
- __call__
## IFSuperResolutionPipeline
[[autodoc]] IFSuperResolutionPipeline
- all
- __call__
## IFImg2ImgPipeline
[[autodoc]] IFImg2ImgPipeline
- all
- __call__
## IFImg2ImgSuperResolutionPipeline
[[autodoc]] IFImg2ImgSuperResolutionPipeline
- all
- __call__
## IFInpaintingPipeline
[[autodoc]] IFInpaintingPipeline
- all
- __call__
## IFInpaintingSuperResolutionPipeline
[[autodoc]] IFInpaintingSuperResolutionPipeline
- all
- __call__
|