Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- .gitattributes +9 -0
- .gitignore +168 -0
- LICENSE +407 -0
- PROMPT_GUIDE.md +91 -0
- README.md +169 -14
- assets/cond_and_image.jpg +3 -0
- assets/examples/id_customization/chenhao/image_0.png +0 -0
- assets/examples/id_customization/chenhao/image_1.png +0 -0
- assets/examples/id_customization/chenhao/image_2.png +0 -0
- assets/onediffusion_appendix_faceid.jpg +3 -0
- assets/onediffusion_appendix_faceid_3.jpg +3 -0
- assets/onediffusion_appendix_multiview.jpg +3 -0
- assets/onediffusion_appendix_multiview_2.jpg +0 -0
- assets/onediffusion_appendix_text2multiview.pdf +3 -0
- assets/onediffusion_editing.jpg +0 -0
- assets/onediffusion_zeroshot.jpg +3 -0
- assets/promptguide_complex.jpg +3 -0
- assets/promptguide_idtask.jpg +0 -0
- assets/subject_driven.jpg +0 -0
- assets/teaser.png +3 -0
- assets/text2image.jpg +0 -0
- assets/text2multiview.jpg +3 -0
- docker/Dockerfile +119 -0
- gradio_demo.py +715 -0
- inference.py +37 -0
- onediffusion/dataset/multitask/multiview.py +277 -0
- onediffusion/dataset/raydiff_utils.py +739 -0
- onediffusion/dataset/transforms.py +133 -0
- onediffusion/dataset/utils.py +175 -0
- onediffusion/diffusion/pipelines/image_processor.py +674 -0
- onediffusion/diffusion/pipelines/onediffusion.py +1080 -0
- onediffusion/models/denoiser/__init__.py +3 -0
- onediffusion/models/denoiser/nextdit/__init__.py +1 -0
- onediffusion/models/denoiser/nextdit/layers.py +132 -0
- onediffusion/models/denoiser/nextdit/modeling_nextdit.py +571 -0
- requirements.txt +27 -6
.gitattributes
CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/cond_and_image.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/onediffusion_appendix_faceid.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/onediffusion_appendix_faceid_3.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/onediffusion_appendix_multiview.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/onediffusion_appendix_text2multiview.pdf filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/onediffusion_zeroshot.jpg filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/promptguide_complex.jpg filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/text2multiview.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
#.idea/
|
LICENSE
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
58 |
+
License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
63 |
+
License"). To the extent this Public License may be interpreted as a
|
64 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
65 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
66 |
+
such rights in consideration of benefits the Licensor receives from
|
67 |
+
making the Licensed Material available under these terms and
|
68 |
+
conditions.
|
69 |
+
|
70 |
+
|
71 |
+
Section 1 -- Definitions.
|
72 |
+
|
73 |
+
a. Adapted Material means material subject to Copyright and Similar
|
74 |
+
Rights that is derived from or based upon the Licensed Material
|
75 |
+
and in which the Licensed Material is translated, altered,
|
76 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
77 |
+
permission under the Copyright and Similar Rights held by the
|
78 |
+
Licensor. For purposes of this Public License, where the Licensed
|
79 |
+
Material is a musical work, performance, or sound recording,
|
80 |
+
Adapted Material is always produced where the Licensed Material is
|
81 |
+
synched in timed relation with a moving image.
|
82 |
+
|
83 |
+
b. Adapter's License means the license You apply to Your Copyright
|
84 |
+
and Similar Rights in Your contributions to Adapted Material in
|
85 |
+
accordance with the terms and conditions of this Public License.
|
86 |
+
|
87 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
88 |
+
closely related to copyright including, without limitation,
|
89 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
90 |
+
Rights, without regard to how the rights are labeled or
|
91 |
+
categorized. For purposes of this Public License, the rights
|
92 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
93 |
+
Rights.
|
94 |
+
d. Effective Technological Measures means those measures that, in the
|
95 |
+
absence of proper authority, may not be circumvented under laws
|
96 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
97 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
98 |
+
agreements.
|
99 |
+
|
100 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
101 |
+
any other exception or limitation to Copyright and Similar Rights
|
102 |
+
that applies to Your use of the Licensed Material.
|
103 |
+
|
104 |
+
f. Licensed Material means the artistic or literary work, database,
|
105 |
+
or other material to which the Licensor applied this Public
|
106 |
+
License.
|
107 |
+
|
108 |
+
g. Licensed Rights means the rights granted to You subject to the
|
109 |
+
terms and conditions of this Public License, which are limited to
|
110 |
+
all Copyright and Similar Rights that apply to Your use of the
|
111 |
+
Licensed Material and that the Licensor has authority to license.
|
112 |
+
|
113 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
114 |
+
under this Public License.
|
115 |
+
|
116 |
+
i. NonCommercial means not primarily intended for or directed towards
|
117 |
+
commercial advantage or monetary compensation. For purposes of
|
118 |
+
this Public License, the exchange of the Licensed Material for
|
119 |
+
other material subject to Copyright and Similar Rights by digital
|
120 |
+
file-sharing or similar means is NonCommercial provided there is
|
121 |
+
no payment of monetary compensation in connection with the
|
122 |
+
exchange.
|
123 |
+
|
124 |
+
j. Share means to provide material to the public by any means or
|
125 |
+
process that requires permission under the Licensed Rights, such
|
126 |
+
as reproduction, public display, public performance, distribution,
|
127 |
+
dissemination, communication, or importation, and to make material
|
128 |
+
available to the public including in ways that members of the
|
129 |
+
public may access the material from a place and at a time
|
130 |
+
individually chosen by them.
|
131 |
+
|
132 |
+
k. Sui Generis Database Rights means rights other than copyright
|
133 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
134 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
135 |
+
as amended and/or succeeded, as well as other essentially
|
136 |
+
equivalent rights anywhere in the world.
|
137 |
+
|
138 |
+
l. You means the individual or entity exercising the Licensed Rights
|
139 |
+
under this Public License. Your has a corresponding meaning.
|
140 |
+
|
141 |
+
|
142 |
+
Section 2 -- Scope.
|
143 |
+
|
144 |
+
a. License grant.
|
145 |
+
|
146 |
+
1. Subject to the terms and conditions of this Public License,
|
147 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
148 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
149 |
+
exercise the Licensed Rights in the Licensed Material to:
|
150 |
+
|
151 |
+
a. reproduce and Share the Licensed Material, in whole or
|
152 |
+
in part, for NonCommercial purposes only; and
|
153 |
+
|
154 |
+
b. produce, reproduce, and Share Adapted Material for
|
155 |
+
NonCommercial purposes only.
|
156 |
+
|
157 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
158 |
+
Exceptions and Limitations apply to Your use, this Public
|
159 |
+
License does not apply, and You do not need to comply with
|
160 |
+
its terms and conditions.
|
161 |
+
|
162 |
+
3. Term. The term of this Public License is specified in Section
|
163 |
+
6(a).
|
164 |
+
|
165 |
+
4. Media and formats; technical modifications allowed. The
|
166 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
167 |
+
all media and formats whether now known or hereafter created,
|
168 |
+
and to make technical modifications necessary to do so. The
|
169 |
+
Licensor waives and/or agrees not to assert any right or
|
170 |
+
authority to forbid You from making technical modifications
|
171 |
+
necessary to exercise the Licensed Rights, including
|
172 |
+
technical modifications necessary to circumvent Effective
|
173 |
+
Technological Measures. For purposes of this Public License,
|
174 |
+
simply making modifications authorized by this Section 2(a)
|
175 |
+
(4) never produces Adapted Material.
|
176 |
+
|
177 |
+
5. Downstream recipients.
|
178 |
+
|
179 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
180 |
+
recipient of the Licensed Material automatically
|
181 |
+
receives an offer from the Licensor to exercise the
|
182 |
+
Licensed Rights under the terms and conditions of this
|
183 |
+
Public License.
|
184 |
+
|
185 |
+
b. No downstream restrictions. You may not offer or impose
|
186 |
+
any additional or different terms or conditions on, or
|
187 |
+
apply any Effective Technological Measures to, the
|
188 |
+
Licensed Material if doing so restricts exercise of the
|
189 |
+
Licensed Rights by any recipient of the Licensed
|
190 |
+
Material.
|
191 |
+
|
192 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
193 |
+
may be construed as permission to assert or imply that You
|
194 |
+
are, or that Your use of the Licensed Material is, connected
|
195 |
+
with, or sponsored, endorsed, or granted official status by,
|
196 |
+
the Licensor or others designated to receive attribution as
|
197 |
+
provided in Section 3(a)(1)(A)(i).
|
198 |
+
|
199 |
+
b. Other rights.
|
200 |
+
|
201 |
+
1. Moral rights, such as the right of integrity, are not
|
202 |
+
licensed under this Public License, nor are publicity,
|
203 |
+
privacy, and/or other similar personality rights; however, to
|
204 |
+
the extent possible, the Licensor waives and/or agrees not to
|
205 |
+
assert any such rights held by the Licensor to the limited
|
206 |
+
extent necessary to allow You to exercise the Licensed
|
207 |
+
Rights, but not otherwise.
|
208 |
+
|
209 |
+
2. Patent and trademark rights are not licensed under this
|
210 |
+
Public License.
|
211 |
+
|
212 |
+
3. To the extent possible, the Licensor waives any right to
|
213 |
+
collect royalties from You for the exercise of the Licensed
|
214 |
+
Rights, whether directly or through a collecting society
|
215 |
+
under any voluntary or waivable statutory or compulsory
|
216 |
+
licensing scheme. In all other cases the Licensor expressly
|
217 |
+
reserves any right to collect such royalties, including when
|
218 |
+
the Licensed Material is used other than for NonCommercial
|
219 |
+
purposes.
|
220 |
+
|
221 |
+
|
222 |
+
Section 3 -- License Conditions.
|
223 |
+
|
224 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
225 |
+
following conditions.
|
226 |
+
|
227 |
+
a. Attribution.
|
228 |
+
|
229 |
+
1. If You Share the Licensed Material (including in modified
|
230 |
+
form), You must:
|
231 |
+
|
232 |
+
a. retain the following if it is supplied by the Licensor
|
233 |
+
with the Licensed Material:
|
234 |
+
|
235 |
+
i. identification of the creator(s) of the Licensed
|
236 |
+
Material and any others designated to receive
|
237 |
+
attribution, in any reasonable manner requested by
|
238 |
+
the Licensor (including by pseudonym if
|
239 |
+
designated);
|
240 |
+
|
241 |
+
ii. a copyright notice;
|
242 |
+
|
243 |
+
iii. a notice that refers to this Public License;
|
244 |
+
|
245 |
+
iv. a notice that refers to the disclaimer of
|
246 |
+
warranties;
|
247 |
+
|
248 |
+
v. a URI or hyperlink to the Licensed Material to the
|
249 |
+
extent reasonably practicable;
|
250 |
+
|
251 |
+
b. indicate if You modified the Licensed Material and
|
252 |
+
retain an indication of any previous modifications; and
|
253 |
+
|
254 |
+
c. indicate the Licensed Material is licensed under this
|
255 |
+
Public License, and include the text of, or the URI or
|
256 |
+
hyperlink to, this Public License.
|
257 |
+
|
258 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
259 |
+
reasonable manner based on the medium, means, and context in
|
260 |
+
which You Share the Licensed Material. For example, it may be
|
261 |
+
reasonable to satisfy the conditions by providing a URI or
|
262 |
+
hyperlink to a resource that includes the required
|
263 |
+
information.
|
264 |
+
|
265 |
+
3. If requested by the Licensor, You must remove any of the
|
266 |
+
information required by Section 3(a)(1)(A) to the extent
|
267 |
+
reasonably practicable.
|
268 |
+
|
269 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
270 |
+
License You apply must not prevent recipients of the Adapted
|
271 |
+
Material from complying with this Public License.
|
272 |
+
|
273 |
+
|
274 |
+
Section 4 -- Sui Generis Database Rights.
|
275 |
+
|
276 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
277 |
+
apply to Your use of the Licensed Material:
|
278 |
+
|
279 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
280 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
281 |
+
portion of the contents of the database for NonCommercial purposes
|
282 |
+
only;
|
283 |
+
|
284 |
+
b. if You include all or a substantial portion of the database
|
285 |
+
contents in a database in which You have Sui Generis Database
|
286 |
+
Rights, then the database in which You have Sui Generis Database
|
287 |
+
Rights (but not its individual contents) is Adapted Material; and
|
288 |
+
|
289 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
290 |
+
all or a substantial portion of the contents of the database.
|
291 |
+
|
292 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
293 |
+
replace Your obligations under this Public License where the Licensed
|
294 |
+
Rights include other Copyright and Similar Rights.
|
295 |
+
|
296 |
+
|
297 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
298 |
+
|
299 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
300 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
301 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
302 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
303 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
304 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
305 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
306 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
307 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
308 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
309 |
+
|
310 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
311 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
312 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
313 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
314 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
315 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
316 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
317 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
318 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
319 |
+
|
320 |
+
c. The disclaimer of warranties and limitation of liability provided
|
321 |
+
above shall be interpreted in a manner that, to the extent
|
322 |
+
possible, most closely approximates an absolute disclaimer and
|
323 |
+
waiver of all liability.
|
324 |
+
|
325 |
+
|
326 |
+
Section 6 -- Term and Termination.
|
327 |
+
|
328 |
+
a. This Public License applies for the term of the Copyright and
|
329 |
+
Similar Rights licensed here. However, if You fail to comply with
|
330 |
+
this Public License, then Your rights under this Public License
|
331 |
+
terminate automatically.
|
332 |
+
|
333 |
+
b. Where Your right to use the Licensed Material has terminated under
|
334 |
+
Section 6(a), it reinstates:
|
335 |
+
|
336 |
+
1. automatically as of the date the violation is cured, provided
|
337 |
+
it is cured within 30 days of Your discovery of the
|
338 |
+
violation; or
|
339 |
+
|
340 |
+
2. upon express reinstatement by the Licensor.
|
341 |
+
|
342 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
343 |
+
right the Licensor may have to seek remedies for Your violations
|
344 |
+
of this Public License.
|
345 |
+
|
346 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
347 |
+
Licensed Material under separate terms or conditions or stop
|
348 |
+
distributing the Licensed Material at any time; however, doing so
|
349 |
+
will not terminate this Public License.
|
350 |
+
|
351 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
352 |
+
License.
|
353 |
+
|
354 |
+
|
355 |
+
Section 7 -- Other Terms and Conditions.
|
356 |
+
|
357 |
+
a. The Licensor shall not be bound by any additional or different
|
358 |
+
terms or conditions communicated by You unless expressly agreed.
|
359 |
+
|
360 |
+
b. Any arrangements, understandings, or agreements regarding the
|
361 |
+
Licensed Material not stated herein are separate from and
|
362 |
+
independent of the terms and conditions of this Public License.
|
363 |
+
|
364 |
+
|
365 |
+
Section 8 -- Interpretation.
|
366 |
+
|
367 |
+
a. For the avoidance of doubt, this Public License does not, and
|
368 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
369 |
+
conditions on any use of the Licensed Material that could lawfully
|
370 |
+
be made without permission under this Public License.
|
371 |
+
|
372 |
+
b. To the extent possible, if any provision of this Public License is
|
373 |
+
deemed unenforceable, it shall be automatically reformed to the
|
374 |
+
minimum extent necessary to make it enforceable. If the provision
|
375 |
+
cannot be reformed, it shall be severed from this Public License
|
376 |
+
without affecting the enforceability of the remaining terms and
|
377 |
+
conditions.
|
378 |
+
|
379 |
+
c. No term or condition of this Public License will be waived and no
|
380 |
+
failure to comply consented to unless expressly agreed to by the
|
381 |
+
Licensor.
|
382 |
+
|
383 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
384 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
385 |
+
that apply to the Licensor or You, including from the legal
|
386 |
+
processes of any jurisdiction or authority.
|
387 |
+
|
388 |
+
=======================================================================
|
389 |
+
|
390 |
+
Creative Commons is not a party to its public
|
391 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
392 |
+
its public licenses to material it publishes and in those instances
|
393 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
394 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
395 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
396 |
+
material is shared under a Creative Commons public license or as
|
397 |
+
otherwise permitted by the Creative Commons policies published at
|
398 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
399 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
400 |
+
of Creative Commons without its prior written consent including,
|
401 |
+
without limitation, in connection with any unauthorized modifications
|
402 |
+
to any of its public licenses or any other arrangements,
|
403 |
+
understandings, or agreements concerning use of licensed material. For
|
404 |
+
the avoidance of doubt, this paragraph does not form part of the
|
405 |
+
public licenses.
|
406 |
+
|
407 |
+
Creative Commons may be contacted at creativecommons.org.
|
PROMPT_GUIDE.md
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Prompt Guide
|
2 |
+
|
3 |
+
All examples are generated with a CFG of $4.2$, $50$ steps, and are non-cherrypicked unless otherwise stated. Negative prompt is set to:
|
4 |
+
```
|
5 |
+
monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation
|
6 |
+
```
|
7 |
+
|
8 |
+
## 1. Text-to-Image
|
9 |
+
|
10 |
+
### 1.1 Long and detailed prompts give (much) better results.
|
11 |
+
|
12 |
+
Since our training comprised of long and detailed prompts, the model is more likely to generate better images with detailed prompts.
|
13 |
+
|
14 |
+
|
15 |
+
The model shows good text adherence with long and complex prompts as in below images. We use the first $20$ prompts from [simoryu's examples](https://cloneofsimo.github.io/compare_aura_sd3/). For detailed prompts, results of other models, refer to the above link.
|
16 |
+
|
17 |
+
<p align="center">
|
18 |
+
<img src="assets/promptguide_complex.jpg" alt="Text-to-Image results" width="800">
|
19 |
+
</p>
|
20 |
+
|
21 |
+
|
22 |
+
### 1.2 Resolution
|
23 |
+
|
24 |
+
The model generally works well with height and width in range of $[768; 1280]$ (height/width must be divisible by 16) for text-to-image. For other tasks, it performs best with resolution around $512$.
|
25 |
+
|
26 |
+
## 2. ID Customization & Subject-driven generation
|
27 |
+
|
28 |
+
- The expected length of source captions is $30$ to $75$ words. Empirically, we find that longer prompt can help preserve the ID better but it might hinder the text-adherence for target caption.
|
29 |
+
|
30 |
+
- We find it better to add some descriptions (e.g., from source caption) to target to preserve the identity, especially for complex subjects with delicate details.
|
31 |
+
|
32 |
+
<p align="center">
|
33 |
+
<img src="assets/promptguide_idtask.jpg" alt="ablation id task" width="800">
|
34 |
+
</p>
|
35 |
+
|
36 |
+
## 3. Multiview generation
|
37 |
+
|
38 |
+
We recommend not use captions, which describe the facial features e.g., looking at the camera, etc, to mitigate multifaced/janus problems.
|
39 |
+
|
40 |
+
## 4. Image editing
|
41 |
+
|
42 |
+
We find it's generally better to set the guidance scale to lower value e.g., $[3; 3.5]$ to avoid over-saturation results.
|
43 |
+
|
44 |
+
## 5. Special tokens and available colors
|
45 |
+
|
46 |
+
### 5.1 Task Tokens
|
47 |
+
|
48 |
+
| Task | Token | Additional Tokens |
|
49 |
+
|:---------------------|:---------------------------|:------------------|
|
50 |
+
| Text to Image | `[[text2image]]` | |
|
51 |
+
| Deblurring | `[[deblurring]]` | |
|
52 |
+
| Inpainting | `[[image_inpainting]]` | |
|
53 |
+
| Canny-edge and Image | `[[canny2image]]` | |
|
54 |
+
| Depth and Image | `[[depth2image]]` | |
|
55 |
+
| Hed and Image | `[[hed2img]]` | |
|
56 |
+
| Pose and Image | `[[pose2image]]` | |
|
57 |
+
| Image editing with Instruction | `[[image_editing]]` | |
|
58 |
+
| Semantic map and Image| `[[semanticmap2image]]` | `<#00FFFF cyan mask: object/to/segment>` |
|
59 |
+
| Boundingbox and Image | `[[boundingbox2image]]` | `<#00FFFF cyan boundingbox: object/to/detect>` |
|
60 |
+
| ID customization | `[[faceid]]` | `[[img0]] target/caption [[img1]] caption/of/source/image_1 [[img2]] caption/of/source/image_2 [[img3]] caption/of/source/image_3` |
|
61 |
+
| Multiview | `[[multiview]]` | |
|
62 |
+
| Subject-Driven | `[[subject_driven]]` | `<item: name/of/subject> [[img0]] target/caption/goes/here [[img1]] insert/source/caption` |
|
63 |
+
|
64 |
+
|
65 |
+
Note that you can replace the cyan color above with any from below table and have multiple additional tokens to detect/segment multiple classes.
|
66 |
+
|
67 |
+
### 5.2 Available colors
|
68 |
+
|
69 |
+
|
70 |
+
| Hex Code | Color Name |
|
71 |
+
|:---------|:-----------|
|
72 |
+
| #FF0000 | <span style="color: #FF0000">red</span> |
|
73 |
+
| #00FF00 | <span style="color: #00FF00">lime</span> |
|
74 |
+
| #0000FF | <span style="color: #0000FF">blue</span> |
|
75 |
+
| #FFFF00 | <span style="color: #FFFF00">yellow</span> |
|
76 |
+
| #FF00FF | <span style="color: #FF00FF">magenta</span> |
|
77 |
+
| #00FFFF | <span style="color: #00FFFF">cyan</span> |
|
78 |
+
| #FFA500 | <span style="color: #FFA500">orange</span> |
|
79 |
+
| #800080 | <span style="color: #800080">purple</span> |
|
80 |
+
| #A52A2A | <span style="color: #A52A2A">brown</span> |
|
81 |
+
| #008000 | <span style="color: #008000">green</span> |
|
82 |
+
| #FFC0CB | <span style="color: #FFC0CB">pink</span> |
|
83 |
+
| #008080 | <span style="color: #008080">teal</span> |
|
84 |
+
| #FF8C00 | <span style="color: #FF8C00">darkorange</span> |
|
85 |
+
| #8A2BE2 | <span style="color: #8A2BE2">blueviolet</span> |
|
86 |
+
| #006400 | <span style="color: #006400">darkgreen</span> |
|
87 |
+
| #FF4500 | <span style="color: #FF4500">orangered</span> |
|
88 |
+
| #000080 | <span style="color: #000080">navy</span> |
|
89 |
+
| #FFD700 | <span style="color: #FFD700">gold</span> |
|
90 |
+
| #40E0D0 | <span style="color: #40E0D0">turquoise</span> |
|
91 |
+
| #DA70D6 | <span style="color: #DA70D6">orchid</span> |
|
README.md
CHANGED
@@ -1,14 +1,169 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# One Diffusion to Generate Them All
|
2 |
+
|
3 |
+
<p align="left">
|
4 |
+
<a href="https://lehduong.github.io/OneDiffusion-homepage/">
|
5 |
+
<img alt="Build" src="https://img.shields.io/badge/Project%20Page-OneDiffusion-yellow">
|
6 |
+
</a>
|
7 |
+
<a href="https://arxiv.org/abs/2411.16318">
|
8 |
+
<img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-2411.16318-b31b1b.svg">
|
9 |
+
</a>
|
10 |
+
<a href="https://huggingface.co/spaces/lehduong/OneDiffusion">
|
11 |
+
<img alt="License" src="https://img.shields.io/badge/HF%20Demo-🤗-lightblue">
|
12 |
+
</a>
|
13 |
+
<a href="https://huggingface.co/lehduong/OneDiffusion">
|
14 |
+
<img alt="Build" src="https://img.shields.io/badge/HF%20Model-🤗-yellow">
|
15 |
+
</a>
|
16 |
+
</p>
|
17 |
+
|
18 |
+
<h4 align="left">
|
19 |
+
<p>
|
20 |
+
<a href=#news>News</a> |
|
21 |
+
<a href=#quick-start>Quick start</a> |
|
22 |
+
<a href=https://github.com/lehduong/OneDiffusion/blob/main/PROMPT_GUIDE.md>Prompt guide & Supported tasks </a> |
|
23 |
+
<a href=#qualitative-results>Qualitative results</a> |
|
24 |
+
<a href="#license">License</a> |
|
25 |
+
<a href="#citation">Citation</a>
|
26 |
+
<p>
|
27 |
+
</h4>
|
28 |
+
|
29 |
+
|
30 |
+
<p align="center">
|
31 |
+
<img src="assets/teaser.png" alt="Teaser Image" width="800">
|
32 |
+
</p>
|
33 |
+
|
34 |
+
|
35 |
+
This is official repo of OneDiffusion, a versatile, large-scale diffusion model that seamlessly supports bidirectional image synthesis and understanding across diverse tasks.
|
36 |
+
|
37 |
+
## News
|
38 |
+
- 📦 2024/12/10: Released weight.
|
39 |
+
- 📝 2024/12/06: Added image editing from instruction.
|
40 |
+
- ✨ 2024/12/02: Added subject-driven generation
|
41 |
+
|
42 |
+
## Installation
|
43 |
+
```
|
44 |
+
conda create -n onediffusion_env python=3.8 &&
|
45 |
+
conda activate onediffusion_env &&
|
46 |
+
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118 &&
|
47 |
+
pip install "git+https://github.com/facebookresearch/pytorch3d.git" &&
|
48 |
+
pip install -r requirements.txt
|
49 |
+
```
|
50 |
+
|
51 |
+
## Quick start
|
52 |
+
|
53 |
+
Check `inference.py` for more detailed. For text-to-image, you can use below code snipe.
|
54 |
+
|
55 |
+
```
|
56 |
+
import torch
|
57 |
+
from onediffusion.diffusion.pipelines.onediffusion import OneDiffusionPipeline
|
58 |
+
|
59 |
+
device = torch.device('cuda:0')
|
60 |
+
|
61 |
+
pipeline = OneDiffusionPipeline.from_pretrained("lehduong/OneDiffusion").to(device=device, dtype=torch.bfloat16)
|
62 |
+
|
63 |
+
NEGATIVE_PROMPT = "monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
|
64 |
+
|
65 |
+
output = pipeline(
|
66 |
+
prompt="[[text2image]] A bipedal black cat wearing a huge oversized witch hat, a wizards robe, casting a spell,in an enchanted forest. The scene is filled with fireflies and moss on surrounding rocks and trees",
|
67 |
+
negative_prompt=NEGATIVE_PROMPT,
|
68 |
+
num_inference_steps=50,
|
69 |
+
guidance_scale=4,
|
70 |
+
height=1024,
|
71 |
+
width=1024,
|
72 |
+
)
|
73 |
+
output.images[0].save('text2image_output.jpg')
|
74 |
+
```
|
75 |
+
|
76 |
+
You can run the gradio demo with:
|
77 |
+
```
|
78 |
+
python gradio_demo.py --captioner molmo # [molmo, llava, disable]
|
79 |
+
```
|
80 |
+
The demo provides guidance and helps format the prompt properly for each task.
|
81 |
+
- By default, it loads the Molmo for captioning source images, which significantly increases memory usage. You generally need a GPU with at least $40$ GB of memory to run the demo.
|
82 |
+
- Opting to use LLaVA can reduce this requirement to $\approx 27$ GB, though the resulting captions may be less accurate in some cases.
|
83 |
+
- You can also manually provide the caption for each input image and run with `disable` mode. In this mode, the returned caption is an empty string, but you should still press the `Generate Caption` button so that the code formats the input text properly. The memory requirement for this mode is $\approx 12$ GB.
|
84 |
+
|
85 |
+
Note that the above required memory can change if you use higher resolution or more input images.
|
86 |
+
|
87 |
+
## Qualitative Results
|
88 |
+
|
89 |
+
### 1. Text-to-Image
|
90 |
+
<p align="center">
|
91 |
+
<img src="assets/text2image.jpg" alt="Text-to-Image results" width="800">
|
92 |
+
</p>
|
93 |
+
|
94 |
+
|
95 |
+
### 2. ID customization
|
96 |
+
|
97 |
+
<p align="center">
|
98 |
+
<img src="assets/onediffusion_appendix_faceid.jpg" alt="ID customization" width="800">
|
99 |
+
</p>
|
100 |
+
|
101 |
+
<p align="center">
|
102 |
+
<img src="assets/onediffusion_appendix_faceid_3.jpg" alt="ID customization non-human subject" width="800">
|
103 |
+
</p>
|
104 |
+
|
105 |
+
### 3. Multiview generation
|
106 |
+
|
107 |
+
Single image to multiview:
|
108 |
+
|
109 |
+
<p align="center">
|
110 |
+
<img src="assets/onediffusion_appendix_multiview.jpg" alt="Image to multiview" width="800">
|
111 |
+
</p>
|
112 |
+
|
113 |
+
<p align="center">
|
114 |
+
<img src="assets/onediffusion_appendix_multiview_2.jpg" alt="image to multiview" width="800">
|
115 |
+
</p>
|
116 |
+
|
117 |
+
Text to multiview:
|
118 |
+
|
119 |
+
<p align="center">
|
120 |
+
<img src="assets/text2multiview.jpg" alt="Text to multiview image" width="800">
|
121 |
+
</p>
|
122 |
+
|
123 |
+
### 4. Condition-to-Image and vice versa
|
124 |
+
<p align="center">
|
125 |
+
<img src="assets/cond_and_image.jpg" alt="Condition and Image" width="800">
|
126 |
+
</p>
|
127 |
+
|
128 |
+
### 5. Subject-driven generation
|
129 |
+
|
130 |
+
We finetuned the model on [Subject-200K](https://huggingface.co/datasets/Yuanshi/Subjects200K) dataset (along with all other tasks) for additional 40k steps. The model is now capable of subject-driven generation.
|
131 |
+
|
132 |
+
<p align="center">
|
133 |
+
<img src="assets/subject_driven.jpg" alt="Subject driven generation" width="800">
|
134 |
+
</p>
|
135 |
+
|
136 |
+
### 6. Text-guide image editing
|
137 |
+
|
138 |
+
We finetuned the model on [OmniEdit](https://huggingface.co/datasets/TIGER-Lab/OmniEdit-Filtered-1.2M) dataset for additional 30K steps.
|
139 |
+
|
140 |
+
<p align="center">
|
141 |
+
<img src="assets/onediffusion_editing.jpg" alt="Text-guide editing" width="800">
|
142 |
+
</p>
|
143 |
+
|
144 |
+
### 7. Zero-shot Task combinations
|
145 |
+
|
146 |
+
We found that the model can handle multiple tasks in a zero-shot setting by combining condition images and task tokens without any fine-tuning, as shown in the examples below. However, its performance on these combined tasks might not be robust, and the model’s behavior may change if the order of task tokens or captions is altered. For example, when using both image inpainting and ID customization together, the target prompt and the caption of the masked image must be identical. If you plan to use such combinations, we recommend fine-tuning the model on these tasks to achieve better performance and simpler usage.
|
147 |
+
|
148 |
+
|
149 |
+
<p align="center">
|
150 |
+
<img src="assets/onediffusion_zeroshot.jpg" alt="Subject driven generation" width="800">
|
151 |
+
</p>
|
152 |
+
|
153 |
+
## License
|
154 |
+
|
155 |
+
The model is trained on several non-commercially licensed datasets (e.g., DL3DV, Unsplash), thus, **model weights** are released under a CC BY-NC license as described in [LICENSE](https://github.com/lehduong/onediffusion/blob/main/LICENSE).
|
156 |
+
|
157 |
+
## Citation
|
158 |
+
|
159 |
+
```bibtex
|
160 |
+
@misc{le2024diffusiongenerate,
|
161 |
+
title={One Diffusion to Generate Them All},
|
162 |
+
author={Duong H. Le and Tuan Pham and Sangho Lee and Christopher Clark and Aniruddha Kembhavi and Stephan Mandt and Ranjay Krishna and Jiasen Lu},
|
163 |
+
year={2024},
|
164 |
+
eprint={2411.16318},
|
165 |
+
archivePrefix={arXiv},
|
166 |
+
primaryClass={cs.CV},
|
167 |
+
url={https://arxiv.org/abs/2411.16318},
|
168 |
+
}
|
169 |
+
```
|
assets/cond_and_image.jpg
ADDED
Git LFS Details
|
assets/examples/id_customization/chenhao/image_0.png
ADDED
assets/examples/id_customization/chenhao/image_1.png
ADDED
assets/examples/id_customization/chenhao/image_2.png
ADDED
assets/onediffusion_appendix_faceid.jpg
ADDED
Git LFS Details
|
assets/onediffusion_appendix_faceid_3.jpg
ADDED
Git LFS Details
|
assets/onediffusion_appendix_multiview.jpg
ADDED
Git LFS Details
|
assets/onediffusion_appendix_multiview_2.jpg
ADDED
assets/onediffusion_appendix_text2multiview.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:60a945f7c1c92e823dbd3c7876c843d2668d16cb1ae883f2ba4080d324056225
|
3 |
+
size 8278287
|
assets/onediffusion_editing.jpg
ADDED
assets/onediffusion_zeroshot.jpg
ADDED
Git LFS Details
|
assets/promptguide_complex.jpg
ADDED
Git LFS Details
|
assets/promptguide_idtask.jpg
ADDED
assets/subject_driven.jpg
ADDED
assets/teaser.png
ADDED
Git LFS Details
|
assets/text2image.jpg
ADDED
assets/text2multiview.jpg
ADDED
Git LFS Details
|
docker/Dockerfile
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Inspired by https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile
|
2 |
+
# ARG COMPAT=0
|
3 |
+
ARG PERSONAL=0
|
4 |
+
# FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 as base-0
|
5 |
+
FROM nvcr.io/nvidia/pytorch:22.12-py3 as base
|
6 |
+
|
7 |
+
ENV HOST docker
|
8 |
+
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
|
9 |
+
# https://serverfault.com/questions/683605/docker-container-time-timezone-will-not-reflect-changes
|
10 |
+
ENV TZ America/Los_Angeles
|
11 |
+
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
12 |
+
|
13 |
+
# git for installing dependencies
|
14 |
+
# tzdata to set time zone
|
15 |
+
# wget and unzip to download data
|
16 |
+
# [2021-09-09] TD: zsh, stow, subversion, fasd are for setting up my personal environment.
|
17 |
+
# [2021-12-07] TD: openmpi-bin for MPI (multi-node training)
|
18 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
19 |
+
build-essential \
|
20 |
+
cmake \
|
21 |
+
curl \
|
22 |
+
ca-certificates \
|
23 |
+
sudo \
|
24 |
+
less \
|
25 |
+
htop \
|
26 |
+
git \
|
27 |
+
tzdata \
|
28 |
+
wget \
|
29 |
+
tmux \
|
30 |
+
zip \
|
31 |
+
unzip \
|
32 |
+
zsh stow subversion fasd \
|
33 |
+
&& rm -rf /var/lib/apt/lists/*
|
34 |
+
# openmpi-bin \
|
35 |
+
|
36 |
+
# Allow running runmpi as root
|
37 |
+
# ENV OMPI_ALLOW_RUN_AS_ROOT=1 OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1
|
38 |
+
|
39 |
+
# # Create a non-root user and switch to it
|
40 |
+
# RUN adduser --disabled-password --gecos '' --shell /bin/bash user \
|
41 |
+
# && echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user
|
42 |
+
# USER user
|
43 |
+
|
44 |
+
# All users can use /home/user as their home directory
|
45 |
+
ENV HOME=/home/user
|
46 |
+
RUN mkdir -p /home/user && chmod 777 /home/user
|
47 |
+
WORKDIR /home/user
|
48 |
+
|
49 |
+
# Set up personal environment
|
50 |
+
# FROM base-${COMPAT} as env-0
|
51 |
+
FROM base as env-0
|
52 |
+
FROM env-0 as env-1
|
53 |
+
# Use ONBUILD so that the dotfiles dir doesn't need to exist unless we're building a personal image
|
54 |
+
# https://stackoverflow.com/questions/31528384/conditional-copy-add-in-dockerfile
|
55 |
+
ONBUILD COPY dotfiles ./dotfiles
|
56 |
+
ONBUILD RUN cd ~/dotfiles && stow bash zsh tmux && sudo chsh -s /usr/bin/zsh $(whoami)
|
57 |
+
# nvcr pytorch image sets SHELL=/bin/bash
|
58 |
+
ONBUILD ENV SHELL=/bin/zsh
|
59 |
+
|
60 |
+
FROM env-${PERSONAL} as packages
|
61 |
+
|
62 |
+
# Disable pip cache: https://stackoverflow.com/questions/45594707/what-is-pips-no-cache-dir-good-for
|
63 |
+
ENV PIP_NO_CACHE_DIR=1
|
64 |
+
|
65 |
+
# # apex and pytorch-fast-transformers take a while to compile so we install them first
|
66 |
+
# TD [2022-04-28] apex is already installed. In case we need a newer commit:
|
67 |
+
# RUN pip install --upgrade --force-reinstall --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" --global-option="--fmha" --global-option="--fast_layer_norm" --global-option="--xentropy" git+https://github.com/NVIDIA/apex.git#egg=apex
|
68 |
+
|
69 |
+
# xgboost conflicts with deepspeed
|
70 |
+
RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.7
|
71 |
+
|
72 |
+
# General packages that we don't care about the version
|
73 |
+
# zstandard to extract the_pile dataset
|
74 |
+
# psutil to get the number of cpu physical cores
|
75 |
+
# twine to upload package to PyPI
|
76 |
+
RUN pip install pytest matplotlib jupyter ipython ipdb gpustat scikit-learn spacy munch einops opt_einsum fvcore gsutil cmake pykeops zstandard psutil h5py twine gdown \
|
77 |
+
&& python -m spacy download en_core_web_sm
|
78 |
+
# hydra
|
79 |
+
RUN pip install hydra-core==1.3.1 hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 pyrootutils rich
|
80 |
+
# Core packages
|
81 |
+
RUN pip install transformers==4.45.2 datasets==3.0.1 pytorch-lightning==2.2.1 triton==2.3.1 wandb==0.16.3 controlnet_aux==0.0.9 timm==0.6.7 torchmetrics==1.3.2
|
82 |
+
# torchmetrics 0.11.0 broke hydra's instantiate
|
83 |
+
|
84 |
+
# For MLPerf
|
85 |
+
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
|
86 |
+
|
87 |
+
RUN pip install accelerate==0.34.2
|
88 |
+
|
89 |
+
RUN pip install diffusers==0.30.3
|
90 |
+
|
91 |
+
RUN pip install deepspeed==0.15.2
|
92 |
+
|
93 |
+
RUN pip install sentencepiece==0.1.99
|
94 |
+
|
95 |
+
RUN pip install pillow==10.2.0
|
96 |
+
|
97 |
+
RUN pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118
|
98 |
+
|
99 |
+
# Install FlashAttention
|
100 |
+
RUN pip install flash-attn==2.6.3
|
101 |
+
|
102 |
+
# Install CUDA extensions for fused dense
|
103 |
+
RUN pip install git+https://github.com/Dao-AILab/flash-attention@v2.6.3#subdirectory=csrc/fused_dense_lib
|
104 |
+
|
105 |
+
RUN pip install jaxtyping mediapipe gradio
|
106 |
+
|
107 |
+
RUN pip install "git+https://github.com/facebookresearch/pytorch3d.git"
|
108 |
+
|
109 |
+
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
|
110 |
+
|
111 |
+
RUN pip install opencv-python==4.5.5.64
|
112 |
+
|
113 |
+
RUN pip install opencv-python-headless==4.5.5.64
|
114 |
+
|
115 |
+
RUN pip install huggingface_hub==0.24
|
116 |
+
|
117 |
+
RUN pip install numpy==1.24.4
|
118 |
+
|
119 |
+
|
gradio_demo.py
ADDED
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import base64
|
4 |
+
import io
|
5 |
+
from PIL import Image
|
6 |
+
from transformers import (
|
7 |
+
LlavaNextProcessor, LlavaNextForConditionalGeneration,
|
8 |
+
T5EncoderModel, T5Tokenizer
|
9 |
+
)
|
10 |
+
from transformers import (
|
11 |
+
AutoProcessor, AutoModelForCausalLM, GenerationConfig,
|
12 |
+
T5EncoderModel, T5Tokenizer
|
13 |
+
)
|
14 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler, FluxPipeline
|
15 |
+
from onediffusion.diffusion.pipelines.onediffusion import OneDiffusionPipeline
|
16 |
+
from onediffusion.models.denoiser.nextdit import NextDiT
|
17 |
+
from onediffusion.dataset.utils import get_closest_ratio, ASPECT_RATIO_512
|
18 |
+
from typing import List, Optional
|
19 |
+
import matplotlib
|
20 |
+
import numpy as np
|
21 |
+
import cv2
|
22 |
+
import argparse
|
23 |
+
|
24 |
+
# Task-specific tokens
|
25 |
+
TASK2SPECIAL_TOKENS = {
|
26 |
+
"text2image": "[[text2image]]",
|
27 |
+
"deblurring": "[[deblurring]]",
|
28 |
+
"inpainting": "[[image_inpainting]]",
|
29 |
+
"canny": "[[canny2image]]",
|
30 |
+
"depth2image": "[[depth2image]]",
|
31 |
+
"hed2image": "[[hed2img]]",
|
32 |
+
"pose2image": "[[pose2image]]",
|
33 |
+
"semanticmap2image": "[[semanticmap2image]]",
|
34 |
+
"boundingbox2image": "[[boundingbox2image]]",
|
35 |
+
"image_editing": "[[image_editing]]",
|
36 |
+
"faceid": "[[faceid]]",
|
37 |
+
"multiview": "[[multiview]]",
|
38 |
+
"subject_driven": "[[subject_driven]]"
|
39 |
+
}
|
40 |
+
NEGATIVE_PROMPT = "monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
|
41 |
+
|
42 |
+
|
43 |
+
class LlavaCaptionProcessor:
|
44 |
+
def __init__(self):
|
45 |
+
model_name = "llava-hf/llama3-llava-next-8b-hf"
|
46 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
47 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
48 |
+
self.processor = LlavaNextProcessor.from_pretrained(model_name)
|
49 |
+
self.model = LlavaNextForConditionalGeneration.from_pretrained(
|
50 |
+
model_name, torch_dtype=dtype, low_cpu_mem_usage=True
|
51 |
+
).to(device)
|
52 |
+
self.SPECIAL_TOKENS = "assistant\n\n\n"
|
53 |
+
|
54 |
+
def generate_response(self, image: Image.Image, msg: str) -> str:
|
55 |
+
conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": msg}]}]
|
56 |
+
with torch.no_grad():
|
57 |
+
prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
|
58 |
+
inputs = self.processor(prompt, image, return_tensors="pt").to(self.model.device)
|
59 |
+
output = self.model.generate(**inputs, max_new_tokens=250)
|
60 |
+
response = self.processor.decode(output[0], skip_special_tokens=True)
|
61 |
+
return response.split(msg)[-1].strip()[len(self.SPECIAL_TOKENS):]
|
62 |
+
|
63 |
+
def process(self, images: List[Image.Image], msg: str = None) -> List[str]:
|
64 |
+
if msg is None:
|
65 |
+
msg = f"Describe the contents of the photo in 150 words or fewer."
|
66 |
+
try:
|
67 |
+
return [self.generate_response(img, msg) for img in images]
|
68 |
+
except Exception as e:
|
69 |
+
print(f"Error in process: {str(e)}")
|
70 |
+
raise
|
71 |
+
|
72 |
+
|
73 |
+
class MolmoCaptionProcessor:
|
74 |
+
def __init__(self):
|
75 |
+
pretrained_model_name = 'allenai/Molmo-7B-O-0924'
|
76 |
+
self.processor = AutoProcessor.from_pretrained(
|
77 |
+
pretrained_model_name,
|
78 |
+
trust_remote_code=True,
|
79 |
+
torch_dtype='auto',
|
80 |
+
device_map='auto'
|
81 |
+
)
|
82 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
83 |
+
pretrained_model_name,
|
84 |
+
trust_remote_code=True,
|
85 |
+
torch_dtype='auto',
|
86 |
+
device_map='auto'
|
87 |
+
)
|
88 |
+
|
89 |
+
def generate_response(self, image: Image.Image, msg: str) -> str:
|
90 |
+
inputs = self.processor.process(
|
91 |
+
images=[image],
|
92 |
+
text=msg
|
93 |
+
)
|
94 |
+
# Move inputs to the correct device and make a batch of size 1
|
95 |
+
inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
|
96 |
+
|
97 |
+
# Generate output
|
98 |
+
output = self.model.generate_from_batch(
|
99 |
+
inputs,
|
100 |
+
GenerationConfig(max_new_tokens=250, stop_strings="<|endoftext|>"),
|
101 |
+
tokenizer=self.processor.tokenizer
|
102 |
+
)
|
103 |
+
|
104 |
+
# Only get generated tokens and decode them to text
|
105 |
+
generated_tokens = output[0, inputs['input_ids'].size(1):]
|
106 |
+
return self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
|
107 |
+
|
108 |
+
|
109 |
+
def process(self, images: List[Image.Image], msg: str = None) -> List[str]:
|
110 |
+
if msg is None:
|
111 |
+
msg = f"Describe the contents of the photo in 150 words or fewer."
|
112 |
+
try:
|
113 |
+
return [self.generate_response(img, msg) for img in images]
|
114 |
+
except Exception as e:
|
115 |
+
print(f"Error in process: {str(e)}")
|
116 |
+
raise
|
117 |
+
|
118 |
+
|
119 |
+
class PlaceHolderCaptionProcessor:
|
120 |
+
def __init__(self):
|
121 |
+
pass
|
122 |
+
|
123 |
+
def generate_response(self, image: Image.Image, msg: str) -> str:
|
124 |
+
return ""
|
125 |
+
|
126 |
+
def process(self, images: List[Image.Image], msg: str = None) -> List[str]:
|
127 |
+
return [""] * len(images)
|
128 |
+
|
129 |
+
|
130 |
+
def initialize_models(captioner_name):
|
131 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
132 |
+
pipeline = OneDiffusionPipeline.from_pretrained("lehduong/OneDiffusion").to(device=device, dtype=torch.bfloat16)
|
133 |
+
if captioner_name == 'molmo':
|
134 |
+
captioner = MolmoCaptionProcessor()
|
135 |
+
elif captioner_name == 'llava':
|
136 |
+
captioner = LlavaCaptionProcessor()
|
137 |
+
else:
|
138 |
+
captioner = PlaceHolderCaptionProcessor()
|
139 |
+
return pipeline, captioner
|
140 |
+
|
141 |
+
def colorize_depth_maps(
|
142 |
+
depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
|
143 |
+
):
|
144 |
+
"""
|
145 |
+
Colorize depth maps with reversed colors.
|
146 |
+
"""
|
147 |
+
assert len(depth_map.shape) >= 2, "Invalid dimension"
|
148 |
+
|
149 |
+
if isinstance(depth_map, torch.Tensor):
|
150 |
+
depth = depth_map.detach().squeeze().numpy()
|
151 |
+
elif isinstance(depth_map, np.ndarray):
|
152 |
+
depth = depth_map.copy().squeeze()
|
153 |
+
# reshape to [ (B,) H, W ]
|
154 |
+
if depth.ndim < 3:
|
155 |
+
depth = depth[np.newaxis, :, :]
|
156 |
+
|
157 |
+
# Normalize depth values to [0, 1]
|
158 |
+
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
|
159 |
+
# Invert the depth values to reverse the colors
|
160 |
+
depth = 1 - depth
|
161 |
+
|
162 |
+
# Use the colormap
|
163 |
+
cm = matplotlib.colormaps[cmap]
|
164 |
+
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # values from 0 to 1
|
165 |
+
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
|
166 |
+
|
167 |
+
if valid_mask is not None:
|
168 |
+
if isinstance(depth_map, torch.Tensor):
|
169 |
+
valid_mask = valid_mask.detach().numpy()
|
170 |
+
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
|
171 |
+
if valid_mask.ndim < 3:
|
172 |
+
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
173 |
+
else:
|
174 |
+
valid_mask = valid_mask[:, np.newaxis, :, :]
|
175 |
+
valid_mask = np.repeat(valid_mask, 3, axis=1)
|
176 |
+
img_colored_np[~valid_mask] = 0
|
177 |
+
|
178 |
+
if isinstance(depth_map, torch.Tensor):
|
179 |
+
img_colored = torch.from_numpy(img_colored_np).float()
|
180 |
+
elif isinstance(depth_map, np.ndarray):
|
181 |
+
img_colored = img_colored_np
|
182 |
+
|
183 |
+
return img_colored
|
184 |
+
|
185 |
+
|
186 |
+
def format_prompt(task_type: str, captions: List[str]) -> str:
|
187 |
+
if not captions:
|
188 |
+
return ""
|
189 |
+
if task_type == "faceid":
|
190 |
+
img_prompts = [f"[[img{i}]] {caption}" for i, caption in enumerate(captions, start=1)]
|
191 |
+
return f"[[faceid]] [[img0]] insert/your/caption/here {' '.join(img_prompts)}"
|
192 |
+
elif task_type == "image_editing":
|
193 |
+
return f"[[image_editing]] insert/your/instruction/here"
|
194 |
+
elif task_type == "semanticmap2image":
|
195 |
+
return f"[[semanticmap2image]] <#00ffff Cyan mask: insert/concept/to/segment/here> {captions[0]}"
|
196 |
+
elif task_type == "boundingbox2image":
|
197 |
+
return f"[[boundingbox2image]] <#00ffff Cyan boundingbox: insert/concept/to/segment/here> {captions[0]}"
|
198 |
+
elif task_type == "multiview":
|
199 |
+
img_prompts = captions[0]
|
200 |
+
return f"[[multiview]] {img_prompts}"
|
201 |
+
elif task_type == "subject_driven":
|
202 |
+
return f"[[subject_driven]] <item: insert/item/here> [[img0]] insert/your/target/caption/here [[img1]] {captions[0]}"
|
203 |
+
else:
|
204 |
+
return f"{TASK2SPECIAL_TOKENS[task_type]} {captions[0]}"
|
205 |
+
|
206 |
+
def update_prompt(images: List[Image.Image], task_type: str, custom_msg: str = None):
|
207 |
+
if not images:
|
208 |
+
return format_prompt(task_type, []), "Please upload at least one image!"
|
209 |
+
try:
|
210 |
+
captions = captioner.process(images, custom_msg)
|
211 |
+
if not captions:
|
212 |
+
return "", "No valid images found!"
|
213 |
+
prompt = format_prompt(task_type, captions)
|
214 |
+
return prompt, f"Generated {len(captions)} captions successfully!"
|
215 |
+
except Exception as e:
|
216 |
+
return "", f"Error generating captions: {str(e)}"
|
217 |
+
|
218 |
+
def generate_image(images: List[Image.Image], prompt: str, negative_prompt: str, num_inference_steps: int, guidance_scale: float,
|
219 |
+
denoise_mask: List[str], task_type: str, azimuth: str, elevation: str, distance: str, focal_length: float,
|
220 |
+
height: int = 1024, width: int = 1024, scale_factor: float = 1.0, scale_watershed: float = 1.0,
|
221 |
+
noise_scale: float = None, progress=gr.Progress()):
|
222 |
+
try:
|
223 |
+
img2img_kwargs = {
|
224 |
+
'prompt': prompt,
|
225 |
+
'negative_prompt': negative_prompt,
|
226 |
+
'num_inference_steps': num_inference_steps,
|
227 |
+
'guidance_scale': guidance_scale,
|
228 |
+
'height': height,
|
229 |
+
'width': width,
|
230 |
+
'forward_kwargs': {
|
231 |
+
'scale_factor': scale_factor,
|
232 |
+
'scale_watershed': scale_watershed
|
233 |
+
},
|
234 |
+
'noise_scale': noise_scale # Added noise_scale here
|
235 |
+
}
|
236 |
+
|
237 |
+
if task_type == 'multiview':
|
238 |
+
# Parse azimuth, elevation, and distance into lists, allowing 'None' values
|
239 |
+
azimuths = [float(a.strip()) if a.strip().lower() != 'none' else None for a in azimuth.split(',')] if azimuth else []
|
240 |
+
elevations = [float(e.strip()) if e.strip().lower() != 'none' else None for e in elevation.split(',')] if elevation else []
|
241 |
+
distances = [float(d.strip()) if d.strip().lower() != 'none' else None for d in distance.split(',')] if distance else []
|
242 |
+
|
243 |
+
num_views = max(len(images), len(azimuths), len(elevations), len(distances))
|
244 |
+
if num_views == 0:
|
245 |
+
return None, "At least one image or camera parameter must be provided."
|
246 |
+
|
247 |
+
total_components = []
|
248 |
+
for i in range(num_views):
|
249 |
+
total_components.append(f"image_{i}")
|
250 |
+
total_components.append(f"camera_pose_{i}")
|
251 |
+
|
252 |
+
denoise_mask_int = [1 if comp in denoise_mask else 0 for comp in total_components]
|
253 |
+
|
254 |
+
if len(denoise_mask_int) != len(total_components):
|
255 |
+
return None, f"Denoise mask length mismatch: expected {len(total_components)} components."
|
256 |
+
|
257 |
+
# Pad the input lists to num_views length
|
258 |
+
images_padded = images + [] * (num_views - len(images)) # Do not add None
|
259 |
+
azimuths_padded = azimuths + [None] * (num_views - len(azimuths))
|
260 |
+
elevations_padded = elevations + [None] * (num_views - len(elevations))
|
261 |
+
distances_padded = distances + [None] * (num_views - len(distances))
|
262 |
+
|
263 |
+
# Prepare values
|
264 |
+
img2img_kwargs.update({
|
265 |
+
'image': images_padded,
|
266 |
+
'multiview_azimuths': azimuths_padded,
|
267 |
+
'multiview_elevations': elevations_padded,
|
268 |
+
'multiview_distances': distances_padded,
|
269 |
+
'multiview_focal_length': focal_length, # Pass focal_length here
|
270 |
+
'is_multiview': True,
|
271 |
+
'denoise_mask': denoise_mask_int,
|
272 |
+
# 'predict_camera_poses': True,
|
273 |
+
})
|
274 |
+
else:
|
275 |
+
total_components = ["image_0"] + [f"image_{i+1}" for i in range(len(images))]
|
276 |
+
denoise_mask_int = [1 if comp in denoise_mask else 0 for comp in total_components]
|
277 |
+
if len(denoise_mask_int) != len(total_components):
|
278 |
+
return None, f"Denoise mask length mismatch: expected {len(total_components)} components."
|
279 |
+
|
280 |
+
img2img_kwargs.update({
|
281 |
+
'image': images,
|
282 |
+
'denoise_mask': denoise_mask_int
|
283 |
+
})
|
284 |
+
|
285 |
+
progress(0, desc="Generating image...")
|
286 |
+
if task_type == 'text2image':
|
287 |
+
output = pipeline(
|
288 |
+
prompt=prompt,
|
289 |
+
negative_prompt=negative_prompt,
|
290 |
+
num_inference_steps=num_inference_steps,
|
291 |
+
guidance_scale=guidance_scale,
|
292 |
+
height=height,
|
293 |
+
width=width,
|
294 |
+
scale_factor=scale_factor,
|
295 |
+
scale_watershed=scale_watershed,
|
296 |
+
noise_scale=noise_scale # Added noise_scale here
|
297 |
+
)
|
298 |
+
else:
|
299 |
+
output = pipeline.img2img(**img2img_kwargs)
|
300 |
+
progress(1, desc="Done!")
|
301 |
+
|
302 |
+
# Process the output images if task is 'depth2image' and predicting depth
|
303 |
+
if task_type == 'depth2image' and denoise_mask_int[-1] == 1:
|
304 |
+
processed_images = []
|
305 |
+
for img in output.images:
|
306 |
+
depth_map = np.array(img.convert('L')) # Convert to grayscale numpy array
|
307 |
+
min_depth = depth_map.min()
|
308 |
+
max_depth = depth_map.max()
|
309 |
+
colorized = colorize_depth_maps(depth_map, min_depth, max_depth)[0]
|
310 |
+
colorized = np.transpose(colorized, (1, 2, 0))
|
311 |
+
colorized = (colorized * 255).astype(np.uint8)
|
312 |
+
img_colorized = Image.fromarray(colorized)
|
313 |
+
processed_images.append(img_colorized)
|
314 |
+
output_images = processed_images + output.images
|
315 |
+
elif task_type in ['boundingbox2image', 'semanticmap2image'] and denoise_mask_int == [0,1] and images:
|
316 |
+
# Interpolate between input and output images
|
317 |
+
processed_images = []
|
318 |
+
for input_img, output_img in zip(images, output.images):
|
319 |
+
input_img_resized = input_img.resize(output_img.size)
|
320 |
+
blended_img = Image.blend(input_img_resized, output_img, alpha=0.5)
|
321 |
+
processed_images.append(blended_img)
|
322 |
+
output_images = processed_images + output.images
|
323 |
+
else:
|
324 |
+
output_images = output.images
|
325 |
+
|
326 |
+
return output_images, "Generation completed successfully!"
|
327 |
+
|
328 |
+
except Exception as e:
|
329 |
+
return None, f"Error during generation: {str(e)}"
|
330 |
+
|
331 |
+
def update_denoise_checkboxes(images_state: List[Image.Image], task_type: str, azimuth: str, elevation: str, distance: str):
|
332 |
+
if task_type == 'multiview':
|
333 |
+
azimuths = [a.strip() for a in azimuth.split(',')] if azimuth else []
|
334 |
+
elevations = [e.strip() for e in elevation.split(',')] if elevation else []
|
335 |
+
distances = [d.strip() for d in distance.split(',')] if distance else []
|
336 |
+
images_len = len(images_state)
|
337 |
+
|
338 |
+
num_views = max(images_len, len(azimuths), len(elevations), len(distances))
|
339 |
+
if num_views == 0:
|
340 |
+
return gr.update(choices=[], value=[]), "Please provide at least one image or camera parameter."
|
341 |
+
|
342 |
+
# Pad lists to the same length
|
343 |
+
azimuths += ['None'] * (num_views - len(azimuths))
|
344 |
+
elevations += ['None'] * (num_views - len(elevations))
|
345 |
+
distances += ['None'] * (num_views - len(distances))
|
346 |
+
# Do not add None to images_state
|
347 |
+
|
348 |
+
labels = []
|
349 |
+
values = []
|
350 |
+
for i in range(num_views):
|
351 |
+
labels.append(f"image_{i}")
|
352 |
+
labels.append(f"camera_pose_{i}")
|
353 |
+
|
354 |
+
# Default behavior: condition on provided inputs, generate missing ones
|
355 |
+
if i >= images_len:
|
356 |
+
values.append(f"image_{i}")
|
357 |
+
if azimuths[i].lower() == 'none' or elevations[i].lower() == 'none' or distances[i].lower() == 'none':
|
358 |
+
values.append(f"camera_pose_{i}")
|
359 |
+
|
360 |
+
return gr.update(choices=labels, value=values)
|
361 |
+
else:
|
362 |
+
labels = ["image_0"] + [f"image_{i+1}" for i in range(len(images_state))]
|
363 |
+
values = ["image_0"]
|
364 |
+
return gr.update(choices=labels, value=values)
|
365 |
+
|
366 |
+
def apply_mask(images_state):
|
367 |
+
if len(images_state) < 2:
|
368 |
+
return None, "Please upload at least two images: first as the base image, second as the mask."
|
369 |
+
base_img = images_state[0]
|
370 |
+
mask_img = images_state[1]
|
371 |
+
|
372 |
+
# Convert images to arrays
|
373 |
+
base_arr = np.array(base_img)
|
374 |
+
mask_arr = np.array(mask_img)
|
375 |
+
|
376 |
+
# Convert mask to grayscale
|
377 |
+
if mask_arr.ndim == 3:
|
378 |
+
gray_mask = cv2.cvtColor(mask_arr, cv2.COLOR_RGB2GRAY)
|
379 |
+
else:
|
380 |
+
gray_mask = mask_arr
|
381 |
+
|
382 |
+
# Create a binary mask where non-black pixels are True
|
383 |
+
binary_mask = gray_mask > 10
|
384 |
+
|
385 |
+
# Define the gray color
|
386 |
+
gray_color = np.array([128, 128, 128], dtype=np.uint8)
|
387 |
+
|
388 |
+
# Apply gray color where mask is True
|
389 |
+
masked_arr = base_arr.copy()
|
390 |
+
masked_arr[binary_mask] = gray_color
|
391 |
+
|
392 |
+
masked_img = Image.fromarray(masked_arr)
|
393 |
+
return [masked_img], "Mask applied successfully!"
|
394 |
+
|
395 |
+
def process_images_for_task_type(images_state: List[Image.Image], task_type: str):
|
396 |
+
# No changes needed here since we are processing the output images
|
397 |
+
return images_state, images_state
|
398 |
+
|
399 |
+
with gr.Blocks(title="OneDiffusion Demo") as demo:
|
400 |
+
gr.Markdown("""
|
401 |
+
# OneDiffusion Demo
|
402 |
+
|
403 |
+
**Welcome to the OneDiffusion Demo!**
|
404 |
+
|
405 |
+
This application allows you to generate images based on your input prompts for various tasks. Here's how to use it:
|
406 |
+
|
407 |
+
1. **Select Task Type**: Choose the type of task you want to perform from the "Task Type" dropdown menu.
|
408 |
+
|
409 |
+
2. **Upload Images**: Drag and drop images directly onto the upload area, or click to select files from your device.
|
410 |
+
|
411 |
+
3. **Generate Captions**: **If you upload any images**, Click the "Generate Captions with Molmo" button to generate descriptive captions for your uploaded images (depend on the task). You can enter a custom message in the "Custom Message for Molmo" textbox e.g., "caption in 30 words" instead of 50 words.
|
412 |
+
|
413 |
+
4. **Configure Generation Settings**: Expand the "Advanced Configuration" section to adjust parameters like the number of inference steps, guidance scale, image size, and more.
|
414 |
+
|
415 |
+
5. **Generate Images**: After setting your preferences, click the "Generate Image" button. The generated images will appear in the "Generated Images" gallery.
|
416 |
+
|
417 |
+
6. **Manage Images**: Use the "Delete Selected Images" or "Delete All Images" buttons to remove unwanted images from the gallery.
|
418 |
+
|
419 |
+
**Notes**:
|
420 |
+
- Check out the [Prompt Guide](https://github.com/lehduong/OneDiffusion/blob/main/PROMPT_GUIDE.md).
|
421 |
+
|
422 |
+
- For text-to-image:
|
423 |
+
+ simply enter your prompt in this format "[[text2image]] your/prompt/here" and press the "Generate Image" button.
|
424 |
+
|
425 |
+
- For boundingbox2image/semantic2image/inpainting etc tasks:
|
426 |
+
+ To perform condition-to-image such as semantic map to image, follow above steps
|
427 |
+
+ For image-to-condition e.g., image to depth, change the denoise_mask checkbox before generating images. You must UNCHECK image_0 box and CHECK image_1 box.
|
428 |
+
|
429 |
+
- For FaceID tasks:
|
430 |
+
+ Use 3 or 4 images if single input image does not give satisfactory results.
|
431 |
+
+ All images will be resized and center cropped to the input height and width. You should choose height and width so that faces in input images won't be cropped.
|
432 |
+
+ Model works best with close-up portrait (input and output) images.
|
433 |
+
+ If the model does not conform your text prompt, try using shorter caption for source image(s).
|
434 |
+
+ If you have non-human subjects and does not get satisfactory results, try "copying" part of caption of source images where it describes the properties of the subject e.g., a monster with red eyes, sharp teeth, etc.
|
435 |
+
|
436 |
+
- For Multiview generation:
|
437 |
+
+ The input camera elevation/azimuth ALWAYS starts with $0$. If you want to generate images of azimuths 30,60,90 and elevations of 10,20,30 (wrt input image), the correct input azimuth is: `0, 30, 60, 90`; input elevation is `0,10,20,30`. The camera distance will be `1.5,1.5,1.5,1.5`
|
438 |
+
+ Only support square images (ideally in 512x512 resolution).
|
439 |
+
+ Ensure the number of elevations, azimuths, and distances are equal.
|
440 |
+
+ The model generally works well for 2-5 views (include both input and generated images). Since the model is trained with 3 views on 512x512 resolution, you might try scale_factor of [1.1; 1.5] and scale_watershed of [100; 400] for better extrapolation.
|
441 |
+
+ For better results:
|
442 |
+
1) try increasing num_inference_steps to 75-100.
|
443 |
+
2) avoid aggressively changes in target camera poses, for example to generate novel views at azimuth of 180, (simultaneously) generate 4 views with azimuth of 45, 90, 135, 180.
|
444 |
+
|
445 |
+
Enjoy creating images with OneDiffusion!
|
446 |
+
""")
|
447 |
+
|
448 |
+
with gr.Row():
|
449 |
+
with gr.Column():
|
450 |
+
images_state = gr.State([])
|
451 |
+
selected_indices_state = gr.State([])
|
452 |
+
|
453 |
+
with gr.Row():
|
454 |
+
gallery = gr.Gallery(
|
455 |
+
label="Input Images",
|
456 |
+
show_label=True,
|
457 |
+
columns=2,
|
458 |
+
rows=2,
|
459 |
+
height="auto",
|
460 |
+
object_fit="contain"
|
461 |
+
)
|
462 |
+
|
463 |
+
# In the UI section, update the file_output component:
|
464 |
+
file_output = gr.File(
|
465 |
+
file_count="multiple",
|
466 |
+
file_types=["image"],
|
467 |
+
label="Drag and drop images here or click to upload",
|
468 |
+
height=100,
|
469 |
+
scale=2,
|
470 |
+
type="filepath" # Add this parameter
|
471 |
+
)
|
472 |
+
|
473 |
+
with gr.Row():
|
474 |
+
delete_button = gr.Button("Delete Selected Images")
|
475 |
+
delete_all_button = gr.Button("Delete All Images")
|
476 |
+
|
477 |
+
task_type = gr.Dropdown(
|
478 |
+
choices=list(TASK2SPECIAL_TOKENS.keys()),
|
479 |
+
value="text2image",
|
480 |
+
label="Task Type"
|
481 |
+
)
|
482 |
+
|
483 |
+
captioning_message = gr.Textbox(
|
484 |
+
lines=2,
|
485 |
+
value="Describe the contents of the photo in 50 words.",
|
486 |
+
label="Custom message for captioner"
|
487 |
+
)
|
488 |
+
|
489 |
+
auto_caption_btn = gr.Button("Generate Captions")
|
490 |
+
|
491 |
+
with gr.Column():
|
492 |
+
prompt = gr.Textbox(
|
493 |
+
lines=3,
|
494 |
+
placeholder="Enter your prompt here or use auto-caption...",
|
495 |
+
label="Prompt"
|
496 |
+
)
|
497 |
+
negative_prompt = gr.Textbox(
|
498 |
+
lines=3,
|
499 |
+
value=NEGATIVE_PROMPT,
|
500 |
+
placeholder="Enter negative prompt here...",
|
501 |
+
label="Negative Prompt"
|
502 |
+
)
|
503 |
+
caption_status = gr.Textbox(label="Caption Status")
|
504 |
+
|
505 |
+
num_steps = gr.Slider(
|
506 |
+
minimum=1,
|
507 |
+
maximum=200,
|
508 |
+
value=50,
|
509 |
+
step=1,
|
510 |
+
label="Number of Inference Steps"
|
511 |
+
)
|
512 |
+
guidance_scale = gr.Slider(
|
513 |
+
minimum=0.1,
|
514 |
+
maximum=10.0,
|
515 |
+
value=4,
|
516 |
+
step=0.1,
|
517 |
+
label="Guidance Scale"
|
518 |
+
)
|
519 |
+
height = gr.Number(value=1024, label="Height")
|
520 |
+
width = gr.Number(value=1024, label="Width")
|
521 |
+
|
522 |
+
with gr.Accordion("Advanced Configuration", open=False):
|
523 |
+
with gr.Row():
|
524 |
+
denoise_mask_checkbox = gr.CheckboxGroup(
|
525 |
+
label="Denoise Mask",
|
526 |
+
choices=["image_0"],
|
527 |
+
value=["image_0"]
|
528 |
+
)
|
529 |
+
azimuth = gr.Textbox(
|
530 |
+
value="0",
|
531 |
+
label="Azimuths (degrees, comma-separated, 'None' for missing)"
|
532 |
+
)
|
533 |
+
elevation = gr.Textbox(
|
534 |
+
value="0",
|
535 |
+
label="Elevations (degrees, comma-separated, 'None' for missing)"
|
536 |
+
)
|
537 |
+
distance = gr.Textbox(
|
538 |
+
value="1.5",
|
539 |
+
label="Distances (comma-separated, 'None' for missing)"
|
540 |
+
)
|
541 |
+
focal_length = gr.Number(
|
542 |
+
value=1.3887,
|
543 |
+
label="Focal Length of camera for multiview generation"
|
544 |
+
)
|
545 |
+
scale_factor = gr.Number(value=1.0, label="Scale Factor")
|
546 |
+
scale_watershed = gr.Number(value=1.0, label="Scale Watershed")
|
547 |
+
noise_scale = gr.Number(value=1.0, label="Noise Scale") # Added noise_scale input
|
548 |
+
|
549 |
+
output_images = gr.Gallery(
|
550 |
+
label="Generated Images",
|
551 |
+
show_label=True,
|
552 |
+
columns=4,
|
553 |
+
rows=2,
|
554 |
+
height="auto",
|
555 |
+
object_fit="contain"
|
556 |
+
)
|
557 |
+
|
558 |
+
with gr.Column():
|
559 |
+
generate_btn = gr.Button("Generate Image")
|
560 |
+
# apply_mask_btn = gr.Button("Apply Mask")
|
561 |
+
|
562 |
+
status = gr.Textbox(label="Generation Status")
|
563 |
+
|
564 |
+
# Event Handlers
|
565 |
+
def update_gallery(files, images_state):
|
566 |
+
if not files:
|
567 |
+
return images_state, images_state
|
568 |
+
|
569 |
+
new_images = []
|
570 |
+
for file in files:
|
571 |
+
try:
|
572 |
+
# Handle both file paths and file objects
|
573 |
+
if isinstance(file, dict): # For drag and drop files
|
574 |
+
file = file['path']
|
575 |
+
elif hasattr(file, 'name'): # For uploaded files
|
576 |
+
file = file.name
|
577 |
+
|
578 |
+
img = Image.open(file).convert('RGB')
|
579 |
+
new_images.append(img)
|
580 |
+
except Exception as e:
|
581 |
+
print(f"Error loading image: {str(e)}")
|
582 |
+
continue
|
583 |
+
|
584 |
+
images_state.extend(new_images)
|
585 |
+
return images_state, images_state
|
586 |
+
|
587 |
+
def on_image_select(evt: gr.SelectData, selected_indices_state):
|
588 |
+
selected_indices = selected_indices_state or []
|
589 |
+
index = evt.index
|
590 |
+
if index in selected_indices:
|
591 |
+
selected_indices.remove(index)
|
592 |
+
else:
|
593 |
+
selected_indices.append(index)
|
594 |
+
return selected_indices
|
595 |
+
|
596 |
+
def delete_images(selected_indices, images_state):
|
597 |
+
updated_images = [img for i, img in enumerate(images_state) if i not in selected_indices]
|
598 |
+
selected_indices_state = []
|
599 |
+
return updated_images, updated_images, selected_indices_state
|
600 |
+
|
601 |
+
def delete_all_images(images_state):
|
602 |
+
updated_images = []
|
603 |
+
selected_indices_state = []
|
604 |
+
return updated_images, updated_images, selected_indices_state
|
605 |
+
|
606 |
+
def update_height_width(images_state):
|
607 |
+
if images_state:
|
608 |
+
closest_ar = get_closest_ratio(
|
609 |
+
height=images_state[0].size[1],
|
610 |
+
width=images_state[0].size[0],
|
611 |
+
ratios=ASPECT_RATIO_512
|
612 |
+
)
|
613 |
+
height_val, width_val = int(closest_ar[0][0]), int(closest_ar[0][1])
|
614 |
+
else:
|
615 |
+
height_val, width_val = 1024, 1024 # Default values
|
616 |
+
return gr.update(value=height_val), gr.update(value=width_val)
|
617 |
+
|
618 |
+
# Connect events
|
619 |
+
file_output.change(
|
620 |
+
fn=update_gallery,
|
621 |
+
inputs=[file_output, images_state],
|
622 |
+
outputs=[images_state, gallery]
|
623 |
+
).then(
|
624 |
+
fn=update_height_width,
|
625 |
+
inputs=[images_state],
|
626 |
+
outputs=[height, width]
|
627 |
+
).then(
|
628 |
+
fn=update_denoise_checkboxes,
|
629 |
+
inputs=[images_state, task_type, azimuth, elevation, distance],
|
630 |
+
outputs=[denoise_mask_checkbox]
|
631 |
+
)
|
632 |
+
|
633 |
+
gallery.select(
|
634 |
+
fn=on_image_select,
|
635 |
+
inputs=[selected_indices_state],
|
636 |
+
outputs=[selected_indices_state]
|
637 |
+
)
|
638 |
+
|
639 |
+
delete_button.click(
|
640 |
+
fn=delete_images,
|
641 |
+
inputs=[selected_indices_state, images_state],
|
642 |
+
outputs=[images_state, gallery, selected_indices_state]
|
643 |
+
).then(
|
644 |
+
fn=update_denoise_checkboxes,
|
645 |
+
inputs=[images_state, task_type, azimuth, elevation, distance],
|
646 |
+
outputs=[denoise_mask_checkbox]
|
647 |
+
)
|
648 |
+
|
649 |
+
delete_all_button.click(
|
650 |
+
fn=delete_all_images,
|
651 |
+
inputs=[images_state],
|
652 |
+
outputs=[images_state, gallery, selected_indices_state]
|
653 |
+
).then(
|
654 |
+
fn=update_denoise_checkboxes,
|
655 |
+
inputs=[images_state, task_type, azimuth, elevation, distance],
|
656 |
+
outputs=[denoise_mask_checkbox]
|
657 |
+
)
|
658 |
+
|
659 |
+
task_type.change(
|
660 |
+
fn=update_denoise_checkboxes,
|
661 |
+
inputs=[images_state, task_type, azimuth, elevation, distance],
|
662 |
+
outputs=[denoise_mask_checkbox]
|
663 |
+
)
|
664 |
+
|
665 |
+
azimuth.change(
|
666 |
+
fn=update_denoise_checkboxes,
|
667 |
+
inputs=[images_state, task_type, azimuth, elevation, distance],
|
668 |
+
outputs=[denoise_mask_checkbox]
|
669 |
+
)
|
670 |
+
|
671 |
+
elevation.change(
|
672 |
+
fn=update_denoise_checkboxes,
|
673 |
+
inputs=[images_state, task_type, azimuth, elevation, distance],
|
674 |
+
outputs=[denoise_mask_checkbox]
|
675 |
+
)
|
676 |
+
|
677 |
+
distance.change(
|
678 |
+
fn=update_denoise_checkboxes,
|
679 |
+
inputs=[images_state, task_type, azimuth, elevation, distance],
|
680 |
+
outputs=[denoise_mask_checkbox]
|
681 |
+
)
|
682 |
+
|
683 |
+
generate_btn.click(
|
684 |
+
fn=generate_image,
|
685 |
+
inputs=[
|
686 |
+
images_state, prompt, negative_prompt, num_steps, guidance_scale,
|
687 |
+
denoise_mask_checkbox, task_type, azimuth, elevation, distance,
|
688 |
+
focal_length, height, width, scale_factor, scale_watershed, noise_scale # Added noise_scale here
|
689 |
+
],
|
690 |
+
outputs=[output_images, status],
|
691 |
+
concurrency_id="gpu_queue"
|
692 |
+
)
|
693 |
+
|
694 |
+
auto_caption_btn.click(
|
695 |
+
fn=update_prompt,
|
696 |
+
inputs=[images_state, task_type, captioning_message],
|
697 |
+
outputs=[prompt, caption_status],
|
698 |
+
concurrency_id="gpu_queue"
|
699 |
+
)
|
700 |
+
|
701 |
+
# apply_mask_btn.click(
|
702 |
+
# fn=apply_mask,
|
703 |
+
# inputs=[images_state],
|
704 |
+
# outputs=[output_images, status]
|
705 |
+
# )
|
706 |
+
|
707 |
+
if __name__ == "__main__":
|
708 |
+
parser = argparse.ArgumentParser(description='Start the Gradio demo with specified captioner.')
|
709 |
+
parser.add_argument('--captioner', type=str, choices=['molmo', 'llava', 'disable'], default='molmo', help='Captioner to use: molmo, llava, disable.')
|
710 |
+
args = parser.parse_args()
|
711 |
+
|
712 |
+
# Initialize models with the specified captioner
|
713 |
+
pipeline, captioner = initialize_models(args.captioner)
|
714 |
+
|
715 |
+
demo.launch(share=True)
|
inference.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from onediffusion.diffusion.pipelines.onediffusion import OneDiffusionPipeline
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
device = torch.device('cuda:0')
|
6 |
+
pipeline = OneDiffusionPipeline.from_pretrained("lehduong/OneDiffusion").to(device=device, dtype=torch.bfloat16)
|
7 |
+
|
8 |
+
NEGATIVE_PROMPT = "monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
|
9 |
+
|
10 |
+
## Text-to-image
|
11 |
+
output = pipeline(
|
12 |
+
prompt="[[text2image]] A bipedal black cat wearing a huge oversized witch hat, a wizards robe, casting a spell,in an enchanted forest. The scene is filled with fireflies and moss on surrounding rocks and trees",
|
13 |
+
negative_prompt=NEGATIVE_PROMPT,
|
14 |
+
num_inference_steps=50,
|
15 |
+
guidance_scale=4,
|
16 |
+
height=1024,
|
17 |
+
width=1024,
|
18 |
+
)
|
19 |
+
output.images[0].save('text2image_output.jpg')
|
20 |
+
|
21 |
+
## ID Customization
|
22 |
+
image = [
|
23 |
+
Image.open("assets/examples/id_customization/chenhao/image_0.png"),
|
24 |
+
Image.open("assets/examples/id_customization/chenhao/image_1.png"),
|
25 |
+
Image.open("assets/examples/id_customization/chenhao/image_2.png")
|
26 |
+
]
|
27 |
+
|
28 |
+
# input = [noise, cond_1, cond_2, cond_3]
|
29 |
+
prompt = "[[faceid]] \
|
30 |
+
[[img0]] A woman dressed in traditional attire with intricate headpieces, posing gracefully with a serene expression. \
|
31 |
+
[[img1]] A woman with long dark hair, smiling warmly while wearing a floral dress. \
|
32 |
+
[[img2]] A woman in traditional clothing holding a lace parasol, with her hair styled elegantly. \
|
33 |
+
[[img3]] A woman in elaborate traditional attire and jewelry, with an ornate headdress, looking intently forward. \
|
34 |
+
"
|
35 |
+
|
36 |
+
ret = pipeline.img2img(image=image, num_inference_steps=75, prompt=prompt, denoise_mask=[1, 0, 0, 0], guidance_scale=4)
|
37 |
+
ret.images[0].save("idcustomization_output.jpg")
|
onediffusion/dataset/multitask/multiview.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
from typing import List, Tuple, Union
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from torchvision import transforms
|
9 |
+
import torchvision.transforms as T
|
10 |
+
from onediffusion.dataset.utils import *
|
11 |
+
import glob
|
12 |
+
|
13 |
+
from onediffusion.dataset.raydiff_utils import cameras_to_rays, first_camera_transform, normalize_cameras
|
14 |
+
from onediffusion.dataset.transforms import CenterCropResizeImage
|
15 |
+
from pytorch3d.renderer import PerspectiveCameras
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
def _cameras_from_opencv_projection(
|
20 |
+
R: torch.Tensor,
|
21 |
+
tvec: torch.Tensor,
|
22 |
+
camera_matrix: torch.Tensor,
|
23 |
+
image_size: torch.Tensor,
|
24 |
+
do_normalize_cameras,
|
25 |
+
normalize_scale,
|
26 |
+
) -> PerspectiveCameras:
|
27 |
+
focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
|
28 |
+
principal_point = camera_matrix[:, :2, 2]
|
29 |
+
|
30 |
+
# Retype the image_size correctly and flip to width, height.
|
31 |
+
image_size_wh = image_size.to(R).flip(dims=(1,))
|
32 |
+
|
33 |
+
# Screen to NDC conversion:
|
34 |
+
# For non square images, we scale the points such that smallest side
|
35 |
+
# has range [-1, 1] and the largest side has range [-u, u], with u > 1.
|
36 |
+
# This convention is consistent with the PyTorch3D renderer, as well as
|
37 |
+
# the transformation function `get_ndc_to_screen_transform`.
|
38 |
+
scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0
|
39 |
+
scale = scale.expand(-1, 2)
|
40 |
+
c0 = image_size_wh / 2.0
|
41 |
+
|
42 |
+
# Get the PyTorch3D focal length and principal point.
|
43 |
+
focal_pytorch3d = focal_length / scale
|
44 |
+
p0_pytorch3d = -(principal_point - c0) / scale
|
45 |
+
|
46 |
+
# For R, T we flip x, y axes (opencv screen space has an opposite
|
47 |
+
# orientation of screen axes).
|
48 |
+
# We also transpose R (opencv multiplies points from the opposite=left side).
|
49 |
+
R_pytorch3d = R.clone().permute(0, 2, 1)
|
50 |
+
T_pytorch3d = tvec.clone()
|
51 |
+
R_pytorch3d[:, :, :2] *= -1
|
52 |
+
T_pytorch3d[:, :2] *= -1
|
53 |
+
|
54 |
+
cams = PerspectiveCameras(
|
55 |
+
R=R_pytorch3d,
|
56 |
+
T=T_pytorch3d,
|
57 |
+
focal_length=focal_pytorch3d,
|
58 |
+
principal_point=p0_pytorch3d,
|
59 |
+
image_size=image_size,
|
60 |
+
device=R.device,
|
61 |
+
)
|
62 |
+
|
63 |
+
if do_normalize_cameras:
|
64 |
+
cams, _ = normalize_cameras(cams, scale=normalize_scale)
|
65 |
+
|
66 |
+
cams = first_camera_transform(cams, rotation_only=False)
|
67 |
+
return cams
|
68 |
+
|
69 |
+
def calculate_rays(Ks, sizes, Rs, Ts, target_size, use_plucker=True, do_normalize_cameras=False, normalize_scale=1.0):
|
70 |
+
cameras = _cameras_from_opencv_projection(
|
71 |
+
R=Rs,
|
72 |
+
tvec=Ts,
|
73 |
+
camera_matrix=Ks,
|
74 |
+
image_size=sizes,
|
75 |
+
do_normalize_cameras=do_normalize_cameras,
|
76 |
+
normalize_scale=normalize_scale
|
77 |
+
)
|
78 |
+
|
79 |
+
rays_embedding = cameras_to_rays(
|
80 |
+
cameras=cameras,
|
81 |
+
num_patches_x=target_size,
|
82 |
+
num_patches_y=target_size,
|
83 |
+
crop_parameters=None,
|
84 |
+
use_plucker=use_plucker
|
85 |
+
)
|
86 |
+
|
87 |
+
return rays_embedding.rays
|
88 |
+
|
89 |
+
def convert_rgba_to_rgb_white_bg(image):
|
90 |
+
"""Convert RGBA image to RGB with white background"""
|
91 |
+
if image.mode == 'RGBA':
|
92 |
+
# Create a white background
|
93 |
+
background = Image.new('RGBA', image.size, (255, 255, 255, 255))
|
94 |
+
# Composite the image onto the white background
|
95 |
+
return Image.alpha_composite(background, image).convert('RGB')
|
96 |
+
return image.convert('RGB')
|
97 |
+
|
98 |
+
class MultiviewDataset(Dataset):
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
scene_folders: str,
|
102 |
+
samples_per_set: Union[int, Tuple[int, int]], # Changed from samples_per_set to samples_range
|
103 |
+
transform=None,
|
104 |
+
caption_keys: Union[str, List] = "caption",
|
105 |
+
multiscale=False,
|
106 |
+
aspect_ratio_type=ASPECT_RATIO_512,
|
107 |
+
c2w_scaling=1.7,
|
108 |
+
default_max_distance=1, # default max distance from all camera of a scene ,
|
109 |
+
do_normalize=True, # whether normalize translation of c2w with max_distance
|
110 |
+
swap_xz=False, # whether swap x and z axis of 3D scenes
|
111 |
+
valid_paths: str = "",
|
112 |
+
frame_sliding_windows: float = None # limit all sampled frames to be within this window, so that camera poses won't be too different
|
113 |
+
):
|
114 |
+
if not isinstance(samples_per_set, tuple) and not isinstance(samples_per_set, list):
|
115 |
+
samples_per_set = (samples_per_set, samples_per_set)
|
116 |
+
self.samples_range = samples_per_set # Tuple of (min_samples, max_samples)
|
117 |
+
self.transform = transform
|
118 |
+
self.caption_keys = caption_keys if isinstance(caption_keys, list) else [caption_keys]
|
119 |
+
self.aspect_ratio = aspect_ratio_type
|
120 |
+
self.scene_folders = sorted(glob.glob(scene_folders))
|
121 |
+
# filter out scene folders that do not have transforms.json
|
122 |
+
self.scene_folders = list(filter(lambda x: os.path.exists(os.path.join(x, "transforms.json")), self.scene_folders))
|
123 |
+
|
124 |
+
# if valid_paths.txt exists, only use paths in that file
|
125 |
+
if os.path.exists(valid_paths):
|
126 |
+
with open(valid_paths, 'r') as f:
|
127 |
+
valid_scene_folders = f.read().splitlines()
|
128 |
+
self.scene_folders = sorted(valid_scene_folders)
|
129 |
+
|
130 |
+
self.c2w_scaling = c2w_scaling
|
131 |
+
self.do_normalize = do_normalize
|
132 |
+
self.default_max_distance = default_max_distance
|
133 |
+
self.swap_xz = swap_xz
|
134 |
+
self.frame_sliding_windows = frame_sliding_windows
|
135 |
+
|
136 |
+
if multiscale:
|
137 |
+
assert self.aspect_ratio in [ASPECT_RATIO_512, ASPECT_RATIO_1024, ASPECT_RATIO_2048, ASPECT_RATIO_2880]
|
138 |
+
if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]:
|
139 |
+
self.interpolate_model = T.InterpolationMode.LANCZOS
|
140 |
+
self.ratio_index = {}
|
141 |
+
self.ratio_nums = {}
|
142 |
+
for k, v in self.aspect_ratio.items():
|
143 |
+
self.ratio_index[float(k)] = [] # used for self.getitem
|
144 |
+
self.ratio_nums[float(k)] = 0 # used for batch-sampler
|
145 |
+
|
146 |
+
def __len__(self):
|
147 |
+
return len(self.scene_folders)
|
148 |
+
|
149 |
+
def __getitem__(self, idx):
|
150 |
+
try:
|
151 |
+
scene_path = self.scene_folders[idx]
|
152 |
+
|
153 |
+
if os.path.exists(os.path.join(scene_path, "images")):
|
154 |
+
image_folder = os.path.join(scene_path, "images")
|
155 |
+
downscale_factor = 1
|
156 |
+
elif os.path.exists(os.path.join(scene_path, "images_4")):
|
157 |
+
image_folder = os.path.join(scene_path, "images_4")
|
158 |
+
downscale_factor = 1 / 4
|
159 |
+
elif os.path.exists(os.path.join(scene_path, "images_8")):
|
160 |
+
image_folder = os.path.join(scene_path, "images_8")
|
161 |
+
downscale_factor = 1 / 8
|
162 |
+
else:
|
163 |
+
raise NotImplementedError
|
164 |
+
|
165 |
+
json_path = os.path.join(scene_path, "transforms.json")
|
166 |
+
caption_path = os.path.join(scene_path, "caption.json")
|
167 |
+
image_files = os.listdir(image_folder)
|
168 |
+
|
169 |
+
with open(json_path, 'r') as f:
|
170 |
+
json_data = json.load(f)
|
171 |
+
height, width = json_data['h'], json_data['w']
|
172 |
+
|
173 |
+
dh, dw = int(height * downscale_factor), int(width * downscale_factor)
|
174 |
+
fl_x, fl_y = json_data['fl_x'] * downscale_factor, json_data['fl_y'] * downscale_factor
|
175 |
+
cx = dw // 2
|
176 |
+
cy = dh // 2
|
177 |
+
|
178 |
+
frame_list = json_data['frames']
|
179 |
+
|
180 |
+
# Randomly select number of samples
|
181 |
+
|
182 |
+
samples_per_set = random.randint(self.samples_range[0], self.samples_range[1])
|
183 |
+
|
184 |
+
# uniformly for all scenes
|
185 |
+
if self.frame_sliding_windows is None:
|
186 |
+
selected_indices = random.sample(range(len(frame_list)), min(samples_per_set, len(frame_list)))
|
187 |
+
# limit the multiview to be in a sliding window (to avoid catastrophic difference in camera angles)
|
188 |
+
else:
|
189 |
+
# Determine the starting index of the sliding window
|
190 |
+
if len(frame_list) <= self.frame_sliding_windows:
|
191 |
+
# If the frame list is smaller than or equal to X, use the entire list
|
192 |
+
window_start = 0
|
193 |
+
window_end = len(frame_list)
|
194 |
+
else:
|
195 |
+
# Randomly select a starting point for the window
|
196 |
+
window_start = random.randint(0, len(frame_list) - self.frame_sliding_windows)
|
197 |
+
window_end = window_start + self.frame_sliding_windows
|
198 |
+
|
199 |
+
# Get the indices within the sliding window
|
200 |
+
window_indices = list(range(window_start, window_end))
|
201 |
+
|
202 |
+
# Randomly sample indices from the window
|
203 |
+
selected_indices = random.sample(window_indices, samples_per_set)
|
204 |
+
|
205 |
+
image_files = [os.path.basename(frame_list[i]['file_path']) for i in selected_indices]
|
206 |
+
image_paths = [os.path.join(image_folder, file) for file in image_files]
|
207 |
+
|
208 |
+
# Load images and convert RGBA to RGB with white background
|
209 |
+
images = [convert_rgba_to_rgb_white_bg(Image.open(image_path)) for image_path in image_paths]
|
210 |
+
|
211 |
+
if self.transform:
|
212 |
+
images = [self.transform(image) for image in images]
|
213 |
+
else:
|
214 |
+
closest_size, closest_ratio = self.aspect_ratio['1.0'], 1.0
|
215 |
+
closest_size = tuple(map(int, closest_size))
|
216 |
+
transform = T.Compose([
|
217 |
+
T.ToTensor(),
|
218 |
+
CenterCropResizeImage(closest_size),
|
219 |
+
T.Normalize([.5], [.5]),
|
220 |
+
])
|
221 |
+
images = [transform(image) for image in images]
|
222 |
+
images = torch.stack(images)
|
223 |
+
|
224 |
+
c2ws = [frame_list[i]['transform_matrix'] for i in selected_indices]
|
225 |
+
c2ws = torch.tensor(c2ws).reshape(-1, 4, 4)
|
226 |
+
# max_distance = json_data.get('max_distance', self.default_max_distance)
|
227 |
+
# if 'max_distance' not in json_data.keys():
|
228 |
+
# print(f"not found `max_distance` in json path: {json_path}")
|
229 |
+
|
230 |
+
if self.swap_xz:
|
231 |
+
swap_xz = torch.tensor([[[0, 0, 1., 0],
|
232 |
+
[0, 1., 0, 0],
|
233 |
+
[-1., 0, 0, 0],
|
234 |
+
[0, 0, 0, 1.]]])
|
235 |
+
c2ws = swap_xz @ c2ws
|
236 |
+
|
237 |
+
# OPENGL to OPENCV
|
238 |
+
c2ws[:, 0:3, 1:3] *= -1
|
239 |
+
c2ws = c2ws[:, [1, 0, 2, 3], :]
|
240 |
+
c2ws[:, 2, :] *= -1
|
241 |
+
|
242 |
+
w2cs = torch.inverse(c2ws)
|
243 |
+
K = torch.tensor([[[fl_x, 0, cx], [0, fl_y, cy], [0, 0, 1]]]).repeat(len(c2ws), 1, 1)
|
244 |
+
Rs = w2cs[:, :3, :3]
|
245 |
+
Ts = w2cs[:, :3, 3]
|
246 |
+
sizes = torch.tensor([[dh, dw]]).repeat(len(c2ws), 1)
|
247 |
+
|
248 |
+
# get ray embedding and padding last dimension to 16 (num channels of VAE)
|
249 |
+
# rays_od = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, use_plucker=False, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling)
|
250 |
+
rays = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling)
|
251 |
+
rays = rays.reshape(samples_per_set, closest_size[0] // 8, closest_size[1] // 8, 6)
|
252 |
+
# padding = (0, 10) # pad the last dimension to 16
|
253 |
+
# rays = torch.nn.functional.pad(rays, padding, "constant", 0)
|
254 |
+
rays = torch.cat([rays, rays, rays[..., :4]], dim=-1) * 1.658
|
255 |
+
|
256 |
+
if os.path.exists(caption_path):
|
257 |
+
with open(caption_path, 'r') as f:
|
258 |
+
caption_key = random.choice(self.caption_keys)
|
259 |
+
caption = json.load(f).get(caption_key, "")
|
260 |
+
else:
|
261 |
+
caption = ""
|
262 |
+
|
263 |
+
caption = "[[multiview]] " + caption if caption else "[[multiview]]"
|
264 |
+
|
265 |
+
return {
|
266 |
+
'pixel_values': images,
|
267 |
+
'rays': rays,
|
268 |
+
'aspect_ratio': closest_ratio,
|
269 |
+
'caption': caption,
|
270 |
+
'height': dh,
|
271 |
+
'width': dw,
|
272 |
+
# 'origins': rays_od[..., :3],
|
273 |
+
# 'dirs': rays_od[..., 3:6]
|
274 |
+
}
|
275 |
+
except Exception as e:
|
276 |
+
return self.__getitem__(random.randint(0, len(self.scene_folders) - 1))
|
277 |
+
|
onediffusion/dataset/raydiff_utils.py
ADDED
@@ -0,0 +1,739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
Adapted from code originally written by David Novotny.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from pytorch3d.transforms import Rotate, Translate
|
8 |
+
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from pytorch3d.renderer import PerspectiveCameras, RayBundle
|
13 |
+
|
14 |
+
def intersect_skew_line_groups(p, r, mask):
|
15 |
+
# p, r both of shape (B, N, n_intersected_lines, 3)
|
16 |
+
# mask of shape (B, N, n_intersected_lines)
|
17 |
+
p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
|
18 |
+
if p_intersect is None:
|
19 |
+
return None, None, None, None
|
20 |
+
_, p_line_intersect = point_line_distance(
|
21 |
+
p, r, p_intersect[..., None, :].expand_as(p)
|
22 |
+
)
|
23 |
+
intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(
|
24 |
+
dim=-1
|
25 |
+
)
|
26 |
+
return p_intersect, p_line_intersect, intersect_dist_squared, r
|
27 |
+
|
28 |
+
|
29 |
+
def intersect_skew_lines_high_dim(p, r, mask=None):
|
30 |
+
# Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
|
31 |
+
dim = p.shape[-1]
|
32 |
+
# make sure the heading vectors are l2-normed
|
33 |
+
if mask is None:
|
34 |
+
mask = torch.ones_like(p[..., 0])
|
35 |
+
r = torch.nn.functional.normalize(r, dim=-1)
|
36 |
+
|
37 |
+
eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
|
38 |
+
I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
|
39 |
+
sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
|
40 |
+
|
41 |
+
# I_eps = torch.zeros_like(I_min_cov.sum(dim=-3)) + 1e-10
|
42 |
+
# p_intersect = torch.pinverse(I_min_cov.sum(dim=-3) + I_eps).matmul(sum_proj)[..., 0]
|
43 |
+
p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
|
44 |
+
|
45 |
+
# I_min_cov.sum(dim=-3): torch.Size([1, 1, 3, 3])
|
46 |
+
# sum_proj: torch.Size([1, 1, 3, 1])
|
47 |
+
|
48 |
+
# p_intersect = np.linalg.lstsq(I_min_cov.sum(dim=-3).numpy(), sum_proj.numpy(), rcond=None)[0]
|
49 |
+
|
50 |
+
if torch.any(torch.isnan(p_intersect)):
|
51 |
+
print(p_intersect)
|
52 |
+
return None, None
|
53 |
+
ipdb.set_trace()
|
54 |
+
assert False
|
55 |
+
return p_intersect, r
|
56 |
+
|
57 |
+
|
58 |
+
def point_line_distance(p1, r1, p2):
|
59 |
+
df = p2 - p1
|
60 |
+
proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
|
61 |
+
line_pt_nearest = p2 - proj_vector
|
62 |
+
d = (proj_vector).norm(dim=-1)
|
63 |
+
return d, line_pt_nearest
|
64 |
+
|
65 |
+
|
66 |
+
def compute_optical_axis_intersection(cameras):
|
67 |
+
centers = cameras.get_camera_center()
|
68 |
+
principal_points = cameras.principal_point
|
69 |
+
|
70 |
+
one_vec = torch.ones((len(cameras), 1), device=centers.device)
|
71 |
+
optical_axis = torch.cat((principal_points, one_vec), -1)
|
72 |
+
|
73 |
+
# optical_axis = torch.cat(
|
74 |
+
# (principal_points, cameras.focal_length[:, 0].unsqueeze(1)), -1
|
75 |
+
# )
|
76 |
+
|
77 |
+
pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
|
78 |
+
pp2 = torch.diagonal(pp, dim1=0, dim2=1).T
|
79 |
+
|
80 |
+
directions = pp2 - centers
|
81 |
+
centers = centers.unsqueeze(0).unsqueeze(0)
|
82 |
+
directions = directions.unsqueeze(0).unsqueeze(0)
|
83 |
+
|
84 |
+
p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(
|
85 |
+
p=centers, r=directions, mask=None
|
86 |
+
)
|
87 |
+
|
88 |
+
if p_intersect is None:
|
89 |
+
dist = None
|
90 |
+
else:
|
91 |
+
p_intersect = p_intersect.squeeze().unsqueeze(0)
|
92 |
+
dist = (p_intersect - centers).norm(dim=-1)
|
93 |
+
|
94 |
+
return p_intersect, dist, p_line_intersect, pp2, r
|
95 |
+
|
96 |
+
|
97 |
+
def normalize_cameras(cameras, scale=1.0):
|
98 |
+
"""
|
99 |
+
Normalizes cameras such that the optical axes point to the origin, the rotation is
|
100 |
+
identity, and the norm of the translation of the first camera is 1.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
cameras (pytorch3d.renderer.cameras.CamerasBase).
|
104 |
+
scale (float): Norm of the translation of the first camera.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
new_cameras (pytorch3d.renderer.cameras.CamerasBase): Normalized cameras.
|
108 |
+
undo_transform (function): Function that undoes the normalization.
|
109 |
+
"""
|
110 |
+
|
111 |
+
# Let distance from first camera to origin be unit
|
112 |
+
new_cameras = cameras.clone()
|
113 |
+
new_transform = (
|
114 |
+
new_cameras.get_world_to_view_transform()
|
115 |
+
) # potential R is not valid matrix
|
116 |
+
p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(
|
117 |
+
cameras
|
118 |
+
)
|
119 |
+
|
120 |
+
if p_intersect is None:
|
121 |
+
print("Warning: optical axes code has a nan. Returning identity cameras.")
|
122 |
+
new_cameras.R[:] = torch.eye(3, device=cameras.R.device, dtype=cameras.R.dtype)
|
123 |
+
new_cameras.T[:] = torch.tensor(
|
124 |
+
[0, 0, 1], device=cameras.T.device, dtype=cameras.T.dtype
|
125 |
+
)
|
126 |
+
return new_cameras, lambda x: x
|
127 |
+
|
128 |
+
d = dist.squeeze(dim=1).squeeze(dim=0)[0]
|
129 |
+
# Degenerate case
|
130 |
+
if d == 0:
|
131 |
+
print(cameras.T)
|
132 |
+
print(new_transform.get_matrix()[:, 3, :3])
|
133 |
+
assert False
|
134 |
+
assert d != 0
|
135 |
+
|
136 |
+
# Can't figure out how to make scale part of the transform too without messing up R.
|
137 |
+
# Ideally, we would just wrap it all in a single Pytorch3D transform so that it
|
138 |
+
# would work with any structure (eg PointClouds, Meshes).
|
139 |
+
tR = Rotate(new_cameras.R[0].unsqueeze(0)).inverse()
|
140 |
+
tT = Translate(p_intersect)
|
141 |
+
t = tR.compose(tT)
|
142 |
+
|
143 |
+
new_transform = t.compose(new_transform)
|
144 |
+
new_cameras.R = new_transform.get_matrix()[:, :3, :3]
|
145 |
+
new_cameras.T = new_transform.get_matrix()[:, 3, :3] / d * scale
|
146 |
+
|
147 |
+
def undo_transform(cameras):
|
148 |
+
cameras_copy = cameras.clone()
|
149 |
+
cameras_copy.T *= d / scale
|
150 |
+
new_t = (
|
151 |
+
t.inverse().compose(cameras_copy.get_world_to_view_transform()).get_matrix()
|
152 |
+
)
|
153 |
+
cameras_copy.R = new_t[:, :3, :3]
|
154 |
+
cameras_copy.T = new_t[:, 3, :3]
|
155 |
+
return cameras_copy
|
156 |
+
|
157 |
+
return new_cameras, undo_transform
|
158 |
+
|
159 |
+
def first_camera_transform(cameras, rotation_only=True):
|
160 |
+
new_cameras = cameras.clone()
|
161 |
+
new_transform = new_cameras.get_world_to_view_transform()
|
162 |
+
tR = Rotate(new_cameras.R[0].unsqueeze(0))
|
163 |
+
if rotation_only:
|
164 |
+
t = tR.inverse()
|
165 |
+
else:
|
166 |
+
tT = Translate(new_cameras.T[0].unsqueeze(0))
|
167 |
+
t = tR.compose(tT).inverse()
|
168 |
+
|
169 |
+
new_transform = t.compose(new_transform)
|
170 |
+
new_cameras.R = new_transform.get_matrix()[:, :3, :3]
|
171 |
+
new_cameras.T = new_transform.get_matrix()[:, 3, :3]
|
172 |
+
|
173 |
+
return new_cameras
|
174 |
+
|
175 |
+
|
176 |
+
def get_identity_cameras_with_intrinsics(cameras):
|
177 |
+
D = len(cameras)
|
178 |
+
device = cameras.R.device
|
179 |
+
|
180 |
+
new_cameras = cameras.clone()
|
181 |
+
new_cameras.R = torch.eye(3, device=device).unsqueeze(0).repeat((D, 1, 1))
|
182 |
+
new_cameras.T = torch.zeros((D, 3), device=device)
|
183 |
+
|
184 |
+
return new_cameras
|
185 |
+
|
186 |
+
|
187 |
+
def normalize_cameras_batch(cameras, scale=1.0, normalize_first_camera=False):
|
188 |
+
new_cameras = []
|
189 |
+
undo_transforms = []
|
190 |
+
for cam in cameras:
|
191 |
+
if normalize_first_camera:
|
192 |
+
# Normalize cameras such that first camera is identity and origin is at
|
193 |
+
# first camera center.
|
194 |
+
normalized_cameras = first_camera_transform(cam, rotation_only=False)
|
195 |
+
undo_transform = None
|
196 |
+
else:
|
197 |
+
normalized_cameras, undo_transform = normalize_cameras(cam, scale=scale)
|
198 |
+
new_cameras.append(normalized_cameras)
|
199 |
+
undo_transforms.append(undo_transform)
|
200 |
+
return new_cameras, undo_transforms
|
201 |
+
|
202 |
+
|
203 |
+
class Rays(object):
|
204 |
+
def __init__(
|
205 |
+
self,
|
206 |
+
rays=None,
|
207 |
+
origins=None,
|
208 |
+
directions=None,
|
209 |
+
moments=None,
|
210 |
+
is_plucker=False,
|
211 |
+
moments_rescale=1.0,
|
212 |
+
ndc_coordinates=None,
|
213 |
+
crop_parameters=None,
|
214 |
+
num_patches_x=16,
|
215 |
+
num_patches_y=16,
|
216 |
+
):
|
217 |
+
"""
|
218 |
+
Ray class to keep track of current ray representation.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
rays: (..., 6).
|
222 |
+
origins: (..., 3).
|
223 |
+
directions: (..., 3).
|
224 |
+
moments: (..., 3).
|
225 |
+
is_plucker: If True, rays are in plucker coordinates (Default: False).
|
226 |
+
moments_rescale: Rescale the moment component of the rays by a scalar.
|
227 |
+
ndc_coordinates: (..., 2): NDC coordinates of each ray.
|
228 |
+
"""
|
229 |
+
if rays is not None:
|
230 |
+
self.rays = rays
|
231 |
+
self._is_plucker = is_plucker
|
232 |
+
elif origins is not None and directions is not None:
|
233 |
+
self.rays = torch.cat((origins, directions), dim=-1)
|
234 |
+
self._is_plucker = False
|
235 |
+
elif directions is not None and moments is not None:
|
236 |
+
self.rays = torch.cat((directions, moments), dim=-1)
|
237 |
+
self._is_plucker = True
|
238 |
+
else:
|
239 |
+
raise Exception("Invalid combination of arguments")
|
240 |
+
|
241 |
+
if moments_rescale != 1.0:
|
242 |
+
self.rescale_moments(moments_rescale)
|
243 |
+
|
244 |
+
if ndc_coordinates is not None:
|
245 |
+
self.ndc_coordinates = ndc_coordinates
|
246 |
+
elif crop_parameters is not None:
|
247 |
+
# (..., H, W, 2)
|
248 |
+
xy_grid = compute_ndc_coordinates(
|
249 |
+
crop_parameters,
|
250 |
+
num_patches_x=num_patches_x,
|
251 |
+
num_patches_y=num_patches_y,
|
252 |
+
)[..., :2]
|
253 |
+
xy_grid = xy_grid.reshape(*xy_grid.shape[:-3], -1, 2)
|
254 |
+
self.ndc_coordinates = xy_grid
|
255 |
+
else:
|
256 |
+
self.ndc_coordinates = None
|
257 |
+
|
258 |
+
def __getitem__(self, index):
|
259 |
+
return Rays(
|
260 |
+
rays=self.rays[index],
|
261 |
+
is_plucker=self._is_plucker,
|
262 |
+
ndc_coordinates=(
|
263 |
+
self.ndc_coordinates[index]
|
264 |
+
if self.ndc_coordinates is not None
|
265 |
+
else None
|
266 |
+
),
|
267 |
+
)
|
268 |
+
|
269 |
+
def to_spatial(self, include_ndc_coordinates=False):
|
270 |
+
"""
|
271 |
+
Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W)
|
272 |
+
|
273 |
+
Returns:
|
274 |
+
torch.Tensor: (..., 6, H, W)
|
275 |
+
"""
|
276 |
+
rays = self.to_plucker().rays
|
277 |
+
*batch_dims, P, D = rays.shape
|
278 |
+
H = W = int(np.sqrt(P))
|
279 |
+
assert H * W == P
|
280 |
+
rays = torch.transpose(rays, -1, -2) # (..., 6, H * W)
|
281 |
+
rays = rays.reshape(*batch_dims, D, H, W)
|
282 |
+
if include_ndc_coordinates:
|
283 |
+
ndc_coords = self.ndc_coordinates.transpose(-1, -2) # (..., 2, H * W)
|
284 |
+
ndc_coords = ndc_coords.reshape(*batch_dims, 2, H, W)
|
285 |
+
rays = torch.cat((rays, ndc_coords), dim=-3)
|
286 |
+
return rays
|
287 |
+
|
288 |
+
def rescale_moments(self, scale):
|
289 |
+
"""
|
290 |
+
Rescale the moment component of the rays by a scalar. Might be desirable since
|
291 |
+
moments may come from a very narrow distribution.
|
292 |
+
|
293 |
+
Note that this modifies in place!
|
294 |
+
"""
|
295 |
+
if self.is_plucker:
|
296 |
+
self.rays[..., 3:] *= scale
|
297 |
+
return self
|
298 |
+
else:
|
299 |
+
return self.to_plucker().rescale_moments(scale)
|
300 |
+
|
301 |
+
@classmethod
|
302 |
+
def from_spatial(cls, rays, moments_rescale=1.0, ndc_coordinates=None):
|
303 |
+
"""
|
304 |
+
Converts rays from spatial representation: (..., 6, H, W) --> (..., H * W, 6)
|
305 |
+
|
306 |
+
Args:
|
307 |
+
rays: (..., 6, H, W)
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
Rays: (..., H * W, 6)
|
311 |
+
"""
|
312 |
+
*batch_dims, D, H, W = rays.shape
|
313 |
+
rays = rays.reshape(*batch_dims, D, H * W)
|
314 |
+
rays = torch.transpose(rays, -1, -2)
|
315 |
+
return cls(
|
316 |
+
rays=rays,
|
317 |
+
is_plucker=True,
|
318 |
+
moments_rescale=moments_rescale,
|
319 |
+
ndc_coordinates=ndc_coordinates,
|
320 |
+
)
|
321 |
+
|
322 |
+
def to_point_direction(self, normalize_moment=True):
|
323 |
+
"""
|
324 |
+
Convert to point direction representation <O, D>.
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
rays: (..., 6).
|
328 |
+
"""
|
329 |
+
if self._is_plucker:
|
330 |
+
direction = torch.nn.functional.normalize(self.rays[..., :3], dim=-1)
|
331 |
+
moment = self.rays[..., 3:]
|
332 |
+
if normalize_moment:
|
333 |
+
c = torch.linalg.norm(direction, dim=-1, keepdim=True)
|
334 |
+
moment = moment / c
|
335 |
+
points = torch.cross(direction, moment, dim=-1)
|
336 |
+
return Rays(
|
337 |
+
rays=torch.cat((points, direction), dim=-1),
|
338 |
+
is_plucker=False,
|
339 |
+
ndc_coordinates=self.ndc_coordinates,
|
340 |
+
)
|
341 |
+
else:
|
342 |
+
return self
|
343 |
+
|
344 |
+
def to_plucker(self):
|
345 |
+
"""
|
346 |
+
Convert to plucker representation <D, OxD>.
|
347 |
+
"""
|
348 |
+
if self.is_plucker:
|
349 |
+
return self
|
350 |
+
else:
|
351 |
+
ray = self.rays.clone()
|
352 |
+
ray_origins = ray[..., :3]
|
353 |
+
ray_directions = ray[..., 3:]
|
354 |
+
# Normalize ray directions to unit vectors
|
355 |
+
ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True)
|
356 |
+
plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1)
|
357 |
+
new_ray = torch.cat([ray_directions, plucker_normal], dim=-1)
|
358 |
+
return Rays(
|
359 |
+
rays=new_ray, is_plucker=True, ndc_coordinates=self.ndc_coordinates
|
360 |
+
)
|
361 |
+
|
362 |
+
def get_directions(self, normalize=True):
|
363 |
+
if self.is_plucker:
|
364 |
+
directions = self.rays[..., :3]
|
365 |
+
else:
|
366 |
+
directions = self.rays[..., 3:]
|
367 |
+
if normalize:
|
368 |
+
directions = torch.nn.functional.normalize(directions, dim=-1)
|
369 |
+
return directions
|
370 |
+
|
371 |
+
def get_origins(self):
|
372 |
+
if self.is_plucker:
|
373 |
+
origins = self.to_point_direction().get_origins()
|
374 |
+
else:
|
375 |
+
origins = self.rays[..., :3]
|
376 |
+
return origins
|
377 |
+
|
378 |
+
def get_moments(self):
|
379 |
+
if self.is_plucker:
|
380 |
+
moments = self.rays[..., 3:]
|
381 |
+
else:
|
382 |
+
moments = self.to_plucker().get_moments()
|
383 |
+
return moments
|
384 |
+
|
385 |
+
def get_ndc_coordinates(self):
|
386 |
+
return self.ndc_coordinates
|
387 |
+
|
388 |
+
@property
|
389 |
+
def is_plucker(self):
|
390 |
+
return self._is_plucker
|
391 |
+
|
392 |
+
@property
|
393 |
+
def device(self):
|
394 |
+
return self.rays.device
|
395 |
+
|
396 |
+
def __repr__(self, *args, **kwargs):
|
397 |
+
ray_str = self.rays.__repr__(*args, **kwargs)[6:] # remove "tensor"
|
398 |
+
if self._is_plucker:
|
399 |
+
return "PluRay" + ray_str
|
400 |
+
else:
|
401 |
+
return "DirRay" + ray_str
|
402 |
+
|
403 |
+
def to(self, device):
|
404 |
+
self.rays = self.rays.to(device)
|
405 |
+
|
406 |
+
def clone(self):
|
407 |
+
return Rays(rays=self.rays.clone(), is_plucker=self._is_plucker)
|
408 |
+
|
409 |
+
@property
|
410 |
+
def shape(self):
|
411 |
+
return self.rays.shape
|
412 |
+
|
413 |
+
def visualize(self):
|
414 |
+
directions = torch.nn.functional.normalize(self.get_directions(), dim=-1).cpu()
|
415 |
+
moments = torch.nn.functional.normalize(self.get_moments(), dim=-1).cpu()
|
416 |
+
return (directions + 1) / 2, (moments + 1) / 2
|
417 |
+
|
418 |
+
def to_ray_bundle(self, length=0.3, recenter=True):
|
419 |
+
lengths = torch.ones_like(self.get_origins()[..., :2]) * length
|
420 |
+
lengths[..., 0] = 0
|
421 |
+
if recenter:
|
422 |
+
centers, _ = intersect_skew_lines_high_dim(
|
423 |
+
self.get_origins(), self.get_directions()
|
424 |
+
)
|
425 |
+
centers = centers.unsqueeze(1).repeat(1, lengths.shape[1], 1)
|
426 |
+
else:
|
427 |
+
centers = self.get_origins()
|
428 |
+
return RayBundle(
|
429 |
+
origins=centers,
|
430 |
+
directions=self.get_directions(),
|
431 |
+
lengths=lengths,
|
432 |
+
xys=self.get_directions(),
|
433 |
+
)
|
434 |
+
|
435 |
+
|
436 |
+
def cameras_to_rays(
|
437 |
+
cameras,
|
438 |
+
crop_parameters,
|
439 |
+
use_half_pix=True,
|
440 |
+
use_plucker=True,
|
441 |
+
num_patches_x=16,
|
442 |
+
num_patches_y=16,
|
443 |
+
):
|
444 |
+
"""
|
445 |
+
Unprojects rays from camera center to grid on image plane.
|
446 |
+
|
447 |
+
Args:
|
448 |
+
cameras: Pytorch3D cameras to unproject. Can be batched.
|
449 |
+
crop_parameters: Crop parameters in NDC (cc_x, cc_y, crop_width, scale).
|
450 |
+
Shape is (B, 4).
|
451 |
+
use_half_pix: If True, use half pixel offset (Default: True).
|
452 |
+
use_plucker: If True, return rays in plucker coordinates (Default: False).
|
453 |
+
num_patches_x: Number of patches in x direction (Default: 16).
|
454 |
+
num_patches_y: Number of patches in y direction (Default: 16).
|
455 |
+
"""
|
456 |
+
unprojected = []
|
457 |
+
crop_parameters_list = (
|
458 |
+
crop_parameters if crop_parameters is not None else [None for _ in cameras]
|
459 |
+
)
|
460 |
+
for camera, crop_param in zip(cameras, crop_parameters_list):
|
461 |
+
xyd_grid = compute_ndc_coordinates(
|
462 |
+
crop_parameters=crop_param,
|
463 |
+
use_half_pix=use_half_pix,
|
464 |
+
num_patches_x=num_patches_x,
|
465 |
+
num_patches_y=num_patches_y,
|
466 |
+
)
|
467 |
+
|
468 |
+
unprojected.append(
|
469 |
+
camera.unproject_points(
|
470 |
+
xyd_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True
|
471 |
+
)
|
472 |
+
)
|
473 |
+
unprojected = torch.stack(unprojected, dim=0) # (N, P, 3)
|
474 |
+
origins = cameras.get_camera_center().unsqueeze(1) # (N, 1, 3)
|
475 |
+
origins = origins.repeat(1, num_patches_x * num_patches_y, 1) # (N, P, 3)
|
476 |
+
directions = unprojected - origins
|
477 |
+
|
478 |
+
rays = Rays(
|
479 |
+
origins=origins,
|
480 |
+
directions=directions,
|
481 |
+
crop_parameters=crop_parameters,
|
482 |
+
num_patches_x=num_patches_x,
|
483 |
+
num_patches_y=num_patches_y,
|
484 |
+
)
|
485 |
+
if use_plucker:
|
486 |
+
return rays.to_plucker()
|
487 |
+
return rays
|
488 |
+
|
489 |
+
|
490 |
+
def rays_to_cameras(
|
491 |
+
rays,
|
492 |
+
crop_parameters,
|
493 |
+
num_patches_x=16,
|
494 |
+
num_patches_y=16,
|
495 |
+
use_half_pix=True,
|
496 |
+
sampled_ray_idx=None,
|
497 |
+
cameras=None,
|
498 |
+
focal_length=(3.453,),
|
499 |
+
):
|
500 |
+
"""
|
501 |
+
If cameras are provided, will use those intrinsics. Otherwise will use the provided
|
502 |
+
focal_length(s). Dataset default is 3.32.
|
503 |
+
|
504 |
+
Args:
|
505 |
+
rays (Rays): (N, P, 6)
|
506 |
+
crop_parameters (torch.Tensor): (N, 4)
|
507 |
+
"""
|
508 |
+
device = rays.device
|
509 |
+
origins = rays.get_origins()
|
510 |
+
directions = rays.get_directions()
|
511 |
+
camera_centers, _ = intersect_skew_lines_high_dim(origins, directions)
|
512 |
+
|
513 |
+
# Retrieve target rays
|
514 |
+
if cameras is None:
|
515 |
+
if len(focal_length) == 1:
|
516 |
+
focal_length = focal_length * rays.shape[0]
|
517 |
+
I_camera = PerspectiveCameras(focal_length=focal_length, device=device)
|
518 |
+
else:
|
519 |
+
# Use same intrinsics but reset to identity extrinsics.
|
520 |
+
I_camera = cameras.clone()
|
521 |
+
I_camera.R[:] = torch.eye(3, device=device)
|
522 |
+
I_camera.T[:] = torch.zeros(3, device=device)
|
523 |
+
I_patch_rays = cameras_to_rays(
|
524 |
+
cameras=I_camera,
|
525 |
+
num_patches_x=num_patches_x,
|
526 |
+
num_patches_y=num_patches_y,
|
527 |
+
use_half_pix=use_half_pix,
|
528 |
+
crop_parameters=crop_parameters,
|
529 |
+
).get_directions()
|
530 |
+
|
531 |
+
if sampled_ray_idx is not None:
|
532 |
+
I_patch_rays = I_patch_rays[:, sampled_ray_idx]
|
533 |
+
|
534 |
+
# Compute optimal rotation to align rays
|
535 |
+
R = torch.zeros_like(I_camera.R)
|
536 |
+
for i in range(len(I_camera)):
|
537 |
+
R[i] = compute_optimal_rotation_alignment(
|
538 |
+
I_patch_rays[i],
|
539 |
+
directions[i],
|
540 |
+
)
|
541 |
+
|
542 |
+
# Construct and return rotated camera
|
543 |
+
cam = I_camera.clone()
|
544 |
+
cam.R = R
|
545 |
+
cam.T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2)
|
546 |
+
return cam
|
547 |
+
|
548 |
+
|
549 |
+
# https://www.reddit.com/r/learnmath/comments/v1crd7/linear_algebra_qr_to_ql_decomposition/
|
550 |
+
def ql_decomposition(A):
|
551 |
+
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device).float()
|
552 |
+
A_tilde = torch.matmul(A, P)
|
553 |
+
Q_tilde, R_tilde = torch.linalg.qr(A_tilde)
|
554 |
+
Q = torch.matmul(Q_tilde, P)
|
555 |
+
L = torch.matmul(torch.matmul(P, R_tilde), P)
|
556 |
+
d = torch.diag(L)
|
557 |
+
Q[:, 0] *= torch.sign(d[0])
|
558 |
+
Q[:, 1] *= torch.sign(d[1])
|
559 |
+
Q[:, 2] *= torch.sign(d[2])
|
560 |
+
L[0] *= torch.sign(d[0])
|
561 |
+
L[1] *= torch.sign(d[1])
|
562 |
+
L[2] *= torch.sign(d[2])
|
563 |
+
return Q, L
|
564 |
+
|
565 |
+
|
566 |
+
def rays_to_cameras_homography(
|
567 |
+
rays,
|
568 |
+
crop_parameters,
|
569 |
+
num_patches_x=16,
|
570 |
+
num_patches_y=16,
|
571 |
+
use_half_pix=True,
|
572 |
+
sampled_ray_idx=None,
|
573 |
+
reproj_threshold=0.2,
|
574 |
+
):
|
575 |
+
"""
|
576 |
+
Args:
|
577 |
+
rays (Rays): (N, P, 6)
|
578 |
+
crop_parameters (torch.Tensor): (N, 4)
|
579 |
+
"""
|
580 |
+
device = rays.device
|
581 |
+
origins = rays.get_origins()
|
582 |
+
directions = rays.get_directions()
|
583 |
+
camera_centers, _ = intersect_skew_lines_high_dim(origins, directions)
|
584 |
+
|
585 |
+
# Retrieve target rays
|
586 |
+
I_camera = PerspectiveCameras(focal_length=[1] * rays.shape[0], device=device)
|
587 |
+
I_patch_rays = cameras_to_rays(
|
588 |
+
cameras=I_camera,
|
589 |
+
num_patches_x=num_patches_x,
|
590 |
+
num_patches_y=num_patches_y,
|
591 |
+
use_half_pix=use_half_pix,
|
592 |
+
crop_parameters=crop_parameters,
|
593 |
+
).get_directions()
|
594 |
+
|
595 |
+
if sampled_ray_idx is not None:
|
596 |
+
I_patch_rays = I_patch_rays[:, sampled_ray_idx]
|
597 |
+
|
598 |
+
# Compute optimal rotation to align rays
|
599 |
+
Rs = []
|
600 |
+
focal_lengths = []
|
601 |
+
principal_points = []
|
602 |
+
for i in range(rays.shape[-3]):
|
603 |
+
R, f, pp = compute_optimal_rotation_intrinsics(
|
604 |
+
I_patch_rays[i],
|
605 |
+
directions[i],
|
606 |
+
reproj_threshold=reproj_threshold,
|
607 |
+
)
|
608 |
+
Rs.append(R)
|
609 |
+
focal_lengths.append(f)
|
610 |
+
principal_points.append(pp)
|
611 |
+
|
612 |
+
R = torch.stack(Rs)
|
613 |
+
focal_lengths = torch.stack(focal_lengths)
|
614 |
+
principal_points = torch.stack(principal_points)
|
615 |
+
T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2)
|
616 |
+
return PerspectiveCameras(
|
617 |
+
R=R,
|
618 |
+
T=T,
|
619 |
+
focal_length=focal_lengths,
|
620 |
+
principal_point=principal_points,
|
621 |
+
device=device,
|
622 |
+
)
|
623 |
+
|
624 |
+
|
625 |
+
def compute_optimal_rotation_alignment(A, B):
|
626 |
+
"""
|
627 |
+
Compute optimal R that minimizes: || A - B @ R ||_F
|
628 |
+
|
629 |
+
Args:
|
630 |
+
A (torch.Tensor): (N, 3)
|
631 |
+
B (torch.Tensor): (N, 3)
|
632 |
+
|
633 |
+
Returns:
|
634 |
+
R (torch.tensor): (3, 3)
|
635 |
+
"""
|
636 |
+
# normally with R @ B, this would be A @ B.T
|
637 |
+
H = B.T @ A
|
638 |
+
U, _, Vh = torch.linalg.svd(H, full_matrices=True)
|
639 |
+
s = torch.linalg.det(U @ Vh)
|
640 |
+
S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device))
|
641 |
+
return U @ S_prime @ Vh
|
642 |
+
|
643 |
+
|
644 |
+
def compute_optimal_rotation_intrinsics(
|
645 |
+
rays_origin, rays_target, z_threshold=1e-4, reproj_threshold=0.2
|
646 |
+
):
|
647 |
+
"""
|
648 |
+
Note: for some reason, f seems to be 1/f.
|
649 |
+
|
650 |
+
Args:
|
651 |
+
rays_origin (torch.Tensor): (N, 3)
|
652 |
+
rays_target (torch.Tensor): (N, 3)
|
653 |
+
z_threshold (float): Threshold for z value to be considered valid.
|
654 |
+
|
655 |
+
Returns:
|
656 |
+
R (torch.tensor): (3, 3)
|
657 |
+
focal_length (torch.tensor): (2,)
|
658 |
+
principal_point (torch.tensor): (2,)
|
659 |
+
"""
|
660 |
+
device = rays_origin.device
|
661 |
+
z_mask = torch.logical_and(
|
662 |
+
torch.abs(rays_target) > z_threshold, torch.abs(rays_origin) > z_threshold
|
663 |
+
)[:, 2]
|
664 |
+
rays_target = rays_target[z_mask]
|
665 |
+
rays_origin = rays_origin[z_mask]
|
666 |
+
rays_origin = rays_origin[:, :2] / rays_origin[:, -1:]
|
667 |
+
rays_target = rays_target[:, :2] / rays_target[:, -1:]
|
668 |
+
|
669 |
+
A, _ = cv2.findHomography(
|
670 |
+
rays_origin.cpu().numpy(),
|
671 |
+
rays_target.cpu().numpy(),
|
672 |
+
cv2.RANSAC,
|
673 |
+
reproj_threshold,
|
674 |
+
)
|
675 |
+
A = torch.from_numpy(A).float().to(device)
|
676 |
+
|
677 |
+
if torch.linalg.det(A) < 0:
|
678 |
+
A = -A
|
679 |
+
|
680 |
+
R, L = ql_decomposition(A)
|
681 |
+
L = L / L[2][2]
|
682 |
+
|
683 |
+
f = torch.stack((L[0][0], L[1][1]))
|
684 |
+
pp = torch.stack((L[2][0], L[2][1]))
|
685 |
+
return R, f, pp
|
686 |
+
|
687 |
+
|
688 |
+
def compute_ndc_coordinates(
|
689 |
+
crop_parameters=None,
|
690 |
+
use_half_pix=True,
|
691 |
+
num_patches_x=16,
|
692 |
+
num_patches_y=16,
|
693 |
+
device=None,
|
694 |
+
):
|
695 |
+
"""
|
696 |
+
Computes NDC Grid using crop_parameters. If crop_parameters is not provided,
|
697 |
+
then it assumes that the crop is the entire image (corresponding to an NDC grid
|
698 |
+
where top left corner is (1, 1) and bottom right corner is (-1, -1)).
|
699 |
+
"""
|
700 |
+
if crop_parameters is None:
|
701 |
+
cc_x, cc_y, width = 0, 0, 2
|
702 |
+
else:
|
703 |
+
if len(crop_parameters.shape) > 1:
|
704 |
+
return torch.stack(
|
705 |
+
[
|
706 |
+
compute_ndc_coordinates(
|
707 |
+
crop_parameters=crop_param,
|
708 |
+
use_half_pix=use_half_pix,
|
709 |
+
num_patches_x=num_patches_x,
|
710 |
+
num_patches_y=num_patches_y,
|
711 |
+
)
|
712 |
+
for crop_param in crop_parameters
|
713 |
+
],
|
714 |
+
dim=0,
|
715 |
+
)
|
716 |
+
device = crop_parameters.device
|
717 |
+
cc_x, cc_y, width, _ = crop_parameters
|
718 |
+
|
719 |
+
dx = 1 / num_patches_x
|
720 |
+
dy = 1 / num_patches_y
|
721 |
+
if use_half_pix:
|
722 |
+
min_y = 1 - dy
|
723 |
+
max_y = -min_y
|
724 |
+
min_x = 1 - dx
|
725 |
+
max_x = -min_x
|
726 |
+
else:
|
727 |
+
min_y = min_x = 1
|
728 |
+
max_y = -1 + 2 * dy
|
729 |
+
max_x = -1 + 2 * dx
|
730 |
+
|
731 |
+
y, x = torch.meshgrid(
|
732 |
+
torch.linspace(min_y, max_y, num_patches_y, dtype=torch.float32, device=device),
|
733 |
+
torch.linspace(min_x, max_x, num_patches_x, dtype=torch.float32, device=device),
|
734 |
+
indexing="ij",
|
735 |
+
)
|
736 |
+
x_prime = x * width / 2 - cc_x
|
737 |
+
y_prime = y * width / 2 - cc_y
|
738 |
+
xyd_grid = torch.stack([x_prime, y_prime, torch.ones_like(x)], dim=-1)
|
739 |
+
return xyd_grid
|
onediffusion/dataset/transforms.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
def crop(image, i, j, h, w):
|
5 |
+
"""
|
6 |
+
Args:
|
7 |
+
image (torch.tensor): Image to be cropped. Size is (C, H, W)
|
8 |
+
"""
|
9 |
+
if len(image.size()) != 3:
|
10 |
+
raise ValueError("image should be a 3D tensor")
|
11 |
+
return image[..., i : i + h, j : j + w]
|
12 |
+
|
13 |
+
def resize(image, target_size, interpolation_mode):
|
14 |
+
if len(target_size) != 2:
|
15 |
+
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
16 |
+
return F.interpolate(image.unsqueeze(0), size=target_size, mode=interpolation_mode, align_corners=False).squeeze(0)
|
17 |
+
|
18 |
+
def resize_scale(image, target_size, interpolation_mode):
|
19 |
+
if len(target_size) != 2:
|
20 |
+
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
21 |
+
H, W = image.size(-2), image.size(-1)
|
22 |
+
scale_ = target_size[0] / min(H, W)
|
23 |
+
return F.interpolate(image.unsqueeze(0), scale_factor=scale_, mode=interpolation_mode, align_corners=False).squeeze(0)
|
24 |
+
|
25 |
+
def resized_crop(image, i, j, h, w, size, interpolation_mode="bilinear"):
|
26 |
+
"""
|
27 |
+
Do spatial cropping and resizing to the image
|
28 |
+
Args:
|
29 |
+
image (torch.tensor): Image to be cropped. Size is (C, H, W)
|
30 |
+
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
31 |
+
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
32 |
+
h (int): Height of the cropped region.
|
33 |
+
w (int): Width of the cropped region.
|
34 |
+
size (tuple(int, int)): height and width of resized image
|
35 |
+
Returns:
|
36 |
+
image (torch.tensor): Resized and cropped image. Size is (C, H, W)
|
37 |
+
"""
|
38 |
+
if len(image.size()) != 3:
|
39 |
+
raise ValueError("image should be a 3D torch.tensor")
|
40 |
+
image = crop(image, i, j, h, w)
|
41 |
+
image = resize(image, size, interpolation_mode)
|
42 |
+
return image
|
43 |
+
|
44 |
+
def center_crop(image, crop_size):
|
45 |
+
if len(image.size()) != 3:
|
46 |
+
raise ValueError("image should be a 3D torch.tensor")
|
47 |
+
h, w = image.size(-2), image.size(-1)
|
48 |
+
th, tw = crop_size
|
49 |
+
if h < th or w < tw:
|
50 |
+
raise ValueError("height and width must be no smaller than crop_size")
|
51 |
+
i = int(round((h - th) / 2.0))
|
52 |
+
j = int(round((w - tw) / 2.0))
|
53 |
+
return crop(image, i, j, th, tw)
|
54 |
+
|
55 |
+
def center_crop_using_short_edge(image):
|
56 |
+
if len(image.size()) != 3:
|
57 |
+
raise ValueError("image should be a 3D torch.tensor")
|
58 |
+
h, w = image.size(-2), image.size(-1)
|
59 |
+
if h < w:
|
60 |
+
th, tw = h, h
|
61 |
+
i = 0
|
62 |
+
j = int(round((w - tw) / 2.0))
|
63 |
+
else:
|
64 |
+
th, tw = w, w
|
65 |
+
i = int(round((h - th) / 2.0))
|
66 |
+
j = 0
|
67 |
+
return crop(image, i, j, th, tw)
|
68 |
+
|
69 |
+
class CenterCropResizeImage:
|
70 |
+
"""
|
71 |
+
Resize the image while maintaining aspect ratio, and then crop it to the desired size.
|
72 |
+
The resizing is done such that the area of padding/cropping is minimized.
|
73 |
+
"""
|
74 |
+
def __init__(self, size, interpolation_mode="bilinear"):
|
75 |
+
if isinstance(size, tuple):
|
76 |
+
if len(size) != 2:
|
77 |
+
raise ValueError(f"Size should be a tuple (height, width), instead got {size}")
|
78 |
+
self.size = size
|
79 |
+
else:
|
80 |
+
self.size = (size, size)
|
81 |
+
self.interpolation_mode = interpolation_mode
|
82 |
+
|
83 |
+
def __call__(self, image):
|
84 |
+
"""
|
85 |
+
Args:
|
86 |
+
image (torch.Tensor): Image to be resized and cropped. Size is (C, H, W)
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
torch.Tensor: Resized and cropped image. Size is (C, target_height, target_width)
|
90 |
+
"""
|
91 |
+
target_height, target_width = self.size
|
92 |
+
target_aspect = target_width / target_height
|
93 |
+
|
94 |
+
# Get current image shape and aspect ratio
|
95 |
+
_, height, width = image.shape
|
96 |
+
height, width = float(height), float(width)
|
97 |
+
current_aspect = width / height
|
98 |
+
|
99 |
+
# Calculate crop dimensions
|
100 |
+
if current_aspect > target_aspect:
|
101 |
+
# Image is wider than target, crop width
|
102 |
+
crop_height = height
|
103 |
+
crop_width = height * target_aspect
|
104 |
+
else:
|
105 |
+
# Image is taller than target, crop height
|
106 |
+
crop_height = width / target_aspect
|
107 |
+
crop_width = width
|
108 |
+
|
109 |
+
# Calculate crop coordinates (center crop)
|
110 |
+
y1 = (height - crop_height) / 2
|
111 |
+
x1 = (width - crop_width) / 2
|
112 |
+
|
113 |
+
# Perform the crop
|
114 |
+
cropped_image = crop(image, int(y1), int(x1), int(crop_height), int(crop_width))
|
115 |
+
|
116 |
+
# Resize the cropped image to the target size
|
117 |
+
resized_image = resize(cropped_image, self.size, self.interpolation_mode)
|
118 |
+
|
119 |
+
return resized_image
|
120 |
+
|
121 |
+
# Example usage
|
122 |
+
if __name__ == "__main__":
|
123 |
+
# Create a sample image tensor
|
124 |
+
sample_image = torch.rand(3, 480, 640) # (C, H, W)
|
125 |
+
|
126 |
+
# Initialize the transform
|
127 |
+
transform = CenterCropResizeImage(size=(224, 224), interpolation_mode="bilinear")
|
128 |
+
|
129 |
+
# Apply the transform
|
130 |
+
transformed_image = transform(sample_image)
|
131 |
+
|
132 |
+
print(f"Original image shape: {sample_image.shape}")
|
133 |
+
print(f"Transformed image shape: {transformed_image.shape}")
|
onediffusion/dataset/utils.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
ASPECT_RATIO_2880 = {
|
3 |
+
'0.25': [1408.0, 5760.0], '0.26': [1408.0, 5568.0], '0.27': [1408.0, 5376.0], '0.28': [1408.0, 5184.0],
|
4 |
+
'0.32': [1600.0, 4992.0], '0.33': [1600.0, 4800.0], '0.34': [1600.0, 4672.0], '0.4': [1792.0, 4480.0],
|
5 |
+
'0.42': [1792.0, 4288.0], '0.47': [1920.0, 4096.0], '0.49': [1920.0, 3904.0], '0.51': [1920.0, 3776.0],
|
6 |
+
'0.55': [2112.0, 3840.0], '0.59': [2112.0, 3584.0], '0.68': [2304.0, 3392.0], '0.72': [2304.0, 3200.0],
|
7 |
+
'0.78': [2496.0, 3200.0], '0.83': [2496.0, 3008.0], '0.89': [2688.0, 3008.0], '0.93': [2688.0, 2880.0],
|
8 |
+
'1.0': [2880.0, 2880.0], '1.07': [2880.0, 2688.0], '1.12': [3008.0, 2688.0], '1.21': [3008.0, 2496.0],
|
9 |
+
'1.28': [3200.0, 2496.0], '1.39': [3200.0, 2304.0], '1.47': [3392.0, 2304.0], '1.7': [3584.0, 2112.0],
|
10 |
+
'1.82': [3840.0, 2112.0], '2.03': [3904.0, 1920.0], '2.13': [4096.0, 1920.0], '2.39': [4288.0, 1792.0],
|
11 |
+
'2.5': [4480.0, 1792.0], '2.92': [4672.0, 1600.0], '3.0': [4800.0, 1600.0], '3.12': [4992.0, 1600.0],
|
12 |
+
'3.68': [5184.0, 1408.0], '3.82': [5376.0, 1408.0], '3.95': [5568.0, 1408.0], '4.0': [5760.0, 1408.0]
|
13 |
+
}
|
14 |
+
|
15 |
+
ASPECT_RATIO_2048 = {
|
16 |
+
'0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0], '0.27': [1024.0, 3840.0], '0.28': [1024.0, 3712.0],
|
17 |
+
'0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0],
|
18 |
+
'0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0],
|
19 |
+
'0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0],
|
20 |
+
'0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0],
|
21 |
+
'1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0],
|
22 |
+
'1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0],
|
23 |
+
'1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0],
|
24 |
+
'2.5': [3200.0, 1280.0], '2.89': [3328.0, 1152.0], '3.0': [3456.0, 1152.0], '3.11': [3584.0, 1152.0],
|
25 |
+
'3.62': [3712.0, 1024.0], '3.75': [3840.0, 1024.0], '3.88': [3968.0, 1024.0], '4.0': [4096.0, 1024.0]
|
26 |
+
}
|
27 |
+
|
28 |
+
ASPECT_RATIO_1024 = {
|
29 |
+
'0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.],
|
30 |
+
'0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
|
31 |
+
'0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
|
32 |
+
'0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
|
33 |
+
'0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
|
34 |
+
'1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
|
35 |
+
'1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
|
36 |
+
'1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
|
37 |
+
'2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.],
|
38 |
+
'3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.],
|
39 |
+
}
|
40 |
+
|
41 |
+
ASPECT_RATIO_512 = {
|
42 |
+
'0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
|
43 |
+
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
|
44 |
+
'0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
|
45 |
+
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
|
46 |
+
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
|
47 |
+
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
|
48 |
+
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
|
49 |
+
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
|
50 |
+
'2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
|
51 |
+
'3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
ASPECT_RATIO_384 = {
|
56 |
+
'0.25': [192.0, 768.0],
|
57 |
+
'0.26': [192.0, 736.0],
|
58 |
+
'0.27': [208.0, 768.0],
|
59 |
+
'0.28': [208.0, 736.0],
|
60 |
+
'0.33': [240.0, 720.0],
|
61 |
+
'0.4': [256.0, 640.0],
|
62 |
+
'0.42': [304.0, 720.0],
|
63 |
+
'0.48': [368.0, 768.0],
|
64 |
+
'0.5': [384.0, 768.0],
|
65 |
+
'0.52': [384.0, 736.0],
|
66 |
+
'0.57': [384.0, 672.0],
|
67 |
+
'0.6': [384.0, 640.0],
|
68 |
+
'0.73': [384.0, 528.0],
|
69 |
+
'0.77': [384.0, 496.0],
|
70 |
+
'0.83': [384.0, 464.0],
|
71 |
+
'0.89': [384.0, 432.0],
|
72 |
+
'0.92': [384.0, 416.0],
|
73 |
+
'1.0': [384.0, 384.0],
|
74 |
+
'1.09': [384.0, 352.0],
|
75 |
+
'1.14': [384.0, 336.0],
|
76 |
+
'1.2': [384.0, 320.0],
|
77 |
+
'1.26': [384.0, 304.0],
|
78 |
+
'1.33': [384.0, 288.0],
|
79 |
+
'1.41': [384.0, 272.0],
|
80 |
+
'1.6': [384.0, 240.0],
|
81 |
+
'1.71': [384.0, 224.0],
|
82 |
+
'2.0': [384.0, 192.0],
|
83 |
+
'2.4': [384.0, 160.0],
|
84 |
+
'2.88': [368.0, 128.0],
|
85 |
+
'3.0': [384.0, 128.0],
|
86 |
+
'3.43': [384.0, 112.0],
|
87 |
+
'4.0': [384.0, 96.0]
|
88 |
+
}
|
89 |
+
|
90 |
+
ASPECT_RATIO_256 = {
|
91 |
+
'0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0],
|
92 |
+
'0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
|
93 |
+
'0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
|
94 |
+
'0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
|
95 |
+
'0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
|
96 |
+
'1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
|
97 |
+
'1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
|
98 |
+
'1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
|
99 |
+
'2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0],
|
100 |
+
'3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0]
|
101 |
+
}
|
102 |
+
|
103 |
+
ASPECT_RATIO_256_TEST = {
|
104 |
+
'0.25': [128.0, 512.0], '0.28': [128.0, 464.0],
|
105 |
+
'0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
|
106 |
+
'0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
|
107 |
+
'0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
|
108 |
+
'0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
|
109 |
+
'1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
|
110 |
+
'1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
|
111 |
+
'1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
|
112 |
+
'2.5': [400.0, 160.0], '3.0': [432.0, 144.0],
|
113 |
+
'4.0': [512.0, 128.0]
|
114 |
+
}
|
115 |
+
|
116 |
+
ASPECT_RATIO_512_TEST = {
|
117 |
+
'0.25': [256.0, 1024.0], '0.28': [256.0, 928.0],
|
118 |
+
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
|
119 |
+
'0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
|
120 |
+
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
|
121 |
+
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
|
122 |
+
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
|
123 |
+
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
|
124 |
+
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
|
125 |
+
'2.5': [800.0, 320.0], '3.0': [864.0, 288.0],
|
126 |
+
'4.0': [1024.0, 256.0]
|
127 |
+
}
|
128 |
+
|
129 |
+
ASPECT_RATIO_1024_TEST = {
|
130 |
+
'0.25': [512., 2048.], '0.28': [512., 1856.],
|
131 |
+
'0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
|
132 |
+
'0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
|
133 |
+
'0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
|
134 |
+
'0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
|
135 |
+
'1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
|
136 |
+
'1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
|
137 |
+
'1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
|
138 |
+
'2.5': [1600., 640.], '3.0': [1728., 576.],
|
139 |
+
'4.0': [2048., 512.],
|
140 |
+
}
|
141 |
+
|
142 |
+
ASPECT_RATIO_2048_TEST = {
|
143 |
+
'0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0],
|
144 |
+
'0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0],
|
145 |
+
'0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0],
|
146 |
+
'0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0],
|
147 |
+
'0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0],
|
148 |
+
'1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0],
|
149 |
+
'1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0],
|
150 |
+
'1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0],
|
151 |
+
'2.5': [3200.0, 1280.0], '3.0': [3456.0, 1152.0],
|
152 |
+
'4.0': [4096.0, 1024.0]
|
153 |
+
}
|
154 |
+
|
155 |
+
ASPECT_RATIO_2880_TEST = {
|
156 |
+
'0.25': [2048.0, 8192.0], '0.26': [2048.0, 7936.0],
|
157 |
+
'0.32': [2304.0, 7168.0], '0.33': [2304.0, 6912.0], '0.35': [2304.0, 6656.0], '0.4': [2560.0, 6400.0],
|
158 |
+
'0.42': [2560.0, 6144.0], '0.48': [2816.0, 5888.0], '0.5': [2816.0, 5632.0], '0.52': [2816.0, 5376.0],
|
159 |
+
'0.57': [3072.0, 5376.0], '0.6': [3072.0, 5120.0], '0.68': [3328.0, 4864.0], '0.72': [3328.0, 4608.0],
|
160 |
+
'0.78': [3584.0, 4608.0], '0.82': [3584.0, 4352.0], '0.88': [3840.0, 4352.0], '0.94': [3840.0, 4096.0],
|
161 |
+
'1.0': [4096.0, 4096.0], '1.07': [4096.0, 3840.0], '1.13': [4352.0, 3840.0], '1.21': [4352.0, 3584.0],
|
162 |
+
'1.29': [4608.0, 3584.0], '1.38': [4608.0, 3328.0], '1.46': [4864.0, 3328.0], '1.67': [5120.0, 3072.0],
|
163 |
+
'1.75': [5376.0, 3072.0], '2.0': [5632.0, 2816.0], '2.09': [5888.0, 2816.0], '2.4': [6144.0, 2560.0],
|
164 |
+
'2.5': [6400.0, 2560.0], '3.0': [6912.0, 2304.0],
|
165 |
+
'4.0': [8192.0, 2048.0],
|
166 |
+
}
|
167 |
+
|
168 |
+
def get_chunks(lst, n):
|
169 |
+
for i in range(0, len(lst), n):
|
170 |
+
yield lst[i:i + n]
|
171 |
+
|
172 |
+
def get_closest_ratio(height: float, width: float, ratios: dict):
|
173 |
+
aspect_ratio = height / width
|
174 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
175 |
+
return ratios[closest_ratio], float(closest_ratio)
|
onediffusion/diffusion/pipelines/image_processor.py
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
import warnings
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import PIL.Image
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import torchvision.transforms as T
|
24 |
+
from PIL import Image, ImageFilter, ImageOps
|
25 |
+
|
26 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
27 |
+
from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
28 |
+
|
29 |
+
from onediffusion.dataset.transforms import CenterCropResizeImage
|
30 |
+
|
31 |
+
PipelineImageInput = Union[
|
32 |
+
PIL.Image.Image,
|
33 |
+
np.ndarray,
|
34 |
+
torch.Tensor,
|
35 |
+
List[PIL.Image.Image],
|
36 |
+
List[np.ndarray],
|
37 |
+
List[torch.Tensor],
|
38 |
+
]
|
39 |
+
|
40 |
+
PipelineDepthInput = PipelineImageInput
|
41 |
+
|
42 |
+
|
43 |
+
def is_valid_image(image):
|
44 |
+
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
|
45 |
+
|
46 |
+
|
47 |
+
def is_valid_image_imagelist(images):
|
48 |
+
# check if the image input is one of the supported formats for image and image list:
|
49 |
+
# it can be either one of below 3
|
50 |
+
# (1) a 4d pytorch tensor or numpy array,
|
51 |
+
# (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
|
52 |
+
# (3) a list of valid image
|
53 |
+
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
|
54 |
+
return True
|
55 |
+
elif is_valid_image(images):
|
56 |
+
return True
|
57 |
+
elif isinstance(images, list):
|
58 |
+
return all(is_valid_image(image) for image in images)
|
59 |
+
return False
|
60 |
+
|
61 |
+
|
62 |
+
class VaeImageProcessorOneDiffuser(ConfigMixin):
|
63 |
+
"""
|
64 |
+
Image processor for VAE.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
68 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
69 |
+
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
70 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
71 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
72 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
73 |
+
Resampling filter to use when resizing the image.
|
74 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
75 |
+
Whether to normalize the image to [-1,1].
|
76 |
+
do_binarize (`bool`, *optional*, defaults to `False`):
|
77 |
+
Whether to binarize the image to 0/1.
|
78 |
+
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
79 |
+
Whether to convert the images to RGB format.
|
80 |
+
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
|
81 |
+
Whether to convert the images to grayscale format.
|
82 |
+
"""
|
83 |
+
|
84 |
+
config_name = CONFIG_NAME
|
85 |
+
|
86 |
+
@register_to_config
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
do_resize: bool = True,
|
90 |
+
vae_scale_factor: int = 8,
|
91 |
+
vae_latent_channels: int = 4,
|
92 |
+
resample: str = "lanczos",
|
93 |
+
do_normalize: bool = True,
|
94 |
+
do_binarize: bool = False,
|
95 |
+
do_convert_rgb: bool = False,
|
96 |
+
do_convert_grayscale: bool = False,
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
if do_convert_rgb and do_convert_grayscale:
|
100 |
+
raise ValueError(
|
101 |
+
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
|
102 |
+
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
|
103 |
+
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
|
104 |
+
)
|
105 |
+
|
106 |
+
@staticmethod
|
107 |
+
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
108 |
+
"""
|
109 |
+
Convert a numpy image or a batch of images to a PIL image.
|
110 |
+
"""
|
111 |
+
if images.ndim == 3:
|
112 |
+
images = images[None, ...]
|
113 |
+
images = (images * 255).round().astype("uint8")
|
114 |
+
if images.shape[-1] == 1:
|
115 |
+
# special case for grayscale (single channel) images
|
116 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
117 |
+
else:
|
118 |
+
pil_images = [Image.fromarray(image) for image in images]
|
119 |
+
|
120 |
+
return pil_images
|
121 |
+
|
122 |
+
@staticmethod
|
123 |
+
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
124 |
+
"""
|
125 |
+
Convert a PIL image or a list of PIL images to NumPy arrays.
|
126 |
+
"""
|
127 |
+
if not isinstance(images, list):
|
128 |
+
images = [images]
|
129 |
+
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
|
130 |
+
images = np.stack(images, axis=0)
|
131 |
+
|
132 |
+
return images
|
133 |
+
|
134 |
+
@staticmethod
|
135 |
+
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
|
136 |
+
"""
|
137 |
+
Convert a NumPy image to a PyTorch tensor.
|
138 |
+
"""
|
139 |
+
if images.ndim == 3:
|
140 |
+
images = images[..., None]
|
141 |
+
|
142 |
+
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
143 |
+
return images
|
144 |
+
|
145 |
+
@staticmethod
|
146 |
+
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
|
147 |
+
"""
|
148 |
+
Convert a PyTorch tensor to a NumPy image.
|
149 |
+
"""
|
150 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
151 |
+
return images
|
152 |
+
|
153 |
+
@staticmethod
|
154 |
+
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
155 |
+
"""
|
156 |
+
Normalize an image array to [-1,1].
|
157 |
+
"""
|
158 |
+
return 2.0 * images - 1.0
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
162 |
+
"""
|
163 |
+
Denormalize an image array to [0,1].
|
164 |
+
"""
|
165 |
+
return (images / 2 + 0.5).clamp(0, 1)
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
169 |
+
"""
|
170 |
+
Converts a PIL image to RGB format.
|
171 |
+
"""
|
172 |
+
image = image.convert("RGB")
|
173 |
+
|
174 |
+
return image
|
175 |
+
|
176 |
+
@staticmethod
|
177 |
+
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
|
178 |
+
"""
|
179 |
+
Converts a PIL image to grayscale format.
|
180 |
+
"""
|
181 |
+
image = image.convert("L")
|
182 |
+
|
183 |
+
return image
|
184 |
+
|
185 |
+
@staticmethod
|
186 |
+
def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
|
187 |
+
"""
|
188 |
+
Applies Gaussian blur to an image.
|
189 |
+
"""
|
190 |
+
image = image.filter(ImageFilter.GaussianBlur(blur_factor))
|
191 |
+
|
192 |
+
return image
|
193 |
+
|
194 |
+
@staticmethod
|
195 |
+
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
|
196 |
+
"""
|
197 |
+
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
|
198 |
+
ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
|
199 |
+
processing are 512x512, the region will be expanded to 128x128.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
mask_image (PIL.Image.Image): Mask image.
|
203 |
+
width (int): Width of the image to be processed.
|
204 |
+
height (int): Height of the image to be processed.
|
205 |
+
pad (int, optional): Padding to be added to the crop region. Defaults to 0.
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
|
209 |
+
matches the original aspect ratio.
|
210 |
+
"""
|
211 |
+
|
212 |
+
mask_image = mask_image.convert("L")
|
213 |
+
mask = np.array(mask_image)
|
214 |
+
|
215 |
+
# 1. find a rectangular region that contains all masked ares in an image
|
216 |
+
h, w = mask.shape
|
217 |
+
crop_left = 0
|
218 |
+
for i in range(w):
|
219 |
+
if not (mask[:, i] == 0).all():
|
220 |
+
break
|
221 |
+
crop_left += 1
|
222 |
+
|
223 |
+
crop_right = 0
|
224 |
+
for i in reversed(range(w)):
|
225 |
+
if not (mask[:, i] == 0).all():
|
226 |
+
break
|
227 |
+
crop_right += 1
|
228 |
+
|
229 |
+
crop_top = 0
|
230 |
+
for i in range(h):
|
231 |
+
if not (mask[i] == 0).all():
|
232 |
+
break
|
233 |
+
crop_top += 1
|
234 |
+
|
235 |
+
crop_bottom = 0
|
236 |
+
for i in reversed(range(h)):
|
237 |
+
if not (mask[i] == 0).all():
|
238 |
+
break
|
239 |
+
crop_bottom += 1
|
240 |
+
|
241 |
+
# 2. add padding to the crop region
|
242 |
+
x1, y1, x2, y2 = (
|
243 |
+
int(max(crop_left - pad, 0)),
|
244 |
+
int(max(crop_top - pad, 0)),
|
245 |
+
int(min(w - crop_right + pad, w)),
|
246 |
+
int(min(h - crop_bottom + pad, h)),
|
247 |
+
)
|
248 |
+
|
249 |
+
# 3. expands crop region to match the aspect ratio of the image to be processed
|
250 |
+
ratio_crop_region = (x2 - x1) / (y2 - y1)
|
251 |
+
ratio_processing = width / height
|
252 |
+
|
253 |
+
if ratio_crop_region > ratio_processing:
|
254 |
+
desired_height = (x2 - x1) / ratio_processing
|
255 |
+
desired_height_diff = int(desired_height - (y2 - y1))
|
256 |
+
y1 -= desired_height_diff // 2
|
257 |
+
y2 += desired_height_diff - desired_height_diff // 2
|
258 |
+
if y2 >= mask_image.height:
|
259 |
+
diff = y2 - mask_image.height
|
260 |
+
y2 -= diff
|
261 |
+
y1 -= diff
|
262 |
+
if y1 < 0:
|
263 |
+
y2 -= y1
|
264 |
+
y1 -= y1
|
265 |
+
if y2 >= mask_image.height:
|
266 |
+
y2 = mask_image.height
|
267 |
+
else:
|
268 |
+
desired_width = (y2 - y1) * ratio_processing
|
269 |
+
desired_width_diff = int(desired_width - (x2 - x1))
|
270 |
+
x1 -= desired_width_diff // 2
|
271 |
+
x2 += desired_width_diff - desired_width_diff // 2
|
272 |
+
if x2 >= mask_image.width:
|
273 |
+
diff = x2 - mask_image.width
|
274 |
+
x2 -= diff
|
275 |
+
x1 -= diff
|
276 |
+
if x1 < 0:
|
277 |
+
x2 -= x1
|
278 |
+
x1 -= x1
|
279 |
+
if x2 >= mask_image.width:
|
280 |
+
x2 = mask_image.width
|
281 |
+
|
282 |
+
return x1, y1, x2, y2
|
283 |
+
|
284 |
+
def _resize_and_fill(
|
285 |
+
self,
|
286 |
+
image: PIL.Image.Image,
|
287 |
+
width: int,
|
288 |
+
height: int,
|
289 |
+
) -> PIL.Image.Image:
|
290 |
+
"""
|
291 |
+
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
292 |
+
the image within the dimensions, filling empty with data from image.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
image: The image to resize.
|
296 |
+
width: The width to resize the image to.
|
297 |
+
height: The height to resize the image to.
|
298 |
+
"""
|
299 |
+
|
300 |
+
ratio = width / height
|
301 |
+
src_ratio = image.width / image.height
|
302 |
+
|
303 |
+
src_w = width if ratio < src_ratio else image.width * height // image.height
|
304 |
+
src_h = height if ratio >= src_ratio else image.height * width // image.width
|
305 |
+
|
306 |
+
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
307 |
+
res = Image.new("RGB", (width, height))
|
308 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
309 |
+
|
310 |
+
if ratio < src_ratio:
|
311 |
+
fill_height = height // 2 - src_h // 2
|
312 |
+
if fill_height > 0:
|
313 |
+
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
314 |
+
res.paste(
|
315 |
+
resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
|
316 |
+
box=(0, fill_height + src_h),
|
317 |
+
)
|
318 |
+
elif ratio > src_ratio:
|
319 |
+
fill_width = width // 2 - src_w // 2
|
320 |
+
if fill_width > 0:
|
321 |
+
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
322 |
+
res.paste(
|
323 |
+
resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
|
324 |
+
box=(fill_width + src_w, 0),
|
325 |
+
)
|
326 |
+
|
327 |
+
return res
|
328 |
+
|
329 |
+
def _resize_and_crop(
|
330 |
+
self,
|
331 |
+
image: PIL.Image.Image,
|
332 |
+
width: int,
|
333 |
+
height: int,
|
334 |
+
) -> PIL.Image.Image:
|
335 |
+
"""
|
336 |
+
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
337 |
+
the image within the dimensions, cropping the excess.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
image: The image to resize.
|
341 |
+
width: The width to resize the image to.
|
342 |
+
height: The height to resize the image to.
|
343 |
+
"""
|
344 |
+
ratio = width / height
|
345 |
+
src_ratio = image.width / image.height
|
346 |
+
|
347 |
+
src_w = width if ratio > src_ratio else image.width * height // image.height
|
348 |
+
src_h = height if ratio <= src_ratio else image.height * width // image.width
|
349 |
+
|
350 |
+
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
351 |
+
res = Image.new("RGB", (width, height))
|
352 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
353 |
+
return res
|
354 |
+
|
355 |
+
def resize(
|
356 |
+
self,
|
357 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
358 |
+
height: int,
|
359 |
+
width: int,
|
360 |
+
resize_mode: str = "default", # "default", "fill", "crop"
|
361 |
+
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
362 |
+
"""
|
363 |
+
Resize image.
|
364 |
+
|
365 |
+
Args:
|
366 |
+
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
367 |
+
The image input, can be a PIL image, numpy array or pytorch tensor.
|
368 |
+
height (`int`):
|
369 |
+
The height to resize to.
|
370 |
+
width (`int`):
|
371 |
+
The width to resize to.
|
372 |
+
resize_mode (`str`, *optional*, defaults to `default`):
|
373 |
+
The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
|
374 |
+
within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
|
375 |
+
will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
|
376 |
+
then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
|
377 |
+
the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
378 |
+
the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
379 |
+
supported for PIL image input.
|
380 |
+
|
381 |
+
Returns:
|
382 |
+
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
383 |
+
The resized image.
|
384 |
+
"""
|
385 |
+
if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
|
386 |
+
raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
|
387 |
+
if isinstance(image, PIL.Image.Image):
|
388 |
+
if resize_mode == "default":
|
389 |
+
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
390 |
+
elif resize_mode == "fill":
|
391 |
+
image = self._resize_and_fill(image, width, height)
|
392 |
+
elif resize_mode == "crop":
|
393 |
+
image = self._resize_and_crop(image, width, height)
|
394 |
+
else:
|
395 |
+
raise ValueError(f"resize_mode {resize_mode} is not supported")
|
396 |
+
|
397 |
+
elif isinstance(image, torch.Tensor):
|
398 |
+
image = torch.nn.functional.interpolate(
|
399 |
+
image,
|
400 |
+
size=(height, width),
|
401 |
+
)
|
402 |
+
elif isinstance(image, np.ndarray):
|
403 |
+
image = self.numpy_to_pt(image)
|
404 |
+
image = torch.nn.functional.interpolate(
|
405 |
+
image,
|
406 |
+
size=(height, width),
|
407 |
+
)
|
408 |
+
image = self.pt_to_numpy(image)
|
409 |
+
return image
|
410 |
+
|
411 |
+
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
412 |
+
"""
|
413 |
+
Create a mask.
|
414 |
+
|
415 |
+
Args:
|
416 |
+
image (`PIL.Image.Image`):
|
417 |
+
The image input, should be a PIL image.
|
418 |
+
|
419 |
+
Returns:
|
420 |
+
`PIL.Image.Image`:
|
421 |
+
The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
|
422 |
+
"""
|
423 |
+
image[image < 0.5] = 0
|
424 |
+
image[image >= 0.5] = 1
|
425 |
+
|
426 |
+
return image
|
427 |
+
|
428 |
+
def get_default_height_width(
|
429 |
+
self,
|
430 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
431 |
+
height: Optional[int] = None,
|
432 |
+
width: Optional[int] = None,
|
433 |
+
) -> Tuple[int, int]:
|
434 |
+
"""
|
435 |
+
This function return the height and width that are downscaled to the next integer multiple of
|
436 |
+
`vae_scale_factor`.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
440 |
+
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
|
441 |
+
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
|
442 |
+
have shape `[batch, channel, height, width]`.
|
443 |
+
height (`int`, *optional*, defaults to `None`):
|
444 |
+
The height in preprocessed image. If `None`, will use the height of `image` input.
|
445 |
+
width (`int`, *optional*`, defaults to `None`):
|
446 |
+
The width in preprocessed. If `None`, will use the width of the `image` input.
|
447 |
+
"""
|
448 |
+
|
449 |
+
if height is None:
|
450 |
+
if isinstance(image, PIL.Image.Image):
|
451 |
+
height = image.height
|
452 |
+
elif isinstance(image, torch.Tensor):
|
453 |
+
height = image.shape[2]
|
454 |
+
else:
|
455 |
+
height = image.shape[1]
|
456 |
+
|
457 |
+
if width is None:
|
458 |
+
if isinstance(image, PIL.Image.Image):
|
459 |
+
width = image.width
|
460 |
+
elif isinstance(image, torch.Tensor):
|
461 |
+
width = image.shape[3]
|
462 |
+
else:
|
463 |
+
width = image.shape[2]
|
464 |
+
|
465 |
+
width, height = (
|
466 |
+
x - x % self.config.vae_scale_factor for x in (width, height)
|
467 |
+
) # resize to integer multiple of vae_scale_factor
|
468 |
+
|
469 |
+
return height, width
|
470 |
+
|
471 |
+
def preprocess(
|
472 |
+
self,
|
473 |
+
image: PipelineImageInput,
|
474 |
+
height: Optional[int] = None,
|
475 |
+
width: Optional[int] = None,
|
476 |
+
resize_mode: str = "default", # "default", "fill", "crop"
|
477 |
+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
478 |
+
do_crop: bool = True,
|
479 |
+
) -> torch.Tensor:
|
480 |
+
"""
|
481 |
+
Preprocess the image input.
|
482 |
+
|
483 |
+
Args:
|
484 |
+
image (`pipeline_image_input`):
|
485 |
+
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
|
486 |
+
supported formats.
|
487 |
+
height (`int`, *optional*, defaults to `None`):
|
488 |
+
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
|
489 |
+
height.
|
490 |
+
width (`int`, *optional*`, defaults to `None`):
|
491 |
+
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
492 |
+
resize_mode (`str`, *optional*, defaults to `default`):
|
493 |
+
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
|
494 |
+
the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
|
495 |
+
resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
|
496 |
+
center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
|
497 |
+
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
|
498 |
+
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
499 |
+
supported for PIL image input.
|
500 |
+
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
501 |
+
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
502 |
+
"""
|
503 |
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
504 |
+
|
505 |
+
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
506 |
+
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
|
507 |
+
if isinstance(image, torch.Tensor):
|
508 |
+
# if image is a pytorch tensor could have 2 possible shapes:
|
509 |
+
# 1. batch x height x width: we should insert the channel dimension at position 1
|
510 |
+
# 2. channel x height x width: we should insert batch dimension at position 0,
|
511 |
+
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
|
512 |
+
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
|
513 |
+
image = image.unsqueeze(1)
|
514 |
+
else:
|
515 |
+
# if it is a numpy array, it could have 2 possible shapes:
|
516 |
+
# 1. batch x height x width: insert channel dimension on last position
|
517 |
+
# 2. height x width x channel: insert batch dimension on first position
|
518 |
+
if image.shape[-1] == 1:
|
519 |
+
image = np.expand_dims(image, axis=0)
|
520 |
+
else:
|
521 |
+
image = np.expand_dims(image, axis=-1)
|
522 |
+
|
523 |
+
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
|
524 |
+
warnings.warn(
|
525 |
+
"Passing `image` as a list of 4d np.ndarray is deprecated."
|
526 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
|
527 |
+
FutureWarning,
|
528 |
+
)
|
529 |
+
image = np.concatenate(image, axis=0)
|
530 |
+
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
|
531 |
+
warnings.warn(
|
532 |
+
"Passing `image` as a list of 4d torch.Tensor is deprecated."
|
533 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
|
534 |
+
FutureWarning,
|
535 |
+
)
|
536 |
+
image = torch.cat(image, axis=0)
|
537 |
+
|
538 |
+
if not is_valid_image_imagelist(image):
|
539 |
+
raise ValueError(
|
540 |
+
f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
|
541 |
+
)
|
542 |
+
if not isinstance(image, list):
|
543 |
+
image = [image]
|
544 |
+
|
545 |
+
if isinstance(image[0], PIL.Image.Image):
|
546 |
+
pass
|
547 |
+
elif isinstance(image[0], np.ndarray):
|
548 |
+
image = self.numpy_to_pil(image)
|
549 |
+
elif isinstance(image[0], torch.Tensor):
|
550 |
+
image = self.pt_to_numpy(image)
|
551 |
+
image = self.numpy_to_pil(image)
|
552 |
+
|
553 |
+
if do_crop:
|
554 |
+
transforms = T.Compose([
|
555 |
+
T.Lambda(lambda image: image.convert('RGB')),
|
556 |
+
T.ToTensor(),
|
557 |
+
CenterCropResizeImage((height, width)),
|
558 |
+
T.Normalize([.5], [.5]),
|
559 |
+
])
|
560 |
+
else:
|
561 |
+
transforms = T.Compose([
|
562 |
+
T.Lambda(lambda image: image.convert('RGB')),
|
563 |
+
T.ToTensor(),
|
564 |
+
T.Resize((height, width)),
|
565 |
+
T.Normalize([.5], [.5]),
|
566 |
+
])
|
567 |
+
image = torch.stack([transforms(i) for i in image])
|
568 |
+
|
569 |
+
# expected range [0,1], normalize to [-1,1]
|
570 |
+
do_normalize = self.config.do_normalize
|
571 |
+
if do_normalize and image.min() < 0:
|
572 |
+
warnings.warn(
|
573 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
574 |
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
575 |
+
FutureWarning,
|
576 |
+
)
|
577 |
+
do_normalize = False
|
578 |
+
if do_normalize:
|
579 |
+
image = self.normalize(image)
|
580 |
+
|
581 |
+
if self.config.do_binarize:
|
582 |
+
image = self.binarize(image)
|
583 |
+
|
584 |
+
return image
|
585 |
+
|
586 |
+
def postprocess(
|
587 |
+
self,
|
588 |
+
image: torch.Tensor,
|
589 |
+
output_type: str = "pil",
|
590 |
+
do_denormalize: Optional[List[bool]] = None,
|
591 |
+
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
592 |
+
"""
|
593 |
+
Postprocess the image output from tensor to `output_type`.
|
594 |
+
|
595 |
+
Args:
|
596 |
+
image (`torch.Tensor`):
|
597 |
+
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
598 |
+
output_type (`str`, *optional*, defaults to `pil`):
|
599 |
+
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
600 |
+
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
601 |
+
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
602 |
+
`VaeImageProcessor` config.
|
603 |
+
|
604 |
+
Returns:
|
605 |
+
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
606 |
+
The postprocessed image.
|
607 |
+
"""
|
608 |
+
if not isinstance(image, torch.Tensor):
|
609 |
+
raise ValueError(
|
610 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
611 |
+
)
|
612 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
613 |
+
deprecation_message = (
|
614 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
615 |
+
"`pil`, `np`, `pt`, `latent`"
|
616 |
+
)
|
617 |
+
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
618 |
+
output_type = "np"
|
619 |
+
|
620 |
+
if output_type == "latent":
|
621 |
+
return image
|
622 |
+
|
623 |
+
if do_denormalize is None:
|
624 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
625 |
+
|
626 |
+
image = torch.stack(
|
627 |
+
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
628 |
+
)
|
629 |
+
|
630 |
+
if output_type == "pt":
|
631 |
+
return image
|
632 |
+
|
633 |
+
image = self.pt_to_numpy(image)
|
634 |
+
|
635 |
+
if output_type == "np":
|
636 |
+
return image
|
637 |
+
|
638 |
+
if output_type == "pil":
|
639 |
+
return self.numpy_to_pil(image)
|
640 |
+
|
641 |
+
def apply_overlay(
|
642 |
+
self,
|
643 |
+
mask: PIL.Image.Image,
|
644 |
+
init_image: PIL.Image.Image,
|
645 |
+
image: PIL.Image.Image,
|
646 |
+
crop_coords: Optional[Tuple[int, int, int, int]] = None,
|
647 |
+
) -> PIL.Image.Image:
|
648 |
+
"""
|
649 |
+
overlay the inpaint output to the original image
|
650 |
+
"""
|
651 |
+
|
652 |
+
width, height = image.width, image.height
|
653 |
+
|
654 |
+
init_image = self.resize(init_image, width=width, height=height)
|
655 |
+
mask = self.resize(mask, width=width, height=height)
|
656 |
+
|
657 |
+
init_image_masked = PIL.Image.new("RGBa", (width, height))
|
658 |
+
init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
|
659 |
+
init_image_masked = init_image_masked.convert("RGBA")
|
660 |
+
|
661 |
+
if crop_coords is not None:
|
662 |
+
x, y, x2, y2 = crop_coords
|
663 |
+
w = x2 - x
|
664 |
+
h = y2 - y
|
665 |
+
base_image = PIL.Image.new("RGBA", (width, height))
|
666 |
+
image = self.resize(image, height=h, width=w, resize_mode="crop")
|
667 |
+
base_image.paste(image, (x, y))
|
668 |
+
image = base_image.convert("RGB")
|
669 |
+
|
670 |
+
image = image.convert("RGBA")
|
671 |
+
image.alpha_composite(init_image_masked)
|
672 |
+
image = image.convert("RGB")
|
673 |
+
|
674 |
+
return image
|
onediffusion/diffusion/pipelines/onediffusion.py
ADDED
@@ -0,0 +1,1080 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import einops
|
2 |
+
import inspect
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import PIL
|
6 |
+
import os
|
7 |
+
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
10 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
11 |
+
from diffusers.utils import (
|
12 |
+
CONFIG_NAME,
|
13 |
+
DEPRECATED_REVISION_ARGS,
|
14 |
+
BaseOutput,
|
15 |
+
PushToHubMixin,
|
16 |
+
deprecate,
|
17 |
+
is_accelerate_available,
|
18 |
+
is_accelerate_version,
|
19 |
+
is_torch_npu_available,
|
20 |
+
is_torch_version,
|
21 |
+
logging,
|
22 |
+
numpy_to_pil,
|
23 |
+
replace_example_docstring,
|
24 |
+
)
|
25 |
+
from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
|
26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
27 |
+
from diffusers.utils import BaseOutput
|
28 |
+
# from diffusers.image_processor import VaeImageProcessor
|
29 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
30 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
31 |
+
from PIL import Image
|
32 |
+
|
33 |
+
from onediffusion.models.denoiser.nextdit import NextDiT
|
34 |
+
from onediffusion.dataset.utils import *
|
35 |
+
from onediffusion.dataset.multitask.multiview import calculate_rays
|
36 |
+
from onediffusion.diffusion.pipelines.image_processor import VaeImageProcessorOneDiffuser
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39 |
+
|
40 |
+
SUPPORTED_DEVICE_MAP = ["balanced"]
|
41 |
+
|
42 |
+
EXAMPLE_DOC_STRING = """
|
43 |
+
Examples:
|
44 |
+
```py
|
45 |
+
>>> import torch
|
46 |
+
>>> from one_diffusion import OneDiffusionPipeline
|
47 |
+
|
48 |
+
>>> pipe = OneDiffusionPipeline.from_pretrained("path_to_one_diffuser_model")
|
49 |
+
>>> pipe = pipe.to("cuda")
|
50 |
+
|
51 |
+
>>> prompt = "A beautiful sunset over the ocean"
|
52 |
+
>>> image = pipe(prompt).images[0]
|
53 |
+
>>> image.save("beautiful_sunset.png")
|
54 |
+
```
|
55 |
+
"""
|
56 |
+
|
57 |
+
def create_c2w_matrix(azimuth_deg, elevation_deg, distance=1.0, target=np.array([0, 0, 0])):
|
58 |
+
"""
|
59 |
+
Create a Camera-to-World (C2W) matrix from azimuth and elevation angles.
|
60 |
+
|
61 |
+
Parameters:
|
62 |
+
- azimuth_deg: Azimuth angle in degrees.
|
63 |
+
- elevation_deg: Elevation angle in degrees.
|
64 |
+
- distance: Distance from the target point.
|
65 |
+
- target: The point the camera is looking at in world coordinates.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
- C2W: A 4x4 NumPy array representing the Camera-to-World transformation matrix.
|
69 |
+
"""
|
70 |
+
# Convert angles from degrees to radians
|
71 |
+
azimuth = np.deg2rad(azimuth_deg)
|
72 |
+
elevation = np.deg2rad(elevation_deg)
|
73 |
+
|
74 |
+
# Spherical to Cartesian conversion for camera position
|
75 |
+
x = distance * np.cos(elevation) * np.cos(azimuth)
|
76 |
+
y = distance * np.cos(elevation) * np.sin(azimuth)
|
77 |
+
z = distance * np.sin(elevation)
|
78 |
+
camera_position = np.array([x, y, z])
|
79 |
+
|
80 |
+
# Define the forward vector (from camera to target)
|
81 |
+
target = 2*camera_position - target
|
82 |
+
forward = target - camera_position
|
83 |
+
forward /= np.linalg.norm(forward)
|
84 |
+
|
85 |
+
# Define the world up vector
|
86 |
+
world_up = np.array([0, 0, 1])
|
87 |
+
|
88 |
+
# Compute the right vector
|
89 |
+
right = np.cross(world_up, forward)
|
90 |
+
if np.linalg.norm(right) < 1e-6:
|
91 |
+
# Handle the singularity when forward is parallel to world_up
|
92 |
+
world_up = np.array([0, 1, 0])
|
93 |
+
right = np.cross(world_up, forward)
|
94 |
+
right /= np.linalg.norm(right)
|
95 |
+
|
96 |
+
# Recompute the orthogonal up vector
|
97 |
+
up = np.cross(forward, right)
|
98 |
+
|
99 |
+
# Construct the rotation matrix
|
100 |
+
rotation = np.vstack([right, up, forward]).T # 3x3
|
101 |
+
|
102 |
+
# Construct the full C2W matrix
|
103 |
+
C2W = np.eye(4)
|
104 |
+
C2W[:3, :3] = rotation
|
105 |
+
C2W[:3, 3] = camera_position
|
106 |
+
|
107 |
+
return C2W
|
108 |
+
|
109 |
+
@dataclass
|
110 |
+
class OneDiffusionPipelineOutput(BaseOutput):
|
111 |
+
"""
|
112 |
+
Output class for Stable Diffusion pipelines.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
116 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
117 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
118 |
+
"""
|
119 |
+
|
120 |
+
images: Union[List[Image.Image], np.ndarray]
|
121 |
+
latents: Optional[torch.Tensor] = None
|
122 |
+
|
123 |
+
|
124 |
+
def retrieve_latents(
|
125 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
126 |
+
):
|
127 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
128 |
+
return encoder_output.latent_dist.sample(generator)
|
129 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
130 |
+
return encoder_output.latent_dist.mode()
|
131 |
+
elif hasattr(encoder_output, "latents"):
|
132 |
+
return encoder_output.latents
|
133 |
+
else:
|
134 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
135 |
+
|
136 |
+
|
137 |
+
def calculate_shift(
|
138 |
+
image_seq_len,
|
139 |
+
base_seq_len: int = 256,
|
140 |
+
max_seq_len: int = 4096,
|
141 |
+
base_shift: float = 0.5,
|
142 |
+
max_shift: float = 1.16,
|
143 |
+
# max_clip: float = 1.5,
|
144 |
+
):
|
145 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) # 0.000169270833
|
146 |
+
b = base_shift - m * base_seq_len # 0.5-0.0433333332
|
147 |
+
mu = image_seq_len * m + b
|
148 |
+
# mu = min(mu, max_clip)
|
149 |
+
return mu
|
150 |
+
|
151 |
+
|
152 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
153 |
+
def retrieve_timesteps(
|
154 |
+
scheduler,
|
155 |
+
num_inference_steps: Optional[int] = None,
|
156 |
+
device: Optional[Union[str, torch.device]] = None,
|
157 |
+
timesteps: Optional[List[int]] = None,
|
158 |
+
sigmas: Optional[List[float]] = None,
|
159 |
+
**kwargs,
|
160 |
+
):
|
161 |
+
"""
|
162 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
163 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
scheduler (`SchedulerMixin`):
|
167 |
+
The scheduler to get timesteps from.
|
168 |
+
num_inference_steps (`int`):
|
169 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
170 |
+
must be `None`.
|
171 |
+
device (`str` or `torch.device`, *optional*):
|
172 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
173 |
+
timesteps (`List[int]`, *optional*):
|
174 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
175 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
176 |
+
sigmas (`List[float]`, *optional*):
|
177 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
178 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
182 |
+
second element is the number of inference steps.
|
183 |
+
"""
|
184 |
+
if timesteps is not None and sigmas is not None:
|
185 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
186 |
+
if timesteps is not None:
|
187 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
188 |
+
if not accepts_timesteps:
|
189 |
+
raise ValueError(
|
190 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
191 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
192 |
+
)
|
193 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
194 |
+
timesteps = scheduler.timesteps
|
195 |
+
num_inference_steps = len(timesteps)
|
196 |
+
elif sigmas is not None:
|
197 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
198 |
+
if not accept_sigmas:
|
199 |
+
raise ValueError(
|
200 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
201 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
202 |
+
)
|
203 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
204 |
+
timesteps = scheduler.timesteps
|
205 |
+
num_inference_steps = len(timesteps)
|
206 |
+
else:
|
207 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
208 |
+
timesteps = scheduler.timesteps
|
209 |
+
return timesteps, num_inference_steps
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
class OneDiffusionPipeline(DiffusionPipeline):
|
214 |
+
r"""
|
215 |
+
Pipeline for text-to-image generation using OneDiffuser.
|
216 |
+
|
217 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
218 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
219 |
+
|
220 |
+
Args:
|
221 |
+
transformer ([`NextDiT`]):
|
222 |
+
Conditional transformer (NextDiT) architecture to denoise the encoded image latents.
|
223 |
+
vae ([`AutoencoderKL`]):
|
224 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
225 |
+
text_encoder ([`T5EncoderModel`]):
|
226 |
+
Frozen text-encoder. OneDiffuser uses the T5 model as text encoder.
|
227 |
+
tokenizer (`T5Tokenizer`):
|
228 |
+
Tokenizer of class T5Tokenizer.
|
229 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
230 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
231 |
+
"""
|
232 |
+
|
233 |
+
def __init__(
|
234 |
+
self,
|
235 |
+
transformer: NextDiT,
|
236 |
+
vae: AutoencoderKL,
|
237 |
+
text_encoder: T5EncoderModel,
|
238 |
+
tokenizer: T5Tokenizer,
|
239 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
240 |
+
):
|
241 |
+
super().__init__()
|
242 |
+
self.register_modules(
|
243 |
+
transformer=transformer,
|
244 |
+
vae=vae,
|
245 |
+
text_encoder=text_encoder,
|
246 |
+
tokenizer=tokenizer,
|
247 |
+
scheduler=scheduler,
|
248 |
+
)
|
249 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
250 |
+
self.image_processor = VaeImageProcessorOneDiffuser(vae_scale_factor=self.vae_scale_factor)
|
251 |
+
|
252 |
+
def enable_vae_slicing(self):
|
253 |
+
self.vae.enable_slicing()
|
254 |
+
|
255 |
+
def disable_vae_slicing(self):
|
256 |
+
self.vae.disable_slicing()
|
257 |
+
|
258 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
259 |
+
if is_accelerate_available():
|
260 |
+
from accelerate import cpu_offload
|
261 |
+
else:
|
262 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
263 |
+
|
264 |
+
device = torch.device(f"cuda:{gpu_id}")
|
265 |
+
|
266 |
+
for cpu_offloaded_model in [self.transformer, self.text_encoder, self.vae]:
|
267 |
+
if cpu_offloaded_model is not None:
|
268 |
+
cpu_offload(cpu_offloaded_model, device)
|
269 |
+
|
270 |
+
@property
|
271 |
+
def _execution_device(self):
|
272 |
+
if self.device != torch.device("meta") or not hasattr(self.transformer, "_hf_hook"):
|
273 |
+
return self.device
|
274 |
+
for module in self.transformer.modules():
|
275 |
+
if (
|
276 |
+
hasattr(module, "_hf_hook")
|
277 |
+
and hasattr(module._hf_hook, "execution_device")
|
278 |
+
and module._hf_hook.execution_device is not None
|
279 |
+
):
|
280 |
+
return torch.device(module._hf_hook.execution_device)
|
281 |
+
return self.device
|
282 |
+
|
283 |
+
def encode_prompt(
|
284 |
+
self,
|
285 |
+
prompt,
|
286 |
+
device,
|
287 |
+
num_images_per_prompt,
|
288 |
+
do_classifier_free_guidance,
|
289 |
+
negative_prompt=None,
|
290 |
+
max_length=300,
|
291 |
+
):
|
292 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
293 |
+
|
294 |
+
text_inputs = self.tokenizer(
|
295 |
+
prompt,
|
296 |
+
padding="max_length",
|
297 |
+
max_length=max_length,
|
298 |
+
truncation=True,
|
299 |
+
add_special_tokens=True,
|
300 |
+
return_tensors="pt",
|
301 |
+
)
|
302 |
+
text_input_ids = text_inputs.input_ids
|
303 |
+
attention_mask = text_inputs.attention_mask
|
304 |
+
|
305 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
306 |
+
|
307 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
308 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
309 |
+
logger.warning(
|
310 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
311 |
+
f" {max_length} tokens: {removed_text}"
|
312 |
+
)
|
313 |
+
|
314 |
+
text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))
|
315 |
+
prompt_embeds = text_encoder_output[0].to(torch.float32)
|
316 |
+
|
317 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
318 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
319 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
320 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
321 |
+
|
322 |
+
# duplicate attention mask for each generation per prompt
|
323 |
+
attention_mask = attention_mask.repeat(1, num_images_per_prompt)
|
324 |
+
attention_mask = attention_mask.view(bs_embed * num_images_per_prompt, -1)
|
325 |
+
|
326 |
+
# get unconditional embeddings for classifier free guidance
|
327 |
+
if do_classifier_free_guidance:
|
328 |
+
uncond_tokens: List[str]
|
329 |
+
if negative_prompt is None:
|
330 |
+
uncond_tokens = [""] * batch_size
|
331 |
+
elif type(prompt) is not type(negative_prompt):
|
332 |
+
raise TypeError(
|
333 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
334 |
+
f" {type(prompt)}."
|
335 |
+
)
|
336 |
+
elif isinstance(negative_prompt, str):
|
337 |
+
uncond_tokens = [negative_prompt]
|
338 |
+
elif batch_size != len(negative_prompt):
|
339 |
+
raise ValueError(
|
340 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
341 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
342 |
+
" the batch size of `prompt`."
|
343 |
+
)
|
344 |
+
else:
|
345 |
+
uncond_tokens = negative_prompt
|
346 |
+
|
347 |
+
max_length = text_input_ids.shape[-1]
|
348 |
+
uncond_input = self.tokenizer(
|
349 |
+
uncond_tokens,
|
350 |
+
padding="max_length",
|
351 |
+
max_length=max_length,
|
352 |
+
truncation=True,
|
353 |
+
return_tensors="pt",
|
354 |
+
)
|
355 |
+
|
356 |
+
uncond_encoder_output = self.text_encoder(uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device))
|
357 |
+
negative_prompt_embeds = uncond_encoder_output[0].to(torch.float32)
|
358 |
+
|
359 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
360 |
+
seq_len = negative_prompt_embeds.shape[1]
|
361 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
362 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
363 |
+
|
364 |
+
# duplicate unconditional attention mask for each generation per prompt
|
365 |
+
uncond_attention_mask = uncond_input.attention_mask.repeat(1, num_images_per_prompt)
|
366 |
+
uncond_attention_mask = uncond_attention_mask.view(batch_size * num_images_per_prompt, -1)
|
367 |
+
|
368 |
+
# For classifier free guidance, we need to do two forward passes.
|
369 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
370 |
+
# to avoid doing two forward passes
|
371 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
372 |
+
attention_mask = torch.cat([uncond_attention_mask, attention_mask])
|
373 |
+
|
374 |
+
return prompt_embeds.to(device), attention_mask.to(device)
|
375 |
+
|
376 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
377 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
378 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
379 |
+
raise ValueError(
|
380 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
381 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
382 |
+
)
|
383 |
+
|
384 |
+
if latents is None:
|
385 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
386 |
+
else:
|
387 |
+
latents = latents.to(device)
|
388 |
+
|
389 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
390 |
+
latents = latents * self.scheduler.init_noise_sigma
|
391 |
+
return latents
|
392 |
+
|
393 |
+
@torch.no_grad()
|
394 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
395 |
+
def __call__(
|
396 |
+
self,
|
397 |
+
prompt: Union[str, List[str]] = None,
|
398 |
+
height: Optional[int] = None,
|
399 |
+
width: Optional[int] = None,
|
400 |
+
num_inference_steps: int = 50,
|
401 |
+
guidance_scale: float = 5.0,
|
402 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
403 |
+
num_images_per_prompt: Optional[int] = 1,
|
404 |
+
eta: float = 0.0,
|
405 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
406 |
+
latents: Optional[torch.FloatTensor] = None,
|
407 |
+
output_type: Optional[str] = "pil",
|
408 |
+
return_dict: bool = True,
|
409 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
410 |
+
callback_steps: int = 1,
|
411 |
+
forward_kwargs: Optional[Dict[str, Any]] = {},
|
412 |
+
**kwargs,
|
413 |
+
):
|
414 |
+
r"""
|
415 |
+
Function invoked when calling the pipeline for generation.
|
416 |
+
|
417 |
+
Args:
|
418 |
+
prompt (`str` or `List[str]`, *optional*):
|
419 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
420 |
+
height (`int`, *optional*, defaults to self.transformer.config.sample_size):
|
421 |
+
The height in pixels of the generated image.
|
422 |
+
width (`int`, *optional*, defaults to self.transformer.config.sample_size):
|
423 |
+
The width in pixels of the generated image.
|
424 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
425 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
426 |
+
expense of slower inference.
|
427 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
428 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
429 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
430 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
431 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
432 |
+
usually at the expense of lower image quality.
|
433 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
434 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
435 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
436 |
+
less than `1`).
|
437 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
438 |
+
The number of images to generate per prompt.
|
439 |
+
eta (`float`, *optional*, defaults to 0.0):
|
440 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
441 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
442 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
443 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
444 |
+
to make generation deterministic.
|
445 |
+
latents (`torch.FloatTensor`, *optional*):
|
446 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
447 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
448 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
449 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
450 |
+
The output format of the generate image. Choose between
|
451 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
452 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
453 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
454 |
+
plain tuple.
|
455 |
+
callback (`Callable`, *optional*):
|
456 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
457 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
458 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
459 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
460 |
+
called at every step.
|
461 |
+
|
462 |
+
Examples:
|
463 |
+
|
464 |
+
Returns:
|
465 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
466 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
467 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
468 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
469 |
+
(nsfw) content, according to the `safety_checker`.
|
470 |
+
"""
|
471 |
+
height = height or self.transformer.config.input_size[-2] * 8 # TODO: Hardcoded downscale factor of vae
|
472 |
+
width = width or self.transformer.config.input_size[-1] * 8
|
473 |
+
|
474 |
+
# check inputs. Raise error if not correct
|
475 |
+
self.check_inputs(prompt, height, width, callback_steps)
|
476 |
+
|
477 |
+
# define call parameters
|
478 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
479 |
+
device = self._execution_device
|
480 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
481 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf
|
482 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
483 |
+
|
484 |
+
encoder_hidden_states, encoder_attention_mask = self.encode_prompt(
|
485 |
+
prompt,
|
486 |
+
device,
|
487 |
+
num_images_per_prompt,
|
488 |
+
do_classifier_free_guidance,
|
489 |
+
negative_prompt,
|
490 |
+
)
|
491 |
+
|
492 |
+
# set timesteps
|
493 |
+
# # self.scheduler.set_timesteps(num_inference_steps, device=device)
|
494 |
+
# timesteps = self.scheduler.timesteps
|
495 |
+
timesteps = None
|
496 |
+
|
497 |
+
# prepare latent variables
|
498 |
+
num_channels_latents = self.transformer.config.in_channels
|
499 |
+
latents = self.prepare_latents(
|
500 |
+
batch_size * num_images_per_prompt,
|
501 |
+
num_channels_latents,
|
502 |
+
height,
|
503 |
+
width,
|
504 |
+
self.dtype,
|
505 |
+
device,
|
506 |
+
generator,
|
507 |
+
latents,
|
508 |
+
)
|
509 |
+
|
510 |
+
# prepare extra step kwargs
|
511 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
512 |
+
|
513 |
+
# 5. Prepare timesteps
|
514 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
515 |
+
image_seq_len = latents.shape[-1] * latents.shape[-2] / self.transformer.config.patch_size[-1] / self.transformer.config.patch_size[-2]
|
516 |
+
mu = calculate_shift(
|
517 |
+
image_seq_len,
|
518 |
+
self.scheduler.config.base_image_seq_len,
|
519 |
+
self.scheduler.config.max_image_seq_len,
|
520 |
+
self.scheduler.config.base_shift,
|
521 |
+
self.scheduler.config.max_shift,
|
522 |
+
)
|
523 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
524 |
+
self.scheduler,
|
525 |
+
num_inference_steps,
|
526 |
+
device,
|
527 |
+
timesteps,
|
528 |
+
sigmas,
|
529 |
+
mu=mu,
|
530 |
+
)
|
531 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
532 |
+
self._num_timesteps = len(timesteps)
|
533 |
+
|
534 |
+
# denoising loop
|
535 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
536 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
537 |
+
for i, t in enumerate(timesteps):
|
538 |
+
# expand the latents if we are doing classifier free guidance
|
539 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
540 |
+
# latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
541 |
+
|
542 |
+
# predict the noise residual
|
543 |
+
noise_pred = self.transformer(
|
544 |
+
samples=latent_model_input.to(self.dtype),
|
545 |
+
timesteps=torch.tensor([t] * latent_model_input.shape[0], device=device),
|
546 |
+
encoder_hidden_states=encoder_hidden_states.to(self.dtype),
|
547 |
+
encoder_attention_mask=encoder_attention_mask,
|
548 |
+
**forward_kwargs
|
549 |
+
)
|
550 |
+
|
551 |
+
# perform guidance
|
552 |
+
if do_classifier_free_guidance:
|
553 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
554 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
555 |
+
|
556 |
+
# compute the previous noisy sample x_t -> x_t-1
|
557 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
558 |
+
|
559 |
+
# call the callback, if provided
|
560 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
561 |
+
progress_bar.update()
|
562 |
+
if callback is not None and i % callback_steps == 0:
|
563 |
+
callback(i, t, latents)
|
564 |
+
|
565 |
+
# scale and decode the image latents with vae
|
566 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
567 |
+
if latents.ndim == 5:
|
568 |
+
latents = latents.squeeze(1)
|
569 |
+
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
570 |
+
|
571 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
572 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
573 |
+
|
574 |
+
if output_type == "pil":
|
575 |
+
image = self.numpy_to_pil(image)
|
576 |
+
|
577 |
+
if not return_dict:
|
578 |
+
return (image, None)
|
579 |
+
|
580 |
+
return OneDiffusionPipelineOutput(images=image)
|
581 |
+
|
582 |
+
@torch.no_grad()
|
583 |
+
def img2img(
|
584 |
+
self,
|
585 |
+
prompt: Union[str, List[str]] = None,
|
586 |
+
image: Union[PIL.Image.Image, List[PIL.Image.Image]] = None,
|
587 |
+
height: Optional[int] = None,
|
588 |
+
width: Optional[int] = None,
|
589 |
+
num_inference_steps: int = 50,
|
590 |
+
guidance_scale: float = 5.0,
|
591 |
+
denoise_mask: Optional[List[int]] = [1, 0],
|
592 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
593 |
+
num_images_per_prompt: Optional[int] = 1,
|
594 |
+
eta: float = 0.0,
|
595 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
596 |
+
latents: Optional[torch.FloatTensor] = None,
|
597 |
+
output_type: Optional[str] = "pil",
|
598 |
+
return_dict: bool = True,
|
599 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
600 |
+
callback_steps: int = 1,
|
601 |
+
do_crop: bool = True,
|
602 |
+
is_multiview: bool = False,
|
603 |
+
multiview_azimuths: Optional[List[int]] = [0, 30, 60, 90],
|
604 |
+
multiview_elevations: Optional[List[int]] = [0, 0, 0, 0],
|
605 |
+
multiview_distances: float = 1.7,
|
606 |
+
multiview_c2ws: Optional[List[torch.Tensor]] = None,
|
607 |
+
multiview_intrinsics: Optional[torch.Tensor] = None,
|
608 |
+
multiview_focal_length: float = 1.3887,
|
609 |
+
forward_kwargs: Optional[Dict[str, Any]] = {},
|
610 |
+
noise_scale: float = 1.0,
|
611 |
+
**kwargs,
|
612 |
+
):
|
613 |
+
# Convert single image to list for consistent handling
|
614 |
+
if isinstance(image, PIL.Image.Image):
|
615 |
+
image = [image]
|
616 |
+
|
617 |
+
if height is None or width is None:
|
618 |
+
closest_ar = get_closest_ratio(height=image[0].size[1], width=image[0].size[0], ratios=ASPECT_RATIO_512)
|
619 |
+
height, width = int(closest_ar[0][0]), int(closest_ar[0][1])
|
620 |
+
|
621 |
+
if not isinstance(multiview_distances, list) and not isinstance(multiview_distances, tuple):
|
622 |
+
multiview_distances = [multiview_distances] * len(multiview_azimuths)
|
623 |
+
|
624 |
+
# height = height or self.transformer.config.input_size[-2] * 8 # TODO: Hardcoded downscale factor of vae
|
625 |
+
# width = width or self.transformer.config.input_size[-1] * 8
|
626 |
+
|
627 |
+
# 1. check inputs. Raise error if not correct
|
628 |
+
self.check_inputs(prompt, height, width, callback_steps)
|
629 |
+
|
630 |
+
# Additional input validation for image list
|
631 |
+
if not all(isinstance(img, PIL.Image.Image) for img in image):
|
632 |
+
raise ValueError("All elements in image list must be PIL.Image objects")
|
633 |
+
|
634 |
+
# 2. define call parameters
|
635 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
636 |
+
device = self._execution_device
|
637 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
638 |
+
|
639 |
+
# 3. Encode input prompt
|
640 |
+
encoder_hidden_states, encoder_attention_mask = self.encode_prompt(
|
641 |
+
prompt,
|
642 |
+
device,
|
643 |
+
num_images_per_prompt,
|
644 |
+
do_classifier_free_guidance,
|
645 |
+
negative_prompt,
|
646 |
+
)
|
647 |
+
|
648 |
+
# 4. Preprocess all images
|
649 |
+
if image is not None and len(image) > 0:
|
650 |
+
processed_image = self.image_processor.preprocess(image, height=height, width=width, do_crop=do_crop)
|
651 |
+
else:
|
652 |
+
processed_image = None
|
653 |
+
|
654 |
+
# # Stack processed images along the sequence dimension
|
655 |
+
# if len(processed_images) > 1:
|
656 |
+
# processed_image = torch.cat(processed_images, dim=0)
|
657 |
+
# else:
|
658 |
+
# processed_image = processed_images[0]
|
659 |
+
|
660 |
+
timesteps = None
|
661 |
+
|
662 |
+
# 6. prepare latent variables
|
663 |
+
num_channels_latents = self.transformer.config.in_channels
|
664 |
+
if processed_image is not None:
|
665 |
+
cond_latents = self.prepare_latents(
|
666 |
+
batch_size * num_images_per_prompt,
|
667 |
+
num_channels_latents,
|
668 |
+
height,
|
669 |
+
width,
|
670 |
+
self.dtype,
|
671 |
+
device,
|
672 |
+
generator,
|
673 |
+
latents,
|
674 |
+
image=processed_image,
|
675 |
+
)
|
676 |
+
else:
|
677 |
+
cond_latents = None
|
678 |
+
|
679 |
+
# 7. prepare extra step kwargs
|
680 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
681 |
+
denoise_mask = torch.tensor(denoise_mask, device=device)
|
682 |
+
denoise_indices = torch.where(denoise_mask == 1)[0]
|
683 |
+
cond_indices = torch.where(denoise_mask == 0)[0]
|
684 |
+
seq_length = denoise_mask.shape[0]
|
685 |
+
|
686 |
+
latents = self.prepare_init_latents(
|
687 |
+
batch_size * num_images_per_prompt,
|
688 |
+
seq_length,
|
689 |
+
num_channels_latents,
|
690 |
+
height,
|
691 |
+
width,
|
692 |
+
self.dtype,
|
693 |
+
device,
|
694 |
+
generator,
|
695 |
+
)
|
696 |
+
|
697 |
+
# 5. Prepare timesteps
|
698 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
699 |
+
# image_seq_len = latents.shape[1] * latents.shape[-1] * latents.shape[-2] / self.transformer.config.patch_size[-1] / self.transformer.config.patch_size[-2]
|
700 |
+
image_seq_len = noise_scale * sum(denoise_mask) * latents.shape[-1] * latents.shape[-2] / self.transformer.config.patch_size[-1] / self.transformer.config.patch_size[-2]
|
701 |
+
# image_seq_len = 256
|
702 |
+
mu = calculate_shift(
|
703 |
+
image_seq_len,
|
704 |
+
self.scheduler.config.base_image_seq_len,
|
705 |
+
self.scheduler.config.max_image_seq_len,
|
706 |
+
self.scheduler.config.base_shift,
|
707 |
+
self.scheduler.config.max_shift,
|
708 |
+
)
|
709 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
710 |
+
self.scheduler,
|
711 |
+
num_inference_steps,
|
712 |
+
device,
|
713 |
+
timesteps,
|
714 |
+
sigmas,
|
715 |
+
mu=mu,
|
716 |
+
)
|
717 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
718 |
+
self._num_timesteps = len(timesteps)
|
719 |
+
|
720 |
+
if is_multiview:
|
721 |
+
cond_indices_images = [index // 2 for index in cond_indices if index % 2 == 0]
|
722 |
+
cond_indices_rays = [index // 2 for index in cond_indices if index % 2 == 1]
|
723 |
+
|
724 |
+
multiview_elevations = [element for element in multiview_elevations if element is not None]
|
725 |
+
multiview_azimuths = [element for element in multiview_azimuths if element is not None]
|
726 |
+
multiview_distances = [element for element in multiview_distances if element is not None]
|
727 |
+
|
728 |
+
if multiview_c2ws is None:
|
729 |
+
multiview_c2ws = [
|
730 |
+
torch.tensor(create_c2w_matrix(azimuth, elevation, distance)) for azimuth, elevation, distance in zip(multiview_azimuths, multiview_elevations, multiview_distances)
|
731 |
+
]
|
732 |
+
c2ws = torch.stack(multiview_c2ws).float()
|
733 |
+
else:
|
734 |
+
c2ws = torch.Tensor(multiview_c2ws).float()
|
735 |
+
|
736 |
+
c2ws[:, 0:3, 1:3] *= -1
|
737 |
+
c2ws = c2ws[:, [1, 0, 2, 3], :]
|
738 |
+
c2ws[:, 2, :] *= -1
|
739 |
+
|
740 |
+
w2cs = torch.inverse(c2ws)
|
741 |
+
if multiview_intrinsics is None:
|
742 |
+
multiview_intrinsics = torch.Tensor([[[multiview_focal_length, 0, 0.5], [0, multiview_focal_length, 0.5], [0, 0, 1]]]).repeat(c2ws.shape[0], 1, 1)
|
743 |
+
K = multiview_intrinsics
|
744 |
+
Rs = w2cs[:, :3, :3]
|
745 |
+
Ts = w2cs[:, :3, 3]
|
746 |
+
sizes = torch.Tensor([[1, 1]]).repeat(c2ws.shape[0], 1)
|
747 |
+
|
748 |
+
assert height == width
|
749 |
+
cond_rays = calculate_rays(K, sizes, Rs, Ts, height // 8)
|
750 |
+
cond_rays = cond_rays.reshape(-1, height // 8, width // 8, 6)
|
751 |
+
# padding = (0, 10)
|
752 |
+
# cond_rays = torch.nn.functional.pad(cond_rays, padding, "constant", 0)
|
753 |
+
cond_rays = torch.cat([cond_rays, cond_rays, cond_rays[..., :4]], dim=-1) * 1.658
|
754 |
+
cond_rays = cond_rays[None].repeat(batch_size * num_images_per_prompt, 1, 1, 1, 1)
|
755 |
+
cond_rays = cond_rays.permute(0, 1, 4, 2, 3)
|
756 |
+
cond_rays = cond_rays.to(device, dtype=self.dtype)
|
757 |
+
|
758 |
+
latents = einops.rearrange(latents, "b (f n) c h w -> b f n c h w", n=2)
|
759 |
+
if cond_latents is not None:
|
760 |
+
latents[:, cond_indices_images, 0] = cond_latents
|
761 |
+
latents[:, cond_indices_rays, 1] = cond_rays
|
762 |
+
latents = einops.rearrange(latents, "b f n c h w -> b (f n) c h w")
|
763 |
+
else:
|
764 |
+
if cond_latents is not None:
|
765 |
+
latents[:, cond_indices] = cond_latents
|
766 |
+
|
767 |
+
# denoising loop
|
768 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
769 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
770 |
+
for i, t in enumerate(timesteps):
|
771 |
+
# expand the latents if we are doing classifier free guidance
|
772 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
773 |
+
input_t = torch.broadcast_to(einops.repeat(torch.Tensor([t]).to(device), "1 -> 1 f 1 1 1", f=latent_model_input.shape[1]), latent_model_input.shape).clone()
|
774 |
+
|
775 |
+
if is_multiview:
|
776 |
+
input_t = einops.rearrange(input_t, "b (f n) c h w -> b f n c h w", n=2)
|
777 |
+
input_t[:, cond_indices_images, 0] = self.scheduler.timesteps[-1]
|
778 |
+
input_t[:, cond_indices_rays, 1] = self.scheduler.timesteps[-1]
|
779 |
+
input_t = einops.rearrange(input_t, "b f n c h w -> b (f n) c h w")
|
780 |
+
else:
|
781 |
+
input_t[:, cond_indices] = self.scheduler.timesteps[-1]
|
782 |
+
|
783 |
+
# predict the noise residual
|
784 |
+
noise_pred = self.transformer(
|
785 |
+
samples=latent_model_input.to(self.dtype),
|
786 |
+
timesteps=input_t,
|
787 |
+
encoder_hidden_states=encoder_hidden_states.to(self.dtype),
|
788 |
+
encoder_attention_mask=encoder_attention_mask,
|
789 |
+
**forward_kwargs
|
790 |
+
)
|
791 |
+
|
792 |
+
# perform guidance
|
793 |
+
if do_classifier_free_guidance:
|
794 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
795 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
796 |
+
|
797 |
+
# compute the previous noisy sample x_t -> x_t-1
|
798 |
+
bs, n_frame = noise_pred.shape[:2]
|
799 |
+
noise_pred = einops.rearrange(noise_pred, "b f c h w -> (b f) c h w")
|
800 |
+
latents = einops.rearrange(latents, "b f c h w -> (b f) c h w")
|
801 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
802 |
+
latents = einops.rearrange(latents, "(b f) c h w -> b f c h w", b=bs, f=n_frame)
|
803 |
+
if is_multiview:
|
804 |
+
latents = einops.rearrange(latents, "b (f n) c h w -> b f n c h w", n=2)
|
805 |
+
if cond_latents is not None:
|
806 |
+
latents[:, cond_indices_images, 0] = cond_latents
|
807 |
+
latents[:, cond_indices_rays, 1] = cond_rays
|
808 |
+
latents = einops.rearrange(latents, "b f n c h w -> b (f n) c h w")
|
809 |
+
else:
|
810 |
+
if cond_latents is not None:
|
811 |
+
latents[:, cond_indices] = cond_latents
|
812 |
+
|
813 |
+
# call the callback, if provided
|
814 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
815 |
+
progress_bar.update()
|
816 |
+
if callback is not None and i % callback_steps == 0:
|
817 |
+
callback(i, t, latents)
|
818 |
+
|
819 |
+
decoded_latents = latents / 1.658
|
820 |
+
# scale and decode the image latents with vae
|
821 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
822 |
+
if latents.ndim == 5:
|
823 |
+
latents = latents[:, denoise_indices]
|
824 |
+
latents = einops.rearrange(latents, "b f c h w -> (b f) c h w")
|
825 |
+
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
826 |
+
|
827 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
828 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
829 |
+
|
830 |
+
if output_type == "pil":
|
831 |
+
image = self.numpy_to_pil(image)
|
832 |
+
|
833 |
+
if not return_dict:
|
834 |
+
return (image, None)
|
835 |
+
|
836 |
+
return OneDiffusionPipelineOutput(images=image, latents=decoded_latents)
|
837 |
+
|
838 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
839 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
840 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
841 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
842 |
+
# and should be between [0, 1]
|
843 |
+
|
844 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
845 |
+
extra_step_kwargs = {}
|
846 |
+
if accepts_eta:
|
847 |
+
extra_step_kwargs["eta"] = eta
|
848 |
+
|
849 |
+
# check if the scheduler accepts generator
|
850 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
851 |
+
if accepts_generator:
|
852 |
+
extra_step_kwargs["generator"] = generator
|
853 |
+
return extra_step_kwargs
|
854 |
+
|
855 |
+
def check_inputs(self, prompt, height, width, callback_steps):
|
856 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
857 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
858 |
+
|
859 |
+
if height % 16 != 0 or width % 16 != 0:
|
860 |
+
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
861 |
+
|
862 |
+
if (callback_steps is None) or (
|
863 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
864 |
+
):
|
865 |
+
raise ValueError(
|
866 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
867 |
+
f" {type(callback_steps)}."
|
868 |
+
)
|
869 |
+
|
870 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
871 |
+
# get the original timestep using init_timestep
|
872 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
873 |
+
|
874 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
875 |
+
timesteps = self.scheduler.timesteps[t_start:]
|
876 |
+
|
877 |
+
return timesteps, num_inference_steps - t_start
|
878 |
+
|
879 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, image=None):
|
880 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
881 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
882 |
+
raise ValueError(
|
883 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
884 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
885 |
+
)
|
886 |
+
|
887 |
+
if latents is None:
|
888 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
889 |
+
else:
|
890 |
+
latents = latents.to(device)
|
891 |
+
|
892 |
+
if image is None:
|
893 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
894 |
+
# latents = latents * self.scheduler.init_noise_sigma
|
895 |
+
return latents
|
896 |
+
|
897 |
+
image = image.to(device=device, dtype=dtype)
|
898 |
+
|
899 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
900 |
+
raise ValueError(
|
901 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
902 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
903 |
+
)
|
904 |
+
elif isinstance(generator, list):
|
905 |
+
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
|
906 |
+
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
|
907 |
+
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
|
908 |
+
raise ValueError(
|
909 |
+
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
|
910 |
+
)
|
911 |
+
init_latents = [
|
912 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
913 |
+
for i in range(batch_size)
|
914 |
+
]
|
915 |
+
init_latents = torch.cat(init_latents, dim=0)
|
916 |
+
else:
|
917 |
+
init_latents = retrieve_latents(self.vae.encode(image.to(self.vae.dtype)), generator=generator)
|
918 |
+
|
919 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
920 |
+
init_latents = init_latents.to(device=device, dtype=dtype)
|
921 |
+
|
922 |
+
init_latents = einops.rearrange(init_latents, "(bs views) c h w -> bs views c h w", bs=batch_size, views=init_latents.shape[0]//batch_size)
|
923 |
+
# latents = einops.rearrange(latents, "b c h w -> b 1 c h w")
|
924 |
+
# latents = torch.concat([latents, init_latents], dim=1)
|
925 |
+
return init_latents
|
926 |
+
|
927 |
+
def prepare_init_latents(self, batch_size, seq_length, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
928 |
+
shape = (batch_size, seq_length, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
929 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
930 |
+
raise ValueError(
|
931 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
932 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
933 |
+
)
|
934 |
+
|
935 |
+
if latents is None:
|
936 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
937 |
+
else:
|
938 |
+
latents = latents.to(device)
|
939 |
+
|
940 |
+
return latents
|
941 |
+
|
942 |
+
@torch.no_grad()
|
943 |
+
def generate(
|
944 |
+
self,
|
945 |
+
prompt: Union[str, List[str]],
|
946 |
+
num_inference_steps: int = 50,
|
947 |
+
guidance_scale: float = 5.0,
|
948 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
949 |
+
num_images_per_prompt: Optional[int] = 1,
|
950 |
+
height: Optional[int] = None,
|
951 |
+
width: Optional[int] = None,
|
952 |
+
eta: float = 0.0,
|
953 |
+
generator: Optional[torch.Generator] = None,
|
954 |
+
latents: Optional[torch.FloatTensor] = None,
|
955 |
+
output_type: Optional[str] = "pil",
|
956 |
+
return_dict: bool = True,
|
957 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
958 |
+
callback_steps: Optional[int] = 1,
|
959 |
+
):
|
960 |
+
"""
|
961 |
+
Function for image generation using the OneDiffusionPipeline.
|
962 |
+
"""
|
963 |
+
return self(
|
964 |
+
prompt=prompt,
|
965 |
+
num_inference_steps=num_inference_steps,
|
966 |
+
guidance_scale=guidance_scale,
|
967 |
+
negative_prompt=negative_prompt,
|
968 |
+
num_images_per_prompt=num_images_per_prompt,
|
969 |
+
height=height,
|
970 |
+
width=width,
|
971 |
+
eta=eta,
|
972 |
+
generator=generator,
|
973 |
+
latents=latents,
|
974 |
+
output_type=output_type,
|
975 |
+
return_dict=return_dict,
|
976 |
+
callback=callback,
|
977 |
+
callback_steps=callback_steps,
|
978 |
+
)
|
979 |
+
|
980 |
+
@staticmethod
|
981 |
+
def numpy_to_pil(images):
|
982 |
+
"""
|
983 |
+
Convert a numpy image or a batch of images to a PIL image.
|
984 |
+
"""
|
985 |
+
if images.ndim == 3:
|
986 |
+
images = images[None, ...]
|
987 |
+
images = (images * 255).round().astype("uint8")
|
988 |
+
if images.shape[-1] == 1:
|
989 |
+
# special case for grayscale (single channel) images
|
990 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
991 |
+
else:
|
992 |
+
pil_images = [Image.fromarray(image) for image in images]
|
993 |
+
|
994 |
+
return pil_images
|
995 |
+
|
996 |
+
@classmethod
|
997 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
998 |
+
model_path = pretrained_model_name_or_path
|
999 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
1000 |
+
force_download = kwargs.pop("force_download", False)
|
1001 |
+
proxies = kwargs.pop("proxies", None)
|
1002 |
+
local_files_only = kwargs.pop("local_files_only", None)
|
1003 |
+
token = kwargs.pop("token", None)
|
1004 |
+
revision = kwargs.pop("revision", None)
|
1005 |
+
from_flax = kwargs.pop("from_flax", False)
|
1006 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
1007 |
+
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
1008 |
+
custom_revision = kwargs.pop("custom_revision", None)
|
1009 |
+
provider = kwargs.pop("provider", None)
|
1010 |
+
sess_options = kwargs.pop("sess_options", None)
|
1011 |
+
device_map = kwargs.pop("device_map", None)
|
1012 |
+
max_memory = kwargs.pop("max_memory", None)
|
1013 |
+
offload_folder = kwargs.pop("offload_folder", None)
|
1014 |
+
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
1015 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
1016 |
+
variant = kwargs.pop("variant", None)
|
1017 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
1018 |
+
use_onnx = kwargs.pop("use_onnx", None)
|
1019 |
+
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
1020 |
+
|
1021 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
1022 |
+
low_cpu_mem_usage = False
|
1023 |
+
logger.warning(
|
1024 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
1025 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
1026 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
1027 |
+
" install accelerate\n```\n."
|
1028 |
+
)
|
1029 |
+
|
1030 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
1031 |
+
raise NotImplementedError(
|
1032 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
1033 |
+
" `low_cpu_mem_usage=False`."
|
1034 |
+
)
|
1035 |
+
|
1036 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
1037 |
+
raise NotImplementedError(
|
1038 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
1039 |
+
" `device_map=None`."
|
1040 |
+
)
|
1041 |
+
|
1042 |
+
if device_map is not None and not is_accelerate_available():
|
1043 |
+
raise NotImplementedError(
|
1044 |
+
"Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
|
1045 |
+
)
|
1046 |
+
|
1047 |
+
if device_map is not None and not isinstance(device_map, str):
|
1048 |
+
raise ValueError("`device_map` must be a string.")
|
1049 |
+
|
1050 |
+
if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
|
1051 |
+
raise NotImplementedError(
|
1052 |
+
f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
|
1053 |
+
)
|
1054 |
+
|
1055 |
+
if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
|
1056 |
+
if is_accelerate_version("<", "0.28.0"):
|
1057 |
+
raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")
|
1058 |
+
|
1059 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
1060 |
+
raise ValueError(
|
1061 |
+
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
1062 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
1063 |
+
)
|
1064 |
+
|
1065 |
+
transformer = NextDiT.from_pretrained(f"{model_path}", subfolder="transformer", torch_dtype=torch.float32, cache_dir=cache_dir)
|
1066 |
+
vae = AutoencoderKL.from_pretrained(f"{model_path}", subfolder="vae", cache_dir=cache_dir)
|
1067 |
+
text_encoder = T5EncoderModel.from_pretrained(f"{model_path}", subfolder="text_encoder", torch_dtype=torch.float16, cache_dir=cache_dir)
|
1068 |
+
tokenizer = T5Tokenizer.from_pretrained(model_path, subfolder="tokenizer", cache_dir=cache_dir)
|
1069 |
+
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler", cache_dir=cache_dir)
|
1070 |
+
|
1071 |
+
pipeline = cls(
|
1072 |
+
transformer=transformer,
|
1073 |
+
vae=vae,
|
1074 |
+
text_encoder=text_encoder,
|
1075 |
+
tokenizer=tokenizer,
|
1076 |
+
scheduler=scheduler,
|
1077 |
+
**kwargs
|
1078 |
+
)
|
1079 |
+
|
1080 |
+
return pipeline
|
onediffusion/models/denoiser/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from . import (
|
2 |
+
nextdit
|
3 |
+
)
|
onediffusion/models/denoiser/nextdit/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .modeling_nextdit import NextDiT
|
onediffusion/models/denoiser/nextdit/layers.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from typing import Callable, Optional
|
6 |
+
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
try:
|
13 |
+
from apex.normalization import FusedRMSNorm as RMSNorm
|
14 |
+
except ImportError:
|
15 |
+
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
16 |
+
|
17 |
+
|
18 |
+
class RMSNorm(torch.nn.Module):
|
19 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
20 |
+
"""
|
21 |
+
Initialize the RMSNorm normalization layer.
|
22 |
+
Args:
|
23 |
+
dim (int): The dimension of the input tensor.
|
24 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
25 |
+
Attributes:
|
26 |
+
eps (float): A small value added to the denominator for numerical stability.
|
27 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
28 |
+
"""
|
29 |
+
super().__init__()
|
30 |
+
self.eps = eps
|
31 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
32 |
+
|
33 |
+
def _norm(self, x):
|
34 |
+
"""
|
35 |
+
Apply the RMSNorm normalization to the input tensor.
|
36 |
+
Args:
|
37 |
+
x (torch.Tensor): The input tensor.
|
38 |
+
Returns:
|
39 |
+
torch.Tensor: The normalized tensor.
|
40 |
+
"""
|
41 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
"""
|
45 |
+
Forward pass through the RMSNorm layer.
|
46 |
+
Args:
|
47 |
+
x (torch.Tensor): The input tensor.
|
48 |
+
Returns:
|
49 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
50 |
+
"""
|
51 |
+
output = self._norm(x.float()).type_as(x)
|
52 |
+
return output * self.weight
|
53 |
+
|
54 |
+
|
55 |
+
def modulate(x, scale):
|
56 |
+
return x * (1 + scale.unsqueeze(1))
|
57 |
+
|
58 |
+
class LLamaFeedForward(nn.Module):
|
59 |
+
"""
|
60 |
+
Corresponds to the FeedForward layer in Next DiT.
|
61 |
+
"""
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
dim: int,
|
65 |
+
hidden_dim: int,
|
66 |
+
multiple_of: int,
|
67 |
+
ffn_dim_multiplier: Optional[float] = None,
|
68 |
+
zeros_initialize: bool = True,
|
69 |
+
dtype: torch.dtype = torch.float32,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
self.dim = dim
|
73 |
+
self.hidden_dim = hidden_dim
|
74 |
+
self.multiple_of = multiple_of
|
75 |
+
self.ffn_dim_multiplier = ffn_dim_multiplier
|
76 |
+
self.zeros_initialize = zeros_initialize
|
77 |
+
self.dtype = dtype
|
78 |
+
|
79 |
+
# Compute hidden_dim based on the given formula
|
80 |
+
hidden_dim_calculated = int(2 * self.hidden_dim / 3)
|
81 |
+
if self.ffn_dim_multiplier is not None:
|
82 |
+
hidden_dim_calculated = int(self.ffn_dim_multiplier * hidden_dim_calculated)
|
83 |
+
hidden_dim_calculated = self.multiple_of * ((hidden_dim_calculated + self.multiple_of - 1) // self.multiple_of)
|
84 |
+
|
85 |
+
# Define linear layers
|
86 |
+
self.w1 = nn.Linear(self.dim, hidden_dim_calculated, bias=False)
|
87 |
+
self.w2 = nn.Linear(hidden_dim_calculated, self.dim, bias=False)
|
88 |
+
self.w3 = nn.Linear(self.dim, hidden_dim_calculated, bias=False)
|
89 |
+
|
90 |
+
# Initialize weights
|
91 |
+
if self.zeros_initialize:
|
92 |
+
nn.init.zeros_(self.w2.weight)
|
93 |
+
else:
|
94 |
+
nn.init.xavier_uniform_(self.w2.weight)
|
95 |
+
nn.init.xavier_uniform_(self.w1.weight)
|
96 |
+
nn.init.xavier_uniform_(self.w3.weight)
|
97 |
+
|
98 |
+
def _forward_silu_gating(self, x1, x3):
|
99 |
+
return F.silu(x1) * x3
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
103 |
+
|
104 |
+
class FinalLayer(nn.Module):
|
105 |
+
"""
|
106 |
+
The final layer of Next-DiT.
|
107 |
+
"""
|
108 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
109 |
+
super().__init__()
|
110 |
+
self.hidden_size = hidden_size
|
111 |
+
self.patch_size = patch_size
|
112 |
+
self.out_channels = out_channels
|
113 |
+
|
114 |
+
# LayerNorm without learnable parameters (elementwise_affine=False)
|
115 |
+
self.norm_final = nn.LayerNorm(self.hidden_size, eps=1e-6, elementwise_affine=False)
|
116 |
+
self.linear = nn.Linear(self.hidden_size, np.prod(self.patch_size) * self.out_channels, bias=True)
|
117 |
+
nn.init.zeros_(self.linear.weight)
|
118 |
+
nn.init.zeros_(self.linear.bias)
|
119 |
+
|
120 |
+
self.adaLN_modulation = nn.Sequential(
|
121 |
+
nn.SiLU(),
|
122 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
123 |
+
)
|
124 |
+
# Initialize the last layer with zeros
|
125 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
126 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
127 |
+
|
128 |
+
def forward(self, x, c):
|
129 |
+
scale = self.adaLN_modulation(c)
|
130 |
+
x = modulate(self.norm_final(x), scale)
|
131 |
+
x = self.linear(x)
|
132 |
+
return x
|
onediffusion/models/denoiser/nextdit/modeling_nextdit.py
ADDED
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
import einops
|
7 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
8 |
+
from diffusers.models.modeling_utils import ModelMixin
|
9 |
+
from typing import Any, Tuple, Optional
|
10 |
+
from flash_attn import flash_attn_varlen_func
|
11 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
12 |
+
|
13 |
+
from .layers import LLamaFeedForward, RMSNorm
|
14 |
+
|
15 |
+
# import frasch
|
16 |
+
|
17 |
+
|
18 |
+
def modulate(x, scale):
|
19 |
+
return x * (1 + scale)
|
20 |
+
|
21 |
+
class TimestepEmbedder(nn.Module):
|
22 |
+
"""
|
23 |
+
Embeds scalar timesteps into vector representations.
|
24 |
+
"""
|
25 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
26 |
+
super().__init__()
|
27 |
+
self.hidden_size = hidden_size
|
28 |
+
self.frequency_embedding_size = frequency_embedding_size
|
29 |
+
self.mlp = nn.Sequential(
|
30 |
+
nn.Linear(self.frequency_embedding_size, self.hidden_size),
|
31 |
+
nn.SiLU(),
|
32 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
33 |
+
)
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def timestep_embedding(t, dim, max_period=10000):
|
37 |
+
"""
|
38 |
+
Create sinusoidal timestep embeddings.
|
39 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
40 |
+
:param dim: the dimension of the output.
|
41 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
42 |
+
:return: an (N, D) Tensor of positional embeddings.
|
43 |
+
"""
|
44 |
+
half = dim // 2
|
45 |
+
freqs = torch.exp(
|
46 |
+
-np.log(max_period) * torch.arange(0, half, dtype=t.dtype) / half
|
47 |
+
).to(t.device)
|
48 |
+
args = t[:, :, None] * freqs[None, :]
|
49 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
50 |
+
if dim % 2:
|
51 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :, :1])], dim=-1)
|
52 |
+
return embedding
|
53 |
+
|
54 |
+
def forward(self, t):
|
55 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
56 |
+
t_freq = t_freq.to(self.mlp[0].weight.dtype)
|
57 |
+
return self.mlp(t_freq)
|
58 |
+
|
59 |
+
class FinalLayer(nn.Module):
|
60 |
+
def __init__(self, hidden_size, num_patches, out_channels):
|
61 |
+
super().__init__()
|
62 |
+
self.norm_final = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False)
|
63 |
+
self.linear = nn.Linear(hidden_size, num_patches * out_channels)
|
64 |
+
self.adaLN_modulation = nn.Sequential(
|
65 |
+
nn.SiLU(),
|
66 |
+
nn.Linear(min(hidden_size, 1024), hidden_size),
|
67 |
+
)
|
68 |
+
|
69 |
+
def forward(self, x, c):
|
70 |
+
scale = self.adaLN_modulation(c)
|
71 |
+
x = modulate(self.norm_final(x), scale)
|
72 |
+
x = self.linear(x)
|
73 |
+
return x
|
74 |
+
|
75 |
+
class Attention(nn.Module):
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
dim,
|
79 |
+
n_heads,
|
80 |
+
n_kv_heads=None,
|
81 |
+
qk_norm=False,
|
82 |
+
y_dim=0,
|
83 |
+
base_seqlen=None,
|
84 |
+
proportional_attn=False,
|
85 |
+
attention_dropout=0.0,
|
86 |
+
max_position_embeddings=384,
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
self.dim = dim
|
90 |
+
self.n_heads = n_heads
|
91 |
+
self.n_kv_heads = n_kv_heads or n_heads
|
92 |
+
self.qk_norm = qk_norm
|
93 |
+
self.y_dim = y_dim
|
94 |
+
self.base_seqlen = base_seqlen
|
95 |
+
self.proportional_attn = proportional_attn
|
96 |
+
self.attention_dropout = attention_dropout
|
97 |
+
self.max_position_embeddings = max_position_embeddings
|
98 |
+
|
99 |
+
self.head_dim = dim // n_heads
|
100 |
+
|
101 |
+
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
|
102 |
+
self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
103 |
+
self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
104 |
+
|
105 |
+
if y_dim > 0:
|
106 |
+
self.wk_y = nn.Linear(y_dim, self.n_kv_heads * self.head_dim, bias=False)
|
107 |
+
self.wv_y = nn.Linear(y_dim, self.n_kv_heads * self.head_dim, bias=False)
|
108 |
+
self.gate = nn.Parameter(torch.zeros(n_heads))
|
109 |
+
|
110 |
+
self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
|
111 |
+
|
112 |
+
if qk_norm:
|
113 |
+
self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim)
|
114 |
+
self.k_norm = nn.LayerNorm(self.n_kv_heads * self.head_dim)
|
115 |
+
if y_dim > 0:
|
116 |
+
self.ky_norm = nn.LayerNorm(self.n_kv_heads * self.head_dim, eps=1e-6)
|
117 |
+
else:
|
118 |
+
self.ky_norm = nn.Identity()
|
119 |
+
else:
|
120 |
+
self.q_norm = nn.Identity()
|
121 |
+
self.k_norm = nn.Identity()
|
122 |
+
self.ky_norm = nn.Identity()
|
123 |
+
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
def apply_rotary_emb(xq, xk, freqs_cis):
|
127 |
+
# xq, xk: [batch_size, seq_len, n_heads, head_dim]
|
128 |
+
# freqs_cis: [1, seq_len, 1, head_dim]
|
129 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
|
130 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
|
131 |
+
|
132 |
+
xq_complex = torch.view_as_complex(xq_)
|
133 |
+
xk_complex = torch.view_as_complex(xk_)
|
134 |
+
|
135 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
136 |
+
|
137 |
+
# Apply freqs_cis
|
138 |
+
xq_out = xq_complex * freqs_cis
|
139 |
+
xk_out = xk_complex * freqs_cis
|
140 |
+
|
141 |
+
# Convert back to real numbers
|
142 |
+
xq_out = torch.view_as_real(xq_out).flatten(-2)
|
143 |
+
xk_out = torch.view_as_real(xk_out).flatten(-2)
|
144 |
+
|
145 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
146 |
+
|
147 |
+
# copied from huggingface modeling_llama.py
|
148 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
149 |
+
def _get_unpad_data(attention_mask):
|
150 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
151 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
152 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
153 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
154 |
+
return (
|
155 |
+
indices,
|
156 |
+
cu_seqlens,
|
157 |
+
max_seqlen_in_batch,
|
158 |
+
)
|
159 |
+
|
160 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
161 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
162 |
+
|
163 |
+
key_layer = index_first_axis(
|
164 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
165 |
+
indices_k,
|
166 |
+
)
|
167 |
+
value_layer = index_first_axis(
|
168 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
169 |
+
indices_k,
|
170 |
+
)
|
171 |
+
if query_length == kv_seq_len:
|
172 |
+
query_layer = index_first_axis(
|
173 |
+
query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim),
|
174 |
+
indices_k,
|
175 |
+
)
|
176 |
+
cu_seqlens_q = cu_seqlens_k
|
177 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
178 |
+
indices_q = indices_k
|
179 |
+
elif query_length == 1:
|
180 |
+
max_seqlen_in_batch_q = 1
|
181 |
+
cu_seqlens_q = torch.arange(
|
182 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
183 |
+
) # There is a memcpy here, that is very bad.
|
184 |
+
indices_q = cu_seqlens_q[:-1]
|
185 |
+
query_layer = query_layer.squeeze(1)
|
186 |
+
else:
|
187 |
+
# The -q_len: slice assumes left padding.
|
188 |
+
attention_mask = attention_mask[:, -query_length:]
|
189 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
190 |
+
|
191 |
+
return (
|
192 |
+
query_layer,
|
193 |
+
key_layer,
|
194 |
+
value_layer,
|
195 |
+
indices_q,
|
196 |
+
(cu_seqlens_q, cu_seqlens_k),
|
197 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
198 |
+
)
|
199 |
+
|
200 |
+
def forward(
|
201 |
+
self,
|
202 |
+
x,
|
203 |
+
x_mask,
|
204 |
+
freqs_cis,
|
205 |
+
y=None,
|
206 |
+
y_mask=None,
|
207 |
+
init_cache=False,
|
208 |
+
):
|
209 |
+
bsz, seqlen, _ = x.size()
|
210 |
+
xq = self.wq(x)
|
211 |
+
xk = self.wk(x)
|
212 |
+
xv = self.wv(x)
|
213 |
+
|
214 |
+
if x_mask is None:
|
215 |
+
x_mask = torch.ones(bsz, seqlen, dtype=torch.bool, device=x.device)
|
216 |
+
inp_dtype = xq.dtype
|
217 |
+
|
218 |
+
xq = self.q_norm(xq)
|
219 |
+
xk = self.k_norm(xk)
|
220 |
+
|
221 |
+
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
|
222 |
+
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
223 |
+
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
224 |
+
|
225 |
+
if self.n_kv_heads != self.n_heads:
|
226 |
+
n_rep = self.n_heads // self.n_kv_heads
|
227 |
+
xk = xk.repeat_interleave(n_rep, dim=2)
|
228 |
+
xv = xv.repeat_interleave(n_rep, dim=2)
|
229 |
+
|
230 |
+
freqs_cis = freqs_cis.to(xq.device)
|
231 |
+
xq, xk = self.apply_rotary_emb(xq, xk, freqs_cis)
|
232 |
+
|
233 |
+
if inp_dtype in [torch.float16, torch.bfloat16]:
|
234 |
+
# begin var_len flash attn
|
235 |
+
(
|
236 |
+
query_states,
|
237 |
+
key_states,
|
238 |
+
value_states,
|
239 |
+
indices_q,
|
240 |
+
cu_seq_lens,
|
241 |
+
max_seq_lens,
|
242 |
+
) = self._upad_input(xq, xk, xv, x_mask, seqlen)
|
243 |
+
|
244 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
245 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
246 |
+
|
247 |
+
attn_output_unpad = flash_attn_varlen_func(
|
248 |
+
query_states.to(inp_dtype),
|
249 |
+
key_states.to(inp_dtype),
|
250 |
+
value_states.to(inp_dtype),
|
251 |
+
cu_seqlens_q=cu_seqlens_q,
|
252 |
+
cu_seqlens_k=cu_seqlens_k,
|
253 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
254 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
255 |
+
dropout_p=0.0,
|
256 |
+
causal=False,
|
257 |
+
softmax_scale=None,
|
258 |
+
softcap=30,
|
259 |
+
)
|
260 |
+
output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
|
261 |
+
else:
|
262 |
+
output = (
|
263 |
+
F.scaled_dot_product_attention(
|
264 |
+
xq.permute(0, 2, 1, 3),
|
265 |
+
xk.permute(0, 2, 1, 3),
|
266 |
+
xv.permute(0, 2, 1, 3),
|
267 |
+
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_heads, seqlen, -1),
|
268 |
+
scale=None,
|
269 |
+
)
|
270 |
+
.permute(0, 2, 1, 3)
|
271 |
+
.to(inp_dtype)
|
272 |
+
) #ok
|
273 |
+
|
274 |
+
|
275 |
+
if hasattr(self, "wk_y"):
|
276 |
+
yk = self.ky_norm(self.wk_y(y)).view(bsz, -1, self.n_kv_heads, self.head_dim)
|
277 |
+
yv = self.wv_y(y).view(bsz, -1, self.n_kv_heads, self.head_dim)
|
278 |
+
n_rep = self.n_heads // self.n_kv_heads
|
279 |
+
# if n_rep >= 1:
|
280 |
+
# yk = yk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
281 |
+
# yv = yv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
282 |
+
if n_rep >= 1:
|
283 |
+
yk = einops.repeat(yk, "b l h d -> b l (repeat h) d", repeat=n_rep)
|
284 |
+
yv = einops.repeat(yv, "b l h d -> b l (repeat h) d", repeat=n_rep)
|
285 |
+
output_y = F.scaled_dot_product_attention(
|
286 |
+
xq.permute(0, 2, 1, 3),
|
287 |
+
yk.permute(0, 2, 1, 3),
|
288 |
+
yv.permute(0, 2, 1, 3),
|
289 |
+
y_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_heads, seqlen, -1).to(torch.bool),
|
290 |
+
).permute(0, 2, 1, 3)
|
291 |
+
output_y = output_y * self.gate.tanh().view(1, 1, -1, 1)
|
292 |
+
output = output + output_y
|
293 |
+
|
294 |
+
output = output.flatten(-2)
|
295 |
+
output = self.wo(output)
|
296 |
+
|
297 |
+
return output.to(inp_dtype)
|
298 |
+
|
299 |
+
class TransformerBlock(nn.Module):
|
300 |
+
"""
|
301 |
+
Corresponds to the Transformer block in the JAX code.
|
302 |
+
"""
|
303 |
+
def __init__(
|
304 |
+
self,
|
305 |
+
dim,
|
306 |
+
n_heads,
|
307 |
+
n_kv_heads,
|
308 |
+
multiple_of,
|
309 |
+
ffn_dim_multiplier,
|
310 |
+
norm_eps,
|
311 |
+
qk_norm,
|
312 |
+
y_dim,
|
313 |
+
max_position_embeddings,
|
314 |
+
):
|
315 |
+
super().__init__()
|
316 |
+
self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, y_dim=y_dim, max_position_embeddings=max_position_embeddings)
|
317 |
+
self.feed_forward = LLamaFeedForward(
|
318 |
+
dim=dim,
|
319 |
+
hidden_dim=4 * dim,
|
320 |
+
multiple_of=multiple_of,
|
321 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
322 |
+
)
|
323 |
+
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
324 |
+
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
325 |
+
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
326 |
+
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
327 |
+
self.adaLN_modulation = nn.Sequential(
|
328 |
+
nn.SiLU(),
|
329 |
+
nn.Linear(min(dim, 1024), 4 * dim),
|
330 |
+
)
|
331 |
+
self.attention_y_norm = RMSNorm(y_dim, eps=norm_eps)
|
332 |
+
|
333 |
+
def forward(
|
334 |
+
self,
|
335 |
+
x,
|
336 |
+
x_mask,
|
337 |
+
freqs_cis,
|
338 |
+
y,
|
339 |
+
y_mask,
|
340 |
+
adaln_input=None,
|
341 |
+
):
|
342 |
+
if adaln_input is not None:
|
343 |
+
scales_gates = self.adaLN_modulation(adaln_input)
|
344 |
+
# TODO: Duong - check the dimension of chunking
|
345 |
+
# scale_msa, gate_msa, scale_mlp, gate_mlp = scales_gates.chunk(4, dim=-1)
|
346 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = scales_gates.chunk(4, dim=-1)
|
347 |
+
x = x + torch.tanh(gate_msa) * self.attention_norm2(
|
348 |
+
self.attention(
|
349 |
+
modulate(self.attention_norm1(x), scale_msa), # ok
|
350 |
+
x_mask,
|
351 |
+
freqs_cis,
|
352 |
+
self.attention_y_norm(y), # ok
|
353 |
+
y_mask,
|
354 |
+
)
|
355 |
+
)
|
356 |
+
x = x + torch.tanh(gate_mlp) * self.ffn_norm2(
|
357 |
+
self.feed_forward(
|
358 |
+
modulate(self.ffn_norm1(x), scale_mlp),
|
359 |
+
)
|
360 |
+
)
|
361 |
+
else:
|
362 |
+
x = x + self.attention_norm2(
|
363 |
+
self.attention(
|
364 |
+
self.attention_norm1(x),
|
365 |
+
x_mask,
|
366 |
+
freqs_cis,
|
367 |
+
self.attention_y_norm(y),
|
368 |
+
y_mask,
|
369 |
+
)
|
370 |
+
)
|
371 |
+
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
|
372 |
+
return x
|
373 |
+
|
374 |
+
|
375 |
+
class NextDiT(ModelMixin, ConfigMixin):
|
376 |
+
"""
|
377 |
+
Diffusion model with a Transformer backbone for joint image-video training.
|
378 |
+
"""
|
379 |
+
@register_to_config
|
380 |
+
def __init__(
|
381 |
+
self,
|
382 |
+
input_size=(1, 32, 32),
|
383 |
+
patch_size=(1, 2, 2),
|
384 |
+
in_channels=16,
|
385 |
+
hidden_size=4096,
|
386 |
+
depth=32,
|
387 |
+
num_heads=32,
|
388 |
+
num_kv_heads=None,
|
389 |
+
multiple_of=256,
|
390 |
+
ffn_dim_multiplier=None,
|
391 |
+
norm_eps=1e-5,
|
392 |
+
pred_sigma=False,
|
393 |
+
caption_channels=4096,
|
394 |
+
qk_norm=False,
|
395 |
+
norm_type="rms",
|
396 |
+
model_max_length=120,
|
397 |
+
rotary_max_length=384,
|
398 |
+
rotary_max_length_t=None
|
399 |
+
):
|
400 |
+
super().__init__()
|
401 |
+
self.input_size = input_size
|
402 |
+
self.patch_size = patch_size
|
403 |
+
self.in_channels = in_channels
|
404 |
+
self.hidden_size = hidden_size
|
405 |
+
self.depth = depth
|
406 |
+
self.num_heads = num_heads
|
407 |
+
self.num_kv_heads = num_kv_heads or num_heads
|
408 |
+
self.multiple_of = multiple_of
|
409 |
+
self.ffn_dim_multiplier = ffn_dim_multiplier
|
410 |
+
self.norm_eps = norm_eps
|
411 |
+
self.pred_sigma = pred_sigma
|
412 |
+
self.caption_channels = caption_channels
|
413 |
+
self.qk_norm = qk_norm
|
414 |
+
self.norm_type = norm_type
|
415 |
+
self.model_max_length = model_max_length
|
416 |
+
self.rotary_max_length = rotary_max_length
|
417 |
+
self.rotary_max_length_t = rotary_max_length_t
|
418 |
+
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
419 |
+
|
420 |
+
self.x_embedder = nn.Linear(np.prod(self.patch_size) * in_channels, hidden_size)
|
421 |
+
|
422 |
+
self.t_embedder = TimestepEmbedder(min(hidden_size, 1024))
|
423 |
+
self.y_embedder = nn.Sequential(
|
424 |
+
nn.LayerNorm(caption_channels, eps=1e-6),
|
425 |
+
nn.Linear(caption_channels, min(hidden_size, 1024)),
|
426 |
+
)
|
427 |
+
|
428 |
+
self.layers = nn.ModuleList([
|
429 |
+
TransformerBlock(
|
430 |
+
dim=hidden_size,
|
431 |
+
n_heads=num_heads,
|
432 |
+
n_kv_heads=self.num_kv_heads,
|
433 |
+
multiple_of=multiple_of,
|
434 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
435 |
+
norm_eps=norm_eps,
|
436 |
+
qk_norm=qk_norm,
|
437 |
+
y_dim=caption_channels,
|
438 |
+
max_position_embeddings=rotary_max_length,
|
439 |
+
)
|
440 |
+
for _ in range(depth)
|
441 |
+
])
|
442 |
+
|
443 |
+
self.final_layer = FinalLayer(
|
444 |
+
hidden_size=hidden_size,
|
445 |
+
num_patches=np.prod(patch_size),
|
446 |
+
out_channels=self.out_channels,
|
447 |
+
)
|
448 |
+
|
449 |
+
assert (hidden_size // num_heads) % 6 == 0, "3d rope needs head dim to be divisible by 6"
|
450 |
+
|
451 |
+
self.freqs_cis = self.precompute_freqs_cis(
|
452 |
+
hidden_size // num_heads,
|
453 |
+
self.rotary_max_length,
|
454 |
+
end_t=self.rotary_max_length_t
|
455 |
+
)
|
456 |
+
|
457 |
+
def to(self, *args, **kwargs):
|
458 |
+
self = super().to(*args, **kwargs)
|
459 |
+
# self.freqs_cis = self.freqs_cis.to(*args, **kwargs)
|
460 |
+
return self
|
461 |
+
|
462 |
+
@staticmethod
|
463 |
+
def precompute_freqs_cis(
|
464 |
+
dim: int,
|
465 |
+
end: int,
|
466 |
+
end_t: int = None,
|
467 |
+
theta: float = 10000.0,
|
468 |
+
scale_factor: float = 1.0,
|
469 |
+
scale_watershed: float = 1.0,
|
470 |
+
timestep: float = 1.0,
|
471 |
+
):
|
472 |
+
if timestep < scale_watershed:
|
473 |
+
linear_factor = scale_factor
|
474 |
+
ntk_factor = 1.0
|
475 |
+
else:
|
476 |
+
linear_factor = 1.0
|
477 |
+
ntk_factor = scale_factor
|
478 |
+
|
479 |
+
theta = theta * ntk_factor
|
480 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 6)[: (dim // 6)] / dim)) / linear_factor
|
481 |
+
|
482 |
+
timestep = torch.arange(end, dtype=torch.float32)
|
483 |
+
freqs = torch.outer(timestep, freqs).float()
|
484 |
+
freqs_cis = torch.exp(1j * freqs)
|
485 |
+
|
486 |
+
if end_t is not None:
|
487 |
+
freqs_t = 1.0 / (theta ** (torch.arange(0, dim, 6)[: (dim // 6)] / dim)) / linear_factor
|
488 |
+
timestep_t = torch.arange(end_t, dtype=torch.float32)
|
489 |
+
freqs_t = torch.outer(timestep_t, freqs_t).float()
|
490 |
+
freqs_cis_t = torch.exp(1j * freqs_t)
|
491 |
+
freqs_cis_t = freqs_cis_t.view(end_t, 1, 1, dim // 6).repeat(1, end, end, 1)
|
492 |
+
else:
|
493 |
+
end_t = end
|
494 |
+
freqs_cis_t = freqs_cis.view(end_t, 1, 1, dim // 6).repeat(1, end, end, 1)
|
495 |
+
|
496 |
+
freqs_cis_h = freqs_cis.view(1, end, 1, dim // 6).repeat(end_t, 1, end, 1)
|
497 |
+
freqs_cis_w = freqs_cis.view(1, 1, end, dim // 6).repeat(end_t, end, 1, 1)
|
498 |
+
freqs_cis = torch.cat([freqs_cis_t, freqs_cis_h, freqs_cis_w], dim=-1).view(end_t, end, end, -1)
|
499 |
+
return freqs_cis
|
500 |
+
|
501 |
+
def forward(
|
502 |
+
self,
|
503 |
+
samples,
|
504 |
+
timesteps,
|
505 |
+
encoder_hidden_states,
|
506 |
+
encoder_attention_mask,
|
507 |
+
scale_factor: float = 1.0, # scale_factor for rotary embedding
|
508 |
+
scale_watershed: float = 1.0, # scale_watershed for rotary embedding
|
509 |
+
):
|
510 |
+
if samples.ndim == 4: # B C H W
|
511 |
+
samples = samples[:, None, ...] # B F C H W
|
512 |
+
|
513 |
+
precomputed_freqs_cis = None
|
514 |
+
if scale_factor != 1 or scale_watershed != 1:
|
515 |
+
precomputed_freqs_cis = self.precompute_freqs_cis(
|
516 |
+
self.hidden_size // self.num_heads,
|
517 |
+
self.rotary_max_length,
|
518 |
+
end_t=self.rotary_max_length_t,
|
519 |
+
scale_factor=scale_factor,
|
520 |
+
scale_watershed=scale_watershed,
|
521 |
+
timestep=torch.max(timesteps.cpu()).item()
|
522 |
+
)
|
523 |
+
|
524 |
+
if len(timesteps.shape) == 5:
|
525 |
+
t, *_ = self.patchify(timesteps, precomputed_freqs_cis)
|
526 |
+
timesteps = t.mean(dim=-1)
|
527 |
+
elif len(timesteps.shape) == 1:
|
528 |
+
timesteps = timesteps[:, None, None, None, None].expand_as(samples)
|
529 |
+
t, *_ = self.patchify(timesteps, precomputed_freqs_cis)
|
530 |
+
timesteps = t.mean(dim=-1)
|
531 |
+
samples, T, H, W, freqs_cis = self.patchify(samples, precomputed_freqs_cis)
|
532 |
+
samples = self.x_embedder(samples)
|
533 |
+
t = self.t_embedder(timesteps)
|
534 |
+
|
535 |
+
encoder_attention_mask_float = encoder_attention_mask[..., None].float()
|
536 |
+
encoder_hidden_states_pool = (encoder_hidden_states * encoder_attention_mask_float).sum(dim=1) / (encoder_attention_mask_float.sum(dim=1) + 1e-8)
|
537 |
+
encoder_hidden_states_pool = encoder_hidden_states_pool.to(samples.dtype)
|
538 |
+
y = self.y_embedder(encoder_hidden_states_pool)
|
539 |
+
y = y.unsqueeze(1).expand(-1, samples.size(1), -1)
|
540 |
+
|
541 |
+
adaln_input = t + y
|
542 |
+
|
543 |
+
for block in self.layers:
|
544 |
+
samples = block(samples, None, freqs_cis, encoder_hidden_states, encoder_attention_mask, adaln_input)
|
545 |
+
|
546 |
+
samples = self.final_layer(samples, adaln_input)
|
547 |
+
samples = self.unpatchify(samples, T, H, W)
|
548 |
+
|
549 |
+
return samples
|
550 |
+
|
551 |
+
def patchify(self, x, precompute_freqs_cis=None):
|
552 |
+
# pytorch is C, H, W
|
553 |
+
B, T, C, H, W = x.size()
|
554 |
+
pT, pH, pW = self.patch_size
|
555 |
+
x = x.view(B, T // pT, pT, C, H // pH, pH, W // pW, pW)
|
556 |
+
x = x.permute(0, 1, 4, 6, 2, 5, 7, 3)
|
557 |
+
x = x.reshape(B, -1, pT * pH * pW * C)
|
558 |
+
if precompute_freqs_cis is None:
|
559 |
+
freqs_cis = self.freqs_cis[: T // pT, :H // pH, :W // pW].reshape(-1, * self.freqs_cis.shape[3:])[None].to(x.device)
|
560 |
+
else:
|
561 |
+
freqs_cis = precompute_freqs_cis[: T // pT, :H // pH, :W // pW].reshape(-1, * precompute_freqs_cis.shape[3:])[None].to(x.device)
|
562 |
+
return x, T // pT, H // pH, W // pW, freqs_cis
|
563 |
+
|
564 |
+
def unpatchify(self, x, T, H, W):
|
565 |
+
B = x.size(0)
|
566 |
+
C = self.out_channels
|
567 |
+
pT, pH, pW = self.patch_size
|
568 |
+
x = x.view(B, T, H, W, pT, pH, pW, C)
|
569 |
+
x = x.permute(0, 1, 4, 7, 2, 5, 3, 6)
|
570 |
+
x = x.reshape(B, T * pT, C, H * pH, W * pW)
|
571 |
+
return x
|
requirements.txt
CHANGED
@@ -1,6 +1,27 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pytest
|
2 |
+
matplotlib
|
3 |
+
scikit-learn
|
4 |
+
scipy
|
5 |
+
spacy
|
6 |
+
numpy
|
7 |
+
einops
|
8 |
+
einsum
|
9 |
+
fvcore
|
10 |
+
h5py
|
11 |
+
twine
|
12 |
+
transformers==4.45.2
|
13 |
+
huggingface_hub==0.24
|
14 |
+
accelerate==0.34.2
|
15 |
+
diffusers==0.30.3
|
16 |
+
pillow==10.2.0
|
17 |
+
torch==2.3.1
|
18 |
+
torchvision==0.18.1
|
19 |
+
torchaudio==2.3.1
|
20 |
+
flash-attn==2.6.3
|
21 |
+
git+https://github.com/Dao-AILab/flash-attention@v2.6.3#subdirectory=csrc/fused_dense_lib
|
22 |
+
jaxtyping
|
23 |
+
mediapipe
|
24 |
+
gradio
|
25 |
+
git+https://github.com/facebookresearch/pytorch3d.git
|
26 |
+
opencv-python==4.5.5.64
|
27 |
+
opencv-python-headless==4.5.5.64
|