hello10000 commited on
Commit
4ab551f
·
1 Parent(s): f5718c2

Add folder with files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +229 -0
  2. .python-version +1 -0
  3. LICENSE +201 -0
  4. README-zh.md +164 -0
  5. app.py +106 -0
  6. datasets/make_yolo_images.py +49 -0
  7. example.py +9 -0
  8. ffmpeg/README.md +69 -0
  9. notebooks/imputation.ipynb +0 -0
  10. one-click-portable.md +26 -0
  11. pyproject.toml +39 -0
  12. resources/first_frame.json +0 -0
  13. resources/watermark_template.png +0 -0
  14. sorawm/__init__.py +0 -0
  15. sorawm/configs.py +27 -0
  16. sorawm/core.py +197 -0
  17. sorawm/iopaint/__init__.py +56 -0
  18. sorawm/iopaint/__main__.py +4 -0
  19. sorawm/iopaint/api.py +411 -0
  20. sorawm/iopaint/batch_processing.py +128 -0
  21. sorawm/iopaint/benchmark.py +109 -0
  22. sorawm/iopaint/cli.py +245 -0
  23. sorawm/iopaint/const.py +134 -0
  24. sorawm/iopaint/download.py +314 -0
  25. sorawm/iopaint/file_manager/__init__.py +1 -0
  26. sorawm/iopaint/file_manager/file_manager.py +220 -0
  27. sorawm/iopaint/file_manager/storage_backends.py +46 -0
  28. sorawm/iopaint/file_manager/utils.py +64 -0
  29. sorawm/iopaint/helper.py +411 -0
  30. sorawm/iopaint/installer.py +11 -0
  31. sorawm/iopaint/model/__init__.py +38 -0
  32. sorawm/iopaint/model/anytext/__init__.py +0 -0
  33. sorawm/iopaint/model/anytext/anytext_model.py +73 -0
  34. sorawm/iopaint/model/anytext/anytext_pipeline.py +401 -0
  35. sorawm/iopaint/model/anytext/anytext_sd15.yaml +99 -0
  36. sorawm/iopaint/model/anytext/cldm/__init__.py +0 -0
  37. sorawm/iopaint/model/anytext/cldm/cldm.py +780 -0
  38. sorawm/iopaint/model/anytext/cldm/ddim_hacked.py +486 -0
  39. sorawm/iopaint/model/anytext/cldm/embedding_manager.py +185 -0
  40. sorawm/iopaint/model/anytext/cldm/hack.py +128 -0
  41. sorawm/iopaint/model/anytext/cldm/model.py +41 -0
  42. sorawm/iopaint/model/anytext/cldm/recognizer.py +302 -0
  43. sorawm/iopaint/model/anytext/ldm/__init__.py +0 -0
  44. sorawm/iopaint/model/anytext/ldm/models/__init__.py +0 -0
  45. sorawm/iopaint/model/anytext/ldm/models/autoencoder.py +275 -0
  46. sorawm/iopaint/model/anytext/ldm/models/diffusion/__init__.py +0 -0
  47. sorawm/iopaint/model/anytext/ldm/models/diffusion/ddim.py +525 -0
  48. sorawm/iopaint/model/anytext/ldm/models/diffusion/ddpm.py +2386 -0
  49. sorawm/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py +1 -0
  50. sorawm/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py +1464 -0
.gitignore ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
208
+ output
209
+
210
+ videos
211
+
212
+ datasets/images
213
+ datasets/labels
214
+ datasets/coco8
215
+ .DS_store
216
+ outputs
217
+ yolo11n.pt
218
+ yolo11s.pt
219
+ best.pt
220
+
221
+ .claude
222
+ best.pt
223
+
224
+ runs
225
+ .idea
226
+ working_dir
227
+ data
228
+ upload_to_huggingface.py
229
+ resources/best.pt
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README-zh.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SoraWatermarkCleaner
2
+
3
+ [English](README.md) | 中文
4
+
5
+ 这个项目提供了一种优雅的方式来移除 Sora2 生成视频中的 Sora 水印。
6
+
7
+
8
+ - 移除水印后
9
+
10
+ https://github.com/user-attachments/assets/8cdc075e-7d15-4d04-8fa2-53dd287e5f4c
11
+
12
+ - 原始视频
13
+
14
+ https://github.com/user-attachments/assets/3c850ff1-b8e3-41af-a46f-2c734406e77d
15
+
16
+ ⭐️:
17
+
18
+ 1. **YOLO 权重已更新** — 请尝试新版本的水印检测模型,效果会更好!
19
+
20
+ 2. **数据集已开源** — 我们已经将标注好的数据集上传到了 Hugging Face,查看[此数据集](https://huggingface.co/datasets/LLinked/sora-watermark-dataset)。欢迎训练你自己的检测模型或改进我们的模型!
21
+
22
+ 3. **一键便携版已发布** — [点击这里下载](#3-一键便携版),Windows 用户无需安装即可使用!
23
+
24
+
25
+ ## 1. 方法
26
+
27
+ SoraWatermarkCleaner(后面我们简称为 `SoraWm`)由两部分组成:
28
+
29
+ - SoraWaterMarkDetector:我们训练了一个 yolov11s 版本来检测 Sora 水印。(感谢 YOLO!)
30
+
31
+ - WaterMarkCleaner:我们参考了 IOPaint 的实现,使用 LAMA 模型进行水印移除。
32
+
33
+ (此代码库来自 https://github.com/Sanster/IOPaint#,感谢他们的出色工作!)
34
+
35
+ 我们的 SoraWm 完全由深度学习驱动,在许多生成的视频中都能产生良好的效果。
36
+
37
+
38
+
39
+ ## 2. 安装
40
+ 视频处理需要 [FFmpeg](https://ffmpeg.org/),请先安装它。我们强烈推荐使用 `uv` 来安装环境:
41
+
42
+ 1. 安装:
43
+
44
+ ```bash
45
+ uv sync
46
+ ```
47
+
48
+ > 现在环境将被安装在 `.venv` 目录下,你可以使用以下命令激活环境:
49
+ >
50
+ > ```bash
51
+ > source .venv/bin/activate
52
+ > ```
53
+
54
+ 2. 下载预训练模型:
55
+
56
+ 训练好的 YOLO 权重将存储在 `resources` 目录中,文件名为 `best.pt`。它将从 https://github.com/linkedlist771/SoraWatermarkCleaner/releases/download/V0.0.1/best.pt 自动下载。`Lama` 模型从 https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt 下载,并将存储在 torch 缓存目录中。两者都是自动下载的,如果失败,请检查你的网络状态。
57
+
58
+ ## 3. 一键便携版
59
+
60
+ 对于不想手动安装的用户,我们提供了**一键便携版本**,包含所有预配置的依赖项,开箱即用。
61
+
62
+ ### 下载链接
63
+
64
+ **Google Drive(谷歌云盘):**
65
+ - [从 Google Drive 下载](https://drive.google.com/file/d/1ujH28aHaCXGgB146g6kyfz3Qxd-wHR1c/view?usp=share_link)
66
+
67
+ **百度网盘(推荐国内用户使用):**
68
+ - 链接:https://pan.baidu.com/s/1i4exYsPvXv0evnGs5MWcYA?pwd=3jr6
69
+ - 提取码:`3jr6`
70
+
71
+ ### 特点
72
+ - ✅ 无需安装
73
+ - ✅ 包含所有依赖
74
+ - ✅ 预配置环境
75
+ - ✅ 开箱即用
76
+
77
+ 只需下载、解压并运行!
78
+
79
+ ## 4. 演示
80
+
81
+ 基本用法,只需尝试 `example.py`:
82
+
83
+ ```python
84
+
85
+ from pathlib import Path
86
+ from sorawm.core import SoraWM
87
+
88
+
89
+ if __name__ == "__main__":
90
+ input_video_path = Path(
91
+ "resources/dog_vs_sam.mp4"
92
+ )
93
+ output_video_path = Path("outputs/sora_watermark_removed.mp4")
94
+ sora_wm = SoraWM()
95
+ sora_wm.run(input_video_path, output_video_path)
96
+
97
+ ```
98
+
99
+ 我们还提供了基于 `streamlit` 的交互式网页界面,使用以下命令尝试:
100
+
101
+ ```bash
102
+ streamlit run app.py
103
+ ```
104
+
105
+ <img src="resources/app.png" style="zoom: 25%;" />
106
+
107
+ ## 5. WebServer
108
+
109
+ 在这里,我们提供了一个基于 FastAPI 的 Web 服务器,可以快速将这个水印清除器转换为服务。
110
+
111
+ 只需运行:
112
+
113
+ ```python
114
+ python start_server.py
115
+ ```
116
+
117
+ Web 服务器将在端口 `5344` 启动,你可以查看 FastAPI [文档](http://localhost:5344/docs) 了解详情,有三个路由:
118
+
119
+ 1. submit_remove_task:
120
+
121
+ > 上传视频后,会返回一个任务 ID,该视频将立即被处理。
122
+
123
+ <img src="resources/53abf3fd-11a9-4dd7-a348-34920775f8ad.png" alt="image" style="zoom: 25%;" />
124
+
125
+ 2. get_results:
126
+
127
+ 你可以使用上面的任务 ID 检索任务状态,它会显示视频处理的百分比。一旦完成,返回的数据中会有下载 URL。
128
+
129
+ 3. downlaod:
130
+
131
+ 你可以使用第2步中的下载 URL 来获取清理后的视频。
132
+
133
+ ## 6. 数据集
134
+
135
+ 我们已经将标注好的数据集上传到了 Hugging Face,请查看 https://huggingface.co/datasets/LLinked/sora-watermark-dataset。欢迎训练你自己的检测模型或改进我们的模型!
136
+
137
+
138
+
139
+ ## 7. API
140
+
141
+ 打包为 Cog 并[发布到 Replicate](https://replicate.com/uglyrobot/sora2-watermark-remover),便于基于 API 的简单使用。
142
+
143
+ ## 8. 许可证
144
+
145
+ Apache License
146
+
147
+
148
+ ## 9. 引用
149
+
150
+ 如果你使用了这个项目,请引用:
151
+
152
+ ```bibtex
153
+ @misc{sorawatermarkcleaner2025,
154
+ author = {linkedlist771},
155
+ title = {SoraWatermarkCleaner},
156
+ year = {2025},
157
+ url = {https://github.com/linkedlist771/SoraWatermarkCleaner}
158
+ }
159
+ ```
160
+
161
+ ## 10. 致谢
162
+
163
+ - [IOPaint](https://github.com/Sanster/IOPaint) 提供的 LAMA 实现
164
+ - [Ultralytics YOLO](https://github.com/ultralytics/ultralytics) 提供的目标检测
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import tempfile
3
+ from pathlib import Path
4
+
5
+ import streamlit as st
6
+
7
+ from sorawm.core import SoraWM
8
+
9
+
10
+ def main():
11
+ st.set_page_config(
12
+ page_title="Sora Watermark Cleaner", page_icon="🎬", layout="centered"
13
+ )
14
+
15
+ st.title("🎬 Sora Watermark Cleaner")
16
+ st.markdown("Remove watermarks from Sora-generated videos with ease")
17
+
18
+ # Initialize SoraWM
19
+ if "sora_wm" not in st.session_state:
20
+ with st.spinner("Loading AI models..."):
21
+ st.session_state.sora_wm = SoraWM()
22
+
23
+ st.markdown("---")
24
+
25
+ # File uploader
26
+ uploaded_file = st.file_uploader(
27
+ "Upload your video",
28
+ type=["mp4", "avi", "mov", "mkv"],
29
+ help="Select a video file to remove watermarks",
30
+ )
31
+
32
+ if uploaded_file is not None:
33
+ # Display video info
34
+ st.success(f"✅ Uploaded: {uploaded_file.name}")
35
+ st.video(uploaded_file)
36
+
37
+ # Process button
38
+ if st.button("🚀 Remove Watermark", type="primary", use_container_width=True):
39
+ with tempfile.TemporaryDirectory() as tmp_dir:
40
+ tmp_path = Path(tmp_dir)
41
+
42
+ # Save uploaded file
43
+ input_path = tmp_path / uploaded_file.name
44
+ with open(input_path, "wb") as f:
45
+ f.write(uploaded_file.read())
46
+
47
+ # Process video
48
+ output_path = tmp_path / f"cleaned_{uploaded_file.name}"
49
+
50
+ try:
51
+ # Create progress bar and status text
52
+ progress_bar = st.progress(0)
53
+ status_text = st.empty()
54
+
55
+ def update_progress(progress: int):
56
+ progress_bar.progress(progress / 100)
57
+ if progress < 50:
58
+ status_text.text(f"🔍 Detecting watermarks... {progress}%")
59
+ elif progress < 95:
60
+ status_text.text(f"🧹 Removing watermarks... {progress}%")
61
+ else:
62
+ status_text.text(f"🎵 Merging audio... {progress}%")
63
+
64
+ # Run the watermark removal with progress callback
65
+ st.session_state.sora_wm.run(
66
+ input_path, output_path, progress_callback=update_progress
67
+ )
68
+
69
+ # Complete the progress bar
70
+ progress_bar.progress(100)
71
+ status_text.text("✅ Processing complete!")
72
+
73
+ st.success("✅ Watermark removed successfully!")
74
+
75
+ # Display result
76
+ st.markdown("### Result")
77
+ st.video(str(output_path))
78
+
79
+ # Download button
80
+ with open(output_path, "rb") as f:
81
+ st.download_button(
82
+ label="⬇️ Download Cleaned Video",
83
+ data=f,
84
+ file_name=f"cleaned_{uploaded_file.name}",
85
+ mime="video/mp4",
86
+ use_container_width=True,
87
+ )
88
+
89
+ except Exception as e:
90
+ st.error(f"❌ Error processing video: {str(e)}")
91
+
92
+ # Footer
93
+ st.markdown("---")
94
+ st.markdown(
95
+ """
96
+ <div style='text-align: center'>
97
+ <p>Built with ❤️ using Streamlit and AI</p>
98
+ <p><a href='https://github.com/linkedlist771/SoraWatermarkCleaner'>GitHub Repository</a></p>
99
+ </div>
100
+ """,
101
+ unsafe_allow_html=True,
102
+ )
103
+
104
+
105
+ if __name__ == "__main__":
106
+ main()
datasets/make_yolo_images.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import cv2
4
+ from tqdm import tqdm
5
+
6
+ from sorawm.configs import ROOT
7
+
8
+ videos_dir = ROOT / "videos"
9
+ datasets_dir = ROOT / "datasets"
10
+ images_dir = datasets_dir / "images"
11
+ images_dir.mkdir(exist_ok=True, parents=True)
12
+
13
+ if __name__ == "__main__":
14
+ fps_save_interval = 1 # Save every 1th frame
15
+
16
+ idx = 0
17
+ for video_path in tqdm(list(videos_dir.rglob("*.mp4"))):
18
+ # Open the video file
19
+ cap = cv2.VideoCapture(str(video_path))
20
+
21
+ if not cap.isOpened():
22
+ print(f"Error opening video: {video_path}")
23
+ continue
24
+
25
+ frame_count = 0
26
+
27
+ while True:
28
+ ret, frame = cap.read()
29
+
30
+ # Break if no more frames
31
+ if not ret:
32
+ break
33
+
34
+ # Save frame at the specified interval
35
+ if frame_count % fps_save_interval == 0:
36
+ # Create filename: image_idx_framecount.jpg
37
+ image_filename = f"image_{idx:06d}_frame_{frame_count:06d}.jpg"
38
+ image_path = images_dir / image_filename
39
+
40
+ # Save the frame
41
+ cv2.imwrite(str(image_path), frame)
42
+
43
+ frame_count += 1
44
+
45
+ # Release the video capture object
46
+ cap.release()
47
+ idx += 1
48
+
49
+ print(f"Processed {idx} videos, extracted frames saved to {images_dir}")
example.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from sorawm.core import SoraWM
4
+
5
+ if __name__ == "__main__":
6
+ input_video_path = Path("resources/dog_vs_sam.mp4")
7
+ output_video_path = Path("outputs/sora_watermark_removed.mp4")
8
+ sora_wm = SoraWM()
9
+ sora_wm.run(input_video_path, output_video_path)
ffmpeg/README.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FFmpeg 可执行文件目录
2
+
3
+ ## 用途
4
+
5
+ 这个目录用于存放 FFmpeg 可执行文件,使项目成为真正的便携版(无需系统安装 FFmpeg)。
6
+
7
+ ## Windows 用户配置步骤
8
+
9
+ ### 1. 下载 FFmpeg
10
+
11
+ 访问 [FFmpeg-Builds Release](https://github.com/BtbN/FFmpeg-Builds/releases) 页面:
12
+
13
+ - 下载最新的 `ffmpeg-master-latest-win64-gpl.zip`(约 120MB)
14
+ - 或者下载特定版本,如 `ffmpeg-n6.1-latest-win64-gpl-6.1.zip`
15
+
16
+ ### 2. 解压并复制文件
17
+
18
+ 1. 解压下载的 zip 文件
19
+ 2. 在解压后的文件夹中找到 `bin` 目录
20
+ 3. 将以下两个文件复制到**当前目录**(`ffmpeg/`):
21
+ - `ffmpeg.exe` - FFmpeg 主程序
22
+ - `ffprobe.exe` - FFmpeg 媒体信息探测工具
23
+
24
+ ### 3. 验证配置
25
+
26
+ 完成后,此目录应包含:
27
+
28
+ ```
29
+ ffmpeg/
30
+ ├── .gitkeep
31
+ ├── README.md
32
+ ├── ffmpeg.exe ← 你复制的文件
33
+ └── ffprobe.exe ← 你复制的文件
34
+ ```
35
+
36
+ ### 4. 测试
37
+
38
+ 运行项目中的测试脚本验证配置:
39
+
40
+ ```bash
41
+ python test_ffmpeg_setup.py
42
+ ```
43
+
44
+ 如果配置正确,你将看到:`✓ 测试通过!FFmpeg已正确配置并可以使用`
45
+
46
+ ## macOS/Linux 用户
47
+
48
+ 如果需要便携版,请:
49
+
50
+ 1. 下载对应平台的 FFmpeg 二进制文件
51
+ 2. 将 `ffmpeg` 和 `ffprobe` 可执行文件放到此目录
52
+ 3. 确保文件有执行权限:`chmod +x ffmpeg ffprobe`
53
+
54
+ ## 注意事项
55
+
56
+ - 这些可执行文件不会被 git 提交(已在 `.gitignore` 中配置)
57
+ - 程序会自动检测并使用此目录下的 FFmpeg
58
+ - 如果此目录没有 FFmpeg,程序会尝试使用系统安装的版本
59
+
60
+ ## 下载链接汇总
61
+
62
+ - **Windows**: https://github.com/BtbN/FFmpeg-Builds/releases
63
+ - **官方网站**: https://ffmpeg.org/download.html
64
+ - **镜像站点**: https://www.gyan.dev/ffmpeg/builds/ (Windows)
65
+
66
+ ## 许可证
67
+
68
+ FFmpeg 使用 GPL 许可证,请遵守相关条款。
69
+
notebooks/imputation.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
one-click-portable.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # One-Click Portable Version | 一键便携版
2
+
3
+ For **Windows** users - No installation required!
4
+
5
+ 适用于 **Windows** 用户 - 无需安装!
6
+
7
+ ## Download | 下载
8
+
9
+ **Google Drive:**
10
+ - https://drive.google.com/file/d/1ujH28aHaCXGgB146g6kyfz3Qxd-wHR1c/view?usp=share_link
11
+
12
+ **Baidu Pan | 百度网盘:**
13
+ - Link | 链接: https://pan.baidu.com/s/1_tdgs-3-dLNn0IbufIM75g?pwd=fiju
14
+ - Extract Code | 提取码: `fiju`
15
+
16
+ ## Usage | 使用方法
17
+
18
+ 1. Download and extract the zip file | 下载并解压 zip 文件
19
+ 2. Double-click `run.bat` | 双击 `run.bat` 文件
20
+ 3. The web service will start automatically! | 网页服务将自动启动!
21
+
22
+ ## Features | 特点
23
+
24
+ - ✅ Zero installation | 无需安装
25
+ - ✅ All dependencies included | 包含所有依赖
26
+ - ✅ Ready to use | 开箱即用
pyproject.toml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "sorawatermarkcleaner"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "aiofiles>=24.1.0",
9
+ "aiosqlite>=0.21.0",
10
+ "diffusers>=0.35.1",
11
+ "einops>=0.8.1",
12
+ "fastapi==0.108.0",
13
+ "ffmpeg-python>=0.2.0",
14
+ "fire>=0.7.1",
15
+ "httpx>=0.28.1",
16
+ "huggingface-hub>=0.35.3",
17
+ "jupyter>=1.1.1",
18
+ "loguru>=0.7.3",
19
+ "matplotlib>=3.10.6",
20
+ "notebook>=7.4.7",
21
+ "omegaconf>=2.3.0",
22
+ "opencv-python>=4.12.0.88",
23
+ "pandas>=2.3.3",
24
+ "pydantic>=2.11.10",
25
+ "python-multipart>=0.0.20",
26
+ "requests>=2.32.5",
27
+ "ruptures>=1.1.10",
28
+ "scikit-learn>=1.7.2",
29
+ "sqlalchemy>=2.0.43",
30
+ "streamlit>=1.50.0",
31
+ "torch>=2.5.0",
32
+ "torchvision>=0.20.0",
33
+ "tqdm>=4.67.1",
34
+ "transformers>=4.57.0",
35
+ "ultralytics>=8.3.204",
36
+ "uuid>=1.30",
37
+ "uvicorn>=0.35.0",
38
+ ]
39
+
resources/first_frame.json ADDED
The diff for this file is too large to render. See raw diff
 
resources/watermark_template.png ADDED
sorawm/__init__.py ADDED
File without changes
sorawm/configs.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ ROOT = Path(__file__).parent.parent
4
+
5
+
6
+ RESOURCES_DIR = ROOT / "resources"
7
+ WATER_MARK_TEMPLATE_IMAGE_PATH = RESOURCES_DIR / "watermark_template.png"
8
+
9
+ WATER_MARK_DETECT_YOLO_WEIGHTS = RESOURCES_DIR / "best.pt"
10
+
11
+ OUTPUT_DIR = ROOT / "output"
12
+
13
+ OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
14
+
15
+
16
+ DEFAULT_WATERMARK_REMOVE_MODEL = "lama"
17
+
18
+ WORKING_DIR = ROOT / "working_dir"
19
+ WORKING_DIR.mkdir(exist_ok=True, parents=True)
20
+
21
+ LOGS_PATH = ROOT / "logs"
22
+ LOGS_PATH.mkdir(exist_ok=True, parents=True)
23
+
24
+ DATA_PATH = ROOT / "data"
25
+ DATA_PATH.mkdir(exist_ok=True, parents=True)
26
+
27
+ SQLITE_PATH = DATA_PATH / "db.sqlite3"
sorawm/core.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Callable
3
+
4
+ import ffmpeg
5
+ import numpy as np
6
+ from loguru import logger
7
+ from tqdm import tqdm
8
+
9
+ from sorawm.utils.video_utils import VideoLoader
10
+ from sorawm.watermark_cleaner import WaterMarkCleaner
11
+ from sorawm.watermark_detector import SoraWaterMarkDetector
12
+ from sorawm.utils.imputation_utils import (
13
+ find_2d_data_bkps,
14
+ get_interval_average_bbox,
15
+ find_idxs_interval,
16
+ )
17
+
18
+
19
+ class SoraWM:
20
+ def __init__(self):
21
+ self.detector = SoraWaterMarkDetector()
22
+ self.cleaner = WaterMarkCleaner()
23
+
24
+ def run(
25
+ self,
26
+ input_video_path: Path,
27
+ output_video_path: Path,
28
+ progress_callback: Callable[[int], None] | None = None,
29
+ ):
30
+ input_video_loader = VideoLoader(input_video_path)
31
+ output_video_path.parent.mkdir(parents=True, exist_ok=True)
32
+ width = input_video_loader.width
33
+ height = input_video_loader.height
34
+ fps = input_video_loader.fps
35
+ total_frames = input_video_loader.total_frames
36
+
37
+ temp_output_path = output_video_path.parent / f"temp_{output_video_path.name}"
38
+ output_options = {
39
+ "pix_fmt": "yuv420p",
40
+ "vcodec": "libx264",
41
+ "preset": "slow",
42
+ }
43
+
44
+ if input_video_loader.original_bitrate:
45
+ output_options["video_bitrate"] = str(
46
+ int(int(input_video_loader.original_bitrate) * 1.2)
47
+ )
48
+ else:
49
+ output_options["crf"] = "18"
50
+
51
+ process_out = (
52
+ ffmpeg.input(
53
+ "pipe:",
54
+ format="rawvideo",
55
+ pix_fmt="bgr24",
56
+ s=f"{width}x{height}",
57
+ r=fps,
58
+ )
59
+ .output(str(temp_output_path), **output_options)
60
+ .overwrite_output()
61
+ .global_args("-loglevel", "error")
62
+ .run_async(pipe_stdin=True)
63
+ )
64
+
65
+ frame_and_mask = {}
66
+ detect_missed = []
67
+ bbox_centers = []
68
+ bboxes = []
69
+
70
+ logger.debug(
71
+ f"total frames: {total_frames}, fps: {fps}, width: {width}, height: {height}"
72
+ )
73
+ for idx, frame in enumerate(
74
+ tqdm(input_video_loader, total=total_frames, desc="Detect watermarks")
75
+ ):
76
+ detection_result = self.detector.detect(frame)
77
+ if detection_result["detected"]:
78
+ frame_and_mask[idx] = {"frame": frame, "bbox": detection_result["bbox"]}
79
+ x1, y1, x2, y2 = detection_result["bbox"]
80
+ bbox_centers.append((int((x1 + x2) / 2), int((y1 + y2) / 2)))
81
+ bboxes.append((x1, y1, x2, y2))
82
+
83
+ else:
84
+ frame_and_mask[idx] = {"frame": frame, "bbox": None}
85
+ detect_missed.append(idx)
86
+ bbox_centers.append(None)
87
+ bboxes.append(None)
88
+ # 10% - 50%
89
+ if progress_callback and idx % 10 == 0:
90
+ progress = 10 + int((idx / total_frames) * 40)
91
+ progress_callback(progress)
92
+
93
+ logger.debug(f"detect missed frames: {detect_missed}")
94
+ # logger.debug(f"bbox centers: \n{bbox_centers}")
95
+ if detect_missed:
96
+ # 1. find the bkps of the bbox centers
97
+ bkps = find_2d_data_bkps(bbox_centers)
98
+ # add the start and end position, to form the complete interval boundaries
99
+ bkps_full = [0] + bkps + [total_frames]
100
+ # logger.debug(f"bkps intervals: {bkps_full}")
101
+
102
+ # 2. calculate the average bbox of each interval
103
+ interval_bboxes = get_interval_average_bbox(bboxes, bkps_full)
104
+ # logger.debug(f"interval average bboxes: {interval_bboxes}")
105
+
106
+ # 3. find the interval index of each missed frame
107
+ missed_intervals = find_idxs_interval(detect_missed, bkps_full)
108
+ # logger.debug(
109
+ # f"missed frame intervals: {list(zip(detect_missed, missed_intervals))}"
110
+ # )
111
+
112
+ # 4. fill the missed frames with the average bbox of the corresponding interval
113
+ for missed_idx, interval_idx in zip(detect_missed, missed_intervals):
114
+ if (
115
+ interval_idx < len(interval_bboxes)
116
+ and interval_bboxes[interval_idx] is not None
117
+ ):
118
+ frame_and_mask[missed_idx]["bbox"] = interval_bboxes[interval_idx]
119
+ logger.debug(f"Filled missed frame {missed_idx} with bbox:\n"
120
+ f" {interval_bboxes[interval_idx]}")
121
+ else:
122
+ # if the interval has no valid bbox, use the previous and next frame to complete (fallback strategy)
123
+ before = max(missed_idx - 1, 0)
124
+ after = min(missed_idx + 1, total_frames - 1)
125
+ before_box = frame_and_mask[before]["bbox"]
126
+ after_box = frame_and_mask[after]["bbox"]
127
+ if before_box:
128
+ frame_and_mask[missed_idx]["bbox"] = before_box
129
+ elif after_box:
130
+ frame_and_mask[missed_idx]["bbox"] = after_box
131
+ else:
132
+ del bboxes
133
+ del bbox_centers
134
+ del detect_missed
135
+
136
+ for idx in tqdm(range(total_frames), desc="Remove watermarks"):
137
+ frame_info = frame_and_mask[idx]
138
+ frame = frame_info["frame"]
139
+ bbox = frame_info["bbox"]
140
+ if bbox is not None:
141
+ x1, y1, x2, y2 = bbox
142
+ mask = np.zeros((height, width), dtype=np.uint8)
143
+ mask[y1:y2, x1:x2] = 255
144
+ cleaned_frame = self.cleaner.clean(frame, mask)
145
+ else:
146
+ cleaned_frame = frame
147
+ process_out.stdin.write(cleaned_frame.tobytes())
148
+
149
+ # 50% - 95%
150
+ if progress_callback and idx % 10 == 0:
151
+ progress = 50 + int((idx / total_frames) * 45)
152
+ progress_callback(progress)
153
+
154
+ process_out.stdin.close()
155
+ process_out.wait()
156
+
157
+ # 95% - 99%
158
+ if progress_callback:
159
+ progress_callback(95)
160
+
161
+ self.merge_audio_track(input_video_path, temp_output_path, output_video_path)
162
+
163
+ if progress_callback:
164
+ progress_callback(99)
165
+
166
+ def merge_audio_track(
167
+ self, input_video_path: Path, temp_output_path: Path, output_video_path: Path
168
+ ):
169
+ logger.info("Merging audio track...")
170
+ video_stream = ffmpeg.input(str(temp_output_path))
171
+ audio_stream = ffmpeg.input(str(input_video_path)).audio
172
+
173
+ (
174
+ ffmpeg.output(
175
+ video_stream,
176
+ audio_stream,
177
+ str(output_video_path),
178
+ vcodec="copy",
179
+ acodec="aac",
180
+ )
181
+ .overwrite_output()
182
+ .run(quiet=True)
183
+ )
184
+ # Clean up temporary file
185
+ temp_output_path.unlink()
186
+ logger.info(f"Saved no watermark video with audio at: {output_video_path}")
187
+
188
+
189
+ if __name__ == "__main__":
190
+ from pathlib import Path
191
+
192
+ input_video_path = Path(
193
+ "resources/19700121_1645_68e0a027836c8191a50bea3717ea7485.mp4"
194
+ )
195
+ output_video_path = Path("outputs/sora_watermark_removed.mp4")
196
+ sora_wm = SoraWM()
197
+ sora_wm.run(input_video_path, output_video_path)
sorawm/iopaint/__init__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import logging
4
+ import os
5
+ import shutil
6
+
7
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
8
+ # https://github.com/pytorch/pytorch/issues/27971#issuecomment-1768868068
9
+ os.environ["ONEDNN_PRIMITIVE_CACHE_CAPACITY"] = "1"
10
+ os.environ["LRU_CACHE_CAPACITY"] = "1"
11
+ # prevent CPU memory leak when run model on GPU
12
+ # https://github.com/pytorch/pytorch/issues/98688#issuecomment-1869288431
13
+ # https://github.com/pytorch/pytorch/issues/108334#issuecomment-1752763633
14
+ os.environ["TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT"] = "1"
15
+
16
+ import warnings
17
+
18
+ warnings.simplefilter("ignore", UserWarning)
19
+
20
+
21
+ def fix_window_pytorch():
22
+ # copy from: https://github.com/comfyanonymous/ComfyUI/blob/5cbaa9e07c97296b536f240688f5a19300ecf30d/fix_torch.py#L4
23
+ import platform
24
+
25
+ try:
26
+ if platform.system() != "Windows":
27
+ return
28
+ torch_spec = importlib.util.find_spec("torch")
29
+ for folder in torch_spec.submodule_search_locations:
30
+ lib_folder = os.path.join(folder, "lib")
31
+ test_file = os.path.join(lib_folder, "fbgemm.dll")
32
+ dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
33
+ if os.path.exists(dest):
34
+ break
35
+
36
+ with open(test_file, "rb") as f:
37
+ contents = f.read()
38
+ if b"libomp140.x86_64.dll" not in contents:
39
+ break
40
+ try:
41
+ mydll = ctypes.cdll.LoadLibrary(test_file)
42
+ except FileNotFoundError:
43
+ logging.warning("Detected pytorch version with libomp issue, patching.")
44
+ shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
45
+ except:
46
+ pass
47
+
48
+
49
+ def entry_point():
50
+ # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
51
+ # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
52
+ from sorawm.iopaint.cli import typer_app
53
+
54
+ fix_window_pytorch()
55
+
56
+ typer_app()
sorawm/iopaint/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from iopaint import entry_point
2
+
3
+ if __name__ == "__main__":
4
+ entry_point()
sorawm/iopaint/api.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import threading
4
+ import time
5
+ import traceback
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import socketio
12
+ import torch
13
+
14
+ try:
15
+ torch._C._jit_override_can_fuse_on_cpu(False)
16
+ torch._C._jit_override_can_fuse_on_gpu(False)
17
+ torch._C._jit_set_texpr_fuser_enabled(False)
18
+ torch._C._jit_set_nvfuser_enabled(False)
19
+ torch._C._jit_set_profiling_mode(False)
20
+ except:
21
+ pass
22
+
23
+ import uvicorn
24
+ from fastapi import APIRouter, FastAPI, Request, UploadFile
25
+ from fastapi.encoders import jsonable_encoder
26
+ from fastapi.exceptions import HTTPException
27
+ from fastapi.middleware.cors import CORSMiddleware
28
+ from fastapi.responses import FileResponse, JSONResponse, Response
29
+ from fastapi.staticfiles import StaticFiles
30
+ from loguru import logger
31
+ from PIL import Image
32
+ from socketio import AsyncServer
33
+
34
+ from sorawm.iopaint.file_manager import FileManager
35
+ from sorawm.iopaint.helper import (
36
+ adjust_mask,
37
+ concat_alpha_channel,
38
+ decode_base64_to_image,
39
+ gen_frontend_mask,
40
+ load_img,
41
+ numpy_to_bytes,
42
+ pil_to_bytes,
43
+ )
44
+ from sorawm.iopaint.model.utils import torch_gc
45
+ from sorawm.iopaint.model_manager import ModelManager
46
+ from sorawm.iopaint.plugins import InteractiveSeg, RealESRGANUpscaler, build_plugins
47
+ from sorawm.iopaint.plugins.base_plugin import BasePlugin
48
+ from sorawm.iopaint.plugins.remove_bg import RemoveBG
49
+ from sorawm.iopaint.schema import (
50
+ AdjustMaskRequest,
51
+ ApiConfig,
52
+ GenInfoResponse,
53
+ InpaintRequest,
54
+ InteractiveSegModel,
55
+ ModelInfo,
56
+ PluginInfo,
57
+ RealESRGANModel,
58
+ RemoveBGModel,
59
+ RunPluginRequest,
60
+ SDSampler,
61
+ ServerConfigResponse,
62
+ SwitchModelRequest,
63
+ SwitchPluginModelRequest,
64
+ )
65
+
66
+ CURRENT_DIR = Path(__file__).parent.absolute().resolve()
67
+ WEB_APP_DIR = CURRENT_DIR / "web_app"
68
+
69
+
70
+ def api_middleware(app: FastAPI):
71
+ rich_available = False
72
+ try:
73
+ if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
74
+ import anyio # importing just so it can be placed on silent list
75
+ import starlette # importing just so it can be placed on silent list
76
+ from rich.console import Console
77
+
78
+ console = Console()
79
+ rich_available = True
80
+ except Exception:
81
+ pass
82
+
83
+ def handle_exception(request: Request, e: Exception):
84
+ err = {
85
+ "error": type(e).__name__,
86
+ "detail": vars(e).get("detail", ""),
87
+ "body": vars(e).get("body", ""),
88
+ "errors": str(e),
89
+ }
90
+ if not isinstance(
91
+ e, HTTPException
92
+ ): # do not print backtrace on known httpexceptions
93
+ message = f"API error: {request.method}: {request.url} {err}"
94
+ if rich_available:
95
+ print(message)
96
+ console.print_exception(
97
+ show_locals=True,
98
+ max_frames=2,
99
+ extra_lines=1,
100
+ suppress=[anyio, starlette],
101
+ word_wrap=False,
102
+ width=min([console.width, 200]),
103
+ )
104
+ else:
105
+ traceback.print_exc()
106
+ return JSONResponse(
107
+ status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err)
108
+ )
109
+
110
+ @app.middleware("http")
111
+ async def exception_handling(request: Request, call_next):
112
+ try:
113
+ return await call_next(request)
114
+ except Exception as e:
115
+ return handle_exception(request, e)
116
+
117
+ @app.exception_handler(Exception)
118
+ async def fastapi_exception_handler(request: Request, e: Exception):
119
+ return handle_exception(request, e)
120
+
121
+ @app.exception_handler(HTTPException)
122
+ async def http_exception_handler(request: Request, e: HTTPException):
123
+ return handle_exception(request, e)
124
+
125
+ cors_options = {
126
+ "allow_methods": ["*"],
127
+ "allow_headers": ["*"],
128
+ "allow_origins": ["*"],
129
+ "allow_credentials": True,
130
+ "expose_headers": ["X-Seed"],
131
+ }
132
+ app.add_middleware(CORSMiddleware, **cors_options)
133
+
134
+
135
+ global_sio: AsyncServer = None
136
+
137
+
138
+ def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}):
139
+ # self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
140
+ # logger.info(f"diffusion callback: step={step}, timestep={timestep}")
141
+
142
+ # We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
143
+ # but for now let's just start a separate event loop. It shouldn't make a difference for single person use
144
+ asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
145
+ return {}
146
+
147
+
148
+ class Api:
149
+ def __init__(self, app: FastAPI, config: ApiConfig):
150
+ self.app = app
151
+ self.config = config
152
+ self.router = APIRouter()
153
+ self.queue_lock = threading.Lock()
154
+ api_middleware(self.app)
155
+
156
+ self.file_manager = self._build_file_manager()
157
+ self.plugins = self._build_plugins()
158
+ self.model_manager = self._build_model_manager()
159
+
160
+ # fmt: off
161
+ self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
162
+ self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"],
163
+ response_model=ServerConfigResponse)
164
+ self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
165
+ self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
166
+ self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
167
+ self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
168
+ self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
169
+ self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
170
+ self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
171
+ self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
172
+ self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
173
+ self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
174
+ self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
175
+ # fmt: on
176
+
177
+ global global_sio
178
+ self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
179
+ self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
180
+ self.app.mount("/ws", self.combined_asgi_app)
181
+ global_sio = self.sio
182
+
183
+ def add_api_route(self, path: str, endpoint, **kwargs):
184
+ return self.app.add_api_route(path, endpoint, **kwargs)
185
+
186
+ def api_save_image(self, file: UploadFile):
187
+ # Sanitize filename to prevent path traversal
188
+ safe_filename = Path(file.filename).name # Get just the filename component
189
+
190
+ # Construct the full path within output_dir
191
+ output_path = self.config.output_dir / safe_filename
192
+
193
+ # Ensure output directory exists
194
+ if not self.config.output_dir or not self.config.output_dir.exists():
195
+ raise HTTPException(
196
+ status_code=400,
197
+ detail="Output directory not configured or doesn't exist",
198
+ )
199
+
200
+ # Read and write the file
201
+ origin_image_bytes = file.file.read()
202
+ with open(output_path, "wb") as fw:
203
+ fw.write(origin_image_bytes)
204
+
205
+ def api_current_model(self) -> ModelInfo:
206
+ return self.model_manager.current_model
207
+
208
+ def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo:
209
+ if req.name == self.model_manager.name:
210
+ return self.model_manager.current_model
211
+ self.model_manager.switch(req.name)
212
+ return self.model_manager.current_model
213
+
214
+ def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
215
+ if req.plugin_name in self.plugins:
216
+ self.plugins[req.plugin_name].switch_model(req.model_name)
217
+ if req.plugin_name == RemoveBG.name:
218
+ self.config.remove_bg_model = req.model_name
219
+ if req.plugin_name == RealESRGANUpscaler.name:
220
+ self.config.realesrgan_model = req.model_name
221
+ if req.plugin_name == InteractiveSeg.name:
222
+ self.config.interactive_seg_model = req.model_name
223
+ torch_gc()
224
+
225
+ def api_server_config(self) -> ServerConfigResponse:
226
+ plugins = []
227
+ for it in self.plugins.values():
228
+ plugins.append(
229
+ PluginInfo(
230
+ name=it.name,
231
+ support_gen_image=it.support_gen_image,
232
+ support_gen_mask=it.support_gen_mask,
233
+ )
234
+ )
235
+
236
+ return ServerConfigResponse(
237
+ plugins=plugins,
238
+ modelInfos=self.model_manager.scan_models(),
239
+ removeBGModel=self.config.remove_bg_model,
240
+ removeBGModels=RemoveBGModel.values(),
241
+ realesrganModel=self.config.realesrgan_model,
242
+ realesrganModels=RealESRGANModel.values(),
243
+ interactiveSegModel=self.config.interactive_seg_model,
244
+ interactiveSegModels=InteractiveSegModel.values(),
245
+ enableFileManager=self.file_manager is not None,
246
+ enableAutoSaving=self.config.output_dir is not None,
247
+ enableControlnet=self.model_manager.enable_controlnet,
248
+ controlnetMethod=self.model_manager.controlnet_method,
249
+ disableModelSwitch=False,
250
+ isDesktop=False,
251
+ samplers=self.api_samplers(),
252
+ )
253
+
254
+ def api_input_image(self) -> FileResponse:
255
+ if self.config.input is None:
256
+ raise HTTPException(status_code=200, detail="No input image configured")
257
+
258
+ if self.config.input.is_file():
259
+ return FileResponse(self.config.input)
260
+ raise HTTPException(status_code=404, detail="Input image not found")
261
+
262
+ def api_geninfo(self, file: UploadFile) -> GenInfoResponse:
263
+ _, _, info = load_img(file.file.read(), return_info=True)
264
+ parts = info.get("parameters", "").split("Negative prompt: ")
265
+ prompt = parts[0].strip()
266
+ negative_prompt = ""
267
+ if len(parts) > 1:
268
+ negative_prompt = parts[1].split("\n")[0].strip()
269
+ return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
270
+
271
+ def api_inpaint(self, req: InpaintRequest):
272
+ image, alpha_channel, infos, ext = decode_base64_to_image(req.image)
273
+ mask, _, _, _ = decode_base64_to_image(req.mask, gray=True)
274
+ logger.info(f"image ext: {ext}")
275
+
276
+ mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
277
+ if image.shape[:2] != mask.shape[:2]:
278
+ raise HTTPException(
279
+ 400,
280
+ detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
281
+ )
282
+
283
+ start = time.time()
284
+ rgb_np_img = self.model_manager(image, mask, req)
285
+ logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
286
+ torch_gc()
287
+
288
+ rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
289
+ rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
290
+
291
+ res_img_bytes = pil_to_bytes(
292
+ Image.fromarray(rgb_res),
293
+ ext=ext,
294
+ quality=self.config.quality,
295
+ infos=infos,
296
+ )
297
+
298
+ asyncio.run(self.sio.emit("diffusion_finish"))
299
+
300
+ return Response(
301
+ content=res_img_bytes,
302
+ media_type=f"image/{ext}",
303
+ headers={"X-Seed": str(req.sd_seed)},
304
+ )
305
+
306
+ def api_run_plugin_gen_image(self, req: RunPluginRequest):
307
+ ext = "png"
308
+ if req.name not in self.plugins:
309
+ raise HTTPException(status_code=422, detail="Plugin not found")
310
+ if not self.plugins[req.name].support_gen_image:
311
+ raise HTTPException(
312
+ status_code=422, detail="Plugin does not support output image"
313
+ )
314
+ rgb_np_img, alpha_channel, infos, _ = decode_base64_to_image(req.image)
315
+ bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
316
+ torch_gc()
317
+
318
+ if bgr_or_rgba_np_img.shape[2] == 4:
319
+ rgba_np_img = bgr_or_rgba_np_img
320
+ else:
321
+ rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
322
+ rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
323
+
324
+ return Response(
325
+ content=pil_to_bytes(
326
+ Image.fromarray(rgba_np_img),
327
+ ext=ext,
328
+ quality=self.config.quality,
329
+ infos=infos,
330
+ ),
331
+ media_type=f"image/{ext}",
332
+ )
333
+
334
+ def api_run_plugin_gen_mask(self, req: RunPluginRequest):
335
+ if req.name not in self.plugins:
336
+ raise HTTPException(status_code=422, detail="Plugin not found")
337
+ if not self.plugins[req.name].support_gen_mask:
338
+ raise HTTPException(
339
+ status_code=422, detail="Plugin does not support output image"
340
+ )
341
+ rgb_np_img, _, _, _ = decode_base64_to_image(req.image)
342
+ bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
343
+ torch_gc()
344
+ res_mask = gen_frontend_mask(bgr_or_gray_mask)
345
+ return Response(
346
+ content=numpy_to_bytes(res_mask, "png"),
347
+ media_type="image/png",
348
+ )
349
+
350
+ def api_samplers(self) -> List[str]:
351
+ return [member.value for member in SDSampler.__members__.values()]
352
+
353
+ def api_adjust_mask(self, req: AdjustMaskRequest):
354
+ mask, _, _, _ = decode_base64_to_image(req.mask, gray=True)
355
+ mask = adjust_mask(mask, req.kernel_size, req.operate)
356
+ return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
357
+
358
+ def launch(self):
359
+ self.app.include_router(self.router)
360
+ uvicorn.run(
361
+ self.combined_asgi_app,
362
+ host=self.config.host,
363
+ port=self.config.port,
364
+ timeout_keep_alive=999999999,
365
+ )
366
+
367
+ def _build_file_manager(self) -> Optional[FileManager]:
368
+ if self.config.input and self.config.input.is_dir():
369
+ logger.info(
370
+ f"Input is directory, initialize file manager {self.config.input}"
371
+ )
372
+
373
+ return FileManager(
374
+ app=self.app,
375
+ input_dir=self.config.input,
376
+ mask_dir=self.config.mask_dir,
377
+ output_dir=self.config.output_dir,
378
+ )
379
+ return None
380
+
381
+ def _build_plugins(self) -> Dict[str, BasePlugin]:
382
+ return build_plugins(
383
+ self.config.enable_interactive_seg,
384
+ self.config.interactive_seg_model,
385
+ self.config.interactive_seg_device,
386
+ self.config.enable_remove_bg,
387
+ self.config.remove_bg_device,
388
+ self.config.remove_bg_model,
389
+ self.config.enable_anime_seg,
390
+ self.config.enable_realesrgan,
391
+ self.config.realesrgan_device,
392
+ self.config.realesrgan_model,
393
+ self.config.enable_gfpgan,
394
+ self.config.gfpgan_device,
395
+ self.config.enable_restoreformer,
396
+ self.config.restoreformer_device,
397
+ self.config.no_half,
398
+ )
399
+
400
+ def _build_model_manager(self):
401
+ return ModelManager(
402
+ name=self.config.model,
403
+ device=torch.device(self.config.device),
404
+ no_half=self.config.no_half,
405
+ low_mem=self.config.low_mem,
406
+ disable_nsfw=self.config.disable_nsfw_checker,
407
+ sd_cpu_textencoder=self.config.cpu_textencoder,
408
+ local_files_only=self.config.local_files_only,
409
+ cpu_offload=self.config.cpu_offload,
410
+ callback=diffuser_callback,
411
+ )
sorawm/iopaint/batch_processing.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Dict, Optional
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from loguru import logger
8
+ from PIL import Image
9
+ from rich.console import Console
10
+ from rich.progress import (
11
+ BarColumn,
12
+ MofNCompleteColumn,
13
+ Progress,
14
+ SpinnerColumn,
15
+ TaskProgressColumn,
16
+ TextColumn,
17
+ TimeElapsedColumn,
18
+ )
19
+
20
+ from sorawm.iopaint.helper import pil_to_bytes
21
+ from sorawm.iopaint.model.utils import torch_gc
22
+ from sorawm.iopaint.model_manager import ModelManager
23
+ from sorawm.iopaint.schema import InpaintRequest
24
+
25
+
26
+ def glob_images(path: Path) -> Dict[str, Path]:
27
+ # png/jpg/jpeg
28
+ if path.is_file():
29
+ return {path.stem: path}
30
+ elif path.is_dir():
31
+ res = {}
32
+ for it in path.glob("*.*"):
33
+ if it.suffix.lower() in [".png", ".jpg", ".jpeg"]:
34
+ res[it.stem] = it
35
+ return res
36
+
37
+
38
+ def batch_inpaint(
39
+ model: str,
40
+ device,
41
+ image: Path,
42
+ mask: Path,
43
+ output: Path,
44
+ config: Optional[Path] = None,
45
+ concat: bool = False,
46
+ ):
47
+ if image.is_dir() and output.is_file():
48
+ logger.error(
49
+ "invalid --output: when image is a directory, output should be a directory"
50
+ )
51
+ exit(-1)
52
+ output.mkdir(parents=True, exist_ok=True)
53
+
54
+ image_paths = glob_images(image)
55
+ mask_paths = glob_images(mask)
56
+ if len(image_paths) == 0:
57
+ logger.error("invalid --image: empty image folder")
58
+ exit(-1)
59
+ if len(mask_paths) == 0:
60
+ logger.error("invalid --mask: empty mask folder")
61
+ exit(-1)
62
+
63
+ if config is None:
64
+ inpaint_request = InpaintRequest()
65
+ logger.info(f"Using default config: {inpaint_request}")
66
+ else:
67
+ with open(config, "r", encoding="utf-8") as f:
68
+ inpaint_request = InpaintRequest(**json.load(f))
69
+ logger.info(f"Using config: {inpaint_request}")
70
+
71
+ model_manager = ModelManager(name=model, device=device)
72
+ first_mask = list(mask_paths.values())[0]
73
+
74
+ console = Console()
75
+
76
+ with Progress(
77
+ SpinnerColumn(),
78
+ TextColumn("[progress.description]{task.description}"),
79
+ BarColumn(),
80
+ TaskProgressColumn(),
81
+ MofNCompleteColumn(),
82
+ TimeElapsedColumn(),
83
+ console=console,
84
+ transient=False,
85
+ ) as progress:
86
+ task = progress.add_task("Batch processing...", total=len(image_paths))
87
+ for stem, image_p in image_paths.items():
88
+ if stem not in mask_paths and mask.is_dir():
89
+ progress.log(f"mask for {image_p} not found")
90
+ progress.update(task, advance=1)
91
+ continue
92
+ mask_p = mask_paths.get(stem, first_mask)
93
+
94
+ infos = Image.open(image_p).info
95
+
96
+ img = np.array(Image.open(image_p).convert("RGB"))
97
+ mask_img = np.array(Image.open(mask_p).convert("L"))
98
+
99
+ if mask_img.shape[:2] != img.shape[:2]:
100
+ progress.log(
101
+ f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}"
102
+ )
103
+ mask_img = cv2.resize(
104
+ mask_img,
105
+ (img.shape[1], img.shape[0]),
106
+ interpolation=cv2.INTER_NEAREST,
107
+ )
108
+ mask_img[mask_img >= 127] = 255
109
+ mask_img[mask_img < 127] = 0
110
+
111
+ # bgr
112
+ inpaint_result = model_manager(img, mask_img, inpaint_request)
113
+ inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
114
+ if concat:
115
+ mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB)
116
+ inpaint_result = cv2.hconcat([img, mask_img, inpaint_result])
117
+
118
+ img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos)
119
+ save_p = output / f"{stem}.png"
120
+ with open(save_p, "wb") as fw:
121
+ fw.write(img_bytes)
122
+
123
+ progress.update(task, advance=1)
124
+ torch_gc()
125
+ # pid = psutil.Process().pid
126
+ # memory_info = psutil.Process(pid).memory_info()
127
+ # memory_in_mb = memory_info.rss / (1024 * 1024)
128
+ # print(f"原图大小:{img.shape},当前进程的内存占用:{memory_in_mb}MB")
sorawm/iopaint/benchmark.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import argparse
4
+ import os
5
+ import time
6
+
7
+ import numpy as np
8
+ import nvidia_smi
9
+ import psutil
10
+ import torch
11
+
12
+ from sorawm.iopaint.model_manager import ModelManager
13
+ from sorawm.iopaint.schema import HDStrategy, InpaintRequest, SDSampler
14
+
15
+ try:
16
+ torch._C._jit_override_can_fuse_on_cpu(False)
17
+ torch._C._jit_override_can_fuse_on_gpu(False)
18
+ torch._C._jit_set_texpr_fuser_enabled(False)
19
+ torch._C._jit_set_nvfuser_enabled(False)
20
+ except:
21
+ pass
22
+
23
+ NUM_THREADS = str(4)
24
+
25
+ os.environ["OMP_NUM_THREADS"] = NUM_THREADS
26
+ os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
27
+ os.environ["MKL_NUM_THREADS"] = NUM_THREADS
28
+ os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
29
+ os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
30
+ if os.environ.get("CACHE_DIR"):
31
+ os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
32
+
33
+
34
+ def run_model(model, size):
35
+ # RGB
36
+ image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
37
+ mask = np.random.randint(0, 255, size).astype(np.uint8)
38
+
39
+ config = InpaintRequest(
40
+ ldm_steps=2,
41
+ hd_strategy=HDStrategy.ORIGINAL,
42
+ hd_strategy_crop_margin=128,
43
+ hd_strategy_crop_trigger_size=128,
44
+ hd_strategy_resize_limit=128,
45
+ prompt="a fox is sitting on a bench",
46
+ sd_steps=5,
47
+ sd_sampler=SDSampler.ddim,
48
+ )
49
+ model(image, mask, config)
50
+
51
+
52
+ def benchmark(model, times: int, empty_cache: bool):
53
+ sizes = [(512, 512)]
54
+
55
+ nvidia_smi.nvmlInit()
56
+ device_id = 0
57
+ handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id)
58
+
59
+ def format(metrics):
60
+ return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}"
61
+
62
+ process = psutil.Process(os.getpid())
63
+ # 每个 size 给出显存和内存占用的指标
64
+ for size in sizes:
65
+ torch.cuda.empty_cache()
66
+ time_metrics = []
67
+ cpu_metrics = []
68
+ memory_metrics = []
69
+ gpu_memory_metrics = []
70
+ for _ in range(times):
71
+ start = time.time()
72
+ run_model(model, size)
73
+ torch.cuda.synchronize()
74
+
75
+ # cpu_metrics.append(process.cpu_percent())
76
+ time_metrics.append((time.time() - start) * 1000)
77
+ memory_metrics.append(process.memory_info().rss / 1024 / 1024)
78
+ gpu_memory_metrics.append(
79
+ nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024
80
+ )
81
+
82
+ print(f"size: {size}".center(80, "-"))
83
+ # print(f"cpu: {format(cpu_metrics)}")
84
+ print(f"latency: {format(time_metrics)}ms")
85
+ print(f"memory: {format(memory_metrics)} MB")
86
+ print(f"gpu memory: {format(gpu_memory_metrics)} MB")
87
+
88
+ nvidia_smi.nvmlShutdown()
89
+
90
+
91
+ def get_args_parser():
92
+ parser = argparse.ArgumentParser()
93
+ parser.add_argument("--name")
94
+ parser.add_argument("--device", default="cuda", type=str)
95
+ parser.add_argument("--times", default=10, type=int)
96
+ parser.add_argument("--empty-cache", action="store_true")
97
+ return parser.parse_args()
98
+
99
+
100
+ if __name__ == "__main__":
101
+ args = get_args_parser()
102
+ device = torch.device(args.device)
103
+ model = ModelManager(
104
+ name=args.name,
105
+ device=device,
106
+ disable_nsfw=True,
107
+ sd_cpu_textencoder=True,
108
+ )
109
+ benchmark(model, args.times, args.empty_cache)
sorawm/iopaint/cli.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import webbrowser
2
+ from contextlib import asynccontextmanager
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import typer
7
+ from fastapi import FastAPI
8
+ from loguru import logger
9
+ from typer import Option
10
+ from typer_config import use_json_config
11
+
12
+ from sorawm.iopaint.const import *
13
+ from sorawm.iopaint.runtime import check_device, dump_environment_info, setup_model_dir
14
+ from sorawm.iopaint.schema import (
15
+ Device,
16
+ InteractiveSegModel,
17
+ RealESRGANModel,
18
+ RemoveBGModel,
19
+ )
20
+
21
+ typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
22
+
23
+
24
+ @typer_app.command(help="Install all plugins dependencies")
25
+ def install_plugins_packages():
26
+ from sorawm.iopaint.installer import install_plugins_package
27
+
28
+ install_plugins_package()
29
+
30
+
31
+ @typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace")
32
+ def download(
33
+ model: str = Option(
34
+ ..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting"
35
+ ),
36
+ model_dir: Path = Option(
37
+ DEFAULT_MODEL_DIR,
38
+ help=MODEL_DIR_HELP,
39
+ file_okay=False,
40
+ callback=setup_model_dir,
41
+ ),
42
+ ):
43
+ from sorawm.iopaint.download import cli_download_model
44
+
45
+ cli_download_model(model)
46
+
47
+
48
+ @typer_app.command(name="list", help="List downloaded models")
49
+ def list_model(
50
+ model_dir: Path = Option(
51
+ DEFAULT_MODEL_DIR,
52
+ help=MODEL_DIR_HELP,
53
+ file_okay=False,
54
+ callback=setup_model_dir,
55
+ ),
56
+ ):
57
+ from sorawm.iopaint.download import scan_models
58
+
59
+ scanned_models = scan_models()
60
+ for it in scanned_models:
61
+ print(it.name)
62
+
63
+
64
+ @typer_app.command(help="Batch processing images")
65
+ def run(
66
+ model: str = Option("lama"),
67
+ device: Device = Option(Device.cpu),
68
+ image: Path = Option(..., help="Image folders or file path"),
69
+ mask: Path = Option(
70
+ ...,
71
+ help="Mask folders or file path. "
72
+ "If it is a directory, the mask images in the directory should have the same name as the original image."
73
+ "If it is a file, all images will use this mask."
74
+ "Mask will automatically resize to the same size as the original image.",
75
+ ),
76
+ output: Path = Option(..., help="Output directory or file path"),
77
+ config: Path = Option(
78
+ None, help="Config file path. You can use dump command to create a base config."
79
+ ),
80
+ concat: bool = Option(
81
+ False, help="Concat original image, mask and output images into one image"
82
+ ),
83
+ model_dir: Path = Option(
84
+ DEFAULT_MODEL_DIR,
85
+ help=MODEL_DIR_HELP,
86
+ file_okay=False,
87
+ callback=setup_model_dir,
88
+ ),
89
+ ):
90
+ from sorawm.iopaint.download import cli_download_model, scan_models
91
+
92
+ scanned_models = scan_models()
93
+ if model not in [it.name for it in scanned_models]:
94
+ logger.info(f"{model} not found in {model_dir}, try to downloading")
95
+ cli_download_model(model)
96
+
97
+ from sorawm.iopaint.batch_processing import batch_inpaint
98
+
99
+ batch_inpaint(model, device, image, mask, output, config, concat)
100
+
101
+
102
+ @typer_app.command(help="Start IOPaint server")
103
+ @use_json_config()
104
+ def start(
105
+ host: str = Option("127.0.0.1"),
106
+ port: int = Option(8080),
107
+ inbrowser: bool = Option(False, help=INBROWSER_HELP),
108
+ model: str = Option(
109
+ DEFAULT_MODEL,
110
+ help=f"Erase models: [{', '.join(AVAILABLE_MODELS)}].\n"
111
+ f"Diffusion models: [{', '.join(DIFFUSION_MODELS)}] or any SD/SDXL normal/inpainting models on HuggingFace.",
112
+ ),
113
+ model_dir: Path = Option(
114
+ DEFAULT_MODEL_DIR,
115
+ help=MODEL_DIR_HELP,
116
+ dir_okay=True,
117
+ file_okay=False,
118
+ callback=setup_model_dir,
119
+ ),
120
+ low_mem: bool = Option(False, help=LOW_MEM_HELP),
121
+ no_half: bool = Option(False, help=NO_HALF_HELP),
122
+ cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
123
+ disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
124
+ cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP),
125
+ local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
126
+ device: Device = Option(Device.cpu),
127
+ input: Optional[Path] = Option(None, help=INPUT_HELP),
128
+ mask_dir: Optional[Path] = Option(
129
+ None, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
130
+ ),
131
+ output_dir: Optional[Path] = Option(
132
+ None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
133
+ ),
134
+ quality: int = Option(100, help=QUALITY_HELP),
135
+ enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
136
+ interactive_seg_model: InteractiveSegModel = Option(
137
+ InteractiveSegModel.sam2_1_tiny, help=INTERACTIVE_SEG_MODEL_HELP
138
+ ),
139
+ interactive_seg_device: Device = Option(Device.cpu),
140
+ enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
141
+ remove_bg_device: Device = Option(Device.cpu, help=REMOVE_BG_DEVICE_HELP),
142
+ remove_bg_model: RemoveBGModel = Option(RemoveBGModel.briaai_rmbg_1_4),
143
+ enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
144
+ enable_realesrgan: bool = Option(False),
145
+ realesrgan_device: Device = Option(Device.cpu),
146
+ realesrgan_model: RealESRGANModel = Option(RealESRGANModel.realesr_general_x4v3),
147
+ enable_gfpgan: bool = Option(False),
148
+ gfpgan_device: Device = Option(Device.cpu),
149
+ enable_restoreformer: bool = Option(False),
150
+ restoreformer_device: Device = Option(Device.cpu),
151
+ ):
152
+ dump_environment_info()
153
+ device = check_device(device)
154
+ remove_bg_device = check_device(remove_bg_device)
155
+ realesrgan_device = check_device(realesrgan_device)
156
+ gfpgan_device = check_device(gfpgan_device)
157
+
158
+ if input and not input.exists():
159
+ logger.error(f"invalid --input: {input} not exists")
160
+ exit(-1)
161
+ if mask_dir and not mask_dir.exists():
162
+ logger.error(f"invalid --mask-dir: {mask_dir} not exists")
163
+ exit(-1)
164
+ if input and input.is_dir() and not output_dir:
165
+ logger.error(
166
+ "invalid --output-dir: --output-dir must be set when --input is a directory"
167
+ )
168
+ exit(-1)
169
+ if output_dir:
170
+ output_dir = output_dir.expanduser().absolute()
171
+ logger.info(f"Image will be saved to {output_dir}")
172
+ if not output_dir.exists():
173
+ logger.info(f"Create output directory {output_dir}")
174
+ output_dir.mkdir(parents=True)
175
+ if mask_dir:
176
+ mask_dir = mask_dir.expanduser().absolute()
177
+
178
+ model_dir = model_dir.expanduser().absolute()
179
+
180
+ if local_files_only:
181
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
182
+ os.environ["HF_HUB_OFFLINE"] = "1"
183
+
184
+ from sorawm.iopaint.download import cli_download_model, scan_models
185
+
186
+ scanned_models = scan_models()
187
+ if model not in [it.name for it in scanned_models]:
188
+ logger.info(f"{model} not found in {model_dir}, try to downloading")
189
+ cli_download_model(model)
190
+
191
+ from sorawm.iopaint.api import Api
192
+ from sorawm.iopaint.schema import ApiConfig
193
+
194
+ @asynccontextmanager
195
+ async def lifespan(app: FastAPI):
196
+ if inbrowser:
197
+ webbrowser.open(f"http://localhost:{port}", new=0, autoraise=True)
198
+ yield
199
+
200
+ app = FastAPI(lifespan=lifespan)
201
+
202
+ api_config = ApiConfig(
203
+ host=host,
204
+ port=port,
205
+ inbrowser=inbrowser,
206
+ model=model,
207
+ no_half=no_half,
208
+ low_mem=low_mem,
209
+ cpu_offload=cpu_offload,
210
+ disable_nsfw_checker=disable_nsfw_checker,
211
+ local_files_only=local_files_only,
212
+ cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
213
+ device=device,
214
+ input=input,
215
+ mask_dir=mask_dir,
216
+ output_dir=output_dir,
217
+ quality=quality,
218
+ enable_interactive_seg=enable_interactive_seg,
219
+ interactive_seg_model=interactive_seg_model,
220
+ interactive_seg_device=interactive_seg_device,
221
+ enable_remove_bg=enable_remove_bg,
222
+ remove_bg_device=remove_bg_device,
223
+ remove_bg_model=remove_bg_model,
224
+ enable_anime_seg=enable_anime_seg,
225
+ enable_realesrgan=enable_realesrgan,
226
+ realesrgan_device=realesrgan_device,
227
+ realesrgan_model=realesrgan_model,
228
+ enable_gfpgan=enable_gfpgan,
229
+ gfpgan_device=gfpgan_device,
230
+ enable_restoreformer=enable_restoreformer,
231
+ restoreformer_device=restoreformer_device,
232
+ )
233
+ print(api_config.model_dump_json(indent=4))
234
+ api = Api(app, api_config)
235
+ api.launch()
236
+
237
+
238
+ @typer_app.command(help="Start IOPaint web config page")
239
+ def start_web_config(
240
+ config_file: Path = Option("config.json"),
241
+ ):
242
+ dump_environment_info()
243
+ from sorawm.iopaint.web_config import main
244
+
245
+ main(config_file)
sorawm/iopaint/const.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix"
5
+ KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
6
+ POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
7
+ ANYTEXT_NAME = "Sanster/AnyText"
8
+
9
+ DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
10
+ DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
11
+ DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
12
+ DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline"
13
+
14
+ MPS_UNSUPPORT_MODELS = [
15
+ "lama",
16
+ "ldm",
17
+ "zits",
18
+ "mat",
19
+ "fcf",
20
+ "cv2",
21
+ "manga",
22
+ ]
23
+
24
+ DEFAULT_MODEL = "lama"
25
+ AVAILABLE_MODELS = ["lama", "ldm", "zits", "mat", "fcf", "manga", "cv2", "migan"]
26
+ DIFFUSION_MODELS = [
27
+ "runwayml/stable-diffusion-inpainting",
28
+ "Uminosachi/realisticVisionV51_v51VAE-inpainting",
29
+ "redstonehero/dreamshaper-inpainting",
30
+ "Sanster/anything-4.0-inpainting",
31
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
32
+ "Fantasy-Studio/Paint-by-Example",
33
+ "RunDiffusion/Juggernaut-XI-v11",
34
+ "SG161222/RealVisXL_V5.0",
35
+ "eienmojiki/Anything-XL",
36
+ POWERPAINT_NAME,
37
+ ANYTEXT_NAME,
38
+ ]
39
+
40
+ NO_HALF_HELP = """
41
+ Using full precision(fp32) model.
42
+ If your diffusion model generate result is always black or green, use this argument.
43
+ """
44
+
45
+ CPU_OFFLOAD_HELP = """
46
+ Offloads diffusion model's weight to CPU RAM, significantly reducing vRAM usage.
47
+ """
48
+
49
+ LOW_MEM_HELP = "Enable attention slicing and vae tiling to save memory."
50
+
51
+ DISABLE_NSFW_HELP = """
52
+ Disable NSFW checker for diffusion model.
53
+ """
54
+
55
+ CPU_TEXTENCODER_HELP = """
56
+ Run diffusion models text encoder on CPU to reduce vRAM usage.
57
+ """
58
+
59
+ SD_CONTROLNET_CHOICES: List[str] = [
60
+ "lllyasviel/control_v11p_sd15_canny",
61
+ # "lllyasviel/control_v11p_sd15_seg",
62
+ "lllyasviel/control_v11p_sd15_openpose",
63
+ "lllyasviel/control_v11p_sd15_inpaint",
64
+ "lllyasviel/control_v11f1p_sd15_depth",
65
+ ]
66
+
67
+ SD_BRUSHNET_CHOICES: List[str] = [
68
+ "Sanster/brushnet_random_mask",
69
+ "Sanster/brushnet_segmentation_mask",
70
+ ]
71
+
72
+ SD2_CONTROLNET_CHOICES = [
73
+ "thibaud/controlnet-sd21-canny-diffusers",
74
+ "thibaud/controlnet-sd21-depth-diffusers",
75
+ "thibaud/controlnet-sd21-openpose-diffusers",
76
+ ]
77
+
78
+ SDXL_CONTROLNET_CHOICES = [
79
+ "thibaud/controlnet-openpose-sdxl-1.0",
80
+ "destitech/controlnet-inpaint-dreamer-sdxl",
81
+ "diffusers/controlnet-canny-sdxl-1.0",
82
+ "diffusers/controlnet-canny-sdxl-1.0-mid",
83
+ "diffusers/controlnet-canny-sdxl-1.0-small",
84
+ "diffusers/controlnet-depth-sdxl-1.0",
85
+ "diffusers/controlnet-depth-sdxl-1.0-mid",
86
+ "diffusers/controlnet-depth-sdxl-1.0-small",
87
+ ]
88
+
89
+ SDXL_BRUSHNET_CHOICES = ["Regulus0725/random_mask_brushnet_ckpt_sdxl_regulus_v1"]
90
+
91
+ LOCAL_FILES_ONLY_HELP = """
92
+ When loading diffusion models, using local files only, not connect to HuggingFace server.
93
+ """
94
+
95
+ DEFAULT_MODEL_DIR = os.path.abspath(
96
+ os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache"))
97
+ )
98
+
99
+ MODEL_DIR_HELP = f"""
100
+ Model download directory (by setting XDG_CACHE_HOME environment variable), by default model download to {DEFAULT_MODEL_DIR}
101
+ """
102
+
103
+ OUTPUT_DIR_HELP = """
104
+ Result images will be saved to output directory automatically.
105
+ """
106
+
107
+ MASK_DIR_HELP = """
108
+ You can view masks in FileManager
109
+ """
110
+
111
+ INPUT_HELP = """
112
+ If input is image, it will be loaded by default.
113
+ If input is directory, you can browse and select image in file manager.
114
+ """
115
+
116
+ GUI_HELP = """
117
+ Launch Lama Cleaner as desktop app
118
+ """
119
+
120
+ QUALITY_HELP = """
121
+ Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size.
122
+ """
123
+
124
+ INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
125
+ INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
126
+ REMOVE_BG_HELP = "Enable remove background plugin."
127
+ REMOVE_BG_DEVICE_HELP = "Device for remove background plugin. 'cuda' only supports briaai models(briaai/RMBG-1.4 and briaai/RMBG-2.0)"
128
+ ANIMESEG_HELP = "Enable anime segmentation plugin. Always run on CPU"
129
+ REALESRGAN_HELP = "Enable realesrgan super resolution"
130
+ GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan"
131
+ RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To also enhance background, use with --enable-realesrgan"
132
+ GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
133
+
134
+ INBROWSER_HELP = "Automatically launch IOPaint in a new tab on the default browser"
sorawm/iopaint/download.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import os
4
+ from functools import lru_cache
5
+ from pathlib import Path
6
+ from typing import List, Optional
7
+
8
+ from loguru import logger
9
+
10
+ from sorawm.iopaint.const import (
11
+ ANYTEXT_NAME,
12
+ DEFAULT_MODEL_DIR,
13
+ DIFFUSERS_SD_CLASS_NAME,
14
+ DIFFUSERS_SD_INPAINT_CLASS_NAME,
15
+ DIFFUSERS_SDXL_CLASS_NAME,
16
+ DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
17
+ )
18
+ from sorawm.iopaint.model.original_sd_configs import get_config_files
19
+ from sorawm.iopaint.schema import ModelInfo, ModelType
20
+
21
+
22
+ def cli_download_model(model: str):
23
+ from sorawm.iopaint.model import models
24
+ from sorawm.iopaint.model.utils import handle_from_pretrained_exceptions
25
+
26
+ if model in models and models[model].is_erase_model:
27
+ logger.info(f"Downloading {model}...")
28
+ models[model].download()
29
+ logger.info("Done.")
30
+ elif model == ANYTEXT_NAME:
31
+ logger.info(f"Downloading {model}...")
32
+ models[model].download()
33
+ logger.info("Done.")
34
+ else:
35
+ logger.info(f"Downloading model from Huggingface: {model}")
36
+ from diffusers import DiffusionPipeline
37
+
38
+ downloaded_path = handle_from_pretrained_exceptions(
39
+ DiffusionPipeline.download, pretrained_model_name=model, variant="fp16"
40
+ )
41
+ logger.info(f"Done. Downloaded to {downloaded_path}")
42
+
43
+
44
+ def folder_name_to_show_name(name: str) -> str:
45
+ return name.replace("models--", "").replace("--", "/")
46
+
47
+
48
+ @lru_cache(maxsize=512)
49
+ def get_sd_model_type(model_abs_path: str) -> Optional[ModelType]:
50
+ if "inpaint" in Path(model_abs_path).name.lower():
51
+ model_type = ModelType.DIFFUSERS_SD_INPAINT
52
+ else:
53
+ # load once to check num_in_channels
54
+ from diffusers import StableDiffusionInpaintPipeline
55
+
56
+ try:
57
+ StableDiffusionInpaintPipeline.from_single_file(
58
+ model_abs_path,
59
+ load_safety_checker=False,
60
+ num_in_channels=9,
61
+ original_config_file=get_config_files()["v1"],
62
+ )
63
+ model_type = ModelType.DIFFUSERS_SD_INPAINT
64
+ except ValueError as e:
65
+ if "[320, 4, 3, 3]" in str(e):
66
+ model_type = ModelType.DIFFUSERS_SD
67
+ else:
68
+ logger.info(f"Ignore non sdxl file: {model_abs_path}")
69
+ return
70
+ except Exception as e:
71
+ logger.error(f"Failed to load {model_abs_path}: {e}")
72
+ return
73
+ return model_type
74
+
75
+
76
+ @lru_cache()
77
+ def get_sdxl_model_type(model_abs_path: str) -> Optional[ModelType]:
78
+ if "inpaint" in model_abs_path:
79
+ model_type = ModelType.DIFFUSERS_SDXL_INPAINT
80
+ else:
81
+ # load once to check num_in_channels
82
+ from diffusers import StableDiffusionXLInpaintPipeline
83
+
84
+ try:
85
+ model = StableDiffusionXLInpaintPipeline.from_single_file(
86
+ model_abs_path,
87
+ load_safety_checker=False,
88
+ num_in_channels=9,
89
+ original_config_file=get_config_files()["xl"],
90
+ )
91
+ if model.unet.config.in_channels == 9:
92
+ # https://github.com/huggingface/diffusers/issues/6610
93
+ model_type = ModelType.DIFFUSERS_SDXL_INPAINT
94
+ else:
95
+ model_type = ModelType.DIFFUSERS_SDXL
96
+ except ValueError as e:
97
+ if "[320, 4, 3, 3]" in str(e):
98
+ model_type = ModelType.DIFFUSERS_SDXL
99
+ else:
100
+ logger.info(f"Ignore non sdxl file: {model_abs_path}")
101
+ return
102
+ except Exception as e:
103
+ logger.error(f"Failed to load {model_abs_path}: {e}")
104
+ return
105
+ return model_type
106
+
107
+
108
+ def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
109
+ cache_dir = Path(cache_dir)
110
+ stable_diffusion_dir = cache_dir / "stable_diffusion"
111
+ cache_file = stable_diffusion_dir / "iopaint_cache.json"
112
+ model_type_cache = {}
113
+ if cache_file.exists():
114
+ try:
115
+ with open(cache_file, "r", encoding="utf-8") as f:
116
+ model_type_cache = json.load(f)
117
+ assert isinstance(model_type_cache, dict)
118
+ except:
119
+ pass
120
+
121
+ res = []
122
+ for it in stable_diffusion_dir.glob("*.*"):
123
+ if it.suffix not in [".safetensors", ".ckpt"]:
124
+ continue
125
+ model_abs_path = str(it.absolute())
126
+ model_type = model_type_cache.get(it.name)
127
+ if model_type is None:
128
+ model_type = get_sd_model_type(model_abs_path)
129
+ if model_type is None:
130
+ continue
131
+
132
+ model_type_cache[it.name] = model_type
133
+ res.append(
134
+ ModelInfo(
135
+ name=it.name,
136
+ path=model_abs_path,
137
+ model_type=model_type,
138
+ is_single_file_diffusers=True,
139
+ )
140
+ )
141
+ if stable_diffusion_dir.exists():
142
+ with open(cache_file, "w", encoding="utf-8") as fw:
143
+ json.dump(model_type_cache, fw, indent=2, ensure_ascii=False)
144
+
145
+ stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
146
+ sdxl_cache_file = stable_diffusion_xl_dir / "iopaint_cache.json"
147
+ sdxl_model_type_cache = {}
148
+ if sdxl_cache_file.exists():
149
+ try:
150
+ with open(sdxl_cache_file, "r", encoding="utf-8") as f:
151
+ sdxl_model_type_cache = json.load(f)
152
+ assert isinstance(sdxl_model_type_cache, dict)
153
+ except:
154
+ pass
155
+
156
+ for it in stable_diffusion_xl_dir.glob("*.*"):
157
+ if it.suffix not in [".safetensors", ".ckpt"]:
158
+ continue
159
+ model_abs_path = str(it.absolute())
160
+ model_type = sdxl_model_type_cache.get(it.name)
161
+ if model_type is None:
162
+ model_type = get_sdxl_model_type(model_abs_path)
163
+ if model_type is None:
164
+ continue
165
+
166
+ sdxl_model_type_cache[it.name] = model_type
167
+ if stable_diffusion_xl_dir.exists():
168
+ with open(sdxl_cache_file, "w", encoding="utf-8") as fw:
169
+ json.dump(sdxl_model_type_cache, fw, indent=2, ensure_ascii=False)
170
+
171
+ res.append(
172
+ ModelInfo(
173
+ name=it.name,
174
+ path=model_abs_path,
175
+ model_type=model_type,
176
+ is_single_file_diffusers=True,
177
+ )
178
+ )
179
+ return res
180
+
181
+
182
+ def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]:
183
+ res = []
184
+ from sorawm.iopaint.model import models
185
+
186
+ # logger.info(f"Scanning inpaint models in {model_dir}")
187
+
188
+ for name, m in models.items():
189
+ if m.is_erase_model and m.is_downloaded():
190
+ res.append(
191
+ ModelInfo(
192
+ name=name,
193
+ path=name,
194
+ model_type=ModelType.INPAINT,
195
+ )
196
+ )
197
+ return res
198
+
199
+
200
+ def scan_diffusers_models() -> List[ModelInfo]:
201
+ from huggingface_hub.constants import HF_HUB_CACHE
202
+
203
+ available_models = []
204
+ cache_dir = Path(HF_HUB_CACHE)
205
+ # logger.info(f"Scanning diffusers models in {cache_dir}")
206
+ diffusers_model_names = []
207
+ model_index_files = glob.glob(
208
+ os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True
209
+ )
210
+ for it in model_index_files:
211
+ it = Path(it)
212
+ try:
213
+ with open(it, "r", encoding="utf-8") as f:
214
+ data = json.load(f)
215
+ except:
216
+ continue
217
+
218
+ _class_name = data["_class_name"]
219
+ name = folder_name_to_show_name(it.parent.parent.parent.name)
220
+ if name in diffusers_model_names:
221
+ continue
222
+ if "PowerPaint" in name:
223
+ model_type = ModelType.DIFFUSERS_OTHER
224
+ elif _class_name == DIFFUSERS_SD_CLASS_NAME:
225
+ model_type = ModelType.DIFFUSERS_SD
226
+ elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
227
+ model_type = ModelType.DIFFUSERS_SD_INPAINT
228
+ elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
229
+ model_type = ModelType.DIFFUSERS_SDXL
230
+ elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
231
+ model_type = ModelType.DIFFUSERS_SDXL_INPAINT
232
+ elif _class_name in [
233
+ "StableDiffusionInstructPix2PixPipeline",
234
+ "PaintByExamplePipeline",
235
+ "KandinskyV22InpaintPipeline",
236
+ "AnyText",
237
+ ]:
238
+ model_type = ModelType.DIFFUSERS_OTHER
239
+ else:
240
+ continue
241
+
242
+ diffusers_model_names.append(name)
243
+ available_models.append(
244
+ ModelInfo(
245
+ name=name,
246
+ path=name,
247
+ model_type=model_type,
248
+ )
249
+ )
250
+ return available_models
251
+
252
+
253
+ def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
254
+ cache_dir = Path(cache_dir)
255
+ available_models = []
256
+ diffusers_model_names = []
257
+ model_index_files = glob.glob(
258
+ os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True
259
+ )
260
+ for it in model_index_files:
261
+ it = Path(it)
262
+ with open(it, "r", encoding="utf-8") as f:
263
+ try:
264
+ data = json.load(f)
265
+ except:
266
+ logger.error(
267
+ f"Failed to load {it}, please try revert from original model or fix model_index.json by hand."
268
+ )
269
+ continue
270
+
271
+ _class_name = data["_class_name"]
272
+ name = folder_name_to_show_name(it.parent.name)
273
+ if name in diffusers_model_names:
274
+ continue
275
+ elif _class_name == DIFFUSERS_SD_CLASS_NAME:
276
+ model_type = ModelType.DIFFUSERS_SD
277
+ elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
278
+ model_type = ModelType.DIFFUSERS_SD_INPAINT
279
+ elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
280
+ model_type = ModelType.DIFFUSERS_SDXL
281
+ elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
282
+ model_type = ModelType.DIFFUSERS_SDXL_INPAINT
283
+ else:
284
+ continue
285
+
286
+ diffusers_model_names.append(name)
287
+ available_models.append(
288
+ ModelInfo(
289
+ name=name,
290
+ path=str(it.parent.absolute()),
291
+ model_type=model_type,
292
+ )
293
+ )
294
+ return available_models
295
+
296
+
297
+ def scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
298
+ cache_dir = Path(cache_dir)
299
+ available_models = []
300
+ stable_diffusion_dir = cache_dir / "stable_diffusion"
301
+ stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
302
+ available_models.extend(_scan_converted_diffusers_models(stable_diffusion_dir))
303
+ available_models.extend(_scan_converted_diffusers_models(stable_diffusion_xl_dir))
304
+ return available_models
305
+
306
+
307
+ def scan_models() -> List[ModelInfo]:
308
+ model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
309
+ available_models = []
310
+ available_models.extend(scan_inpaint_models(model_dir))
311
+ available_models.extend(scan_single_file_diffusion_models(model_dir))
312
+ available_models.extend(scan_diffusers_models())
313
+ available_models.extend(scan_converted_diffusers_models(model_dir))
314
+ return available_models
sorawm/iopaint/file_manager/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .file_manager import FileManager
sorawm/iopaint/file_manager/file_manager.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import BytesIO
3
+ from pathlib import Path
4
+ from typing import List
5
+
6
+ from fastapi import FastAPI, HTTPException
7
+ from PIL import Image, ImageOps, PngImagePlugin
8
+ from starlette.responses import FileResponse
9
+
10
+ from ..schema import MediasResponse, MediaTab
11
+
12
+ LARGE_ENOUGH_NUMBER = 100
13
+ PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
14
+ from .storage_backends import FilesystemStorageBackend
15
+ from .utils import aspect_to_string, generate_filename, glob_img
16
+
17
+
18
+ class FileManager:
19
+ def __init__(self, app: FastAPI, input_dir: Path, mask_dir: Path, output_dir: Path):
20
+ self.app = app
21
+ self.input_dir: Path = input_dir
22
+ self.mask_dir: Path = mask_dir
23
+ self.output_dir: Path = output_dir
24
+
25
+ self.image_dir_filenames = []
26
+ self.output_dir_filenames = []
27
+ if not self.thumbnail_directory.exists():
28
+ self.thumbnail_directory.mkdir(parents=True)
29
+
30
+ # fmt: off
31
+ self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["GET"], response_model=List[MediasResponse])
32
+ self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["GET"])
33
+ self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["GET"])
34
+ # fmt: on
35
+
36
+ def api_medias(self, tab: MediaTab) -> List[MediasResponse]:
37
+ img_dir = self._get_dir(tab)
38
+ return self._media_names(img_dir)
39
+
40
+ def api_media_file(self, tab: MediaTab, filename: str) -> FileResponse:
41
+ file_path = self._get_file(tab, filename)
42
+ return FileResponse(file_path, media_type="image/png")
43
+
44
+ # tab=${tab}?filename=${filename.name}?width=${width}&height=${height}
45
+ def api_media_thumbnail_file(
46
+ self, tab: MediaTab, filename: str, width: int, height: int
47
+ ) -> FileResponse:
48
+ img_dir = self._get_dir(tab)
49
+ thumb_filename, (width, height) = self.get_thumbnail(
50
+ img_dir, filename, width=width, height=height
51
+ )
52
+ thumbnail_filepath = self.thumbnail_directory / thumb_filename
53
+ return FileResponse(
54
+ thumbnail_filepath,
55
+ headers={
56
+ "X-Width": str(width),
57
+ "X-Height": str(height),
58
+ },
59
+ media_type="image/jpeg",
60
+ )
61
+
62
+ def _get_dir(self, tab: MediaTab) -> Path:
63
+ if tab == "input":
64
+ return self.input_dir
65
+ elif tab == "output":
66
+ return self.output_dir
67
+ elif tab == "mask":
68
+ return self.mask_dir
69
+ else:
70
+ raise HTTPException(status_code=422, detail=f"tab not found: {tab}")
71
+
72
+ def _get_file(self, tab: MediaTab, filename: str) -> Path:
73
+ file_path = self._get_dir(tab) / filename
74
+ if not file_path.exists():
75
+ raise HTTPException(status_code=422, detail=f"file not found: {file_path}")
76
+ return file_path
77
+
78
+ @property
79
+ def thumbnail_directory(self) -> Path:
80
+ return self.output_dir / "thumbnails"
81
+
82
+ @staticmethod
83
+ def _media_names(directory: Path) -> List[MediasResponse]:
84
+ if directory is None:
85
+ return []
86
+ names = sorted([it.name for it in glob_img(directory)])
87
+ res = []
88
+ for name in names:
89
+ path = os.path.join(directory, name)
90
+ img = Image.open(path)
91
+ res.append(
92
+ MediasResponse(
93
+ name=name,
94
+ height=img.height,
95
+ width=img.width,
96
+ ctime=os.path.getctime(path),
97
+ mtime=os.path.getmtime(path),
98
+ )
99
+ )
100
+ return res
101
+
102
+ def get_thumbnail(
103
+ self, directory: Path, original_filename: str, width, height, **options
104
+ ):
105
+ directory = Path(directory)
106
+ storage = FilesystemStorageBackend(self.app)
107
+ crop = options.get("crop", "fit")
108
+ background = options.get("background")
109
+ quality = options.get("quality", 90)
110
+
111
+ original_path, original_filename = os.path.split(original_filename)
112
+ original_filepath = os.path.join(directory, original_path, original_filename)
113
+ image = Image.open(BytesIO(storage.read(original_filepath)))
114
+
115
+ # keep ratio resize
116
+ if not width and not height:
117
+ width = 256
118
+
119
+ if width != 0:
120
+ height = int(image.height * width / image.width)
121
+ else:
122
+ width = int(image.width * height / image.height)
123
+
124
+ thumbnail_size = (width, height)
125
+
126
+ thumbnail_filename = generate_filename(
127
+ directory,
128
+ original_filename,
129
+ aspect_to_string(thumbnail_size),
130
+ crop,
131
+ background,
132
+ quality,
133
+ )
134
+
135
+ thumbnail_filepath = os.path.join(
136
+ self.thumbnail_directory, original_path, thumbnail_filename
137
+ )
138
+
139
+ if storage.exists(thumbnail_filepath):
140
+ return thumbnail_filepath, (width, height)
141
+
142
+ try:
143
+ image.load()
144
+ except (IOError, OSError):
145
+ self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
146
+ return thumbnail_filepath, (width, height)
147
+
148
+ # get original image format
149
+ options["format"] = options.get("format", image.format)
150
+
151
+ image = self._create_thumbnail(
152
+ image, thumbnail_size, crop, background=background
153
+ )
154
+
155
+ raw_data = self.get_raw_data(image, **options)
156
+ storage.save(thumbnail_filepath, raw_data)
157
+
158
+ return thumbnail_filepath, (width, height)
159
+
160
+ def get_raw_data(self, image, **options):
161
+ data = {
162
+ "format": self._get_format(image, **options),
163
+ "quality": options.get("quality", 90),
164
+ }
165
+
166
+ _file = BytesIO()
167
+ image.save(_file, **data)
168
+ return _file.getvalue()
169
+
170
+ @staticmethod
171
+ def colormode(image, colormode="RGB"):
172
+ if colormode == "RGB" or colormode == "RGBA":
173
+ if image.mode == "RGBA":
174
+ return image
175
+ if image.mode == "LA":
176
+ return image.convert("RGBA")
177
+ return image.convert(colormode)
178
+
179
+ if colormode == "GRAY":
180
+ return image.convert("L")
181
+
182
+ return image.convert(colormode)
183
+
184
+ @staticmethod
185
+ def background(original_image, color=0xFF):
186
+ size = (max(original_image.size),) * 2
187
+ image = Image.new("L", size, color)
188
+ image.paste(
189
+ original_image,
190
+ tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))),
191
+ )
192
+
193
+ return image
194
+
195
+ def _get_format(self, image, **options):
196
+ if options.get("format"):
197
+ return options.get("format")
198
+ if image.format:
199
+ return image.format
200
+
201
+ return "JPEG"
202
+
203
+ def _create_thumbnail(self, image, size, crop="fit", background=None):
204
+ try:
205
+ resample = Image.Resampling.LANCZOS
206
+ except AttributeError: # pylint: disable=raise-missing-from
207
+ resample = Image.ANTIALIAS
208
+
209
+ if crop == "fit":
210
+ image = ImageOps.fit(image, size, resample)
211
+ else:
212
+ image = image.copy()
213
+ image.thumbnail(size, resample=resample)
214
+
215
+ if background is not None:
216
+ image = self.background(image)
217
+
218
+ image = self.colormode(image)
219
+
220
+ return image
sorawm/iopaint/file_manager/storage_backends.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/storage_backends.py
2
+ import errno
3
+ import os
4
+ from abc import ABC, abstractmethod
5
+
6
+
7
+ class BaseStorageBackend(ABC):
8
+ def __init__(self, app=None):
9
+ self.app = app
10
+
11
+ @abstractmethod
12
+ def read(self, filepath, mode="rb", **kwargs):
13
+ raise NotImplementedError
14
+
15
+ @abstractmethod
16
+ def exists(self, filepath):
17
+ raise NotImplementedError
18
+
19
+ @abstractmethod
20
+ def save(self, filepath, data):
21
+ raise NotImplementedError
22
+
23
+
24
+ class FilesystemStorageBackend(BaseStorageBackend):
25
+ def read(self, filepath, mode="rb", **kwargs):
26
+ with open(filepath, mode) as f: # pylint: disable=unspecified-encoding
27
+ return f.read()
28
+
29
+ def exists(self, filepath):
30
+ return os.path.exists(filepath)
31
+
32
+ def save(self, filepath, data):
33
+ directory = os.path.dirname(filepath)
34
+
35
+ if not os.path.exists(directory):
36
+ try:
37
+ os.makedirs(directory)
38
+ except OSError as e:
39
+ if e.errno != errno.EEXIST:
40
+ raise
41
+
42
+ if not os.path.isdir(directory):
43
+ raise IOError("{} is not a directory".format(directory))
44
+
45
+ with open(filepath, "wb") as f:
46
+ f.write(data)
sorawm/iopaint/file_manager/utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
2
+ import hashlib
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+
7
+ def generate_filename(directory: Path, original_filename, *options) -> str:
8
+ text = str(directory.absolute()) + original_filename
9
+ for v in options:
10
+ text += "%s" % v
11
+ md5_hash = hashlib.md5()
12
+ md5_hash.update(text.encode("utf-8"))
13
+ return md5_hash.hexdigest() + ".jpg"
14
+
15
+
16
+ def parse_size(size):
17
+ if isinstance(size, int):
18
+ # If the size parameter is a single number, assume square aspect.
19
+ return [size, size]
20
+
21
+ if isinstance(size, (tuple, list)):
22
+ if len(size) == 1:
23
+ # If single value tuple/list is provided, exand it to two elements
24
+ return size + type(size)(size)
25
+ return size
26
+
27
+ try:
28
+ thumbnail_size = [int(x) for x in size.lower().split("x", 1)]
29
+ except ValueError:
30
+ raise ValueError( # pylint: disable=raise-missing-from
31
+ "Bad thumbnail size format. Valid format is INTxINT."
32
+ )
33
+
34
+ if len(thumbnail_size) == 1:
35
+ # If the size parameter only contains a single integer, assume square aspect.
36
+ thumbnail_size.append(thumbnail_size[0])
37
+
38
+ return thumbnail_size
39
+
40
+
41
+ def aspect_to_string(size):
42
+ if isinstance(size, str):
43
+ return size
44
+
45
+ return "x".join(map(str, size))
46
+
47
+
48
+ IMG_SUFFIX = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
49
+
50
+
51
+ def glob_img(p: Union[Path, str], recursive: bool = False):
52
+ p = Path(p)
53
+ if p.is_file() and p.suffix in IMG_SUFFIX:
54
+ yield p
55
+ else:
56
+ if recursive:
57
+ files = Path(p).glob("**/*.*")
58
+ else:
59
+ files = Path(p).glob("*.*")
60
+
61
+ for it in files:
62
+ if it.suffix not in IMG_SUFFIX:
63
+ continue
64
+ yield it
sorawm/iopaint/helper.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import hashlib
3
+ import imghdr
4
+ import io
5
+ import os
6
+ import sys
7
+ from typing import Dict, List, Optional, Tuple
8
+ from urllib.parse import urlparse
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+ from loguru import logger
14
+ from PIL import Image, ImageOps, PngImagePlugin
15
+ from torch.hub import download_url_to_file, get_dir
16
+
17
+ from sorawm.iopaint.const import MPS_UNSUPPORT_MODELS
18
+
19
+
20
+ def md5sum(filename):
21
+ md5 = hashlib.md5()
22
+ with open(filename, "rb") as f:
23
+ for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
24
+ md5.update(chunk)
25
+ return md5.hexdigest()
26
+
27
+
28
+ def switch_mps_device(model_name, device):
29
+ if model_name in MPS_UNSUPPORT_MODELS and str(device) == "mps":
30
+ logger.info(f"{model_name} not support mps, switch to cpu")
31
+ return torch.device("cpu")
32
+ return device
33
+
34
+
35
+ def get_cache_path_by_url(url):
36
+ parts = urlparse(url)
37
+ hub_dir = get_dir()
38
+ model_dir = os.path.join(hub_dir, "checkpoints")
39
+ if not os.path.isdir(model_dir):
40
+ os.makedirs(model_dir)
41
+ filename = os.path.basename(parts.path)
42
+ cached_file = os.path.join(model_dir, filename)
43
+ return cached_file
44
+
45
+
46
+ def download_model(url, model_md5: str = None):
47
+ if os.path.exists(url):
48
+ cached_file = url
49
+ else:
50
+ cached_file = get_cache_path_by_url(url)
51
+ if not os.path.exists(cached_file):
52
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
53
+ hash_prefix = None
54
+ download_url_to_file(url, cached_file, hash_prefix, progress=True)
55
+ if model_md5:
56
+ _md5 = md5sum(cached_file)
57
+ if model_md5 == _md5:
58
+ logger.info(f"Download model success, md5: {_md5}")
59
+ else:
60
+ try:
61
+ os.remove(cached_file)
62
+ logger.error(
63
+ f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart sorawm.iopaint."
64
+ f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
65
+ )
66
+ except:
67
+ logger.error(
68
+ f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart sorawm.iopaint."
69
+ )
70
+ exit(-1)
71
+
72
+ return cached_file
73
+
74
+
75
+ def ceil_modulo(x, mod):
76
+ if x % mod == 0:
77
+ return x
78
+ return (x // mod + 1) * mod
79
+
80
+
81
+ def handle_error(model_path, model_md5, e):
82
+ _md5 = md5sum(model_path)
83
+ if _md5 != model_md5:
84
+ try:
85
+ os.remove(model_path)
86
+ logger.error(
87
+ f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart sorawm.iopaint."
88
+ f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
89
+ )
90
+ except:
91
+ logger.error(
92
+ f"Model md5: {_md5}, expected md5: {model_md5}, please delete {model_path} and restart sorawm.iopaint."
93
+ )
94
+ else:
95
+ logger.error(
96
+ f"Failed to load model {model_path},"
97
+ f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
98
+ )
99
+ exit(-1)
100
+
101
+
102
+ def load_jit_model(url_or_path, device, model_md5: str):
103
+ if os.path.exists(url_or_path):
104
+ model_path = url_or_path
105
+ else:
106
+ model_path = download_model(url_or_path, model_md5)
107
+
108
+ logger.info(f"Loading model from: {model_path}")
109
+ try:
110
+ model = torch.jit.load(model_path, map_location="cpu").to(device)
111
+ except Exception as e:
112
+ handle_error(model_path, model_md5, e)
113
+ model.eval()
114
+ return model
115
+
116
+
117
+ def load_model(model: torch.nn.Module, url_or_path, device, model_md5):
118
+ if os.path.exists(url_or_path):
119
+ model_path = url_or_path
120
+ else:
121
+ model_path = download_model(url_or_path, model_md5)
122
+
123
+ try:
124
+ logger.info(f"Loading model from: {model_path}")
125
+ state_dict = torch.load(model_path, map_location="cpu")
126
+ model.load_state_dict(state_dict, strict=True)
127
+ model.to(device)
128
+ except Exception as e:
129
+ handle_error(model_path, model_md5, e)
130
+ model.eval()
131
+ return model
132
+
133
+
134
+ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
135
+ data = cv2.imencode(
136
+ f".{ext}",
137
+ image_numpy,
138
+ [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
139
+ )[1]
140
+ image_bytes = data.tobytes()
141
+ return image_bytes
142
+
143
+
144
+ def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes:
145
+ with io.BytesIO() as output:
146
+ kwargs = {k: v for k, v in infos.items() if v is not None}
147
+ if ext == "jpg":
148
+ ext = "jpeg"
149
+ if "png" == ext.lower() and "parameters" in kwargs:
150
+ pnginfo_data = PngImagePlugin.PngInfo()
151
+ pnginfo_data.add_text("parameters", kwargs["parameters"])
152
+ kwargs["pnginfo"] = pnginfo_data
153
+
154
+ pil_img.save(output, format=ext, quality=quality, **kwargs)
155
+ image_bytes = output.getvalue()
156
+ return image_bytes
157
+
158
+
159
+ def load_img(img_bytes, gray: bool = False, return_info: bool = False):
160
+ alpha_channel = None
161
+ image = Image.open(io.BytesIO(img_bytes))
162
+
163
+ if return_info:
164
+ infos = image.info
165
+
166
+ try:
167
+ image = ImageOps.exif_transpose(image)
168
+ except:
169
+ pass
170
+
171
+ if gray:
172
+ image = image.convert("L")
173
+ np_img = np.array(image)
174
+ else:
175
+ if image.mode == "RGBA":
176
+ np_img = np.array(image)
177
+ alpha_channel = np_img[:, :, -1]
178
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
179
+ else:
180
+ image = image.convert("RGB")
181
+ np_img = np.array(image)
182
+
183
+ if return_info:
184
+ return np_img, alpha_channel, infos
185
+ return np_img, alpha_channel
186
+
187
+
188
+ def norm_img(np_img):
189
+ if len(np_img.shape) == 2:
190
+ np_img = np_img[:, :, np.newaxis]
191
+ np_img = np.transpose(np_img, (2, 0, 1))
192
+ np_img = np_img.astype("float32") / 255
193
+ return np_img
194
+
195
+
196
+ def resize_max_size(
197
+ np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
198
+ ) -> np.ndarray:
199
+ # Resize image's longer size to size_limit if longer size larger than size_limit
200
+ h, w = np_img.shape[:2]
201
+ if max(h, w) > size_limit:
202
+ ratio = size_limit / max(h, w)
203
+ new_w = int(w * ratio + 0.5)
204
+ new_h = int(h * ratio + 0.5)
205
+ return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
206
+ else:
207
+ return np_img
208
+
209
+
210
+ def pad_img_to_modulo(
211
+ img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
212
+ ):
213
+ """
214
+
215
+ Args:
216
+ img: [H, W, C]
217
+ mod:
218
+ square: 是否为正方形
219
+ min_size:
220
+
221
+ Returns:
222
+
223
+ """
224
+ if len(img.shape) == 2:
225
+ img = img[:, :, np.newaxis]
226
+ height, width = img.shape[:2]
227
+ out_height = ceil_modulo(height, mod)
228
+ out_width = ceil_modulo(width, mod)
229
+
230
+ if min_size is not None:
231
+ assert min_size % mod == 0
232
+ out_width = max(min_size, out_width)
233
+ out_height = max(min_size, out_height)
234
+
235
+ if square:
236
+ max_size = max(out_height, out_width)
237
+ out_height = max_size
238
+ out_width = max_size
239
+
240
+ return np.pad(
241
+ img,
242
+ ((0, out_height - height), (0, out_width - width), (0, 0)),
243
+ mode="symmetric",
244
+ )
245
+
246
+
247
+ def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
248
+ """
249
+ Args:
250
+ mask: (h, w, 1) 0~255
251
+
252
+ Returns:
253
+
254
+ """
255
+ height, width = mask.shape[:2]
256
+ _, thresh = cv2.threshold(mask, 127, 255, 0)
257
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
258
+
259
+ boxes = []
260
+ for cnt in contours:
261
+ x, y, w, h = cv2.boundingRect(cnt)
262
+ box = np.array([x, y, x + w, y + h]).astype(int)
263
+
264
+ box[::2] = np.clip(box[::2], 0, width)
265
+ box[1::2] = np.clip(box[1::2], 0, height)
266
+ boxes.append(box)
267
+
268
+ return boxes
269
+
270
+
271
+ def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
272
+ """
273
+ Args:
274
+ mask: (h, w) 0~255
275
+
276
+ Returns:
277
+
278
+ """
279
+ _, thresh = cv2.threshold(mask, 127, 255, 0)
280
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
281
+
282
+ max_area = 0
283
+ max_index = -1
284
+ for i, cnt in enumerate(contours):
285
+ area = cv2.contourArea(cnt)
286
+ if area > max_area:
287
+ max_area = area
288
+ max_index = i
289
+
290
+ if max_index != -1:
291
+ new_mask = np.zeros_like(mask)
292
+ return cv2.drawContours(new_mask, contours, max_index, 255, -1)
293
+ else:
294
+ return mask
295
+
296
+
297
+ def is_mac():
298
+ return sys.platform == "darwin"
299
+
300
+
301
+ def get_image_ext(img_bytes):
302
+ w = imghdr.what("", img_bytes)
303
+ if w is None:
304
+ w = "jpeg"
305
+ return w
306
+
307
+
308
+ def decode_base64_to_image(
309
+ encoding: str, gray=False
310
+ ) -> Tuple[np.array, Optional[np.array], Dict, str]:
311
+ if encoding.startswith("data:image/") or encoding.startswith(
312
+ "data:application/octet-stream;base64,"
313
+ ):
314
+ encoding = encoding.split(";")[1].split(",")[1]
315
+ image_bytes = base64.b64decode(encoding)
316
+ ext = get_image_ext(image_bytes)
317
+ image = Image.open(io.BytesIO(image_bytes))
318
+
319
+ alpha_channel = None
320
+ try:
321
+ image = ImageOps.exif_transpose(image)
322
+ except:
323
+ pass
324
+ # exif_transpose will remove exif rotate info,we must call image.info after exif_transpose
325
+ infos = image.info
326
+
327
+ if gray:
328
+ image = image.convert("L")
329
+ np_img = np.array(image)
330
+ else:
331
+ if image.mode == "RGBA":
332
+ np_img = np.array(image)
333
+ alpha_channel = np_img[:, :, -1]
334
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
335
+ else:
336
+ image = image.convert("RGB")
337
+ np_img = np.array(image)
338
+
339
+ return np_img, alpha_channel, infos, ext
340
+
341
+
342
+ def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes:
343
+ img_bytes = pil_to_bytes(
344
+ image,
345
+ "png",
346
+ quality=quality,
347
+ infos=infos,
348
+ )
349
+ return base64.b64encode(img_bytes)
350
+
351
+
352
+ def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray:
353
+ if alpha_channel is not None:
354
+ if alpha_channel.shape[:2] != rgb_np_img.shape[:2]:
355
+ alpha_channel = cv2.resize(
356
+ alpha_channel, dsize=(rgb_np_img.shape[1], rgb_np_img.shape[0])
357
+ )
358
+ rgb_np_img = np.concatenate(
359
+ (rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
360
+ )
361
+ return rgb_np_img
362
+
363
+
364
+ def adjust_mask(mask: np.ndarray, kernel_size: int, operate):
365
+ # fronted brush color "ffcc00bb"
366
+ # kernel_size = kernel_size*2+1
367
+ mask[mask >= 127] = 255
368
+ mask[mask < 127] = 0
369
+
370
+ if operate == "reverse":
371
+ mask = 255 - mask
372
+ else:
373
+ kernel = cv2.getStructuringElement(
374
+ cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
375
+ )
376
+ if operate == "expand":
377
+ mask = cv2.dilate(
378
+ mask,
379
+ kernel,
380
+ iterations=1,
381
+ )
382
+ else:
383
+ mask = cv2.erode(
384
+ mask,
385
+ kernel,
386
+ iterations=1,
387
+ )
388
+ res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
389
+ res_mask[mask > 128] = [255, 203, 0, int(255 * 0.73)]
390
+ res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
391
+ return res_mask
392
+
393
+
394
+ def gen_frontend_mask(bgr_or_gray_mask):
395
+ if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1:
396
+ bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY)
397
+
398
+ # fronted brush color "ffcc00bb"
399
+ # TODO: how to set kernel size?
400
+ kernel_size = 9
401
+ bgr_or_gray_mask = cv2.dilate(
402
+ bgr_or_gray_mask,
403
+ np.ones((kernel_size, kernel_size), np.uint8),
404
+ iterations=1,
405
+ )
406
+ res_mask = np.zeros(
407
+ (bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8
408
+ )
409
+ res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)]
410
+ res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
411
+ return res_mask
sorawm/iopaint/installer.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+
4
+
5
+ def install(package):
6
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
7
+
8
+
9
+ def install_plugins_package():
10
+ install("onnxruntime<=1.19.2")
11
+ install("rembg[cpu]")
sorawm/iopaint/model/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .anytext.anytext_model import AnyText
2
+ from .controlnet import ControlNet
3
+ from .fcf import FcF
4
+ from .instruct_pix2pix import InstructPix2Pix
5
+ from .kandinsky import Kandinsky22
6
+ from .lama import AnimeLaMa, LaMa
7
+ from .ldm import LDM
8
+ from .manga import Manga
9
+ from .mat import MAT
10
+ from .mi_gan import MIGAN
11
+ from .opencv2 import OpenCV2
12
+ from .paint_by_example import PaintByExample
13
+ from .power_paint.power_paint import PowerPaint
14
+ from .sd import SD, SD2, SD15, Anything4, RealisticVision14
15
+ from .sdxl import SDXL
16
+ from .zits import ZITS
17
+
18
+ models = {
19
+ LaMa.name: LaMa,
20
+ AnimeLaMa.name: AnimeLaMa,
21
+ LDM.name: LDM,
22
+ ZITS.name: ZITS,
23
+ MAT.name: MAT,
24
+ FcF.name: FcF,
25
+ OpenCV2.name: OpenCV2,
26
+ Manga.name: Manga,
27
+ MIGAN.name: MIGAN,
28
+ SD15.name: SD15,
29
+ Anything4.name: Anything4,
30
+ RealisticVision14.name: RealisticVision14,
31
+ SD2.name: SD2,
32
+ PaintByExample.name: PaintByExample,
33
+ InstructPix2Pix.name: InstructPix2Pix,
34
+ Kandinsky22.name: Kandinsky22,
35
+ SDXL.name: SDXL,
36
+ PowerPaint.name: PowerPaint,
37
+ AnyText.name: AnyText,
38
+ }
sorawm/iopaint/model/anytext/__init__.py ADDED
File without changes
sorawm/iopaint/model/anytext/anytext_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from huggingface_hub import hf_hub_download
3
+
4
+ from sorawm.iopaint.const import ANYTEXT_NAME
5
+ from sorawm.iopaint.model.anytext.anytext_pipeline import AnyTextPipeline
6
+ from sorawm.iopaint.model.base import DiffusionInpaintModel
7
+ from sorawm.iopaint.model.utils import get_torch_dtype, is_local_files_only
8
+ from sorawm.iopaint.schema import InpaintRequest
9
+
10
+
11
+ class AnyText(DiffusionInpaintModel):
12
+ name = ANYTEXT_NAME
13
+ pad_mod = 64
14
+ is_erase_model = False
15
+
16
+ @staticmethod
17
+ def download(local_files_only=False):
18
+ hf_hub_download(
19
+ repo_id=ANYTEXT_NAME,
20
+ filename="model_index.json",
21
+ local_files_only=local_files_only,
22
+ )
23
+ ckpt_path = hf_hub_download(
24
+ repo_id=ANYTEXT_NAME,
25
+ filename="pytorch_model.fp16.safetensors",
26
+ local_files_only=local_files_only,
27
+ )
28
+ font_path = hf_hub_download(
29
+ repo_id=ANYTEXT_NAME,
30
+ filename="SourceHanSansSC-Medium.otf",
31
+ local_files_only=local_files_only,
32
+ )
33
+ return ckpt_path, font_path
34
+
35
+ def init_model(self, device, **kwargs):
36
+ local_files_only = is_local_files_only(**kwargs)
37
+ ckpt_path, font_path = self.download(local_files_only)
38
+ use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
39
+ self.model = AnyTextPipeline(
40
+ ckpt_path=ckpt_path,
41
+ font_path=font_path,
42
+ device=device,
43
+ use_fp16=torch_dtype == torch.float16,
44
+ )
45
+ self.callback = kwargs.pop("callback", None)
46
+
47
+ def forward(self, image, mask, config: InpaintRequest):
48
+ """Input image and output image have same size
49
+ image: [H, W, C] RGB
50
+ mask: [H, W, 1] 255 means area to inpainting
51
+ return: BGR IMAGE
52
+ """
53
+ height, width = image.shape[:2]
54
+ mask = mask.astype("float32") / 255.0
55
+ masked_image = image * (1 - mask)
56
+
57
+ # list of rgb ndarray
58
+ results, rtn_code, rtn_warning = self.model(
59
+ image=image,
60
+ masked_image=masked_image,
61
+ prompt=config.prompt,
62
+ negative_prompt=config.negative_prompt,
63
+ num_inference_steps=config.sd_steps,
64
+ strength=config.sd_strength,
65
+ guidance_scale=config.sd_guidance_scale,
66
+ height=height,
67
+ width=width,
68
+ seed=config.sd_seed,
69
+ sort_priority="y",
70
+ callback=self.callback,
71
+ )
72
+ inpainted_rgb_image = results[0][..., ::-1]
73
+ return inpainted_rgb_image
sorawm/iopaint/model/anytext/anytext_pipeline.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AnyText: Multilingual Visual Text Generation And Editing
3
+ Paper: https://arxiv.org/abs/2311.03054
4
+ Code: https://github.com/tyxsspa/AnyText
5
+ Copyright (c) Alibaba, Inc. and its affiliates.
6
+ """
7
+ import os
8
+ from pathlib import Path
9
+
10
+ from safetensors.torch import load_file
11
+
12
+ from sorawm.iopaint.model.utils import set_seed
13
+
14
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
15
+ import re
16
+
17
+ import cv2
18
+ import einops
19
+ import numpy as np
20
+ import torch
21
+ from PIL import ImageFont
22
+
23
+ from sorawm.iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler
24
+ from sorawm.iopaint.model.anytext.cldm.model import create_model, load_state_dict
25
+ from sorawm.iopaint.model.anytext.utils import check_channels, draw_glyph, draw_glyph2
26
+
27
+ BBOX_MAX_NUM = 8
28
+ PLACE_HOLDER = "*"
29
+ max_chars = 20
30
+
31
+ ANYTEXT_CFG = os.path.join(
32
+ os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml"
33
+ )
34
+
35
+
36
+ def check_limits(tensor):
37
+ float16_min = torch.finfo(torch.float16).min
38
+ float16_max = torch.finfo(torch.float16).max
39
+
40
+ # 检查张量中是否有值小于float16的最小值或大于float16的最大值
41
+ is_below_min = (tensor < float16_min).any()
42
+ is_above_max = (tensor > float16_max).any()
43
+
44
+ return is_below_min or is_above_max
45
+
46
+
47
+ class AnyTextPipeline:
48
+ def __init__(self, ckpt_path, font_path, device, use_fp16=True):
49
+ self.cfg_path = ANYTEXT_CFG
50
+ self.font_path = font_path
51
+ self.use_fp16 = use_fp16
52
+ self.device = device
53
+
54
+ self.font = ImageFont.truetype(font_path, size=60)
55
+ self.model = create_model(
56
+ self.cfg_path,
57
+ device=self.device,
58
+ use_fp16=self.use_fp16,
59
+ )
60
+ if self.use_fp16:
61
+ self.model = self.model.half()
62
+ if Path(ckpt_path).suffix == ".safetensors":
63
+ state_dict = load_file(ckpt_path, device="cpu")
64
+ else:
65
+ state_dict = load_state_dict(ckpt_path, location="cpu")
66
+ self.model.load_state_dict(state_dict, strict=False)
67
+ self.model = self.model.eval().to(self.device)
68
+ self.ddim_sampler = DDIMSampler(self.model, device=self.device)
69
+
70
+ def __call__(
71
+ self,
72
+ prompt: str,
73
+ negative_prompt: str,
74
+ image: np.ndarray,
75
+ masked_image: np.ndarray,
76
+ num_inference_steps: int,
77
+ strength: float,
78
+ guidance_scale: float,
79
+ height: int,
80
+ width: int,
81
+ seed: int,
82
+ sort_priority: str = "y",
83
+ callback=None,
84
+ ):
85
+ """
86
+
87
+ Args:
88
+ prompt:
89
+ negative_prompt:
90
+ image:
91
+ masked_image:
92
+ num_inference_steps:
93
+ strength:
94
+ guidance_scale:
95
+ height:
96
+ width:
97
+ seed:
98
+ sort_priority: x: left-right, y: top-down
99
+
100
+ Returns:
101
+ result: list of images in numpy.ndarray format
102
+ rst_code: 0: normal -1: error 1:warning
103
+ rst_info: string of error or warning
104
+
105
+ """
106
+ set_seed(seed)
107
+ str_warning = ""
108
+
109
+ mode = "text-editing"
110
+ revise_pos = False
111
+ img_count = 1
112
+ ddim_steps = num_inference_steps
113
+ w = width
114
+ h = height
115
+ strength = strength
116
+ cfg_scale = guidance_scale
117
+ eta = 0.0
118
+
119
+ prompt, texts = self.modify_prompt(prompt)
120
+ if prompt is None and texts is None:
121
+ return (
122
+ None,
123
+ -1,
124
+ "You have input Chinese prompt but the translator is not loaded!",
125
+ "",
126
+ )
127
+ n_lines = len(texts)
128
+ if mode in ["text-generation", "gen"]:
129
+ edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
130
+ elif mode in ["text-editing", "edit"]:
131
+ if masked_image is None or image is None:
132
+ return (
133
+ None,
134
+ -1,
135
+ "Reference image and position image are needed for text editing!",
136
+ "",
137
+ )
138
+ if isinstance(image, str):
139
+ image = cv2.imread(image)[..., ::-1]
140
+ assert image is not None, f"Can't read ori_image image from{image}!"
141
+ elif isinstance(image, torch.Tensor):
142
+ image = image.cpu().numpy()
143
+ else:
144
+ assert isinstance(
145
+ image, np.ndarray
146
+ ), f"Unknown format of ori_image: {type(image)}"
147
+ edit_image = image.clip(1, 255) # for mask reason
148
+ edit_image = check_channels(edit_image)
149
+ # edit_image = resize_image(
150
+ # edit_image, max_length=768
151
+ # ) # make w h multiple of 64, resize if w or h > max_length
152
+ h, w = edit_image.shape[:2] # change h, w by input ref_img
153
+ # preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
154
+ if masked_image is None:
155
+ pos_imgs = np.zeros((w, h, 1))
156
+ if isinstance(masked_image, str):
157
+ masked_image = cv2.imread(masked_image)[..., ::-1]
158
+ assert (
159
+ masked_image is not None
160
+ ), f"Can't read draw_pos image from{masked_image}!"
161
+ pos_imgs = 255 - masked_image
162
+ elif isinstance(masked_image, torch.Tensor):
163
+ pos_imgs = masked_image.cpu().numpy()
164
+ else:
165
+ assert isinstance(
166
+ masked_image, np.ndarray
167
+ ), f"Unknown format of draw_pos: {type(masked_image)}"
168
+ pos_imgs = 255 - masked_image
169
+ pos_imgs = pos_imgs[..., 0:1]
170
+ pos_imgs = cv2.convertScaleAbs(pos_imgs)
171
+ _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
172
+ # seprate pos_imgs
173
+ pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority)
174
+ if len(pos_imgs) == 0:
175
+ pos_imgs = [np.zeros((h, w, 1))]
176
+ if len(pos_imgs) < n_lines:
177
+ if n_lines == 1 and texts[0] == " ":
178
+ pass # text-to-image without text
179
+ else:
180
+ raise RuntimeError(
181
+ f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images"
182
+ )
183
+ elif len(pos_imgs) > n_lines:
184
+ str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
185
+ # get pre_pos, poly_list, hint that needed for anytext
186
+ pre_pos = []
187
+ poly_list = []
188
+ for input_pos in pos_imgs:
189
+ if input_pos.mean() != 0:
190
+ input_pos = (
191
+ input_pos[..., np.newaxis]
192
+ if len(input_pos.shape) == 2
193
+ else input_pos
194
+ )
195
+ poly, pos_img = self.find_polygon(input_pos)
196
+ pre_pos += [pos_img / 255.0]
197
+ poly_list += [poly]
198
+ else:
199
+ pre_pos += [np.zeros((h, w, 1))]
200
+ poly_list += [None]
201
+ np_hint = np.sum(pre_pos, axis=0).clip(0, 1)
202
+ # prepare info dict
203
+ info = {}
204
+ info["glyphs"] = []
205
+ info["gly_line"] = []
206
+ info["positions"] = []
207
+ info["n_lines"] = [len(texts)] * img_count
208
+ gly_pos_imgs = []
209
+ for i in range(len(texts)):
210
+ text = texts[i]
211
+ if len(text) > max_chars:
212
+ str_warning = (
213
+ f'"{text}" length > max_chars: {max_chars}, will be cut off...'
214
+ )
215
+ text = text[:max_chars]
216
+ gly_scale = 2
217
+ if pre_pos[i].mean() != 0:
218
+ gly_line = draw_glyph(self.font, text)
219
+ glyphs = draw_glyph2(
220
+ self.font,
221
+ text,
222
+ poly_list[i],
223
+ scale=gly_scale,
224
+ width=w,
225
+ height=h,
226
+ add_space=False,
227
+ )
228
+ gly_pos_img = cv2.drawContours(
229
+ glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1
230
+ )
231
+ if revise_pos:
232
+ resize_gly = cv2.resize(
233
+ glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])
234
+ )
235
+ new_pos = cv2.morphologyEx(
236
+ (resize_gly * 255).astype(np.uint8),
237
+ cv2.MORPH_CLOSE,
238
+ kernel=np.ones(
239
+ (resize_gly.shape[0] // 10, resize_gly.shape[1] // 10),
240
+ dtype=np.uint8,
241
+ ),
242
+ iterations=1,
243
+ )
244
+ new_pos = (
245
+ new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos
246
+ )
247
+ contours, _ = cv2.findContours(
248
+ new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
249
+ )
250
+ if len(contours) != 1:
251
+ str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..."
252
+ else:
253
+ rect = cv2.minAreaRect(contours[0])
254
+ poly = np.int0(cv2.boxPoints(rect))
255
+ pre_pos[i] = (
256
+ cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0
257
+ )
258
+ gly_pos_img = cv2.drawContours(
259
+ glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1
260
+ )
261
+ gly_pos_imgs += [gly_pos_img] # for show
262
+ else:
263
+ glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))
264
+ gly_line = np.zeros((80, 512, 1))
265
+ gly_pos_imgs += [
266
+ np.zeros((h * gly_scale, w * gly_scale, 1))
267
+ ] # for show
268
+ pos = pre_pos[i]
269
+ info["glyphs"] += [self.arr2tensor(glyphs, img_count)]
270
+ info["gly_line"] += [self.arr2tensor(gly_line, img_count)]
271
+ info["positions"] += [self.arr2tensor(pos, img_count)]
272
+ # get masked_x
273
+ masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
274
+ masked_img = np.transpose(masked_img, (2, 0, 1))
275
+ masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
276
+ if self.use_fp16:
277
+ masked_img = masked_img.half()
278
+ encoder_posterior = self.model.encode_first_stage(masked_img[None, ...])
279
+ masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach()
280
+ if self.use_fp16:
281
+ masked_x = masked_x.half()
282
+ info["masked_x"] = torch.cat([masked_x for _ in range(img_count)], dim=0)
283
+
284
+ hint = self.arr2tensor(np_hint, img_count)
285
+ cond = self.model.get_learned_conditioning(
286
+ dict(
287
+ c_concat=[hint],
288
+ c_crossattn=[[prompt] * img_count],
289
+ text_info=info,
290
+ )
291
+ )
292
+ un_cond = self.model.get_learned_conditioning(
293
+ dict(
294
+ c_concat=[hint],
295
+ c_crossattn=[[negative_prompt] * img_count],
296
+ text_info=info,
297
+ )
298
+ )
299
+ shape = (4, h // 8, w // 8)
300
+ self.model.control_scales = [strength] * 13
301
+ samples, intermediates = self.ddim_sampler.sample(
302
+ ddim_steps,
303
+ img_count,
304
+ shape,
305
+ cond,
306
+ verbose=False,
307
+ eta=eta,
308
+ unconditional_guidance_scale=cfg_scale,
309
+ unconditional_conditioning=un_cond,
310
+ callback=callback,
311
+ )
312
+ if self.use_fp16:
313
+ samples = samples.half()
314
+ x_samples = self.model.decode_first_stage(samples)
315
+ x_samples = (
316
+ (einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5)
317
+ .cpu()
318
+ .numpy()
319
+ .clip(0, 255)
320
+ .astype(np.uint8)
321
+ )
322
+ results = [x_samples[i] for i in range(img_count)]
323
+ # if (
324
+ # mode == "edit" and False
325
+ # ): # replace backgound in text editing but not ideal yet
326
+ # results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
327
+ # results = [r.clip(0, 255).astype(np.uint8) for r in results]
328
+ # if len(gly_pos_imgs) > 0 and show_debug:
329
+ # glyph_bs = np.stack(gly_pos_imgs, axis=2)
330
+ # glyph_img = np.sum(glyph_bs, axis=2) * 255
331
+ # glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
332
+ # results += [np.repeat(glyph_img, 3, axis=2)]
333
+ rst_code = 1 if str_warning else 0
334
+ return results, rst_code, str_warning
335
+
336
+ def modify_prompt(self, prompt):
337
+ prompt = prompt.replace("“", '"')
338
+ prompt = prompt.replace("”", '"')
339
+ p = '"(.*?)"'
340
+ strs = re.findall(p, prompt)
341
+ if len(strs) == 0:
342
+ strs = [" "]
343
+ else:
344
+ for s in strs:
345
+ prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1)
346
+ # if self.is_chinese(prompt):
347
+ # if self.trans_pipe is None:
348
+ # return None, None
349
+ # old_prompt = prompt
350
+ # prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1]
351
+ # print(f"Translate: {old_prompt} --> {prompt}")
352
+ return prompt, strs
353
+
354
+ # def is_chinese(self, text):
355
+ # text = checker._clean_text(text)
356
+ # for char in text:
357
+ # cp = ord(char)
358
+ # if checker._is_chinese_char(cp):
359
+ # return True
360
+ # return False
361
+
362
+ def separate_pos_imgs(self, img, sort_priority, gap=102):
363
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img)
364
+ components = []
365
+ for label in range(1, num_labels):
366
+ component = np.zeros_like(img)
367
+ component[labels == label] = 255
368
+ components.append((component, centroids[label]))
369
+ if sort_priority == "y":
370
+ fir, sec = 1, 0 # top-down first
371
+ elif sort_priority == "x":
372
+ fir, sec = 0, 1 # left-right first
373
+ components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))
374
+ sorted_components = [c[0] for c in components]
375
+ return sorted_components
376
+
377
+ def find_polygon(self, image, min_rect=False):
378
+ contours, hierarchy = cv2.findContours(
379
+ image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
380
+ )
381
+ max_contour = max(contours, key=cv2.contourArea) # get contour with max area
382
+ if min_rect:
383
+ # get minimum enclosing rectangle
384
+ rect = cv2.minAreaRect(max_contour)
385
+ poly = np.int0(cv2.boxPoints(rect))
386
+ else:
387
+ # get approximate polygon
388
+ epsilon = 0.01 * cv2.arcLength(max_contour, True)
389
+ poly = cv2.approxPolyDP(max_contour, epsilon, True)
390
+ n, _, xy = poly.shape
391
+ poly = poly.reshape(n, xy)
392
+ cv2.drawContours(image, [poly], -1, 255, -1)
393
+ return poly, image
394
+
395
+ def arr2tensor(self, arr, bs):
396
+ arr = np.transpose(arr, (2, 0, 1))
397
+ _arr = torch.from_numpy(arr.copy()).float().to(self.device)
398
+ if self.use_fp16:
399
+ _arr = _arr.half()
400
+ _arr = torch.stack([_arr for _ in range(bs)], dim=0)
401
+ return _arr
sorawm/iopaint/model/anytext/anytext_sd15.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sorawm.iopaint.model.anytext.cldm.cldm.ControlLDM
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ num_timesteps_cond: 1
7
+ log_every_t: 200
8
+ timesteps: 1000
9
+ first_stage_key: "img"
10
+ cond_stage_key: "caption"
11
+ control_key: "hint"
12
+ glyph_key: "glyphs"
13
+ position_key: "positions"
14
+ image_size: 64
15
+ channels: 4
16
+ cond_stage_trainable: true # need be true when embedding_manager is valid
17
+ conditioning_key: crossattn
18
+ monitor: val/loss_simple_ema
19
+ scale_factor: 0.18215
20
+ use_ema: False
21
+ only_mid_control: False
22
+ loss_alpha: 0 # perceptual loss, 0.003
23
+ loss_beta: 0 # ctc loss
24
+ latin_weight: 1.0 # latin text line may need smaller weigth
25
+ with_step_weight: true
26
+ use_vae_upsample: true
27
+ embedding_manager_config:
28
+ target: sorawm.iopaint.model.anytext.cldm.embedding_manager.EmbeddingManager
29
+ params:
30
+ valid: true # v6
31
+ emb_type: ocr # ocr, vit, conv
32
+ glyph_channels: 1
33
+ position_channels: 1
34
+ add_pos: false
35
+ placeholder_string: '*'
36
+
37
+ control_stage_config:
38
+ target: sorawm.iopaint.model.anytext.cldm.cldm.ControlNet
39
+ params:
40
+ image_size: 32 # unused
41
+ in_channels: 4
42
+ model_channels: 320
43
+ glyph_channels: 1
44
+ position_channels: 1
45
+ attention_resolutions: [ 4, 2, 1 ]
46
+ num_res_blocks: 2
47
+ channel_mult: [ 1, 2, 4, 4 ]
48
+ num_heads: 8
49
+ use_spatial_transformer: True
50
+ transformer_depth: 1
51
+ context_dim: 768
52
+ use_checkpoint: True
53
+ legacy: False
54
+
55
+ unet_config:
56
+ target: sorawm.iopaint.model.anytext.cldm.cldm.ControlledUnetModel
57
+ params:
58
+ image_size: 32 # unused
59
+ in_channels: 4
60
+ out_channels: 4
61
+ model_channels: 320
62
+ attention_resolutions: [ 4, 2, 1 ]
63
+ num_res_blocks: 2
64
+ channel_mult: [ 1, 2, 4, 4 ]
65
+ num_heads: 8
66
+ use_spatial_transformer: True
67
+ transformer_depth: 1
68
+ context_dim: 768
69
+ use_checkpoint: True
70
+ legacy: False
71
+
72
+ first_stage_config:
73
+ target: sorawm.iopaint.model.anytext.ldm.models.autoencoder.AutoencoderKL
74
+ params:
75
+ embed_dim: 4
76
+ monitor: val/rec_loss
77
+ ddconfig:
78
+ double_z: true
79
+ z_channels: 4
80
+ resolution: 256
81
+ in_channels: 3
82
+ out_ch: 3
83
+ ch: 128
84
+ ch_mult:
85
+ - 1
86
+ - 2
87
+ - 4
88
+ - 4
89
+ num_res_blocks: 2
90
+ attn_resolutions: []
91
+ dropout: 0.0
92
+ lossconfig:
93
+ target: torch.nn.Identity
94
+
95
+ cond_stage_config:
96
+ target: sorawm.iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedderT3
97
+ params:
98
+ version: openai/clip-vit-large-patch14
99
+ use_vision: false # v6
sorawm/iopaint/model/anytext/cldm/__init__.py ADDED
File without changes
sorawm/iopaint/model/anytext/cldm/cldm.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import einops
6
+ import torch
7
+ import torch as th
8
+ import torch.nn as nn
9
+ from easydict import EasyDict as edict
10
+ from einops import rearrange, repeat
11
+
12
+ from sorawm.iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler
13
+ from sorawm.iopaint.model.anytext.ldm.models.diffusion.ddpm import LatentDiffusion
14
+ from sorawm.iopaint.model.anytext.ldm.modules.attention import SpatialTransformer
15
+ from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.openaimodel import (
16
+ AttentionBlock,
17
+ Downsample,
18
+ ResBlock,
19
+ TimestepEmbedSequential,
20
+ UNetModel,
21
+ )
22
+ from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
23
+ conv_nd,
24
+ linear,
25
+ timestep_embedding,
26
+ zero_module,
27
+ )
28
+ from sorawm.iopaint.model.anytext.ldm.modules.distributions.distributions import (
29
+ DiagonalGaussianDistribution,
30
+ )
31
+ from sorawm.iopaint.model.anytext.ldm.util import (
32
+ exists,
33
+ instantiate_from_config,
34
+ log_txt_as_img,
35
+ )
36
+
37
+ from .recognizer import TextRecognizer, create_predictor
38
+
39
+ CURRENT_DIR = Path(os.path.dirname(os.path.abspath(__file__)))
40
+
41
+
42
+ def count_parameters(model):
43
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
44
+
45
+
46
+ class ControlledUnetModel(UNetModel):
47
+ def forward(
48
+ self,
49
+ x,
50
+ timesteps=None,
51
+ context=None,
52
+ control=None,
53
+ only_mid_control=False,
54
+ **kwargs,
55
+ ):
56
+ hs = []
57
+ with torch.no_grad():
58
+ t_emb = timestep_embedding(
59
+ timesteps, self.model_channels, repeat_only=False
60
+ )
61
+ if self.use_fp16:
62
+ t_emb = t_emb.half()
63
+ emb = self.time_embed(t_emb)
64
+ h = x.type(self.dtype)
65
+ for module in self.input_blocks:
66
+ h = module(h, emb, context)
67
+ hs.append(h)
68
+ h = self.middle_block(h, emb, context)
69
+
70
+ if control is not None:
71
+ h += control.pop()
72
+
73
+ for i, module in enumerate(self.output_blocks):
74
+ if only_mid_control or control is None:
75
+ h = torch.cat([h, hs.pop()], dim=1)
76
+ else:
77
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
78
+ h = module(h, emb, context)
79
+
80
+ h = h.type(x.dtype)
81
+ return self.out(h)
82
+
83
+
84
+ class ControlNet(nn.Module):
85
+ def __init__(
86
+ self,
87
+ image_size,
88
+ in_channels,
89
+ model_channels,
90
+ glyph_channels,
91
+ position_channels,
92
+ num_res_blocks,
93
+ attention_resolutions,
94
+ dropout=0,
95
+ channel_mult=(1, 2, 4, 8),
96
+ conv_resample=True,
97
+ dims=2,
98
+ use_checkpoint=False,
99
+ use_fp16=False,
100
+ num_heads=-1,
101
+ num_head_channels=-1,
102
+ num_heads_upsample=-1,
103
+ use_scale_shift_norm=False,
104
+ resblock_updown=False,
105
+ use_new_attention_order=False,
106
+ use_spatial_transformer=False, # custom transformer support
107
+ transformer_depth=1, # custom transformer support
108
+ context_dim=None, # custom transformer support
109
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
110
+ legacy=True,
111
+ disable_self_attentions=None,
112
+ num_attention_blocks=None,
113
+ disable_middle_self_attn=False,
114
+ use_linear_in_transformer=False,
115
+ ):
116
+ super().__init__()
117
+ if use_spatial_transformer:
118
+ assert (
119
+ context_dim is not None
120
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
121
+
122
+ if context_dim is not None:
123
+ assert (
124
+ use_spatial_transformer
125
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
126
+ from omegaconf.listconfig import ListConfig
127
+
128
+ if type(context_dim) == ListConfig:
129
+ context_dim = list(context_dim)
130
+
131
+ if num_heads_upsample == -1:
132
+ num_heads_upsample = num_heads
133
+
134
+ if num_heads == -1:
135
+ assert (
136
+ num_head_channels != -1
137
+ ), "Either num_heads or num_head_channels has to be set"
138
+
139
+ if num_head_channels == -1:
140
+ assert (
141
+ num_heads != -1
142
+ ), "Either num_heads or num_head_channels has to be set"
143
+ self.dims = dims
144
+ self.image_size = image_size
145
+ self.in_channels = in_channels
146
+ self.model_channels = model_channels
147
+ if isinstance(num_res_blocks, int):
148
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
149
+ else:
150
+ if len(num_res_blocks) != len(channel_mult):
151
+ raise ValueError(
152
+ "provide num_res_blocks either as an int (globally constant) or "
153
+ "as a list/tuple (per-level) with the same length as channel_mult"
154
+ )
155
+ self.num_res_blocks = num_res_blocks
156
+ if disable_self_attentions is not None:
157
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
158
+ assert len(disable_self_attentions) == len(channel_mult)
159
+ if num_attention_blocks is not None:
160
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
161
+ assert all(
162
+ map(
163
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
164
+ range(len(num_attention_blocks)),
165
+ )
166
+ )
167
+ print(
168
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
169
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
170
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
171
+ f"attention will still not be set."
172
+ )
173
+ self.attention_resolutions = attention_resolutions
174
+ self.dropout = dropout
175
+ self.channel_mult = channel_mult
176
+ self.conv_resample = conv_resample
177
+ self.use_checkpoint = use_checkpoint
178
+ self.use_fp16 = use_fp16
179
+ self.dtype = th.float16 if use_fp16 else th.float32
180
+ self.num_heads = num_heads
181
+ self.num_head_channels = num_head_channels
182
+ self.num_heads_upsample = num_heads_upsample
183
+ self.predict_codebook_ids = n_embed is not None
184
+
185
+ time_embed_dim = model_channels * 4
186
+ self.time_embed = nn.Sequential(
187
+ linear(model_channels, time_embed_dim),
188
+ nn.SiLU(),
189
+ linear(time_embed_dim, time_embed_dim),
190
+ )
191
+
192
+ self.input_blocks = nn.ModuleList(
193
+ [
194
+ TimestepEmbedSequential(
195
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
196
+ )
197
+ ]
198
+ )
199
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
200
+
201
+ self.glyph_block = TimestepEmbedSequential(
202
+ conv_nd(dims, glyph_channels, 8, 3, padding=1),
203
+ nn.SiLU(),
204
+ conv_nd(dims, 8, 8, 3, padding=1),
205
+ nn.SiLU(),
206
+ conv_nd(dims, 8, 16, 3, padding=1, stride=2),
207
+ nn.SiLU(),
208
+ conv_nd(dims, 16, 16, 3, padding=1),
209
+ nn.SiLU(),
210
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
211
+ nn.SiLU(),
212
+ conv_nd(dims, 32, 32, 3, padding=1),
213
+ nn.SiLU(),
214
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
215
+ nn.SiLU(),
216
+ conv_nd(dims, 96, 96, 3, padding=1),
217
+ nn.SiLU(),
218
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
219
+ nn.SiLU(),
220
+ )
221
+
222
+ self.position_block = TimestepEmbedSequential(
223
+ conv_nd(dims, position_channels, 8, 3, padding=1),
224
+ nn.SiLU(),
225
+ conv_nd(dims, 8, 8, 3, padding=1),
226
+ nn.SiLU(),
227
+ conv_nd(dims, 8, 16, 3, padding=1, stride=2),
228
+ nn.SiLU(),
229
+ conv_nd(dims, 16, 16, 3, padding=1),
230
+ nn.SiLU(),
231
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
232
+ nn.SiLU(),
233
+ conv_nd(dims, 32, 32, 3, padding=1),
234
+ nn.SiLU(),
235
+ conv_nd(dims, 32, 64, 3, padding=1, stride=2),
236
+ nn.SiLU(),
237
+ )
238
+
239
+ self.fuse_block = zero_module(
240
+ conv_nd(dims, 256 + 64 + 4, model_channels, 3, padding=1)
241
+ )
242
+
243
+ self._feature_size = model_channels
244
+ input_block_chans = [model_channels]
245
+ ch = model_channels
246
+ ds = 1
247
+ for level, mult in enumerate(channel_mult):
248
+ for nr in range(self.num_res_blocks[level]):
249
+ layers = [
250
+ ResBlock(
251
+ ch,
252
+ time_embed_dim,
253
+ dropout,
254
+ out_channels=mult * model_channels,
255
+ dims=dims,
256
+ use_checkpoint=use_checkpoint,
257
+ use_scale_shift_norm=use_scale_shift_norm,
258
+ )
259
+ ]
260
+ ch = mult * model_channels
261
+ if ds in attention_resolutions:
262
+ if num_head_channels == -1:
263
+ dim_head = ch // num_heads
264
+ else:
265
+ num_heads = ch // num_head_channels
266
+ dim_head = num_head_channels
267
+ if legacy:
268
+ # num_heads = 1
269
+ dim_head = (
270
+ ch // num_heads
271
+ if use_spatial_transformer
272
+ else num_head_channels
273
+ )
274
+ if exists(disable_self_attentions):
275
+ disabled_sa = disable_self_attentions[level]
276
+ else:
277
+ disabled_sa = False
278
+
279
+ if (
280
+ not exists(num_attention_blocks)
281
+ or nr < num_attention_blocks[level]
282
+ ):
283
+ layers.append(
284
+ AttentionBlock(
285
+ ch,
286
+ use_checkpoint=use_checkpoint,
287
+ num_heads=num_heads,
288
+ num_head_channels=dim_head,
289
+ use_new_attention_order=use_new_attention_order,
290
+ )
291
+ if not use_spatial_transformer
292
+ else SpatialTransformer(
293
+ ch,
294
+ num_heads,
295
+ dim_head,
296
+ depth=transformer_depth,
297
+ context_dim=context_dim,
298
+ disable_self_attn=disabled_sa,
299
+ use_linear=use_linear_in_transformer,
300
+ use_checkpoint=use_checkpoint,
301
+ )
302
+ )
303
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
304
+ self.zero_convs.append(self.make_zero_conv(ch))
305
+ self._feature_size += ch
306
+ input_block_chans.append(ch)
307
+ if level != len(channel_mult) - 1:
308
+ out_ch = ch
309
+ self.input_blocks.append(
310
+ TimestepEmbedSequential(
311
+ ResBlock(
312
+ ch,
313
+ time_embed_dim,
314
+ dropout,
315
+ out_channels=out_ch,
316
+ dims=dims,
317
+ use_checkpoint=use_checkpoint,
318
+ use_scale_shift_norm=use_scale_shift_norm,
319
+ down=True,
320
+ )
321
+ if resblock_updown
322
+ else Downsample(
323
+ ch, conv_resample, dims=dims, out_channels=out_ch
324
+ )
325
+ )
326
+ )
327
+ ch = out_ch
328
+ input_block_chans.append(ch)
329
+ self.zero_convs.append(self.make_zero_conv(ch))
330
+ ds *= 2
331
+ self._feature_size += ch
332
+
333
+ if num_head_channels == -1:
334
+ dim_head = ch // num_heads
335
+ else:
336
+ num_heads = ch // num_head_channels
337
+ dim_head = num_head_channels
338
+ if legacy:
339
+ # num_heads = 1
340
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
341
+ self.middle_block = TimestepEmbedSequential(
342
+ ResBlock(
343
+ ch,
344
+ time_embed_dim,
345
+ dropout,
346
+ dims=dims,
347
+ use_checkpoint=use_checkpoint,
348
+ use_scale_shift_norm=use_scale_shift_norm,
349
+ ),
350
+ AttentionBlock(
351
+ ch,
352
+ use_checkpoint=use_checkpoint,
353
+ num_heads=num_heads,
354
+ num_head_channels=dim_head,
355
+ use_new_attention_order=use_new_attention_order,
356
+ )
357
+ if not use_spatial_transformer
358
+ else SpatialTransformer( # always uses a self-attn
359
+ ch,
360
+ num_heads,
361
+ dim_head,
362
+ depth=transformer_depth,
363
+ context_dim=context_dim,
364
+ disable_self_attn=disable_middle_self_attn,
365
+ use_linear=use_linear_in_transformer,
366
+ use_checkpoint=use_checkpoint,
367
+ ),
368
+ ResBlock(
369
+ ch,
370
+ time_embed_dim,
371
+ dropout,
372
+ dims=dims,
373
+ use_checkpoint=use_checkpoint,
374
+ use_scale_shift_norm=use_scale_shift_norm,
375
+ ),
376
+ )
377
+ self.middle_block_out = self.make_zero_conv(ch)
378
+ self._feature_size += ch
379
+
380
+ def make_zero_conv(self, channels):
381
+ return TimestepEmbedSequential(
382
+ zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))
383
+ )
384
+
385
+ def forward(self, x, hint, text_info, timesteps, context, **kwargs):
386
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
387
+ if self.use_fp16:
388
+ t_emb = t_emb.half()
389
+ emb = self.time_embed(t_emb)
390
+
391
+ # guided_hint from text_info
392
+ B, C, H, W = x.shape
393
+ glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True)
394
+ positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True)
395
+ enc_glyph = self.glyph_block(glyphs, emb, context)
396
+ enc_pos = self.position_block(positions, emb, context)
397
+ guided_hint = self.fuse_block(
398
+ torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1)
399
+ )
400
+
401
+ outs = []
402
+
403
+ h = x.type(self.dtype)
404
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
405
+ if guided_hint is not None:
406
+ h = module(h, emb, context)
407
+ h += guided_hint
408
+ guided_hint = None
409
+ else:
410
+ h = module(h, emb, context)
411
+ outs.append(zero_conv(h, emb, context))
412
+
413
+ h = self.middle_block(h, emb, context)
414
+ outs.append(self.middle_block_out(h, emb, context))
415
+
416
+ return outs
417
+
418
+
419
+ class ControlLDM(LatentDiffusion):
420
+ def __init__(
421
+ self,
422
+ control_stage_config,
423
+ control_key,
424
+ glyph_key,
425
+ position_key,
426
+ only_mid_control,
427
+ loss_alpha=0,
428
+ loss_beta=0,
429
+ with_step_weight=False,
430
+ use_vae_upsample=False,
431
+ latin_weight=1.0,
432
+ embedding_manager_config=None,
433
+ *args,
434
+ **kwargs,
435
+ ):
436
+ self.use_fp16 = kwargs.pop("use_fp16", False)
437
+ super().__init__(*args, **kwargs)
438
+ self.control_model = instantiate_from_config(control_stage_config)
439
+ self.control_key = control_key
440
+ self.glyph_key = glyph_key
441
+ self.position_key = position_key
442
+ self.only_mid_control = only_mid_control
443
+ self.control_scales = [1.0] * 13
444
+ self.loss_alpha = loss_alpha
445
+ self.loss_beta = loss_beta
446
+ self.with_step_weight = with_step_weight
447
+ self.use_vae_upsample = use_vae_upsample
448
+ self.latin_weight = latin_weight
449
+
450
+ if (
451
+ embedding_manager_config is not None
452
+ and embedding_manager_config.params.valid
453
+ ):
454
+ self.embedding_manager = self.instantiate_embedding_manager(
455
+ embedding_manager_config, self.cond_stage_model
456
+ )
457
+ for param in self.embedding_manager.embedding_parameters():
458
+ param.requires_grad = True
459
+ else:
460
+ self.embedding_manager = None
461
+ if self.loss_alpha > 0 or self.loss_beta > 0 or self.embedding_manager:
462
+ if embedding_manager_config.params.emb_type == "ocr":
463
+ self.text_predictor = create_predictor().eval()
464
+ args = edict()
465
+ args.rec_image_shape = "3, 48, 320"
466
+ args.rec_batch_num = 6
467
+ args.rec_char_dict_path = str(
468
+ CURRENT_DIR.parent / "ocr_recog" / "ppocr_keys_v1.txt"
469
+ )
470
+ args.use_fp16 = self.use_fp16
471
+ self.cn_recognizer = TextRecognizer(args, self.text_predictor)
472
+ for param in self.text_predictor.parameters():
473
+ param.requires_grad = False
474
+ if self.embedding_manager:
475
+ self.embedding_manager.recog = self.cn_recognizer
476
+
477
+ @torch.no_grad()
478
+ def get_input(self, batch, k, bs=None, *args, **kwargs):
479
+ if self.embedding_manager is None: # fill in full caption
480
+ self.fill_caption(batch)
481
+ x, c, mx = super().get_input(
482
+ batch, self.first_stage_key, mask_k="masked_img", *args, **kwargs
483
+ )
484
+ control = batch[
485
+ self.control_key
486
+ ] # for log_images and loss_alpha, not real control
487
+ if bs is not None:
488
+ control = control[:bs]
489
+ control = control.to(self.device)
490
+ control = einops.rearrange(control, "b h w c -> b c h w")
491
+ control = control.to(memory_format=torch.contiguous_format).float()
492
+
493
+ inv_mask = batch["inv_mask"]
494
+ if bs is not None:
495
+ inv_mask = inv_mask[:bs]
496
+ inv_mask = inv_mask.to(self.device)
497
+ inv_mask = einops.rearrange(inv_mask, "b h w c -> b c h w")
498
+ inv_mask = inv_mask.to(memory_format=torch.contiguous_format).float()
499
+
500
+ glyphs = batch[self.glyph_key]
501
+ gly_line = batch["gly_line"]
502
+ positions = batch[self.position_key]
503
+ n_lines = batch["n_lines"]
504
+ language = batch["language"]
505
+ texts = batch["texts"]
506
+ assert len(glyphs) == len(positions)
507
+ for i in range(len(glyphs)):
508
+ if bs is not None:
509
+ glyphs[i] = glyphs[i][:bs]
510
+ gly_line[i] = gly_line[i][:bs]
511
+ positions[i] = positions[i][:bs]
512
+ n_lines = n_lines[:bs]
513
+ glyphs[i] = glyphs[i].to(self.device)
514
+ gly_line[i] = gly_line[i].to(self.device)
515
+ positions[i] = positions[i].to(self.device)
516
+ glyphs[i] = einops.rearrange(glyphs[i], "b h w c -> b c h w")
517
+ gly_line[i] = einops.rearrange(gly_line[i], "b h w c -> b c h w")
518
+ positions[i] = einops.rearrange(positions[i], "b h w c -> b c h w")
519
+ glyphs[i] = glyphs[i].to(memory_format=torch.contiguous_format).float()
520
+ gly_line[i] = gly_line[i].to(memory_format=torch.contiguous_format).float()
521
+ positions[i] = (
522
+ positions[i].to(memory_format=torch.contiguous_format).float()
523
+ )
524
+ info = {}
525
+ info["glyphs"] = glyphs
526
+ info["positions"] = positions
527
+ info["n_lines"] = n_lines
528
+ info["language"] = language
529
+ info["texts"] = texts
530
+ info["img"] = batch["img"] # nhwc, (-1,1)
531
+ info["masked_x"] = mx
532
+ info["gly_line"] = gly_line
533
+ info["inv_mask"] = inv_mask
534
+ return x, dict(c_crossattn=[c], c_concat=[control], text_info=info)
535
+
536
+ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
537
+ assert isinstance(cond, dict)
538
+ diffusion_model = self.model.diffusion_model
539
+ _cond = torch.cat(cond["c_crossattn"], 1)
540
+ _hint = torch.cat(cond["c_concat"], 1)
541
+ if self.use_fp16:
542
+ x_noisy = x_noisy.half()
543
+ control = self.control_model(
544
+ x=x_noisy,
545
+ timesteps=t,
546
+ context=_cond,
547
+ hint=_hint,
548
+ text_info=cond["text_info"],
549
+ )
550
+ control = [c * scale for c, scale in zip(control, self.control_scales)]
551
+ eps = diffusion_model(
552
+ x=x_noisy,
553
+ timesteps=t,
554
+ context=_cond,
555
+ control=control,
556
+ only_mid_control=self.only_mid_control,
557
+ )
558
+
559
+ return eps
560
+
561
+ def instantiate_embedding_manager(self, config, embedder):
562
+ model = instantiate_from_config(config, embedder=embedder)
563
+ return model
564
+
565
+ @torch.no_grad()
566
+ def get_unconditional_conditioning(self, N):
567
+ return self.get_learned_conditioning(
568
+ dict(c_crossattn=[[""] * N], text_info=None)
569
+ )
570
+
571
+ def get_learned_conditioning(self, c):
572
+ if self.cond_stage_forward is None:
573
+ if hasattr(self.cond_stage_model, "encode") and callable(
574
+ self.cond_stage_model.encode
575
+ ):
576
+ if self.embedding_manager is not None and c["text_info"] is not None:
577
+ self.embedding_manager.encode_text(c["text_info"])
578
+ if isinstance(c, dict):
579
+ cond_txt = c["c_crossattn"][0]
580
+ else:
581
+ cond_txt = c
582
+ if self.embedding_manager is not None:
583
+ cond_txt = self.cond_stage_model.encode(
584
+ cond_txt, embedding_manager=self.embedding_manager
585
+ )
586
+ else:
587
+ cond_txt = self.cond_stage_model.encode(cond_txt)
588
+ if isinstance(c, dict):
589
+ c["c_crossattn"][0] = cond_txt
590
+ else:
591
+ c = cond_txt
592
+ if isinstance(c, DiagonalGaussianDistribution):
593
+ c = c.mode()
594
+ else:
595
+ c = self.cond_stage_model(c)
596
+ else:
597
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
598
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
599
+ return c
600
+
601
+ def fill_caption(self, batch, place_holder="*"):
602
+ bs = len(batch["n_lines"])
603
+ cond_list = copy.deepcopy(batch[self.cond_stage_key])
604
+ for i in range(bs):
605
+ n_lines = batch["n_lines"][i]
606
+ if n_lines == 0:
607
+ continue
608
+ cur_cap = cond_list[i]
609
+ for j in range(n_lines):
610
+ r_txt = batch["texts"][j][i]
611
+ cur_cap = cur_cap.replace(place_holder, f'"{r_txt}"', 1)
612
+ cond_list[i] = cur_cap
613
+ batch[self.cond_stage_key] = cond_list
614
+
615
+ @torch.no_grad()
616
+ def log_images(
617
+ self,
618
+ batch,
619
+ N=4,
620
+ n_row=2,
621
+ sample=False,
622
+ ddim_steps=50,
623
+ ddim_eta=0.0,
624
+ return_keys=None,
625
+ quantize_denoised=True,
626
+ inpaint=True,
627
+ plot_denoise_rows=False,
628
+ plot_progressive_rows=True,
629
+ plot_diffusion_rows=False,
630
+ unconditional_guidance_scale=9.0,
631
+ unconditional_guidance_label=None,
632
+ use_ema_scope=True,
633
+ **kwargs,
634
+ ):
635
+ use_ddim = ddim_steps is not None
636
+
637
+ log = dict()
638
+ z, c = self.get_input(batch, self.first_stage_key, bs=N)
639
+ if self.cond_stage_trainable:
640
+ with torch.no_grad():
641
+ c = self.get_learned_conditioning(c)
642
+ c_crossattn = c["c_crossattn"][0][:N]
643
+ c_cat = c["c_concat"][0][:N]
644
+ text_info = c["text_info"]
645
+ text_info["glyphs"] = [i[:N] for i in text_info["glyphs"]]
646
+ text_info["gly_line"] = [i[:N] for i in text_info["gly_line"]]
647
+ text_info["positions"] = [i[:N] for i in text_info["positions"]]
648
+ text_info["n_lines"] = text_info["n_lines"][:N]
649
+ text_info["masked_x"] = text_info["masked_x"][:N]
650
+ text_info["img"] = text_info["img"][:N]
651
+
652
+ N = min(z.shape[0], N)
653
+ n_row = min(z.shape[0], n_row)
654
+ log["reconstruction"] = self.decode_first_stage(z)
655
+ log["masked_image"] = self.decode_first_stage(text_info["masked_x"])
656
+ log["control"] = c_cat * 2.0 - 1.0
657
+ log["img"] = text_info["img"].permute(0, 3, 1, 2) # log source image if needed
658
+ # get glyph
659
+ glyph_bs = torch.stack(text_info["glyphs"])
660
+ glyph_bs = torch.sum(glyph_bs, dim=0) * 2.0 - 1.0
661
+ log["glyph"] = torch.nn.functional.interpolate(
662
+ glyph_bs,
663
+ size=(512, 512),
664
+ mode="bilinear",
665
+ align_corners=True,
666
+ )
667
+ # fill caption
668
+ if not self.embedding_manager:
669
+ self.fill_caption(batch)
670
+ captions = batch[self.cond_stage_key]
671
+ log["conditioning"] = log_txt_as_img((512, 512), captions, size=16)
672
+
673
+ if plot_diffusion_rows:
674
+ # get diffusion row
675
+ diffusion_row = list()
676
+ z_start = z[:n_row]
677
+ for t in range(self.num_timesteps):
678
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
679
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
680
+ t = t.to(self.device).long()
681
+ noise = torch.randn_like(z_start)
682
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
683
+ diffusion_row.append(self.decode_first_stage(z_noisy))
684
+
685
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
686
+ diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
687
+ diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
688
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
689
+ log["diffusion_row"] = diffusion_grid
690
+
691
+ if sample:
692
+ # get denoise row
693
+ samples, z_denoise_row = self.sample_log(
694
+ cond={"c_concat": [c_cat], "c_crossattn": [c], "text_info": text_info},
695
+ batch_size=N,
696
+ ddim=use_ddim,
697
+ ddim_steps=ddim_steps,
698
+ eta=ddim_eta,
699
+ )
700
+ x_samples = self.decode_first_stage(samples)
701
+ log["samples"] = x_samples
702
+ if plot_denoise_rows:
703
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
704
+ log["denoise_row"] = denoise_grid
705
+
706
+ if unconditional_guidance_scale > 1.0:
707
+ uc_cross = self.get_unconditional_conditioning(N)
708
+ uc_cat = c_cat # torch.zeros_like(c_cat)
709
+ uc_full = {
710
+ "c_concat": [uc_cat],
711
+ "c_crossattn": [uc_cross["c_crossattn"][0]],
712
+ "text_info": text_info,
713
+ }
714
+ samples_cfg, tmps = self.sample_log(
715
+ cond={
716
+ "c_concat": [c_cat],
717
+ "c_crossattn": [c_crossattn],
718
+ "text_info": text_info,
719
+ },
720
+ batch_size=N,
721
+ ddim=use_ddim,
722
+ ddim_steps=ddim_steps,
723
+ eta=ddim_eta,
724
+ unconditional_guidance_scale=unconditional_guidance_scale,
725
+ unconditional_conditioning=uc_full,
726
+ )
727
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
728
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
729
+ pred_x0 = False # wether log pred_x0
730
+ if pred_x0:
731
+ for idx in range(len(tmps["pred_x0"])):
732
+ pred_x0 = self.decode_first_stage(tmps["pred_x0"][idx])
733
+ log[f"pred_x0_{tmps['index'][idx]}"] = pred_x0
734
+
735
+ return log
736
+
737
+ @torch.no_grad()
738
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
739
+ ddim_sampler = DDIMSampler(self)
740
+ b, c, h, w = cond["c_concat"][0].shape
741
+ shape = (self.channels, h // 8, w // 8)
742
+ samples, intermediates = ddim_sampler.sample(
743
+ ddim_steps, batch_size, shape, cond, verbose=False, log_every_t=5, **kwargs
744
+ )
745
+ return samples, intermediates
746
+
747
+ def configure_optimizers(self):
748
+ lr = self.learning_rate
749
+ params = list(self.control_model.parameters())
750
+ if self.embedding_manager:
751
+ params += list(self.embedding_manager.embedding_parameters())
752
+ if not self.sd_locked:
753
+ # params += list(self.model.diffusion_model.input_blocks.parameters())
754
+ # params += list(self.model.diffusion_model.middle_block.parameters())
755
+ params += list(self.model.diffusion_model.output_blocks.parameters())
756
+ params += list(self.model.diffusion_model.out.parameters())
757
+ if self.unlockKV:
758
+ nCount = 0
759
+ for name, param in self.model.diffusion_model.named_parameters():
760
+ if "attn2.to_k" in name or "attn2.to_v" in name:
761
+ params += [param]
762
+ nCount += 1
763
+ print(
764
+ f"Cross attention is unlocked, and {nCount} Wk or Wv are added to potimizers!!!"
765
+ )
766
+
767
+ opt = torch.optim.AdamW(params, lr=lr)
768
+ return opt
769
+
770
+ def low_vram_shift(self, is_diffusing):
771
+ if is_diffusing:
772
+ self.model = self.model.cuda()
773
+ self.control_model = self.control_model.cuda()
774
+ self.first_stage_model = self.first_stage_model.cpu()
775
+ self.cond_stage_model = self.cond_stage_model.cpu()
776
+ else:
777
+ self.model = self.model.cpu()
778
+ self.control_model = self.control_model.cpu()
779
+ self.first_stage_model = self.first_stage_model.cuda()
780
+ self.cond_stage_model = self.cond_stage_model.cuda()
sorawm/iopaint/model/anytext/cldm/ddim_hacked.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import numpy as np
4
+ import torch
5
+ from tqdm import tqdm
6
+
7
+ from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
8
+ extract_into_tensor,
9
+ make_ddim_sampling_parameters,
10
+ make_ddim_timesteps,
11
+ noise_like,
12
+ )
13
+
14
+
15
+ class DDIMSampler(object):
16
+ def __init__(self, model, device, schedule="linear", **kwargs):
17
+ super().__init__()
18
+ self.device = device
19
+ self.model = model
20
+ self.ddpm_num_timesteps = model.num_timesteps
21
+ self.schedule = schedule
22
+
23
+ def register_buffer(self, name, attr):
24
+ if type(attr) == torch.Tensor:
25
+ if attr.device != torch.device(self.device):
26
+ attr = attr.to(torch.device(self.device))
27
+ setattr(self, name, attr)
28
+
29
+ def make_schedule(
30
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
31
+ ):
32
+ self.ddim_timesteps = make_ddim_timesteps(
33
+ ddim_discr_method=ddim_discretize,
34
+ num_ddim_timesteps=ddim_num_steps,
35
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
36
+ verbose=verbose,
37
+ )
38
+ alphas_cumprod = self.model.alphas_cumprod
39
+ assert (
40
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
41
+ ), "alphas have to be defined for each timestep"
42
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
43
+
44
+ self.register_buffer("betas", to_torch(self.model.betas))
45
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
46
+ self.register_buffer(
47
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
48
+ )
49
+
50
+ # calculations for diffusion q(x_t | x_{t-1}) and others
51
+ self.register_buffer(
52
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
53
+ )
54
+ self.register_buffer(
55
+ "sqrt_one_minus_alphas_cumprod",
56
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
57
+ )
58
+ self.register_buffer(
59
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
60
+ )
61
+ self.register_buffer(
62
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
63
+ )
64
+ self.register_buffer(
65
+ "sqrt_recipm1_alphas_cumprod",
66
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
67
+ )
68
+
69
+ # ddim sampling parameters
70
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
71
+ alphacums=alphas_cumprod.cpu(),
72
+ ddim_timesteps=self.ddim_timesteps,
73
+ eta=ddim_eta,
74
+ verbose=verbose,
75
+ )
76
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
77
+ self.register_buffer("ddim_alphas", ddim_alphas)
78
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
79
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
80
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
81
+ (1 - self.alphas_cumprod_prev)
82
+ / (1 - self.alphas_cumprod)
83
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
84
+ )
85
+ self.register_buffer(
86
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
87
+ )
88
+
89
+ @torch.no_grad()
90
+ def sample(
91
+ self,
92
+ S,
93
+ batch_size,
94
+ shape,
95
+ conditioning=None,
96
+ callback=None,
97
+ normals_sequence=None,
98
+ img_callback=None,
99
+ quantize_x0=False,
100
+ eta=0.0,
101
+ mask=None,
102
+ x0=None,
103
+ temperature=1.0,
104
+ noise_dropout=0.0,
105
+ score_corrector=None,
106
+ corrector_kwargs=None,
107
+ verbose=True,
108
+ x_T=None,
109
+ log_every_t=100,
110
+ unconditional_guidance_scale=1.0,
111
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
112
+ dynamic_threshold=None,
113
+ ucg_schedule=None,
114
+ **kwargs,
115
+ ):
116
+ if conditioning is not None:
117
+ if isinstance(conditioning, dict):
118
+ ctmp = conditioning[list(conditioning.keys())[0]]
119
+ while isinstance(ctmp, list):
120
+ ctmp = ctmp[0]
121
+ cbs = ctmp.shape[0]
122
+ if cbs != batch_size:
123
+ print(
124
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
125
+ )
126
+
127
+ elif isinstance(conditioning, list):
128
+ for ctmp in conditioning:
129
+ if ctmp.shape[0] != batch_size:
130
+ print(
131
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
132
+ )
133
+
134
+ else:
135
+ if conditioning.shape[0] != batch_size:
136
+ print(
137
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
138
+ )
139
+
140
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
141
+ # sampling
142
+ C, H, W = shape
143
+ size = (batch_size, C, H, W)
144
+ print(f"Data shape for DDIM sampling is {size}, eta {eta}")
145
+
146
+ samples, intermediates = self.ddim_sampling(
147
+ conditioning,
148
+ size,
149
+ callback=callback,
150
+ img_callback=img_callback,
151
+ quantize_denoised=quantize_x0,
152
+ mask=mask,
153
+ x0=x0,
154
+ ddim_use_original_steps=False,
155
+ noise_dropout=noise_dropout,
156
+ temperature=temperature,
157
+ score_corrector=score_corrector,
158
+ corrector_kwargs=corrector_kwargs,
159
+ x_T=x_T,
160
+ log_every_t=log_every_t,
161
+ unconditional_guidance_scale=unconditional_guidance_scale,
162
+ unconditional_conditioning=unconditional_conditioning,
163
+ dynamic_threshold=dynamic_threshold,
164
+ ucg_schedule=ucg_schedule,
165
+ )
166
+ return samples, intermediates
167
+
168
+ @torch.no_grad()
169
+ def ddim_sampling(
170
+ self,
171
+ cond,
172
+ shape,
173
+ x_T=None,
174
+ ddim_use_original_steps=False,
175
+ callback=None,
176
+ timesteps=None,
177
+ quantize_denoised=False,
178
+ mask=None,
179
+ x0=None,
180
+ img_callback=None,
181
+ log_every_t=100,
182
+ temperature=1.0,
183
+ noise_dropout=0.0,
184
+ score_corrector=None,
185
+ corrector_kwargs=None,
186
+ unconditional_guidance_scale=1.0,
187
+ unconditional_conditioning=None,
188
+ dynamic_threshold=None,
189
+ ucg_schedule=None,
190
+ ):
191
+ device = self.model.betas.device
192
+ b = shape[0]
193
+ if x_T is None:
194
+ img = torch.randn(shape, device=device)
195
+ else:
196
+ img = x_T
197
+
198
+ if timesteps is None:
199
+ timesteps = (
200
+ self.ddpm_num_timesteps
201
+ if ddim_use_original_steps
202
+ else self.ddim_timesteps
203
+ )
204
+ elif timesteps is not None and not ddim_use_original_steps:
205
+ subset_end = (
206
+ int(
207
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
208
+ * self.ddim_timesteps.shape[0]
209
+ )
210
+ - 1
211
+ )
212
+ timesteps = self.ddim_timesteps[:subset_end]
213
+
214
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
215
+ time_range = (
216
+ reversed(range(0, timesteps))
217
+ if ddim_use_original_steps
218
+ else np.flip(timesteps)
219
+ )
220
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
221
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
222
+
223
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
224
+
225
+ for i, step in enumerate(iterator):
226
+ index = total_steps - i - 1
227
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
228
+
229
+ if mask is not None:
230
+ assert x0 is not None
231
+ img_orig = self.model.q_sample(
232
+ x0, ts
233
+ ) # TODO: deterministic forward pass?
234
+ img = img_orig * mask + (1.0 - mask) * img
235
+
236
+ if ucg_schedule is not None:
237
+ assert len(ucg_schedule) == len(time_range)
238
+ unconditional_guidance_scale = ucg_schedule[i]
239
+
240
+ outs = self.p_sample_ddim(
241
+ img,
242
+ cond,
243
+ ts,
244
+ index=index,
245
+ use_original_steps=ddim_use_original_steps,
246
+ quantize_denoised=quantize_denoised,
247
+ temperature=temperature,
248
+ noise_dropout=noise_dropout,
249
+ score_corrector=score_corrector,
250
+ corrector_kwargs=corrector_kwargs,
251
+ unconditional_guidance_scale=unconditional_guidance_scale,
252
+ unconditional_conditioning=unconditional_conditioning,
253
+ dynamic_threshold=dynamic_threshold,
254
+ )
255
+ img, pred_x0 = outs
256
+ if callback:
257
+ callback(None, i, None, None)
258
+ if img_callback:
259
+ img_callback(pred_x0, i)
260
+
261
+ if index % log_every_t == 0 or index == total_steps - 1:
262
+ intermediates["x_inter"].append(img)
263
+ intermediates["pred_x0"].append(pred_x0)
264
+
265
+ return img, intermediates
266
+
267
+ @torch.no_grad()
268
+ def p_sample_ddim(
269
+ self,
270
+ x,
271
+ c,
272
+ t,
273
+ index,
274
+ repeat_noise=False,
275
+ use_original_steps=False,
276
+ quantize_denoised=False,
277
+ temperature=1.0,
278
+ noise_dropout=0.0,
279
+ score_corrector=None,
280
+ corrector_kwargs=None,
281
+ unconditional_guidance_scale=1.0,
282
+ unconditional_conditioning=None,
283
+ dynamic_threshold=None,
284
+ ):
285
+ b, *_, device = *x.shape, x.device
286
+
287
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
288
+ model_output = self.model.apply_model(x, t, c)
289
+ else:
290
+ model_t = self.model.apply_model(x, t, c)
291
+ model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
292
+ model_output = model_uncond + unconditional_guidance_scale * (
293
+ model_t - model_uncond
294
+ )
295
+
296
+ if self.model.parameterization == "v":
297
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
298
+ else:
299
+ e_t = model_output
300
+
301
+ if score_corrector is not None:
302
+ assert self.model.parameterization == "eps", "not implemented"
303
+ e_t = score_corrector.modify_score(
304
+ self.model, e_t, x, t, c, **corrector_kwargs
305
+ )
306
+
307
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
308
+ alphas_prev = (
309
+ self.model.alphas_cumprod_prev
310
+ if use_original_steps
311
+ else self.ddim_alphas_prev
312
+ )
313
+ sqrt_one_minus_alphas = (
314
+ self.model.sqrt_one_minus_alphas_cumprod
315
+ if use_original_steps
316
+ else self.ddim_sqrt_one_minus_alphas
317
+ )
318
+ sigmas = (
319
+ self.model.ddim_sigmas_for_original_num_steps
320
+ if use_original_steps
321
+ else self.ddim_sigmas
322
+ )
323
+ # select parameters corresponding to the currently considered timestep
324
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
325
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
326
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
327
+ sqrt_one_minus_at = torch.full(
328
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
329
+ )
330
+
331
+ # current prediction for x_0
332
+ if self.model.parameterization != "v":
333
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
334
+ else:
335
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
336
+
337
+ if quantize_denoised:
338
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
339
+
340
+ if dynamic_threshold is not None:
341
+ raise NotImplementedError()
342
+
343
+ # direction pointing to x_t
344
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
345
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
346
+ if noise_dropout > 0.0:
347
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
348
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
349
+ return x_prev, pred_x0
350
+
351
+ @torch.no_grad()
352
+ def encode(
353
+ self,
354
+ x0,
355
+ c,
356
+ t_enc,
357
+ use_original_steps=False,
358
+ return_intermediates=None,
359
+ unconditional_guidance_scale=1.0,
360
+ unconditional_conditioning=None,
361
+ callback=None,
362
+ ):
363
+ timesteps = (
364
+ np.arange(self.ddpm_num_timesteps)
365
+ if use_original_steps
366
+ else self.ddim_timesteps
367
+ )
368
+ num_reference_steps = timesteps.shape[0]
369
+
370
+ assert t_enc <= num_reference_steps
371
+ num_steps = t_enc
372
+
373
+ if use_original_steps:
374
+ alphas_next = self.alphas_cumprod[:num_steps]
375
+ alphas = self.alphas_cumprod_prev[:num_steps]
376
+ else:
377
+ alphas_next = self.ddim_alphas[:num_steps]
378
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
379
+
380
+ x_next = x0
381
+ intermediates = []
382
+ inter_steps = []
383
+ for i in tqdm(range(num_steps), desc="Encoding Image"):
384
+ t = torch.full(
385
+ (x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long
386
+ )
387
+ if unconditional_guidance_scale == 1.0:
388
+ noise_pred = self.model.apply_model(x_next, t, c)
389
+ else:
390
+ assert unconditional_conditioning is not None
391
+ e_t_uncond, noise_pred = torch.chunk(
392
+ self.model.apply_model(
393
+ torch.cat((x_next, x_next)),
394
+ torch.cat((t, t)),
395
+ torch.cat((unconditional_conditioning, c)),
396
+ ),
397
+ 2,
398
+ )
399
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (
400
+ noise_pred - e_t_uncond
401
+ )
402
+
403
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
404
+ weighted_noise_pred = (
405
+ alphas_next[i].sqrt()
406
+ * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
407
+ * noise_pred
408
+ )
409
+ x_next = xt_weighted + weighted_noise_pred
410
+ if (
411
+ return_intermediates
412
+ and i % (num_steps // return_intermediates) == 0
413
+ and i < num_steps - 1
414
+ ):
415
+ intermediates.append(x_next)
416
+ inter_steps.append(i)
417
+ elif return_intermediates and i >= num_steps - 2:
418
+ intermediates.append(x_next)
419
+ inter_steps.append(i)
420
+ if callback:
421
+ callback(i)
422
+
423
+ out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
424
+ if return_intermediates:
425
+ out.update({"intermediates": intermediates})
426
+ return x_next, out
427
+
428
+ @torch.no_grad()
429
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
430
+ # fast, but does not allow for exact reconstruction
431
+ # t serves as an index to gather the correct alphas
432
+ if use_original_steps:
433
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
434
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
435
+ else:
436
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
437
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
438
+
439
+ if noise is None:
440
+ noise = torch.randn_like(x0)
441
+ return (
442
+ extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
443
+ + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
444
+ )
445
+
446
+ @torch.no_grad()
447
+ def decode(
448
+ self,
449
+ x_latent,
450
+ cond,
451
+ t_start,
452
+ unconditional_guidance_scale=1.0,
453
+ unconditional_conditioning=None,
454
+ use_original_steps=False,
455
+ callback=None,
456
+ ):
457
+ timesteps = (
458
+ np.arange(self.ddpm_num_timesteps)
459
+ if use_original_steps
460
+ else self.ddim_timesteps
461
+ )
462
+ timesteps = timesteps[:t_start]
463
+
464
+ time_range = np.flip(timesteps)
465
+ total_steps = timesteps.shape[0]
466
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
467
+
468
+ iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
469
+ x_dec = x_latent
470
+ for i, step in enumerate(iterator):
471
+ index = total_steps - i - 1
472
+ ts = torch.full(
473
+ (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
474
+ )
475
+ x_dec, _ = self.p_sample_ddim(
476
+ x_dec,
477
+ cond,
478
+ ts,
479
+ index=index,
480
+ use_original_steps=use_original_steps,
481
+ unconditional_guidance_scale=unconditional_guidance_scale,
482
+ unconditional_conditioning=unconditional_conditioning,
483
+ )
484
+ if callback:
485
+ callback(i)
486
+ return x_dec
sorawm/iopaint/model/anytext/cldm/embedding_manager.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
11
+ conv_nd,
12
+ linear,
13
+ )
14
+
15
+
16
+ def get_clip_token_for_string(tokenizer, string):
17
+ batch_encoding = tokenizer(
18
+ string,
19
+ truncation=True,
20
+ max_length=77,
21
+ return_length=True,
22
+ return_overflowing_tokens=False,
23
+ padding="max_length",
24
+ return_tensors="pt",
25
+ )
26
+ tokens = batch_encoding["input_ids"]
27
+ assert (
28
+ torch.count_nonzero(tokens - 49407) == 2
29
+ ), f"String '{string}' maps to more than a single token. Please use another string"
30
+ return tokens[0, 1]
31
+
32
+
33
+ def get_bert_token_for_string(tokenizer, string):
34
+ token = tokenizer(string)
35
+ assert (
36
+ torch.count_nonzero(token) == 3
37
+ ), f"String '{string}' maps to more than a single token. Please use another string"
38
+ token = token[0, 1]
39
+ return token
40
+
41
+
42
+ def get_clip_vision_emb(encoder, processor, img):
43
+ _img = img.repeat(1, 3, 1, 1) * 255
44
+ inputs = processor(images=_img, return_tensors="pt")
45
+ inputs["pixel_values"] = inputs["pixel_values"].to(img.device)
46
+ outputs = encoder(**inputs)
47
+ emb = outputs.image_embeds
48
+ return emb
49
+
50
+
51
+ def get_recog_emb(encoder, img_list):
52
+ _img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list]
53
+ encoder.predictor.eval()
54
+ _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
55
+ return preds_neck
56
+
57
+
58
+ def pad_H(x):
59
+ _, _, H, W = x.shape
60
+ p_top = (W - H) // 2
61
+ p_bot = W - H - p_top
62
+ return F.pad(x, (0, 0, p_top, p_bot))
63
+
64
+
65
+ class EncodeNet(nn.Module):
66
+ def __init__(self, in_channels, out_channels):
67
+ super(EncodeNet, self).__init__()
68
+ chan = 16
69
+ n_layer = 4 # downsample
70
+
71
+ self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
72
+ self.conv_list = nn.ModuleList([])
73
+ _c = chan
74
+ for i in range(n_layer):
75
+ self.conv_list.append(conv_nd(2, _c, _c * 2, 3, padding=1, stride=2))
76
+ _c *= 2
77
+ self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
78
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
79
+ self.act = nn.SiLU()
80
+
81
+ def forward(self, x):
82
+ x = self.act(self.conv1(x))
83
+ for layer in self.conv_list:
84
+ x = self.act(layer(x))
85
+ x = self.act(self.conv2(x))
86
+ x = self.avgpool(x)
87
+ x = x.view(x.size(0), -1)
88
+ return x
89
+
90
+
91
+ class EmbeddingManager(nn.Module):
92
+ def __init__(
93
+ self,
94
+ embedder,
95
+ valid=True,
96
+ glyph_channels=20,
97
+ position_channels=1,
98
+ placeholder_string="*",
99
+ add_pos=False,
100
+ emb_type="ocr",
101
+ **kwargs,
102
+ ):
103
+ super().__init__()
104
+ if hasattr(embedder, "tokenizer"): # using Stable Diffusion's CLIP encoder
105
+ get_token_for_string = partial(
106
+ get_clip_token_for_string, embedder.tokenizer
107
+ )
108
+ token_dim = 768
109
+ if hasattr(embedder, "vit"):
110
+ assert emb_type == "vit"
111
+ self.get_vision_emb = partial(
112
+ get_clip_vision_emb, embedder.vit, embedder.processor
113
+ )
114
+ self.get_recog_emb = None
115
+ else: # using LDM's BERT encoder
116
+ get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
117
+ token_dim = 1280
118
+ self.token_dim = token_dim
119
+ self.emb_type = emb_type
120
+
121
+ self.add_pos = add_pos
122
+ if add_pos:
123
+ self.position_encoder = EncodeNet(position_channels, token_dim)
124
+ if emb_type == "ocr":
125
+ self.proj = linear(40 * 64, token_dim)
126
+ if emb_type == "conv":
127
+ self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
128
+
129
+ self.placeholder_token = get_token_for_string(placeholder_string)
130
+
131
+ def encode_text(self, text_info):
132
+ if self.get_recog_emb is None and self.emb_type == "ocr":
133
+ self.get_recog_emb = partial(get_recog_emb, self.recog)
134
+
135
+ gline_list = []
136
+ pos_list = []
137
+ for i in range(len(text_info["n_lines"])): # sample index in a batch
138
+ n_lines = text_info["n_lines"][i]
139
+ for j in range(n_lines): # line
140
+ gline_list += [text_info["gly_line"][j][i : i + 1]]
141
+ if self.add_pos:
142
+ pos_list += [text_info["positions"][j][i : i + 1]]
143
+
144
+ if len(gline_list) > 0:
145
+ if self.emb_type == "ocr":
146
+ recog_emb = self.get_recog_emb(gline_list)
147
+ enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
148
+ elif self.emb_type == "vit":
149
+ enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
150
+ elif self.emb_type == "conv":
151
+ enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
152
+ if self.add_pos:
153
+ enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
154
+ enc_glyph = enc_glyph + enc_pos
155
+
156
+ self.text_embs_all = []
157
+ n_idx = 0
158
+ for i in range(len(text_info["n_lines"])): # sample index in a batch
159
+ n_lines = text_info["n_lines"][i]
160
+ text_embs = []
161
+ for j in range(n_lines): # line
162
+ text_embs += [enc_glyph[n_idx : n_idx + 1]]
163
+ n_idx += 1
164
+ self.text_embs_all += [text_embs]
165
+
166
+ def forward(
167
+ self,
168
+ tokenized_text,
169
+ embedded_text,
170
+ ):
171
+ b, device = tokenized_text.shape[0], tokenized_text.device
172
+ for i in range(b):
173
+ idx = tokenized_text[i] == self.placeholder_token.to(device)
174
+ if sum(idx) > 0:
175
+ if i >= len(self.text_embs_all):
176
+ print("truncation for log images...")
177
+ break
178
+ text_emb = torch.cat(self.text_embs_all[i], dim=0)
179
+ if sum(idx) != len(text_emb):
180
+ print("truncation for long caption...")
181
+ embedded_text[i][idx] = text_emb[: sum(idx)]
182
+ return embedded_text
183
+
184
+ def embedding_parameters(self):
185
+ return self.parameters()
sorawm/iopaint/model/anytext/cldm/hack.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import torch
3
+ from transformers import logging
4
+
5
+ import sorawm.iopaint.model.anytext.ldm.modules.attention
6
+ import sorawm.iopaint.model.anytext.ldm.modules.encoders.modules
7
+ from sorawm.iopaint.model.anytext.ldm.modules.attention import default
8
+
9
+
10
+ def disable_verbosity():
11
+ logging.set_verbosity_error()
12
+ print("logging improved.")
13
+ return
14
+
15
+
16
+ def enable_sliced_attention():
17
+ sorawm.iopaint.model.anytext.ldm.modules.attention.CrossAttention.forward = (
18
+ _hacked_sliced_attentin_forward
19
+ )
20
+ print("Enabled sliced_attention.")
21
+ return
22
+
23
+
24
+ def hack_everything(clip_skip=0):
25
+ disable_verbosity()
26
+ sorawm.iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = (
27
+ _hacked_clip_forward
28
+ )
29
+ sorawm.iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = (
30
+ clip_skip
31
+ )
32
+ print("Enabled clip hacks.")
33
+ return
34
+
35
+
36
+ # Written by Lvmin
37
+ def _hacked_clip_forward(self, text):
38
+ PAD = self.tokenizer.pad_token_id
39
+ EOS = self.tokenizer.eos_token_id
40
+ BOS = self.tokenizer.bos_token_id
41
+
42
+ def tokenize(t):
43
+ return self.tokenizer(t, truncation=False, add_special_tokens=False)[
44
+ "input_ids"
45
+ ]
46
+
47
+ def transformer_encode(t):
48
+ if self.clip_skip > 1:
49
+ rt = self.transformer(input_ids=t, output_hidden_states=True)
50
+ return self.transformer.text_model.final_layer_norm(
51
+ rt.hidden_states[-self.clip_skip]
52
+ )
53
+ else:
54
+ return self.transformer(
55
+ input_ids=t, output_hidden_states=False
56
+ ).last_hidden_state
57
+
58
+ def split(x):
59
+ return x[75 * 0 : 75 * 1], x[75 * 1 : 75 * 2], x[75 * 2 : 75 * 3]
60
+
61
+ def pad(x, p, i):
62
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
63
+
64
+ raw_tokens_list = tokenize(text)
65
+ tokens_list = []
66
+
67
+ for raw_tokens in raw_tokens_list:
68
+ raw_tokens_123 = split(raw_tokens)
69
+ raw_tokens_123 = [
70
+ [BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123
71
+ ]
72
+ raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
73
+ tokens_list.append(raw_tokens_123)
74
+
75
+ tokens_list = torch.IntTensor(tokens_list).to(self.device)
76
+
77
+ feed = einops.rearrange(tokens_list, "b f i -> (b f) i")
78
+ y = transformer_encode(feed)
79
+ z = einops.rearrange(y, "(b f) i c -> b (f i) c", f=3)
80
+
81
+ return z
82
+
83
+
84
+ # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
85
+ def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
86
+ h = self.heads
87
+
88
+ q = self.to_q(x)
89
+ context = default(context, x)
90
+ k = self.to_k(context)
91
+ v = self.to_v(context)
92
+ del context, x
93
+
94
+ q, k, v = map(
95
+ lambda t: einops.rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)
96
+ )
97
+
98
+ limit = k.shape[0]
99
+ att_step = 1
100
+ q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
101
+ k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
102
+ v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
103
+
104
+ q_chunks.reverse()
105
+ k_chunks.reverse()
106
+ v_chunks.reverse()
107
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
108
+ del k, q, v
109
+ for i in range(0, limit, att_step):
110
+ q_buffer = q_chunks.pop()
111
+ k_buffer = k_chunks.pop()
112
+ v_buffer = v_chunks.pop()
113
+ sim_buffer = (
114
+ torch.einsum("b i d, b j d -> b i j", q_buffer, k_buffer) * self.scale
115
+ )
116
+
117
+ del k_buffer, q_buffer
118
+ # attention, what we cannot get enough of, by chunks
119
+
120
+ sim_buffer = sim_buffer.softmax(dim=-1)
121
+
122
+ sim_buffer = torch.einsum("b i j, b j d -> b i d", sim_buffer, v_buffer)
123
+ del v_buffer
124
+ sim[i : i + att_step, :, :] = sim_buffer
125
+
126
+ del sim_buffer
127
+ sim = einops.rearrange(sim, "(b h) n d -> b n (h d)", h=h)
128
+ return self.to_out(sim)
sorawm/iopaint/model/anytext/cldm/model.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from omegaconf import OmegaConf
5
+
6
+ from sorawm.iopaint.model.anytext.ldm.util import instantiate_from_config
7
+
8
+
9
+ def get_state_dict(d):
10
+ return d.get("state_dict", d)
11
+
12
+
13
+ def load_state_dict(ckpt_path, location="cpu"):
14
+ _, extension = os.path.splitext(ckpt_path)
15
+ if extension.lower() == ".safetensors":
16
+ import safetensors.torch
17
+
18
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
19
+ else:
20
+ state_dict = get_state_dict(
21
+ torch.load(ckpt_path, map_location=torch.device(location))
22
+ )
23
+ state_dict = get_state_dict(state_dict)
24
+ print(f"Loaded state_dict from [{ckpt_path}]")
25
+ return state_dict
26
+
27
+
28
+ def create_model(config_path, device, cond_stage_path=None, use_fp16=False):
29
+ config = OmegaConf.load(config_path)
30
+ # if cond_stage_path:
31
+ # config.model.params.cond_stage_config.params.version = (
32
+ # cond_stage_path # use pre-downloaded ckpts, in case blocked
33
+ # )
34
+ config.model.params.cond_stage_config.params.device = str(device)
35
+ if use_fp16:
36
+ config.model.params.use_fp16 = True
37
+ config.model.params.control_stage_config.params.use_fp16 = True
38
+ config.model.params.unet_config.params.use_fp16 = True
39
+ model = instantiate_from_config(config.model).cpu()
40
+ print(f"Loaded model config from [{config_path}]")
41
+ return model
sorawm/iopaint/model/anytext/cldm/recognizer.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """
4
+ import math
5
+ import os
6
+ import time
7
+ import traceback
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from easydict import EasyDict as edict
14
+
15
+ from sorawm.iopaint.model.anytext.ocr_recog.RecModel import RecModel
16
+
17
+
18
+ def min_bounding_rect(img):
19
+ ret, thresh = cv2.threshold(img, 127, 255, 0)
20
+ contours, hierarchy = cv2.findContours(
21
+ thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
22
+ )
23
+ if len(contours) == 0:
24
+ print("Bad contours, using fake bbox...")
25
+ return np.array([[0, 0], [100, 0], [100, 100], [0, 100]])
26
+ max_contour = max(contours, key=cv2.contourArea)
27
+ rect = cv2.minAreaRect(max_contour)
28
+ box = cv2.boxPoints(rect)
29
+ box = np.int0(box)
30
+ # sort
31
+ x_sorted = sorted(box, key=lambda x: x[0])
32
+ left = x_sorted[:2]
33
+ right = x_sorted[2:]
34
+ left = sorted(left, key=lambda x: x[1])
35
+ (tl, bl) = left
36
+ right = sorted(right, key=lambda x: x[1])
37
+ (tr, br) = right
38
+ if tl[1] > bl[1]:
39
+ (tl, bl) = (bl, tl)
40
+ if tr[1] > br[1]:
41
+ (tr, br) = (br, tr)
42
+ return np.array([tl, tr, br, bl])
43
+
44
+
45
+ def create_predictor(model_dir=None, model_lang="ch", is_onnx=False):
46
+ model_file_path = model_dir
47
+ if model_file_path is not None and not os.path.exists(model_file_path):
48
+ raise ValueError("not find model file path {}".format(model_file_path))
49
+
50
+ if is_onnx:
51
+ import onnxruntime as ort
52
+
53
+ sess = ort.InferenceSession(
54
+ model_file_path, providers=["CPUExecutionProvider"]
55
+ ) # 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
56
+ return sess
57
+ else:
58
+ if model_lang == "ch":
59
+ n_class = 6625
60
+ elif model_lang == "en":
61
+ n_class = 97
62
+ else:
63
+ raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
64
+ rec_config = edict(
65
+ in_channels=3,
66
+ backbone=edict(
67
+ type="MobileNetV1Enhance",
68
+ scale=0.5,
69
+ last_conv_stride=[1, 2],
70
+ last_pool_type="avg",
71
+ ),
72
+ neck=edict(
73
+ type="SequenceEncoder",
74
+ encoder_type="svtr",
75
+ dims=64,
76
+ depth=2,
77
+ hidden_dims=120,
78
+ use_guide=True,
79
+ ),
80
+ head=edict(
81
+ type="CTCHead",
82
+ fc_decay=0.00001,
83
+ out_channels=n_class,
84
+ return_feats=True,
85
+ ),
86
+ )
87
+
88
+ rec_model = RecModel(rec_config)
89
+ if model_file_path is not None:
90
+ rec_model.load_state_dict(torch.load(model_file_path, map_location="cpu"))
91
+ rec_model.eval()
92
+ return rec_model.eval()
93
+
94
+
95
+ def _check_image_file(path):
96
+ img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff"}
97
+ return any([path.lower().endswith(e) for e in img_end])
98
+
99
+
100
+ def get_image_file_list(img_file):
101
+ imgs_lists = []
102
+ if img_file is None or not os.path.exists(img_file):
103
+ raise Exception("not found any img file in {}".format(img_file))
104
+ if os.path.isfile(img_file) and _check_image_file(img_file):
105
+ imgs_lists.append(img_file)
106
+ elif os.path.isdir(img_file):
107
+ for single_file in os.listdir(img_file):
108
+ file_path = os.path.join(img_file, single_file)
109
+ if os.path.isfile(file_path) and _check_image_file(file_path):
110
+ imgs_lists.append(file_path)
111
+ if len(imgs_lists) == 0:
112
+ raise Exception("not found any img file in {}".format(img_file))
113
+ imgs_lists = sorted(imgs_lists)
114
+ return imgs_lists
115
+
116
+
117
+ class TextRecognizer(object):
118
+ def __init__(self, args, predictor):
119
+ self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
120
+ self.rec_batch_num = args.rec_batch_num
121
+ self.predictor = predictor
122
+ self.chars = self.get_char_dict(args.rec_char_dict_path)
123
+ self.char2id = {x: i for i, x in enumerate(self.chars)}
124
+ self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
125
+ self.use_fp16 = args.use_fp16
126
+
127
+ # img: CHW
128
+ def resize_norm_img(self, img, max_wh_ratio):
129
+ imgC, imgH, imgW = self.rec_image_shape
130
+ assert imgC == img.shape[0]
131
+ imgW = int((imgH * max_wh_ratio))
132
+
133
+ h, w = img.shape[1:]
134
+ ratio = w / float(h)
135
+ if math.ceil(imgH * ratio) > imgW:
136
+ resized_w = imgW
137
+ else:
138
+ resized_w = int(math.ceil(imgH * ratio))
139
+ resized_image = torch.nn.functional.interpolate(
140
+ img.unsqueeze(0),
141
+ size=(imgH, resized_w),
142
+ mode="bilinear",
143
+ align_corners=True,
144
+ )
145
+ resized_image /= 255.0
146
+ resized_image -= 0.5
147
+ resized_image /= 0.5
148
+ padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device)
149
+ padding_im[:, :, 0:resized_w] = resized_image[0]
150
+ return padding_im
151
+
152
+ # img_list: list of tensors with shape chw 0-255
153
+ def pred_imglist(self, img_list, show_debug=False, is_ori=False):
154
+ img_num = len(img_list)
155
+ assert img_num > 0
156
+ # Calculate the aspect ratio of all text bars
157
+ width_list = []
158
+ for img in img_list:
159
+ width_list.append(img.shape[2] / float(img.shape[1]))
160
+ # Sorting can speed up the recognition process
161
+ indices = torch.from_numpy(np.argsort(np.array(width_list)))
162
+ batch_num = self.rec_batch_num
163
+ preds_all = [None] * img_num
164
+ preds_neck_all = [None] * img_num
165
+ for beg_img_no in range(0, img_num, batch_num):
166
+ end_img_no = min(img_num, beg_img_no + batch_num)
167
+ norm_img_batch = []
168
+
169
+ imgC, imgH, imgW = self.rec_image_shape[:3]
170
+ max_wh_ratio = imgW / imgH
171
+ for ino in range(beg_img_no, end_img_no):
172
+ h, w = img_list[indices[ino]].shape[1:]
173
+ if h > w * 1.2:
174
+ img = img_list[indices[ino]]
175
+ img = torch.transpose(img, 1, 2).flip(dims=[1])
176
+ img_list[indices[ino]] = img
177
+ h, w = img.shape[1:]
178
+ # wh_ratio = w * 1.0 / h
179
+ # max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio
180
+ for ino in range(beg_img_no, end_img_no):
181
+ norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
182
+ if self.use_fp16:
183
+ norm_img = norm_img.half()
184
+ norm_img = norm_img.unsqueeze(0)
185
+ norm_img_batch.append(norm_img)
186
+ norm_img_batch = torch.cat(norm_img_batch, dim=0)
187
+ if show_debug:
188
+ for i in range(len(norm_img_batch)):
189
+ _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy()
190
+ _img = (_img + 0.5) * 255
191
+ _img = _img[:, :, ::-1]
192
+ file_name = f"{indices[beg_img_no + i]}"
193
+ file_name = file_name + "_ori" if is_ori else file_name
194
+ cv2.imwrite(file_name + ".jpg", _img)
195
+ if self.is_onnx:
196
+ input_dict = {}
197
+ input_dict[self.predictor.get_inputs()[0].name] = (
198
+ norm_img_batch.detach().cpu().numpy()
199
+ )
200
+ outputs = self.predictor.run(None, input_dict)
201
+ preds = {}
202
+ preds["ctc"] = torch.from_numpy(outputs[0])
203
+ preds["ctc_neck"] = [torch.zeros(1)] * img_num
204
+ else:
205
+ preds = self.predictor(norm_img_batch)
206
+ for rno in range(preds["ctc"].shape[0]):
207
+ preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno]
208
+ preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno]
209
+
210
+ return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0)
211
+
212
+ def get_char_dict(self, character_dict_path):
213
+ character_str = []
214
+ with open(character_dict_path, "rb") as fin:
215
+ lines = fin.readlines()
216
+ for line in lines:
217
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
218
+ character_str.append(line)
219
+ dict_character = list(character_str)
220
+ dict_character = ["sos"] + dict_character + [" "] # eos is space
221
+ return dict_character
222
+
223
+ def get_text(self, order):
224
+ char_list = [self.chars[text_id] for text_id in order]
225
+ return "".join(char_list)
226
+
227
+ def decode(self, mat):
228
+ text_index = mat.detach().cpu().numpy().argmax(axis=1)
229
+ ignored_tokens = [0]
230
+ selection = np.ones(len(text_index), dtype=bool)
231
+ selection[1:] = text_index[1:] != text_index[:-1]
232
+ for ignored_token in ignored_tokens:
233
+ selection &= text_index != ignored_token
234
+ return text_index[selection], np.where(selection)[0]
235
+
236
+ def get_ctcloss(self, preds, gt_text, weight):
237
+ if not isinstance(weight, torch.Tensor):
238
+ weight = torch.tensor(weight).to(preds.device)
239
+ ctc_loss = torch.nn.CTCLoss(reduction="none")
240
+ log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC
241
+ targets = []
242
+ target_lengths = []
243
+ for t in gt_text:
244
+ targets += [self.char2id.get(i, len(self.chars) - 1) for i in t]
245
+ target_lengths += [len(t)]
246
+ targets = torch.tensor(targets).to(preds.device)
247
+ target_lengths = torch.tensor(target_lengths).to(preds.device)
248
+ input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(
249
+ preds.device
250
+ )
251
+ loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
252
+ loss = loss / input_lengths * weight
253
+ return loss
254
+
255
+
256
+ def main():
257
+ rec_model_dir = "./ocr_weights/ppv3_rec.pth"
258
+ predictor = create_predictor(rec_model_dir)
259
+ args = edict()
260
+ args.rec_image_shape = "3, 48, 320"
261
+ args.rec_char_dict_path = "./ocr_weights/ppocr_keys_v1.txt"
262
+ args.rec_batch_num = 6
263
+ text_recognizer = TextRecognizer(args, predictor)
264
+ image_dir = "./test_imgs_cn"
265
+ gt_text = ["韩国小馆"] * 14
266
+
267
+ image_file_list = get_image_file_list(image_dir)
268
+ valid_image_file_list = []
269
+ img_list = []
270
+
271
+ for image_file in image_file_list:
272
+ img = cv2.imread(image_file)
273
+ if img is None:
274
+ print("error in loading image:{}".format(image_file))
275
+ continue
276
+ valid_image_file_list.append(image_file)
277
+ img_list.append(torch.from_numpy(img).permute(2, 0, 1).float())
278
+ try:
279
+ tic = time.time()
280
+ times = []
281
+ for i in range(10):
282
+ preds, _ = text_recognizer.pred_imglist(img_list) # get text
283
+ preds_all = preds.softmax(dim=2)
284
+ times += [(time.time() - tic) * 1000.0]
285
+ tic = time.time()
286
+ print(times)
287
+ print(np.mean(times[1:]) / len(preds_all))
288
+ weight = np.ones(len(gt_text))
289
+ loss = text_recognizer.get_ctcloss(preds, gt_text, weight)
290
+ for i in range(len(valid_image_file_list)):
291
+ pred = preds_all[i]
292
+ order, idx = text_recognizer.decode(pred)
293
+ text = text_recognizer.get_text(order)
294
+ print(
295
+ f'{valid_image_file_list[i]}: pred/gt="{text}"/"{gt_text[i]}", loss={loss[i]:.2f}'
296
+ )
297
+ except Exception as E:
298
+ print(traceback.format_exc(), E)
299
+
300
+
301
+ if __name__ == "__main__":
302
+ main()
sorawm/iopaint/model/anytext/ldm/__init__.py ADDED
File without changes
sorawm/iopaint/model/anytext/ldm/models/__init__.py ADDED
File without changes
sorawm/iopaint/model/anytext/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.model import (
7
+ Decoder,
8
+ Encoder,
9
+ )
10
+ from sorawm.iopaint.model.anytext.ldm.modules.distributions.distributions import (
11
+ DiagonalGaussianDistribution,
12
+ )
13
+ from sorawm.iopaint.model.anytext.ldm.modules.ema import LitEma
14
+ from sorawm.iopaint.model.anytext.ldm.util import instantiate_from_config
15
+
16
+
17
+ class AutoencoderKL(torch.nn.Module):
18
+ def __init__(
19
+ self,
20
+ ddconfig,
21
+ lossconfig,
22
+ embed_dim,
23
+ ckpt_path=None,
24
+ ignore_keys=[],
25
+ image_key="image",
26
+ colorize_nlabels=None,
27
+ monitor=None,
28
+ ema_decay=None,
29
+ learn_logvar=False,
30
+ ):
31
+ super().__init__()
32
+ self.learn_logvar = learn_logvar
33
+ self.image_key = image_key
34
+ self.encoder = Encoder(**ddconfig)
35
+ self.decoder = Decoder(**ddconfig)
36
+ self.loss = instantiate_from_config(lossconfig)
37
+ assert ddconfig["double_z"]
38
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
39
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
40
+ self.embed_dim = embed_dim
41
+ if colorize_nlabels is not None:
42
+ assert type(colorize_nlabels) == int
43
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
44
+ if monitor is not None:
45
+ self.monitor = monitor
46
+
47
+ self.use_ema = ema_decay is not None
48
+ if self.use_ema:
49
+ self.ema_decay = ema_decay
50
+ assert 0.0 < ema_decay < 1.0
51
+ self.model_ema = LitEma(self, decay=ema_decay)
52
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
53
+
54
+ if ckpt_path is not None:
55
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
56
+
57
+ def init_from_ckpt(self, path, ignore_keys=list()):
58
+ sd = torch.load(path, map_location="cpu")["state_dict"]
59
+ keys = list(sd.keys())
60
+ for k in keys:
61
+ for ik in ignore_keys:
62
+ if k.startswith(ik):
63
+ print("Deleting key {} from state_dict.".format(k))
64
+ del sd[k]
65
+ self.load_state_dict(sd, strict=False)
66
+ print(f"Restored from {path}")
67
+
68
+ @contextmanager
69
+ def ema_scope(self, context=None):
70
+ if self.use_ema:
71
+ self.model_ema.store(self.parameters())
72
+ self.model_ema.copy_to(self)
73
+ if context is not None:
74
+ print(f"{context}: Switched to EMA weights")
75
+ try:
76
+ yield None
77
+ finally:
78
+ if self.use_ema:
79
+ self.model_ema.restore(self.parameters())
80
+ if context is not None:
81
+ print(f"{context}: Restored training weights")
82
+
83
+ def on_train_batch_end(self, *args, **kwargs):
84
+ if self.use_ema:
85
+ self.model_ema(self)
86
+
87
+ def encode(self, x):
88
+ h = self.encoder(x)
89
+ moments = self.quant_conv(h)
90
+ posterior = DiagonalGaussianDistribution(moments)
91
+ return posterior
92
+
93
+ def decode(self, z):
94
+ z = self.post_quant_conv(z)
95
+ dec = self.decoder(z)
96
+ return dec
97
+
98
+ def forward(self, input, sample_posterior=True):
99
+ posterior = self.encode(input)
100
+ if sample_posterior:
101
+ z = posterior.sample()
102
+ else:
103
+ z = posterior.mode()
104
+ dec = self.decode(z)
105
+ return dec, posterior
106
+
107
+ def get_input(self, batch, k):
108
+ x = batch[k]
109
+ if len(x.shape) == 3:
110
+ x = x[..., None]
111
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
112
+ return x
113
+
114
+ def training_step(self, batch, batch_idx, optimizer_idx):
115
+ inputs = self.get_input(batch, self.image_key)
116
+ reconstructions, posterior = self(inputs)
117
+
118
+ if optimizer_idx == 0:
119
+ # train encoder+decoder+logvar
120
+ aeloss, log_dict_ae = self.loss(
121
+ inputs,
122
+ reconstructions,
123
+ posterior,
124
+ optimizer_idx,
125
+ self.global_step,
126
+ last_layer=self.get_last_layer(),
127
+ split="train",
128
+ )
129
+ self.log(
130
+ "aeloss",
131
+ aeloss,
132
+ prog_bar=True,
133
+ logger=True,
134
+ on_step=True,
135
+ on_epoch=True,
136
+ )
137
+ self.log_dict(
138
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
139
+ )
140
+ return aeloss
141
+
142
+ if optimizer_idx == 1:
143
+ # train the discriminator
144
+ discloss, log_dict_disc = self.loss(
145
+ inputs,
146
+ reconstructions,
147
+ posterior,
148
+ optimizer_idx,
149
+ self.global_step,
150
+ last_layer=self.get_last_layer(),
151
+ split="train",
152
+ )
153
+
154
+ self.log(
155
+ "discloss",
156
+ discloss,
157
+ prog_bar=True,
158
+ logger=True,
159
+ on_step=True,
160
+ on_epoch=True,
161
+ )
162
+ self.log_dict(
163
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
164
+ )
165
+ return discloss
166
+
167
+ def validation_step(self, batch, batch_idx):
168
+ log_dict = self._validation_step(batch, batch_idx)
169
+ with self.ema_scope():
170
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
171
+ return log_dict
172
+
173
+ def _validation_step(self, batch, batch_idx, postfix=""):
174
+ inputs = self.get_input(batch, self.image_key)
175
+ reconstructions, posterior = self(inputs)
176
+ aeloss, log_dict_ae = self.loss(
177
+ inputs,
178
+ reconstructions,
179
+ posterior,
180
+ 0,
181
+ self.global_step,
182
+ last_layer=self.get_last_layer(),
183
+ split="val" + postfix,
184
+ )
185
+
186
+ discloss, log_dict_disc = self.loss(
187
+ inputs,
188
+ reconstructions,
189
+ posterior,
190
+ 1,
191
+ self.global_step,
192
+ last_layer=self.get_last_layer(),
193
+ split="val" + postfix,
194
+ )
195
+
196
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
197
+ self.log_dict(log_dict_ae)
198
+ self.log_dict(log_dict_disc)
199
+ return self.log_dict
200
+
201
+ def configure_optimizers(self):
202
+ lr = self.learning_rate
203
+ ae_params_list = (
204
+ list(self.encoder.parameters())
205
+ + list(self.decoder.parameters())
206
+ + list(self.quant_conv.parameters())
207
+ + list(self.post_quant_conv.parameters())
208
+ )
209
+ if self.learn_logvar:
210
+ print(f"{self.__class__.__name__}: Learning logvar")
211
+ ae_params_list.append(self.loss.logvar)
212
+ opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9))
213
+ opt_disc = torch.optim.Adam(
214
+ self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
215
+ )
216
+ return [opt_ae, opt_disc], []
217
+
218
+ def get_last_layer(self):
219
+ return self.decoder.conv_out.weight
220
+
221
+ @torch.no_grad()
222
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
223
+ log = dict()
224
+ x = self.get_input(batch, self.image_key)
225
+ x = x.to(self.device)
226
+ if not only_inputs:
227
+ xrec, posterior = self(x)
228
+ if x.shape[1] > 3:
229
+ # colorize with random projection
230
+ assert xrec.shape[1] > 3
231
+ x = self.to_rgb(x)
232
+ xrec = self.to_rgb(xrec)
233
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
234
+ log["reconstructions"] = xrec
235
+ if log_ema or self.use_ema:
236
+ with self.ema_scope():
237
+ xrec_ema, posterior_ema = self(x)
238
+ if x.shape[1] > 3:
239
+ # colorize with random projection
240
+ assert xrec_ema.shape[1] > 3
241
+ xrec_ema = self.to_rgb(xrec_ema)
242
+ log["samples_ema"] = self.decode(
243
+ torch.randn_like(posterior_ema.sample())
244
+ )
245
+ log["reconstructions_ema"] = xrec_ema
246
+ log["inputs"] = x
247
+ return log
248
+
249
+ def to_rgb(self, x):
250
+ assert self.image_key == "segmentation"
251
+ if not hasattr(self, "colorize"):
252
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
253
+ x = F.conv2d(x, weight=self.colorize)
254
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
255
+ return x
256
+
257
+
258
+ class IdentityFirstStage(torch.nn.Module):
259
+ def __init__(self, *args, vq_interface=False, **kwargs):
260
+ self.vq_interface = vq_interface
261
+ super().__init__()
262
+
263
+ def encode(self, x, *args, **kwargs):
264
+ return x
265
+
266
+ def decode(self, x, *args, **kwargs):
267
+ return x
268
+
269
+ def quantize(self, x, *args, **kwargs):
270
+ if self.vq_interface:
271
+ return x, None, [None, None, None]
272
+ return x
273
+
274
+ def forward(self, x, *args, **kwargs):
275
+ return x
sorawm/iopaint/model/anytext/ldm/models/diffusion/__init__.py ADDED
File without changes
sorawm/iopaint/model/anytext/ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import numpy as np
4
+ import torch
5
+ from tqdm import tqdm
6
+
7
+ from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
8
+ extract_into_tensor,
9
+ make_ddim_sampling_parameters,
10
+ make_ddim_timesteps,
11
+ noise_like,
12
+ )
13
+
14
+
15
+ class DDIMSampler(object):
16
+ def __init__(self, model, schedule="linear", **kwargs):
17
+ super().__init__()
18
+ self.model = model
19
+ self.ddpm_num_timesteps = model.num_timesteps
20
+ self.schedule = schedule
21
+
22
+ def register_buffer(self, name, attr):
23
+ if type(attr) == torch.Tensor:
24
+ if attr.device != torch.device("cuda"):
25
+ attr = attr.to(torch.device("cuda"))
26
+ setattr(self, name, attr)
27
+
28
+ def make_schedule(
29
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
30
+ ):
31
+ self.ddim_timesteps = make_ddim_timesteps(
32
+ ddim_discr_method=ddim_discretize,
33
+ num_ddim_timesteps=ddim_num_steps,
34
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
35
+ verbose=verbose,
36
+ )
37
+ alphas_cumprod = self.model.alphas_cumprod
38
+ assert (
39
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
40
+ ), "alphas have to be defined for each timestep"
41
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
42
+
43
+ self.register_buffer("betas", to_torch(self.model.betas))
44
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
45
+ self.register_buffer(
46
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
47
+ )
48
+
49
+ # calculations for diffusion q(x_t | x_{t-1}) and others
50
+ self.register_buffer(
51
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
52
+ )
53
+ self.register_buffer(
54
+ "sqrt_one_minus_alphas_cumprod",
55
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
56
+ )
57
+ self.register_buffer(
58
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
59
+ )
60
+ self.register_buffer(
61
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
62
+ )
63
+ self.register_buffer(
64
+ "sqrt_recipm1_alphas_cumprod",
65
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
66
+ )
67
+
68
+ # ddim sampling parameters
69
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
70
+ alphacums=alphas_cumprod.cpu(),
71
+ ddim_timesteps=self.ddim_timesteps,
72
+ eta=ddim_eta,
73
+ verbose=verbose,
74
+ )
75
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
76
+ self.register_buffer("ddim_alphas", ddim_alphas)
77
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
78
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
79
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
80
+ (1 - self.alphas_cumprod_prev)
81
+ / (1 - self.alphas_cumprod)
82
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
83
+ )
84
+ self.register_buffer(
85
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
86
+ )
87
+
88
+ @torch.no_grad()
89
+ def sample(
90
+ self,
91
+ S,
92
+ batch_size,
93
+ shape,
94
+ conditioning=None,
95
+ callback=None,
96
+ normals_sequence=None,
97
+ img_callback=None,
98
+ quantize_x0=False,
99
+ eta=0.0,
100
+ mask=None,
101
+ x0=None,
102
+ temperature=1.0,
103
+ noise_dropout=0.0,
104
+ score_corrector=None,
105
+ corrector_kwargs=None,
106
+ verbose=True,
107
+ x_T=None,
108
+ log_every_t=100,
109
+ unconditional_guidance_scale=1.0,
110
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
111
+ dynamic_threshold=None,
112
+ ucg_schedule=None,
113
+ **kwargs,
114
+ ):
115
+ if conditioning is not None:
116
+ if isinstance(conditioning, dict):
117
+ ctmp = conditioning[list(conditioning.keys())[0]]
118
+ while isinstance(ctmp, list):
119
+ ctmp = ctmp[0]
120
+ cbs = ctmp.shape[0]
121
+ # cbs = len(ctmp[0])
122
+ if cbs != batch_size:
123
+ print(
124
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
125
+ )
126
+
127
+ elif isinstance(conditioning, list):
128
+ for ctmp in conditioning:
129
+ if ctmp.shape[0] != batch_size:
130
+ print(
131
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
132
+ )
133
+
134
+ else:
135
+ if conditioning.shape[0] != batch_size:
136
+ print(
137
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
138
+ )
139
+
140
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
141
+ # sampling
142
+ C, H, W = shape
143
+ size = (batch_size, C, H, W)
144
+ print(f"Data shape for DDIM sampling is {size}, eta {eta}")
145
+
146
+ samples, intermediates = self.ddim_sampling(
147
+ conditioning,
148
+ size,
149
+ callback=callback,
150
+ img_callback=img_callback,
151
+ quantize_denoised=quantize_x0,
152
+ mask=mask,
153
+ x0=x0,
154
+ ddim_use_original_steps=False,
155
+ noise_dropout=noise_dropout,
156
+ temperature=temperature,
157
+ score_corrector=score_corrector,
158
+ corrector_kwargs=corrector_kwargs,
159
+ x_T=x_T,
160
+ log_every_t=log_every_t,
161
+ unconditional_guidance_scale=unconditional_guidance_scale,
162
+ unconditional_conditioning=unconditional_conditioning,
163
+ dynamic_threshold=dynamic_threshold,
164
+ ucg_schedule=ucg_schedule,
165
+ )
166
+ return samples, intermediates
167
+
168
+ @torch.no_grad()
169
+ def ddim_sampling(
170
+ self,
171
+ cond,
172
+ shape,
173
+ x_T=None,
174
+ ddim_use_original_steps=False,
175
+ callback=None,
176
+ timesteps=None,
177
+ quantize_denoised=False,
178
+ mask=None,
179
+ x0=None,
180
+ img_callback=None,
181
+ log_every_t=100,
182
+ temperature=1.0,
183
+ noise_dropout=0.0,
184
+ score_corrector=None,
185
+ corrector_kwargs=None,
186
+ unconditional_guidance_scale=1.0,
187
+ unconditional_conditioning=None,
188
+ dynamic_threshold=None,
189
+ ucg_schedule=None,
190
+ ):
191
+ device = self.model.betas.device
192
+ b = shape[0]
193
+ if x_T is None:
194
+ img = torch.randn(shape, device=device)
195
+ else:
196
+ img = x_T
197
+
198
+ if timesteps is None:
199
+ timesteps = (
200
+ self.ddpm_num_timesteps
201
+ if ddim_use_original_steps
202
+ else self.ddim_timesteps
203
+ )
204
+ elif timesteps is not None and not ddim_use_original_steps:
205
+ subset_end = (
206
+ int(
207
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
208
+ * self.ddim_timesteps.shape[0]
209
+ )
210
+ - 1
211
+ )
212
+ timesteps = self.ddim_timesteps[:subset_end]
213
+
214
+ intermediates = {"x_inter": [img], "pred_x0": [img], "index": [10000]}
215
+ time_range = (
216
+ reversed(range(0, timesteps))
217
+ if ddim_use_original_steps
218
+ else np.flip(timesteps)
219
+ )
220
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
221
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
222
+
223
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
224
+
225
+ for i, step in enumerate(iterator):
226
+ index = total_steps - i - 1
227
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
228
+
229
+ if mask is not None:
230
+ assert x0 is not None
231
+ img_orig = self.model.q_sample(
232
+ x0, ts
233
+ ) # TODO: deterministic forward pass?
234
+ img = img_orig * mask + (1.0 - mask) * img
235
+
236
+ if ucg_schedule is not None:
237
+ assert len(ucg_schedule) == len(time_range)
238
+ unconditional_guidance_scale = ucg_schedule[i]
239
+
240
+ outs = self.p_sample_ddim(
241
+ img,
242
+ cond,
243
+ ts,
244
+ index=index,
245
+ use_original_steps=ddim_use_original_steps,
246
+ quantize_denoised=quantize_denoised,
247
+ temperature=temperature,
248
+ noise_dropout=noise_dropout,
249
+ score_corrector=score_corrector,
250
+ corrector_kwargs=corrector_kwargs,
251
+ unconditional_guidance_scale=unconditional_guidance_scale,
252
+ unconditional_conditioning=unconditional_conditioning,
253
+ dynamic_threshold=dynamic_threshold,
254
+ )
255
+ img, pred_x0 = outs
256
+ if callback:
257
+ callback(i)
258
+ if img_callback:
259
+ img_callback(pred_x0, i)
260
+
261
+ if index % log_every_t == 0 or index == total_steps - 1:
262
+ intermediates["x_inter"].append(img)
263
+ intermediates["pred_x0"].append(pred_x0)
264
+ intermediates["index"].append(index)
265
+
266
+ return img, intermediates
267
+
268
+ @torch.no_grad()
269
+ def p_sample_ddim(
270
+ self,
271
+ x,
272
+ c,
273
+ t,
274
+ index,
275
+ repeat_noise=False,
276
+ use_original_steps=False,
277
+ quantize_denoised=False,
278
+ temperature=1.0,
279
+ noise_dropout=0.0,
280
+ score_corrector=None,
281
+ corrector_kwargs=None,
282
+ unconditional_guidance_scale=1.0,
283
+ unconditional_conditioning=None,
284
+ dynamic_threshold=None,
285
+ ):
286
+ b, *_, device = *x.shape, x.device
287
+
288
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
289
+ model_output = self.model.apply_model(x, t, c)
290
+ else:
291
+ x_in = torch.cat([x] * 2)
292
+ t_in = torch.cat([t] * 2)
293
+ if isinstance(c, dict):
294
+ assert isinstance(unconditional_conditioning, dict)
295
+ c_in = dict()
296
+ for k in c:
297
+ if isinstance(c[k], list):
298
+ c_in[k] = [
299
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
300
+ for i in range(len(c[k]))
301
+ ]
302
+ elif isinstance(c[k], dict):
303
+ c_in[k] = dict()
304
+ for key in c[k]:
305
+ if isinstance(c[k][key], list):
306
+ if not isinstance(c[k][key][0], torch.Tensor):
307
+ continue
308
+ c_in[k][key] = [
309
+ torch.cat(
310
+ [
311
+ unconditional_conditioning[k][key][i],
312
+ c[k][key][i],
313
+ ]
314
+ )
315
+ for i in range(len(c[k][key]))
316
+ ]
317
+ else:
318
+ c_in[k][key] = torch.cat(
319
+ [unconditional_conditioning[k][key], c[k][key]]
320
+ )
321
+
322
+ else:
323
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
324
+ elif isinstance(c, list):
325
+ c_in = list()
326
+ assert isinstance(unconditional_conditioning, list)
327
+ for i in range(len(c)):
328
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
329
+ else:
330
+ c_in = torch.cat([unconditional_conditioning, c])
331
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
332
+ model_output = model_uncond + unconditional_guidance_scale * (
333
+ model_t - model_uncond
334
+ )
335
+
336
+ if self.model.parameterization == "v":
337
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
338
+ else:
339
+ e_t = model_output
340
+
341
+ if score_corrector is not None:
342
+ assert self.model.parameterization == "eps", "not implemented"
343
+ e_t = score_corrector.modify_score(
344
+ self.model, e_t, x, t, c, **corrector_kwargs
345
+ )
346
+
347
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
348
+ alphas_prev = (
349
+ self.model.alphas_cumprod_prev
350
+ if use_original_steps
351
+ else self.ddim_alphas_prev
352
+ )
353
+ sqrt_one_minus_alphas = (
354
+ self.model.sqrt_one_minus_alphas_cumprod
355
+ if use_original_steps
356
+ else self.ddim_sqrt_one_minus_alphas
357
+ )
358
+ sigmas = (
359
+ self.model.ddim_sigmas_for_original_num_steps
360
+ if use_original_steps
361
+ else self.ddim_sigmas
362
+ )
363
+ # select parameters corresponding to the currently considered timestep
364
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
365
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
366
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
367
+ sqrt_one_minus_at = torch.full(
368
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
369
+ )
370
+
371
+ # current prediction for x_0
372
+ if self.model.parameterization != "v":
373
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
374
+ else:
375
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
376
+
377
+ if quantize_denoised:
378
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
379
+
380
+ if dynamic_threshold is not None:
381
+ raise NotImplementedError()
382
+
383
+ # direction pointing to x_t
384
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
385
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
386
+ if noise_dropout > 0.0:
387
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
388
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
389
+ return x_prev, pred_x0
390
+
391
+ @torch.no_grad()
392
+ def encode(
393
+ self,
394
+ x0,
395
+ c,
396
+ t_enc,
397
+ use_original_steps=False,
398
+ return_intermediates=None,
399
+ unconditional_guidance_scale=1.0,
400
+ unconditional_conditioning=None,
401
+ callback=None,
402
+ ):
403
+ num_reference_steps = (
404
+ self.ddpm_num_timesteps
405
+ if use_original_steps
406
+ else self.ddim_timesteps.shape[0]
407
+ )
408
+
409
+ assert t_enc <= num_reference_steps
410
+ num_steps = t_enc
411
+
412
+ if use_original_steps:
413
+ alphas_next = self.alphas_cumprod[:num_steps]
414
+ alphas = self.alphas_cumprod_prev[:num_steps]
415
+ else:
416
+ alphas_next = self.ddim_alphas[:num_steps]
417
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
418
+
419
+ x_next = x0
420
+ intermediates = []
421
+ inter_steps = []
422
+ for i in tqdm(range(num_steps), desc="Encoding Image"):
423
+ t = torch.full(
424
+ (x0.shape[0],), i, device=self.model.device, dtype=torch.long
425
+ )
426
+ if unconditional_guidance_scale == 1.0:
427
+ noise_pred = self.model.apply_model(x_next, t, c)
428
+ else:
429
+ assert unconditional_conditioning is not None
430
+ e_t_uncond, noise_pred = torch.chunk(
431
+ self.model.apply_model(
432
+ torch.cat((x_next, x_next)),
433
+ torch.cat((t, t)),
434
+ torch.cat((unconditional_conditioning, c)),
435
+ ),
436
+ 2,
437
+ )
438
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (
439
+ noise_pred - e_t_uncond
440
+ )
441
+
442
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
443
+ weighted_noise_pred = (
444
+ alphas_next[i].sqrt()
445
+ * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
446
+ * noise_pred
447
+ )
448
+ x_next = xt_weighted + weighted_noise_pred
449
+ if (
450
+ return_intermediates
451
+ and i % (num_steps // return_intermediates) == 0
452
+ and i < num_steps - 1
453
+ ):
454
+ intermediates.append(x_next)
455
+ inter_steps.append(i)
456
+ elif return_intermediates and i >= num_steps - 2:
457
+ intermediates.append(x_next)
458
+ inter_steps.append(i)
459
+ if callback:
460
+ callback(i)
461
+
462
+ out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
463
+ if return_intermediates:
464
+ out.update({"intermediates": intermediates})
465
+ return x_next, out
466
+
467
+ @torch.no_grad()
468
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
469
+ # fast, but does not allow for exact reconstruction
470
+ # t serves as an index to gather the correct alphas
471
+ if use_original_steps:
472
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
473
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
474
+ else:
475
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
476
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
477
+
478
+ if noise is None:
479
+ noise = torch.randn_like(x0)
480
+ return (
481
+ extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
482
+ + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
483
+ )
484
+
485
+ @torch.no_grad()
486
+ def decode(
487
+ self,
488
+ x_latent,
489
+ cond,
490
+ t_start,
491
+ unconditional_guidance_scale=1.0,
492
+ unconditional_conditioning=None,
493
+ use_original_steps=False,
494
+ callback=None,
495
+ ):
496
+ timesteps = (
497
+ np.arange(self.ddpm_num_timesteps)
498
+ if use_original_steps
499
+ else self.ddim_timesteps
500
+ )
501
+ timesteps = timesteps[:t_start]
502
+
503
+ time_range = np.flip(timesteps)
504
+ total_steps = timesteps.shape[0]
505
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
506
+
507
+ iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
508
+ x_dec = x_latent
509
+ for i, step in enumerate(iterator):
510
+ index = total_steps - i - 1
511
+ ts = torch.full(
512
+ (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
513
+ )
514
+ x_dec, _ = self.p_sample_ddim(
515
+ x_dec,
516
+ cond,
517
+ ts,
518
+ index=index,
519
+ use_original_steps=use_original_steps,
520
+ unconditional_guidance_scale=unconditional_guidance_scale,
521
+ unconditional_conditioning=unconditional_conditioning,
522
+ )
523
+ if callback:
524
+ callback(i)
525
+ return x_dec
sorawm/iopaint/model/anytext/ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,2386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Part of the implementation is borrowed and modified from ControlNet, publicly available at https://github.com/lllyasviel/ControlNet/blob/main/ldm/models/diffusion/ddpm.py
3
+ """
4
+
5
+ import itertools
6
+ from contextlib import contextmanager, nullcontext
7
+ from functools import partial
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from einops import rearrange, repeat
14
+ from omegaconf import ListConfig
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from torchvision.utils import make_grid
17
+ from tqdm import tqdm
18
+
19
+ from sorawm.iopaint.model.anytext.ldm.models.autoencoder import (
20
+ AutoencoderKL,
21
+ IdentityFirstStage,
22
+ )
23
+ from sorawm.iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler
24
+ from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
25
+ extract_into_tensor,
26
+ make_beta_schedule,
27
+ noise_like,
28
+ )
29
+ from sorawm.iopaint.model.anytext.ldm.modules.distributions.distributions import (
30
+ DiagonalGaussianDistribution,
31
+ normal_kl,
32
+ )
33
+ from sorawm.iopaint.model.anytext.ldm.modules.ema import LitEma
34
+ from sorawm.iopaint.model.anytext.ldm.util import (
35
+ count_params,
36
+ default,
37
+ exists,
38
+ instantiate_from_config,
39
+ isimage,
40
+ ismap,
41
+ log_txt_as_img,
42
+ mean_flat,
43
+ )
44
+
45
+ __conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
46
+
47
+ PRINT_DEBUG = False
48
+
49
+
50
+ def print_grad(grad):
51
+ # print('Gradient:', grad)
52
+ # print(grad.shape)
53
+ a = grad.max()
54
+ b = grad.min()
55
+ # print(f'mean={grad.mean():.4f}, max={a:.4f}, min={b:.4f}')
56
+ s = 255.0 / (a - b)
57
+ c = 255 * (-b / (a - b))
58
+ grad = grad * s + c
59
+ # print(f'mean={grad.mean():.4f}, max={grad.max():.4f}, min={grad.min():.4f}')
60
+ img = grad[0].permute(1, 2, 0).detach().cpu().numpy()
61
+ if img.shape[0] == 512:
62
+ cv2.imwrite("grad-img.jpg", img)
63
+ elif img.shape[0] == 64:
64
+ cv2.imwrite("grad-latent.jpg", img)
65
+
66
+
67
+ def disabled_train(self, mode=True):
68
+ """Overwrite model.train with this function to make sure train/eval mode
69
+ does not change anymore."""
70
+ return self
71
+
72
+
73
+ def uniform_on_device(r1, r2, shape, device):
74
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
75
+
76
+
77
+ class DDPM(torch.nn.Module):
78
+ # classic DDPM with Gaussian diffusion, in image space
79
+ def __init__(
80
+ self,
81
+ unet_config,
82
+ timesteps=1000,
83
+ beta_schedule="linear",
84
+ loss_type="l2",
85
+ ckpt_path=None,
86
+ ignore_keys=[],
87
+ load_only_unet=False,
88
+ monitor="val/loss",
89
+ use_ema=True,
90
+ first_stage_key="image",
91
+ image_size=256,
92
+ channels=3,
93
+ log_every_t=100,
94
+ clip_denoised=True,
95
+ linear_start=1e-4,
96
+ linear_end=2e-2,
97
+ cosine_s=8e-3,
98
+ given_betas=None,
99
+ original_elbo_weight=0.0,
100
+ v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
101
+ l_simple_weight=1.0,
102
+ conditioning_key=None,
103
+ parameterization="eps", # all assuming fixed variance schedules
104
+ scheduler_config=None,
105
+ use_positional_encodings=False,
106
+ learn_logvar=False,
107
+ logvar_init=0.0,
108
+ make_it_fit=False,
109
+ ucg_training=None,
110
+ reset_ema=False,
111
+ reset_num_ema_updates=False,
112
+ ):
113
+ super().__init__()
114
+ assert parameterization in [
115
+ "eps",
116
+ "x0",
117
+ "v",
118
+ ], 'currently only supporting "eps" and "x0" and "v"'
119
+ self.parameterization = parameterization
120
+ print(
121
+ f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
122
+ )
123
+ self.cond_stage_model = None
124
+ self.clip_denoised = clip_denoised
125
+ self.log_every_t = log_every_t
126
+ self.first_stage_key = first_stage_key
127
+ self.image_size = image_size # try conv?
128
+ self.channels = channels
129
+ self.use_positional_encodings = use_positional_encodings
130
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
131
+ count_params(self.model, verbose=True)
132
+ self.use_ema = use_ema
133
+ if self.use_ema:
134
+ self.model_ema = LitEma(self.model)
135
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
136
+
137
+ self.use_scheduler = scheduler_config is not None
138
+ if self.use_scheduler:
139
+ self.scheduler_config = scheduler_config
140
+
141
+ self.v_posterior = v_posterior
142
+ self.original_elbo_weight = original_elbo_weight
143
+ self.l_simple_weight = l_simple_weight
144
+
145
+ if monitor is not None:
146
+ self.monitor = monitor
147
+ self.make_it_fit = make_it_fit
148
+ if reset_ema:
149
+ assert exists(ckpt_path)
150
+ if ckpt_path is not None:
151
+ self.init_from_ckpt(
152
+ ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet
153
+ )
154
+ if reset_ema:
155
+ assert self.use_ema
156
+ print(
157
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint."
158
+ )
159
+ self.model_ema = LitEma(self.model)
160
+ if reset_num_ema_updates:
161
+ print(
162
+ " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ "
163
+ )
164
+ assert self.use_ema
165
+ self.model_ema.reset_num_updates()
166
+
167
+ self.register_schedule(
168
+ given_betas=given_betas,
169
+ beta_schedule=beta_schedule,
170
+ timesteps=timesteps,
171
+ linear_start=linear_start,
172
+ linear_end=linear_end,
173
+ cosine_s=cosine_s,
174
+ )
175
+
176
+ self.loss_type = loss_type
177
+
178
+ self.learn_logvar = learn_logvar
179
+ logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
180
+ if self.learn_logvar:
181
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
182
+ else:
183
+ self.register_buffer("logvar", logvar)
184
+
185
+ self.ucg_training = ucg_training or dict()
186
+ if self.ucg_training:
187
+ self.ucg_prng = np.random.RandomState()
188
+
189
+ def register_schedule(
190
+ self,
191
+ given_betas=None,
192
+ beta_schedule="linear",
193
+ timesteps=1000,
194
+ linear_start=1e-4,
195
+ linear_end=2e-2,
196
+ cosine_s=8e-3,
197
+ ):
198
+ if exists(given_betas):
199
+ betas = given_betas
200
+ else:
201
+ betas = make_beta_schedule(
202
+ beta_schedule,
203
+ timesteps,
204
+ linear_start=linear_start,
205
+ linear_end=linear_end,
206
+ cosine_s=cosine_s,
207
+ )
208
+ alphas = 1.0 - betas
209
+ alphas_cumprod = np.cumprod(alphas, axis=0)
210
+ # np.save('1.npy', alphas_cumprod)
211
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
212
+
213
+ (timesteps,) = betas.shape
214
+ self.num_timesteps = int(timesteps)
215
+ self.linear_start = linear_start
216
+ self.linear_end = linear_end
217
+ assert (
218
+ alphas_cumprod.shape[0] == self.num_timesteps
219
+ ), "alphas have to be defined for each timestep"
220
+
221
+ to_torch = partial(torch.tensor, dtype=torch.float32)
222
+
223
+ self.register_buffer("betas", to_torch(betas))
224
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
225
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
226
+
227
+ # calculations for diffusion q(x_t | x_{t-1}) and others
228
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
229
+ self.register_buffer(
230
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
231
+ )
232
+ self.register_buffer(
233
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
234
+ )
235
+ self.register_buffer(
236
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
237
+ )
238
+ self.register_buffer(
239
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
240
+ )
241
+
242
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
243
+ posterior_variance = (1 - self.v_posterior) * betas * (
244
+ 1.0 - alphas_cumprod_prev
245
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
246
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
247
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
248
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
249
+ self.register_buffer(
250
+ "posterior_log_variance_clipped",
251
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
252
+ )
253
+ self.register_buffer(
254
+ "posterior_mean_coef1",
255
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
256
+ )
257
+ self.register_buffer(
258
+ "posterior_mean_coef2",
259
+ to_torch(
260
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
261
+ ),
262
+ )
263
+
264
+ if self.parameterization == "eps":
265
+ lvlb_weights = self.betas**2 / (
266
+ 2
267
+ * self.posterior_variance
268
+ * to_torch(alphas)
269
+ * (1 - self.alphas_cumprod)
270
+ )
271
+ elif self.parameterization == "x0":
272
+ lvlb_weights = (
273
+ 0.5
274
+ * np.sqrt(torch.Tensor(alphas_cumprod))
275
+ / (2.0 * 1 - torch.Tensor(alphas_cumprod))
276
+ )
277
+ elif self.parameterization == "v":
278
+ lvlb_weights = torch.ones_like(
279
+ self.betas**2
280
+ / (
281
+ 2
282
+ * self.posterior_variance
283
+ * to_torch(alphas)
284
+ * (1 - self.alphas_cumprod)
285
+ )
286
+ )
287
+ else:
288
+ raise NotImplementedError("mu not supported")
289
+ lvlb_weights[0] = lvlb_weights[1]
290
+ self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
291
+ assert not torch.isnan(self.lvlb_weights).all()
292
+
293
+ @contextmanager
294
+ def ema_scope(self, context=None):
295
+ if self.use_ema:
296
+ self.model_ema.store(self.model.parameters())
297
+ self.model_ema.copy_to(self.model)
298
+ if context is not None:
299
+ print(f"{context}: Switched to EMA weights")
300
+ try:
301
+ yield None
302
+ finally:
303
+ if self.use_ema:
304
+ self.model_ema.restore(self.model.parameters())
305
+ if context is not None:
306
+ print(f"{context}: Restored training weights")
307
+
308
+ @torch.no_grad()
309
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
310
+ sd = torch.load(path, map_location="cpu")
311
+ if "state_dict" in list(sd.keys()):
312
+ sd = sd["state_dict"]
313
+ keys = list(sd.keys())
314
+ for k in keys:
315
+ for ik in ignore_keys:
316
+ if k.startswith(ik):
317
+ print("Deleting key {} from state_dict.".format(k))
318
+ del sd[k]
319
+ if self.make_it_fit:
320
+ n_params = len(
321
+ [
322
+ name
323
+ for name, _ in itertools.chain(
324
+ self.named_parameters(), self.named_buffers()
325
+ )
326
+ ]
327
+ )
328
+ for name, param in tqdm(
329
+ itertools.chain(self.named_parameters(), self.named_buffers()),
330
+ desc="Fitting old weights to new weights",
331
+ total=n_params,
332
+ ):
333
+ if not name in sd:
334
+ continue
335
+ old_shape = sd[name].shape
336
+ new_shape = param.shape
337
+ assert len(old_shape) == len(new_shape)
338
+ if len(new_shape) > 2:
339
+ # we only modify first two axes
340
+ assert new_shape[2:] == old_shape[2:]
341
+ # assumes first axis corresponds to output dim
342
+ if not new_shape == old_shape:
343
+ new_param = param.clone()
344
+ old_param = sd[name]
345
+ if len(new_shape) == 1:
346
+ for i in range(new_param.shape[0]):
347
+ new_param[i] = old_param[i % old_shape[0]]
348
+ elif len(new_shape) >= 2:
349
+ for i in range(new_param.shape[0]):
350
+ for j in range(new_param.shape[1]):
351
+ new_param[i, j] = old_param[
352
+ i % old_shape[0], j % old_shape[1]
353
+ ]
354
+
355
+ n_used_old = torch.ones(old_shape[1])
356
+ for j in range(new_param.shape[1]):
357
+ n_used_old[j % old_shape[1]] += 1
358
+ n_used_new = torch.zeros(new_shape[1])
359
+ for j in range(new_param.shape[1]):
360
+ n_used_new[j] = n_used_old[j % old_shape[1]]
361
+
362
+ n_used_new = n_used_new[None, :]
363
+ while len(n_used_new.shape) < len(new_shape):
364
+ n_used_new = n_used_new.unsqueeze(-1)
365
+ new_param /= n_used_new
366
+
367
+ sd[name] = new_param
368
+
369
+ missing, unexpected = (
370
+ self.load_state_dict(sd, strict=False)
371
+ if not only_model
372
+ else self.model.load_state_dict(sd, strict=False)
373
+ )
374
+ print(
375
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
376
+ )
377
+ if len(missing) > 0:
378
+ print(f"Missing Keys:\n {missing}")
379
+ if len(unexpected) > 0:
380
+ print(f"\nUnexpected Keys:\n {unexpected}")
381
+
382
+ def q_mean_variance(self, x_start, t):
383
+ """
384
+ Get the distribution q(x_t | x_0).
385
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
386
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
387
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
388
+ """
389
+ mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
390
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
391
+ log_variance = extract_into_tensor(
392
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
393
+ )
394
+ return mean, variance, log_variance
395
+
396
+ def predict_start_from_noise(self, x_t, t, noise):
397
+ return (
398
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
399
+ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
400
+ * noise
401
+ )
402
+
403
+ def predict_start_from_z_and_v(self, x_t, t, v):
404
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
405
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
406
+ return (
407
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
408
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
409
+ )
410
+
411
+ def predict_eps_from_z_and_v(self, x_t, t, v):
412
+ return (
413
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
414
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
415
+ * x_t
416
+ )
417
+
418
+ def q_posterior(self, x_start, x_t, t):
419
+ posterior_mean = (
420
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
421
+ + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
422
+ )
423
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
424
+ posterior_log_variance_clipped = extract_into_tensor(
425
+ self.posterior_log_variance_clipped, t, x_t.shape
426
+ )
427
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
428
+
429
+ def p_mean_variance(self, x, t, clip_denoised: bool):
430
+ model_out = self.model(x, t)
431
+ if self.parameterization == "eps":
432
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
433
+ elif self.parameterization == "x0":
434
+ x_recon = model_out
435
+ if clip_denoised:
436
+ x_recon.clamp_(-1.0, 1.0)
437
+
438
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
439
+ x_start=x_recon, x_t=x, t=t
440
+ )
441
+ return model_mean, posterior_variance, posterior_log_variance
442
+
443
+ @torch.no_grad()
444
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
445
+ b, *_, device = *x.shape, x.device
446
+ model_mean, _, model_log_variance = self.p_mean_variance(
447
+ x=x, t=t, clip_denoised=clip_denoised
448
+ )
449
+ noise = noise_like(x.shape, device, repeat_noise)
450
+ # no noise when t == 0
451
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
452
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
453
+
454
+ @torch.no_grad()
455
+ def p_sample_loop(self, shape, return_intermediates=False):
456
+ device = self.betas.device
457
+ b = shape[0]
458
+ img = torch.randn(shape, device=device)
459
+ intermediates = [img]
460
+ for i in tqdm(
461
+ reversed(range(0, self.num_timesteps)),
462
+ desc="Sampling t",
463
+ total=self.num_timesteps,
464
+ ):
465
+ img = self.p_sample(
466
+ img,
467
+ torch.full((b,), i, device=device, dtype=torch.long),
468
+ clip_denoised=self.clip_denoised,
469
+ )
470
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
471
+ intermediates.append(img)
472
+ if return_intermediates:
473
+ return img, intermediates
474
+ return img
475
+
476
+ @torch.no_grad()
477
+ def sample(self, batch_size=16, return_intermediates=False):
478
+ image_size = self.image_size
479
+ channels = self.channels
480
+ return self.p_sample_loop(
481
+ (batch_size, channels, image_size, image_size),
482
+ return_intermediates=return_intermediates,
483
+ )
484
+
485
+ def q_sample(self, x_start, t, noise=None):
486
+ noise = default(noise, lambda: torch.randn_like(x_start))
487
+ return (
488
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
489
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
490
+ * noise
491
+ )
492
+
493
+ def get_v(self, x, noise, t):
494
+ return (
495
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
496
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
497
+ )
498
+
499
+ def get_loss(self, pred, target, mean=True):
500
+ if self.loss_type == "l1":
501
+ loss = (target - pred).abs()
502
+ if mean:
503
+ loss = loss.mean()
504
+ elif self.loss_type == "l2":
505
+ if mean:
506
+ loss = torch.nn.functional.mse_loss(target, pred)
507
+ else:
508
+ loss = torch.nn.functional.mse_loss(target, pred, reduction="none")
509
+ else:
510
+ raise NotImplementedError("unknown loss type '{loss_type}'")
511
+
512
+ return loss
513
+
514
+ def p_losses(self, x_start, t, noise=None):
515
+ noise = default(noise, lambda: torch.randn_like(x_start))
516
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
517
+ model_out = self.model(x_noisy, t)
518
+
519
+ loss_dict = {}
520
+ if self.parameterization == "eps":
521
+ target = noise
522
+ elif self.parameterization == "x0":
523
+ target = x_start
524
+ elif self.parameterization == "v":
525
+ target = self.get_v(x_start, noise, t)
526
+ else:
527
+ raise NotImplementedError(
528
+ f"Parameterization {self.parameterization} not yet supported"
529
+ )
530
+
531
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
532
+
533
+ log_prefix = "train" if self.training else "val"
534
+
535
+ loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()})
536
+ loss_simple = loss.mean() * self.l_simple_weight
537
+
538
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
539
+ loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb})
540
+
541
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
542
+
543
+ loss_dict.update({f"{log_prefix}/loss": loss})
544
+
545
+ return loss, loss_dict
546
+
547
+ def forward(self, x, *args, **kwargs):
548
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
549
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
550
+ t = torch.randint(
551
+ 0, self.num_timesteps, (x.shape[0],), device=self.device
552
+ ).long()
553
+ return self.p_losses(x, t, *args, **kwargs)
554
+
555
+ def get_input(self, batch, k):
556
+ x = batch[k]
557
+ if len(x.shape) == 3:
558
+ x = x[..., None]
559
+ x = rearrange(x, "b h w c -> b c h w")
560
+ x = x.to(memory_format=torch.contiguous_format).float()
561
+ return x
562
+
563
+ def shared_step(self, batch):
564
+ x = self.get_input(batch, self.first_stage_key)
565
+ loss, loss_dict = self(x)
566
+ return loss, loss_dict
567
+
568
+ def training_step(self, batch, batch_idx):
569
+ for k in self.ucg_training:
570
+ p = self.ucg_training[k]["p"]
571
+ val = self.ucg_training[k]["val"]
572
+ if val is None:
573
+ val = ""
574
+ for i in range(len(batch[k])):
575
+ if self.ucg_prng.choice(2, p=[1 - p, p]):
576
+ batch[k][i] = val
577
+
578
+ loss, loss_dict = self.shared_step(batch)
579
+
580
+ self.log_dict(
581
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True
582
+ )
583
+
584
+ self.log(
585
+ "global_step",
586
+ self.global_step,
587
+ prog_bar=True,
588
+ logger=True,
589
+ on_step=True,
590
+ on_epoch=False,
591
+ )
592
+
593
+ if self.use_scheduler:
594
+ lr = self.optimizers().param_groups[0]["lr"]
595
+ self.log(
596
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
597
+ )
598
+
599
+ return loss
600
+
601
+ @torch.no_grad()
602
+ def validation_step(self, batch, batch_idx):
603
+ _, loss_dict_no_ema = self.shared_step(batch)
604
+ with self.ema_scope():
605
+ _, loss_dict_ema = self.shared_step(batch)
606
+ loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema}
607
+ self.log_dict(
608
+ loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True
609
+ )
610
+ self.log_dict(
611
+ loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True
612
+ )
613
+
614
+ def on_train_batch_end(self, *args, **kwargs):
615
+ if self.use_ema:
616
+ self.model_ema(self.model)
617
+
618
+ def _get_rows_from_list(self, samples):
619
+ n_imgs_per_row = len(samples)
620
+ denoise_grid = rearrange(samples, "n b c h w -> b n c h w")
621
+ denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
622
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
623
+ return denoise_grid
624
+
625
+ @torch.no_grad()
626
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
627
+ log = dict()
628
+ x = self.get_input(batch, self.first_stage_key)
629
+ N = min(x.shape[0], N)
630
+ n_row = min(x.shape[0], n_row)
631
+ x = x.to(self.device)[:N]
632
+ log["inputs"] = x
633
+
634
+ # get diffusion row
635
+ diffusion_row = list()
636
+ x_start = x[:n_row]
637
+
638
+ for t in range(self.num_timesteps):
639
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
640
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
641
+ t = t.to(self.device).long()
642
+ noise = torch.randn_like(x_start)
643
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
644
+ diffusion_row.append(x_noisy)
645
+
646
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
647
+
648
+ if sample:
649
+ # get denoise row
650
+ with self.ema_scope("Plotting"):
651
+ samples, denoise_row = self.sample(
652
+ batch_size=N, return_intermediates=True
653
+ )
654
+
655
+ log["samples"] = samples
656
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
657
+
658
+ if return_keys:
659
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
660
+ return log
661
+ else:
662
+ return {key: log[key] for key in return_keys}
663
+ return log
664
+
665
+ def configure_optimizers(self):
666
+ lr = self.learning_rate
667
+ params = list(self.model.parameters())
668
+ if self.learn_logvar:
669
+ params = params + [self.logvar]
670
+ opt = torch.optim.AdamW(params, lr=lr)
671
+ return opt
672
+
673
+
674
+ class LatentDiffusion(DDPM):
675
+ """main class"""
676
+
677
+ def __init__(
678
+ self,
679
+ first_stage_config,
680
+ cond_stage_config,
681
+ num_timesteps_cond=None,
682
+ cond_stage_key="image",
683
+ cond_stage_trainable=False,
684
+ concat_mode=True,
685
+ cond_stage_forward=None,
686
+ conditioning_key=None,
687
+ scale_factor=1.0,
688
+ scale_by_std=False,
689
+ force_null_conditioning=False,
690
+ *args,
691
+ **kwargs,
692
+ ):
693
+ self.force_null_conditioning = force_null_conditioning
694
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
695
+ self.scale_by_std = scale_by_std
696
+ assert self.num_timesteps_cond <= kwargs["timesteps"]
697
+ # for backwards compatibility after implementation of DiffusionWrapper
698
+ if conditioning_key is None:
699
+ conditioning_key = "concat" if concat_mode else "crossattn"
700
+ if (
701
+ cond_stage_config == "__is_unconditional__"
702
+ and not self.force_null_conditioning
703
+ ):
704
+ conditioning_key = None
705
+ ckpt_path = kwargs.pop("ckpt_path", None)
706
+ reset_ema = kwargs.pop("reset_ema", False)
707
+ reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
708
+ ignore_keys = kwargs.pop("ignore_keys", [])
709
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
710
+ self.concat_mode = concat_mode
711
+ self.cond_stage_trainable = cond_stage_trainable
712
+ self.cond_stage_key = cond_stage_key
713
+ try:
714
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
715
+ except:
716
+ self.num_downs = 0
717
+ if not scale_by_std:
718
+ self.scale_factor = scale_factor
719
+ else:
720
+ self.register_buffer("scale_factor", torch.tensor(scale_factor))
721
+ self.instantiate_first_stage(first_stage_config)
722
+ self.instantiate_cond_stage(cond_stage_config)
723
+ self.cond_stage_forward = cond_stage_forward
724
+ self.clip_denoised = False
725
+ self.bbox_tokenizer = None
726
+
727
+ self.restarted_from_ckpt = False
728
+ if ckpt_path is not None:
729
+ self.init_from_ckpt(ckpt_path, ignore_keys)
730
+ self.restarted_from_ckpt = True
731
+ if reset_ema:
732
+ assert self.use_ema
733
+ print(
734
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint."
735
+ )
736
+ self.model_ema = LitEma(self.model)
737
+ if reset_num_ema_updates:
738
+ print(
739
+ " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ "
740
+ )
741
+ assert self.use_ema
742
+ self.model_ema.reset_num_updates()
743
+
744
+ def make_cond_schedule(
745
+ self,
746
+ ):
747
+ self.cond_ids = torch.full(
748
+ size=(self.num_timesteps,),
749
+ fill_value=self.num_timesteps - 1,
750
+ dtype=torch.long,
751
+ )
752
+ ids = torch.round(
753
+ torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
754
+ ).long()
755
+ self.cond_ids[: self.num_timesteps_cond] = ids
756
+
757
+ @torch.no_grad()
758
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
759
+ # only for very first batch
760
+ if (
761
+ self.scale_by_std
762
+ and self.current_epoch == 0
763
+ and self.global_step == 0
764
+ and batch_idx == 0
765
+ and not self.restarted_from_ckpt
766
+ ):
767
+ assert (
768
+ self.scale_factor == 1.0
769
+ ), "rather not use custom rescaling and std-rescaling simultaneously"
770
+ # set rescale weight to 1./std of encodings
771
+ print("### USING STD-RESCALING ###")
772
+ x = super().get_input(batch, self.first_stage_key)
773
+ x = x.to(self.device)
774
+ encoder_posterior = self.encode_first_stage(x)
775
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
776
+ del self.scale_factor
777
+ self.register_buffer("scale_factor", 1.0 / z.flatten().std())
778
+ print(f"setting self.scale_factor to {self.scale_factor}")
779
+ print("### USING STD-RESCALING ###")
780
+
781
+ def register_schedule(
782
+ self,
783
+ given_betas=None,
784
+ beta_schedule="linear",
785
+ timesteps=1000,
786
+ linear_start=1e-4,
787
+ linear_end=2e-2,
788
+ cosine_s=8e-3,
789
+ ):
790
+ super().register_schedule(
791
+ given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
792
+ )
793
+
794
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
795
+ if self.shorten_cond_schedule:
796
+ self.make_cond_schedule()
797
+
798
+ def instantiate_first_stage(self, config):
799
+ model = instantiate_from_config(config)
800
+ self.first_stage_model = model.eval()
801
+ self.first_stage_model.train = disabled_train
802
+ for param in self.first_stage_model.parameters():
803
+ param.requires_grad = False
804
+
805
+ def instantiate_cond_stage(self, config):
806
+ if not self.cond_stage_trainable:
807
+ if config == "__is_first_stage__":
808
+ print("Using first stage also as cond stage.")
809
+ self.cond_stage_model = self.first_stage_model
810
+ elif config == "__is_unconditional__":
811
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
812
+ self.cond_stage_model = None
813
+ # self.be_unconditional = True
814
+ else:
815
+ model = instantiate_from_config(config)
816
+ self.cond_stage_model = model.eval()
817
+ self.cond_stage_model.train = disabled_train
818
+ for param in self.cond_stage_model.parameters():
819
+ param.requires_grad = False
820
+ else:
821
+ assert config != "__is_first_stage__"
822
+ assert config != "__is_unconditional__"
823
+ model = instantiate_from_config(config)
824
+ self.cond_stage_model = model
825
+
826
+ def _get_denoise_row_from_list(
827
+ self, samples, desc="", force_no_decoder_quantization=False
828
+ ):
829
+ denoise_row = []
830
+ for zd in tqdm(samples, desc=desc):
831
+ denoise_row.append(
832
+ self.decode_first_stage(
833
+ zd.to(self.device), force_not_quantize=force_no_decoder_quantization
834
+ )
835
+ )
836
+ n_imgs_per_row = len(denoise_row)
837
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
838
+ denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w")
839
+ denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
840
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
841
+ return denoise_grid
842
+
843
+ def get_first_stage_encoding(self, encoder_posterior):
844
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
845
+ z = encoder_posterior.sample()
846
+ elif isinstance(encoder_posterior, torch.Tensor):
847
+ z = encoder_posterior
848
+ else:
849
+ raise NotImplementedError(
850
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
851
+ )
852
+ return self.scale_factor * z
853
+
854
+ def get_learned_conditioning(self, c):
855
+ if self.cond_stage_forward is None:
856
+ if hasattr(self.cond_stage_model, "encode") and callable(
857
+ self.cond_stage_model.encode
858
+ ):
859
+ c = self.cond_stage_model.encode(c)
860
+ if isinstance(c, DiagonalGaussianDistribution):
861
+ c = c.mode()
862
+ else:
863
+ c = self.cond_stage_model(c)
864
+ else:
865
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
866
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
867
+ return c
868
+
869
+ def meshgrid(self, h, w):
870
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
871
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
872
+
873
+ arr = torch.cat([y, x], dim=-1)
874
+ return arr
875
+
876
+ def delta_border(self, h, w):
877
+ """
878
+ :param h: height
879
+ :param w: width
880
+ :return: normalized distance to image border,
881
+ wtith min distance = 0 at border and max dist = 0.5 at image center
882
+ """
883
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
884
+ arr = self.meshgrid(h, w) / lower_right_corner
885
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
886
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
887
+ edge_dist = torch.min(
888
+ torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1
889
+ )[0]
890
+ return edge_dist
891
+
892
+ def get_weighting(self, h, w, Ly, Lx, device):
893
+ weighting = self.delta_border(h, w)
894
+ weighting = torch.clip(
895
+ weighting,
896
+ self.split_input_params["clip_min_weight"],
897
+ self.split_input_params["clip_max_weight"],
898
+ )
899
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
900
+
901
+ if self.split_input_params["tie_braker"]:
902
+ L_weighting = self.delta_border(Ly, Lx)
903
+ L_weighting = torch.clip(
904
+ L_weighting,
905
+ self.split_input_params["clip_min_tie_weight"],
906
+ self.split_input_params["clip_max_tie_weight"],
907
+ )
908
+
909
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
910
+ weighting = weighting * L_weighting
911
+ return weighting
912
+
913
+ def get_fold_unfold(
914
+ self, x, kernel_size, stride, uf=1, df=1
915
+ ): # todo load once not every time, shorten code
916
+ """
917
+ :param x: img of size (bs, c, h, w)
918
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
919
+ """
920
+ bs, nc, h, w = x.shape
921
+
922
+ # number of crops in image
923
+ Ly = (h - kernel_size[0]) // stride[0] + 1
924
+ Lx = (w - kernel_size[1]) // stride[1] + 1
925
+
926
+ if uf == 1 and df == 1:
927
+ fold_params = dict(
928
+ kernel_size=kernel_size, dilation=1, padding=0, stride=stride
929
+ )
930
+ unfold = torch.nn.Unfold(**fold_params)
931
+
932
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
933
+
934
+ weighting = self.get_weighting(
935
+ kernel_size[0], kernel_size[1], Ly, Lx, x.device
936
+ ).to(x.dtype)
937
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
938
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
939
+
940
+ elif uf > 1 and df == 1:
941
+ fold_params = dict(
942
+ kernel_size=kernel_size, dilation=1, padding=0, stride=stride
943
+ )
944
+ unfold = torch.nn.Unfold(**fold_params)
945
+
946
+ fold_params2 = dict(
947
+ kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
948
+ dilation=1,
949
+ padding=0,
950
+ stride=(stride[0] * uf, stride[1] * uf),
951
+ )
952
+ fold = torch.nn.Fold(
953
+ output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2
954
+ )
955
+
956
+ weighting = self.get_weighting(
957
+ kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device
958
+ ).to(x.dtype)
959
+ normalization = fold(weighting).view(
960
+ 1, 1, h * uf, w * uf
961
+ ) # normalizes the overlap
962
+ weighting = weighting.view(
963
+ (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)
964
+ )
965
+
966
+ elif df > 1 and uf == 1:
967
+ fold_params = dict(
968
+ kernel_size=kernel_size, dilation=1, padding=0, stride=stride
969
+ )
970
+ unfold = torch.nn.Unfold(**fold_params)
971
+
972
+ fold_params2 = dict(
973
+ kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
974
+ dilation=1,
975
+ padding=0,
976
+ stride=(stride[0] // df, stride[1] // df),
977
+ )
978
+ fold = torch.nn.Fold(
979
+ output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2
980
+ )
981
+
982
+ weighting = self.get_weighting(
983
+ kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device
984
+ ).to(x.dtype)
985
+ normalization = fold(weighting).view(
986
+ 1, 1, h // df, w // df
987
+ ) # normalizes the overlap
988
+ weighting = weighting.view(
989
+ (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)
990
+ )
991
+
992
+ else:
993
+ raise NotImplementedError
994
+
995
+ return fold, unfold, normalization, weighting
996
+
997
+ @torch.no_grad()
998
+ def get_input(
999
+ self,
1000
+ batch,
1001
+ k,
1002
+ return_first_stage_outputs=False,
1003
+ force_c_encode=False,
1004
+ cond_key=None,
1005
+ return_original_cond=False,
1006
+ bs=None,
1007
+ return_x=False,
1008
+ mask_k=None,
1009
+ ):
1010
+ x = super().get_input(batch, k)
1011
+ if bs is not None:
1012
+ x = x[:bs]
1013
+ x = x.to(self.device)
1014
+ encoder_posterior = self.encode_first_stage(x)
1015
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
1016
+
1017
+ if mask_k is not None:
1018
+ mx = super().get_input(batch, mask_k)
1019
+ if bs is not None:
1020
+ mx = mx[:bs]
1021
+ mx = mx.to(self.device)
1022
+ encoder_posterior = self.encode_first_stage(mx)
1023
+ mx = self.get_first_stage_encoding(encoder_posterior).detach()
1024
+
1025
+ if self.model.conditioning_key is not None and not self.force_null_conditioning:
1026
+ if cond_key is None:
1027
+ cond_key = self.cond_stage_key
1028
+ if cond_key != self.first_stage_key:
1029
+ if cond_key in ["caption", "coordinates_bbox", "txt"]:
1030
+ xc = batch[cond_key]
1031
+ elif cond_key in ["class_label", "cls"]:
1032
+ xc = batch
1033
+ else:
1034
+ xc = super().get_input(batch, cond_key).to(self.device)
1035
+ else:
1036
+ xc = x
1037
+ if not self.cond_stage_trainable or force_c_encode:
1038
+ if isinstance(xc, dict) or isinstance(xc, list):
1039
+ c = self.get_learned_conditioning(xc)
1040
+ else:
1041
+ c = self.get_learned_conditioning(xc.to(self.device))
1042
+ else:
1043
+ c = xc
1044
+ if bs is not None:
1045
+ c = c[:bs]
1046
+
1047
+ if self.use_positional_encodings:
1048
+ pos_x, pos_y = self.compute_latent_shifts(batch)
1049
+ ckey = __conditioning_keys__[self.model.conditioning_key]
1050
+ c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y}
1051
+
1052
+ else:
1053
+ c = None
1054
+ xc = None
1055
+ if self.use_positional_encodings:
1056
+ pos_x, pos_y = self.compute_latent_shifts(batch)
1057
+ c = {"pos_x": pos_x, "pos_y": pos_y}
1058
+ out = [z, c]
1059
+ if return_first_stage_outputs:
1060
+ xrec = self.decode_first_stage(z)
1061
+ out.extend([x, xrec])
1062
+ if return_x:
1063
+ out.extend([x])
1064
+ if return_original_cond:
1065
+ out.append(xc)
1066
+ if mask_k:
1067
+ out.append(mx)
1068
+ return out
1069
+
1070
+ @torch.no_grad()
1071
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
1072
+ if predict_cids:
1073
+ if z.dim() == 4:
1074
+ z = torch.argmax(z.exp(), dim=1).long()
1075
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
1076
+ z = rearrange(z, "b h w c -> b c h w").contiguous()
1077
+
1078
+ z = 1.0 / self.scale_factor * z
1079
+ return self.first_stage_model.decode(z)
1080
+
1081
+ def decode_first_stage_grad(self, z, predict_cids=False, force_not_quantize=False):
1082
+ if predict_cids:
1083
+ if z.dim() == 4:
1084
+ z = torch.argmax(z.exp(), dim=1).long()
1085
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
1086
+ z = rearrange(z, "b h w c -> b c h w").contiguous()
1087
+
1088
+ z = 1.0 / self.scale_factor * z
1089
+ return self.first_stage_model.decode(z)
1090
+
1091
+ @torch.no_grad()
1092
+ def encode_first_stage(self, x):
1093
+ return self.first_stage_model.encode(x)
1094
+
1095
+ def shared_step(self, batch, **kwargs):
1096
+ x, c = self.get_input(batch, self.first_stage_key)
1097
+ loss = self(x, c)
1098
+ return loss
1099
+
1100
+ def forward(self, x, c, *args, **kwargs):
1101
+ t = torch.randint(
1102
+ 0, self.num_timesteps, (x.shape[0],), device=self.device
1103
+ ).long()
1104
+ # t = torch.randint(500, 501, (x.shape[0],), device=self.device).long()
1105
+ if self.model.conditioning_key is not None:
1106
+ assert c is not None
1107
+ if self.cond_stage_trainable:
1108
+ c = self.get_learned_conditioning(c)
1109
+ if self.shorten_cond_schedule: # TODO: drop this option
1110
+ tc = self.cond_ids[t].to(self.device)
1111
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
1112
+ return self.p_losses(x, c, t, *args, **kwargs)
1113
+
1114
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
1115
+ if isinstance(cond, dict):
1116
+ # hybrid case, cond is expected to be a dict
1117
+ pass
1118
+ else:
1119
+ if not isinstance(cond, list):
1120
+ cond = [cond]
1121
+ key = (
1122
+ "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn"
1123
+ )
1124
+ cond = {key: cond}
1125
+
1126
+ x_recon = self.model(x_noisy, t, **cond)
1127
+
1128
+ if isinstance(x_recon, tuple) and not return_ids:
1129
+ return x_recon[0]
1130
+ else:
1131
+ return x_recon
1132
+
1133
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
1134
+ return (
1135
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
1136
+ - pred_xstart
1137
+ ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
1138
+
1139
+ def _prior_bpd(self, x_start):
1140
+ """
1141
+ Get the prior KL term for the variational lower-bound, measured in
1142
+ bits-per-dim.
1143
+ This term can't be optimized, as it only depends on the encoder.
1144
+ :param x_start: the [N x C x ...] tensor of inputs.
1145
+ :return: a batch of [N] KL values (in bits), one per batch element.
1146
+ """
1147
+ batch_size = x_start.shape[0]
1148
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1149
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1150
+ kl_prior = normal_kl(
1151
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
1152
+ )
1153
+ return mean_flat(kl_prior) / np.log(2.0)
1154
+
1155
+ def p_mean_variance(
1156
+ self,
1157
+ x,
1158
+ c,
1159
+ t,
1160
+ clip_denoised: bool,
1161
+ return_codebook_ids=False,
1162
+ quantize_denoised=False,
1163
+ return_x0=False,
1164
+ score_corrector=None,
1165
+ corrector_kwargs=None,
1166
+ ):
1167
+ t_in = t
1168
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
1169
+
1170
+ if score_corrector is not None:
1171
+ assert self.parameterization == "eps"
1172
+ model_out = score_corrector.modify_score(
1173
+ self, model_out, x, t, c, **corrector_kwargs
1174
+ )
1175
+
1176
+ if return_codebook_ids:
1177
+ model_out, logits = model_out
1178
+
1179
+ if self.parameterization == "eps":
1180
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1181
+ elif self.parameterization == "x0":
1182
+ x_recon = model_out
1183
+ else:
1184
+ raise NotImplementedError()
1185
+
1186
+ if clip_denoised:
1187
+ x_recon.clamp_(-1.0, 1.0)
1188
+ if quantize_denoised:
1189
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1190
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
1191
+ x_start=x_recon, x_t=x, t=t
1192
+ )
1193
+ if return_codebook_ids:
1194
+ return model_mean, posterior_variance, posterior_log_variance, logits
1195
+ elif return_x0:
1196
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1197
+ else:
1198
+ return model_mean, posterior_variance, posterior_log_variance
1199
+
1200
+ @torch.no_grad()
1201
+ def p_sample(
1202
+ self,
1203
+ x,
1204
+ c,
1205
+ t,
1206
+ clip_denoised=False,
1207
+ repeat_noise=False,
1208
+ return_codebook_ids=False,
1209
+ quantize_denoised=False,
1210
+ return_x0=False,
1211
+ temperature=1.0,
1212
+ noise_dropout=0.0,
1213
+ score_corrector=None,
1214
+ corrector_kwargs=None,
1215
+ ):
1216
+ b, *_, device = *x.shape, x.device
1217
+ outputs = self.p_mean_variance(
1218
+ x=x,
1219
+ c=c,
1220
+ t=t,
1221
+ clip_denoised=clip_denoised,
1222
+ return_codebook_ids=return_codebook_ids,
1223
+ quantize_denoised=quantize_denoised,
1224
+ return_x0=return_x0,
1225
+ score_corrector=score_corrector,
1226
+ corrector_kwargs=corrector_kwargs,
1227
+ )
1228
+ if return_codebook_ids:
1229
+ raise DeprecationWarning("Support dropped.")
1230
+ model_mean, _, model_log_variance, logits = outputs
1231
+ elif return_x0:
1232
+ model_mean, _, model_log_variance, x0 = outputs
1233
+ else:
1234
+ model_mean, _, model_log_variance = outputs
1235
+
1236
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1237
+ if noise_dropout > 0.0:
1238
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1239
+ # no noise when t == 0
1240
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1241
+
1242
+ if return_codebook_ids:
1243
+ return (
1244
+ model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
1245
+ logits.argmax(dim=1),
1246
+ )
1247
+ if return_x0:
1248
+ return (
1249
+ model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
1250
+ x0,
1251
+ )
1252
+ else:
1253
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1254
+
1255
+ @torch.no_grad()
1256
+ def progressive_denoising(
1257
+ self,
1258
+ cond,
1259
+ shape,
1260
+ verbose=True,
1261
+ callback=None,
1262
+ quantize_denoised=False,
1263
+ img_callback=None,
1264
+ mask=None,
1265
+ x0=None,
1266
+ temperature=1.0,
1267
+ noise_dropout=0.0,
1268
+ score_corrector=None,
1269
+ corrector_kwargs=None,
1270
+ batch_size=None,
1271
+ x_T=None,
1272
+ start_T=None,
1273
+ log_every_t=None,
1274
+ ):
1275
+ if not log_every_t:
1276
+ log_every_t = self.log_every_t
1277
+ timesteps = self.num_timesteps
1278
+ if batch_size is not None:
1279
+ b = batch_size if batch_size is not None else shape[0]
1280
+ shape = [batch_size] + list(shape)
1281
+ else:
1282
+ b = batch_size = shape[0]
1283
+ if x_T is None:
1284
+ img = torch.randn(shape, device=self.device)
1285
+ else:
1286
+ img = x_T
1287
+ intermediates = []
1288
+ if cond is not None:
1289
+ if isinstance(cond, dict):
1290
+ cond = {
1291
+ key: cond[key][:batch_size]
1292
+ if not isinstance(cond[key], list)
1293
+ else list(map(lambda x: x[:batch_size], cond[key]))
1294
+ for key in cond
1295
+ }
1296
+ else:
1297
+ cond = (
1298
+ [c[:batch_size] for c in cond]
1299
+ if isinstance(cond, list)
1300
+ else cond[:batch_size]
1301
+ )
1302
+
1303
+ if start_T is not None:
1304
+ timesteps = min(timesteps, start_T)
1305
+ iterator = (
1306
+ tqdm(
1307
+ reversed(range(0, timesteps)),
1308
+ desc="Progressive Generation",
1309
+ total=timesteps,
1310
+ )
1311
+ if verbose
1312
+ else reversed(range(0, timesteps))
1313
+ )
1314
+ if type(temperature) == float:
1315
+ temperature = [temperature] * timesteps
1316
+
1317
+ for i in iterator:
1318
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1319
+ if self.shorten_cond_schedule:
1320
+ assert self.model.conditioning_key != "hybrid"
1321
+ tc = self.cond_ids[ts].to(cond.device)
1322
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1323
+
1324
+ img, x0_partial = self.p_sample(
1325
+ img,
1326
+ cond,
1327
+ ts,
1328
+ clip_denoised=self.clip_denoised,
1329
+ quantize_denoised=quantize_denoised,
1330
+ return_x0=True,
1331
+ temperature=temperature[i],
1332
+ noise_dropout=noise_dropout,
1333
+ score_corrector=score_corrector,
1334
+ corrector_kwargs=corrector_kwargs,
1335
+ )
1336
+ if mask is not None:
1337
+ assert x0 is not None
1338
+ img_orig = self.q_sample(x0, ts)
1339
+ img = img_orig * mask + (1.0 - mask) * img
1340
+
1341
+ if i % log_every_t == 0 or i == timesteps - 1:
1342
+ intermediates.append(x0_partial)
1343
+ if callback:
1344
+ callback(i)
1345
+ if img_callback:
1346
+ img_callback(img, i)
1347
+ return img, intermediates
1348
+
1349
+ @torch.no_grad()
1350
+ def p_sample_loop(
1351
+ self,
1352
+ cond,
1353
+ shape,
1354
+ return_intermediates=False,
1355
+ x_T=None,
1356
+ verbose=True,
1357
+ callback=None,
1358
+ timesteps=None,
1359
+ quantize_denoised=False,
1360
+ mask=None,
1361
+ x0=None,
1362
+ img_callback=None,
1363
+ start_T=None,
1364
+ log_every_t=None,
1365
+ ):
1366
+ if not log_every_t:
1367
+ log_every_t = self.log_every_t
1368
+ device = self.betas.device
1369
+ b = shape[0]
1370
+ if x_T is None:
1371
+ img = torch.randn(shape, device=device)
1372
+ else:
1373
+ img = x_T
1374
+
1375
+ intermediates = [img]
1376
+ if timesteps is None:
1377
+ timesteps = self.num_timesteps
1378
+
1379
+ if start_T is not None:
1380
+ timesteps = min(timesteps, start_T)
1381
+ iterator = (
1382
+ tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
1383
+ if verbose
1384
+ else reversed(range(0, timesteps))
1385
+ )
1386
+
1387
+ if mask is not None:
1388
+ assert x0 is not None
1389
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1390
+
1391
+ for i in iterator:
1392
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1393
+ if self.shorten_cond_schedule:
1394
+ assert self.model.conditioning_key != "hybrid"
1395
+ tc = self.cond_ids[ts].to(cond.device)
1396
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1397
+
1398
+ img = self.p_sample(
1399
+ img,
1400
+ cond,
1401
+ ts,
1402
+ clip_denoised=self.clip_denoised,
1403
+ quantize_denoised=quantize_denoised,
1404
+ )
1405
+ if mask is not None:
1406
+ img_orig = self.q_sample(x0, ts)
1407
+ img = img_orig * mask + (1.0 - mask) * img
1408
+
1409
+ if i % log_every_t == 0 or i == timesteps - 1:
1410
+ intermediates.append(img)
1411
+ if callback:
1412
+ callback(i)
1413
+ if img_callback:
1414
+ img_callback(img, i)
1415
+
1416
+ if return_intermediates:
1417
+ return img, intermediates
1418
+ return img
1419
+
1420
+ @torch.no_grad()
1421
+ def sample(
1422
+ self,
1423
+ cond,
1424
+ batch_size=16,
1425
+ return_intermediates=False,
1426
+ x_T=None,
1427
+ verbose=True,
1428
+ timesteps=None,
1429
+ quantize_denoised=False,
1430
+ mask=None,
1431
+ x0=None,
1432
+ shape=None,
1433
+ **kwargs,
1434
+ ):
1435
+ if shape is None:
1436
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1437
+ if cond is not None:
1438
+ if isinstance(cond, dict):
1439
+ cond = {
1440
+ key: cond[key][:batch_size]
1441
+ if not isinstance(cond[key], list)
1442
+ else list(map(lambda x: x[:batch_size], cond[key]))
1443
+ for key in cond
1444
+ }
1445
+ else:
1446
+ cond = (
1447
+ [c[:batch_size] for c in cond]
1448
+ if isinstance(cond, list)
1449
+ else cond[:batch_size]
1450
+ )
1451
+ return self.p_sample_loop(
1452
+ cond,
1453
+ shape,
1454
+ return_intermediates=return_intermediates,
1455
+ x_T=x_T,
1456
+ verbose=verbose,
1457
+ timesteps=timesteps,
1458
+ quantize_denoised=quantize_denoised,
1459
+ mask=mask,
1460
+ x0=x0,
1461
+ )
1462
+
1463
+ @torch.no_grad()
1464
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
1465
+ if ddim:
1466
+ ddim_sampler = DDIMSampler(self)
1467
+ shape = (self.channels, self.image_size, self.image_size)
1468
+ samples, intermediates = ddim_sampler.sample(
1469
+ ddim_steps, batch_size, shape, cond, verbose=False, **kwargs
1470
+ )
1471
+
1472
+ else:
1473
+ samples, intermediates = self.sample(
1474
+ cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs
1475
+ )
1476
+
1477
+ return samples, intermediates
1478
+
1479
+ @torch.no_grad()
1480
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
1481
+ if null_label is not None:
1482
+ xc = null_label
1483
+ if isinstance(xc, ListConfig):
1484
+ xc = list(xc)
1485
+ if isinstance(xc, dict) or isinstance(xc, list):
1486
+ c = self.get_learned_conditioning(xc)
1487
+ else:
1488
+ if hasattr(xc, "to"):
1489
+ xc = xc.to(self.device)
1490
+ c = self.get_learned_conditioning(xc)
1491
+ else:
1492
+ if self.cond_stage_key in ["class_label", "cls"]:
1493
+ xc = self.cond_stage_model.get_unconditional_conditioning(
1494
+ batch_size, device=self.device
1495
+ )
1496
+ return self.get_learned_conditioning(xc)
1497
+ else:
1498
+ raise NotImplementedError("todo")
1499
+ if isinstance(c, list): # in case the encoder gives us a list
1500
+ for i in range(len(c)):
1501
+ c[i] = repeat(c[i], "1 ... -> b ...", b=batch_size).to(self.device)
1502
+ else:
1503
+ c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
1504
+ return c
1505
+
1506
+ @torch.no_grad()
1507
+ def log_images(
1508
+ self,
1509
+ batch,
1510
+ N=8,
1511
+ n_row=4,
1512
+ sample=True,
1513
+ ddim_steps=50,
1514
+ ddim_eta=0.0,
1515
+ return_keys=None,
1516
+ quantize_denoised=True,
1517
+ inpaint=True,
1518
+ plot_denoise_rows=False,
1519
+ plot_progressive_rows=True,
1520
+ plot_diffusion_rows=True,
1521
+ unconditional_guidance_scale=1.0,
1522
+ unconditional_guidance_label=None,
1523
+ use_ema_scope=True,
1524
+ **kwargs,
1525
+ ):
1526
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1527
+ use_ddim = ddim_steps is not None
1528
+
1529
+ log = dict()
1530
+ z, c, x, xrec, xc = self.get_input(
1531
+ batch,
1532
+ self.first_stage_key,
1533
+ return_first_stage_outputs=True,
1534
+ force_c_encode=True,
1535
+ return_original_cond=True,
1536
+ bs=N,
1537
+ )
1538
+ N = min(x.shape[0], N)
1539
+ n_row = min(x.shape[0], n_row)
1540
+ log["inputs"] = x
1541
+ log["reconstruction"] = xrec
1542
+ if self.model.conditioning_key is not None:
1543
+ if hasattr(self.cond_stage_model, "decode"):
1544
+ xc = self.cond_stage_model.decode(c)
1545
+ log["conditioning"] = xc
1546
+ elif self.cond_stage_key in ["caption", "txt"]:
1547
+ xc = log_txt_as_img(
1548
+ (x.shape[2], x.shape[3]),
1549
+ batch[self.cond_stage_key],
1550
+ size=x.shape[2] // 25,
1551
+ )
1552
+ log["conditioning"] = xc
1553
+ elif self.cond_stage_key in ["class_label", "cls"]:
1554
+ try:
1555
+ xc = log_txt_as_img(
1556
+ (x.shape[2], x.shape[3]),
1557
+ batch["human_label"],
1558
+ size=x.shape[2] // 25,
1559
+ )
1560
+ log["conditioning"] = xc
1561
+ except KeyError:
1562
+ # probably no "human_label" in batch
1563
+ pass
1564
+ elif isimage(xc):
1565
+ log["conditioning"] = xc
1566
+ if ismap(xc):
1567
+ log["original_conditioning"] = self.to_rgb(xc)
1568
+
1569
+ if plot_diffusion_rows:
1570
+ # get diffusion row
1571
+ diffusion_row = list()
1572
+ z_start = z[:n_row]
1573
+ for t in range(self.num_timesteps):
1574
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1575
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
1576
+ t = t.to(self.device).long()
1577
+ noise = torch.randn_like(z_start)
1578
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1579
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1580
+
1581
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1582
+ diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
1583
+ diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
1584
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1585
+ log["diffusion_row"] = diffusion_grid
1586
+
1587
+ if sample:
1588
+ # get denoise row
1589
+ with ema_scope("Sampling"):
1590
+ samples, z_denoise_row = self.sample_log(
1591
+ cond=c,
1592
+ batch_size=N,
1593
+ ddim=use_ddim,
1594
+ ddim_steps=ddim_steps,
1595
+ eta=ddim_eta,
1596
+ )
1597
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1598
+ x_samples = self.decode_first_stage(samples)
1599
+ log["samples"] = x_samples
1600
+ if plot_denoise_rows:
1601
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1602
+ log["denoise_row"] = denoise_grid
1603
+
1604
+ if (
1605
+ quantize_denoised
1606
+ and not isinstance(self.first_stage_model, AutoencoderKL)
1607
+ and not isinstance(self.first_stage_model, IdentityFirstStage)
1608
+ ):
1609
+ # also display when quantizing x0 while sampling
1610
+ with ema_scope("Plotting Quantized Denoised"):
1611
+ samples, z_denoise_row = self.sample_log(
1612
+ cond=c,
1613
+ batch_size=N,
1614
+ ddim=use_ddim,
1615
+ ddim_steps=ddim_steps,
1616
+ eta=ddim_eta,
1617
+ quantize_denoised=True,
1618
+ )
1619
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1620
+ # quantize_denoised=True)
1621
+ x_samples = self.decode_first_stage(samples.to(self.device))
1622
+ log["samples_x0_quantized"] = x_samples
1623
+
1624
+ if unconditional_guidance_scale > 1.0:
1625
+ uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1626
+ if self.model.conditioning_key == "crossattn-adm":
1627
+ uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
1628
+ with ema_scope("Sampling with classifier-free guidance"):
1629
+ samples_cfg, _ = self.sample_log(
1630
+ cond=c,
1631
+ batch_size=N,
1632
+ ddim=use_ddim,
1633
+ ddim_steps=ddim_steps,
1634
+ eta=ddim_eta,
1635
+ unconditional_guidance_scale=unconditional_guidance_scale,
1636
+ unconditional_conditioning=uc,
1637
+ )
1638
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1639
+ log[
1640
+ f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
1641
+ ] = x_samples_cfg
1642
+
1643
+ if inpaint:
1644
+ # make a simple center square
1645
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1646
+ mask = torch.ones(N, h, w).to(self.device)
1647
+ # zeros will be filled in
1648
+ mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0
1649
+ mask = mask[:, None, ...]
1650
+ with ema_scope("Plotting Inpaint"):
1651
+ samples, _ = self.sample_log(
1652
+ cond=c,
1653
+ batch_size=N,
1654
+ ddim=use_ddim,
1655
+ eta=ddim_eta,
1656
+ ddim_steps=ddim_steps,
1657
+ x0=z[:N],
1658
+ mask=mask,
1659
+ )
1660
+ x_samples = self.decode_first_stage(samples.to(self.device))
1661
+ log["samples_inpainting"] = x_samples
1662
+ log["mask"] = mask
1663
+
1664
+ # outpaint
1665
+ mask = 1.0 - mask
1666
+ with ema_scope("Plotting Outpaint"):
1667
+ samples, _ = self.sample_log(
1668
+ cond=c,
1669
+ batch_size=N,
1670
+ ddim=use_ddim,
1671
+ eta=ddim_eta,
1672
+ ddim_steps=ddim_steps,
1673
+ x0=z[:N],
1674
+ mask=mask,
1675
+ )
1676
+ x_samples = self.decode_first_stage(samples.to(self.device))
1677
+ log["samples_outpainting"] = x_samples
1678
+
1679
+ if plot_progressive_rows:
1680
+ with ema_scope("Plotting Progressives"):
1681
+ img, progressives = self.progressive_denoising(
1682
+ c,
1683
+ shape=(self.channels, self.image_size, self.image_size),
1684
+ batch_size=N,
1685
+ )
1686
+ prog_row = self._get_denoise_row_from_list(
1687
+ progressives, desc="Progressive Generation"
1688
+ )
1689
+ log["progressive_row"] = prog_row
1690
+
1691
+ if return_keys:
1692
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1693
+ return log
1694
+ else:
1695
+ return {key: log[key] for key in return_keys}
1696
+ return log
1697
+
1698
+ def configure_optimizers(self):
1699
+ lr = self.learning_rate
1700
+ params = list(self.model.parameters())
1701
+ if self.cond_stage_trainable:
1702
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1703
+ params = params + list(self.cond_stage_model.parameters())
1704
+ if self.learn_logvar:
1705
+ print("Diffusion model optimizing logvar")
1706
+ params.append(self.logvar)
1707
+ opt = torch.optim.AdamW(params, lr=lr)
1708
+ if self.use_scheduler:
1709
+ assert "target" in self.scheduler_config
1710
+ scheduler = instantiate_from_config(self.scheduler_config)
1711
+
1712
+ print("Setting up LambdaLR scheduler...")
1713
+ scheduler = [
1714
+ {
1715
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
1716
+ "interval": "step",
1717
+ "frequency": 1,
1718
+ }
1719
+ ]
1720
+ return [opt], scheduler
1721
+ return opt
1722
+
1723
+ @torch.no_grad()
1724
+ def to_rgb(self, x):
1725
+ x = x.float()
1726
+ if not hasattr(self, "colorize"):
1727
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1728
+ x = nn.functional.conv2d(x, weight=self.colorize)
1729
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
1730
+ return x
1731
+
1732
+
1733
+ class DiffusionWrapper(torch.nn.Module):
1734
+ def __init__(self, diff_model_config, conditioning_key):
1735
+ super().__init__()
1736
+ self.sequential_cross_attn = diff_model_config.pop(
1737
+ "sequential_crossattn", False
1738
+ )
1739
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1740
+ self.conditioning_key = conditioning_key
1741
+ assert self.conditioning_key in [
1742
+ None,
1743
+ "concat",
1744
+ "crossattn",
1745
+ "hybrid",
1746
+ "adm",
1747
+ "hybrid-adm",
1748
+ "crossattn-adm",
1749
+ ]
1750
+
1751
+ def forward(
1752
+ self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None
1753
+ ):
1754
+ if self.conditioning_key is None:
1755
+ out = self.diffusion_model(x, t)
1756
+ elif self.conditioning_key == "concat":
1757
+ xc = torch.cat([x] + c_concat, dim=1)
1758
+ out = self.diffusion_model(xc, t)
1759
+ elif self.conditioning_key == "crossattn":
1760
+ if not self.sequential_cross_attn:
1761
+ cc = torch.cat(c_crossattn, 1)
1762
+ else:
1763
+ cc = c_crossattn
1764
+ out = self.diffusion_model(x, t, context=cc)
1765
+ elif self.conditioning_key == "hybrid":
1766
+ xc = torch.cat([x] + c_concat, dim=1)
1767
+ cc = torch.cat(c_crossattn, 1)
1768
+ out = self.diffusion_model(xc, t, context=cc)
1769
+ elif self.conditioning_key == "hybrid-adm":
1770
+ assert c_adm is not None
1771
+ xc = torch.cat([x] + c_concat, dim=1)
1772
+ cc = torch.cat(c_crossattn, 1)
1773
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm)
1774
+ elif self.conditioning_key == "crossattn-adm":
1775
+ assert c_adm is not None
1776
+ cc = torch.cat(c_crossattn, 1)
1777
+ out = self.diffusion_model(x, t, context=cc, y=c_adm)
1778
+ elif self.conditioning_key == "adm":
1779
+ cc = c_crossattn[0]
1780
+ out = self.diffusion_model(x, t, y=cc)
1781
+ else:
1782
+ raise NotImplementedError()
1783
+
1784
+ return out
1785
+
1786
+
1787
+ class LatentUpscaleDiffusion(LatentDiffusion):
1788
+ def __init__(
1789
+ self,
1790
+ *args,
1791
+ low_scale_config,
1792
+ low_scale_key="LR",
1793
+ noise_level_key=None,
1794
+ **kwargs,
1795
+ ):
1796
+ super().__init__(*args, **kwargs)
1797
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
1798
+ assert not self.cond_stage_trainable
1799
+ self.instantiate_low_stage(low_scale_config)
1800
+ self.low_scale_key = low_scale_key
1801
+ self.noise_level_key = noise_level_key
1802
+
1803
+ def instantiate_low_stage(self, config):
1804
+ model = instantiate_from_config(config)
1805
+ self.low_scale_model = model.eval()
1806
+ self.low_scale_model.train = disabled_train
1807
+ for param in self.low_scale_model.parameters():
1808
+ param.requires_grad = False
1809
+
1810
+ @torch.no_grad()
1811
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
1812
+ if not log_mode:
1813
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
1814
+ else:
1815
+ z, c, x, xrec, xc = super().get_input(
1816
+ batch,
1817
+ self.first_stage_key,
1818
+ return_first_stage_outputs=True,
1819
+ force_c_encode=True,
1820
+ return_original_cond=True,
1821
+ bs=bs,
1822
+ )
1823
+ x_low = batch[self.low_scale_key][:bs]
1824
+ x_low = rearrange(x_low, "b h w c -> b c h w")
1825
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
1826
+ zx, noise_level = self.low_scale_model(x_low)
1827
+ if self.noise_level_key is not None:
1828
+ # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
1829
+ raise NotImplementedError("TODO")
1830
+
1831
+ all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
1832
+ if log_mode:
1833
+ # TODO: maybe disable if too expensive
1834
+ x_low_rec = self.low_scale_model.decode(zx)
1835
+ return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
1836
+ return z, all_conds
1837
+
1838
+ @torch.no_grad()
1839
+ def log_images(
1840
+ self,
1841
+ batch,
1842
+ N=8,
1843
+ n_row=4,
1844
+ sample=True,
1845
+ ddim_steps=200,
1846
+ ddim_eta=1.0,
1847
+ return_keys=None,
1848
+ plot_denoise_rows=False,
1849
+ plot_progressive_rows=True,
1850
+ plot_diffusion_rows=True,
1851
+ unconditional_guidance_scale=1.0,
1852
+ unconditional_guidance_label=None,
1853
+ use_ema_scope=True,
1854
+ **kwargs,
1855
+ ):
1856
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1857
+ use_ddim = ddim_steps is not None
1858
+
1859
+ log = dict()
1860
+ z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(
1861
+ batch, self.first_stage_key, bs=N, log_mode=True
1862
+ )
1863
+ N = min(x.shape[0], N)
1864
+ n_row = min(x.shape[0], n_row)
1865
+ log["inputs"] = x
1866
+ log["reconstruction"] = xrec
1867
+ log["x_lr"] = x_low
1868
+ log[
1869
+ f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"
1870
+ ] = x_low_rec
1871
+ if self.model.conditioning_key is not None:
1872
+ if hasattr(self.cond_stage_model, "decode"):
1873
+ xc = self.cond_stage_model.decode(c)
1874
+ log["conditioning"] = xc
1875
+ elif self.cond_stage_key in ["caption", "txt"]:
1876
+ xc = log_txt_as_img(
1877
+ (x.shape[2], x.shape[3]),
1878
+ batch[self.cond_stage_key],
1879
+ size=x.shape[2] // 25,
1880
+ )
1881
+ log["conditioning"] = xc
1882
+ elif self.cond_stage_key in ["class_label", "cls"]:
1883
+ xc = log_txt_as_img(
1884
+ (x.shape[2], x.shape[3]),
1885
+ batch["human_label"],
1886
+ size=x.shape[2] // 25,
1887
+ )
1888
+ log["conditioning"] = xc
1889
+ elif isimage(xc):
1890
+ log["conditioning"] = xc
1891
+ if ismap(xc):
1892
+ log["original_conditioning"] = self.to_rgb(xc)
1893
+
1894
+ if plot_diffusion_rows:
1895
+ # get diffusion row
1896
+ diffusion_row = list()
1897
+ z_start = z[:n_row]
1898
+ for t in range(self.num_timesteps):
1899
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1900
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
1901
+ t = t.to(self.device).long()
1902
+ noise = torch.randn_like(z_start)
1903
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1904
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1905
+
1906
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1907
+ diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
1908
+ diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
1909
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1910
+ log["diffusion_row"] = diffusion_grid
1911
+
1912
+ if sample:
1913
+ # get denoise row
1914
+ with ema_scope("Sampling"):
1915
+ samples, z_denoise_row = self.sample_log(
1916
+ cond=c,
1917
+ batch_size=N,
1918
+ ddim=use_ddim,
1919
+ ddim_steps=ddim_steps,
1920
+ eta=ddim_eta,
1921
+ )
1922
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1923
+ x_samples = self.decode_first_stage(samples)
1924
+ log["samples"] = x_samples
1925
+ if plot_denoise_rows:
1926
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1927
+ log["denoise_row"] = denoise_grid
1928
+
1929
+ if unconditional_guidance_scale > 1.0:
1930
+ uc_tmp = self.get_unconditional_conditioning(
1931
+ N, unconditional_guidance_label
1932
+ )
1933
+ # TODO explore better "unconditional" choices for the other keys
1934
+ # maybe guide away from empty text label and highest noise level and maximally degraded zx?
1935
+ uc = dict()
1936
+ for k in c:
1937
+ if k == "c_crossattn":
1938
+ assert isinstance(c[k], list) and len(c[k]) == 1
1939
+ uc[k] = [uc_tmp]
1940
+ elif k == "c_adm": # todo: only run with text-based guidance?
1941
+ assert isinstance(c[k], torch.Tensor)
1942
+ # uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
1943
+ uc[k] = c[k]
1944
+ elif isinstance(c[k], list):
1945
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
1946
+ else:
1947
+ uc[k] = c[k]
1948
+
1949
+ with ema_scope("Sampling with classifier-free guidance"):
1950
+ samples_cfg, _ = self.sample_log(
1951
+ cond=c,
1952
+ batch_size=N,
1953
+ ddim=use_ddim,
1954
+ ddim_steps=ddim_steps,
1955
+ eta=ddim_eta,
1956
+ unconditional_guidance_scale=unconditional_guidance_scale,
1957
+ unconditional_conditioning=uc,
1958
+ )
1959
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1960
+ log[
1961
+ f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
1962
+ ] = x_samples_cfg
1963
+
1964
+ if plot_progressive_rows:
1965
+ with ema_scope("Plotting Progressives"):
1966
+ img, progressives = self.progressive_denoising(
1967
+ c,
1968
+ shape=(self.channels, self.image_size, self.image_size),
1969
+ batch_size=N,
1970
+ )
1971
+ prog_row = self._get_denoise_row_from_list(
1972
+ progressives, desc="Progressive Generation"
1973
+ )
1974
+ log["progressive_row"] = prog_row
1975
+
1976
+ return log
1977
+
1978
+
1979
+ class LatentFinetuneDiffusion(LatentDiffusion):
1980
+ """
1981
+ Basis for different finetunas, such as inpainting or depth2image
1982
+ To disable finetuning mode, set finetune_keys to None
1983
+ """
1984
+
1985
+ def __init__(
1986
+ self,
1987
+ concat_keys: tuple,
1988
+ finetune_keys=(
1989
+ "model.diffusion_model.input_blocks.0.0.weight",
1990
+ "model_ema.diffusion_modelinput_blocks00weight",
1991
+ ),
1992
+ keep_finetune_dims=4,
1993
+ # if model was trained without concat mode before and we would like to keep these channels
1994
+ c_concat_log_start=None, # to log reconstruction of c_concat codes
1995
+ c_concat_log_end=None,
1996
+ *args,
1997
+ **kwargs,
1998
+ ):
1999
+ ckpt_path = kwargs.pop("ckpt_path", None)
2000
+ ignore_keys = kwargs.pop("ignore_keys", list())
2001
+ super().__init__(*args, **kwargs)
2002
+ self.finetune_keys = finetune_keys
2003
+ self.concat_keys = concat_keys
2004
+ self.keep_dims = keep_finetune_dims
2005
+ self.c_concat_log_start = c_concat_log_start
2006
+ self.c_concat_log_end = c_concat_log_end
2007
+ if exists(self.finetune_keys):
2008
+ assert exists(ckpt_path), "can only finetune from a given checkpoint"
2009
+ if exists(ckpt_path):
2010
+ self.init_from_ckpt(ckpt_path, ignore_keys)
2011
+
2012
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
2013
+ sd = torch.load(path, map_location="cpu")
2014
+ if "state_dict" in list(sd.keys()):
2015
+ sd = sd["state_dict"]
2016
+ keys = list(sd.keys())
2017
+ for k in keys:
2018
+ for ik in ignore_keys:
2019
+ if k.startswith(ik):
2020
+ print("Deleting key {} from state_dict.".format(k))
2021
+ del sd[k]
2022
+
2023
+ # make it explicit, finetune by including extra input channels
2024
+ if exists(self.finetune_keys) and k in self.finetune_keys:
2025
+ new_entry = None
2026
+ for name, param in self.named_parameters():
2027
+ if name in self.finetune_keys:
2028
+ print(
2029
+ f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only"
2030
+ )
2031
+ new_entry = torch.zeros_like(param) # zero init
2032
+ assert exists(new_entry), "did not find matching parameter to modify"
2033
+ new_entry[:, : self.keep_dims, ...] = sd[k]
2034
+ sd[k] = new_entry
2035
+
2036
+ missing, unexpected = (
2037
+ self.load_state_dict(sd, strict=False)
2038
+ if not only_model
2039
+ else self.model.load_state_dict(sd, strict=False)
2040
+ )
2041
+ print(
2042
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
2043
+ )
2044
+ if len(missing) > 0:
2045
+ print(f"Missing Keys: {missing}")
2046
+ if len(unexpected) > 0:
2047
+ print(f"Unexpected Keys: {unexpected}")
2048
+
2049
+ @torch.no_grad()
2050
+ def log_images(
2051
+ self,
2052
+ batch,
2053
+ N=8,
2054
+ n_row=4,
2055
+ sample=True,
2056
+ ddim_steps=200,
2057
+ ddim_eta=1.0,
2058
+ return_keys=None,
2059
+ quantize_denoised=True,
2060
+ inpaint=True,
2061
+ plot_denoise_rows=False,
2062
+ plot_progressive_rows=True,
2063
+ plot_diffusion_rows=True,
2064
+ unconditional_guidance_scale=1.0,
2065
+ unconditional_guidance_label=None,
2066
+ use_ema_scope=True,
2067
+ **kwargs,
2068
+ ):
2069
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
2070
+ use_ddim = ddim_steps is not None
2071
+
2072
+ log = dict()
2073
+ z, c, x, xrec, xc = self.get_input(
2074
+ batch, self.first_stage_key, bs=N, return_first_stage_outputs=True
2075
+ )
2076
+ c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
2077
+ N = min(x.shape[0], N)
2078
+ n_row = min(x.shape[0], n_row)
2079
+ log["inputs"] = x
2080
+ log["reconstruction"] = xrec
2081
+ if self.model.conditioning_key is not None:
2082
+ if hasattr(self.cond_stage_model, "decode"):
2083
+ xc = self.cond_stage_model.decode(c)
2084
+ log["conditioning"] = xc
2085
+ elif self.cond_stage_key in ["caption", "txt"]:
2086
+ xc = log_txt_as_img(
2087
+ (x.shape[2], x.shape[3]),
2088
+ batch[self.cond_stage_key],
2089
+ size=x.shape[2] // 25,
2090
+ )
2091
+ log["conditioning"] = xc
2092
+ elif self.cond_stage_key in ["class_label", "cls"]:
2093
+ xc = log_txt_as_img(
2094
+ (x.shape[2], x.shape[3]),
2095
+ batch["human_label"],
2096
+ size=x.shape[2] // 25,
2097
+ )
2098
+ log["conditioning"] = xc
2099
+ elif isimage(xc):
2100
+ log["conditioning"] = xc
2101
+ if ismap(xc):
2102
+ log["original_conditioning"] = self.to_rgb(xc)
2103
+
2104
+ if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
2105
+ log["c_concat_decoded"] = self.decode_first_stage(
2106
+ c_cat[:, self.c_concat_log_start : self.c_concat_log_end]
2107
+ )
2108
+
2109
+ if plot_diffusion_rows:
2110
+ # get diffusion row
2111
+ diffusion_row = list()
2112
+ z_start = z[:n_row]
2113
+ for t in range(self.num_timesteps):
2114
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
2115
+ t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
2116
+ t = t.to(self.device).long()
2117
+ noise = torch.randn_like(z_start)
2118
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
2119
+ diffusion_row.append(self.decode_first_stage(z_noisy))
2120
+
2121
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
2122
+ diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
2123
+ diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
2124
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
2125
+ log["diffusion_row"] = diffusion_grid
2126
+
2127
+ if sample:
2128
+ # get denoise row
2129
+ with ema_scope("Sampling"):
2130
+ samples, z_denoise_row = self.sample_log(
2131
+ cond={"c_concat": [c_cat], "c_crossattn": [c]},
2132
+ batch_size=N,
2133
+ ddim=use_ddim,
2134
+ ddim_steps=ddim_steps,
2135
+ eta=ddim_eta,
2136
+ )
2137
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
2138
+ x_samples = self.decode_first_stage(samples)
2139
+ log["samples"] = x_samples
2140
+ if plot_denoise_rows:
2141
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
2142
+ log["denoise_row"] = denoise_grid
2143
+
2144
+ if unconditional_guidance_scale > 1.0:
2145
+ uc_cross = self.get_unconditional_conditioning(
2146
+ N, unconditional_guidance_label
2147
+ )
2148
+ uc_cat = c_cat
2149
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
2150
+ with ema_scope("Sampling with classifier-free guidance"):
2151
+ samples_cfg, _ = self.sample_log(
2152
+ cond={"c_concat": [c_cat], "c_crossattn": [c]},
2153
+ batch_size=N,
2154
+ ddim=use_ddim,
2155
+ ddim_steps=ddim_steps,
2156
+ eta=ddim_eta,
2157
+ unconditional_guidance_scale=unconditional_guidance_scale,
2158
+ unconditional_conditioning=uc_full,
2159
+ )
2160
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
2161
+ log[
2162
+ f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
2163
+ ] = x_samples_cfg
2164
+
2165
+ return log
2166
+
2167
+
2168
+ class LatentInpaintDiffusion(LatentFinetuneDiffusion):
2169
+ """
2170
+ can either run as pure inpainting model (only concat mode) or with mixed conditionings,
2171
+ e.g. mask as concat and text via cross-attn.
2172
+ To disable finetuning mode, set finetune_keys to None
2173
+ """
2174
+
2175
+ def __init__(
2176
+ self,
2177
+ concat_keys=("mask", "masked_image"),
2178
+ masked_image_key="masked_image",
2179
+ *args,
2180
+ **kwargs,
2181
+ ):
2182
+ super().__init__(concat_keys, *args, **kwargs)
2183
+ self.masked_image_key = masked_image_key
2184
+ assert self.masked_image_key in concat_keys
2185
+
2186
+ @torch.no_grad()
2187
+ def get_input(
2188
+ self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
2189
+ ):
2190
+ # note: restricted to non-trainable encoders currently
2191
+ assert (
2192
+ not self.cond_stage_trainable
2193
+ ), "trainable cond stages not yet supported for inpainting"
2194
+ z, c, x, xrec, xc = super().get_input(
2195
+ batch,
2196
+ self.first_stage_key,
2197
+ return_first_stage_outputs=True,
2198
+ force_c_encode=True,
2199
+ return_original_cond=True,
2200
+ bs=bs,
2201
+ )
2202
+
2203
+ assert exists(self.concat_keys)
2204
+ c_cat = list()
2205
+ for ck in self.concat_keys:
2206
+ cc = (
2207
+ rearrange(batch[ck], "b h w c -> b c h w")
2208
+ .to(memory_format=torch.contiguous_format)
2209
+ .float()
2210
+ )
2211
+ if bs is not None:
2212
+ cc = cc[:bs]
2213
+ cc = cc.to(self.device)
2214
+ bchw = z.shape
2215
+ if ck != self.masked_image_key:
2216
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
2217
+ else:
2218
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
2219
+ c_cat.append(cc)
2220
+ c_cat = torch.cat(c_cat, dim=1)
2221
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
2222
+ if return_first_stage_outputs:
2223
+ return z, all_conds, x, xrec, xc
2224
+ return z, all_conds
2225
+
2226
+ @torch.no_grad()
2227
+ def log_images(self, *args, **kwargs):
2228
+ log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
2229
+ log["masked_image"] = (
2230
+ rearrange(args[0]["masked_image"], "b h w c -> b c h w")
2231
+ .to(memory_format=torch.contiguous_format)
2232
+ .float()
2233
+ )
2234
+ return log
2235
+
2236
+
2237
+ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
2238
+ """
2239
+ condition on monocular depth estimation
2240
+ """
2241
+
2242
+ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
2243
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
2244
+ self.depth_model = instantiate_from_config(depth_stage_config)
2245
+ self.depth_stage_key = concat_keys[0]
2246
+
2247
+ @torch.no_grad()
2248
+ def get_input(
2249
+ self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
2250
+ ):
2251
+ # note: restricted to non-trainable encoders currently
2252
+ assert (
2253
+ not self.cond_stage_trainable
2254
+ ), "trainable cond stages not yet supported for depth2img"
2255
+ z, c, x, xrec, xc = super().get_input(
2256
+ batch,
2257
+ self.first_stage_key,
2258
+ return_first_stage_outputs=True,
2259
+ force_c_encode=True,
2260
+ return_original_cond=True,
2261
+ bs=bs,
2262
+ )
2263
+
2264
+ assert exists(self.concat_keys)
2265
+ assert len(self.concat_keys) == 1
2266
+ c_cat = list()
2267
+ for ck in self.concat_keys:
2268
+ cc = batch[ck]
2269
+ if bs is not None:
2270
+ cc = cc[:bs]
2271
+ cc = cc.to(self.device)
2272
+ cc = self.depth_model(cc)
2273
+ cc = torch.nn.functional.interpolate(
2274
+ cc,
2275
+ size=z.shape[2:],
2276
+ mode="bicubic",
2277
+ align_corners=False,
2278
+ )
2279
+
2280
+ depth_min, depth_max = (
2281
+ torch.amin(cc, dim=[1, 2, 3], keepdim=True),
2282
+ torch.amax(cc, dim=[1, 2, 3], keepdim=True),
2283
+ )
2284
+ cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0
2285
+ c_cat.append(cc)
2286
+ c_cat = torch.cat(c_cat, dim=1)
2287
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
2288
+ if return_first_stage_outputs:
2289
+ return z, all_conds, x, xrec, xc
2290
+ return z, all_conds
2291
+
2292
+ @torch.no_grad()
2293
+ def log_images(self, *args, **kwargs):
2294
+ log = super().log_images(*args, **kwargs)
2295
+ depth = self.depth_model(args[0][self.depth_stage_key])
2296
+ depth_min, depth_max = (
2297
+ torch.amin(depth, dim=[1, 2, 3], keepdim=True),
2298
+ torch.amax(depth, dim=[1, 2, 3], keepdim=True),
2299
+ )
2300
+ log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0
2301
+ return log
2302
+
2303
+
2304
+ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
2305
+ """
2306
+ condition on low-res image (and optionally on some spatial noise augmentation)
2307
+ """
2308
+
2309
+ def __init__(
2310
+ self,
2311
+ concat_keys=("lr",),
2312
+ reshuffle_patch_size=None,
2313
+ low_scale_config=None,
2314
+ low_scale_key=None,
2315
+ *args,
2316
+ **kwargs,
2317
+ ):
2318
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
2319
+ self.reshuffle_patch_size = reshuffle_patch_size
2320
+ self.low_scale_model = None
2321
+ if low_scale_config is not None:
2322
+ print("Initializing a low-scale model")
2323
+ assert exists(low_scale_key)
2324
+ self.instantiate_low_stage(low_scale_config)
2325
+ self.low_scale_key = low_scale_key
2326
+
2327
+ def instantiate_low_stage(self, config):
2328
+ model = instantiate_from_config(config)
2329
+ self.low_scale_model = model.eval()
2330
+ self.low_scale_model.train = disabled_train
2331
+ for param in self.low_scale_model.parameters():
2332
+ param.requires_grad = False
2333
+
2334
+ @torch.no_grad()
2335
+ def get_input(
2336
+ self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
2337
+ ):
2338
+ # note: restricted to non-trainable encoders currently
2339
+ assert (
2340
+ not self.cond_stage_trainable
2341
+ ), "trainable cond stages not yet supported for upscaling-ft"
2342
+ z, c, x, xrec, xc = super().get_input(
2343
+ batch,
2344
+ self.first_stage_key,
2345
+ return_first_stage_outputs=True,
2346
+ force_c_encode=True,
2347
+ return_original_cond=True,
2348
+ bs=bs,
2349
+ )
2350
+
2351
+ assert exists(self.concat_keys)
2352
+ assert len(self.concat_keys) == 1
2353
+ # optionally make spatial noise_level here
2354
+ c_cat = list()
2355
+ noise_level = None
2356
+ for ck in self.concat_keys:
2357
+ cc = batch[ck]
2358
+ cc = rearrange(cc, "b h w c -> b c h w")
2359
+ if exists(self.reshuffle_patch_size):
2360
+ assert isinstance(self.reshuffle_patch_size, int)
2361
+ cc = rearrange(
2362
+ cc,
2363
+ "b c (p1 h) (p2 w) -> b (p1 p2 c) h w",
2364
+ p1=self.reshuffle_patch_size,
2365
+ p2=self.reshuffle_patch_size,
2366
+ )
2367
+ if bs is not None:
2368
+ cc = cc[:bs]
2369
+ cc = cc.to(self.device)
2370
+ if exists(self.low_scale_model) and ck == self.low_scale_key:
2371
+ cc, noise_level = self.low_scale_model(cc)
2372
+ c_cat.append(cc)
2373
+ c_cat = torch.cat(c_cat, dim=1)
2374
+ if exists(noise_level):
2375
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
2376
+ else:
2377
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
2378
+ if return_first_stage_outputs:
2379
+ return z, all_conds, x, xrec, xc
2380
+ return z, all_conds
2381
+
2382
+ @torch.no_grad()
2383
+ def log_images(self, *args, **kwargs):
2384
+ log = super().log_images(*args, **kwargs)
2385
+ log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w")
2386
+ return log
sorawm/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampler import DPMSolverSampler
sorawm/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py ADDED
@@ -0,0 +1,1464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from tqdm import tqdm
6
+
7
+
8
+ class NoiseScheduleVP:
9
+ def __init__(
10
+ self,
11
+ schedule="discrete",
12
+ betas=None,
13
+ alphas_cumprod=None,
14
+ continuous_beta_0=0.1,
15
+ continuous_beta_1=20.0,
16
+ ):
17
+ """Create a wrapper class for the forward SDE (VP type).
18
+ ***
19
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
20
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
21
+ ***
22
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
23
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
24
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
25
+ log_alpha_t = self.marginal_log_mean_coeff(t)
26
+ sigma_t = self.marginal_std(t)
27
+ lambda_t = self.marginal_lambda(t)
28
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
29
+ t = self.inverse_lambda(lambda_t)
30
+ ===============================================================
31
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
32
+ 1. For discrete-time DPMs:
33
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
34
+ t_i = (i + 1) / N
35
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
36
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
37
+ Args:
38
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
39
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
40
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
41
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
42
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
43
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
44
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
45
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
46
+ and
47
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
48
+ 2. For continuous-time DPMs:
49
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
50
+ schedule are the default settings in DDPM and improved-DDPM:
51
+ Args:
52
+ beta_min: A `float` number. The smallest beta for the linear schedule.
53
+ beta_max: A `float` number. The largest beta for the linear schedule.
54
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
55
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
56
+ T: A `float` number. The ending time of the forward process.
57
+ ===============================================================
58
+ Args:
59
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
60
+ 'linear' or 'cosine' for continuous-time DPMs.
61
+ Returns:
62
+ A wrapper object of the forward SDE (VP type).
63
+
64
+ ===============================================================
65
+ Example:
66
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
67
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
68
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
69
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
70
+ # For continuous-time DPMs (VPSDE), linear schedule:
71
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
72
+ """
73
+
74
+ if schedule not in ["discrete", "linear", "cosine"]:
75
+ raise ValueError(
76
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
77
+ schedule
78
+ )
79
+ )
80
+
81
+ self.schedule = schedule
82
+ if schedule == "discrete":
83
+ if betas is not None:
84
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
85
+ else:
86
+ assert alphas_cumprod is not None
87
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
88
+ self.total_N = len(log_alphas)
89
+ self.T = 1.0
90
+ self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape(
91
+ (1, -1)
92
+ )
93
+ self.log_alpha_array = log_alphas.reshape(
94
+ (
95
+ 1,
96
+ -1,
97
+ )
98
+ )
99
+ else:
100
+ self.total_N = 1000
101
+ self.beta_0 = continuous_beta_0
102
+ self.beta_1 = continuous_beta_1
103
+ self.cosine_s = 0.008
104
+ self.cosine_beta_max = 999.0
105
+ self.cosine_t_max = (
106
+ math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
107
+ * 2.0
108
+ * (1.0 + self.cosine_s)
109
+ / math.pi
110
+ - self.cosine_s
111
+ )
112
+ self.cosine_log_alpha_0 = math.log(
113
+ math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)
114
+ )
115
+ self.schedule = schedule
116
+ if schedule == "cosine":
117
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
118
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
119
+ self.T = 0.9946
120
+ else:
121
+ self.T = 1.0
122
+
123
+ def marginal_log_mean_coeff(self, t):
124
+ """
125
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
126
+ """
127
+ if self.schedule == "discrete":
128
+ return interpolate_fn(
129
+ t.reshape((-1, 1)),
130
+ self.t_array.to(t.device),
131
+ self.log_alpha_array.to(t.device),
132
+ ).reshape((-1))
133
+ elif self.schedule == "linear":
134
+ return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
135
+ elif self.schedule == "cosine":
136
+ log_alpha_fn = lambda s: torch.log(
137
+ torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)
138
+ )
139
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
140
+ return log_alpha_t
141
+
142
+ def marginal_alpha(self, t):
143
+ """
144
+ Compute alpha_t of a given continuous-time label t in [0, T].
145
+ """
146
+ return torch.exp(self.marginal_log_mean_coeff(t))
147
+
148
+ def marginal_std(self, t):
149
+ """
150
+ Compute sigma_t of a given continuous-time label t in [0, T].
151
+ """
152
+ return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
153
+
154
+ def marginal_lambda(self, t):
155
+ """
156
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
157
+ """
158
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
159
+ log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
160
+ return log_mean_coeff - log_std
161
+
162
+ def inverse_lambda(self, lamb):
163
+ """
164
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
165
+ """
166
+ if self.schedule == "linear":
167
+ tmp = (
168
+ 2.0
169
+ * (self.beta_1 - self.beta_0)
170
+ * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
171
+ )
172
+ Delta = self.beta_0**2 + tmp
173
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
174
+ elif self.schedule == "discrete":
175
+ log_alpha = -0.5 * torch.logaddexp(
176
+ torch.zeros((1,)).to(lamb.device), -2.0 * lamb
177
+ )
178
+ t = interpolate_fn(
179
+ log_alpha.reshape((-1, 1)),
180
+ torch.flip(self.log_alpha_array.to(lamb.device), [1]),
181
+ torch.flip(self.t_array.to(lamb.device), [1]),
182
+ )
183
+ return t.reshape((-1,))
184
+ else:
185
+ log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
186
+ t_fn = (
187
+ lambda log_alpha_t: torch.arccos(
188
+ torch.exp(log_alpha_t + self.cosine_log_alpha_0)
189
+ )
190
+ * 2.0
191
+ * (1.0 + self.cosine_s)
192
+ / math.pi
193
+ - self.cosine_s
194
+ )
195
+ t = t_fn(log_alpha)
196
+ return t
197
+
198
+
199
+ def model_wrapper(
200
+ model,
201
+ noise_schedule,
202
+ model_type="noise",
203
+ model_kwargs={},
204
+ guidance_type="uncond",
205
+ condition=None,
206
+ unconditional_condition=None,
207
+ guidance_scale=1.0,
208
+ classifier_fn=None,
209
+ classifier_kwargs={},
210
+ ):
211
+ """Create a wrapper function for the noise prediction model.
212
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
213
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
214
+ We support four types of the diffusion model by setting `model_type`:
215
+ 1. "noise": noise prediction model. (Trained by predicting noise).
216
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
217
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
218
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
219
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
220
+ arXiv preprint arXiv:2202.00512 (2022).
221
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
222
+ arXiv preprint arXiv:2210.02303 (2022).
223
+
224
+ 4. "score": marginal score function. (Trained by denoising score matching).
225
+ Note that the score function and the noise prediction model follows a simple relationship:
226
+ ```
227
+ noise(x_t, t) = -sigma_t * score(x_t, t)
228
+ ```
229
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
230
+ 1. "uncond": unconditional sampling by DPMs.
231
+ The input `model` has the following format:
232
+ ``
233
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
234
+ ``
235
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
236
+ The input `model` has the following format:
237
+ ``
238
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
239
+ ``
240
+ The input `classifier_fn` has the following format:
241
+ ``
242
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
243
+ ``
244
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
245
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
246
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
247
+ The input `model` has the following format:
248
+ ``
249
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
250
+ ``
251
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
252
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
253
+ arXiv preprint arXiv:2207.12598 (2022).
254
+
255
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
256
+ or continuous-time labels (i.e. epsilon to T).
257
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
258
+ ``
259
+ def model_fn(x, t_continuous) -> noise:
260
+ t_input = get_model_input_time(t_continuous)
261
+ return noise_pred(model, x, t_input, **model_kwargs)
262
+ ``
263
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
264
+ ===============================================================
265
+ Args:
266
+ model: A diffusion model with the corresponding format described above.
267
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
268
+ model_type: A `str`. The parameterization type of the diffusion model.
269
+ "noise" or "x_start" or "v" or "score".
270
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
271
+ guidance_type: A `str`. The type of the guidance for sampling.
272
+ "uncond" or "classifier" or "classifier-free".
273
+ condition: A pytorch tensor. The condition for the guided sampling.
274
+ Only used for "classifier" or "classifier-free" guidance type.
275
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
276
+ Only used for "classifier-free" guidance type.
277
+ guidance_scale: A `float`. The scale for the guided sampling.
278
+ classifier_fn: A classifier function. Only used for the classifier guidance.
279
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
280
+ Returns:
281
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
282
+ """
283
+
284
+ def get_model_input_time(t_continuous):
285
+ """
286
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
287
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
288
+ For continuous-time DPMs, we just use `t_continuous`.
289
+ """
290
+ if noise_schedule.schedule == "discrete":
291
+ return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
292
+ else:
293
+ return t_continuous
294
+
295
+ def noise_pred_fn(x, t_continuous, cond=None):
296
+ if t_continuous.reshape((-1,)).shape[0] == 1:
297
+ t_continuous = t_continuous.expand((x.shape[0]))
298
+ t_input = get_model_input_time(t_continuous)
299
+ if cond is None:
300
+ output = model(x, t_input, **model_kwargs)
301
+ else:
302
+ output = model(x, t_input, cond, **model_kwargs)
303
+ if model_type == "noise":
304
+ return output
305
+ elif model_type == "x_start":
306
+ alpha_t, sigma_t = (
307
+ noise_schedule.marginal_alpha(t_continuous),
308
+ noise_schedule.marginal_std(t_continuous),
309
+ )
310
+ dims = x.dim()
311
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(
312
+ sigma_t, dims
313
+ )
314
+ elif model_type == "v":
315
+ alpha_t, sigma_t = (
316
+ noise_schedule.marginal_alpha(t_continuous),
317
+ noise_schedule.marginal_std(t_continuous),
318
+ )
319
+ dims = x.dim()
320
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
321
+ elif model_type == "score":
322
+ sigma_t = noise_schedule.marginal_std(t_continuous)
323
+ dims = x.dim()
324
+ return -expand_dims(sigma_t, dims) * output
325
+
326
+ def cond_grad_fn(x, t_input):
327
+ """
328
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
329
+ """
330
+ with torch.enable_grad():
331
+ x_in = x.detach().requires_grad_(True)
332
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
333
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
334
+
335
+ def model_fn(x, t_continuous):
336
+ """
337
+ The noise predicition model function that is used for DPM-Solver.
338
+ """
339
+ if t_continuous.reshape((-1,)).shape[0] == 1:
340
+ t_continuous = t_continuous.expand((x.shape[0]))
341
+ if guidance_type == "uncond":
342
+ return noise_pred_fn(x, t_continuous)
343
+ elif guidance_type == "classifier":
344
+ assert classifier_fn is not None
345
+ t_input = get_model_input_time(t_continuous)
346
+ cond_grad = cond_grad_fn(x, t_input)
347
+ sigma_t = noise_schedule.marginal_std(t_continuous)
348
+ noise = noise_pred_fn(x, t_continuous)
349
+ return (
350
+ noise
351
+ - guidance_scale
352
+ * expand_dims(sigma_t, dims=cond_grad.dim())
353
+ * cond_grad
354
+ )
355
+ elif guidance_type == "classifier-free":
356
+ if guidance_scale == 1.0 or unconditional_condition is None:
357
+ return noise_pred_fn(x, t_continuous, cond=condition)
358
+ else:
359
+ x_in = torch.cat([x] * 2)
360
+ t_in = torch.cat([t_continuous] * 2)
361
+ c_in = torch.cat([unconditional_condition, condition])
362
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
363
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
364
+
365
+ assert model_type in ["noise", "x_start", "v"]
366
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
367
+ return model_fn
368
+
369
+
370
+ class DPM_Solver:
371
+ def __init__(
372
+ self,
373
+ model_fn,
374
+ noise_schedule,
375
+ predict_x0=False,
376
+ thresholding=False,
377
+ max_val=1.0,
378
+ ):
379
+ """Construct a DPM-Solver.
380
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
381
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
382
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
383
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
384
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
385
+ Args:
386
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
387
+ ``
388
+ def model_fn(x, t_continuous):
389
+ return noise
390
+ ``
391
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
392
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
393
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
394
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
395
+
396
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
397
+ """
398
+ self.model = model_fn
399
+ self.noise_schedule = noise_schedule
400
+ self.predict_x0 = predict_x0
401
+ self.thresholding = thresholding
402
+ self.max_val = max_val
403
+
404
+ def noise_prediction_fn(self, x, t):
405
+ """
406
+ Return the noise prediction model.
407
+ """
408
+ return self.model(x, t)
409
+
410
+ def data_prediction_fn(self, x, t):
411
+ """
412
+ Return the data prediction model (with thresholding).
413
+ """
414
+ noise = self.noise_prediction_fn(x, t)
415
+ dims = x.dim()
416
+ alpha_t, sigma_t = (
417
+ self.noise_schedule.marginal_alpha(t),
418
+ self.noise_schedule.marginal_std(t),
419
+ )
420
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
421
+ if self.thresholding:
422
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
423
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
424
+ s = expand_dims(
425
+ torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims
426
+ )
427
+ x0 = torch.clamp(x0, -s, s) / s
428
+ return x0
429
+
430
+ def model_fn(self, x, t):
431
+ """
432
+ Convert the model to the noise prediction model or the data prediction model.
433
+ """
434
+ if self.predict_x0:
435
+ return self.data_prediction_fn(x, t)
436
+ else:
437
+ return self.noise_prediction_fn(x, t)
438
+
439
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
440
+ """Compute the intermediate time steps for sampling.
441
+ Args:
442
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
443
+ - 'logSNR': uniform logSNR for the time steps.
444
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
445
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
446
+ t_T: A `float`. The starting time of the sampling (default is T).
447
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
448
+ N: A `int`. The total number of the spacing of the time steps.
449
+ device: A torch device.
450
+ Returns:
451
+ A pytorch tensor of the time steps, with the shape (N + 1,).
452
+ """
453
+ if skip_type == "logSNR":
454
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
455
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
456
+ logSNR_steps = torch.linspace(
457
+ lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1
458
+ ).to(device)
459
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
460
+ elif skip_type == "time_uniform":
461
+ return torch.linspace(t_T, t_0, N + 1).to(device)
462
+ elif skip_type == "time_quadratic":
463
+ t_order = 2
464
+ t = (
465
+ torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1)
466
+ .pow(t_order)
467
+ .to(device)
468
+ )
469
+ return t
470
+ else:
471
+ raise ValueError(
472
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(
473
+ skip_type
474
+ )
475
+ )
476
+
477
+ def get_orders_and_timesteps_for_singlestep_solver(
478
+ self, steps, order, skip_type, t_T, t_0, device
479
+ ):
480
+ """
481
+ Get the order of each step for sampling by the singlestep DPM-Solver.
482
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
483
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
484
+ - If order == 1:
485
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
486
+ - If order == 2:
487
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
488
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
489
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
490
+ - If order == 3:
491
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
492
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
493
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
494
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
495
+ ============================================
496
+ Args:
497
+ order: A `int`. The max order for the solver (2 or 3).
498
+ steps: A `int`. The total number of function evaluations (NFE).
499
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
500
+ - 'logSNR': uniform logSNR for the time steps.
501
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
502
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
503
+ t_T: A `float`. The starting time of the sampling (default is T).
504
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
505
+ device: A torch device.
506
+ Returns:
507
+ orders: A list of the solver order of each step.
508
+ """
509
+ if order == 3:
510
+ K = steps // 3 + 1
511
+ if steps % 3 == 0:
512
+ orders = [
513
+ 3,
514
+ ] * (
515
+ K - 2
516
+ ) + [2, 1]
517
+ elif steps % 3 == 1:
518
+ orders = [
519
+ 3,
520
+ ] * (
521
+ K - 1
522
+ ) + [1]
523
+ else:
524
+ orders = [
525
+ 3,
526
+ ] * (
527
+ K - 1
528
+ ) + [2]
529
+ elif order == 2:
530
+ if steps % 2 == 0:
531
+ K = steps // 2
532
+ orders = [
533
+ 2,
534
+ ] * K
535
+ else:
536
+ K = steps // 2 + 1
537
+ orders = [
538
+ 2,
539
+ ] * (
540
+ K - 1
541
+ ) + [1]
542
+ elif order == 1:
543
+ K = 1
544
+ orders = [
545
+ 1,
546
+ ] * steps
547
+ else:
548
+ raise ValueError("'order' must be '1' or '2' or '3'.")
549
+ if skip_type == "logSNR":
550
+ # To reproduce the results in DPM-Solver paper
551
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
552
+ else:
553
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
554
+ torch.cumsum(
555
+ torch.tensor(
556
+ [
557
+ 0,
558
+ ]
559
+ + orders
560
+ )
561
+ ).to(device)
562
+ ]
563
+ return timesteps_outer, orders
564
+
565
+ def denoise_to_zero_fn(self, x, s):
566
+ """
567
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
568
+ """
569
+ return self.data_prediction_fn(x, s)
570
+
571
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
572
+ """
573
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
574
+ Args:
575
+ x: A pytorch tensor. The initial value at time `s`.
576
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
577
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
578
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
579
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
580
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
581
+ Returns:
582
+ x_t: A pytorch tensor. The approximated solution at time `t`.
583
+ """
584
+ ns = self.noise_schedule
585
+ dims = x.dim()
586
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
587
+ h = lambda_t - lambda_s
588
+ log_alpha_s, log_alpha_t = (
589
+ ns.marginal_log_mean_coeff(s),
590
+ ns.marginal_log_mean_coeff(t),
591
+ )
592
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
593
+ alpha_t = torch.exp(log_alpha_t)
594
+
595
+ if self.predict_x0:
596
+ phi_1 = torch.expm1(-h)
597
+ if model_s is None:
598
+ model_s = self.model_fn(x, s)
599
+ x_t = (
600
+ expand_dims(sigma_t / sigma_s, dims) * x
601
+ - expand_dims(alpha_t * phi_1, dims) * model_s
602
+ )
603
+ if return_intermediate:
604
+ return x_t, {"model_s": model_s}
605
+ else:
606
+ return x_t
607
+ else:
608
+ phi_1 = torch.expm1(h)
609
+ if model_s is None:
610
+ model_s = self.model_fn(x, s)
611
+ x_t = (
612
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
613
+ - expand_dims(sigma_t * phi_1, dims) * model_s
614
+ )
615
+ if return_intermediate:
616
+ return x_t, {"model_s": model_s}
617
+ else:
618
+ return x_t
619
+
620
+ def singlestep_dpm_solver_second_update(
621
+ self,
622
+ x,
623
+ s,
624
+ t,
625
+ r1=0.5,
626
+ model_s=None,
627
+ return_intermediate=False,
628
+ solver_type="dpm_solver",
629
+ ):
630
+ """
631
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
632
+ Args:
633
+ x: A pytorch tensor. The initial value at time `s`.
634
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
635
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
636
+ r1: A `float`. The hyperparameter of the second-order solver.
637
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
638
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
639
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
640
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
641
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
642
+ Returns:
643
+ x_t: A pytorch tensor. The approximated solution at time `t`.
644
+ """
645
+ if solver_type not in ["dpm_solver", "taylor"]:
646
+ raise ValueError(
647
+ "'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(
648
+ solver_type
649
+ )
650
+ )
651
+ if r1 is None:
652
+ r1 = 0.5
653
+ ns = self.noise_schedule
654
+ dims = x.dim()
655
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
656
+ h = lambda_t - lambda_s
657
+ lambda_s1 = lambda_s + r1 * h
658
+ s1 = ns.inverse_lambda(lambda_s1)
659
+ log_alpha_s, log_alpha_s1, log_alpha_t = (
660
+ ns.marginal_log_mean_coeff(s),
661
+ ns.marginal_log_mean_coeff(s1),
662
+ ns.marginal_log_mean_coeff(t),
663
+ )
664
+ sigma_s, sigma_s1, sigma_t = (
665
+ ns.marginal_std(s),
666
+ ns.marginal_std(s1),
667
+ ns.marginal_std(t),
668
+ )
669
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
670
+
671
+ if self.predict_x0:
672
+ phi_11 = torch.expm1(-r1 * h)
673
+ phi_1 = torch.expm1(-h)
674
+
675
+ if model_s is None:
676
+ model_s = self.model_fn(x, s)
677
+ x_s1 = (
678
+ expand_dims(sigma_s1 / sigma_s, dims) * x
679
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
680
+ )
681
+ model_s1 = self.model_fn(x_s1, s1)
682
+ if solver_type == "dpm_solver":
683
+ x_t = (
684
+ expand_dims(sigma_t / sigma_s, dims) * x
685
+ - expand_dims(alpha_t * phi_1, dims) * model_s
686
+ - (0.5 / r1)
687
+ * expand_dims(alpha_t * phi_1, dims)
688
+ * (model_s1 - model_s)
689
+ )
690
+ elif solver_type == "taylor":
691
+ x_t = (
692
+ expand_dims(sigma_t / sigma_s, dims) * x
693
+ - expand_dims(alpha_t * phi_1, dims) * model_s
694
+ + (1.0 / r1)
695
+ * expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims)
696
+ * (model_s1 - model_s)
697
+ )
698
+ else:
699
+ phi_11 = torch.expm1(r1 * h)
700
+ phi_1 = torch.expm1(h)
701
+
702
+ if model_s is None:
703
+ model_s = self.model_fn(x, s)
704
+ x_s1 = (
705
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
706
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
707
+ )
708
+ model_s1 = self.model_fn(x_s1, s1)
709
+ if solver_type == "dpm_solver":
710
+ x_t = (
711
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
712
+ - expand_dims(sigma_t * phi_1, dims) * model_s
713
+ - (0.5 / r1)
714
+ * expand_dims(sigma_t * phi_1, dims)
715
+ * (model_s1 - model_s)
716
+ )
717
+ elif solver_type == "taylor":
718
+ x_t = (
719
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
720
+ - expand_dims(sigma_t * phi_1, dims) * model_s
721
+ - (1.0 / r1)
722
+ * expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims)
723
+ * (model_s1 - model_s)
724
+ )
725
+ if return_intermediate:
726
+ return x_t, {"model_s": model_s, "model_s1": model_s1}
727
+ else:
728
+ return x_t
729
+
730
+ def singlestep_dpm_solver_third_update(
731
+ self,
732
+ x,
733
+ s,
734
+ t,
735
+ r1=1.0 / 3.0,
736
+ r2=2.0 / 3.0,
737
+ model_s=None,
738
+ model_s1=None,
739
+ return_intermediate=False,
740
+ solver_type="dpm_solver",
741
+ ):
742
+ """
743
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
744
+ Args:
745
+ x: A pytorch tensor. The initial value at time `s`.
746
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
747
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
748
+ r1: A `float`. The hyperparameter of the third-order solver.
749
+ r2: A `float`. The hyperparameter of the third-order solver.
750
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
751
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
752
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
753
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
754
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
755
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
756
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
757
+ Returns:
758
+ x_t: A pytorch tensor. The approximated solution at time `t`.
759
+ """
760
+ if solver_type not in ["dpm_solver", "taylor"]:
761
+ raise ValueError(
762
+ "'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(
763
+ solver_type
764
+ )
765
+ )
766
+ if r1 is None:
767
+ r1 = 1.0 / 3.0
768
+ if r2 is None:
769
+ r2 = 2.0 / 3.0
770
+ ns = self.noise_schedule
771
+ dims = x.dim()
772
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
773
+ h = lambda_t - lambda_s
774
+ lambda_s1 = lambda_s + r1 * h
775
+ lambda_s2 = lambda_s + r2 * h
776
+ s1 = ns.inverse_lambda(lambda_s1)
777
+ s2 = ns.inverse_lambda(lambda_s2)
778
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = (
779
+ ns.marginal_log_mean_coeff(s),
780
+ ns.marginal_log_mean_coeff(s1),
781
+ ns.marginal_log_mean_coeff(s2),
782
+ ns.marginal_log_mean_coeff(t),
783
+ )
784
+ sigma_s, sigma_s1, sigma_s2, sigma_t = (
785
+ ns.marginal_std(s),
786
+ ns.marginal_std(s1),
787
+ ns.marginal_std(s2),
788
+ ns.marginal_std(t),
789
+ )
790
+ alpha_s1, alpha_s2, alpha_t = (
791
+ torch.exp(log_alpha_s1),
792
+ torch.exp(log_alpha_s2),
793
+ torch.exp(log_alpha_t),
794
+ )
795
+
796
+ if self.predict_x0:
797
+ phi_11 = torch.expm1(-r1 * h)
798
+ phi_12 = torch.expm1(-r2 * h)
799
+ phi_1 = torch.expm1(-h)
800
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0
801
+ phi_2 = phi_1 / h + 1.0
802
+ phi_3 = phi_2 / h - 0.5
803
+
804
+ if model_s is None:
805
+ model_s = self.model_fn(x, s)
806
+ if model_s1 is None:
807
+ x_s1 = (
808
+ expand_dims(sigma_s1 / sigma_s, dims) * x
809
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
810
+ )
811
+ model_s1 = self.model_fn(x_s1, s1)
812
+ x_s2 = (
813
+ expand_dims(sigma_s2 / sigma_s, dims) * x
814
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
815
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
816
+ )
817
+ model_s2 = self.model_fn(x_s2, s2)
818
+ if solver_type == "dpm_solver":
819
+ x_t = (
820
+ expand_dims(sigma_t / sigma_s, dims) * x
821
+ - expand_dims(alpha_t * phi_1, dims) * model_s
822
+ + (1.0 / r2)
823
+ * expand_dims(alpha_t * phi_2, dims)
824
+ * (model_s2 - model_s)
825
+ )
826
+ elif solver_type == "taylor":
827
+ D1_0 = (1.0 / r1) * (model_s1 - model_s)
828
+ D1_1 = (1.0 / r2) * (model_s2 - model_s)
829
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
830
+ D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
831
+ x_t = (
832
+ expand_dims(sigma_t / sigma_s, dims) * x
833
+ - expand_dims(alpha_t * phi_1, dims) * model_s
834
+ + expand_dims(alpha_t * phi_2, dims) * D1
835
+ - expand_dims(alpha_t * phi_3, dims) * D2
836
+ )
837
+ else:
838
+ phi_11 = torch.expm1(r1 * h)
839
+ phi_12 = torch.expm1(r2 * h)
840
+ phi_1 = torch.expm1(h)
841
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0
842
+ phi_2 = phi_1 / h - 1.0
843
+ phi_3 = phi_2 / h - 0.5
844
+
845
+ if model_s is None:
846
+ model_s = self.model_fn(x, s)
847
+ if model_s1 is None:
848
+ x_s1 = (
849
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
850
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
851
+ )
852
+ model_s1 = self.model_fn(x_s1, s1)
853
+ x_s2 = (
854
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
855
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
856
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
857
+ )
858
+ model_s2 = self.model_fn(x_s2, s2)
859
+ if solver_type == "dpm_solver":
860
+ x_t = (
861
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
862
+ - expand_dims(sigma_t * phi_1, dims) * model_s
863
+ - (1.0 / r2)
864
+ * expand_dims(sigma_t * phi_2, dims)
865
+ * (model_s2 - model_s)
866
+ )
867
+ elif solver_type == "taylor":
868
+ D1_0 = (1.0 / r1) * (model_s1 - model_s)
869
+ D1_1 = (1.0 / r2) * (model_s2 - model_s)
870
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
871
+ D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
872
+ x_t = (
873
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
874
+ - expand_dims(sigma_t * phi_1, dims) * model_s
875
+ - expand_dims(sigma_t * phi_2, dims) * D1
876
+ - expand_dims(sigma_t * phi_3, dims) * D2
877
+ )
878
+
879
+ if return_intermediate:
880
+ return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2}
881
+ else:
882
+ return x_t
883
+
884
+ def multistep_dpm_solver_second_update(
885
+ self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"
886
+ ):
887
+ """
888
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
889
+ Args:
890
+ x: A pytorch tensor. The initial value at time `s`.
891
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
892
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
893
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
894
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
895
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
896
+ Returns:
897
+ x_t: A pytorch tensor. The approximated solution at time `t`.
898
+ """
899
+ if solver_type not in ["dpm_solver", "taylor"]:
900
+ raise ValueError(
901
+ "'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(
902
+ solver_type
903
+ )
904
+ )
905
+ ns = self.noise_schedule
906
+ dims = x.dim()
907
+ model_prev_1, model_prev_0 = model_prev_list
908
+ t_prev_1, t_prev_0 = t_prev_list
909
+ lambda_prev_1, lambda_prev_0, lambda_t = (
910
+ ns.marginal_lambda(t_prev_1),
911
+ ns.marginal_lambda(t_prev_0),
912
+ ns.marginal_lambda(t),
913
+ )
914
+ log_alpha_prev_0, log_alpha_t = (
915
+ ns.marginal_log_mean_coeff(t_prev_0),
916
+ ns.marginal_log_mean_coeff(t),
917
+ )
918
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
919
+ alpha_t = torch.exp(log_alpha_t)
920
+
921
+ h_0 = lambda_prev_0 - lambda_prev_1
922
+ h = lambda_t - lambda_prev_0
923
+ r0 = h_0 / h
924
+ D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1)
925
+ if self.predict_x0:
926
+ if solver_type == "dpm_solver":
927
+ x_t = (
928
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
929
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
930
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * D1_0
931
+ )
932
+ elif solver_type == "taylor":
933
+ x_t = (
934
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
935
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
936
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims)
937
+ * D1_0
938
+ )
939
+ else:
940
+ if solver_type == "dpm_solver":
941
+ x_t = (
942
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
943
+ - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
944
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * D1_0
945
+ )
946
+ elif solver_type == "taylor":
947
+ x_t = (
948
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
949
+ - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
950
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims)
951
+ * D1_0
952
+ )
953
+ return x_t
954
+
955
+ def multistep_dpm_solver_third_update(
956
+ self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"
957
+ ):
958
+ """
959
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
960
+ Args:
961
+ x: A pytorch tensor. The initial value at time `s`.
962
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
963
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
964
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
965
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
966
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
967
+ Returns:
968
+ x_t: A pytorch tensor. The approximated solution at time `t`.
969
+ """
970
+ ns = self.noise_schedule
971
+ dims = x.dim()
972
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
973
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
974
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
975
+ ns.marginal_lambda(t_prev_2),
976
+ ns.marginal_lambda(t_prev_1),
977
+ ns.marginal_lambda(t_prev_0),
978
+ ns.marginal_lambda(t),
979
+ )
980
+ log_alpha_prev_0, log_alpha_t = (
981
+ ns.marginal_log_mean_coeff(t_prev_0),
982
+ ns.marginal_log_mean_coeff(t),
983
+ )
984
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
985
+ alpha_t = torch.exp(log_alpha_t)
986
+
987
+ h_1 = lambda_prev_1 - lambda_prev_2
988
+ h_0 = lambda_prev_0 - lambda_prev_1
989
+ h = lambda_t - lambda_prev_0
990
+ r0, r1 = h_0 / h, h_1 / h
991
+ D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1)
992
+ D1_1 = expand_dims(1.0 / r1, dims) * (model_prev_1 - model_prev_2)
993
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
994
+ D2 = expand_dims(1.0 / (r0 + r1), dims) * (D1_0 - D1_1)
995
+ if self.predict_x0:
996
+ x_t = (
997
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
998
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
999
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1
1000
+ - expand_dims(
1001
+ alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims
1002
+ )
1003
+ * D2
1004
+ )
1005
+ else:
1006
+ x_t = (
1007
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
1008
+ - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
1009
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1
1010
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims)
1011
+ * D2
1012
+ )
1013
+ return x_t
1014
+
1015
+ def singlestep_dpm_solver_update(
1016
+ self,
1017
+ x,
1018
+ s,
1019
+ t,
1020
+ order,
1021
+ return_intermediate=False,
1022
+ solver_type="dpm_solver",
1023
+ r1=None,
1024
+ r2=None,
1025
+ ):
1026
+ """
1027
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
1028
+ Args:
1029
+ x: A pytorch tensor. The initial value at time `s`.
1030
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
1031
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
1032
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
1033
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
1034
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
1035
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
1036
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
1037
+ r2: A `float`. The hyperparameter of the third-order solver.
1038
+ Returns:
1039
+ x_t: A pytorch tensor. The approximated solution at time `t`.
1040
+ """
1041
+ if order == 1:
1042
+ return self.dpm_solver_first_update(
1043
+ x, s, t, return_intermediate=return_intermediate
1044
+ )
1045
+ elif order == 2:
1046
+ return self.singlestep_dpm_solver_second_update(
1047
+ x,
1048
+ s,
1049
+ t,
1050
+ return_intermediate=return_intermediate,
1051
+ solver_type=solver_type,
1052
+ r1=r1,
1053
+ )
1054
+ elif order == 3:
1055
+ return self.singlestep_dpm_solver_third_update(
1056
+ x,
1057
+ s,
1058
+ t,
1059
+ return_intermediate=return_intermediate,
1060
+ solver_type=solver_type,
1061
+ r1=r1,
1062
+ r2=r2,
1063
+ )
1064
+ else:
1065
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
1066
+
1067
+ def multistep_dpm_solver_update(
1068
+ self, x, model_prev_list, t_prev_list, t, order, solver_type="dpm_solver"
1069
+ ):
1070
+ """
1071
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
1072
+ Args:
1073
+ x: A pytorch tensor. The initial value at time `s`.
1074
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
1075
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
1076
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
1077
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
1078
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
1079
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
1080
+ Returns:
1081
+ x_t: A pytorch tensor. The approximated solution at time `t`.
1082
+ """
1083
+ if order == 1:
1084
+ return self.dpm_solver_first_update(
1085
+ x, t_prev_list[-1], t, model_s=model_prev_list[-1]
1086
+ )
1087
+ elif order == 2:
1088
+ return self.multistep_dpm_solver_second_update(
1089
+ x, model_prev_list, t_prev_list, t, solver_type=solver_type
1090
+ )
1091
+ elif order == 3:
1092
+ return self.multistep_dpm_solver_third_update(
1093
+ x, model_prev_list, t_prev_list, t, solver_type=solver_type
1094
+ )
1095
+ else:
1096
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
1097
+
1098
+ def dpm_solver_adaptive(
1099
+ self,
1100
+ x,
1101
+ order,
1102
+ t_T,
1103
+ t_0,
1104
+ h_init=0.05,
1105
+ atol=0.0078,
1106
+ rtol=0.05,
1107
+ theta=0.9,
1108
+ t_err=1e-5,
1109
+ solver_type="dpm_solver",
1110
+ ):
1111
+ """
1112
+ The adaptive step size solver based on singlestep DPM-Solver.
1113
+ Args:
1114
+ x: A pytorch tensor. The initial value at time `t_T`.
1115
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
1116
+ t_T: A `float`. The starting time of the sampling (default is T).
1117
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
1118
+ h_init: A `float`. The initial step size (for logSNR).
1119
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
1120
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
1121
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
1122
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
1123
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
1124
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
1125
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
1126
+ Returns:
1127
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
1128
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
1129
+ """
1130
+ ns = self.noise_schedule
1131
+ s = t_T * torch.ones((x.shape[0],)).to(x)
1132
+ lambda_s = ns.marginal_lambda(s)
1133
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
1134
+ h = h_init * torch.ones_like(s).to(x)
1135
+ x_prev = x
1136
+ nfe = 0
1137
+ if order == 2:
1138
+ r1 = 0.5
1139
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(
1140
+ x, s, t, return_intermediate=True
1141
+ )
1142
+ higher_update = (
1143
+ lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(
1144
+ x, s, t, r1=r1, solver_type=solver_type, **kwargs
1145
+ )
1146
+ )
1147
+ elif order == 3:
1148
+ r1, r2 = 1.0 / 3.0, 2.0 / 3.0
1149
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(
1150
+ x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type
1151
+ )
1152
+ higher_update = (
1153
+ lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(
1154
+ x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
1155
+ )
1156
+ )
1157
+ else:
1158
+ raise ValueError(
1159
+ "For adaptive step size solver, order must be 2 or 3, got {}".format(
1160
+ order
1161
+ )
1162
+ )
1163
+ while torch.abs((s - t_0)).mean() > t_err:
1164
+ t = ns.inverse_lambda(lambda_s + h)
1165
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
1166
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
1167
+ delta = torch.max(
1168
+ torch.ones_like(x).to(x) * atol,
1169
+ rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)),
1170
+ )
1171
+ norm_fn = lambda v: torch.sqrt(
1172
+ torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)
1173
+ )
1174
+ E = norm_fn((x_higher - x_lower) / delta).max()
1175
+ if torch.all(E <= 1.0):
1176
+ x = x_higher
1177
+ s = t
1178
+ x_prev = x_lower
1179
+ lambda_s = ns.marginal_lambda(s)
1180
+ h = torch.min(
1181
+ theta * h * torch.float_power(E, -1.0 / order).float(),
1182
+ lambda_0 - lambda_s,
1183
+ )
1184
+ nfe += order
1185
+ print("adaptive solver nfe", nfe)
1186
+ return x
1187
+
1188
+ def sample(
1189
+ self,
1190
+ x,
1191
+ steps=20,
1192
+ t_start=None,
1193
+ t_end=None,
1194
+ order=3,
1195
+ skip_type="time_uniform",
1196
+ method="singlestep",
1197
+ lower_order_final=True,
1198
+ denoise_to_zero=False,
1199
+ solver_type="dpm_solver",
1200
+ atol=0.0078,
1201
+ rtol=0.05,
1202
+ ):
1203
+ """
1204
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
1205
+ =====================================================
1206
+ We support the following algorithms for both noise prediction model and data prediction model:
1207
+ - 'singlestep':
1208
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
1209
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
1210
+ The total number of function evaluations (NFE) == `steps`.
1211
+ Given a fixed NFE == `steps`, the sampling procedure is:
1212
+ - If `order` == 1:
1213
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
1214
+ - If `order` == 2:
1215
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
1216
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
1217
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1218
+ - If `order` == 3:
1219
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
1220
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1221
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
1222
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
1223
+ - 'multistep':
1224
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
1225
+ We initialize the first `order` values by lower order multistep solvers.
1226
+ Given a fixed NFE == `steps`, the sampling procedure is:
1227
+ Denote K = steps.
1228
+ - If `order` == 1:
1229
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
1230
+ - If `order` == 2:
1231
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
1232
+ - If `order` == 3:
1233
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
1234
+ - 'singlestep_fixed':
1235
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
1236
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
1237
+ - 'adaptive':
1238
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
1239
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
1240
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
1241
+ (NFE) and the sample quality.
1242
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
1243
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
1244
+ =====================================================
1245
+ Some advices for choosing the algorithm:
1246
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1247
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
1248
+ e.g.
1249
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
1250
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1251
+ skip_type='time_uniform', method='singlestep')
1252
+ - For **guided sampling with large guidance scale** by DPMs:
1253
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
1254
+ e.g.
1255
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
1256
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1257
+ skip_type='time_uniform', method='multistep')
1258
+ We support three types of `skip_type`:
1259
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1260
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1261
+ - 'time_quadratic': quadratic time for the time steps.
1262
+ =====================================================
1263
+ Args:
1264
+ x: A pytorch tensor. The initial value at time `t_start`
1265
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1266
+ steps: A `int`. The total number of function evaluations (NFE).
1267
+ t_start: A `float`. The starting time of the sampling.
1268
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1269
+ t_end: A `float`. The ending time of the sampling.
1270
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1271
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1272
+ For discrete-time DPMs:
1273
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1274
+ For continuous-time DPMs:
1275
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1276
+ order: A `int`. The order of DPM-Solver.
1277
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1278
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1279
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1280
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1281
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1282
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1283
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1284
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1285
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1286
+ it for high-resolutional images.
1287
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1288
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1289
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1290
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1291
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
1292
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1293
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1294
+ Returns:
1295
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1296
+ """
1297
+ t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
1298
+ t_T = self.noise_schedule.T if t_start is None else t_start
1299
+ device = x.device
1300
+ if method == "adaptive":
1301
+ with torch.no_grad():
1302
+ x = self.dpm_solver_adaptive(
1303
+ x,
1304
+ order=order,
1305
+ t_T=t_T,
1306
+ t_0=t_0,
1307
+ atol=atol,
1308
+ rtol=rtol,
1309
+ solver_type=solver_type,
1310
+ )
1311
+ elif method == "multistep":
1312
+ assert steps >= order
1313
+ timesteps = self.get_time_steps(
1314
+ skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device
1315
+ )
1316
+ assert timesteps.shape[0] - 1 == steps
1317
+ with torch.no_grad():
1318
+ vec_t = timesteps[0].expand((x.shape[0]))
1319
+ model_prev_list = [self.model_fn(x, vec_t)]
1320
+ t_prev_list = [vec_t]
1321
+ # Init the first `order` values by lower order multistep DPM-Solver.
1322
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
1323
+ vec_t = timesteps[init_order].expand(x.shape[0])
1324
+ x = self.multistep_dpm_solver_update(
1325
+ x,
1326
+ model_prev_list,
1327
+ t_prev_list,
1328
+ vec_t,
1329
+ init_order,
1330
+ solver_type=solver_type,
1331
+ )
1332
+ model_prev_list.append(self.model_fn(x, vec_t))
1333
+ t_prev_list.append(vec_t)
1334
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1335
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
1336
+ vec_t = timesteps[step].expand(x.shape[0])
1337
+ if lower_order_final and steps < 15:
1338
+ step_order = min(order, steps + 1 - step)
1339
+ else:
1340
+ step_order = order
1341
+ x = self.multistep_dpm_solver_update(
1342
+ x,
1343
+ model_prev_list,
1344
+ t_prev_list,
1345
+ vec_t,
1346
+ step_order,
1347
+ solver_type=solver_type,
1348
+ )
1349
+ for i in range(order - 1):
1350
+ t_prev_list[i] = t_prev_list[i + 1]
1351
+ model_prev_list[i] = model_prev_list[i + 1]
1352
+ t_prev_list[-1] = vec_t
1353
+ # We do not need to evaluate the final model value.
1354
+ if step < steps:
1355
+ model_prev_list[-1] = self.model_fn(x, vec_t)
1356
+ elif method in ["singlestep", "singlestep_fixed"]:
1357
+ if method == "singlestep":
1358
+ (
1359
+ timesteps_outer,
1360
+ orders,
1361
+ ) = self.get_orders_and_timesteps_for_singlestep_solver(
1362
+ steps=steps,
1363
+ order=order,
1364
+ skip_type=skip_type,
1365
+ t_T=t_T,
1366
+ t_0=t_0,
1367
+ device=device,
1368
+ )
1369
+ elif method == "singlestep_fixed":
1370
+ K = steps // order
1371
+ orders = [
1372
+ order,
1373
+ ] * K
1374
+ timesteps_outer = self.get_time_steps(
1375
+ skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device
1376
+ )
1377
+ for i, order in enumerate(orders):
1378
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
1379
+ timesteps_inner = self.get_time_steps(
1380
+ skip_type=skip_type,
1381
+ t_T=t_T_inner.item(),
1382
+ t_0=t_0_inner.item(),
1383
+ N=order,
1384
+ device=device,
1385
+ )
1386
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1387
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
1388
+ h = lambda_inner[-1] - lambda_inner[0]
1389
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1390
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1391
+ x = self.singlestep_dpm_solver_update(
1392
+ x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2
1393
+ )
1394
+ if denoise_to_zero:
1395
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
1396
+ return x
1397
+
1398
+
1399
+ #############################################################
1400
+ # other utility functions
1401
+ #############################################################
1402
+
1403
+
1404
+ def interpolate_fn(x, xp, yp):
1405
+ """
1406
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1407
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1408
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1409
+ Args:
1410
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1411
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1412
+ yp: PyTorch tensor with shape [C, K].
1413
+ Returns:
1414
+ The function values f(x), with shape [N, C].
1415
+ """
1416
+ N, K = x.shape[0], xp.shape[1]
1417
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1418
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1419
+ x_idx = torch.argmin(x_indices, dim=2)
1420
+ cand_start_idx = x_idx - 1
1421
+ start_idx = torch.where(
1422
+ torch.eq(x_idx, 0),
1423
+ torch.tensor(1, device=x.device),
1424
+ torch.where(
1425
+ torch.eq(x_idx, K),
1426
+ torch.tensor(K - 2, device=x.device),
1427
+ cand_start_idx,
1428
+ ),
1429
+ )
1430
+ end_idx = torch.where(
1431
+ torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1
1432
+ )
1433
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1434
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1435
+ start_idx2 = torch.where(
1436
+ torch.eq(x_idx, 0),
1437
+ torch.tensor(0, device=x.device),
1438
+ torch.where(
1439
+ torch.eq(x_idx, K),
1440
+ torch.tensor(K - 2, device=x.device),
1441
+ cand_start_idx,
1442
+ ),
1443
+ )
1444
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1445
+ start_y = torch.gather(
1446
+ y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)
1447
+ ).squeeze(2)
1448
+ end_y = torch.gather(
1449
+ y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)
1450
+ ).squeeze(2)
1451
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1452
+ return cand
1453
+
1454
+
1455
+ def expand_dims(v, dims):
1456
+ """
1457
+ Expand the tensor `v` to the dim `dims`.
1458
+ Args:
1459
+ `v`: a PyTorch tensor with shape [N].
1460
+ `dim`: a `int`.
1461
+ Returns:
1462
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1463
+ """
1464
+ return v[(...,) + (None,) * (dims - 1)]