nvan15 commited on
Commit
b03742a
·
verified ·
1 Parent(s): 390ddaa

Batch upload part 1

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +18 -0
  2. .gitignore +229 -0
  3. LICENSE +201 -0
  4. README.md +198 -0
  5. ablation_qkv.py +170 -0
  6. assets/book.jpg +0 -0
  7. assets/cartoon_boy.png +3 -0
  8. assets/clock.jpg +3 -0
  9. assets/coffee.png +0 -0
  10. assets/demo/art1.png +3 -0
  11. assets/demo/art2.png +3 -0
  12. assets/demo/book_omini.jpg +0 -0
  13. assets/demo/clock_omini.jpg +0 -0
  14. assets/demo/demo_this_is_omini_control.jpg +3 -0
  15. assets/demo/dreambooth_res.jpg +3 -0
  16. assets/demo/man_omini.jpg +0 -0
  17. assets/demo/monalisa_omini.jpg +3 -0
  18. assets/demo/oranges_omini.jpg +0 -0
  19. assets/demo/panda_omini.jpg +0 -0
  20. assets/demo/penguin_omini.jpg +0 -0
  21. assets/demo/rc_car_omini.jpg +0 -0
  22. assets/demo/room_corner_canny.jpg +0 -0
  23. assets/demo/room_corner_coloring.jpg +0 -0
  24. assets/demo/room_corner_deblurring.jpg +0 -0
  25. assets/demo/room_corner_depth.jpg +0 -0
  26. assets/demo/scene_variation.jpg +3 -0
  27. assets/demo/shirt_omini.jpg +0 -0
  28. assets/demo/try_on.jpg +3 -0
  29. assets/monalisa.jpg +3 -0
  30. assets/ominicontrol_art/DistractedBoyfriend.webp +3 -0
  31. assets/ominicontrol_art/PulpFiction.jpg +3 -0
  32. assets/ominicontrol_art/breakingbad.jpg +3 -0
  33. assets/ominicontrol_art/oiiai.png +3 -0
  34. assets/oranges.jpg +0 -0
  35. assets/penguin.jpg +0 -0
  36. assets/rc_car.jpg +3 -0
  37. assets/room_corner.jpg +3 -0
  38. assets/test_in.jpg +0 -0
  39. assets/test_out.jpg +0 -0
  40. assets/tshirt.jpg +3 -0
  41. assets/vase.jpg +0 -0
  42. assets/vase_hq.jpg +3 -0
  43. evaluation.py +169 -0
  44. evaluation_coco.py +250 -0
  45. evaluation_coco_baseline.py +222 -0
  46. evaluation_subject_driven.py +362 -0
  47. examples/combine_with_style_lora.ipynb +235 -0
  48. examples/inpainting.ipynb +135 -0
  49. examples/ominicontrol_art.ipynb +218 -0
  50. examples/spatial.ipynb +191 -0
.gitattributes CHANGED
@@ -33,3 +33,21 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cartoon_boy.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/clock.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/monalisa.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/rc_car.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/room_corner.jpg filter=lfs diff=lfs merge=lfs -text
41
+ assets/tshirt.jpg filter=lfs diff=lfs merge=lfs -text
42
+ assets/vase_hq.jpg filter=lfs diff=lfs merge=lfs -text
43
+ assets/demo/art1.png filter=lfs diff=lfs merge=lfs -text
44
+ assets/demo/art2.png filter=lfs diff=lfs merge=lfs -text
45
+ assets/demo/demo_this_is_omini_control.jpg filter=lfs diff=lfs merge=lfs -text
46
+ assets/demo/dreambooth_res.jpg filter=lfs diff=lfs merge=lfs -text
47
+ assets/demo/monalisa_omini.jpg filter=lfs diff=lfs merge=lfs -text
48
+ assets/demo/scene_variation.jpg filter=lfs diff=lfs merge=lfs -text
49
+ assets/demo/try_on.jpg filter=lfs diff=lfs merge=lfs -text
50
+ assets/ominicontrol_art/DistractedBoyfriend.webp filter=lfs diff=lfs merge=lfs -text
51
+ assets/ominicontrol_art/PulpFiction.jpg filter=lfs diff=lfs merge=lfs -text
52
+ assets/ominicontrol_art/breakingbad.jpg filter=lfs diff=lfs merge=lfs -text
53
+ assets/ominicontrol_art/oiiai.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb/*
2
+ runs/*
3
+
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[codz]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py.cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ # Pipfile.lock
100
+
101
+ # UV
102
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # uv.lock
106
+
107
+ # poetry
108
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
109
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
110
+ # commonly ignored for libraries.
111
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
112
+ # poetry.lock
113
+ # poetry.toml
114
+
115
+ # pdm
116
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
117
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
118
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
119
+ # pdm.lock
120
+ # pdm.toml
121
+ .pdm-python
122
+ .pdm-build/
123
+
124
+ # pixi
125
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
126
+ # pixi.lock
127
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
128
+ # in the .venv directory. It is recommended not to include this directory in version control.
129
+ .pixi
130
+
131
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
132
+ __pypackages__/
133
+
134
+ # Celery stuff
135
+ celerybeat-schedule
136
+ celerybeat.pid
137
+
138
+ # Redis
139
+ *.rdb
140
+ *.aof
141
+ *.pid
142
+
143
+ # RabbitMQ
144
+ mnesia/
145
+ rabbitmq/
146
+ rabbitmq-data/
147
+
148
+ # ActiveMQ
149
+ activemq-data/
150
+
151
+ # SageMath parsed files
152
+ *.sage.py
153
+
154
+ # Environments
155
+ .env
156
+ .envrc
157
+ .venv
158
+ env/
159
+ venv/
160
+ ENV/
161
+ env.bak/
162
+ venv.bak/
163
+
164
+ # Spyder project settings
165
+ .spyderproject
166
+ .spyproject
167
+
168
+ # Rope project settings
169
+ .ropeproject
170
+
171
+ # mkdocs documentation
172
+ /site
173
+
174
+ # mypy
175
+ .mypy_cache/
176
+ .dmypy.json
177
+ dmypy.json
178
+
179
+ # Pyre type checker
180
+ .pyre/
181
+
182
+ # pytype static type analyzer
183
+ .pytype/
184
+
185
+ # Cython debug symbols
186
+ cython_debug/
187
+
188
+ # PyCharm
189
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
190
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
191
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
192
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
193
+ # .idea/
194
+
195
+ # Abstra
196
+ # Abstra is an AI-powered process automation framework.
197
+ # Ignore directories containing user credentials, local state, and settings.
198
+ # Learn more at https://abstra.io/docs
199
+ .abstra/
200
+
201
+ # Visual Studio Code
202
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
203
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
204
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
205
+ # you could uncomment the following to ignore the entire vscode folder
206
+ # .vscode/
207
+
208
+ # Ruff stuff:
209
+ .ruff_cache/
210
+
211
+ # PyPI configuration file
212
+ .pypirc
213
+
214
+ # Marimo
215
+ marimo/_static/
216
+ marimo/_lsp/
217
+ __marimo__/
218
+
219
+ # Streamlit
220
+ .streamlit/secrets.toml
221
+
222
+
223
+ # exps/
224
+ # wandb/
225
+ # *.ipynb
226
+ # glue_exp/
227
+ # logs_hyper/
228
+ # # grid/
229
+ # glue22_ex
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2024] [Zhenxiong Tan]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OminiControl
2
+
3
+
4
+ <img src='./assets/demo/demo_this_is_omini_control.jpg' width='100%' />
5
+ <br>
6
+
7
+ <a href="https://huggingface.co/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a>
8
+ <a href="https://huggingface.co/spaces/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/🤗_HuggingFace-Demo-ffbd45.svg" alt="HuggingFace"></a>
9
+ <a href="https://huggingface.co/spaces/Yuanshi/OminiControl_Art"><img src="https://img.shields.io/badge/🤗_HuggingFace-Demo2-ffbd45.svg" alt="HuggingFace"></a>
10
+ <a href="https://github.com/Yuanshi9815/Subjects200K"><img src="https://img.shields.io/badge/GitHub-Dataset-blue.svg?logo=github&" alt="GitHub"></a>
11
+ <a href="https://huggingface.co/datasets/Yuanshi/Subjects200K"><img src="https://img.shields.io/badge/🤗_HuggingFace-Dataset-ffbd45.svg" alt="HuggingFace"></a>
12
+ <br>
13
+ <a href="https://arxiv.org/abs/2411.15098"><img src="https://img.shields.io/badge/ariXv-OminiControl-A42C25.svg" alt="arXiv"></a>
14
+ <a href="https://arxiv.org/abs/2503.08280"><img src="https://img.shields.io/badge/ariXv-OminiControl2-A42C25.svg" alt="arXiv"></a>
15
+
16
+ > **OminiControl: Minimal and Universal Control for Diffusion Transformer**
17
+ > <br>
18
+ > Zhenxiong Tan,
19
+ > [Songhua Liu](http://121.37.94.87/),
20
+ > [Xingyi Yang](https://adamdad.github.io/),
21
+ > Qiaochu Xue,
22
+ > and
23
+ > [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
24
+ > <br>
25
+ > [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore
26
+ > <br>
27
+
28
+ > **OminiControl2: Efficient Conditioning for Diffusion Transformers**
29
+ > <br>
30
+ > Zhenxiong Tan,
31
+ > Qiaochu Xue,
32
+ > [Xingyi Yang](https://adamdad.github.io/),
33
+ > [Songhua Liu](http://121.37.94.87/),
34
+ > and
35
+ > [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
36
+ > <br>
37
+ > [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore
38
+ > <br>
39
+
40
+
41
+
42
+ ## Features
43
+
44
+ OminiControl is a minimal yet powerful universal control framework for Diffusion Transformer models like [FLUX](https://github.com/black-forest-labs/flux).
45
+
46
+ * **Universal Control 🌐**: A unified control framework that supports both subject-driven control and spatial control (such as edge-guided and in-painting generation).
47
+
48
+ * **Minimal Design 🚀**: Injects control signals while preserving original model structure. Only introduces 0.1% additional parameters to the base model.
49
+
50
+ ## News
51
+ - **2025-05-12**: ⭐️ The code of [OminiControl2](https://arxiv.org/abs/2503.08280) is released. It introduces a new efficient conditioning method for diffusion transformers. (Check out the training code [here](./train)).
52
+ - **2025-05-12**: Support custom style LoRA. (Check out the [example](./examples/combine_with_style_lora.ipynb)).
53
+ - **2025-04-09**: ⭐️ [OminiControl Art](https://huggingface.co/spaces/Yuanshi/OminiControl_Art) is released. It can stylize any image with a artistic style. (Check out the [demo](https://huggingface.co/spaces/Yuanshi/OminiControl_Art) and [inference examples](./examples/ominicontrol_art.ipynb)).
54
+ - **2024-12-26**: Training code are released. Now you can create your own OminiControl model by customizing any control tasks (3D, multi-view, pose-guided, try-on, etc.) with the FLUX model. Check the [training folder](./train) for more details.
55
+
56
+ ## Quick Start
57
+ ### Setup (Optional)
58
+ 1. **Environment setup**
59
+ ```bash
60
+ conda create -n omini python=3.12
61
+ conda activate omini
62
+ ```
63
+ 2. **Requirements installation**
64
+ ```bash
65
+ pip install -r requirements.txt
66
+ ```
67
+ ### Usage example
68
+ 1. Subject-driven generation: `examples/subject.ipynb`
69
+ 2. In-painting: `examples/inpainting.ipynb`
70
+ 3. Canny edge to image, depth to image, colorization, deblurring: `examples/spatial.ipynb`
71
+
72
+
73
+ ### Guidelines for subject-driven generation
74
+ 1. Input images are automatically center-cropped and resized to 512x512 resolution.
75
+ 2. When writing prompts, refer to the subject using phrases like `this item`, `the object`, or `it`. e.g.
76
+ 1. *A close up view of this item. It is placed on a wooden table.*
77
+ 2. *A young lady is wearing this shirt.*
78
+ 3. The model primarily works with objects rather than human subjects currently, due to the absence of human data in training.
79
+
80
+ ## Generated samples
81
+ ### Subject-driven generation
82
+ <a href="https://huggingface.co/spaces/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/🤗_HuggingFace-Space-ffbd45.svg" alt="HuggingFace"></a>
83
+
84
+ **Demos** (Left: condition image; Right: generated image)
85
+
86
+ <div float="left">
87
+ <img src='./assets/demo/oranges_omini.jpg' width='48%'/>
88
+ <img src='./assets/demo/rc_car_omini.jpg' width='48%' />
89
+ <img src='./assets/demo/clock_omini.jpg' width='48%' />
90
+ <img src='./assets/demo/shirt_omini.jpg' width='48%' />
91
+ </div>
92
+
93
+ <details>
94
+ <summary>Text Prompts</summary>
95
+
96
+ - Prompt1: *A close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!.'*
97
+ - Prompt2: *A film style shot. On the moon, this item drives across the moon surface. A flag on it reads 'Omini'. The background is that Earth looms large in the foreground.*
98
+ - Prompt3: *In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.*
99
+ - Prompt4: *"On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple."*
100
+ </details>
101
+ <details>
102
+ <summary>More results</summary>
103
+
104
+ * Try on:
105
+ <img src='./assets/demo/try_on.jpg'/>
106
+ * Scene variations:
107
+ <img src='./assets/demo/scene_variation.jpg'/>
108
+ * Dreambooth dataset:
109
+ <img src='./assets/demo/dreambooth_res.jpg'/>
110
+ * Oye-cartoon finetune:
111
+ <div float="left">
112
+ <img src='./assets/demo/man_omini.jpg' width='48%' />
113
+ <img src='./assets/demo/panda_omini.jpg' width='48%' />
114
+ </div>
115
+ </details>
116
+
117
+ ### Spatially aligned control
118
+ 1. **Image Inpainting** (Left: original image; Center: masked image; Right: filled image)
119
+ - Prompt: *The Mona Lisa is wearing a white VR headset with 'Omini' written on it.*
120
+ </br>
121
+ <img src='./assets/demo/monalisa_omini.jpg' width='700px' />
122
+ - Prompt: *A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.*
123
+ </br>
124
+ <img src='./assets/demo/book_omini.jpg' width='700px' />
125
+ 2. **Other spatially aligned tasks** (Canny edge to image, depth to image, colorization, deblurring)
126
+ </br>
127
+ <details>
128
+ <summary>Click to show</summary>
129
+ <div float="left">
130
+ <img src='./assets/demo/room_corner_canny.jpg' width='48%'/>
131
+ <img src='./assets/demo/room_corner_depth.jpg' width='48%' />
132
+ <img src='./assets/demo/room_corner_coloring.jpg' width='48%' />
133
+ <img src='./assets/demo/room_corner_deblurring.jpg' width='48%' />
134
+ </div>
135
+
136
+ Prompt: *A light gray sofa stands against a white wall, featuring a black and white geometric patterned pillow. A white side table sits next to the sofa, topped with a white adjustable desk lamp and some books. Dark hardwood flooring contrasts with the pale walls and furniture.*
137
+ </details>
138
+
139
+ ### Stylize images
140
+ <a href="https://huggingface.co/spaces/Yuanshi/OminiControl_Art"><img src="https://img.shields.io/badge/🤗_HuggingFace-Demo2-ffbd45.svg" alt="HuggingFace"></a>
141
+ </br>
142
+ <img src='./assets/demo/art1.png' width='600px' />
143
+ <img src='./assets/demo/art2.png' width='600px' />
144
+ </br>
145
+
146
+
147
+
148
+ ## Models
149
+
150
+ **Subject-driven control:**
151
+ | Model | Base model | Description | Resolution |
152
+ | ------------------------------------------------------------------------------------------------ | -------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------ |
153
+ | [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `subject` | FLUX.1-schnell | The model used in the paper. | (512, 512) |
154
+ | [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_512` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset. | (512, 512) |
155
+ | [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_1024` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset and accommodates higher resolution. | (1024, 1024) |
156
+ | [`oye-cartoon`](https://huggingface.co/saquiboye/oye-cartoon) | FLUX.1-dev | The model has been fine-tuned on [oye-cartoon](https://huggingface.co/datasets/saquiboye/oye-cartoon) dataset by [@saquib764](https://github.com/Saquib764) | (512, 512) |
157
+
158
+ **Spatial aligned control:**
159
+ | Model | Base model | Description | Resolution |
160
+ | --------------------------------------------------------------------------------------------------------- | ---------- | -------------------------------------------------------------------------- | ------------ |
161
+ | [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `<task_name>` | FLUX.1 | Canny edge to image, depth to image, colorization, deblurring, in-painting | (512, 512) |=
162
+
163
+ ## Community Extensions
164
+ - [ComfyUI-Diffusers-OminiControl](https://github.com/Macoron/ComfyUI-Diffusers-OminiControl) - ComfyUI integration by [@Macoron](https://github.com/Macoron)
165
+ - [ComfyUI_RH_OminiControl](https://github.com/HM-RunningHub/ComfyUI_RH_OminiControl) - ComfyUI integration by [@HM-RunningHub](https://github.com/HM-RunningHub)
166
+
167
+ ## Limitations
168
+ 1. The model's subject-driven generation primarily works with objects rather than human subjects due to the absence of human data in training.
169
+ 2. The subject-driven generation model may not work well with `FLUX.1-dev`.
170
+ 3. The released model only supports the resolution of 512x512.
171
+
172
+ ## Training
173
+ Training instructions can be found in this [folder](./train).
174
+
175
+
176
+ ## To-do
177
+ - [x] Release the training code.
178
+ - [x] Release the model for higher resolution (1024x1024).
179
+
180
+ ## Acknowledgment
181
+ We would like to acknowledge that the computational work involved in this research work is partially supported by NUS IT’s Research Computing group using grant numbers NUSREC-HPC-00001.
182
+
183
+ ## Citation
184
+ ```
185
+ @article{tan2025ominicontrol,
186
+ title={OminiControl: Minimal and Universal Control for Diffusion Transformer},
187
+ author={Tan, Zhenxiong and Liu, Songhua and Yang, Xingyi and Xue, Qiaochu and Wang, Xinchao},
188
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
189
+ year={2025}
190
+ }
191
+
192
+ @article{tan2025ominicontrol2,
193
+ title={OminiControl2: Efficient Conditioning for Diffusion Transformers},
194
+ author={Tan, Zhenxiong and Xue, Qiaochu and Yang, Xingyi and Liu, Songhua and Wang, Xinchao},
195
+ journal={arXiv preprint arXiv:2503.08280},
196
+ year={2025}
197
+ }
198
+ ```
ablation_qkv.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers.pipelines import FluxPipeline
4
+ from omini.pipeline.flux_omini_ablate_qkv import Condition, generate, seed_everything, convert_to_condition
5
+ from omini.rotation import RotationConfig, RotationTuner
6
+ from PIL import Image
7
+
8
+
9
+ def load_rotation(transformer, path: str, adapter_name: str = "default", strict: bool = False):
10
+ """
11
+ Load rotation adapter weights.
12
+
13
+ Args:
14
+ path: Directory containing the saved adapter weights
15
+ adapter_name: Name of the adapter to load
16
+ strict: Whether to strictly match all keys
17
+ """
18
+ from safetensors.torch import load_file
19
+ import os
20
+ import yaml
21
+
22
+ device = transformer.device
23
+ print(f"device for loading: {device}")
24
+
25
+ # Try to load safetensors first, then fallback to .pth
26
+ safetensors_path = os.path.join(path, f"{adapter_name}.safetensors")
27
+ pth_path = os.path.join(path, f"{adapter_name}.pth")
28
+
29
+ if os.path.exists(safetensors_path):
30
+ state_dict = load_file(safetensors_path)
31
+ print(f"Loaded rotation adapter from {safetensors_path}")
32
+ elif os.path.exists(pth_path):
33
+ state_dict = torch.load(pth_path, map_location=device)
34
+ print(f"Loaded rotation adapter from {pth_path}")
35
+ else:
36
+ raise FileNotFoundError(
37
+ f"No adapter weights found for '{adapter_name}' in {path}\n"
38
+ f"Looking for: {safetensors_path} or {pth_path}"
39
+ )
40
+
41
+ # # Get the device and dtype of the transformer
42
+ transformer_device = next(transformer.parameters()).device
43
+ transformer_dtype = next(transformer.parameters()).dtype
44
+
45
+
46
+
47
+ state_dict_with_adapter = {}
48
+ for k, v in state_dict.items():
49
+ # Reconstruct the full key with adapter name
50
+ new_key = k.replace(".rotation.", f".rotation.{adapter_name}.")
51
+ if "_adapter_config" in new_key:
52
+ print(f"adapter_config key: {new_key}")
53
+
54
+
55
+ # Move to target device and dtype
56
+ # Check if this parameter should keep its original dtype (e.g., indices, masks)
57
+ if v.dtype in [torch.long, torch.int, torch.int32, torch.int64, torch.bool]:
58
+ # Keep integer/boolean dtypes, only move device
59
+ state_dict_with_adapter[new_key] = v.to(device=transformer_device)
60
+ else:
61
+ # Convert floating point tensors to target dtype and device
62
+ state_dict_with_adapter[new_key] = v.to(device=transformer_device, dtype=transformer_dtype)
63
+
64
+ # Add adapter name back to keys (reverse of what we did in save)
65
+ state_dict_with_adapter = {
66
+ k.replace(".rotation.", f".rotation.{adapter_name}."): v
67
+ for k, v in state_dict.items()
68
+ }
69
+
70
+
71
+ # Load into the model
72
+ missing, unexpected = transformer.load_state_dict(
73
+ state_dict_with_adapter,
74
+ strict=strict
75
+ )
76
+
77
+ if missing:
78
+ print(f"Missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}")
79
+ if unexpected:
80
+ print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
81
+
82
+ # Load config if available
83
+ config_path = os.path.join(path, f"{adapter_name}_config.yaml")
84
+ if os.path.exists(config_path):
85
+ with open(config_path, 'r') as f:
86
+ config = yaml.safe_load(f)
87
+ print(f"Loaded config: {config}")
88
+
89
+ total_params = sum(p.numel() for p in state_dict.values())
90
+ print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)")
91
+
92
+ return state_dict
93
+
94
+
95
+ # prepare input image and prompt
96
+ image = Image.open("assets/coffee.png").convert("RGB")
97
+
98
+ w, h, min_dim = image.size + (min(image.size),)
99
+ image = image.crop(
100
+ ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)
101
+ ).resize((512, 512))
102
+
103
+ prompt = "In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table."
104
+
105
+ canny_image = convert_to_condition("canny", image)
106
+ condition = Condition(canny_image, "canny")
107
+
108
+ seed_everything()
109
+
110
+
111
+
112
+ pipe = FluxPipeline.from_pretrained(
113
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
114
+ )
115
+
116
+
117
+ # add adapter to the transformer
118
+ transformer = pipe.transformer
119
+
120
+ adapter_name = "default"
121
+ transformer._hf_peft_config_loaded = True
122
+
123
+ rotation_adapter_config = {
124
+ "r": 4,
125
+ "num_rotations": 4,
126
+ "target_modules": "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)",
127
+ }
128
+
129
+ config = RotationConfig(**rotation_adapter_config)
130
+ rotation_tuner = RotationTuner(
131
+ transformer,
132
+ config,
133
+ adapter_name=adapter_name,
134
+ )
135
+ # Convert rotation tuner to bfloat16
136
+ transformer = transformer.to(torch.bfloat16)
137
+ transformer.set_adapter(adapter_name)
138
+
139
+ # load adapter weights
140
+ load_rotation(
141
+ transformer,
142
+ path="runs/20251110-191859-canny/ckpt/25000",
143
+ adapter_name=adapter_name,
144
+ strict=False,
145
+ )
146
+
147
+ # alter T for Query, Key, Value projections
148
+
149
+
150
+ pipe = pipe.to("cuda")
151
+ for i in range(0, 25):
152
+ seed_everything()
153
+ result_img = generate(
154
+ pipe,
155
+ prompt=prompt,
156
+ conditions=[condition],
157
+ # global_T_Q=float(i + 1) /20.,
158
+ global_T_K=float(i + 1) /20.,
159
+ global_T_V=float(i + 1) /20.,
160
+ ).images[0]
161
+
162
+ concat_image = Image.new("RGB", (1536, 512))
163
+ concat_image.paste(image, (0, 0))
164
+ concat_image.paste(condition.condition, (512, 0))
165
+ concat_image.paste(result_img, (1024, 0))
166
+
167
+ # Save images
168
+ result_img.save(f"result_{i+1}.png")
169
+ concat_image.save(f"result_concat_{i+1}.png")
170
+ print(f"Saved result_{i+1}.png and result_concat_{i+1}.png")
assets/book.jpg ADDED
assets/cartoon_boy.png ADDED

Git LFS Details

  • SHA256: d4a82c0f9ed09b9468bded7d901beffaf29addc30ed5f72ad72451e1b6344b1c
  • Pointer size: 131 Bytes
  • Size of remote file: 429 kB
assets/clock.jpg ADDED

Git LFS Details

  • SHA256: 41235973f26152ac92d32bfc166fb5f9f1e352c5e16807920238473316ec462b
  • Pointer size: 131 Bytes
  • Size of remote file: 289 kB
assets/coffee.png ADDED
assets/demo/art1.png ADDED

Git LFS Details

  • SHA256: 08a70e1c48306adb48adc6596af979326f7f0d9fac69a4bf7f46922f85ffd111
  • Pointer size: 131 Bytes
  • Size of remote file: 547 kB
assets/demo/art2.png ADDED

Git LFS Details

  • SHA256: 9f4436f3c53f693826b35f667c55968605ace9edb092ea85e17c47c2f89373eb
  • Pointer size: 131 Bytes
  • Size of remote file: 536 kB
assets/demo/book_omini.jpg ADDED
assets/demo/clock_omini.jpg ADDED
assets/demo/demo_this_is_omini_control.jpg ADDED

Git LFS Details

  • SHA256: 798b7c25be6be118dc0de97c444c840869afca633a0d48f99d940aec040a7518
  • Pointer size: 131 Bytes
  • Size of remote file: 129 kB
assets/demo/dreambooth_res.jpg ADDED

Git LFS Details

  • SHA256: ba36bd861989564dc679acf3b5e56f382f1a11b1596e6f611ea0bd7d81b89680
  • Pointer size: 132 Bytes
  • Size of remote file: 1.94 MB
assets/demo/man_omini.jpg ADDED
assets/demo/monalisa_omini.jpg ADDED

Git LFS Details

  • SHA256: e5ca6c2bf44f19d216b2eb16dcc67d19f11d87220d3ee80f5e5e1ad98a5536dc
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
assets/demo/oranges_omini.jpg ADDED
assets/demo/panda_omini.jpg ADDED
assets/demo/penguin_omini.jpg ADDED
assets/demo/rc_car_omini.jpg ADDED
assets/demo/room_corner_canny.jpg ADDED
assets/demo/room_corner_coloring.jpg ADDED
assets/demo/room_corner_deblurring.jpg ADDED
assets/demo/room_corner_depth.jpg ADDED
assets/demo/scene_variation.jpg ADDED

Git LFS Details

  • SHA256: 39e4e16d2eeb58b3775b6d34c8b3e125d0d19cc36fa90b07c6c8d57624ad4333
  • Pointer size: 131 Bytes
  • Size of remote file: 958 kB
assets/demo/shirt_omini.jpg ADDED
assets/demo/try_on.jpg ADDED

Git LFS Details

  • SHA256: 6adce5194329a83f0109b4375e00667c341879e64fb55831c70ea3f3b2f99f7e
  • Pointer size: 131 Bytes
  • Size of remote file: 774 kB
assets/monalisa.jpg ADDED

Git LFS Details

  • SHA256: 188b8b6499e4541f9dfef2a9daf6f1eb920079c9208f587fd97566d6aa4a9719
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
assets/ominicontrol_art/DistractedBoyfriend.webp ADDED

Git LFS Details

  • SHA256: 467731bd39620ada5af31cbd57dab4d813c533942f7bdde263c6fb238d5af8be
  • Pointer size: 131 Bytes
  • Size of remote file: 155 kB
assets/ominicontrol_art/PulpFiction.jpg ADDED

Git LFS Details

  • SHA256: bc03d4009a2b787ba36e8ad4eefc4bacc4398320b9740925d67d75b60f922115
  • Pointer size: 132 Bytes
  • Size of remote file: 1.69 MB
assets/ominicontrol_art/breakingbad.jpg ADDED

Git LFS Details

  • SHA256: 4988d0d0940bf6ef10f7d4b27b49445d8c546b92d77f92610945d644dfdcdc69
  • Pointer size: 131 Bytes
  • Size of remote file: 214 kB
assets/ominicontrol_art/oiiai.png ADDED

Git LFS Details

  • SHA256: cead6fef8a29b390a2c7b0c46d4e515c9108d6981e1e5bed9d71c1dbb41504b9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.65 MB
assets/oranges.jpg ADDED
assets/penguin.jpg ADDED
assets/rc_car.jpg ADDED

Git LFS Details

  • SHA256: ae8aed11029fa3b084deb286c07a8cab5056840c9c123816fe2b504e94233e95
  • Pointer size: 131 Bytes
  • Size of remote file: 254 kB
assets/room_corner.jpg ADDED

Git LFS Details

  • SHA256: f97bd63df05f5f15ad5dd1a2ccef803e74e12caadd8fe145493fd6d5219045e7
  • Pointer size: 131 Bytes
  • Size of remote file: 236 kB
assets/test_in.jpg ADDED
assets/test_out.jpg ADDED
assets/tshirt.jpg ADDED

Git LFS Details

  • SHA256: cb1803315765302113a9e7a64dedd4ecba2672028cf093cbc33ef2edd2247c39
  • Pointer size: 131 Bytes
  • Size of remote file: 301 kB
assets/vase.jpg ADDED
assets/vase_hq.jpg ADDED

Git LFS Details

  • SHA256: 279905e32116792f118802d23b0d96629d98ccbdac9e704e65eaf2e98c752679
  • Pointer size: 132 Bytes
  • Size of remote file: 2.9 MB
evaluation.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers.pipelines import FluxPipeline
4
+ from omini.pipeline.flux_omini import Condition, generate, seed_everything, convert_to_condition
5
+ from omini.rotation import RotationConfig, RotationTuner
6
+ from PIL import Image
7
+
8
+
9
+ def load_rotation(transformer, path: str, adapter_name: str = "default", strict: bool = False):
10
+ """
11
+ Load rotation adapter weights.
12
+
13
+ Args:
14
+ path: Directory containing the saved adapter weights
15
+ adapter_name: Name of the adapter to load
16
+ strict: Whether to strictly match all keys
17
+ """
18
+ from safetensors.torch import load_file
19
+ import os
20
+ import yaml
21
+
22
+ device = transformer.device
23
+ print(f"device for loading: {device}")
24
+
25
+ # Try to load safetensors first, then fallback to .pth
26
+ safetensors_path = os.path.join(path, f"{adapter_name}.safetensors")
27
+ pth_path = os.path.join(path, f"{adapter_name}.pth")
28
+
29
+ if os.path.exists(safetensors_path):
30
+ state_dict = load_file(safetensors_path)
31
+ print(f"Loaded rotation adapter from {safetensors_path}")
32
+ elif os.path.exists(pth_path):
33
+ state_dict = torch.load(pth_path, map_location=device)
34
+ print(f"Loaded rotation adapter from {pth_path}")
35
+ else:
36
+ raise FileNotFoundError(
37
+ f"No adapter weights found for '{adapter_name}' in {path}\n"
38
+ f"Looking for: {safetensors_path} or {pth_path}"
39
+ )
40
+
41
+ # # Get the device and dtype of the transformer
42
+ transformer_device = next(transformer.parameters()).device
43
+ transformer_dtype = next(transformer.parameters()).dtype
44
+
45
+
46
+
47
+ state_dict_with_adapter = {}
48
+ for k, v in state_dict.items():
49
+ # Reconstruct the full key with adapter name
50
+ new_key = k.replace(".rotation.", f".rotation.{adapter_name}.")
51
+ if "_adapter_config" in new_key:
52
+ print(f"adapter_config key: {new_key}")
53
+
54
+
55
+ # Move to target device and dtype
56
+ # Check if this parameter should keep its original dtype (e.g., indices, masks)
57
+ if v.dtype in [torch.long, torch.int, torch.int32, torch.int64, torch.bool]:
58
+ # Keep integer/boolean dtypes, only move device
59
+ state_dict_with_adapter[new_key] = v.to(device=transformer_device)
60
+ else:
61
+ # Convert floating point tensors to target dtype and device
62
+ state_dict_with_adapter[new_key] = v.to(device=transformer_device, dtype=transformer_dtype)
63
+
64
+ # Add adapter name back to keys (reverse of what we did in save)
65
+ state_dict_with_adapter = {
66
+ k.replace(".rotation.", f".rotation.{adapter_name}."): v
67
+ for k, v in state_dict.items()
68
+ }
69
+
70
+
71
+ # Load into the model
72
+ missing, unexpected = transformer.load_state_dict(
73
+ state_dict_with_adapter,
74
+ strict=strict
75
+ )
76
+
77
+ if missing:
78
+ print(f"Missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}")
79
+ if unexpected:
80
+ print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
81
+
82
+ # Load config if available
83
+ config_path = os.path.join(path, f"{adapter_name}_config.yaml")
84
+ if os.path.exists(config_path):
85
+ with open(config_path, 'r') as f:
86
+ config = yaml.safe_load(f)
87
+ print(f"Loaded config: {config}")
88
+
89
+ total_params = sum(p.numel() for p in state_dict.values())
90
+ print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)")
91
+
92
+ return state_dict
93
+
94
+
95
+ # prepare input image and prompt
96
+ image = Image.open("assets/coffee.png").convert("RGB")
97
+
98
+ w, h, min_dim = image.size + (min(image.size),)
99
+ image = image.crop(
100
+ ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)
101
+ ).resize((512, 512))
102
+
103
+ prompt = "In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table."
104
+
105
+ canny_image = convert_to_condition("canny", image)
106
+ condition = Condition(canny_image, "canny")
107
+
108
+ seed_everything()
109
+
110
+
111
+
112
+ for i in range(40, 60):
113
+ pipe = FluxPipeline.from_pretrained(
114
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
115
+ )
116
+
117
+
118
+ # add adapter to the transformer
119
+ transformer = pipe.transformer
120
+
121
+ adapter_name = "default"
122
+ transformer._hf_peft_config_loaded = True
123
+
124
+ rotation_adapter_config = {
125
+ "r": 4,
126
+ "num_rotations": 4,
127
+ "target_modules": "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)",
128
+ }
129
+
130
+ config = RotationConfig(**rotation_adapter_config)
131
+ config.T = float(i + 1) / 20
132
+ rotation_tuner = RotationTuner(
133
+ transformer,
134
+ config,
135
+ adapter_name=adapter_name,
136
+ )
137
+ # Convert rotation tuner to bfloat16
138
+ transformer = transformer.to(torch.bfloat16)
139
+ transformer.set_adapter(adapter_name)
140
+
141
+ # load adapter weights
142
+ load_rotation(
143
+ transformer,
144
+ path="runs/20251110-191859/ckpt/4000",
145
+ adapter_name=adapter_name,
146
+ strict=False,
147
+ )
148
+
149
+ pipe = pipe.to("cuda")
150
+
151
+
152
+
153
+
154
+
155
+ result_img = generate(
156
+ pipe,
157
+ prompt=prompt,
158
+ conditions=[condition],
159
+ ).images[0]
160
+
161
+ concat_image = Image.new("RGB", (1536, 512))
162
+ concat_image.paste(image, (0, 0))
163
+ concat_image.paste(condition.condition, (512, 0))
164
+ concat_image.paste(result_img, (1024, 0))
165
+
166
+ # Save images
167
+ result_img.save(f"result_{i+1}.png")
168
+ concat_image.save(f"result_concat_{i+1}.png")
169
+ print(f"Saved result_{i+1}.png and result_concat_{i+1}.png")
evaluation_coco.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # example images:depth-41888,15254,16228, 24144,37777, 22192
2
+
3
+ # ablate image: 87038
4
+ import json
5
+ from PIL import Image
6
+ import os
7
+ import argparse
8
+ from omini.pipeline.flux_omini import Condition, generate, seed_everything, convert_to_condition
9
+ from omini.rotation import RotationConfig, RotationTuner
10
+ from tqdm import tqdm
11
+ import torch
12
+ from diffusers.pipelines import FluxPipeline
13
+
14
+ def evaluate(pipe,
15
+ condition_type: str, # e.g., "canny"
16
+ caption_file: str,
17
+ image_dir: str,
18
+ save_root_dir: str,
19
+ num_images: int = 5000,
20
+ start_index: int = 0):
21
+ """
22
+ Evaluate the model on a subset of the COCO dataset.
23
+
24
+ Args:
25
+ pipe: The flux pipeline to use for generation
26
+ caption_file: Path to the COCO captions JSON file
27
+ image_dir: Directory containing COCO images
28
+ num_images: Number of images to evaluate on
29
+ """
30
+
31
+ os.makedirs(os.path.join(save_root_dir, "generated"), exist_ok=True)
32
+ os.makedirs(os.path.join(save_root_dir, "resized"), exist_ok=True)
33
+ os.makedirs(os.path.join(save_root_dir, condition_type), exist_ok=True)
34
+
35
+ # Load data
36
+ with open(caption_file, "r") as f:
37
+ coco_data = json.load(f)
38
+
39
+ # Build a mapping: image_id → (filename, captions)
40
+ id_to_filename = {img["id"]: img["file_name"] for img in coco_data["images"]}
41
+ captions_by_image = {}
42
+
43
+ for ann in coco_data["annotations"]:
44
+ img_id = ann["image_id"]
45
+ captions_by_image.setdefault(img_id, []).append(ann["caption"])
46
+
47
+ # Take first 5000 images
48
+ image_ids = list(id_to_filename.keys())[:5000]
49
+
50
+ # Collect data
51
+ captions_subset = [
52
+ {
53
+ "image_id": img_id,
54
+ "file_name": id_to_filename[img_id],
55
+ "captions": captions_by_image.get(img_id, [])
56
+ }
57
+ for img_id in image_ids
58
+ ]
59
+
60
+ for item in tqdm(captions_subset[start_index:start_index+num_images]):
61
+
62
+ image_id = item["image_id"]
63
+ image_path = os.path.join(image_dir, item["file_name"])
64
+ image = Image.open(image_path).convert("RGB")
65
+
66
+ # Resize and center-crop to 512x512
67
+ w, h, min_dim = image.size + (min(image.size),)
68
+ image = image.crop(
69
+ ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)
70
+ ).resize((512, 512))
71
+
72
+ condition_image = convert_to_condition(condition_type, image)
73
+ condition = Condition(condition_image, condition_type)
74
+
75
+ prompt = item["captions"][0] if item["captions"] else "No caption available."
76
+
77
+ seed_everything(42)
78
+
79
+ # generate image
80
+ result_img = generate(
81
+ pipe,
82
+ prompt=prompt,
83
+ conditions=[condition],
84
+ ).images[0]
85
+
86
+
87
+ result_img.save(os.path.join(save_root_dir, "generated", f"{image_id}.jpg"))
88
+ image.save(os.path.join(save_root_dir, "resized", f"{image_id}.jpg"))
89
+ condition.condition.save(os.path.join(save_root_dir, condition_type, f"{image_id}.jpg"))
90
+
91
+
92
+
93
+
94
+
95
+ def load_rotation(transformer, path: str, adapter_name: str = "default", strict: bool = False):
96
+ """
97
+ Load rotation adapter weights.
98
+
99
+ Args:
100
+ path: Directory containing the saved adapter weights
101
+ adapter_name: Name of the adapter to load
102
+ strict: Whether to strictly match all keys
103
+ """
104
+ from safetensors.torch import load_file
105
+ import os
106
+ import yaml
107
+
108
+ device = transformer.device
109
+ print(f"device for loading: {device}")
110
+
111
+ # Try to load safetensors first, then fallback to .pth
112
+ safetensors_path = os.path.join(path, f"{adapter_name}.safetensors")
113
+ pth_path = os.path.join(path, f"{adapter_name}.pth")
114
+
115
+ if os.path.exists(safetensors_path):
116
+ state_dict = load_file(safetensors_path)
117
+ print(f"Loaded rotation adapter from {safetensors_path}")
118
+ elif os.path.exists(pth_path):
119
+ state_dict = torch.load(pth_path, map_location=device)
120
+ print(f"Loaded rotation adapter from {pth_path}")
121
+ else:
122
+ raise FileNotFoundError(
123
+ f"No adapter weights found for '{adapter_name}' in {path}\n"
124
+ f"Looking for: {safetensors_path} or {pth_path}"
125
+ )
126
+
127
+ # # Get the device and dtype of the transformer
128
+ transformer_device = next(transformer.parameters()).device
129
+ transformer_dtype = next(transformer.parameters()).dtype
130
+
131
+
132
+
133
+ state_dict_with_adapter = {}
134
+ for k, v in state_dict.items():
135
+ # Reconstruct the full key with adapter name
136
+ new_key = k.replace(".rotation.", f".rotation.{adapter_name}.")
137
+
138
+ # Move to target device and dtype
139
+ # Check if this parameter should keep its original dtype (e.g., indices, masks)
140
+ if v.dtype in [torch.long, torch.int, torch.int32, torch.int64, torch.bool]:
141
+ # Keep integer/boolean dtypes, only move device
142
+ state_dict_with_adapter[new_key] = v.to(device=transformer_device)
143
+ else:
144
+ # Convert floating point tensors to target dtype and device
145
+ state_dict_with_adapter[new_key] = v.to(device=transformer_device, dtype=transformer_dtype)
146
+
147
+ # Add adapter name back to keys (reverse of what we did in save)
148
+ state_dict_with_adapter = {
149
+ k.replace(".rotation.", f".rotation.{adapter_name}."): v
150
+ for k, v in state_dict.items()
151
+ }
152
+
153
+
154
+ # Load into the model
155
+ missing, unexpected = transformer.load_state_dict(
156
+ state_dict_with_adapter,
157
+ strict=strict
158
+ )
159
+
160
+ if missing:
161
+ print(f"Missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}")
162
+ if unexpected:
163
+ print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
164
+
165
+ # Load config if available
166
+ config_path = os.path.join(path, f"{adapter_name}_config.yaml")
167
+ if os.path.exists(config_path):
168
+ with open(config_path, 'r') as f:
169
+ config = yaml.safe_load(f)
170
+ print(f"Loaded config: {config}")
171
+
172
+ total_params = sum(p.numel() for p in state_dict.values())
173
+ print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)")
174
+
175
+ return state_dict
176
+
177
+
178
+
179
+
180
+ if __name__ == "__main__":
181
+ parser = argparse.ArgumentParser(description="Evaluate OminiControl on COCO dataset")
182
+ parser.add_argument("--start_index", type=int, default=0, help="Starting index for evaluation")
183
+ parser.add_argument("--num_images", type=int, default=500, help="Number of images to evaluate")
184
+ parser.add_argument("--condition_type", type=str, default="deblurring", help="Type of condition (e.g., 'deblurring', 'canny', 'depth')")
185
+ parser.add_argument("--adapter_path", type=str, default="runs/20251111-212406-deblurring/ckpt/25000", help="Path to the adapter checkpoint")
186
+ args = parser.parse_args()
187
+
188
+ START_INDEX = args.start_index
189
+ NUM_IMAGES = args.num_images
190
+
191
+ # Path to your captions file (change if needed)
192
+ CAPTION_FILE = "/home/work/koopman/oft/data/coco/annotations/captions_val2017.json"
193
+ IMAGE_DIR = "/home/work/koopman/oft/data/coco/images/val2017/"
194
+ CONDITION_TYPE = args.condition_type
195
+ SAVE_ROOT_DIR = f"./coco/results_{CONDITION_TYPE}_1000/"
196
+ ADAPTER_PATH = args.adapter_path
197
+
198
+ # Load your Flux pipeline
199
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float32) # Replace with your model path
200
+
201
+ # add adapter to the transformer
202
+ transformer = pipe.transformer
203
+
204
+ adapter_name = "default"
205
+ transformer._hf_peft_config_loaded = True
206
+
207
+ # make sure this is the same with your config.yaml used in training
208
+ rotation_adapter_config = {
209
+ "r": 1,
210
+ "num_rotations": 8,
211
+ "target_modules": "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)",
212
+ }
213
+
214
+ config = RotationConfig(**rotation_adapter_config)
215
+ rotation_tuner = RotationTuner(
216
+ transformer,
217
+ config,
218
+ adapter_name=adapter_name,
219
+ )
220
+
221
+ transformer.set_adapter(adapter_name)
222
+
223
+ load_rotation(
224
+ transformer,
225
+ path=ADAPTER_PATH,
226
+ adapter_name=adapter_name,
227
+ strict=False,
228
+ )
229
+
230
+
231
+ pipe = pipe.to("cuda")
232
+
233
+ # # Prepare for inference
234
+ rotation_tuner.merge_adapter(["default"])
235
+
236
+ # Convert rotation tuner to bfloat16
237
+ pipe = pipe.to(torch.bfloat16)
238
+ pipe.transformer.eval()
239
+
240
+
241
+ # Evaluate on COCO
242
+ evaluate(
243
+ pipe,
244
+ condition_type=CONDITION_TYPE,
245
+ caption_file=CAPTION_FILE,
246
+ image_dir=IMAGE_DIR,
247
+ save_root_dir=SAVE_ROOT_DIR,
248
+ num_images=NUM_IMAGES,
249
+ start_index=START_INDEX,
250
+ )
evaluation_coco_baseline.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # example images:depth-41888,15254,16228, 24144,37777, 22192
2
+
3
+ # ablate image: 87038
4
+ import json
5
+ from PIL import Image
6
+ import os
7
+ import argparse
8
+ from omini.pipeline.flux_omini import Condition, generate, seed_everything, convert_to_condition
9
+ from omini.rotation import RotationConfig, RotationTuner
10
+ from tqdm import tqdm
11
+ import torch
12
+ from diffusers.pipelines import FluxPipeline
13
+
14
+ def evaluate(pipe,
15
+ condition_type: str, # e.g., "canny"
16
+ caption_file: str,
17
+ image_dir: str,
18
+ save_root_dir: str,
19
+ num_images: int = 5000,
20
+ start_index: int = 0):
21
+ """
22
+ Evaluate the model on a subset of the COCO dataset.
23
+
24
+ Args:
25
+ pipe: The flux pipeline to use for generation
26
+ caption_file: Path to the COCO captions JSON file
27
+ image_dir: Directory containing COCO images
28
+ num_images: Number of images to evaluate on
29
+ """
30
+
31
+ os.makedirs(os.path.join(save_root_dir, "generated"), exist_ok=True)
32
+ os.makedirs(os.path.join(save_root_dir, "resized"), exist_ok=True)
33
+ os.makedirs(os.path.join(save_root_dir, condition_type), exist_ok=True)
34
+
35
+ # Load data
36
+ with open(caption_file, "r") as f:
37
+ coco_data = json.load(f)
38
+
39
+ # Build a mapping: image_id → (filename, captions)
40
+ id_to_filename = {img["id"]: img["file_name"] for img in coco_data["images"]}
41
+ captions_by_image = {}
42
+
43
+ for ann in coco_data["annotations"]:
44
+ img_id = ann["image_id"]
45
+ captions_by_image.setdefault(img_id, []).append(ann["caption"])
46
+
47
+ # Take first 5000 images
48
+ image_ids = list(id_to_filename.keys())[:5000]
49
+
50
+ # Collect data
51
+ captions_subset = [
52
+ {
53
+ "image_id": img_id,
54
+ "file_name": id_to_filename[img_id],
55
+ "captions": captions_by_image.get(img_id, [])
56
+ }
57
+ for img_id in image_ids
58
+ ]
59
+
60
+ for item in tqdm(captions_subset[start_index:start_index+num_images]):
61
+
62
+ image_id = item["image_id"]
63
+ image_path = os.path.join(image_dir, item["file_name"])
64
+ image = Image.open(image_path).convert("RGB")
65
+
66
+ # Resize and center-crop to 512x512
67
+ w, h, min_dim = image.size + (min(image.size),)
68
+ image = image.crop(
69
+ ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)
70
+ ).resize((512, 512))
71
+
72
+ condition_image = convert_to_condition(condition_type, image)
73
+ condition = Condition(condition_image, condition_type)
74
+
75
+ prompt = item["captions"][0] if item["captions"] else "No caption available."
76
+
77
+ seed_everything(42)
78
+
79
+ # generate image
80
+ result_img = generate(
81
+ pipe,
82
+ prompt=prompt,
83
+ conditions=[condition],
84
+ ).images[0]
85
+
86
+
87
+ result_img.save(os.path.join(save_root_dir, "generated", f"{image_id}.jpg"))
88
+ image.save(os.path.join(save_root_dir, "resized", f"{image_id}.jpg"))
89
+ condition.condition.save(os.path.join(save_root_dir, condition_type, f"{image_id}.jpg"))
90
+
91
+
92
+
93
+
94
+
95
+ def load_rotation(transformer, path: str, adapter_name: str = "default", strict: bool = False):
96
+ """
97
+ Load rotation adapter weights.
98
+
99
+ Args:
100
+ path: Directory containing the saved adapter weights
101
+ adapter_name: Name of the adapter to load
102
+ strict: Whether to strictly match all keys
103
+ """
104
+ from safetensors.torch import load_file
105
+ import os
106
+ import yaml
107
+
108
+ device = transformer.device
109
+ print(f"device for loading: {device}")
110
+
111
+ # Try to load safetensors first, then fallback to .pth
112
+ safetensors_path = os.path.join(path, f"{adapter_name}.safetensors")
113
+ pth_path = os.path.join(path, f"{adapter_name}.pth")
114
+
115
+ if os.path.exists(safetensors_path):
116
+ state_dict = load_file(safetensors_path)
117
+ print(f"Loaded rotation adapter from {safetensors_path}")
118
+ elif os.path.exists(pth_path):
119
+ state_dict = torch.load(pth_path, map_location=device)
120
+ print(f"Loaded rotation adapter from {pth_path}")
121
+ else:
122
+ raise FileNotFoundError(
123
+ f"No adapter weights found for '{adapter_name}' in {path}\n"
124
+ f"Looking for: {safetensors_path} or {pth_path}"
125
+ )
126
+
127
+ # # Get the device and dtype of the transformer
128
+ transformer_device = next(transformer.parameters()).device
129
+ transformer_dtype = next(transformer.parameters()).dtype
130
+
131
+
132
+
133
+ state_dict_with_adapter = {}
134
+ for k, v in state_dict.items():
135
+ # Reconstruct the full key with adapter name
136
+ new_key = k.replace(".rotation.", f".rotation.{adapter_name}.")
137
+
138
+ # Move to target device and dtype
139
+ # Check if this parameter should keep its original dtype (e.g., indices, masks)
140
+ if v.dtype in [torch.long, torch.int, torch.int32, torch.int64, torch.bool]:
141
+ # Keep integer/boolean dtypes, only move device
142
+ state_dict_with_adapter[new_key] = v.to(device=transformer_device)
143
+ else:
144
+ # Convert floating point tensors to target dtype and device
145
+ state_dict_with_adapter[new_key] = v.to(device=transformer_device, dtype=transformer_dtype)
146
+
147
+ # Add adapter name back to keys (reverse of what we did in save)
148
+ state_dict_with_adapter = {
149
+ k.replace(".rotation.", f".rotation.{adapter_name}."): v
150
+ for k, v in state_dict.items()
151
+ }
152
+
153
+
154
+ # Load into the model
155
+ missing, unexpected = transformer.load_state_dict(
156
+ state_dict_with_adapter,
157
+ strict=strict
158
+ )
159
+
160
+ if missing:
161
+ print(f"Missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}")
162
+ if unexpected:
163
+ print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
164
+
165
+ # Load config if available
166
+ config_path = os.path.join(path, f"{adapter_name}_config.yaml")
167
+ if os.path.exists(config_path):
168
+ with open(config_path, 'r') as f:
169
+ config = yaml.safe_load(f)
170
+ print(f"Loaded config: {config}")
171
+
172
+ total_params = sum(p.numel() for p in state_dict.values())
173
+ print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)")
174
+
175
+ return state_dict
176
+
177
+
178
+
179
+
180
+ if __name__ == "__main__":
181
+ parser = argparse.ArgumentParser(description="Evaluate OminiControl on COCO dataset")
182
+ parser.add_argument("--start_index", type=int, default=0, help="Starting index for evaluation")
183
+ parser.add_argument("--num_images", type=int, default=500, help="Number of images to evaluate")
184
+ parser.add_argument("--condition_type", type=str, default="deblurring", help="Type of condition (e.g., 'deblurring', 'canny', 'depth')")
185
+ args = parser.parse_args()
186
+
187
+ START_INDEX = args.start_index
188
+ NUM_IMAGES = args.num_images
189
+
190
+ # Path to your captions file (change if needed)
191
+ CAPTION_FILE = "/home/work/koopman/oft/data/coco/annotations/captions_val2017.json"
192
+ IMAGE_DIR = "/home/work/koopman/oft/data/coco/images/val2017/"
193
+ CONDITION_TYPE = args.condition_type
194
+ SAVE_ROOT_DIR = f"./coco_baseline/results_{CONDITION_TYPE}_1000/"
195
+
196
+ # Load your Flux pipeline
197
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float16) # Replace with your model path
198
+
199
+ ### FOR OMINI
200
+
201
+ pipe.load_lora_weights(
202
+ "Yuanshi/OminiControl",
203
+ weight_name=f"experimental/{CONDITION_TYPE}.safetensors",
204
+ adapter_name=CONDITION_TYPE,
205
+ )
206
+ pipe.fuse_lora(lora_scale=1.0)
207
+ pipe.unload_lora_weights()
208
+
209
+ # pipe.set_adapters([CONDITION_TYPE])
210
+ pipe = pipe.to("cuda")
211
+
212
+
213
+ # Evaluate on COCO
214
+ evaluate(
215
+ pipe,
216
+ condition_type=CONDITION_TYPE,
217
+ caption_file=CAPTION_FILE,
218
+ image_dir=IMAGE_DIR,
219
+ save_root_dir=SAVE_ROOT_DIR,
220
+ num_images=NUM_IMAGES,
221
+ start_index=START_INDEX,
222
+ )
evaluation_subject_driven.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import base64
3
+ from pathlib import Path
4
+ import random
5
+ import os
6
+
7
+
8
+
9
+ evaluation_prompts = {
10
+ "identity": """
11
+ Compare the original subject image with the generated image.
12
+ Rate on a scale of 1-5 how well the essential identifying features
13
+ are preserved (logos, brand marks, distinctive patterns).
14
+ Score: [1-5]
15
+ Reasoning: [explanation]
16
+ """,
17
+
18
+ "material": """
19
+ Evaluate the material quality and surface characteristics.
20
+ Rate on a scale of 1-5 how accurately materials are represented
21
+ (textures, reflections, surface properties).
22
+ Score: [1-5]
23
+ Reasoning: [explanation]
24
+ """,
25
+
26
+ "color": """
27
+ Assess color fidelity in regions NOT specified for modification.
28
+ Rate on a scale of 1-5 how consistent colors remain.
29
+ Score: [1-5]
30
+ Reasoning: [explanation]
31
+ """,
32
+
33
+ "appearance": """
34
+ Evaluate the overall realism and coherence of the generated image.
35
+ Rate on a scale of 1-5 how realistic and natural it appears.
36
+ Score: [1-5]
37
+ Reasoning: [explanation]
38
+ """,
39
+
40
+ "modification": """
41
+ Given the text prompt: "{prompt}"
42
+ Rate on a scale of 1-5 how well the specified changes are executed.
43
+ Score: [1-5]
44
+ Reasoning: [explanation]
45
+ """
46
+ }
47
+
48
+
49
+ def encode_image(image_path):
50
+ with open(image_path, "rb") as image_file:
51
+ return base64.b64encode(image_file.read()).decode('utf-8')
52
+
53
+ def evaluate_subject_driven_generation(
54
+ original_image_path,
55
+ generated_image_path,
56
+ text_prompt,
57
+ client
58
+ ):
59
+ """
60
+ Evaluate a subject-driven generation using GPT-4o vision
61
+ """
62
+
63
+ # Encode images
64
+ original_img = encode_image(original_image_path)
65
+ generated_img = encode_image(generated_image_path)
66
+
67
+ results = {}
68
+
69
+ # 1. Identity Preservation
70
+ response = client.chat.completions.create(
71
+ model="gpt-4o",
72
+ messages=[{
73
+ "role": "user",
74
+ "content": [
75
+ {"type": "text", "text": "Original subject image:"},
76
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_img}"}},
77
+ {"type": "text", "text": "Generated image:"},
78
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}},
79
+ {"type": "text", "text": evaluation_prompts["identity"]}
80
+ ]
81
+ }],
82
+ max_tokens=300
83
+ )
84
+ results['identity'] = parse_score(response.choices[0].message.content)
85
+
86
+ # 2. Material Quality
87
+ response = client.chat.completions.create(
88
+ model="gpt-4o",
89
+ messages=[{
90
+ "role": "user",
91
+ "content": [
92
+ {"type": "text", "text": "Evaluate this generated image:"},
93
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}},
94
+ {"type": "text", "text": evaluation_prompts["material"]}
95
+ ]
96
+ }],
97
+ max_tokens=300
98
+ )
99
+ results['material'] = parse_score(response.choices[0].message.content)
100
+
101
+ # 3. Color Fidelity
102
+ response = client.chat.completions.create(
103
+ model="gpt-4o",
104
+ messages=[{
105
+ "role": "user",
106
+ "content": [
107
+ {"type": "text", "text": "Original:"},
108
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_img}"}},
109
+ {"type": "text", "text": "Generated:"},
110
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}},
111
+ {"type": "text", "text": evaluation_prompts["color"]}
112
+ ]
113
+ }],
114
+ max_tokens=300
115
+ )
116
+ results['color'] = parse_score(response.choices[0].message.content)
117
+
118
+ # 4. Natural Appearance
119
+ response = client.chat.completions.create(
120
+ model="gpt-4o",
121
+ messages=[{
122
+ "role": "user",
123
+ "content": [
124
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}},
125
+ {"type": "text", "text": evaluation_prompts["appearance"]}
126
+ ]
127
+ }],
128
+ max_tokens=300
129
+ )
130
+ results['appearance'] = parse_score(response.choices[0].message.content)
131
+
132
+ # 5. Modification Accuracy
133
+ response = client.chat.completions.create(
134
+ model="gpt-4o",
135
+ messages=[{
136
+ "role": "user",
137
+ "content": [
138
+ {"type": "text", "text": f"Text prompt: {text_prompt}"},
139
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}},
140
+ {"type": "text", "text": evaluation_prompts["modification"].format(prompt=text_prompt)}
141
+ ]
142
+ }],
143
+ max_tokens=300
144
+ )
145
+ results['modification'] = parse_score(response.choices[0].message.content)
146
+
147
+ return results
148
+
149
+ def parse_score(response_text):
150
+ """Extract score from GPT-4o response"""
151
+ # Look for "Score: X" pattern
152
+ import re
153
+ match = re.search(r'Score:\s*(\d+)', response_text)
154
+ if match:
155
+ return int(match.group(1))
156
+ return None
157
+
158
+ subject_names = [
159
+ "backpack", "backpack_dog", "bear_plushie", "berry_bowl", "can",
160
+ "candle", "cat", "cat2", "clock", "colorful_sneaker",
161
+ "dog", "dog2", "dog3", "dog5", "dog6",
162
+ "dog7", "dog8", "duck_toy", "fancy_boot", "grey_sloth_plushie",
163
+ "monster_toy", "pink_sunglasses", "poop_emoji", "rc_car", "red_cartoon",
164
+ "robot_toy", "shiny_sneaker", "teapot", "vase", "wolf_plushie"
165
+ ]
166
+
167
+
168
+ def get_prompt(subject_id, prompt_id):
169
+
170
+ # subject in 0|1|2|3|4|5|8|9|17|18|19|20|21|22|23|24|25|26|27|28|29
171
+ if subject_id in [0,1,2,3,4,5,8,9,17,18,19,20,21,22,23,24,25,26,27,28,29]:
172
+ subject_name = subject_names[subject_id]
173
+ prompts = [
174
+ f"a {subject_name} in the jungle",
175
+ f"a {subject_name} in the snow",
176
+ f"a {subject_name} on the beach",
177
+ f"a {subject_name} on a cobblestone street",
178
+ f"a {subject_name} on top of pink fabric",
179
+ f"a {subject_name} on top of a wooden floor",
180
+ f"a {subject_name} with a city in the background",
181
+ f"a {subject_name} with a mountain in the background",
182
+ f"a {subject_name} with a blue house in the background",
183
+ f"a {subject_name} on top of a purple rug in a forest",
184
+ f"a {subject_name} with a wheat field in the background",
185
+ f"a {subject_name} with a tree and autumn leaves in the background",
186
+ f"a {subject_name} with the Eiffel Tower in the background",
187
+ f"a {subject_name} floating on top of water",
188
+ f"a {subject_name} floating in an ocean of milk",
189
+ f"a {subject_name} on top of green grass with sunflowers around it",
190
+ f"a {subject_name} on top of a mirror",
191
+ f"a {subject_name} on top of the sidewalk in a crowded street",
192
+ f"a {subject_name} on top of a dirt road",
193
+ f"a {subject_name} on top of a white rug",
194
+ f"a red {subject_name}",
195
+ f"a purple {subject_name}",
196
+ f"a shiny {subject_name}",
197
+ f"a wet {subject_name}",
198
+ f"a cube shaped {subject_name}"
199
+ ]
200
+
201
+ else:
202
+ prompts = [
203
+ f"a {subject_name} in the jungle",
204
+ f"a {subject_name} in the snow",
205
+ f"a {subject_name} on the beach",
206
+ f"a {subject_name} on a cobblestone street",
207
+ f"a {subject_name} on top of pink fabric",
208
+ f"a {subject_name} on top of a wooden floor",
209
+ f"a {subject_name} with a city in the background",
210
+ f"a {subject_name} with a mountain in the background",
211
+ f"a {subject_name} with a blue house in the background",
212
+ f"a {subject_name} on top of a purple rug in a forest",
213
+ f"a {subject_name} wearing a red hat",
214
+ f"a {subject_name} wearing a santa hat",
215
+ f"a {subject_name} wearing a rainbow scarf",
216
+ f"a {subject_name} wearing a black top hat and a monocle",
217
+ f"a {subject_name} in a chef outfit",
218
+ f"a {subject_name} in a firefighter outfit",
219
+ f"a {subject_name} in a police outfit",
220
+ f"a {subject_name} wearing pink glasses",
221
+ f"a {subject_name} wearing a yellow shirt",
222
+ f"a {subject_name} in a purple wizard outfit",
223
+ f"a red {subject_name}",
224
+ f"a purple {subject_name}",
225
+ f"a shiny {subject_name}",
226
+ f"a wet {subject_name}",
227
+ f"a cube shaped {subject_name}"
228
+ ]
229
+
230
+ return prompts[prompt_id]
231
+
232
+
233
+
234
+
235
+
236
+ def batch_evaluate_dreambooth(client, generate_fn, dataset_path, output_csv):
237
+ """
238
+ Evaluate 750 image pairs with 5 seeds each
239
+ """
240
+ import pandas as pd
241
+
242
+ results_list = []
243
+
244
+ # Iterate through DreamBooth dataset
245
+ for subject_id in range(30): # 30 subjects
246
+ subject_name = subject_names[subject_id]
247
+ for prompt_id in range(25): # 25 prompts per subject
248
+ original = f"{dataset_path}/{subject_name}"
249
+ # get a random file in this folder
250
+ original_files = list(Path(original).glob("*.png"))
251
+ if len(original_files) == 0:
252
+ raise ValueError(f"No original images found in {original}")
253
+
254
+ original = str(original_files[0])
255
+
256
+
257
+ for seed in range(5): # 5 different seeds
258
+ # take random file in the folder
259
+ prompt = get_prompt(subject_id, prompt_id)
260
+
261
+ # generated image path
262
+ generated_folder = f"{dataset_path}/{subject_name}/generated/"
263
+ os.makedirs(generated_folder, exist_ok=True)
264
+ generated = f"{generated_folder}/gen_seed{seed}_prompt{prompt_id}.png"
265
+
266
+ generate_fn(
267
+ prompt=prompt,
268
+ subject_image_path=original,
269
+ output_image_path=generated,
270
+ seed=seed
271
+ )
272
+
273
+ scores = evaluate_subject_driven_generation(
274
+ original, generated, prompt, client
275
+ )
276
+
277
+ results_list.append({
278
+ 'subject_id': subject_id,
279
+ 'subject_name': subject_name,
280
+ 'prompt_id': prompt_id,
281
+ 'seed': seed,
282
+ 'prompt': prompt,
283
+
284
+ **scores
285
+ })
286
+
287
+ # Save results
288
+ df = pd.DataFrame(results_list)
289
+ df.to_csv(output_csv, index=False)
290
+
291
+ # Calculate statistics
292
+ print(df.groupby('subject_id').mean())
293
+ print(f"\nOverall averages:")
294
+ print(df[['identity', 'material', 'color', 'appearance', 'modification']].mean())
295
+
296
+
297
+ def evaluate_omini_control():
298
+
299
+ import torch
300
+ from diffusers.pipelines import FluxPipeline
301
+ from PIL import Image
302
+
303
+ from omini.pipeline.flux_omini import Condition, generate, seed_everything
304
+
305
+ pipe = FluxPipeline.from_pretrained(
306
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
307
+ )
308
+
309
+ pipe = pipe.to("cuda")
310
+ pipe.load_lora_weights(
311
+ "Yuanshi/OminiControl",
312
+ weight_name=f"omini/subject_512.safetensors",
313
+ adapter_name="subject",
314
+ )
315
+
316
+ def generate_fn(image_path, prompt, seed, output_path):
317
+ seed_everything(seed)
318
+
319
+ image = Image.open(image_path).convert("RGB").resize((512, 512))
320
+ condition = Condition.from_image(
321
+ image,
322
+ "subject", position_delta=(0, 32)
323
+ )
324
+
325
+ result_img = generate(
326
+ pipe,
327
+ prompt=prompt,
328
+ conditions=[condition],
329
+ ).images[0]
330
+
331
+ result_img.save(output_path)
332
+
333
+ return generate_fn
334
+
335
+
336
+ if __name__ == "__main__":
337
+
338
+
339
+
340
+ openai.api_key = os.getenv("OPENAI_API_KEY")
341
+ # client = openai.Client()
342
+
343
+ # generate_fn = evaluate_omini_control()
344
+
345
+ # dataset_path = "data/dreambooth"
346
+ # output_csv = "evaluation_subject_driven_omini_control.csv"
347
+
348
+ # batch_evaluate_dreambooth(
349
+ # client,
350
+ # generate_fn,
351
+ # dataset_path,
352
+ # output_csv
353
+ # )
354
+
355
+ result = evaluate_subject_driven_generation(
356
+ "data/dreambooth/backpack/00.jpg",
357
+ "data/dreambooth/backpack/01.jpg",
358
+ "a backpack in the jungle",
359
+ openai.Client()
360
+ )
361
+
362
+ print(result)
examples/combine_with_style_lora.ipynb ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "\n",
11
+ "os.chdir(\"..\")"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import torch\n",
21
+ "from diffusers.pipelines import FluxPipeline\n",
22
+ "from PIL import Image\n",
23
+ "\n",
24
+ "from omini.pipeline.flux_omini import Condition, generate, seed_everything"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "\n",
34
+ "pipe = FluxPipeline.from_pretrained(\n",
35
+ " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n",
36
+ ")\n",
37
+ "pipe = pipe.to(\"cuda\")"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "pipe.unload_lora_weights()\n",
47
+ "\n",
48
+ "pipe.load_lora_weights(\n",
49
+ " \"Yuanshi/OminiControl\",\n",
50
+ " weight_name=f\"omini/subject_512.safetensors\",\n",
51
+ " adapter_name=\"subject\",\n",
52
+ ")\n",
53
+ "pipe.load_lora_weights(\"XLabs-AI/flux-RealismLora\", adapter_name=\"realism\")\n",
54
+ "\n",
55
+ "pipe.set_adapters([\"subject\", \"realism\"])"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n",
65
+ "\n",
66
+ "# For this model, the position_delta is (0, 32).\n",
67
+ "# For more details of position_delta, please refer to:\n",
68
+ "# https://github.com/Yuanshi9815/OminiControl/issues/89#issuecomment-2827080344\n",
69
+ "condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
70
+ "\n",
71
+ "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n",
72
+ "\n",
73
+ "\n",
74
+ "seed_everything(0)\n",
75
+ "\n",
76
+ "result_img = generate(\n",
77
+ " pipe,\n",
78
+ " prompt=prompt,\n",
79
+ " conditions=[condition],\n",
80
+ " num_inference_steps=8,\n",
81
+ " height=512,\n",
82
+ " width=512,\n",
83
+ " main_adapter=\"realism\"\n",
84
+ ").images[0]\n",
85
+ "\n",
86
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
87
+ "concat_image.paste(image, (0, 0))\n",
88
+ "concat_image.paste(result_img, (512, 0))\n",
89
+ "concat_image"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": null,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n",
99
+ "\n",
100
+ "condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
101
+ "\n",
102
+ "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n",
103
+ "\n",
104
+ "\n",
105
+ "seed_everything()\n",
106
+ "\n",
107
+ "result_img = generate(\n",
108
+ " pipe,\n",
109
+ " prompt=prompt,\n",
110
+ " conditions=[condition],\n",
111
+ " num_inference_steps=8,\n",
112
+ " height=512,\n",
113
+ " width=512,\n",
114
+ " main_adapter=\"realism\"\n",
115
+ ").images[0]\n",
116
+ "\n",
117
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
118
+ "concat_image.paste(condition.condition, (0, 0))\n",
119
+ "concat_image.paste(result_img, (512, 0))\n",
120
+ "concat_image"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n",
130
+ "\n",
131
+ "condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
132
+ "\n",
133
+ "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n",
134
+ "\n",
135
+ "seed_everything()\n",
136
+ "\n",
137
+ "result_img = generate(\n",
138
+ " pipe,\n",
139
+ " prompt=prompt,\n",
140
+ " conditions=[condition],\n",
141
+ " num_inference_steps=8,\n",
142
+ " height=512,\n",
143
+ " width=512,\n",
144
+ " main_adapter=\"realism\"\n",
145
+ ").images[0]\n",
146
+ "\n",
147
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
148
+ "concat_image.paste(condition.condition, (0, 0))\n",
149
+ "concat_image.paste(result_img, (512, 0))\n",
150
+ "concat_image"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "metadata": {},
157
+ "outputs": [],
158
+ "source": [
159
+ "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n",
160
+ "\n",
161
+ "condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
162
+ "\n",
163
+ "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n",
164
+ "\n",
165
+ "seed_everything()\n",
166
+ "\n",
167
+ "result_img = generate(\n",
168
+ " pipe,\n",
169
+ " prompt=prompt,\n",
170
+ " conditions=[condition],\n",
171
+ " num_inference_steps=8,\n",
172
+ " height=512,\n",
173
+ " width=512,\n",
174
+ " main_adapter=\"realism\"\n",
175
+ ").images[0]\n",
176
+ "\n",
177
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
178
+ "concat_image.paste(condition.condition, (0, 0))\n",
179
+ "concat_image.paste(result_img, (512, 0))\n",
180
+ "concat_image"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": null,
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": [
189
+ "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n",
190
+ "\n",
191
+ "condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
192
+ "\n",
193
+ "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n",
194
+ "\n",
195
+ "seed_everything()\n",
196
+ "\n",
197
+ "result_img = generate(\n",
198
+ " pipe,\n",
199
+ " prompt=prompt,\n",
200
+ " conditions=[condition],\n",
201
+ " num_inference_steps=8,\n",
202
+ " height=512,\n",
203
+ " width=512,\n",
204
+ " main_adapter=\"realism\"\n",
205
+ ").images[0]\n",
206
+ "\n",
207
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
208
+ "concat_image.paste(condition.condition, (0, 0))\n",
209
+ "concat_image.paste(result_img, (512, 0))\n",
210
+ "concat_image"
211
+ ]
212
+ }
213
+ ],
214
+ "metadata": {
215
+ "kernelspec": {
216
+ "display_name": "Python 3 (ipykernel)",
217
+ "language": "python",
218
+ "name": "python3"
219
+ },
220
+ "language_info": {
221
+ "codemirror_mode": {
222
+ "name": "ipython",
223
+ "version": 3
224
+ },
225
+ "file_extension": ".py",
226
+ "mimetype": "text/x-python",
227
+ "name": "python",
228
+ "nbconvert_exporter": "python",
229
+ "pygments_lexer": "ipython3",
230
+ "version": "3.9.21"
231
+ }
232
+ },
233
+ "nbformat": 4,
234
+ "nbformat_minor": 2
235
+ }
examples/inpainting.ipynb ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "\n",
11
+ "os.chdir(\"..\")"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import torch\n",
21
+ "from diffusers.pipelines import FluxPipeline\n",
22
+ "from PIL import Image\n",
23
+ "\n",
24
+ "from omini.pipeline.flux_omini import Condition, generate, seed_everything"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "pipe = FluxPipeline.from_pretrained(\n",
34
+ " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
35
+ ")\n",
36
+ "pipe = pipe.to(\"cuda\")"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "pipe.load_lora_weights(\n",
46
+ " \"Yuanshi/OminiControl\",\n",
47
+ " weight_name=f\"experimental/fill.safetensors\",\n",
48
+ " adapter_name=\"fill\",\n",
49
+ ")"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "image = Image.open(\"assets/monalisa.jpg\").convert(\"RGB\").resize((512, 512))\n",
59
+ "\n",
60
+ "masked_image = image.copy()\n",
61
+ "masked_image.paste((0, 0, 0), (128, 100, 384, 220))\n",
62
+ "\n",
63
+ "condition = Condition(masked_image, \"fill\")\n",
64
+ "\n",
65
+ "seed_everything()\n",
66
+ "result_img = generate(\n",
67
+ " pipe,\n",
68
+ " prompt=\"The Mona Lisa is wearing a white VR headset with 'Omini' written on it.\",\n",
69
+ " conditions=[condition],\n",
70
+ ").images[0]\n",
71
+ "\n",
72
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
73
+ "concat_image.paste(image, (0, 0))\n",
74
+ "concat_image.paste(condition.condition, (512, 0))\n",
75
+ "concat_image.paste(result_img, (1024, 0))\n",
76
+ "concat_image"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": [
85
+ "image = Image.open(\"assets/book.jpg\").convert(\"RGB\").resize((512, 512))\n",
86
+ "\n",
87
+ "w, h, min_dim = image.size + (min(image.size),)\n",
88
+ "image = image.crop(\n",
89
+ " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n",
90
+ ").resize((512, 512))\n",
91
+ "\n",
92
+ "\n",
93
+ "masked_image = image.copy()\n",
94
+ "masked_image.paste((0, 0, 0), (150, 150, 350, 250))\n",
95
+ "masked_image.paste((0, 0, 0), (200, 380, 320, 420))\n",
96
+ "\n",
97
+ "condition = Condition(masked_image, \"fill\")\n",
98
+ "\n",
99
+ "seed_everything()\n",
100
+ "result_img = generate(\n",
101
+ " pipe,\n",
102
+ " prompt=\"A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.\",\n",
103
+ " conditions=[condition],\n",
104
+ ").images[0]\n",
105
+ "\n",
106
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
107
+ "concat_image.paste(image, (0, 0))\n",
108
+ "concat_image.paste(condition.condition, (512, 0))\n",
109
+ "concat_image.paste(result_img, (1024, 0))\n",
110
+ "concat_image"
111
+ ]
112
+ }
113
+ ],
114
+ "metadata": {
115
+ "kernelspec": {
116
+ "display_name": "base",
117
+ "language": "python",
118
+ "name": "python3"
119
+ },
120
+ "language_info": {
121
+ "codemirror_mode": {
122
+ "name": "ipython",
123
+ "version": 3
124
+ },
125
+ "file_extension": ".py",
126
+ "mimetype": "text/x-python",
127
+ "name": "python",
128
+ "nbconvert_exporter": "python",
129
+ "pygments_lexer": "ipython3",
130
+ "version": "3.12.3"
131
+ }
132
+ },
133
+ "nbformat": 4,
134
+ "nbformat_minor": 2
135
+ }
examples/ominicontrol_art.ipynb ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "\n",
11
+ "os.chdir(\"..\")"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import torch\n",
21
+ "from diffusers.pipelines import FluxPipeline\n",
22
+ "from PIL import Image\n",
23
+ "\n",
24
+ "from omini.pipeline.flux_omini import Condition, generate, seed_everything"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "pipe = FluxPipeline.from_pretrained(\n",
34
+ " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
35
+ ")\n",
36
+ "pipe = pipe.to(\"cuda\")"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "pipe.unload_lora_weights()\n",
46
+ "\n",
47
+ "for style_type in [\"ghibli\", \"irasutoya\", \"simpsons\", \"snoopy\"]:\n",
48
+ " pipe.load_lora_weights(\n",
49
+ " \"Yuanshi/OminiControlArt\",\n",
50
+ " weight_name=f\"v0/{style_type}.safetensors\",\n",
51
+ " adapter_name=style_type,\n",
52
+ " )\n",
53
+ "\n",
54
+ "pipe.set_adapters([\"ghibli\", \"irasutoya\", \"simpsons\", \"snoopy\"])"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "def resize(img, factor=16):\n",
64
+ " # Resize the image to be divisible by the factor\n",
65
+ " w, h = img.size\n",
66
+ " new_w, new_h = w // factor * factor, h // factor * factor\n",
67
+ " padding_w, padding_h = (w - new_w) // 2, (h - new_h) // 2\n",
68
+ " img = img.crop((padding_w, padding_h, new_w + padding_w, new_h + padding_h))\n",
69
+ " return img\n",
70
+ "\n",
71
+ "\n",
72
+ "def bound_image(image):\n",
73
+ " factor = 512 / max(image.size)\n",
74
+ " image = resize(\n",
75
+ " image.resize(\n",
76
+ " (int(image.size[0] * factor), int(image.size[1] * factor)),\n",
77
+ " Image.LANCZOS,\n",
78
+ " )\n",
79
+ " )\n",
80
+ " delta = (0, -image.size[0] // 16)\n",
81
+ " return image, delta\n",
82
+ "\n",
83
+ "sizes = {\n",
84
+ " \"2:3\": (640, 960),\n",
85
+ " \"1:1\": (640, 640),\n",
86
+ " \"3:2\": (960, 640),\n",
87
+ "}"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "image = Image.open(\"assets/ominicontrol_art/DistractedBoyfriend.webp\").convert(\"RGB\")\n",
97
+ "image, delta = bound_image(image)\n",
98
+ "condition = Condition(image, \"ghibli\", position_delta=delta)\n",
99
+ "\n",
100
+ "seed_everything()\n",
101
+ "\n",
102
+ "size = sizes[\"3:2\"]\n",
103
+ "\n",
104
+ "result_img = generate(\n",
105
+ " pipe,\n",
106
+ " prompt=\"\",\n",
107
+ " conditions=[condition],\n",
108
+ " max_sequence_length=32,\n",
109
+ " width=size[0],\n",
110
+ " height=size[1],\n",
111
+ " image_guidance_scale=1.5,\n",
112
+ ").images[0]\n"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "image = Image.open(\"assets/ominicontrol_art/oiiai.png\").convert(\"RGB\")\n",
122
+ "image, delta = bound_image(image)\n",
123
+ "condition = Condition(image, \"irasutoya\", position_delta=delta)\n",
124
+ "\n",
125
+ "seed_everything()\n",
126
+ "\n",
127
+ "size = sizes[\"1:1\"]\n",
128
+ "\n",
129
+ "result_img = generate(\n",
130
+ " pipe,\n",
131
+ " prompt=\"\",\n",
132
+ " conditions=[condition],\n",
133
+ " max_sequence_length=32,\n",
134
+ " width=size[0],\n",
135
+ " height=size[1],\n",
136
+ " image_guidance_scale=1.5,\n",
137
+ ").images[0]\n",
138
+ "\n",
139
+ "result_img"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": null,
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": [
148
+ "image = Image.open(\"assets/ominicontrol_art/breakingbad.jpg\").convert(\"RGB\")\n",
149
+ "image, delta = bound_image(image)\n",
150
+ "condition = Condition(image, \"simpsons\", position_delta=delta)\n",
151
+ "\n",
152
+ "seed_everything()\n",
153
+ "\n",
154
+ "size = sizes[\"3:2\"]\n",
155
+ "\n",
156
+ "result_img = generate(\n",
157
+ " pipe,\n",
158
+ " prompt=\"\",\n",
159
+ " conditions=[condition],\n",
160
+ " max_sequence_length=32,\n",
161
+ " width=size[0],\n",
162
+ " height=size[1],\n",
163
+ " image_guidance_scale=1.5,\n",
164
+ ").images[0]\n",
165
+ "\n",
166
+ "result_img"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "image = Image.open(\"assets/ominicontrol_art/PulpFiction.jpg\").convert(\"RGB\")\n",
176
+ "image, delta = bound_image(image)\n",
177
+ "condition = Condition(image, \"snoopy\", position_delta=delta)\n",
178
+ "\n",
179
+ "seed_everything()\n",
180
+ "\n",
181
+ "size = sizes[\"3:2\"]\n",
182
+ "\n",
183
+ "result_img = generate(\n",
184
+ " pipe,\n",
185
+ " prompt=\"\",\n",
186
+ " conditions=[condition],\n",
187
+ " max_sequence_length=32,\n",
188
+ " width=size[0],\n",
189
+ " height=size[1],\n",
190
+ " image_guidance_scale=1.5,\n",
191
+ ").images[0]\n",
192
+ "\n",
193
+ "result_img"
194
+ ]
195
+ }
196
+ ],
197
+ "metadata": {
198
+ "kernelspec": {
199
+ "display_name": "Python 3 (ipykernel)",
200
+ "language": "python",
201
+ "name": "python3"
202
+ },
203
+ "language_info": {
204
+ "codemirror_mode": {
205
+ "name": "ipython",
206
+ "version": 3
207
+ },
208
+ "file_extension": ".py",
209
+ "mimetype": "text/x-python",
210
+ "name": "python",
211
+ "nbconvert_exporter": "python",
212
+ "pygments_lexer": "ipython3",
213
+ "version": "3.9.21"
214
+ }
215
+ },
216
+ "nbformat": 4,
217
+ "nbformat_minor": 2
218
+ }
examples/spatial.ipynb ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "\n",
11
+ "os.chdir(\"..\")"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import torch\n",
21
+ "from diffusers.pipelines import FluxPipeline\n",
22
+ "from PIL import Image\n",
23
+ "\n",
24
+ "from omini.pipeline.flux_omini import Condition, generate, seed_everything, convert_to_condition"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "pipe = FluxPipeline.from_pretrained(\n",
34
+ " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
35
+ ")\n",
36
+ "pipe = pipe.to(\"cuda\")"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "pipe.unload_lora_weights()\n",
46
+ "\n",
47
+ "for condition_type in [\"canny\", \"depth\", \"coloring\", \"deblurring\"]:\n",
48
+ " pipe.load_lora_weights(\n",
49
+ " \"Yuanshi/OminiControl\",\n",
50
+ " weight_name=f\"experimental/{condition_type}.safetensors\",\n",
51
+ " adapter_name=condition_type,\n",
52
+ " )\n",
53
+ "\n",
54
+ "pipe.set_adapters([\"canny\", \"depth\", \"coloring\", \"deblurring\"])"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "image = Image.open(\"assets/coffee.png\").convert(\"RGB\")\n",
64
+ "\n",
65
+ "w, h, min_dim = image.size + (min(image.size),)\n",
66
+ "image = image.crop(\n",
67
+ " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n",
68
+ ").resize((512, 512))\n",
69
+ "\n",
70
+ "prompt = \"In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table.\""
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": [
79
+ "canny_image = convert_to_condition(\"canny\", image)\n",
80
+ "condition = Condition(canny_image, \"canny\")\n",
81
+ "\n",
82
+ "seed_everything()\n",
83
+ "\n",
84
+ "result_img = generate(\n",
85
+ " pipe,\n",
86
+ " prompt=prompt,\n",
87
+ " conditions=[condition],\n",
88
+ ").images[0]\n",
89
+ "\n",
90
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
91
+ "concat_image.paste(image, (0, 0))\n",
92
+ "concat_image.paste(condition.condition, (512, 0))\n",
93
+ "concat_image.paste(result_img, (1024, 0))\n",
94
+ "concat_image"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "depth_image = convert_to_condition(\"depth\", image)\n",
104
+ "condition = Condition(depth_image, \"depth\")\n",
105
+ "\n",
106
+ "seed_everything()\n",
107
+ "\n",
108
+ "result_img = generate(\n",
109
+ " pipe,\n",
110
+ " prompt=prompt,\n",
111
+ " conditions=[condition],\n",
112
+ ").images[0]\n",
113
+ "\n",
114
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
115
+ "concat_image.paste(image, (0, 0))\n",
116
+ "concat_image.paste(condition.condition, (512, 0))\n",
117
+ "concat_image.paste(result_img, (1024, 0))\n",
118
+ "concat_image"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "blur_image = convert_to_condition(\"deblurring\", image)\n",
128
+ "condition = Condition(blur_image, \"deblurring\")\n",
129
+ "\n",
130
+ "seed_everything()\n",
131
+ "\n",
132
+ "result_img = generate(\n",
133
+ " pipe,\n",
134
+ " prompt=prompt,\n",
135
+ " conditions=[condition],\n",
136
+ ").images[0]\n",
137
+ "\n",
138
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
139
+ "concat_image.paste(image, (0, 0))\n",
140
+ "concat_image.paste(condition.condition, (512, 0))\n",
141
+ "concat_image.paste(result_img, (1024, 0))\n",
142
+ "concat_image"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": null,
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": [
151
+ "condition_image = convert_to_condition(\"coloring\", image)\n",
152
+ "condition = Condition(condition_image, \"coloring\")\n",
153
+ "\n",
154
+ "seed_everything()\n",
155
+ "\n",
156
+ "result_img = generate(\n",
157
+ " pipe,\n",
158
+ " prompt=prompt,\n",
159
+ " conditions=[condition],\n",
160
+ ").images[0]\n",
161
+ "\n",
162
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
163
+ "concat_image.paste(image, (0, 0))\n",
164
+ "concat_image.paste(condition.condition, (512, 0))\n",
165
+ "concat_image.paste(result_img, (1024, 0))\n",
166
+ "concat_image"
167
+ ]
168
+ }
169
+ ],
170
+ "metadata": {
171
+ "kernelspec": {
172
+ "display_name": "base",
173
+ "language": "python",
174
+ "name": "python3"
175
+ },
176
+ "language_info": {
177
+ "codemirror_mode": {
178
+ "name": "ipython",
179
+ "version": 3
180
+ },
181
+ "file_extension": ".py",
182
+ "mimetype": "text/x-python",
183
+ "name": "python",
184
+ "nbconvert_exporter": "python",
185
+ "pygments_lexer": "ipython3",
186
+ "version": "3.12.3"
187
+ }
188
+ },
189
+ "nbformat": 4,
190
+ "nbformat_minor": 2
191
+ }