Spaces:
Runtime error
Runtime error
Commit
•
3aff77a
1
Parent(s):
073f81a
Upload folder using huggingface_hub
Browse files- .gitattributes +11 -0
- .gitignore +162 -0
- README.md +46 -12
- app_ctrlx.py +412 -0
- assets/images/bear_avocado__spatext.jpg +0 -0
- assets/images/bedroom__sketch.jpg +0 -0
- assets/images/cat__mesh.jpg +0 -0
- assets/images/cat__point_cloud.jpg +0 -0
- assets/images/dog__sketch.jpg +0 -0
- assets/images/fruit_bowl.jpg +0 -0
- assets/images/grapes.jpg +0 -0
- assets/images/horse.jpg +0 -0
- assets/images/horse__point_cloud.jpg +0 -0
- assets/images/knight__humanoid.jpg +0 -0
- assets/images/library__mesh.jpg +0 -0
- assets/images/living_room__seg.jpg +0 -0
- assets/images/living_room_modern.jpg +0 -0
- assets/images/man_park.jpg +0 -0
- assets/images/person__mesh.jpg +0 -0
- assets/images/running__pose.jpg +0 -0
- assets/images/squirrel.jpg +0 -0
- assets/images/tiger.jpg +0 -0
- assets/images/van_gogh.jpg +0 -0
- ctrl_x/__init__.py +0 -0
- ctrl_x/pipelines/__init__.py +0 -0
- ctrl_x/pipelines/pipeline_sdxl.py +665 -0
- ctrl_x/utils/__init__.py +3 -0
- ctrl_x/utils/feature.py +79 -0
- ctrl_x/utils/media.py +21 -0
- ctrl_x/utils/sdxl.py +274 -0
- ctrl_x/utils/utils.py +88 -0
- docs/assets/bootstrap.min.css +0 -0
- docs/assets/cross_image_attention.jpg +3 -0
- docs/assets/ctrl-x.jpg +3 -0
- docs/assets/font.css +37 -0
- docs/assets/freecontrol.jpg +3 -0
- docs/assets/genforce.png +0 -0
- docs/assets/pipeline.jpg +3 -0
- docs/assets/results_animatediff.mp4 +3 -0
- docs/assets/results_multi_subject.jpg +3 -0
- docs/assets/results_struct+app.jpg +3 -0
- docs/assets/results_struct+app_2.jpg +3 -0
- docs/assets/results_struct+prompt.jpg +3 -0
- docs/assets/style.css +139 -0
- docs/assets/teaser_github.jpg +3 -0
- docs/assets/teaser_small.jpg +3 -0
- docs/index.html +186 -0
- environment.yaml +125 -0
.gitattributes
CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
docs/assets/cross_image_attention.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
docs/assets/ctrl-x.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
docs/assets/freecontrol.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
docs/assets/pipeline.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
docs/assets/results_animatediff.mp4 filter=lfs diff=lfs merge=lfs -text
|
41 |
+
docs/assets/results_multi_subject.jpg filter=lfs diff=lfs merge=lfs -text
|
42 |
+
docs/assets/results_struct+app.jpg filter=lfs diff=lfs merge=lfs -text
|
43 |
+
docs/assets/results_struct+app_2.jpg filter=lfs diff=lfs merge=lfs -text
|
44 |
+
docs/assets/results_struct+prompt.jpg filter=lfs diff=lfs merge=lfs -text
|
45 |
+
docs/assets/teaser_github.jpg filter=lfs diff=lfs merge=lfs -text
|
46 |
+
docs/assets/teaser_small.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
README.md
CHANGED
@@ -1,12 +1,46 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ctrl-X: Controlling Structure and Appearance for Text-To-Image Generation Without Guidance (NeurIPS 2024)
|
2 |
+
|
3 |
+
<a href="https://arxiv.org/abs/2406.07540"><img src="https://img.shields.io/badge/arXiv-Paper-red"></a>
|
4 |
+
<a href="https://genforce.github.io/ctrl-x"><img src="https://img.shields.io/badge/Project-Page-yellow"></a>
|
5 |
+
[![GitHub](https://img.shields.io/github/stars/genforce/ctrl-x?style=social)](https://github.com/genforce/ctrl-x)
|
6 |
+
|
7 |
+
[Kuan Heng Lin](https://kuanhenglin.github.io)<sup>1*</sup>, [Sicheng Mo](https://sichengmo.github.io/)<sup>1*</sup>, [Ben Klingher](https://bklingher.github.io)<sup>1</sup>, [Fangzhou Mu](https://pages.cs.wisc.edu/~fmu/)<sup>2</sup>, [Bolei Zhou](https://boleizhou.github.io/)<sup>1</sup> <br>
|
8 |
+
<sup>1</sup>UCLA <sup>2</sup>NVIDIA <br>
|
9 |
+
<sup>*</sup>Equal contribution <br>
|
10 |
+
|
11 |
+
![Ctrl-X teaser figure](docs/assets/teaser_github.jpg)
|
12 |
+
|
13 |
+
## Getting started
|
14 |
+
|
15 |
+
### Environment setup
|
16 |
+
|
17 |
+
Our code is built on top of [`diffusers v0.28.0`](https://github.com/huggingface/diffusers). To set up the environment, please run the following.
|
18 |
+
```
|
19 |
+
conda env create -f environment.yaml
|
20 |
+
conda activate ctrlx
|
21 |
+
```
|
22 |
+
|
23 |
+
### Gradio demo
|
24 |
+
|
25 |
+
We provide a user interface for testing our method. Running the following command starts the demo.
|
26 |
+
```
|
27 |
+
python3 app_ctrlx.py
|
28 |
+
```
|
29 |
+
Have fun playing around! :D
|
30 |
+
|
31 |
+
## Contact
|
32 |
+
|
33 |
+
For any questions, thoughts, discussions, and any other things you want to reach out for, please contact [Kuan Heng (Jordan) Lin](https://kuanhenglin.github.io) (kuanhenglin@ucla.edu).
|
34 |
+
|
35 |
+
## Reference
|
36 |
+
|
37 |
+
If you use our code in your research, please cite the following work.
|
38 |
+
|
39 |
+
```bibtex
|
40 |
+
@inproceedings{lin2024ctrlx,
|
41 |
+
author = {Lin, {Kuan Heng} and Mo, Sicheng and Klingher, Ben and Mu, Fangzhou and Zhou, Bolei},
|
42 |
+
booktitle = {Advances in Neural Information Processing Systems},
|
43 |
+
title = {Ctrl-X: Controlling Structure and Appearance for Text-To-Image Generation Without Guidance},
|
44 |
+
year = {2024}
|
45 |
+
}
|
46 |
+
```
|
app_ctrlx.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser
|
2 |
+
|
3 |
+
from diffusers import DDIMScheduler, StableDiffusionXLImg2ImgPipeline
|
4 |
+
import gradio as gr
|
5 |
+
import torch
|
6 |
+
import yaml
|
7 |
+
|
8 |
+
from ctrl_x.pipelines.pipeline_sdxl import CtrlXStableDiffusionXLPipeline
|
9 |
+
from ctrl_x.utils import *
|
10 |
+
from ctrl_x.utils.sdxl import *
|
11 |
+
|
12 |
+
|
13 |
+
parser = ArgumentParser()
|
14 |
+
parser.add_argument("-m", "--model", type=str, default=None) # Optionally, load model checkpoint from single file
|
15 |
+
args = parser.parse_args()
|
16 |
+
|
17 |
+
torch.backends.cudnn.enabled = False # Sometimes necessary to suppress CUDNN_STATUS_NOT_SUPPORTED
|
18 |
+
|
19 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
20 |
+
|
21 |
+
model_id_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
22 |
+
refiner_id_or_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
23 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
24 |
+
variant = "fp16" if device == "cuda" else "fp32"
|
25 |
+
|
26 |
+
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler") # TODO: Support other schedulers
|
27 |
+
if args.model is None:
|
28 |
+
pipe = CtrlXStableDiffusionXLPipeline.from_pretrained(
|
29 |
+
model_id_or_path, scheduler=scheduler, torch_dtype=torch_dtype, variant=variant, use_safetensors=True
|
30 |
+
)
|
31 |
+
else:
|
32 |
+
print(f"Using weights {args.model} for SDXL base model.")
|
33 |
+
pipe = CtrlXStableDiffusionXLPipeline.from_single_file(args.model, scheduler=scheduler, torch_dtype=torch_dtype)
|
34 |
+
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
35 |
+
refiner_id_or_path, scheduler=scheduler, text_encoder_2=pipe.text_encoder_2, vae=pipe.vae,
|
36 |
+
torch_dtype=torch_dtype, variant=variant, use_safetensors=True,
|
37 |
+
)
|
38 |
+
|
39 |
+
if torch.cuda.is_available():
|
40 |
+
pipe = pipe.to("cuda")
|
41 |
+
refiner = refiner.to("cuda")
|
42 |
+
|
43 |
+
|
44 |
+
def get_control_config(structure_schedule, appearance_schedule):
|
45 |
+
s = structure_schedule
|
46 |
+
a = appearance_schedule
|
47 |
+
|
48 |
+
control_config =\
|
49 |
+
f"""control_schedule:
|
50 |
+
# structure_conv structure_attn appearance_attn conv/attn
|
51 |
+
encoder: # (num layers)
|
52 |
+
0: [[ ], [ ], [ ]] # 2/0
|
53 |
+
1: [[ ], [ ], [{a}, {a} ]] # 2/2
|
54 |
+
2: [[ ], [ ], [{a}, {a} ]] # 2/2
|
55 |
+
middle: [[ ], [ ], [ ]] # 2/1
|
56 |
+
decoder:
|
57 |
+
0: [[{s} ], [{s}, {s}, {s}], [0.0, {a}, {a}]] # 3/3
|
58 |
+
1: [[ ], [ ], [{a}, {a} ]] # 3/3
|
59 |
+
2: [[ ], [ ], [ ]] # 3/0
|
60 |
+
|
61 |
+
control_target:
|
62 |
+
- [output_tensor] # structure_conv choices: {{hidden_states, output_tensor}}
|
63 |
+
- [query, key] # structure_attn choices: {{query, key, value}}
|
64 |
+
- [before] # appearance_attn choices: {{before, value, after}}
|
65 |
+
|
66 |
+
self_recurrence_schedule:
|
67 |
+
- [0.1, 0.5, 2] # format: [start, end, num_recurrence]"""
|
68 |
+
|
69 |
+
return control_config
|
70 |
+
|
71 |
+
|
72 |
+
css = """
|
73 |
+
.config textarea {font-family: monospace; font-size: 80%; white-space: pre}
|
74 |
+
.mono {font-family: monospace}
|
75 |
+
"""
|
76 |
+
|
77 |
+
title = """
|
78 |
+
<div style="display: flex; align-items: center; justify-content: center;margin-bottom: -15px">
|
79 |
+
<h1 style="margin-left: 12px;text-align: center;display: inline-block">
|
80 |
+
Ctrl-X: Controlling Structure and Appearance for Text-To-Image Generation Without Guidance
|
81 |
+
</h1>
|
82 |
+
<h3 style="display: inline-block; margin-left: 10px; margin-top: 7.5px; font-weight: 500">
|
83 |
+
SDXL v1.0
|
84 |
+
</h3>
|
85 |
+
</div>
|
86 |
+
<div style="display: flex; align-items: center; justify-content: center;margin-bottom: 25px">
|
87 |
+
<h3 style="text-align: center">
|
88 |
+
[<a href="https://genforce.github.io/ctrl-x/">Page</a>]
|
89 |
+
|
90 |
+
[<a href="https://arxiv.org/abs/2406.07540">Paper</a>]
|
91 |
+
|
92 |
+
[<a href="https://github.com/genforce/ctrl-x">Code</a>]
|
93 |
+
</h3>
|
94 |
+
</div>
|
95 |
+
<div>
|
96 |
+
<p>
|
97 |
+
<b>Ctrl-X</b> is a simple training-free and guidance-free framework for text-to-image (T2I) generation with
|
98 |
+
structure and appearance control. Given structure and appearance images, Ctrl-X designs feedforward structure
|
99 |
+
control to enable structure alignment with the arbitrary structure image and semantic-aware appearance transfer
|
100 |
+
to facilitate the appearance transfer from the appearance image.
|
101 |
+
</p>
|
102 |
+
<p>
|
103 |
+
Here are some notes and tips for this demo:
|
104 |
+
</p>
|
105 |
+
<ul>
|
106 |
+
<li> On input images:
|
107 |
+
<ul>
|
108 |
+
<li>
|
109 |
+
If both the structure and appearance images are provided, then Ctrl-X does <i>structure and
|
110 |
+
appearance</i> control.
|
111 |
+
</li>
|
112 |
+
<li>
|
113 |
+
If only the structure image is provided, then Ctrl-X does <i>structure-only</i> control and the
|
114 |
+
appearance image is jointly generated with the output image.
|
115 |
+
</li>
|
116 |
+
<li>
|
117 |
+
Similarly, if only the appearance image is provided, then Ctrl-X does <i>appearance-only</i>
|
118 |
+
control.
|
119 |
+
</li>
|
120 |
+
</ul>
|
121 |
+
</li>
|
122 |
+
<li> On prompts:
|
123 |
+
<ul>
|
124 |
+
<li>
|
125 |
+
Though the output prompt can affect the output image to a noticeable extent, the "accuracy" of the
|
126 |
+
structure and appearance prompts are not impactful to the final image.
|
127 |
+
</li>
|
128 |
+
<li>
|
129 |
+
If the structure or appearance prompt is left blank, then it uses the (non-optional) output prompt
|
130 |
+
by default.
|
131 |
+
</li>
|
132 |
+
</ul>
|
133 |
+
</li>
|
134 |
+
<li> On control schedules:
|
135 |
+
<ul>
|
136 |
+
<li>
|
137 |
+
When "Use advanced config" is <b>OFF</b>, the demo uses the structure guidance
|
138 |
+
(<span class="mono">structure_conv</span> and <span class="mono">structure_attn</span>
|
139 |
+
in the advanced config) and appearance guidance (<span class="mono">appearance_attn</span> in the
|
140 |
+
advanced config) sliders to change the control schedules.
|
141 |
+
</li>
|
142 |
+
<li>
|
143 |
+
Otherwise, the demo uses "Advanced control config," which allows per-layer structure and
|
144 |
+
appearance schedule control, along with self-recurrence control. <i>This should be used
|
145 |
+
carefully</i>, and we recommend switching "Use advanced config" <b>OFF</b> in most cases. (For the
|
146 |
+
examples provided at the bottom of the demo, the advanced config uses the default schedules that
|
147 |
+
may not be the best settings for these examples.)
|
148 |
+
</li>
|
149 |
+
</ul>
|
150 |
+
</li>
|
151 |
+
</ul>
|
152 |
+
<p>
|
153 |
+
Have fun! :D
|
154 |
+
</p>
|
155 |
+
</div>
|
156 |
+
"""
|
157 |
+
|
158 |
+
|
159 |
+
def inference(
|
160 |
+
structure_image, appearance_image,
|
161 |
+
prompt, structure_prompt, appearance_prompt,
|
162 |
+
positive_prompt, negative_prompt,
|
163 |
+
guidance_scale, structure_guidance_scale, appearance_guidance_scale,
|
164 |
+
num_inference_steps, eta, seed,
|
165 |
+
width, height,
|
166 |
+
structure_schedule, appearance_schedule, use_advanced_config,
|
167 |
+
control_config,
|
168 |
+
):
|
169 |
+
torch.manual_seed(seed)
|
170 |
+
|
171 |
+
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
|
172 |
+
timesteps = pipe.scheduler.timesteps
|
173 |
+
|
174 |
+
print(f"\nUsing the following control config (use_advanced_config={use_advanced_config}):")
|
175 |
+
if not use_advanced_config:
|
176 |
+
control_config = get_control_config(structure_schedule, appearance_schedule)
|
177 |
+
print(control_config, end="\n\n")
|
178 |
+
|
179 |
+
config = yaml.safe_load(control_config)
|
180 |
+
register_control(
|
181 |
+
model = pipe,
|
182 |
+
timesteps = timesteps,
|
183 |
+
control_schedule = config["control_schedule"],
|
184 |
+
control_target = config["control_target"],
|
185 |
+
)
|
186 |
+
|
187 |
+
pipe.safety_checker = None
|
188 |
+
pipe.requires_safety_checker = False
|
189 |
+
|
190 |
+
self_recurrence_schedule = get_self_recurrence_schedule(config["self_recurrence_schedule"], num_inference_steps)
|
191 |
+
|
192 |
+
pipe.set_progress_bar_config(desc="Ctrl-X inference")
|
193 |
+
refiner.set_progress_bar_config(desc="Refiner")
|
194 |
+
|
195 |
+
result, structure, appearance = pipe(
|
196 |
+
prompt = prompt,
|
197 |
+
structure_prompt = structure_prompt,
|
198 |
+
appearance_prompt = appearance_prompt,
|
199 |
+
structure_image = structure_image,
|
200 |
+
appearance_image = appearance_image,
|
201 |
+
num_inference_steps = num_inference_steps,
|
202 |
+
negative_prompt = negative_prompt,
|
203 |
+
positive_prompt = positive_prompt,
|
204 |
+
height = height,
|
205 |
+
width = width,
|
206 |
+
guidance_scale = guidance_scale,
|
207 |
+
structure_guidance_scale = structure_guidance_scale,
|
208 |
+
appearance_guidance_scale = appearance_guidance_scale,
|
209 |
+
eta = eta,
|
210 |
+
output_type = "pil",
|
211 |
+
return_dict = False,
|
212 |
+
control_schedule = config["control_schedule"],
|
213 |
+
self_recurrence_schedule = self_recurrence_schedule,
|
214 |
+
)
|
215 |
+
|
216 |
+
result_refiner = refiner(
|
217 |
+
image = pipe.refiner_args["latents"],
|
218 |
+
prompt = pipe.refiner_args["prompt"],
|
219 |
+
negative_prompt = pipe.refiner_args["negative_prompt"],
|
220 |
+
height = height,
|
221 |
+
width = width,
|
222 |
+
num_inference_steps = num_inference_steps,
|
223 |
+
guidance_scale = guidance_scale,
|
224 |
+
guidance_rescale = 0.7,
|
225 |
+
num_images_per_prompt = 1,
|
226 |
+
eta = eta,
|
227 |
+
output_type = "pil",
|
228 |
+
).images
|
229 |
+
del pipe.refiner_args
|
230 |
+
|
231 |
+
return [result[0], result_refiner[0], structure[0], appearance[0]]
|
232 |
+
|
233 |
+
|
234 |
+
with gr.Blocks(theme=gr.themes.Default(), css=css, title="Ctrl-X (SDXL v1.0)") as app:
|
235 |
+
gr.HTML(title)
|
236 |
+
|
237 |
+
with gr.Row():
|
238 |
+
|
239 |
+
with gr.Column(scale=55):
|
240 |
+
with gr.Group():
|
241 |
+
kwargs = {} # {"width": 400, "height": 400}
|
242 |
+
with gr.Row():
|
243 |
+
result = gr.Image(label="Output image", format="jpg", **kwargs)
|
244 |
+
result_refiner = gr.Image(label="Output image w/ refiner", format="jpg", **kwargs)
|
245 |
+
with gr.Row():
|
246 |
+
structure_recon = gr.Image(label="Structure image", format="jpg", **kwargs)
|
247 |
+
appearance_recon = gr.Image(label="Style image", format="jpg", **kwargs)
|
248 |
+
with gr.Row():
|
249 |
+
structure_image = gr.Image(label="Upload structure image (optional)", type="pil", **kwargs)
|
250 |
+
appearance_image = gr.Image(label="Upload appearance image (optional)", type="pil", **kwargs)
|
251 |
+
|
252 |
+
with gr.Column(scale=45):
|
253 |
+
with gr.Group():
|
254 |
+
with gr.Row():
|
255 |
+
structure_prompt = gr.Textbox(label="Structure prompt (optional)", placeholder="Prompt which describes the structure image")
|
256 |
+
appearance_prompt = gr.Textbox(label="Appearance prompt (optional)", placeholder="Prompt which describes the style image")
|
257 |
+
with gr.Row():
|
258 |
+
prompt = gr.Textbox(label="Output prompt", placeholder="Prompt which describes the output image")
|
259 |
+
with gr.Row():
|
260 |
+
positive_prompt = gr.Textbox(label="Positive prompt", value="high quality", placeholder="")
|
261 |
+
negative_prompt = gr.Textbox(label="Negative prompt", value="ugly, blurry, dark, low res, unrealistic", placeholder="")
|
262 |
+
with gr.Row():
|
263 |
+
guidance_scale = gr.Slider(label="Target guidance scale", value=5.0, minimum=1, maximum=10)
|
264 |
+
structure_guidance_scale = gr.Slider(label="Structure guidance scale", value=5.0, minimum=1, maximum=10)
|
265 |
+
appearance_guidance_scale = gr.Slider(label="Appearance guidance scale", value=5.0, minimum=1, maximum=10)
|
266 |
+
with gr.Row():
|
267 |
+
num_inference_steps = gr.Slider(label="# inference steps", value=50, minimum=1, maximum=200, step=1)
|
268 |
+
eta = gr.Slider(label="Eta (noise)", value=1.0, minimum=0, maximum=1.0, step=0.01)
|
269 |
+
seed = gr.Slider(0, 2147483647, label="Seed", value=90095, step=1)
|
270 |
+
with gr.Row():
|
271 |
+
width = gr.Slider(label="Width", value=1024, minimum=256, maximum=2048, step=pipe.vae_scale_factor)
|
272 |
+
height = gr.Slider(label="Height", value=1024, minimum=256, maximum=2048, step=pipe.vae_scale_factor)
|
273 |
+
with gr.Row():
|
274 |
+
structure_schedule = gr.Slider(label="Structure schedule", value=0.6, minimum=0.0, maximum=1.0, step=0.01, scale=2)
|
275 |
+
appearance_schedule = gr.Slider(label="Appearance schedule", value=0.6, minimum=0.0, maximum=1.0, step=0.01, scale=2)
|
276 |
+
use_advanced_config = gr.Checkbox(label="Use advanced config", value=False, scale=1)
|
277 |
+
with gr.Row():
|
278 |
+
control_config = gr.Textbox(
|
279 |
+
label="Advanced control config", lines=20, value=get_control_config(0.6, 0.6), elem_classes=["config"], visible=False,
|
280 |
+
)
|
281 |
+
use_advanced_config.change(
|
282 |
+
fn=lambda value: gr.update(visible=value), inputs=use_advanced_config, outputs=control_config,
|
283 |
+
)
|
284 |
+
with gr.Row():
|
285 |
+
generate = gr.Button(value="Run")
|
286 |
+
|
287 |
+
inputs = [
|
288 |
+
structure_image, appearance_image,
|
289 |
+
prompt, structure_prompt, appearance_prompt,
|
290 |
+
positive_prompt, negative_prompt,
|
291 |
+
guidance_scale, structure_guidance_scale, appearance_guidance_scale,
|
292 |
+
num_inference_steps, eta, seed,
|
293 |
+
width, height,
|
294 |
+
structure_schedule, appearance_schedule, use_advanced_config,
|
295 |
+
control_config,
|
296 |
+
]
|
297 |
+
outputs = [result, result_refiner, structure_recon, appearance_recon]
|
298 |
+
|
299 |
+
generate.click(inference, inputs=inputs, outputs=outputs)
|
300 |
+
|
301 |
+
examples = gr.Examples(
|
302 |
+
[
|
303 |
+
[
|
304 |
+
"assets/images/horse__point_cloud.jpg",
|
305 |
+
"assets/images/horse.jpg",
|
306 |
+
"a 3D point cloud of a horse",
|
307 |
+
"",
|
308 |
+
"a photo of a horse standing on grass",
|
309 |
+
0.6, 0.6,
|
310 |
+
],
|
311 |
+
[
|
312 |
+
"assets/images/cat__mesh.jpg",
|
313 |
+
"assets/images/tiger.jpg",
|
314 |
+
"a 3D mesh of a cat",
|
315 |
+
"",
|
316 |
+
"a photo of a tiger standing on snow",
|
317 |
+
0.6, 0.6,
|
318 |
+
],
|
319 |
+
[
|
320 |
+
"assets/images/dog__sketch.jpg",
|
321 |
+
"assets/images/squirrel.jpg",
|
322 |
+
"a sketch of a dog",
|
323 |
+
"",
|
324 |
+
"a photo of a squirrel",
|
325 |
+
0.6, 0.6,
|
326 |
+
],
|
327 |
+
[
|
328 |
+
"assets/images/living_room__seg.jpg",
|
329 |
+
"assets/images/van_gogh.jpg",
|
330 |
+
"a segmentation map of a living room",
|
331 |
+
"",
|
332 |
+
"a Van Gogh painting of a living room",
|
333 |
+
0.6, 0.6,
|
334 |
+
],
|
335 |
+
[
|
336 |
+
"assets/images/bedroom__sketch.jpg",
|
337 |
+
"assets/images/living_room_modern.jpg",
|
338 |
+
"a sketch of a bedroom",
|
339 |
+
"",
|
340 |
+
"a photo of a modern bedroom during sunset",
|
341 |
+
0.6, 0.6,
|
342 |
+
],
|
343 |
+
[
|
344 |
+
"assets/images/running__pose.jpg",
|
345 |
+
"assets/images/man_park.jpg",
|
346 |
+
"a pose image of a person running",
|
347 |
+
"",
|
348 |
+
"a photo of a man running in a park",
|
349 |
+
0.4, 0.6,
|
350 |
+
],
|
351 |
+
[
|
352 |
+
"assets/images/fruit_bowl.jpg",
|
353 |
+
"assets/images/grapes.jpg",
|
354 |
+
"a photo of a bowl of fruits",
|
355 |
+
"",
|
356 |
+
"a photo of a bowl of grapes in the trees",
|
357 |
+
0.6, 0.6,
|
358 |
+
],
|
359 |
+
[
|
360 |
+
"assets/images/bear_avocado__spatext.jpg",
|
361 |
+
None,
|
362 |
+
"a segmentation map of a bear and an avocado",
|
363 |
+
"",
|
364 |
+
"a realistic photo of a bear and an avocado in a forest",
|
365 |
+
0.6, 0.6,
|
366 |
+
],
|
367 |
+
[
|
368 |
+
"assets/images/cat__point_cloud.jpg",
|
369 |
+
None,
|
370 |
+
"a 3D point cloud of a cat",
|
371 |
+
"",
|
372 |
+
"an embroidery of a white cat sitting on a rock under the night sky",
|
373 |
+
0.6, 0.6,
|
374 |
+
],
|
375 |
+
[
|
376 |
+
"assets/images/library__mesh.jpg",
|
377 |
+
None,
|
378 |
+
"a 3D mesh of a library",
|
379 |
+
"",
|
380 |
+
"a Polaroid photo of an old library, sunlight streaming in",
|
381 |
+
0.6, 0.6,
|
382 |
+
],
|
383 |
+
[
|
384 |
+
"assets/images/knight__humanoid.jpg",
|
385 |
+
None,
|
386 |
+
"a 3D model of a person holding a sword and shield",
|
387 |
+
"",
|
388 |
+
"a photo of a medieval soldier standing on a barren field, raining",
|
389 |
+
0.6, 0.6,
|
390 |
+
],
|
391 |
+
[
|
392 |
+
"assets/images/person__mesh.jpg",
|
393 |
+
None,
|
394 |
+
"a 3D mesh of a person",
|
395 |
+
"",
|
396 |
+
"a photo of a Karate man performing in a cyberpunk city at night",
|
397 |
+
0.5, 0.6,
|
398 |
+
],
|
399 |
+
],
|
400 |
+
[
|
401 |
+
structure_image,
|
402 |
+
appearance_image,
|
403 |
+
structure_prompt,
|
404 |
+
appearance_prompt,
|
405 |
+
prompt,
|
406 |
+
structure_schedule,
|
407 |
+
appearance_schedule,
|
408 |
+
],
|
409 |
+
examples_per_page=50,
|
410 |
+
)
|
411 |
+
|
412 |
+
app.launch(debug=False, share=False)
|
assets/images/bear_avocado__spatext.jpg
ADDED
assets/images/bedroom__sketch.jpg
ADDED
assets/images/cat__mesh.jpg
ADDED
assets/images/cat__point_cloud.jpg
ADDED
assets/images/dog__sketch.jpg
ADDED
assets/images/fruit_bowl.jpg
ADDED
assets/images/grapes.jpg
ADDED
assets/images/horse.jpg
ADDED
assets/images/horse__point_cloud.jpg
ADDED
assets/images/knight__humanoid.jpg
ADDED
assets/images/library__mesh.jpg
ADDED
assets/images/living_room__seg.jpg
ADDED
assets/images/living_room_modern.jpg
ADDED
assets/images/man_park.jpg
ADDED
assets/images/person__mesh.jpg
ADDED
assets/images/running__pose.jpg
ADDED
assets/images/squirrel.jpg
ADDED
assets/images/tiger.jpg
ADDED
assets/images/van_gogh.jpg
ADDED
ctrl_x/__init__.py
ADDED
File without changes
|
ctrl_x/pipelines/__init__.py
ADDED
File without changes
|
ctrl_x/pipelines/pipeline_sdxl.py
ADDED
@@ -0,0 +1,665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
from diffusers import StableDiffusionXLPipeline
|
6 |
+
from diffusers.image_processor import PipelineImageInput
|
7 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import\
|
8 |
+
rescale_noise_cfg, retrieve_latents, retrieve_timesteps
|
9 |
+
from diffusers.utils import BaseOutput, deprecate
|
10 |
+
from diffusers.utils.torch_utils import randn_tensor
|
11 |
+
import numpy as np
|
12 |
+
import PIL
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from ..utils import *
|
16 |
+
from ..utils.sdxl import *
|
17 |
+
|
18 |
+
|
19 |
+
BATCH_ORDER = [
|
20 |
+
"structure_uncond", "appearance_uncond", "uncond", "structure_cond", "appearance_cond", "cond",
|
21 |
+
]
|
22 |
+
|
23 |
+
|
24 |
+
def get_last_control_i(control_schedule, num_inference_steps):
|
25 |
+
if control_schedule is None:
|
26 |
+
return num_inference_steps, num_inference_steps
|
27 |
+
|
28 |
+
def max_(l):
|
29 |
+
if len(l) == 0:
|
30 |
+
return 0.0
|
31 |
+
return max(l)
|
32 |
+
|
33 |
+
structure_max = 0.0
|
34 |
+
appearance_max = 0.0
|
35 |
+
for block in control_schedule.values():
|
36 |
+
if isinstance(block, list): # Handling mid_block
|
37 |
+
block = {0: block}
|
38 |
+
for layer in block.values():
|
39 |
+
structure_max = max(structure_max, max_(layer[0] + layer[1]))
|
40 |
+
appearance_max = max(appearance_max, max_(layer[2]))
|
41 |
+
|
42 |
+
structure_i = round(num_inference_steps * structure_max)
|
43 |
+
appearance_i = round(num_inference_steps * appearance_max)
|
44 |
+
return structure_i, appearance_i
|
45 |
+
|
46 |
+
|
47 |
+
@dataclass
|
48 |
+
class CtrlXStableDiffusionXLPipelineOutput(BaseOutput):
|
49 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
50 |
+
structures = Union[List[PIL.Image.Image], np.ndarray]
|
51 |
+
appearances = Union[List[PIL.Image.Image], np.ndarray]
|
52 |
+
|
53 |
+
|
54 |
+
class CtrlXStableDiffusionXLPipeline(StableDiffusionXLPipeline): # diffusers==0.28.0
|
55 |
+
|
56 |
+
def prepare_latents(
|
57 |
+
self, image, batch_size, num_images_per_prompt, num_channels_latents, height, width,
|
58 |
+
dtype, device, generator=None, noise=None,
|
59 |
+
):
|
60 |
+
batch_size = batch_size * num_images_per_prompt
|
61 |
+
|
62 |
+
if noise is None:
|
63 |
+
shape = (
|
64 |
+
batch_size,
|
65 |
+
num_channels_latents,
|
66 |
+
height // self.vae_scale_factor,
|
67 |
+
width // self.vae_scale_factor
|
68 |
+
)
|
69 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
70 |
+
noise = noise * self.scheduler.init_noise_sigma # Starting noise, need to scale
|
71 |
+
else:
|
72 |
+
noise = noise.to(device)
|
73 |
+
|
74 |
+
if image is None:
|
75 |
+
return noise, None
|
76 |
+
|
77 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
78 |
+
raise ValueError(
|
79 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
80 |
+
)
|
81 |
+
|
82 |
+
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
83 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
84 |
+
self.text_encoder_2.to("cpu")
|
85 |
+
torch.cuda.empty_cache()
|
86 |
+
|
87 |
+
image = image.to(device=device, dtype=dtype)
|
88 |
+
|
89 |
+
if image.shape[1] == 4: # Image already in latents form
|
90 |
+
init_latents = image
|
91 |
+
|
92 |
+
else:
|
93 |
+
# Make sure the VAE is in float32 mode, as it overflows in float16
|
94 |
+
if self.vae.config.force_upcast:
|
95 |
+
image = image.to(torch.float32)
|
96 |
+
self.vae.to(torch.float32)
|
97 |
+
|
98 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
99 |
+
raise ValueError(
|
100 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
101 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
102 |
+
)
|
103 |
+
elif isinstance(generator, list):
|
104 |
+
init_latents = [
|
105 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
106 |
+
for i in range(batch_size)
|
107 |
+
]
|
108 |
+
init_latents = torch.cat(init_latents, dim=0)
|
109 |
+
else:
|
110 |
+
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
111 |
+
|
112 |
+
if self.vae.config.force_upcast:
|
113 |
+
self.vae.to(dtype)
|
114 |
+
|
115 |
+
init_latents = init_latents.to(dtype)
|
116 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
117 |
+
|
118 |
+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
119 |
+
# Expand init_latents for batch_size
|
120 |
+
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
121 |
+
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
122 |
+
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
123 |
+
raise ValueError(
|
124 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
init_latents = torch.cat([init_latents], dim=0)
|
128 |
+
|
129 |
+
return noise, init_latents
|
130 |
+
|
131 |
+
@property
|
132 |
+
def structure_guidance_scale(self):
|
133 |
+
return self._guidance_scale if self._structure_guidance_scale is None else self._structure_guidance_scale
|
134 |
+
|
135 |
+
@property
|
136 |
+
def appearance_guidance_scale(self):
|
137 |
+
return self._guidance_scale if self._appearance_guidance_scale is None else self._appearance_guidance_scale
|
138 |
+
|
139 |
+
@torch.no_grad()
|
140 |
+
def __call__(
|
141 |
+
self,
|
142 |
+
prompt: Union[str, List[str]] = None, # TODO: Support prompt_2 and negative_prompt_2
|
143 |
+
structure_prompt: Optional[Union[str, List[str]]] = None,
|
144 |
+
appearance_prompt: Optional[Union[str, List[str]]] = None,
|
145 |
+
structure_image: Optional[PipelineImageInput] = None,
|
146 |
+
appearance_image: Optional[PipelineImageInput] = None,
|
147 |
+
num_inference_steps: int = 50,
|
148 |
+
timesteps: List[int] = None,
|
149 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
150 |
+
positive_prompt: Optional[Union[str, List[str]]] = None,
|
151 |
+
height: Optional[int] = None,
|
152 |
+
width: Optional[int] = None,
|
153 |
+
guidance_scale: float = 5.0,
|
154 |
+
structure_guidance_scale: Optional[float] = None,
|
155 |
+
appearance_guidance_scale: Optional[float] = None,
|
156 |
+
num_images_per_prompt: Optional[int] = 1,
|
157 |
+
eta: float = 0.0,
|
158 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
159 |
+
latents: Optional[torch.Tensor] = None,
|
160 |
+
structure_latents: Optional[torch.Tensor] = None,
|
161 |
+
appearance_latents: Optional[torch.Tensor] = None,
|
162 |
+
prompt_embeds: Optional[torch.Tensor] = None, # Positive prompt is concatenated with prompt, so no embeddings
|
163 |
+
structure_prompt_embeds: Optional[torch.Tensor] = None,
|
164 |
+
appearance_prompt_embeds: Optional[torch.Tensor] = None,
|
165 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
166 |
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
167 |
+
structure_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
168 |
+
appearance_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
169 |
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
170 |
+
control_schedule: Optional[Dict] = None,
|
171 |
+
self_recurrence_schedule: Optional[List[int]] = [], # Format: [(start, end, num_repeat)]
|
172 |
+
decode_structure: Optional[bool] = True,
|
173 |
+
decode_appearance: Optional[bool] = True,
|
174 |
+
output_type: Optional[str] = "pil",
|
175 |
+
return_dict: bool = True,
|
176 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
177 |
+
guidance_rescale: float = 0.0,
|
178 |
+
original_size: Tuple[int, int] = None,
|
179 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
180 |
+
target_size: Tuple[int, int] = None,
|
181 |
+
clip_skip: Optional[int] = None,
|
182 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
183 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
184 |
+
**kwargs,
|
185 |
+
):
|
186 |
+
# TODO: Add function argument documentation
|
187 |
+
|
188 |
+
callback = kwargs.pop("callback", None)
|
189 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
190 |
+
|
191 |
+
if callback is not None:
|
192 |
+
deprecate(
|
193 |
+
"callback",
|
194 |
+
"1.0.0",
|
195 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
196 |
+
)
|
197 |
+
if callback_steps is not None:
|
198 |
+
deprecate(
|
199 |
+
"callback_steps",
|
200 |
+
"1.0.0",
|
201 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
202 |
+
)
|
203 |
+
|
204 |
+
# 0. Default height and width to U-Net
|
205 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
206 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
207 |
+
original_size = original_size or (height, width)
|
208 |
+
target_size = target_size or (height, width)
|
209 |
+
|
210 |
+
# 1. Check inputs. Raise error if not correct
|
211 |
+
self.check_inputs( # TODO: Custom check_inputs for our method
|
212 |
+
prompt,
|
213 |
+
None, # prompt_2
|
214 |
+
height,
|
215 |
+
width,
|
216 |
+
callback_steps,
|
217 |
+
negative_prompt = negative_prompt,
|
218 |
+
negative_prompt_2 = None, # negative_prompt_2
|
219 |
+
prompt_embeds = prompt_embeds,
|
220 |
+
negative_prompt_embeds = negative_prompt_embeds,
|
221 |
+
pooled_prompt_embeds = pooled_prompt_embeds,
|
222 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds,
|
223 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs,
|
224 |
+
)
|
225 |
+
|
226 |
+
self._guidance_scale = guidance_scale
|
227 |
+
self._structure_guidance_scale = structure_guidance_scale
|
228 |
+
self._appearance_guidance_scale = appearance_guidance_scale
|
229 |
+
self._guidance_rescale = guidance_rescale
|
230 |
+
self._clip_skip = clip_skip
|
231 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
232 |
+
self._denoising_end = None # denoising_end
|
233 |
+
self._denoising_start = None # denoising_start
|
234 |
+
self._interrupt = False
|
235 |
+
|
236 |
+
# 2. Define call parameters
|
237 |
+
if prompt is not None and isinstance(prompt, str):
|
238 |
+
batch_size = 1
|
239 |
+
elif prompt is not None and isinstance(prompt, list):
|
240 |
+
batch_size = len(prompt)
|
241 |
+
else:
|
242 |
+
batch_size = prompt_embeds.shape[0]
|
243 |
+
|
244 |
+
if batch_size * num_images_per_prompt != 1:
|
245 |
+
raise ValueError(
|
246 |
+
f"Pipeline currently does not support batch_size={batch_size} and num_images_per_prompt=1. "
|
247 |
+
"Effective batch size (batch_size * num_images_per_prompt) must be 1."
|
248 |
+
)
|
249 |
+
|
250 |
+
device = self._execution_device
|
251 |
+
|
252 |
+
# 3. Encode input prompt
|
253 |
+
text_encoder_lora_scale = (
|
254 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
255 |
+
)
|
256 |
+
|
257 |
+
if positive_prompt is not None and positive_prompt != "":
|
258 |
+
prompt = prompt + ", " + positive_prompt # Add positive prompt with comma
|
259 |
+
# By default, only add positive prompt to the appearance prompt and not the structure prompt
|
260 |
+
if appearance_prompt is not None and appearance_prompt != "":
|
261 |
+
appearance_prompt = appearance_prompt + ", " + positive_prompt
|
262 |
+
|
263 |
+
(
|
264 |
+
prompt_embeds_,
|
265 |
+
negative_prompt_embeds,
|
266 |
+
pooled_prompt_embeds_,
|
267 |
+
negative_pooled_prompt_embeds,
|
268 |
+
) = self.encode_prompt(
|
269 |
+
prompt = prompt,
|
270 |
+
prompt_2 = None, # prompt_2
|
271 |
+
device = device,
|
272 |
+
num_images_per_prompt = num_images_per_prompt,
|
273 |
+
do_classifier_free_guidance = True, # self.do_classifier_free_guidance, TODO: Support no CFG
|
274 |
+
negative_prompt = negative_prompt,
|
275 |
+
negative_prompt_2 = None, # negative_prompt_2
|
276 |
+
prompt_embeds = prompt_embeds,
|
277 |
+
negative_prompt_embeds = negative_prompt_embeds,
|
278 |
+
pooled_prompt_embeds = pooled_prompt_embeds,
|
279 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds,
|
280 |
+
lora_scale = text_encoder_lora_scale,
|
281 |
+
clip_skip = self.clip_skip,
|
282 |
+
)
|
283 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds_], dim=0).to(device)
|
284 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_], dim=0).to(device)
|
285 |
+
|
286 |
+
# 3.1. Structure prompt embeddings
|
287 |
+
if structure_prompt is not None and structure_prompt != "":
|
288 |
+
(
|
289 |
+
structure_prompt_embeds,
|
290 |
+
negative_structure_prompt_embeds,
|
291 |
+
structure_pooled_prompt_embeds,
|
292 |
+
negative_structure_pooled_prompt_embeds,
|
293 |
+
) = self.encode_prompt(
|
294 |
+
prompt = structure_prompt,
|
295 |
+
prompt_2 = None, # prompt_2
|
296 |
+
device = device,
|
297 |
+
num_images_per_prompt = num_images_per_prompt,
|
298 |
+
do_classifier_free_guidance = True, # self.do_classifier_free_guidance, TODO: Support no CFG
|
299 |
+
negative_prompt = negative_prompt if structure_image is None else "",
|
300 |
+
negative_prompt_2 = None, # negative_prompt_2
|
301 |
+
prompt_embeds = structure_prompt_embeds,
|
302 |
+
negative_prompt_embeds = None, # negative_prompt_embeds
|
303 |
+
pooled_prompt_embeds = structure_pooled_prompt_embeds,
|
304 |
+
negative_pooled_prompt_embeds = None, # negative_pooled_prompt_embeds
|
305 |
+
lora_scale = text_encoder_lora_scale,
|
306 |
+
clip_skip = self.clip_skip,
|
307 |
+
)
|
308 |
+
structure_prompt_embeds = torch.cat(
|
309 |
+
[negative_structure_prompt_embeds, structure_prompt_embeds], dim=0
|
310 |
+
).to(device)
|
311 |
+
structure_add_text_embeds = torch.cat(
|
312 |
+
[negative_structure_pooled_prompt_embeds, structure_pooled_prompt_embeds], dim=0
|
313 |
+
).to(device)
|
314 |
+
else:
|
315 |
+
structure_prompt_embeds = prompt_embeds
|
316 |
+
structure_add_text_embeds = add_text_embeds
|
317 |
+
|
318 |
+
# 3.2. Appearance prompt embeddings
|
319 |
+
if appearance_prompt is not None and appearance_prompt != "":
|
320 |
+
(
|
321 |
+
appearance_prompt_embeds,
|
322 |
+
negative_appearance_prompt_embeds,
|
323 |
+
appearance_pooled_prompt_embeds,
|
324 |
+
negative_appearance_pooled_prompt_embeds,
|
325 |
+
) = self.encode_prompt(
|
326 |
+
prompt = appearance_prompt,
|
327 |
+
prompt_2 = None, # prompt_2
|
328 |
+
device = device,
|
329 |
+
num_images_per_prompt = num_images_per_prompt,
|
330 |
+
do_classifier_free_guidance = True, # self.do_classifier_free_guidance, TODO: Support no CFG
|
331 |
+
negative_prompt = negative_prompt if appearance_image is None else "",
|
332 |
+
negative_prompt_2 = None, # negative_prompt_2
|
333 |
+
prompt_embeds = appearance_prompt_embeds,
|
334 |
+
negative_prompt_embeds = None, # negative_prompt_embeds
|
335 |
+
pooled_prompt_embeds = appearance_pooled_prompt_embeds, # pooled_prompt_embeds
|
336 |
+
negative_pooled_prompt_embeds = None, # negative_pooled_prompt_embeds
|
337 |
+
lora_scale = text_encoder_lora_scale,
|
338 |
+
clip_skip = self.clip_skip,
|
339 |
+
)
|
340 |
+
appearance_prompt_embeds = torch.cat(
|
341 |
+
[negative_appearance_prompt_embeds, appearance_prompt_embeds], dim=0
|
342 |
+
).to(device)
|
343 |
+
appearance_add_text_embeds = torch.cat(
|
344 |
+
[negative_appearance_pooled_prompt_embeds, appearance_pooled_prompt_embeds], dim=0
|
345 |
+
).to(device)
|
346 |
+
else:
|
347 |
+
appearance_prompt_embeds = prompt_embeds
|
348 |
+
appearance_add_text_embeds = add_text_embeds
|
349 |
+
|
350 |
+
# 3.3. Prepare added time ids & embeddings, TODO: Support no CFG
|
351 |
+
if self.text_encoder_2 is None:
|
352 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
353 |
+
else:
|
354 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
355 |
+
|
356 |
+
add_time_ids = self._get_add_time_ids(
|
357 |
+
original_size,
|
358 |
+
crops_coords_top_left,
|
359 |
+
target_size,
|
360 |
+
dtype = prompt_embeds.dtype,
|
361 |
+
text_encoder_projection_dim = text_encoder_projection_dim,
|
362 |
+
)
|
363 |
+
negative_add_time_ids = add_time_ids
|
364 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0).to(device)
|
365 |
+
|
366 |
+
# 4. Prepare timesteps
|
367 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
368 |
+
|
369 |
+
# 5. Prepare latent variables
|
370 |
+
num_channels_latents = self.unet.config.in_channels
|
371 |
+
|
372 |
+
latents, _ = self.prepare_latents(
|
373 |
+
None, batch_size, num_images_per_prompt, num_channels_latents, height, width,
|
374 |
+
prompt_embeds.dtype, device, generator, latents
|
375 |
+
)
|
376 |
+
|
377 |
+
if structure_image is not None:
|
378 |
+
structure_image = preprocess( # Center crop + resize
|
379 |
+
structure_image, self.image_processor, height=height, width=width, resize_mode="crop"
|
380 |
+
)
|
381 |
+
_, clean_structure_latents = self.prepare_latents(
|
382 |
+
structure_image, batch_size, num_images_per_prompt, num_channels_latents, height, width,
|
383 |
+
prompt_embeds.dtype, device, generator, structure_latents,
|
384 |
+
)
|
385 |
+
else:
|
386 |
+
clean_structure_latents = None
|
387 |
+
structure_latents = latents if structure_latents is None else structure_latents
|
388 |
+
|
389 |
+
if appearance_image is not None:
|
390 |
+
appearance_image = preprocess( # Center crop + resize
|
391 |
+
appearance_image, self.image_processor, height=height, width=width, resize_mode="crop"
|
392 |
+
)
|
393 |
+
_, clean_appearance_latents = self.prepare_latents(
|
394 |
+
appearance_image, batch_size, num_images_per_prompt, num_channels_latents, height, width,
|
395 |
+
prompt_embeds.dtype, device, generator, appearance_latents,
|
396 |
+
)
|
397 |
+
else:
|
398 |
+
clean_appearance_latents = None
|
399 |
+
appearance_latents = latents if appearance_latents is None else appearance_latents
|
400 |
+
|
401 |
+
# 6. Prepare extra step kwargs
|
402 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
403 |
+
|
404 |
+
# 7. Denoising loop
|
405 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
406 |
+
|
407 |
+
# 7.1 Apply denoising_end
|
408 |
+
def denoising_value_valid(dnv):
|
409 |
+
return isinstance(self.denoising_end, float) and 0 < dnv < 1
|
410 |
+
|
411 |
+
if (
|
412 |
+
self.denoising_end is not None
|
413 |
+
and self.denoising_start is not None
|
414 |
+
and denoising_value_valid(self.denoising_end)
|
415 |
+
and denoising_value_valid(self.denoising_start)
|
416 |
+
and self.denoising_start >= self.denoising_end
|
417 |
+
):
|
418 |
+
raise ValueError(
|
419 |
+
f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
|
420 |
+
+ f" {self.denoising_end} when using type float."
|
421 |
+
)
|
422 |
+
elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
|
423 |
+
discrete_timestep_cutoff = int(
|
424 |
+
round(
|
425 |
+
self.scheduler.config.num_train_timesteps
|
426 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
427 |
+
)
|
428 |
+
)
|
429 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
430 |
+
timesteps = timesteps[:num_inference_steps]
|
431 |
+
|
432 |
+
# 7.2 Optionally get guidance scale embedding
|
433 |
+
timestep_cond = None
|
434 |
+
if self.unet.config.time_cond_proj_dim is not None: # TODO: Make guidance scale embedding work with batch_order
|
435 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
436 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
437 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
438 |
+
).to(device=device, dtype=latents.dtype)
|
439 |
+
|
440 |
+
# 7.3 Get batch order
|
441 |
+
batch_order = deepcopy(BATCH_ORDER)
|
442 |
+
if structure_image is not None: # If image is provided, not generating, so no CFG needed
|
443 |
+
batch_order.remove("structure_uncond")
|
444 |
+
if appearance_image is not None:
|
445 |
+
batch_order.remove("appearance_uncond")
|
446 |
+
|
447 |
+
structure_control_stop_i, appearance_control_stop_i = get_last_control_i(control_schedule, num_inference_steps)
|
448 |
+
if self_recurrence_schedule is None:
|
449 |
+
self_recurrence_schedule = [0] * num_inference_steps
|
450 |
+
|
451 |
+
self._num_timesteps = len(timesteps)
|
452 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
453 |
+
for i, t in enumerate(timesteps):
|
454 |
+
if self.interrupt:
|
455 |
+
continue
|
456 |
+
|
457 |
+
if i == structure_control_stop_i: # If not generating structure/appearance, drop after last control
|
458 |
+
if "structure_uncond" not in batch_order:
|
459 |
+
batch_order.remove("structure_cond")
|
460 |
+
if i == appearance_control_stop_i:
|
461 |
+
if "appearance_uncond" not in batch_order:
|
462 |
+
batch_order.remove("appearance_cond")
|
463 |
+
|
464 |
+
register_attr(self, t=t.item(), do_control=True, batch_order=batch_order)
|
465 |
+
|
466 |
+
# TODO: For now, assume we are doing classifier-free guidance, support no CF-guidance later
|
467 |
+
latent_model_input = self.scheduler.scale_model_input(latents, t)
|
468 |
+
structure_latent_model_input = self.scheduler.scale_model_input(structure_latents, t)
|
469 |
+
appearance_latent_model_input = self.scheduler.scale_model_input(appearance_latents, t)
|
470 |
+
|
471 |
+
all_latent_model_input = {
|
472 |
+
"structure_uncond": structure_latent_model_input[0:1],
|
473 |
+
"appearance_uncond": appearance_latent_model_input[0:1],
|
474 |
+
"uncond": latent_model_input[0:1],
|
475 |
+
"structure_cond": structure_latent_model_input[0:1],
|
476 |
+
"appearance_cond": appearance_latent_model_input[0:1],
|
477 |
+
"cond": latent_model_input[0:1],
|
478 |
+
}
|
479 |
+
all_prompt_embeds = {
|
480 |
+
"structure_uncond": structure_prompt_embeds[0:1],
|
481 |
+
"appearance_uncond": appearance_prompt_embeds[0:1],
|
482 |
+
"uncond": prompt_embeds[0:1],
|
483 |
+
"structure_cond": structure_prompt_embeds[1:2],
|
484 |
+
"appearance_cond": appearance_prompt_embeds[1:2],
|
485 |
+
"cond": prompt_embeds[1:2],
|
486 |
+
}
|
487 |
+
all_add_text_embeds = {
|
488 |
+
"structure_uncond": structure_add_text_embeds[0:1],
|
489 |
+
"appearance_uncond": appearance_add_text_embeds[0:1],
|
490 |
+
"uncond": add_text_embeds[0:1],
|
491 |
+
"structure_cond": structure_add_text_embeds[1:2],
|
492 |
+
"appearance_cond": appearance_add_text_embeds[1:2],
|
493 |
+
"cond": add_text_embeds[1:2],
|
494 |
+
}
|
495 |
+
all_time_ids = {
|
496 |
+
"structure_uncond": add_time_ids[0:1],
|
497 |
+
"appearance_uncond": add_time_ids[0:1],
|
498 |
+
"uncond": add_time_ids[0:1],
|
499 |
+
"structure_cond": add_time_ids[1:2],
|
500 |
+
"appearance_cond": add_time_ids[1:2],
|
501 |
+
"cond": add_time_ids[1:2],
|
502 |
+
}
|
503 |
+
|
504 |
+
concat_latent_model_input = batch_dict_to_tensor(all_latent_model_input, batch_order)
|
505 |
+
concat_prompt_embeds = batch_dict_to_tensor(all_prompt_embeds, batch_order)
|
506 |
+
concat_add_text_embeds = batch_dict_to_tensor(all_add_text_embeds, batch_order)
|
507 |
+
concat_add_time_ids = batch_dict_to_tensor(all_time_ids, batch_order)
|
508 |
+
|
509 |
+
# Predict the noise residual
|
510 |
+
added_cond_kwargs = {"text_embeds": concat_add_text_embeds, "time_ids": concat_add_time_ids}
|
511 |
+
|
512 |
+
concat_noise_pred = self.unet(
|
513 |
+
concat_latent_model_input,
|
514 |
+
t,
|
515 |
+
encoder_hidden_states = concat_prompt_embeds,
|
516 |
+
timestep_cond = timestep_cond,
|
517 |
+
cross_attention_kwargs = self.cross_attention_kwargs,
|
518 |
+
added_cond_kwargs = added_cond_kwargs,
|
519 |
+
).sample
|
520 |
+
all_noise_pred = batch_tensor_to_dict(concat_noise_pred, batch_order)
|
521 |
+
|
522 |
+
# Classifier-free guidance, TODO: Support no CFG
|
523 |
+
noise_pred = all_noise_pred["uncond"] +\
|
524 |
+
self.guidance_scale * (all_noise_pred["cond"] - all_noise_pred["uncond"])
|
525 |
+
|
526 |
+
structure_noise_pred = all_noise_pred["structure_cond"]\
|
527 |
+
if "structure_cond" in batch_order else noise_pred
|
528 |
+
if "structure_uncond" in all_noise_pred:
|
529 |
+
structure_noise_pred = all_noise_pred["structure_uncond"] +\
|
530 |
+
self.structure_guidance_scale * (structure_noise_pred - all_noise_pred["structure_uncond"])
|
531 |
+
|
532 |
+
appearance_noise_pred = all_noise_pred["appearance_cond"]\
|
533 |
+
if "appearance_cond" in batch_order else noise_pred
|
534 |
+
if "appearance_uncond" in all_noise_pred:
|
535 |
+
appearance_noise_pred = all_noise_pred["appearance_uncond"] +\
|
536 |
+
self.appearance_guidance_scale * (appearance_noise_pred - all_noise_pred["appearance_uncond"])
|
537 |
+
|
538 |
+
if self.guidance_rescale > 0.0:
|
539 |
+
noise_pred = rescale_noise_cfg(
|
540 |
+
noise_pred, all_noise_pred["cond"], guidance_rescale=self.guidance_rescale
|
541 |
+
)
|
542 |
+
if "structure_uncond" in all_noise_pred:
|
543 |
+
structure_noise_pred = rescale_noise_cfg(
|
544 |
+
structure_noise_pred, all_noise_pred["structure_cond"],
|
545 |
+
guidance_rescale=self.guidance_rescale
|
546 |
+
)
|
547 |
+
if "appearance_uncond" in all_noise_pred:
|
548 |
+
appearance_noise_pred = rescale_noise_cfg(
|
549 |
+
appearance_noise_pred, all_noise_pred["appearance_cond"],
|
550 |
+
guidance_rescale=self.guidance_rescale
|
551 |
+
)
|
552 |
+
|
553 |
+
# Compute the previous noisy sample x_t -> x_t-1
|
554 |
+
concat_noise_pred = torch.cat(
|
555 |
+
[structure_noise_pred, appearance_noise_pred, noise_pred], dim=0,
|
556 |
+
)
|
557 |
+
concat_latents = torch.cat(
|
558 |
+
[structure_latents, appearance_latents, latents], dim=0,
|
559 |
+
)
|
560 |
+
structure_latents, appearance_latents, latents = self.scheduler.step(
|
561 |
+
concat_noise_pred, t, concat_latents, **extra_step_kwargs,
|
562 |
+
).prev_sample.chunk(3)
|
563 |
+
|
564 |
+
if clean_structure_latents is not None:
|
565 |
+
structure_latents = noise_prev(self.scheduler, t, clean_structure_latents)
|
566 |
+
if clean_appearance_latents is not None:
|
567 |
+
appearance_latents = noise_prev(self.scheduler, t, clean_appearance_latents)
|
568 |
+
|
569 |
+
# Self-recurrence
|
570 |
+
for _ in range(self_recurrence_schedule[i]):
|
571 |
+
if hasattr(self.scheduler, "_step_index"): # For fancier schedulers
|
572 |
+
self.scheduler._step_index -= 1 # TODO: Does this actually work?
|
573 |
+
|
574 |
+
t_prev = 0 if i + 1 >= num_inference_steps else timesteps[i + 1]
|
575 |
+
latents = noise_t2t(self.scheduler, t_prev, t, latents)
|
576 |
+
latent_model_input = torch.cat([latents] * 2)
|
577 |
+
|
578 |
+
register_attr(self, t=t.item(), do_control=False, batch_order=["uncond", "cond"])
|
579 |
+
|
580 |
+
# Predict the noise residual
|
581 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
582 |
+
noise_pred_uncond, noise_pred_ = self.unet(
|
583 |
+
latent_model_input,
|
584 |
+
t,
|
585 |
+
encoder_hidden_states = prompt_embeds,
|
586 |
+
timestep_cond = timestep_cond,
|
587 |
+
cross_attention_kwargs = self.cross_attention_kwargs,
|
588 |
+
added_cond_kwargs = added_cond_kwargs,
|
589 |
+
).sample.chunk(2)
|
590 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_ - noise_pred_uncond)
|
591 |
+
|
592 |
+
if self.guidance_rescale > 0.0:
|
593 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_, guidance_rescale=self.guidance_rescale)
|
594 |
+
|
595 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
596 |
+
|
597 |
+
# Callbacks
|
598 |
+
if callback_on_step_end is not None:
|
599 |
+
callback_kwargs = {}
|
600 |
+
for k in callback_on_step_end_tensor_inputs:
|
601 |
+
callback_kwargs[k] = locals()[k]
|
602 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
603 |
+
|
604 |
+
latents = callback_outputs.pop("latents", latents)
|
605 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
606 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
607 |
+
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
608 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
609 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
610 |
+
)
|
611 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
612 |
+
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
|
613 |
+
|
614 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
615 |
+
progress_bar.update()
|
616 |
+
if callback is not None and i % callback_steps == 0:
|
617 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
618 |
+
callback(step_idx, t, latents)
|
619 |
+
|
620 |
+
# "Reconstruction"
|
621 |
+
if clean_structure_latents is not None:
|
622 |
+
structure_latents = clean_structure_latents
|
623 |
+
if clean_appearance_latents is not None:
|
624 |
+
appearance_latents = clean_appearance_latents
|
625 |
+
|
626 |
+
# For passing important information onto the refiner
|
627 |
+
self.refiner_args = {"latents": latents.detach(), "prompt": prompt, "negative_prompt": negative_prompt}
|
628 |
+
|
629 |
+
if not output_type == "latent":
|
630 |
+
# Make sure the VAE is in float32 mode, as it overflows in float16
|
631 |
+
if self.vae.config.force_upcast:
|
632 |
+
self.vae.to(torch.float32) # self.upcast_vae() is buggy
|
633 |
+
latents = latents.to(torch.float32)
|
634 |
+
structure_latents = structure_latents.to(torch.float32)
|
635 |
+
appearance_latents = appearance_latents.to(torch.float32)
|
636 |
+
|
637 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
638 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
639 |
+
if decode_structure:
|
640 |
+
structure = self.vae.decode(structure_latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
641 |
+
structure = self.image_processor.postprocess(structure, output_type=output_type)
|
642 |
+
else:
|
643 |
+
structure = structure_latents
|
644 |
+
if decode_appearance:
|
645 |
+
appearance = self.vae.decode(appearance_latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
646 |
+
appearance = self.image_processor.postprocess(appearance, output_type=output_type)
|
647 |
+
else:
|
648 |
+
appearance = appearance_latents
|
649 |
+
|
650 |
+
# Cast back to fp16 if needed
|
651 |
+
if self.vae.config.force_upcast:
|
652 |
+
self.vae.to(dtype=torch.float16)
|
653 |
+
|
654 |
+
else:
|
655 |
+
return CtrlXStableDiffusionXLPipelineOutput(
|
656 |
+
images=latents, structures=structure_latents, appearances=appearance_latents
|
657 |
+
)
|
658 |
+
|
659 |
+
# Offload all models
|
660 |
+
self.maybe_free_model_hooks()
|
661 |
+
|
662 |
+
if not return_dict:
|
663 |
+
return (image, structure, appearance)
|
664 |
+
|
665 |
+
return CtrlXStableDiffusionXLPipelineOutput(images=image, structures=structure, appearances=appearance)
|
ctrl_x/utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .feature import *
|
2 |
+
from .media import *
|
3 |
+
from .utils import *
|
ctrl_x/utils/feature.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .utils import *
|
6 |
+
|
7 |
+
|
8 |
+
def get_schedule(timesteps, schedule):
|
9 |
+
end = round(len(timesteps) * schedule)
|
10 |
+
timesteps = timesteps[:end]
|
11 |
+
return timesteps
|
12 |
+
|
13 |
+
|
14 |
+
def get_elem(l, i, default=0.0):
|
15 |
+
if i >= len(l):
|
16 |
+
return default
|
17 |
+
return l[i]
|
18 |
+
|
19 |
+
|
20 |
+
def pad_list(l_1, l_2, pad=0.0):
|
21 |
+
max_len = max(len(l_1), len(l_2))
|
22 |
+
l_1 = l_1 + [pad] * (max_len - len(l_1))
|
23 |
+
l_2 = l_2 + [pad] * (max_len - len(l_2))
|
24 |
+
return l_1, l_2
|
25 |
+
|
26 |
+
|
27 |
+
def normalize(x, dim):
|
28 |
+
x_mean = x.mean(dim=dim, keepdim=True)
|
29 |
+
x_std = x.std(dim=dim, keepdim=True)
|
30 |
+
x_normalized = (x - x_mean) / x_std
|
31 |
+
return x_normalized
|
32 |
+
|
33 |
+
|
34 |
+
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
35 |
+
def appearance_mean_std(q_c_normed, k_s_normed, v_s): # c: content, s: style
|
36 |
+
q_c = q_c_normed # q_c and k_s must be projected from normalized features
|
37 |
+
k_s = k_s_normed
|
38 |
+
scale_factor = 1 / math.sqrt(q_c.shape[-1])
|
39 |
+
|
40 |
+
# My notation below is very jank: D = (H W) is number of tokens, and C is token dimension
|
41 |
+
# Horrible notation coming from how self-attention dimensions work in Stable Diffusion
|
42 |
+
A = q_c @ k_s.mT # (B H D C/H) (B H C/H D)^T -> (B H D D)
|
43 |
+
A = F.softmax(A * scale_factor, dim=-1) # Softmax on last D in (B H D D)
|
44 |
+
mean = A @ v_s # (B H D D) (B H D C/H) -> (B H D C/H)
|
45 |
+
std = (A @ v_s.square() - mean.square()).relu().sqrt()
|
46 |
+
|
47 |
+
return mean, std
|
48 |
+
|
49 |
+
|
50 |
+
def feature_injection(features, batch_order):
|
51 |
+
assert features.shape[0] % len(batch_order) == 0
|
52 |
+
features_dict = batch_tensor_to_dict(features, batch_order)
|
53 |
+
features_dict["cond"] = features_dict["structure_cond"]
|
54 |
+
features = batch_dict_to_tensor(features_dict, batch_order)
|
55 |
+
return features
|
56 |
+
|
57 |
+
|
58 |
+
def appearance_transfer(features, q_normed, k_normed, batch_order, v=None, reshape_fn=None):
|
59 |
+
assert features.shape[0] % len(batch_order) == 0
|
60 |
+
|
61 |
+
features_dict = batch_tensor_to_dict(features, batch_order)
|
62 |
+
q_normed_dict = batch_tensor_to_dict(q_normed, batch_order)
|
63 |
+
k_normed_dict = batch_tensor_to_dict(k_normed, batch_order)
|
64 |
+
v_dict = features_dict
|
65 |
+
if v is not None:
|
66 |
+
v_dict = batch_tensor_to_dict(v, batch_order)
|
67 |
+
|
68 |
+
mean_cond, std_cond = appearance_mean_std(
|
69 |
+
q_normed_dict["cond"], k_normed_dict["appearance_cond"], v_dict["appearance_cond"],
|
70 |
+
)
|
71 |
+
|
72 |
+
if reshape_fn is not None:
|
73 |
+
mean_cond = reshape_fn(mean_cond)
|
74 |
+
std_cond = reshape_fn(std_cond)
|
75 |
+
|
76 |
+
features_dict["cond"] = std_cond * normalize(features_dict["cond"], dim=-2) + mean_cond
|
77 |
+
|
78 |
+
features = batch_dict_to_tensor(features_dict, batch_order)
|
79 |
+
return features
|
ctrl_x/utils/media.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torchvision.transforms.functional as vF
|
4 |
+
import PIL
|
5 |
+
|
6 |
+
|
7 |
+
JPEG_QUALITY = 95
|
8 |
+
|
9 |
+
|
10 |
+
def preprocess(image, processor, **kwargs):
|
11 |
+
if isinstance(image, PIL.Image.Image):
|
12 |
+
pass
|
13 |
+
elif isinstance(image, np.ndarray):
|
14 |
+
image = PIL.Image.fromarray(image)
|
15 |
+
elif isinstance(image, torch.Tensor):
|
16 |
+
image = vF.to_pil_image(image)
|
17 |
+
else:
|
18 |
+
raise TypeError(f"Image must be of type PIL.Image, np.ndarray, or torch.Tensor, got {type(image)} instead.")
|
19 |
+
|
20 |
+
image = processor.preprocess(image, **kwargs)
|
21 |
+
return image
|
ctrl_x/utils/sdxl.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from types import MethodType
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from diffusers.models.attention_processor import Attention
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from .feature import *
|
9 |
+
from .utils import *
|
10 |
+
|
11 |
+
|
12 |
+
def convolution_forward( # From <class 'diffusers.models.resnet.ResnetBlock2D'>, forward (diffusers==0.28.0)
|
13 |
+
self,
|
14 |
+
input_tensor: torch.Tensor,
|
15 |
+
temb: torch.Tensor,
|
16 |
+
*args,
|
17 |
+
**kwargs,
|
18 |
+
) -> torch.Tensor:
|
19 |
+
do_structure_control = self.do_control and self.t in self.structure_schedule
|
20 |
+
|
21 |
+
hidden_states = input_tensor
|
22 |
+
|
23 |
+
hidden_states = self.norm1(hidden_states)
|
24 |
+
hidden_states = self.nonlinearity(hidden_states)
|
25 |
+
|
26 |
+
if self.upsample is not None:
|
27 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
28 |
+
if hidden_states.shape[0] >= 64:
|
29 |
+
input_tensor = input_tensor.contiguous()
|
30 |
+
hidden_states = hidden_states.contiguous()
|
31 |
+
input_tensor = self.upsample(input_tensor)
|
32 |
+
hidden_states = self.upsample(hidden_states)
|
33 |
+
elif self.downsample is not None:
|
34 |
+
input_tensor = self.downsample(input_tensor)
|
35 |
+
hidden_states = self.downsample(hidden_states)
|
36 |
+
|
37 |
+
hidden_states = self.conv1(hidden_states)
|
38 |
+
|
39 |
+
if self.time_emb_proj is not None:
|
40 |
+
if not self.skip_time_act:
|
41 |
+
temb = self.nonlinearity(temb)
|
42 |
+
temb = self.time_emb_proj(temb)[:, :, None, None]
|
43 |
+
|
44 |
+
if self.time_embedding_norm == "default":
|
45 |
+
if temb is not None:
|
46 |
+
hidden_states = hidden_states + temb
|
47 |
+
hidden_states = self.norm2(hidden_states)
|
48 |
+
elif self.time_embedding_norm == "scale_shift":
|
49 |
+
if temb is None:
|
50 |
+
raise ValueError(
|
51 |
+
f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}"
|
52 |
+
)
|
53 |
+
time_scale, time_shift = torch.chunk(temb, 2, dim=1)
|
54 |
+
hidden_states = self.norm2(hidden_states)
|
55 |
+
hidden_states = hidden_states * (1 + time_scale) + time_shift
|
56 |
+
else:
|
57 |
+
hidden_states = self.norm2(hidden_states)
|
58 |
+
|
59 |
+
hidden_states = self.nonlinearity(hidden_states)
|
60 |
+
|
61 |
+
hidden_states = self.dropout(hidden_states)
|
62 |
+
hidden_states = self.conv2(hidden_states)
|
63 |
+
|
64 |
+
# Feature injection and AdaIN (hidden_states)
|
65 |
+
if do_structure_control and "hidden_states" in self.structure_target:
|
66 |
+
hidden_states = feature_injection(hidden_states, batch_order=self.batch_order)
|
67 |
+
|
68 |
+
if self.conv_shortcut is not None:
|
69 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
70 |
+
|
71 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
72 |
+
|
73 |
+
# Feature injection and AdaIN (output_tensor)
|
74 |
+
if do_structure_control and "output_tensor" in self.structure_target:
|
75 |
+
output_tensor = feature_injection(output_tensor, batch_order=self.batch_order)
|
76 |
+
|
77 |
+
return output_tensor
|
78 |
+
|
79 |
+
|
80 |
+
class AttnProcessor2_0: # From <class 'diffusers.models.attention_processor.AttnProcessor2_0'> (diffusers==0.28.0)
|
81 |
+
|
82 |
+
def __init__(self):
|
83 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
84 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
85 |
+
|
86 |
+
def __call__(
|
87 |
+
self,
|
88 |
+
attn: Attention,
|
89 |
+
hidden_states: torch.FloatTensor,
|
90 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
91 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
92 |
+
temb: Optional[torch.FloatTensor] = None,
|
93 |
+
*args,
|
94 |
+
**kwargs,
|
95 |
+
) -> torch.FloatTensor:
|
96 |
+
do_structure_control = attn.do_control and attn.t in attn.structure_schedule
|
97 |
+
do_appearance_control = attn.do_control and attn.t in attn.appearance_schedule
|
98 |
+
|
99 |
+
residual = hidden_states
|
100 |
+
if attn.spatial_norm is not None:
|
101 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
102 |
+
|
103 |
+
input_ndim = hidden_states.ndim
|
104 |
+
|
105 |
+
if input_ndim == 4:
|
106 |
+
batch_size, channel, height, width = hidden_states.shape
|
107 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
108 |
+
|
109 |
+
batch_size, sequence_length, _ = (
|
110 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
111 |
+
)
|
112 |
+
|
113 |
+
if attention_mask is not None:
|
114 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
115 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
116 |
+
# (batch, heads, source_length, target_length)
|
117 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
118 |
+
|
119 |
+
if attn.group_norm is not None:
|
120 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
121 |
+
|
122 |
+
no_encoder_hidden_states = encoder_hidden_states is None
|
123 |
+
if no_encoder_hidden_states:
|
124 |
+
encoder_hidden_states = hidden_states
|
125 |
+
elif attn.norm_cross:
|
126 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
127 |
+
|
128 |
+
if do_appearance_control: # Assume we only have this for self attention
|
129 |
+
hidden_states_normed = normalize(hidden_states, dim=-2) # B H D C
|
130 |
+
encoder_hidden_states_normed = normalize(encoder_hidden_states, dim=-2)
|
131 |
+
|
132 |
+
query_normed = attn.to_q(hidden_states_normed)
|
133 |
+
key_normed = attn.to_k(encoder_hidden_states_normed)
|
134 |
+
|
135 |
+
inner_dim = key_normed.shape[-1]
|
136 |
+
head_dim = inner_dim // attn.heads
|
137 |
+
query_normed = query_normed.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
138 |
+
key_normed = key_normed.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
139 |
+
|
140 |
+
# Match query and key injection with structure injection (if injection is happening this layer)
|
141 |
+
if do_structure_control:
|
142 |
+
if "query" in attn.structure_target:
|
143 |
+
query_normed = feature_injection(query_normed, batch_order=attn.batch_order)
|
144 |
+
if "key" in attn.structure_target:
|
145 |
+
key_normed = feature_injection(key_normed, batch_order=attn.batch_order)
|
146 |
+
|
147 |
+
# Appearance transfer (before)
|
148 |
+
if do_appearance_control and "before" in attn.appearance_target:
|
149 |
+
hidden_states = hidden_states.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
150 |
+
hidden_states = appearance_transfer(hidden_states, query_normed, key_normed, batch_order=attn.batch_order)
|
151 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
152 |
+
|
153 |
+
if no_encoder_hidden_states:
|
154 |
+
encoder_hidden_states = hidden_states
|
155 |
+
elif attn.norm_cross:
|
156 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
157 |
+
|
158 |
+
query = attn.to_q(hidden_states)
|
159 |
+
|
160 |
+
key = attn.to_k(encoder_hidden_states)
|
161 |
+
value = attn.to_v(encoder_hidden_states)
|
162 |
+
|
163 |
+
inner_dim = key.shape[-1]
|
164 |
+
head_dim = inner_dim // attn.heads
|
165 |
+
|
166 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
167 |
+
|
168 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
169 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
170 |
+
|
171 |
+
# Feature injection (query, key, and/or value)
|
172 |
+
if do_structure_control:
|
173 |
+
if "query" in attn.structure_target:
|
174 |
+
query = feature_injection(query, batch_order=attn.batch_order)
|
175 |
+
if "key" in attn.structure_target:
|
176 |
+
key = feature_injection(key, batch_order=attn.batch_order)
|
177 |
+
if "value" in attn.structure_target:
|
178 |
+
value = feature_injection(value, batch_order=attn.batch_order)
|
179 |
+
|
180 |
+
# Appearance transfer (value)
|
181 |
+
if do_appearance_control and "value" in attn.appearance_target:
|
182 |
+
value = appearance_transfer(value, query_normed, key_normed, batch_order=attn.batch_order)
|
183 |
+
|
184 |
+
# The output of sdp = (batch, num_heads, seq_len, head_dim)
|
185 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
186 |
+
hidden_states = F.scaled_dot_product_attention(
|
187 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
188 |
+
)
|
189 |
+
|
190 |
+
# Appearance transfer (after)
|
191 |
+
if do_appearance_control and "after" in attn.appearance_target:
|
192 |
+
hidden_states = appearance_transfer(hidden_states, query_normed, key_normed, batch_order=attn.batch_order)
|
193 |
+
|
194 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
195 |
+
hidden_states = hidden_states.to(query.dtype)
|
196 |
+
|
197 |
+
# Linear projection
|
198 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
199 |
+
# Dropout
|
200 |
+
hidden_states = attn.to_out[1](hidden_states)
|
201 |
+
|
202 |
+
if input_ndim == 4:
|
203 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
204 |
+
|
205 |
+
if attn.residual_connection:
|
206 |
+
hidden_states = hidden_states + residual
|
207 |
+
|
208 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
209 |
+
|
210 |
+
return hidden_states
|
211 |
+
|
212 |
+
|
213 |
+
def register_control(
|
214 |
+
model,
|
215 |
+
timesteps,
|
216 |
+
control_schedule, # structure_conv, structure_attn, appearance_attn
|
217 |
+
control_target = [["output_tensor"], ["query", "key"], ["before"]],
|
218 |
+
):
|
219 |
+
# Assume timesteps in reverse order (T -> 0)
|
220 |
+
for block_type in ["encoder", "decoder", "middle"]:
|
221 |
+
blocks = {
|
222 |
+
"encoder": model.unet.down_blocks,
|
223 |
+
"decoder": model.unet.up_blocks,
|
224 |
+
"middle": [model.unet.mid_block],
|
225 |
+
}[block_type]
|
226 |
+
|
227 |
+
control_schedule_block = control_schedule[block_type]
|
228 |
+
if block_type == "middle":
|
229 |
+
control_schedule_block = [control_schedule_block]
|
230 |
+
|
231 |
+
for layer in range(len(control_schedule_block)):
|
232 |
+
# Convolution
|
233 |
+
num_blocks = len(blocks[layer].resnets) if hasattr(blocks[layer], "resnets") else 0
|
234 |
+
for block in range(num_blocks):
|
235 |
+
convolution = blocks[layer].resnets[block]
|
236 |
+
convolution.structure_target = control_target[0]
|
237 |
+
convolution.structure_schedule = get_schedule(
|
238 |
+
timesteps, get_elem(control_schedule_block[layer][0], block)
|
239 |
+
)
|
240 |
+
convolution.forward = MethodType(convolution_forward, convolution)
|
241 |
+
|
242 |
+
# Self-attention
|
243 |
+
num_blocks = len(blocks[layer].attentions) if hasattr(blocks[layer], "attentions") else 0
|
244 |
+
for block in range(num_blocks):
|
245 |
+
for transformer_block in blocks[layer].attentions[block].transformer_blocks:
|
246 |
+
attention = transformer_block.attn1
|
247 |
+
attention.structure_target = control_target[1]
|
248 |
+
attention.structure_schedule = get_schedule(
|
249 |
+
timesteps, get_elem(control_schedule_block[layer][1], block)
|
250 |
+
)
|
251 |
+
attention.appearance_target = control_target[2]
|
252 |
+
attention.appearance_schedule = get_schedule(
|
253 |
+
timesteps, get_elem(control_schedule_block[layer][2], block)
|
254 |
+
)
|
255 |
+
attention.processor = AttnProcessor2_0()
|
256 |
+
|
257 |
+
|
258 |
+
def register_attr(model, t, do_control, batch_order):
|
259 |
+
for layer_type in ["encoder", "decoder", "middle"]:
|
260 |
+
blocks = {"encoder": model.unet.down_blocks, "decoder": model.unet.up_blocks,
|
261 |
+
"middle": [model.unet.mid_block]}[layer_type]
|
262 |
+
for layer in blocks:
|
263 |
+
# Convolution
|
264 |
+
for module in layer.resnets:
|
265 |
+
module.t = t
|
266 |
+
module.do_control = do_control
|
267 |
+
module.batch_order = batch_order
|
268 |
+
# Self-attention
|
269 |
+
if hasattr(layer, "attentions"):
|
270 |
+
for block in layer.attentions:
|
271 |
+
for module in block.transformer_blocks:
|
272 |
+
module.attn1.t = t
|
273 |
+
module.attn1.do_control = do_control
|
274 |
+
module.attn1.batch_order = batch_order
|
ctrl_x/utils/utils.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
JPEG_QUALITY = 95
|
5 |
+
|
6 |
+
|
7 |
+
def exists(x):
|
8 |
+
return x is not None
|
9 |
+
|
10 |
+
|
11 |
+
def get(x, default):
|
12 |
+
if exists(x):
|
13 |
+
return x
|
14 |
+
return default
|
15 |
+
|
16 |
+
|
17 |
+
def get_self_recurrence_schedule(schedule, num_inference_steps):
|
18 |
+
self_recurrence_schedule = [0] * num_inference_steps
|
19 |
+
for schedule_current in reversed(schedule):
|
20 |
+
if schedule_current is None or len(schedule_current) == 0:
|
21 |
+
continue
|
22 |
+
[start, end, repeat] = schedule_current
|
23 |
+
start_i = round(num_inference_steps * start)
|
24 |
+
end_i = round(num_inference_steps * end)
|
25 |
+
for i in range(start_i, end_i):
|
26 |
+
self_recurrence_schedule[i] = repeat
|
27 |
+
return self_recurrence_schedule
|
28 |
+
|
29 |
+
|
30 |
+
def batch_dict_to_tensor(batch_dict, batch_order):
|
31 |
+
batch_tensor = []
|
32 |
+
for batch_type in batch_order:
|
33 |
+
batch_tensor.append(batch_dict[batch_type])
|
34 |
+
batch_tensor = torch.cat(batch_tensor, dim=0)
|
35 |
+
return batch_tensor
|
36 |
+
|
37 |
+
|
38 |
+
def batch_tensor_to_dict(batch_tensor, batch_order):
|
39 |
+
batch_tensor_chunk = batch_tensor.chunk(len(batch_order))
|
40 |
+
batch_dict = {}
|
41 |
+
for i, batch_type in enumerate(batch_order):
|
42 |
+
batch_dict[batch_type] = batch_tensor_chunk[i]
|
43 |
+
return batch_dict
|
44 |
+
|
45 |
+
|
46 |
+
def noise_prev(scheduler, timestep, x_0, noise=None):
|
47 |
+
if scheduler.num_inference_steps is None:
|
48 |
+
raise ValueError(
|
49 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
50 |
+
)
|
51 |
+
|
52 |
+
if noise is None:
|
53 |
+
noise = torch.randn_like(x_0).to(x_0)
|
54 |
+
|
55 |
+
# From DDIMScheduler step function (hopefully this works)
|
56 |
+
timestep_i = (scheduler.timesteps == timestep).nonzero(as_tuple=True)[0][0].item()
|
57 |
+
if timestep_i + 1 >= scheduler.timesteps.shape[0]: # We are at t = 0 (ish)
|
58 |
+
return x_0
|
59 |
+
prev_timestep = scheduler.timesteps[timestep_i + 1:timestep_i + 2] # Make sure t is not 0-dim
|
60 |
+
|
61 |
+
x_t_prev = scheduler.add_noise(x_0, noise, prev_timestep)
|
62 |
+
return x_t_prev
|
63 |
+
|
64 |
+
|
65 |
+
def noise_t2t(scheduler, timestep, timestep_target, x_t, noise=None):
|
66 |
+
assert timestep_target >= timestep
|
67 |
+
if noise is None:
|
68 |
+
noise = torch.randn_like(x_t).to(x_t)
|
69 |
+
|
70 |
+
alphas_cumprod = scheduler.alphas_cumprod.to(device=x_t.device, dtype=x_t.dtype)
|
71 |
+
|
72 |
+
timestep = timestep.to(torch.long)
|
73 |
+
timestep_target = timestep_target.to(torch.long)
|
74 |
+
|
75 |
+
alpha_prod_t = alphas_cumprod[timestep]
|
76 |
+
alpha_prod_tt = alphas_cumprod[timestep_target]
|
77 |
+
alpha_prod = alpha_prod_tt / alpha_prod_t
|
78 |
+
|
79 |
+
sqrt_alpha_prod = (alpha_prod ** 0.5).flatten()
|
80 |
+
while len(sqrt_alpha_prod.shape) < len(x_t.shape):
|
81 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
82 |
+
|
83 |
+
sqrt_one_minus_alpha_prod = ((1 - alpha_prod) ** 0.5).flatten()
|
84 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(x_t.shape):
|
85 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
86 |
+
|
87 |
+
x_tt = sqrt_alpha_prod * x_t + sqrt_one_minus_alpha_prod * noise
|
88 |
+
return x_tt
|
docs/assets/bootstrap.min.css
ADDED
The diff for this file is too large to render.
See raw diff
|
|
docs/assets/cross_image_attention.jpg
ADDED
Git LFS Details
|
docs/assets/ctrl-x.jpg
ADDED
Git LFS Details
|
docs/assets/font.css
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Homepage Font */
|
2 |
+
|
3 |
+
/* latin-ext */
|
4 |
+
@font-face {
|
5 |
+
font-family: 'Lato';
|
6 |
+
font-style: normal;
|
7 |
+
font-weight: 400;
|
8 |
+
src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjxAwXjeu.woff2) format('woff2');
|
9 |
+
unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF;
|
10 |
+
}
|
11 |
+
|
12 |
+
/* latin */
|
13 |
+
@font-face {
|
14 |
+
font-family: 'Lato';
|
15 |
+
font-style: normal;
|
16 |
+
font-weight: 400;
|
17 |
+
src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjx4wXg.woff2) format('woff2');
|
18 |
+
unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
|
19 |
+
}
|
20 |
+
|
21 |
+
/* latin-ext */
|
22 |
+
@font-face {
|
23 |
+
font-family: 'Lato';
|
24 |
+
font-style: normal;
|
25 |
+
font-weight: 700;
|
26 |
+
src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwaPGR_p.woff2) format('woff2');
|
27 |
+
unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF;
|
28 |
+
}
|
29 |
+
|
30 |
+
/* latin */
|
31 |
+
@font-face {
|
32 |
+
font-family: 'Lato';
|
33 |
+
font-style: normal;
|
34 |
+
font-weight: 700;
|
35 |
+
src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwiPGQ.woff2) format('woff2');
|
36 |
+
unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
|
37 |
+
}
|
docs/assets/freecontrol.jpg
ADDED
Git LFS Details
|
docs/assets/genforce.png
ADDED
docs/assets/pipeline.jpg
ADDED
Git LFS Details
|
docs/assets/results_animatediff.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:43e29629924da2f368048016b2bb4ee973d0d38dc6f868098b0d9fbd6ac2e8ea
|
3 |
+
size 20573323
|
docs/assets/results_multi_subject.jpg
ADDED
Git LFS Details
|
docs/assets/results_struct+app.jpg
ADDED
Git LFS Details
|
docs/assets/results_struct+app_2.jpg
ADDED
Git LFS Details
|
docs/assets/results_struct+prompt.jpg
ADDED
Git LFS Details
|
docs/assets/style.css
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Body */
|
2 |
+
body {
|
3 |
+
background: #e3e5e8;
|
4 |
+
color: #ffffff;
|
5 |
+
font-family: 'Lato', Verdana, Helvetica, sans-serif;
|
6 |
+
font-weight: 300;
|
7 |
+
font-size: 14pt;
|
8 |
+
}
|
9 |
+
|
10 |
+
/* Hyperlinks */
|
11 |
+
a {text-decoration: none;}
|
12 |
+
a:link {color: #1772d0;}
|
13 |
+
a:visited {color: #1772d0;}
|
14 |
+
a:active {color: red;}
|
15 |
+
a:hover {color: #f09228;}
|
16 |
+
|
17 |
+
/* Pre-formatted Text */
|
18 |
+
pre {
|
19 |
+
margin: 5pt 0;
|
20 |
+
border: 0;
|
21 |
+
font-size: 12pt;
|
22 |
+
background: #fcfcfc;
|
23 |
+
}
|
24 |
+
|
25 |
+
/* Project Page Style */
|
26 |
+
/* Section */
|
27 |
+
.section {
|
28 |
+
width: 768pt;
|
29 |
+
min-height: 100pt;
|
30 |
+
margin: 15pt auto;
|
31 |
+
padding: 20pt 30pt;
|
32 |
+
border: 1pt hidden #000;
|
33 |
+
text-align: justify;
|
34 |
+
color: #000000;
|
35 |
+
background: #ffffff;
|
36 |
+
}
|
37 |
+
|
38 |
+
/* Header (Title and Logo) */
|
39 |
+
.section .header {
|
40 |
+
min-height: 80pt;
|
41 |
+
margin-top: 30pt;
|
42 |
+
}
|
43 |
+
.section .header .logo {
|
44 |
+
width: 80pt;
|
45 |
+
margin-left: 10pt;
|
46 |
+
float: left;
|
47 |
+
}
|
48 |
+
.section .header .logo img {
|
49 |
+
width: 80pt;
|
50 |
+
object-fit: cover;
|
51 |
+
}
|
52 |
+
.section .header .title {
|
53 |
+
margin: 0 120pt;
|
54 |
+
text-align: center;
|
55 |
+
font-size: 22pt;
|
56 |
+
}
|
57 |
+
|
58 |
+
/* Author */
|
59 |
+
.section .author {
|
60 |
+
margin: 5pt 0;
|
61 |
+
text-align: center;
|
62 |
+
font-size: 16pt;
|
63 |
+
}
|
64 |
+
|
65 |
+
/* Institution */
|
66 |
+
.section .institution {
|
67 |
+
margin: 5pt 0;
|
68 |
+
text-align: center;
|
69 |
+
font-size: 16pt;
|
70 |
+
}
|
71 |
+
|
72 |
+
/* Note */
|
73 |
+
.section .note {
|
74 |
+
margin: 5pt 0;
|
75 |
+
text-align: center;
|
76 |
+
font-size: 12pt;
|
77 |
+
}
|
78 |
+
|
79 |
+
/* Hyperlink (such as Paper and Code) */
|
80 |
+
.section .link {
|
81 |
+
margin: 5pt 0;
|
82 |
+
text-align: center;
|
83 |
+
font-size: 16pt;
|
84 |
+
}
|
85 |
+
|
86 |
+
/* Teaser */
|
87 |
+
.section .teaser {
|
88 |
+
margin: 20pt 0;
|
89 |
+
text-align: center;
|
90 |
+
}
|
91 |
+
|
92 |
+
/* Section Title */
|
93 |
+
.section .title {
|
94 |
+
text-align: center;
|
95 |
+
font-size: 22pt;
|
96 |
+
margin: 5pt 0 15pt 0; /* top right bottom left */
|
97 |
+
}
|
98 |
+
|
99 |
+
/* Section Body */
|
100 |
+
.section .body {
|
101 |
+
margin-bottom: 15pt;
|
102 |
+
text-align: justify;
|
103 |
+
font-size: 14pt;
|
104 |
+
}
|
105 |
+
|
106 |
+
/* BibTeX */
|
107 |
+
.section .bibtex {
|
108 |
+
margin: 5pt 0;
|
109 |
+
text-align: left;
|
110 |
+
font-size: 22pt;
|
111 |
+
}
|
112 |
+
|
113 |
+
/* Related Work */
|
114 |
+
.section .ref {
|
115 |
+
margin: 20pt 0 10pt 0; /* top right bottom left */
|
116 |
+
text-align: left;
|
117 |
+
font-size: 18pt;
|
118 |
+
font-weight: bold;
|
119 |
+
}
|
120 |
+
|
121 |
+
/* Citation */
|
122 |
+
.section .citation {
|
123 |
+
min-height: 60pt;
|
124 |
+
margin: 10pt 0;
|
125 |
+
}
|
126 |
+
.section .citation .image {
|
127 |
+
width: 120pt;
|
128 |
+
float: left;
|
129 |
+
}
|
130 |
+
.section .citation .image img {
|
131 |
+
max-height: 60pt;
|
132 |
+
width: 120pt;
|
133 |
+
object-fit: cover;
|
134 |
+
}
|
135 |
+
.section .citation .comment{
|
136 |
+
margin-left: 130pt;
|
137 |
+
text-align: left;
|
138 |
+
font-size: 14pt;
|
139 |
+
}
|
docs/assets/teaser_github.jpg
ADDED
Git LFS Details
|
docs/assets/teaser_small.jpg
ADDED
Git LFS Details
|
docs/index.html
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!doctype html>
|
2 |
+
<html lang="en">
|
3 |
+
|
4 |
+
|
5 |
+
<!-- === Header Starts === -->
|
6 |
+
<head>
|
7 |
+
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
|
8 |
+
|
9 |
+
<title>Ctrl-X</title>
|
10 |
+
|
11 |
+
<link href="./assets/bootstrap.min.css" rel="stylesheet">
|
12 |
+
<link href="./assets/font.css" rel="stylesheet" type="text/css">
|
13 |
+
<link href="./assets/style.css" rel="stylesheet" type="text/css">
|
14 |
+
</head>
|
15 |
+
<!-- === Header Ends === -->
|
16 |
+
|
17 |
+
|
18 |
+
<body>
|
19 |
+
|
20 |
+
|
21 |
+
<!-- === Home Section Starts === -->
|
22 |
+
<div class="section">
|
23 |
+
<!-- === Title Starts === -->
|
24 |
+
<div class="header">
|
25 |
+
<div class="logo">
|
26 |
+
<a href="https://genforce.github.io/" target="_blank"><img src="./assets/genforce.png"></a>
|
27 |
+
</div>
|
28 |
+
<div class="title", style="padding-top: 25pt;"> <!-- Set padding as 10 if title is with two lines. -->
|
29 |
+
Ctrl-X: Controlling Structure and Appearance for Text-To-Image Generation Without Guidance
|
30 |
+
</div>
|
31 |
+
</div>
|
32 |
+
<!-- === Title Ends === -->
|
33 |
+
<div class="author">
|
34 |
+
<a href="https://kuanhenglin.github.io" target="_blank">Kuan Heng Lin</a><sup>1</sup>*
|
35 |
+
<a href="https://sichengmo.github.io/" target="_blank">Sicheng Mo</a><sup>1</sup>*
|
36 |
+
<a href="https://bklingher.github.io" target="_blank">Ben Klingher</a><sup>1</sup>
|
37 |
+
<a href="https://pages.cs.wisc.edu/~fmu/" target="_blank">Fangzhou Mu</a><sup>2</sup>
|
38 |
+
<a href="https://boleizhou.github.io/" target="_blank">Bolei Zhou</a><sup>1</sup>
|
39 |
+
</div>
|
40 |
+
<div class="institution">
|
41 |
+
<sup>1</sup>UCLA
|
42 |
+
<sup>2</sup>NVIDIA
|
43 |
+
</div>
|
44 |
+
<div class="note">
|
45 |
+
*Equal contribution
|
46 |
+
</div>
|
47 |
+
<div class="title" style="font-size: 18pt;margin: 15pt 0 15pt 0">
|
48 |
+
NeurIPS 2024
|
49 |
+
</div>
|
50 |
+
<div class="link">
|
51 |
+
[<a href="https://arxiv.org/abs/2406.07540" target="_blank">Paper</a>]
|
52 |
+
[<a href="https://github.com/genforce/ctrl-x" target="_blank">Code</a>]
|
53 |
+
</div>
|
54 |
+
<div class="teaser">
|
55 |
+
<img src="assets/ctrl-x.jpg" width="85%">
|
56 |
+
</div>
|
57 |
+
</div>
|
58 |
+
<!-- === Home Section Ends === -->
|
59 |
+
|
60 |
+
|
61 |
+
<!-- === Overview Section Starts === -->
|
62 |
+
<div class="section">
|
63 |
+
<div class="title">Overview</div>
|
64 |
+
<div class="body">
|
65 |
+
We present <b>Ctrl-X</b>, a simple <i>training-free</i> and <i>guidance-free</i> framework for text-to-image (T2I) generation with structure and appearance control. Given user-provided structure and appearance images, Ctrl-X designs feedforward structure control to enable structure alignment with the structure image and semantic-aware appearance transfer to facilitate the appearance transfer from the appearance image. Ctrl-X supports novel structure control with arbitrary condition images of any modality, is significantly faster than prior training-free appearance transfer methods, and provides instant plug-and-play to any T2I and text-to-video (T2V) diffusion model.
|
66 |
+
<table width="100%" style="margin: 20pt 0; text-align: center;">
|
67 |
+
<tr>
|
68 |
+
<td><img src="assets/pipeline.jpg" width="85%"></td>
|
69 |
+
</tr>
|
70 |
+
</table>
|
71 |
+
|
72 |
+
<b>How does it work?</b> Given clean structure and appearance latents, we first obtain noised structure and appearance latents via the diffusion forward process, then extracting their U-Net features from a pretrained T2I diffusion model. When denoising the output latent, we inject convolution and self-attention features from the structure latent and leverage self-attention correspondence to transfer spatially-aware appearance statistics from the appearance latent to achieve structure and appearance control. We name our method "Ctrl-X" because we reformulate the controllable generation problem by 'cutting' (and 'pasting') structure preservation and semantic-aware stylization together.
|
73 |
+
</div>
|
74 |
+
</div>
|
75 |
+
<!-- === Overview Section Ends === -->
|
76 |
+
|
77 |
+
|
78 |
+
<!-- === Result Section Starts === -->
|
79 |
+
<div class="section">
|
80 |
+
<div class="title">Results: Structure and appearance control</div>
|
81 |
+
<div class="body">
|
82 |
+
Results of training-free and guidance-free T2I diffusion with structure and appearance control, where Ctrl-X supports a diverse variety of structure images, including natural images, ControlNet-supported conditions (e.g., canny maps, normal maps), and in-the-wild conditions (e.g., wireframes, 3D meshes). The base model here is <a href="https://arxiv.org/abs/2307.01952" target="_blank">Stable Diffusion XL v1.0</a>.
|
83 |
+
|
84 |
+
<!-- Adjust the number of rows and columns (EVERY project differs). -->
|
85 |
+
<table width="100%" style="margin: 20pt 0; text-align: center;">
|
86 |
+
<tr>
|
87 |
+
<td><img src="assets/results_struct+app.jpg" width="100%"></td>
|
88 |
+
</tr>
|
89 |
+
</table>
|
90 |
+
<table width="100%" style="margin: 20pt 0; text-align: center;">
|
91 |
+
<tr>
|
92 |
+
<td><img src="assets/results_struct+app_2.jpg" width="85%"></td>
|
93 |
+
</tr>
|
94 |
+
</table>
|
95 |
+
</div>
|
96 |
+
</div>
|
97 |
+
|
98 |
+
<div class="section">
|
99 |
+
<div class="title">Results: Multi-subject structure and appearance control</div>
|
100 |
+
<div class="body">
|
101 |
+
Ctrl-X is capable of multi-subject generation with semantic correspondence between appearance and structure images across both subjects and backgrounds. In comparison, <a href="https://arxiv.org/abs/2302.05543" target="_blank">ControlNet</a> + <a href="https://arxiv.org/abs/2308.06721" target="_blank">IP-Adapter</a> often fails at transferring all subject and background appearances.
|
102 |
+
|
103 |
+
<!-- Adjust the number of rows and columns (EVERY project differs). -->
|
104 |
+
<table width="100%" style="margin: 20pt 0; text-align: center;">
|
105 |
+
<tr>
|
106 |
+
<td><img src="assets/results_multi_subject.jpg" width="90%"></td>
|
107 |
+
</tr>
|
108 |
+
</table>
|
109 |
+
</div>
|
110 |
+
</div>
|
111 |
+
|
112 |
+
<div class="section">
|
113 |
+
<div class="title">Results: Prompt-driven conditional generation</div>
|
114 |
+
<div class="body">
|
115 |
+
Ctrl-X also supports prompt-driven conditional generation, where it generates an output image complying with the given text prompt while aligning with the structure of the structure image. Ctrl-X continues to support any structure image/condition type here as well. The base model here is <a href="https://arxiv.org/abs/2307.01952" target="_blank">Stable Diffusion XL v1.0</a>.
|
116 |
+
|
117 |
+
<!-- Adjust the number of rows and columns (EVERY project differs). -->
|
118 |
+
<table width="100%" style="margin: 20pt 0; text-align: center;">
|
119 |
+
<tr>
|
120 |
+
<td><img src="assets/results_struct+prompt.jpg" width="100%"></td>
|
121 |
+
</tr>
|
122 |
+
</table>
|
123 |
+
</div>
|
124 |
+
</div>
|
125 |
+
|
126 |
+
<div class="section">
|
127 |
+
<div class="title">Results: Extension to video generation</div>
|
128 |
+
<div class="body">
|
129 |
+
We can directly apply Ctrl-X to text-to-video (T2V) models. We show results of <a href="https://animatediff.github.io/" target="_blank">AnimateDiff v1.5.3</a> (with base model <a href="https://huggingface.co/SG161222/Realistic_Vision_V5.1_noVAE" target="_blank">Realistic Vision v5.1</a>) here.
|
130 |
+
|
131 |
+
<!-- Demo video here. Adjust the frame size based on the demo (EVERY project differs). -->
|
132 |
+
<div style="position: relative; padding-top: 50%; margin: 20pt 0; text-align: center;">
|
133 |
+
<iframe src="assets/results_animatediff.mp4" frameborder=0
|
134 |
+
style="position: absolute; top: 2.5%; left: 0%; width: 100%; height: 100%;"
|
135 |
+
allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture"
|
136 |
+
allowfullscreen></iframe>
|
137 |
+
</div>
|
138 |
+
</div>
|
139 |
+
</div>
|
140 |
+
|
141 |
+
<!-- === Result Section Ends === -->
|
142 |
+
|
143 |
+
|
144 |
+
<!-- === Reference Section Starts === -->
|
145 |
+
<div class="section">
|
146 |
+
<div class="bibtex">BibTeX</div>
|
147 |
+
<pre>
|
148 |
+
@inproceedings{lin2024ctrlx,
|
149 |
+
author = {Lin, {Kuan Heng} and Mo, Sicheng and Klingher, Ben and Mu, Fangzhou and Zhou, Bolei},
|
150 |
+
booktitle = {Advances in Neural Information Processing Systems},
|
151 |
+
title = {Ctrl-X: Controlling Structure and Appearance for Text-To-Image Generation Without Guidance},
|
152 |
+
year = {2024}
|
153 |
+
}
|
154 |
+
</pre>
|
155 |
+
|
156 |
+
<!-- BZ: we should give other related work enough credits, -->
|
157 |
+
<!-- so please include some most relevant work and leave some comment to summarize work and the difference. -->
|
158 |
+
<div class="ref">Related Work</div>
|
159 |
+
<div class="citation">
|
160 |
+
<div class="image"><img src="assets/freecontrol.jpg"></div>
|
161 |
+
<div class="comment">
|
162 |
+
<a href="https://genforce.github.io/freecontrol/" target="_blank">
|
163 |
+
Sicheng Mo, Fangzhou Mu, Kuan Heng Lin, Yanli Liu, Bochen Guan, Yin Li, Bolei Zhou.
|
164 |
+
FreeControl: Training-Free Spatial Control of Any Text-to-Image Diffusion Model with Any Condition.
|
165 |
+
CVPR 2024.</a><br>
|
166 |
+
<b>Comment:</b>
|
167 |
+
Training-free conditional generation by guidance in diffusion U-Net subspaces for structure control and appearance regularization.
|
168 |
+
</div>
|
169 |
+
</div>
|
170 |
+
<div class="citation">
|
171 |
+
<div class="image"><img src="assets/cross_image_attention.jpg"></div>
|
172 |
+
<div class="comment">
|
173 |
+
<a href="https://garibida.github.io/cross-image-attention/" target="_blank">
|
174 |
+
Yuval Alaluf, Daniel Garibi, Or Patashnik, Hadar Averbuch-Elor, Daniel Cohen-Or.
|
175 |
+
Cross-Image Attention for Zero-Shot Appearance Transfer.
|
176 |
+
SIGGRAPH 2024.</a><br>
|
177 |
+
<b>Comment:</b>
|
178 |
+
Guidance-free appearance transfer to natural images with self-attention key + value swaps via cross-image correspondence.
|
179 |
+
</div>
|
180 |
+
</div>
|
181 |
+
</div>
|
182 |
+
<!-- === Reference Section Ends === -->
|
183 |
+
|
184 |
+
|
185 |
+
</body>
|
186 |
+
</html>
|
environment.yaml
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: ctrlx
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
dependencies:
|
5 |
+
- _libgcc_mutex=0.1=main
|
6 |
+
- _openmp_mutex=5.1=1_gnu
|
7 |
+
- bzip2=1.0.8=h5eee18b_6
|
8 |
+
- ca-certificates=2024.3.11=h06a4308_0
|
9 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
10 |
+
- libffi=3.4.4=h6a678d5_1
|
11 |
+
- libgcc-ng=11.2.0=h1234567_1
|
12 |
+
- libgomp=11.2.0=h1234567_1
|
13 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
14 |
+
- libuuid=1.41.5=h5eee18b_0
|
15 |
+
- ncurses=6.4=h6a678d5_0
|
16 |
+
- openssl=3.0.13=h7f8727e_2
|
17 |
+
- pip=24.0=py310h06a4308_0
|
18 |
+
- python=3.10.14=h955ad1f_1
|
19 |
+
- readline=8.2=h5eee18b_0
|
20 |
+
- setuptools=69.5.1=py310h06a4308_0
|
21 |
+
- sqlite=3.45.3=h5eee18b_0
|
22 |
+
- tk=8.6.14=h39e8969_0
|
23 |
+
- wheel=0.43.0=py310h06a4308_0
|
24 |
+
- xz=5.4.6=h5eee18b_1
|
25 |
+
- zlib=1.2.13=h5eee18b_1
|
26 |
+
- pip:
|
27 |
+
- aiofiles==23.2.1
|
28 |
+
- altair==5.3.0
|
29 |
+
- annotated-types==0.7.0
|
30 |
+
- anyio==4.4.0
|
31 |
+
- attrs==23.2.0
|
32 |
+
- certifi==2024.2.2
|
33 |
+
- charset-normalizer==3.3.2
|
34 |
+
- click==8.1.7
|
35 |
+
- contourpy==1.2.1
|
36 |
+
- cycler==0.12.1
|
37 |
+
- diffusers==0.28.0
|
38 |
+
- dnspython==2.6.1
|
39 |
+
- einops==0.8.0
|
40 |
+
- email-validator==2.1.1
|
41 |
+
- exceptiongroup==1.2.1
|
42 |
+
- fastapi==0.111.0
|
43 |
+
- fastapi-cli==0.0.4
|
44 |
+
- ffmpy==0.3.2
|
45 |
+
- filelock==3.14.0
|
46 |
+
- fonttools==4.52.4
|
47 |
+
- fsspec==2024.5.0
|
48 |
+
- gradio==4.31.5
|
49 |
+
- gradio-client==0.16.4
|
50 |
+
- h11==0.14.0
|
51 |
+
- httpcore==1.0.5
|
52 |
+
- httptools==0.6.1
|
53 |
+
- httpx==0.27.0
|
54 |
+
- huggingface-hub==0.23.2
|
55 |
+
- idna==3.7
|
56 |
+
- importlib-metadata==7.1.0
|
57 |
+
- importlib-resources==6.4.0
|
58 |
+
- jinja2==3.1.4
|
59 |
+
- jsonschema==4.22.0
|
60 |
+
- jsonschema-specifications==2023.12.1
|
61 |
+
- kiwisolver==1.4.5
|
62 |
+
- markdown-it-py==3.0.0
|
63 |
+
- markupsafe==2.1.5
|
64 |
+
- matplotlib==3.9.0
|
65 |
+
- mdurl==0.1.2
|
66 |
+
- mpmath==1.3.0
|
67 |
+
- networkx==3.3
|
68 |
+
- numpy==1.26.4
|
69 |
+
- nvidia-cublas-cu12==12.1.3.1
|
70 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
71 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
72 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
73 |
+
- nvidia-cudnn-cu12==8.9.2.26
|
74 |
+
- nvidia-cufft-cu12==11.0.2.54
|
75 |
+
- nvidia-curand-cu12==10.3.2.106
|
76 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
77 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
78 |
+
- nvidia-nccl-cu12==2.20.5
|
79 |
+
- nvidia-nvjitlink-cu12==12.5.40
|
80 |
+
- nvidia-nvtx-cu12==12.1.105
|
81 |
+
- orjson==3.10.3
|
82 |
+
- packaging==24.0
|
83 |
+
- pandas==2.2.2
|
84 |
+
- pillow==10.3.0
|
85 |
+
- pydantic==2.7.2
|
86 |
+
- pydantic-core==2.18.3
|
87 |
+
- pydub==0.25.1
|
88 |
+
- pygments==2.18.0
|
89 |
+
- pyparsing==3.1.2
|
90 |
+
- python-dateutil==2.9.0.post0
|
91 |
+
- python-dotenv==1.0.1
|
92 |
+
- python-multipart==0.0.9
|
93 |
+
- pytz==2024.1
|
94 |
+
- pyyaml==6.0.1
|
95 |
+
- referencing==0.35.1
|
96 |
+
- regex==2024.5.15
|
97 |
+
- requests==2.32.2
|
98 |
+
- rich==13.7.1
|
99 |
+
- rpds-py==0.18.1
|
100 |
+
- ruff==0.4.6
|
101 |
+
- safetensors==0.4.3
|
102 |
+
- semantic-version==2.10.0
|
103 |
+
- shellingham==1.5.4
|
104 |
+
- six==1.16.0
|
105 |
+
- sniffio==1.3.1
|
106 |
+
- starlette==0.37.2
|
107 |
+
- sympy==1.12
|
108 |
+
- tokenizers==0.19.1
|
109 |
+
- tomlkit==0.12.0
|
110 |
+
- toolz==0.12.1
|
111 |
+
- torch==2.3.0
|
112 |
+
- torchvision==0.18.0
|
113 |
+
- tqdm==4.66.4
|
114 |
+
- transformers==4.41.1
|
115 |
+
- triton==2.3.0
|
116 |
+
- typer==0.12.3
|
117 |
+
- typing-extensions==4.12.0
|
118 |
+
- tzdata==2024.1
|
119 |
+
- ujson==5.10.0
|
120 |
+
- urllib3==2.2.1
|
121 |
+
- uvicorn==0.30.0
|
122 |
+
- uvloop==0.19.0
|
123 |
+
- watchfiles==0.22.0
|
124 |
+
- websockets==11.0.3
|
125 |
+
- zipp==3.19.0
|