Spaces:
Running
Running
Commit
·
4ab551f
1
Parent(s):
f5718c2
Add folder with files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +229 -0
- .python-version +1 -0
- LICENSE +201 -0
- README-zh.md +164 -0
- app.py +106 -0
- datasets/make_yolo_images.py +49 -0
- example.py +9 -0
- ffmpeg/README.md +69 -0
- notebooks/imputation.ipynb +0 -0
- one-click-portable.md +26 -0
- pyproject.toml +39 -0
- resources/first_frame.json +0 -0
- resources/watermark_template.png +0 -0
- sorawm/__init__.py +0 -0
- sorawm/configs.py +27 -0
- sorawm/core.py +197 -0
- sorawm/iopaint/__init__.py +56 -0
- sorawm/iopaint/__main__.py +4 -0
- sorawm/iopaint/api.py +411 -0
- sorawm/iopaint/batch_processing.py +128 -0
- sorawm/iopaint/benchmark.py +109 -0
- sorawm/iopaint/cli.py +245 -0
- sorawm/iopaint/const.py +134 -0
- sorawm/iopaint/download.py +314 -0
- sorawm/iopaint/file_manager/__init__.py +1 -0
- sorawm/iopaint/file_manager/file_manager.py +220 -0
- sorawm/iopaint/file_manager/storage_backends.py +46 -0
- sorawm/iopaint/file_manager/utils.py +64 -0
- sorawm/iopaint/helper.py +411 -0
- sorawm/iopaint/installer.py +11 -0
- sorawm/iopaint/model/__init__.py +38 -0
- sorawm/iopaint/model/anytext/__init__.py +0 -0
- sorawm/iopaint/model/anytext/anytext_model.py +73 -0
- sorawm/iopaint/model/anytext/anytext_pipeline.py +401 -0
- sorawm/iopaint/model/anytext/anytext_sd15.yaml +99 -0
- sorawm/iopaint/model/anytext/cldm/__init__.py +0 -0
- sorawm/iopaint/model/anytext/cldm/cldm.py +780 -0
- sorawm/iopaint/model/anytext/cldm/ddim_hacked.py +486 -0
- sorawm/iopaint/model/anytext/cldm/embedding_manager.py +185 -0
- sorawm/iopaint/model/anytext/cldm/hack.py +128 -0
- sorawm/iopaint/model/anytext/cldm/model.py +41 -0
- sorawm/iopaint/model/anytext/cldm/recognizer.py +302 -0
- sorawm/iopaint/model/anytext/ldm/__init__.py +0 -0
- sorawm/iopaint/model/anytext/ldm/models/__init__.py +0 -0
- sorawm/iopaint/model/anytext/ldm/models/autoencoder.py +275 -0
- sorawm/iopaint/model/anytext/ldm/models/diffusion/__init__.py +0 -0
- sorawm/iopaint/model/anytext/ldm/models/diffusion/ddim.py +525 -0
- sorawm/iopaint/model/anytext/ldm/models/diffusion/ddpm.py +2386 -0
- sorawm/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py +1 -0
- sorawm/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py +1464 -0
.gitignore
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
#poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
#pdm.lock
|
| 116 |
+
#pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
#pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# SageMath parsed files
|
| 135 |
+
*.sage.py
|
| 136 |
+
|
| 137 |
+
# Environments
|
| 138 |
+
.env
|
| 139 |
+
.envrc
|
| 140 |
+
.venv
|
| 141 |
+
env/
|
| 142 |
+
venv/
|
| 143 |
+
ENV/
|
| 144 |
+
env.bak/
|
| 145 |
+
venv.bak/
|
| 146 |
+
|
| 147 |
+
# Spyder project settings
|
| 148 |
+
.spyderproject
|
| 149 |
+
.spyproject
|
| 150 |
+
|
| 151 |
+
# Rope project settings
|
| 152 |
+
.ropeproject
|
| 153 |
+
|
| 154 |
+
# mkdocs documentation
|
| 155 |
+
/site
|
| 156 |
+
|
| 157 |
+
# mypy
|
| 158 |
+
.mypy_cache/
|
| 159 |
+
.dmypy.json
|
| 160 |
+
dmypy.json
|
| 161 |
+
|
| 162 |
+
# Pyre type checker
|
| 163 |
+
.pyre/
|
| 164 |
+
|
| 165 |
+
# pytype static type analyzer
|
| 166 |
+
.pytype/
|
| 167 |
+
|
| 168 |
+
# Cython debug symbols
|
| 169 |
+
cython_debug/
|
| 170 |
+
|
| 171 |
+
# PyCharm
|
| 172 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 173 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 174 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 175 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 176 |
+
#.idea/
|
| 177 |
+
|
| 178 |
+
# Abstra
|
| 179 |
+
# Abstra is an AI-powered process automation framework.
|
| 180 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 181 |
+
# Learn more at https://abstra.io/docs
|
| 182 |
+
.abstra/
|
| 183 |
+
|
| 184 |
+
# Visual Studio Code
|
| 185 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 186 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 187 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 188 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 189 |
+
# .vscode/
|
| 190 |
+
|
| 191 |
+
# Ruff stuff:
|
| 192 |
+
.ruff_cache/
|
| 193 |
+
|
| 194 |
+
# PyPI configuration file
|
| 195 |
+
.pypirc
|
| 196 |
+
|
| 197 |
+
# Cursor
|
| 198 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 199 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 200 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 201 |
+
.cursorignore
|
| 202 |
+
.cursorindexingignore
|
| 203 |
+
|
| 204 |
+
# Marimo
|
| 205 |
+
marimo/_static/
|
| 206 |
+
marimo/_lsp/
|
| 207 |
+
__marimo__/
|
| 208 |
+
output
|
| 209 |
+
|
| 210 |
+
videos
|
| 211 |
+
|
| 212 |
+
datasets/images
|
| 213 |
+
datasets/labels
|
| 214 |
+
datasets/coco8
|
| 215 |
+
.DS_store
|
| 216 |
+
outputs
|
| 217 |
+
yolo11n.pt
|
| 218 |
+
yolo11s.pt
|
| 219 |
+
best.pt
|
| 220 |
+
|
| 221 |
+
.claude
|
| 222 |
+
best.pt
|
| 223 |
+
|
| 224 |
+
runs
|
| 225 |
+
.idea
|
| 226 |
+
working_dir
|
| 227 |
+
data
|
| 228 |
+
upload_to_huggingface.py
|
| 229 |
+
resources/best.pt
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README-zh.md
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SoraWatermarkCleaner
|
| 2 |
+
|
| 3 |
+
[English](README.md) | 中文
|
| 4 |
+
|
| 5 |
+
这个项目提供了一种优雅的方式来移除 Sora2 生成视频中的 Sora 水印。
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
- 移除水印后
|
| 9 |
+
|
| 10 |
+
https://github.com/user-attachments/assets/8cdc075e-7d15-4d04-8fa2-53dd287e5f4c
|
| 11 |
+
|
| 12 |
+
- 原始视频
|
| 13 |
+
|
| 14 |
+
https://github.com/user-attachments/assets/3c850ff1-b8e3-41af-a46f-2c734406e77d
|
| 15 |
+
|
| 16 |
+
⭐️:
|
| 17 |
+
|
| 18 |
+
1. **YOLO 权重已更新** — 请尝试新版本的水印检测模型,效果会更好!
|
| 19 |
+
|
| 20 |
+
2. **数据集已开源** — 我们已经将标注好的数据集上传到了 Hugging Face,查看[此数据集](https://huggingface.co/datasets/LLinked/sora-watermark-dataset)。欢迎训练你自己的检测模型或改进我们的模型!
|
| 21 |
+
|
| 22 |
+
3. **一键便携版已发布** — [点击这里下载](#3-一键便携版),Windows 用户无需安装即可使用!
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
## 1. 方法
|
| 26 |
+
|
| 27 |
+
SoraWatermarkCleaner(后面我们简称为 `SoraWm`)由两部分组成:
|
| 28 |
+
|
| 29 |
+
- SoraWaterMarkDetector:我们训练了一个 yolov11s 版本来检测 Sora 水印。(感谢 YOLO!)
|
| 30 |
+
|
| 31 |
+
- WaterMarkCleaner:我们参考了 IOPaint 的实现,使用 LAMA 模型进行水印移除。
|
| 32 |
+
|
| 33 |
+
(此代码库来自 https://github.com/Sanster/IOPaint#,感谢他们的出色工作!)
|
| 34 |
+
|
| 35 |
+
我们的 SoraWm 完全由深度学习驱动,在许多生成的视频中都能产生良好的效果。
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
## 2. 安装
|
| 40 |
+
视频处理需要 [FFmpeg](https://ffmpeg.org/),请先安装它。我们强烈推荐使用 `uv` 来安装环境:
|
| 41 |
+
|
| 42 |
+
1. 安装:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
uv sync
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
> 现在环境将被安装在 `.venv` 目录下,你可以使用以下命令激活环境:
|
| 49 |
+
>
|
| 50 |
+
> ```bash
|
| 51 |
+
> source .venv/bin/activate
|
| 52 |
+
> ```
|
| 53 |
+
|
| 54 |
+
2. 下载预训练模型:
|
| 55 |
+
|
| 56 |
+
训练好的 YOLO 权重将存储在 `resources` 目录中,文件名为 `best.pt`。它将从 https://github.com/linkedlist771/SoraWatermarkCleaner/releases/download/V0.0.1/best.pt 自动下载。`Lama` 模型从 https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt 下载,并将存储在 torch 缓存目录中。两者都是自动下载的,如果失败,请检查你的网络状态。
|
| 57 |
+
|
| 58 |
+
## 3. 一键便携版
|
| 59 |
+
|
| 60 |
+
对于不想手动安装的用户,我们提供了**一键便携版本**,包含所有预配置的依赖项,开箱即用。
|
| 61 |
+
|
| 62 |
+
### 下载链接
|
| 63 |
+
|
| 64 |
+
**Google Drive(谷歌云盘):**
|
| 65 |
+
- [从 Google Drive 下载](https://drive.google.com/file/d/1ujH28aHaCXGgB146g6kyfz3Qxd-wHR1c/view?usp=share_link)
|
| 66 |
+
|
| 67 |
+
**百度网盘(推荐国内用户使用):**
|
| 68 |
+
- 链接:https://pan.baidu.com/s/1i4exYsPvXv0evnGs5MWcYA?pwd=3jr6
|
| 69 |
+
- 提取码:`3jr6`
|
| 70 |
+
|
| 71 |
+
### 特点
|
| 72 |
+
- ✅ 无需安装
|
| 73 |
+
- ✅ 包含所有依赖
|
| 74 |
+
- ✅ 预配置环境
|
| 75 |
+
- ✅ 开箱即用
|
| 76 |
+
|
| 77 |
+
只需下载、解压并运行!
|
| 78 |
+
|
| 79 |
+
## 4. 演示
|
| 80 |
+
|
| 81 |
+
基本用法,只需尝试 `example.py`:
|
| 82 |
+
|
| 83 |
+
```python
|
| 84 |
+
|
| 85 |
+
from pathlib import Path
|
| 86 |
+
from sorawm.core import SoraWM
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
input_video_path = Path(
|
| 91 |
+
"resources/dog_vs_sam.mp4"
|
| 92 |
+
)
|
| 93 |
+
output_video_path = Path("outputs/sora_watermark_removed.mp4")
|
| 94 |
+
sora_wm = SoraWM()
|
| 95 |
+
sora_wm.run(input_video_path, output_video_path)
|
| 96 |
+
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
我们还提供了基于 `streamlit` 的交互式网页界面,使用以下命令尝试:
|
| 100 |
+
|
| 101 |
+
```bash
|
| 102 |
+
streamlit run app.py
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
<img src="resources/app.png" style="zoom: 25%;" />
|
| 106 |
+
|
| 107 |
+
## 5. WebServer
|
| 108 |
+
|
| 109 |
+
在这里,我们提供了一个基于 FastAPI 的 Web 服务器,可以快速将这个水印清除器转换为服务。
|
| 110 |
+
|
| 111 |
+
只需运行:
|
| 112 |
+
|
| 113 |
+
```python
|
| 114 |
+
python start_server.py
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
Web 服务器将在端口 `5344` 启动,你可以查看 FastAPI [文档](http://localhost:5344/docs) 了解详情,有三个路由:
|
| 118 |
+
|
| 119 |
+
1. submit_remove_task:
|
| 120 |
+
|
| 121 |
+
> 上传视频后,会返回一个任务 ID,该视频将立即被处理。
|
| 122 |
+
|
| 123 |
+
<img src="resources/53abf3fd-11a9-4dd7-a348-34920775f8ad.png" alt="image" style="zoom: 25%;" />
|
| 124 |
+
|
| 125 |
+
2. get_results:
|
| 126 |
+
|
| 127 |
+
你可以使用上面的任务 ID 检索任务状态,它会显示视频处理的百分比。一旦完成,返回的数据中会有下载 URL。
|
| 128 |
+
|
| 129 |
+
3. downlaod:
|
| 130 |
+
|
| 131 |
+
你可以使用第2步中的下载 URL 来获取清理后的视频。
|
| 132 |
+
|
| 133 |
+
## 6. 数据集
|
| 134 |
+
|
| 135 |
+
我们已经将标注好的数据集上传到了 Hugging Face,请查看 https://huggingface.co/datasets/LLinked/sora-watermark-dataset。欢迎训练你自己的检测模型或改进我们的模型!
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
## 7. API
|
| 140 |
+
|
| 141 |
+
打包为 Cog 并[发布到 Replicate](https://replicate.com/uglyrobot/sora2-watermark-remover),便于基于 API 的简单使用。
|
| 142 |
+
|
| 143 |
+
## 8. 许可证
|
| 144 |
+
|
| 145 |
+
Apache License
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
## 9. 引用
|
| 149 |
+
|
| 150 |
+
如果你使用了这个项目,请引用:
|
| 151 |
+
|
| 152 |
+
```bibtex
|
| 153 |
+
@misc{sorawatermarkcleaner2025,
|
| 154 |
+
author = {linkedlist771},
|
| 155 |
+
title = {SoraWatermarkCleaner},
|
| 156 |
+
year = {2025},
|
| 157 |
+
url = {https://github.com/linkedlist771/SoraWatermarkCleaner}
|
| 158 |
+
}
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
## 10. 致谢
|
| 162 |
+
|
| 163 |
+
- [IOPaint](https://github.com/Sanster/IOPaint) 提供的 LAMA 实现
|
| 164 |
+
- [Ultralytics YOLO](https://github.com/ultralytics/ultralytics) 提供的目标检测
|
app.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import shutil
|
| 2 |
+
import tempfile
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import streamlit as st
|
| 6 |
+
|
| 7 |
+
from sorawm.core import SoraWM
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
st.set_page_config(
|
| 12 |
+
page_title="Sora Watermark Cleaner", page_icon="🎬", layout="centered"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
st.title("🎬 Sora Watermark Cleaner")
|
| 16 |
+
st.markdown("Remove watermarks from Sora-generated videos with ease")
|
| 17 |
+
|
| 18 |
+
# Initialize SoraWM
|
| 19 |
+
if "sora_wm" not in st.session_state:
|
| 20 |
+
with st.spinner("Loading AI models..."):
|
| 21 |
+
st.session_state.sora_wm = SoraWM()
|
| 22 |
+
|
| 23 |
+
st.markdown("---")
|
| 24 |
+
|
| 25 |
+
# File uploader
|
| 26 |
+
uploaded_file = st.file_uploader(
|
| 27 |
+
"Upload your video",
|
| 28 |
+
type=["mp4", "avi", "mov", "mkv"],
|
| 29 |
+
help="Select a video file to remove watermarks",
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
if uploaded_file is not None:
|
| 33 |
+
# Display video info
|
| 34 |
+
st.success(f"✅ Uploaded: {uploaded_file.name}")
|
| 35 |
+
st.video(uploaded_file)
|
| 36 |
+
|
| 37 |
+
# Process button
|
| 38 |
+
if st.button("🚀 Remove Watermark", type="primary", use_container_width=True):
|
| 39 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 40 |
+
tmp_path = Path(tmp_dir)
|
| 41 |
+
|
| 42 |
+
# Save uploaded file
|
| 43 |
+
input_path = tmp_path / uploaded_file.name
|
| 44 |
+
with open(input_path, "wb") as f:
|
| 45 |
+
f.write(uploaded_file.read())
|
| 46 |
+
|
| 47 |
+
# Process video
|
| 48 |
+
output_path = tmp_path / f"cleaned_{uploaded_file.name}"
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
# Create progress bar and status text
|
| 52 |
+
progress_bar = st.progress(0)
|
| 53 |
+
status_text = st.empty()
|
| 54 |
+
|
| 55 |
+
def update_progress(progress: int):
|
| 56 |
+
progress_bar.progress(progress / 100)
|
| 57 |
+
if progress < 50:
|
| 58 |
+
status_text.text(f"🔍 Detecting watermarks... {progress}%")
|
| 59 |
+
elif progress < 95:
|
| 60 |
+
status_text.text(f"🧹 Removing watermarks... {progress}%")
|
| 61 |
+
else:
|
| 62 |
+
status_text.text(f"🎵 Merging audio... {progress}%")
|
| 63 |
+
|
| 64 |
+
# Run the watermark removal with progress callback
|
| 65 |
+
st.session_state.sora_wm.run(
|
| 66 |
+
input_path, output_path, progress_callback=update_progress
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Complete the progress bar
|
| 70 |
+
progress_bar.progress(100)
|
| 71 |
+
status_text.text("✅ Processing complete!")
|
| 72 |
+
|
| 73 |
+
st.success("✅ Watermark removed successfully!")
|
| 74 |
+
|
| 75 |
+
# Display result
|
| 76 |
+
st.markdown("### Result")
|
| 77 |
+
st.video(str(output_path))
|
| 78 |
+
|
| 79 |
+
# Download button
|
| 80 |
+
with open(output_path, "rb") as f:
|
| 81 |
+
st.download_button(
|
| 82 |
+
label="⬇️ Download Cleaned Video",
|
| 83 |
+
data=f,
|
| 84 |
+
file_name=f"cleaned_{uploaded_file.name}",
|
| 85 |
+
mime="video/mp4",
|
| 86 |
+
use_container_width=True,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
st.error(f"❌ Error processing video: {str(e)}")
|
| 91 |
+
|
| 92 |
+
# Footer
|
| 93 |
+
st.markdown("---")
|
| 94 |
+
st.markdown(
|
| 95 |
+
"""
|
| 96 |
+
<div style='text-align: center'>
|
| 97 |
+
<p>Built with ❤️ using Streamlit and AI</p>
|
| 98 |
+
<p><a href='https://github.com/linkedlist771/SoraWatermarkCleaner'>GitHub Repository</a></p>
|
| 99 |
+
</div>
|
| 100 |
+
""",
|
| 101 |
+
unsafe_allow_html=True,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
main()
|
datasets/make_yolo_images.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
from sorawm.configs import ROOT
|
| 7 |
+
|
| 8 |
+
videos_dir = ROOT / "videos"
|
| 9 |
+
datasets_dir = ROOT / "datasets"
|
| 10 |
+
images_dir = datasets_dir / "images"
|
| 11 |
+
images_dir.mkdir(exist_ok=True, parents=True)
|
| 12 |
+
|
| 13 |
+
if __name__ == "__main__":
|
| 14 |
+
fps_save_interval = 1 # Save every 1th frame
|
| 15 |
+
|
| 16 |
+
idx = 0
|
| 17 |
+
for video_path in tqdm(list(videos_dir.rglob("*.mp4"))):
|
| 18 |
+
# Open the video file
|
| 19 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 20 |
+
|
| 21 |
+
if not cap.isOpened():
|
| 22 |
+
print(f"Error opening video: {video_path}")
|
| 23 |
+
continue
|
| 24 |
+
|
| 25 |
+
frame_count = 0
|
| 26 |
+
|
| 27 |
+
while True:
|
| 28 |
+
ret, frame = cap.read()
|
| 29 |
+
|
| 30 |
+
# Break if no more frames
|
| 31 |
+
if not ret:
|
| 32 |
+
break
|
| 33 |
+
|
| 34 |
+
# Save frame at the specified interval
|
| 35 |
+
if frame_count % fps_save_interval == 0:
|
| 36 |
+
# Create filename: image_idx_framecount.jpg
|
| 37 |
+
image_filename = f"image_{idx:06d}_frame_{frame_count:06d}.jpg"
|
| 38 |
+
image_path = images_dir / image_filename
|
| 39 |
+
|
| 40 |
+
# Save the frame
|
| 41 |
+
cv2.imwrite(str(image_path), frame)
|
| 42 |
+
|
| 43 |
+
frame_count += 1
|
| 44 |
+
|
| 45 |
+
# Release the video capture object
|
| 46 |
+
cap.release()
|
| 47 |
+
idx += 1
|
| 48 |
+
|
| 49 |
+
print(f"Processed {idx} videos, extracted frames saved to {images_dir}")
|
example.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
from sorawm.core import SoraWM
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
input_video_path = Path("resources/dog_vs_sam.mp4")
|
| 7 |
+
output_video_path = Path("outputs/sora_watermark_removed.mp4")
|
| 8 |
+
sora_wm = SoraWM()
|
| 9 |
+
sora_wm.run(input_video_path, output_video_path)
|
ffmpeg/README.md
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FFmpeg 可执行文件目录
|
| 2 |
+
|
| 3 |
+
## 用途
|
| 4 |
+
|
| 5 |
+
这个目录用于存放 FFmpeg 可执行文件,使项目成为真正的便携版(无需系统安装 FFmpeg)。
|
| 6 |
+
|
| 7 |
+
## Windows 用户配置步骤
|
| 8 |
+
|
| 9 |
+
### 1. 下载 FFmpeg
|
| 10 |
+
|
| 11 |
+
访问 [FFmpeg-Builds Release](https://github.com/BtbN/FFmpeg-Builds/releases) 页面:
|
| 12 |
+
|
| 13 |
+
- 下载最新的 `ffmpeg-master-latest-win64-gpl.zip`(约 120MB)
|
| 14 |
+
- 或者下载特定版本,如 `ffmpeg-n6.1-latest-win64-gpl-6.1.zip`
|
| 15 |
+
|
| 16 |
+
### 2. 解压并复制文件
|
| 17 |
+
|
| 18 |
+
1. 解压下载的 zip 文件
|
| 19 |
+
2. 在解压后的文件夹中找到 `bin` 目录
|
| 20 |
+
3. 将以下两个文件复制到**当前目录**(`ffmpeg/`):
|
| 21 |
+
- `ffmpeg.exe` - FFmpeg 主程序
|
| 22 |
+
- `ffprobe.exe` - FFmpeg 媒体信息探测工具
|
| 23 |
+
|
| 24 |
+
### 3. 验证配置
|
| 25 |
+
|
| 26 |
+
完成后,此目录应包含:
|
| 27 |
+
|
| 28 |
+
```
|
| 29 |
+
ffmpeg/
|
| 30 |
+
├── .gitkeep
|
| 31 |
+
├── README.md
|
| 32 |
+
├── ffmpeg.exe ← 你复制的文件
|
| 33 |
+
└── ffprobe.exe ← 你复制的文件
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### 4. 测试
|
| 37 |
+
|
| 38 |
+
运行项目中的测试脚本验证配置:
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
python test_ffmpeg_setup.py
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
如果配置正确,你将看到:`✓ 测试通过!FFmpeg已正确配置并可以使用`
|
| 45 |
+
|
| 46 |
+
## macOS/Linux 用户
|
| 47 |
+
|
| 48 |
+
如果需要便携版,请:
|
| 49 |
+
|
| 50 |
+
1. 下载对应平台的 FFmpeg 二进制文件
|
| 51 |
+
2. 将 `ffmpeg` 和 `ffprobe` 可执行文件放到此目录
|
| 52 |
+
3. 确保文件有执行权限:`chmod +x ffmpeg ffprobe`
|
| 53 |
+
|
| 54 |
+
## 注意事项
|
| 55 |
+
|
| 56 |
+
- 这些可执行文件不会被 git 提交(已在 `.gitignore` 中配置)
|
| 57 |
+
- 程序会自动检测并使用此目录下的 FFmpeg
|
| 58 |
+
- 如果此目录没有 FFmpeg,程序会尝试使用系统安装的版本
|
| 59 |
+
|
| 60 |
+
## 下载链接汇总
|
| 61 |
+
|
| 62 |
+
- **Windows**: https://github.com/BtbN/FFmpeg-Builds/releases
|
| 63 |
+
- **官方网站**: https://ffmpeg.org/download.html
|
| 64 |
+
- **镜像站点**: https://www.gyan.dev/ffmpeg/builds/ (Windows)
|
| 65 |
+
|
| 66 |
+
## 许可证
|
| 67 |
+
|
| 68 |
+
FFmpeg 使用 GPL 许可证,请遵守相关条款。
|
| 69 |
+
|
notebooks/imputation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
one-click-portable.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# One-Click Portable Version | 一键便携版
|
| 2 |
+
|
| 3 |
+
For **Windows** users - No installation required!
|
| 4 |
+
|
| 5 |
+
适用于 **Windows** 用户 - 无需安装!
|
| 6 |
+
|
| 7 |
+
## Download | 下载
|
| 8 |
+
|
| 9 |
+
**Google Drive:**
|
| 10 |
+
- https://drive.google.com/file/d/1ujH28aHaCXGgB146g6kyfz3Qxd-wHR1c/view?usp=share_link
|
| 11 |
+
|
| 12 |
+
**Baidu Pan | 百度网盘:**
|
| 13 |
+
- Link | 链接: https://pan.baidu.com/s/1_tdgs-3-dLNn0IbufIM75g?pwd=fiju
|
| 14 |
+
- Extract Code | 提取码: `fiju`
|
| 15 |
+
|
| 16 |
+
## Usage | 使用方法
|
| 17 |
+
|
| 18 |
+
1. Download and extract the zip file | 下载并解压 zip 文件
|
| 19 |
+
2. Double-click `run.bat` | 双击 `run.bat` 文件
|
| 20 |
+
3. The web service will start automatically! | 网页服务将自动启动!
|
| 21 |
+
|
| 22 |
+
## Features | 特点
|
| 23 |
+
|
| 24 |
+
- ✅ Zero installation | 无需安装
|
| 25 |
+
- ✅ All dependencies included | 包含所有依赖
|
| 26 |
+
- ✅ Ready to use | 开箱即用
|
pyproject.toml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "sorawatermarkcleaner"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"aiofiles>=24.1.0",
|
| 9 |
+
"aiosqlite>=0.21.0",
|
| 10 |
+
"diffusers>=0.35.1",
|
| 11 |
+
"einops>=0.8.1",
|
| 12 |
+
"fastapi==0.108.0",
|
| 13 |
+
"ffmpeg-python>=0.2.0",
|
| 14 |
+
"fire>=0.7.1",
|
| 15 |
+
"httpx>=0.28.1",
|
| 16 |
+
"huggingface-hub>=0.35.3",
|
| 17 |
+
"jupyter>=1.1.1",
|
| 18 |
+
"loguru>=0.7.3",
|
| 19 |
+
"matplotlib>=3.10.6",
|
| 20 |
+
"notebook>=7.4.7",
|
| 21 |
+
"omegaconf>=2.3.0",
|
| 22 |
+
"opencv-python>=4.12.0.88",
|
| 23 |
+
"pandas>=2.3.3",
|
| 24 |
+
"pydantic>=2.11.10",
|
| 25 |
+
"python-multipart>=0.0.20",
|
| 26 |
+
"requests>=2.32.5",
|
| 27 |
+
"ruptures>=1.1.10",
|
| 28 |
+
"scikit-learn>=1.7.2",
|
| 29 |
+
"sqlalchemy>=2.0.43",
|
| 30 |
+
"streamlit>=1.50.0",
|
| 31 |
+
"torch>=2.5.0",
|
| 32 |
+
"torchvision>=0.20.0",
|
| 33 |
+
"tqdm>=4.67.1",
|
| 34 |
+
"transformers>=4.57.0",
|
| 35 |
+
"ultralytics>=8.3.204",
|
| 36 |
+
"uuid>=1.30",
|
| 37 |
+
"uvicorn>=0.35.0",
|
| 38 |
+
]
|
| 39 |
+
|
resources/first_frame.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
resources/watermark_template.png
ADDED
|
sorawm/__init__.py
ADDED
|
File without changes
|
sorawm/configs.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
ROOT = Path(__file__).parent.parent
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
RESOURCES_DIR = ROOT / "resources"
|
| 7 |
+
WATER_MARK_TEMPLATE_IMAGE_PATH = RESOURCES_DIR / "watermark_template.png"
|
| 8 |
+
|
| 9 |
+
WATER_MARK_DETECT_YOLO_WEIGHTS = RESOURCES_DIR / "best.pt"
|
| 10 |
+
|
| 11 |
+
OUTPUT_DIR = ROOT / "output"
|
| 12 |
+
|
| 13 |
+
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
DEFAULT_WATERMARK_REMOVE_MODEL = "lama"
|
| 17 |
+
|
| 18 |
+
WORKING_DIR = ROOT / "working_dir"
|
| 19 |
+
WORKING_DIR.mkdir(exist_ok=True, parents=True)
|
| 20 |
+
|
| 21 |
+
LOGS_PATH = ROOT / "logs"
|
| 22 |
+
LOGS_PATH.mkdir(exist_ok=True, parents=True)
|
| 23 |
+
|
| 24 |
+
DATA_PATH = ROOT / "data"
|
| 25 |
+
DATA_PATH.mkdir(exist_ok=True, parents=True)
|
| 26 |
+
|
| 27 |
+
SQLITE_PATH = DATA_PATH / "db.sqlite3"
|
sorawm/core.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Callable
|
| 3 |
+
|
| 4 |
+
import ffmpeg
|
| 5 |
+
import numpy as np
|
| 6 |
+
from loguru import logger
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
from sorawm.utils.video_utils import VideoLoader
|
| 10 |
+
from sorawm.watermark_cleaner import WaterMarkCleaner
|
| 11 |
+
from sorawm.watermark_detector import SoraWaterMarkDetector
|
| 12 |
+
from sorawm.utils.imputation_utils import (
|
| 13 |
+
find_2d_data_bkps,
|
| 14 |
+
get_interval_average_bbox,
|
| 15 |
+
find_idxs_interval,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SoraWM:
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.detector = SoraWaterMarkDetector()
|
| 22 |
+
self.cleaner = WaterMarkCleaner()
|
| 23 |
+
|
| 24 |
+
def run(
|
| 25 |
+
self,
|
| 26 |
+
input_video_path: Path,
|
| 27 |
+
output_video_path: Path,
|
| 28 |
+
progress_callback: Callable[[int], None] | None = None,
|
| 29 |
+
):
|
| 30 |
+
input_video_loader = VideoLoader(input_video_path)
|
| 31 |
+
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
| 32 |
+
width = input_video_loader.width
|
| 33 |
+
height = input_video_loader.height
|
| 34 |
+
fps = input_video_loader.fps
|
| 35 |
+
total_frames = input_video_loader.total_frames
|
| 36 |
+
|
| 37 |
+
temp_output_path = output_video_path.parent / f"temp_{output_video_path.name}"
|
| 38 |
+
output_options = {
|
| 39 |
+
"pix_fmt": "yuv420p",
|
| 40 |
+
"vcodec": "libx264",
|
| 41 |
+
"preset": "slow",
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
if input_video_loader.original_bitrate:
|
| 45 |
+
output_options["video_bitrate"] = str(
|
| 46 |
+
int(int(input_video_loader.original_bitrate) * 1.2)
|
| 47 |
+
)
|
| 48 |
+
else:
|
| 49 |
+
output_options["crf"] = "18"
|
| 50 |
+
|
| 51 |
+
process_out = (
|
| 52 |
+
ffmpeg.input(
|
| 53 |
+
"pipe:",
|
| 54 |
+
format="rawvideo",
|
| 55 |
+
pix_fmt="bgr24",
|
| 56 |
+
s=f"{width}x{height}",
|
| 57 |
+
r=fps,
|
| 58 |
+
)
|
| 59 |
+
.output(str(temp_output_path), **output_options)
|
| 60 |
+
.overwrite_output()
|
| 61 |
+
.global_args("-loglevel", "error")
|
| 62 |
+
.run_async(pipe_stdin=True)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
frame_and_mask = {}
|
| 66 |
+
detect_missed = []
|
| 67 |
+
bbox_centers = []
|
| 68 |
+
bboxes = []
|
| 69 |
+
|
| 70 |
+
logger.debug(
|
| 71 |
+
f"total frames: {total_frames}, fps: {fps}, width: {width}, height: {height}"
|
| 72 |
+
)
|
| 73 |
+
for idx, frame in enumerate(
|
| 74 |
+
tqdm(input_video_loader, total=total_frames, desc="Detect watermarks")
|
| 75 |
+
):
|
| 76 |
+
detection_result = self.detector.detect(frame)
|
| 77 |
+
if detection_result["detected"]:
|
| 78 |
+
frame_and_mask[idx] = {"frame": frame, "bbox": detection_result["bbox"]}
|
| 79 |
+
x1, y1, x2, y2 = detection_result["bbox"]
|
| 80 |
+
bbox_centers.append((int((x1 + x2) / 2), int((y1 + y2) / 2)))
|
| 81 |
+
bboxes.append((x1, y1, x2, y2))
|
| 82 |
+
|
| 83 |
+
else:
|
| 84 |
+
frame_and_mask[idx] = {"frame": frame, "bbox": None}
|
| 85 |
+
detect_missed.append(idx)
|
| 86 |
+
bbox_centers.append(None)
|
| 87 |
+
bboxes.append(None)
|
| 88 |
+
# 10% - 50%
|
| 89 |
+
if progress_callback and idx % 10 == 0:
|
| 90 |
+
progress = 10 + int((idx / total_frames) * 40)
|
| 91 |
+
progress_callback(progress)
|
| 92 |
+
|
| 93 |
+
logger.debug(f"detect missed frames: {detect_missed}")
|
| 94 |
+
# logger.debug(f"bbox centers: \n{bbox_centers}")
|
| 95 |
+
if detect_missed:
|
| 96 |
+
# 1. find the bkps of the bbox centers
|
| 97 |
+
bkps = find_2d_data_bkps(bbox_centers)
|
| 98 |
+
# add the start and end position, to form the complete interval boundaries
|
| 99 |
+
bkps_full = [0] + bkps + [total_frames]
|
| 100 |
+
# logger.debug(f"bkps intervals: {bkps_full}")
|
| 101 |
+
|
| 102 |
+
# 2. calculate the average bbox of each interval
|
| 103 |
+
interval_bboxes = get_interval_average_bbox(bboxes, bkps_full)
|
| 104 |
+
# logger.debug(f"interval average bboxes: {interval_bboxes}")
|
| 105 |
+
|
| 106 |
+
# 3. find the interval index of each missed frame
|
| 107 |
+
missed_intervals = find_idxs_interval(detect_missed, bkps_full)
|
| 108 |
+
# logger.debug(
|
| 109 |
+
# f"missed frame intervals: {list(zip(detect_missed, missed_intervals))}"
|
| 110 |
+
# )
|
| 111 |
+
|
| 112 |
+
# 4. fill the missed frames with the average bbox of the corresponding interval
|
| 113 |
+
for missed_idx, interval_idx in zip(detect_missed, missed_intervals):
|
| 114 |
+
if (
|
| 115 |
+
interval_idx < len(interval_bboxes)
|
| 116 |
+
and interval_bboxes[interval_idx] is not None
|
| 117 |
+
):
|
| 118 |
+
frame_and_mask[missed_idx]["bbox"] = interval_bboxes[interval_idx]
|
| 119 |
+
logger.debug(f"Filled missed frame {missed_idx} with bbox:\n"
|
| 120 |
+
f" {interval_bboxes[interval_idx]}")
|
| 121 |
+
else:
|
| 122 |
+
# if the interval has no valid bbox, use the previous and next frame to complete (fallback strategy)
|
| 123 |
+
before = max(missed_idx - 1, 0)
|
| 124 |
+
after = min(missed_idx + 1, total_frames - 1)
|
| 125 |
+
before_box = frame_and_mask[before]["bbox"]
|
| 126 |
+
after_box = frame_and_mask[after]["bbox"]
|
| 127 |
+
if before_box:
|
| 128 |
+
frame_and_mask[missed_idx]["bbox"] = before_box
|
| 129 |
+
elif after_box:
|
| 130 |
+
frame_and_mask[missed_idx]["bbox"] = after_box
|
| 131 |
+
else:
|
| 132 |
+
del bboxes
|
| 133 |
+
del bbox_centers
|
| 134 |
+
del detect_missed
|
| 135 |
+
|
| 136 |
+
for idx in tqdm(range(total_frames), desc="Remove watermarks"):
|
| 137 |
+
frame_info = frame_and_mask[idx]
|
| 138 |
+
frame = frame_info["frame"]
|
| 139 |
+
bbox = frame_info["bbox"]
|
| 140 |
+
if bbox is not None:
|
| 141 |
+
x1, y1, x2, y2 = bbox
|
| 142 |
+
mask = np.zeros((height, width), dtype=np.uint8)
|
| 143 |
+
mask[y1:y2, x1:x2] = 255
|
| 144 |
+
cleaned_frame = self.cleaner.clean(frame, mask)
|
| 145 |
+
else:
|
| 146 |
+
cleaned_frame = frame
|
| 147 |
+
process_out.stdin.write(cleaned_frame.tobytes())
|
| 148 |
+
|
| 149 |
+
# 50% - 95%
|
| 150 |
+
if progress_callback and idx % 10 == 0:
|
| 151 |
+
progress = 50 + int((idx / total_frames) * 45)
|
| 152 |
+
progress_callback(progress)
|
| 153 |
+
|
| 154 |
+
process_out.stdin.close()
|
| 155 |
+
process_out.wait()
|
| 156 |
+
|
| 157 |
+
# 95% - 99%
|
| 158 |
+
if progress_callback:
|
| 159 |
+
progress_callback(95)
|
| 160 |
+
|
| 161 |
+
self.merge_audio_track(input_video_path, temp_output_path, output_video_path)
|
| 162 |
+
|
| 163 |
+
if progress_callback:
|
| 164 |
+
progress_callback(99)
|
| 165 |
+
|
| 166 |
+
def merge_audio_track(
|
| 167 |
+
self, input_video_path: Path, temp_output_path: Path, output_video_path: Path
|
| 168 |
+
):
|
| 169 |
+
logger.info("Merging audio track...")
|
| 170 |
+
video_stream = ffmpeg.input(str(temp_output_path))
|
| 171 |
+
audio_stream = ffmpeg.input(str(input_video_path)).audio
|
| 172 |
+
|
| 173 |
+
(
|
| 174 |
+
ffmpeg.output(
|
| 175 |
+
video_stream,
|
| 176 |
+
audio_stream,
|
| 177 |
+
str(output_video_path),
|
| 178 |
+
vcodec="copy",
|
| 179 |
+
acodec="aac",
|
| 180 |
+
)
|
| 181 |
+
.overwrite_output()
|
| 182 |
+
.run(quiet=True)
|
| 183 |
+
)
|
| 184 |
+
# Clean up temporary file
|
| 185 |
+
temp_output_path.unlink()
|
| 186 |
+
logger.info(f"Saved no watermark video with audio at: {output_video_path}")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
if __name__ == "__main__":
|
| 190 |
+
from pathlib import Path
|
| 191 |
+
|
| 192 |
+
input_video_path = Path(
|
| 193 |
+
"resources/19700121_1645_68e0a027836c8191a50bea3717ea7485.mp4"
|
| 194 |
+
)
|
| 195 |
+
output_video_path = Path("outputs/sora_watermark_removed.mp4")
|
| 196 |
+
sora_wm = SoraWM()
|
| 197 |
+
sora_wm.run(input_video_path, output_video_path)
|
sorawm/iopaint/__init__.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import importlib.util
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
|
| 7 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
| 8 |
+
# https://github.com/pytorch/pytorch/issues/27971#issuecomment-1768868068
|
| 9 |
+
os.environ["ONEDNN_PRIMITIVE_CACHE_CAPACITY"] = "1"
|
| 10 |
+
os.environ["LRU_CACHE_CAPACITY"] = "1"
|
| 11 |
+
# prevent CPU memory leak when run model on GPU
|
| 12 |
+
# https://github.com/pytorch/pytorch/issues/98688#issuecomment-1869288431
|
| 13 |
+
# https://github.com/pytorch/pytorch/issues/108334#issuecomment-1752763633
|
| 14 |
+
os.environ["TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT"] = "1"
|
| 15 |
+
|
| 16 |
+
import warnings
|
| 17 |
+
|
| 18 |
+
warnings.simplefilter("ignore", UserWarning)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def fix_window_pytorch():
|
| 22 |
+
# copy from: https://github.com/comfyanonymous/ComfyUI/blob/5cbaa9e07c97296b536f240688f5a19300ecf30d/fix_torch.py#L4
|
| 23 |
+
import platform
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
if platform.system() != "Windows":
|
| 27 |
+
return
|
| 28 |
+
torch_spec = importlib.util.find_spec("torch")
|
| 29 |
+
for folder in torch_spec.submodule_search_locations:
|
| 30 |
+
lib_folder = os.path.join(folder, "lib")
|
| 31 |
+
test_file = os.path.join(lib_folder, "fbgemm.dll")
|
| 32 |
+
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
|
| 33 |
+
if os.path.exists(dest):
|
| 34 |
+
break
|
| 35 |
+
|
| 36 |
+
with open(test_file, "rb") as f:
|
| 37 |
+
contents = f.read()
|
| 38 |
+
if b"libomp140.x86_64.dll" not in contents:
|
| 39 |
+
break
|
| 40 |
+
try:
|
| 41 |
+
mydll = ctypes.cdll.LoadLibrary(test_file)
|
| 42 |
+
except FileNotFoundError:
|
| 43 |
+
logging.warning("Detected pytorch version with libomp issue, patching.")
|
| 44 |
+
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
|
| 45 |
+
except:
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def entry_point():
|
| 50 |
+
# To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
|
| 51 |
+
# https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
|
| 52 |
+
from sorawm.iopaint.cli import typer_app
|
| 53 |
+
|
| 54 |
+
fix_window_pytorch()
|
| 55 |
+
|
| 56 |
+
typer_app()
|
sorawm/iopaint/__main__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from iopaint import entry_point
|
| 2 |
+
|
| 3 |
+
if __name__ == "__main__":
|
| 4 |
+
entry_point()
|
sorawm/iopaint/api.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
import threading
|
| 4 |
+
import time
|
| 5 |
+
import traceback
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, List, Optional
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
import socketio
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
torch._C._jit_override_can_fuse_on_cpu(False)
|
| 16 |
+
torch._C._jit_override_can_fuse_on_gpu(False)
|
| 17 |
+
torch._C._jit_set_texpr_fuser_enabled(False)
|
| 18 |
+
torch._C._jit_set_nvfuser_enabled(False)
|
| 19 |
+
torch._C._jit_set_profiling_mode(False)
|
| 20 |
+
except:
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
import uvicorn
|
| 24 |
+
from fastapi import APIRouter, FastAPI, Request, UploadFile
|
| 25 |
+
from fastapi.encoders import jsonable_encoder
|
| 26 |
+
from fastapi.exceptions import HTTPException
|
| 27 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 28 |
+
from fastapi.responses import FileResponse, JSONResponse, Response
|
| 29 |
+
from fastapi.staticfiles import StaticFiles
|
| 30 |
+
from loguru import logger
|
| 31 |
+
from PIL import Image
|
| 32 |
+
from socketio import AsyncServer
|
| 33 |
+
|
| 34 |
+
from sorawm.iopaint.file_manager import FileManager
|
| 35 |
+
from sorawm.iopaint.helper import (
|
| 36 |
+
adjust_mask,
|
| 37 |
+
concat_alpha_channel,
|
| 38 |
+
decode_base64_to_image,
|
| 39 |
+
gen_frontend_mask,
|
| 40 |
+
load_img,
|
| 41 |
+
numpy_to_bytes,
|
| 42 |
+
pil_to_bytes,
|
| 43 |
+
)
|
| 44 |
+
from sorawm.iopaint.model.utils import torch_gc
|
| 45 |
+
from sorawm.iopaint.model_manager import ModelManager
|
| 46 |
+
from sorawm.iopaint.plugins import InteractiveSeg, RealESRGANUpscaler, build_plugins
|
| 47 |
+
from sorawm.iopaint.plugins.base_plugin import BasePlugin
|
| 48 |
+
from sorawm.iopaint.plugins.remove_bg import RemoveBG
|
| 49 |
+
from sorawm.iopaint.schema import (
|
| 50 |
+
AdjustMaskRequest,
|
| 51 |
+
ApiConfig,
|
| 52 |
+
GenInfoResponse,
|
| 53 |
+
InpaintRequest,
|
| 54 |
+
InteractiveSegModel,
|
| 55 |
+
ModelInfo,
|
| 56 |
+
PluginInfo,
|
| 57 |
+
RealESRGANModel,
|
| 58 |
+
RemoveBGModel,
|
| 59 |
+
RunPluginRequest,
|
| 60 |
+
SDSampler,
|
| 61 |
+
ServerConfigResponse,
|
| 62 |
+
SwitchModelRequest,
|
| 63 |
+
SwitchPluginModelRequest,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
| 67 |
+
WEB_APP_DIR = CURRENT_DIR / "web_app"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def api_middleware(app: FastAPI):
|
| 71 |
+
rich_available = False
|
| 72 |
+
try:
|
| 73 |
+
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
|
| 74 |
+
import anyio # importing just so it can be placed on silent list
|
| 75 |
+
import starlette # importing just so it can be placed on silent list
|
| 76 |
+
from rich.console import Console
|
| 77 |
+
|
| 78 |
+
console = Console()
|
| 79 |
+
rich_available = True
|
| 80 |
+
except Exception:
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
def handle_exception(request: Request, e: Exception):
|
| 84 |
+
err = {
|
| 85 |
+
"error": type(e).__name__,
|
| 86 |
+
"detail": vars(e).get("detail", ""),
|
| 87 |
+
"body": vars(e).get("body", ""),
|
| 88 |
+
"errors": str(e),
|
| 89 |
+
}
|
| 90 |
+
if not isinstance(
|
| 91 |
+
e, HTTPException
|
| 92 |
+
): # do not print backtrace on known httpexceptions
|
| 93 |
+
message = f"API error: {request.method}: {request.url} {err}"
|
| 94 |
+
if rich_available:
|
| 95 |
+
print(message)
|
| 96 |
+
console.print_exception(
|
| 97 |
+
show_locals=True,
|
| 98 |
+
max_frames=2,
|
| 99 |
+
extra_lines=1,
|
| 100 |
+
suppress=[anyio, starlette],
|
| 101 |
+
word_wrap=False,
|
| 102 |
+
width=min([console.width, 200]),
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
traceback.print_exc()
|
| 106 |
+
return JSONResponse(
|
| 107 |
+
status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err)
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
@app.middleware("http")
|
| 111 |
+
async def exception_handling(request: Request, call_next):
|
| 112 |
+
try:
|
| 113 |
+
return await call_next(request)
|
| 114 |
+
except Exception as e:
|
| 115 |
+
return handle_exception(request, e)
|
| 116 |
+
|
| 117 |
+
@app.exception_handler(Exception)
|
| 118 |
+
async def fastapi_exception_handler(request: Request, e: Exception):
|
| 119 |
+
return handle_exception(request, e)
|
| 120 |
+
|
| 121 |
+
@app.exception_handler(HTTPException)
|
| 122 |
+
async def http_exception_handler(request: Request, e: HTTPException):
|
| 123 |
+
return handle_exception(request, e)
|
| 124 |
+
|
| 125 |
+
cors_options = {
|
| 126 |
+
"allow_methods": ["*"],
|
| 127 |
+
"allow_headers": ["*"],
|
| 128 |
+
"allow_origins": ["*"],
|
| 129 |
+
"allow_credentials": True,
|
| 130 |
+
"expose_headers": ["X-Seed"],
|
| 131 |
+
}
|
| 132 |
+
app.add_middleware(CORSMiddleware, **cors_options)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
global_sio: AsyncServer = None
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}):
|
| 139 |
+
# self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
|
| 140 |
+
# logger.info(f"diffusion callback: step={step}, timestep={timestep}")
|
| 141 |
+
|
| 142 |
+
# We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
|
| 143 |
+
# but for now let's just start a separate event loop. It shouldn't make a difference for single person use
|
| 144 |
+
asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
|
| 145 |
+
return {}
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class Api:
|
| 149 |
+
def __init__(self, app: FastAPI, config: ApiConfig):
|
| 150 |
+
self.app = app
|
| 151 |
+
self.config = config
|
| 152 |
+
self.router = APIRouter()
|
| 153 |
+
self.queue_lock = threading.Lock()
|
| 154 |
+
api_middleware(self.app)
|
| 155 |
+
|
| 156 |
+
self.file_manager = self._build_file_manager()
|
| 157 |
+
self.plugins = self._build_plugins()
|
| 158 |
+
self.model_manager = self._build_model_manager()
|
| 159 |
+
|
| 160 |
+
# fmt: off
|
| 161 |
+
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
|
| 162 |
+
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"],
|
| 163 |
+
response_model=ServerConfigResponse)
|
| 164 |
+
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
|
| 165 |
+
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
| 166 |
+
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
| 167 |
+
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
| 168 |
+
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
|
| 169 |
+
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
| 170 |
+
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
| 171 |
+
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
| 172 |
+
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
|
| 173 |
+
self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
|
| 174 |
+
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
| 175 |
+
# fmt: on
|
| 176 |
+
|
| 177 |
+
global global_sio
|
| 178 |
+
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
| 179 |
+
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
|
| 180 |
+
self.app.mount("/ws", self.combined_asgi_app)
|
| 181 |
+
global_sio = self.sio
|
| 182 |
+
|
| 183 |
+
def add_api_route(self, path: str, endpoint, **kwargs):
|
| 184 |
+
return self.app.add_api_route(path, endpoint, **kwargs)
|
| 185 |
+
|
| 186 |
+
def api_save_image(self, file: UploadFile):
|
| 187 |
+
# Sanitize filename to prevent path traversal
|
| 188 |
+
safe_filename = Path(file.filename).name # Get just the filename component
|
| 189 |
+
|
| 190 |
+
# Construct the full path within output_dir
|
| 191 |
+
output_path = self.config.output_dir / safe_filename
|
| 192 |
+
|
| 193 |
+
# Ensure output directory exists
|
| 194 |
+
if not self.config.output_dir or not self.config.output_dir.exists():
|
| 195 |
+
raise HTTPException(
|
| 196 |
+
status_code=400,
|
| 197 |
+
detail="Output directory not configured or doesn't exist",
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Read and write the file
|
| 201 |
+
origin_image_bytes = file.file.read()
|
| 202 |
+
with open(output_path, "wb") as fw:
|
| 203 |
+
fw.write(origin_image_bytes)
|
| 204 |
+
|
| 205 |
+
def api_current_model(self) -> ModelInfo:
|
| 206 |
+
return self.model_manager.current_model
|
| 207 |
+
|
| 208 |
+
def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo:
|
| 209 |
+
if req.name == self.model_manager.name:
|
| 210 |
+
return self.model_manager.current_model
|
| 211 |
+
self.model_manager.switch(req.name)
|
| 212 |
+
return self.model_manager.current_model
|
| 213 |
+
|
| 214 |
+
def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
|
| 215 |
+
if req.plugin_name in self.plugins:
|
| 216 |
+
self.plugins[req.plugin_name].switch_model(req.model_name)
|
| 217 |
+
if req.plugin_name == RemoveBG.name:
|
| 218 |
+
self.config.remove_bg_model = req.model_name
|
| 219 |
+
if req.plugin_name == RealESRGANUpscaler.name:
|
| 220 |
+
self.config.realesrgan_model = req.model_name
|
| 221 |
+
if req.plugin_name == InteractiveSeg.name:
|
| 222 |
+
self.config.interactive_seg_model = req.model_name
|
| 223 |
+
torch_gc()
|
| 224 |
+
|
| 225 |
+
def api_server_config(self) -> ServerConfigResponse:
|
| 226 |
+
plugins = []
|
| 227 |
+
for it in self.plugins.values():
|
| 228 |
+
plugins.append(
|
| 229 |
+
PluginInfo(
|
| 230 |
+
name=it.name,
|
| 231 |
+
support_gen_image=it.support_gen_image,
|
| 232 |
+
support_gen_mask=it.support_gen_mask,
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
return ServerConfigResponse(
|
| 237 |
+
plugins=plugins,
|
| 238 |
+
modelInfos=self.model_manager.scan_models(),
|
| 239 |
+
removeBGModel=self.config.remove_bg_model,
|
| 240 |
+
removeBGModels=RemoveBGModel.values(),
|
| 241 |
+
realesrganModel=self.config.realesrgan_model,
|
| 242 |
+
realesrganModels=RealESRGANModel.values(),
|
| 243 |
+
interactiveSegModel=self.config.interactive_seg_model,
|
| 244 |
+
interactiveSegModels=InteractiveSegModel.values(),
|
| 245 |
+
enableFileManager=self.file_manager is not None,
|
| 246 |
+
enableAutoSaving=self.config.output_dir is not None,
|
| 247 |
+
enableControlnet=self.model_manager.enable_controlnet,
|
| 248 |
+
controlnetMethod=self.model_manager.controlnet_method,
|
| 249 |
+
disableModelSwitch=False,
|
| 250 |
+
isDesktop=False,
|
| 251 |
+
samplers=self.api_samplers(),
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
def api_input_image(self) -> FileResponse:
|
| 255 |
+
if self.config.input is None:
|
| 256 |
+
raise HTTPException(status_code=200, detail="No input image configured")
|
| 257 |
+
|
| 258 |
+
if self.config.input.is_file():
|
| 259 |
+
return FileResponse(self.config.input)
|
| 260 |
+
raise HTTPException(status_code=404, detail="Input image not found")
|
| 261 |
+
|
| 262 |
+
def api_geninfo(self, file: UploadFile) -> GenInfoResponse:
|
| 263 |
+
_, _, info = load_img(file.file.read(), return_info=True)
|
| 264 |
+
parts = info.get("parameters", "").split("Negative prompt: ")
|
| 265 |
+
prompt = parts[0].strip()
|
| 266 |
+
negative_prompt = ""
|
| 267 |
+
if len(parts) > 1:
|
| 268 |
+
negative_prompt = parts[1].split("\n")[0].strip()
|
| 269 |
+
return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
|
| 270 |
+
|
| 271 |
+
def api_inpaint(self, req: InpaintRequest):
|
| 272 |
+
image, alpha_channel, infos, ext = decode_base64_to_image(req.image)
|
| 273 |
+
mask, _, _, _ = decode_base64_to_image(req.mask, gray=True)
|
| 274 |
+
logger.info(f"image ext: {ext}")
|
| 275 |
+
|
| 276 |
+
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
| 277 |
+
if image.shape[:2] != mask.shape[:2]:
|
| 278 |
+
raise HTTPException(
|
| 279 |
+
400,
|
| 280 |
+
detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
start = time.time()
|
| 284 |
+
rgb_np_img = self.model_manager(image, mask, req)
|
| 285 |
+
logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
|
| 286 |
+
torch_gc()
|
| 287 |
+
|
| 288 |
+
rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
| 289 |
+
rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
|
| 290 |
+
|
| 291 |
+
res_img_bytes = pil_to_bytes(
|
| 292 |
+
Image.fromarray(rgb_res),
|
| 293 |
+
ext=ext,
|
| 294 |
+
quality=self.config.quality,
|
| 295 |
+
infos=infos,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
asyncio.run(self.sio.emit("diffusion_finish"))
|
| 299 |
+
|
| 300 |
+
return Response(
|
| 301 |
+
content=res_img_bytes,
|
| 302 |
+
media_type=f"image/{ext}",
|
| 303 |
+
headers={"X-Seed": str(req.sd_seed)},
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
def api_run_plugin_gen_image(self, req: RunPluginRequest):
|
| 307 |
+
ext = "png"
|
| 308 |
+
if req.name not in self.plugins:
|
| 309 |
+
raise HTTPException(status_code=422, detail="Plugin not found")
|
| 310 |
+
if not self.plugins[req.name].support_gen_image:
|
| 311 |
+
raise HTTPException(
|
| 312 |
+
status_code=422, detail="Plugin does not support output image"
|
| 313 |
+
)
|
| 314 |
+
rgb_np_img, alpha_channel, infos, _ = decode_base64_to_image(req.image)
|
| 315 |
+
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
|
| 316 |
+
torch_gc()
|
| 317 |
+
|
| 318 |
+
if bgr_or_rgba_np_img.shape[2] == 4:
|
| 319 |
+
rgba_np_img = bgr_or_rgba_np_img
|
| 320 |
+
else:
|
| 321 |
+
rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
|
| 322 |
+
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
|
| 323 |
+
|
| 324 |
+
return Response(
|
| 325 |
+
content=pil_to_bytes(
|
| 326 |
+
Image.fromarray(rgba_np_img),
|
| 327 |
+
ext=ext,
|
| 328 |
+
quality=self.config.quality,
|
| 329 |
+
infos=infos,
|
| 330 |
+
),
|
| 331 |
+
media_type=f"image/{ext}",
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
def api_run_plugin_gen_mask(self, req: RunPluginRequest):
|
| 335 |
+
if req.name not in self.plugins:
|
| 336 |
+
raise HTTPException(status_code=422, detail="Plugin not found")
|
| 337 |
+
if not self.plugins[req.name].support_gen_mask:
|
| 338 |
+
raise HTTPException(
|
| 339 |
+
status_code=422, detail="Plugin does not support output image"
|
| 340 |
+
)
|
| 341 |
+
rgb_np_img, _, _, _ = decode_base64_to_image(req.image)
|
| 342 |
+
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
|
| 343 |
+
torch_gc()
|
| 344 |
+
res_mask = gen_frontend_mask(bgr_or_gray_mask)
|
| 345 |
+
return Response(
|
| 346 |
+
content=numpy_to_bytes(res_mask, "png"),
|
| 347 |
+
media_type="image/png",
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
def api_samplers(self) -> List[str]:
|
| 351 |
+
return [member.value for member in SDSampler.__members__.values()]
|
| 352 |
+
|
| 353 |
+
def api_adjust_mask(self, req: AdjustMaskRequest):
|
| 354 |
+
mask, _, _, _ = decode_base64_to_image(req.mask, gray=True)
|
| 355 |
+
mask = adjust_mask(mask, req.kernel_size, req.operate)
|
| 356 |
+
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
|
| 357 |
+
|
| 358 |
+
def launch(self):
|
| 359 |
+
self.app.include_router(self.router)
|
| 360 |
+
uvicorn.run(
|
| 361 |
+
self.combined_asgi_app,
|
| 362 |
+
host=self.config.host,
|
| 363 |
+
port=self.config.port,
|
| 364 |
+
timeout_keep_alive=999999999,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
def _build_file_manager(self) -> Optional[FileManager]:
|
| 368 |
+
if self.config.input and self.config.input.is_dir():
|
| 369 |
+
logger.info(
|
| 370 |
+
f"Input is directory, initialize file manager {self.config.input}"
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
return FileManager(
|
| 374 |
+
app=self.app,
|
| 375 |
+
input_dir=self.config.input,
|
| 376 |
+
mask_dir=self.config.mask_dir,
|
| 377 |
+
output_dir=self.config.output_dir,
|
| 378 |
+
)
|
| 379 |
+
return None
|
| 380 |
+
|
| 381 |
+
def _build_plugins(self) -> Dict[str, BasePlugin]:
|
| 382 |
+
return build_plugins(
|
| 383 |
+
self.config.enable_interactive_seg,
|
| 384 |
+
self.config.interactive_seg_model,
|
| 385 |
+
self.config.interactive_seg_device,
|
| 386 |
+
self.config.enable_remove_bg,
|
| 387 |
+
self.config.remove_bg_device,
|
| 388 |
+
self.config.remove_bg_model,
|
| 389 |
+
self.config.enable_anime_seg,
|
| 390 |
+
self.config.enable_realesrgan,
|
| 391 |
+
self.config.realesrgan_device,
|
| 392 |
+
self.config.realesrgan_model,
|
| 393 |
+
self.config.enable_gfpgan,
|
| 394 |
+
self.config.gfpgan_device,
|
| 395 |
+
self.config.enable_restoreformer,
|
| 396 |
+
self.config.restoreformer_device,
|
| 397 |
+
self.config.no_half,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
def _build_model_manager(self):
|
| 401 |
+
return ModelManager(
|
| 402 |
+
name=self.config.model,
|
| 403 |
+
device=torch.device(self.config.device),
|
| 404 |
+
no_half=self.config.no_half,
|
| 405 |
+
low_mem=self.config.low_mem,
|
| 406 |
+
disable_nsfw=self.config.disable_nsfw_checker,
|
| 407 |
+
sd_cpu_textencoder=self.config.cpu_textencoder,
|
| 408 |
+
local_files_only=self.config.local_files_only,
|
| 409 |
+
cpu_offload=self.config.cpu_offload,
|
| 410 |
+
callback=diffuser_callback,
|
| 411 |
+
)
|
sorawm/iopaint/batch_processing.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict, Optional
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
from loguru import logger
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from rich.console import Console
|
| 10 |
+
from rich.progress import (
|
| 11 |
+
BarColumn,
|
| 12 |
+
MofNCompleteColumn,
|
| 13 |
+
Progress,
|
| 14 |
+
SpinnerColumn,
|
| 15 |
+
TaskProgressColumn,
|
| 16 |
+
TextColumn,
|
| 17 |
+
TimeElapsedColumn,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from sorawm.iopaint.helper import pil_to_bytes
|
| 21 |
+
from sorawm.iopaint.model.utils import torch_gc
|
| 22 |
+
from sorawm.iopaint.model_manager import ModelManager
|
| 23 |
+
from sorawm.iopaint.schema import InpaintRequest
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def glob_images(path: Path) -> Dict[str, Path]:
|
| 27 |
+
# png/jpg/jpeg
|
| 28 |
+
if path.is_file():
|
| 29 |
+
return {path.stem: path}
|
| 30 |
+
elif path.is_dir():
|
| 31 |
+
res = {}
|
| 32 |
+
for it in path.glob("*.*"):
|
| 33 |
+
if it.suffix.lower() in [".png", ".jpg", ".jpeg"]:
|
| 34 |
+
res[it.stem] = it
|
| 35 |
+
return res
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def batch_inpaint(
|
| 39 |
+
model: str,
|
| 40 |
+
device,
|
| 41 |
+
image: Path,
|
| 42 |
+
mask: Path,
|
| 43 |
+
output: Path,
|
| 44 |
+
config: Optional[Path] = None,
|
| 45 |
+
concat: bool = False,
|
| 46 |
+
):
|
| 47 |
+
if image.is_dir() and output.is_file():
|
| 48 |
+
logger.error(
|
| 49 |
+
"invalid --output: when image is a directory, output should be a directory"
|
| 50 |
+
)
|
| 51 |
+
exit(-1)
|
| 52 |
+
output.mkdir(parents=True, exist_ok=True)
|
| 53 |
+
|
| 54 |
+
image_paths = glob_images(image)
|
| 55 |
+
mask_paths = glob_images(mask)
|
| 56 |
+
if len(image_paths) == 0:
|
| 57 |
+
logger.error("invalid --image: empty image folder")
|
| 58 |
+
exit(-1)
|
| 59 |
+
if len(mask_paths) == 0:
|
| 60 |
+
logger.error("invalid --mask: empty mask folder")
|
| 61 |
+
exit(-1)
|
| 62 |
+
|
| 63 |
+
if config is None:
|
| 64 |
+
inpaint_request = InpaintRequest()
|
| 65 |
+
logger.info(f"Using default config: {inpaint_request}")
|
| 66 |
+
else:
|
| 67 |
+
with open(config, "r", encoding="utf-8") as f:
|
| 68 |
+
inpaint_request = InpaintRequest(**json.load(f))
|
| 69 |
+
logger.info(f"Using config: {inpaint_request}")
|
| 70 |
+
|
| 71 |
+
model_manager = ModelManager(name=model, device=device)
|
| 72 |
+
first_mask = list(mask_paths.values())[0]
|
| 73 |
+
|
| 74 |
+
console = Console()
|
| 75 |
+
|
| 76 |
+
with Progress(
|
| 77 |
+
SpinnerColumn(),
|
| 78 |
+
TextColumn("[progress.description]{task.description}"),
|
| 79 |
+
BarColumn(),
|
| 80 |
+
TaskProgressColumn(),
|
| 81 |
+
MofNCompleteColumn(),
|
| 82 |
+
TimeElapsedColumn(),
|
| 83 |
+
console=console,
|
| 84 |
+
transient=False,
|
| 85 |
+
) as progress:
|
| 86 |
+
task = progress.add_task("Batch processing...", total=len(image_paths))
|
| 87 |
+
for stem, image_p in image_paths.items():
|
| 88 |
+
if stem not in mask_paths and mask.is_dir():
|
| 89 |
+
progress.log(f"mask for {image_p} not found")
|
| 90 |
+
progress.update(task, advance=1)
|
| 91 |
+
continue
|
| 92 |
+
mask_p = mask_paths.get(stem, first_mask)
|
| 93 |
+
|
| 94 |
+
infos = Image.open(image_p).info
|
| 95 |
+
|
| 96 |
+
img = np.array(Image.open(image_p).convert("RGB"))
|
| 97 |
+
mask_img = np.array(Image.open(mask_p).convert("L"))
|
| 98 |
+
|
| 99 |
+
if mask_img.shape[:2] != img.shape[:2]:
|
| 100 |
+
progress.log(
|
| 101 |
+
f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}"
|
| 102 |
+
)
|
| 103 |
+
mask_img = cv2.resize(
|
| 104 |
+
mask_img,
|
| 105 |
+
(img.shape[1], img.shape[0]),
|
| 106 |
+
interpolation=cv2.INTER_NEAREST,
|
| 107 |
+
)
|
| 108 |
+
mask_img[mask_img >= 127] = 255
|
| 109 |
+
mask_img[mask_img < 127] = 0
|
| 110 |
+
|
| 111 |
+
# bgr
|
| 112 |
+
inpaint_result = model_manager(img, mask_img, inpaint_request)
|
| 113 |
+
inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
|
| 114 |
+
if concat:
|
| 115 |
+
mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB)
|
| 116 |
+
inpaint_result = cv2.hconcat([img, mask_img, inpaint_result])
|
| 117 |
+
|
| 118 |
+
img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos)
|
| 119 |
+
save_p = output / f"{stem}.png"
|
| 120 |
+
with open(save_p, "wb") as fw:
|
| 121 |
+
fw.write(img_bytes)
|
| 122 |
+
|
| 123 |
+
progress.update(task, advance=1)
|
| 124 |
+
torch_gc()
|
| 125 |
+
# pid = psutil.Process().pid
|
| 126 |
+
# memory_info = psutil.Process(pid).memory_info()
|
| 127 |
+
# memory_in_mb = memory_info.rss / (1024 * 1024)
|
| 128 |
+
# print(f"原图大小:{img.shape},当前进程的内存占用:{memory_in_mb}MB")
|
sorawm/iopaint/benchmark.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import nvidia_smi
|
| 9 |
+
import psutil
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from sorawm.iopaint.model_manager import ModelManager
|
| 13 |
+
from sorawm.iopaint.schema import HDStrategy, InpaintRequest, SDSampler
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
torch._C._jit_override_can_fuse_on_cpu(False)
|
| 17 |
+
torch._C._jit_override_can_fuse_on_gpu(False)
|
| 18 |
+
torch._C._jit_set_texpr_fuser_enabled(False)
|
| 19 |
+
torch._C._jit_set_nvfuser_enabled(False)
|
| 20 |
+
except:
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
NUM_THREADS = str(4)
|
| 24 |
+
|
| 25 |
+
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
| 26 |
+
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
| 27 |
+
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
| 28 |
+
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
| 29 |
+
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
| 30 |
+
if os.environ.get("CACHE_DIR"):
|
| 31 |
+
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def run_model(model, size):
|
| 35 |
+
# RGB
|
| 36 |
+
image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
|
| 37 |
+
mask = np.random.randint(0, 255, size).astype(np.uint8)
|
| 38 |
+
|
| 39 |
+
config = InpaintRequest(
|
| 40 |
+
ldm_steps=2,
|
| 41 |
+
hd_strategy=HDStrategy.ORIGINAL,
|
| 42 |
+
hd_strategy_crop_margin=128,
|
| 43 |
+
hd_strategy_crop_trigger_size=128,
|
| 44 |
+
hd_strategy_resize_limit=128,
|
| 45 |
+
prompt="a fox is sitting on a bench",
|
| 46 |
+
sd_steps=5,
|
| 47 |
+
sd_sampler=SDSampler.ddim,
|
| 48 |
+
)
|
| 49 |
+
model(image, mask, config)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def benchmark(model, times: int, empty_cache: bool):
|
| 53 |
+
sizes = [(512, 512)]
|
| 54 |
+
|
| 55 |
+
nvidia_smi.nvmlInit()
|
| 56 |
+
device_id = 0
|
| 57 |
+
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id)
|
| 58 |
+
|
| 59 |
+
def format(metrics):
|
| 60 |
+
return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}"
|
| 61 |
+
|
| 62 |
+
process = psutil.Process(os.getpid())
|
| 63 |
+
# 每个 size 给出显存和内存占用的指标
|
| 64 |
+
for size in sizes:
|
| 65 |
+
torch.cuda.empty_cache()
|
| 66 |
+
time_metrics = []
|
| 67 |
+
cpu_metrics = []
|
| 68 |
+
memory_metrics = []
|
| 69 |
+
gpu_memory_metrics = []
|
| 70 |
+
for _ in range(times):
|
| 71 |
+
start = time.time()
|
| 72 |
+
run_model(model, size)
|
| 73 |
+
torch.cuda.synchronize()
|
| 74 |
+
|
| 75 |
+
# cpu_metrics.append(process.cpu_percent())
|
| 76 |
+
time_metrics.append((time.time() - start) * 1000)
|
| 77 |
+
memory_metrics.append(process.memory_info().rss / 1024 / 1024)
|
| 78 |
+
gpu_memory_metrics.append(
|
| 79 |
+
nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
print(f"size: {size}".center(80, "-"))
|
| 83 |
+
# print(f"cpu: {format(cpu_metrics)}")
|
| 84 |
+
print(f"latency: {format(time_metrics)}ms")
|
| 85 |
+
print(f"memory: {format(memory_metrics)} MB")
|
| 86 |
+
print(f"gpu memory: {format(gpu_memory_metrics)} MB")
|
| 87 |
+
|
| 88 |
+
nvidia_smi.nvmlShutdown()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_args_parser():
|
| 92 |
+
parser = argparse.ArgumentParser()
|
| 93 |
+
parser.add_argument("--name")
|
| 94 |
+
parser.add_argument("--device", default="cuda", type=str)
|
| 95 |
+
parser.add_argument("--times", default=10, type=int)
|
| 96 |
+
parser.add_argument("--empty-cache", action="store_true")
|
| 97 |
+
return parser.parse_args()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if __name__ == "__main__":
|
| 101 |
+
args = get_args_parser()
|
| 102 |
+
device = torch.device(args.device)
|
| 103 |
+
model = ModelManager(
|
| 104 |
+
name=args.name,
|
| 105 |
+
device=device,
|
| 106 |
+
disable_nsfw=True,
|
| 107 |
+
sd_cpu_textencoder=True,
|
| 108 |
+
)
|
| 109 |
+
benchmark(model, args.times, args.empty_cache)
|
sorawm/iopaint/cli.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import webbrowser
|
| 2 |
+
from contextlib import asynccontextmanager
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import typer
|
| 7 |
+
from fastapi import FastAPI
|
| 8 |
+
from loguru import logger
|
| 9 |
+
from typer import Option
|
| 10 |
+
from typer_config import use_json_config
|
| 11 |
+
|
| 12 |
+
from sorawm.iopaint.const import *
|
| 13 |
+
from sorawm.iopaint.runtime import check_device, dump_environment_info, setup_model_dir
|
| 14 |
+
from sorawm.iopaint.schema import (
|
| 15 |
+
Device,
|
| 16 |
+
InteractiveSegModel,
|
| 17 |
+
RealESRGANModel,
|
| 18 |
+
RemoveBGModel,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@typer_app.command(help="Install all plugins dependencies")
|
| 25 |
+
def install_plugins_packages():
|
| 26 |
+
from sorawm.iopaint.installer import install_plugins_package
|
| 27 |
+
|
| 28 |
+
install_plugins_package()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace")
|
| 32 |
+
def download(
|
| 33 |
+
model: str = Option(
|
| 34 |
+
..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting"
|
| 35 |
+
),
|
| 36 |
+
model_dir: Path = Option(
|
| 37 |
+
DEFAULT_MODEL_DIR,
|
| 38 |
+
help=MODEL_DIR_HELP,
|
| 39 |
+
file_okay=False,
|
| 40 |
+
callback=setup_model_dir,
|
| 41 |
+
),
|
| 42 |
+
):
|
| 43 |
+
from sorawm.iopaint.download import cli_download_model
|
| 44 |
+
|
| 45 |
+
cli_download_model(model)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@typer_app.command(name="list", help="List downloaded models")
|
| 49 |
+
def list_model(
|
| 50 |
+
model_dir: Path = Option(
|
| 51 |
+
DEFAULT_MODEL_DIR,
|
| 52 |
+
help=MODEL_DIR_HELP,
|
| 53 |
+
file_okay=False,
|
| 54 |
+
callback=setup_model_dir,
|
| 55 |
+
),
|
| 56 |
+
):
|
| 57 |
+
from sorawm.iopaint.download import scan_models
|
| 58 |
+
|
| 59 |
+
scanned_models = scan_models()
|
| 60 |
+
for it in scanned_models:
|
| 61 |
+
print(it.name)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@typer_app.command(help="Batch processing images")
|
| 65 |
+
def run(
|
| 66 |
+
model: str = Option("lama"),
|
| 67 |
+
device: Device = Option(Device.cpu),
|
| 68 |
+
image: Path = Option(..., help="Image folders or file path"),
|
| 69 |
+
mask: Path = Option(
|
| 70 |
+
...,
|
| 71 |
+
help="Mask folders or file path. "
|
| 72 |
+
"If it is a directory, the mask images in the directory should have the same name as the original image."
|
| 73 |
+
"If it is a file, all images will use this mask."
|
| 74 |
+
"Mask will automatically resize to the same size as the original image.",
|
| 75 |
+
),
|
| 76 |
+
output: Path = Option(..., help="Output directory or file path"),
|
| 77 |
+
config: Path = Option(
|
| 78 |
+
None, help="Config file path. You can use dump command to create a base config."
|
| 79 |
+
),
|
| 80 |
+
concat: bool = Option(
|
| 81 |
+
False, help="Concat original image, mask and output images into one image"
|
| 82 |
+
),
|
| 83 |
+
model_dir: Path = Option(
|
| 84 |
+
DEFAULT_MODEL_DIR,
|
| 85 |
+
help=MODEL_DIR_HELP,
|
| 86 |
+
file_okay=False,
|
| 87 |
+
callback=setup_model_dir,
|
| 88 |
+
),
|
| 89 |
+
):
|
| 90 |
+
from sorawm.iopaint.download import cli_download_model, scan_models
|
| 91 |
+
|
| 92 |
+
scanned_models = scan_models()
|
| 93 |
+
if model not in [it.name for it in scanned_models]:
|
| 94 |
+
logger.info(f"{model} not found in {model_dir}, try to downloading")
|
| 95 |
+
cli_download_model(model)
|
| 96 |
+
|
| 97 |
+
from sorawm.iopaint.batch_processing import batch_inpaint
|
| 98 |
+
|
| 99 |
+
batch_inpaint(model, device, image, mask, output, config, concat)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@typer_app.command(help="Start IOPaint server")
|
| 103 |
+
@use_json_config()
|
| 104 |
+
def start(
|
| 105 |
+
host: str = Option("127.0.0.1"),
|
| 106 |
+
port: int = Option(8080),
|
| 107 |
+
inbrowser: bool = Option(False, help=INBROWSER_HELP),
|
| 108 |
+
model: str = Option(
|
| 109 |
+
DEFAULT_MODEL,
|
| 110 |
+
help=f"Erase models: [{', '.join(AVAILABLE_MODELS)}].\n"
|
| 111 |
+
f"Diffusion models: [{', '.join(DIFFUSION_MODELS)}] or any SD/SDXL normal/inpainting models on HuggingFace.",
|
| 112 |
+
),
|
| 113 |
+
model_dir: Path = Option(
|
| 114 |
+
DEFAULT_MODEL_DIR,
|
| 115 |
+
help=MODEL_DIR_HELP,
|
| 116 |
+
dir_okay=True,
|
| 117 |
+
file_okay=False,
|
| 118 |
+
callback=setup_model_dir,
|
| 119 |
+
),
|
| 120 |
+
low_mem: bool = Option(False, help=LOW_MEM_HELP),
|
| 121 |
+
no_half: bool = Option(False, help=NO_HALF_HELP),
|
| 122 |
+
cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
|
| 123 |
+
disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
|
| 124 |
+
cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP),
|
| 125 |
+
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
|
| 126 |
+
device: Device = Option(Device.cpu),
|
| 127 |
+
input: Optional[Path] = Option(None, help=INPUT_HELP),
|
| 128 |
+
mask_dir: Optional[Path] = Option(
|
| 129 |
+
None, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
|
| 130 |
+
),
|
| 131 |
+
output_dir: Optional[Path] = Option(
|
| 132 |
+
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
|
| 133 |
+
),
|
| 134 |
+
quality: int = Option(100, help=QUALITY_HELP),
|
| 135 |
+
enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
|
| 136 |
+
interactive_seg_model: InteractiveSegModel = Option(
|
| 137 |
+
InteractiveSegModel.sam2_1_tiny, help=INTERACTIVE_SEG_MODEL_HELP
|
| 138 |
+
),
|
| 139 |
+
interactive_seg_device: Device = Option(Device.cpu),
|
| 140 |
+
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
|
| 141 |
+
remove_bg_device: Device = Option(Device.cpu, help=REMOVE_BG_DEVICE_HELP),
|
| 142 |
+
remove_bg_model: RemoveBGModel = Option(RemoveBGModel.briaai_rmbg_1_4),
|
| 143 |
+
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
|
| 144 |
+
enable_realesrgan: bool = Option(False),
|
| 145 |
+
realesrgan_device: Device = Option(Device.cpu),
|
| 146 |
+
realesrgan_model: RealESRGANModel = Option(RealESRGANModel.realesr_general_x4v3),
|
| 147 |
+
enable_gfpgan: bool = Option(False),
|
| 148 |
+
gfpgan_device: Device = Option(Device.cpu),
|
| 149 |
+
enable_restoreformer: bool = Option(False),
|
| 150 |
+
restoreformer_device: Device = Option(Device.cpu),
|
| 151 |
+
):
|
| 152 |
+
dump_environment_info()
|
| 153 |
+
device = check_device(device)
|
| 154 |
+
remove_bg_device = check_device(remove_bg_device)
|
| 155 |
+
realesrgan_device = check_device(realesrgan_device)
|
| 156 |
+
gfpgan_device = check_device(gfpgan_device)
|
| 157 |
+
|
| 158 |
+
if input and not input.exists():
|
| 159 |
+
logger.error(f"invalid --input: {input} not exists")
|
| 160 |
+
exit(-1)
|
| 161 |
+
if mask_dir and not mask_dir.exists():
|
| 162 |
+
logger.error(f"invalid --mask-dir: {mask_dir} not exists")
|
| 163 |
+
exit(-1)
|
| 164 |
+
if input and input.is_dir() and not output_dir:
|
| 165 |
+
logger.error(
|
| 166 |
+
"invalid --output-dir: --output-dir must be set when --input is a directory"
|
| 167 |
+
)
|
| 168 |
+
exit(-1)
|
| 169 |
+
if output_dir:
|
| 170 |
+
output_dir = output_dir.expanduser().absolute()
|
| 171 |
+
logger.info(f"Image will be saved to {output_dir}")
|
| 172 |
+
if not output_dir.exists():
|
| 173 |
+
logger.info(f"Create output directory {output_dir}")
|
| 174 |
+
output_dir.mkdir(parents=True)
|
| 175 |
+
if mask_dir:
|
| 176 |
+
mask_dir = mask_dir.expanduser().absolute()
|
| 177 |
+
|
| 178 |
+
model_dir = model_dir.expanduser().absolute()
|
| 179 |
+
|
| 180 |
+
if local_files_only:
|
| 181 |
+
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
| 182 |
+
os.environ["HF_HUB_OFFLINE"] = "1"
|
| 183 |
+
|
| 184 |
+
from sorawm.iopaint.download import cli_download_model, scan_models
|
| 185 |
+
|
| 186 |
+
scanned_models = scan_models()
|
| 187 |
+
if model not in [it.name for it in scanned_models]:
|
| 188 |
+
logger.info(f"{model} not found in {model_dir}, try to downloading")
|
| 189 |
+
cli_download_model(model)
|
| 190 |
+
|
| 191 |
+
from sorawm.iopaint.api import Api
|
| 192 |
+
from sorawm.iopaint.schema import ApiConfig
|
| 193 |
+
|
| 194 |
+
@asynccontextmanager
|
| 195 |
+
async def lifespan(app: FastAPI):
|
| 196 |
+
if inbrowser:
|
| 197 |
+
webbrowser.open(f"http://localhost:{port}", new=0, autoraise=True)
|
| 198 |
+
yield
|
| 199 |
+
|
| 200 |
+
app = FastAPI(lifespan=lifespan)
|
| 201 |
+
|
| 202 |
+
api_config = ApiConfig(
|
| 203 |
+
host=host,
|
| 204 |
+
port=port,
|
| 205 |
+
inbrowser=inbrowser,
|
| 206 |
+
model=model,
|
| 207 |
+
no_half=no_half,
|
| 208 |
+
low_mem=low_mem,
|
| 209 |
+
cpu_offload=cpu_offload,
|
| 210 |
+
disable_nsfw_checker=disable_nsfw_checker,
|
| 211 |
+
local_files_only=local_files_only,
|
| 212 |
+
cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
|
| 213 |
+
device=device,
|
| 214 |
+
input=input,
|
| 215 |
+
mask_dir=mask_dir,
|
| 216 |
+
output_dir=output_dir,
|
| 217 |
+
quality=quality,
|
| 218 |
+
enable_interactive_seg=enable_interactive_seg,
|
| 219 |
+
interactive_seg_model=interactive_seg_model,
|
| 220 |
+
interactive_seg_device=interactive_seg_device,
|
| 221 |
+
enable_remove_bg=enable_remove_bg,
|
| 222 |
+
remove_bg_device=remove_bg_device,
|
| 223 |
+
remove_bg_model=remove_bg_model,
|
| 224 |
+
enable_anime_seg=enable_anime_seg,
|
| 225 |
+
enable_realesrgan=enable_realesrgan,
|
| 226 |
+
realesrgan_device=realesrgan_device,
|
| 227 |
+
realesrgan_model=realesrgan_model,
|
| 228 |
+
enable_gfpgan=enable_gfpgan,
|
| 229 |
+
gfpgan_device=gfpgan_device,
|
| 230 |
+
enable_restoreformer=enable_restoreformer,
|
| 231 |
+
restoreformer_device=restoreformer_device,
|
| 232 |
+
)
|
| 233 |
+
print(api_config.model_dump_json(indent=4))
|
| 234 |
+
api = Api(app, api_config)
|
| 235 |
+
api.launch()
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@typer_app.command(help="Start IOPaint web config page")
|
| 239 |
+
def start_web_config(
|
| 240 |
+
config_file: Path = Option("config.json"),
|
| 241 |
+
):
|
| 242 |
+
dump_environment_info()
|
| 243 |
+
from sorawm.iopaint.web_config import main
|
| 244 |
+
|
| 245 |
+
main(config_file)
|
sorawm/iopaint/const.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix"
|
| 5 |
+
KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
| 6 |
+
POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
|
| 7 |
+
ANYTEXT_NAME = "Sanster/AnyText"
|
| 8 |
+
|
| 9 |
+
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
|
| 10 |
+
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
|
| 11 |
+
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
|
| 12 |
+
DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline"
|
| 13 |
+
|
| 14 |
+
MPS_UNSUPPORT_MODELS = [
|
| 15 |
+
"lama",
|
| 16 |
+
"ldm",
|
| 17 |
+
"zits",
|
| 18 |
+
"mat",
|
| 19 |
+
"fcf",
|
| 20 |
+
"cv2",
|
| 21 |
+
"manga",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
DEFAULT_MODEL = "lama"
|
| 25 |
+
AVAILABLE_MODELS = ["lama", "ldm", "zits", "mat", "fcf", "manga", "cv2", "migan"]
|
| 26 |
+
DIFFUSION_MODELS = [
|
| 27 |
+
"runwayml/stable-diffusion-inpainting",
|
| 28 |
+
"Uminosachi/realisticVisionV51_v51VAE-inpainting",
|
| 29 |
+
"redstonehero/dreamshaper-inpainting",
|
| 30 |
+
"Sanster/anything-4.0-inpainting",
|
| 31 |
+
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
| 32 |
+
"Fantasy-Studio/Paint-by-Example",
|
| 33 |
+
"RunDiffusion/Juggernaut-XI-v11",
|
| 34 |
+
"SG161222/RealVisXL_V5.0",
|
| 35 |
+
"eienmojiki/Anything-XL",
|
| 36 |
+
POWERPAINT_NAME,
|
| 37 |
+
ANYTEXT_NAME,
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
NO_HALF_HELP = """
|
| 41 |
+
Using full precision(fp32) model.
|
| 42 |
+
If your diffusion model generate result is always black or green, use this argument.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
CPU_OFFLOAD_HELP = """
|
| 46 |
+
Offloads diffusion model's weight to CPU RAM, significantly reducing vRAM usage.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
LOW_MEM_HELP = "Enable attention slicing and vae tiling to save memory."
|
| 50 |
+
|
| 51 |
+
DISABLE_NSFW_HELP = """
|
| 52 |
+
Disable NSFW checker for diffusion model.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
CPU_TEXTENCODER_HELP = """
|
| 56 |
+
Run diffusion models text encoder on CPU to reduce vRAM usage.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
SD_CONTROLNET_CHOICES: List[str] = [
|
| 60 |
+
"lllyasviel/control_v11p_sd15_canny",
|
| 61 |
+
# "lllyasviel/control_v11p_sd15_seg",
|
| 62 |
+
"lllyasviel/control_v11p_sd15_openpose",
|
| 63 |
+
"lllyasviel/control_v11p_sd15_inpaint",
|
| 64 |
+
"lllyasviel/control_v11f1p_sd15_depth",
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
SD_BRUSHNET_CHOICES: List[str] = [
|
| 68 |
+
"Sanster/brushnet_random_mask",
|
| 69 |
+
"Sanster/brushnet_segmentation_mask",
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
SD2_CONTROLNET_CHOICES = [
|
| 73 |
+
"thibaud/controlnet-sd21-canny-diffusers",
|
| 74 |
+
"thibaud/controlnet-sd21-depth-diffusers",
|
| 75 |
+
"thibaud/controlnet-sd21-openpose-diffusers",
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
SDXL_CONTROLNET_CHOICES = [
|
| 79 |
+
"thibaud/controlnet-openpose-sdxl-1.0",
|
| 80 |
+
"destitech/controlnet-inpaint-dreamer-sdxl",
|
| 81 |
+
"diffusers/controlnet-canny-sdxl-1.0",
|
| 82 |
+
"diffusers/controlnet-canny-sdxl-1.0-mid",
|
| 83 |
+
"diffusers/controlnet-canny-sdxl-1.0-small",
|
| 84 |
+
"diffusers/controlnet-depth-sdxl-1.0",
|
| 85 |
+
"diffusers/controlnet-depth-sdxl-1.0-mid",
|
| 86 |
+
"diffusers/controlnet-depth-sdxl-1.0-small",
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
SDXL_BRUSHNET_CHOICES = ["Regulus0725/random_mask_brushnet_ckpt_sdxl_regulus_v1"]
|
| 90 |
+
|
| 91 |
+
LOCAL_FILES_ONLY_HELP = """
|
| 92 |
+
When loading diffusion models, using local files only, not connect to HuggingFace server.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
DEFAULT_MODEL_DIR = os.path.abspath(
|
| 96 |
+
os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache"))
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
MODEL_DIR_HELP = f"""
|
| 100 |
+
Model download directory (by setting XDG_CACHE_HOME environment variable), by default model download to {DEFAULT_MODEL_DIR}
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
OUTPUT_DIR_HELP = """
|
| 104 |
+
Result images will be saved to output directory automatically.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
MASK_DIR_HELP = """
|
| 108 |
+
You can view masks in FileManager
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
INPUT_HELP = """
|
| 112 |
+
If input is image, it will be loaded by default.
|
| 113 |
+
If input is directory, you can browse and select image in file manager.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
GUI_HELP = """
|
| 117 |
+
Launch Lama Cleaner as desktop app
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
QUALITY_HELP = """
|
| 121 |
+
Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
|
| 125 |
+
INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
|
| 126 |
+
REMOVE_BG_HELP = "Enable remove background plugin."
|
| 127 |
+
REMOVE_BG_DEVICE_HELP = "Device for remove background plugin. 'cuda' only supports briaai models(briaai/RMBG-1.4 and briaai/RMBG-2.0)"
|
| 128 |
+
ANIMESEG_HELP = "Enable anime segmentation plugin. Always run on CPU"
|
| 129 |
+
REALESRGAN_HELP = "Enable realesrgan super resolution"
|
| 130 |
+
GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan"
|
| 131 |
+
RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To also enhance background, use with --enable-realesrgan"
|
| 132 |
+
GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
|
| 133 |
+
|
| 134 |
+
INBROWSER_HELP = "Automatically launch IOPaint in a new tab on the default browser"
|
sorawm/iopaint/download.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
|
| 8 |
+
from loguru import logger
|
| 9 |
+
|
| 10 |
+
from sorawm.iopaint.const import (
|
| 11 |
+
ANYTEXT_NAME,
|
| 12 |
+
DEFAULT_MODEL_DIR,
|
| 13 |
+
DIFFUSERS_SD_CLASS_NAME,
|
| 14 |
+
DIFFUSERS_SD_INPAINT_CLASS_NAME,
|
| 15 |
+
DIFFUSERS_SDXL_CLASS_NAME,
|
| 16 |
+
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
|
| 17 |
+
)
|
| 18 |
+
from sorawm.iopaint.model.original_sd_configs import get_config_files
|
| 19 |
+
from sorawm.iopaint.schema import ModelInfo, ModelType
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def cli_download_model(model: str):
|
| 23 |
+
from sorawm.iopaint.model import models
|
| 24 |
+
from sorawm.iopaint.model.utils import handle_from_pretrained_exceptions
|
| 25 |
+
|
| 26 |
+
if model in models and models[model].is_erase_model:
|
| 27 |
+
logger.info(f"Downloading {model}...")
|
| 28 |
+
models[model].download()
|
| 29 |
+
logger.info("Done.")
|
| 30 |
+
elif model == ANYTEXT_NAME:
|
| 31 |
+
logger.info(f"Downloading {model}...")
|
| 32 |
+
models[model].download()
|
| 33 |
+
logger.info("Done.")
|
| 34 |
+
else:
|
| 35 |
+
logger.info(f"Downloading model from Huggingface: {model}")
|
| 36 |
+
from diffusers import DiffusionPipeline
|
| 37 |
+
|
| 38 |
+
downloaded_path = handle_from_pretrained_exceptions(
|
| 39 |
+
DiffusionPipeline.download, pretrained_model_name=model, variant="fp16"
|
| 40 |
+
)
|
| 41 |
+
logger.info(f"Done. Downloaded to {downloaded_path}")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def folder_name_to_show_name(name: str) -> str:
|
| 45 |
+
return name.replace("models--", "").replace("--", "/")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@lru_cache(maxsize=512)
|
| 49 |
+
def get_sd_model_type(model_abs_path: str) -> Optional[ModelType]:
|
| 50 |
+
if "inpaint" in Path(model_abs_path).name.lower():
|
| 51 |
+
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
| 52 |
+
else:
|
| 53 |
+
# load once to check num_in_channels
|
| 54 |
+
from diffusers import StableDiffusionInpaintPipeline
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
StableDiffusionInpaintPipeline.from_single_file(
|
| 58 |
+
model_abs_path,
|
| 59 |
+
load_safety_checker=False,
|
| 60 |
+
num_in_channels=9,
|
| 61 |
+
original_config_file=get_config_files()["v1"],
|
| 62 |
+
)
|
| 63 |
+
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
| 64 |
+
except ValueError as e:
|
| 65 |
+
if "[320, 4, 3, 3]" in str(e):
|
| 66 |
+
model_type = ModelType.DIFFUSERS_SD
|
| 67 |
+
else:
|
| 68 |
+
logger.info(f"Ignore non sdxl file: {model_abs_path}")
|
| 69 |
+
return
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error(f"Failed to load {model_abs_path}: {e}")
|
| 72 |
+
return
|
| 73 |
+
return model_type
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@lru_cache()
|
| 77 |
+
def get_sdxl_model_type(model_abs_path: str) -> Optional[ModelType]:
|
| 78 |
+
if "inpaint" in model_abs_path:
|
| 79 |
+
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
| 80 |
+
else:
|
| 81 |
+
# load once to check num_in_channels
|
| 82 |
+
from diffusers import StableDiffusionXLInpaintPipeline
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
model = StableDiffusionXLInpaintPipeline.from_single_file(
|
| 86 |
+
model_abs_path,
|
| 87 |
+
load_safety_checker=False,
|
| 88 |
+
num_in_channels=9,
|
| 89 |
+
original_config_file=get_config_files()["xl"],
|
| 90 |
+
)
|
| 91 |
+
if model.unet.config.in_channels == 9:
|
| 92 |
+
# https://github.com/huggingface/diffusers/issues/6610
|
| 93 |
+
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
| 94 |
+
else:
|
| 95 |
+
model_type = ModelType.DIFFUSERS_SDXL
|
| 96 |
+
except ValueError as e:
|
| 97 |
+
if "[320, 4, 3, 3]" in str(e):
|
| 98 |
+
model_type = ModelType.DIFFUSERS_SDXL
|
| 99 |
+
else:
|
| 100 |
+
logger.info(f"Ignore non sdxl file: {model_abs_path}")
|
| 101 |
+
return
|
| 102 |
+
except Exception as e:
|
| 103 |
+
logger.error(f"Failed to load {model_abs_path}: {e}")
|
| 104 |
+
return
|
| 105 |
+
return model_type
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
|
| 109 |
+
cache_dir = Path(cache_dir)
|
| 110 |
+
stable_diffusion_dir = cache_dir / "stable_diffusion"
|
| 111 |
+
cache_file = stable_diffusion_dir / "iopaint_cache.json"
|
| 112 |
+
model_type_cache = {}
|
| 113 |
+
if cache_file.exists():
|
| 114 |
+
try:
|
| 115 |
+
with open(cache_file, "r", encoding="utf-8") as f:
|
| 116 |
+
model_type_cache = json.load(f)
|
| 117 |
+
assert isinstance(model_type_cache, dict)
|
| 118 |
+
except:
|
| 119 |
+
pass
|
| 120 |
+
|
| 121 |
+
res = []
|
| 122 |
+
for it in stable_diffusion_dir.glob("*.*"):
|
| 123 |
+
if it.suffix not in [".safetensors", ".ckpt"]:
|
| 124 |
+
continue
|
| 125 |
+
model_abs_path = str(it.absolute())
|
| 126 |
+
model_type = model_type_cache.get(it.name)
|
| 127 |
+
if model_type is None:
|
| 128 |
+
model_type = get_sd_model_type(model_abs_path)
|
| 129 |
+
if model_type is None:
|
| 130 |
+
continue
|
| 131 |
+
|
| 132 |
+
model_type_cache[it.name] = model_type
|
| 133 |
+
res.append(
|
| 134 |
+
ModelInfo(
|
| 135 |
+
name=it.name,
|
| 136 |
+
path=model_abs_path,
|
| 137 |
+
model_type=model_type,
|
| 138 |
+
is_single_file_diffusers=True,
|
| 139 |
+
)
|
| 140 |
+
)
|
| 141 |
+
if stable_diffusion_dir.exists():
|
| 142 |
+
with open(cache_file, "w", encoding="utf-8") as fw:
|
| 143 |
+
json.dump(model_type_cache, fw, indent=2, ensure_ascii=False)
|
| 144 |
+
|
| 145 |
+
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
|
| 146 |
+
sdxl_cache_file = stable_diffusion_xl_dir / "iopaint_cache.json"
|
| 147 |
+
sdxl_model_type_cache = {}
|
| 148 |
+
if sdxl_cache_file.exists():
|
| 149 |
+
try:
|
| 150 |
+
with open(sdxl_cache_file, "r", encoding="utf-8") as f:
|
| 151 |
+
sdxl_model_type_cache = json.load(f)
|
| 152 |
+
assert isinstance(sdxl_model_type_cache, dict)
|
| 153 |
+
except:
|
| 154 |
+
pass
|
| 155 |
+
|
| 156 |
+
for it in stable_diffusion_xl_dir.glob("*.*"):
|
| 157 |
+
if it.suffix not in [".safetensors", ".ckpt"]:
|
| 158 |
+
continue
|
| 159 |
+
model_abs_path = str(it.absolute())
|
| 160 |
+
model_type = sdxl_model_type_cache.get(it.name)
|
| 161 |
+
if model_type is None:
|
| 162 |
+
model_type = get_sdxl_model_type(model_abs_path)
|
| 163 |
+
if model_type is None:
|
| 164 |
+
continue
|
| 165 |
+
|
| 166 |
+
sdxl_model_type_cache[it.name] = model_type
|
| 167 |
+
if stable_diffusion_xl_dir.exists():
|
| 168 |
+
with open(sdxl_cache_file, "w", encoding="utf-8") as fw:
|
| 169 |
+
json.dump(sdxl_model_type_cache, fw, indent=2, ensure_ascii=False)
|
| 170 |
+
|
| 171 |
+
res.append(
|
| 172 |
+
ModelInfo(
|
| 173 |
+
name=it.name,
|
| 174 |
+
path=model_abs_path,
|
| 175 |
+
model_type=model_type,
|
| 176 |
+
is_single_file_diffusers=True,
|
| 177 |
+
)
|
| 178 |
+
)
|
| 179 |
+
return res
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]:
|
| 183 |
+
res = []
|
| 184 |
+
from sorawm.iopaint.model import models
|
| 185 |
+
|
| 186 |
+
# logger.info(f"Scanning inpaint models in {model_dir}")
|
| 187 |
+
|
| 188 |
+
for name, m in models.items():
|
| 189 |
+
if m.is_erase_model and m.is_downloaded():
|
| 190 |
+
res.append(
|
| 191 |
+
ModelInfo(
|
| 192 |
+
name=name,
|
| 193 |
+
path=name,
|
| 194 |
+
model_type=ModelType.INPAINT,
|
| 195 |
+
)
|
| 196 |
+
)
|
| 197 |
+
return res
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def scan_diffusers_models() -> List[ModelInfo]:
|
| 201 |
+
from huggingface_hub.constants import HF_HUB_CACHE
|
| 202 |
+
|
| 203 |
+
available_models = []
|
| 204 |
+
cache_dir = Path(HF_HUB_CACHE)
|
| 205 |
+
# logger.info(f"Scanning diffusers models in {cache_dir}")
|
| 206 |
+
diffusers_model_names = []
|
| 207 |
+
model_index_files = glob.glob(
|
| 208 |
+
os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True
|
| 209 |
+
)
|
| 210 |
+
for it in model_index_files:
|
| 211 |
+
it = Path(it)
|
| 212 |
+
try:
|
| 213 |
+
with open(it, "r", encoding="utf-8") as f:
|
| 214 |
+
data = json.load(f)
|
| 215 |
+
except:
|
| 216 |
+
continue
|
| 217 |
+
|
| 218 |
+
_class_name = data["_class_name"]
|
| 219 |
+
name = folder_name_to_show_name(it.parent.parent.parent.name)
|
| 220 |
+
if name in diffusers_model_names:
|
| 221 |
+
continue
|
| 222 |
+
if "PowerPaint" in name:
|
| 223 |
+
model_type = ModelType.DIFFUSERS_OTHER
|
| 224 |
+
elif _class_name == DIFFUSERS_SD_CLASS_NAME:
|
| 225 |
+
model_type = ModelType.DIFFUSERS_SD
|
| 226 |
+
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
|
| 227 |
+
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
| 228 |
+
elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
|
| 229 |
+
model_type = ModelType.DIFFUSERS_SDXL
|
| 230 |
+
elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
|
| 231 |
+
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
| 232 |
+
elif _class_name in [
|
| 233 |
+
"StableDiffusionInstructPix2PixPipeline",
|
| 234 |
+
"PaintByExamplePipeline",
|
| 235 |
+
"KandinskyV22InpaintPipeline",
|
| 236 |
+
"AnyText",
|
| 237 |
+
]:
|
| 238 |
+
model_type = ModelType.DIFFUSERS_OTHER
|
| 239 |
+
else:
|
| 240 |
+
continue
|
| 241 |
+
|
| 242 |
+
diffusers_model_names.append(name)
|
| 243 |
+
available_models.append(
|
| 244 |
+
ModelInfo(
|
| 245 |
+
name=name,
|
| 246 |
+
path=name,
|
| 247 |
+
model_type=model_type,
|
| 248 |
+
)
|
| 249 |
+
)
|
| 250 |
+
return available_models
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
|
| 254 |
+
cache_dir = Path(cache_dir)
|
| 255 |
+
available_models = []
|
| 256 |
+
diffusers_model_names = []
|
| 257 |
+
model_index_files = glob.glob(
|
| 258 |
+
os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True
|
| 259 |
+
)
|
| 260 |
+
for it in model_index_files:
|
| 261 |
+
it = Path(it)
|
| 262 |
+
with open(it, "r", encoding="utf-8") as f:
|
| 263 |
+
try:
|
| 264 |
+
data = json.load(f)
|
| 265 |
+
except:
|
| 266 |
+
logger.error(
|
| 267 |
+
f"Failed to load {it}, please try revert from original model or fix model_index.json by hand."
|
| 268 |
+
)
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
_class_name = data["_class_name"]
|
| 272 |
+
name = folder_name_to_show_name(it.parent.name)
|
| 273 |
+
if name in diffusers_model_names:
|
| 274 |
+
continue
|
| 275 |
+
elif _class_name == DIFFUSERS_SD_CLASS_NAME:
|
| 276 |
+
model_type = ModelType.DIFFUSERS_SD
|
| 277 |
+
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
|
| 278 |
+
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
| 279 |
+
elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
|
| 280 |
+
model_type = ModelType.DIFFUSERS_SDXL
|
| 281 |
+
elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
|
| 282 |
+
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
| 283 |
+
else:
|
| 284 |
+
continue
|
| 285 |
+
|
| 286 |
+
diffusers_model_names.append(name)
|
| 287 |
+
available_models.append(
|
| 288 |
+
ModelInfo(
|
| 289 |
+
name=name,
|
| 290 |
+
path=str(it.parent.absolute()),
|
| 291 |
+
model_type=model_type,
|
| 292 |
+
)
|
| 293 |
+
)
|
| 294 |
+
return available_models
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
|
| 298 |
+
cache_dir = Path(cache_dir)
|
| 299 |
+
available_models = []
|
| 300 |
+
stable_diffusion_dir = cache_dir / "stable_diffusion"
|
| 301 |
+
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
|
| 302 |
+
available_models.extend(_scan_converted_diffusers_models(stable_diffusion_dir))
|
| 303 |
+
available_models.extend(_scan_converted_diffusers_models(stable_diffusion_xl_dir))
|
| 304 |
+
return available_models
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def scan_models() -> List[ModelInfo]:
|
| 308 |
+
model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
|
| 309 |
+
available_models = []
|
| 310 |
+
available_models.extend(scan_inpaint_models(model_dir))
|
| 311 |
+
available_models.extend(scan_single_file_diffusion_models(model_dir))
|
| 312 |
+
available_models.extend(scan_diffusers_models())
|
| 313 |
+
available_models.extend(scan_converted_diffusers_models(model_dir))
|
| 314 |
+
return available_models
|
sorawm/iopaint/file_manager/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .file_manager import FileManager
|
sorawm/iopaint/file_manager/file_manager.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
from fastapi import FastAPI, HTTPException
|
| 7 |
+
from PIL import Image, ImageOps, PngImagePlugin
|
| 8 |
+
from starlette.responses import FileResponse
|
| 9 |
+
|
| 10 |
+
from ..schema import MediasResponse, MediaTab
|
| 11 |
+
|
| 12 |
+
LARGE_ENOUGH_NUMBER = 100
|
| 13 |
+
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
|
| 14 |
+
from .storage_backends import FilesystemStorageBackend
|
| 15 |
+
from .utils import aspect_to_string, generate_filename, glob_img
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class FileManager:
|
| 19 |
+
def __init__(self, app: FastAPI, input_dir: Path, mask_dir: Path, output_dir: Path):
|
| 20 |
+
self.app = app
|
| 21 |
+
self.input_dir: Path = input_dir
|
| 22 |
+
self.mask_dir: Path = mask_dir
|
| 23 |
+
self.output_dir: Path = output_dir
|
| 24 |
+
|
| 25 |
+
self.image_dir_filenames = []
|
| 26 |
+
self.output_dir_filenames = []
|
| 27 |
+
if not self.thumbnail_directory.exists():
|
| 28 |
+
self.thumbnail_directory.mkdir(parents=True)
|
| 29 |
+
|
| 30 |
+
# fmt: off
|
| 31 |
+
self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["GET"], response_model=List[MediasResponse])
|
| 32 |
+
self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["GET"])
|
| 33 |
+
self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["GET"])
|
| 34 |
+
# fmt: on
|
| 35 |
+
|
| 36 |
+
def api_medias(self, tab: MediaTab) -> List[MediasResponse]:
|
| 37 |
+
img_dir = self._get_dir(tab)
|
| 38 |
+
return self._media_names(img_dir)
|
| 39 |
+
|
| 40 |
+
def api_media_file(self, tab: MediaTab, filename: str) -> FileResponse:
|
| 41 |
+
file_path = self._get_file(tab, filename)
|
| 42 |
+
return FileResponse(file_path, media_type="image/png")
|
| 43 |
+
|
| 44 |
+
# tab=${tab}?filename=${filename.name}?width=${width}&height=${height}
|
| 45 |
+
def api_media_thumbnail_file(
|
| 46 |
+
self, tab: MediaTab, filename: str, width: int, height: int
|
| 47 |
+
) -> FileResponse:
|
| 48 |
+
img_dir = self._get_dir(tab)
|
| 49 |
+
thumb_filename, (width, height) = self.get_thumbnail(
|
| 50 |
+
img_dir, filename, width=width, height=height
|
| 51 |
+
)
|
| 52 |
+
thumbnail_filepath = self.thumbnail_directory / thumb_filename
|
| 53 |
+
return FileResponse(
|
| 54 |
+
thumbnail_filepath,
|
| 55 |
+
headers={
|
| 56 |
+
"X-Width": str(width),
|
| 57 |
+
"X-Height": str(height),
|
| 58 |
+
},
|
| 59 |
+
media_type="image/jpeg",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def _get_dir(self, tab: MediaTab) -> Path:
|
| 63 |
+
if tab == "input":
|
| 64 |
+
return self.input_dir
|
| 65 |
+
elif tab == "output":
|
| 66 |
+
return self.output_dir
|
| 67 |
+
elif tab == "mask":
|
| 68 |
+
return self.mask_dir
|
| 69 |
+
else:
|
| 70 |
+
raise HTTPException(status_code=422, detail=f"tab not found: {tab}")
|
| 71 |
+
|
| 72 |
+
def _get_file(self, tab: MediaTab, filename: str) -> Path:
|
| 73 |
+
file_path = self._get_dir(tab) / filename
|
| 74 |
+
if not file_path.exists():
|
| 75 |
+
raise HTTPException(status_code=422, detail=f"file not found: {file_path}")
|
| 76 |
+
return file_path
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def thumbnail_directory(self) -> Path:
|
| 80 |
+
return self.output_dir / "thumbnails"
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def _media_names(directory: Path) -> List[MediasResponse]:
|
| 84 |
+
if directory is None:
|
| 85 |
+
return []
|
| 86 |
+
names = sorted([it.name for it in glob_img(directory)])
|
| 87 |
+
res = []
|
| 88 |
+
for name in names:
|
| 89 |
+
path = os.path.join(directory, name)
|
| 90 |
+
img = Image.open(path)
|
| 91 |
+
res.append(
|
| 92 |
+
MediasResponse(
|
| 93 |
+
name=name,
|
| 94 |
+
height=img.height,
|
| 95 |
+
width=img.width,
|
| 96 |
+
ctime=os.path.getctime(path),
|
| 97 |
+
mtime=os.path.getmtime(path),
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
return res
|
| 101 |
+
|
| 102 |
+
def get_thumbnail(
|
| 103 |
+
self, directory: Path, original_filename: str, width, height, **options
|
| 104 |
+
):
|
| 105 |
+
directory = Path(directory)
|
| 106 |
+
storage = FilesystemStorageBackend(self.app)
|
| 107 |
+
crop = options.get("crop", "fit")
|
| 108 |
+
background = options.get("background")
|
| 109 |
+
quality = options.get("quality", 90)
|
| 110 |
+
|
| 111 |
+
original_path, original_filename = os.path.split(original_filename)
|
| 112 |
+
original_filepath = os.path.join(directory, original_path, original_filename)
|
| 113 |
+
image = Image.open(BytesIO(storage.read(original_filepath)))
|
| 114 |
+
|
| 115 |
+
# keep ratio resize
|
| 116 |
+
if not width and not height:
|
| 117 |
+
width = 256
|
| 118 |
+
|
| 119 |
+
if width != 0:
|
| 120 |
+
height = int(image.height * width / image.width)
|
| 121 |
+
else:
|
| 122 |
+
width = int(image.width * height / image.height)
|
| 123 |
+
|
| 124 |
+
thumbnail_size = (width, height)
|
| 125 |
+
|
| 126 |
+
thumbnail_filename = generate_filename(
|
| 127 |
+
directory,
|
| 128 |
+
original_filename,
|
| 129 |
+
aspect_to_string(thumbnail_size),
|
| 130 |
+
crop,
|
| 131 |
+
background,
|
| 132 |
+
quality,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
thumbnail_filepath = os.path.join(
|
| 136 |
+
self.thumbnail_directory, original_path, thumbnail_filename
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
if storage.exists(thumbnail_filepath):
|
| 140 |
+
return thumbnail_filepath, (width, height)
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
image.load()
|
| 144 |
+
except (IOError, OSError):
|
| 145 |
+
self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
|
| 146 |
+
return thumbnail_filepath, (width, height)
|
| 147 |
+
|
| 148 |
+
# get original image format
|
| 149 |
+
options["format"] = options.get("format", image.format)
|
| 150 |
+
|
| 151 |
+
image = self._create_thumbnail(
|
| 152 |
+
image, thumbnail_size, crop, background=background
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
raw_data = self.get_raw_data(image, **options)
|
| 156 |
+
storage.save(thumbnail_filepath, raw_data)
|
| 157 |
+
|
| 158 |
+
return thumbnail_filepath, (width, height)
|
| 159 |
+
|
| 160 |
+
def get_raw_data(self, image, **options):
|
| 161 |
+
data = {
|
| 162 |
+
"format": self._get_format(image, **options),
|
| 163 |
+
"quality": options.get("quality", 90),
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
_file = BytesIO()
|
| 167 |
+
image.save(_file, **data)
|
| 168 |
+
return _file.getvalue()
|
| 169 |
+
|
| 170 |
+
@staticmethod
|
| 171 |
+
def colormode(image, colormode="RGB"):
|
| 172 |
+
if colormode == "RGB" or colormode == "RGBA":
|
| 173 |
+
if image.mode == "RGBA":
|
| 174 |
+
return image
|
| 175 |
+
if image.mode == "LA":
|
| 176 |
+
return image.convert("RGBA")
|
| 177 |
+
return image.convert(colormode)
|
| 178 |
+
|
| 179 |
+
if colormode == "GRAY":
|
| 180 |
+
return image.convert("L")
|
| 181 |
+
|
| 182 |
+
return image.convert(colormode)
|
| 183 |
+
|
| 184 |
+
@staticmethod
|
| 185 |
+
def background(original_image, color=0xFF):
|
| 186 |
+
size = (max(original_image.size),) * 2
|
| 187 |
+
image = Image.new("L", size, color)
|
| 188 |
+
image.paste(
|
| 189 |
+
original_image,
|
| 190 |
+
tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))),
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
return image
|
| 194 |
+
|
| 195 |
+
def _get_format(self, image, **options):
|
| 196 |
+
if options.get("format"):
|
| 197 |
+
return options.get("format")
|
| 198 |
+
if image.format:
|
| 199 |
+
return image.format
|
| 200 |
+
|
| 201 |
+
return "JPEG"
|
| 202 |
+
|
| 203 |
+
def _create_thumbnail(self, image, size, crop="fit", background=None):
|
| 204 |
+
try:
|
| 205 |
+
resample = Image.Resampling.LANCZOS
|
| 206 |
+
except AttributeError: # pylint: disable=raise-missing-from
|
| 207 |
+
resample = Image.ANTIALIAS
|
| 208 |
+
|
| 209 |
+
if crop == "fit":
|
| 210 |
+
image = ImageOps.fit(image, size, resample)
|
| 211 |
+
else:
|
| 212 |
+
image = image.copy()
|
| 213 |
+
image.thumbnail(size, resample=resample)
|
| 214 |
+
|
| 215 |
+
if background is not None:
|
| 216 |
+
image = self.background(image)
|
| 217 |
+
|
| 218 |
+
image = self.colormode(image)
|
| 219 |
+
|
| 220 |
+
return image
|
sorawm/iopaint/file_manager/storage_backends.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/storage_backends.py
|
| 2 |
+
import errno
|
| 3 |
+
import os
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BaseStorageBackend(ABC):
|
| 8 |
+
def __init__(self, app=None):
|
| 9 |
+
self.app = app
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def read(self, filepath, mode="rb", **kwargs):
|
| 13 |
+
raise NotImplementedError
|
| 14 |
+
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def exists(self, filepath):
|
| 17 |
+
raise NotImplementedError
|
| 18 |
+
|
| 19 |
+
@abstractmethod
|
| 20 |
+
def save(self, filepath, data):
|
| 21 |
+
raise NotImplementedError
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class FilesystemStorageBackend(BaseStorageBackend):
|
| 25 |
+
def read(self, filepath, mode="rb", **kwargs):
|
| 26 |
+
with open(filepath, mode) as f: # pylint: disable=unspecified-encoding
|
| 27 |
+
return f.read()
|
| 28 |
+
|
| 29 |
+
def exists(self, filepath):
|
| 30 |
+
return os.path.exists(filepath)
|
| 31 |
+
|
| 32 |
+
def save(self, filepath, data):
|
| 33 |
+
directory = os.path.dirname(filepath)
|
| 34 |
+
|
| 35 |
+
if not os.path.exists(directory):
|
| 36 |
+
try:
|
| 37 |
+
os.makedirs(directory)
|
| 38 |
+
except OSError as e:
|
| 39 |
+
if e.errno != errno.EEXIST:
|
| 40 |
+
raise
|
| 41 |
+
|
| 42 |
+
if not os.path.isdir(directory):
|
| 43 |
+
raise IOError("{} is not a directory".format(directory))
|
| 44 |
+
|
| 45 |
+
with open(filepath, "wb") as f:
|
| 46 |
+
f.write(data)
|
sorawm/iopaint/file_manager/utils.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
|
| 2 |
+
import hashlib
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def generate_filename(directory: Path, original_filename, *options) -> str:
|
| 8 |
+
text = str(directory.absolute()) + original_filename
|
| 9 |
+
for v in options:
|
| 10 |
+
text += "%s" % v
|
| 11 |
+
md5_hash = hashlib.md5()
|
| 12 |
+
md5_hash.update(text.encode("utf-8"))
|
| 13 |
+
return md5_hash.hexdigest() + ".jpg"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def parse_size(size):
|
| 17 |
+
if isinstance(size, int):
|
| 18 |
+
# If the size parameter is a single number, assume square aspect.
|
| 19 |
+
return [size, size]
|
| 20 |
+
|
| 21 |
+
if isinstance(size, (tuple, list)):
|
| 22 |
+
if len(size) == 1:
|
| 23 |
+
# If single value tuple/list is provided, exand it to two elements
|
| 24 |
+
return size + type(size)(size)
|
| 25 |
+
return size
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
thumbnail_size = [int(x) for x in size.lower().split("x", 1)]
|
| 29 |
+
except ValueError:
|
| 30 |
+
raise ValueError( # pylint: disable=raise-missing-from
|
| 31 |
+
"Bad thumbnail size format. Valid format is INTxINT."
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
if len(thumbnail_size) == 1:
|
| 35 |
+
# If the size parameter only contains a single integer, assume square aspect.
|
| 36 |
+
thumbnail_size.append(thumbnail_size[0])
|
| 37 |
+
|
| 38 |
+
return thumbnail_size
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def aspect_to_string(size):
|
| 42 |
+
if isinstance(size, str):
|
| 43 |
+
return size
|
| 44 |
+
|
| 45 |
+
return "x".join(map(str, size))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
IMG_SUFFIX = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def glob_img(p: Union[Path, str], recursive: bool = False):
|
| 52 |
+
p = Path(p)
|
| 53 |
+
if p.is_file() and p.suffix in IMG_SUFFIX:
|
| 54 |
+
yield p
|
| 55 |
+
else:
|
| 56 |
+
if recursive:
|
| 57 |
+
files = Path(p).glob("**/*.*")
|
| 58 |
+
else:
|
| 59 |
+
files = Path(p).glob("*.*")
|
| 60 |
+
|
| 61 |
+
for it in files:
|
| 62 |
+
if it.suffix not in IMG_SUFFIX:
|
| 63 |
+
continue
|
| 64 |
+
yield it
|
sorawm/iopaint/helper.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import hashlib
|
| 3 |
+
import imghdr
|
| 4 |
+
import io
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
from typing import Dict, List, Optional, Tuple
|
| 8 |
+
from urllib.parse import urlparse
|
| 9 |
+
|
| 10 |
+
import cv2
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from loguru import logger
|
| 14 |
+
from PIL import Image, ImageOps, PngImagePlugin
|
| 15 |
+
from torch.hub import download_url_to_file, get_dir
|
| 16 |
+
|
| 17 |
+
from sorawm.iopaint.const import MPS_UNSUPPORT_MODELS
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def md5sum(filename):
|
| 21 |
+
md5 = hashlib.md5()
|
| 22 |
+
with open(filename, "rb") as f:
|
| 23 |
+
for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
|
| 24 |
+
md5.update(chunk)
|
| 25 |
+
return md5.hexdigest()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def switch_mps_device(model_name, device):
|
| 29 |
+
if model_name in MPS_UNSUPPORT_MODELS and str(device) == "mps":
|
| 30 |
+
logger.info(f"{model_name} not support mps, switch to cpu")
|
| 31 |
+
return torch.device("cpu")
|
| 32 |
+
return device
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_cache_path_by_url(url):
|
| 36 |
+
parts = urlparse(url)
|
| 37 |
+
hub_dir = get_dir()
|
| 38 |
+
model_dir = os.path.join(hub_dir, "checkpoints")
|
| 39 |
+
if not os.path.isdir(model_dir):
|
| 40 |
+
os.makedirs(model_dir)
|
| 41 |
+
filename = os.path.basename(parts.path)
|
| 42 |
+
cached_file = os.path.join(model_dir, filename)
|
| 43 |
+
return cached_file
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def download_model(url, model_md5: str = None):
|
| 47 |
+
if os.path.exists(url):
|
| 48 |
+
cached_file = url
|
| 49 |
+
else:
|
| 50 |
+
cached_file = get_cache_path_by_url(url)
|
| 51 |
+
if not os.path.exists(cached_file):
|
| 52 |
+
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
| 53 |
+
hash_prefix = None
|
| 54 |
+
download_url_to_file(url, cached_file, hash_prefix, progress=True)
|
| 55 |
+
if model_md5:
|
| 56 |
+
_md5 = md5sum(cached_file)
|
| 57 |
+
if model_md5 == _md5:
|
| 58 |
+
logger.info(f"Download model success, md5: {_md5}")
|
| 59 |
+
else:
|
| 60 |
+
try:
|
| 61 |
+
os.remove(cached_file)
|
| 62 |
+
logger.error(
|
| 63 |
+
f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart sorawm.iopaint."
|
| 64 |
+
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
|
| 65 |
+
)
|
| 66 |
+
except:
|
| 67 |
+
logger.error(
|
| 68 |
+
f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart sorawm.iopaint."
|
| 69 |
+
)
|
| 70 |
+
exit(-1)
|
| 71 |
+
|
| 72 |
+
return cached_file
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def ceil_modulo(x, mod):
|
| 76 |
+
if x % mod == 0:
|
| 77 |
+
return x
|
| 78 |
+
return (x // mod + 1) * mod
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def handle_error(model_path, model_md5, e):
|
| 82 |
+
_md5 = md5sum(model_path)
|
| 83 |
+
if _md5 != model_md5:
|
| 84 |
+
try:
|
| 85 |
+
os.remove(model_path)
|
| 86 |
+
logger.error(
|
| 87 |
+
f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart sorawm.iopaint."
|
| 88 |
+
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
|
| 89 |
+
)
|
| 90 |
+
except:
|
| 91 |
+
logger.error(
|
| 92 |
+
f"Model md5: {_md5}, expected md5: {model_md5}, please delete {model_path} and restart sorawm.iopaint."
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
logger.error(
|
| 96 |
+
f"Failed to load model {model_path},"
|
| 97 |
+
f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
|
| 98 |
+
)
|
| 99 |
+
exit(-1)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def load_jit_model(url_or_path, device, model_md5: str):
|
| 103 |
+
if os.path.exists(url_or_path):
|
| 104 |
+
model_path = url_or_path
|
| 105 |
+
else:
|
| 106 |
+
model_path = download_model(url_or_path, model_md5)
|
| 107 |
+
|
| 108 |
+
logger.info(f"Loading model from: {model_path}")
|
| 109 |
+
try:
|
| 110 |
+
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
| 111 |
+
except Exception as e:
|
| 112 |
+
handle_error(model_path, model_md5, e)
|
| 113 |
+
model.eval()
|
| 114 |
+
return model
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def load_model(model: torch.nn.Module, url_or_path, device, model_md5):
|
| 118 |
+
if os.path.exists(url_or_path):
|
| 119 |
+
model_path = url_or_path
|
| 120 |
+
else:
|
| 121 |
+
model_path = download_model(url_or_path, model_md5)
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
logger.info(f"Loading model from: {model_path}")
|
| 125 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 126 |
+
model.load_state_dict(state_dict, strict=True)
|
| 127 |
+
model.to(device)
|
| 128 |
+
except Exception as e:
|
| 129 |
+
handle_error(model_path, model_md5, e)
|
| 130 |
+
model.eval()
|
| 131 |
+
return model
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
| 135 |
+
data = cv2.imencode(
|
| 136 |
+
f".{ext}",
|
| 137 |
+
image_numpy,
|
| 138 |
+
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
| 139 |
+
)[1]
|
| 140 |
+
image_bytes = data.tobytes()
|
| 141 |
+
return image_bytes
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes:
|
| 145 |
+
with io.BytesIO() as output:
|
| 146 |
+
kwargs = {k: v for k, v in infos.items() if v is not None}
|
| 147 |
+
if ext == "jpg":
|
| 148 |
+
ext = "jpeg"
|
| 149 |
+
if "png" == ext.lower() and "parameters" in kwargs:
|
| 150 |
+
pnginfo_data = PngImagePlugin.PngInfo()
|
| 151 |
+
pnginfo_data.add_text("parameters", kwargs["parameters"])
|
| 152 |
+
kwargs["pnginfo"] = pnginfo_data
|
| 153 |
+
|
| 154 |
+
pil_img.save(output, format=ext, quality=quality, **kwargs)
|
| 155 |
+
image_bytes = output.getvalue()
|
| 156 |
+
return image_bytes
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def load_img(img_bytes, gray: bool = False, return_info: bool = False):
|
| 160 |
+
alpha_channel = None
|
| 161 |
+
image = Image.open(io.BytesIO(img_bytes))
|
| 162 |
+
|
| 163 |
+
if return_info:
|
| 164 |
+
infos = image.info
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
image = ImageOps.exif_transpose(image)
|
| 168 |
+
except:
|
| 169 |
+
pass
|
| 170 |
+
|
| 171 |
+
if gray:
|
| 172 |
+
image = image.convert("L")
|
| 173 |
+
np_img = np.array(image)
|
| 174 |
+
else:
|
| 175 |
+
if image.mode == "RGBA":
|
| 176 |
+
np_img = np.array(image)
|
| 177 |
+
alpha_channel = np_img[:, :, -1]
|
| 178 |
+
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
|
| 179 |
+
else:
|
| 180 |
+
image = image.convert("RGB")
|
| 181 |
+
np_img = np.array(image)
|
| 182 |
+
|
| 183 |
+
if return_info:
|
| 184 |
+
return np_img, alpha_channel, infos
|
| 185 |
+
return np_img, alpha_channel
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def norm_img(np_img):
|
| 189 |
+
if len(np_img.shape) == 2:
|
| 190 |
+
np_img = np_img[:, :, np.newaxis]
|
| 191 |
+
np_img = np.transpose(np_img, (2, 0, 1))
|
| 192 |
+
np_img = np_img.astype("float32") / 255
|
| 193 |
+
return np_img
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def resize_max_size(
|
| 197 |
+
np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
|
| 198 |
+
) -> np.ndarray:
|
| 199 |
+
# Resize image's longer size to size_limit if longer size larger than size_limit
|
| 200 |
+
h, w = np_img.shape[:2]
|
| 201 |
+
if max(h, w) > size_limit:
|
| 202 |
+
ratio = size_limit / max(h, w)
|
| 203 |
+
new_w = int(w * ratio + 0.5)
|
| 204 |
+
new_h = int(h * ratio + 0.5)
|
| 205 |
+
return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
|
| 206 |
+
else:
|
| 207 |
+
return np_img
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def pad_img_to_modulo(
|
| 211 |
+
img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
|
| 212 |
+
):
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
img: [H, W, C]
|
| 217 |
+
mod:
|
| 218 |
+
square: 是否为正方形
|
| 219 |
+
min_size:
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
|
| 223 |
+
"""
|
| 224 |
+
if len(img.shape) == 2:
|
| 225 |
+
img = img[:, :, np.newaxis]
|
| 226 |
+
height, width = img.shape[:2]
|
| 227 |
+
out_height = ceil_modulo(height, mod)
|
| 228 |
+
out_width = ceil_modulo(width, mod)
|
| 229 |
+
|
| 230 |
+
if min_size is not None:
|
| 231 |
+
assert min_size % mod == 0
|
| 232 |
+
out_width = max(min_size, out_width)
|
| 233 |
+
out_height = max(min_size, out_height)
|
| 234 |
+
|
| 235 |
+
if square:
|
| 236 |
+
max_size = max(out_height, out_width)
|
| 237 |
+
out_height = max_size
|
| 238 |
+
out_width = max_size
|
| 239 |
+
|
| 240 |
+
return np.pad(
|
| 241 |
+
img,
|
| 242 |
+
((0, out_height - height), (0, out_width - width), (0, 0)),
|
| 243 |
+
mode="symmetric",
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
|
| 248 |
+
"""
|
| 249 |
+
Args:
|
| 250 |
+
mask: (h, w, 1) 0~255
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
|
| 254 |
+
"""
|
| 255 |
+
height, width = mask.shape[:2]
|
| 256 |
+
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
| 257 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 258 |
+
|
| 259 |
+
boxes = []
|
| 260 |
+
for cnt in contours:
|
| 261 |
+
x, y, w, h = cv2.boundingRect(cnt)
|
| 262 |
+
box = np.array([x, y, x + w, y + h]).astype(int)
|
| 263 |
+
|
| 264 |
+
box[::2] = np.clip(box[::2], 0, width)
|
| 265 |
+
box[1::2] = np.clip(box[1::2], 0, height)
|
| 266 |
+
boxes.append(box)
|
| 267 |
+
|
| 268 |
+
return boxes
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
|
| 272 |
+
"""
|
| 273 |
+
Args:
|
| 274 |
+
mask: (h, w) 0~255
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
|
| 278 |
+
"""
|
| 279 |
+
_, thresh = cv2.threshold(mask, 127, 255, 0)
|
| 280 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 281 |
+
|
| 282 |
+
max_area = 0
|
| 283 |
+
max_index = -1
|
| 284 |
+
for i, cnt in enumerate(contours):
|
| 285 |
+
area = cv2.contourArea(cnt)
|
| 286 |
+
if area > max_area:
|
| 287 |
+
max_area = area
|
| 288 |
+
max_index = i
|
| 289 |
+
|
| 290 |
+
if max_index != -1:
|
| 291 |
+
new_mask = np.zeros_like(mask)
|
| 292 |
+
return cv2.drawContours(new_mask, contours, max_index, 255, -1)
|
| 293 |
+
else:
|
| 294 |
+
return mask
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def is_mac():
|
| 298 |
+
return sys.platform == "darwin"
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def get_image_ext(img_bytes):
|
| 302 |
+
w = imghdr.what("", img_bytes)
|
| 303 |
+
if w is None:
|
| 304 |
+
w = "jpeg"
|
| 305 |
+
return w
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def decode_base64_to_image(
|
| 309 |
+
encoding: str, gray=False
|
| 310 |
+
) -> Tuple[np.array, Optional[np.array], Dict, str]:
|
| 311 |
+
if encoding.startswith("data:image/") or encoding.startswith(
|
| 312 |
+
"data:application/octet-stream;base64,"
|
| 313 |
+
):
|
| 314 |
+
encoding = encoding.split(";")[1].split(",")[1]
|
| 315 |
+
image_bytes = base64.b64decode(encoding)
|
| 316 |
+
ext = get_image_ext(image_bytes)
|
| 317 |
+
image = Image.open(io.BytesIO(image_bytes))
|
| 318 |
+
|
| 319 |
+
alpha_channel = None
|
| 320 |
+
try:
|
| 321 |
+
image = ImageOps.exif_transpose(image)
|
| 322 |
+
except:
|
| 323 |
+
pass
|
| 324 |
+
# exif_transpose will remove exif rotate info,we must call image.info after exif_transpose
|
| 325 |
+
infos = image.info
|
| 326 |
+
|
| 327 |
+
if gray:
|
| 328 |
+
image = image.convert("L")
|
| 329 |
+
np_img = np.array(image)
|
| 330 |
+
else:
|
| 331 |
+
if image.mode == "RGBA":
|
| 332 |
+
np_img = np.array(image)
|
| 333 |
+
alpha_channel = np_img[:, :, -1]
|
| 334 |
+
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
|
| 335 |
+
else:
|
| 336 |
+
image = image.convert("RGB")
|
| 337 |
+
np_img = np.array(image)
|
| 338 |
+
|
| 339 |
+
return np_img, alpha_channel, infos, ext
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes:
|
| 343 |
+
img_bytes = pil_to_bytes(
|
| 344 |
+
image,
|
| 345 |
+
"png",
|
| 346 |
+
quality=quality,
|
| 347 |
+
infos=infos,
|
| 348 |
+
)
|
| 349 |
+
return base64.b64encode(img_bytes)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray:
|
| 353 |
+
if alpha_channel is not None:
|
| 354 |
+
if alpha_channel.shape[:2] != rgb_np_img.shape[:2]:
|
| 355 |
+
alpha_channel = cv2.resize(
|
| 356 |
+
alpha_channel, dsize=(rgb_np_img.shape[1], rgb_np_img.shape[0])
|
| 357 |
+
)
|
| 358 |
+
rgb_np_img = np.concatenate(
|
| 359 |
+
(rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
|
| 360 |
+
)
|
| 361 |
+
return rgb_np_img
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def adjust_mask(mask: np.ndarray, kernel_size: int, operate):
|
| 365 |
+
# fronted brush color "ffcc00bb"
|
| 366 |
+
# kernel_size = kernel_size*2+1
|
| 367 |
+
mask[mask >= 127] = 255
|
| 368 |
+
mask[mask < 127] = 0
|
| 369 |
+
|
| 370 |
+
if operate == "reverse":
|
| 371 |
+
mask = 255 - mask
|
| 372 |
+
else:
|
| 373 |
+
kernel = cv2.getStructuringElement(
|
| 374 |
+
cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
|
| 375 |
+
)
|
| 376 |
+
if operate == "expand":
|
| 377 |
+
mask = cv2.dilate(
|
| 378 |
+
mask,
|
| 379 |
+
kernel,
|
| 380 |
+
iterations=1,
|
| 381 |
+
)
|
| 382 |
+
else:
|
| 383 |
+
mask = cv2.erode(
|
| 384 |
+
mask,
|
| 385 |
+
kernel,
|
| 386 |
+
iterations=1,
|
| 387 |
+
)
|
| 388 |
+
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
|
| 389 |
+
res_mask[mask > 128] = [255, 203, 0, int(255 * 0.73)]
|
| 390 |
+
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
|
| 391 |
+
return res_mask
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def gen_frontend_mask(bgr_or_gray_mask):
|
| 395 |
+
if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1:
|
| 396 |
+
bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY)
|
| 397 |
+
|
| 398 |
+
# fronted brush color "ffcc00bb"
|
| 399 |
+
# TODO: how to set kernel size?
|
| 400 |
+
kernel_size = 9
|
| 401 |
+
bgr_or_gray_mask = cv2.dilate(
|
| 402 |
+
bgr_or_gray_mask,
|
| 403 |
+
np.ones((kernel_size, kernel_size), np.uint8),
|
| 404 |
+
iterations=1,
|
| 405 |
+
)
|
| 406 |
+
res_mask = np.zeros(
|
| 407 |
+
(bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8
|
| 408 |
+
)
|
| 409 |
+
res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)]
|
| 410 |
+
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
|
| 411 |
+
return res_mask
|
sorawm/iopaint/installer.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def install(package):
|
| 6 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def install_plugins_package():
|
| 10 |
+
install("onnxruntime<=1.19.2")
|
| 11 |
+
install("rembg[cpu]")
|
sorawm/iopaint/model/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .anytext.anytext_model import AnyText
|
| 2 |
+
from .controlnet import ControlNet
|
| 3 |
+
from .fcf import FcF
|
| 4 |
+
from .instruct_pix2pix import InstructPix2Pix
|
| 5 |
+
from .kandinsky import Kandinsky22
|
| 6 |
+
from .lama import AnimeLaMa, LaMa
|
| 7 |
+
from .ldm import LDM
|
| 8 |
+
from .manga import Manga
|
| 9 |
+
from .mat import MAT
|
| 10 |
+
from .mi_gan import MIGAN
|
| 11 |
+
from .opencv2 import OpenCV2
|
| 12 |
+
from .paint_by_example import PaintByExample
|
| 13 |
+
from .power_paint.power_paint import PowerPaint
|
| 14 |
+
from .sd import SD, SD2, SD15, Anything4, RealisticVision14
|
| 15 |
+
from .sdxl import SDXL
|
| 16 |
+
from .zits import ZITS
|
| 17 |
+
|
| 18 |
+
models = {
|
| 19 |
+
LaMa.name: LaMa,
|
| 20 |
+
AnimeLaMa.name: AnimeLaMa,
|
| 21 |
+
LDM.name: LDM,
|
| 22 |
+
ZITS.name: ZITS,
|
| 23 |
+
MAT.name: MAT,
|
| 24 |
+
FcF.name: FcF,
|
| 25 |
+
OpenCV2.name: OpenCV2,
|
| 26 |
+
Manga.name: Manga,
|
| 27 |
+
MIGAN.name: MIGAN,
|
| 28 |
+
SD15.name: SD15,
|
| 29 |
+
Anything4.name: Anything4,
|
| 30 |
+
RealisticVision14.name: RealisticVision14,
|
| 31 |
+
SD2.name: SD2,
|
| 32 |
+
PaintByExample.name: PaintByExample,
|
| 33 |
+
InstructPix2Pix.name: InstructPix2Pix,
|
| 34 |
+
Kandinsky22.name: Kandinsky22,
|
| 35 |
+
SDXL.name: SDXL,
|
| 36 |
+
PowerPaint.name: PowerPaint,
|
| 37 |
+
AnyText.name: AnyText,
|
| 38 |
+
}
|
sorawm/iopaint/model/anytext/__init__.py
ADDED
|
File without changes
|
sorawm/iopaint/model/anytext/anytext_model.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from huggingface_hub import hf_hub_download
|
| 3 |
+
|
| 4 |
+
from sorawm.iopaint.const import ANYTEXT_NAME
|
| 5 |
+
from sorawm.iopaint.model.anytext.anytext_pipeline import AnyTextPipeline
|
| 6 |
+
from sorawm.iopaint.model.base import DiffusionInpaintModel
|
| 7 |
+
from sorawm.iopaint.model.utils import get_torch_dtype, is_local_files_only
|
| 8 |
+
from sorawm.iopaint.schema import InpaintRequest
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AnyText(DiffusionInpaintModel):
|
| 12 |
+
name = ANYTEXT_NAME
|
| 13 |
+
pad_mod = 64
|
| 14 |
+
is_erase_model = False
|
| 15 |
+
|
| 16 |
+
@staticmethod
|
| 17 |
+
def download(local_files_only=False):
|
| 18 |
+
hf_hub_download(
|
| 19 |
+
repo_id=ANYTEXT_NAME,
|
| 20 |
+
filename="model_index.json",
|
| 21 |
+
local_files_only=local_files_only,
|
| 22 |
+
)
|
| 23 |
+
ckpt_path = hf_hub_download(
|
| 24 |
+
repo_id=ANYTEXT_NAME,
|
| 25 |
+
filename="pytorch_model.fp16.safetensors",
|
| 26 |
+
local_files_only=local_files_only,
|
| 27 |
+
)
|
| 28 |
+
font_path = hf_hub_download(
|
| 29 |
+
repo_id=ANYTEXT_NAME,
|
| 30 |
+
filename="SourceHanSansSC-Medium.otf",
|
| 31 |
+
local_files_only=local_files_only,
|
| 32 |
+
)
|
| 33 |
+
return ckpt_path, font_path
|
| 34 |
+
|
| 35 |
+
def init_model(self, device, **kwargs):
|
| 36 |
+
local_files_only = is_local_files_only(**kwargs)
|
| 37 |
+
ckpt_path, font_path = self.download(local_files_only)
|
| 38 |
+
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
| 39 |
+
self.model = AnyTextPipeline(
|
| 40 |
+
ckpt_path=ckpt_path,
|
| 41 |
+
font_path=font_path,
|
| 42 |
+
device=device,
|
| 43 |
+
use_fp16=torch_dtype == torch.float16,
|
| 44 |
+
)
|
| 45 |
+
self.callback = kwargs.pop("callback", None)
|
| 46 |
+
|
| 47 |
+
def forward(self, image, mask, config: InpaintRequest):
|
| 48 |
+
"""Input image and output image have same size
|
| 49 |
+
image: [H, W, C] RGB
|
| 50 |
+
mask: [H, W, 1] 255 means area to inpainting
|
| 51 |
+
return: BGR IMAGE
|
| 52 |
+
"""
|
| 53 |
+
height, width = image.shape[:2]
|
| 54 |
+
mask = mask.astype("float32") / 255.0
|
| 55 |
+
masked_image = image * (1 - mask)
|
| 56 |
+
|
| 57 |
+
# list of rgb ndarray
|
| 58 |
+
results, rtn_code, rtn_warning = self.model(
|
| 59 |
+
image=image,
|
| 60 |
+
masked_image=masked_image,
|
| 61 |
+
prompt=config.prompt,
|
| 62 |
+
negative_prompt=config.negative_prompt,
|
| 63 |
+
num_inference_steps=config.sd_steps,
|
| 64 |
+
strength=config.sd_strength,
|
| 65 |
+
guidance_scale=config.sd_guidance_scale,
|
| 66 |
+
height=height,
|
| 67 |
+
width=width,
|
| 68 |
+
seed=config.sd_seed,
|
| 69 |
+
sort_priority="y",
|
| 70 |
+
callback=self.callback,
|
| 71 |
+
)
|
| 72 |
+
inpainted_rgb_image = results[0][..., ::-1]
|
| 73 |
+
return inpainted_rgb_image
|
sorawm/iopaint/model/anytext/anytext_pipeline.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AnyText: Multilingual Visual Text Generation And Editing
|
| 3 |
+
Paper: https://arxiv.org/abs/2311.03054
|
| 4 |
+
Code: https://github.com/tyxsspa/AnyText
|
| 5 |
+
Copyright (c) Alibaba, Inc. and its affiliates.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from safetensors.torch import load_file
|
| 11 |
+
|
| 12 |
+
from sorawm.iopaint.model.utils import set_seed
|
| 13 |
+
|
| 14 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
| 15 |
+
import re
|
| 16 |
+
|
| 17 |
+
import cv2
|
| 18 |
+
import einops
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
from PIL import ImageFont
|
| 22 |
+
|
| 23 |
+
from sorawm.iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler
|
| 24 |
+
from sorawm.iopaint.model.anytext.cldm.model import create_model, load_state_dict
|
| 25 |
+
from sorawm.iopaint.model.anytext.utils import check_channels, draw_glyph, draw_glyph2
|
| 26 |
+
|
| 27 |
+
BBOX_MAX_NUM = 8
|
| 28 |
+
PLACE_HOLDER = "*"
|
| 29 |
+
max_chars = 20
|
| 30 |
+
|
| 31 |
+
ANYTEXT_CFG = os.path.join(
|
| 32 |
+
os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def check_limits(tensor):
|
| 37 |
+
float16_min = torch.finfo(torch.float16).min
|
| 38 |
+
float16_max = torch.finfo(torch.float16).max
|
| 39 |
+
|
| 40 |
+
# 检查张量中是否有值小于float16的最小值或大于float16的最大值
|
| 41 |
+
is_below_min = (tensor < float16_min).any()
|
| 42 |
+
is_above_max = (tensor > float16_max).any()
|
| 43 |
+
|
| 44 |
+
return is_below_min or is_above_max
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class AnyTextPipeline:
|
| 48 |
+
def __init__(self, ckpt_path, font_path, device, use_fp16=True):
|
| 49 |
+
self.cfg_path = ANYTEXT_CFG
|
| 50 |
+
self.font_path = font_path
|
| 51 |
+
self.use_fp16 = use_fp16
|
| 52 |
+
self.device = device
|
| 53 |
+
|
| 54 |
+
self.font = ImageFont.truetype(font_path, size=60)
|
| 55 |
+
self.model = create_model(
|
| 56 |
+
self.cfg_path,
|
| 57 |
+
device=self.device,
|
| 58 |
+
use_fp16=self.use_fp16,
|
| 59 |
+
)
|
| 60 |
+
if self.use_fp16:
|
| 61 |
+
self.model = self.model.half()
|
| 62 |
+
if Path(ckpt_path).suffix == ".safetensors":
|
| 63 |
+
state_dict = load_file(ckpt_path, device="cpu")
|
| 64 |
+
else:
|
| 65 |
+
state_dict = load_state_dict(ckpt_path, location="cpu")
|
| 66 |
+
self.model.load_state_dict(state_dict, strict=False)
|
| 67 |
+
self.model = self.model.eval().to(self.device)
|
| 68 |
+
self.ddim_sampler = DDIMSampler(self.model, device=self.device)
|
| 69 |
+
|
| 70 |
+
def __call__(
|
| 71 |
+
self,
|
| 72 |
+
prompt: str,
|
| 73 |
+
negative_prompt: str,
|
| 74 |
+
image: np.ndarray,
|
| 75 |
+
masked_image: np.ndarray,
|
| 76 |
+
num_inference_steps: int,
|
| 77 |
+
strength: float,
|
| 78 |
+
guidance_scale: float,
|
| 79 |
+
height: int,
|
| 80 |
+
width: int,
|
| 81 |
+
seed: int,
|
| 82 |
+
sort_priority: str = "y",
|
| 83 |
+
callback=None,
|
| 84 |
+
):
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
prompt:
|
| 89 |
+
negative_prompt:
|
| 90 |
+
image:
|
| 91 |
+
masked_image:
|
| 92 |
+
num_inference_steps:
|
| 93 |
+
strength:
|
| 94 |
+
guidance_scale:
|
| 95 |
+
height:
|
| 96 |
+
width:
|
| 97 |
+
seed:
|
| 98 |
+
sort_priority: x: left-right, y: top-down
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
result: list of images in numpy.ndarray format
|
| 102 |
+
rst_code: 0: normal -1: error 1:warning
|
| 103 |
+
rst_info: string of error or warning
|
| 104 |
+
|
| 105 |
+
"""
|
| 106 |
+
set_seed(seed)
|
| 107 |
+
str_warning = ""
|
| 108 |
+
|
| 109 |
+
mode = "text-editing"
|
| 110 |
+
revise_pos = False
|
| 111 |
+
img_count = 1
|
| 112 |
+
ddim_steps = num_inference_steps
|
| 113 |
+
w = width
|
| 114 |
+
h = height
|
| 115 |
+
strength = strength
|
| 116 |
+
cfg_scale = guidance_scale
|
| 117 |
+
eta = 0.0
|
| 118 |
+
|
| 119 |
+
prompt, texts = self.modify_prompt(prompt)
|
| 120 |
+
if prompt is None and texts is None:
|
| 121 |
+
return (
|
| 122 |
+
None,
|
| 123 |
+
-1,
|
| 124 |
+
"You have input Chinese prompt but the translator is not loaded!",
|
| 125 |
+
"",
|
| 126 |
+
)
|
| 127 |
+
n_lines = len(texts)
|
| 128 |
+
if mode in ["text-generation", "gen"]:
|
| 129 |
+
edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
|
| 130 |
+
elif mode in ["text-editing", "edit"]:
|
| 131 |
+
if masked_image is None or image is None:
|
| 132 |
+
return (
|
| 133 |
+
None,
|
| 134 |
+
-1,
|
| 135 |
+
"Reference image and position image are needed for text editing!",
|
| 136 |
+
"",
|
| 137 |
+
)
|
| 138 |
+
if isinstance(image, str):
|
| 139 |
+
image = cv2.imread(image)[..., ::-1]
|
| 140 |
+
assert image is not None, f"Can't read ori_image image from{image}!"
|
| 141 |
+
elif isinstance(image, torch.Tensor):
|
| 142 |
+
image = image.cpu().numpy()
|
| 143 |
+
else:
|
| 144 |
+
assert isinstance(
|
| 145 |
+
image, np.ndarray
|
| 146 |
+
), f"Unknown format of ori_image: {type(image)}"
|
| 147 |
+
edit_image = image.clip(1, 255) # for mask reason
|
| 148 |
+
edit_image = check_channels(edit_image)
|
| 149 |
+
# edit_image = resize_image(
|
| 150 |
+
# edit_image, max_length=768
|
| 151 |
+
# ) # make w h multiple of 64, resize if w or h > max_length
|
| 152 |
+
h, w = edit_image.shape[:2] # change h, w by input ref_img
|
| 153 |
+
# preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
|
| 154 |
+
if masked_image is None:
|
| 155 |
+
pos_imgs = np.zeros((w, h, 1))
|
| 156 |
+
if isinstance(masked_image, str):
|
| 157 |
+
masked_image = cv2.imread(masked_image)[..., ::-1]
|
| 158 |
+
assert (
|
| 159 |
+
masked_image is not None
|
| 160 |
+
), f"Can't read draw_pos image from{masked_image}!"
|
| 161 |
+
pos_imgs = 255 - masked_image
|
| 162 |
+
elif isinstance(masked_image, torch.Tensor):
|
| 163 |
+
pos_imgs = masked_image.cpu().numpy()
|
| 164 |
+
else:
|
| 165 |
+
assert isinstance(
|
| 166 |
+
masked_image, np.ndarray
|
| 167 |
+
), f"Unknown format of draw_pos: {type(masked_image)}"
|
| 168 |
+
pos_imgs = 255 - masked_image
|
| 169 |
+
pos_imgs = pos_imgs[..., 0:1]
|
| 170 |
+
pos_imgs = cv2.convertScaleAbs(pos_imgs)
|
| 171 |
+
_, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
|
| 172 |
+
# seprate pos_imgs
|
| 173 |
+
pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority)
|
| 174 |
+
if len(pos_imgs) == 0:
|
| 175 |
+
pos_imgs = [np.zeros((h, w, 1))]
|
| 176 |
+
if len(pos_imgs) < n_lines:
|
| 177 |
+
if n_lines == 1 and texts[0] == " ":
|
| 178 |
+
pass # text-to-image without text
|
| 179 |
+
else:
|
| 180 |
+
raise RuntimeError(
|
| 181 |
+
f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images"
|
| 182 |
+
)
|
| 183 |
+
elif len(pos_imgs) > n_lines:
|
| 184 |
+
str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
|
| 185 |
+
# get pre_pos, poly_list, hint that needed for anytext
|
| 186 |
+
pre_pos = []
|
| 187 |
+
poly_list = []
|
| 188 |
+
for input_pos in pos_imgs:
|
| 189 |
+
if input_pos.mean() != 0:
|
| 190 |
+
input_pos = (
|
| 191 |
+
input_pos[..., np.newaxis]
|
| 192 |
+
if len(input_pos.shape) == 2
|
| 193 |
+
else input_pos
|
| 194 |
+
)
|
| 195 |
+
poly, pos_img = self.find_polygon(input_pos)
|
| 196 |
+
pre_pos += [pos_img / 255.0]
|
| 197 |
+
poly_list += [poly]
|
| 198 |
+
else:
|
| 199 |
+
pre_pos += [np.zeros((h, w, 1))]
|
| 200 |
+
poly_list += [None]
|
| 201 |
+
np_hint = np.sum(pre_pos, axis=0).clip(0, 1)
|
| 202 |
+
# prepare info dict
|
| 203 |
+
info = {}
|
| 204 |
+
info["glyphs"] = []
|
| 205 |
+
info["gly_line"] = []
|
| 206 |
+
info["positions"] = []
|
| 207 |
+
info["n_lines"] = [len(texts)] * img_count
|
| 208 |
+
gly_pos_imgs = []
|
| 209 |
+
for i in range(len(texts)):
|
| 210 |
+
text = texts[i]
|
| 211 |
+
if len(text) > max_chars:
|
| 212 |
+
str_warning = (
|
| 213 |
+
f'"{text}" length > max_chars: {max_chars}, will be cut off...'
|
| 214 |
+
)
|
| 215 |
+
text = text[:max_chars]
|
| 216 |
+
gly_scale = 2
|
| 217 |
+
if pre_pos[i].mean() != 0:
|
| 218 |
+
gly_line = draw_glyph(self.font, text)
|
| 219 |
+
glyphs = draw_glyph2(
|
| 220 |
+
self.font,
|
| 221 |
+
text,
|
| 222 |
+
poly_list[i],
|
| 223 |
+
scale=gly_scale,
|
| 224 |
+
width=w,
|
| 225 |
+
height=h,
|
| 226 |
+
add_space=False,
|
| 227 |
+
)
|
| 228 |
+
gly_pos_img = cv2.drawContours(
|
| 229 |
+
glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1
|
| 230 |
+
)
|
| 231 |
+
if revise_pos:
|
| 232 |
+
resize_gly = cv2.resize(
|
| 233 |
+
glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])
|
| 234 |
+
)
|
| 235 |
+
new_pos = cv2.morphologyEx(
|
| 236 |
+
(resize_gly * 255).astype(np.uint8),
|
| 237 |
+
cv2.MORPH_CLOSE,
|
| 238 |
+
kernel=np.ones(
|
| 239 |
+
(resize_gly.shape[0] // 10, resize_gly.shape[1] // 10),
|
| 240 |
+
dtype=np.uint8,
|
| 241 |
+
),
|
| 242 |
+
iterations=1,
|
| 243 |
+
)
|
| 244 |
+
new_pos = (
|
| 245 |
+
new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos
|
| 246 |
+
)
|
| 247 |
+
contours, _ = cv2.findContours(
|
| 248 |
+
new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
|
| 249 |
+
)
|
| 250 |
+
if len(contours) != 1:
|
| 251 |
+
str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..."
|
| 252 |
+
else:
|
| 253 |
+
rect = cv2.minAreaRect(contours[0])
|
| 254 |
+
poly = np.int0(cv2.boxPoints(rect))
|
| 255 |
+
pre_pos[i] = (
|
| 256 |
+
cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0
|
| 257 |
+
)
|
| 258 |
+
gly_pos_img = cv2.drawContours(
|
| 259 |
+
glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1
|
| 260 |
+
)
|
| 261 |
+
gly_pos_imgs += [gly_pos_img] # for show
|
| 262 |
+
else:
|
| 263 |
+
glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))
|
| 264 |
+
gly_line = np.zeros((80, 512, 1))
|
| 265 |
+
gly_pos_imgs += [
|
| 266 |
+
np.zeros((h * gly_scale, w * gly_scale, 1))
|
| 267 |
+
] # for show
|
| 268 |
+
pos = pre_pos[i]
|
| 269 |
+
info["glyphs"] += [self.arr2tensor(glyphs, img_count)]
|
| 270 |
+
info["gly_line"] += [self.arr2tensor(gly_line, img_count)]
|
| 271 |
+
info["positions"] += [self.arr2tensor(pos, img_count)]
|
| 272 |
+
# get masked_x
|
| 273 |
+
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
|
| 274 |
+
masked_img = np.transpose(masked_img, (2, 0, 1))
|
| 275 |
+
masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
|
| 276 |
+
if self.use_fp16:
|
| 277 |
+
masked_img = masked_img.half()
|
| 278 |
+
encoder_posterior = self.model.encode_first_stage(masked_img[None, ...])
|
| 279 |
+
masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach()
|
| 280 |
+
if self.use_fp16:
|
| 281 |
+
masked_x = masked_x.half()
|
| 282 |
+
info["masked_x"] = torch.cat([masked_x for _ in range(img_count)], dim=0)
|
| 283 |
+
|
| 284 |
+
hint = self.arr2tensor(np_hint, img_count)
|
| 285 |
+
cond = self.model.get_learned_conditioning(
|
| 286 |
+
dict(
|
| 287 |
+
c_concat=[hint],
|
| 288 |
+
c_crossattn=[[prompt] * img_count],
|
| 289 |
+
text_info=info,
|
| 290 |
+
)
|
| 291 |
+
)
|
| 292 |
+
un_cond = self.model.get_learned_conditioning(
|
| 293 |
+
dict(
|
| 294 |
+
c_concat=[hint],
|
| 295 |
+
c_crossattn=[[negative_prompt] * img_count],
|
| 296 |
+
text_info=info,
|
| 297 |
+
)
|
| 298 |
+
)
|
| 299 |
+
shape = (4, h // 8, w // 8)
|
| 300 |
+
self.model.control_scales = [strength] * 13
|
| 301 |
+
samples, intermediates = self.ddim_sampler.sample(
|
| 302 |
+
ddim_steps,
|
| 303 |
+
img_count,
|
| 304 |
+
shape,
|
| 305 |
+
cond,
|
| 306 |
+
verbose=False,
|
| 307 |
+
eta=eta,
|
| 308 |
+
unconditional_guidance_scale=cfg_scale,
|
| 309 |
+
unconditional_conditioning=un_cond,
|
| 310 |
+
callback=callback,
|
| 311 |
+
)
|
| 312 |
+
if self.use_fp16:
|
| 313 |
+
samples = samples.half()
|
| 314 |
+
x_samples = self.model.decode_first_stage(samples)
|
| 315 |
+
x_samples = (
|
| 316 |
+
(einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5)
|
| 317 |
+
.cpu()
|
| 318 |
+
.numpy()
|
| 319 |
+
.clip(0, 255)
|
| 320 |
+
.astype(np.uint8)
|
| 321 |
+
)
|
| 322 |
+
results = [x_samples[i] for i in range(img_count)]
|
| 323 |
+
# if (
|
| 324 |
+
# mode == "edit" and False
|
| 325 |
+
# ): # replace backgound in text editing but not ideal yet
|
| 326 |
+
# results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
|
| 327 |
+
# results = [r.clip(0, 255).astype(np.uint8) for r in results]
|
| 328 |
+
# if len(gly_pos_imgs) > 0 and show_debug:
|
| 329 |
+
# glyph_bs = np.stack(gly_pos_imgs, axis=2)
|
| 330 |
+
# glyph_img = np.sum(glyph_bs, axis=2) * 255
|
| 331 |
+
# glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
|
| 332 |
+
# results += [np.repeat(glyph_img, 3, axis=2)]
|
| 333 |
+
rst_code = 1 if str_warning else 0
|
| 334 |
+
return results, rst_code, str_warning
|
| 335 |
+
|
| 336 |
+
def modify_prompt(self, prompt):
|
| 337 |
+
prompt = prompt.replace("“", '"')
|
| 338 |
+
prompt = prompt.replace("”", '"')
|
| 339 |
+
p = '"(.*?)"'
|
| 340 |
+
strs = re.findall(p, prompt)
|
| 341 |
+
if len(strs) == 0:
|
| 342 |
+
strs = [" "]
|
| 343 |
+
else:
|
| 344 |
+
for s in strs:
|
| 345 |
+
prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1)
|
| 346 |
+
# if self.is_chinese(prompt):
|
| 347 |
+
# if self.trans_pipe is None:
|
| 348 |
+
# return None, None
|
| 349 |
+
# old_prompt = prompt
|
| 350 |
+
# prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1]
|
| 351 |
+
# print(f"Translate: {old_prompt} --> {prompt}")
|
| 352 |
+
return prompt, strs
|
| 353 |
+
|
| 354 |
+
# def is_chinese(self, text):
|
| 355 |
+
# text = checker._clean_text(text)
|
| 356 |
+
# for char in text:
|
| 357 |
+
# cp = ord(char)
|
| 358 |
+
# if checker._is_chinese_char(cp):
|
| 359 |
+
# return True
|
| 360 |
+
# return False
|
| 361 |
+
|
| 362 |
+
def separate_pos_imgs(self, img, sort_priority, gap=102):
|
| 363 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img)
|
| 364 |
+
components = []
|
| 365 |
+
for label in range(1, num_labels):
|
| 366 |
+
component = np.zeros_like(img)
|
| 367 |
+
component[labels == label] = 255
|
| 368 |
+
components.append((component, centroids[label]))
|
| 369 |
+
if sort_priority == "y":
|
| 370 |
+
fir, sec = 1, 0 # top-down first
|
| 371 |
+
elif sort_priority == "x":
|
| 372 |
+
fir, sec = 0, 1 # left-right first
|
| 373 |
+
components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))
|
| 374 |
+
sorted_components = [c[0] for c in components]
|
| 375 |
+
return sorted_components
|
| 376 |
+
|
| 377 |
+
def find_polygon(self, image, min_rect=False):
|
| 378 |
+
contours, hierarchy = cv2.findContours(
|
| 379 |
+
image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
|
| 380 |
+
)
|
| 381 |
+
max_contour = max(contours, key=cv2.contourArea) # get contour with max area
|
| 382 |
+
if min_rect:
|
| 383 |
+
# get minimum enclosing rectangle
|
| 384 |
+
rect = cv2.minAreaRect(max_contour)
|
| 385 |
+
poly = np.int0(cv2.boxPoints(rect))
|
| 386 |
+
else:
|
| 387 |
+
# get approximate polygon
|
| 388 |
+
epsilon = 0.01 * cv2.arcLength(max_contour, True)
|
| 389 |
+
poly = cv2.approxPolyDP(max_contour, epsilon, True)
|
| 390 |
+
n, _, xy = poly.shape
|
| 391 |
+
poly = poly.reshape(n, xy)
|
| 392 |
+
cv2.drawContours(image, [poly], -1, 255, -1)
|
| 393 |
+
return poly, image
|
| 394 |
+
|
| 395 |
+
def arr2tensor(self, arr, bs):
|
| 396 |
+
arr = np.transpose(arr, (2, 0, 1))
|
| 397 |
+
_arr = torch.from_numpy(arr.copy()).float().to(self.device)
|
| 398 |
+
if self.use_fp16:
|
| 399 |
+
_arr = _arr.half()
|
| 400 |
+
_arr = torch.stack([_arr for _ in range(bs)], dim=0)
|
| 401 |
+
return _arr
|
sorawm/iopaint/model/anytext/anytext_sd15.yaml
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
target: sorawm.iopaint.model.anytext.cldm.cldm.ControlLDM
|
| 3 |
+
params:
|
| 4 |
+
linear_start: 0.00085
|
| 5 |
+
linear_end: 0.0120
|
| 6 |
+
num_timesteps_cond: 1
|
| 7 |
+
log_every_t: 200
|
| 8 |
+
timesteps: 1000
|
| 9 |
+
first_stage_key: "img"
|
| 10 |
+
cond_stage_key: "caption"
|
| 11 |
+
control_key: "hint"
|
| 12 |
+
glyph_key: "glyphs"
|
| 13 |
+
position_key: "positions"
|
| 14 |
+
image_size: 64
|
| 15 |
+
channels: 4
|
| 16 |
+
cond_stage_trainable: true # need be true when embedding_manager is valid
|
| 17 |
+
conditioning_key: crossattn
|
| 18 |
+
monitor: val/loss_simple_ema
|
| 19 |
+
scale_factor: 0.18215
|
| 20 |
+
use_ema: False
|
| 21 |
+
only_mid_control: False
|
| 22 |
+
loss_alpha: 0 # perceptual loss, 0.003
|
| 23 |
+
loss_beta: 0 # ctc loss
|
| 24 |
+
latin_weight: 1.0 # latin text line may need smaller weigth
|
| 25 |
+
with_step_weight: true
|
| 26 |
+
use_vae_upsample: true
|
| 27 |
+
embedding_manager_config:
|
| 28 |
+
target: sorawm.iopaint.model.anytext.cldm.embedding_manager.EmbeddingManager
|
| 29 |
+
params:
|
| 30 |
+
valid: true # v6
|
| 31 |
+
emb_type: ocr # ocr, vit, conv
|
| 32 |
+
glyph_channels: 1
|
| 33 |
+
position_channels: 1
|
| 34 |
+
add_pos: false
|
| 35 |
+
placeholder_string: '*'
|
| 36 |
+
|
| 37 |
+
control_stage_config:
|
| 38 |
+
target: sorawm.iopaint.model.anytext.cldm.cldm.ControlNet
|
| 39 |
+
params:
|
| 40 |
+
image_size: 32 # unused
|
| 41 |
+
in_channels: 4
|
| 42 |
+
model_channels: 320
|
| 43 |
+
glyph_channels: 1
|
| 44 |
+
position_channels: 1
|
| 45 |
+
attention_resolutions: [ 4, 2, 1 ]
|
| 46 |
+
num_res_blocks: 2
|
| 47 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
| 48 |
+
num_heads: 8
|
| 49 |
+
use_spatial_transformer: True
|
| 50 |
+
transformer_depth: 1
|
| 51 |
+
context_dim: 768
|
| 52 |
+
use_checkpoint: True
|
| 53 |
+
legacy: False
|
| 54 |
+
|
| 55 |
+
unet_config:
|
| 56 |
+
target: sorawm.iopaint.model.anytext.cldm.cldm.ControlledUnetModel
|
| 57 |
+
params:
|
| 58 |
+
image_size: 32 # unused
|
| 59 |
+
in_channels: 4
|
| 60 |
+
out_channels: 4
|
| 61 |
+
model_channels: 320
|
| 62 |
+
attention_resolutions: [ 4, 2, 1 ]
|
| 63 |
+
num_res_blocks: 2
|
| 64 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
| 65 |
+
num_heads: 8
|
| 66 |
+
use_spatial_transformer: True
|
| 67 |
+
transformer_depth: 1
|
| 68 |
+
context_dim: 768
|
| 69 |
+
use_checkpoint: True
|
| 70 |
+
legacy: False
|
| 71 |
+
|
| 72 |
+
first_stage_config:
|
| 73 |
+
target: sorawm.iopaint.model.anytext.ldm.models.autoencoder.AutoencoderKL
|
| 74 |
+
params:
|
| 75 |
+
embed_dim: 4
|
| 76 |
+
monitor: val/rec_loss
|
| 77 |
+
ddconfig:
|
| 78 |
+
double_z: true
|
| 79 |
+
z_channels: 4
|
| 80 |
+
resolution: 256
|
| 81 |
+
in_channels: 3
|
| 82 |
+
out_ch: 3
|
| 83 |
+
ch: 128
|
| 84 |
+
ch_mult:
|
| 85 |
+
- 1
|
| 86 |
+
- 2
|
| 87 |
+
- 4
|
| 88 |
+
- 4
|
| 89 |
+
num_res_blocks: 2
|
| 90 |
+
attn_resolutions: []
|
| 91 |
+
dropout: 0.0
|
| 92 |
+
lossconfig:
|
| 93 |
+
target: torch.nn.Identity
|
| 94 |
+
|
| 95 |
+
cond_stage_config:
|
| 96 |
+
target: sorawm.iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedderT3
|
| 97 |
+
params:
|
| 98 |
+
version: openai/clip-vit-large-patch14
|
| 99 |
+
use_vision: false # v6
|
sorawm/iopaint/model/anytext/cldm/__init__.py
ADDED
|
File without changes
|
sorawm/iopaint/model/anytext/cldm/cldm.py
ADDED
|
@@ -0,0 +1,780 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import einops
|
| 6 |
+
import torch
|
| 7 |
+
import torch as th
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from easydict import EasyDict as edict
|
| 10 |
+
from einops import rearrange, repeat
|
| 11 |
+
|
| 12 |
+
from sorawm.iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler
|
| 13 |
+
from sorawm.iopaint.model.anytext.ldm.models.diffusion.ddpm import LatentDiffusion
|
| 14 |
+
from sorawm.iopaint.model.anytext.ldm.modules.attention import SpatialTransformer
|
| 15 |
+
from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.openaimodel import (
|
| 16 |
+
AttentionBlock,
|
| 17 |
+
Downsample,
|
| 18 |
+
ResBlock,
|
| 19 |
+
TimestepEmbedSequential,
|
| 20 |
+
UNetModel,
|
| 21 |
+
)
|
| 22 |
+
from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
|
| 23 |
+
conv_nd,
|
| 24 |
+
linear,
|
| 25 |
+
timestep_embedding,
|
| 26 |
+
zero_module,
|
| 27 |
+
)
|
| 28 |
+
from sorawm.iopaint.model.anytext.ldm.modules.distributions.distributions import (
|
| 29 |
+
DiagonalGaussianDistribution,
|
| 30 |
+
)
|
| 31 |
+
from sorawm.iopaint.model.anytext.ldm.util import (
|
| 32 |
+
exists,
|
| 33 |
+
instantiate_from_config,
|
| 34 |
+
log_txt_as_img,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
from .recognizer import TextRecognizer, create_predictor
|
| 38 |
+
|
| 39 |
+
CURRENT_DIR = Path(os.path.dirname(os.path.abspath(__file__)))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def count_parameters(model):
|
| 43 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ControlledUnetModel(UNetModel):
|
| 47 |
+
def forward(
|
| 48 |
+
self,
|
| 49 |
+
x,
|
| 50 |
+
timesteps=None,
|
| 51 |
+
context=None,
|
| 52 |
+
control=None,
|
| 53 |
+
only_mid_control=False,
|
| 54 |
+
**kwargs,
|
| 55 |
+
):
|
| 56 |
+
hs = []
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
t_emb = timestep_embedding(
|
| 59 |
+
timesteps, self.model_channels, repeat_only=False
|
| 60 |
+
)
|
| 61 |
+
if self.use_fp16:
|
| 62 |
+
t_emb = t_emb.half()
|
| 63 |
+
emb = self.time_embed(t_emb)
|
| 64 |
+
h = x.type(self.dtype)
|
| 65 |
+
for module in self.input_blocks:
|
| 66 |
+
h = module(h, emb, context)
|
| 67 |
+
hs.append(h)
|
| 68 |
+
h = self.middle_block(h, emb, context)
|
| 69 |
+
|
| 70 |
+
if control is not None:
|
| 71 |
+
h += control.pop()
|
| 72 |
+
|
| 73 |
+
for i, module in enumerate(self.output_blocks):
|
| 74 |
+
if only_mid_control or control is None:
|
| 75 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
| 76 |
+
else:
|
| 77 |
+
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
|
| 78 |
+
h = module(h, emb, context)
|
| 79 |
+
|
| 80 |
+
h = h.type(x.dtype)
|
| 81 |
+
return self.out(h)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class ControlNet(nn.Module):
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
image_size,
|
| 88 |
+
in_channels,
|
| 89 |
+
model_channels,
|
| 90 |
+
glyph_channels,
|
| 91 |
+
position_channels,
|
| 92 |
+
num_res_blocks,
|
| 93 |
+
attention_resolutions,
|
| 94 |
+
dropout=0,
|
| 95 |
+
channel_mult=(1, 2, 4, 8),
|
| 96 |
+
conv_resample=True,
|
| 97 |
+
dims=2,
|
| 98 |
+
use_checkpoint=False,
|
| 99 |
+
use_fp16=False,
|
| 100 |
+
num_heads=-1,
|
| 101 |
+
num_head_channels=-1,
|
| 102 |
+
num_heads_upsample=-1,
|
| 103 |
+
use_scale_shift_norm=False,
|
| 104 |
+
resblock_updown=False,
|
| 105 |
+
use_new_attention_order=False,
|
| 106 |
+
use_spatial_transformer=False, # custom transformer support
|
| 107 |
+
transformer_depth=1, # custom transformer support
|
| 108 |
+
context_dim=None, # custom transformer support
|
| 109 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
| 110 |
+
legacy=True,
|
| 111 |
+
disable_self_attentions=None,
|
| 112 |
+
num_attention_blocks=None,
|
| 113 |
+
disable_middle_self_attn=False,
|
| 114 |
+
use_linear_in_transformer=False,
|
| 115 |
+
):
|
| 116 |
+
super().__init__()
|
| 117 |
+
if use_spatial_transformer:
|
| 118 |
+
assert (
|
| 119 |
+
context_dim is not None
|
| 120 |
+
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
| 121 |
+
|
| 122 |
+
if context_dim is not None:
|
| 123 |
+
assert (
|
| 124 |
+
use_spatial_transformer
|
| 125 |
+
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
| 126 |
+
from omegaconf.listconfig import ListConfig
|
| 127 |
+
|
| 128 |
+
if type(context_dim) == ListConfig:
|
| 129 |
+
context_dim = list(context_dim)
|
| 130 |
+
|
| 131 |
+
if num_heads_upsample == -1:
|
| 132 |
+
num_heads_upsample = num_heads
|
| 133 |
+
|
| 134 |
+
if num_heads == -1:
|
| 135 |
+
assert (
|
| 136 |
+
num_head_channels != -1
|
| 137 |
+
), "Either num_heads or num_head_channels has to be set"
|
| 138 |
+
|
| 139 |
+
if num_head_channels == -1:
|
| 140 |
+
assert (
|
| 141 |
+
num_heads != -1
|
| 142 |
+
), "Either num_heads or num_head_channels has to be set"
|
| 143 |
+
self.dims = dims
|
| 144 |
+
self.image_size = image_size
|
| 145 |
+
self.in_channels = in_channels
|
| 146 |
+
self.model_channels = model_channels
|
| 147 |
+
if isinstance(num_res_blocks, int):
|
| 148 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
| 149 |
+
else:
|
| 150 |
+
if len(num_res_blocks) != len(channel_mult):
|
| 151 |
+
raise ValueError(
|
| 152 |
+
"provide num_res_blocks either as an int (globally constant) or "
|
| 153 |
+
"as a list/tuple (per-level) with the same length as channel_mult"
|
| 154 |
+
)
|
| 155 |
+
self.num_res_blocks = num_res_blocks
|
| 156 |
+
if disable_self_attentions is not None:
|
| 157 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
| 158 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
| 159 |
+
if num_attention_blocks is not None:
|
| 160 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
| 161 |
+
assert all(
|
| 162 |
+
map(
|
| 163 |
+
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
|
| 164 |
+
range(len(num_attention_blocks)),
|
| 165 |
+
)
|
| 166 |
+
)
|
| 167 |
+
print(
|
| 168 |
+
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
| 169 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
| 170 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
| 171 |
+
f"attention will still not be set."
|
| 172 |
+
)
|
| 173 |
+
self.attention_resolutions = attention_resolutions
|
| 174 |
+
self.dropout = dropout
|
| 175 |
+
self.channel_mult = channel_mult
|
| 176 |
+
self.conv_resample = conv_resample
|
| 177 |
+
self.use_checkpoint = use_checkpoint
|
| 178 |
+
self.use_fp16 = use_fp16
|
| 179 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
| 180 |
+
self.num_heads = num_heads
|
| 181 |
+
self.num_head_channels = num_head_channels
|
| 182 |
+
self.num_heads_upsample = num_heads_upsample
|
| 183 |
+
self.predict_codebook_ids = n_embed is not None
|
| 184 |
+
|
| 185 |
+
time_embed_dim = model_channels * 4
|
| 186 |
+
self.time_embed = nn.Sequential(
|
| 187 |
+
linear(model_channels, time_embed_dim),
|
| 188 |
+
nn.SiLU(),
|
| 189 |
+
linear(time_embed_dim, time_embed_dim),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
self.input_blocks = nn.ModuleList(
|
| 193 |
+
[
|
| 194 |
+
TimestepEmbedSequential(
|
| 195 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
| 196 |
+
)
|
| 197 |
+
]
|
| 198 |
+
)
|
| 199 |
+
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
| 200 |
+
|
| 201 |
+
self.glyph_block = TimestepEmbedSequential(
|
| 202 |
+
conv_nd(dims, glyph_channels, 8, 3, padding=1),
|
| 203 |
+
nn.SiLU(),
|
| 204 |
+
conv_nd(dims, 8, 8, 3, padding=1),
|
| 205 |
+
nn.SiLU(),
|
| 206 |
+
conv_nd(dims, 8, 16, 3, padding=1, stride=2),
|
| 207 |
+
nn.SiLU(),
|
| 208 |
+
conv_nd(dims, 16, 16, 3, padding=1),
|
| 209 |
+
nn.SiLU(),
|
| 210 |
+
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
| 211 |
+
nn.SiLU(),
|
| 212 |
+
conv_nd(dims, 32, 32, 3, padding=1),
|
| 213 |
+
nn.SiLU(),
|
| 214 |
+
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
| 215 |
+
nn.SiLU(),
|
| 216 |
+
conv_nd(dims, 96, 96, 3, padding=1),
|
| 217 |
+
nn.SiLU(),
|
| 218 |
+
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
| 219 |
+
nn.SiLU(),
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
self.position_block = TimestepEmbedSequential(
|
| 223 |
+
conv_nd(dims, position_channels, 8, 3, padding=1),
|
| 224 |
+
nn.SiLU(),
|
| 225 |
+
conv_nd(dims, 8, 8, 3, padding=1),
|
| 226 |
+
nn.SiLU(),
|
| 227 |
+
conv_nd(dims, 8, 16, 3, padding=1, stride=2),
|
| 228 |
+
nn.SiLU(),
|
| 229 |
+
conv_nd(dims, 16, 16, 3, padding=1),
|
| 230 |
+
nn.SiLU(),
|
| 231 |
+
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
| 232 |
+
nn.SiLU(),
|
| 233 |
+
conv_nd(dims, 32, 32, 3, padding=1),
|
| 234 |
+
nn.SiLU(),
|
| 235 |
+
conv_nd(dims, 32, 64, 3, padding=1, stride=2),
|
| 236 |
+
nn.SiLU(),
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
self.fuse_block = zero_module(
|
| 240 |
+
conv_nd(dims, 256 + 64 + 4, model_channels, 3, padding=1)
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
self._feature_size = model_channels
|
| 244 |
+
input_block_chans = [model_channels]
|
| 245 |
+
ch = model_channels
|
| 246 |
+
ds = 1
|
| 247 |
+
for level, mult in enumerate(channel_mult):
|
| 248 |
+
for nr in range(self.num_res_blocks[level]):
|
| 249 |
+
layers = [
|
| 250 |
+
ResBlock(
|
| 251 |
+
ch,
|
| 252 |
+
time_embed_dim,
|
| 253 |
+
dropout,
|
| 254 |
+
out_channels=mult * model_channels,
|
| 255 |
+
dims=dims,
|
| 256 |
+
use_checkpoint=use_checkpoint,
|
| 257 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 258 |
+
)
|
| 259 |
+
]
|
| 260 |
+
ch = mult * model_channels
|
| 261 |
+
if ds in attention_resolutions:
|
| 262 |
+
if num_head_channels == -1:
|
| 263 |
+
dim_head = ch // num_heads
|
| 264 |
+
else:
|
| 265 |
+
num_heads = ch // num_head_channels
|
| 266 |
+
dim_head = num_head_channels
|
| 267 |
+
if legacy:
|
| 268 |
+
# num_heads = 1
|
| 269 |
+
dim_head = (
|
| 270 |
+
ch // num_heads
|
| 271 |
+
if use_spatial_transformer
|
| 272 |
+
else num_head_channels
|
| 273 |
+
)
|
| 274 |
+
if exists(disable_self_attentions):
|
| 275 |
+
disabled_sa = disable_self_attentions[level]
|
| 276 |
+
else:
|
| 277 |
+
disabled_sa = False
|
| 278 |
+
|
| 279 |
+
if (
|
| 280 |
+
not exists(num_attention_blocks)
|
| 281 |
+
or nr < num_attention_blocks[level]
|
| 282 |
+
):
|
| 283 |
+
layers.append(
|
| 284 |
+
AttentionBlock(
|
| 285 |
+
ch,
|
| 286 |
+
use_checkpoint=use_checkpoint,
|
| 287 |
+
num_heads=num_heads,
|
| 288 |
+
num_head_channels=dim_head,
|
| 289 |
+
use_new_attention_order=use_new_attention_order,
|
| 290 |
+
)
|
| 291 |
+
if not use_spatial_transformer
|
| 292 |
+
else SpatialTransformer(
|
| 293 |
+
ch,
|
| 294 |
+
num_heads,
|
| 295 |
+
dim_head,
|
| 296 |
+
depth=transformer_depth,
|
| 297 |
+
context_dim=context_dim,
|
| 298 |
+
disable_self_attn=disabled_sa,
|
| 299 |
+
use_linear=use_linear_in_transformer,
|
| 300 |
+
use_checkpoint=use_checkpoint,
|
| 301 |
+
)
|
| 302 |
+
)
|
| 303 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| 304 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
| 305 |
+
self._feature_size += ch
|
| 306 |
+
input_block_chans.append(ch)
|
| 307 |
+
if level != len(channel_mult) - 1:
|
| 308 |
+
out_ch = ch
|
| 309 |
+
self.input_blocks.append(
|
| 310 |
+
TimestepEmbedSequential(
|
| 311 |
+
ResBlock(
|
| 312 |
+
ch,
|
| 313 |
+
time_embed_dim,
|
| 314 |
+
dropout,
|
| 315 |
+
out_channels=out_ch,
|
| 316 |
+
dims=dims,
|
| 317 |
+
use_checkpoint=use_checkpoint,
|
| 318 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 319 |
+
down=True,
|
| 320 |
+
)
|
| 321 |
+
if resblock_updown
|
| 322 |
+
else Downsample(
|
| 323 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
| 324 |
+
)
|
| 325 |
+
)
|
| 326 |
+
)
|
| 327 |
+
ch = out_ch
|
| 328 |
+
input_block_chans.append(ch)
|
| 329 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
| 330 |
+
ds *= 2
|
| 331 |
+
self._feature_size += ch
|
| 332 |
+
|
| 333 |
+
if num_head_channels == -1:
|
| 334 |
+
dim_head = ch // num_heads
|
| 335 |
+
else:
|
| 336 |
+
num_heads = ch // num_head_channels
|
| 337 |
+
dim_head = num_head_channels
|
| 338 |
+
if legacy:
|
| 339 |
+
# num_heads = 1
|
| 340 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
| 341 |
+
self.middle_block = TimestepEmbedSequential(
|
| 342 |
+
ResBlock(
|
| 343 |
+
ch,
|
| 344 |
+
time_embed_dim,
|
| 345 |
+
dropout,
|
| 346 |
+
dims=dims,
|
| 347 |
+
use_checkpoint=use_checkpoint,
|
| 348 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 349 |
+
),
|
| 350 |
+
AttentionBlock(
|
| 351 |
+
ch,
|
| 352 |
+
use_checkpoint=use_checkpoint,
|
| 353 |
+
num_heads=num_heads,
|
| 354 |
+
num_head_channels=dim_head,
|
| 355 |
+
use_new_attention_order=use_new_attention_order,
|
| 356 |
+
)
|
| 357 |
+
if not use_spatial_transformer
|
| 358 |
+
else SpatialTransformer( # always uses a self-attn
|
| 359 |
+
ch,
|
| 360 |
+
num_heads,
|
| 361 |
+
dim_head,
|
| 362 |
+
depth=transformer_depth,
|
| 363 |
+
context_dim=context_dim,
|
| 364 |
+
disable_self_attn=disable_middle_self_attn,
|
| 365 |
+
use_linear=use_linear_in_transformer,
|
| 366 |
+
use_checkpoint=use_checkpoint,
|
| 367 |
+
),
|
| 368 |
+
ResBlock(
|
| 369 |
+
ch,
|
| 370 |
+
time_embed_dim,
|
| 371 |
+
dropout,
|
| 372 |
+
dims=dims,
|
| 373 |
+
use_checkpoint=use_checkpoint,
|
| 374 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 375 |
+
),
|
| 376 |
+
)
|
| 377 |
+
self.middle_block_out = self.make_zero_conv(ch)
|
| 378 |
+
self._feature_size += ch
|
| 379 |
+
|
| 380 |
+
def make_zero_conv(self, channels):
|
| 381 |
+
return TimestepEmbedSequential(
|
| 382 |
+
zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
def forward(self, x, hint, text_info, timesteps, context, **kwargs):
|
| 386 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 387 |
+
if self.use_fp16:
|
| 388 |
+
t_emb = t_emb.half()
|
| 389 |
+
emb = self.time_embed(t_emb)
|
| 390 |
+
|
| 391 |
+
# guided_hint from text_info
|
| 392 |
+
B, C, H, W = x.shape
|
| 393 |
+
glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True)
|
| 394 |
+
positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True)
|
| 395 |
+
enc_glyph = self.glyph_block(glyphs, emb, context)
|
| 396 |
+
enc_pos = self.position_block(positions, emb, context)
|
| 397 |
+
guided_hint = self.fuse_block(
|
| 398 |
+
torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1)
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
outs = []
|
| 402 |
+
|
| 403 |
+
h = x.type(self.dtype)
|
| 404 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
| 405 |
+
if guided_hint is not None:
|
| 406 |
+
h = module(h, emb, context)
|
| 407 |
+
h += guided_hint
|
| 408 |
+
guided_hint = None
|
| 409 |
+
else:
|
| 410 |
+
h = module(h, emb, context)
|
| 411 |
+
outs.append(zero_conv(h, emb, context))
|
| 412 |
+
|
| 413 |
+
h = self.middle_block(h, emb, context)
|
| 414 |
+
outs.append(self.middle_block_out(h, emb, context))
|
| 415 |
+
|
| 416 |
+
return outs
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class ControlLDM(LatentDiffusion):
|
| 420 |
+
def __init__(
|
| 421 |
+
self,
|
| 422 |
+
control_stage_config,
|
| 423 |
+
control_key,
|
| 424 |
+
glyph_key,
|
| 425 |
+
position_key,
|
| 426 |
+
only_mid_control,
|
| 427 |
+
loss_alpha=0,
|
| 428 |
+
loss_beta=0,
|
| 429 |
+
with_step_weight=False,
|
| 430 |
+
use_vae_upsample=False,
|
| 431 |
+
latin_weight=1.0,
|
| 432 |
+
embedding_manager_config=None,
|
| 433 |
+
*args,
|
| 434 |
+
**kwargs,
|
| 435 |
+
):
|
| 436 |
+
self.use_fp16 = kwargs.pop("use_fp16", False)
|
| 437 |
+
super().__init__(*args, **kwargs)
|
| 438 |
+
self.control_model = instantiate_from_config(control_stage_config)
|
| 439 |
+
self.control_key = control_key
|
| 440 |
+
self.glyph_key = glyph_key
|
| 441 |
+
self.position_key = position_key
|
| 442 |
+
self.only_mid_control = only_mid_control
|
| 443 |
+
self.control_scales = [1.0] * 13
|
| 444 |
+
self.loss_alpha = loss_alpha
|
| 445 |
+
self.loss_beta = loss_beta
|
| 446 |
+
self.with_step_weight = with_step_weight
|
| 447 |
+
self.use_vae_upsample = use_vae_upsample
|
| 448 |
+
self.latin_weight = latin_weight
|
| 449 |
+
|
| 450 |
+
if (
|
| 451 |
+
embedding_manager_config is not None
|
| 452 |
+
and embedding_manager_config.params.valid
|
| 453 |
+
):
|
| 454 |
+
self.embedding_manager = self.instantiate_embedding_manager(
|
| 455 |
+
embedding_manager_config, self.cond_stage_model
|
| 456 |
+
)
|
| 457 |
+
for param in self.embedding_manager.embedding_parameters():
|
| 458 |
+
param.requires_grad = True
|
| 459 |
+
else:
|
| 460 |
+
self.embedding_manager = None
|
| 461 |
+
if self.loss_alpha > 0 or self.loss_beta > 0 or self.embedding_manager:
|
| 462 |
+
if embedding_manager_config.params.emb_type == "ocr":
|
| 463 |
+
self.text_predictor = create_predictor().eval()
|
| 464 |
+
args = edict()
|
| 465 |
+
args.rec_image_shape = "3, 48, 320"
|
| 466 |
+
args.rec_batch_num = 6
|
| 467 |
+
args.rec_char_dict_path = str(
|
| 468 |
+
CURRENT_DIR.parent / "ocr_recog" / "ppocr_keys_v1.txt"
|
| 469 |
+
)
|
| 470 |
+
args.use_fp16 = self.use_fp16
|
| 471 |
+
self.cn_recognizer = TextRecognizer(args, self.text_predictor)
|
| 472 |
+
for param in self.text_predictor.parameters():
|
| 473 |
+
param.requires_grad = False
|
| 474 |
+
if self.embedding_manager:
|
| 475 |
+
self.embedding_manager.recog = self.cn_recognizer
|
| 476 |
+
|
| 477 |
+
@torch.no_grad()
|
| 478 |
+
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
| 479 |
+
if self.embedding_manager is None: # fill in full caption
|
| 480 |
+
self.fill_caption(batch)
|
| 481 |
+
x, c, mx = super().get_input(
|
| 482 |
+
batch, self.first_stage_key, mask_k="masked_img", *args, **kwargs
|
| 483 |
+
)
|
| 484 |
+
control = batch[
|
| 485 |
+
self.control_key
|
| 486 |
+
] # for log_images and loss_alpha, not real control
|
| 487 |
+
if bs is not None:
|
| 488 |
+
control = control[:bs]
|
| 489 |
+
control = control.to(self.device)
|
| 490 |
+
control = einops.rearrange(control, "b h w c -> b c h w")
|
| 491 |
+
control = control.to(memory_format=torch.contiguous_format).float()
|
| 492 |
+
|
| 493 |
+
inv_mask = batch["inv_mask"]
|
| 494 |
+
if bs is not None:
|
| 495 |
+
inv_mask = inv_mask[:bs]
|
| 496 |
+
inv_mask = inv_mask.to(self.device)
|
| 497 |
+
inv_mask = einops.rearrange(inv_mask, "b h w c -> b c h w")
|
| 498 |
+
inv_mask = inv_mask.to(memory_format=torch.contiguous_format).float()
|
| 499 |
+
|
| 500 |
+
glyphs = batch[self.glyph_key]
|
| 501 |
+
gly_line = batch["gly_line"]
|
| 502 |
+
positions = batch[self.position_key]
|
| 503 |
+
n_lines = batch["n_lines"]
|
| 504 |
+
language = batch["language"]
|
| 505 |
+
texts = batch["texts"]
|
| 506 |
+
assert len(glyphs) == len(positions)
|
| 507 |
+
for i in range(len(glyphs)):
|
| 508 |
+
if bs is not None:
|
| 509 |
+
glyphs[i] = glyphs[i][:bs]
|
| 510 |
+
gly_line[i] = gly_line[i][:bs]
|
| 511 |
+
positions[i] = positions[i][:bs]
|
| 512 |
+
n_lines = n_lines[:bs]
|
| 513 |
+
glyphs[i] = glyphs[i].to(self.device)
|
| 514 |
+
gly_line[i] = gly_line[i].to(self.device)
|
| 515 |
+
positions[i] = positions[i].to(self.device)
|
| 516 |
+
glyphs[i] = einops.rearrange(glyphs[i], "b h w c -> b c h w")
|
| 517 |
+
gly_line[i] = einops.rearrange(gly_line[i], "b h w c -> b c h w")
|
| 518 |
+
positions[i] = einops.rearrange(positions[i], "b h w c -> b c h w")
|
| 519 |
+
glyphs[i] = glyphs[i].to(memory_format=torch.contiguous_format).float()
|
| 520 |
+
gly_line[i] = gly_line[i].to(memory_format=torch.contiguous_format).float()
|
| 521 |
+
positions[i] = (
|
| 522 |
+
positions[i].to(memory_format=torch.contiguous_format).float()
|
| 523 |
+
)
|
| 524 |
+
info = {}
|
| 525 |
+
info["glyphs"] = glyphs
|
| 526 |
+
info["positions"] = positions
|
| 527 |
+
info["n_lines"] = n_lines
|
| 528 |
+
info["language"] = language
|
| 529 |
+
info["texts"] = texts
|
| 530 |
+
info["img"] = batch["img"] # nhwc, (-1,1)
|
| 531 |
+
info["masked_x"] = mx
|
| 532 |
+
info["gly_line"] = gly_line
|
| 533 |
+
info["inv_mask"] = inv_mask
|
| 534 |
+
return x, dict(c_crossattn=[c], c_concat=[control], text_info=info)
|
| 535 |
+
|
| 536 |
+
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
|
| 537 |
+
assert isinstance(cond, dict)
|
| 538 |
+
diffusion_model = self.model.diffusion_model
|
| 539 |
+
_cond = torch.cat(cond["c_crossattn"], 1)
|
| 540 |
+
_hint = torch.cat(cond["c_concat"], 1)
|
| 541 |
+
if self.use_fp16:
|
| 542 |
+
x_noisy = x_noisy.half()
|
| 543 |
+
control = self.control_model(
|
| 544 |
+
x=x_noisy,
|
| 545 |
+
timesteps=t,
|
| 546 |
+
context=_cond,
|
| 547 |
+
hint=_hint,
|
| 548 |
+
text_info=cond["text_info"],
|
| 549 |
+
)
|
| 550 |
+
control = [c * scale for c, scale in zip(control, self.control_scales)]
|
| 551 |
+
eps = diffusion_model(
|
| 552 |
+
x=x_noisy,
|
| 553 |
+
timesteps=t,
|
| 554 |
+
context=_cond,
|
| 555 |
+
control=control,
|
| 556 |
+
only_mid_control=self.only_mid_control,
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
return eps
|
| 560 |
+
|
| 561 |
+
def instantiate_embedding_manager(self, config, embedder):
|
| 562 |
+
model = instantiate_from_config(config, embedder=embedder)
|
| 563 |
+
return model
|
| 564 |
+
|
| 565 |
+
@torch.no_grad()
|
| 566 |
+
def get_unconditional_conditioning(self, N):
|
| 567 |
+
return self.get_learned_conditioning(
|
| 568 |
+
dict(c_crossattn=[[""] * N], text_info=None)
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
def get_learned_conditioning(self, c):
|
| 572 |
+
if self.cond_stage_forward is None:
|
| 573 |
+
if hasattr(self.cond_stage_model, "encode") and callable(
|
| 574 |
+
self.cond_stage_model.encode
|
| 575 |
+
):
|
| 576 |
+
if self.embedding_manager is not None and c["text_info"] is not None:
|
| 577 |
+
self.embedding_manager.encode_text(c["text_info"])
|
| 578 |
+
if isinstance(c, dict):
|
| 579 |
+
cond_txt = c["c_crossattn"][0]
|
| 580 |
+
else:
|
| 581 |
+
cond_txt = c
|
| 582 |
+
if self.embedding_manager is not None:
|
| 583 |
+
cond_txt = self.cond_stage_model.encode(
|
| 584 |
+
cond_txt, embedding_manager=self.embedding_manager
|
| 585 |
+
)
|
| 586 |
+
else:
|
| 587 |
+
cond_txt = self.cond_stage_model.encode(cond_txt)
|
| 588 |
+
if isinstance(c, dict):
|
| 589 |
+
c["c_crossattn"][0] = cond_txt
|
| 590 |
+
else:
|
| 591 |
+
c = cond_txt
|
| 592 |
+
if isinstance(c, DiagonalGaussianDistribution):
|
| 593 |
+
c = c.mode()
|
| 594 |
+
else:
|
| 595 |
+
c = self.cond_stage_model(c)
|
| 596 |
+
else:
|
| 597 |
+
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
| 598 |
+
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
| 599 |
+
return c
|
| 600 |
+
|
| 601 |
+
def fill_caption(self, batch, place_holder="*"):
|
| 602 |
+
bs = len(batch["n_lines"])
|
| 603 |
+
cond_list = copy.deepcopy(batch[self.cond_stage_key])
|
| 604 |
+
for i in range(bs):
|
| 605 |
+
n_lines = batch["n_lines"][i]
|
| 606 |
+
if n_lines == 0:
|
| 607 |
+
continue
|
| 608 |
+
cur_cap = cond_list[i]
|
| 609 |
+
for j in range(n_lines):
|
| 610 |
+
r_txt = batch["texts"][j][i]
|
| 611 |
+
cur_cap = cur_cap.replace(place_holder, f'"{r_txt}"', 1)
|
| 612 |
+
cond_list[i] = cur_cap
|
| 613 |
+
batch[self.cond_stage_key] = cond_list
|
| 614 |
+
|
| 615 |
+
@torch.no_grad()
|
| 616 |
+
def log_images(
|
| 617 |
+
self,
|
| 618 |
+
batch,
|
| 619 |
+
N=4,
|
| 620 |
+
n_row=2,
|
| 621 |
+
sample=False,
|
| 622 |
+
ddim_steps=50,
|
| 623 |
+
ddim_eta=0.0,
|
| 624 |
+
return_keys=None,
|
| 625 |
+
quantize_denoised=True,
|
| 626 |
+
inpaint=True,
|
| 627 |
+
plot_denoise_rows=False,
|
| 628 |
+
plot_progressive_rows=True,
|
| 629 |
+
plot_diffusion_rows=False,
|
| 630 |
+
unconditional_guidance_scale=9.0,
|
| 631 |
+
unconditional_guidance_label=None,
|
| 632 |
+
use_ema_scope=True,
|
| 633 |
+
**kwargs,
|
| 634 |
+
):
|
| 635 |
+
use_ddim = ddim_steps is not None
|
| 636 |
+
|
| 637 |
+
log = dict()
|
| 638 |
+
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
| 639 |
+
if self.cond_stage_trainable:
|
| 640 |
+
with torch.no_grad():
|
| 641 |
+
c = self.get_learned_conditioning(c)
|
| 642 |
+
c_crossattn = c["c_crossattn"][0][:N]
|
| 643 |
+
c_cat = c["c_concat"][0][:N]
|
| 644 |
+
text_info = c["text_info"]
|
| 645 |
+
text_info["glyphs"] = [i[:N] for i in text_info["glyphs"]]
|
| 646 |
+
text_info["gly_line"] = [i[:N] for i in text_info["gly_line"]]
|
| 647 |
+
text_info["positions"] = [i[:N] for i in text_info["positions"]]
|
| 648 |
+
text_info["n_lines"] = text_info["n_lines"][:N]
|
| 649 |
+
text_info["masked_x"] = text_info["masked_x"][:N]
|
| 650 |
+
text_info["img"] = text_info["img"][:N]
|
| 651 |
+
|
| 652 |
+
N = min(z.shape[0], N)
|
| 653 |
+
n_row = min(z.shape[0], n_row)
|
| 654 |
+
log["reconstruction"] = self.decode_first_stage(z)
|
| 655 |
+
log["masked_image"] = self.decode_first_stage(text_info["masked_x"])
|
| 656 |
+
log["control"] = c_cat * 2.0 - 1.0
|
| 657 |
+
log["img"] = text_info["img"].permute(0, 3, 1, 2) # log source image if needed
|
| 658 |
+
# get glyph
|
| 659 |
+
glyph_bs = torch.stack(text_info["glyphs"])
|
| 660 |
+
glyph_bs = torch.sum(glyph_bs, dim=0) * 2.0 - 1.0
|
| 661 |
+
log["glyph"] = torch.nn.functional.interpolate(
|
| 662 |
+
glyph_bs,
|
| 663 |
+
size=(512, 512),
|
| 664 |
+
mode="bilinear",
|
| 665 |
+
align_corners=True,
|
| 666 |
+
)
|
| 667 |
+
# fill caption
|
| 668 |
+
if not self.embedding_manager:
|
| 669 |
+
self.fill_caption(batch)
|
| 670 |
+
captions = batch[self.cond_stage_key]
|
| 671 |
+
log["conditioning"] = log_txt_as_img((512, 512), captions, size=16)
|
| 672 |
+
|
| 673 |
+
if plot_diffusion_rows:
|
| 674 |
+
# get diffusion row
|
| 675 |
+
diffusion_row = list()
|
| 676 |
+
z_start = z[:n_row]
|
| 677 |
+
for t in range(self.num_timesteps):
|
| 678 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
| 679 |
+
t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
|
| 680 |
+
t = t.to(self.device).long()
|
| 681 |
+
noise = torch.randn_like(z_start)
|
| 682 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
| 683 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
| 684 |
+
|
| 685 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
| 686 |
+
diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
|
| 687 |
+
diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
|
| 688 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
| 689 |
+
log["diffusion_row"] = diffusion_grid
|
| 690 |
+
|
| 691 |
+
if sample:
|
| 692 |
+
# get denoise row
|
| 693 |
+
samples, z_denoise_row = self.sample_log(
|
| 694 |
+
cond={"c_concat": [c_cat], "c_crossattn": [c], "text_info": text_info},
|
| 695 |
+
batch_size=N,
|
| 696 |
+
ddim=use_ddim,
|
| 697 |
+
ddim_steps=ddim_steps,
|
| 698 |
+
eta=ddim_eta,
|
| 699 |
+
)
|
| 700 |
+
x_samples = self.decode_first_stage(samples)
|
| 701 |
+
log["samples"] = x_samples
|
| 702 |
+
if plot_denoise_rows:
|
| 703 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
| 704 |
+
log["denoise_row"] = denoise_grid
|
| 705 |
+
|
| 706 |
+
if unconditional_guidance_scale > 1.0:
|
| 707 |
+
uc_cross = self.get_unconditional_conditioning(N)
|
| 708 |
+
uc_cat = c_cat # torch.zeros_like(c_cat)
|
| 709 |
+
uc_full = {
|
| 710 |
+
"c_concat": [uc_cat],
|
| 711 |
+
"c_crossattn": [uc_cross["c_crossattn"][0]],
|
| 712 |
+
"text_info": text_info,
|
| 713 |
+
}
|
| 714 |
+
samples_cfg, tmps = self.sample_log(
|
| 715 |
+
cond={
|
| 716 |
+
"c_concat": [c_cat],
|
| 717 |
+
"c_crossattn": [c_crossattn],
|
| 718 |
+
"text_info": text_info,
|
| 719 |
+
},
|
| 720 |
+
batch_size=N,
|
| 721 |
+
ddim=use_ddim,
|
| 722 |
+
ddim_steps=ddim_steps,
|
| 723 |
+
eta=ddim_eta,
|
| 724 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 725 |
+
unconditional_conditioning=uc_full,
|
| 726 |
+
)
|
| 727 |
+
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
| 728 |
+
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
| 729 |
+
pred_x0 = False # wether log pred_x0
|
| 730 |
+
if pred_x0:
|
| 731 |
+
for idx in range(len(tmps["pred_x0"])):
|
| 732 |
+
pred_x0 = self.decode_first_stage(tmps["pred_x0"][idx])
|
| 733 |
+
log[f"pred_x0_{tmps['index'][idx]}"] = pred_x0
|
| 734 |
+
|
| 735 |
+
return log
|
| 736 |
+
|
| 737 |
+
@torch.no_grad()
|
| 738 |
+
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
| 739 |
+
ddim_sampler = DDIMSampler(self)
|
| 740 |
+
b, c, h, w = cond["c_concat"][0].shape
|
| 741 |
+
shape = (self.channels, h // 8, w // 8)
|
| 742 |
+
samples, intermediates = ddim_sampler.sample(
|
| 743 |
+
ddim_steps, batch_size, shape, cond, verbose=False, log_every_t=5, **kwargs
|
| 744 |
+
)
|
| 745 |
+
return samples, intermediates
|
| 746 |
+
|
| 747 |
+
def configure_optimizers(self):
|
| 748 |
+
lr = self.learning_rate
|
| 749 |
+
params = list(self.control_model.parameters())
|
| 750 |
+
if self.embedding_manager:
|
| 751 |
+
params += list(self.embedding_manager.embedding_parameters())
|
| 752 |
+
if not self.sd_locked:
|
| 753 |
+
# params += list(self.model.diffusion_model.input_blocks.parameters())
|
| 754 |
+
# params += list(self.model.diffusion_model.middle_block.parameters())
|
| 755 |
+
params += list(self.model.diffusion_model.output_blocks.parameters())
|
| 756 |
+
params += list(self.model.diffusion_model.out.parameters())
|
| 757 |
+
if self.unlockKV:
|
| 758 |
+
nCount = 0
|
| 759 |
+
for name, param in self.model.diffusion_model.named_parameters():
|
| 760 |
+
if "attn2.to_k" in name or "attn2.to_v" in name:
|
| 761 |
+
params += [param]
|
| 762 |
+
nCount += 1
|
| 763 |
+
print(
|
| 764 |
+
f"Cross attention is unlocked, and {nCount} Wk or Wv are added to potimizers!!!"
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
| 768 |
+
return opt
|
| 769 |
+
|
| 770 |
+
def low_vram_shift(self, is_diffusing):
|
| 771 |
+
if is_diffusing:
|
| 772 |
+
self.model = self.model.cuda()
|
| 773 |
+
self.control_model = self.control_model.cuda()
|
| 774 |
+
self.first_stage_model = self.first_stage_model.cpu()
|
| 775 |
+
self.cond_stage_model = self.cond_stage_model.cpu()
|
| 776 |
+
else:
|
| 777 |
+
self.model = self.model.cpu()
|
| 778 |
+
self.control_model = self.control_model.cpu()
|
| 779 |
+
self.first_stage_model = self.first_stage_model.cuda()
|
| 780 |
+
self.cond_stage_model = self.cond_stage_model.cuda()
|
sorawm/iopaint/model/anytext/cldm/ddim_hacked.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SAMPLING ONLY."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
|
| 8 |
+
extract_into_tensor,
|
| 9 |
+
make_ddim_sampling_parameters,
|
| 10 |
+
make_ddim_timesteps,
|
| 11 |
+
noise_like,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DDIMSampler(object):
|
| 16 |
+
def __init__(self, model, device, schedule="linear", **kwargs):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.device = device
|
| 19 |
+
self.model = model
|
| 20 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
| 21 |
+
self.schedule = schedule
|
| 22 |
+
|
| 23 |
+
def register_buffer(self, name, attr):
|
| 24 |
+
if type(attr) == torch.Tensor:
|
| 25 |
+
if attr.device != torch.device(self.device):
|
| 26 |
+
attr = attr.to(torch.device(self.device))
|
| 27 |
+
setattr(self, name, attr)
|
| 28 |
+
|
| 29 |
+
def make_schedule(
|
| 30 |
+
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
| 31 |
+
):
|
| 32 |
+
self.ddim_timesteps = make_ddim_timesteps(
|
| 33 |
+
ddim_discr_method=ddim_discretize,
|
| 34 |
+
num_ddim_timesteps=ddim_num_steps,
|
| 35 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
| 36 |
+
verbose=verbose,
|
| 37 |
+
)
|
| 38 |
+
alphas_cumprod = self.model.alphas_cumprod
|
| 39 |
+
assert (
|
| 40 |
+
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
| 41 |
+
), "alphas have to be defined for each timestep"
|
| 42 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
|
| 43 |
+
|
| 44 |
+
self.register_buffer("betas", to_torch(self.model.betas))
|
| 45 |
+
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
| 46 |
+
self.register_buffer(
|
| 47 |
+
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 51 |
+
self.register_buffer(
|
| 52 |
+
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
| 53 |
+
)
|
| 54 |
+
self.register_buffer(
|
| 55 |
+
"sqrt_one_minus_alphas_cumprod",
|
| 56 |
+
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
| 57 |
+
)
|
| 58 |
+
self.register_buffer(
|
| 59 |
+
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
| 60 |
+
)
|
| 61 |
+
self.register_buffer(
|
| 62 |
+
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
| 63 |
+
)
|
| 64 |
+
self.register_buffer(
|
| 65 |
+
"sqrt_recipm1_alphas_cumprod",
|
| 66 |
+
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# ddim sampling parameters
|
| 70 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
| 71 |
+
alphacums=alphas_cumprod.cpu(),
|
| 72 |
+
ddim_timesteps=self.ddim_timesteps,
|
| 73 |
+
eta=ddim_eta,
|
| 74 |
+
verbose=verbose,
|
| 75 |
+
)
|
| 76 |
+
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
| 77 |
+
self.register_buffer("ddim_alphas", ddim_alphas)
|
| 78 |
+
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
| 79 |
+
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
| 80 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
| 81 |
+
(1 - self.alphas_cumprod_prev)
|
| 82 |
+
/ (1 - self.alphas_cumprod)
|
| 83 |
+
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
| 84 |
+
)
|
| 85 |
+
self.register_buffer(
|
| 86 |
+
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
@torch.no_grad()
|
| 90 |
+
def sample(
|
| 91 |
+
self,
|
| 92 |
+
S,
|
| 93 |
+
batch_size,
|
| 94 |
+
shape,
|
| 95 |
+
conditioning=None,
|
| 96 |
+
callback=None,
|
| 97 |
+
normals_sequence=None,
|
| 98 |
+
img_callback=None,
|
| 99 |
+
quantize_x0=False,
|
| 100 |
+
eta=0.0,
|
| 101 |
+
mask=None,
|
| 102 |
+
x0=None,
|
| 103 |
+
temperature=1.0,
|
| 104 |
+
noise_dropout=0.0,
|
| 105 |
+
score_corrector=None,
|
| 106 |
+
corrector_kwargs=None,
|
| 107 |
+
verbose=True,
|
| 108 |
+
x_T=None,
|
| 109 |
+
log_every_t=100,
|
| 110 |
+
unconditional_guidance_scale=1.0,
|
| 111 |
+
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 112 |
+
dynamic_threshold=None,
|
| 113 |
+
ucg_schedule=None,
|
| 114 |
+
**kwargs,
|
| 115 |
+
):
|
| 116 |
+
if conditioning is not None:
|
| 117 |
+
if isinstance(conditioning, dict):
|
| 118 |
+
ctmp = conditioning[list(conditioning.keys())[0]]
|
| 119 |
+
while isinstance(ctmp, list):
|
| 120 |
+
ctmp = ctmp[0]
|
| 121 |
+
cbs = ctmp.shape[0]
|
| 122 |
+
if cbs != batch_size:
|
| 123 |
+
print(
|
| 124 |
+
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
elif isinstance(conditioning, list):
|
| 128 |
+
for ctmp in conditioning:
|
| 129 |
+
if ctmp.shape[0] != batch_size:
|
| 130 |
+
print(
|
| 131 |
+
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
else:
|
| 135 |
+
if conditioning.shape[0] != batch_size:
|
| 136 |
+
print(
|
| 137 |
+
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
| 141 |
+
# sampling
|
| 142 |
+
C, H, W = shape
|
| 143 |
+
size = (batch_size, C, H, W)
|
| 144 |
+
print(f"Data shape for DDIM sampling is {size}, eta {eta}")
|
| 145 |
+
|
| 146 |
+
samples, intermediates = self.ddim_sampling(
|
| 147 |
+
conditioning,
|
| 148 |
+
size,
|
| 149 |
+
callback=callback,
|
| 150 |
+
img_callback=img_callback,
|
| 151 |
+
quantize_denoised=quantize_x0,
|
| 152 |
+
mask=mask,
|
| 153 |
+
x0=x0,
|
| 154 |
+
ddim_use_original_steps=False,
|
| 155 |
+
noise_dropout=noise_dropout,
|
| 156 |
+
temperature=temperature,
|
| 157 |
+
score_corrector=score_corrector,
|
| 158 |
+
corrector_kwargs=corrector_kwargs,
|
| 159 |
+
x_T=x_T,
|
| 160 |
+
log_every_t=log_every_t,
|
| 161 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 162 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 163 |
+
dynamic_threshold=dynamic_threshold,
|
| 164 |
+
ucg_schedule=ucg_schedule,
|
| 165 |
+
)
|
| 166 |
+
return samples, intermediates
|
| 167 |
+
|
| 168 |
+
@torch.no_grad()
|
| 169 |
+
def ddim_sampling(
|
| 170 |
+
self,
|
| 171 |
+
cond,
|
| 172 |
+
shape,
|
| 173 |
+
x_T=None,
|
| 174 |
+
ddim_use_original_steps=False,
|
| 175 |
+
callback=None,
|
| 176 |
+
timesteps=None,
|
| 177 |
+
quantize_denoised=False,
|
| 178 |
+
mask=None,
|
| 179 |
+
x0=None,
|
| 180 |
+
img_callback=None,
|
| 181 |
+
log_every_t=100,
|
| 182 |
+
temperature=1.0,
|
| 183 |
+
noise_dropout=0.0,
|
| 184 |
+
score_corrector=None,
|
| 185 |
+
corrector_kwargs=None,
|
| 186 |
+
unconditional_guidance_scale=1.0,
|
| 187 |
+
unconditional_conditioning=None,
|
| 188 |
+
dynamic_threshold=None,
|
| 189 |
+
ucg_schedule=None,
|
| 190 |
+
):
|
| 191 |
+
device = self.model.betas.device
|
| 192 |
+
b = shape[0]
|
| 193 |
+
if x_T is None:
|
| 194 |
+
img = torch.randn(shape, device=device)
|
| 195 |
+
else:
|
| 196 |
+
img = x_T
|
| 197 |
+
|
| 198 |
+
if timesteps is None:
|
| 199 |
+
timesteps = (
|
| 200 |
+
self.ddpm_num_timesteps
|
| 201 |
+
if ddim_use_original_steps
|
| 202 |
+
else self.ddim_timesteps
|
| 203 |
+
)
|
| 204 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
| 205 |
+
subset_end = (
|
| 206 |
+
int(
|
| 207 |
+
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
| 208 |
+
* self.ddim_timesteps.shape[0]
|
| 209 |
+
)
|
| 210 |
+
- 1
|
| 211 |
+
)
|
| 212 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
| 213 |
+
|
| 214 |
+
intermediates = {"x_inter": [img], "pred_x0": [img]}
|
| 215 |
+
time_range = (
|
| 216 |
+
reversed(range(0, timesteps))
|
| 217 |
+
if ddim_use_original_steps
|
| 218 |
+
else np.flip(timesteps)
|
| 219 |
+
)
|
| 220 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
| 221 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 222 |
+
|
| 223 |
+
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
| 224 |
+
|
| 225 |
+
for i, step in enumerate(iterator):
|
| 226 |
+
index = total_steps - i - 1
|
| 227 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
| 228 |
+
|
| 229 |
+
if mask is not None:
|
| 230 |
+
assert x0 is not None
|
| 231 |
+
img_orig = self.model.q_sample(
|
| 232 |
+
x0, ts
|
| 233 |
+
) # TODO: deterministic forward pass?
|
| 234 |
+
img = img_orig * mask + (1.0 - mask) * img
|
| 235 |
+
|
| 236 |
+
if ucg_schedule is not None:
|
| 237 |
+
assert len(ucg_schedule) == len(time_range)
|
| 238 |
+
unconditional_guidance_scale = ucg_schedule[i]
|
| 239 |
+
|
| 240 |
+
outs = self.p_sample_ddim(
|
| 241 |
+
img,
|
| 242 |
+
cond,
|
| 243 |
+
ts,
|
| 244 |
+
index=index,
|
| 245 |
+
use_original_steps=ddim_use_original_steps,
|
| 246 |
+
quantize_denoised=quantize_denoised,
|
| 247 |
+
temperature=temperature,
|
| 248 |
+
noise_dropout=noise_dropout,
|
| 249 |
+
score_corrector=score_corrector,
|
| 250 |
+
corrector_kwargs=corrector_kwargs,
|
| 251 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 252 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 253 |
+
dynamic_threshold=dynamic_threshold,
|
| 254 |
+
)
|
| 255 |
+
img, pred_x0 = outs
|
| 256 |
+
if callback:
|
| 257 |
+
callback(None, i, None, None)
|
| 258 |
+
if img_callback:
|
| 259 |
+
img_callback(pred_x0, i)
|
| 260 |
+
|
| 261 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
| 262 |
+
intermediates["x_inter"].append(img)
|
| 263 |
+
intermediates["pred_x0"].append(pred_x0)
|
| 264 |
+
|
| 265 |
+
return img, intermediates
|
| 266 |
+
|
| 267 |
+
@torch.no_grad()
|
| 268 |
+
def p_sample_ddim(
|
| 269 |
+
self,
|
| 270 |
+
x,
|
| 271 |
+
c,
|
| 272 |
+
t,
|
| 273 |
+
index,
|
| 274 |
+
repeat_noise=False,
|
| 275 |
+
use_original_steps=False,
|
| 276 |
+
quantize_denoised=False,
|
| 277 |
+
temperature=1.0,
|
| 278 |
+
noise_dropout=0.0,
|
| 279 |
+
score_corrector=None,
|
| 280 |
+
corrector_kwargs=None,
|
| 281 |
+
unconditional_guidance_scale=1.0,
|
| 282 |
+
unconditional_conditioning=None,
|
| 283 |
+
dynamic_threshold=None,
|
| 284 |
+
):
|
| 285 |
+
b, *_, device = *x.shape, x.device
|
| 286 |
+
|
| 287 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
|
| 288 |
+
model_output = self.model.apply_model(x, t, c)
|
| 289 |
+
else:
|
| 290 |
+
model_t = self.model.apply_model(x, t, c)
|
| 291 |
+
model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
|
| 292 |
+
model_output = model_uncond + unconditional_guidance_scale * (
|
| 293 |
+
model_t - model_uncond
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if self.model.parameterization == "v":
|
| 297 |
+
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
| 298 |
+
else:
|
| 299 |
+
e_t = model_output
|
| 300 |
+
|
| 301 |
+
if score_corrector is not None:
|
| 302 |
+
assert self.model.parameterization == "eps", "not implemented"
|
| 303 |
+
e_t = score_corrector.modify_score(
|
| 304 |
+
self.model, e_t, x, t, c, **corrector_kwargs
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
| 308 |
+
alphas_prev = (
|
| 309 |
+
self.model.alphas_cumprod_prev
|
| 310 |
+
if use_original_steps
|
| 311 |
+
else self.ddim_alphas_prev
|
| 312 |
+
)
|
| 313 |
+
sqrt_one_minus_alphas = (
|
| 314 |
+
self.model.sqrt_one_minus_alphas_cumprod
|
| 315 |
+
if use_original_steps
|
| 316 |
+
else self.ddim_sqrt_one_minus_alphas
|
| 317 |
+
)
|
| 318 |
+
sigmas = (
|
| 319 |
+
self.model.ddim_sigmas_for_original_num_steps
|
| 320 |
+
if use_original_steps
|
| 321 |
+
else self.ddim_sigmas
|
| 322 |
+
)
|
| 323 |
+
# select parameters corresponding to the currently considered timestep
|
| 324 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
| 325 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
| 326 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
| 327 |
+
sqrt_one_minus_at = torch.full(
|
| 328 |
+
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# current prediction for x_0
|
| 332 |
+
if self.model.parameterization != "v":
|
| 333 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
| 334 |
+
else:
|
| 335 |
+
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
| 336 |
+
|
| 337 |
+
if quantize_denoised:
|
| 338 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
| 339 |
+
|
| 340 |
+
if dynamic_threshold is not None:
|
| 341 |
+
raise NotImplementedError()
|
| 342 |
+
|
| 343 |
+
# direction pointing to x_t
|
| 344 |
+
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
| 345 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
| 346 |
+
if noise_dropout > 0.0:
|
| 347 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 348 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
| 349 |
+
return x_prev, pred_x0
|
| 350 |
+
|
| 351 |
+
@torch.no_grad()
|
| 352 |
+
def encode(
|
| 353 |
+
self,
|
| 354 |
+
x0,
|
| 355 |
+
c,
|
| 356 |
+
t_enc,
|
| 357 |
+
use_original_steps=False,
|
| 358 |
+
return_intermediates=None,
|
| 359 |
+
unconditional_guidance_scale=1.0,
|
| 360 |
+
unconditional_conditioning=None,
|
| 361 |
+
callback=None,
|
| 362 |
+
):
|
| 363 |
+
timesteps = (
|
| 364 |
+
np.arange(self.ddpm_num_timesteps)
|
| 365 |
+
if use_original_steps
|
| 366 |
+
else self.ddim_timesteps
|
| 367 |
+
)
|
| 368 |
+
num_reference_steps = timesteps.shape[0]
|
| 369 |
+
|
| 370 |
+
assert t_enc <= num_reference_steps
|
| 371 |
+
num_steps = t_enc
|
| 372 |
+
|
| 373 |
+
if use_original_steps:
|
| 374 |
+
alphas_next = self.alphas_cumprod[:num_steps]
|
| 375 |
+
alphas = self.alphas_cumprod_prev[:num_steps]
|
| 376 |
+
else:
|
| 377 |
+
alphas_next = self.ddim_alphas[:num_steps]
|
| 378 |
+
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
| 379 |
+
|
| 380 |
+
x_next = x0
|
| 381 |
+
intermediates = []
|
| 382 |
+
inter_steps = []
|
| 383 |
+
for i in tqdm(range(num_steps), desc="Encoding Image"):
|
| 384 |
+
t = torch.full(
|
| 385 |
+
(x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long
|
| 386 |
+
)
|
| 387 |
+
if unconditional_guidance_scale == 1.0:
|
| 388 |
+
noise_pred = self.model.apply_model(x_next, t, c)
|
| 389 |
+
else:
|
| 390 |
+
assert unconditional_conditioning is not None
|
| 391 |
+
e_t_uncond, noise_pred = torch.chunk(
|
| 392 |
+
self.model.apply_model(
|
| 393 |
+
torch.cat((x_next, x_next)),
|
| 394 |
+
torch.cat((t, t)),
|
| 395 |
+
torch.cat((unconditional_conditioning, c)),
|
| 396 |
+
),
|
| 397 |
+
2,
|
| 398 |
+
)
|
| 399 |
+
noise_pred = e_t_uncond + unconditional_guidance_scale * (
|
| 400 |
+
noise_pred - e_t_uncond
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
| 404 |
+
weighted_noise_pred = (
|
| 405 |
+
alphas_next[i].sqrt()
|
| 406 |
+
* ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
|
| 407 |
+
* noise_pred
|
| 408 |
+
)
|
| 409 |
+
x_next = xt_weighted + weighted_noise_pred
|
| 410 |
+
if (
|
| 411 |
+
return_intermediates
|
| 412 |
+
and i % (num_steps // return_intermediates) == 0
|
| 413 |
+
and i < num_steps - 1
|
| 414 |
+
):
|
| 415 |
+
intermediates.append(x_next)
|
| 416 |
+
inter_steps.append(i)
|
| 417 |
+
elif return_intermediates and i >= num_steps - 2:
|
| 418 |
+
intermediates.append(x_next)
|
| 419 |
+
inter_steps.append(i)
|
| 420 |
+
if callback:
|
| 421 |
+
callback(i)
|
| 422 |
+
|
| 423 |
+
out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
|
| 424 |
+
if return_intermediates:
|
| 425 |
+
out.update({"intermediates": intermediates})
|
| 426 |
+
return x_next, out
|
| 427 |
+
|
| 428 |
+
@torch.no_grad()
|
| 429 |
+
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
| 430 |
+
# fast, but does not allow for exact reconstruction
|
| 431 |
+
# t serves as an index to gather the correct alphas
|
| 432 |
+
if use_original_steps:
|
| 433 |
+
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
| 434 |
+
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
| 435 |
+
else:
|
| 436 |
+
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
| 437 |
+
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
| 438 |
+
|
| 439 |
+
if noise is None:
|
| 440 |
+
noise = torch.randn_like(x0)
|
| 441 |
+
return (
|
| 442 |
+
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
| 443 |
+
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
@torch.no_grad()
|
| 447 |
+
def decode(
|
| 448 |
+
self,
|
| 449 |
+
x_latent,
|
| 450 |
+
cond,
|
| 451 |
+
t_start,
|
| 452 |
+
unconditional_guidance_scale=1.0,
|
| 453 |
+
unconditional_conditioning=None,
|
| 454 |
+
use_original_steps=False,
|
| 455 |
+
callback=None,
|
| 456 |
+
):
|
| 457 |
+
timesteps = (
|
| 458 |
+
np.arange(self.ddpm_num_timesteps)
|
| 459 |
+
if use_original_steps
|
| 460 |
+
else self.ddim_timesteps
|
| 461 |
+
)
|
| 462 |
+
timesteps = timesteps[:t_start]
|
| 463 |
+
|
| 464 |
+
time_range = np.flip(timesteps)
|
| 465 |
+
total_steps = timesteps.shape[0]
|
| 466 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 467 |
+
|
| 468 |
+
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
|
| 469 |
+
x_dec = x_latent
|
| 470 |
+
for i, step in enumerate(iterator):
|
| 471 |
+
index = total_steps - i - 1
|
| 472 |
+
ts = torch.full(
|
| 473 |
+
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
|
| 474 |
+
)
|
| 475 |
+
x_dec, _ = self.p_sample_ddim(
|
| 476 |
+
x_dec,
|
| 477 |
+
cond,
|
| 478 |
+
ts,
|
| 479 |
+
index=index,
|
| 480 |
+
use_original_steps=use_original_steps,
|
| 481 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 482 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 483 |
+
)
|
| 484 |
+
if callback:
|
| 485 |
+
callback(i)
|
| 486 |
+
return x_dec
|
sorawm/iopaint/model/anytext/cldm/embedding_manager.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) Alibaba, Inc. and its affiliates.
|
| 3 |
+
"""
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
|
| 11 |
+
conv_nd,
|
| 12 |
+
linear,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_clip_token_for_string(tokenizer, string):
|
| 17 |
+
batch_encoding = tokenizer(
|
| 18 |
+
string,
|
| 19 |
+
truncation=True,
|
| 20 |
+
max_length=77,
|
| 21 |
+
return_length=True,
|
| 22 |
+
return_overflowing_tokens=False,
|
| 23 |
+
padding="max_length",
|
| 24 |
+
return_tensors="pt",
|
| 25 |
+
)
|
| 26 |
+
tokens = batch_encoding["input_ids"]
|
| 27 |
+
assert (
|
| 28 |
+
torch.count_nonzero(tokens - 49407) == 2
|
| 29 |
+
), f"String '{string}' maps to more than a single token. Please use another string"
|
| 30 |
+
return tokens[0, 1]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_bert_token_for_string(tokenizer, string):
|
| 34 |
+
token = tokenizer(string)
|
| 35 |
+
assert (
|
| 36 |
+
torch.count_nonzero(token) == 3
|
| 37 |
+
), f"String '{string}' maps to more than a single token. Please use another string"
|
| 38 |
+
token = token[0, 1]
|
| 39 |
+
return token
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_clip_vision_emb(encoder, processor, img):
|
| 43 |
+
_img = img.repeat(1, 3, 1, 1) * 255
|
| 44 |
+
inputs = processor(images=_img, return_tensors="pt")
|
| 45 |
+
inputs["pixel_values"] = inputs["pixel_values"].to(img.device)
|
| 46 |
+
outputs = encoder(**inputs)
|
| 47 |
+
emb = outputs.image_embeds
|
| 48 |
+
return emb
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_recog_emb(encoder, img_list):
|
| 52 |
+
_img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list]
|
| 53 |
+
encoder.predictor.eval()
|
| 54 |
+
_, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
|
| 55 |
+
return preds_neck
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def pad_H(x):
|
| 59 |
+
_, _, H, W = x.shape
|
| 60 |
+
p_top = (W - H) // 2
|
| 61 |
+
p_bot = W - H - p_top
|
| 62 |
+
return F.pad(x, (0, 0, p_top, p_bot))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class EncodeNet(nn.Module):
|
| 66 |
+
def __init__(self, in_channels, out_channels):
|
| 67 |
+
super(EncodeNet, self).__init__()
|
| 68 |
+
chan = 16
|
| 69 |
+
n_layer = 4 # downsample
|
| 70 |
+
|
| 71 |
+
self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
|
| 72 |
+
self.conv_list = nn.ModuleList([])
|
| 73 |
+
_c = chan
|
| 74 |
+
for i in range(n_layer):
|
| 75 |
+
self.conv_list.append(conv_nd(2, _c, _c * 2, 3, padding=1, stride=2))
|
| 76 |
+
_c *= 2
|
| 77 |
+
self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
|
| 78 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
| 79 |
+
self.act = nn.SiLU()
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
x = self.act(self.conv1(x))
|
| 83 |
+
for layer in self.conv_list:
|
| 84 |
+
x = self.act(layer(x))
|
| 85 |
+
x = self.act(self.conv2(x))
|
| 86 |
+
x = self.avgpool(x)
|
| 87 |
+
x = x.view(x.size(0), -1)
|
| 88 |
+
return x
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class EmbeddingManager(nn.Module):
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
embedder,
|
| 95 |
+
valid=True,
|
| 96 |
+
glyph_channels=20,
|
| 97 |
+
position_channels=1,
|
| 98 |
+
placeholder_string="*",
|
| 99 |
+
add_pos=False,
|
| 100 |
+
emb_type="ocr",
|
| 101 |
+
**kwargs,
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
if hasattr(embedder, "tokenizer"): # using Stable Diffusion's CLIP encoder
|
| 105 |
+
get_token_for_string = partial(
|
| 106 |
+
get_clip_token_for_string, embedder.tokenizer
|
| 107 |
+
)
|
| 108 |
+
token_dim = 768
|
| 109 |
+
if hasattr(embedder, "vit"):
|
| 110 |
+
assert emb_type == "vit"
|
| 111 |
+
self.get_vision_emb = partial(
|
| 112 |
+
get_clip_vision_emb, embedder.vit, embedder.processor
|
| 113 |
+
)
|
| 114 |
+
self.get_recog_emb = None
|
| 115 |
+
else: # using LDM's BERT encoder
|
| 116 |
+
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
|
| 117 |
+
token_dim = 1280
|
| 118 |
+
self.token_dim = token_dim
|
| 119 |
+
self.emb_type = emb_type
|
| 120 |
+
|
| 121 |
+
self.add_pos = add_pos
|
| 122 |
+
if add_pos:
|
| 123 |
+
self.position_encoder = EncodeNet(position_channels, token_dim)
|
| 124 |
+
if emb_type == "ocr":
|
| 125 |
+
self.proj = linear(40 * 64, token_dim)
|
| 126 |
+
if emb_type == "conv":
|
| 127 |
+
self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
|
| 128 |
+
|
| 129 |
+
self.placeholder_token = get_token_for_string(placeholder_string)
|
| 130 |
+
|
| 131 |
+
def encode_text(self, text_info):
|
| 132 |
+
if self.get_recog_emb is None and self.emb_type == "ocr":
|
| 133 |
+
self.get_recog_emb = partial(get_recog_emb, self.recog)
|
| 134 |
+
|
| 135 |
+
gline_list = []
|
| 136 |
+
pos_list = []
|
| 137 |
+
for i in range(len(text_info["n_lines"])): # sample index in a batch
|
| 138 |
+
n_lines = text_info["n_lines"][i]
|
| 139 |
+
for j in range(n_lines): # line
|
| 140 |
+
gline_list += [text_info["gly_line"][j][i : i + 1]]
|
| 141 |
+
if self.add_pos:
|
| 142 |
+
pos_list += [text_info["positions"][j][i : i + 1]]
|
| 143 |
+
|
| 144 |
+
if len(gline_list) > 0:
|
| 145 |
+
if self.emb_type == "ocr":
|
| 146 |
+
recog_emb = self.get_recog_emb(gline_list)
|
| 147 |
+
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
|
| 148 |
+
elif self.emb_type == "vit":
|
| 149 |
+
enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
|
| 150 |
+
elif self.emb_type == "conv":
|
| 151 |
+
enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
|
| 152 |
+
if self.add_pos:
|
| 153 |
+
enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
|
| 154 |
+
enc_glyph = enc_glyph + enc_pos
|
| 155 |
+
|
| 156 |
+
self.text_embs_all = []
|
| 157 |
+
n_idx = 0
|
| 158 |
+
for i in range(len(text_info["n_lines"])): # sample index in a batch
|
| 159 |
+
n_lines = text_info["n_lines"][i]
|
| 160 |
+
text_embs = []
|
| 161 |
+
for j in range(n_lines): # line
|
| 162 |
+
text_embs += [enc_glyph[n_idx : n_idx + 1]]
|
| 163 |
+
n_idx += 1
|
| 164 |
+
self.text_embs_all += [text_embs]
|
| 165 |
+
|
| 166 |
+
def forward(
|
| 167 |
+
self,
|
| 168 |
+
tokenized_text,
|
| 169 |
+
embedded_text,
|
| 170 |
+
):
|
| 171 |
+
b, device = tokenized_text.shape[0], tokenized_text.device
|
| 172 |
+
for i in range(b):
|
| 173 |
+
idx = tokenized_text[i] == self.placeholder_token.to(device)
|
| 174 |
+
if sum(idx) > 0:
|
| 175 |
+
if i >= len(self.text_embs_all):
|
| 176 |
+
print("truncation for log images...")
|
| 177 |
+
break
|
| 178 |
+
text_emb = torch.cat(self.text_embs_all[i], dim=0)
|
| 179 |
+
if sum(idx) != len(text_emb):
|
| 180 |
+
print("truncation for long caption...")
|
| 181 |
+
embedded_text[i][idx] = text_emb[: sum(idx)]
|
| 182 |
+
return embedded_text
|
| 183 |
+
|
| 184 |
+
def embedding_parameters(self):
|
| 185 |
+
return self.parameters()
|
sorawm/iopaint/model/anytext/cldm/hack.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import einops
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import logging
|
| 4 |
+
|
| 5 |
+
import sorawm.iopaint.model.anytext.ldm.modules.attention
|
| 6 |
+
import sorawm.iopaint.model.anytext.ldm.modules.encoders.modules
|
| 7 |
+
from sorawm.iopaint.model.anytext.ldm.modules.attention import default
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def disable_verbosity():
|
| 11 |
+
logging.set_verbosity_error()
|
| 12 |
+
print("logging improved.")
|
| 13 |
+
return
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def enable_sliced_attention():
|
| 17 |
+
sorawm.iopaint.model.anytext.ldm.modules.attention.CrossAttention.forward = (
|
| 18 |
+
_hacked_sliced_attentin_forward
|
| 19 |
+
)
|
| 20 |
+
print("Enabled sliced_attention.")
|
| 21 |
+
return
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def hack_everything(clip_skip=0):
|
| 25 |
+
disable_verbosity()
|
| 26 |
+
sorawm.iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = (
|
| 27 |
+
_hacked_clip_forward
|
| 28 |
+
)
|
| 29 |
+
sorawm.iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = (
|
| 30 |
+
clip_skip
|
| 31 |
+
)
|
| 32 |
+
print("Enabled clip hacks.")
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Written by Lvmin
|
| 37 |
+
def _hacked_clip_forward(self, text):
|
| 38 |
+
PAD = self.tokenizer.pad_token_id
|
| 39 |
+
EOS = self.tokenizer.eos_token_id
|
| 40 |
+
BOS = self.tokenizer.bos_token_id
|
| 41 |
+
|
| 42 |
+
def tokenize(t):
|
| 43 |
+
return self.tokenizer(t, truncation=False, add_special_tokens=False)[
|
| 44 |
+
"input_ids"
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
def transformer_encode(t):
|
| 48 |
+
if self.clip_skip > 1:
|
| 49 |
+
rt = self.transformer(input_ids=t, output_hidden_states=True)
|
| 50 |
+
return self.transformer.text_model.final_layer_norm(
|
| 51 |
+
rt.hidden_states[-self.clip_skip]
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
return self.transformer(
|
| 55 |
+
input_ids=t, output_hidden_states=False
|
| 56 |
+
).last_hidden_state
|
| 57 |
+
|
| 58 |
+
def split(x):
|
| 59 |
+
return x[75 * 0 : 75 * 1], x[75 * 1 : 75 * 2], x[75 * 2 : 75 * 3]
|
| 60 |
+
|
| 61 |
+
def pad(x, p, i):
|
| 62 |
+
return x[:i] if len(x) >= i else x + [p] * (i - len(x))
|
| 63 |
+
|
| 64 |
+
raw_tokens_list = tokenize(text)
|
| 65 |
+
tokens_list = []
|
| 66 |
+
|
| 67 |
+
for raw_tokens in raw_tokens_list:
|
| 68 |
+
raw_tokens_123 = split(raw_tokens)
|
| 69 |
+
raw_tokens_123 = [
|
| 70 |
+
[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123
|
| 71 |
+
]
|
| 72 |
+
raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
|
| 73 |
+
tokens_list.append(raw_tokens_123)
|
| 74 |
+
|
| 75 |
+
tokens_list = torch.IntTensor(tokens_list).to(self.device)
|
| 76 |
+
|
| 77 |
+
feed = einops.rearrange(tokens_list, "b f i -> (b f) i")
|
| 78 |
+
y = transformer_encode(feed)
|
| 79 |
+
z = einops.rearrange(y, "(b f) i c -> b (f i) c", f=3)
|
| 80 |
+
|
| 81 |
+
return z
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
|
| 85 |
+
def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
|
| 86 |
+
h = self.heads
|
| 87 |
+
|
| 88 |
+
q = self.to_q(x)
|
| 89 |
+
context = default(context, x)
|
| 90 |
+
k = self.to_k(context)
|
| 91 |
+
v = self.to_v(context)
|
| 92 |
+
del context, x
|
| 93 |
+
|
| 94 |
+
q, k, v = map(
|
| 95 |
+
lambda t: einops.rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
limit = k.shape[0]
|
| 99 |
+
att_step = 1
|
| 100 |
+
q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
|
| 101 |
+
k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
|
| 102 |
+
v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
|
| 103 |
+
|
| 104 |
+
q_chunks.reverse()
|
| 105 |
+
k_chunks.reverse()
|
| 106 |
+
v_chunks.reverse()
|
| 107 |
+
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
| 108 |
+
del k, q, v
|
| 109 |
+
for i in range(0, limit, att_step):
|
| 110 |
+
q_buffer = q_chunks.pop()
|
| 111 |
+
k_buffer = k_chunks.pop()
|
| 112 |
+
v_buffer = v_chunks.pop()
|
| 113 |
+
sim_buffer = (
|
| 114 |
+
torch.einsum("b i d, b j d -> b i j", q_buffer, k_buffer) * self.scale
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
del k_buffer, q_buffer
|
| 118 |
+
# attention, what we cannot get enough of, by chunks
|
| 119 |
+
|
| 120 |
+
sim_buffer = sim_buffer.softmax(dim=-1)
|
| 121 |
+
|
| 122 |
+
sim_buffer = torch.einsum("b i j, b j d -> b i d", sim_buffer, v_buffer)
|
| 123 |
+
del v_buffer
|
| 124 |
+
sim[i : i + att_step, :, :] = sim_buffer
|
| 125 |
+
|
| 126 |
+
del sim_buffer
|
| 127 |
+
sim = einops.rearrange(sim, "(b h) n d -> b n (h d)", h=h)
|
| 128 |
+
return self.to_out(sim)
|
sorawm/iopaint/model/anytext/cldm/model.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
|
| 6 |
+
from sorawm.iopaint.model.anytext.ldm.util import instantiate_from_config
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_state_dict(d):
|
| 10 |
+
return d.get("state_dict", d)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_state_dict(ckpt_path, location="cpu"):
|
| 14 |
+
_, extension = os.path.splitext(ckpt_path)
|
| 15 |
+
if extension.lower() == ".safetensors":
|
| 16 |
+
import safetensors.torch
|
| 17 |
+
|
| 18 |
+
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
| 19 |
+
else:
|
| 20 |
+
state_dict = get_state_dict(
|
| 21 |
+
torch.load(ckpt_path, map_location=torch.device(location))
|
| 22 |
+
)
|
| 23 |
+
state_dict = get_state_dict(state_dict)
|
| 24 |
+
print(f"Loaded state_dict from [{ckpt_path}]")
|
| 25 |
+
return state_dict
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def create_model(config_path, device, cond_stage_path=None, use_fp16=False):
|
| 29 |
+
config = OmegaConf.load(config_path)
|
| 30 |
+
# if cond_stage_path:
|
| 31 |
+
# config.model.params.cond_stage_config.params.version = (
|
| 32 |
+
# cond_stage_path # use pre-downloaded ckpts, in case blocked
|
| 33 |
+
# )
|
| 34 |
+
config.model.params.cond_stage_config.params.device = str(device)
|
| 35 |
+
if use_fp16:
|
| 36 |
+
config.model.params.use_fp16 = True
|
| 37 |
+
config.model.params.control_stage_config.params.use_fp16 = True
|
| 38 |
+
config.model.params.unet_config.params.use_fp16 = True
|
| 39 |
+
model = instantiate_from_config(config.model).cpu()
|
| 40 |
+
print(f"Loaded model config from [{config_path}]")
|
| 41 |
+
return model
|
sorawm/iopaint/model/anytext/cldm/recognizer.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) Alibaba, Inc. and its affiliates.
|
| 3 |
+
"""
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
import traceback
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from easydict import EasyDict as edict
|
| 14 |
+
|
| 15 |
+
from sorawm.iopaint.model.anytext.ocr_recog.RecModel import RecModel
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def min_bounding_rect(img):
|
| 19 |
+
ret, thresh = cv2.threshold(img, 127, 255, 0)
|
| 20 |
+
contours, hierarchy = cv2.findContours(
|
| 21 |
+
thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
| 22 |
+
)
|
| 23 |
+
if len(contours) == 0:
|
| 24 |
+
print("Bad contours, using fake bbox...")
|
| 25 |
+
return np.array([[0, 0], [100, 0], [100, 100], [0, 100]])
|
| 26 |
+
max_contour = max(contours, key=cv2.contourArea)
|
| 27 |
+
rect = cv2.minAreaRect(max_contour)
|
| 28 |
+
box = cv2.boxPoints(rect)
|
| 29 |
+
box = np.int0(box)
|
| 30 |
+
# sort
|
| 31 |
+
x_sorted = sorted(box, key=lambda x: x[0])
|
| 32 |
+
left = x_sorted[:2]
|
| 33 |
+
right = x_sorted[2:]
|
| 34 |
+
left = sorted(left, key=lambda x: x[1])
|
| 35 |
+
(tl, bl) = left
|
| 36 |
+
right = sorted(right, key=lambda x: x[1])
|
| 37 |
+
(tr, br) = right
|
| 38 |
+
if tl[1] > bl[1]:
|
| 39 |
+
(tl, bl) = (bl, tl)
|
| 40 |
+
if tr[1] > br[1]:
|
| 41 |
+
(tr, br) = (br, tr)
|
| 42 |
+
return np.array([tl, tr, br, bl])
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def create_predictor(model_dir=None, model_lang="ch", is_onnx=False):
|
| 46 |
+
model_file_path = model_dir
|
| 47 |
+
if model_file_path is not None and not os.path.exists(model_file_path):
|
| 48 |
+
raise ValueError("not find model file path {}".format(model_file_path))
|
| 49 |
+
|
| 50 |
+
if is_onnx:
|
| 51 |
+
import onnxruntime as ort
|
| 52 |
+
|
| 53 |
+
sess = ort.InferenceSession(
|
| 54 |
+
model_file_path, providers=["CPUExecutionProvider"]
|
| 55 |
+
) # 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
|
| 56 |
+
return sess
|
| 57 |
+
else:
|
| 58 |
+
if model_lang == "ch":
|
| 59 |
+
n_class = 6625
|
| 60 |
+
elif model_lang == "en":
|
| 61 |
+
n_class = 97
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
|
| 64 |
+
rec_config = edict(
|
| 65 |
+
in_channels=3,
|
| 66 |
+
backbone=edict(
|
| 67 |
+
type="MobileNetV1Enhance",
|
| 68 |
+
scale=0.5,
|
| 69 |
+
last_conv_stride=[1, 2],
|
| 70 |
+
last_pool_type="avg",
|
| 71 |
+
),
|
| 72 |
+
neck=edict(
|
| 73 |
+
type="SequenceEncoder",
|
| 74 |
+
encoder_type="svtr",
|
| 75 |
+
dims=64,
|
| 76 |
+
depth=2,
|
| 77 |
+
hidden_dims=120,
|
| 78 |
+
use_guide=True,
|
| 79 |
+
),
|
| 80 |
+
head=edict(
|
| 81 |
+
type="CTCHead",
|
| 82 |
+
fc_decay=0.00001,
|
| 83 |
+
out_channels=n_class,
|
| 84 |
+
return_feats=True,
|
| 85 |
+
),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
rec_model = RecModel(rec_config)
|
| 89 |
+
if model_file_path is not None:
|
| 90 |
+
rec_model.load_state_dict(torch.load(model_file_path, map_location="cpu"))
|
| 91 |
+
rec_model.eval()
|
| 92 |
+
return rec_model.eval()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _check_image_file(path):
|
| 96 |
+
img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff"}
|
| 97 |
+
return any([path.lower().endswith(e) for e in img_end])
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_image_file_list(img_file):
|
| 101 |
+
imgs_lists = []
|
| 102 |
+
if img_file is None or not os.path.exists(img_file):
|
| 103 |
+
raise Exception("not found any img file in {}".format(img_file))
|
| 104 |
+
if os.path.isfile(img_file) and _check_image_file(img_file):
|
| 105 |
+
imgs_lists.append(img_file)
|
| 106 |
+
elif os.path.isdir(img_file):
|
| 107 |
+
for single_file in os.listdir(img_file):
|
| 108 |
+
file_path = os.path.join(img_file, single_file)
|
| 109 |
+
if os.path.isfile(file_path) and _check_image_file(file_path):
|
| 110 |
+
imgs_lists.append(file_path)
|
| 111 |
+
if len(imgs_lists) == 0:
|
| 112 |
+
raise Exception("not found any img file in {}".format(img_file))
|
| 113 |
+
imgs_lists = sorted(imgs_lists)
|
| 114 |
+
return imgs_lists
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class TextRecognizer(object):
|
| 118 |
+
def __init__(self, args, predictor):
|
| 119 |
+
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
|
| 120 |
+
self.rec_batch_num = args.rec_batch_num
|
| 121 |
+
self.predictor = predictor
|
| 122 |
+
self.chars = self.get_char_dict(args.rec_char_dict_path)
|
| 123 |
+
self.char2id = {x: i for i, x in enumerate(self.chars)}
|
| 124 |
+
self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
|
| 125 |
+
self.use_fp16 = args.use_fp16
|
| 126 |
+
|
| 127 |
+
# img: CHW
|
| 128 |
+
def resize_norm_img(self, img, max_wh_ratio):
|
| 129 |
+
imgC, imgH, imgW = self.rec_image_shape
|
| 130 |
+
assert imgC == img.shape[0]
|
| 131 |
+
imgW = int((imgH * max_wh_ratio))
|
| 132 |
+
|
| 133 |
+
h, w = img.shape[1:]
|
| 134 |
+
ratio = w / float(h)
|
| 135 |
+
if math.ceil(imgH * ratio) > imgW:
|
| 136 |
+
resized_w = imgW
|
| 137 |
+
else:
|
| 138 |
+
resized_w = int(math.ceil(imgH * ratio))
|
| 139 |
+
resized_image = torch.nn.functional.interpolate(
|
| 140 |
+
img.unsqueeze(0),
|
| 141 |
+
size=(imgH, resized_w),
|
| 142 |
+
mode="bilinear",
|
| 143 |
+
align_corners=True,
|
| 144 |
+
)
|
| 145 |
+
resized_image /= 255.0
|
| 146 |
+
resized_image -= 0.5
|
| 147 |
+
resized_image /= 0.5
|
| 148 |
+
padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device)
|
| 149 |
+
padding_im[:, :, 0:resized_w] = resized_image[0]
|
| 150 |
+
return padding_im
|
| 151 |
+
|
| 152 |
+
# img_list: list of tensors with shape chw 0-255
|
| 153 |
+
def pred_imglist(self, img_list, show_debug=False, is_ori=False):
|
| 154 |
+
img_num = len(img_list)
|
| 155 |
+
assert img_num > 0
|
| 156 |
+
# Calculate the aspect ratio of all text bars
|
| 157 |
+
width_list = []
|
| 158 |
+
for img in img_list:
|
| 159 |
+
width_list.append(img.shape[2] / float(img.shape[1]))
|
| 160 |
+
# Sorting can speed up the recognition process
|
| 161 |
+
indices = torch.from_numpy(np.argsort(np.array(width_list)))
|
| 162 |
+
batch_num = self.rec_batch_num
|
| 163 |
+
preds_all = [None] * img_num
|
| 164 |
+
preds_neck_all = [None] * img_num
|
| 165 |
+
for beg_img_no in range(0, img_num, batch_num):
|
| 166 |
+
end_img_no = min(img_num, beg_img_no + batch_num)
|
| 167 |
+
norm_img_batch = []
|
| 168 |
+
|
| 169 |
+
imgC, imgH, imgW = self.rec_image_shape[:3]
|
| 170 |
+
max_wh_ratio = imgW / imgH
|
| 171 |
+
for ino in range(beg_img_no, end_img_no):
|
| 172 |
+
h, w = img_list[indices[ino]].shape[1:]
|
| 173 |
+
if h > w * 1.2:
|
| 174 |
+
img = img_list[indices[ino]]
|
| 175 |
+
img = torch.transpose(img, 1, 2).flip(dims=[1])
|
| 176 |
+
img_list[indices[ino]] = img
|
| 177 |
+
h, w = img.shape[1:]
|
| 178 |
+
# wh_ratio = w * 1.0 / h
|
| 179 |
+
# max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio
|
| 180 |
+
for ino in range(beg_img_no, end_img_no):
|
| 181 |
+
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
|
| 182 |
+
if self.use_fp16:
|
| 183 |
+
norm_img = norm_img.half()
|
| 184 |
+
norm_img = norm_img.unsqueeze(0)
|
| 185 |
+
norm_img_batch.append(norm_img)
|
| 186 |
+
norm_img_batch = torch.cat(norm_img_batch, dim=0)
|
| 187 |
+
if show_debug:
|
| 188 |
+
for i in range(len(norm_img_batch)):
|
| 189 |
+
_img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy()
|
| 190 |
+
_img = (_img + 0.5) * 255
|
| 191 |
+
_img = _img[:, :, ::-1]
|
| 192 |
+
file_name = f"{indices[beg_img_no + i]}"
|
| 193 |
+
file_name = file_name + "_ori" if is_ori else file_name
|
| 194 |
+
cv2.imwrite(file_name + ".jpg", _img)
|
| 195 |
+
if self.is_onnx:
|
| 196 |
+
input_dict = {}
|
| 197 |
+
input_dict[self.predictor.get_inputs()[0].name] = (
|
| 198 |
+
norm_img_batch.detach().cpu().numpy()
|
| 199 |
+
)
|
| 200 |
+
outputs = self.predictor.run(None, input_dict)
|
| 201 |
+
preds = {}
|
| 202 |
+
preds["ctc"] = torch.from_numpy(outputs[0])
|
| 203 |
+
preds["ctc_neck"] = [torch.zeros(1)] * img_num
|
| 204 |
+
else:
|
| 205 |
+
preds = self.predictor(norm_img_batch)
|
| 206 |
+
for rno in range(preds["ctc"].shape[0]):
|
| 207 |
+
preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno]
|
| 208 |
+
preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno]
|
| 209 |
+
|
| 210 |
+
return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0)
|
| 211 |
+
|
| 212 |
+
def get_char_dict(self, character_dict_path):
|
| 213 |
+
character_str = []
|
| 214 |
+
with open(character_dict_path, "rb") as fin:
|
| 215 |
+
lines = fin.readlines()
|
| 216 |
+
for line in lines:
|
| 217 |
+
line = line.decode("utf-8").strip("\n").strip("\r\n")
|
| 218 |
+
character_str.append(line)
|
| 219 |
+
dict_character = list(character_str)
|
| 220 |
+
dict_character = ["sos"] + dict_character + [" "] # eos is space
|
| 221 |
+
return dict_character
|
| 222 |
+
|
| 223 |
+
def get_text(self, order):
|
| 224 |
+
char_list = [self.chars[text_id] for text_id in order]
|
| 225 |
+
return "".join(char_list)
|
| 226 |
+
|
| 227 |
+
def decode(self, mat):
|
| 228 |
+
text_index = mat.detach().cpu().numpy().argmax(axis=1)
|
| 229 |
+
ignored_tokens = [0]
|
| 230 |
+
selection = np.ones(len(text_index), dtype=bool)
|
| 231 |
+
selection[1:] = text_index[1:] != text_index[:-1]
|
| 232 |
+
for ignored_token in ignored_tokens:
|
| 233 |
+
selection &= text_index != ignored_token
|
| 234 |
+
return text_index[selection], np.where(selection)[0]
|
| 235 |
+
|
| 236 |
+
def get_ctcloss(self, preds, gt_text, weight):
|
| 237 |
+
if not isinstance(weight, torch.Tensor):
|
| 238 |
+
weight = torch.tensor(weight).to(preds.device)
|
| 239 |
+
ctc_loss = torch.nn.CTCLoss(reduction="none")
|
| 240 |
+
log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC
|
| 241 |
+
targets = []
|
| 242 |
+
target_lengths = []
|
| 243 |
+
for t in gt_text:
|
| 244 |
+
targets += [self.char2id.get(i, len(self.chars) - 1) for i in t]
|
| 245 |
+
target_lengths += [len(t)]
|
| 246 |
+
targets = torch.tensor(targets).to(preds.device)
|
| 247 |
+
target_lengths = torch.tensor(target_lengths).to(preds.device)
|
| 248 |
+
input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(
|
| 249 |
+
preds.device
|
| 250 |
+
)
|
| 251 |
+
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
|
| 252 |
+
loss = loss / input_lengths * weight
|
| 253 |
+
return loss
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def main():
|
| 257 |
+
rec_model_dir = "./ocr_weights/ppv3_rec.pth"
|
| 258 |
+
predictor = create_predictor(rec_model_dir)
|
| 259 |
+
args = edict()
|
| 260 |
+
args.rec_image_shape = "3, 48, 320"
|
| 261 |
+
args.rec_char_dict_path = "./ocr_weights/ppocr_keys_v1.txt"
|
| 262 |
+
args.rec_batch_num = 6
|
| 263 |
+
text_recognizer = TextRecognizer(args, predictor)
|
| 264 |
+
image_dir = "./test_imgs_cn"
|
| 265 |
+
gt_text = ["韩国小馆"] * 14
|
| 266 |
+
|
| 267 |
+
image_file_list = get_image_file_list(image_dir)
|
| 268 |
+
valid_image_file_list = []
|
| 269 |
+
img_list = []
|
| 270 |
+
|
| 271 |
+
for image_file in image_file_list:
|
| 272 |
+
img = cv2.imread(image_file)
|
| 273 |
+
if img is None:
|
| 274 |
+
print("error in loading image:{}".format(image_file))
|
| 275 |
+
continue
|
| 276 |
+
valid_image_file_list.append(image_file)
|
| 277 |
+
img_list.append(torch.from_numpy(img).permute(2, 0, 1).float())
|
| 278 |
+
try:
|
| 279 |
+
tic = time.time()
|
| 280 |
+
times = []
|
| 281 |
+
for i in range(10):
|
| 282 |
+
preds, _ = text_recognizer.pred_imglist(img_list) # get text
|
| 283 |
+
preds_all = preds.softmax(dim=2)
|
| 284 |
+
times += [(time.time() - tic) * 1000.0]
|
| 285 |
+
tic = time.time()
|
| 286 |
+
print(times)
|
| 287 |
+
print(np.mean(times[1:]) / len(preds_all))
|
| 288 |
+
weight = np.ones(len(gt_text))
|
| 289 |
+
loss = text_recognizer.get_ctcloss(preds, gt_text, weight)
|
| 290 |
+
for i in range(len(valid_image_file_list)):
|
| 291 |
+
pred = preds_all[i]
|
| 292 |
+
order, idx = text_recognizer.decode(pred)
|
| 293 |
+
text = text_recognizer.get_text(order)
|
| 294 |
+
print(
|
| 295 |
+
f'{valid_image_file_list[i]}: pred/gt="{text}"/"{gt_text[i]}", loss={loss[i]:.2f}'
|
| 296 |
+
)
|
| 297 |
+
except Exception as E:
|
| 298 |
+
print(traceback.format_exc(), E)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
if __name__ == "__main__":
|
| 302 |
+
main()
|
sorawm/iopaint/model/anytext/ldm/__init__.py
ADDED
|
File without changes
|
sorawm/iopaint/model/anytext/ldm/models/__init__.py
ADDED
|
File without changes
|
sorawm/iopaint/model/anytext/ldm/models/autoencoder.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import contextmanager
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.model import (
|
| 7 |
+
Decoder,
|
| 8 |
+
Encoder,
|
| 9 |
+
)
|
| 10 |
+
from sorawm.iopaint.model.anytext.ldm.modules.distributions.distributions import (
|
| 11 |
+
DiagonalGaussianDistribution,
|
| 12 |
+
)
|
| 13 |
+
from sorawm.iopaint.model.anytext.ldm.modules.ema import LitEma
|
| 14 |
+
from sorawm.iopaint.model.anytext.ldm.util import instantiate_from_config
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class AutoencoderKL(torch.nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
ddconfig,
|
| 21 |
+
lossconfig,
|
| 22 |
+
embed_dim,
|
| 23 |
+
ckpt_path=None,
|
| 24 |
+
ignore_keys=[],
|
| 25 |
+
image_key="image",
|
| 26 |
+
colorize_nlabels=None,
|
| 27 |
+
monitor=None,
|
| 28 |
+
ema_decay=None,
|
| 29 |
+
learn_logvar=False,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.learn_logvar = learn_logvar
|
| 33 |
+
self.image_key = image_key
|
| 34 |
+
self.encoder = Encoder(**ddconfig)
|
| 35 |
+
self.decoder = Decoder(**ddconfig)
|
| 36 |
+
self.loss = instantiate_from_config(lossconfig)
|
| 37 |
+
assert ddconfig["double_z"]
|
| 38 |
+
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
| 39 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
| 40 |
+
self.embed_dim = embed_dim
|
| 41 |
+
if colorize_nlabels is not None:
|
| 42 |
+
assert type(colorize_nlabels) == int
|
| 43 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
| 44 |
+
if monitor is not None:
|
| 45 |
+
self.monitor = monitor
|
| 46 |
+
|
| 47 |
+
self.use_ema = ema_decay is not None
|
| 48 |
+
if self.use_ema:
|
| 49 |
+
self.ema_decay = ema_decay
|
| 50 |
+
assert 0.0 < ema_decay < 1.0
|
| 51 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
| 52 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
| 53 |
+
|
| 54 |
+
if ckpt_path is not None:
|
| 55 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 56 |
+
|
| 57 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
| 58 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
| 59 |
+
keys = list(sd.keys())
|
| 60 |
+
for k in keys:
|
| 61 |
+
for ik in ignore_keys:
|
| 62 |
+
if k.startswith(ik):
|
| 63 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 64 |
+
del sd[k]
|
| 65 |
+
self.load_state_dict(sd, strict=False)
|
| 66 |
+
print(f"Restored from {path}")
|
| 67 |
+
|
| 68 |
+
@contextmanager
|
| 69 |
+
def ema_scope(self, context=None):
|
| 70 |
+
if self.use_ema:
|
| 71 |
+
self.model_ema.store(self.parameters())
|
| 72 |
+
self.model_ema.copy_to(self)
|
| 73 |
+
if context is not None:
|
| 74 |
+
print(f"{context}: Switched to EMA weights")
|
| 75 |
+
try:
|
| 76 |
+
yield None
|
| 77 |
+
finally:
|
| 78 |
+
if self.use_ema:
|
| 79 |
+
self.model_ema.restore(self.parameters())
|
| 80 |
+
if context is not None:
|
| 81 |
+
print(f"{context}: Restored training weights")
|
| 82 |
+
|
| 83 |
+
def on_train_batch_end(self, *args, **kwargs):
|
| 84 |
+
if self.use_ema:
|
| 85 |
+
self.model_ema(self)
|
| 86 |
+
|
| 87 |
+
def encode(self, x):
|
| 88 |
+
h = self.encoder(x)
|
| 89 |
+
moments = self.quant_conv(h)
|
| 90 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 91 |
+
return posterior
|
| 92 |
+
|
| 93 |
+
def decode(self, z):
|
| 94 |
+
z = self.post_quant_conv(z)
|
| 95 |
+
dec = self.decoder(z)
|
| 96 |
+
return dec
|
| 97 |
+
|
| 98 |
+
def forward(self, input, sample_posterior=True):
|
| 99 |
+
posterior = self.encode(input)
|
| 100 |
+
if sample_posterior:
|
| 101 |
+
z = posterior.sample()
|
| 102 |
+
else:
|
| 103 |
+
z = posterior.mode()
|
| 104 |
+
dec = self.decode(z)
|
| 105 |
+
return dec, posterior
|
| 106 |
+
|
| 107 |
+
def get_input(self, batch, k):
|
| 108 |
+
x = batch[k]
|
| 109 |
+
if len(x.shape) == 3:
|
| 110 |
+
x = x[..., None]
|
| 111 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
| 115 |
+
inputs = self.get_input(batch, self.image_key)
|
| 116 |
+
reconstructions, posterior = self(inputs)
|
| 117 |
+
|
| 118 |
+
if optimizer_idx == 0:
|
| 119 |
+
# train encoder+decoder+logvar
|
| 120 |
+
aeloss, log_dict_ae = self.loss(
|
| 121 |
+
inputs,
|
| 122 |
+
reconstructions,
|
| 123 |
+
posterior,
|
| 124 |
+
optimizer_idx,
|
| 125 |
+
self.global_step,
|
| 126 |
+
last_layer=self.get_last_layer(),
|
| 127 |
+
split="train",
|
| 128 |
+
)
|
| 129 |
+
self.log(
|
| 130 |
+
"aeloss",
|
| 131 |
+
aeloss,
|
| 132 |
+
prog_bar=True,
|
| 133 |
+
logger=True,
|
| 134 |
+
on_step=True,
|
| 135 |
+
on_epoch=True,
|
| 136 |
+
)
|
| 137 |
+
self.log_dict(
|
| 138 |
+
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
|
| 139 |
+
)
|
| 140 |
+
return aeloss
|
| 141 |
+
|
| 142 |
+
if optimizer_idx == 1:
|
| 143 |
+
# train the discriminator
|
| 144 |
+
discloss, log_dict_disc = self.loss(
|
| 145 |
+
inputs,
|
| 146 |
+
reconstructions,
|
| 147 |
+
posterior,
|
| 148 |
+
optimizer_idx,
|
| 149 |
+
self.global_step,
|
| 150 |
+
last_layer=self.get_last_layer(),
|
| 151 |
+
split="train",
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
self.log(
|
| 155 |
+
"discloss",
|
| 156 |
+
discloss,
|
| 157 |
+
prog_bar=True,
|
| 158 |
+
logger=True,
|
| 159 |
+
on_step=True,
|
| 160 |
+
on_epoch=True,
|
| 161 |
+
)
|
| 162 |
+
self.log_dict(
|
| 163 |
+
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
|
| 164 |
+
)
|
| 165 |
+
return discloss
|
| 166 |
+
|
| 167 |
+
def validation_step(self, batch, batch_idx):
|
| 168 |
+
log_dict = self._validation_step(batch, batch_idx)
|
| 169 |
+
with self.ema_scope():
|
| 170 |
+
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
| 171 |
+
return log_dict
|
| 172 |
+
|
| 173 |
+
def _validation_step(self, batch, batch_idx, postfix=""):
|
| 174 |
+
inputs = self.get_input(batch, self.image_key)
|
| 175 |
+
reconstructions, posterior = self(inputs)
|
| 176 |
+
aeloss, log_dict_ae = self.loss(
|
| 177 |
+
inputs,
|
| 178 |
+
reconstructions,
|
| 179 |
+
posterior,
|
| 180 |
+
0,
|
| 181 |
+
self.global_step,
|
| 182 |
+
last_layer=self.get_last_layer(),
|
| 183 |
+
split="val" + postfix,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
discloss, log_dict_disc = self.loss(
|
| 187 |
+
inputs,
|
| 188 |
+
reconstructions,
|
| 189 |
+
posterior,
|
| 190 |
+
1,
|
| 191 |
+
self.global_step,
|
| 192 |
+
last_layer=self.get_last_layer(),
|
| 193 |
+
split="val" + postfix,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
| 197 |
+
self.log_dict(log_dict_ae)
|
| 198 |
+
self.log_dict(log_dict_disc)
|
| 199 |
+
return self.log_dict
|
| 200 |
+
|
| 201 |
+
def configure_optimizers(self):
|
| 202 |
+
lr = self.learning_rate
|
| 203 |
+
ae_params_list = (
|
| 204 |
+
list(self.encoder.parameters())
|
| 205 |
+
+ list(self.decoder.parameters())
|
| 206 |
+
+ list(self.quant_conv.parameters())
|
| 207 |
+
+ list(self.post_quant_conv.parameters())
|
| 208 |
+
)
|
| 209 |
+
if self.learn_logvar:
|
| 210 |
+
print(f"{self.__class__.__name__}: Learning logvar")
|
| 211 |
+
ae_params_list.append(self.loss.logvar)
|
| 212 |
+
opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9))
|
| 213 |
+
opt_disc = torch.optim.Adam(
|
| 214 |
+
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
|
| 215 |
+
)
|
| 216 |
+
return [opt_ae, opt_disc], []
|
| 217 |
+
|
| 218 |
+
def get_last_layer(self):
|
| 219 |
+
return self.decoder.conv_out.weight
|
| 220 |
+
|
| 221 |
+
@torch.no_grad()
|
| 222 |
+
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
| 223 |
+
log = dict()
|
| 224 |
+
x = self.get_input(batch, self.image_key)
|
| 225 |
+
x = x.to(self.device)
|
| 226 |
+
if not only_inputs:
|
| 227 |
+
xrec, posterior = self(x)
|
| 228 |
+
if x.shape[1] > 3:
|
| 229 |
+
# colorize with random projection
|
| 230 |
+
assert xrec.shape[1] > 3
|
| 231 |
+
x = self.to_rgb(x)
|
| 232 |
+
xrec = self.to_rgb(xrec)
|
| 233 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
| 234 |
+
log["reconstructions"] = xrec
|
| 235 |
+
if log_ema or self.use_ema:
|
| 236 |
+
with self.ema_scope():
|
| 237 |
+
xrec_ema, posterior_ema = self(x)
|
| 238 |
+
if x.shape[1] > 3:
|
| 239 |
+
# colorize with random projection
|
| 240 |
+
assert xrec_ema.shape[1] > 3
|
| 241 |
+
xrec_ema = self.to_rgb(xrec_ema)
|
| 242 |
+
log["samples_ema"] = self.decode(
|
| 243 |
+
torch.randn_like(posterior_ema.sample())
|
| 244 |
+
)
|
| 245 |
+
log["reconstructions_ema"] = xrec_ema
|
| 246 |
+
log["inputs"] = x
|
| 247 |
+
return log
|
| 248 |
+
|
| 249 |
+
def to_rgb(self, x):
|
| 250 |
+
assert self.image_key == "segmentation"
|
| 251 |
+
if not hasattr(self, "colorize"):
|
| 252 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
| 253 |
+
x = F.conv2d(x, weight=self.colorize)
|
| 254 |
+
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
| 255 |
+
return x
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class IdentityFirstStage(torch.nn.Module):
|
| 259 |
+
def __init__(self, *args, vq_interface=False, **kwargs):
|
| 260 |
+
self.vq_interface = vq_interface
|
| 261 |
+
super().__init__()
|
| 262 |
+
|
| 263 |
+
def encode(self, x, *args, **kwargs):
|
| 264 |
+
return x
|
| 265 |
+
|
| 266 |
+
def decode(self, x, *args, **kwargs):
|
| 267 |
+
return x
|
| 268 |
+
|
| 269 |
+
def quantize(self, x, *args, **kwargs):
|
| 270 |
+
if self.vq_interface:
|
| 271 |
+
return x, None, [None, None, None]
|
| 272 |
+
return x
|
| 273 |
+
|
| 274 |
+
def forward(self, x, *args, **kwargs):
|
| 275 |
+
return x
|
sorawm/iopaint/model/anytext/ldm/models/diffusion/__init__.py
ADDED
|
File without changes
|
sorawm/iopaint/model/anytext/ldm/models/diffusion/ddim.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SAMPLING ONLY."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
|
| 8 |
+
extract_into_tensor,
|
| 9 |
+
make_ddim_sampling_parameters,
|
| 10 |
+
make_ddim_timesteps,
|
| 11 |
+
noise_like,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DDIMSampler(object):
|
| 16 |
+
def __init__(self, model, schedule="linear", **kwargs):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.model = model
|
| 19 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
| 20 |
+
self.schedule = schedule
|
| 21 |
+
|
| 22 |
+
def register_buffer(self, name, attr):
|
| 23 |
+
if type(attr) == torch.Tensor:
|
| 24 |
+
if attr.device != torch.device("cuda"):
|
| 25 |
+
attr = attr.to(torch.device("cuda"))
|
| 26 |
+
setattr(self, name, attr)
|
| 27 |
+
|
| 28 |
+
def make_schedule(
|
| 29 |
+
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
| 30 |
+
):
|
| 31 |
+
self.ddim_timesteps = make_ddim_timesteps(
|
| 32 |
+
ddim_discr_method=ddim_discretize,
|
| 33 |
+
num_ddim_timesteps=ddim_num_steps,
|
| 34 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
| 35 |
+
verbose=verbose,
|
| 36 |
+
)
|
| 37 |
+
alphas_cumprod = self.model.alphas_cumprod
|
| 38 |
+
assert (
|
| 39 |
+
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
| 40 |
+
), "alphas have to be defined for each timestep"
|
| 41 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
| 42 |
+
|
| 43 |
+
self.register_buffer("betas", to_torch(self.model.betas))
|
| 44 |
+
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
| 45 |
+
self.register_buffer(
|
| 46 |
+
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 50 |
+
self.register_buffer(
|
| 51 |
+
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
| 52 |
+
)
|
| 53 |
+
self.register_buffer(
|
| 54 |
+
"sqrt_one_minus_alphas_cumprod",
|
| 55 |
+
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
| 56 |
+
)
|
| 57 |
+
self.register_buffer(
|
| 58 |
+
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
| 59 |
+
)
|
| 60 |
+
self.register_buffer(
|
| 61 |
+
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
| 62 |
+
)
|
| 63 |
+
self.register_buffer(
|
| 64 |
+
"sqrt_recipm1_alphas_cumprod",
|
| 65 |
+
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# ddim sampling parameters
|
| 69 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
| 70 |
+
alphacums=alphas_cumprod.cpu(),
|
| 71 |
+
ddim_timesteps=self.ddim_timesteps,
|
| 72 |
+
eta=ddim_eta,
|
| 73 |
+
verbose=verbose,
|
| 74 |
+
)
|
| 75 |
+
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
| 76 |
+
self.register_buffer("ddim_alphas", ddim_alphas)
|
| 77 |
+
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
| 78 |
+
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
| 79 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
| 80 |
+
(1 - self.alphas_cumprod_prev)
|
| 81 |
+
/ (1 - self.alphas_cumprod)
|
| 82 |
+
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
| 83 |
+
)
|
| 84 |
+
self.register_buffer(
|
| 85 |
+
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
@torch.no_grad()
|
| 89 |
+
def sample(
|
| 90 |
+
self,
|
| 91 |
+
S,
|
| 92 |
+
batch_size,
|
| 93 |
+
shape,
|
| 94 |
+
conditioning=None,
|
| 95 |
+
callback=None,
|
| 96 |
+
normals_sequence=None,
|
| 97 |
+
img_callback=None,
|
| 98 |
+
quantize_x0=False,
|
| 99 |
+
eta=0.0,
|
| 100 |
+
mask=None,
|
| 101 |
+
x0=None,
|
| 102 |
+
temperature=1.0,
|
| 103 |
+
noise_dropout=0.0,
|
| 104 |
+
score_corrector=None,
|
| 105 |
+
corrector_kwargs=None,
|
| 106 |
+
verbose=True,
|
| 107 |
+
x_T=None,
|
| 108 |
+
log_every_t=100,
|
| 109 |
+
unconditional_guidance_scale=1.0,
|
| 110 |
+
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 111 |
+
dynamic_threshold=None,
|
| 112 |
+
ucg_schedule=None,
|
| 113 |
+
**kwargs,
|
| 114 |
+
):
|
| 115 |
+
if conditioning is not None:
|
| 116 |
+
if isinstance(conditioning, dict):
|
| 117 |
+
ctmp = conditioning[list(conditioning.keys())[0]]
|
| 118 |
+
while isinstance(ctmp, list):
|
| 119 |
+
ctmp = ctmp[0]
|
| 120 |
+
cbs = ctmp.shape[0]
|
| 121 |
+
# cbs = len(ctmp[0])
|
| 122 |
+
if cbs != batch_size:
|
| 123 |
+
print(
|
| 124 |
+
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
elif isinstance(conditioning, list):
|
| 128 |
+
for ctmp in conditioning:
|
| 129 |
+
if ctmp.shape[0] != batch_size:
|
| 130 |
+
print(
|
| 131 |
+
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
else:
|
| 135 |
+
if conditioning.shape[0] != batch_size:
|
| 136 |
+
print(
|
| 137 |
+
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
| 141 |
+
# sampling
|
| 142 |
+
C, H, W = shape
|
| 143 |
+
size = (batch_size, C, H, W)
|
| 144 |
+
print(f"Data shape for DDIM sampling is {size}, eta {eta}")
|
| 145 |
+
|
| 146 |
+
samples, intermediates = self.ddim_sampling(
|
| 147 |
+
conditioning,
|
| 148 |
+
size,
|
| 149 |
+
callback=callback,
|
| 150 |
+
img_callback=img_callback,
|
| 151 |
+
quantize_denoised=quantize_x0,
|
| 152 |
+
mask=mask,
|
| 153 |
+
x0=x0,
|
| 154 |
+
ddim_use_original_steps=False,
|
| 155 |
+
noise_dropout=noise_dropout,
|
| 156 |
+
temperature=temperature,
|
| 157 |
+
score_corrector=score_corrector,
|
| 158 |
+
corrector_kwargs=corrector_kwargs,
|
| 159 |
+
x_T=x_T,
|
| 160 |
+
log_every_t=log_every_t,
|
| 161 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 162 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 163 |
+
dynamic_threshold=dynamic_threshold,
|
| 164 |
+
ucg_schedule=ucg_schedule,
|
| 165 |
+
)
|
| 166 |
+
return samples, intermediates
|
| 167 |
+
|
| 168 |
+
@torch.no_grad()
|
| 169 |
+
def ddim_sampling(
|
| 170 |
+
self,
|
| 171 |
+
cond,
|
| 172 |
+
shape,
|
| 173 |
+
x_T=None,
|
| 174 |
+
ddim_use_original_steps=False,
|
| 175 |
+
callback=None,
|
| 176 |
+
timesteps=None,
|
| 177 |
+
quantize_denoised=False,
|
| 178 |
+
mask=None,
|
| 179 |
+
x0=None,
|
| 180 |
+
img_callback=None,
|
| 181 |
+
log_every_t=100,
|
| 182 |
+
temperature=1.0,
|
| 183 |
+
noise_dropout=0.0,
|
| 184 |
+
score_corrector=None,
|
| 185 |
+
corrector_kwargs=None,
|
| 186 |
+
unconditional_guidance_scale=1.0,
|
| 187 |
+
unconditional_conditioning=None,
|
| 188 |
+
dynamic_threshold=None,
|
| 189 |
+
ucg_schedule=None,
|
| 190 |
+
):
|
| 191 |
+
device = self.model.betas.device
|
| 192 |
+
b = shape[0]
|
| 193 |
+
if x_T is None:
|
| 194 |
+
img = torch.randn(shape, device=device)
|
| 195 |
+
else:
|
| 196 |
+
img = x_T
|
| 197 |
+
|
| 198 |
+
if timesteps is None:
|
| 199 |
+
timesteps = (
|
| 200 |
+
self.ddpm_num_timesteps
|
| 201 |
+
if ddim_use_original_steps
|
| 202 |
+
else self.ddim_timesteps
|
| 203 |
+
)
|
| 204 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
| 205 |
+
subset_end = (
|
| 206 |
+
int(
|
| 207 |
+
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
| 208 |
+
* self.ddim_timesteps.shape[0]
|
| 209 |
+
)
|
| 210 |
+
- 1
|
| 211 |
+
)
|
| 212 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
| 213 |
+
|
| 214 |
+
intermediates = {"x_inter": [img], "pred_x0": [img], "index": [10000]}
|
| 215 |
+
time_range = (
|
| 216 |
+
reversed(range(0, timesteps))
|
| 217 |
+
if ddim_use_original_steps
|
| 218 |
+
else np.flip(timesteps)
|
| 219 |
+
)
|
| 220 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
| 221 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 222 |
+
|
| 223 |
+
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
| 224 |
+
|
| 225 |
+
for i, step in enumerate(iterator):
|
| 226 |
+
index = total_steps - i - 1
|
| 227 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
| 228 |
+
|
| 229 |
+
if mask is not None:
|
| 230 |
+
assert x0 is not None
|
| 231 |
+
img_orig = self.model.q_sample(
|
| 232 |
+
x0, ts
|
| 233 |
+
) # TODO: deterministic forward pass?
|
| 234 |
+
img = img_orig * mask + (1.0 - mask) * img
|
| 235 |
+
|
| 236 |
+
if ucg_schedule is not None:
|
| 237 |
+
assert len(ucg_schedule) == len(time_range)
|
| 238 |
+
unconditional_guidance_scale = ucg_schedule[i]
|
| 239 |
+
|
| 240 |
+
outs = self.p_sample_ddim(
|
| 241 |
+
img,
|
| 242 |
+
cond,
|
| 243 |
+
ts,
|
| 244 |
+
index=index,
|
| 245 |
+
use_original_steps=ddim_use_original_steps,
|
| 246 |
+
quantize_denoised=quantize_denoised,
|
| 247 |
+
temperature=temperature,
|
| 248 |
+
noise_dropout=noise_dropout,
|
| 249 |
+
score_corrector=score_corrector,
|
| 250 |
+
corrector_kwargs=corrector_kwargs,
|
| 251 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 252 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 253 |
+
dynamic_threshold=dynamic_threshold,
|
| 254 |
+
)
|
| 255 |
+
img, pred_x0 = outs
|
| 256 |
+
if callback:
|
| 257 |
+
callback(i)
|
| 258 |
+
if img_callback:
|
| 259 |
+
img_callback(pred_x0, i)
|
| 260 |
+
|
| 261 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
| 262 |
+
intermediates["x_inter"].append(img)
|
| 263 |
+
intermediates["pred_x0"].append(pred_x0)
|
| 264 |
+
intermediates["index"].append(index)
|
| 265 |
+
|
| 266 |
+
return img, intermediates
|
| 267 |
+
|
| 268 |
+
@torch.no_grad()
|
| 269 |
+
def p_sample_ddim(
|
| 270 |
+
self,
|
| 271 |
+
x,
|
| 272 |
+
c,
|
| 273 |
+
t,
|
| 274 |
+
index,
|
| 275 |
+
repeat_noise=False,
|
| 276 |
+
use_original_steps=False,
|
| 277 |
+
quantize_denoised=False,
|
| 278 |
+
temperature=1.0,
|
| 279 |
+
noise_dropout=0.0,
|
| 280 |
+
score_corrector=None,
|
| 281 |
+
corrector_kwargs=None,
|
| 282 |
+
unconditional_guidance_scale=1.0,
|
| 283 |
+
unconditional_conditioning=None,
|
| 284 |
+
dynamic_threshold=None,
|
| 285 |
+
):
|
| 286 |
+
b, *_, device = *x.shape, x.device
|
| 287 |
+
|
| 288 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
|
| 289 |
+
model_output = self.model.apply_model(x, t, c)
|
| 290 |
+
else:
|
| 291 |
+
x_in = torch.cat([x] * 2)
|
| 292 |
+
t_in = torch.cat([t] * 2)
|
| 293 |
+
if isinstance(c, dict):
|
| 294 |
+
assert isinstance(unconditional_conditioning, dict)
|
| 295 |
+
c_in = dict()
|
| 296 |
+
for k in c:
|
| 297 |
+
if isinstance(c[k], list):
|
| 298 |
+
c_in[k] = [
|
| 299 |
+
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
| 300 |
+
for i in range(len(c[k]))
|
| 301 |
+
]
|
| 302 |
+
elif isinstance(c[k], dict):
|
| 303 |
+
c_in[k] = dict()
|
| 304 |
+
for key in c[k]:
|
| 305 |
+
if isinstance(c[k][key], list):
|
| 306 |
+
if not isinstance(c[k][key][0], torch.Tensor):
|
| 307 |
+
continue
|
| 308 |
+
c_in[k][key] = [
|
| 309 |
+
torch.cat(
|
| 310 |
+
[
|
| 311 |
+
unconditional_conditioning[k][key][i],
|
| 312 |
+
c[k][key][i],
|
| 313 |
+
]
|
| 314 |
+
)
|
| 315 |
+
for i in range(len(c[k][key]))
|
| 316 |
+
]
|
| 317 |
+
else:
|
| 318 |
+
c_in[k][key] = torch.cat(
|
| 319 |
+
[unconditional_conditioning[k][key], c[k][key]]
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
else:
|
| 323 |
+
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
| 324 |
+
elif isinstance(c, list):
|
| 325 |
+
c_in = list()
|
| 326 |
+
assert isinstance(unconditional_conditioning, list)
|
| 327 |
+
for i in range(len(c)):
|
| 328 |
+
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
|
| 329 |
+
else:
|
| 330 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
| 331 |
+
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
| 332 |
+
model_output = model_uncond + unconditional_guidance_scale * (
|
| 333 |
+
model_t - model_uncond
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
if self.model.parameterization == "v":
|
| 337 |
+
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
| 338 |
+
else:
|
| 339 |
+
e_t = model_output
|
| 340 |
+
|
| 341 |
+
if score_corrector is not None:
|
| 342 |
+
assert self.model.parameterization == "eps", "not implemented"
|
| 343 |
+
e_t = score_corrector.modify_score(
|
| 344 |
+
self.model, e_t, x, t, c, **corrector_kwargs
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
| 348 |
+
alphas_prev = (
|
| 349 |
+
self.model.alphas_cumprod_prev
|
| 350 |
+
if use_original_steps
|
| 351 |
+
else self.ddim_alphas_prev
|
| 352 |
+
)
|
| 353 |
+
sqrt_one_minus_alphas = (
|
| 354 |
+
self.model.sqrt_one_minus_alphas_cumprod
|
| 355 |
+
if use_original_steps
|
| 356 |
+
else self.ddim_sqrt_one_minus_alphas
|
| 357 |
+
)
|
| 358 |
+
sigmas = (
|
| 359 |
+
self.model.ddim_sigmas_for_original_num_steps
|
| 360 |
+
if use_original_steps
|
| 361 |
+
else self.ddim_sigmas
|
| 362 |
+
)
|
| 363 |
+
# select parameters corresponding to the currently considered timestep
|
| 364 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
| 365 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
| 366 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
| 367 |
+
sqrt_one_minus_at = torch.full(
|
| 368 |
+
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# current prediction for x_0
|
| 372 |
+
if self.model.parameterization != "v":
|
| 373 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
| 374 |
+
else:
|
| 375 |
+
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
| 376 |
+
|
| 377 |
+
if quantize_denoised:
|
| 378 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
| 379 |
+
|
| 380 |
+
if dynamic_threshold is not None:
|
| 381 |
+
raise NotImplementedError()
|
| 382 |
+
|
| 383 |
+
# direction pointing to x_t
|
| 384 |
+
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
| 385 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
| 386 |
+
if noise_dropout > 0.0:
|
| 387 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 388 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
| 389 |
+
return x_prev, pred_x0
|
| 390 |
+
|
| 391 |
+
@torch.no_grad()
|
| 392 |
+
def encode(
|
| 393 |
+
self,
|
| 394 |
+
x0,
|
| 395 |
+
c,
|
| 396 |
+
t_enc,
|
| 397 |
+
use_original_steps=False,
|
| 398 |
+
return_intermediates=None,
|
| 399 |
+
unconditional_guidance_scale=1.0,
|
| 400 |
+
unconditional_conditioning=None,
|
| 401 |
+
callback=None,
|
| 402 |
+
):
|
| 403 |
+
num_reference_steps = (
|
| 404 |
+
self.ddpm_num_timesteps
|
| 405 |
+
if use_original_steps
|
| 406 |
+
else self.ddim_timesteps.shape[0]
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
assert t_enc <= num_reference_steps
|
| 410 |
+
num_steps = t_enc
|
| 411 |
+
|
| 412 |
+
if use_original_steps:
|
| 413 |
+
alphas_next = self.alphas_cumprod[:num_steps]
|
| 414 |
+
alphas = self.alphas_cumprod_prev[:num_steps]
|
| 415 |
+
else:
|
| 416 |
+
alphas_next = self.ddim_alphas[:num_steps]
|
| 417 |
+
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
| 418 |
+
|
| 419 |
+
x_next = x0
|
| 420 |
+
intermediates = []
|
| 421 |
+
inter_steps = []
|
| 422 |
+
for i in tqdm(range(num_steps), desc="Encoding Image"):
|
| 423 |
+
t = torch.full(
|
| 424 |
+
(x0.shape[0],), i, device=self.model.device, dtype=torch.long
|
| 425 |
+
)
|
| 426 |
+
if unconditional_guidance_scale == 1.0:
|
| 427 |
+
noise_pred = self.model.apply_model(x_next, t, c)
|
| 428 |
+
else:
|
| 429 |
+
assert unconditional_conditioning is not None
|
| 430 |
+
e_t_uncond, noise_pred = torch.chunk(
|
| 431 |
+
self.model.apply_model(
|
| 432 |
+
torch.cat((x_next, x_next)),
|
| 433 |
+
torch.cat((t, t)),
|
| 434 |
+
torch.cat((unconditional_conditioning, c)),
|
| 435 |
+
),
|
| 436 |
+
2,
|
| 437 |
+
)
|
| 438 |
+
noise_pred = e_t_uncond + unconditional_guidance_scale * (
|
| 439 |
+
noise_pred - e_t_uncond
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
| 443 |
+
weighted_noise_pred = (
|
| 444 |
+
alphas_next[i].sqrt()
|
| 445 |
+
* ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
|
| 446 |
+
* noise_pred
|
| 447 |
+
)
|
| 448 |
+
x_next = xt_weighted + weighted_noise_pred
|
| 449 |
+
if (
|
| 450 |
+
return_intermediates
|
| 451 |
+
and i % (num_steps // return_intermediates) == 0
|
| 452 |
+
and i < num_steps - 1
|
| 453 |
+
):
|
| 454 |
+
intermediates.append(x_next)
|
| 455 |
+
inter_steps.append(i)
|
| 456 |
+
elif return_intermediates and i >= num_steps - 2:
|
| 457 |
+
intermediates.append(x_next)
|
| 458 |
+
inter_steps.append(i)
|
| 459 |
+
if callback:
|
| 460 |
+
callback(i)
|
| 461 |
+
|
| 462 |
+
out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
|
| 463 |
+
if return_intermediates:
|
| 464 |
+
out.update({"intermediates": intermediates})
|
| 465 |
+
return x_next, out
|
| 466 |
+
|
| 467 |
+
@torch.no_grad()
|
| 468 |
+
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
| 469 |
+
# fast, but does not allow for exact reconstruction
|
| 470 |
+
# t serves as an index to gather the correct alphas
|
| 471 |
+
if use_original_steps:
|
| 472 |
+
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
| 473 |
+
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
| 474 |
+
else:
|
| 475 |
+
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
| 476 |
+
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
| 477 |
+
|
| 478 |
+
if noise is None:
|
| 479 |
+
noise = torch.randn_like(x0)
|
| 480 |
+
return (
|
| 481 |
+
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
| 482 |
+
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
@torch.no_grad()
|
| 486 |
+
def decode(
|
| 487 |
+
self,
|
| 488 |
+
x_latent,
|
| 489 |
+
cond,
|
| 490 |
+
t_start,
|
| 491 |
+
unconditional_guidance_scale=1.0,
|
| 492 |
+
unconditional_conditioning=None,
|
| 493 |
+
use_original_steps=False,
|
| 494 |
+
callback=None,
|
| 495 |
+
):
|
| 496 |
+
timesteps = (
|
| 497 |
+
np.arange(self.ddpm_num_timesteps)
|
| 498 |
+
if use_original_steps
|
| 499 |
+
else self.ddim_timesteps
|
| 500 |
+
)
|
| 501 |
+
timesteps = timesteps[:t_start]
|
| 502 |
+
|
| 503 |
+
time_range = np.flip(timesteps)
|
| 504 |
+
total_steps = timesteps.shape[0]
|
| 505 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 506 |
+
|
| 507 |
+
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
|
| 508 |
+
x_dec = x_latent
|
| 509 |
+
for i, step in enumerate(iterator):
|
| 510 |
+
index = total_steps - i - 1
|
| 511 |
+
ts = torch.full(
|
| 512 |
+
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
|
| 513 |
+
)
|
| 514 |
+
x_dec, _ = self.p_sample_ddim(
|
| 515 |
+
x_dec,
|
| 516 |
+
cond,
|
| 517 |
+
ts,
|
| 518 |
+
index=index,
|
| 519 |
+
use_original_steps=use_original_steps,
|
| 520 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 521 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 522 |
+
)
|
| 523 |
+
if callback:
|
| 524 |
+
callback(i)
|
| 525 |
+
return x_dec
|
sorawm/iopaint/model/anytext/ldm/models/diffusion/ddpm.py
ADDED
|
@@ -0,0 +1,2386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Part of the implementation is borrowed and modified from ControlNet, publicly available at https://github.com/lllyasviel/ControlNet/blob/main/ldm/models/diffusion/ddpm.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import itertools
|
| 6 |
+
from contextlib import contextmanager, nullcontext
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from einops import rearrange, repeat
|
| 14 |
+
from omegaconf import ListConfig
|
| 15 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 16 |
+
from torchvision.utils import make_grid
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
from sorawm.iopaint.model.anytext.ldm.models.autoencoder import (
|
| 20 |
+
AutoencoderKL,
|
| 21 |
+
IdentityFirstStage,
|
| 22 |
+
)
|
| 23 |
+
from sorawm.iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler
|
| 24 |
+
from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
|
| 25 |
+
extract_into_tensor,
|
| 26 |
+
make_beta_schedule,
|
| 27 |
+
noise_like,
|
| 28 |
+
)
|
| 29 |
+
from sorawm.iopaint.model.anytext.ldm.modules.distributions.distributions import (
|
| 30 |
+
DiagonalGaussianDistribution,
|
| 31 |
+
normal_kl,
|
| 32 |
+
)
|
| 33 |
+
from sorawm.iopaint.model.anytext.ldm.modules.ema import LitEma
|
| 34 |
+
from sorawm.iopaint.model.anytext.ldm.util import (
|
| 35 |
+
count_params,
|
| 36 |
+
default,
|
| 37 |
+
exists,
|
| 38 |
+
instantiate_from_config,
|
| 39 |
+
isimage,
|
| 40 |
+
ismap,
|
| 41 |
+
log_txt_as_img,
|
| 42 |
+
mean_flat,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
|
| 46 |
+
|
| 47 |
+
PRINT_DEBUG = False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def print_grad(grad):
|
| 51 |
+
# print('Gradient:', grad)
|
| 52 |
+
# print(grad.shape)
|
| 53 |
+
a = grad.max()
|
| 54 |
+
b = grad.min()
|
| 55 |
+
# print(f'mean={grad.mean():.4f}, max={a:.4f}, min={b:.4f}')
|
| 56 |
+
s = 255.0 / (a - b)
|
| 57 |
+
c = 255 * (-b / (a - b))
|
| 58 |
+
grad = grad * s + c
|
| 59 |
+
# print(f'mean={grad.mean():.4f}, max={grad.max():.4f}, min={grad.min():.4f}')
|
| 60 |
+
img = grad[0].permute(1, 2, 0).detach().cpu().numpy()
|
| 61 |
+
if img.shape[0] == 512:
|
| 62 |
+
cv2.imwrite("grad-img.jpg", img)
|
| 63 |
+
elif img.shape[0] == 64:
|
| 64 |
+
cv2.imwrite("grad-latent.jpg", img)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def disabled_train(self, mode=True):
|
| 68 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
| 69 |
+
does not change anymore."""
|
| 70 |
+
return self
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def uniform_on_device(r1, r2, shape, device):
|
| 74 |
+
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class DDPM(torch.nn.Module):
|
| 78 |
+
# classic DDPM with Gaussian diffusion, in image space
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
unet_config,
|
| 82 |
+
timesteps=1000,
|
| 83 |
+
beta_schedule="linear",
|
| 84 |
+
loss_type="l2",
|
| 85 |
+
ckpt_path=None,
|
| 86 |
+
ignore_keys=[],
|
| 87 |
+
load_only_unet=False,
|
| 88 |
+
monitor="val/loss",
|
| 89 |
+
use_ema=True,
|
| 90 |
+
first_stage_key="image",
|
| 91 |
+
image_size=256,
|
| 92 |
+
channels=3,
|
| 93 |
+
log_every_t=100,
|
| 94 |
+
clip_denoised=True,
|
| 95 |
+
linear_start=1e-4,
|
| 96 |
+
linear_end=2e-2,
|
| 97 |
+
cosine_s=8e-3,
|
| 98 |
+
given_betas=None,
|
| 99 |
+
original_elbo_weight=0.0,
|
| 100 |
+
v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
| 101 |
+
l_simple_weight=1.0,
|
| 102 |
+
conditioning_key=None,
|
| 103 |
+
parameterization="eps", # all assuming fixed variance schedules
|
| 104 |
+
scheduler_config=None,
|
| 105 |
+
use_positional_encodings=False,
|
| 106 |
+
learn_logvar=False,
|
| 107 |
+
logvar_init=0.0,
|
| 108 |
+
make_it_fit=False,
|
| 109 |
+
ucg_training=None,
|
| 110 |
+
reset_ema=False,
|
| 111 |
+
reset_num_ema_updates=False,
|
| 112 |
+
):
|
| 113 |
+
super().__init__()
|
| 114 |
+
assert parameterization in [
|
| 115 |
+
"eps",
|
| 116 |
+
"x0",
|
| 117 |
+
"v",
|
| 118 |
+
], 'currently only supporting "eps" and "x0" and "v"'
|
| 119 |
+
self.parameterization = parameterization
|
| 120 |
+
print(
|
| 121 |
+
f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
|
| 122 |
+
)
|
| 123 |
+
self.cond_stage_model = None
|
| 124 |
+
self.clip_denoised = clip_denoised
|
| 125 |
+
self.log_every_t = log_every_t
|
| 126 |
+
self.first_stage_key = first_stage_key
|
| 127 |
+
self.image_size = image_size # try conv?
|
| 128 |
+
self.channels = channels
|
| 129 |
+
self.use_positional_encodings = use_positional_encodings
|
| 130 |
+
self.model = DiffusionWrapper(unet_config, conditioning_key)
|
| 131 |
+
count_params(self.model, verbose=True)
|
| 132 |
+
self.use_ema = use_ema
|
| 133 |
+
if self.use_ema:
|
| 134 |
+
self.model_ema = LitEma(self.model)
|
| 135 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
| 136 |
+
|
| 137 |
+
self.use_scheduler = scheduler_config is not None
|
| 138 |
+
if self.use_scheduler:
|
| 139 |
+
self.scheduler_config = scheduler_config
|
| 140 |
+
|
| 141 |
+
self.v_posterior = v_posterior
|
| 142 |
+
self.original_elbo_weight = original_elbo_weight
|
| 143 |
+
self.l_simple_weight = l_simple_weight
|
| 144 |
+
|
| 145 |
+
if monitor is not None:
|
| 146 |
+
self.monitor = monitor
|
| 147 |
+
self.make_it_fit = make_it_fit
|
| 148 |
+
if reset_ema:
|
| 149 |
+
assert exists(ckpt_path)
|
| 150 |
+
if ckpt_path is not None:
|
| 151 |
+
self.init_from_ckpt(
|
| 152 |
+
ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet
|
| 153 |
+
)
|
| 154 |
+
if reset_ema:
|
| 155 |
+
assert self.use_ema
|
| 156 |
+
print(
|
| 157 |
+
f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint."
|
| 158 |
+
)
|
| 159 |
+
self.model_ema = LitEma(self.model)
|
| 160 |
+
if reset_num_ema_updates:
|
| 161 |
+
print(
|
| 162 |
+
" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ "
|
| 163 |
+
)
|
| 164 |
+
assert self.use_ema
|
| 165 |
+
self.model_ema.reset_num_updates()
|
| 166 |
+
|
| 167 |
+
self.register_schedule(
|
| 168 |
+
given_betas=given_betas,
|
| 169 |
+
beta_schedule=beta_schedule,
|
| 170 |
+
timesteps=timesteps,
|
| 171 |
+
linear_start=linear_start,
|
| 172 |
+
linear_end=linear_end,
|
| 173 |
+
cosine_s=cosine_s,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
self.loss_type = loss_type
|
| 177 |
+
|
| 178 |
+
self.learn_logvar = learn_logvar
|
| 179 |
+
logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
|
| 180 |
+
if self.learn_logvar:
|
| 181 |
+
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
| 182 |
+
else:
|
| 183 |
+
self.register_buffer("logvar", logvar)
|
| 184 |
+
|
| 185 |
+
self.ucg_training = ucg_training or dict()
|
| 186 |
+
if self.ucg_training:
|
| 187 |
+
self.ucg_prng = np.random.RandomState()
|
| 188 |
+
|
| 189 |
+
def register_schedule(
|
| 190 |
+
self,
|
| 191 |
+
given_betas=None,
|
| 192 |
+
beta_schedule="linear",
|
| 193 |
+
timesteps=1000,
|
| 194 |
+
linear_start=1e-4,
|
| 195 |
+
linear_end=2e-2,
|
| 196 |
+
cosine_s=8e-3,
|
| 197 |
+
):
|
| 198 |
+
if exists(given_betas):
|
| 199 |
+
betas = given_betas
|
| 200 |
+
else:
|
| 201 |
+
betas = make_beta_schedule(
|
| 202 |
+
beta_schedule,
|
| 203 |
+
timesteps,
|
| 204 |
+
linear_start=linear_start,
|
| 205 |
+
linear_end=linear_end,
|
| 206 |
+
cosine_s=cosine_s,
|
| 207 |
+
)
|
| 208 |
+
alphas = 1.0 - betas
|
| 209 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
| 210 |
+
# np.save('1.npy', alphas_cumprod)
|
| 211 |
+
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
| 212 |
+
|
| 213 |
+
(timesteps,) = betas.shape
|
| 214 |
+
self.num_timesteps = int(timesteps)
|
| 215 |
+
self.linear_start = linear_start
|
| 216 |
+
self.linear_end = linear_end
|
| 217 |
+
assert (
|
| 218 |
+
alphas_cumprod.shape[0] == self.num_timesteps
|
| 219 |
+
), "alphas have to be defined for each timestep"
|
| 220 |
+
|
| 221 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
| 222 |
+
|
| 223 |
+
self.register_buffer("betas", to_torch(betas))
|
| 224 |
+
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
| 225 |
+
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
| 226 |
+
|
| 227 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 228 |
+
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
| 229 |
+
self.register_buffer(
|
| 230 |
+
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
| 231 |
+
)
|
| 232 |
+
self.register_buffer(
|
| 233 |
+
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
|
| 234 |
+
)
|
| 235 |
+
self.register_buffer(
|
| 236 |
+
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
|
| 237 |
+
)
|
| 238 |
+
self.register_buffer(
|
| 239 |
+
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 243 |
+
posterior_variance = (1 - self.v_posterior) * betas * (
|
| 244 |
+
1.0 - alphas_cumprod_prev
|
| 245 |
+
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
| 246 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
| 247 |
+
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
| 248 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
| 249 |
+
self.register_buffer(
|
| 250 |
+
"posterior_log_variance_clipped",
|
| 251 |
+
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
|
| 252 |
+
)
|
| 253 |
+
self.register_buffer(
|
| 254 |
+
"posterior_mean_coef1",
|
| 255 |
+
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
|
| 256 |
+
)
|
| 257 |
+
self.register_buffer(
|
| 258 |
+
"posterior_mean_coef2",
|
| 259 |
+
to_torch(
|
| 260 |
+
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
|
| 261 |
+
),
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
if self.parameterization == "eps":
|
| 265 |
+
lvlb_weights = self.betas**2 / (
|
| 266 |
+
2
|
| 267 |
+
* self.posterior_variance
|
| 268 |
+
* to_torch(alphas)
|
| 269 |
+
* (1 - self.alphas_cumprod)
|
| 270 |
+
)
|
| 271 |
+
elif self.parameterization == "x0":
|
| 272 |
+
lvlb_weights = (
|
| 273 |
+
0.5
|
| 274 |
+
* np.sqrt(torch.Tensor(alphas_cumprod))
|
| 275 |
+
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
| 276 |
+
)
|
| 277 |
+
elif self.parameterization == "v":
|
| 278 |
+
lvlb_weights = torch.ones_like(
|
| 279 |
+
self.betas**2
|
| 280 |
+
/ (
|
| 281 |
+
2
|
| 282 |
+
* self.posterior_variance
|
| 283 |
+
* to_torch(alphas)
|
| 284 |
+
* (1 - self.alphas_cumprod)
|
| 285 |
+
)
|
| 286 |
+
)
|
| 287 |
+
else:
|
| 288 |
+
raise NotImplementedError("mu not supported")
|
| 289 |
+
lvlb_weights[0] = lvlb_weights[1]
|
| 290 |
+
self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
|
| 291 |
+
assert not torch.isnan(self.lvlb_weights).all()
|
| 292 |
+
|
| 293 |
+
@contextmanager
|
| 294 |
+
def ema_scope(self, context=None):
|
| 295 |
+
if self.use_ema:
|
| 296 |
+
self.model_ema.store(self.model.parameters())
|
| 297 |
+
self.model_ema.copy_to(self.model)
|
| 298 |
+
if context is not None:
|
| 299 |
+
print(f"{context}: Switched to EMA weights")
|
| 300 |
+
try:
|
| 301 |
+
yield None
|
| 302 |
+
finally:
|
| 303 |
+
if self.use_ema:
|
| 304 |
+
self.model_ema.restore(self.model.parameters())
|
| 305 |
+
if context is not None:
|
| 306 |
+
print(f"{context}: Restored training weights")
|
| 307 |
+
|
| 308 |
+
@torch.no_grad()
|
| 309 |
+
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
| 310 |
+
sd = torch.load(path, map_location="cpu")
|
| 311 |
+
if "state_dict" in list(sd.keys()):
|
| 312 |
+
sd = sd["state_dict"]
|
| 313 |
+
keys = list(sd.keys())
|
| 314 |
+
for k in keys:
|
| 315 |
+
for ik in ignore_keys:
|
| 316 |
+
if k.startswith(ik):
|
| 317 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 318 |
+
del sd[k]
|
| 319 |
+
if self.make_it_fit:
|
| 320 |
+
n_params = len(
|
| 321 |
+
[
|
| 322 |
+
name
|
| 323 |
+
for name, _ in itertools.chain(
|
| 324 |
+
self.named_parameters(), self.named_buffers()
|
| 325 |
+
)
|
| 326 |
+
]
|
| 327 |
+
)
|
| 328 |
+
for name, param in tqdm(
|
| 329 |
+
itertools.chain(self.named_parameters(), self.named_buffers()),
|
| 330 |
+
desc="Fitting old weights to new weights",
|
| 331 |
+
total=n_params,
|
| 332 |
+
):
|
| 333 |
+
if not name in sd:
|
| 334 |
+
continue
|
| 335 |
+
old_shape = sd[name].shape
|
| 336 |
+
new_shape = param.shape
|
| 337 |
+
assert len(old_shape) == len(new_shape)
|
| 338 |
+
if len(new_shape) > 2:
|
| 339 |
+
# we only modify first two axes
|
| 340 |
+
assert new_shape[2:] == old_shape[2:]
|
| 341 |
+
# assumes first axis corresponds to output dim
|
| 342 |
+
if not new_shape == old_shape:
|
| 343 |
+
new_param = param.clone()
|
| 344 |
+
old_param = sd[name]
|
| 345 |
+
if len(new_shape) == 1:
|
| 346 |
+
for i in range(new_param.shape[0]):
|
| 347 |
+
new_param[i] = old_param[i % old_shape[0]]
|
| 348 |
+
elif len(new_shape) >= 2:
|
| 349 |
+
for i in range(new_param.shape[0]):
|
| 350 |
+
for j in range(new_param.shape[1]):
|
| 351 |
+
new_param[i, j] = old_param[
|
| 352 |
+
i % old_shape[0], j % old_shape[1]
|
| 353 |
+
]
|
| 354 |
+
|
| 355 |
+
n_used_old = torch.ones(old_shape[1])
|
| 356 |
+
for j in range(new_param.shape[1]):
|
| 357 |
+
n_used_old[j % old_shape[1]] += 1
|
| 358 |
+
n_used_new = torch.zeros(new_shape[1])
|
| 359 |
+
for j in range(new_param.shape[1]):
|
| 360 |
+
n_used_new[j] = n_used_old[j % old_shape[1]]
|
| 361 |
+
|
| 362 |
+
n_used_new = n_used_new[None, :]
|
| 363 |
+
while len(n_used_new.shape) < len(new_shape):
|
| 364 |
+
n_used_new = n_used_new.unsqueeze(-1)
|
| 365 |
+
new_param /= n_used_new
|
| 366 |
+
|
| 367 |
+
sd[name] = new_param
|
| 368 |
+
|
| 369 |
+
missing, unexpected = (
|
| 370 |
+
self.load_state_dict(sd, strict=False)
|
| 371 |
+
if not only_model
|
| 372 |
+
else self.model.load_state_dict(sd, strict=False)
|
| 373 |
+
)
|
| 374 |
+
print(
|
| 375 |
+
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
| 376 |
+
)
|
| 377 |
+
if len(missing) > 0:
|
| 378 |
+
print(f"Missing Keys:\n {missing}")
|
| 379 |
+
if len(unexpected) > 0:
|
| 380 |
+
print(f"\nUnexpected Keys:\n {unexpected}")
|
| 381 |
+
|
| 382 |
+
def q_mean_variance(self, x_start, t):
|
| 383 |
+
"""
|
| 384 |
+
Get the distribution q(x_t | x_0).
|
| 385 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
| 386 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
| 387 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
| 388 |
+
"""
|
| 389 |
+
mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 390 |
+
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
| 391 |
+
log_variance = extract_into_tensor(
|
| 392 |
+
self.log_one_minus_alphas_cumprod, t, x_start.shape
|
| 393 |
+
)
|
| 394 |
+
return mean, variance, log_variance
|
| 395 |
+
|
| 396 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 397 |
+
return (
|
| 398 |
+
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
| 399 |
+
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
| 400 |
+
* noise
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
def predict_start_from_z_and_v(self, x_t, t, v):
|
| 404 |
+
# self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
| 405 |
+
# self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
| 406 |
+
return (
|
| 407 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
|
| 408 |
+
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
def predict_eps_from_z_and_v(self, x_t, t, v):
|
| 412 |
+
return (
|
| 413 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
|
| 414 |
+
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
|
| 415 |
+
* x_t
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
def q_posterior(self, x_start, x_t, t):
|
| 419 |
+
posterior_mean = (
|
| 420 |
+
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
| 421 |
+
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 422 |
+
)
|
| 423 |
+
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
| 424 |
+
posterior_log_variance_clipped = extract_into_tensor(
|
| 425 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
| 426 |
+
)
|
| 427 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 428 |
+
|
| 429 |
+
def p_mean_variance(self, x, t, clip_denoised: bool):
|
| 430 |
+
model_out = self.model(x, t)
|
| 431 |
+
if self.parameterization == "eps":
|
| 432 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
|
| 433 |
+
elif self.parameterization == "x0":
|
| 434 |
+
x_recon = model_out
|
| 435 |
+
if clip_denoised:
|
| 436 |
+
x_recon.clamp_(-1.0, 1.0)
|
| 437 |
+
|
| 438 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
|
| 439 |
+
x_start=x_recon, x_t=x, t=t
|
| 440 |
+
)
|
| 441 |
+
return model_mean, posterior_variance, posterior_log_variance
|
| 442 |
+
|
| 443 |
+
@torch.no_grad()
|
| 444 |
+
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
|
| 445 |
+
b, *_, device = *x.shape, x.device
|
| 446 |
+
model_mean, _, model_log_variance = self.p_mean_variance(
|
| 447 |
+
x=x, t=t, clip_denoised=clip_denoised
|
| 448 |
+
)
|
| 449 |
+
noise = noise_like(x.shape, device, repeat_noise)
|
| 450 |
+
# no noise when t == 0
|
| 451 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
| 452 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
| 453 |
+
|
| 454 |
+
@torch.no_grad()
|
| 455 |
+
def p_sample_loop(self, shape, return_intermediates=False):
|
| 456 |
+
device = self.betas.device
|
| 457 |
+
b = shape[0]
|
| 458 |
+
img = torch.randn(shape, device=device)
|
| 459 |
+
intermediates = [img]
|
| 460 |
+
for i in tqdm(
|
| 461 |
+
reversed(range(0, self.num_timesteps)),
|
| 462 |
+
desc="Sampling t",
|
| 463 |
+
total=self.num_timesteps,
|
| 464 |
+
):
|
| 465 |
+
img = self.p_sample(
|
| 466 |
+
img,
|
| 467 |
+
torch.full((b,), i, device=device, dtype=torch.long),
|
| 468 |
+
clip_denoised=self.clip_denoised,
|
| 469 |
+
)
|
| 470 |
+
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
|
| 471 |
+
intermediates.append(img)
|
| 472 |
+
if return_intermediates:
|
| 473 |
+
return img, intermediates
|
| 474 |
+
return img
|
| 475 |
+
|
| 476 |
+
@torch.no_grad()
|
| 477 |
+
def sample(self, batch_size=16, return_intermediates=False):
|
| 478 |
+
image_size = self.image_size
|
| 479 |
+
channels = self.channels
|
| 480 |
+
return self.p_sample_loop(
|
| 481 |
+
(batch_size, channels, image_size, image_size),
|
| 482 |
+
return_intermediates=return_intermediates,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
def q_sample(self, x_start, t, noise=None):
|
| 486 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 487 |
+
return (
|
| 488 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 489 |
+
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
| 490 |
+
* noise
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
def get_v(self, x, noise, t):
|
| 494 |
+
return (
|
| 495 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
|
| 496 |
+
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
def get_loss(self, pred, target, mean=True):
|
| 500 |
+
if self.loss_type == "l1":
|
| 501 |
+
loss = (target - pred).abs()
|
| 502 |
+
if mean:
|
| 503 |
+
loss = loss.mean()
|
| 504 |
+
elif self.loss_type == "l2":
|
| 505 |
+
if mean:
|
| 506 |
+
loss = torch.nn.functional.mse_loss(target, pred)
|
| 507 |
+
else:
|
| 508 |
+
loss = torch.nn.functional.mse_loss(target, pred, reduction="none")
|
| 509 |
+
else:
|
| 510 |
+
raise NotImplementedError("unknown loss type '{loss_type}'")
|
| 511 |
+
|
| 512 |
+
return loss
|
| 513 |
+
|
| 514 |
+
def p_losses(self, x_start, t, noise=None):
|
| 515 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 516 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
| 517 |
+
model_out = self.model(x_noisy, t)
|
| 518 |
+
|
| 519 |
+
loss_dict = {}
|
| 520 |
+
if self.parameterization == "eps":
|
| 521 |
+
target = noise
|
| 522 |
+
elif self.parameterization == "x0":
|
| 523 |
+
target = x_start
|
| 524 |
+
elif self.parameterization == "v":
|
| 525 |
+
target = self.get_v(x_start, noise, t)
|
| 526 |
+
else:
|
| 527 |
+
raise NotImplementedError(
|
| 528 |
+
f"Parameterization {self.parameterization} not yet supported"
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
|
| 532 |
+
|
| 533 |
+
log_prefix = "train" if self.training else "val"
|
| 534 |
+
|
| 535 |
+
loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()})
|
| 536 |
+
loss_simple = loss.mean() * self.l_simple_weight
|
| 537 |
+
|
| 538 |
+
loss_vlb = (self.lvlb_weights[t] * loss).mean()
|
| 539 |
+
loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb})
|
| 540 |
+
|
| 541 |
+
loss = loss_simple + self.original_elbo_weight * loss_vlb
|
| 542 |
+
|
| 543 |
+
loss_dict.update({f"{log_prefix}/loss": loss})
|
| 544 |
+
|
| 545 |
+
return loss, loss_dict
|
| 546 |
+
|
| 547 |
+
def forward(self, x, *args, **kwargs):
|
| 548 |
+
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
|
| 549 |
+
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
|
| 550 |
+
t = torch.randint(
|
| 551 |
+
0, self.num_timesteps, (x.shape[0],), device=self.device
|
| 552 |
+
).long()
|
| 553 |
+
return self.p_losses(x, t, *args, **kwargs)
|
| 554 |
+
|
| 555 |
+
def get_input(self, batch, k):
|
| 556 |
+
x = batch[k]
|
| 557 |
+
if len(x.shape) == 3:
|
| 558 |
+
x = x[..., None]
|
| 559 |
+
x = rearrange(x, "b h w c -> b c h w")
|
| 560 |
+
x = x.to(memory_format=torch.contiguous_format).float()
|
| 561 |
+
return x
|
| 562 |
+
|
| 563 |
+
def shared_step(self, batch):
|
| 564 |
+
x = self.get_input(batch, self.first_stage_key)
|
| 565 |
+
loss, loss_dict = self(x)
|
| 566 |
+
return loss, loss_dict
|
| 567 |
+
|
| 568 |
+
def training_step(self, batch, batch_idx):
|
| 569 |
+
for k in self.ucg_training:
|
| 570 |
+
p = self.ucg_training[k]["p"]
|
| 571 |
+
val = self.ucg_training[k]["val"]
|
| 572 |
+
if val is None:
|
| 573 |
+
val = ""
|
| 574 |
+
for i in range(len(batch[k])):
|
| 575 |
+
if self.ucg_prng.choice(2, p=[1 - p, p]):
|
| 576 |
+
batch[k][i] = val
|
| 577 |
+
|
| 578 |
+
loss, loss_dict = self.shared_step(batch)
|
| 579 |
+
|
| 580 |
+
self.log_dict(
|
| 581 |
+
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
self.log(
|
| 585 |
+
"global_step",
|
| 586 |
+
self.global_step,
|
| 587 |
+
prog_bar=True,
|
| 588 |
+
logger=True,
|
| 589 |
+
on_step=True,
|
| 590 |
+
on_epoch=False,
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
if self.use_scheduler:
|
| 594 |
+
lr = self.optimizers().param_groups[0]["lr"]
|
| 595 |
+
self.log(
|
| 596 |
+
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
return loss
|
| 600 |
+
|
| 601 |
+
@torch.no_grad()
|
| 602 |
+
def validation_step(self, batch, batch_idx):
|
| 603 |
+
_, loss_dict_no_ema = self.shared_step(batch)
|
| 604 |
+
with self.ema_scope():
|
| 605 |
+
_, loss_dict_ema = self.shared_step(batch)
|
| 606 |
+
loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema}
|
| 607 |
+
self.log_dict(
|
| 608 |
+
loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True
|
| 609 |
+
)
|
| 610 |
+
self.log_dict(
|
| 611 |
+
loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
def on_train_batch_end(self, *args, **kwargs):
|
| 615 |
+
if self.use_ema:
|
| 616 |
+
self.model_ema(self.model)
|
| 617 |
+
|
| 618 |
+
def _get_rows_from_list(self, samples):
|
| 619 |
+
n_imgs_per_row = len(samples)
|
| 620 |
+
denoise_grid = rearrange(samples, "n b c h w -> b n c h w")
|
| 621 |
+
denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
|
| 622 |
+
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
|
| 623 |
+
return denoise_grid
|
| 624 |
+
|
| 625 |
+
@torch.no_grad()
|
| 626 |
+
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
| 627 |
+
log = dict()
|
| 628 |
+
x = self.get_input(batch, self.first_stage_key)
|
| 629 |
+
N = min(x.shape[0], N)
|
| 630 |
+
n_row = min(x.shape[0], n_row)
|
| 631 |
+
x = x.to(self.device)[:N]
|
| 632 |
+
log["inputs"] = x
|
| 633 |
+
|
| 634 |
+
# get diffusion row
|
| 635 |
+
diffusion_row = list()
|
| 636 |
+
x_start = x[:n_row]
|
| 637 |
+
|
| 638 |
+
for t in range(self.num_timesteps):
|
| 639 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
| 640 |
+
t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
|
| 641 |
+
t = t.to(self.device).long()
|
| 642 |
+
noise = torch.randn_like(x_start)
|
| 643 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
| 644 |
+
diffusion_row.append(x_noisy)
|
| 645 |
+
|
| 646 |
+
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
|
| 647 |
+
|
| 648 |
+
if sample:
|
| 649 |
+
# get denoise row
|
| 650 |
+
with self.ema_scope("Plotting"):
|
| 651 |
+
samples, denoise_row = self.sample(
|
| 652 |
+
batch_size=N, return_intermediates=True
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
log["samples"] = samples
|
| 656 |
+
log["denoise_row"] = self._get_rows_from_list(denoise_row)
|
| 657 |
+
|
| 658 |
+
if return_keys:
|
| 659 |
+
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
|
| 660 |
+
return log
|
| 661 |
+
else:
|
| 662 |
+
return {key: log[key] for key in return_keys}
|
| 663 |
+
return log
|
| 664 |
+
|
| 665 |
+
def configure_optimizers(self):
|
| 666 |
+
lr = self.learning_rate
|
| 667 |
+
params = list(self.model.parameters())
|
| 668 |
+
if self.learn_logvar:
|
| 669 |
+
params = params + [self.logvar]
|
| 670 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
| 671 |
+
return opt
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
class LatentDiffusion(DDPM):
|
| 675 |
+
"""main class"""
|
| 676 |
+
|
| 677 |
+
def __init__(
|
| 678 |
+
self,
|
| 679 |
+
first_stage_config,
|
| 680 |
+
cond_stage_config,
|
| 681 |
+
num_timesteps_cond=None,
|
| 682 |
+
cond_stage_key="image",
|
| 683 |
+
cond_stage_trainable=False,
|
| 684 |
+
concat_mode=True,
|
| 685 |
+
cond_stage_forward=None,
|
| 686 |
+
conditioning_key=None,
|
| 687 |
+
scale_factor=1.0,
|
| 688 |
+
scale_by_std=False,
|
| 689 |
+
force_null_conditioning=False,
|
| 690 |
+
*args,
|
| 691 |
+
**kwargs,
|
| 692 |
+
):
|
| 693 |
+
self.force_null_conditioning = force_null_conditioning
|
| 694 |
+
self.num_timesteps_cond = default(num_timesteps_cond, 1)
|
| 695 |
+
self.scale_by_std = scale_by_std
|
| 696 |
+
assert self.num_timesteps_cond <= kwargs["timesteps"]
|
| 697 |
+
# for backwards compatibility after implementation of DiffusionWrapper
|
| 698 |
+
if conditioning_key is None:
|
| 699 |
+
conditioning_key = "concat" if concat_mode else "crossattn"
|
| 700 |
+
if (
|
| 701 |
+
cond_stage_config == "__is_unconditional__"
|
| 702 |
+
and not self.force_null_conditioning
|
| 703 |
+
):
|
| 704 |
+
conditioning_key = None
|
| 705 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
| 706 |
+
reset_ema = kwargs.pop("reset_ema", False)
|
| 707 |
+
reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
|
| 708 |
+
ignore_keys = kwargs.pop("ignore_keys", [])
|
| 709 |
+
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
| 710 |
+
self.concat_mode = concat_mode
|
| 711 |
+
self.cond_stage_trainable = cond_stage_trainable
|
| 712 |
+
self.cond_stage_key = cond_stage_key
|
| 713 |
+
try:
|
| 714 |
+
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
| 715 |
+
except:
|
| 716 |
+
self.num_downs = 0
|
| 717 |
+
if not scale_by_std:
|
| 718 |
+
self.scale_factor = scale_factor
|
| 719 |
+
else:
|
| 720 |
+
self.register_buffer("scale_factor", torch.tensor(scale_factor))
|
| 721 |
+
self.instantiate_first_stage(first_stage_config)
|
| 722 |
+
self.instantiate_cond_stage(cond_stage_config)
|
| 723 |
+
self.cond_stage_forward = cond_stage_forward
|
| 724 |
+
self.clip_denoised = False
|
| 725 |
+
self.bbox_tokenizer = None
|
| 726 |
+
|
| 727 |
+
self.restarted_from_ckpt = False
|
| 728 |
+
if ckpt_path is not None:
|
| 729 |
+
self.init_from_ckpt(ckpt_path, ignore_keys)
|
| 730 |
+
self.restarted_from_ckpt = True
|
| 731 |
+
if reset_ema:
|
| 732 |
+
assert self.use_ema
|
| 733 |
+
print(
|
| 734 |
+
f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint."
|
| 735 |
+
)
|
| 736 |
+
self.model_ema = LitEma(self.model)
|
| 737 |
+
if reset_num_ema_updates:
|
| 738 |
+
print(
|
| 739 |
+
" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ "
|
| 740 |
+
)
|
| 741 |
+
assert self.use_ema
|
| 742 |
+
self.model_ema.reset_num_updates()
|
| 743 |
+
|
| 744 |
+
def make_cond_schedule(
|
| 745 |
+
self,
|
| 746 |
+
):
|
| 747 |
+
self.cond_ids = torch.full(
|
| 748 |
+
size=(self.num_timesteps,),
|
| 749 |
+
fill_value=self.num_timesteps - 1,
|
| 750 |
+
dtype=torch.long,
|
| 751 |
+
)
|
| 752 |
+
ids = torch.round(
|
| 753 |
+
torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
|
| 754 |
+
).long()
|
| 755 |
+
self.cond_ids[: self.num_timesteps_cond] = ids
|
| 756 |
+
|
| 757 |
+
@torch.no_grad()
|
| 758 |
+
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
| 759 |
+
# only for very first batch
|
| 760 |
+
if (
|
| 761 |
+
self.scale_by_std
|
| 762 |
+
and self.current_epoch == 0
|
| 763 |
+
and self.global_step == 0
|
| 764 |
+
and batch_idx == 0
|
| 765 |
+
and not self.restarted_from_ckpt
|
| 766 |
+
):
|
| 767 |
+
assert (
|
| 768 |
+
self.scale_factor == 1.0
|
| 769 |
+
), "rather not use custom rescaling and std-rescaling simultaneously"
|
| 770 |
+
# set rescale weight to 1./std of encodings
|
| 771 |
+
print("### USING STD-RESCALING ###")
|
| 772 |
+
x = super().get_input(batch, self.first_stage_key)
|
| 773 |
+
x = x.to(self.device)
|
| 774 |
+
encoder_posterior = self.encode_first_stage(x)
|
| 775 |
+
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
| 776 |
+
del self.scale_factor
|
| 777 |
+
self.register_buffer("scale_factor", 1.0 / z.flatten().std())
|
| 778 |
+
print(f"setting self.scale_factor to {self.scale_factor}")
|
| 779 |
+
print("### USING STD-RESCALING ###")
|
| 780 |
+
|
| 781 |
+
def register_schedule(
|
| 782 |
+
self,
|
| 783 |
+
given_betas=None,
|
| 784 |
+
beta_schedule="linear",
|
| 785 |
+
timesteps=1000,
|
| 786 |
+
linear_start=1e-4,
|
| 787 |
+
linear_end=2e-2,
|
| 788 |
+
cosine_s=8e-3,
|
| 789 |
+
):
|
| 790 |
+
super().register_schedule(
|
| 791 |
+
given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
self.shorten_cond_schedule = self.num_timesteps_cond > 1
|
| 795 |
+
if self.shorten_cond_schedule:
|
| 796 |
+
self.make_cond_schedule()
|
| 797 |
+
|
| 798 |
+
def instantiate_first_stage(self, config):
|
| 799 |
+
model = instantiate_from_config(config)
|
| 800 |
+
self.first_stage_model = model.eval()
|
| 801 |
+
self.first_stage_model.train = disabled_train
|
| 802 |
+
for param in self.first_stage_model.parameters():
|
| 803 |
+
param.requires_grad = False
|
| 804 |
+
|
| 805 |
+
def instantiate_cond_stage(self, config):
|
| 806 |
+
if not self.cond_stage_trainable:
|
| 807 |
+
if config == "__is_first_stage__":
|
| 808 |
+
print("Using first stage also as cond stage.")
|
| 809 |
+
self.cond_stage_model = self.first_stage_model
|
| 810 |
+
elif config == "__is_unconditional__":
|
| 811 |
+
print(f"Training {self.__class__.__name__} as an unconditional model.")
|
| 812 |
+
self.cond_stage_model = None
|
| 813 |
+
# self.be_unconditional = True
|
| 814 |
+
else:
|
| 815 |
+
model = instantiate_from_config(config)
|
| 816 |
+
self.cond_stage_model = model.eval()
|
| 817 |
+
self.cond_stage_model.train = disabled_train
|
| 818 |
+
for param in self.cond_stage_model.parameters():
|
| 819 |
+
param.requires_grad = False
|
| 820 |
+
else:
|
| 821 |
+
assert config != "__is_first_stage__"
|
| 822 |
+
assert config != "__is_unconditional__"
|
| 823 |
+
model = instantiate_from_config(config)
|
| 824 |
+
self.cond_stage_model = model
|
| 825 |
+
|
| 826 |
+
def _get_denoise_row_from_list(
|
| 827 |
+
self, samples, desc="", force_no_decoder_quantization=False
|
| 828 |
+
):
|
| 829 |
+
denoise_row = []
|
| 830 |
+
for zd in tqdm(samples, desc=desc):
|
| 831 |
+
denoise_row.append(
|
| 832 |
+
self.decode_first_stage(
|
| 833 |
+
zd.to(self.device), force_not_quantize=force_no_decoder_quantization
|
| 834 |
+
)
|
| 835 |
+
)
|
| 836 |
+
n_imgs_per_row = len(denoise_row)
|
| 837 |
+
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
|
| 838 |
+
denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w")
|
| 839 |
+
denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
|
| 840 |
+
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
|
| 841 |
+
return denoise_grid
|
| 842 |
+
|
| 843 |
+
def get_first_stage_encoding(self, encoder_posterior):
|
| 844 |
+
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
|
| 845 |
+
z = encoder_posterior.sample()
|
| 846 |
+
elif isinstance(encoder_posterior, torch.Tensor):
|
| 847 |
+
z = encoder_posterior
|
| 848 |
+
else:
|
| 849 |
+
raise NotImplementedError(
|
| 850 |
+
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
|
| 851 |
+
)
|
| 852 |
+
return self.scale_factor * z
|
| 853 |
+
|
| 854 |
+
def get_learned_conditioning(self, c):
|
| 855 |
+
if self.cond_stage_forward is None:
|
| 856 |
+
if hasattr(self.cond_stage_model, "encode") and callable(
|
| 857 |
+
self.cond_stage_model.encode
|
| 858 |
+
):
|
| 859 |
+
c = self.cond_stage_model.encode(c)
|
| 860 |
+
if isinstance(c, DiagonalGaussianDistribution):
|
| 861 |
+
c = c.mode()
|
| 862 |
+
else:
|
| 863 |
+
c = self.cond_stage_model(c)
|
| 864 |
+
else:
|
| 865 |
+
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
| 866 |
+
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
| 867 |
+
return c
|
| 868 |
+
|
| 869 |
+
def meshgrid(self, h, w):
|
| 870 |
+
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
|
| 871 |
+
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
|
| 872 |
+
|
| 873 |
+
arr = torch.cat([y, x], dim=-1)
|
| 874 |
+
return arr
|
| 875 |
+
|
| 876 |
+
def delta_border(self, h, w):
|
| 877 |
+
"""
|
| 878 |
+
:param h: height
|
| 879 |
+
:param w: width
|
| 880 |
+
:return: normalized distance to image border,
|
| 881 |
+
wtith min distance = 0 at border and max dist = 0.5 at image center
|
| 882 |
+
"""
|
| 883 |
+
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
|
| 884 |
+
arr = self.meshgrid(h, w) / lower_right_corner
|
| 885 |
+
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
|
| 886 |
+
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
|
| 887 |
+
edge_dist = torch.min(
|
| 888 |
+
torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1
|
| 889 |
+
)[0]
|
| 890 |
+
return edge_dist
|
| 891 |
+
|
| 892 |
+
def get_weighting(self, h, w, Ly, Lx, device):
|
| 893 |
+
weighting = self.delta_border(h, w)
|
| 894 |
+
weighting = torch.clip(
|
| 895 |
+
weighting,
|
| 896 |
+
self.split_input_params["clip_min_weight"],
|
| 897 |
+
self.split_input_params["clip_max_weight"],
|
| 898 |
+
)
|
| 899 |
+
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
|
| 900 |
+
|
| 901 |
+
if self.split_input_params["tie_braker"]:
|
| 902 |
+
L_weighting = self.delta_border(Ly, Lx)
|
| 903 |
+
L_weighting = torch.clip(
|
| 904 |
+
L_weighting,
|
| 905 |
+
self.split_input_params["clip_min_tie_weight"],
|
| 906 |
+
self.split_input_params["clip_max_tie_weight"],
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
|
| 910 |
+
weighting = weighting * L_weighting
|
| 911 |
+
return weighting
|
| 912 |
+
|
| 913 |
+
def get_fold_unfold(
|
| 914 |
+
self, x, kernel_size, stride, uf=1, df=1
|
| 915 |
+
): # todo load once not every time, shorten code
|
| 916 |
+
"""
|
| 917 |
+
:param x: img of size (bs, c, h, w)
|
| 918 |
+
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
|
| 919 |
+
"""
|
| 920 |
+
bs, nc, h, w = x.shape
|
| 921 |
+
|
| 922 |
+
# number of crops in image
|
| 923 |
+
Ly = (h - kernel_size[0]) // stride[0] + 1
|
| 924 |
+
Lx = (w - kernel_size[1]) // stride[1] + 1
|
| 925 |
+
|
| 926 |
+
if uf == 1 and df == 1:
|
| 927 |
+
fold_params = dict(
|
| 928 |
+
kernel_size=kernel_size, dilation=1, padding=0, stride=stride
|
| 929 |
+
)
|
| 930 |
+
unfold = torch.nn.Unfold(**fold_params)
|
| 931 |
+
|
| 932 |
+
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
|
| 933 |
+
|
| 934 |
+
weighting = self.get_weighting(
|
| 935 |
+
kernel_size[0], kernel_size[1], Ly, Lx, x.device
|
| 936 |
+
).to(x.dtype)
|
| 937 |
+
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
|
| 938 |
+
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
|
| 939 |
+
|
| 940 |
+
elif uf > 1 and df == 1:
|
| 941 |
+
fold_params = dict(
|
| 942 |
+
kernel_size=kernel_size, dilation=1, padding=0, stride=stride
|
| 943 |
+
)
|
| 944 |
+
unfold = torch.nn.Unfold(**fold_params)
|
| 945 |
+
|
| 946 |
+
fold_params2 = dict(
|
| 947 |
+
kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
|
| 948 |
+
dilation=1,
|
| 949 |
+
padding=0,
|
| 950 |
+
stride=(stride[0] * uf, stride[1] * uf),
|
| 951 |
+
)
|
| 952 |
+
fold = torch.nn.Fold(
|
| 953 |
+
output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
weighting = self.get_weighting(
|
| 957 |
+
kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device
|
| 958 |
+
).to(x.dtype)
|
| 959 |
+
normalization = fold(weighting).view(
|
| 960 |
+
1, 1, h * uf, w * uf
|
| 961 |
+
) # normalizes the overlap
|
| 962 |
+
weighting = weighting.view(
|
| 963 |
+
(1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
elif df > 1 and uf == 1:
|
| 967 |
+
fold_params = dict(
|
| 968 |
+
kernel_size=kernel_size, dilation=1, padding=0, stride=stride
|
| 969 |
+
)
|
| 970 |
+
unfold = torch.nn.Unfold(**fold_params)
|
| 971 |
+
|
| 972 |
+
fold_params2 = dict(
|
| 973 |
+
kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
|
| 974 |
+
dilation=1,
|
| 975 |
+
padding=0,
|
| 976 |
+
stride=(stride[0] // df, stride[1] // df),
|
| 977 |
+
)
|
| 978 |
+
fold = torch.nn.Fold(
|
| 979 |
+
output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2
|
| 980 |
+
)
|
| 981 |
+
|
| 982 |
+
weighting = self.get_weighting(
|
| 983 |
+
kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device
|
| 984 |
+
).to(x.dtype)
|
| 985 |
+
normalization = fold(weighting).view(
|
| 986 |
+
1, 1, h // df, w // df
|
| 987 |
+
) # normalizes the overlap
|
| 988 |
+
weighting = weighting.view(
|
| 989 |
+
(1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
else:
|
| 993 |
+
raise NotImplementedError
|
| 994 |
+
|
| 995 |
+
return fold, unfold, normalization, weighting
|
| 996 |
+
|
| 997 |
+
@torch.no_grad()
|
| 998 |
+
def get_input(
|
| 999 |
+
self,
|
| 1000 |
+
batch,
|
| 1001 |
+
k,
|
| 1002 |
+
return_first_stage_outputs=False,
|
| 1003 |
+
force_c_encode=False,
|
| 1004 |
+
cond_key=None,
|
| 1005 |
+
return_original_cond=False,
|
| 1006 |
+
bs=None,
|
| 1007 |
+
return_x=False,
|
| 1008 |
+
mask_k=None,
|
| 1009 |
+
):
|
| 1010 |
+
x = super().get_input(batch, k)
|
| 1011 |
+
if bs is not None:
|
| 1012 |
+
x = x[:bs]
|
| 1013 |
+
x = x.to(self.device)
|
| 1014 |
+
encoder_posterior = self.encode_first_stage(x)
|
| 1015 |
+
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
| 1016 |
+
|
| 1017 |
+
if mask_k is not None:
|
| 1018 |
+
mx = super().get_input(batch, mask_k)
|
| 1019 |
+
if bs is not None:
|
| 1020 |
+
mx = mx[:bs]
|
| 1021 |
+
mx = mx.to(self.device)
|
| 1022 |
+
encoder_posterior = self.encode_first_stage(mx)
|
| 1023 |
+
mx = self.get_first_stage_encoding(encoder_posterior).detach()
|
| 1024 |
+
|
| 1025 |
+
if self.model.conditioning_key is not None and not self.force_null_conditioning:
|
| 1026 |
+
if cond_key is None:
|
| 1027 |
+
cond_key = self.cond_stage_key
|
| 1028 |
+
if cond_key != self.first_stage_key:
|
| 1029 |
+
if cond_key in ["caption", "coordinates_bbox", "txt"]:
|
| 1030 |
+
xc = batch[cond_key]
|
| 1031 |
+
elif cond_key in ["class_label", "cls"]:
|
| 1032 |
+
xc = batch
|
| 1033 |
+
else:
|
| 1034 |
+
xc = super().get_input(batch, cond_key).to(self.device)
|
| 1035 |
+
else:
|
| 1036 |
+
xc = x
|
| 1037 |
+
if not self.cond_stage_trainable or force_c_encode:
|
| 1038 |
+
if isinstance(xc, dict) or isinstance(xc, list):
|
| 1039 |
+
c = self.get_learned_conditioning(xc)
|
| 1040 |
+
else:
|
| 1041 |
+
c = self.get_learned_conditioning(xc.to(self.device))
|
| 1042 |
+
else:
|
| 1043 |
+
c = xc
|
| 1044 |
+
if bs is not None:
|
| 1045 |
+
c = c[:bs]
|
| 1046 |
+
|
| 1047 |
+
if self.use_positional_encodings:
|
| 1048 |
+
pos_x, pos_y = self.compute_latent_shifts(batch)
|
| 1049 |
+
ckey = __conditioning_keys__[self.model.conditioning_key]
|
| 1050 |
+
c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y}
|
| 1051 |
+
|
| 1052 |
+
else:
|
| 1053 |
+
c = None
|
| 1054 |
+
xc = None
|
| 1055 |
+
if self.use_positional_encodings:
|
| 1056 |
+
pos_x, pos_y = self.compute_latent_shifts(batch)
|
| 1057 |
+
c = {"pos_x": pos_x, "pos_y": pos_y}
|
| 1058 |
+
out = [z, c]
|
| 1059 |
+
if return_first_stage_outputs:
|
| 1060 |
+
xrec = self.decode_first_stage(z)
|
| 1061 |
+
out.extend([x, xrec])
|
| 1062 |
+
if return_x:
|
| 1063 |
+
out.extend([x])
|
| 1064 |
+
if return_original_cond:
|
| 1065 |
+
out.append(xc)
|
| 1066 |
+
if mask_k:
|
| 1067 |
+
out.append(mx)
|
| 1068 |
+
return out
|
| 1069 |
+
|
| 1070 |
+
@torch.no_grad()
|
| 1071 |
+
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
| 1072 |
+
if predict_cids:
|
| 1073 |
+
if z.dim() == 4:
|
| 1074 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
| 1075 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
| 1076 |
+
z = rearrange(z, "b h w c -> b c h w").contiguous()
|
| 1077 |
+
|
| 1078 |
+
z = 1.0 / self.scale_factor * z
|
| 1079 |
+
return self.first_stage_model.decode(z)
|
| 1080 |
+
|
| 1081 |
+
def decode_first_stage_grad(self, z, predict_cids=False, force_not_quantize=False):
|
| 1082 |
+
if predict_cids:
|
| 1083 |
+
if z.dim() == 4:
|
| 1084 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
| 1085 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
| 1086 |
+
z = rearrange(z, "b h w c -> b c h w").contiguous()
|
| 1087 |
+
|
| 1088 |
+
z = 1.0 / self.scale_factor * z
|
| 1089 |
+
return self.first_stage_model.decode(z)
|
| 1090 |
+
|
| 1091 |
+
@torch.no_grad()
|
| 1092 |
+
def encode_first_stage(self, x):
|
| 1093 |
+
return self.first_stage_model.encode(x)
|
| 1094 |
+
|
| 1095 |
+
def shared_step(self, batch, **kwargs):
|
| 1096 |
+
x, c = self.get_input(batch, self.first_stage_key)
|
| 1097 |
+
loss = self(x, c)
|
| 1098 |
+
return loss
|
| 1099 |
+
|
| 1100 |
+
def forward(self, x, c, *args, **kwargs):
|
| 1101 |
+
t = torch.randint(
|
| 1102 |
+
0, self.num_timesteps, (x.shape[0],), device=self.device
|
| 1103 |
+
).long()
|
| 1104 |
+
# t = torch.randint(500, 501, (x.shape[0],), device=self.device).long()
|
| 1105 |
+
if self.model.conditioning_key is not None:
|
| 1106 |
+
assert c is not None
|
| 1107 |
+
if self.cond_stage_trainable:
|
| 1108 |
+
c = self.get_learned_conditioning(c)
|
| 1109 |
+
if self.shorten_cond_schedule: # TODO: drop this option
|
| 1110 |
+
tc = self.cond_ids[t].to(self.device)
|
| 1111 |
+
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
| 1112 |
+
return self.p_losses(x, c, t, *args, **kwargs)
|
| 1113 |
+
|
| 1114 |
+
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
| 1115 |
+
if isinstance(cond, dict):
|
| 1116 |
+
# hybrid case, cond is expected to be a dict
|
| 1117 |
+
pass
|
| 1118 |
+
else:
|
| 1119 |
+
if not isinstance(cond, list):
|
| 1120 |
+
cond = [cond]
|
| 1121 |
+
key = (
|
| 1122 |
+
"c_concat" if self.model.conditioning_key == "concat" else "c_crossattn"
|
| 1123 |
+
)
|
| 1124 |
+
cond = {key: cond}
|
| 1125 |
+
|
| 1126 |
+
x_recon = self.model(x_noisy, t, **cond)
|
| 1127 |
+
|
| 1128 |
+
if isinstance(x_recon, tuple) and not return_ids:
|
| 1129 |
+
return x_recon[0]
|
| 1130 |
+
else:
|
| 1131 |
+
return x_recon
|
| 1132 |
+
|
| 1133 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
| 1134 |
+
return (
|
| 1135 |
+
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
| 1136 |
+
- pred_xstart
|
| 1137 |
+
) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
| 1138 |
+
|
| 1139 |
+
def _prior_bpd(self, x_start):
|
| 1140 |
+
"""
|
| 1141 |
+
Get the prior KL term for the variational lower-bound, measured in
|
| 1142 |
+
bits-per-dim.
|
| 1143 |
+
This term can't be optimized, as it only depends on the encoder.
|
| 1144 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
| 1145 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
| 1146 |
+
"""
|
| 1147 |
+
batch_size = x_start.shape[0]
|
| 1148 |
+
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
| 1149 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
| 1150 |
+
kl_prior = normal_kl(
|
| 1151 |
+
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
| 1152 |
+
)
|
| 1153 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
| 1154 |
+
|
| 1155 |
+
def p_mean_variance(
|
| 1156 |
+
self,
|
| 1157 |
+
x,
|
| 1158 |
+
c,
|
| 1159 |
+
t,
|
| 1160 |
+
clip_denoised: bool,
|
| 1161 |
+
return_codebook_ids=False,
|
| 1162 |
+
quantize_denoised=False,
|
| 1163 |
+
return_x0=False,
|
| 1164 |
+
score_corrector=None,
|
| 1165 |
+
corrector_kwargs=None,
|
| 1166 |
+
):
|
| 1167 |
+
t_in = t
|
| 1168 |
+
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
|
| 1169 |
+
|
| 1170 |
+
if score_corrector is not None:
|
| 1171 |
+
assert self.parameterization == "eps"
|
| 1172 |
+
model_out = score_corrector.modify_score(
|
| 1173 |
+
self, model_out, x, t, c, **corrector_kwargs
|
| 1174 |
+
)
|
| 1175 |
+
|
| 1176 |
+
if return_codebook_ids:
|
| 1177 |
+
model_out, logits = model_out
|
| 1178 |
+
|
| 1179 |
+
if self.parameterization == "eps":
|
| 1180 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
|
| 1181 |
+
elif self.parameterization == "x0":
|
| 1182 |
+
x_recon = model_out
|
| 1183 |
+
else:
|
| 1184 |
+
raise NotImplementedError()
|
| 1185 |
+
|
| 1186 |
+
if clip_denoised:
|
| 1187 |
+
x_recon.clamp_(-1.0, 1.0)
|
| 1188 |
+
if quantize_denoised:
|
| 1189 |
+
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
|
| 1190 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
|
| 1191 |
+
x_start=x_recon, x_t=x, t=t
|
| 1192 |
+
)
|
| 1193 |
+
if return_codebook_ids:
|
| 1194 |
+
return model_mean, posterior_variance, posterior_log_variance, logits
|
| 1195 |
+
elif return_x0:
|
| 1196 |
+
return model_mean, posterior_variance, posterior_log_variance, x_recon
|
| 1197 |
+
else:
|
| 1198 |
+
return model_mean, posterior_variance, posterior_log_variance
|
| 1199 |
+
|
| 1200 |
+
@torch.no_grad()
|
| 1201 |
+
def p_sample(
|
| 1202 |
+
self,
|
| 1203 |
+
x,
|
| 1204 |
+
c,
|
| 1205 |
+
t,
|
| 1206 |
+
clip_denoised=False,
|
| 1207 |
+
repeat_noise=False,
|
| 1208 |
+
return_codebook_ids=False,
|
| 1209 |
+
quantize_denoised=False,
|
| 1210 |
+
return_x0=False,
|
| 1211 |
+
temperature=1.0,
|
| 1212 |
+
noise_dropout=0.0,
|
| 1213 |
+
score_corrector=None,
|
| 1214 |
+
corrector_kwargs=None,
|
| 1215 |
+
):
|
| 1216 |
+
b, *_, device = *x.shape, x.device
|
| 1217 |
+
outputs = self.p_mean_variance(
|
| 1218 |
+
x=x,
|
| 1219 |
+
c=c,
|
| 1220 |
+
t=t,
|
| 1221 |
+
clip_denoised=clip_denoised,
|
| 1222 |
+
return_codebook_ids=return_codebook_ids,
|
| 1223 |
+
quantize_denoised=quantize_denoised,
|
| 1224 |
+
return_x0=return_x0,
|
| 1225 |
+
score_corrector=score_corrector,
|
| 1226 |
+
corrector_kwargs=corrector_kwargs,
|
| 1227 |
+
)
|
| 1228 |
+
if return_codebook_ids:
|
| 1229 |
+
raise DeprecationWarning("Support dropped.")
|
| 1230 |
+
model_mean, _, model_log_variance, logits = outputs
|
| 1231 |
+
elif return_x0:
|
| 1232 |
+
model_mean, _, model_log_variance, x0 = outputs
|
| 1233 |
+
else:
|
| 1234 |
+
model_mean, _, model_log_variance = outputs
|
| 1235 |
+
|
| 1236 |
+
noise = noise_like(x.shape, device, repeat_noise) * temperature
|
| 1237 |
+
if noise_dropout > 0.0:
|
| 1238 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 1239 |
+
# no noise when t == 0
|
| 1240 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
| 1241 |
+
|
| 1242 |
+
if return_codebook_ids:
|
| 1243 |
+
return (
|
| 1244 |
+
model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
|
| 1245 |
+
logits.argmax(dim=1),
|
| 1246 |
+
)
|
| 1247 |
+
if return_x0:
|
| 1248 |
+
return (
|
| 1249 |
+
model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
|
| 1250 |
+
x0,
|
| 1251 |
+
)
|
| 1252 |
+
else:
|
| 1253 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
| 1254 |
+
|
| 1255 |
+
@torch.no_grad()
|
| 1256 |
+
def progressive_denoising(
|
| 1257 |
+
self,
|
| 1258 |
+
cond,
|
| 1259 |
+
shape,
|
| 1260 |
+
verbose=True,
|
| 1261 |
+
callback=None,
|
| 1262 |
+
quantize_denoised=False,
|
| 1263 |
+
img_callback=None,
|
| 1264 |
+
mask=None,
|
| 1265 |
+
x0=None,
|
| 1266 |
+
temperature=1.0,
|
| 1267 |
+
noise_dropout=0.0,
|
| 1268 |
+
score_corrector=None,
|
| 1269 |
+
corrector_kwargs=None,
|
| 1270 |
+
batch_size=None,
|
| 1271 |
+
x_T=None,
|
| 1272 |
+
start_T=None,
|
| 1273 |
+
log_every_t=None,
|
| 1274 |
+
):
|
| 1275 |
+
if not log_every_t:
|
| 1276 |
+
log_every_t = self.log_every_t
|
| 1277 |
+
timesteps = self.num_timesteps
|
| 1278 |
+
if batch_size is not None:
|
| 1279 |
+
b = batch_size if batch_size is not None else shape[0]
|
| 1280 |
+
shape = [batch_size] + list(shape)
|
| 1281 |
+
else:
|
| 1282 |
+
b = batch_size = shape[0]
|
| 1283 |
+
if x_T is None:
|
| 1284 |
+
img = torch.randn(shape, device=self.device)
|
| 1285 |
+
else:
|
| 1286 |
+
img = x_T
|
| 1287 |
+
intermediates = []
|
| 1288 |
+
if cond is not None:
|
| 1289 |
+
if isinstance(cond, dict):
|
| 1290 |
+
cond = {
|
| 1291 |
+
key: cond[key][:batch_size]
|
| 1292 |
+
if not isinstance(cond[key], list)
|
| 1293 |
+
else list(map(lambda x: x[:batch_size], cond[key]))
|
| 1294 |
+
for key in cond
|
| 1295 |
+
}
|
| 1296 |
+
else:
|
| 1297 |
+
cond = (
|
| 1298 |
+
[c[:batch_size] for c in cond]
|
| 1299 |
+
if isinstance(cond, list)
|
| 1300 |
+
else cond[:batch_size]
|
| 1301 |
+
)
|
| 1302 |
+
|
| 1303 |
+
if start_T is not None:
|
| 1304 |
+
timesteps = min(timesteps, start_T)
|
| 1305 |
+
iterator = (
|
| 1306 |
+
tqdm(
|
| 1307 |
+
reversed(range(0, timesteps)),
|
| 1308 |
+
desc="Progressive Generation",
|
| 1309 |
+
total=timesteps,
|
| 1310 |
+
)
|
| 1311 |
+
if verbose
|
| 1312 |
+
else reversed(range(0, timesteps))
|
| 1313 |
+
)
|
| 1314 |
+
if type(temperature) == float:
|
| 1315 |
+
temperature = [temperature] * timesteps
|
| 1316 |
+
|
| 1317 |
+
for i in iterator:
|
| 1318 |
+
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
|
| 1319 |
+
if self.shorten_cond_schedule:
|
| 1320 |
+
assert self.model.conditioning_key != "hybrid"
|
| 1321 |
+
tc = self.cond_ids[ts].to(cond.device)
|
| 1322 |
+
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
|
| 1323 |
+
|
| 1324 |
+
img, x0_partial = self.p_sample(
|
| 1325 |
+
img,
|
| 1326 |
+
cond,
|
| 1327 |
+
ts,
|
| 1328 |
+
clip_denoised=self.clip_denoised,
|
| 1329 |
+
quantize_denoised=quantize_denoised,
|
| 1330 |
+
return_x0=True,
|
| 1331 |
+
temperature=temperature[i],
|
| 1332 |
+
noise_dropout=noise_dropout,
|
| 1333 |
+
score_corrector=score_corrector,
|
| 1334 |
+
corrector_kwargs=corrector_kwargs,
|
| 1335 |
+
)
|
| 1336 |
+
if mask is not None:
|
| 1337 |
+
assert x0 is not None
|
| 1338 |
+
img_orig = self.q_sample(x0, ts)
|
| 1339 |
+
img = img_orig * mask + (1.0 - mask) * img
|
| 1340 |
+
|
| 1341 |
+
if i % log_every_t == 0 or i == timesteps - 1:
|
| 1342 |
+
intermediates.append(x0_partial)
|
| 1343 |
+
if callback:
|
| 1344 |
+
callback(i)
|
| 1345 |
+
if img_callback:
|
| 1346 |
+
img_callback(img, i)
|
| 1347 |
+
return img, intermediates
|
| 1348 |
+
|
| 1349 |
+
@torch.no_grad()
|
| 1350 |
+
def p_sample_loop(
|
| 1351 |
+
self,
|
| 1352 |
+
cond,
|
| 1353 |
+
shape,
|
| 1354 |
+
return_intermediates=False,
|
| 1355 |
+
x_T=None,
|
| 1356 |
+
verbose=True,
|
| 1357 |
+
callback=None,
|
| 1358 |
+
timesteps=None,
|
| 1359 |
+
quantize_denoised=False,
|
| 1360 |
+
mask=None,
|
| 1361 |
+
x0=None,
|
| 1362 |
+
img_callback=None,
|
| 1363 |
+
start_T=None,
|
| 1364 |
+
log_every_t=None,
|
| 1365 |
+
):
|
| 1366 |
+
if not log_every_t:
|
| 1367 |
+
log_every_t = self.log_every_t
|
| 1368 |
+
device = self.betas.device
|
| 1369 |
+
b = shape[0]
|
| 1370 |
+
if x_T is None:
|
| 1371 |
+
img = torch.randn(shape, device=device)
|
| 1372 |
+
else:
|
| 1373 |
+
img = x_T
|
| 1374 |
+
|
| 1375 |
+
intermediates = [img]
|
| 1376 |
+
if timesteps is None:
|
| 1377 |
+
timesteps = self.num_timesteps
|
| 1378 |
+
|
| 1379 |
+
if start_T is not None:
|
| 1380 |
+
timesteps = min(timesteps, start_T)
|
| 1381 |
+
iterator = (
|
| 1382 |
+
tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
|
| 1383 |
+
if verbose
|
| 1384 |
+
else reversed(range(0, timesteps))
|
| 1385 |
+
)
|
| 1386 |
+
|
| 1387 |
+
if mask is not None:
|
| 1388 |
+
assert x0 is not None
|
| 1389 |
+
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
|
| 1390 |
+
|
| 1391 |
+
for i in iterator:
|
| 1392 |
+
ts = torch.full((b,), i, device=device, dtype=torch.long)
|
| 1393 |
+
if self.shorten_cond_schedule:
|
| 1394 |
+
assert self.model.conditioning_key != "hybrid"
|
| 1395 |
+
tc = self.cond_ids[ts].to(cond.device)
|
| 1396 |
+
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
|
| 1397 |
+
|
| 1398 |
+
img = self.p_sample(
|
| 1399 |
+
img,
|
| 1400 |
+
cond,
|
| 1401 |
+
ts,
|
| 1402 |
+
clip_denoised=self.clip_denoised,
|
| 1403 |
+
quantize_denoised=quantize_denoised,
|
| 1404 |
+
)
|
| 1405 |
+
if mask is not None:
|
| 1406 |
+
img_orig = self.q_sample(x0, ts)
|
| 1407 |
+
img = img_orig * mask + (1.0 - mask) * img
|
| 1408 |
+
|
| 1409 |
+
if i % log_every_t == 0 or i == timesteps - 1:
|
| 1410 |
+
intermediates.append(img)
|
| 1411 |
+
if callback:
|
| 1412 |
+
callback(i)
|
| 1413 |
+
if img_callback:
|
| 1414 |
+
img_callback(img, i)
|
| 1415 |
+
|
| 1416 |
+
if return_intermediates:
|
| 1417 |
+
return img, intermediates
|
| 1418 |
+
return img
|
| 1419 |
+
|
| 1420 |
+
@torch.no_grad()
|
| 1421 |
+
def sample(
|
| 1422 |
+
self,
|
| 1423 |
+
cond,
|
| 1424 |
+
batch_size=16,
|
| 1425 |
+
return_intermediates=False,
|
| 1426 |
+
x_T=None,
|
| 1427 |
+
verbose=True,
|
| 1428 |
+
timesteps=None,
|
| 1429 |
+
quantize_denoised=False,
|
| 1430 |
+
mask=None,
|
| 1431 |
+
x0=None,
|
| 1432 |
+
shape=None,
|
| 1433 |
+
**kwargs,
|
| 1434 |
+
):
|
| 1435 |
+
if shape is None:
|
| 1436 |
+
shape = (batch_size, self.channels, self.image_size, self.image_size)
|
| 1437 |
+
if cond is not None:
|
| 1438 |
+
if isinstance(cond, dict):
|
| 1439 |
+
cond = {
|
| 1440 |
+
key: cond[key][:batch_size]
|
| 1441 |
+
if not isinstance(cond[key], list)
|
| 1442 |
+
else list(map(lambda x: x[:batch_size], cond[key]))
|
| 1443 |
+
for key in cond
|
| 1444 |
+
}
|
| 1445 |
+
else:
|
| 1446 |
+
cond = (
|
| 1447 |
+
[c[:batch_size] for c in cond]
|
| 1448 |
+
if isinstance(cond, list)
|
| 1449 |
+
else cond[:batch_size]
|
| 1450 |
+
)
|
| 1451 |
+
return self.p_sample_loop(
|
| 1452 |
+
cond,
|
| 1453 |
+
shape,
|
| 1454 |
+
return_intermediates=return_intermediates,
|
| 1455 |
+
x_T=x_T,
|
| 1456 |
+
verbose=verbose,
|
| 1457 |
+
timesteps=timesteps,
|
| 1458 |
+
quantize_denoised=quantize_denoised,
|
| 1459 |
+
mask=mask,
|
| 1460 |
+
x0=x0,
|
| 1461 |
+
)
|
| 1462 |
+
|
| 1463 |
+
@torch.no_grad()
|
| 1464 |
+
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
| 1465 |
+
if ddim:
|
| 1466 |
+
ddim_sampler = DDIMSampler(self)
|
| 1467 |
+
shape = (self.channels, self.image_size, self.image_size)
|
| 1468 |
+
samples, intermediates = ddim_sampler.sample(
|
| 1469 |
+
ddim_steps, batch_size, shape, cond, verbose=False, **kwargs
|
| 1470 |
+
)
|
| 1471 |
+
|
| 1472 |
+
else:
|
| 1473 |
+
samples, intermediates = self.sample(
|
| 1474 |
+
cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs
|
| 1475 |
+
)
|
| 1476 |
+
|
| 1477 |
+
return samples, intermediates
|
| 1478 |
+
|
| 1479 |
+
@torch.no_grad()
|
| 1480 |
+
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
| 1481 |
+
if null_label is not None:
|
| 1482 |
+
xc = null_label
|
| 1483 |
+
if isinstance(xc, ListConfig):
|
| 1484 |
+
xc = list(xc)
|
| 1485 |
+
if isinstance(xc, dict) or isinstance(xc, list):
|
| 1486 |
+
c = self.get_learned_conditioning(xc)
|
| 1487 |
+
else:
|
| 1488 |
+
if hasattr(xc, "to"):
|
| 1489 |
+
xc = xc.to(self.device)
|
| 1490 |
+
c = self.get_learned_conditioning(xc)
|
| 1491 |
+
else:
|
| 1492 |
+
if self.cond_stage_key in ["class_label", "cls"]:
|
| 1493 |
+
xc = self.cond_stage_model.get_unconditional_conditioning(
|
| 1494 |
+
batch_size, device=self.device
|
| 1495 |
+
)
|
| 1496 |
+
return self.get_learned_conditioning(xc)
|
| 1497 |
+
else:
|
| 1498 |
+
raise NotImplementedError("todo")
|
| 1499 |
+
if isinstance(c, list): # in case the encoder gives us a list
|
| 1500 |
+
for i in range(len(c)):
|
| 1501 |
+
c[i] = repeat(c[i], "1 ... -> b ...", b=batch_size).to(self.device)
|
| 1502 |
+
else:
|
| 1503 |
+
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
|
| 1504 |
+
return c
|
| 1505 |
+
|
| 1506 |
+
@torch.no_grad()
|
| 1507 |
+
def log_images(
|
| 1508 |
+
self,
|
| 1509 |
+
batch,
|
| 1510 |
+
N=8,
|
| 1511 |
+
n_row=4,
|
| 1512 |
+
sample=True,
|
| 1513 |
+
ddim_steps=50,
|
| 1514 |
+
ddim_eta=0.0,
|
| 1515 |
+
return_keys=None,
|
| 1516 |
+
quantize_denoised=True,
|
| 1517 |
+
inpaint=True,
|
| 1518 |
+
plot_denoise_rows=False,
|
| 1519 |
+
plot_progressive_rows=True,
|
| 1520 |
+
plot_diffusion_rows=True,
|
| 1521 |
+
unconditional_guidance_scale=1.0,
|
| 1522 |
+
unconditional_guidance_label=None,
|
| 1523 |
+
use_ema_scope=True,
|
| 1524 |
+
**kwargs,
|
| 1525 |
+
):
|
| 1526 |
+
ema_scope = self.ema_scope if use_ema_scope else nullcontext
|
| 1527 |
+
use_ddim = ddim_steps is not None
|
| 1528 |
+
|
| 1529 |
+
log = dict()
|
| 1530 |
+
z, c, x, xrec, xc = self.get_input(
|
| 1531 |
+
batch,
|
| 1532 |
+
self.first_stage_key,
|
| 1533 |
+
return_first_stage_outputs=True,
|
| 1534 |
+
force_c_encode=True,
|
| 1535 |
+
return_original_cond=True,
|
| 1536 |
+
bs=N,
|
| 1537 |
+
)
|
| 1538 |
+
N = min(x.shape[0], N)
|
| 1539 |
+
n_row = min(x.shape[0], n_row)
|
| 1540 |
+
log["inputs"] = x
|
| 1541 |
+
log["reconstruction"] = xrec
|
| 1542 |
+
if self.model.conditioning_key is not None:
|
| 1543 |
+
if hasattr(self.cond_stage_model, "decode"):
|
| 1544 |
+
xc = self.cond_stage_model.decode(c)
|
| 1545 |
+
log["conditioning"] = xc
|
| 1546 |
+
elif self.cond_stage_key in ["caption", "txt"]:
|
| 1547 |
+
xc = log_txt_as_img(
|
| 1548 |
+
(x.shape[2], x.shape[3]),
|
| 1549 |
+
batch[self.cond_stage_key],
|
| 1550 |
+
size=x.shape[2] // 25,
|
| 1551 |
+
)
|
| 1552 |
+
log["conditioning"] = xc
|
| 1553 |
+
elif self.cond_stage_key in ["class_label", "cls"]:
|
| 1554 |
+
try:
|
| 1555 |
+
xc = log_txt_as_img(
|
| 1556 |
+
(x.shape[2], x.shape[3]),
|
| 1557 |
+
batch["human_label"],
|
| 1558 |
+
size=x.shape[2] // 25,
|
| 1559 |
+
)
|
| 1560 |
+
log["conditioning"] = xc
|
| 1561 |
+
except KeyError:
|
| 1562 |
+
# probably no "human_label" in batch
|
| 1563 |
+
pass
|
| 1564 |
+
elif isimage(xc):
|
| 1565 |
+
log["conditioning"] = xc
|
| 1566 |
+
if ismap(xc):
|
| 1567 |
+
log["original_conditioning"] = self.to_rgb(xc)
|
| 1568 |
+
|
| 1569 |
+
if plot_diffusion_rows:
|
| 1570 |
+
# get diffusion row
|
| 1571 |
+
diffusion_row = list()
|
| 1572 |
+
z_start = z[:n_row]
|
| 1573 |
+
for t in range(self.num_timesteps):
|
| 1574 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
| 1575 |
+
t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
|
| 1576 |
+
t = t.to(self.device).long()
|
| 1577 |
+
noise = torch.randn_like(z_start)
|
| 1578 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
| 1579 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
| 1580 |
+
|
| 1581 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
| 1582 |
+
diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
|
| 1583 |
+
diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
|
| 1584 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
| 1585 |
+
log["diffusion_row"] = diffusion_grid
|
| 1586 |
+
|
| 1587 |
+
if sample:
|
| 1588 |
+
# get denoise row
|
| 1589 |
+
with ema_scope("Sampling"):
|
| 1590 |
+
samples, z_denoise_row = self.sample_log(
|
| 1591 |
+
cond=c,
|
| 1592 |
+
batch_size=N,
|
| 1593 |
+
ddim=use_ddim,
|
| 1594 |
+
ddim_steps=ddim_steps,
|
| 1595 |
+
eta=ddim_eta,
|
| 1596 |
+
)
|
| 1597 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
| 1598 |
+
x_samples = self.decode_first_stage(samples)
|
| 1599 |
+
log["samples"] = x_samples
|
| 1600 |
+
if plot_denoise_rows:
|
| 1601 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
| 1602 |
+
log["denoise_row"] = denoise_grid
|
| 1603 |
+
|
| 1604 |
+
if (
|
| 1605 |
+
quantize_denoised
|
| 1606 |
+
and not isinstance(self.first_stage_model, AutoencoderKL)
|
| 1607 |
+
and not isinstance(self.first_stage_model, IdentityFirstStage)
|
| 1608 |
+
):
|
| 1609 |
+
# also display when quantizing x0 while sampling
|
| 1610 |
+
with ema_scope("Plotting Quantized Denoised"):
|
| 1611 |
+
samples, z_denoise_row = self.sample_log(
|
| 1612 |
+
cond=c,
|
| 1613 |
+
batch_size=N,
|
| 1614 |
+
ddim=use_ddim,
|
| 1615 |
+
ddim_steps=ddim_steps,
|
| 1616 |
+
eta=ddim_eta,
|
| 1617 |
+
quantize_denoised=True,
|
| 1618 |
+
)
|
| 1619 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
| 1620 |
+
# quantize_denoised=True)
|
| 1621 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
| 1622 |
+
log["samples_x0_quantized"] = x_samples
|
| 1623 |
+
|
| 1624 |
+
if unconditional_guidance_scale > 1.0:
|
| 1625 |
+
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
| 1626 |
+
if self.model.conditioning_key == "crossattn-adm":
|
| 1627 |
+
uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
|
| 1628 |
+
with ema_scope("Sampling with classifier-free guidance"):
|
| 1629 |
+
samples_cfg, _ = self.sample_log(
|
| 1630 |
+
cond=c,
|
| 1631 |
+
batch_size=N,
|
| 1632 |
+
ddim=use_ddim,
|
| 1633 |
+
ddim_steps=ddim_steps,
|
| 1634 |
+
eta=ddim_eta,
|
| 1635 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 1636 |
+
unconditional_conditioning=uc,
|
| 1637 |
+
)
|
| 1638 |
+
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
| 1639 |
+
log[
|
| 1640 |
+
f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
|
| 1641 |
+
] = x_samples_cfg
|
| 1642 |
+
|
| 1643 |
+
if inpaint:
|
| 1644 |
+
# make a simple center square
|
| 1645 |
+
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
| 1646 |
+
mask = torch.ones(N, h, w).to(self.device)
|
| 1647 |
+
# zeros will be filled in
|
| 1648 |
+
mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0
|
| 1649 |
+
mask = mask[:, None, ...]
|
| 1650 |
+
with ema_scope("Plotting Inpaint"):
|
| 1651 |
+
samples, _ = self.sample_log(
|
| 1652 |
+
cond=c,
|
| 1653 |
+
batch_size=N,
|
| 1654 |
+
ddim=use_ddim,
|
| 1655 |
+
eta=ddim_eta,
|
| 1656 |
+
ddim_steps=ddim_steps,
|
| 1657 |
+
x0=z[:N],
|
| 1658 |
+
mask=mask,
|
| 1659 |
+
)
|
| 1660 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
| 1661 |
+
log["samples_inpainting"] = x_samples
|
| 1662 |
+
log["mask"] = mask
|
| 1663 |
+
|
| 1664 |
+
# outpaint
|
| 1665 |
+
mask = 1.0 - mask
|
| 1666 |
+
with ema_scope("Plotting Outpaint"):
|
| 1667 |
+
samples, _ = self.sample_log(
|
| 1668 |
+
cond=c,
|
| 1669 |
+
batch_size=N,
|
| 1670 |
+
ddim=use_ddim,
|
| 1671 |
+
eta=ddim_eta,
|
| 1672 |
+
ddim_steps=ddim_steps,
|
| 1673 |
+
x0=z[:N],
|
| 1674 |
+
mask=mask,
|
| 1675 |
+
)
|
| 1676 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
| 1677 |
+
log["samples_outpainting"] = x_samples
|
| 1678 |
+
|
| 1679 |
+
if plot_progressive_rows:
|
| 1680 |
+
with ema_scope("Plotting Progressives"):
|
| 1681 |
+
img, progressives = self.progressive_denoising(
|
| 1682 |
+
c,
|
| 1683 |
+
shape=(self.channels, self.image_size, self.image_size),
|
| 1684 |
+
batch_size=N,
|
| 1685 |
+
)
|
| 1686 |
+
prog_row = self._get_denoise_row_from_list(
|
| 1687 |
+
progressives, desc="Progressive Generation"
|
| 1688 |
+
)
|
| 1689 |
+
log["progressive_row"] = prog_row
|
| 1690 |
+
|
| 1691 |
+
if return_keys:
|
| 1692 |
+
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
|
| 1693 |
+
return log
|
| 1694 |
+
else:
|
| 1695 |
+
return {key: log[key] for key in return_keys}
|
| 1696 |
+
return log
|
| 1697 |
+
|
| 1698 |
+
def configure_optimizers(self):
|
| 1699 |
+
lr = self.learning_rate
|
| 1700 |
+
params = list(self.model.parameters())
|
| 1701 |
+
if self.cond_stage_trainable:
|
| 1702 |
+
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
|
| 1703 |
+
params = params + list(self.cond_stage_model.parameters())
|
| 1704 |
+
if self.learn_logvar:
|
| 1705 |
+
print("Diffusion model optimizing logvar")
|
| 1706 |
+
params.append(self.logvar)
|
| 1707 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
| 1708 |
+
if self.use_scheduler:
|
| 1709 |
+
assert "target" in self.scheduler_config
|
| 1710 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
| 1711 |
+
|
| 1712 |
+
print("Setting up LambdaLR scheduler...")
|
| 1713 |
+
scheduler = [
|
| 1714 |
+
{
|
| 1715 |
+
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
| 1716 |
+
"interval": "step",
|
| 1717 |
+
"frequency": 1,
|
| 1718 |
+
}
|
| 1719 |
+
]
|
| 1720 |
+
return [opt], scheduler
|
| 1721 |
+
return opt
|
| 1722 |
+
|
| 1723 |
+
@torch.no_grad()
|
| 1724 |
+
def to_rgb(self, x):
|
| 1725 |
+
x = x.float()
|
| 1726 |
+
if not hasattr(self, "colorize"):
|
| 1727 |
+
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
|
| 1728 |
+
x = nn.functional.conv2d(x, weight=self.colorize)
|
| 1729 |
+
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
| 1730 |
+
return x
|
| 1731 |
+
|
| 1732 |
+
|
| 1733 |
+
class DiffusionWrapper(torch.nn.Module):
|
| 1734 |
+
def __init__(self, diff_model_config, conditioning_key):
|
| 1735 |
+
super().__init__()
|
| 1736 |
+
self.sequential_cross_attn = diff_model_config.pop(
|
| 1737 |
+
"sequential_crossattn", False
|
| 1738 |
+
)
|
| 1739 |
+
self.diffusion_model = instantiate_from_config(diff_model_config)
|
| 1740 |
+
self.conditioning_key = conditioning_key
|
| 1741 |
+
assert self.conditioning_key in [
|
| 1742 |
+
None,
|
| 1743 |
+
"concat",
|
| 1744 |
+
"crossattn",
|
| 1745 |
+
"hybrid",
|
| 1746 |
+
"adm",
|
| 1747 |
+
"hybrid-adm",
|
| 1748 |
+
"crossattn-adm",
|
| 1749 |
+
]
|
| 1750 |
+
|
| 1751 |
+
def forward(
|
| 1752 |
+
self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None
|
| 1753 |
+
):
|
| 1754 |
+
if self.conditioning_key is None:
|
| 1755 |
+
out = self.diffusion_model(x, t)
|
| 1756 |
+
elif self.conditioning_key == "concat":
|
| 1757 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
| 1758 |
+
out = self.diffusion_model(xc, t)
|
| 1759 |
+
elif self.conditioning_key == "crossattn":
|
| 1760 |
+
if not self.sequential_cross_attn:
|
| 1761 |
+
cc = torch.cat(c_crossattn, 1)
|
| 1762 |
+
else:
|
| 1763 |
+
cc = c_crossattn
|
| 1764 |
+
out = self.diffusion_model(x, t, context=cc)
|
| 1765 |
+
elif self.conditioning_key == "hybrid":
|
| 1766 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
| 1767 |
+
cc = torch.cat(c_crossattn, 1)
|
| 1768 |
+
out = self.diffusion_model(xc, t, context=cc)
|
| 1769 |
+
elif self.conditioning_key == "hybrid-adm":
|
| 1770 |
+
assert c_adm is not None
|
| 1771 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
| 1772 |
+
cc = torch.cat(c_crossattn, 1)
|
| 1773 |
+
out = self.diffusion_model(xc, t, context=cc, y=c_adm)
|
| 1774 |
+
elif self.conditioning_key == "crossattn-adm":
|
| 1775 |
+
assert c_adm is not None
|
| 1776 |
+
cc = torch.cat(c_crossattn, 1)
|
| 1777 |
+
out = self.diffusion_model(x, t, context=cc, y=c_adm)
|
| 1778 |
+
elif self.conditioning_key == "adm":
|
| 1779 |
+
cc = c_crossattn[0]
|
| 1780 |
+
out = self.diffusion_model(x, t, y=cc)
|
| 1781 |
+
else:
|
| 1782 |
+
raise NotImplementedError()
|
| 1783 |
+
|
| 1784 |
+
return out
|
| 1785 |
+
|
| 1786 |
+
|
| 1787 |
+
class LatentUpscaleDiffusion(LatentDiffusion):
|
| 1788 |
+
def __init__(
|
| 1789 |
+
self,
|
| 1790 |
+
*args,
|
| 1791 |
+
low_scale_config,
|
| 1792 |
+
low_scale_key="LR",
|
| 1793 |
+
noise_level_key=None,
|
| 1794 |
+
**kwargs,
|
| 1795 |
+
):
|
| 1796 |
+
super().__init__(*args, **kwargs)
|
| 1797 |
+
# assumes that neither the cond_stage nor the low_scale_model contain trainable params
|
| 1798 |
+
assert not self.cond_stage_trainable
|
| 1799 |
+
self.instantiate_low_stage(low_scale_config)
|
| 1800 |
+
self.low_scale_key = low_scale_key
|
| 1801 |
+
self.noise_level_key = noise_level_key
|
| 1802 |
+
|
| 1803 |
+
def instantiate_low_stage(self, config):
|
| 1804 |
+
model = instantiate_from_config(config)
|
| 1805 |
+
self.low_scale_model = model.eval()
|
| 1806 |
+
self.low_scale_model.train = disabled_train
|
| 1807 |
+
for param in self.low_scale_model.parameters():
|
| 1808 |
+
param.requires_grad = False
|
| 1809 |
+
|
| 1810 |
+
@torch.no_grad()
|
| 1811 |
+
def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
|
| 1812 |
+
if not log_mode:
|
| 1813 |
+
z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
|
| 1814 |
+
else:
|
| 1815 |
+
z, c, x, xrec, xc = super().get_input(
|
| 1816 |
+
batch,
|
| 1817 |
+
self.first_stage_key,
|
| 1818 |
+
return_first_stage_outputs=True,
|
| 1819 |
+
force_c_encode=True,
|
| 1820 |
+
return_original_cond=True,
|
| 1821 |
+
bs=bs,
|
| 1822 |
+
)
|
| 1823 |
+
x_low = batch[self.low_scale_key][:bs]
|
| 1824 |
+
x_low = rearrange(x_low, "b h w c -> b c h w")
|
| 1825 |
+
x_low = x_low.to(memory_format=torch.contiguous_format).float()
|
| 1826 |
+
zx, noise_level = self.low_scale_model(x_low)
|
| 1827 |
+
if self.noise_level_key is not None:
|
| 1828 |
+
# get noise level from batch instead, e.g. when extracting a custom noise level for bsr
|
| 1829 |
+
raise NotImplementedError("TODO")
|
| 1830 |
+
|
| 1831 |
+
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
|
| 1832 |
+
if log_mode:
|
| 1833 |
+
# TODO: maybe disable if too expensive
|
| 1834 |
+
x_low_rec = self.low_scale_model.decode(zx)
|
| 1835 |
+
return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
|
| 1836 |
+
return z, all_conds
|
| 1837 |
+
|
| 1838 |
+
@torch.no_grad()
|
| 1839 |
+
def log_images(
|
| 1840 |
+
self,
|
| 1841 |
+
batch,
|
| 1842 |
+
N=8,
|
| 1843 |
+
n_row=4,
|
| 1844 |
+
sample=True,
|
| 1845 |
+
ddim_steps=200,
|
| 1846 |
+
ddim_eta=1.0,
|
| 1847 |
+
return_keys=None,
|
| 1848 |
+
plot_denoise_rows=False,
|
| 1849 |
+
plot_progressive_rows=True,
|
| 1850 |
+
plot_diffusion_rows=True,
|
| 1851 |
+
unconditional_guidance_scale=1.0,
|
| 1852 |
+
unconditional_guidance_label=None,
|
| 1853 |
+
use_ema_scope=True,
|
| 1854 |
+
**kwargs,
|
| 1855 |
+
):
|
| 1856 |
+
ema_scope = self.ema_scope if use_ema_scope else nullcontext
|
| 1857 |
+
use_ddim = ddim_steps is not None
|
| 1858 |
+
|
| 1859 |
+
log = dict()
|
| 1860 |
+
z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(
|
| 1861 |
+
batch, self.first_stage_key, bs=N, log_mode=True
|
| 1862 |
+
)
|
| 1863 |
+
N = min(x.shape[0], N)
|
| 1864 |
+
n_row = min(x.shape[0], n_row)
|
| 1865 |
+
log["inputs"] = x
|
| 1866 |
+
log["reconstruction"] = xrec
|
| 1867 |
+
log["x_lr"] = x_low
|
| 1868 |
+
log[
|
| 1869 |
+
f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"
|
| 1870 |
+
] = x_low_rec
|
| 1871 |
+
if self.model.conditioning_key is not None:
|
| 1872 |
+
if hasattr(self.cond_stage_model, "decode"):
|
| 1873 |
+
xc = self.cond_stage_model.decode(c)
|
| 1874 |
+
log["conditioning"] = xc
|
| 1875 |
+
elif self.cond_stage_key in ["caption", "txt"]:
|
| 1876 |
+
xc = log_txt_as_img(
|
| 1877 |
+
(x.shape[2], x.shape[3]),
|
| 1878 |
+
batch[self.cond_stage_key],
|
| 1879 |
+
size=x.shape[2] // 25,
|
| 1880 |
+
)
|
| 1881 |
+
log["conditioning"] = xc
|
| 1882 |
+
elif self.cond_stage_key in ["class_label", "cls"]:
|
| 1883 |
+
xc = log_txt_as_img(
|
| 1884 |
+
(x.shape[2], x.shape[3]),
|
| 1885 |
+
batch["human_label"],
|
| 1886 |
+
size=x.shape[2] // 25,
|
| 1887 |
+
)
|
| 1888 |
+
log["conditioning"] = xc
|
| 1889 |
+
elif isimage(xc):
|
| 1890 |
+
log["conditioning"] = xc
|
| 1891 |
+
if ismap(xc):
|
| 1892 |
+
log["original_conditioning"] = self.to_rgb(xc)
|
| 1893 |
+
|
| 1894 |
+
if plot_diffusion_rows:
|
| 1895 |
+
# get diffusion row
|
| 1896 |
+
diffusion_row = list()
|
| 1897 |
+
z_start = z[:n_row]
|
| 1898 |
+
for t in range(self.num_timesteps):
|
| 1899 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
| 1900 |
+
t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
|
| 1901 |
+
t = t.to(self.device).long()
|
| 1902 |
+
noise = torch.randn_like(z_start)
|
| 1903 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
| 1904 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
| 1905 |
+
|
| 1906 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
| 1907 |
+
diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
|
| 1908 |
+
diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
|
| 1909 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
| 1910 |
+
log["diffusion_row"] = diffusion_grid
|
| 1911 |
+
|
| 1912 |
+
if sample:
|
| 1913 |
+
# get denoise row
|
| 1914 |
+
with ema_scope("Sampling"):
|
| 1915 |
+
samples, z_denoise_row = self.sample_log(
|
| 1916 |
+
cond=c,
|
| 1917 |
+
batch_size=N,
|
| 1918 |
+
ddim=use_ddim,
|
| 1919 |
+
ddim_steps=ddim_steps,
|
| 1920 |
+
eta=ddim_eta,
|
| 1921 |
+
)
|
| 1922 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
| 1923 |
+
x_samples = self.decode_first_stage(samples)
|
| 1924 |
+
log["samples"] = x_samples
|
| 1925 |
+
if plot_denoise_rows:
|
| 1926 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
| 1927 |
+
log["denoise_row"] = denoise_grid
|
| 1928 |
+
|
| 1929 |
+
if unconditional_guidance_scale > 1.0:
|
| 1930 |
+
uc_tmp = self.get_unconditional_conditioning(
|
| 1931 |
+
N, unconditional_guidance_label
|
| 1932 |
+
)
|
| 1933 |
+
# TODO explore better "unconditional" choices for the other keys
|
| 1934 |
+
# maybe guide away from empty text label and highest noise level and maximally degraded zx?
|
| 1935 |
+
uc = dict()
|
| 1936 |
+
for k in c:
|
| 1937 |
+
if k == "c_crossattn":
|
| 1938 |
+
assert isinstance(c[k], list) and len(c[k]) == 1
|
| 1939 |
+
uc[k] = [uc_tmp]
|
| 1940 |
+
elif k == "c_adm": # todo: only run with text-based guidance?
|
| 1941 |
+
assert isinstance(c[k], torch.Tensor)
|
| 1942 |
+
# uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
|
| 1943 |
+
uc[k] = c[k]
|
| 1944 |
+
elif isinstance(c[k], list):
|
| 1945 |
+
uc[k] = [c[k][i] for i in range(len(c[k]))]
|
| 1946 |
+
else:
|
| 1947 |
+
uc[k] = c[k]
|
| 1948 |
+
|
| 1949 |
+
with ema_scope("Sampling with classifier-free guidance"):
|
| 1950 |
+
samples_cfg, _ = self.sample_log(
|
| 1951 |
+
cond=c,
|
| 1952 |
+
batch_size=N,
|
| 1953 |
+
ddim=use_ddim,
|
| 1954 |
+
ddim_steps=ddim_steps,
|
| 1955 |
+
eta=ddim_eta,
|
| 1956 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 1957 |
+
unconditional_conditioning=uc,
|
| 1958 |
+
)
|
| 1959 |
+
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
| 1960 |
+
log[
|
| 1961 |
+
f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
|
| 1962 |
+
] = x_samples_cfg
|
| 1963 |
+
|
| 1964 |
+
if plot_progressive_rows:
|
| 1965 |
+
with ema_scope("Plotting Progressives"):
|
| 1966 |
+
img, progressives = self.progressive_denoising(
|
| 1967 |
+
c,
|
| 1968 |
+
shape=(self.channels, self.image_size, self.image_size),
|
| 1969 |
+
batch_size=N,
|
| 1970 |
+
)
|
| 1971 |
+
prog_row = self._get_denoise_row_from_list(
|
| 1972 |
+
progressives, desc="Progressive Generation"
|
| 1973 |
+
)
|
| 1974 |
+
log["progressive_row"] = prog_row
|
| 1975 |
+
|
| 1976 |
+
return log
|
| 1977 |
+
|
| 1978 |
+
|
| 1979 |
+
class LatentFinetuneDiffusion(LatentDiffusion):
|
| 1980 |
+
"""
|
| 1981 |
+
Basis for different finetunas, such as inpainting or depth2image
|
| 1982 |
+
To disable finetuning mode, set finetune_keys to None
|
| 1983 |
+
"""
|
| 1984 |
+
|
| 1985 |
+
def __init__(
|
| 1986 |
+
self,
|
| 1987 |
+
concat_keys: tuple,
|
| 1988 |
+
finetune_keys=(
|
| 1989 |
+
"model.diffusion_model.input_blocks.0.0.weight",
|
| 1990 |
+
"model_ema.diffusion_modelinput_blocks00weight",
|
| 1991 |
+
),
|
| 1992 |
+
keep_finetune_dims=4,
|
| 1993 |
+
# if model was trained without concat mode before and we would like to keep these channels
|
| 1994 |
+
c_concat_log_start=None, # to log reconstruction of c_concat codes
|
| 1995 |
+
c_concat_log_end=None,
|
| 1996 |
+
*args,
|
| 1997 |
+
**kwargs,
|
| 1998 |
+
):
|
| 1999 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
| 2000 |
+
ignore_keys = kwargs.pop("ignore_keys", list())
|
| 2001 |
+
super().__init__(*args, **kwargs)
|
| 2002 |
+
self.finetune_keys = finetune_keys
|
| 2003 |
+
self.concat_keys = concat_keys
|
| 2004 |
+
self.keep_dims = keep_finetune_dims
|
| 2005 |
+
self.c_concat_log_start = c_concat_log_start
|
| 2006 |
+
self.c_concat_log_end = c_concat_log_end
|
| 2007 |
+
if exists(self.finetune_keys):
|
| 2008 |
+
assert exists(ckpt_path), "can only finetune from a given checkpoint"
|
| 2009 |
+
if exists(ckpt_path):
|
| 2010 |
+
self.init_from_ckpt(ckpt_path, ignore_keys)
|
| 2011 |
+
|
| 2012 |
+
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
| 2013 |
+
sd = torch.load(path, map_location="cpu")
|
| 2014 |
+
if "state_dict" in list(sd.keys()):
|
| 2015 |
+
sd = sd["state_dict"]
|
| 2016 |
+
keys = list(sd.keys())
|
| 2017 |
+
for k in keys:
|
| 2018 |
+
for ik in ignore_keys:
|
| 2019 |
+
if k.startswith(ik):
|
| 2020 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 2021 |
+
del sd[k]
|
| 2022 |
+
|
| 2023 |
+
# make it explicit, finetune by including extra input channels
|
| 2024 |
+
if exists(self.finetune_keys) and k in self.finetune_keys:
|
| 2025 |
+
new_entry = None
|
| 2026 |
+
for name, param in self.named_parameters():
|
| 2027 |
+
if name in self.finetune_keys:
|
| 2028 |
+
print(
|
| 2029 |
+
f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only"
|
| 2030 |
+
)
|
| 2031 |
+
new_entry = torch.zeros_like(param) # zero init
|
| 2032 |
+
assert exists(new_entry), "did not find matching parameter to modify"
|
| 2033 |
+
new_entry[:, : self.keep_dims, ...] = sd[k]
|
| 2034 |
+
sd[k] = new_entry
|
| 2035 |
+
|
| 2036 |
+
missing, unexpected = (
|
| 2037 |
+
self.load_state_dict(sd, strict=False)
|
| 2038 |
+
if not only_model
|
| 2039 |
+
else self.model.load_state_dict(sd, strict=False)
|
| 2040 |
+
)
|
| 2041 |
+
print(
|
| 2042 |
+
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
| 2043 |
+
)
|
| 2044 |
+
if len(missing) > 0:
|
| 2045 |
+
print(f"Missing Keys: {missing}")
|
| 2046 |
+
if len(unexpected) > 0:
|
| 2047 |
+
print(f"Unexpected Keys: {unexpected}")
|
| 2048 |
+
|
| 2049 |
+
@torch.no_grad()
|
| 2050 |
+
def log_images(
|
| 2051 |
+
self,
|
| 2052 |
+
batch,
|
| 2053 |
+
N=8,
|
| 2054 |
+
n_row=4,
|
| 2055 |
+
sample=True,
|
| 2056 |
+
ddim_steps=200,
|
| 2057 |
+
ddim_eta=1.0,
|
| 2058 |
+
return_keys=None,
|
| 2059 |
+
quantize_denoised=True,
|
| 2060 |
+
inpaint=True,
|
| 2061 |
+
plot_denoise_rows=False,
|
| 2062 |
+
plot_progressive_rows=True,
|
| 2063 |
+
plot_diffusion_rows=True,
|
| 2064 |
+
unconditional_guidance_scale=1.0,
|
| 2065 |
+
unconditional_guidance_label=None,
|
| 2066 |
+
use_ema_scope=True,
|
| 2067 |
+
**kwargs,
|
| 2068 |
+
):
|
| 2069 |
+
ema_scope = self.ema_scope if use_ema_scope else nullcontext
|
| 2070 |
+
use_ddim = ddim_steps is not None
|
| 2071 |
+
|
| 2072 |
+
log = dict()
|
| 2073 |
+
z, c, x, xrec, xc = self.get_input(
|
| 2074 |
+
batch, self.first_stage_key, bs=N, return_first_stage_outputs=True
|
| 2075 |
+
)
|
| 2076 |
+
c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
|
| 2077 |
+
N = min(x.shape[0], N)
|
| 2078 |
+
n_row = min(x.shape[0], n_row)
|
| 2079 |
+
log["inputs"] = x
|
| 2080 |
+
log["reconstruction"] = xrec
|
| 2081 |
+
if self.model.conditioning_key is not None:
|
| 2082 |
+
if hasattr(self.cond_stage_model, "decode"):
|
| 2083 |
+
xc = self.cond_stage_model.decode(c)
|
| 2084 |
+
log["conditioning"] = xc
|
| 2085 |
+
elif self.cond_stage_key in ["caption", "txt"]:
|
| 2086 |
+
xc = log_txt_as_img(
|
| 2087 |
+
(x.shape[2], x.shape[3]),
|
| 2088 |
+
batch[self.cond_stage_key],
|
| 2089 |
+
size=x.shape[2] // 25,
|
| 2090 |
+
)
|
| 2091 |
+
log["conditioning"] = xc
|
| 2092 |
+
elif self.cond_stage_key in ["class_label", "cls"]:
|
| 2093 |
+
xc = log_txt_as_img(
|
| 2094 |
+
(x.shape[2], x.shape[3]),
|
| 2095 |
+
batch["human_label"],
|
| 2096 |
+
size=x.shape[2] // 25,
|
| 2097 |
+
)
|
| 2098 |
+
log["conditioning"] = xc
|
| 2099 |
+
elif isimage(xc):
|
| 2100 |
+
log["conditioning"] = xc
|
| 2101 |
+
if ismap(xc):
|
| 2102 |
+
log["original_conditioning"] = self.to_rgb(xc)
|
| 2103 |
+
|
| 2104 |
+
if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
|
| 2105 |
+
log["c_concat_decoded"] = self.decode_first_stage(
|
| 2106 |
+
c_cat[:, self.c_concat_log_start : self.c_concat_log_end]
|
| 2107 |
+
)
|
| 2108 |
+
|
| 2109 |
+
if plot_diffusion_rows:
|
| 2110 |
+
# get diffusion row
|
| 2111 |
+
diffusion_row = list()
|
| 2112 |
+
z_start = z[:n_row]
|
| 2113 |
+
for t in range(self.num_timesteps):
|
| 2114 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
| 2115 |
+
t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
|
| 2116 |
+
t = t.to(self.device).long()
|
| 2117 |
+
noise = torch.randn_like(z_start)
|
| 2118 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
| 2119 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
| 2120 |
+
|
| 2121 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
| 2122 |
+
diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
|
| 2123 |
+
diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
|
| 2124 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
| 2125 |
+
log["diffusion_row"] = diffusion_grid
|
| 2126 |
+
|
| 2127 |
+
if sample:
|
| 2128 |
+
# get denoise row
|
| 2129 |
+
with ema_scope("Sampling"):
|
| 2130 |
+
samples, z_denoise_row = self.sample_log(
|
| 2131 |
+
cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
| 2132 |
+
batch_size=N,
|
| 2133 |
+
ddim=use_ddim,
|
| 2134 |
+
ddim_steps=ddim_steps,
|
| 2135 |
+
eta=ddim_eta,
|
| 2136 |
+
)
|
| 2137 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
| 2138 |
+
x_samples = self.decode_first_stage(samples)
|
| 2139 |
+
log["samples"] = x_samples
|
| 2140 |
+
if plot_denoise_rows:
|
| 2141 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
| 2142 |
+
log["denoise_row"] = denoise_grid
|
| 2143 |
+
|
| 2144 |
+
if unconditional_guidance_scale > 1.0:
|
| 2145 |
+
uc_cross = self.get_unconditional_conditioning(
|
| 2146 |
+
N, unconditional_guidance_label
|
| 2147 |
+
)
|
| 2148 |
+
uc_cat = c_cat
|
| 2149 |
+
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
| 2150 |
+
with ema_scope("Sampling with classifier-free guidance"):
|
| 2151 |
+
samples_cfg, _ = self.sample_log(
|
| 2152 |
+
cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
| 2153 |
+
batch_size=N,
|
| 2154 |
+
ddim=use_ddim,
|
| 2155 |
+
ddim_steps=ddim_steps,
|
| 2156 |
+
eta=ddim_eta,
|
| 2157 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 2158 |
+
unconditional_conditioning=uc_full,
|
| 2159 |
+
)
|
| 2160 |
+
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
| 2161 |
+
log[
|
| 2162 |
+
f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
|
| 2163 |
+
] = x_samples_cfg
|
| 2164 |
+
|
| 2165 |
+
return log
|
| 2166 |
+
|
| 2167 |
+
|
| 2168 |
+
class LatentInpaintDiffusion(LatentFinetuneDiffusion):
|
| 2169 |
+
"""
|
| 2170 |
+
can either run as pure inpainting model (only concat mode) or with mixed conditionings,
|
| 2171 |
+
e.g. mask as concat and text via cross-attn.
|
| 2172 |
+
To disable finetuning mode, set finetune_keys to None
|
| 2173 |
+
"""
|
| 2174 |
+
|
| 2175 |
+
def __init__(
|
| 2176 |
+
self,
|
| 2177 |
+
concat_keys=("mask", "masked_image"),
|
| 2178 |
+
masked_image_key="masked_image",
|
| 2179 |
+
*args,
|
| 2180 |
+
**kwargs,
|
| 2181 |
+
):
|
| 2182 |
+
super().__init__(concat_keys, *args, **kwargs)
|
| 2183 |
+
self.masked_image_key = masked_image_key
|
| 2184 |
+
assert self.masked_image_key in concat_keys
|
| 2185 |
+
|
| 2186 |
+
@torch.no_grad()
|
| 2187 |
+
def get_input(
|
| 2188 |
+
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
|
| 2189 |
+
):
|
| 2190 |
+
# note: restricted to non-trainable encoders currently
|
| 2191 |
+
assert (
|
| 2192 |
+
not self.cond_stage_trainable
|
| 2193 |
+
), "trainable cond stages not yet supported for inpainting"
|
| 2194 |
+
z, c, x, xrec, xc = super().get_input(
|
| 2195 |
+
batch,
|
| 2196 |
+
self.first_stage_key,
|
| 2197 |
+
return_first_stage_outputs=True,
|
| 2198 |
+
force_c_encode=True,
|
| 2199 |
+
return_original_cond=True,
|
| 2200 |
+
bs=bs,
|
| 2201 |
+
)
|
| 2202 |
+
|
| 2203 |
+
assert exists(self.concat_keys)
|
| 2204 |
+
c_cat = list()
|
| 2205 |
+
for ck in self.concat_keys:
|
| 2206 |
+
cc = (
|
| 2207 |
+
rearrange(batch[ck], "b h w c -> b c h w")
|
| 2208 |
+
.to(memory_format=torch.contiguous_format)
|
| 2209 |
+
.float()
|
| 2210 |
+
)
|
| 2211 |
+
if bs is not None:
|
| 2212 |
+
cc = cc[:bs]
|
| 2213 |
+
cc = cc.to(self.device)
|
| 2214 |
+
bchw = z.shape
|
| 2215 |
+
if ck != self.masked_image_key:
|
| 2216 |
+
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
| 2217 |
+
else:
|
| 2218 |
+
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
|
| 2219 |
+
c_cat.append(cc)
|
| 2220 |
+
c_cat = torch.cat(c_cat, dim=1)
|
| 2221 |
+
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
|
| 2222 |
+
if return_first_stage_outputs:
|
| 2223 |
+
return z, all_conds, x, xrec, xc
|
| 2224 |
+
return z, all_conds
|
| 2225 |
+
|
| 2226 |
+
@torch.no_grad()
|
| 2227 |
+
def log_images(self, *args, **kwargs):
|
| 2228 |
+
log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
|
| 2229 |
+
log["masked_image"] = (
|
| 2230 |
+
rearrange(args[0]["masked_image"], "b h w c -> b c h w")
|
| 2231 |
+
.to(memory_format=torch.contiguous_format)
|
| 2232 |
+
.float()
|
| 2233 |
+
)
|
| 2234 |
+
return log
|
| 2235 |
+
|
| 2236 |
+
|
| 2237 |
+
class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
|
| 2238 |
+
"""
|
| 2239 |
+
condition on monocular depth estimation
|
| 2240 |
+
"""
|
| 2241 |
+
|
| 2242 |
+
def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
|
| 2243 |
+
super().__init__(concat_keys=concat_keys, *args, **kwargs)
|
| 2244 |
+
self.depth_model = instantiate_from_config(depth_stage_config)
|
| 2245 |
+
self.depth_stage_key = concat_keys[0]
|
| 2246 |
+
|
| 2247 |
+
@torch.no_grad()
|
| 2248 |
+
def get_input(
|
| 2249 |
+
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
|
| 2250 |
+
):
|
| 2251 |
+
# note: restricted to non-trainable encoders currently
|
| 2252 |
+
assert (
|
| 2253 |
+
not self.cond_stage_trainable
|
| 2254 |
+
), "trainable cond stages not yet supported for depth2img"
|
| 2255 |
+
z, c, x, xrec, xc = super().get_input(
|
| 2256 |
+
batch,
|
| 2257 |
+
self.first_stage_key,
|
| 2258 |
+
return_first_stage_outputs=True,
|
| 2259 |
+
force_c_encode=True,
|
| 2260 |
+
return_original_cond=True,
|
| 2261 |
+
bs=bs,
|
| 2262 |
+
)
|
| 2263 |
+
|
| 2264 |
+
assert exists(self.concat_keys)
|
| 2265 |
+
assert len(self.concat_keys) == 1
|
| 2266 |
+
c_cat = list()
|
| 2267 |
+
for ck in self.concat_keys:
|
| 2268 |
+
cc = batch[ck]
|
| 2269 |
+
if bs is not None:
|
| 2270 |
+
cc = cc[:bs]
|
| 2271 |
+
cc = cc.to(self.device)
|
| 2272 |
+
cc = self.depth_model(cc)
|
| 2273 |
+
cc = torch.nn.functional.interpolate(
|
| 2274 |
+
cc,
|
| 2275 |
+
size=z.shape[2:],
|
| 2276 |
+
mode="bicubic",
|
| 2277 |
+
align_corners=False,
|
| 2278 |
+
)
|
| 2279 |
+
|
| 2280 |
+
depth_min, depth_max = (
|
| 2281 |
+
torch.amin(cc, dim=[1, 2, 3], keepdim=True),
|
| 2282 |
+
torch.amax(cc, dim=[1, 2, 3], keepdim=True),
|
| 2283 |
+
)
|
| 2284 |
+
cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0
|
| 2285 |
+
c_cat.append(cc)
|
| 2286 |
+
c_cat = torch.cat(c_cat, dim=1)
|
| 2287 |
+
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
|
| 2288 |
+
if return_first_stage_outputs:
|
| 2289 |
+
return z, all_conds, x, xrec, xc
|
| 2290 |
+
return z, all_conds
|
| 2291 |
+
|
| 2292 |
+
@torch.no_grad()
|
| 2293 |
+
def log_images(self, *args, **kwargs):
|
| 2294 |
+
log = super().log_images(*args, **kwargs)
|
| 2295 |
+
depth = self.depth_model(args[0][self.depth_stage_key])
|
| 2296 |
+
depth_min, depth_max = (
|
| 2297 |
+
torch.amin(depth, dim=[1, 2, 3], keepdim=True),
|
| 2298 |
+
torch.amax(depth, dim=[1, 2, 3], keepdim=True),
|
| 2299 |
+
)
|
| 2300 |
+
log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0
|
| 2301 |
+
return log
|
| 2302 |
+
|
| 2303 |
+
|
| 2304 |
+
class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
| 2305 |
+
"""
|
| 2306 |
+
condition on low-res image (and optionally on some spatial noise augmentation)
|
| 2307 |
+
"""
|
| 2308 |
+
|
| 2309 |
+
def __init__(
|
| 2310 |
+
self,
|
| 2311 |
+
concat_keys=("lr",),
|
| 2312 |
+
reshuffle_patch_size=None,
|
| 2313 |
+
low_scale_config=None,
|
| 2314 |
+
low_scale_key=None,
|
| 2315 |
+
*args,
|
| 2316 |
+
**kwargs,
|
| 2317 |
+
):
|
| 2318 |
+
super().__init__(concat_keys=concat_keys, *args, **kwargs)
|
| 2319 |
+
self.reshuffle_patch_size = reshuffle_patch_size
|
| 2320 |
+
self.low_scale_model = None
|
| 2321 |
+
if low_scale_config is not None:
|
| 2322 |
+
print("Initializing a low-scale model")
|
| 2323 |
+
assert exists(low_scale_key)
|
| 2324 |
+
self.instantiate_low_stage(low_scale_config)
|
| 2325 |
+
self.low_scale_key = low_scale_key
|
| 2326 |
+
|
| 2327 |
+
def instantiate_low_stage(self, config):
|
| 2328 |
+
model = instantiate_from_config(config)
|
| 2329 |
+
self.low_scale_model = model.eval()
|
| 2330 |
+
self.low_scale_model.train = disabled_train
|
| 2331 |
+
for param in self.low_scale_model.parameters():
|
| 2332 |
+
param.requires_grad = False
|
| 2333 |
+
|
| 2334 |
+
@torch.no_grad()
|
| 2335 |
+
def get_input(
|
| 2336 |
+
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
|
| 2337 |
+
):
|
| 2338 |
+
# note: restricted to non-trainable encoders currently
|
| 2339 |
+
assert (
|
| 2340 |
+
not self.cond_stage_trainable
|
| 2341 |
+
), "trainable cond stages not yet supported for upscaling-ft"
|
| 2342 |
+
z, c, x, xrec, xc = super().get_input(
|
| 2343 |
+
batch,
|
| 2344 |
+
self.first_stage_key,
|
| 2345 |
+
return_first_stage_outputs=True,
|
| 2346 |
+
force_c_encode=True,
|
| 2347 |
+
return_original_cond=True,
|
| 2348 |
+
bs=bs,
|
| 2349 |
+
)
|
| 2350 |
+
|
| 2351 |
+
assert exists(self.concat_keys)
|
| 2352 |
+
assert len(self.concat_keys) == 1
|
| 2353 |
+
# optionally make spatial noise_level here
|
| 2354 |
+
c_cat = list()
|
| 2355 |
+
noise_level = None
|
| 2356 |
+
for ck in self.concat_keys:
|
| 2357 |
+
cc = batch[ck]
|
| 2358 |
+
cc = rearrange(cc, "b h w c -> b c h w")
|
| 2359 |
+
if exists(self.reshuffle_patch_size):
|
| 2360 |
+
assert isinstance(self.reshuffle_patch_size, int)
|
| 2361 |
+
cc = rearrange(
|
| 2362 |
+
cc,
|
| 2363 |
+
"b c (p1 h) (p2 w) -> b (p1 p2 c) h w",
|
| 2364 |
+
p1=self.reshuffle_patch_size,
|
| 2365 |
+
p2=self.reshuffle_patch_size,
|
| 2366 |
+
)
|
| 2367 |
+
if bs is not None:
|
| 2368 |
+
cc = cc[:bs]
|
| 2369 |
+
cc = cc.to(self.device)
|
| 2370 |
+
if exists(self.low_scale_model) and ck == self.low_scale_key:
|
| 2371 |
+
cc, noise_level = self.low_scale_model(cc)
|
| 2372 |
+
c_cat.append(cc)
|
| 2373 |
+
c_cat = torch.cat(c_cat, dim=1)
|
| 2374 |
+
if exists(noise_level):
|
| 2375 |
+
all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
|
| 2376 |
+
else:
|
| 2377 |
+
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
|
| 2378 |
+
if return_first_stage_outputs:
|
| 2379 |
+
return z, all_conds, x, xrec, xc
|
| 2380 |
+
return z, all_conds
|
| 2381 |
+
|
| 2382 |
+
@torch.no_grad()
|
| 2383 |
+
def log_images(self, *args, **kwargs):
|
| 2384 |
+
log = super().log_images(*args, **kwargs)
|
| 2385 |
+
log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w")
|
| 2386 |
+
return log
|
sorawm/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .sampler import DPMSolverSampler
|
sorawm/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py
ADDED
|
@@ -0,0 +1,1464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class NoiseScheduleVP:
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
schedule="discrete",
|
| 12 |
+
betas=None,
|
| 13 |
+
alphas_cumprod=None,
|
| 14 |
+
continuous_beta_0=0.1,
|
| 15 |
+
continuous_beta_1=20.0,
|
| 16 |
+
):
|
| 17 |
+
"""Create a wrapper class for the forward SDE (VP type).
|
| 18 |
+
***
|
| 19 |
+
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
| 20 |
+
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
|
| 21 |
+
***
|
| 22 |
+
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
| 23 |
+
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
| 24 |
+
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
| 25 |
+
log_alpha_t = self.marginal_log_mean_coeff(t)
|
| 26 |
+
sigma_t = self.marginal_std(t)
|
| 27 |
+
lambda_t = self.marginal_lambda(t)
|
| 28 |
+
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
| 29 |
+
t = self.inverse_lambda(lambda_t)
|
| 30 |
+
===============================================================
|
| 31 |
+
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
| 32 |
+
1. For discrete-time DPMs:
|
| 33 |
+
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
| 34 |
+
t_i = (i + 1) / N
|
| 35 |
+
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
| 36 |
+
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
| 37 |
+
Args:
|
| 38 |
+
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
| 39 |
+
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
| 40 |
+
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
| 41 |
+
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
| 42 |
+
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
| 43 |
+
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
| 44 |
+
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
| 45 |
+
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
| 46 |
+
and
|
| 47 |
+
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
| 48 |
+
2. For continuous-time DPMs:
|
| 49 |
+
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
| 50 |
+
schedule are the default settings in DDPM and improved-DDPM:
|
| 51 |
+
Args:
|
| 52 |
+
beta_min: A `float` number. The smallest beta for the linear schedule.
|
| 53 |
+
beta_max: A `float` number. The largest beta for the linear schedule.
|
| 54 |
+
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
| 55 |
+
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
| 56 |
+
T: A `float` number. The ending time of the forward process.
|
| 57 |
+
===============================================================
|
| 58 |
+
Args:
|
| 59 |
+
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
| 60 |
+
'linear' or 'cosine' for continuous-time DPMs.
|
| 61 |
+
Returns:
|
| 62 |
+
A wrapper object of the forward SDE (VP type).
|
| 63 |
+
|
| 64 |
+
===============================================================
|
| 65 |
+
Example:
|
| 66 |
+
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
| 67 |
+
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
| 68 |
+
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
| 69 |
+
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
| 70 |
+
# For continuous-time DPMs (VPSDE), linear schedule:
|
| 71 |
+
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
if schedule not in ["discrete", "linear", "cosine"]:
|
| 75 |
+
raise ValueError(
|
| 76 |
+
"Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
|
| 77 |
+
schedule
|
| 78 |
+
)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
self.schedule = schedule
|
| 82 |
+
if schedule == "discrete":
|
| 83 |
+
if betas is not None:
|
| 84 |
+
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
| 85 |
+
else:
|
| 86 |
+
assert alphas_cumprod is not None
|
| 87 |
+
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
| 88 |
+
self.total_N = len(log_alphas)
|
| 89 |
+
self.T = 1.0
|
| 90 |
+
self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape(
|
| 91 |
+
(1, -1)
|
| 92 |
+
)
|
| 93 |
+
self.log_alpha_array = log_alphas.reshape(
|
| 94 |
+
(
|
| 95 |
+
1,
|
| 96 |
+
-1,
|
| 97 |
+
)
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
self.total_N = 1000
|
| 101 |
+
self.beta_0 = continuous_beta_0
|
| 102 |
+
self.beta_1 = continuous_beta_1
|
| 103 |
+
self.cosine_s = 0.008
|
| 104 |
+
self.cosine_beta_max = 999.0
|
| 105 |
+
self.cosine_t_max = (
|
| 106 |
+
math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
|
| 107 |
+
* 2.0
|
| 108 |
+
* (1.0 + self.cosine_s)
|
| 109 |
+
/ math.pi
|
| 110 |
+
- self.cosine_s
|
| 111 |
+
)
|
| 112 |
+
self.cosine_log_alpha_0 = math.log(
|
| 113 |
+
math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)
|
| 114 |
+
)
|
| 115 |
+
self.schedule = schedule
|
| 116 |
+
if schedule == "cosine":
|
| 117 |
+
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
| 118 |
+
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
| 119 |
+
self.T = 0.9946
|
| 120 |
+
else:
|
| 121 |
+
self.T = 1.0
|
| 122 |
+
|
| 123 |
+
def marginal_log_mean_coeff(self, t):
|
| 124 |
+
"""
|
| 125 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
| 126 |
+
"""
|
| 127 |
+
if self.schedule == "discrete":
|
| 128 |
+
return interpolate_fn(
|
| 129 |
+
t.reshape((-1, 1)),
|
| 130 |
+
self.t_array.to(t.device),
|
| 131 |
+
self.log_alpha_array.to(t.device),
|
| 132 |
+
).reshape((-1))
|
| 133 |
+
elif self.schedule == "linear":
|
| 134 |
+
return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
| 135 |
+
elif self.schedule == "cosine":
|
| 136 |
+
log_alpha_fn = lambda s: torch.log(
|
| 137 |
+
torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)
|
| 138 |
+
)
|
| 139 |
+
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
| 140 |
+
return log_alpha_t
|
| 141 |
+
|
| 142 |
+
def marginal_alpha(self, t):
|
| 143 |
+
"""
|
| 144 |
+
Compute alpha_t of a given continuous-time label t in [0, T].
|
| 145 |
+
"""
|
| 146 |
+
return torch.exp(self.marginal_log_mean_coeff(t))
|
| 147 |
+
|
| 148 |
+
def marginal_std(self, t):
|
| 149 |
+
"""
|
| 150 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
| 151 |
+
"""
|
| 152 |
+
return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
|
| 153 |
+
|
| 154 |
+
def marginal_lambda(self, t):
|
| 155 |
+
"""
|
| 156 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
| 157 |
+
"""
|
| 158 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
| 159 |
+
log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
|
| 160 |
+
return log_mean_coeff - log_std
|
| 161 |
+
|
| 162 |
+
def inverse_lambda(self, lamb):
|
| 163 |
+
"""
|
| 164 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
| 165 |
+
"""
|
| 166 |
+
if self.schedule == "linear":
|
| 167 |
+
tmp = (
|
| 168 |
+
2.0
|
| 169 |
+
* (self.beta_1 - self.beta_0)
|
| 170 |
+
* torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
| 171 |
+
)
|
| 172 |
+
Delta = self.beta_0**2 + tmp
|
| 173 |
+
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
| 174 |
+
elif self.schedule == "discrete":
|
| 175 |
+
log_alpha = -0.5 * torch.logaddexp(
|
| 176 |
+
torch.zeros((1,)).to(lamb.device), -2.0 * lamb
|
| 177 |
+
)
|
| 178 |
+
t = interpolate_fn(
|
| 179 |
+
log_alpha.reshape((-1, 1)),
|
| 180 |
+
torch.flip(self.log_alpha_array.to(lamb.device), [1]),
|
| 181 |
+
torch.flip(self.t_array.to(lamb.device), [1]),
|
| 182 |
+
)
|
| 183 |
+
return t.reshape((-1,))
|
| 184 |
+
else:
|
| 185 |
+
log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
| 186 |
+
t_fn = (
|
| 187 |
+
lambda log_alpha_t: torch.arccos(
|
| 188 |
+
torch.exp(log_alpha_t + self.cosine_log_alpha_0)
|
| 189 |
+
)
|
| 190 |
+
* 2.0
|
| 191 |
+
* (1.0 + self.cosine_s)
|
| 192 |
+
/ math.pi
|
| 193 |
+
- self.cosine_s
|
| 194 |
+
)
|
| 195 |
+
t = t_fn(log_alpha)
|
| 196 |
+
return t
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def model_wrapper(
|
| 200 |
+
model,
|
| 201 |
+
noise_schedule,
|
| 202 |
+
model_type="noise",
|
| 203 |
+
model_kwargs={},
|
| 204 |
+
guidance_type="uncond",
|
| 205 |
+
condition=None,
|
| 206 |
+
unconditional_condition=None,
|
| 207 |
+
guidance_scale=1.0,
|
| 208 |
+
classifier_fn=None,
|
| 209 |
+
classifier_kwargs={},
|
| 210 |
+
):
|
| 211 |
+
"""Create a wrapper function for the noise prediction model.
|
| 212 |
+
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
| 213 |
+
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
| 214 |
+
We support four types of the diffusion model by setting `model_type`:
|
| 215 |
+
1. "noise": noise prediction model. (Trained by predicting noise).
|
| 216 |
+
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
| 217 |
+
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
| 218 |
+
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
| 219 |
+
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
|
| 220 |
+
arXiv preprint arXiv:2202.00512 (2022).
|
| 221 |
+
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
| 222 |
+
arXiv preprint arXiv:2210.02303 (2022).
|
| 223 |
+
|
| 224 |
+
4. "score": marginal score function. (Trained by denoising score matching).
|
| 225 |
+
Note that the score function and the noise prediction model follows a simple relationship:
|
| 226 |
+
```
|
| 227 |
+
noise(x_t, t) = -sigma_t * score(x_t, t)
|
| 228 |
+
```
|
| 229 |
+
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
| 230 |
+
1. "uncond": unconditional sampling by DPMs.
|
| 231 |
+
The input `model` has the following format:
|
| 232 |
+
``
|
| 233 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
| 234 |
+
``
|
| 235 |
+
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
| 236 |
+
The input `model` has the following format:
|
| 237 |
+
``
|
| 238 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
| 239 |
+
``
|
| 240 |
+
The input `classifier_fn` has the following format:
|
| 241 |
+
``
|
| 242 |
+
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
| 243 |
+
``
|
| 244 |
+
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
|
| 245 |
+
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
| 246 |
+
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
| 247 |
+
The input `model` has the following format:
|
| 248 |
+
``
|
| 249 |
+
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
| 250 |
+
``
|
| 251 |
+
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
| 252 |
+
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
| 253 |
+
arXiv preprint arXiv:2207.12598 (2022).
|
| 254 |
+
|
| 255 |
+
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
| 256 |
+
or continuous-time labels (i.e. epsilon to T).
|
| 257 |
+
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
| 258 |
+
``
|
| 259 |
+
def model_fn(x, t_continuous) -> noise:
|
| 260 |
+
t_input = get_model_input_time(t_continuous)
|
| 261 |
+
return noise_pred(model, x, t_input, **model_kwargs)
|
| 262 |
+
``
|
| 263 |
+
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
| 264 |
+
===============================================================
|
| 265 |
+
Args:
|
| 266 |
+
model: A diffusion model with the corresponding format described above.
|
| 267 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
| 268 |
+
model_type: A `str`. The parameterization type of the diffusion model.
|
| 269 |
+
"noise" or "x_start" or "v" or "score".
|
| 270 |
+
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
| 271 |
+
guidance_type: A `str`. The type of the guidance for sampling.
|
| 272 |
+
"uncond" or "classifier" or "classifier-free".
|
| 273 |
+
condition: A pytorch tensor. The condition for the guided sampling.
|
| 274 |
+
Only used for "classifier" or "classifier-free" guidance type.
|
| 275 |
+
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
| 276 |
+
Only used for "classifier-free" guidance type.
|
| 277 |
+
guidance_scale: A `float`. The scale for the guided sampling.
|
| 278 |
+
classifier_fn: A classifier function. Only used for the classifier guidance.
|
| 279 |
+
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
| 280 |
+
Returns:
|
| 281 |
+
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
def get_model_input_time(t_continuous):
|
| 285 |
+
"""
|
| 286 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
| 287 |
+
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
| 288 |
+
For continuous-time DPMs, we just use `t_continuous`.
|
| 289 |
+
"""
|
| 290 |
+
if noise_schedule.schedule == "discrete":
|
| 291 |
+
return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
|
| 292 |
+
else:
|
| 293 |
+
return t_continuous
|
| 294 |
+
|
| 295 |
+
def noise_pred_fn(x, t_continuous, cond=None):
|
| 296 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
| 297 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
| 298 |
+
t_input = get_model_input_time(t_continuous)
|
| 299 |
+
if cond is None:
|
| 300 |
+
output = model(x, t_input, **model_kwargs)
|
| 301 |
+
else:
|
| 302 |
+
output = model(x, t_input, cond, **model_kwargs)
|
| 303 |
+
if model_type == "noise":
|
| 304 |
+
return output
|
| 305 |
+
elif model_type == "x_start":
|
| 306 |
+
alpha_t, sigma_t = (
|
| 307 |
+
noise_schedule.marginal_alpha(t_continuous),
|
| 308 |
+
noise_schedule.marginal_std(t_continuous),
|
| 309 |
+
)
|
| 310 |
+
dims = x.dim()
|
| 311 |
+
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(
|
| 312 |
+
sigma_t, dims
|
| 313 |
+
)
|
| 314 |
+
elif model_type == "v":
|
| 315 |
+
alpha_t, sigma_t = (
|
| 316 |
+
noise_schedule.marginal_alpha(t_continuous),
|
| 317 |
+
noise_schedule.marginal_std(t_continuous),
|
| 318 |
+
)
|
| 319 |
+
dims = x.dim()
|
| 320 |
+
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
|
| 321 |
+
elif model_type == "score":
|
| 322 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 323 |
+
dims = x.dim()
|
| 324 |
+
return -expand_dims(sigma_t, dims) * output
|
| 325 |
+
|
| 326 |
+
def cond_grad_fn(x, t_input):
|
| 327 |
+
"""
|
| 328 |
+
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
| 329 |
+
"""
|
| 330 |
+
with torch.enable_grad():
|
| 331 |
+
x_in = x.detach().requires_grad_(True)
|
| 332 |
+
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
| 333 |
+
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
| 334 |
+
|
| 335 |
+
def model_fn(x, t_continuous):
|
| 336 |
+
"""
|
| 337 |
+
The noise predicition model function that is used for DPM-Solver.
|
| 338 |
+
"""
|
| 339 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
| 340 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
| 341 |
+
if guidance_type == "uncond":
|
| 342 |
+
return noise_pred_fn(x, t_continuous)
|
| 343 |
+
elif guidance_type == "classifier":
|
| 344 |
+
assert classifier_fn is not None
|
| 345 |
+
t_input = get_model_input_time(t_continuous)
|
| 346 |
+
cond_grad = cond_grad_fn(x, t_input)
|
| 347 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 348 |
+
noise = noise_pred_fn(x, t_continuous)
|
| 349 |
+
return (
|
| 350 |
+
noise
|
| 351 |
+
- guidance_scale
|
| 352 |
+
* expand_dims(sigma_t, dims=cond_grad.dim())
|
| 353 |
+
* cond_grad
|
| 354 |
+
)
|
| 355 |
+
elif guidance_type == "classifier-free":
|
| 356 |
+
if guidance_scale == 1.0 or unconditional_condition is None:
|
| 357 |
+
return noise_pred_fn(x, t_continuous, cond=condition)
|
| 358 |
+
else:
|
| 359 |
+
x_in = torch.cat([x] * 2)
|
| 360 |
+
t_in = torch.cat([t_continuous] * 2)
|
| 361 |
+
c_in = torch.cat([unconditional_condition, condition])
|
| 362 |
+
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
| 363 |
+
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
| 364 |
+
|
| 365 |
+
assert model_type in ["noise", "x_start", "v"]
|
| 366 |
+
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
| 367 |
+
return model_fn
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class DPM_Solver:
|
| 371 |
+
def __init__(
|
| 372 |
+
self,
|
| 373 |
+
model_fn,
|
| 374 |
+
noise_schedule,
|
| 375 |
+
predict_x0=False,
|
| 376 |
+
thresholding=False,
|
| 377 |
+
max_val=1.0,
|
| 378 |
+
):
|
| 379 |
+
"""Construct a DPM-Solver.
|
| 380 |
+
We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
|
| 381 |
+
If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
|
| 382 |
+
If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
|
| 383 |
+
In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
|
| 384 |
+
The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
|
| 385 |
+
Args:
|
| 386 |
+
model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
|
| 387 |
+
``
|
| 388 |
+
def model_fn(x, t_continuous):
|
| 389 |
+
return noise
|
| 390 |
+
``
|
| 391 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
| 392 |
+
predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
|
| 393 |
+
thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
|
| 394 |
+
max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
|
| 395 |
+
|
| 396 |
+
[1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
|
| 397 |
+
"""
|
| 398 |
+
self.model = model_fn
|
| 399 |
+
self.noise_schedule = noise_schedule
|
| 400 |
+
self.predict_x0 = predict_x0
|
| 401 |
+
self.thresholding = thresholding
|
| 402 |
+
self.max_val = max_val
|
| 403 |
+
|
| 404 |
+
def noise_prediction_fn(self, x, t):
|
| 405 |
+
"""
|
| 406 |
+
Return the noise prediction model.
|
| 407 |
+
"""
|
| 408 |
+
return self.model(x, t)
|
| 409 |
+
|
| 410 |
+
def data_prediction_fn(self, x, t):
|
| 411 |
+
"""
|
| 412 |
+
Return the data prediction model (with thresholding).
|
| 413 |
+
"""
|
| 414 |
+
noise = self.noise_prediction_fn(x, t)
|
| 415 |
+
dims = x.dim()
|
| 416 |
+
alpha_t, sigma_t = (
|
| 417 |
+
self.noise_schedule.marginal_alpha(t),
|
| 418 |
+
self.noise_schedule.marginal_std(t),
|
| 419 |
+
)
|
| 420 |
+
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
| 421 |
+
if self.thresholding:
|
| 422 |
+
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
| 423 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
| 424 |
+
s = expand_dims(
|
| 425 |
+
torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims
|
| 426 |
+
)
|
| 427 |
+
x0 = torch.clamp(x0, -s, s) / s
|
| 428 |
+
return x0
|
| 429 |
+
|
| 430 |
+
def model_fn(self, x, t):
|
| 431 |
+
"""
|
| 432 |
+
Convert the model to the noise prediction model or the data prediction model.
|
| 433 |
+
"""
|
| 434 |
+
if self.predict_x0:
|
| 435 |
+
return self.data_prediction_fn(x, t)
|
| 436 |
+
else:
|
| 437 |
+
return self.noise_prediction_fn(x, t)
|
| 438 |
+
|
| 439 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
| 440 |
+
"""Compute the intermediate time steps for sampling.
|
| 441 |
+
Args:
|
| 442 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
| 443 |
+
- 'logSNR': uniform logSNR for the time steps.
|
| 444 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
| 445 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
| 446 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
| 447 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
| 448 |
+
N: A `int`. The total number of the spacing of the time steps.
|
| 449 |
+
device: A torch device.
|
| 450 |
+
Returns:
|
| 451 |
+
A pytorch tensor of the time steps, with the shape (N + 1,).
|
| 452 |
+
"""
|
| 453 |
+
if skip_type == "logSNR":
|
| 454 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
| 455 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
| 456 |
+
logSNR_steps = torch.linspace(
|
| 457 |
+
lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1
|
| 458 |
+
).to(device)
|
| 459 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
| 460 |
+
elif skip_type == "time_uniform":
|
| 461 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
| 462 |
+
elif skip_type == "time_quadratic":
|
| 463 |
+
t_order = 2
|
| 464 |
+
t = (
|
| 465 |
+
torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1)
|
| 466 |
+
.pow(t_order)
|
| 467 |
+
.to(device)
|
| 468 |
+
)
|
| 469 |
+
return t
|
| 470 |
+
else:
|
| 471 |
+
raise ValueError(
|
| 472 |
+
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(
|
| 473 |
+
skip_type
|
| 474 |
+
)
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
def get_orders_and_timesteps_for_singlestep_solver(
|
| 478 |
+
self, steps, order, skip_type, t_T, t_0, device
|
| 479 |
+
):
|
| 480 |
+
"""
|
| 481 |
+
Get the order of each step for sampling by the singlestep DPM-Solver.
|
| 482 |
+
We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
|
| 483 |
+
Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
|
| 484 |
+
- If order == 1:
|
| 485 |
+
We take `steps` of DPM-Solver-1 (i.e. DDIM).
|
| 486 |
+
- If order == 2:
|
| 487 |
+
- Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
|
| 488 |
+
- If steps % 2 == 0, we use K steps of DPM-Solver-2.
|
| 489 |
+
- If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
| 490 |
+
- If order == 3:
|
| 491 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
| 492 |
+
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
| 493 |
+
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
|
| 494 |
+
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
|
| 495 |
+
============================================
|
| 496 |
+
Args:
|
| 497 |
+
order: A `int`. The max order for the solver (2 or 3).
|
| 498 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
| 499 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
| 500 |
+
- 'logSNR': uniform logSNR for the time steps.
|
| 501 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
| 502 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
| 503 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
| 504 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
| 505 |
+
device: A torch device.
|
| 506 |
+
Returns:
|
| 507 |
+
orders: A list of the solver order of each step.
|
| 508 |
+
"""
|
| 509 |
+
if order == 3:
|
| 510 |
+
K = steps // 3 + 1
|
| 511 |
+
if steps % 3 == 0:
|
| 512 |
+
orders = [
|
| 513 |
+
3,
|
| 514 |
+
] * (
|
| 515 |
+
K - 2
|
| 516 |
+
) + [2, 1]
|
| 517 |
+
elif steps % 3 == 1:
|
| 518 |
+
orders = [
|
| 519 |
+
3,
|
| 520 |
+
] * (
|
| 521 |
+
K - 1
|
| 522 |
+
) + [1]
|
| 523 |
+
else:
|
| 524 |
+
orders = [
|
| 525 |
+
3,
|
| 526 |
+
] * (
|
| 527 |
+
K - 1
|
| 528 |
+
) + [2]
|
| 529 |
+
elif order == 2:
|
| 530 |
+
if steps % 2 == 0:
|
| 531 |
+
K = steps // 2
|
| 532 |
+
orders = [
|
| 533 |
+
2,
|
| 534 |
+
] * K
|
| 535 |
+
else:
|
| 536 |
+
K = steps // 2 + 1
|
| 537 |
+
orders = [
|
| 538 |
+
2,
|
| 539 |
+
] * (
|
| 540 |
+
K - 1
|
| 541 |
+
) + [1]
|
| 542 |
+
elif order == 1:
|
| 543 |
+
K = 1
|
| 544 |
+
orders = [
|
| 545 |
+
1,
|
| 546 |
+
] * steps
|
| 547 |
+
else:
|
| 548 |
+
raise ValueError("'order' must be '1' or '2' or '3'.")
|
| 549 |
+
if skip_type == "logSNR":
|
| 550 |
+
# To reproduce the results in DPM-Solver paper
|
| 551 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
| 552 |
+
else:
|
| 553 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
|
| 554 |
+
torch.cumsum(
|
| 555 |
+
torch.tensor(
|
| 556 |
+
[
|
| 557 |
+
0,
|
| 558 |
+
]
|
| 559 |
+
+ orders
|
| 560 |
+
)
|
| 561 |
+
).to(device)
|
| 562 |
+
]
|
| 563 |
+
return timesteps_outer, orders
|
| 564 |
+
|
| 565 |
+
def denoise_to_zero_fn(self, x, s):
|
| 566 |
+
"""
|
| 567 |
+
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
| 568 |
+
"""
|
| 569 |
+
return self.data_prediction_fn(x, s)
|
| 570 |
+
|
| 571 |
+
def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
|
| 572 |
+
"""
|
| 573 |
+
DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
|
| 574 |
+
Args:
|
| 575 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 576 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
| 577 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
| 578 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
| 579 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
| 580 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`.
|
| 581 |
+
Returns:
|
| 582 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 583 |
+
"""
|
| 584 |
+
ns = self.noise_schedule
|
| 585 |
+
dims = x.dim()
|
| 586 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
| 587 |
+
h = lambda_t - lambda_s
|
| 588 |
+
log_alpha_s, log_alpha_t = (
|
| 589 |
+
ns.marginal_log_mean_coeff(s),
|
| 590 |
+
ns.marginal_log_mean_coeff(t),
|
| 591 |
+
)
|
| 592 |
+
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
|
| 593 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 594 |
+
|
| 595 |
+
if self.predict_x0:
|
| 596 |
+
phi_1 = torch.expm1(-h)
|
| 597 |
+
if model_s is None:
|
| 598 |
+
model_s = self.model_fn(x, s)
|
| 599 |
+
x_t = (
|
| 600 |
+
expand_dims(sigma_t / sigma_s, dims) * x
|
| 601 |
+
- expand_dims(alpha_t * phi_1, dims) * model_s
|
| 602 |
+
)
|
| 603 |
+
if return_intermediate:
|
| 604 |
+
return x_t, {"model_s": model_s}
|
| 605 |
+
else:
|
| 606 |
+
return x_t
|
| 607 |
+
else:
|
| 608 |
+
phi_1 = torch.expm1(h)
|
| 609 |
+
if model_s is None:
|
| 610 |
+
model_s = self.model_fn(x, s)
|
| 611 |
+
x_t = (
|
| 612 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
| 613 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
| 614 |
+
)
|
| 615 |
+
if return_intermediate:
|
| 616 |
+
return x_t, {"model_s": model_s}
|
| 617 |
+
else:
|
| 618 |
+
return x_t
|
| 619 |
+
|
| 620 |
+
def singlestep_dpm_solver_second_update(
|
| 621 |
+
self,
|
| 622 |
+
x,
|
| 623 |
+
s,
|
| 624 |
+
t,
|
| 625 |
+
r1=0.5,
|
| 626 |
+
model_s=None,
|
| 627 |
+
return_intermediate=False,
|
| 628 |
+
solver_type="dpm_solver",
|
| 629 |
+
):
|
| 630 |
+
"""
|
| 631 |
+
Singlestep solver DPM-Solver-2 from time `s` to time `t`.
|
| 632 |
+
Args:
|
| 633 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 634 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
| 635 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
| 636 |
+
r1: A `float`. The hyperparameter of the second-order solver.
|
| 637 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
| 638 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
| 639 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
|
| 640 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
| 641 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
| 642 |
+
Returns:
|
| 643 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 644 |
+
"""
|
| 645 |
+
if solver_type not in ["dpm_solver", "taylor"]:
|
| 646 |
+
raise ValueError(
|
| 647 |
+
"'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(
|
| 648 |
+
solver_type
|
| 649 |
+
)
|
| 650 |
+
)
|
| 651 |
+
if r1 is None:
|
| 652 |
+
r1 = 0.5
|
| 653 |
+
ns = self.noise_schedule
|
| 654 |
+
dims = x.dim()
|
| 655 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
| 656 |
+
h = lambda_t - lambda_s
|
| 657 |
+
lambda_s1 = lambda_s + r1 * h
|
| 658 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
| 659 |
+
log_alpha_s, log_alpha_s1, log_alpha_t = (
|
| 660 |
+
ns.marginal_log_mean_coeff(s),
|
| 661 |
+
ns.marginal_log_mean_coeff(s1),
|
| 662 |
+
ns.marginal_log_mean_coeff(t),
|
| 663 |
+
)
|
| 664 |
+
sigma_s, sigma_s1, sigma_t = (
|
| 665 |
+
ns.marginal_std(s),
|
| 666 |
+
ns.marginal_std(s1),
|
| 667 |
+
ns.marginal_std(t),
|
| 668 |
+
)
|
| 669 |
+
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
|
| 670 |
+
|
| 671 |
+
if self.predict_x0:
|
| 672 |
+
phi_11 = torch.expm1(-r1 * h)
|
| 673 |
+
phi_1 = torch.expm1(-h)
|
| 674 |
+
|
| 675 |
+
if model_s is None:
|
| 676 |
+
model_s = self.model_fn(x, s)
|
| 677 |
+
x_s1 = (
|
| 678 |
+
expand_dims(sigma_s1 / sigma_s, dims) * x
|
| 679 |
+
- expand_dims(alpha_s1 * phi_11, dims) * model_s
|
| 680 |
+
)
|
| 681 |
+
model_s1 = self.model_fn(x_s1, s1)
|
| 682 |
+
if solver_type == "dpm_solver":
|
| 683 |
+
x_t = (
|
| 684 |
+
expand_dims(sigma_t / sigma_s, dims) * x
|
| 685 |
+
- expand_dims(alpha_t * phi_1, dims) * model_s
|
| 686 |
+
- (0.5 / r1)
|
| 687 |
+
* expand_dims(alpha_t * phi_1, dims)
|
| 688 |
+
* (model_s1 - model_s)
|
| 689 |
+
)
|
| 690 |
+
elif solver_type == "taylor":
|
| 691 |
+
x_t = (
|
| 692 |
+
expand_dims(sigma_t / sigma_s, dims) * x
|
| 693 |
+
- expand_dims(alpha_t * phi_1, dims) * model_s
|
| 694 |
+
+ (1.0 / r1)
|
| 695 |
+
* expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims)
|
| 696 |
+
* (model_s1 - model_s)
|
| 697 |
+
)
|
| 698 |
+
else:
|
| 699 |
+
phi_11 = torch.expm1(r1 * h)
|
| 700 |
+
phi_1 = torch.expm1(h)
|
| 701 |
+
|
| 702 |
+
if model_s is None:
|
| 703 |
+
model_s = self.model_fn(x, s)
|
| 704 |
+
x_s1 = (
|
| 705 |
+
expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
|
| 706 |
+
- expand_dims(sigma_s1 * phi_11, dims) * model_s
|
| 707 |
+
)
|
| 708 |
+
model_s1 = self.model_fn(x_s1, s1)
|
| 709 |
+
if solver_type == "dpm_solver":
|
| 710 |
+
x_t = (
|
| 711 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
| 712 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
| 713 |
+
- (0.5 / r1)
|
| 714 |
+
* expand_dims(sigma_t * phi_1, dims)
|
| 715 |
+
* (model_s1 - model_s)
|
| 716 |
+
)
|
| 717 |
+
elif solver_type == "taylor":
|
| 718 |
+
x_t = (
|
| 719 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
| 720 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
| 721 |
+
- (1.0 / r1)
|
| 722 |
+
* expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims)
|
| 723 |
+
* (model_s1 - model_s)
|
| 724 |
+
)
|
| 725 |
+
if return_intermediate:
|
| 726 |
+
return x_t, {"model_s": model_s, "model_s1": model_s1}
|
| 727 |
+
else:
|
| 728 |
+
return x_t
|
| 729 |
+
|
| 730 |
+
def singlestep_dpm_solver_third_update(
|
| 731 |
+
self,
|
| 732 |
+
x,
|
| 733 |
+
s,
|
| 734 |
+
t,
|
| 735 |
+
r1=1.0 / 3.0,
|
| 736 |
+
r2=2.0 / 3.0,
|
| 737 |
+
model_s=None,
|
| 738 |
+
model_s1=None,
|
| 739 |
+
return_intermediate=False,
|
| 740 |
+
solver_type="dpm_solver",
|
| 741 |
+
):
|
| 742 |
+
"""
|
| 743 |
+
Singlestep solver DPM-Solver-3 from time `s` to time `t`.
|
| 744 |
+
Args:
|
| 745 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 746 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
| 747 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
| 748 |
+
r1: A `float`. The hyperparameter of the third-order solver.
|
| 749 |
+
r2: A `float`. The hyperparameter of the third-order solver.
|
| 750 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
| 751 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
| 752 |
+
model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
|
| 753 |
+
If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
|
| 754 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
|
| 755 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
| 756 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
| 757 |
+
Returns:
|
| 758 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 759 |
+
"""
|
| 760 |
+
if solver_type not in ["dpm_solver", "taylor"]:
|
| 761 |
+
raise ValueError(
|
| 762 |
+
"'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(
|
| 763 |
+
solver_type
|
| 764 |
+
)
|
| 765 |
+
)
|
| 766 |
+
if r1 is None:
|
| 767 |
+
r1 = 1.0 / 3.0
|
| 768 |
+
if r2 is None:
|
| 769 |
+
r2 = 2.0 / 3.0
|
| 770 |
+
ns = self.noise_schedule
|
| 771 |
+
dims = x.dim()
|
| 772 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
| 773 |
+
h = lambda_t - lambda_s
|
| 774 |
+
lambda_s1 = lambda_s + r1 * h
|
| 775 |
+
lambda_s2 = lambda_s + r2 * h
|
| 776 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
| 777 |
+
s2 = ns.inverse_lambda(lambda_s2)
|
| 778 |
+
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = (
|
| 779 |
+
ns.marginal_log_mean_coeff(s),
|
| 780 |
+
ns.marginal_log_mean_coeff(s1),
|
| 781 |
+
ns.marginal_log_mean_coeff(s2),
|
| 782 |
+
ns.marginal_log_mean_coeff(t),
|
| 783 |
+
)
|
| 784 |
+
sigma_s, sigma_s1, sigma_s2, sigma_t = (
|
| 785 |
+
ns.marginal_std(s),
|
| 786 |
+
ns.marginal_std(s1),
|
| 787 |
+
ns.marginal_std(s2),
|
| 788 |
+
ns.marginal_std(t),
|
| 789 |
+
)
|
| 790 |
+
alpha_s1, alpha_s2, alpha_t = (
|
| 791 |
+
torch.exp(log_alpha_s1),
|
| 792 |
+
torch.exp(log_alpha_s2),
|
| 793 |
+
torch.exp(log_alpha_t),
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
if self.predict_x0:
|
| 797 |
+
phi_11 = torch.expm1(-r1 * h)
|
| 798 |
+
phi_12 = torch.expm1(-r2 * h)
|
| 799 |
+
phi_1 = torch.expm1(-h)
|
| 800 |
+
phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0
|
| 801 |
+
phi_2 = phi_1 / h + 1.0
|
| 802 |
+
phi_3 = phi_2 / h - 0.5
|
| 803 |
+
|
| 804 |
+
if model_s is None:
|
| 805 |
+
model_s = self.model_fn(x, s)
|
| 806 |
+
if model_s1 is None:
|
| 807 |
+
x_s1 = (
|
| 808 |
+
expand_dims(sigma_s1 / sigma_s, dims) * x
|
| 809 |
+
- expand_dims(alpha_s1 * phi_11, dims) * model_s
|
| 810 |
+
)
|
| 811 |
+
model_s1 = self.model_fn(x_s1, s1)
|
| 812 |
+
x_s2 = (
|
| 813 |
+
expand_dims(sigma_s2 / sigma_s, dims) * x
|
| 814 |
+
- expand_dims(alpha_s2 * phi_12, dims) * model_s
|
| 815 |
+
+ r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
|
| 816 |
+
)
|
| 817 |
+
model_s2 = self.model_fn(x_s2, s2)
|
| 818 |
+
if solver_type == "dpm_solver":
|
| 819 |
+
x_t = (
|
| 820 |
+
expand_dims(sigma_t / sigma_s, dims) * x
|
| 821 |
+
- expand_dims(alpha_t * phi_1, dims) * model_s
|
| 822 |
+
+ (1.0 / r2)
|
| 823 |
+
* expand_dims(alpha_t * phi_2, dims)
|
| 824 |
+
* (model_s2 - model_s)
|
| 825 |
+
)
|
| 826 |
+
elif solver_type == "taylor":
|
| 827 |
+
D1_0 = (1.0 / r1) * (model_s1 - model_s)
|
| 828 |
+
D1_1 = (1.0 / r2) * (model_s2 - model_s)
|
| 829 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
| 830 |
+
D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
|
| 831 |
+
x_t = (
|
| 832 |
+
expand_dims(sigma_t / sigma_s, dims) * x
|
| 833 |
+
- expand_dims(alpha_t * phi_1, dims) * model_s
|
| 834 |
+
+ expand_dims(alpha_t * phi_2, dims) * D1
|
| 835 |
+
- expand_dims(alpha_t * phi_3, dims) * D2
|
| 836 |
+
)
|
| 837 |
+
else:
|
| 838 |
+
phi_11 = torch.expm1(r1 * h)
|
| 839 |
+
phi_12 = torch.expm1(r2 * h)
|
| 840 |
+
phi_1 = torch.expm1(h)
|
| 841 |
+
phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0
|
| 842 |
+
phi_2 = phi_1 / h - 1.0
|
| 843 |
+
phi_3 = phi_2 / h - 0.5
|
| 844 |
+
|
| 845 |
+
if model_s is None:
|
| 846 |
+
model_s = self.model_fn(x, s)
|
| 847 |
+
if model_s1 is None:
|
| 848 |
+
x_s1 = (
|
| 849 |
+
expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
|
| 850 |
+
- expand_dims(sigma_s1 * phi_11, dims) * model_s
|
| 851 |
+
)
|
| 852 |
+
model_s1 = self.model_fn(x_s1, s1)
|
| 853 |
+
x_s2 = (
|
| 854 |
+
expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
|
| 855 |
+
- expand_dims(sigma_s2 * phi_12, dims) * model_s
|
| 856 |
+
- r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
|
| 857 |
+
)
|
| 858 |
+
model_s2 = self.model_fn(x_s2, s2)
|
| 859 |
+
if solver_type == "dpm_solver":
|
| 860 |
+
x_t = (
|
| 861 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
| 862 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
| 863 |
+
- (1.0 / r2)
|
| 864 |
+
* expand_dims(sigma_t * phi_2, dims)
|
| 865 |
+
* (model_s2 - model_s)
|
| 866 |
+
)
|
| 867 |
+
elif solver_type == "taylor":
|
| 868 |
+
D1_0 = (1.0 / r1) * (model_s1 - model_s)
|
| 869 |
+
D1_1 = (1.0 / r2) * (model_s2 - model_s)
|
| 870 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
| 871 |
+
D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
|
| 872 |
+
x_t = (
|
| 873 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
| 874 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
| 875 |
+
- expand_dims(sigma_t * phi_2, dims) * D1
|
| 876 |
+
- expand_dims(sigma_t * phi_3, dims) * D2
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
if return_intermediate:
|
| 880 |
+
return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2}
|
| 881 |
+
else:
|
| 882 |
+
return x_t
|
| 883 |
+
|
| 884 |
+
def multistep_dpm_solver_second_update(
|
| 885 |
+
self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"
|
| 886 |
+
):
|
| 887 |
+
"""
|
| 888 |
+
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
|
| 889 |
+
Args:
|
| 890 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 891 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
| 892 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
|
| 893 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
| 894 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
| 895 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
| 896 |
+
Returns:
|
| 897 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 898 |
+
"""
|
| 899 |
+
if solver_type not in ["dpm_solver", "taylor"]:
|
| 900 |
+
raise ValueError(
|
| 901 |
+
"'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(
|
| 902 |
+
solver_type
|
| 903 |
+
)
|
| 904 |
+
)
|
| 905 |
+
ns = self.noise_schedule
|
| 906 |
+
dims = x.dim()
|
| 907 |
+
model_prev_1, model_prev_0 = model_prev_list
|
| 908 |
+
t_prev_1, t_prev_0 = t_prev_list
|
| 909 |
+
lambda_prev_1, lambda_prev_0, lambda_t = (
|
| 910 |
+
ns.marginal_lambda(t_prev_1),
|
| 911 |
+
ns.marginal_lambda(t_prev_0),
|
| 912 |
+
ns.marginal_lambda(t),
|
| 913 |
+
)
|
| 914 |
+
log_alpha_prev_0, log_alpha_t = (
|
| 915 |
+
ns.marginal_log_mean_coeff(t_prev_0),
|
| 916 |
+
ns.marginal_log_mean_coeff(t),
|
| 917 |
+
)
|
| 918 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
| 919 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 920 |
+
|
| 921 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
| 922 |
+
h = lambda_t - lambda_prev_0
|
| 923 |
+
r0 = h_0 / h
|
| 924 |
+
D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1)
|
| 925 |
+
if self.predict_x0:
|
| 926 |
+
if solver_type == "dpm_solver":
|
| 927 |
+
x_t = (
|
| 928 |
+
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
| 929 |
+
- expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
|
| 930 |
+
- 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * D1_0
|
| 931 |
+
)
|
| 932 |
+
elif solver_type == "taylor":
|
| 933 |
+
x_t = (
|
| 934 |
+
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
| 935 |
+
- expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
|
| 936 |
+
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims)
|
| 937 |
+
* D1_0
|
| 938 |
+
)
|
| 939 |
+
else:
|
| 940 |
+
if solver_type == "dpm_solver":
|
| 941 |
+
x_t = (
|
| 942 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
| 943 |
+
- expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
|
| 944 |
+
- 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * D1_0
|
| 945 |
+
)
|
| 946 |
+
elif solver_type == "taylor":
|
| 947 |
+
x_t = (
|
| 948 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
| 949 |
+
- expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
|
| 950 |
+
- expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims)
|
| 951 |
+
* D1_0
|
| 952 |
+
)
|
| 953 |
+
return x_t
|
| 954 |
+
|
| 955 |
+
def multistep_dpm_solver_third_update(
|
| 956 |
+
self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"
|
| 957 |
+
):
|
| 958 |
+
"""
|
| 959 |
+
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
|
| 960 |
+
Args:
|
| 961 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 962 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
| 963 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
|
| 964 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
| 965 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
| 966 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
| 967 |
+
Returns:
|
| 968 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 969 |
+
"""
|
| 970 |
+
ns = self.noise_schedule
|
| 971 |
+
dims = x.dim()
|
| 972 |
+
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
|
| 973 |
+
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
|
| 974 |
+
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
|
| 975 |
+
ns.marginal_lambda(t_prev_2),
|
| 976 |
+
ns.marginal_lambda(t_prev_1),
|
| 977 |
+
ns.marginal_lambda(t_prev_0),
|
| 978 |
+
ns.marginal_lambda(t),
|
| 979 |
+
)
|
| 980 |
+
log_alpha_prev_0, log_alpha_t = (
|
| 981 |
+
ns.marginal_log_mean_coeff(t_prev_0),
|
| 982 |
+
ns.marginal_log_mean_coeff(t),
|
| 983 |
+
)
|
| 984 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
| 985 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 986 |
+
|
| 987 |
+
h_1 = lambda_prev_1 - lambda_prev_2
|
| 988 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
| 989 |
+
h = lambda_t - lambda_prev_0
|
| 990 |
+
r0, r1 = h_0 / h, h_1 / h
|
| 991 |
+
D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1)
|
| 992 |
+
D1_1 = expand_dims(1.0 / r1, dims) * (model_prev_1 - model_prev_2)
|
| 993 |
+
D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
|
| 994 |
+
D2 = expand_dims(1.0 / (r0 + r1), dims) * (D1_0 - D1_1)
|
| 995 |
+
if self.predict_x0:
|
| 996 |
+
x_t = (
|
| 997 |
+
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
| 998 |
+
- expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
|
| 999 |
+
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1
|
| 1000 |
+
- expand_dims(
|
| 1001 |
+
alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims
|
| 1002 |
+
)
|
| 1003 |
+
* D2
|
| 1004 |
+
)
|
| 1005 |
+
else:
|
| 1006 |
+
x_t = (
|
| 1007 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
| 1008 |
+
- expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
|
| 1009 |
+
- expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1
|
| 1010 |
+
- expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims)
|
| 1011 |
+
* D2
|
| 1012 |
+
)
|
| 1013 |
+
return x_t
|
| 1014 |
+
|
| 1015 |
+
def singlestep_dpm_solver_update(
|
| 1016 |
+
self,
|
| 1017 |
+
x,
|
| 1018 |
+
s,
|
| 1019 |
+
t,
|
| 1020 |
+
order,
|
| 1021 |
+
return_intermediate=False,
|
| 1022 |
+
solver_type="dpm_solver",
|
| 1023 |
+
r1=None,
|
| 1024 |
+
r2=None,
|
| 1025 |
+
):
|
| 1026 |
+
"""
|
| 1027 |
+
Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
|
| 1028 |
+
Args:
|
| 1029 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 1030 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
| 1031 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
| 1032 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
| 1033 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
|
| 1034 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
| 1035 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
| 1036 |
+
r1: A `float`. The hyperparameter of the second-order or third-order solver.
|
| 1037 |
+
r2: A `float`. The hyperparameter of the third-order solver.
|
| 1038 |
+
Returns:
|
| 1039 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 1040 |
+
"""
|
| 1041 |
+
if order == 1:
|
| 1042 |
+
return self.dpm_solver_first_update(
|
| 1043 |
+
x, s, t, return_intermediate=return_intermediate
|
| 1044 |
+
)
|
| 1045 |
+
elif order == 2:
|
| 1046 |
+
return self.singlestep_dpm_solver_second_update(
|
| 1047 |
+
x,
|
| 1048 |
+
s,
|
| 1049 |
+
t,
|
| 1050 |
+
return_intermediate=return_intermediate,
|
| 1051 |
+
solver_type=solver_type,
|
| 1052 |
+
r1=r1,
|
| 1053 |
+
)
|
| 1054 |
+
elif order == 3:
|
| 1055 |
+
return self.singlestep_dpm_solver_third_update(
|
| 1056 |
+
x,
|
| 1057 |
+
s,
|
| 1058 |
+
t,
|
| 1059 |
+
return_intermediate=return_intermediate,
|
| 1060 |
+
solver_type=solver_type,
|
| 1061 |
+
r1=r1,
|
| 1062 |
+
r2=r2,
|
| 1063 |
+
)
|
| 1064 |
+
else:
|
| 1065 |
+
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
| 1066 |
+
|
| 1067 |
+
def multistep_dpm_solver_update(
|
| 1068 |
+
self, x, model_prev_list, t_prev_list, t, order, solver_type="dpm_solver"
|
| 1069 |
+
):
|
| 1070 |
+
"""
|
| 1071 |
+
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
|
| 1072 |
+
Args:
|
| 1073 |
+
x: A pytorch tensor. The initial value at time `s`.
|
| 1074 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
| 1075 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
|
| 1076 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
| 1077 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
| 1078 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
| 1079 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
| 1080 |
+
Returns:
|
| 1081 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
| 1082 |
+
"""
|
| 1083 |
+
if order == 1:
|
| 1084 |
+
return self.dpm_solver_first_update(
|
| 1085 |
+
x, t_prev_list[-1], t, model_s=model_prev_list[-1]
|
| 1086 |
+
)
|
| 1087 |
+
elif order == 2:
|
| 1088 |
+
return self.multistep_dpm_solver_second_update(
|
| 1089 |
+
x, model_prev_list, t_prev_list, t, solver_type=solver_type
|
| 1090 |
+
)
|
| 1091 |
+
elif order == 3:
|
| 1092 |
+
return self.multistep_dpm_solver_third_update(
|
| 1093 |
+
x, model_prev_list, t_prev_list, t, solver_type=solver_type
|
| 1094 |
+
)
|
| 1095 |
+
else:
|
| 1096 |
+
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
| 1097 |
+
|
| 1098 |
+
def dpm_solver_adaptive(
|
| 1099 |
+
self,
|
| 1100 |
+
x,
|
| 1101 |
+
order,
|
| 1102 |
+
t_T,
|
| 1103 |
+
t_0,
|
| 1104 |
+
h_init=0.05,
|
| 1105 |
+
atol=0.0078,
|
| 1106 |
+
rtol=0.05,
|
| 1107 |
+
theta=0.9,
|
| 1108 |
+
t_err=1e-5,
|
| 1109 |
+
solver_type="dpm_solver",
|
| 1110 |
+
):
|
| 1111 |
+
"""
|
| 1112 |
+
The adaptive step size solver based on singlestep DPM-Solver.
|
| 1113 |
+
Args:
|
| 1114 |
+
x: A pytorch tensor. The initial value at time `t_T`.
|
| 1115 |
+
order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
|
| 1116 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
| 1117 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
| 1118 |
+
h_init: A `float`. The initial step size (for logSNR).
|
| 1119 |
+
atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
|
| 1120 |
+
rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
|
| 1121 |
+
theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
|
| 1122 |
+
t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
|
| 1123 |
+
current time and `t_0` is less than `t_err`. The default setting is 1e-5.
|
| 1124 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
| 1125 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
| 1126 |
+
Returns:
|
| 1127 |
+
x_0: A pytorch tensor. The approximated solution at time `t_0`.
|
| 1128 |
+
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
|
| 1129 |
+
"""
|
| 1130 |
+
ns = self.noise_schedule
|
| 1131 |
+
s = t_T * torch.ones((x.shape[0],)).to(x)
|
| 1132 |
+
lambda_s = ns.marginal_lambda(s)
|
| 1133 |
+
lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
|
| 1134 |
+
h = h_init * torch.ones_like(s).to(x)
|
| 1135 |
+
x_prev = x
|
| 1136 |
+
nfe = 0
|
| 1137 |
+
if order == 2:
|
| 1138 |
+
r1 = 0.5
|
| 1139 |
+
lower_update = lambda x, s, t: self.dpm_solver_first_update(
|
| 1140 |
+
x, s, t, return_intermediate=True
|
| 1141 |
+
)
|
| 1142 |
+
higher_update = (
|
| 1143 |
+
lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(
|
| 1144 |
+
x, s, t, r1=r1, solver_type=solver_type, **kwargs
|
| 1145 |
+
)
|
| 1146 |
+
)
|
| 1147 |
+
elif order == 3:
|
| 1148 |
+
r1, r2 = 1.0 / 3.0, 2.0 / 3.0
|
| 1149 |
+
lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(
|
| 1150 |
+
x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type
|
| 1151 |
+
)
|
| 1152 |
+
higher_update = (
|
| 1153 |
+
lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(
|
| 1154 |
+
x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
|
| 1155 |
+
)
|
| 1156 |
+
)
|
| 1157 |
+
else:
|
| 1158 |
+
raise ValueError(
|
| 1159 |
+
"For adaptive step size solver, order must be 2 or 3, got {}".format(
|
| 1160 |
+
order
|
| 1161 |
+
)
|
| 1162 |
+
)
|
| 1163 |
+
while torch.abs((s - t_0)).mean() > t_err:
|
| 1164 |
+
t = ns.inverse_lambda(lambda_s + h)
|
| 1165 |
+
x_lower, lower_noise_kwargs = lower_update(x, s, t)
|
| 1166 |
+
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
|
| 1167 |
+
delta = torch.max(
|
| 1168 |
+
torch.ones_like(x).to(x) * atol,
|
| 1169 |
+
rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)),
|
| 1170 |
+
)
|
| 1171 |
+
norm_fn = lambda v: torch.sqrt(
|
| 1172 |
+
torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)
|
| 1173 |
+
)
|
| 1174 |
+
E = norm_fn((x_higher - x_lower) / delta).max()
|
| 1175 |
+
if torch.all(E <= 1.0):
|
| 1176 |
+
x = x_higher
|
| 1177 |
+
s = t
|
| 1178 |
+
x_prev = x_lower
|
| 1179 |
+
lambda_s = ns.marginal_lambda(s)
|
| 1180 |
+
h = torch.min(
|
| 1181 |
+
theta * h * torch.float_power(E, -1.0 / order).float(),
|
| 1182 |
+
lambda_0 - lambda_s,
|
| 1183 |
+
)
|
| 1184 |
+
nfe += order
|
| 1185 |
+
print("adaptive solver nfe", nfe)
|
| 1186 |
+
return x
|
| 1187 |
+
|
| 1188 |
+
def sample(
|
| 1189 |
+
self,
|
| 1190 |
+
x,
|
| 1191 |
+
steps=20,
|
| 1192 |
+
t_start=None,
|
| 1193 |
+
t_end=None,
|
| 1194 |
+
order=3,
|
| 1195 |
+
skip_type="time_uniform",
|
| 1196 |
+
method="singlestep",
|
| 1197 |
+
lower_order_final=True,
|
| 1198 |
+
denoise_to_zero=False,
|
| 1199 |
+
solver_type="dpm_solver",
|
| 1200 |
+
atol=0.0078,
|
| 1201 |
+
rtol=0.05,
|
| 1202 |
+
):
|
| 1203 |
+
"""
|
| 1204 |
+
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
|
| 1205 |
+
=====================================================
|
| 1206 |
+
We support the following algorithms for both noise prediction model and data prediction model:
|
| 1207 |
+
- 'singlestep':
|
| 1208 |
+
Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
|
| 1209 |
+
We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
|
| 1210 |
+
The total number of function evaluations (NFE) == `steps`.
|
| 1211 |
+
Given a fixed NFE == `steps`, the sampling procedure is:
|
| 1212 |
+
- If `order` == 1:
|
| 1213 |
+
- Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
|
| 1214 |
+
- If `order` == 2:
|
| 1215 |
+
- Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
|
| 1216 |
+
- If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
|
| 1217 |
+
- If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
|
| 1218 |
+
- If `order` == 3:
|
| 1219 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
| 1220 |
+
- If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
|
| 1221 |
+
- If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
|
| 1222 |
+
- If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
|
| 1223 |
+
- 'multistep':
|
| 1224 |
+
Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
|
| 1225 |
+
We initialize the first `order` values by lower order multistep solvers.
|
| 1226 |
+
Given a fixed NFE == `steps`, the sampling procedure is:
|
| 1227 |
+
Denote K = steps.
|
| 1228 |
+
- If `order` == 1:
|
| 1229 |
+
- We use K steps of DPM-Solver-1 (i.e. DDIM).
|
| 1230 |
+
- If `order` == 2:
|
| 1231 |
+
- We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
|
| 1232 |
+
- If `order` == 3:
|
| 1233 |
+
- We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
|
| 1234 |
+
- 'singlestep_fixed':
|
| 1235 |
+
Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
|
| 1236 |
+
We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
|
| 1237 |
+
- 'adaptive':
|
| 1238 |
+
Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
|
| 1239 |
+
We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
|
| 1240 |
+
You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
|
| 1241 |
+
(NFE) and the sample quality.
|
| 1242 |
+
- If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
|
| 1243 |
+
- If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
|
| 1244 |
+
=====================================================
|
| 1245 |
+
Some advices for choosing the algorithm:
|
| 1246 |
+
- For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
|
| 1247 |
+
Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
|
| 1248 |
+
e.g.
|
| 1249 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
|
| 1250 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
|
| 1251 |
+
skip_type='time_uniform', method='singlestep')
|
| 1252 |
+
- For **guided sampling with large guidance scale** by DPMs:
|
| 1253 |
+
Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
|
| 1254 |
+
e.g.
|
| 1255 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
|
| 1256 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
|
| 1257 |
+
skip_type='time_uniform', method='multistep')
|
| 1258 |
+
We support three types of `skip_type`:
|
| 1259 |
+
- 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
|
| 1260 |
+
- 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
|
| 1261 |
+
- 'time_quadratic': quadratic time for the time steps.
|
| 1262 |
+
=====================================================
|
| 1263 |
+
Args:
|
| 1264 |
+
x: A pytorch tensor. The initial value at time `t_start`
|
| 1265 |
+
e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
|
| 1266 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
| 1267 |
+
t_start: A `float`. The starting time of the sampling.
|
| 1268 |
+
If `T` is None, we use self.noise_schedule.T (default is 1.0).
|
| 1269 |
+
t_end: A `float`. The ending time of the sampling.
|
| 1270 |
+
If `t_end` is None, we use 1. / self.noise_schedule.total_N.
|
| 1271 |
+
e.g. if total_N == 1000, we have `t_end` == 1e-3.
|
| 1272 |
+
For discrete-time DPMs:
|
| 1273 |
+
- We recommend `t_end` == 1. / self.noise_schedule.total_N.
|
| 1274 |
+
For continuous-time DPMs:
|
| 1275 |
+
- We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
|
| 1276 |
+
order: A `int`. The order of DPM-Solver.
|
| 1277 |
+
skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
|
| 1278 |
+
method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
|
| 1279 |
+
denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
|
| 1280 |
+
Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
|
| 1281 |
+
This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
|
| 1282 |
+
score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
|
| 1283 |
+
for diffusion models sampling by diffusion SDEs for low-resolutional images
|
| 1284 |
+
(such as CIFAR-10). However, we observed that such trick does not matter for
|
| 1285 |
+
high-resolutional images. As it needs an additional NFE, we do not recommend
|
| 1286 |
+
it for high-resolutional images.
|
| 1287 |
+
lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
|
| 1288 |
+
Only valid for `method=multistep` and `steps < 15`. We empirically find that
|
| 1289 |
+
this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
|
| 1290 |
+
(especially for steps <= 10). So we recommend to set it to be `True`.
|
| 1291 |
+
solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
|
| 1292 |
+
atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
| 1293 |
+
rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
| 1294 |
+
Returns:
|
| 1295 |
+
x_end: A pytorch tensor. The approximated solution at time `t_end`.
|
| 1296 |
+
"""
|
| 1297 |
+
t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
|
| 1298 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
| 1299 |
+
device = x.device
|
| 1300 |
+
if method == "adaptive":
|
| 1301 |
+
with torch.no_grad():
|
| 1302 |
+
x = self.dpm_solver_adaptive(
|
| 1303 |
+
x,
|
| 1304 |
+
order=order,
|
| 1305 |
+
t_T=t_T,
|
| 1306 |
+
t_0=t_0,
|
| 1307 |
+
atol=atol,
|
| 1308 |
+
rtol=rtol,
|
| 1309 |
+
solver_type=solver_type,
|
| 1310 |
+
)
|
| 1311 |
+
elif method == "multistep":
|
| 1312 |
+
assert steps >= order
|
| 1313 |
+
timesteps = self.get_time_steps(
|
| 1314 |
+
skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device
|
| 1315 |
+
)
|
| 1316 |
+
assert timesteps.shape[0] - 1 == steps
|
| 1317 |
+
with torch.no_grad():
|
| 1318 |
+
vec_t = timesteps[0].expand((x.shape[0]))
|
| 1319 |
+
model_prev_list = [self.model_fn(x, vec_t)]
|
| 1320 |
+
t_prev_list = [vec_t]
|
| 1321 |
+
# Init the first `order` values by lower order multistep DPM-Solver.
|
| 1322 |
+
for init_order in tqdm(range(1, order), desc="DPM init order"):
|
| 1323 |
+
vec_t = timesteps[init_order].expand(x.shape[0])
|
| 1324 |
+
x = self.multistep_dpm_solver_update(
|
| 1325 |
+
x,
|
| 1326 |
+
model_prev_list,
|
| 1327 |
+
t_prev_list,
|
| 1328 |
+
vec_t,
|
| 1329 |
+
init_order,
|
| 1330 |
+
solver_type=solver_type,
|
| 1331 |
+
)
|
| 1332 |
+
model_prev_list.append(self.model_fn(x, vec_t))
|
| 1333 |
+
t_prev_list.append(vec_t)
|
| 1334 |
+
# Compute the remaining values by `order`-th order multistep DPM-Solver.
|
| 1335 |
+
for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
|
| 1336 |
+
vec_t = timesteps[step].expand(x.shape[0])
|
| 1337 |
+
if lower_order_final and steps < 15:
|
| 1338 |
+
step_order = min(order, steps + 1 - step)
|
| 1339 |
+
else:
|
| 1340 |
+
step_order = order
|
| 1341 |
+
x = self.multistep_dpm_solver_update(
|
| 1342 |
+
x,
|
| 1343 |
+
model_prev_list,
|
| 1344 |
+
t_prev_list,
|
| 1345 |
+
vec_t,
|
| 1346 |
+
step_order,
|
| 1347 |
+
solver_type=solver_type,
|
| 1348 |
+
)
|
| 1349 |
+
for i in range(order - 1):
|
| 1350 |
+
t_prev_list[i] = t_prev_list[i + 1]
|
| 1351 |
+
model_prev_list[i] = model_prev_list[i + 1]
|
| 1352 |
+
t_prev_list[-1] = vec_t
|
| 1353 |
+
# We do not need to evaluate the final model value.
|
| 1354 |
+
if step < steps:
|
| 1355 |
+
model_prev_list[-1] = self.model_fn(x, vec_t)
|
| 1356 |
+
elif method in ["singlestep", "singlestep_fixed"]:
|
| 1357 |
+
if method == "singlestep":
|
| 1358 |
+
(
|
| 1359 |
+
timesteps_outer,
|
| 1360 |
+
orders,
|
| 1361 |
+
) = self.get_orders_and_timesteps_for_singlestep_solver(
|
| 1362 |
+
steps=steps,
|
| 1363 |
+
order=order,
|
| 1364 |
+
skip_type=skip_type,
|
| 1365 |
+
t_T=t_T,
|
| 1366 |
+
t_0=t_0,
|
| 1367 |
+
device=device,
|
| 1368 |
+
)
|
| 1369 |
+
elif method == "singlestep_fixed":
|
| 1370 |
+
K = steps // order
|
| 1371 |
+
orders = [
|
| 1372 |
+
order,
|
| 1373 |
+
] * K
|
| 1374 |
+
timesteps_outer = self.get_time_steps(
|
| 1375 |
+
skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device
|
| 1376 |
+
)
|
| 1377 |
+
for i, order in enumerate(orders):
|
| 1378 |
+
t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
|
| 1379 |
+
timesteps_inner = self.get_time_steps(
|
| 1380 |
+
skip_type=skip_type,
|
| 1381 |
+
t_T=t_T_inner.item(),
|
| 1382 |
+
t_0=t_0_inner.item(),
|
| 1383 |
+
N=order,
|
| 1384 |
+
device=device,
|
| 1385 |
+
)
|
| 1386 |
+
lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
|
| 1387 |
+
vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
|
| 1388 |
+
h = lambda_inner[-1] - lambda_inner[0]
|
| 1389 |
+
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
|
| 1390 |
+
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
|
| 1391 |
+
x = self.singlestep_dpm_solver_update(
|
| 1392 |
+
x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2
|
| 1393 |
+
)
|
| 1394 |
+
if denoise_to_zero:
|
| 1395 |
+
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
| 1396 |
+
return x
|
| 1397 |
+
|
| 1398 |
+
|
| 1399 |
+
#############################################################
|
| 1400 |
+
# other utility functions
|
| 1401 |
+
#############################################################
|
| 1402 |
+
|
| 1403 |
+
|
| 1404 |
+
def interpolate_fn(x, xp, yp):
|
| 1405 |
+
"""
|
| 1406 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
| 1407 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
| 1408 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
| 1409 |
+
Args:
|
| 1410 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
| 1411 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
| 1412 |
+
yp: PyTorch tensor with shape [C, K].
|
| 1413 |
+
Returns:
|
| 1414 |
+
The function values f(x), with shape [N, C].
|
| 1415 |
+
"""
|
| 1416 |
+
N, K = x.shape[0], xp.shape[1]
|
| 1417 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
| 1418 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
| 1419 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
| 1420 |
+
cand_start_idx = x_idx - 1
|
| 1421 |
+
start_idx = torch.where(
|
| 1422 |
+
torch.eq(x_idx, 0),
|
| 1423 |
+
torch.tensor(1, device=x.device),
|
| 1424 |
+
torch.where(
|
| 1425 |
+
torch.eq(x_idx, K),
|
| 1426 |
+
torch.tensor(K - 2, device=x.device),
|
| 1427 |
+
cand_start_idx,
|
| 1428 |
+
),
|
| 1429 |
+
)
|
| 1430 |
+
end_idx = torch.where(
|
| 1431 |
+
torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1
|
| 1432 |
+
)
|
| 1433 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
| 1434 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
| 1435 |
+
start_idx2 = torch.where(
|
| 1436 |
+
torch.eq(x_idx, 0),
|
| 1437 |
+
torch.tensor(0, device=x.device),
|
| 1438 |
+
torch.where(
|
| 1439 |
+
torch.eq(x_idx, K),
|
| 1440 |
+
torch.tensor(K - 2, device=x.device),
|
| 1441 |
+
cand_start_idx,
|
| 1442 |
+
),
|
| 1443 |
+
)
|
| 1444 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
| 1445 |
+
start_y = torch.gather(
|
| 1446 |
+
y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)
|
| 1447 |
+
).squeeze(2)
|
| 1448 |
+
end_y = torch.gather(
|
| 1449 |
+
y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)
|
| 1450 |
+
).squeeze(2)
|
| 1451 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
| 1452 |
+
return cand
|
| 1453 |
+
|
| 1454 |
+
|
| 1455 |
+
def expand_dims(v, dims):
|
| 1456 |
+
"""
|
| 1457 |
+
Expand the tensor `v` to the dim `dims`.
|
| 1458 |
+
Args:
|
| 1459 |
+
`v`: a PyTorch tensor with shape [N].
|
| 1460 |
+
`dim`: a `int`.
|
| 1461 |
+
Returns:
|
| 1462 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
| 1463 |
+
"""
|
| 1464 |
+
return v[(...,) + (None,) * (dims - 1)]
|