Spaces:
Runtime error
Runtime error
kingpreyansh
commited on
Commit
•
ecef837
1
Parent(s):
ef2baa2
Upload 437 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- extensions/sd-webui-controlnet/.github/ISSUE_TEMPLATE/bug_report.yml +84 -0
- extensions/sd-webui-controlnet/.github/ISSUE_TEMPLATE/config.yml +1 -0
- extensions/sd-webui-controlnet/.gitignore +166 -0
- extensions/sd-webui-controlnet/LICENSE +21 -0
- extensions/sd-webui-controlnet/README.md +160 -0
- extensions/sd-webui-controlnet/annotator/binary/__init__.py +14 -0
- extensions/sd-webui-controlnet/annotator/canny/__init__.py +5 -0
- extensions/sd-webui-controlnet/annotator/clip/__init__.py +23 -0
- extensions/sd-webui-controlnet/annotator/color/__init__.py +6 -0
- extensions/sd-webui-controlnet/annotator/hed/__init__.py +148 -0
- extensions/sd-webui-controlnet/annotator/informative/__init__.py +131 -0
- extensions/sd-webui-controlnet/annotator/keypose/__init__.py +212 -0
- extensions/sd-webui-controlnet/annotator/keypose/faster_rcnn_r50_fpn_coco.py +182 -0
- extensions/sd-webui-controlnet/annotator/keypose/hrnet_w48_coco_256x192.py +169 -0
- extensions/sd-webui-controlnet/annotator/leres/__init__.py +115 -0
- extensions/sd-webui-controlnet/annotator/leres/leres/LICENSE +23 -0
- extensions/sd-webui-controlnet/annotator/leres/leres/Resnet.py +199 -0
- extensions/sd-webui-controlnet/annotator/leres/leres/Resnext_torch.py +247 -0
- extensions/sd-webui-controlnet/annotator/leres/leres/depthmap.py +546 -0
- extensions/sd-webui-controlnet/annotator/leres/leres/multi_depth_model_woauxi.py +34 -0
- extensions/sd-webui-controlnet/annotator/leres/leres/net_tools.py +53 -0
- extensions/sd-webui-controlnet/annotator/leres/leres/network_auxi.py +417 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/LICENSE +19 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/models/__init__.py +67 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/models/base_model.py +241 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/models/base_model_hg.py +58 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/models/networks.py +623 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/models/pix2pix4depth_model.py +155 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/options/__init__.py +1 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/options/base_options.py +156 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/options/test_options.py +22 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/util/__init__.py +1 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/util/get_data.py +110 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/util/guidedfilter.py +47 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/util/html.py +86 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/util/image_pool.py +54 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/util/util.py +105 -0
- extensions/sd-webui-controlnet/annotator/leres/pix2pix/util/visualizer.py +166 -0
- extensions/sd-webui-controlnet/annotator/midas/__init__.py +49 -0
- extensions/sd-webui-controlnet/annotator/midas/api.py +181 -0
- extensions/sd-webui-controlnet/annotator/midas/midas/__init__.py +0 -0
- extensions/sd-webui-controlnet/annotator/midas/midas/base_model.py +16 -0
- extensions/sd-webui-controlnet/annotator/midas/midas/blocks.py +342 -0
- extensions/sd-webui-controlnet/annotator/midas/midas/dpt_depth.py +109 -0
- extensions/sd-webui-controlnet/annotator/midas/midas/midas_net.py +76 -0
- extensions/sd-webui-controlnet/annotator/midas/midas/midas_net_custom.py +128 -0
- extensions/sd-webui-controlnet/annotator/midas/midas/transforms.py +234 -0
- extensions/sd-webui-controlnet/annotator/midas/midas/vit.py +491 -0
- extensions/sd-webui-controlnet/annotator/midas/utils.py +189 -0
- extensions/sd-webui-controlnet/annotator/mlsd/__init__.py +49 -0
extensions/sd-webui-controlnet/.github/ISSUE_TEMPLATE/bug_report.yml
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Bug Report
|
2 |
+
description: Create a report
|
3 |
+
title: "[Bug]: "
|
4 |
+
labels: ["bug-report"]
|
5 |
+
|
6 |
+
body:
|
7 |
+
- type: checkboxes
|
8 |
+
attributes:
|
9 |
+
label: Is there an existing issue for this?
|
10 |
+
description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
|
11 |
+
options:
|
12 |
+
- label: I have searched the existing issues and checked the recent builds/commits of both this extension and the webui
|
13 |
+
required: true
|
14 |
+
- type: markdown
|
15 |
+
attributes:
|
16 |
+
value: |
|
17 |
+
*Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
|
18 |
+
- type: textarea
|
19 |
+
id: what-did
|
20 |
+
attributes:
|
21 |
+
label: What happened?
|
22 |
+
description: Tell us what happened in a very clear and simple way
|
23 |
+
validations:
|
24 |
+
required: true
|
25 |
+
- type: textarea
|
26 |
+
id: steps
|
27 |
+
attributes:
|
28 |
+
label: Steps to reproduce the problem
|
29 |
+
description: Please provide us with precise step by step information on how to reproduce the bug
|
30 |
+
value: |
|
31 |
+
1. Go to ....
|
32 |
+
2. Press ....
|
33 |
+
3. ...
|
34 |
+
validations:
|
35 |
+
required: true
|
36 |
+
- type: textarea
|
37 |
+
id: what-should
|
38 |
+
attributes:
|
39 |
+
label: What should have happened?
|
40 |
+
description: Tell what you think the normal behavior should be
|
41 |
+
validations:
|
42 |
+
required: true
|
43 |
+
- type: textarea
|
44 |
+
id: commits
|
45 |
+
attributes:
|
46 |
+
label: Commit where the problem happens
|
47 |
+
description: Which commit of the extension are you running on? Please include the commit of both the extension and the webui (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
|
48 |
+
value: |
|
49 |
+
webui:
|
50 |
+
controlnet:
|
51 |
+
validations:
|
52 |
+
required: true
|
53 |
+
- type: dropdown
|
54 |
+
id: browsers
|
55 |
+
attributes:
|
56 |
+
label: What browsers do you use to access the UI ?
|
57 |
+
multiple: true
|
58 |
+
options:
|
59 |
+
- Mozilla Firefox
|
60 |
+
- Google Chrome
|
61 |
+
- Brave
|
62 |
+
- Apple Safari
|
63 |
+
- Microsoft Edge
|
64 |
+
- type: textarea
|
65 |
+
id: cmdargs
|
66 |
+
attributes:
|
67 |
+
label: Command Line Arguments
|
68 |
+
description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
|
69 |
+
render: Shell
|
70 |
+
validations:
|
71 |
+
required: true
|
72 |
+
- type: textarea
|
73 |
+
id: logs
|
74 |
+
attributes:
|
75 |
+
label: Console logs
|
76 |
+
description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
|
77 |
+
render: Shell
|
78 |
+
validations:
|
79 |
+
required: true
|
80 |
+
- type: textarea
|
81 |
+
id: misc
|
82 |
+
attributes:
|
83 |
+
label: Additional information
|
84 |
+
description: Please provide us with any relevant additional info or context.
|
extensions/sd-webui-controlnet/.github/ISSUE_TEMPLATE/config.yml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
blank_issues_enabled: true
|
extensions/sd-webui-controlnet/.gitignore
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea
|
161 |
+
*.pt
|
162 |
+
*.pth
|
163 |
+
*.ckpt
|
164 |
+
*.safetensors
|
165 |
+
models/control_sd15_scribble.pth
|
166 |
+
detected_maps/
|
extensions/sd-webui-controlnet/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Kakigōri Maker
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
extensions/sd-webui-controlnet/README.md
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## sd-webui-controlnet
|
2 |
+
(WIP) WebUI extension for ControlNet and T2I-Adapter
|
3 |
+
|
4 |
+
This extension is for AUTOMATIC1111's [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui), allows the Web UI to add [ControlNet](https://github.com/lllyasviel/ControlNet) to the original Stable Diffusion model to generate images. The addition is on-the-fly, the merging is not required.
|
5 |
+
|
6 |
+
ControlNet is a neural network structure to control diffusion models by adding extra conditions.
|
7 |
+
|
8 |
+
Thanks & Inspired: kohya-ss/sd-webui-additional-networks
|
9 |
+
|
10 |
+
### Limits
|
11 |
+
|
12 |
+
* Dragging large file on the Web UI may freeze the entire page. It is better to use the upload file option instead.
|
13 |
+
* Just like WebUI's [hijack](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/3715ece0adce7bf7c5e9c5ab3710b2fdc3848f39/modules/sd_hijack_unet.py#L27), we used some interpolate to accept arbitrary size configure (see `scripts/cldm.py`)
|
14 |
+
|
15 |
+
### Install
|
16 |
+
|
17 |
+
1. Open "Extensions" tab.
|
18 |
+
2. Open "Install from URL" tab in the tab.
|
19 |
+
3. Enter URL of this repo to "URL for extension's git repository".
|
20 |
+
4. Press "Install" button.
|
21 |
+
5. Reload/Restart Web UI.
|
22 |
+
|
23 |
+
Upgrade gradio if any ui issues occured: `pip install gradio==3.16.2`
|
24 |
+
|
25 |
+
### Usage
|
26 |
+
|
27 |
+
1. Put the ControlNet models (`.pt`, `.pth`, `.ckpt` or `.safetensors`) inside the `models/ControlNet` folder.
|
28 |
+
2. Open "txt2img" or "img2img" tab, write your prompts.
|
29 |
+
3. Press "Refresh models" and select the model you want to use. (If nothing appears, try reload/restart the webui)
|
30 |
+
4. Upload your image and select preprocessor, done.
|
31 |
+
|
32 |
+
Currently it supports both full models and trimmed models. Use `extract_controlnet.py` to extract controlnet from original `.pth` file.
|
33 |
+
|
34 |
+
Pretrained Models: https://huggingface.co/lllyasviel/ControlNet/tree/main/models
|
35 |
+
|
36 |
+
### Extraction
|
37 |
+
|
38 |
+
Two methods can be used to reduce the model's filesize:
|
39 |
+
|
40 |
+
1. Directly extract controlnet from original .pth file using `extract_controlnet.py`.
|
41 |
+
|
42 |
+
2. Transfer control from original checkpoint by making difference using `extract_controlnet_diff.py`.
|
43 |
+
|
44 |
+
All type of models can be correctly recognized and loaded. The results of different extraction methods are discussed in https://github.com/lllyasviel/ControlNet/discussions/12 and https://github.com/Mikubill/sd-webui-controlnet/issues/73.
|
45 |
+
|
46 |
+
Pre-extracted model: https://huggingface.co/webui/ControlNet-modules-safetensors
|
47 |
+
|
48 |
+
Pre-extracted difference model: https://huggingface.co/kohya-ss/ControlNet-diff-modules
|
49 |
+
|
50 |
+
### T2I-Adapter Support
|
51 |
+
|
52 |
+
Note that the impl is experimental, result may differ from original repo. See `Adapter Examples` for reference.
|
53 |
+
|
54 |
+
To use T2I-Adapter models:
|
55 |
+
1. Download files from https://huggingface.co/TencentARC/T2I-Adapter
|
56 |
+
2. Copy corresponding config file and rename it to the same name as the model - see list below.
|
57 |
+
3. It's better to use a slightly lower strength (t) when generating images with sketch model, such as 0.6-0.8. (ref: [ldm/models/diffusion/plms.py](https://github.com/TencentARC/T2I-Adapter/blob/5f41a0e38fc6eac90d04bc4cede85a2bc4570653/ldm/models/diffusion/plms.py#L158))
|
58 |
+
|
59 |
+
| Adapter | Config |
|
60 |
+
|:-------------------------:|:-------------------------:|
|
61 |
+
| t2iadapter_canny_sd14v1.pth | sketch_adapter_v14.yaml |
|
62 |
+
| t2iadapter_sketch_sd14v1.pth | sketch_adapter_v14.yaml |
|
63 |
+
| t2iadapter_seg_sd14v1.pth | image_adapter_v14.yaml |
|
64 |
+
| t2iadapter_keypose_sd14v1.pth | image_adapter_v14.yaml |
|
65 |
+
| t2iadapter_openpose_sd14v1.pth | image_adapter_v14.yaml |
|
66 |
+
| t2iadapter_color_sd14v1.pth | t2iadapter_color_sd14v1.yaml |
|
67 |
+
| t2iadapter_style_sd14v1.pth | t2iadapter_style_sd14v1.yaml |
|
68 |
+
|
69 |
+
### Tips
|
70 |
+
|
71 |
+
* Don't forget to add some negative prompt, default negative prompt in ControlNet repo is "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality".
|
72 |
+
* Regarding canvas height/width: they are designed for canvas generation. If you want to upload images directly, you can safely ignore them.
|
73 |
+
|
74 |
+
### Examples
|
75 |
+
|
76 |
+
| Source | Input | Output |
|
77 |
+
|:-------------------------:|:-------------------------:|:-------------------------:|
|
78 |
+
| (no preprocessor) | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/bal-source.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/bal-gen.png?raw=true"> |
|
79 |
+
| (no preprocessor) | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/dog_rel.jpg?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/dog_rel.png?raw=true"> |
|
80 |
+
|<img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_input.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro_canny.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/mahiro-out.png?raw=true"> |
|
81 |
+
|<img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/evt_source.jpg?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/evt_hed.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/evt_gen.png?raw=true"> |
|
82 |
+
|<img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/an-source.jpg?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/an-pose.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/an-gen.png?raw=true"> |
|
83 |
+
|<img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/sk-b-src.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/sk-b-dep.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/sk-b-out.png?raw=true"> |
|
84 |
+
|<img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/nm-src.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/nm-gen.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/nm-out.png?raw=true"> |
|
85 |
+
|
86 |
+
### Adapter Examples
|
87 |
+
|
88 |
+
| Source | Input | Output |
|
89 |
+
|:-------------------------:|:-------------------------:|:-------------------------:|
|
90 |
+
| (no preprocessor) | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/dog_sk-2.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/dog_out-2.png?raw=true"> |
|
91 |
+
| (no preprocessor) | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/cat_sk-2.png?raw=true"> | <img width="256" alt="" src="https://github.com/Mikubill/sd-webui-controlnet/blob/main/samples/cat_out-2.png?raw=true"> |
|
92 |
+
| (no preprocessor) | <img width="256" alt="" src="https://user-images.githubusercontent.com/31246794/222967315-dc50406d-2930-47c5-8027-f76b95969f2b.png"> | <img width="256" alt="" src="https://user-images.githubusercontent.com/31246794/222967311-724d9531-4b93-4770-8409-cd9480434112.png"> |
|
93 |
+
| (no preprocessor) | <img width="256" alt="" src="https://user-images.githubusercontent.com/31246794/222966824-8f6c36f1-525b-40c2-ae9e-d3f5d148b5c9.png"> | <img width="256" alt="" src="https://user-images.githubusercontent.com/31246794/222966821-110541a4-5014-4cee-90f8-758edf540eae.png"> |
|
94 |
+
| <img width="256" alt="" src="https://user-images.githubusercontent.com/31246794/222947416-ec9e52a4-a1d0-48d8-bb81-736bf636145e.jpeg"> | <img width="256" alt="" src="https://user-images.githubusercontent.com/31246794/222947435-1164e7d8-d857-42f9-ab10-2d4a4b25f33a.png"> | <img width="256" alt="" src="https://user-images.githubusercontent.com/31246794/222947557-5520d5f8-88b4-474d-a576-5c9cd3acac3a.png"> |
|
95 |
+
| <img width="256" alt="" src="https://user-images.githubusercontent.com/31246794/222947416-ec9e52a4-a1d0-48d8-bb81-736bf636145e.jpeg"> | (clip, non-image) | <img width="256" alt="" src="https://user-images.githubusercontent.com/31246794/222965711-7b884c9e-7095-45cb-a91c-e50d296ba3a2.png"> |
|
96 |
+
|
97 |
+
Examples by catboxanon, no tweaking or cherrypicking. (Color Guidance)
|
98 |
+
|
99 |
+
| Image | Disabled | Enabled |
|
100 |
+
|:-------------------------:|:-------------------------:|:-------------------------:|
|
101 |
+
| <img width="256" alt="" src="https://user-images.githubusercontent.com/122327233/222869104-0830feab-a0a1-448e-8bcd-add54b219cba.png"> | <img width="256" alt="" src="https://user-images.githubusercontent.com/122327233/222869047-d0111979-0ef7-4152-8523-8a45c47217c0.png"> | <img width="256" alt="" src="https://user-images.githubusercontent.com/122327233/222869079-7e5a62e0-fffe-4a19-8875-cba4c68b9428.png"> |
|
102 |
+
| <img width="256" alt="" src="https://user-images.githubusercontent.com/122327233/222869253-44f94dfa-5448-48b2-85be-73db867bdbbb.png"> | <img width="256" alt="" src="https://user-images.githubusercontent.com/122327233/222869261-92e598d0-2950-4874-8b6c-c159bda38365.png"> | <img width="256" alt="" src="https://user-images.githubusercontent.com/122327233/222869272-a4883524-7804-4013-addd-4d1ac56c5d0d.png"> |
|
103 |
+
|
104 |
+
### Minimum Requirements
|
105 |
+
|
106 |
+
* (Windows) (NVIDIA: Ampere) 4gb - with `--xformers` enabled, and `Low VRAM` mode ticked in the UI, goes up to 768x832
|
107 |
+
|
108 |
+
### CFG Based ControlNet (Experimental)
|
109 |
+
|
110 |
+
The original ControlNet applies control to both conditional (cond) and unconditional (uncond) parts. Enabling this option will make the control only apply to the cond part. Some experiments indicate that this approach improves image quality.
|
111 |
+
|
112 |
+
To enable this option, tick `Enable CFG-Based guidance for ControlNet` in the settings.
|
113 |
+
|
114 |
+
Note that you need to use a low cfg scale/guidance scale (such as 3-5) and proper weight tuning to get good result.
|
115 |
+
|
116 |
+
### Guess Mode (Non-Prompt Mode, Experimental)
|
117 |
+
|
118 |
+
Guess Mode is CFG Based ControlNet + Exponential decay in weighting.
|
119 |
+
|
120 |
+
See issue https://github.com/Mikubill/sd-webui-controlnet/issues/236 for more details.
|
121 |
+
|
122 |
+
Original introduction from controlnet:
|
123 |
+
|
124 |
+
The "guess mode" (or called non-prompt mode) will completely unleash all the power of the very powerful ControlNet encoder.
|
125 |
+
|
126 |
+
In this mode, you can just remove all prompts, and then the ControlNet encoder will recognize the content of the input control map, like depth map, edge map, scribbles, etc.
|
127 |
+
|
128 |
+
This mode is very suitable for comparing different methods to control stable diffusion because the non-prompted generating task is significantly more difficult than prompted task. In this mode, different methods' performance will be very salient.
|
129 |
+
|
130 |
+
For this mode, we recommend to **use 50 steps and guidance scale between 3 and 5.**
|
131 |
+
|
132 |
+
### Multi-ControlNet / Joint Conditioning (Experimental)
|
133 |
+
|
134 |
+
This option allows multiple ControlNet inputs for a single generation. To enable this option, change `Multi ControlNet: Max models amount (requires restart)` in the settings. Note that you will need to restart the WebUI for changes to take effect.
|
135 |
+
|
136 |
+
* Guess Mode will apply to all ControlNet if any of them are enabled.
|
137 |
+
|
138 |
+
| Source A | Source B | Output |
|
139 |
+
|:-------------------------:|:-------------------------:|:-------------------------:|
|
140 |
+
| <img width="256" alt="" src="https://user-images.githubusercontent.com/31246794/220448620-cd3ede92-8d3f-43d5-b771-32dd8417618f.png"> | <img width="256" alt="" src="https://user-images.githubusercontent.com/31246794/220448619-beed9bdb-f6bb-41c2-a7df-aa3ef1f653c5.png"> | <img width="256" alt="" src="https://user-images.githubusercontent.com/31246794/220448613-c99a9e04-0450-40fd-bc73-a9122cefaa2c.png"> |
|
141 |
+
|
142 |
+
### Weight and Guidance Strength/Start/End
|
143 |
+
|
144 |
+
Weight is the weight of the controlnet "influence". It's analogous to prompt attention/emphasis. E.g. (myprompt: 1.2). Technically, it's the factor by which to multiply the ControlNet outputs before merging them with original SD Unet.
|
145 |
+
|
146 |
+
Guidance Start/End is the percentage of total steps the controlnet applies (guidance strength = guidance end). It's analogous to prompt editing/shifting. E.g. \[myprompt::0.8\] (It applies from the beginning until 80% of total steps)
|
147 |
+
|
148 |
+
### API/Script Access
|
149 |
+
|
150 |
+
This extension can accept txt2img or img2img tasks via API or external extension call. Note that you may need to enable `Allow other scripts to control this extension` in settings for external calls.
|
151 |
+
|
152 |
+
To use the API: start WebUI with argument `--api` and go to `http://webui-address/docs` for documents or checkout [examples](https://github.com/Mikubill/sd-webui-controlnet/blob/main/example/api_txt2img.ipynb).
|
153 |
+
|
154 |
+
To use external call: Checkout [Wiki](https://github.com/Mikubill/sd-webui-controlnet/wiki/API)
|
155 |
+
|
156 |
+
### MacOS Support
|
157 |
+
|
158 |
+
Tested with pytorch nightly: https://github.com/Mikubill/sd-webui-controlnet/pull/143#issuecomment-1435058285
|
159 |
+
|
160 |
+
To use this extension with mps and normal pytorch, currently you may need to start WebUI with `--no-half`.
|
extensions/sd-webui-controlnet/annotator/binary/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
|
4 |
+
def apply_binary(img, bin_threshold):
|
5 |
+
img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
6 |
+
|
7 |
+
if bin_threshold == 0 or bin_threshold == 255:
|
8 |
+
# Otsu's threshold
|
9 |
+
otsu_threshold, img_bin = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
10 |
+
print("Otsu threshold:", otsu_threshold)
|
11 |
+
else:
|
12 |
+
_, img_bin = cv2.threshold(img_gray, bin_threshold, 255, cv2.THRESH_BINARY_INV)
|
13 |
+
|
14 |
+
return cv2.cvtColor(img_bin, cv2.COLOR_GRAY2RGB)
|
extensions/sd-webui-controlnet/annotator/canny/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
|
4 |
+
def apply_canny(img, low_threshold, high_threshold):
|
5 |
+
return cv2.Canny(img, low_threshold, high_threshold)
|
extensions/sd-webui-controlnet/annotator/clip/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPProcessor, CLIPVisionModel
|
2 |
+
from modules import devices
|
3 |
+
|
4 |
+
version = 'openai/clip-vit-large-patch14'
|
5 |
+
clip_proc = None
|
6 |
+
clip_vision_model = None
|
7 |
+
|
8 |
+
def apply_clip(img):
|
9 |
+
global clip_proc, clip_vision_model
|
10 |
+
|
11 |
+
if clip_vision_model is None:
|
12 |
+
clip_proc = CLIPProcessor.from_pretrained(version)
|
13 |
+
clip_vision_model = CLIPVisionModel.from_pretrained(version)
|
14 |
+
|
15 |
+
clip_vision_model = clip_vision_model.to(devices.get_device_for("controlnet"))
|
16 |
+
style_for_clip = clip_proc(images=img, return_tensors="pt")['pixel_values']
|
17 |
+
style_feat = clip_vision_model(style_for_clip.to(devices.get_device_for("controlnet")))['last_hidden_state']
|
18 |
+
return style_feat
|
19 |
+
|
20 |
+
def unload_clip_model():
|
21 |
+
global clip_proc, clip_vision_model
|
22 |
+
if clip_vision_model is not None:
|
23 |
+
clip_vision_model.cpu()
|
extensions/sd-webui-controlnet/annotator/color/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
def apply_color(img, res=512):
|
4 |
+
input_img_color = cv2.resize(img, (res//64, res//64), interpolation=cv2.INTER_CUBIC)
|
5 |
+
input_img_color = cv2.resize(input_img_color, (res, res), interpolation=cv2.INTER_NEAREST)
|
6 |
+
return input_img_color
|
extensions/sd-webui-controlnet/annotator/hed/__init__.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from distutils import extension
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
import os
|
8 |
+
from modules import devices
|
9 |
+
from modules.paths import models_path
|
10 |
+
|
11 |
+
class Network(torch.nn.Module):
|
12 |
+
def __init__(self, model_path):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.netVggOne = torch.nn.Sequential(
|
16 |
+
torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
|
17 |
+
torch.nn.ReLU(inplace=False),
|
18 |
+
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
19 |
+
torch.nn.ReLU(inplace=False)
|
20 |
+
)
|
21 |
+
|
22 |
+
self.netVggTwo = torch.nn.Sequential(
|
23 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
24 |
+
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
|
25 |
+
torch.nn.ReLU(inplace=False),
|
26 |
+
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
27 |
+
torch.nn.ReLU(inplace=False)
|
28 |
+
)
|
29 |
+
|
30 |
+
self.netVggThr = torch.nn.Sequential(
|
31 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
32 |
+
torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
|
33 |
+
torch.nn.ReLU(inplace=False),
|
34 |
+
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
35 |
+
torch.nn.ReLU(inplace=False),
|
36 |
+
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
37 |
+
torch.nn.ReLU(inplace=False)
|
38 |
+
)
|
39 |
+
|
40 |
+
self.netVggFou = torch.nn.Sequential(
|
41 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
42 |
+
torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
|
43 |
+
torch.nn.ReLU(inplace=False),
|
44 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
45 |
+
torch.nn.ReLU(inplace=False),
|
46 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
47 |
+
torch.nn.ReLU(inplace=False)
|
48 |
+
)
|
49 |
+
|
50 |
+
self.netVggFiv = torch.nn.Sequential(
|
51 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
52 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
53 |
+
torch.nn.ReLU(inplace=False),
|
54 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
55 |
+
torch.nn.ReLU(inplace=False),
|
56 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
57 |
+
torch.nn.ReLU(inplace=False)
|
58 |
+
)
|
59 |
+
|
60 |
+
self.netScoreOne = torch.nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
|
61 |
+
self.netScoreTwo = torch.nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
|
62 |
+
self.netScoreThr = torch.nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
|
63 |
+
self.netScoreFou = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
|
64 |
+
self.netScoreFiv = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
|
65 |
+
|
66 |
+
self.netCombine = torch.nn.Sequential(
|
67 |
+
torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
|
68 |
+
torch.nn.Sigmoid()
|
69 |
+
)
|
70 |
+
|
71 |
+
self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()})
|
72 |
+
# end
|
73 |
+
|
74 |
+
def forward(self, tenInput):
|
75 |
+
tenInput = tenInput * 255.0
|
76 |
+
tenInput = tenInput - torch.tensor(data=[104.00698793, 116.66876762, 122.67891434], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1)
|
77 |
+
|
78 |
+
tenVggOne = self.netVggOne(tenInput)
|
79 |
+
tenVggTwo = self.netVggTwo(tenVggOne)
|
80 |
+
tenVggThr = self.netVggThr(tenVggTwo)
|
81 |
+
tenVggFou = self.netVggFou(tenVggThr)
|
82 |
+
tenVggFiv = self.netVggFiv(tenVggFou)
|
83 |
+
|
84 |
+
tenScoreOne = self.netScoreOne(tenVggOne)
|
85 |
+
tenScoreTwo = self.netScoreTwo(tenVggTwo)
|
86 |
+
tenScoreThr = self.netScoreThr(tenVggThr)
|
87 |
+
tenScoreFou = self.netScoreFou(tenVggFou)
|
88 |
+
tenScoreFiv = self.netScoreFiv(tenVggFiv)
|
89 |
+
|
90 |
+
tenScoreOne = torch.nn.functional.interpolate(input=tenScoreOne, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
91 |
+
tenScoreTwo = torch.nn.functional.interpolate(input=tenScoreTwo, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
92 |
+
tenScoreThr = torch.nn.functional.interpolate(input=tenScoreThr, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
93 |
+
tenScoreFou = torch.nn.functional.interpolate(input=tenScoreFou, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
94 |
+
tenScoreFiv = torch.nn.functional.interpolate(input=tenScoreFiv, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
95 |
+
|
96 |
+
return self.netCombine(torch.cat([ tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv ], 1))
|
97 |
+
# end
|
98 |
+
# end
|
99 |
+
|
100 |
+
netNetwork = None
|
101 |
+
remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/network-bsds500.pth"
|
102 |
+
modeldir = os.path.join(models_path, "hed")
|
103 |
+
old_modeldir = os.path.dirname(os.path.realpath(__file__))
|
104 |
+
|
105 |
+
def apply_hed(input_image):
|
106 |
+
global netNetwork
|
107 |
+
if netNetwork is None:
|
108 |
+
modelpath = os.path.join(modeldir, "network-bsds500.pth")
|
109 |
+
old_modelpath = os.path.join(old_modeldir, "network-bsds500.pth")
|
110 |
+
if os.path.exists(old_modelpath):
|
111 |
+
modelpath = old_modelpath
|
112 |
+
elif not os.path.exists(modelpath):
|
113 |
+
from basicsr.utils.download_util import load_file_from_url
|
114 |
+
load_file_from_url(remote_model_path, model_dir=modeldir)
|
115 |
+
netNetwork = Network(modelpath)
|
116 |
+
netNetwork.to(devices.get_device_for("controlnet")).eval()
|
117 |
+
|
118 |
+
assert input_image.ndim == 3
|
119 |
+
input_image = input_image[:, :, ::-1].copy()
|
120 |
+
with torch.no_grad():
|
121 |
+
image_hed = torch.from_numpy(input_image).float().to(devices.get_device_for("controlnet"))
|
122 |
+
image_hed = image_hed / 255.0
|
123 |
+
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
124 |
+
edge = netNetwork(image_hed)[0]
|
125 |
+
edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
|
126 |
+
return edge[0]
|
127 |
+
|
128 |
+
def unload_hed_model():
|
129 |
+
global netNetwork
|
130 |
+
if netNetwork is not None:
|
131 |
+
netNetwork.cpu()
|
132 |
+
|
133 |
+
def nms(x, t, s):
|
134 |
+
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
135 |
+
|
136 |
+
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
137 |
+
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
138 |
+
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
139 |
+
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
140 |
+
|
141 |
+
y = np.zeros_like(x)
|
142 |
+
|
143 |
+
for f in [f1, f2, f3, f4]:
|
144 |
+
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
145 |
+
|
146 |
+
z = np.zeros_like(y, dtype=np.uint8)
|
147 |
+
z[y > t] = 255
|
148 |
+
return z
|
extensions/sd-webui-controlnet/annotator/informative/__init__.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from distutils import extension
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
import os
|
8 |
+
from modules import devices
|
9 |
+
from modules.paths import models_path
|
10 |
+
|
11 |
+
class Network(torch.nn.Module):
|
12 |
+
def __init__(self, model_path):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.netVggOne = torch.nn.Sequential(
|
16 |
+
torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
|
17 |
+
torch.nn.ReLU(inplace=False),
|
18 |
+
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
19 |
+
torch.nn.ReLU(inplace=False)
|
20 |
+
)
|
21 |
+
|
22 |
+
self.netVggTwo = torch.nn.Sequential(
|
23 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
24 |
+
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
|
25 |
+
torch.nn.ReLU(inplace=False),
|
26 |
+
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
27 |
+
torch.nn.ReLU(inplace=False)
|
28 |
+
)
|
29 |
+
|
30 |
+
self.netVggThr = torch.nn.Sequential(
|
31 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
32 |
+
torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
|
33 |
+
torch.nn.ReLU(inplace=False),
|
34 |
+
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
35 |
+
torch.nn.ReLU(inplace=False),
|
36 |
+
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
37 |
+
torch.nn.ReLU(inplace=False)
|
38 |
+
)
|
39 |
+
|
40 |
+
self.netVggFou = torch.nn.Sequential(
|
41 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
42 |
+
torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
|
43 |
+
torch.nn.ReLU(inplace=False),
|
44 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
45 |
+
torch.nn.ReLU(inplace=False),
|
46 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
47 |
+
torch.nn.ReLU(inplace=False)
|
48 |
+
)
|
49 |
+
|
50 |
+
self.netVggFiv = torch.nn.Sequential(
|
51 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
52 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
53 |
+
torch.nn.ReLU(inplace=False),
|
54 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
55 |
+
torch.nn.ReLU(inplace=False),
|
56 |
+
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
57 |
+
torch.nn.ReLU(inplace=False)
|
58 |
+
)
|
59 |
+
|
60 |
+
self.netScoreOne = torch.nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
|
61 |
+
self.netScoreTwo = torch.nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
|
62 |
+
self.netScoreThr = torch.nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
|
63 |
+
self.netScoreFou = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
|
64 |
+
self.netScoreFiv = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
|
65 |
+
|
66 |
+
self.netCombine = torch.nn.Sequential(
|
67 |
+
torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
|
68 |
+
torch.nn.Sigmoid()
|
69 |
+
)
|
70 |
+
|
71 |
+
self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()})
|
72 |
+
# end
|
73 |
+
|
74 |
+
def forward(self, tenInput):
|
75 |
+
tenInput = tenInput * 255.0
|
76 |
+
tenInput = tenInput - torch.tensor(data=[104.00698793, 116.66876762, 122.67891434], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1)
|
77 |
+
|
78 |
+
tenVggOne = self.netVggOne(tenInput)
|
79 |
+
tenVggTwo = self.netVggTwo(tenVggOne)
|
80 |
+
tenVggThr = self.netVggThr(tenVggTwo)
|
81 |
+
tenVggFou = self.netVggFou(tenVggThr)
|
82 |
+
tenVggFiv = self.netVggFiv(tenVggFou)
|
83 |
+
|
84 |
+
tenScoreOne = self.netScoreOne(tenVggOne)
|
85 |
+
tenScoreTwo = self.netScoreTwo(tenVggTwo)
|
86 |
+
tenScoreThr = self.netScoreThr(tenVggThr)
|
87 |
+
tenScoreFou = self.netScoreFou(tenVggFou)
|
88 |
+
tenScoreFiv = self.netScoreFiv(tenVggFiv)
|
89 |
+
|
90 |
+
tenScoreOne = torch.nn.functional.interpolate(input=tenScoreOne, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
91 |
+
tenScoreTwo = torch.nn.functional.interpolate(input=tenScoreTwo, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
92 |
+
tenScoreThr = torch.nn.functional.interpolate(input=tenScoreThr, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
93 |
+
tenScoreFou = torch.nn.functional.interpolate(input=tenScoreFou, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
94 |
+
tenScoreFiv = torch.nn.functional.interpolate(input=tenScoreFiv, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
|
95 |
+
|
96 |
+
return self.netCombine(torch.cat([ tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv ], 1))
|
97 |
+
# end
|
98 |
+
# end
|
99 |
+
|
100 |
+
netNetwork = None
|
101 |
+
remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/network-bsds500.pth"
|
102 |
+
modeldir = os.path.join(models_path, "hed")
|
103 |
+
old_modeldir = os.path.dirname(os.path.realpath(__file__))
|
104 |
+
|
105 |
+
def apply_hed(input_image):
|
106 |
+
global netNetwork
|
107 |
+
if netNetwork is None:
|
108 |
+
modelpath = os.path.join(modeldir, "network-bsds500.pth")
|
109 |
+
old_modelpath = os.path.join(old_modeldir, "network-bsds500.pth")
|
110 |
+
if os.path.exists(old_modelpath):
|
111 |
+
modelpath = old_modelpath
|
112 |
+
elif not os.path.exists(modelpath):
|
113 |
+
from basicsr.utils.download_util import load_file_from_url
|
114 |
+
load_file_from_url(remote_model_path, model_dir=modeldir)
|
115 |
+
netNetwork = Network(modelpath)
|
116 |
+
netNetwork.to(devices.get_device_for("controlnet")).eval()
|
117 |
+
|
118 |
+
assert input_image.ndim == 3
|
119 |
+
input_image = input_image[:, :, ::-1].copy()
|
120 |
+
with torch.no_grad():
|
121 |
+
image_hed = torch.from_numpy(input_image).float().to(devices.get_device_for("controlnet"))
|
122 |
+
image_hed = image_hed / 255.0
|
123 |
+
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
124 |
+
edge = netNetwork(image_hed)[0]
|
125 |
+
edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
|
126 |
+
return edge[0]
|
127 |
+
|
128 |
+
def unload_hed_model():
|
129 |
+
global netNetwork
|
130 |
+
if netNetwork is not None:
|
131 |
+
netNetwork.cpu()
|
extensions/sd-webui-controlnet/annotator/keypose/__init__.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import os
|
6 |
+
from modules import devices
|
7 |
+
from modules.paths import models_path
|
8 |
+
|
9 |
+
import mmcv
|
10 |
+
from mmdet.apis import inference_detector, init_detector
|
11 |
+
from mmpose.apis import inference_top_down_pose_model
|
12 |
+
from mmpose.apis import init_pose_model, process_mmdet_results, vis_pose_result
|
13 |
+
|
14 |
+
|
15 |
+
def preprocessing(image, device):
|
16 |
+
# Resize
|
17 |
+
scale = 640 / max(image.shape[:2])
|
18 |
+
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
|
19 |
+
raw_image = image.astype(np.uint8)
|
20 |
+
|
21 |
+
# Subtract mean values
|
22 |
+
image = image.astype(np.float32)
|
23 |
+
image -= np.array(
|
24 |
+
[
|
25 |
+
float(104.008),
|
26 |
+
float(116.669),
|
27 |
+
float(122.675),
|
28 |
+
]
|
29 |
+
)
|
30 |
+
|
31 |
+
# Convert to torch.Tensor and add "batch" axis
|
32 |
+
image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
|
33 |
+
image = image.to(device)
|
34 |
+
|
35 |
+
return image, raw_image
|
36 |
+
|
37 |
+
|
38 |
+
def imshow_keypoints(img,
|
39 |
+
pose_result,
|
40 |
+
skeleton=None,
|
41 |
+
kpt_score_thr=0.1,
|
42 |
+
pose_kpt_color=None,
|
43 |
+
pose_link_color=None,
|
44 |
+
radius=4,
|
45 |
+
thickness=1):
|
46 |
+
"""Draw keypoints and links on an image.
|
47 |
+
Args:
|
48 |
+
img (ndarry): The image to draw poses on.
|
49 |
+
pose_result (list[kpts]): The poses to draw. Each element kpts is
|
50 |
+
a set of K keypoints as an Kx3 numpy.ndarray, where each
|
51 |
+
keypoint is represented as x, y, score.
|
52 |
+
kpt_score_thr (float, optional): Minimum score of keypoints
|
53 |
+
to be shown. Default: 0.3.
|
54 |
+
pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
|
55 |
+
the keypoint will not be drawn.
|
56 |
+
pose_link_color (np.array[Mx3]): Color of M links. If None, the
|
57 |
+
links will not be drawn.
|
58 |
+
thickness (int): Thickness of lines.
|
59 |
+
"""
|
60 |
+
|
61 |
+
img_h, img_w, _ = img.shape
|
62 |
+
img = np.zeros(img.shape)
|
63 |
+
|
64 |
+
for idx, kpts in enumerate(pose_result):
|
65 |
+
if idx > 1:
|
66 |
+
continue
|
67 |
+
kpts = kpts['keypoints']
|
68 |
+
# print(kpts)
|
69 |
+
kpts = np.array(kpts, copy=False)
|
70 |
+
|
71 |
+
# draw each point on image
|
72 |
+
if pose_kpt_color is not None:
|
73 |
+
assert len(pose_kpt_color) == len(kpts)
|
74 |
+
|
75 |
+
for kid, kpt in enumerate(kpts):
|
76 |
+
x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
|
77 |
+
|
78 |
+
if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
|
79 |
+
# skip the point that should not be drawn
|
80 |
+
continue
|
81 |
+
|
82 |
+
color = tuple(int(c) for c in pose_kpt_color[kid])
|
83 |
+
cv2.circle(img, (int(x_coord), int(y_coord)),
|
84 |
+
radius, color, -1)
|
85 |
+
|
86 |
+
# draw links
|
87 |
+
if skeleton is not None and pose_link_color is not None:
|
88 |
+
assert len(pose_link_color) == len(skeleton)
|
89 |
+
|
90 |
+
for sk_id, sk in enumerate(skeleton):
|
91 |
+
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
|
92 |
+
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
|
93 |
+
|
94 |
+
if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
|
95 |
+
or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
|
96 |
+
or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
|
97 |
+
# skip the link that should not be drawn
|
98 |
+
continue
|
99 |
+
color = tuple(int(c) for c in pose_link_color[sk_id])
|
100 |
+
cv2.line(img, pos1, pos2, color, thickness=thickness)
|
101 |
+
|
102 |
+
return img
|
103 |
+
|
104 |
+
|
105 |
+
human_det, pose_model = None, None
|
106 |
+
det_model_path = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
|
107 |
+
pose_model_path = "https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth"
|
108 |
+
|
109 |
+
modeldir = os.path.join(models_path, "keypose")
|
110 |
+
old_modeldir = os.path.dirname(os.path.realpath(__file__))
|
111 |
+
|
112 |
+
det_config = 'faster_rcnn_r50_fpn_coco.py'
|
113 |
+
pose_config = 'hrnet_w48_coco_256x192.py'
|
114 |
+
|
115 |
+
det_checkpoint = 'faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
|
116 |
+
pose_checkpoint = 'hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
|
117 |
+
det_cat_id = 1
|
118 |
+
bbox_thr = 0.2
|
119 |
+
|
120 |
+
skeleton = [
|
121 |
+
[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8],
|
122 |
+
[7, 9], [8, 10],
|
123 |
+
[1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]
|
124 |
+
]
|
125 |
+
|
126 |
+
pose_kpt_color = [
|
127 |
+
[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
|
128 |
+
[0, 255, 0],
|
129 |
+
[255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0],
|
130 |
+
[255, 128, 0],
|
131 |
+
[0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]
|
132 |
+
]
|
133 |
+
|
134 |
+
pose_link_color = [
|
135 |
+
[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
|
136 |
+
[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
|
137 |
+
[255, 128, 0],
|
138 |
+
[0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
|
139 |
+
[51, 153, 255],
|
140 |
+
[51, 153, 255], [51, 153, 255], [51, 153, 255]
|
141 |
+
]
|
142 |
+
|
143 |
+
def find_download_model(checkpoint, remote_path):
|
144 |
+
modelpath = os.path.join(modeldir, checkpoint)
|
145 |
+
old_modelpath = os.path.join(old_modeldir, checkpoint)
|
146 |
+
|
147 |
+
if os.path.exists(old_modelpath):
|
148 |
+
modelpath = old_modelpath
|
149 |
+
elif not os.path.exists(modelpath):
|
150 |
+
from basicsr.utils.download_util import load_file_from_url
|
151 |
+
load_file_from_url(remote_path, model_dir=modeldir)
|
152 |
+
|
153 |
+
return modelpath
|
154 |
+
|
155 |
+
def apply_keypose(input_image):
|
156 |
+
global human_det, pose_model
|
157 |
+
if netNetwork is None:
|
158 |
+
det_model_local = find_download_model(det_checkpoint, det_model_path)
|
159 |
+
hrnet_model_local = find_download_model(pose_checkpoint, pose_model_path)
|
160 |
+
det_config_mmcv = mmcv.Config.fromfile(det_config)
|
161 |
+
pose_config_mmcv = mmcv.Config.fromfile(pose_config)
|
162 |
+
human_det = init_detector(det_config_mmcv, det_model_local, device=devices.get_device_for("controlnet"))
|
163 |
+
pose_model = init_pose_model(pose_config_mmcv, hrnet_model_local, device=devices.get_device_for("controlnet"))
|
164 |
+
|
165 |
+
assert input_image.ndim == 3
|
166 |
+
input_image = input_image.copy()
|
167 |
+
with torch.no_grad():
|
168 |
+
image = torch.from_numpy(input_image).float().to(devices.get_device_for("controlnet"))
|
169 |
+
image = image / 255.0
|
170 |
+
mmdet_results = inference_detector(human_det, image)
|
171 |
+
|
172 |
+
# keep the person class bounding boxes.
|
173 |
+
person_results = process_mmdet_results(mmdet_results, det_cat_id)
|
174 |
+
|
175 |
+
return_heatmap = False
|
176 |
+
dataset = pose_model.cfg.data['test']['type']
|
177 |
+
|
178 |
+
# e.g. use ('backbone', ) to return backbone feature
|
179 |
+
output_layer_names = None
|
180 |
+
pose_results, _ = inference_top_down_pose_model(
|
181 |
+
pose_model,
|
182 |
+
image,
|
183 |
+
person_results,
|
184 |
+
bbox_thr=bbox_thr,
|
185 |
+
format='xyxy',
|
186 |
+
dataset=dataset,
|
187 |
+
dataset_info=None,
|
188 |
+
return_heatmap=return_heatmap,
|
189 |
+
outputs=output_layer_names
|
190 |
+
)
|
191 |
+
|
192 |
+
im_keypose_out = imshow_keypoints(
|
193 |
+
image,
|
194 |
+
pose_results,
|
195 |
+
skeleton=skeleton,
|
196 |
+
pose_kpt_color=pose_kpt_color,
|
197 |
+
pose_link_color=pose_link_color,
|
198 |
+
radius=2,
|
199 |
+
thickness=2
|
200 |
+
)
|
201 |
+
im_keypose_out = im_keypose_out.astype(np.uint8)
|
202 |
+
|
203 |
+
# image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
204 |
+
# edge = netNetwork(image_hed)[0]
|
205 |
+
# edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
|
206 |
+
return im_keypose_out
|
207 |
+
|
208 |
+
|
209 |
+
def unload_hed_model():
|
210 |
+
global netNetwork
|
211 |
+
if netNetwork is not None:
|
212 |
+
netNetwork.cpu()
|
extensions/sd-webui-controlnet/annotator/keypose/faster_rcnn_r50_fpn_coco.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoint_config = dict(interval=1)
|
2 |
+
# yapf:disable
|
3 |
+
log_config = dict(
|
4 |
+
interval=50,
|
5 |
+
hooks=[
|
6 |
+
dict(type='TextLoggerHook'),
|
7 |
+
# dict(type='TensorboardLoggerHook')
|
8 |
+
])
|
9 |
+
# yapf:enable
|
10 |
+
dist_params = dict(backend='nccl')
|
11 |
+
log_level = 'INFO'
|
12 |
+
load_from = None
|
13 |
+
resume_from = None
|
14 |
+
workflow = [('train', 1)]
|
15 |
+
# optimizer
|
16 |
+
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
17 |
+
optimizer_config = dict(grad_clip=None)
|
18 |
+
# learning policy
|
19 |
+
lr_config = dict(
|
20 |
+
policy='step',
|
21 |
+
warmup='linear',
|
22 |
+
warmup_iters=500,
|
23 |
+
warmup_ratio=0.001,
|
24 |
+
step=[8, 11])
|
25 |
+
total_epochs = 12
|
26 |
+
|
27 |
+
model = dict(
|
28 |
+
type='FasterRCNN',
|
29 |
+
pretrained='torchvision://resnet50',
|
30 |
+
backbone=dict(
|
31 |
+
type='ResNet',
|
32 |
+
depth=50,
|
33 |
+
num_stages=4,
|
34 |
+
out_indices=(0, 1, 2, 3),
|
35 |
+
frozen_stages=1,
|
36 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
37 |
+
norm_eval=True,
|
38 |
+
style='pytorch'),
|
39 |
+
neck=dict(
|
40 |
+
type='FPN',
|
41 |
+
in_channels=[256, 512, 1024, 2048],
|
42 |
+
out_channels=256,
|
43 |
+
num_outs=5),
|
44 |
+
rpn_head=dict(
|
45 |
+
type='RPNHead',
|
46 |
+
in_channels=256,
|
47 |
+
feat_channels=256,
|
48 |
+
anchor_generator=dict(
|
49 |
+
type='AnchorGenerator',
|
50 |
+
scales=[8],
|
51 |
+
ratios=[0.5, 1.0, 2.0],
|
52 |
+
strides=[4, 8, 16, 32, 64]),
|
53 |
+
bbox_coder=dict(
|
54 |
+
type='DeltaXYWHBBoxCoder',
|
55 |
+
target_means=[.0, .0, .0, .0],
|
56 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
57 |
+
loss_cls=dict(
|
58 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
59 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
60 |
+
roi_head=dict(
|
61 |
+
type='StandardRoIHead',
|
62 |
+
bbox_roi_extractor=dict(
|
63 |
+
type='SingleRoIExtractor',
|
64 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
65 |
+
out_channels=256,
|
66 |
+
featmap_strides=[4, 8, 16, 32]),
|
67 |
+
bbox_head=dict(
|
68 |
+
type='Shared2FCBBoxHead',
|
69 |
+
in_channels=256,
|
70 |
+
fc_out_channels=1024,
|
71 |
+
roi_feat_size=7,
|
72 |
+
num_classes=80,
|
73 |
+
bbox_coder=dict(
|
74 |
+
type='DeltaXYWHBBoxCoder',
|
75 |
+
target_means=[0., 0., 0., 0.],
|
76 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
77 |
+
reg_class_agnostic=False,
|
78 |
+
loss_cls=dict(
|
79 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
80 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
|
81 |
+
# model training and testing settings
|
82 |
+
train_cfg=dict(
|
83 |
+
rpn=dict(
|
84 |
+
assigner=dict(
|
85 |
+
type='MaxIoUAssigner',
|
86 |
+
pos_iou_thr=0.7,
|
87 |
+
neg_iou_thr=0.3,
|
88 |
+
min_pos_iou=0.3,
|
89 |
+
match_low_quality=True,
|
90 |
+
ignore_iof_thr=-1),
|
91 |
+
sampler=dict(
|
92 |
+
type='RandomSampler',
|
93 |
+
num=256,
|
94 |
+
pos_fraction=0.5,
|
95 |
+
neg_pos_ub=-1,
|
96 |
+
add_gt_as_proposals=False),
|
97 |
+
allowed_border=-1,
|
98 |
+
pos_weight=-1,
|
99 |
+
debug=False),
|
100 |
+
rpn_proposal=dict(
|
101 |
+
nms_pre=2000,
|
102 |
+
max_per_img=1000,
|
103 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
104 |
+
min_bbox_size=0),
|
105 |
+
rcnn=dict(
|
106 |
+
assigner=dict(
|
107 |
+
type='MaxIoUAssigner',
|
108 |
+
pos_iou_thr=0.5,
|
109 |
+
neg_iou_thr=0.5,
|
110 |
+
min_pos_iou=0.5,
|
111 |
+
match_low_quality=False,
|
112 |
+
ignore_iof_thr=-1),
|
113 |
+
sampler=dict(
|
114 |
+
type='RandomSampler',
|
115 |
+
num=512,
|
116 |
+
pos_fraction=0.25,
|
117 |
+
neg_pos_ub=-1,
|
118 |
+
add_gt_as_proposals=True),
|
119 |
+
pos_weight=-1,
|
120 |
+
debug=False)),
|
121 |
+
test_cfg=dict(
|
122 |
+
rpn=dict(
|
123 |
+
nms_pre=1000,
|
124 |
+
max_per_img=1000,
|
125 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
126 |
+
min_bbox_size=0),
|
127 |
+
rcnn=dict(
|
128 |
+
score_thr=0.05,
|
129 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
130 |
+
max_per_img=100)
|
131 |
+
# soft-nms is also supported for rcnn testing
|
132 |
+
# e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
|
133 |
+
))
|
134 |
+
|
135 |
+
dataset_type = 'CocoDataset'
|
136 |
+
data_root = 'data/coco'
|
137 |
+
img_norm_cfg = dict(
|
138 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
139 |
+
train_pipeline = [
|
140 |
+
dict(type='LoadImageFromFile'),
|
141 |
+
dict(type='LoadAnnotations', with_bbox=True),
|
142 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
143 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
144 |
+
dict(type='Normalize', **img_norm_cfg),
|
145 |
+
dict(type='Pad', size_divisor=32),
|
146 |
+
dict(type='DefaultFormatBundle'),
|
147 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
148 |
+
]
|
149 |
+
test_pipeline = [
|
150 |
+
dict(type='LoadImageFromFile'),
|
151 |
+
dict(
|
152 |
+
type='MultiScaleFlipAug',
|
153 |
+
img_scale=(1333, 800),
|
154 |
+
flip=False,
|
155 |
+
transforms=[
|
156 |
+
dict(type='Resize', keep_ratio=True),
|
157 |
+
dict(type='RandomFlip'),
|
158 |
+
dict(type='Normalize', **img_norm_cfg),
|
159 |
+
dict(type='Pad', size_divisor=32),
|
160 |
+
dict(type='DefaultFormatBundle'),
|
161 |
+
dict(type='Collect', keys=['img']),
|
162 |
+
])
|
163 |
+
]
|
164 |
+
data = dict(
|
165 |
+
samples_per_gpu=2,
|
166 |
+
workers_per_gpu=2,
|
167 |
+
train=dict(
|
168 |
+
type=dataset_type,
|
169 |
+
ann_file=f'{data_root}/annotations/instances_train2017.json',
|
170 |
+
img_prefix=f'{data_root}/train2017/',
|
171 |
+
pipeline=train_pipeline),
|
172 |
+
val=dict(
|
173 |
+
type=dataset_type,
|
174 |
+
ann_file=f'{data_root}/annotations/instances_val2017.json',
|
175 |
+
img_prefix=f'{data_root}/val2017/',
|
176 |
+
pipeline=test_pipeline),
|
177 |
+
test=dict(
|
178 |
+
type=dataset_type,
|
179 |
+
ann_file=f'{data_root}/annotations/instances_val2017.json',
|
180 |
+
img_prefix=f'{data_root}/val2017/',
|
181 |
+
pipeline=test_pipeline))
|
182 |
+
evaluation = dict(interval=1, metric='bbox')
|
extensions/sd-webui-controlnet/annotator/keypose/hrnet_w48_coco_256x192.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# _base_ = [
|
2 |
+
# '../../../../_base_/default_runtime.py',
|
3 |
+
# '../../../../_base_/datasets/coco.py'
|
4 |
+
# ]
|
5 |
+
evaluation = dict(interval=10, metric='mAP', save_best='AP')
|
6 |
+
|
7 |
+
optimizer = dict(
|
8 |
+
type='Adam',
|
9 |
+
lr=5e-4,
|
10 |
+
)
|
11 |
+
optimizer_config = dict(grad_clip=None)
|
12 |
+
# learning policy
|
13 |
+
lr_config = dict(
|
14 |
+
policy='step',
|
15 |
+
warmup='linear',
|
16 |
+
warmup_iters=500,
|
17 |
+
warmup_ratio=0.001,
|
18 |
+
step=[170, 200])
|
19 |
+
total_epochs = 210
|
20 |
+
channel_cfg = dict(
|
21 |
+
num_output_channels=17,
|
22 |
+
dataset_joints=17,
|
23 |
+
dataset_channel=[
|
24 |
+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
25 |
+
],
|
26 |
+
inference_channel=[
|
27 |
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
|
28 |
+
])
|
29 |
+
|
30 |
+
# model settings
|
31 |
+
model = dict(
|
32 |
+
type='TopDown',
|
33 |
+
pretrained='https://download.openmmlab.com/mmpose/'
|
34 |
+
'pretrain_models/hrnet_w48-8ef0771d.pth',
|
35 |
+
backbone=dict(
|
36 |
+
type='HRNet',
|
37 |
+
in_channels=3,
|
38 |
+
extra=dict(
|
39 |
+
stage1=dict(
|
40 |
+
num_modules=1,
|
41 |
+
num_branches=1,
|
42 |
+
block='BOTTLENECK',
|
43 |
+
num_blocks=(4, ),
|
44 |
+
num_channels=(64, )),
|
45 |
+
stage2=dict(
|
46 |
+
num_modules=1,
|
47 |
+
num_branches=2,
|
48 |
+
block='BASIC',
|
49 |
+
num_blocks=(4, 4),
|
50 |
+
num_channels=(48, 96)),
|
51 |
+
stage3=dict(
|
52 |
+
num_modules=4,
|
53 |
+
num_branches=3,
|
54 |
+
block='BASIC',
|
55 |
+
num_blocks=(4, 4, 4),
|
56 |
+
num_channels=(48, 96, 192)),
|
57 |
+
stage4=dict(
|
58 |
+
num_modules=3,
|
59 |
+
num_branches=4,
|
60 |
+
block='BASIC',
|
61 |
+
num_blocks=(4, 4, 4, 4),
|
62 |
+
num_channels=(48, 96, 192, 384))),
|
63 |
+
),
|
64 |
+
keypoint_head=dict(
|
65 |
+
type='TopdownHeatmapSimpleHead',
|
66 |
+
in_channels=48,
|
67 |
+
out_channels=channel_cfg['num_output_channels'],
|
68 |
+
num_deconv_layers=0,
|
69 |
+
extra=dict(final_conv_kernel=1, ),
|
70 |
+
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
|
71 |
+
train_cfg=dict(),
|
72 |
+
test_cfg=dict(
|
73 |
+
flip_test=True,
|
74 |
+
post_process='default',
|
75 |
+
shift_heatmap=True,
|
76 |
+
modulate_kernel=11))
|
77 |
+
|
78 |
+
data_cfg = dict(
|
79 |
+
image_size=[192, 256],
|
80 |
+
heatmap_size=[48, 64],
|
81 |
+
num_output_channels=channel_cfg['num_output_channels'],
|
82 |
+
num_joints=channel_cfg['dataset_joints'],
|
83 |
+
dataset_channel=channel_cfg['dataset_channel'],
|
84 |
+
inference_channel=channel_cfg['inference_channel'],
|
85 |
+
soft_nms=False,
|
86 |
+
nms_thr=1.0,
|
87 |
+
oks_thr=0.9,
|
88 |
+
vis_thr=0.2,
|
89 |
+
use_gt_bbox=False,
|
90 |
+
det_bbox_thr=0.0,
|
91 |
+
bbox_file='data/coco/person_detection_results/'
|
92 |
+
'COCO_val2017_detections_AP_H_56_person.json',
|
93 |
+
)
|
94 |
+
|
95 |
+
train_pipeline = [
|
96 |
+
dict(type='LoadImageFromFile'),
|
97 |
+
dict(type='TopDownGetBboxCenterScale', padding=1.25),
|
98 |
+
dict(type='TopDownRandomShiftBboxCenter', shift_factor=0.16, prob=0.3),
|
99 |
+
dict(type='TopDownRandomFlip', flip_prob=0.5),
|
100 |
+
dict(
|
101 |
+
type='TopDownHalfBodyTransform',
|
102 |
+
num_joints_half_body=8,
|
103 |
+
prob_half_body=0.3),
|
104 |
+
dict(
|
105 |
+
type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
|
106 |
+
dict(type='TopDownAffine'),
|
107 |
+
dict(type='ToTensor'),
|
108 |
+
dict(
|
109 |
+
type='NormalizeTensor',
|
110 |
+
mean=[0.485, 0.456, 0.406],
|
111 |
+
std=[0.229, 0.224, 0.225]),
|
112 |
+
dict(type='TopDownGenerateTarget', sigma=2),
|
113 |
+
dict(
|
114 |
+
type='Collect',
|
115 |
+
keys=['img', 'target', 'target_weight'],
|
116 |
+
meta_keys=[
|
117 |
+
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
|
118 |
+
'rotation', 'bbox_score', 'flip_pairs'
|
119 |
+
]),
|
120 |
+
]
|
121 |
+
|
122 |
+
val_pipeline = [
|
123 |
+
dict(type='LoadImageFromFile'),
|
124 |
+
dict(type='TopDownGetBboxCenterScale', padding=1.25),
|
125 |
+
dict(type='TopDownAffine'),
|
126 |
+
dict(type='ToTensor'),
|
127 |
+
dict(
|
128 |
+
type='NormalizeTensor',
|
129 |
+
mean=[0.485, 0.456, 0.406],
|
130 |
+
std=[0.229, 0.224, 0.225]),
|
131 |
+
dict(
|
132 |
+
type='Collect',
|
133 |
+
keys=['img'],
|
134 |
+
meta_keys=[
|
135 |
+
'image_file', 'center', 'scale', 'rotation', 'bbox_score',
|
136 |
+
'flip_pairs'
|
137 |
+
]),
|
138 |
+
]
|
139 |
+
|
140 |
+
test_pipeline = val_pipeline
|
141 |
+
|
142 |
+
data_root = 'data/coco'
|
143 |
+
data = dict(
|
144 |
+
samples_per_gpu=32,
|
145 |
+
workers_per_gpu=2,
|
146 |
+
val_dataloader=dict(samples_per_gpu=32),
|
147 |
+
test_dataloader=dict(samples_per_gpu=32),
|
148 |
+
train=dict(
|
149 |
+
type='TopDownCocoDataset',
|
150 |
+
ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
|
151 |
+
img_prefix=f'{data_root}/train2017/',
|
152 |
+
data_cfg=data_cfg,
|
153 |
+
pipeline=train_pipeline,
|
154 |
+
dataset_info={{_base_.dataset_info}}),
|
155 |
+
val=dict(
|
156 |
+
type='TopDownCocoDataset',
|
157 |
+
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
|
158 |
+
img_prefix=f'{data_root}/val2017/',
|
159 |
+
data_cfg=data_cfg,
|
160 |
+
pipeline=val_pipeline,
|
161 |
+
dataset_info={{_base_.dataset_info}}),
|
162 |
+
test=dict(
|
163 |
+
type='TopDownCocoDataset',
|
164 |
+
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
|
165 |
+
img_prefix=f'{data_root}/val2017/',
|
166 |
+
data_cfg=data_cfg,
|
167 |
+
pipeline=test_pipeline,
|
168 |
+
dataset_info={{_base_.dataset_info}}),
|
169 |
+
)
|
extensions/sd-webui-controlnet/annotator/leres/__init__.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
from modules import devices, shared
|
6 |
+
from modules.paths import models_path
|
7 |
+
from torchvision.transforms import transforms
|
8 |
+
|
9 |
+
# AdelaiDepth/LeReS imports
|
10 |
+
from .leres.depthmap import estimateleres, estimateboost
|
11 |
+
from .leres.multi_depth_model_woauxi import RelDepthModel
|
12 |
+
from .leres.net_tools import strip_prefix_if_present
|
13 |
+
|
14 |
+
# pix2pix/merge net imports
|
15 |
+
from .pix2pix.options.test_options import TestOptions
|
16 |
+
from .pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel
|
17 |
+
|
18 |
+
base_model_path = os.path.join(models_path, "leres")
|
19 |
+
old_modeldir = os.path.dirname(os.path.realpath(__file__))
|
20 |
+
|
21 |
+
remote_model_path_leres = "https://cloudstor.aarnet.edu.au/plus/s/lTIJF4vrvHCAI31/download"
|
22 |
+
remote_model_path_pix2pix = "https://sfu.ca/~yagiz/CVPR21/latest_net_G.pth"
|
23 |
+
|
24 |
+
model = None
|
25 |
+
pix2pixmodel = None
|
26 |
+
|
27 |
+
def unload_leres_model():
|
28 |
+
global model, pix2pixmodel
|
29 |
+
if model is not None:
|
30 |
+
model = model.cpu()
|
31 |
+
if pix2pixmodel is not None:
|
32 |
+
pix2pixmodel = pix2pixmodel.unload_network('G')
|
33 |
+
|
34 |
+
def apply_leres(input_image, thr_a, thr_b):
|
35 |
+
global model, pix2pixmodel
|
36 |
+
boost = shared.opts.data.get("control_net_monocular_depth_optim", False)
|
37 |
+
|
38 |
+
if model is None:
|
39 |
+
model_path = os.path.join(base_model_path, "res101.pth")
|
40 |
+
old_model_path = os.path.join(old_modeldir, "res101.pth")
|
41 |
+
|
42 |
+
if os.path.exists(old_model_path):
|
43 |
+
model_path = old_model_path
|
44 |
+
elif not os.path.exists(model_path):
|
45 |
+
from basicsr.utils.download_util import load_file_from_url
|
46 |
+
load_file_from_url(remote_model_path_leres, model_dir=base_model_path)
|
47 |
+
os.rename(os.path.join(base_model_path, 'download'), model_path)
|
48 |
+
|
49 |
+
if torch.cuda.is_available():
|
50 |
+
checkpoint = torch.load(model_path)
|
51 |
+
else:
|
52 |
+
checkpoint = torch.load(model_path,map_location=torch.device('cpu'))
|
53 |
+
|
54 |
+
model = RelDepthModel(backbone='resnext101')
|
55 |
+
model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True)
|
56 |
+
del checkpoint
|
57 |
+
|
58 |
+
if boost and pix2pixmodel is None:
|
59 |
+
pix2pixmodel_path = os.path.join(base_model_path, "latest_net_G.pth")
|
60 |
+
if not os.path.exists(pix2pixmodel_path):
|
61 |
+
from basicsr.utils.download_util import load_file_from_url
|
62 |
+
load_file_from_url(remote_model_path_pix2pix, model_dir=base_model_path)
|
63 |
+
|
64 |
+
opt = TestOptions().parse()
|
65 |
+
if not torch.cuda.is_available():
|
66 |
+
opt.gpu_ids = [] # cpu mode
|
67 |
+
pix2pixmodel = Pix2Pix4DepthModel(opt)
|
68 |
+
pix2pixmodel.save_dir = base_model_path
|
69 |
+
pix2pixmodel.load_networks('latest')
|
70 |
+
pix2pixmodel.eval()
|
71 |
+
|
72 |
+
if devices.get_device_for("controlnet").type != 'mps':
|
73 |
+
model = model.to(devices.get_device_for("controlnet"))
|
74 |
+
|
75 |
+
assert input_image.ndim == 3
|
76 |
+
height, width, dim = input_image.shape
|
77 |
+
|
78 |
+
with torch.no_grad():
|
79 |
+
|
80 |
+
if boost:
|
81 |
+
depth = estimateboost(input_image, model, 0, pix2pixmodel, max(width, height))
|
82 |
+
else:
|
83 |
+
depth = estimateleres(input_image, model, width, height)
|
84 |
+
|
85 |
+
numbytes=2
|
86 |
+
depth_min = depth.min()
|
87 |
+
depth_max = depth.max()
|
88 |
+
max_val = (2**(8*numbytes))-1
|
89 |
+
|
90 |
+
# check output before normalizing and mapping to 16 bit
|
91 |
+
if depth_max - depth_min > np.finfo("float").eps:
|
92 |
+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
93 |
+
else:
|
94 |
+
out = np.zeros(depth.shape)
|
95 |
+
|
96 |
+
# single channel, 16 bit image
|
97 |
+
depth_image = out.astype("uint16")
|
98 |
+
|
99 |
+
# convert to uint8
|
100 |
+
depth_image = cv2.convertScaleAbs(depth_image, alpha=(255.0/65535.0))
|
101 |
+
|
102 |
+
# remove near
|
103 |
+
if thr_a != 0:
|
104 |
+
thr_a = ((thr_a/100)*255)
|
105 |
+
depth_image = cv2.threshold(depth_image, thr_a, 255, cv2.THRESH_TOZERO)[1]
|
106 |
+
|
107 |
+
# invert image
|
108 |
+
depth_image = cv2.bitwise_not(depth_image)
|
109 |
+
|
110 |
+
# remove bg
|
111 |
+
if thr_b != 0:
|
112 |
+
thr_b = ((thr_b/100)*255)
|
113 |
+
depth_image = cv2.threshold(depth_image, thr_b, 255, cv2.THRESH_TOZERO)[1]
|
114 |
+
|
115 |
+
return depth_image
|
extensions/sd-webui-controlnet/annotator/leres/leres/LICENSE
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://github.com/thygate/stable-diffusion-webui-depthmap-script
|
2 |
+
|
3 |
+
MIT License
|
4 |
+
|
5 |
+
Copyright (c) 2023 Bob Thiry
|
6 |
+
|
7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
of this software and associated documentation files (the "Software"), to deal
|
9 |
+
in the Software without restriction, including without limitation the rights
|
10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
copies of the Software, and to permit persons to whom the Software is
|
12 |
+
furnished to do so, subject to the following conditions:
|
13 |
+
|
14 |
+
The above copyright notice and this permission notice shall be included in all
|
15 |
+
copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
SOFTWARE.
|
extensions/sd-webui-controlnet/annotator/leres/leres/Resnet.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn as NN
|
3 |
+
|
4 |
+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
5 |
+
'resnet152']
|
6 |
+
|
7 |
+
|
8 |
+
model_urls = {
|
9 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
10 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
11 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
12 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
13 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
14 |
+
}
|
15 |
+
|
16 |
+
|
17 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
18 |
+
"""3x3 convolution with padding"""
|
19 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
20 |
+
padding=1, bias=False)
|
21 |
+
|
22 |
+
|
23 |
+
class BasicBlock(nn.Module):
|
24 |
+
expansion = 1
|
25 |
+
|
26 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
27 |
+
super(BasicBlock, self).__init__()
|
28 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
29 |
+
self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
|
30 |
+
self.relu = nn.ReLU(inplace=True)
|
31 |
+
self.conv2 = conv3x3(planes, planes)
|
32 |
+
self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
|
33 |
+
self.downsample = downsample
|
34 |
+
self.stride = stride
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
residual = x
|
38 |
+
|
39 |
+
out = self.conv1(x)
|
40 |
+
out = self.bn1(out)
|
41 |
+
out = self.relu(out)
|
42 |
+
|
43 |
+
out = self.conv2(out)
|
44 |
+
out = self.bn2(out)
|
45 |
+
|
46 |
+
if self.downsample is not None:
|
47 |
+
residual = self.downsample(x)
|
48 |
+
|
49 |
+
out += residual
|
50 |
+
out = self.relu(out)
|
51 |
+
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
class Bottleneck(nn.Module):
|
56 |
+
expansion = 4
|
57 |
+
|
58 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
59 |
+
super(Bottleneck, self).__init__()
|
60 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
61 |
+
self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
|
62 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
63 |
+
padding=1, bias=False)
|
64 |
+
self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
|
65 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
66 |
+
self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d
|
67 |
+
self.relu = nn.ReLU(inplace=True)
|
68 |
+
self.downsample = downsample
|
69 |
+
self.stride = stride
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
residual = x
|
73 |
+
|
74 |
+
out = self.conv1(x)
|
75 |
+
out = self.bn1(out)
|
76 |
+
out = self.relu(out)
|
77 |
+
|
78 |
+
out = self.conv2(out)
|
79 |
+
out = self.bn2(out)
|
80 |
+
out = self.relu(out)
|
81 |
+
|
82 |
+
out = self.conv3(out)
|
83 |
+
out = self.bn3(out)
|
84 |
+
|
85 |
+
if self.downsample is not None:
|
86 |
+
residual = self.downsample(x)
|
87 |
+
|
88 |
+
out += residual
|
89 |
+
out = self.relu(out)
|
90 |
+
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class ResNet(nn.Module):
|
95 |
+
|
96 |
+
def __init__(self, block, layers, num_classes=1000):
|
97 |
+
self.inplanes = 64
|
98 |
+
super(ResNet, self).__init__()
|
99 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
100 |
+
bias=False)
|
101 |
+
self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d
|
102 |
+
self.relu = nn.ReLU(inplace=True)
|
103 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
104 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
105 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
106 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
107 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
108 |
+
#self.avgpool = nn.AvgPool2d(7, stride=1)
|
109 |
+
#self.fc = nn.Linear(512 * block.expansion, num_classes)
|
110 |
+
|
111 |
+
for m in self.modules():
|
112 |
+
if isinstance(m, nn.Conv2d):
|
113 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
114 |
+
elif isinstance(m, nn.BatchNorm2d):
|
115 |
+
nn.init.constant_(m.weight, 1)
|
116 |
+
nn.init.constant_(m.bias, 0)
|
117 |
+
|
118 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
119 |
+
downsample = None
|
120 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
121 |
+
downsample = nn.Sequential(
|
122 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
123 |
+
kernel_size=1, stride=stride, bias=False),
|
124 |
+
NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d
|
125 |
+
)
|
126 |
+
|
127 |
+
layers = []
|
128 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
129 |
+
self.inplanes = planes * block.expansion
|
130 |
+
for i in range(1, blocks):
|
131 |
+
layers.append(block(self.inplanes, planes))
|
132 |
+
|
133 |
+
return nn.Sequential(*layers)
|
134 |
+
|
135 |
+
def forward(self, x):
|
136 |
+
features = []
|
137 |
+
|
138 |
+
x = self.conv1(x)
|
139 |
+
x = self.bn1(x)
|
140 |
+
x = self.relu(x)
|
141 |
+
x = self.maxpool(x)
|
142 |
+
|
143 |
+
x = self.layer1(x)
|
144 |
+
features.append(x)
|
145 |
+
x = self.layer2(x)
|
146 |
+
features.append(x)
|
147 |
+
x = self.layer3(x)
|
148 |
+
features.append(x)
|
149 |
+
x = self.layer4(x)
|
150 |
+
features.append(x)
|
151 |
+
|
152 |
+
return features
|
153 |
+
|
154 |
+
|
155 |
+
def resnet18(pretrained=True, **kwargs):
|
156 |
+
"""Constructs a ResNet-18 model.
|
157 |
+
Args:
|
158 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
159 |
+
"""
|
160 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
161 |
+
return model
|
162 |
+
|
163 |
+
|
164 |
+
def resnet34(pretrained=True, **kwargs):
|
165 |
+
"""Constructs a ResNet-34 model.
|
166 |
+
Args:
|
167 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
168 |
+
"""
|
169 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
170 |
+
return model
|
171 |
+
|
172 |
+
|
173 |
+
def resnet50(pretrained=True, **kwargs):
|
174 |
+
"""Constructs a ResNet-50 model.
|
175 |
+
Args:
|
176 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
177 |
+
"""
|
178 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
179 |
+
|
180 |
+
return model
|
181 |
+
|
182 |
+
|
183 |
+
def resnet101(pretrained=True, **kwargs):
|
184 |
+
"""Constructs a ResNet-101 model.
|
185 |
+
Args:
|
186 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
187 |
+
"""
|
188 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
189 |
+
|
190 |
+
return model
|
191 |
+
|
192 |
+
|
193 |
+
def resnet152(pretrained=True, **kwargs):
|
194 |
+
"""Constructs a ResNet-152 model.
|
195 |
+
Args:
|
196 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
197 |
+
"""
|
198 |
+
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
199 |
+
return model
|
extensions/sd-webui-controlnet/annotator/leres/leres/Resnext_torch.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
try:
|
6 |
+
from urllib import urlretrieve
|
7 |
+
except ImportError:
|
8 |
+
from urllib.request import urlretrieve
|
9 |
+
|
10 |
+
__all__ = ['resnext101_32x8d']
|
11 |
+
|
12 |
+
|
13 |
+
model_urls = {
|
14 |
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
15 |
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
20 |
+
"""3x3 convolution with padding"""
|
21 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
22 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
23 |
+
|
24 |
+
|
25 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
26 |
+
"""1x1 convolution"""
|
27 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
28 |
+
|
29 |
+
|
30 |
+
class BasicBlock(nn.Module):
|
31 |
+
expansion = 1
|
32 |
+
|
33 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
34 |
+
base_width=64, dilation=1, norm_layer=None):
|
35 |
+
super(BasicBlock, self).__init__()
|
36 |
+
if norm_layer is None:
|
37 |
+
norm_layer = nn.BatchNorm2d
|
38 |
+
if groups != 1 or base_width != 64:
|
39 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
40 |
+
if dilation > 1:
|
41 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
42 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
43 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
44 |
+
self.bn1 = norm_layer(planes)
|
45 |
+
self.relu = nn.ReLU(inplace=True)
|
46 |
+
self.conv2 = conv3x3(planes, planes)
|
47 |
+
self.bn2 = norm_layer(planes)
|
48 |
+
self.downsample = downsample
|
49 |
+
self.stride = stride
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
identity = x
|
53 |
+
|
54 |
+
out = self.conv1(x)
|
55 |
+
out = self.bn1(out)
|
56 |
+
out = self.relu(out)
|
57 |
+
|
58 |
+
out = self.conv2(out)
|
59 |
+
out = self.bn2(out)
|
60 |
+
|
61 |
+
if self.downsample is not None:
|
62 |
+
identity = self.downsample(x)
|
63 |
+
|
64 |
+
out += identity
|
65 |
+
out = self.relu(out)
|
66 |
+
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
class Bottleneck(nn.Module):
|
71 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
72 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
73 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
74 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
75 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
76 |
+
|
77 |
+
expansion = 4
|
78 |
+
|
79 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
80 |
+
base_width=64, dilation=1, norm_layer=None):
|
81 |
+
super(Bottleneck, self).__init__()
|
82 |
+
if norm_layer is None:
|
83 |
+
norm_layer = nn.BatchNorm2d
|
84 |
+
width = int(planes * (base_width / 64.)) * groups
|
85 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
86 |
+
self.conv1 = conv1x1(inplanes, width)
|
87 |
+
self.bn1 = norm_layer(width)
|
88 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
89 |
+
self.bn2 = norm_layer(width)
|
90 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
91 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
92 |
+
self.relu = nn.ReLU(inplace=True)
|
93 |
+
self.downsample = downsample
|
94 |
+
self.stride = stride
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
identity = x
|
98 |
+
|
99 |
+
out = self.conv1(x)
|
100 |
+
out = self.bn1(out)
|
101 |
+
out = self.relu(out)
|
102 |
+
|
103 |
+
out = self.conv2(out)
|
104 |
+
out = self.bn2(out)
|
105 |
+
out = self.relu(out)
|
106 |
+
|
107 |
+
out = self.conv3(out)
|
108 |
+
out = self.bn3(out)
|
109 |
+
|
110 |
+
if self.downsample is not None:
|
111 |
+
identity = self.downsample(x)
|
112 |
+
|
113 |
+
out += identity
|
114 |
+
out = self.relu(out)
|
115 |
+
|
116 |
+
return out
|
117 |
+
|
118 |
+
|
119 |
+
class ResNet(nn.Module):
|
120 |
+
|
121 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
122 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
123 |
+
norm_layer=None):
|
124 |
+
super(ResNet, self).__init__()
|
125 |
+
if norm_layer is None:
|
126 |
+
norm_layer = nn.BatchNorm2d
|
127 |
+
self._norm_layer = norm_layer
|
128 |
+
|
129 |
+
self.inplanes = 64
|
130 |
+
self.dilation = 1
|
131 |
+
if replace_stride_with_dilation is None:
|
132 |
+
# each element in the tuple indicates if we should replace
|
133 |
+
# the 2x2 stride with a dilated convolution instead
|
134 |
+
replace_stride_with_dilation = [False, False, False]
|
135 |
+
if len(replace_stride_with_dilation) != 3:
|
136 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
137 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
138 |
+
self.groups = groups
|
139 |
+
self.base_width = width_per_group
|
140 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
141 |
+
bias=False)
|
142 |
+
self.bn1 = norm_layer(self.inplanes)
|
143 |
+
self.relu = nn.ReLU(inplace=True)
|
144 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
145 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
146 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
147 |
+
dilate=replace_stride_with_dilation[0])
|
148 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
149 |
+
dilate=replace_stride_with_dilation[1])
|
150 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
151 |
+
dilate=replace_stride_with_dilation[2])
|
152 |
+
#self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
153 |
+
#self.fc = nn.Linear(512 * block.expansion, num_classes)
|
154 |
+
|
155 |
+
for m in self.modules():
|
156 |
+
if isinstance(m, nn.Conv2d):
|
157 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
158 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
159 |
+
nn.init.constant_(m.weight, 1)
|
160 |
+
nn.init.constant_(m.bias, 0)
|
161 |
+
|
162 |
+
# Zero-initialize the last BN in each residual branch,
|
163 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
164 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
165 |
+
if zero_init_residual:
|
166 |
+
for m in self.modules():
|
167 |
+
if isinstance(m, Bottleneck):
|
168 |
+
nn.init.constant_(m.bn3.weight, 0)
|
169 |
+
elif isinstance(m, BasicBlock):
|
170 |
+
nn.init.constant_(m.bn2.weight, 0)
|
171 |
+
|
172 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
173 |
+
norm_layer = self._norm_layer
|
174 |
+
downsample = None
|
175 |
+
previous_dilation = self.dilation
|
176 |
+
if dilate:
|
177 |
+
self.dilation *= stride
|
178 |
+
stride = 1
|
179 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
180 |
+
downsample = nn.Sequential(
|
181 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
182 |
+
norm_layer(planes * block.expansion),
|
183 |
+
)
|
184 |
+
|
185 |
+
layers = []
|
186 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
187 |
+
self.base_width, previous_dilation, norm_layer))
|
188 |
+
self.inplanes = planes * block.expansion
|
189 |
+
for _ in range(1, blocks):
|
190 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
191 |
+
base_width=self.base_width, dilation=self.dilation,
|
192 |
+
norm_layer=norm_layer))
|
193 |
+
|
194 |
+
return nn.Sequential(*layers)
|
195 |
+
|
196 |
+
def _forward_impl(self, x):
|
197 |
+
# See note [TorchScript super()]
|
198 |
+
features = []
|
199 |
+
x = self.conv1(x)
|
200 |
+
x = self.bn1(x)
|
201 |
+
x = self.relu(x)
|
202 |
+
x = self.maxpool(x)
|
203 |
+
|
204 |
+
x = self.layer1(x)
|
205 |
+
features.append(x)
|
206 |
+
|
207 |
+
x = self.layer2(x)
|
208 |
+
features.append(x)
|
209 |
+
|
210 |
+
x = self.layer3(x)
|
211 |
+
features.append(x)
|
212 |
+
|
213 |
+
x = self.layer4(x)
|
214 |
+
features.append(x)
|
215 |
+
|
216 |
+
#x = self.avgpool(x)
|
217 |
+
#x = torch.flatten(x, 1)
|
218 |
+
#x = self.fc(x)
|
219 |
+
|
220 |
+
return features
|
221 |
+
|
222 |
+
def forward(self, x):
|
223 |
+
return self._forward_impl(x)
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
def resnext101_32x8d(pretrained=True, **kwargs):
|
228 |
+
"""Constructs a ResNet-152 model.
|
229 |
+
Args:
|
230 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
231 |
+
"""
|
232 |
+
kwargs['groups'] = 32
|
233 |
+
kwargs['width_per_group'] = 8
|
234 |
+
|
235 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
236 |
+
return model
|
237 |
+
|
238 |
+
|
239 |
+
|
240 |
+
if __name__ == '__main__':
|
241 |
+
import torch
|
242 |
+
model = resnext101_32x8d(True).cuda()
|
243 |
+
|
244 |
+
rgb = torch.rand((2, 3, 256, 256)).cuda()
|
245 |
+
out = model(rgb)
|
246 |
+
print(len(out))
|
247 |
+
|
extensions/sd-webui-controlnet/annotator/leres/leres/depthmap.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Author: thygate
|
2 |
+
# https://github.com/thygate/stable-diffusion-webui-depthmap-script
|
3 |
+
|
4 |
+
from modules import devices
|
5 |
+
from modules.shared import opts
|
6 |
+
from torchvision.transforms import transforms
|
7 |
+
from operator import getitem
|
8 |
+
|
9 |
+
import torch, gc
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
import skimage.measure
|
13 |
+
|
14 |
+
whole_size_threshold = 1600 # R_max from the paper
|
15 |
+
pix2pixsize = 1024
|
16 |
+
|
17 |
+
def scale_torch(img):
|
18 |
+
"""
|
19 |
+
Scale the image and output it in torch.tensor.
|
20 |
+
:param img: input rgb is in shape [H, W, C], input depth/disp is in shape [H, W]
|
21 |
+
:param scale: the scale factor. float
|
22 |
+
:return: img. [C, H, W]
|
23 |
+
"""
|
24 |
+
if len(img.shape) == 2:
|
25 |
+
img = img[np.newaxis, :, :]
|
26 |
+
if img.shape[2] == 3:
|
27 |
+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406) , (0.229, 0.224, 0.225) )])
|
28 |
+
img = transform(img.astype(np.float32))
|
29 |
+
else:
|
30 |
+
img = img.astype(np.float32)
|
31 |
+
img = torch.from_numpy(img)
|
32 |
+
return img
|
33 |
+
|
34 |
+
def estimateleres(img, model, w, h):
|
35 |
+
# leres transform input
|
36 |
+
rgb_c = img[:, :, ::-1].copy()
|
37 |
+
A_resize = cv2.resize(rgb_c, (w, h))
|
38 |
+
img_torch = scale_torch(A_resize)[None, :, :, :]
|
39 |
+
|
40 |
+
# compute
|
41 |
+
with torch.no_grad():
|
42 |
+
img_torch = img_torch.to(devices.get_device_for("controlnet"))
|
43 |
+
prediction = model.depth_model(img_torch)
|
44 |
+
|
45 |
+
prediction = prediction.squeeze().cpu().numpy()
|
46 |
+
prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)
|
47 |
+
|
48 |
+
return prediction
|
49 |
+
|
50 |
+
def generatemask(size):
|
51 |
+
# Generates a Guassian mask
|
52 |
+
mask = np.zeros(size, dtype=np.float32)
|
53 |
+
sigma = int(size[0]/16)
|
54 |
+
k_size = int(2 * np.ceil(2 * int(size[0]/16)) + 1)
|
55 |
+
mask[int(0.15*size[0]):size[0] - int(0.15*size[0]), int(0.15*size[1]): size[1] - int(0.15*size[1])] = 1
|
56 |
+
mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma)
|
57 |
+
mask = (mask - mask.min()) / (mask.max() - mask.min())
|
58 |
+
mask = mask.astype(np.float32)
|
59 |
+
return mask
|
60 |
+
|
61 |
+
def resizewithpool(img, size):
|
62 |
+
i_size = img.shape[0]
|
63 |
+
n = int(np.floor(i_size/size))
|
64 |
+
|
65 |
+
out = skimage.measure.block_reduce(img, (n, n), np.max)
|
66 |
+
return out
|
67 |
+
|
68 |
+
def rgb2gray(rgb):
|
69 |
+
# Converts rgb to gray
|
70 |
+
return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140])
|
71 |
+
|
72 |
+
def calculateprocessingres(img, basesize, confidence=0.1, scale_threshold=3, whole_size_threshold=3000):
|
73 |
+
# Returns the R_x resolution described in section 5 of the main paper.
|
74 |
+
|
75 |
+
# Parameters:
|
76 |
+
# img :input rgb image
|
77 |
+
# basesize : size the dilation kernel which is equal to receptive field of the network.
|
78 |
+
# confidence: value of x in R_x; allowed percentage of pixels that are not getting any contextual cue.
|
79 |
+
# scale_threshold: maximum allowed upscaling on the input image ; it has been set to 3.
|
80 |
+
# whole_size_threshold: maximum allowed resolution. (R_max from section 6 of the main paper)
|
81 |
+
|
82 |
+
# Returns:
|
83 |
+
# outputsize_scale*speed_scale :The computed R_x resolution
|
84 |
+
# patch_scale: K parameter from section 6 of the paper
|
85 |
+
|
86 |
+
# speed scale parameter is to process every image in a smaller size to accelerate the R_x resolution search
|
87 |
+
speed_scale = 32
|
88 |
+
image_dim = int(min(img.shape[0:2]))
|
89 |
+
|
90 |
+
gray = rgb2gray(img)
|
91 |
+
grad = np.abs(cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)) + np.abs(cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3))
|
92 |
+
grad = cv2.resize(grad, (image_dim, image_dim), cv2.INTER_AREA)
|
93 |
+
|
94 |
+
# thresholding the gradient map to generate the edge-map as a proxy of the contextual cues
|
95 |
+
m = grad.min()
|
96 |
+
M = grad.max()
|
97 |
+
middle = m + (0.4 * (M - m))
|
98 |
+
grad[grad < middle] = 0
|
99 |
+
grad[grad >= middle] = 1
|
100 |
+
|
101 |
+
# dilation kernel with size of the receptive field
|
102 |
+
kernel = np.ones((int(basesize/speed_scale), int(basesize/speed_scale)), float)
|
103 |
+
# dilation kernel with size of the a quarter of receptive field used to compute k
|
104 |
+
# as described in section 6 of main paper
|
105 |
+
kernel2 = np.ones((int(basesize / (4*speed_scale)), int(basesize / (4*speed_scale))), float)
|
106 |
+
|
107 |
+
# Output resolution limit set by the whole_size_threshold and scale_threshold.
|
108 |
+
threshold = min(whole_size_threshold, scale_threshold * max(img.shape[:2]))
|
109 |
+
|
110 |
+
outputsize_scale = basesize / speed_scale
|
111 |
+
for p_size in range(int(basesize/speed_scale), int(threshold/speed_scale), int(basesize / (2*speed_scale))):
|
112 |
+
grad_resized = resizewithpool(grad, p_size)
|
113 |
+
grad_resized = cv2.resize(grad_resized, (p_size, p_size), cv2.INTER_NEAREST)
|
114 |
+
grad_resized[grad_resized >= 0.5] = 1
|
115 |
+
grad_resized[grad_resized < 0.5] = 0
|
116 |
+
|
117 |
+
dilated = cv2.dilate(grad_resized, kernel, iterations=1)
|
118 |
+
meanvalue = (1-dilated).mean()
|
119 |
+
if meanvalue > confidence:
|
120 |
+
break
|
121 |
+
else:
|
122 |
+
outputsize_scale = p_size
|
123 |
+
|
124 |
+
grad_region = cv2.dilate(grad_resized, kernel2, iterations=1)
|
125 |
+
patch_scale = grad_region.mean()
|
126 |
+
|
127 |
+
return int(outputsize_scale*speed_scale), patch_scale
|
128 |
+
|
129 |
+
# Generate a double-input depth estimation
|
130 |
+
def doubleestimate(img, size1, size2, pix2pixsize, model, net_type, pix2pixmodel):
|
131 |
+
# Generate the low resolution estimation
|
132 |
+
estimate1 = singleestimate(img, size1, model, net_type)
|
133 |
+
# Resize to the inference size of merge network.
|
134 |
+
estimate1 = cv2.resize(estimate1, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
|
135 |
+
|
136 |
+
# Generate the high resolution estimation
|
137 |
+
estimate2 = singleestimate(img, size2, model, net_type)
|
138 |
+
# Resize to the inference size of merge network.
|
139 |
+
estimate2 = cv2.resize(estimate2, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
|
140 |
+
|
141 |
+
# Inference on the merge model
|
142 |
+
pix2pixmodel.set_input(estimate1, estimate2)
|
143 |
+
pix2pixmodel.test()
|
144 |
+
visuals = pix2pixmodel.get_current_visuals()
|
145 |
+
prediction_mapped = visuals['fake_B']
|
146 |
+
prediction_mapped = (prediction_mapped+1)/2
|
147 |
+
prediction_mapped = (prediction_mapped - torch.min(prediction_mapped)) / (
|
148 |
+
torch.max(prediction_mapped) - torch.min(prediction_mapped))
|
149 |
+
prediction_mapped = prediction_mapped.squeeze().cpu().numpy()
|
150 |
+
|
151 |
+
return prediction_mapped
|
152 |
+
|
153 |
+
# Generate a single-input depth estimation
|
154 |
+
def singleestimate(img, msize, model, net_type):
|
155 |
+
# if net_type == 0:
|
156 |
+
return estimateleres(img, model, msize, msize)
|
157 |
+
# else:
|
158 |
+
# return estimatemidasBoost(img, model, msize, msize)
|
159 |
+
|
160 |
+
def applyGridpatch(blsize, stride, img, box):
|
161 |
+
# Extract a simple grid patch.
|
162 |
+
counter1 = 0
|
163 |
+
patch_bound_list = {}
|
164 |
+
for k in range(blsize, img.shape[1] - blsize, stride):
|
165 |
+
for j in range(blsize, img.shape[0] - blsize, stride):
|
166 |
+
patch_bound_list[str(counter1)] = {}
|
167 |
+
patchbounds = [j - blsize, k - blsize, j - blsize + 2 * blsize, k - blsize + 2 * blsize]
|
168 |
+
patch_bound = [box[0] + patchbounds[1], box[1] + patchbounds[0], patchbounds[3] - patchbounds[1],
|
169 |
+
patchbounds[2] - patchbounds[0]]
|
170 |
+
patch_bound_list[str(counter1)]['rect'] = patch_bound
|
171 |
+
patch_bound_list[str(counter1)]['size'] = patch_bound[2]
|
172 |
+
counter1 = counter1 + 1
|
173 |
+
return patch_bound_list
|
174 |
+
|
175 |
+
# Generating local patches to perform the local refinement described in section 6 of the main paper.
|
176 |
+
def generatepatchs(img, base_size):
|
177 |
+
|
178 |
+
# Compute the gradients as a proxy of the contextual cues.
|
179 |
+
img_gray = rgb2gray(img)
|
180 |
+
whole_grad = np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 0, 1, ksize=3)) +\
|
181 |
+
np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 1, 0, ksize=3))
|
182 |
+
|
183 |
+
threshold = whole_grad[whole_grad > 0].mean()
|
184 |
+
whole_grad[whole_grad < threshold] = 0
|
185 |
+
|
186 |
+
# We use the integral image to speed-up the evaluation of the amount of gradients for each patch.
|
187 |
+
gf = whole_grad.sum()/len(whole_grad.reshape(-1))
|
188 |
+
grad_integral_image = cv2.integral(whole_grad)
|
189 |
+
|
190 |
+
# Variables are selected such that the initial patch size would be the receptive field size
|
191 |
+
# and the stride is set to 1/3 of the receptive field size.
|
192 |
+
blsize = int(round(base_size/2))
|
193 |
+
stride = int(round(blsize*0.75))
|
194 |
+
|
195 |
+
# Get initial Grid
|
196 |
+
patch_bound_list = applyGridpatch(blsize, stride, img, [0, 0, 0, 0])
|
197 |
+
|
198 |
+
# Refine initial Grid of patches by discarding the flat (in terms of gradients of the rgb image) ones. Refine
|
199 |
+
# each patch size to ensure that there will be enough depth cues for the network to generate a consistent depth map.
|
200 |
+
print("Selecting patches ...")
|
201 |
+
patch_bound_list = adaptiveselection(grad_integral_image, patch_bound_list, gf)
|
202 |
+
|
203 |
+
# Sort the patch list to make sure the merging operation will be done with the correct order: starting from biggest
|
204 |
+
# patch
|
205 |
+
patchset = sorted(patch_bound_list.items(), key=lambda x: getitem(x[1], 'size'), reverse=True)
|
206 |
+
return patchset
|
207 |
+
|
208 |
+
def getGF_fromintegral(integralimage, rect):
|
209 |
+
# Computes the gradient density of a given patch from the gradient integral image.
|
210 |
+
x1 = rect[1]
|
211 |
+
x2 = rect[1]+rect[3]
|
212 |
+
y1 = rect[0]
|
213 |
+
y2 = rect[0]+rect[2]
|
214 |
+
value = integralimage[x2, y2]-integralimage[x1, y2]-integralimage[x2, y1]+integralimage[x1, y1]
|
215 |
+
return value
|
216 |
+
|
217 |
+
# Adaptively select patches
|
218 |
+
def adaptiveselection(integral_grad, patch_bound_list, gf):
|
219 |
+
patchlist = {}
|
220 |
+
count = 0
|
221 |
+
height, width = integral_grad.shape
|
222 |
+
|
223 |
+
search_step = int(32/factor)
|
224 |
+
|
225 |
+
# Go through all patches
|
226 |
+
for c in range(len(patch_bound_list)):
|
227 |
+
# Get patch
|
228 |
+
bbox = patch_bound_list[str(c)]['rect']
|
229 |
+
|
230 |
+
# Compute the amount of gradients present in the patch from the integral image.
|
231 |
+
cgf = getGF_fromintegral(integral_grad, bbox)/(bbox[2]*bbox[3])
|
232 |
+
|
233 |
+
# Check if patching is beneficial by comparing the gradient density of the patch to
|
234 |
+
# the gradient density of the whole image
|
235 |
+
if cgf >= gf:
|
236 |
+
bbox_test = bbox.copy()
|
237 |
+
patchlist[str(count)] = {}
|
238 |
+
|
239 |
+
# Enlarge each patch until the gradient density of the patch is equal
|
240 |
+
# to the whole image gradient density
|
241 |
+
while True:
|
242 |
+
|
243 |
+
bbox_test[0] = bbox_test[0] - int(search_step/2)
|
244 |
+
bbox_test[1] = bbox_test[1] - int(search_step/2)
|
245 |
+
|
246 |
+
bbox_test[2] = bbox_test[2] + search_step
|
247 |
+
bbox_test[3] = bbox_test[3] + search_step
|
248 |
+
|
249 |
+
# Check if we are still within the image
|
250 |
+
if bbox_test[0] < 0 or bbox_test[1] < 0 or bbox_test[1] + bbox_test[3] >= height \
|
251 |
+
or bbox_test[0] + bbox_test[2] >= width:
|
252 |
+
break
|
253 |
+
|
254 |
+
# Compare gradient density
|
255 |
+
cgf = getGF_fromintegral(integral_grad, bbox_test)/(bbox_test[2]*bbox_test[3])
|
256 |
+
if cgf < gf:
|
257 |
+
break
|
258 |
+
bbox = bbox_test.copy()
|
259 |
+
|
260 |
+
# Add patch to selected patches
|
261 |
+
patchlist[str(count)]['rect'] = bbox
|
262 |
+
patchlist[str(count)]['size'] = bbox[2]
|
263 |
+
count = count + 1
|
264 |
+
|
265 |
+
# Return selected patches
|
266 |
+
return patchlist
|
267 |
+
|
268 |
+
def impatch(image, rect):
|
269 |
+
# Extract the given patch pixels from a given image.
|
270 |
+
w1 = rect[0]
|
271 |
+
h1 = rect[1]
|
272 |
+
w2 = w1 + rect[2]
|
273 |
+
h2 = h1 + rect[3]
|
274 |
+
image_patch = image[h1:h2, w1:w2]
|
275 |
+
return image_patch
|
276 |
+
|
277 |
+
class ImageandPatchs:
|
278 |
+
def __init__(self, root_dir, name, patchsinfo, rgb_image, scale=1):
|
279 |
+
self.root_dir = root_dir
|
280 |
+
self.patchsinfo = patchsinfo
|
281 |
+
self.name = name
|
282 |
+
self.patchs = patchsinfo
|
283 |
+
self.scale = scale
|
284 |
+
|
285 |
+
self.rgb_image = cv2.resize(rgb_image, (round(rgb_image.shape[1]*scale), round(rgb_image.shape[0]*scale)),
|
286 |
+
interpolation=cv2.INTER_CUBIC)
|
287 |
+
|
288 |
+
self.do_have_estimate = False
|
289 |
+
self.estimation_updated_image = None
|
290 |
+
self.estimation_base_image = None
|
291 |
+
|
292 |
+
def __len__(self):
|
293 |
+
return len(self.patchs)
|
294 |
+
|
295 |
+
def set_base_estimate(self, est):
|
296 |
+
self.estimation_base_image = est
|
297 |
+
if self.estimation_updated_image is not None:
|
298 |
+
self.do_have_estimate = True
|
299 |
+
|
300 |
+
def set_updated_estimate(self, est):
|
301 |
+
self.estimation_updated_image = est
|
302 |
+
if self.estimation_base_image is not None:
|
303 |
+
self.do_have_estimate = True
|
304 |
+
|
305 |
+
def __getitem__(self, index):
|
306 |
+
patch_id = int(self.patchs[index][0])
|
307 |
+
rect = np.array(self.patchs[index][1]['rect'])
|
308 |
+
msize = self.patchs[index][1]['size']
|
309 |
+
|
310 |
+
## applying scale to rect:
|
311 |
+
rect = np.round(rect * self.scale)
|
312 |
+
rect = rect.astype('int')
|
313 |
+
msize = round(msize * self.scale)
|
314 |
+
|
315 |
+
patch_rgb = impatch(self.rgb_image, rect)
|
316 |
+
if self.do_have_estimate:
|
317 |
+
patch_whole_estimate_base = impatch(self.estimation_base_image, rect)
|
318 |
+
patch_whole_estimate_updated = impatch(self.estimation_updated_image, rect)
|
319 |
+
return {'patch_rgb': patch_rgb, 'patch_whole_estimate_base': patch_whole_estimate_base,
|
320 |
+
'patch_whole_estimate_updated': patch_whole_estimate_updated, 'rect': rect,
|
321 |
+
'size': msize, 'id': patch_id}
|
322 |
+
else:
|
323 |
+
return {'patch_rgb': patch_rgb, 'rect': rect, 'size': msize, 'id': patch_id}
|
324 |
+
|
325 |
+
def print_options(self, opt):
|
326 |
+
"""Print and save options
|
327 |
+
|
328 |
+
It will print both current options and default values(if different).
|
329 |
+
It will save options into a text file / [checkpoints_dir] / opt.txt
|
330 |
+
"""
|
331 |
+
message = ''
|
332 |
+
message += '----------------- Options ---------------\n'
|
333 |
+
for k, v in sorted(vars(opt).items()):
|
334 |
+
comment = ''
|
335 |
+
default = self.parser.get_default(k)
|
336 |
+
if v != default:
|
337 |
+
comment = '\t[default: %s]' % str(default)
|
338 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
339 |
+
message += '----------------- End -------------------'
|
340 |
+
print(message)
|
341 |
+
|
342 |
+
# save to the disk
|
343 |
+
"""
|
344 |
+
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
345 |
+
util.mkdirs(expr_dir)
|
346 |
+
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
|
347 |
+
with open(file_name, 'wt') as opt_file:
|
348 |
+
opt_file.write(message)
|
349 |
+
opt_file.write('\n')
|
350 |
+
"""
|
351 |
+
|
352 |
+
def parse(self):
|
353 |
+
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
|
354 |
+
opt = self.gather_options()
|
355 |
+
opt.isTrain = self.isTrain # train or test
|
356 |
+
|
357 |
+
# process opt.suffix
|
358 |
+
if opt.suffix:
|
359 |
+
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
|
360 |
+
opt.name = opt.name + suffix
|
361 |
+
|
362 |
+
#self.print_options(opt)
|
363 |
+
|
364 |
+
# set gpu ids
|
365 |
+
str_ids = opt.gpu_ids.split(',')
|
366 |
+
opt.gpu_ids = []
|
367 |
+
for str_id in str_ids:
|
368 |
+
id = int(str_id)
|
369 |
+
if id >= 0:
|
370 |
+
opt.gpu_ids.append(id)
|
371 |
+
#if len(opt.gpu_ids) > 0:
|
372 |
+
# torch.cuda.set_device(opt.gpu_ids[0])
|
373 |
+
|
374 |
+
self.opt = opt
|
375 |
+
return self.opt
|
376 |
+
|
377 |
+
|
378 |
+
def estimateboost(img, model, model_type, pix2pixmodel, max_res=512):
|
379 |
+
global whole_size_threshold
|
380 |
+
|
381 |
+
# get settings
|
382 |
+
if hasattr(opts, 'depthmap_script_boost_rmax'):
|
383 |
+
whole_size_threshold = opts.depthmap_script_boost_rmax
|
384 |
+
|
385 |
+
if model_type == 0: #leres
|
386 |
+
net_receptive_field_size = 448
|
387 |
+
patch_netsize = 2 * net_receptive_field_size
|
388 |
+
elif model_type == 1: #dpt_beit_large_512
|
389 |
+
net_receptive_field_size = 512
|
390 |
+
patch_netsize = 2 * net_receptive_field_size
|
391 |
+
else: #other midas
|
392 |
+
net_receptive_field_size = 384
|
393 |
+
patch_netsize = 2 * net_receptive_field_size
|
394 |
+
|
395 |
+
gc.collect()
|
396 |
+
devices.torch_gc()
|
397 |
+
|
398 |
+
# Generate mask used to smoothly blend the local pathc estimations to the base estimate.
|
399 |
+
# It is arbitrarily large to avoid artifacts during rescaling for each crop.
|
400 |
+
mask_org = generatemask((3000, 3000))
|
401 |
+
mask = mask_org.copy()
|
402 |
+
|
403 |
+
# Value x of R_x defined in the section 5 of the main paper.
|
404 |
+
r_threshold_value = 0.2
|
405 |
+
#if R0:
|
406 |
+
# r_threshold_value = 0
|
407 |
+
|
408 |
+
input_resolution = img.shape
|
409 |
+
scale_threshold = 3 # Allows up-scaling with a scale up to 3
|
410 |
+
|
411 |
+
# Find the best input resolution R-x. The resolution search described in section 5-double estimation of the main paper and section B of the
|
412 |
+
# supplementary material.
|
413 |
+
whole_image_optimal_size, patch_scale = calculateprocessingres(img, net_receptive_field_size, r_threshold_value, scale_threshold, whole_size_threshold)
|
414 |
+
|
415 |
+
# print('wholeImage being processed in :', whole_image_optimal_size)
|
416 |
+
|
417 |
+
# Generate the base estimate using the double estimation.
|
418 |
+
whole_estimate = doubleestimate(img, net_receptive_field_size, whole_image_optimal_size, pix2pixsize, model, model_type, pix2pixmodel)
|
419 |
+
|
420 |
+
# Compute the multiplier described in section 6 of the main paper to make sure our initial patch can select
|
421 |
+
# small high-density regions of the image.
|
422 |
+
global factor
|
423 |
+
factor = max(min(1, 4 * patch_scale * whole_image_optimal_size / whole_size_threshold), 0.2)
|
424 |
+
# print('Adjust factor is:', 1/factor)
|
425 |
+
|
426 |
+
# Check if Local boosting is beneficial.
|
427 |
+
if max_res < whole_image_optimal_size:
|
428 |
+
# print("No Local boosting. Specified Max Res is smaller than R20, Returning doubleestimate result")
|
429 |
+
return cv2.resize(whole_estimate, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC)
|
430 |
+
|
431 |
+
# Compute the default target resolution.
|
432 |
+
if img.shape[0] > img.shape[1]:
|
433 |
+
a = 2 * whole_image_optimal_size
|
434 |
+
b = round(2 * whole_image_optimal_size * img.shape[1] / img.shape[0])
|
435 |
+
else:
|
436 |
+
a = round(2 * whole_image_optimal_size * img.shape[0] / img.shape[1])
|
437 |
+
b = 2 * whole_image_optimal_size
|
438 |
+
b = int(round(b / factor))
|
439 |
+
a = int(round(a / factor))
|
440 |
+
|
441 |
+
"""
|
442 |
+
# recompute a, b and saturate to max res.
|
443 |
+
if max(a,b) > max_res:
|
444 |
+
print('Default Res is higher than max-res: Reducing final resolution')
|
445 |
+
if img.shape[0] > img.shape[1]:
|
446 |
+
a = max_res
|
447 |
+
b = round(max_res * img.shape[1] / img.shape[0])
|
448 |
+
else:
|
449 |
+
a = round(max_res * img.shape[0] / img.shape[1])
|
450 |
+
b = max_res
|
451 |
+
b = int(b)
|
452 |
+
a = int(a)
|
453 |
+
"""
|
454 |
+
|
455 |
+
img = cv2.resize(img, (b, a), interpolation=cv2.INTER_CUBIC)
|
456 |
+
|
457 |
+
# Extract selected patches for local refinement
|
458 |
+
base_size = net_receptive_field_size * 2
|
459 |
+
patchset = generatepatchs(img, base_size)
|
460 |
+
|
461 |
+
# print('Target resolution: ', img.shape)
|
462 |
+
|
463 |
+
# Computing a scale in case user prompted to generate the results as the same resolution of the input.
|
464 |
+
# Notice that our method output resolution is independent of the input resolution and this parameter will only
|
465 |
+
# enable a scaling operation during the local patch merge implementation to generate results with the same resolution
|
466 |
+
# as the input.
|
467 |
+
"""
|
468 |
+
if output_resolution == 1:
|
469 |
+
mergein_scale = input_resolution[0] / img.shape[0]
|
470 |
+
print('Dynamicly change merged-in resolution; scale:', mergein_scale)
|
471 |
+
else:
|
472 |
+
mergein_scale = 1
|
473 |
+
"""
|
474 |
+
# always rescale to input res for now
|
475 |
+
mergein_scale = input_resolution[0] / img.shape[0]
|
476 |
+
|
477 |
+
imageandpatchs = ImageandPatchs('', '', patchset, img, mergein_scale)
|
478 |
+
whole_estimate_resized = cv2.resize(whole_estimate, (round(img.shape[1]*mergein_scale),
|
479 |
+
round(img.shape[0]*mergein_scale)), interpolation=cv2.INTER_CUBIC)
|
480 |
+
imageandpatchs.set_base_estimate(whole_estimate_resized.copy())
|
481 |
+
imageandpatchs.set_updated_estimate(whole_estimate_resized.copy())
|
482 |
+
|
483 |
+
print('Resulting depthmap resolution will be :', whole_estimate_resized.shape[:2])
|
484 |
+
print('Patches to process: '+str(len(imageandpatchs)))
|
485 |
+
|
486 |
+
# Enumerate through all patches, generate their estimations and refining the base estimate.
|
487 |
+
for patch_ind in range(len(imageandpatchs)):
|
488 |
+
|
489 |
+
# Get patch information
|
490 |
+
patch = imageandpatchs[patch_ind] # patch object
|
491 |
+
patch_rgb = patch['patch_rgb'] # rgb patch
|
492 |
+
patch_whole_estimate_base = patch['patch_whole_estimate_base'] # corresponding patch from base
|
493 |
+
rect = patch['rect'] # patch size and location
|
494 |
+
patch_id = patch['id'] # patch ID
|
495 |
+
org_size = patch_whole_estimate_base.shape # the original size from the unscaled input
|
496 |
+
print('\t Processing patch', patch_ind, '/', len(imageandpatchs)-1, '|', rect)
|
497 |
+
|
498 |
+
# We apply double estimation for patches. The high resolution value is fixed to twice the receptive
|
499 |
+
# field size of the network for patches to accelerate the process.
|
500 |
+
patch_estimation = doubleestimate(patch_rgb, net_receptive_field_size, patch_netsize, pix2pixsize, model, model_type, pix2pixmodel)
|
501 |
+
patch_estimation = cv2.resize(patch_estimation, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
|
502 |
+
patch_whole_estimate_base = cv2.resize(patch_whole_estimate_base, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
|
503 |
+
|
504 |
+
# Merging the patch estimation into the base estimate using our merge network:
|
505 |
+
# We feed the patch estimation and the same region from the updated base estimate to the merge network
|
506 |
+
# to generate the target estimate for the corresponding region.
|
507 |
+
pix2pixmodel.set_input(patch_whole_estimate_base, patch_estimation)
|
508 |
+
|
509 |
+
# Run merging network
|
510 |
+
pix2pixmodel.test()
|
511 |
+
visuals = pix2pixmodel.get_current_visuals()
|
512 |
+
|
513 |
+
prediction_mapped = visuals['fake_B']
|
514 |
+
prediction_mapped = (prediction_mapped+1)/2
|
515 |
+
prediction_mapped = prediction_mapped.squeeze().cpu().numpy()
|
516 |
+
|
517 |
+
mapped = prediction_mapped
|
518 |
+
|
519 |
+
# We use a simple linear polynomial to make sure the result of the merge network would match the values of
|
520 |
+
# base estimate
|
521 |
+
p_coef = np.polyfit(mapped.reshape(-1), patch_whole_estimate_base.reshape(-1), deg=1)
|
522 |
+
merged = np.polyval(p_coef, mapped.reshape(-1)).reshape(mapped.shape)
|
523 |
+
|
524 |
+
merged = cv2.resize(merged, (org_size[1],org_size[0]), interpolation=cv2.INTER_CUBIC)
|
525 |
+
|
526 |
+
# Get patch size and location
|
527 |
+
w1 = rect[0]
|
528 |
+
h1 = rect[1]
|
529 |
+
w2 = w1 + rect[2]
|
530 |
+
h2 = h1 + rect[3]
|
531 |
+
|
532 |
+
# To speed up the implementation, we only generate the Gaussian mask once with a sufficiently large size
|
533 |
+
# and resize it to our needed size while merging the patches.
|
534 |
+
if mask.shape != org_size:
|
535 |
+
mask = cv2.resize(mask_org, (org_size[1],org_size[0]), interpolation=cv2.INTER_LINEAR)
|
536 |
+
|
537 |
+
tobemergedto = imageandpatchs.estimation_updated_image
|
538 |
+
|
539 |
+
# Update the whole estimation:
|
540 |
+
# We use a simple Gaussian mask to blend the merged patch region with the base estimate to ensure seamless
|
541 |
+
# blending at the boundaries of the patch region.
|
542 |
+
tobemergedto[h1:h2, w1:w2] = np.multiply(tobemergedto[h1:h2, w1:w2], 1 - mask) + np.multiply(merged, mask)
|
543 |
+
imageandpatchs.set_updated_estimate(tobemergedto)
|
544 |
+
|
545 |
+
# output
|
546 |
+
return cv2.resize(imageandpatchs.estimation_updated_image, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC)
|
extensions/sd-webui-controlnet/annotator/leres/leres/multi_depth_model_woauxi.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import network_auxi as network
|
2 |
+
from .net_tools import get_func
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from modules import devices
|
6 |
+
|
7 |
+
class RelDepthModel(nn.Module):
|
8 |
+
def __init__(self, backbone='resnet50'):
|
9 |
+
super(RelDepthModel, self).__init__()
|
10 |
+
if backbone == 'resnet50':
|
11 |
+
encoder = 'resnet50_stride32'
|
12 |
+
elif backbone == 'resnext101':
|
13 |
+
encoder = 'resnext101_stride32x8d'
|
14 |
+
self.depth_model = DepthModel(encoder)
|
15 |
+
|
16 |
+
def inference(self, rgb):
|
17 |
+
with torch.no_grad():
|
18 |
+
input = rgb.to(self.depth_model.device)
|
19 |
+
depth = self.depth_model(input)
|
20 |
+
#pred_depth_out = depth - depth.min() + 0.01
|
21 |
+
return depth #pred_depth_out
|
22 |
+
|
23 |
+
|
24 |
+
class DepthModel(nn.Module):
|
25 |
+
def __init__(self, encoder):
|
26 |
+
super(DepthModel, self).__init__()
|
27 |
+
backbone = network.__name__.split('.')[-1] + '.' + encoder
|
28 |
+
self.encoder_modules = get_func(backbone)()
|
29 |
+
self.decoder_modules = network.Decoder()
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
lateral_out = self.encoder_modules(x)
|
33 |
+
out_logit = self.decoder_modules(lateral_out)
|
34 |
+
return out_logit
|
extensions/sd-webui-controlnet/annotator/leres/leres/net_tools.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
|
7 |
+
def get_func(func_name):
|
8 |
+
"""Helper to return a function object by name. func_name must identify a
|
9 |
+
function in this module or the path to a function relative to the base
|
10 |
+
'modeling' module.
|
11 |
+
"""
|
12 |
+
if func_name == '':
|
13 |
+
return None
|
14 |
+
try:
|
15 |
+
parts = func_name.split('.')
|
16 |
+
# Refers to a function in this module
|
17 |
+
if len(parts) == 1:
|
18 |
+
return globals()[parts[0]]
|
19 |
+
# Otherwise, assume we're referencing a module under modeling
|
20 |
+
module_name = 'annotator.leres.leres.' + '.'.join(parts[:-1])
|
21 |
+
module = importlib.import_module(module_name)
|
22 |
+
return getattr(module, parts[-1])
|
23 |
+
except Exception:
|
24 |
+
print('Failed to f1ind function: %s', func_name)
|
25 |
+
raise
|
26 |
+
|
27 |
+
def load_ckpt(args, depth_model, shift_model, focal_model):
|
28 |
+
"""
|
29 |
+
Load checkpoint.
|
30 |
+
"""
|
31 |
+
if os.path.isfile(args.load_ckpt):
|
32 |
+
print("loading checkpoint %s" % args.load_ckpt)
|
33 |
+
checkpoint = torch.load(args.load_ckpt)
|
34 |
+
if shift_model is not None:
|
35 |
+
shift_model.load_state_dict(strip_prefix_if_present(checkpoint['shift_model'], 'module.'),
|
36 |
+
strict=True)
|
37 |
+
if focal_model is not None:
|
38 |
+
focal_model.load_state_dict(strip_prefix_if_present(checkpoint['focal_model'], 'module.'),
|
39 |
+
strict=True)
|
40 |
+
depth_model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."),
|
41 |
+
strict=True)
|
42 |
+
del checkpoint
|
43 |
+
torch.cuda.empty_cache()
|
44 |
+
|
45 |
+
|
46 |
+
def strip_prefix_if_present(state_dict, prefix):
|
47 |
+
keys = sorted(state_dict.keys())
|
48 |
+
if not all(key.startswith(prefix) for key in keys):
|
49 |
+
return state_dict
|
50 |
+
stripped_state_dict = OrderedDict()
|
51 |
+
for key, value in state_dict.items():
|
52 |
+
stripped_state_dict[key.replace(prefix, "")] = value
|
53 |
+
return stripped_state_dict
|
extensions/sd-webui-controlnet/annotator/leres/leres/network_auxi.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.init as init
|
4 |
+
|
5 |
+
from . import Resnet, Resnext_torch
|
6 |
+
|
7 |
+
|
8 |
+
def resnet50_stride32():
|
9 |
+
return DepthNet(backbone='resnet', depth=50, upfactors=[2, 2, 2, 2])
|
10 |
+
|
11 |
+
def resnext101_stride32x8d():
|
12 |
+
return DepthNet(backbone='resnext101_32x8d', depth=101, upfactors=[2, 2, 2, 2])
|
13 |
+
|
14 |
+
|
15 |
+
class Decoder(nn.Module):
|
16 |
+
def __init__(self):
|
17 |
+
super(Decoder, self).__init__()
|
18 |
+
self.inchannels = [256, 512, 1024, 2048]
|
19 |
+
self.midchannels = [256, 256, 256, 512]
|
20 |
+
self.upfactors = [2,2,2,2]
|
21 |
+
self.outchannels = 1
|
22 |
+
|
23 |
+
self.conv = FTB(inchannels=self.inchannels[3], midchannels=self.midchannels[3])
|
24 |
+
self.conv1 = nn.Conv2d(in_channels=self.midchannels[3], out_channels=self.midchannels[2], kernel_size=3, padding=1, stride=1, bias=True)
|
25 |
+
self.upsample = nn.Upsample(scale_factor=self.upfactors[3], mode='bilinear', align_corners=True)
|
26 |
+
|
27 |
+
self.ffm2 = FFM(inchannels=self.inchannels[2], midchannels=self.midchannels[2], outchannels = self.midchannels[2], upfactor=self.upfactors[2])
|
28 |
+
self.ffm1 = FFM(inchannels=self.inchannels[1], midchannels=self.midchannels[1], outchannels = self.midchannels[1], upfactor=self.upfactors[1])
|
29 |
+
self.ffm0 = FFM(inchannels=self.inchannels[0], midchannels=self.midchannels[0], outchannels = self.midchannels[0], upfactor=self.upfactors[0])
|
30 |
+
|
31 |
+
self.outconv = AO(inchannels=self.midchannels[0], outchannels=self.outchannels, upfactor=2)
|
32 |
+
self._init_params()
|
33 |
+
|
34 |
+
def _init_params(self):
|
35 |
+
for m in self.modules():
|
36 |
+
if isinstance(m, nn.Conv2d):
|
37 |
+
init.normal_(m.weight, std=0.01)
|
38 |
+
if m.bias is not None:
|
39 |
+
init.constant_(m.bias, 0)
|
40 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
41 |
+
init.normal_(m.weight, std=0.01)
|
42 |
+
if m.bias is not None:
|
43 |
+
init.constant_(m.bias, 0)
|
44 |
+
elif isinstance(m, nn.BatchNorm2d): #NN.BatchNorm2d
|
45 |
+
init.constant_(m.weight, 1)
|
46 |
+
init.constant_(m.bias, 0)
|
47 |
+
elif isinstance(m, nn.Linear):
|
48 |
+
init.normal_(m.weight, std=0.01)
|
49 |
+
if m.bias is not None:
|
50 |
+
init.constant_(m.bias, 0)
|
51 |
+
|
52 |
+
def forward(self, features):
|
53 |
+
x_32x = self.conv(features[3]) # 1/32
|
54 |
+
x_32 = self.conv1(x_32x)
|
55 |
+
x_16 = self.upsample(x_32) # 1/16
|
56 |
+
|
57 |
+
x_8 = self.ffm2(features[2], x_16) # 1/8
|
58 |
+
x_4 = self.ffm1(features[1], x_8) # 1/4
|
59 |
+
x_2 = self.ffm0(features[0], x_4) # 1/2
|
60 |
+
#-----------------------------------------
|
61 |
+
x = self.outconv(x_2) # original size
|
62 |
+
return x
|
63 |
+
|
64 |
+
class DepthNet(nn.Module):
|
65 |
+
__factory = {
|
66 |
+
18: Resnet.resnet18,
|
67 |
+
34: Resnet.resnet34,
|
68 |
+
50: Resnet.resnet50,
|
69 |
+
101: Resnet.resnet101,
|
70 |
+
152: Resnet.resnet152
|
71 |
+
}
|
72 |
+
def __init__(self,
|
73 |
+
backbone='resnet',
|
74 |
+
depth=50,
|
75 |
+
upfactors=[2, 2, 2, 2]):
|
76 |
+
super(DepthNet, self).__init__()
|
77 |
+
self.backbone = backbone
|
78 |
+
self.depth = depth
|
79 |
+
self.pretrained = False
|
80 |
+
self.inchannels = [256, 512, 1024, 2048]
|
81 |
+
self.midchannels = [256, 256, 256, 512]
|
82 |
+
self.upfactors = upfactors
|
83 |
+
self.outchannels = 1
|
84 |
+
|
85 |
+
# Build model
|
86 |
+
if self.backbone == 'resnet':
|
87 |
+
if self.depth not in DepthNet.__factory:
|
88 |
+
raise KeyError("Unsupported depth:", self.depth)
|
89 |
+
self.encoder = DepthNet.__factory[depth](pretrained=self.pretrained)
|
90 |
+
elif self.backbone == 'resnext101_32x8d':
|
91 |
+
self.encoder = Resnext_torch.resnext101_32x8d(pretrained=self.pretrained)
|
92 |
+
else:
|
93 |
+
self.encoder = Resnext_torch.resnext101(pretrained=self.pretrained)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
x = self.encoder(x) # 1/32, 1/16, 1/8, 1/4
|
97 |
+
return x
|
98 |
+
|
99 |
+
|
100 |
+
class FTB(nn.Module):
|
101 |
+
def __init__(self, inchannels, midchannels=512):
|
102 |
+
super(FTB, self).__init__()
|
103 |
+
self.in1 = inchannels
|
104 |
+
self.mid = midchannels
|
105 |
+
self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1,
|
106 |
+
bias=True)
|
107 |
+
# NN.BatchNorm2d
|
108 |
+
self.conv_branch = nn.Sequential(nn.ReLU(inplace=True), \
|
109 |
+
nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3,
|
110 |
+
padding=1, stride=1, bias=True), \
|
111 |
+
nn.BatchNorm2d(num_features=self.mid), \
|
112 |
+
nn.ReLU(inplace=True), \
|
113 |
+
nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3,
|
114 |
+
padding=1, stride=1, bias=True))
|
115 |
+
self.relu = nn.ReLU(inplace=True)
|
116 |
+
|
117 |
+
self.init_params()
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x = self.conv1(x)
|
121 |
+
x = x + self.conv_branch(x)
|
122 |
+
x = self.relu(x)
|
123 |
+
|
124 |
+
return x
|
125 |
+
|
126 |
+
def init_params(self):
|
127 |
+
for m in self.modules():
|
128 |
+
if isinstance(m, nn.Conv2d):
|
129 |
+
init.normal_(m.weight, std=0.01)
|
130 |
+
if m.bias is not None:
|
131 |
+
init.constant_(m.bias, 0)
|
132 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
133 |
+
# init.kaiming_normal_(m.weight, mode='fan_out')
|
134 |
+
init.normal_(m.weight, std=0.01)
|
135 |
+
# init.xavier_normal_(m.weight)
|
136 |
+
if m.bias is not None:
|
137 |
+
init.constant_(m.bias, 0)
|
138 |
+
elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
|
139 |
+
init.constant_(m.weight, 1)
|
140 |
+
init.constant_(m.bias, 0)
|
141 |
+
elif isinstance(m, nn.Linear):
|
142 |
+
init.normal_(m.weight, std=0.01)
|
143 |
+
if m.bias is not None:
|
144 |
+
init.constant_(m.bias, 0)
|
145 |
+
|
146 |
+
|
147 |
+
class ATA(nn.Module):
|
148 |
+
def __init__(self, inchannels, reduction=8):
|
149 |
+
super(ATA, self).__init__()
|
150 |
+
self.inchannels = inchannels
|
151 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
152 |
+
self.fc = nn.Sequential(nn.Linear(self.inchannels * 2, self.inchannels // reduction),
|
153 |
+
nn.ReLU(inplace=True),
|
154 |
+
nn.Linear(self.inchannels // reduction, self.inchannels),
|
155 |
+
nn.Sigmoid())
|
156 |
+
self.init_params()
|
157 |
+
|
158 |
+
def forward(self, low_x, high_x):
|
159 |
+
n, c, _, _ = low_x.size()
|
160 |
+
x = torch.cat([low_x, high_x], 1)
|
161 |
+
x = self.avg_pool(x)
|
162 |
+
x = x.view(n, -1)
|
163 |
+
x = self.fc(x).view(n, c, 1, 1)
|
164 |
+
x = low_x * x + high_x
|
165 |
+
|
166 |
+
return x
|
167 |
+
|
168 |
+
def init_params(self):
|
169 |
+
for m in self.modules():
|
170 |
+
if isinstance(m, nn.Conv2d):
|
171 |
+
# init.kaiming_normal_(m.weight, mode='fan_out')
|
172 |
+
# init.normal(m.weight, std=0.01)
|
173 |
+
init.xavier_normal_(m.weight)
|
174 |
+
if m.bias is not None:
|
175 |
+
init.constant_(m.bias, 0)
|
176 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
177 |
+
# init.kaiming_normal_(m.weight, mode='fan_out')
|
178 |
+
# init.normal_(m.weight, std=0.01)
|
179 |
+
init.xavier_normal_(m.weight)
|
180 |
+
if m.bias is not None:
|
181 |
+
init.constant_(m.bias, 0)
|
182 |
+
elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
|
183 |
+
init.constant_(m.weight, 1)
|
184 |
+
init.constant_(m.bias, 0)
|
185 |
+
elif isinstance(m, nn.Linear):
|
186 |
+
init.normal_(m.weight, std=0.01)
|
187 |
+
if m.bias is not None:
|
188 |
+
init.constant_(m.bias, 0)
|
189 |
+
|
190 |
+
|
191 |
+
class FFM(nn.Module):
|
192 |
+
def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
|
193 |
+
super(FFM, self).__init__()
|
194 |
+
self.inchannels = inchannels
|
195 |
+
self.midchannels = midchannels
|
196 |
+
self.outchannels = outchannels
|
197 |
+
self.upfactor = upfactor
|
198 |
+
|
199 |
+
self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels)
|
200 |
+
# self.ata = ATA(inchannels = self.midchannels)
|
201 |
+
self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
|
202 |
+
|
203 |
+
self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
|
204 |
+
|
205 |
+
self.init_params()
|
206 |
+
|
207 |
+
def forward(self, low_x, high_x):
|
208 |
+
x = self.ftb1(low_x)
|
209 |
+
x = x + high_x
|
210 |
+
x = self.ftb2(x)
|
211 |
+
x = self.upsample(x)
|
212 |
+
|
213 |
+
return x
|
214 |
+
|
215 |
+
def init_params(self):
|
216 |
+
for m in self.modules():
|
217 |
+
if isinstance(m, nn.Conv2d):
|
218 |
+
# init.kaiming_normal_(m.weight, mode='fan_out')
|
219 |
+
init.normal_(m.weight, std=0.01)
|
220 |
+
# init.xavier_normal_(m.weight)
|
221 |
+
if m.bias is not None:
|
222 |
+
init.constant_(m.bias, 0)
|
223 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
224 |
+
# init.kaiming_normal_(m.weight, mode='fan_out')
|
225 |
+
init.normal_(m.weight, std=0.01)
|
226 |
+
# init.xavier_normal_(m.weight)
|
227 |
+
if m.bias is not None:
|
228 |
+
init.constant_(m.bias, 0)
|
229 |
+
elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d
|
230 |
+
init.constant_(m.weight, 1)
|
231 |
+
init.constant_(m.bias, 0)
|
232 |
+
elif isinstance(m, nn.Linear):
|
233 |
+
init.normal_(m.weight, std=0.01)
|
234 |
+
if m.bias is not None:
|
235 |
+
init.constant_(m.bias, 0)
|
236 |
+
|
237 |
+
|
238 |
+
class AO(nn.Module):
|
239 |
+
# Adaptive output module
|
240 |
+
def __init__(self, inchannels, outchannels, upfactor=2):
|
241 |
+
super(AO, self).__init__()
|
242 |
+
self.inchannels = inchannels
|
243 |
+
self.outchannels = outchannels
|
244 |
+
self.upfactor = upfactor
|
245 |
+
|
246 |
+
self.adapt_conv = nn.Sequential(
|
247 |
+
nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels // 2, kernel_size=3, padding=1,
|
248 |
+
stride=1, bias=True), \
|
249 |
+
nn.BatchNorm2d(num_features=self.inchannels // 2), \
|
250 |
+
nn.ReLU(inplace=True), \
|
251 |
+
nn.Conv2d(in_channels=self.inchannels // 2, out_channels=self.outchannels, kernel_size=3, padding=1,
|
252 |
+
stride=1, bias=True), \
|
253 |
+
nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True))
|
254 |
+
|
255 |
+
self.init_params()
|
256 |
+
|
257 |
+
def forward(self, x):
|
258 |
+
x = self.adapt_conv(x)
|
259 |
+
return x
|
260 |
+
|
261 |
+
def init_params(self):
|
262 |
+
for m in self.modules():
|
263 |
+
if isinstance(m, nn.Conv2d):
|
264 |
+
# init.kaiming_normal_(m.weight, mode='fan_out')
|
265 |
+
init.normal_(m.weight, std=0.01)
|
266 |
+
# init.xavier_normal_(m.weight)
|
267 |
+
if m.bias is not None:
|
268 |
+
init.constant_(m.bias, 0)
|
269 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
270 |
+
# init.kaiming_normal_(m.weight, mode='fan_out')
|
271 |
+
init.normal_(m.weight, std=0.01)
|
272 |
+
# init.xavier_normal_(m.weight)
|
273 |
+
if m.bias is not None:
|
274 |
+
init.constant_(m.bias, 0)
|
275 |
+
elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d
|
276 |
+
init.constant_(m.weight, 1)
|
277 |
+
init.constant_(m.bias, 0)
|
278 |
+
elif isinstance(m, nn.Linear):
|
279 |
+
init.normal_(m.weight, std=0.01)
|
280 |
+
if m.bias is not None:
|
281 |
+
init.constant_(m.bias, 0)
|
282 |
+
|
283 |
+
|
284 |
+
|
285 |
+
# ==============================================================================================================
|
286 |
+
|
287 |
+
|
288 |
+
class ResidualConv(nn.Module):
|
289 |
+
def __init__(self, inchannels):
|
290 |
+
super(ResidualConv, self).__init__()
|
291 |
+
# NN.BatchNorm2d
|
292 |
+
self.conv = nn.Sequential(
|
293 |
+
# nn.BatchNorm2d(num_features=inchannels),
|
294 |
+
nn.ReLU(inplace=False),
|
295 |
+
# nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True),
|
296 |
+
# nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True)
|
297 |
+
nn.Conv2d(in_channels=inchannels, out_channels=inchannels / 2, kernel_size=3, padding=1, stride=1,
|
298 |
+
bias=False),
|
299 |
+
nn.BatchNorm2d(num_features=inchannels / 2),
|
300 |
+
nn.ReLU(inplace=False),
|
301 |
+
nn.Conv2d(in_channels=inchannels / 2, out_channels=inchannels, kernel_size=3, padding=1, stride=1,
|
302 |
+
bias=False)
|
303 |
+
)
|
304 |
+
self.init_params()
|
305 |
+
|
306 |
+
def forward(self, x):
|
307 |
+
x = self.conv(x) + x
|
308 |
+
return x
|
309 |
+
|
310 |
+
def init_params(self):
|
311 |
+
for m in self.modules():
|
312 |
+
if isinstance(m, nn.Conv2d):
|
313 |
+
# init.kaiming_normal_(m.weight, mode='fan_out')
|
314 |
+
init.normal_(m.weight, std=0.01)
|
315 |
+
# init.xavier_normal_(m.weight)
|
316 |
+
if m.bias is not None:
|
317 |
+
init.constant_(m.bias, 0)
|
318 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
319 |
+
# init.kaiming_normal_(m.weight, mode='fan_out')
|
320 |
+
init.normal_(m.weight, std=0.01)
|
321 |
+
# init.xavier_normal_(m.weight)
|
322 |
+
if m.bias is not None:
|
323 |
+
init.constant_(m.bias, 0)
|
324 |
+
elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
|
325 |
+
init.constant_(m.weight, 1)
|
326 |
+
init.constant_(m.bias, 0)
|
327 |
+
elif isinstance(m, nn.Linear):
|
328 |
+
init.normal_(m.weight, std=0.01)
|
329 |
+
if m.bias is not None:
|
330 |
+
init.constant_(m.bias, 0)
|
331 |
+
|
332 |
+
|
333 |
+
class FeatureFusion(nn.Module):
|
334 |
+
def __init__(self, inchannels, outchannels):
|
335 |
+
super(FeatureFusion, self).__init__()
|
336 |
+
self.conv = ResidualConv(inchannels=inchannels)
|
337 |
+
# NN.BatchNorm2d
|
338 |
+
self.up = nn.Sequential(ResidualConv(inchannels=inchannels),
|
339 |
+
nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3,
|
340 |
+
stride=2, padding=1, output_padding=1),
|
341 |
+
nn.BatchNorm2d(num_features=outchannels),
|
342 |
+
nn.ReLU(inplace=True))
|
343 |
+
|
344 |
+
def forward(self, lowfeat, highfeat):
|
345 |
+
return self.up(highfeat + self.conv(lowfeat))
|
346 |
+
|
347 |
+
def init_params(self):
|
348 |
+
for m in self.modules():
|
349 |
+
if isinstance(m, nn.Conv2d):
|
350 |
+
# init.kaiming_normal_(m.weight, mode='fan_out')
|
351 |
+
init.normal_(m.weight, std=0.01)
|
352 |
+
# init.xavier_normal_(m.weight)
|
353 |
+
if m.bias is not None:
|
354 |
+
init.constant_(m.bias, 0)
|
355 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
356 |
+
# init.kaiming_normal_(m.weight, mode='fan_out')
|
357 |
+
init.normal_(m.weight, std=0.01)
|
358 |
+
# init.xavier_normal_(m.weight)
|
359 |
+
if m.bias is not None:
|
360 |
+
init.constant_(m.bias, 0)
|
361 |
+
elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
|
362 |
+
init.constant_(m.weight, 1)
|
363 |
+
init.constant_(m.bias, 0)
|
364 |
+
elif isinstance(m, nn.Linear):
|
365 |
+
init.normal_(m.weight, std=0.01)
|
366 |
+
if m.bias is not None:
|
367 |
+
init.constant_(m.bias, 0)
|
368 |
+
|
369 |
+
|
370 |
+
class SenceUnderstand(nn.Module):
|
371 |
+
def __init__(self, channels):
|
372 |
+
super(SenceUnderstand, self).__init__()
|
373 |
+
self.channels = channels
|
374 |
+
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
|
375 |
+
nn.ReLU(inplace=True))
|
376 |
+
self.pool = nn.AdaptiveAvgPool2d(8)
|
377 |
+
self.fc = nn.Sequential(nn.Linear(512 * 8 * 8, self.channels),
|
378 |
+
nn.ReLU(inplace=True))
|
379 |
+
self.conv2 = nn.Sequential(
|
380 |
+
nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0),
|
381 |
+
nn.ReLU(inplace=True))
|
382 |
+
self.initial_params()
|
383 |
+
|
384 |
+
def forward(self, x):
|
385 |
+
n, c, h, w = x.size()
|
386 |
+
x = self.conv1(x)
|
387 |
+
x = self.pool(x)
|
388 |
+
x = x.view(n, -1)
|
389 |
+
x = self.fc(x)
|
390 |
+
x = x.view(n, self.channels, 1, 1)
|
391 |
+
x = self.conv2(x)
|
392 |
+
x = x.repeat(1, 1, h, w)
|
393 |
+
return x
|
394 |
+
|
395 |
+
def initial_params(self, dev=0.01):
|
396 |
+
for m in self.modules():
|
397 |
+
if isinstance(m, nn.Conv2d):
|
398 |
+
# print torch.sum(m.weight)
|
399 |
+
m.weight.data.normal_(0, dev)
|
400 |
+
if m.bias is not None:
|
401 |
+
m.bias.data.fill_(0)
|
402 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
403 |
+
# print torch.sum(m.weight)
|
404 |
+
m.weight.data.normal_(0, dev)
|
405 |
+
if m.bias is not None:
|
406 |
+
m.bias.data.fill_(0)
|
407 |
+
elif isinstance(m, nn.Linear):
|
408 |
+
m.weight.data.normal_(0, dev)
|
409 |
+
|
410 |
+
|
411 |
+
if __name__ == '__main__':
|
412 |
+
net = DepthNet(depth=50, pretrained=True)
|
413 |
+
print(net)
|
414 |
+
inputs = torch.ones(4,3,128,128)
|
415 |
+
out = net(inputs)
|
416 |
+
print(out.size())
|
417 |
+
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/LICENSE
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://github.com/compphoto/BoostingMonocularDepth
|
2 |
+
|
3 |
+
Copyright 2021, Seyed Mahdi Hosseini Miangoleh, Sebastian Dille, Computational Photography Laboratory. All rights reserved.
|
4 |
+
|
5 |
+
This software is for academic use only. A redistribution of this
|
6 |
+
software, with or without modifications, has to be for academic
|
7 |
+
use only, while giving the appropriate credit to the original
|
8 |
+
authors of the software. The methods implemented as a part of
|
9 |
+
this software may be covered under patents or patent applications.
|
10 |
+
|
11 |
+
THIS SOFTWARE IS PROVIDED BY THE AUTHOR ''AS IS'' AND ANY EXPRESS OR IMPLIED
|
12 |
+
WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
13 |
+
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR
|
14 |
+
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
15 |
+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
16 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
17 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
18 |
+
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
|
19 |
+
ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/models/__init__.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
2 |
+
|
3 |
+
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
4 |
+
You need to implement the following five functions:
|
5 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
6 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
7 |
+
-- <forward>: produce intermediate results.
|
8 |
+
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
9 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
10 |
+
|
11 |
+
In the function <__init__>, you need to define four lists:
|
12 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
13 |
+
-- self.model_names (str list): define networks used in our training.
|
14 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
15 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
16 |
+
|
17 |
+
Now you can use the model class by specifying flag '--model dummy'.
|
18 |
+
See our template model class 'template_model.py' for more details.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import importlib
|
22 |
+
from .base_model import BaseModel
|
23 |
+
|
24 |
+
|
25 |
+
def find_model_using_name(model_name):
|
26 |
+
"""Import the module "models/[model_name]_model.py".
|
27 |
+
|
28 |
+
In the file, the class called DatasetNameModel() will
|
29 |
+
be instantiated. It has to be a subclass of BaseModel,
|
30 |
+
and it is case-insensitive.
|
31 |
+
"""
|
32 |
+
model_filename = "annotator.leres.pix2pix.models." + model_name + "_model"
|
33 |
+
modellib = importlib.import_module(model_filename)
|
34 |
+
model = None
|
35 |
+
target_model_name = model_name.replace('_', '') + 'model'
|
36 |
+
for name, cls in modellib.__dict__.items():
|
37 |
+
if name.lower() == target_model_name.lower() \
|
38 |
+
and issubclass(cls, BaseModel):
|
39 |
+
model = cls
|
40 |
+
|
41 |
+
if model is None:
|
42 |
+
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
43 |
+
exit(0)
|
44 |
+
|
45 |
+
return model
|
46 |
+
|
47 |
+
|
48 |
+
def get_option_setter(model_name):
|
49 |
+
"""Return the static method <modify_commandline_options> of the model class."""
|
50 |
+
model_class = find_model_using_name(model_name)
|
51 |
+
return model_class.modify_commandline_options
|
52 |
+
|
53 |
+
|
54 |
+
def create_model(opt):
|
55 |
+
"""Create a model given the option.
|
56 |
+
|
57 |
+
This function warps the class CustomDatasetDataLoader.
|
58 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
59 |
+
|
60 |
+
Example:
|
61 |
+
>>> from models import create_model
|
62 |
+
>>> model = create_model(opt)
|
63 |
+
"""
|
64 |
+
model = find_model_using_name(opt.model)
|
65 |
+
instance = model(opt)
|
66 |
+
print("model [%s] was created" % type(instance).__name__)
|
67 |
+
return instance
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/models/base_model.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch, gc
|
3 |
+
from modules import devices
|
4 |
+
from collections import OrderedDict
|
5 |
+
from abc import ABC, abstractmethod
|
6 |
+
from . import networks
|
7 |
+
|
8 |
+
|
9 |
+
class BaseModel(ABC):
|
10 |
+
"""This class is an abstract base class (ABC) for models.
|
11 |
+
To create a subclass, you need to implement the following five functions:
|
12 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
13 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
14 |
+
-- <forward>: produce intermediate results.
|
15 |
+
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
16 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, opt):
|
20 |
+
"""Initialize the BaseModel class.
|
21 |
+
|
22 |
+
Parameters:
|
23 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
24 |
+
|
25 |
+
When creating your custom class, you need to implement your own initialization.
|
26 |
+
In this function, you should first call <BaseModel.__init__(self, opt)>
|
27 |
+
Then, you need to define four lists:
|
28 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
29 |
+
-- self.model_names (str list): define networks used in our training.
|
30 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
31 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
32 |
+
"""
|
33 |
+
self.opt = opt
|
34 |
+
self.gpu_ids = opt.gpu_ids
|
35 |
+
self.isTrain = opt.isTrain
|
36 |
+
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
37 |
+
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
38 |
+
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
39 |
+
torch.backends.cudnn.benchmark = True
|
40 |
+
self.loss_names = []
|
41 |
+
self.model_names = []
|
42 |
+
self.visual_names = []
|
43 |
+
self.optimizers = []
|
44 |
+
self.image_paths = []
|
45 |
+
self.metric = 0 # used for learning rate policy 'plateau'
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def modify_commandline_options(parser, is_train):
|
49 |
+
"""Add new model-specific options, and rewrite default values for existing options.
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
parser -- original option parser
|
53 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
the modified parser.
|
57 |
+
"""
|
58 |
+
return parser
|
59 |
+
|
60 |
+
@abstractmethod
|
61 |
+
def set_input(self, input):
|
62 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
63 |
+
|
64 |
+
Parameters:
|
65 |
+
input (dict): includes the data itself and its metadata information.
|
66 |
+
"""
|
67 |
+
pass
|
68 |
+
|
69 |
+
@abstractmethod
|
70 |
+
def forward(self):
|
71 |
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
72 |
+
pass
|
73 |
+
|
74 |
+
@abstractmethod
|
75 |
+
def optimize_parameters(self):
|
76 |
+
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
77 |
+
pass
|
78 |
+
|
79 |
+
def setup(self, opt):
|
80 |
+
"""Load and print networks; create schedulers
|
81 |
+
|
82 |
+
Parameters:
|
83 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
84 |
+
"""
|
85 |
+
if self.isTrain:
|
86 |
+
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
87 |
+
if not self.isTrain or opt.continue_train:
|
88 |
+
load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
|
89 |
+
self.load_networks(load_suffix)
|
90 |
+
self.print_networks(opt.verbose)
|
91 |
+
|
92 |
+
def eval(self):
|
93 |
+
"""Make models eval mode during test time"""
|
94 |
+
for name in self.model_names:
|
95 |
+
if isinstance(name, str):
|
96 |
+
net = getattr(self, 'net' + name)
|
97 |
+
net.eval()
|
98 |
+
|
99 |
+
def test(self):
|
100 |
+
"""Forward function used in test time.
|
101 |
+
|
102 |
+
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
103 |
+
It also calls <compute_visuals> to produce additional visualization results
|
104 |
+
"""
|
105 |
+
with torch.no_grad():
|
106 |
+
self.forward()
|
107 |
+
self.compute_visuals()
|
108 |
+
|
109 |
+
def compute_visuals(self):
|
110 |
+
"""Calculate additional output images for visdom and HTML visualization"""
|
111 |
+
pass
|
112 |
+
|
113 |
+
def get_image_paths(self):
|
114 |
+
""" Return image paths that are used to load current data"""
|
115 |
+
return self.image_paths
|
116 |
+
|
117 |
+
def update_learning_rate(self):
|
118 |
+
"""Update learning rates for all the networks; called at the end of every epoch"""
|
119 |
+
old_lr = self.optimizers[0].param_groups[0]['lr']
|
120 |
+
for scheduler in self.schedulers:
|
121 |
+
if self.opt.lr_policy == 'plateau':
|
122 |
+
scheduler.step(self.metric)
|
123 |
+
else:
|
124 |
+
scheduler.step()
|
125 |
+
|
126 |
+
lr = self.optimizers[0].param_groups[0]['lr']
|
127 |
+
print('learning rate %.7f -> %.7f' % (old_lr, lr))
|
128 |
+
|
129 |
+
def get_current_visuals(self):
|
130 |
+
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
|
131 |
+
visual_ret = OrderedDict()
|
132 |
+
for name in self.visual_names:
|
133 |
+
if isinstance(name, str):
|
134 |
+
visual_ret[name] = getattr(self, name)
|
135 |
+
return visual_ret
|
136 |
+
|
137 |
+
def get_current_losses(self):
|
138 |
+
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
|
139 |
+
errors_ret = OrderedDict()
|
140 |
+
for name in self.loss_names:
|
141 |
+
if isinstance(name, str):
|
142 |
+
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
|
143 |
+
return errors_ret
|
144 |
+
|
145 |
+
def save_networks(self, epoch):
|
146 |
+
"""Save all the networks to the disk.
|
147 |
+
|
148 |
+
Parameters:
|
149 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
150 |
+
"""
|
151 |
+
for name in self.model_names:
|
152 |
+
if isinstance(name, str):
|
153 |
+
save_filename = '%s_net_%s.pth' % (epoch, name)
|
154 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
155 |
+
net = getattr(self, 'net' + name)
|
156 |
+
|
157 |
+
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
158 |
+
torch.save(net.module.cpu().state_dict(), save_path)
|
159 |
+
net.cuda(self.gpu_ids[0])
|
160 |
+
else:
|
161 |
+
torch.save(net.cpu().state_dict(), save_path)
|
162 |
+
|
163 |
+
def unload_network(self, name):
|
164 |
+
"""Unload network and gc.
|
165 |
+
"""
|
166 |
+
if isinstance(name, str):
|
167 |
+
net = getattr(self, 'net' + name)
|
168 |
+
del net
|
169 |
+
gc.collect()
|
170 |
+
devices.torch_gc()
|
171 |
+
return None
|
172 |
+
|
173 |
+
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
174 |
+
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
|
175 |
+
key = keys[i]
|
176 |
+
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
177 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
178 |
+
(key == 'running_mean' or key == 'running_var'):
|
179 |
+
if getattr(module, key) is None:
|
180 |
+
state_dict.pop('.'.join(keys))
|
181 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
182 |
+
(key == 'num_batches_tracked'):
|
183 |
+
state_dict.pop('.'.join(keys))
|
184 |
+
else:
|
185 |
+
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
186 |
+
|
187 |
+
def load_networks(self, epoch):
|
188 |
+
"""Load all the networks from the disk.
|
189 |
+
|
190 |
+
Parameters:
|
191 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
192 |
+
"""
|
193 |
+
for name in self.model_names:
|
194 |
+
if isinstance(name, str):
|
195 |
+
load_filename = '%s_net_%s.pth' % (epoch, name)
|
196 |
+
load_path = os.path.join(self.save_dir, load_filename)
|
197 |
+
net = getattr(self, 'net' + name)
|
198 |
+
if isinstance(net, torch.nn.DataParallel):
|
199 |
+
net = net.module
|
200 |
+
# print('Loading depth boost model from %s' % load_path)
|
201 |
+
# if you are using PyTorch newer than 0.4 (e.g., built from
|
202 |
+
# GitHub source), you can remove str() on self.device
|
203 |
+
state_dict = torch.load(load_path, map_location=str(self.device))
|
204 |
+
if hasattr(state_dict, '_metadata'):
|
205 |
+
del state_dict._metadata
|
206 |
+
|
207 |
+
# patch InstanceNorm checkpoints prior to 0.4
|
208 |
+
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
209 |
+
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
210 |
+
net.load_state_dict(state_dict)
|
211 |
+
|
212 |
+
def print_networks(self, verbose):
|
213 |
+
"""Print the total number of parameters in the network and (if verbose) network architecture
|
214 |
+
|
215 |
+
Parameters:
|
216 |
+
verbose (bool) -- if verbose: print the network architecture
|
217 |
+
"""
|
218 |
+
print('---------- Networks initialized -------------')
|
219 |
+
for name in self.model_names:
|
220 |
+
if isinstance(name, str):
|
221 |
+
net = getattr(self, 'net' + name)
|
222 |
+
num_params = 0
|
223 |
+
for param in net.parameters():
|
224 |
+
num_params += param.numel()
|
225 |
+
if verbose:
|
226 |
+
print(net)
|
227 |
+
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
|
228 |
+
print('-----------------------------------------------')
|
229 |
+
|
230 |
+
def set_requires_grad(self, nets, requires_grad=False):
|
231 |
+
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
232 |
+
Parameters:
|
233 |
+
nets (network list) -- a list of networks
|
234 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
235 |
+
"""
|
236 |
+
if not isinstance(nets, list):
|
237 |
+
nets = [nets]
|
238 |
+
for net in nets:
|
239 |
+
if net is not None:
|
240 |
+
for param in net.parameters():
|
241 |
+
param.requires_grad = requires_grad
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/models/base_model_hg.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class BaseModelHG():
|
5 |
+
def name(self):
|
6 |
+
return 'BaseModel'
|
7 |
+
|
8 |
+
def initialize(self, opt):
|
9 |
+
self.opt = opt
|
10 |
+
self.gpu_ids = opt.gpu_ids
|
11 |
+
self.isTrain = opt.isTrain
|
12 |
+
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
|
13 |
+
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
14 |
+
|
15 |
+
def set_input(self, input):
|
16 |
+
self.input = input
|
17 |
+
|
18 |
+
def forward(self):
|
19 |
+
pass
|
20 |
+
|
21 |
+
# used in test time, no backprop
|
22 |
+
def test(self):
|
23 |
+
pass
|
24 |
+
|
25 |
+
def get_image_paths(self):
|
26 |
+
pass
|
27 |
+
|
28 |
+
def optimize_parameters(self):
|
29 |
+
pass
|
30 |
+
|
31 |
+
def get_current_visuals(self):
|
32 |
+
return self.input
|
33 |
+
|
34 |
+
def get_current_errors(self):
|
35 |
+
return {}
|
36 |
+
|
37 |
+
def save(self, label):
|
38 |
+
pass
|
39 |
+
|
40 |
+
# helper saving function that can be used by subclasses
|
41 |
+
def save_network(self, network, network_label, epoch_label, gpu_ids):
|
42 |
+
save_filename = '_%s_net_%s.pth' % (epoch_label, network_label)
|
43 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
44 |
+
torch.save(network.cpu().state_dict(), save_path)
|
45 |
+
if len(gpu_ids) and torch.cuda.is_available():
|
46 |
+
network.cuda(device_id=gpu_ids[0])
|
47 |
+
|
48 |
+
# helper loading function that can be used by subclasses
|
49 |
+
def load_network(self, network, network_label, epoch_label):
|
50 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
51 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
52 |
+
print(save_path)
|
53 |
+
model = torch.load(save_path)
|
54 |
+
return model
|
55 |
+
# network.load_state_dict(torch.load(save_path))
|
56 |
+
|
57 |
+
def update_learning_rate():
|
58 |
+
pass
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/models/networks.py
ADDED
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import init
|
4 |
+
import functools
|
5 |
+
from torch.optim import lr_scheduler
|
6 |
+
|
7 |
+
|
8 |
+
###############################################################################
|
9 |
+
# Helper Functions
|
10 |
+
###############################################################################
|
11 |
+
|
12 |
+
|
13 |
+
class Identity(nn.Module):
|
14 |
+
def forward(self, x):
|
15 |
+
return x
|
16 |
+
|
17 |
+
|
18 |
+
def get_norm_layer(norm_type='instance'):
|
19 |
+
"""Return a normalization layer
|
20 |
+
|
21 |
+
Parameters:
|
22 |
+
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
23 |
+
|
24 |
+
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
25 |
+
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
26 |
+
"""
|
27 |
+
if norm_type == 'batch':
|
28 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
29 |
+
elif norm_type == 'instance':
|
30 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
31 |
+
elif norm_type == 'none':
|
32 |
+
def norm_layer(x): return Identity()
|
33 |
+
else:
|
34 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
35 |
+
return norm_layer
|
36 |
+
|
37 |
+
|
38 |
+
def get_scheduler(optimizer, opt):
|
39 |
+
"""Return a learning rate scheduler
|
40 |
+
|
41 |
+
Parameters:
|
42 |
+
optimizer -- the optimizer of the network
|
43 |
+
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
44 |
+
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
45 |
+
|
46 |
+
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
|
47 |
+
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
|
48 |
+
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
49 |
+
See https://pytorch.org/docs/stable/optim.html for more details.
|
50 |
+
"""
|
51 |
+
if opt.lr_policy == 'linear':
|
52 |
+
def lambda_rule(epoch):
|
53 |
+
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
|
54 |
+
return lr_l
|
55 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
56 |
+
elif opt.lr_policy == 'step':
|
57 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
58 |
+
elif opt.lr_policy == 'plateau':
|
59 |
+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
60 |
+
elif opt.lr_policy == 'cosine':
|
61 |
+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
|
62 |
+
else:
|
63 |
+
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
64 |
+
return scheduler
|
65 |
+
|
66 |
+
|
67 |
+
def init_weights(net, init_type='normal', init_gain=0.02):
|
68 |
+
"""Initialize network weights.
|
69 |
+
|
70 |
+
Parameters:
|
71 |
+
net (network) -- network to be initialized
|
72 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
73 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
74 |
+
|
75 |
+
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
76 |
+
work better for some applications. Feel free to try yourself.
|
77 |
+
"""
|
78 |
+
def init_func(m): # define the initialization function
|
79 |
+
classname = m.__class__.__name__
|
80 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
81 |
+
if init_type == 'normal':
|
82 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
83 |
+
elif init_type == 'xavier':
|
84 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
85 |
+
elif init_type == 'kaiming':
|
86 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
87 |
+
elif init_type == 'orthogonal':
|
88 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
89 |
+
else:
|
90 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
91 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
92 |
+
init.constant_(m.bias.data, 0.0)
|
93 |
+
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
94 |
+
init.normal_(m.weight.data, 1.0, init_gain)
|
95 |
+
init.constant_(m.bias.data, 0.0)
|
96 |
+
|
97 |
+
# print('initialize network with %s' % init_type)
|
98 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
99 |
+
|
100 |
+
|
101 |
+
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
102 |
+
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
103 |
+
Parameters:
|
104 |
+
net (network) -- the network to be initialized
|
105 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
106 |
+
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
107 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
108 |
+
|
109 |
+
Return an initialized network.
|
110 |
+
"""
|
111 |
+
if len(gpu_ids) > 0:
|
112 |
+
assert(torch.cuda.is_available())
|
113 |
+
net.to(gpu_ids[0])
|
114 |
+
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
|
115 |
+
init_weights(net, init_type, init_gain=init_gain)
|
116 |
+
return net
|
117 |
+
|
118 |
+
|
119 |
+
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
120 |
+
"""Create a generator
|
121 |
+
|
122 |
+
Parameters:
|
123 |
+
input_nc (int) -- the number of channels in input images
|
124 |
+
output_nc (int) -- the number of channels in output images
|
125 |
+
ngf (int) -- the number of filters in the last conv layer
|
126 |
+
netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
|
127 |
+
norm (str) -- the name of normalization layers used in the network: batch | instance | none
|
128 |
+
use_dropout (bool) -- if use dropout layers.
|
129 |
+
init_type (str) -- the name of our initialization method.
|
130 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
131 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
132 |
+
|
133 |
+
Returns a generator
|
134 |
+
|
135 |
+
Our current implementation provides two types of generators:
|
136 |
+
U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
|
137 |
+
The original U-Net paper: https://arxiv.org/abs/1505.04597
|
138 |
+
|
139 |
+
Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
|
140 |
+
Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
|
141 |
+
We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
|
142 |
+
|
143 |
+
|
144 |
+
The generator has been initialized by <init_net>. It uses RELU for non-linearity.
|
145 |
+
"""
|
146 |
+
net = None
|
147 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
148 |
+
|
149 |
+
if netG == 'resnet_9blocks':
|
150 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
|
151 |
+
elif netG == 'resnet_6blocks':
|
152 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
|
153 |
+
elif netG == 'resnet_12blocks':
|
154 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=12)
|
155 |
+
elif netG == 'unet_128':
|
156 |
+
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
157 |
+
elif netG == 'unet_256':
|
158 |
+
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
159 |
+
elif netG == 'unet_672':
|
160 |
+
net = UnetGenerator(input_nc, output_nc, 5, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
161 |
+
elif netG == 'unet_960':
|
162 |
+
net = UnetGenerator(input_nc, output_nc, 6, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
163 |
+
elif netG == 'unet_1024':
|
164 |
+
net = UnetGenerator(input_nc, output_nc, 10, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
165 |
+
else:
|
166 |
+
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
|
167 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
168 |
+
|
169 |
+
|
170 |
+
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
|
171 |
+
"""Create a discriminator
|
172 |
+
|
173 |
+
Parameters:
|
174 |
+
input_nc (int) -- the number of channels in input images
|
175 |
+
ndf (int) -- the number of filters in the first conv layer
|
176 |
+
netD (str) -- the architecture's name: basic | n_layers | pixel
|
177 |
+
n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
|
178 |
+
norm (str) -- the type of normalization layers used in the network.
|
179 |
+
init_type (str) -- the name of the initialization method.
|
180 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
181 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
182 |
+
|
183 |
+
Returns a discriminator
|
184 |
+
|
185 |
+
Our current implementation provides three types of discriminators:
|
186 |
+
[basic]: 'PatchGAN' classifier described in the original pix2pix paper.
|
187 |
+
It can classify whether 70×70 overlapping patches are real or fake.
|
188 |
+
Such a patch-level discriminator architecture has fewer parameters
|
189 |
+
than a full-image discriminator and can work on arbitrarily-sized images
|
190 |
+
in a fully convolutional fashion.
|
191 |
+
|
192 |
+
[n_layers]: With this mode, you can specify the number of conv layers in the discriminator
|
193 |
+
with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
|
194 |
+
|
195 |
+
[pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
|
196 |
+
It encourages greater color diversity but has no effect on spatial statistics.
|
197 |
+
|
198 |
+
The discriminator has been initialized by <init_net>. It uses Leakly RELU for non-linearity.
|
199 |
+
"""
|
200 |
+
net = None
|
201 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
202 |
+
|
203 |
+
if netD == 'basic': # default PatchGAN classifier
|
204 |
+
net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
|
205 |
+
elif netD == 'n_layers': # more options
|
206 |
+
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
|
207 |
+
elif netD == 'pixel': # classify if each pixel is real or fake
|
208 |
+
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
|
209 |
+
else:
|
210 |
+
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
|
211 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
212 |
+
|
213 |
+
|
214 |
+
##############################################################################
|
215 |
+
# Classes
|
216 |
+
##############################################################################
|
217 |
+
class GANLoss(nn.Module):
|
218 |
+
"""Define different GAN objectives.
|
219 |
+
|
220 |
+
The GANLoss class abstracts away the need to create the target label tensor
|
221 |
+
that has the same size as the input.
|
222 |
+
"""
|
223 |
+
|
224 |
+
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
|
225 |
+
""" Initialize the GANLoss class.
|
226 |
+
|
227 |
+
Parameters:
|
228 |
+
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
|
229 |
+
target_real_label (bool) - - label for a real image
|
230 |
+
target_fake_label (bool) - - label of a fake image
|
231 |
+
|
232 |
+
Note: Do not use sigmoid as the last layer of Discriminator.
|
233 |
+
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
|
234 |
+
"""
|
235 |
+
super(GANLoss, self).__init__()
|
236 |
+
self.register_buffer('real_label', torch.tensor(target_real_label))
|
237 |
+
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
238 |
+
self.gan_mode = gan_mode
|
239 |
+
if gan_mode == 'lsgan':
|
240 |
+
self.loss = nn.MSELoss()
|
241 |
+
elif gan_mode == 'vanilla':
|
242 |
+
self.loss = nn.BCEWithLogitsLoss()
|
243 |
+
elif gan_mode in ['wgangp']:
|
244 |
+
self.loss = None
|
245 |
+
else:
|
246 |
+
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
|
247 |
+
|
248 |
+
def get_target_tensor(self, prediction, target_is_real):
|
249 |
+
"""Create label tensors with the same size as the input.
|
250 |
+
|
251 |
+
Parameters:
|
252 |
+
prediction (tensor) - - tpyically the prediction from a discriminator
|
253 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
A label tensor filled with ground truth label, and with the size of the input
|
257 |
+
"""
|
258 |
+
|
259 |
+
if target_is_real:
|
260 |
+
target_tensor = self.real_label
|
261 |
+
else:
|
262 |
+
target_tensor = self.fake_label
|
263 |
+
return target_tensor.expand_as(prediction)
|
264 |
+
|
265 |
+
def __call__(self, prediction, target_is_real):
|
266 |
+
"""Calculate loss given Discriminator's output and grount truth labels.
|
267 |
+
|
268 |
+
Parameters:
|
269 |
+
prediction (tensor) - - tpyically the prediction output from a discriminator
|
270 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
the calculated loss.
|
274 |
+
"""
|
275 |
+
if self.gan_mode in ['lsgan', 'vanilla']:
|
276 |
+
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
277 |
+
loss = self.loss(prediction, target_tensor)
|
278 |
+
elif self.gan_mode == 'wgangp':
|
279 |
+
if target_is_real:
|
280 |
+
loss = -prediction.mean()
|
281 |
+
else:
|
282 |
+
loss = prediction.mean()
|
283 |
+
return loss
|
284 |
+
|
285 |
+
|
286 |
+
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
|
287 |
+
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
|
288 |
+
|
289 |
+
Arguments:
|
290 |
+
netD (network) -- discriminator network
|
291 |
+
real_data (tensor array) -- real images
|
292 |
+
fake_data (tensor array) -- generated images from the generator
|
293 |
+
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
|
294 |
+
type (str) -- if we mix real and fake data or not [real | fake | mixed].
|
295 |
+
constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2
|
296 |
+
lambda_gp (float) -- weight for this loss
|
297 |
+
|
298 |
+
Returns the gradient penalty loss
|
299 |
+
"""
|
300 |
+
if lambda_gp > 0.0:
|
301 |
+
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
|
302 |
+
interpolatesv = real_data
|
303 |
+
elif type == 'fake':
|
304 |
+
interpolatesv = fake_data
|
305 |
+
elif type == 'mixed':
|
306 |
+
alpha = torch.rand(real_data.shape[0], 1, device=device)
|
307 |
+
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
|
308 |
+
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
|
309 |
+
else:
|
310 |
+
raise NotImplementedError('{} not implemented'.format(type))
|
311 |
+
interpolatesv.requires_grad_(True)
|
312 |
+
disc_interpolates = netD(interpolatesv)
|
313 |
+
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
|
314 |
+
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
|
315 |
+
create_graph=True, retain_graph=True, only_inputs=True)
|
316 |
+
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
|
317 |
+
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
|
318 |
+
return gradient_penalty, gradients
|
319 |
+
else:
|
320 |
+
return 0.0, None
|
321 |
+
|
322 |
+
|
323 |
+
class ResnetGenerator(nn.Module):
|
324 |
+
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
325 |
+
|
326 |
+
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
327 |
+
"""
|
328 |
+
|
329 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
|
330 |
+
"""Construct a Resnet-based generator
|
331 |
+
|
332 |
+
Parameters:
|
333 |
+
input_nc (int) -- the number of channels in input images
|
334 |
+
output_nc (int) -- the number of channels in output images
|
335 |
+
ngf (int) -- the number of filters in the last conv layer
|
336 |
+
norm_layer -- normalization layer
|
337 |
+
use_dropout (bool) -- if use dropout layers
|
338 |
+
n_blocks (int) -- the number of ResNet blocks
|
339 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
340 |
+
"""
|
341 |
+
assert(n_blocks >= 0)
|
342 |
+
super(ResnetGenerator, self).__init__()
|
343 |
+
if type(norm_layer) == functools.partial:
|
344 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
345 |
+
else:
|
346 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
347 |
+
|
348 |
+
model = [nn.ReflectionPad2d(3),
|
349 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
350 |
+
norm_layer(ngf),
|
351 |
+
nn.ReLU(True)]
|
352 |
+
|
353 |
+
n_downsampling = 2
|
354 |
+
for i in range(n_downsampling): # add downsampling layers
|
355 |
+
mult = 2 ** i
|
356 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
357 |
+
norm_layer(ngf * mult * 2),
|
358 |
+
nn.ReLU(True)]
|
359 |
+
|
360 |
+
mult = 2 ** n_downsampling
|
361 |
+
for i in range(n_blocks): # add ResNet blocks
|
362 |
+
|
363 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
364 |
+
|
365 |
+
for i in range(n_downsampling): # add upsampling layers
|
366 |
+
mult = 2 ** (n_downsampling - i)
|
367 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
368 |
+
kernel_size=3, stride=2,
|
369 |
+
padding=1, output_padding=1,
|
370 |
+
bias=use_bias),
|
371 |
+
norm_layer(int(ngf * mult / 2)),
|
372 |
+
nn.ReLU(True)]
|
373 |
+
model += [nn.ReflectionPad2d(3)]
|
374 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
375 |
+
model += [nn.Tanh()]
|
376 |
+
|
377 |
+
self.model = nn.Sequential(*model)
|
378 |
+
|
379 |
+
def forward(self, input):
|
380 |
+
"""Standard forward"""
|
381 |
+
return self.model(input)
|
382 |
+
|
383 |
+
|
384 |
+
class ResnetBlock(nn.Module):
|
385 |
+
"""Define a Resnet block"""
|
386 |
+
|
387 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
388 |
+
"""Initialize the Resnet block
|
389 |
+
|
390 |
+
A resnet block is a conv block with skip connections
|
391 |
+
We construct a conv block with build_conv_block function,
|
392 |
+
and implement skip connections in <forward> function.
|
393 |
+
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
394 |
+
"""
|
395 |
+
super(ResnetBlock, self).__init__()
|
396 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
397 |
+
|
398 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
399 |
+
"""Construct a convolutional block.
|
400 |
+
|
401 |
+
Parameters:
|
402 |
+
dim (int) -- the number of channels in the conv layer.
|
403 |
+
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
404 |
+
norm_layer -- normalization layer
|
405 |
+
use_dropout (bool) -- if use dropout layers.
|
406 |
+
use_bias (bool) -- if the conv layer uses bias or not
|
407 |
+
|
408 |
+
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
409 |
+
"""
|
410 |
+
conv_block = []
|
411 |
+
p = 0
|
412 |
+
if padding_type == 'reflect':
|
413 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
414 |
+
elif padding_type == 'replicate':
|
415 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
416 |
+
elif padding_type == 'zero':
|
417 |
+
p = 1
|
418 |
+
else:
|
419 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
420 |
+
|
421 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
|
422 |
+
if use_dropout:
|
423 |
+
conv_block += [nn.Dropout(0.5)]
|
424 |
+
|
425 |
+
p = 0
|
426 |
+
if padding_type == 'reflect':
|
427 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
428 |
+
elif padding_type == 'replicate':
|
429 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
430 |
+
elif padding_type == 'zero':
|
431 |
+
p = 1
|
432 |
+
else:
|
433 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
434 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
435 |
+
|
436 |
+
return nn.Sequential(*conv_block)
|
437 |
+
|
438 |
+
def forward(self, x):
|
439 |
+
"""Forward function (with skip connections)"""
|
440 |
+
out = x + self.conv_block(x) # add skip connections
|
441 |
+
return out
|
442 |
+
|
443 |
+
|
444 |
+
class UnetGenerator(nn.Module):
|
445 |
+
"""Create a Unet-based generator"""
|
446 |
+
|
447 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
448 |
+
"""Construct a Unet generator
|
449 |
+
Parameters:
|
450 |
+
input_nc (int) -- the number of channels in input images
|
451 |
+
output_nc (int) -- the number of channels in output images
|
452 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
453 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
454 |
+
ngf (int) -- the number of filters in the last conv layer
|
455 |
+
norm_layer -- normalization layer
|
456 |
+
|
457 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
458 |
+
It is a recursive process.
|
459 |
+
"""
|
460 |
+
super(UnetGenerator, self).__init__()
|
461 |
+
# construct unet structure
|
462 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
|
463 |
+
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
|
464 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
465 |
+
# gradually reduce the number of filters from ngf * 8 to ngf
|
466 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
467 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
468 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
469 |
+
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
|
470 |
+
|
471 |
+
def forward(self, input):
|
472 |
+
"""Standard forward"""
|
473 |
+
return self.model(input)
|
474 |
+
|
475 |
+
|
476 |
+
class UnetSkipConnectionBlock(nn.Module):
|
477 |
+
"""Defines the Unet submodule with skip connection.
|
478 |
+
X -------------------identity----------------------
|
479 |
+
|-- downsampling -- |submodule| -- upsampling --|
|
480 |
+
"""
|
481 |
+
|
482 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
483 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
484 |
+
"""Construct a Unet submodule with skip connections.
|
485 |
+
|
486 |
+
Parameters:
|
487 |
+
outer_nc (int) -- the number of filters in the outer conv layer
|
488 |
+
inner_nc (int) -- the number of filters in the inner conv layer
|
489 |
+
input_nc (int) -- the number of channels in input images/features
|
490 |
+
submodule (UnetSkipConnectionBlock) -- previously defined submodules
|
491 |
+
outermost (bool) -- if this module is the outermost module
|
492 |
+
innermost (bool) -- if this module is the innermost module
|
493 |
+
norm_layer -- normalization layer
|
494 |
+
use_dropout (bool) -- if use dropout layers.
|
495 |
+
"""
|
496 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
497 |
+
self.outermost = outermost
|
498 |
+
if type(norm_layer) == functools.partial:
|
499 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
500 |
+
else:
|
501 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
502 |
+
if input_nc is None:
|
503 |
+
input_nc = outer_nc
|
504 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
505 |
+
stride=2, padding=1, bias=use_bias)
|
506 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
507 |
+
downnorm = norm_layer(inner_nc)
|
508 |
+
uprelu = nn.ReLU(True)
|
509 |
+
upnorm = norm_layer(outer_nc)
|
510 |
+
|
511 |
+
if outermost:
|
512 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
513 |
+
kernel_size=4, stride=2,
|
514 |
+
padding=1)
|
515 |
+
down = [downconv]
|
516 |
+
up = [uprelu, upconv, nn.Tanh()]
|
517 |
+
model = down + [submodule] + up
|
518 |
+
elif innermost:
|
519 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
520 |
+
kernel_size=4, stride=2,
|
521 |
+
padding=1, bias=use_bias)
|
522 |
+
down = [downrelu, downconv]
|
523 |
+
up = [uprelu, upconv, upnorm]
|
524 |
+
model = down + up
|
525 |
+
else:
|
526 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
527 |
+
kernel_size=4, stride=2,
|
528 |
+
padding=1, bias=use_bias)
|
529 |
+
down = [downrelu, downconv, downnorm]
|
530 |
+
up = [uprelu, upconv, upnorm]
|
531 |
+
|
532 |
+
if use_dropout:
|
533 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
534 |
+
else:
|
535 |
+
model = down + [submodule] + up
|
536 |
+
|
537 |
+
self.model = nn.Sequential(*model)
|
538 |
+
|
539 |
+
def forward(self, x):
|
540 |
+
if self.outermost:
|
541 |
+
return self.model(x)
|
542 |
+
else: # add skip connections
|
543 |
+
return torch.cat([x, self.model(x)], 1)
|
544 |
+
|
545 |
+
|
546 |
+
class NLayerDiscriminator(nn.Module):
|
547 |
+
"""Defines a PatchGAN discriminator"""
|
548 |
+
|
549 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
550 |
+
"""Construct a PatchGAN discriminator
|
551 |
+
|
552 |
+
Parameters:
|
553 |
+
input_nc (int) -- the number of channels in input images
|
554 |
+
ndf (int) -- the number of filters in the last conv layer
|
555 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
556 |
+
norm_layer -- normalization layer
|
557 |
+
"""
|
558 |
+
super(NLayerDiscriminator, self).__init__()
|
559 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
560 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
561 |
+
else:
|
562 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
563 |
+
|
564 |
+
kw = 4
|
565 |
+
padw = 1
|
566 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
567 |
+
nf_mult = 1
|
568 |
+
nf_mult_prev = 1
|
569 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
570 |
+
nf_mult_prev = nf_mult
|
571 |
+
nf_mult = min(2 ** n, 8)
|
572 |
+
sequence += [
|
573 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
574 |
+
norm_layer(ndf * nf_mult),
|
575 |
+
nn.LeakyReLU(0.2, True)
|
576 |
+
]
|
577 |
+
|
578 |
+
nf_mult_prev = nf_mult
|
579 |
+
nf_mult = min(2 ** n_layers, 8)
|
580 |
+
sequence += [
|
581 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
582 |
+
norm_layer(ndf * nf_mult),
|
583 |
+
nn.LeakyReLU(0.2, True)
|
584 |
+
]
|
585 |
+
|
586 |
+
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
587 |
+
self.model = nn.Sequential(*sequence)
|
588 |
+
|
589 |
+
def forward(self, input):
|
590 |
+
"""Standard forward."""
|
591 |
+
return self.model(input)
|
592 |
+
|
593 |
+
|
594 |
+
class PixelDiscriminator(nn.Module):
|
595 |
+
"""Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
|
596 |
+
|
597 |
+
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
|
598 |
+
"""Construct a 1x1 PatchGAN discriminator
|
599 |
+
|
600 |
+
Parameters:
|
601 |
+
input_nc (int) -- the number of channels in input images
|
602 |
+
ndf (int) -- the number of filters in the last conv layer
|
603 |
+
norm_layer -- normalization layer
|
604 |
+
"""
|
605 |
+
super(PixelDiscriminator, self).__init__()
|
606 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
607 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
608 |
+
else:
|
609 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
610 |
+
|
611 |
+
self.net = [
|
612 |
+
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
|
613 |
+
nn.LeakyReLU(0.2, True),
|
614 |
+
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
|
615 |
+
norm_layer(ndf * 2),
|
616 |
+
nn.LeakyReLU(0.2, True),
|
617 |
+
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
|
618 |
+
|
619 |
+
self.net = nn.Sequential(*self.net)
|
620 |
+
|
621 |
+
def forward(self, input):
|
622 |
+
"""Standard forward."""
|
623 |
+
return self.net(input)
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/models/pix2pix4depth_model.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .base_model import BaseModel
|
3 |
+
from . import networks
|
4 |
+
|
5 |
+
|
6 |
+
class Pix2Pix4DepthModel(BaseModel):
|
7 |
+
""" This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.
|
8 |
+
|
9 |
+
The model training requires '--dataset_mode aligned' dataset.
|
10 |
+
By default, it uses a '--netG unet256' U-Net generator,
|
11 |
+
a '--netD basic' discriminator (PatchGAN),
|
12 |
+
and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
|
13 |
+
|
14 |
+
pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
|
15 |
+
"""
|
16 |
+
@staticmethod
|
17 |
+
def modify_commandline_options(parser, is_train=True):
|
18 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
19 |
+
|
20 |
+
Parameters:
|
21 |
+
parser -- original option parser
|
22 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
the modified parser.
|
26 |
+
|
27 |
+
For pix2pix, we do not use image buffer
|
28 |
+
The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1
|
29 |
+
By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets.
|
30 |
+
"""
|
31 |
+
# changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/)
|
32 |
+
parser.set_defaults(input_nc=2,output_nc=1,norm='none', netG='unet_1024', dataset_mode='depthmerge')
|
33 |
+
if is_train:
|
34 |
+
parser.set_defaults(pool_size=0, gan_mode='vanilla',)
|
35 |
+
parser.add_argument('--lambda_L1', type=float, default=1000, help='weight for L1 loss')
|
36 |
+
return parser
|
37 |
+
|
38 |
+
def __init__(self, opt):
|
39 |
+
"""Initialize the pix2pix class.
|
40 |
+
|
41 |
+
Parameters:
|
42 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
43 |
+
"""
|
44 |
+
BaseModel.__init__(self, opt)
|
45 |
+
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
46 |
+
|
47 |
+
self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
|
48 |
+
# self.loss_names = ['G_L1']
|
49 |
+
|
50 |
+
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
51 |
+
if self.isTrain:
|
52 |
+
self.visual_names = ['outer','inner', 'fake_B', 'real_B']
|
53 |
+
else:
|
54 |
+
self.visual_names = ['fake_B']
|
55 |
+
|
56 |
+
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
|
57 |
+
if self.isTrain:
|
58 |
+
self.model_names = ['G','D']
|
59 |
+
else: # during test time, only load G
|
60 |
+
self.model_names = ['G']
|
61 |
+
|
62 |
+
# define networks (both generator and discriminator)
|
63 |
+
self.netG = networks.define_G(opt.input_nc, opt.output_nc, 64, 'unet_1024', 'none',
|
64 |
+
False, 'normal', 0.02, self.gpu_ids)
|
65 |
+
|
66 |
+
if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
|
67 |
+
self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
68 |
+
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
69 |
+
|
70 |
+
if self.isTrain:
|
71 |
+
# define loss functions
|
72 |
+
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
73 |
+
self.criterionL1 = torch.nn.L1Loss()
|
74 |
+
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
|
75 |
+
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=1e-4, betas=(opt.beta1, 0.999))
|
76 |
+
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=2e-06, betas=(opt.beta1, 0.999))
|
77 |
+
self.optimizers.append(self.optimizer_G)
|
78 |
+
self.optimizers.append(self.optimizer_D)
|
79 |
+
|
80 |
+
def set_input_train(self, input):
|
81 |
+
self.outer = input['data_outer'].to(self.device)
|
82 |
+
self.outer = torch.nn.functional.interpolate(self.outer,(1024,1024),mode='bilinear',align_corners=False)
|
83 |
+
|
84 |
+
self.inner = input['data_inner'].to(self.device)
|
85 |
+
self.inner = torch.nn.functional.interpolate(self.inner,(1024,1024),mode='bilinear',align_corners=False)
|
86 |
+
|
87 |
+
self.image_paths = input['image_path']
|
88 |
+
|
89 |
+
if self.isTrain:
|
90 |
+
self.gtfake = input['data_gtfake'].to(self.device)
|
91 |
+
self.gtfake = torch.nn.functional.interpolate(self.gtfake, (1024, 1024), mode='bilinear', align_corners=False)
|
92 |
+
self.real_B = self.gtfake
|
93 |
+
|
94 |
+
self.real_A = torch.cat((self.outer, self.inner), 1)
|
95 |
+
|
96 |
+
def set_input(self, outer, inner):
|
97 |
+
inner = torch.from_numpy(inner).unsqueeze(0).unsqueeze(0)
|
98 |
+
outer = torch.from_numpy(outer).unsqueeze(0).unsqueeze(0)
|
99 |
+
|
100 |
+
inner = (inner - torch.min(inner))/(torch.max(inner)-torch.min(inner))
|
101 |
+
outer = (outer - torch.min(outer))/(torch.max(outer)-torch.min(outer))
|
102 |
+
|
103 |
+
inner = self.normalize(inner)
|
104 |
+
outer = self.normalize(outer)
|
105 |
+
|
106 |
+
self.real_A = torch.cat((outer, inner), 1).to(self.device)
|
107 |
+
|
108 |
+
|
109 |
+
def normalize(self, input):
|
110 |
+
input = input * 2
|
111 |
+
input = input - 1
|
112 |
+
return input
|
113 |
+
|
114 |
+
def forward(self):
|
115 |
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
116 |
+
self.fake_B = self.netG(self.real_A) # G(A)
|
117 |
+
|
118 |
+
def backward_D(self):
|
119 |
+
"""Calculate GAN loss for the discriminator"""
|
120 |
+
# Fake; stop backprop to the generator by detaching fake_B
|
121 |
+
fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator
|
122 |
+
pred_fake = self.netD(fake_AB.detach())
|
123 |
+
self.loss_D_fake = self.criterionGAN(pred_fake, False)
|
124 |
+
# Real
|
125 |
+
real_AB = torch.cat((self.real_A, self.real_B), 1)
|
126 |
+
pred_real = self.netD(real_AB)
|
127 |
+
self.loss_D_real = self.criterionGAN(pred_real, True)
|
128 |
+
# combine loss and calculate gradients
|
129 |
+
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
|
130 |
+
self.loss_D.backward()
|
131 |
+
|
132 |
+
def backward_G(self):
|
133 |
+
"""Calculate GAN and L1 loss for the generator"""
|
134 |
+
# First, G(A) should fake the discriminator
|
135 |
+
fake_AB = torch.cat((self.real_A, self.fake_B), 1)
|
136 |
+
pred_fake = self.netD(fake_AB)
|
137 |
+
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
|
138 |
+
# Second, G(A) = B
|
139 |
+
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
|
140 |
+
# combine loss and calculate gradients
|
141 |
+
self.loss_G = self.loss_G_L1 + self.loss_G_GAN
|
142 |
+
self.loss_G.backward()
|
143 |
+
|
144 |
+
def optimize_parameters(self):
|
145 |
+
self.forward() # compute fake images: G(A)
|
146 |
+
# update D
|
147 |
+
self.set_requires_grad(self.netD, True) # enable backprop for D
|
148 |
+
self.optimizer_D.zero_grad() # set D's gradients to zero
|
149 |
+
self.backward_D() # calculate gradients for D
|
150 |
+
self.optimizer_D.step() # update D's weights
|
151 |
+
# update G
|
152 |
+
self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
|
153 |
+
self.optimizer_G.zero_grad() # set G's gradients to zero
|
154 |
+
self.backward_G() # calculate graidents for G
|
155 |
+
self.optimizer_G.step() # udpate G's weights
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/options/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/options/base_options.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from ...pix2pix.util import util
|
4 |
+
# import torch
|
5 |
+
from ...pix2pix import models
|
6 |
+
# import pix2pix.data
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
class BaseOptions():
|
10 |
+
"""This class defines options used during both training and test time.
|
11 |
+
|
12 |
+
It also implements several helper functions such as parsing, printing, and saving the options.
|
13 |
+
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
"""Reset the class; indicates the class hasn't been initailized"""
|
18 |
+
self.initialized = False
|
19 |
+
|
20 |
+
def initialize(self, parser):
|
21 |
+
"""Define the common options that are used in both training and test."""
|
22 |
+
# basic parameters
|
23 |
+
parser.add_argument('--dataroot', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
|
24 |
+
parser.add_argument('--name', type=str, default='void', help='mahdi_unet_new, scaled_unet')
|
25 |
+
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
|
26 |
+
parser.add_argument('--checkpoints_dir', type=str, default='./pix2pix/checkpoints', help='models are saved here')
|
27 |
+
# model parameters
|
28 |
+
parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
|
29 |
+
parser.add_argument('--input_nc', type=int, default=2, help='# of input image channels: 3 for RGB and 1 for grayscale')
|
30 |
+
parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale')
|
31 |
+
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
|
32 |
+
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
|
33 |
+
parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
|
34 |
+
parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
|
35 |
+
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
|
36 |
+
parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
|
37 |
+
parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
|
38 |
+
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
|
39 |
+
parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
|
40 |
+
# dataset parameters
|
41 |
+
parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
|
42 |
+
parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
|
43 |
+
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
|
44 |
+
parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
|
45 |
+
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
|
46 |
+
parser.add_argument('--load_size', type=int, default=672, help='scale images to this size')
|
47 |
+
parser.add_argument('--crop_size', type=int, default=672, help='then crop to this size')
|
48 |
+
parser.add_argument('--max_dataset_size', type=int, default=10000, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
|
49 |
+
parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
|
50 |
+
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
|
51 |
+
parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
|
52 |
+
# additional parameters
|
53 |
+
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
54 |
+
parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
|
55 |
+
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
|
56 |
+
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
|
57 |
+
|
58 |
+
parser.add_argument('--data_dir', type=str, required=False,
|
59 |
+
help='input files directory images can be .png .jpg .tiff')
|
60 |
+
parser.add_argument('--output_dir', type=str, required=False,
|
61 |
+
help='result dir. result depth will be png. vides are JMPG as avi')
|
62 |
+
parser.add_argument('--savecrops', type=int, required=False)
|
63 |
+
parser.add_argument('--savewholeest', type=int, required=False)
|
64 |
+
parser.add_argument('--output_resolution', type=int, required=False,
|
65 |
+
help='0 for no restriction 1 for resize to input size')
|
66 |
+
parser.add_argument('--net_receptive_field_size', type=int, required=False)
|
67 |
+
parser.add_argument('--pix2pixsize', type=int, required=False)
|
68 |
+
parser.add_argument('--generatevideo', type=int, required=False)
|
69 |
+
parser.add_argument('--depthNet', type=int, required=False, help='0: midas 1:strurturedRL')
|
70 |
+
parser.add_argument('--R0', action='store_true')
|
71 |
+
parser.add_argument('--R20', action='store_true')
|
72 |
+
parser.add_argument('--Final', action='store_true')
|
73 |
+
parser.add_argument('--colorize_results', action='store_true')
|
74 |
+
parser.add_argument('--max_res', type=float, default=np.inf)
|
75 |
+
|
76 |
+
self.initialized = True
|
77 |
+
return parser
|
78 |
+
|
79 |
+
def gather_options(self):
|
80 |
+
"""Initialize our parser with basic options(only once).
|
81 |
+
Add additional model-specific and dataset-specific options.
|
82 |
+
These options are defined in the <modify_commandline_options> function
|
83 |
+
in model and dataset classes.
|
84 |
+
"""
|
85 |
+
if not self.initialized: # check if it has been initialized
|
86 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
87 |
+
parser = self.initialize(parser)
|
88 |
+
|
89 |
+
# get the basic options
|
90 |
+
opt, _ = parser.parse_known_args()
|
91 |
+
|
92 |
+
# modify model-related parser options
|
93 |
+
model_name = opt.model
|
94 |
+
model_option_setter = models.get_option_setter(model_name)
|
95 |
+
parser = model_option_setter(parser, self.isTrain)
|
96 |
+
opt, _ = parser.parse_known_args() # parse again with new defaults
|
97 |
+
|
98 |
+
# modify dataset-related parser options
|
99 |
+
# dataset_name = opt.dataset_mode
|
100 |
+
# dataset_option_setter = pix2pix.data.get_option_setter(dataset_name)
|
101 |
+
# parser = dataset_option_setter(parser, self.isTrain)
|
102 |
+
|
103 |
+
# save and return the parser
|
104 |
+
self.parser = parser
|
105 |
+
#return parser.parse_args() #EVIL
|
106 |
+
return opt
|
107 |
+
|
108 |
+
def print_options(self, opt):
|
109 |
+
"""Print and save options
|
110 |
+
|
111 |
+
It will print both current options and default values(if different).
|
112 |
+
It will save options into a text file / [checkpoints_dir] / opt.txt
|
113 |
+
"""
|
114 |
+
message = ''
|
115 |
+
message += '----------------- Options ---------------\n'
|
116 |
+
for k, v in sorted(vars(opt).items()):
|
117 |
+
comment = ''
|
118 |
+
default = self.parser.get_default(k)
|
119 |
+
if v != default:
|
120 |
+
comment = '\t[default: %s]' % str(default)
|
121 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
122 |
+
message += '----------------- End -------------------'
|
123 |
+
print(message)
|
124 |
+
|
125 |
+
# save to the disk
|
126 |
+
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
127 |
+
util.mkdirs(expr_dir)
|
128 |
+
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
|
129 |
+
with open(file_name, 'wt') as opt_file:
|
130 |
+
opt_file.write(message)
|
131 |
+
opt_file.write('\n')
|
132 |
+
|
133 |
+
def parse(self):
|
134 |
+
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
|
135 |
+
opt = self.gather_options()
|
136 |
+
opt.isTrain = self.isTrain # train or test
|
137 |
+
|
138 |
+
# process opt.suffix
|
139 |
+
if opt.suffix:
|
140 |
+
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
|
141 |
+
opt.name = opt.name + suffix
|
142 |
+
|
143 |
+
#self.print_options(opt)
|
144 |
+
|
145 |
+
# set gpu ids
|
146 |
+
str_ids = opt.gpu_ids.split(',')
|
147 |
+
opt.gpu_ids = []
|
148 |
+
for str_id in str_ids:
|
149 |
+
id = int(str_id)
|
150 |
+
if id >= 0:
|
151 |
+
opt.gpu_ids.append(id)
|
152 |
+
#if len(opt.gpu_ids) > 0:
|
153 |
+
# torch.cuda.set_device(opt.gpu_ids[0])
|
154 |
+
|
155 |
+
self.opt = opt
|
156 |
+
return self.opt
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/options/test_options.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_options import BaseOptions
|
2 |
+
|
3 |
+
|
4 |
+
class TestOptions(BaseOptions):
|
5 |
+
"""This class includes test options.
|
6 |
+
|
7 |
+
It also includes shared options defined in BaseOptions.
|
8 |
+
"""
|
9 |
+
|
10 |
+
def initialize(self, parser):
|
11 |
+
parser = BaseOptions.initialize(self, parser) # define shared options
|
12 |
+
parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
|
13 |
+
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
|
14 |
+
# Dropout and Batchnorm has different behavioir during training and test.
|
15 |
+
parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
|
16 |
+
parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
|
17 |
+
# rewrite devalue values
|
18 |
+
parser.set_defaults(model='pix2pix4depth')
|
19 |
+
# To avoid cropping, the load_size should be the same as crop_size
|
20 |
+
parser.set_defaults(load_size=parser.get_default('crop_size'))
|
21 |
+
self.isTrain = False
|
22 |
+
return parser
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/util/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""This package includes a miscellaneous collection of useful helper functions."""
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/util/get_data.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
import tarfile
|
4 |
+
import requests
|
5 |
+
from warnings import warn
|
6 |
+
from zipfile import ZipFile
|
7 |
+
from bs4 import BeautifulSoup
|
8 |
+
from os.path import abspath, isdir, join, basename
|
9 |
+
|
10 |
+
|
11 |
+
class GetData(object):
|
12 |
+
"""A Python script for downloading CycleGAN or pix2pix datasets.
|
13 |
+
|
14 |
+
Parameters:
|
15 |
+
technique (str) -- One of: 'cyclegan' or 'pix2pix'.
|
16 |
+
verbose (bool) -- If True, print additional information.
|
17 |
+
|
18 |
+
Examples:
|
19 |
+
>>> from util.get_data import GetData
|
20 |
+
>>> gd = GetData(technique='cyclegan')
|
21 |
+
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
|
22 |
+
|
23 |
+
Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
|
24 |
+
and 'scripts/download_cyclegan_model.sh'.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, technique='cyclegan', verbose=True):
|
28 |
+
url_dict = {
|
29 |
+
'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
|
30 |
+
'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
|
31 |
+
}
|
32 |
+
self.url = url_dict.get(technique.lower())
|
33 |
+
self._verbose = verbose
|
34 |
+
|
35 |
+
def _print(self, text):
|
36 |
+
if self._verbose:
|
37 |
+
print(text)
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def _get_options(r):
|
41 |
+
soup = BeautifulSoup(r.text, 'lxml')
|
42 |
+
options = [h.text for h in soup.find_all('a', href=True)
|
43 |
+
if h.text.endswith(('.zip', 'tar.gz'))]
|
44 |
+
return options
|
45 |
+
|
46 |
+
def _present_options(self):
|
47 |
+
r = requests.get(self.url)
|
48 |
+
options = self._get_options(r)
|
49 |
+
print('Options:\n')
|
50 |
+
for i, o in enumerate(options):
|
51 |
+
print("{0}: {1}".format(i, o))
|
52 |
+
choice = input("\nPlease enter the number of the "
|
53 |
+
"dataset above you wish to download:")
|
54 |
+
return options[int(choice)]
|
55 |
+
|
56 |
+
def _download_data(self, dataset_url, save_path):
|
57 |
+
if not isdir(save_path):
|
58 |
+
os.makedirs(save_path)
|
59 |
+
|
60 |
+
base = basename(dataset_url)
|
61 |
+
temp_save_path = join(save_path, base)
|
62 |
+
|
63 |
+
with open(temp_save_path, "wb") as f:
|
64 |
+
r = requests.get(dataset_url)
|
65 |
+
f.write(r.content)
|
66 |
+
|
67 |
+
if base.endswith('.tar.gz'):
|
68 |
+
obj = tarfile.open(temp_save_path)
|
69 |
+
elif base.endswith('.zip'):
|
70 |
+
obj = ZipFile(temp_save_path, 'r')
|
71 |
+
else:
|
72 |
+
raise ValueError("Unknown File Type: {0}.".format(base))
|
73 |
+
|
74 |
+
self._print("Unpacking Data...")
|
75 |
+
obj.extractall(save_path)
|
76 |
+
obj.close()
|
77 |
+
os.remove(temp_save_path)
|
78 |
+
|
79 |
+
def get(self, save_path, dataset=None):
|
80 |
+
"""
|
81 |
+
|
82 |
+
Download a dataset.
|
83 |
+
|
84 |
+
Parameters:
|
85 |
+
save_path (str) -- A directory to save the data to.
|
86 |
+
dataset (str) -- (optional). A specific dataset to download.
|
87 |
+
Note: this must include the file extension.
|
88 |
+
If None, options will be presented for you
|
89 |
+
to choose from.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
save_path_full (str) -- the absolute path to the downloaded data.
|
93 |
+
|
94 |
+
"""
|
95 |
+
if dataset is None:
|
96 |
+
selected_dataset = self._present_options()
|
97 |
+
else:
|
98 |
+
selected_dataset = dataset
|
99 |
+
|
100 |
+
save_path_full = join(save_path, selected_dataset.split('.')[0])
|
101 |
+
|
102 |
+
if isdir(save_path_full):
|
103 |
+
warn("\n'{0}' already exists. Voiding Download.".format(
|
104 |
+
save_path_full))
|
105 |
+
else:
|
106 |
+
self._print('Downloading Data...')
|
107 |
+
url = "{0}/{1}".format(self.url, selected_dataset)
|
108 |
+
self._download_data(url, save_path=save_path)
|
109 |
+
|
110 |
+
return abspath(save_path_full)
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/util/guidedfilter.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
class GuidedFilter():
|
4 |
+
def __init__(self, source, reference, r=64, eps= 0.05**2):
|
5 |
+
self.source = source;
|
6 |
+
self.reference = reference;
|
7 |
+
self.r = r
|
8 |
+
self.eps = eps
|
9 |
+
|
10 |
+
self.smooth = self.guidedfilter(self.source,self.reference,self.r,self.eps)
|
11 |
+
|
12 |
+
def boxfilter(self,img, r):
|
13 |
+
(rows, cols) = img.shape
|
14 |
+
imDst = np.zeros_like(img)
|
15 |
+
|
16 |
+
imCum = np.cumsum(img, 0)
|
17 |
+
imDst[0 : r+1, :] = imCum[r : 2*r+1, :]
|
18 |
+
imDst[r+1 : rows-r, :] = imCum[2*r+1 : rows, :] - imCum[0 : rows-2*r-1, :]
|
19 |
+
imDst[rows-r: rows, :] = np.tile(imCum[rows-1, :], [r, 1]) - imCum[rows-2*r-1 : rows-r-1, :]
|
20 |
+
|
21 |
+
imCum = np.cumsum(imDst, 1)
|
22 |
+
imDst[:, 0 : r+1] = imCum[:, r : 2*r+1]
|
23 |
+
imDst[:, r+1 : cols-r] = imCum[:, 2*r+1 : cols] - imCum[:, 0 : cols-2*r-1]
|
24 |
+
imDst[:, cols-r: cols] = np.tile(imCum[:, cols-1], [r, 1]).T - imCum[:, cols-2*r-1 : cols-r-1]
|
25 |
+
|
26 |
+
return imDst
|
27 |
+
|
28 |
+
def guidedfilter(self,I, p, r, eps):
|
29 |
+
(rows, cols) = I.shape
|
30 |
+
N = self.boxfilter(np.ones([rows, cols]), r)
|
31 |
+
|
32 |
+
meanI = self.boxfilter(I, r) / N
|
33 |
+
meanP = self.boxfilter(p, r) / N
|
34 |
+
meanIp = self.boxfilter(I * p, r) / N
|
35 |
+
covIp = meanIp - meanI * meanP
|
36 |
+
|
37 |
+
meanII = self.boxfilter(I * I, r) / N
|
38 |
+
varI = meanII - meanI * meanI
|
39 |
+
|
40 |
+
a = covIp / (varI + eps)
|
41 |
+
b = meanP - a * meanI
|
42 |
+
|
43 |
+
meanA = self.boxfilter(a, r) / N
|
44 |
+
meanB = self.boxfilter(b, r) / N
|
45 |
+
|
46 |
+
q = meanA * I + meanB
|
47 |
+
return q
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/util/html.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dominate
|
2 |
+
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class HTML:
|
7 |
+
"""This HTML class allows us to save images and write texts into a single HTML file.
|
8 |
+
|
9 |
+
It consists of functions such as <add_header> (add a text header to the HTML file),
|
10 |
+
<add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
|
11 |
+
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, web_dir, title, refresh=0):
|
15 |
+
"""Initialize the HTML classes
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
|
19 |
+
title (str) -- the webpage name
|
20 |
+
refresh (int) -- how often the website refresh itself; if 0; no refreshing
|
21 |
+
"""
|
22 |
+
self.title = title
|
23 |
+
self.web_dir = web_dir
|
24 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
25 |
+
if not os.path.exists(self.web_dir):
|
26 |
+
os.makedirs(self.web_dir)
|
27 |
+
if not os.path.exists(self.img_dir):
|
28 |
+
os.makedirs(self.img_dir)
|
29 |
+
|
30 |
+
self.doc = dominate.document(title=title)
|
31 |
+
if refresh > 0:
|
32 |
+
with self.doc.head:
|
33 |
+
meta(http_equiv="refresh", content=str(refresh))
|
34 |
+
|
35 |
+
def get_image_dir(self):
|
36 |
+
"""Return the directory that stores images"""
|
37 |
+
return self.img_dir
|
38 |
+
|
39 |
+
def add_header(self, text):
|
40 |
+
"""Insert a header to the HTML file
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
text (str) -- the header text
|
44 |
+
"""
|
45 |
+
with self.doc:
|
46 |
+
h3(text)
|
47 |
+
|
48 |
+
def add_images(self, ims, txts, links, width=400):
|
49 |
+
"""add images to the HTML file
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
ims (str list) -- a list of image paths
|
53 |
+
txts (str list) -- a list of image names shown on the website
|
54 |
+
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
|
55 |
+
"""
|
56 |
+
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
|
57 |
+
self.doc.add(self.t)
|
58 |
+
with self.t:
|
59 |
+
with tr():
|
60 |
+
for im, txt, link in zip(ims, txts, links):
|
61 |
+
with td(style="word-wrap: break-word;", halign="center", valign="top"):
|
62 |
+
with p():
|
63 |
+
with a(href=os.path.join('images', link)):
|
64 |
+
img(style="width:%dpx" % width, src=os.path.join('images', im))
|
65 |
+
br()
|
66 |
+
p(txt)
|
67 |
+
|
68 |
+
def save(self):
|
69 |
+
"""save the current content to the HMTL file"""
|
70 |
+
html_file = '%s/index.html' % self.web_dir
|
71 |
+
f = open(html_file, 'wt')
|
72 |
+
f.write(self.doc.render())
|
73 |
+
f.close()
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__': # we show an example usage here.
|
77 |
+
html = HTML('web/', 'test_html')
|
78 |
+
html.add_header('hello world')
|
79 |
+
|
80 |
+
ims, txts, links = [], [], []
|
81 |
+
for n in range(4):
|
82 |
+
ims.append('image_%d.png' % n)
|
83 |
+
txts.append('text_%d' % n)
|
84 |
+
links.append('image_%d.png' % n)
|
85 |
+
html.add_images(ims, txts, links)
|
86 |
+
html.save()
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/util/image_pool.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class ImagePool():
|
6 |
+
"""This class implements an image buffer that stores previously generated images.
|
7 |
+
|
8 |
+
This buffer enables us to update discriminators using a history of generated images
|
9 |
+
rather than the ones produced by the latest generators.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, pool_size):
|
13 |
+
"""Initialize the ImagePool class
|
14 |
+
|
15 |
+
Parameters:
|
16 |
+
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
|
17 |
+
"""
|
18 |
+
self.pool_size = pool_size
|
19 |
+
if self.pool_size > 0: # create an empty pool
|
20 |
+
self.num_imgs = 0
|
21 |
+
self.images = []
|
22 |
+
|
23 |
+
def query(self, images):
|
24 |
+
"""Return an image from the pool.
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
images: the latest generated images from the generator
|
28 |
+
|
29 |
+
Returns images from the buffer.
|
30 |
+
|
31 |
+
By 50/100, the buffer will return input images.
|
32 |
+
By 50/100, the buffer will return images previously stored in the buffer,
|
33 |
+
and insert the current images to the buffer.
|
34 |
+
"""
|
35 |
+
if self.pool_size == 0: # if the buffer size is 0, do nothing
|
36 |
+
return images
|
37 |
+
return_images = []
|
38 |
+
for image in images:
|
39 |
+
image = torch.unsqueeze(image.data, 0)
|
40 |
+
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
|
41 |
+
self.num_imgs = self.num_imgs + 1
|
42 |
+
self.images.append(image)
|
43 |
+
return_images.append(image)
|
44 |
+
else:
|
45 |
+
p = random.uniform(0, 1)
|
46 |
+
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
|
47 |
+
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
|
48 |
+
tmp = self.images[random_id].clone()
|
49 |
+
self.images[random_id] = image
|
50 |
+
return_images.append(tmp)
|
51 |
+
else: # by another 50% chance, the buffer will return the current image
|
52 |
+
return_images.append(image)
|
53 |
+
return_images = torch.cat(return_images, 0) # collect all the images and return
|
54 |
+
return return_images
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/util/util.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module contains simple helper functions """
|
2 |
+
from __future__ import print_function
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
|
8 |
+
|
9 |
+
def tensor2im(input_image, imtype=np.uint16):
|
10 |
+
""""Converts a Tensor array into a numpy image array.
|
11 |
+
|
12 |
+
Parameters:
|
13 |
+
input_image (tensor) -- the input image tensor array
|
14 |
+
imtype (type) -- the desired type of the converted numpy array
|
15 |
+
"""
|
16 |
+
if not isinstance(input_image, np.ndarray):
|
17 |
+
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
18 |
+
image_tensor = input_image.data
|
19 |
+
else:
|
20 |
+
return input_image
|
21 |
+
image_numpy = torch.squeeze(image_tensor).cpu().numpy() # convert it into a numpy array
|
22 |
+
image_numpy = (image_numpy + 1) / 2.0 * (2**16-1) #
|
23 |
+
else: # if it is a numpy array, do nothing
|
24 |
+
image_numpy = input_image
|
25 |
+
return image_numpy.astype(imtype)
|
26 |
+
|
27 |
+
|
28 |
+
def diagnose_network(net, name='network'):
|
29 |
+
"""Calculate and print the mean of average absolute(gradients)
|
30 |
+
|
31 |
+
Parameters:
|
32 |
+
net (torch network) -- Torch network
|
33 |
+
name (str) -- the name of the network
|
34 |
+
"""
|
35 |
+
mean = 0.0
|
36 |
+
count = 0
|
37 |
+
for param in net.parameters():
|
38 |
+
if param.grad is not None:
|
39 |
+
mean += torch.mean(torch.abs(param.grad.data))
|
40 |
+
count += 1
|
41 |
+
if count > 0:
|
42 |
+
mean = mean / count
|
43 |
+
print(name)
|
44 |
+
print(mean)
|
45 |
+
|
46 |
+
|
47 |
+
def save_image(image_numpy, image_path, aspect_ratio=1.0):
|
48 |
+
"""Save a numpy image to the disk
|
49 |
+
|
50 |
+
Parameters:
|
51 |
+
image_numpy (numpy array) -- input numpy array
|
52 |
+
image_path (str) -- the path of the image
|
53 |
+
"""
|
54 |
+
image_pil = Image.fromarray(image_numpy)
|
55 |
+
|
56 |
+
image_pil = image_pil.convert('I;16')
|
57 |
+
|
58 |
+
# image_pil = Image.fromarray(image_numpy)
|
59 |
+
# h, w, _ = image_numpy.shape
|
60 |
+
#
|
61 |
+
# if aspect_ratio > 1.0:
|
62 |
+
# image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
|
63 |
+
# if aspect_ratio < 1.0:
|
64 |
+
# image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
|
65 |
+
|
66 |
+
image_pil.save(image_path)
|
67 |
+
|
68 |
+
|
69 |
+
def print_numpy(x, val=True, shp=False):
|
70 |
+
"""Print the mean, min, max, median, std, and size of a numpy array
|
71 |
+
|
72 |
+
Parameters:
|
73 |
+
val (bool) -- if print the values of the numpy array
|
74 |
+
shp (bool) -- if print the shape of the numpy array
|
75 |
+
"""
|
76 |
+
x = x.astype(np.float64)
|
77 |
+
if shp:
|
78 |
+
print('shape,', x.shape)
|
79 |
+
if val:
|
80 |
+
x = x.flatten()
|
81 |
+
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
82 |
+
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
83 |
+
|
84 |
+
|
85 |
+
def mkdirs(paths):
|
86 |
+
"""create empty directories if they don't exist
|
87 |
+
|
88 |
+
Parameters:
|
89 |
+
paths (str list) -- a list of directory paths
|
90 |
+
"""
|
91 |
+
if isinstance(paths, list) and not isinstance(paths, str):
|
92 |
+
for path in paths:
|
93 |
+
mkdir(path)
|
94 |
+
else:
|
95 |
+
mkdir(paths)
|
96 |
+
|
97 |
+
|
98 |
+
def mkdir(path):
|
99 |
+
"""create a single empty directory if it didn't exist
|
100 |
+
|
101 |
+
Parameters:
|
102 |
+
path (str) -- a single directory path
|
103 |
+
"""
|
104 |
+
if not os.path.exists(path):
|
105 |
+
os.makedirs(path)
|
extensions/sd-webui-controlnet/annotator/leres/pix2pix/util/visualizer.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import ntpath
|
5 |
+
import time
|
6 |
+
from . import util, html
|
7 |
+
from subprocess import Popen, PIPE
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
if sys.version_info[0] == 2:
|
12 |
+
VisdomExceptionBase = Exception
|
13 |
+
else:
|
14 |
+
VisdomExceptionBase = ConnectionError
|
15 |
+
|
16 |
+
|
17 |
+
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
|
18 |
+
"""Save images to the disk.
|
19 |
+
|
20 |
+
Parameters:
|
21 |
+
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
|
22 |
+
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
|
23 |
+
image_path (str) -- the string is used to create image paths
|
24 |
+
aspect_ratio (float) -- the aspect ratio of saved images
|
25 |
+
width (int) -- the images will be resized to width x width
|
26 |
+
|
27 |
+
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
|
28 |
+
"""
|
29 |
+
image_dir = webpage.get_image_dir()
|
30 |
+
short_path = ntpath.basename(image_path[0])
|
31 |
+
name = os.path.splitext(short_path)[0]
|
32 |
+
|
33 |
+
webpage.add_header(name)
|
34 |
+
ims, txts, links = [], [], []
|
35 |
+
|
36 |
+
for label, im_data in visuals.items():
|
37 |
+
im = util.tensor2im(im_data)
|
38 |
+
image_name = '%s_%s.png' % (name, label)
|
39 |
+
save_path = os.path.join(image_dir, image_name)
|
40 |
+
util.save_image(im, save_path, aspect_ratio=aspect_ratio)
|
41 |
+
ims.append(image_name)
|
42 |
+
txts.append(label)
|
43 |
+
links.append(image_name)
|
44 |
+
webpage.add_images(ims, txts, links, width=width)
|
45 |
+
|
46 |
+
|
47 |
+
class Visualizer():
|
48 |
+
"""This class includes several functions that can display/save images and print/save logging information.
|
49 |
+
|
50 |
+
It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, opt):
|
54 |
+
"""Initialize the Visualizer class
|
55 |
+
|
56 |
+
Parameters:
|
57 |
+
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
58 |
+
Step 1: Cache the training/test options
|
59 |
+
Step 2: connect to a visdom server
|
60 |
+
Step 3: create an HTML object for saveing HTML filters
|
61 |
+
Step 4: create a logging file to store training losses
|
62 |
+
"""
|
63 |
+
self.opt = opt # cache the option
|
64 |
+
self.display_id = opt.display_id
|
65 |
+
self.use_html = opt.isTrain and not opt.no_html
|
66 |
+
self.win_size = opt.display_winsize
|
67 |
+
self.name = opt.name
|
68 |
+
self.port = opt.display_port
|
69 |
+
self.saved = False
|
70 |
+
|
71 |
+
if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
|
72 |
+
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
|
73 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
74 |
+
print('create web directory %s...' % self.web_dir)
|
75 |
+
util.mkdirs([self.web_dir, self.img_dir])
|
76 |
+
# create a logging file to store training losses
|
77 |
+
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
78 |
+
with open(self.log_name, "a") as log_file:
|
79 |
+
now = time.strftime("%c")
|
80 |
+
log_file.write('================ Training Loss (%s) ================\n' % now)
|
81 |
+
|
82 |
+
def reset(self):
|
83 |
+
"""Reset the self.saved status"""
|
84 |
+
self.saved = False
|
85 |
+
|
86 |
+
def create_visdom_connections(self):
|
87 |
+
"""If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
|
88 |
+
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
|
89 |
+
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
|
90 |
+
print('Command: %s' % cmd)
|
91 |
+
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
|
92 |
+
|
93 |
+
def display_current_results(self, visuals, epoch, save_result):
|
94 |
+
"""Display current results on visdom; save current results to an HTML file.
|
95 |
+
|
96 |
+
Parameters:
|
97 |
+
visuals (OrderedDict) - - dictionary of images to display or save
|
98 |
+
epoch (int) - - the current epoch
|
99 |
+
save_result (bool) - - if save the current results to an HTML file
|
100 |
+
"""
|
101 |
+
if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
|
102 |
+
self.saved = True
|
103 |
+
# save images to the disk
|
104 |
+
for label, image in visuals.items():
|
105 |
+
image_numpy = util.tensor2im(image)
|
106 |
+
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
|
107 |
+
util.save_image(image_numpy, img_path)
|
108 |
+
|
109 |
+
# update website
|
110 |
+
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
|
111 |
+
for n in range(epoch, 0, -1):
|
112 |
+
webpage.add_header('epoch [%d]' % n)
|
113 |
+
ims, txts, links = [], [], []
|
114 |
+
|
115 |
+
for label, image_numpy in visuals.items():
|
116 |
+
# image_numpy = util.tensor2im(image)
|
117 |
+
img_path = 'epoch%.3d_%s.png' % (n, label)
|
118 |
+
ims.append(img_path)
|
119 |
+
txts.append(label)
|
120 |
+
links.append(img_path)
|
121 |
+
webpage.add_images(ims, txts, links, width=self.win_size)
|
122 |
+
webpage.save()
|
123 |
+
|
124 |
+
# def plot_current_losses(self, epoch, counter_ratio, losses):
|
125 |
+
# """display the current losses on visdom display: dictionary of error labels and values
|
126 |
+
#
|
127 |
+
# Parameters:
|
128 |
+
# epoch (int) -- current epoch
|
129 |
+
# counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
|
130 |
+
# losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
131 |
+
# """
|
132 |
+
# if not hasattr(self, 'plot_data'):
|
133 |
+
# self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
|
134 |
+
# self.plot_data['X'].append(epoch + counter_ratio)
|
135 |
+
# self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
|
136 |
+
# try:
|
137 |
+
# self.vis.line(
|
138 |
+
# X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
|
139 |
+
# Y=np.array(self.plot_data['Y']),
|
140 |
+
# opts={
|
141 |
+
# 'title': self.name + ' loss over time',
|
142 |
+
# 'legend': self.plot_data['legend'],
|
143 |
+
# 'xlabel': 'epoch',
|
144 |
+
# 'ylabel': 'loss'},
|
145 |
+
# win=self.display_id)
|
146 |
+
# except VisdomExceptionBase:
|
147 |
+
# self.create_visdom_connections()
|
148 |
+
|
149 |
+
# losses: same format as |losses| of plot_current_losses
|
150 |
+
def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
|
151 |
+
"""print current losses on console; also save the losses to the disk
|
152 |
+
|
153 |
+
Parameters:
|
154 |
+
epoch (int) -- current epoch
|
155 |
+
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
|
156 |
+
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
157 |
+
t_comp (float) -- computational time per data point (normalized by batch_size)
|
158 |
+
t_data (float) -- data loading time per data point (normalized by batch_size)
|
159 |
+
"""
|
160 |
+
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
|
161 |
+
for k, v in losses.items():
|
162 |
+
message += '%s: %.3f ' % (k, v)
|
163 |
+
|
164 |
+
print(message) # print the message
|
165 |
+
with open(self.log_name, "a") as log_file:
|
166 |
+
log_file.write('%s\n' % message) # save the message
|
extensions/sd-webui-controlnet/annotator/midas/__init__.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from einops import rearrange
|
6 |
+
from .api import MiDaSInference
|
7 |
+
from modules import devices
|
8 |
+
|
9 |
+
model = None
|
10 |
+
|
11 |
+
def unload_midas_model():
|
12 |
+
global model
|
13 |
+
if model is not None:
|
14 |
+
model = model.cpu()
|
15 |
+
|
16 |
+
def apply_midas(input_image, a=np.pi * 2.0, bg_th=0.1):
|
17 |
+
global model
|
18 |
+
if model is None:
|
19 |
+
model = MiDaSInference(model_type="dpt_hybrid")
|
20 |
+
if devices.get_device_for("controlnet").type != 'mps':
|
21 |
+
model = model.to(devices.get_device_for("controlnet"))
|
22 |
+
|
23 |
+
assert input_image.ndim == 3
|
24 |
+
image_depth = input_image
|
25 |
+
with torch.no_grad():
|
26 |
+
image_depth = torch.from_numpy(image_depth).float()
|
27 |
+
if devices.get_device_for("controlnet").type != 'mps':
|
28 |
+
image_depth = image_depth.to(devices.get_device_for("controlnet"))
|
29 |
+
image_depth = image_depth / 127.5 - 1.0
|
30 |
+
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
31 |
+
depth = model(image_depth)[0]
|
32 |
+
|
33 |
+
depth_pt = depth.clone()
|
34 |
+
depth_pt -= torch.min(depth_pt)
|
35 |
+
depth_pt /= torch.max(depth_pt)
|
36 |
+
depth_pt = depth_pt.cpu().numpy()
|
37 |
+
depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
|
38 |
+
|
39 |
+
depth_np = depth.cpu().numpy()
|
40 |
+
x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
|
41 |
+
y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
|
42 |
+
z = np.ones_like(x) * a
|
43 |
+
x[depth_pt < bg_th] = 0
|
44 |
+
y[depth_pt < bg_th] = 0
|
45 |
+
normal = np.stack([x, y, z], axis=2)
|
46 |
+
normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
|
47 |
+
normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
48 |
+
|
49 |
+
return depth_image, normal_image
|
extensions/sd-webui-controlnet/annotator/midas/api.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://github.com/isl-org/MiDaS
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import os
|
7 |
+
from modules.paths import models_path
|
8 |
+
|
9 |
+
from torchvision.transforms import Compose
|
10 |
+
|
11 |
+
from .midas.dpt_depth import DPTDepthModel
|
12 |
+
from .midas.midas_net import MidasNet
|
13 |
+
from .midas.midas_net_custom import MidasNet_small
|
14 |
+
from .midas.transforms import Resize, NormalizeImage, PrepareForNet
|
15 |
+
|
16 |
+
base_model_path = os.path.join(models_path, "midas")
|
17 |
+
old_modeldir = os.path.dirname(os.path.realpath(__file__))
|
18 |
+
remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
|
19 |
+
|
20 |
+
ISL_PATHS = {
|
21 |
+
"dpt_large": os.path.join(base_model_path, "dpt_large-midas-2f21e586.pt"),
|
22 |
+
"dpt_hybrid": os.path.join(base_model_path, "dpt_hybrid-midas-501f0c75.pt"),
|
23 |
+
"midas_v21": "",
|
24 |
+
"midas_v21_small": "",
|
25 |
+
}
|
26 |
+
|
27 |
+
OLD_ISL_PATHS = {
|
28 |
+
"dpt_large": os.path.join(old_modeldir, "dpt_large-midas-2f21e586.pt"),
|
29 |
+
"dpt_hybrid": os.path.join(old_modeldir, "dpt_hybrid-midas-501f0c75.pt"),
|
30 |
+
"midas_v21": "",
|
31 |
+
"midas_v21_small": "",
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
def disabled_train(self, mode=True):
|
36 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
37 |
+
does not change anymore."""
|
38 |
+
return self
|
39 |
+
|
40 |
+
|
41 |
+
def load_midas_transform(model_type):
|
42 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
43 |
+
# load transform only
|
44 |
+
if model_type == "dpt_large": # DPT-Large
|
45 |
+
net_w, net_h = 384, 384
|
46 |
+
resize_mode = "minimal"
|
47 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
48 |
+
|
49 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
50 |
+
net_w, net_h = 384, 384
|
51 |
+
resize_mode = "minimal"
|
52 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
53 |
+
|
54 |
+
elif model_type == "midas_v21":
|
55 |
+
net_w, net_h = 384, 384
|
56 |
+
resize_mode = "upper_bound"
|
57 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
58 |
+
|
59 |
+
elif model_type == "midas_v21_small":
|
60 |
+
net_w, net_h = 256, 256
|
61 |
+
resize_mode = "upper_bound"
|
62 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
63 |
+
|
64 |
+
else:
|
65 |
+
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
66 |
+
|
67 |
+
transform = Compose(
|
68 |
+
[
|
69 |
+
Resize(
|
70 |
+
net_w,
|
71 |
+
net_h,
|
72 |
+
resize_target=None,
|
73 |
+
keep_aspect_ratio=True,
|
74 |
+
ensure_multiple_of=32,
|
75 |
+
resize_method=resize_mode,
|
76 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
77 |
+
),
|
78 |
+
normalization,
|
79 |
+
PrepareForNet(),
|
80 |
+
]
|
81 |
+
)
|
82 |
+
|
83 |
+
return transform
|
84 |
+
|
85 |
+
|
86 |
+
def load_model(model_type):
|
87 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
88 |
+
# load network
|
89 |
+
model_path = ISL_PATHS[model_type]
|
90 |
+
old_model_path = OLD_ISL_PATHS[model_type]
|
91 |
+
if model_type == "dpt_large": # DPT-Large
|
92 |
+
model = DPTDepthModel(
|
93 |
+
path=model_path,
|
94 |
+
backbone="vitl16_384",
|
95 |
+
non_negative=True,
|
96 |
+
)
|
97 |
+
net_w, net_h = 384, 384
|
98 |
+
resize_mode = "minimal"
|
99 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
100 |
+
|
101 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
102 |
+
if os.path.exists(old_model_path):
|
103 |
+
model_path = old_model_path
|
104 |
+
elif not os.path.exists(model_path):
|
105 |
+
from basicsr.utils.download_util import load_file_from_url
|
106 |
+
load_file_from_url(remote_model_path, model_dir=base_model_path)
|
107 |
+
|
108 |
+
model = DPTDepthModel(
|
109 |
+
path=model_path,
|
110 |
+
backbone="vitb_rn50_384",
|
111 |
+
non_negative=True,
|
112 |
+
)
|
113 |
+
net_w, net_h = 384, 384
|
114 |
+
resize_mode = "minimal"
|
115 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
116 |
+
|
117 |
+
elif model_type == "midas_v21":
|
118 |
+
model = MidasNet(model_path, non_negative=True)
|
119 |
+
net_w, net_h = 384, 384
|
120 |
+
resize_mode = "upper_bound"
|
121 |
+
normalization = NormalizeImage(
|
122 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
123 |
+
)
|
124 |
+
|
125 |
+
elif model_type == "midas_v21_small":
|
126 |
+
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
127 |
+
non_negative=True, blocks={'expand': True})
|
128 |
+
net_w, net_h = 256, 256
|
129 |
+
resize_mode = "upper_bound"
|
130 |
+
normalization = NormalizeImage(
|
131 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
132 |
+
)
|
133 |
+
|
134 |
+
else:
|
135 |
+
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
136 |
+
assert False
|
137 |
+
|
138 |
+
transform = Compose(
|
139 |
+
[
|
140 |
+
Resize(
|
141 |
+
net_w,
|
142 |
+
net_h,
|
143 |
+
resize_target=None,
|
144 |
+
keep_aspect_ratio=True,
|
145 |
+
ensure_multiple_of=32,
|
146 |
+
resize_method=resize_mode,
|
147 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
148 |
+
),
|
149 |
+
normalization,
|
150 |
+
PrepareForNet(),
|
151 |
+
]
|
152 |
+
)
|
153 |
+
|
154 |
+
return model.eval(), transform
|
155 |
+
|
156 |
+
|
157 |
+
class MiDaSInference(nn.Module):
|
158 |
+
MODEL_TYPES_TORCH_HUB = [
|
159 |
+
"DPT_Large",
|
160 |
+
"DPT_Hybrid",
|
161 |
+
"MiDaS_small"
|
162 |
+
]
|
163 |
+
MODEL_TYPES_ISL = [
|
164 |
+
"dpt_large",
|
165 |
+
"dpt_hybrid",
|
166 |
+
"midas_v21",
|
167 |
+
"midas_v21_small",
|
168 |
+
]
|
169 |
+
|
170 |
+
def __init__(self, model_type):
|
171 |
+
super().__init__()
|
172 |
+
assert (model_type in self.MODEL_TYPES_ISL)
|
173 |
+
model, _ = load_model(model_type)
|
174 |
+
self.model = model
|
175 |
+
self.model.train = disabled_train
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
with torch.no_grad():
|
179 |
+
prediction = self.model(x)
|
180 |
+
return prediction
|
181 |
+
|
extensions/sd-webui-controlnet/annotator/midas/midas/__init__.py
ADDED
File without changes
|
extensions/sd-webui-controlnet/annotator/midas/midas/base_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class BaseModel(torch.nn.Module):
|
5 |
+
def load(self, path):
|
6 |
+
"""Load model from file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
path (str): file path
|
10 |
+
"""
|
11 |
+
parameters = torch.load(path, map_location=torch.device('cpu'))
|
12 |
+
|
13 |
+
if "optimizer" in parameters:
|
14 |
+
parameters = parameters["model"]
|
15 |
+
|
16 |
+
self.load_state_dict(parameters)
|
extensions/sd-webui-controlnet/annotator/midas/midas/blocks.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .vit import (
|
5 |
+
_make_pretrained_vitb_rn50_384,
|
6 |
+
_make_pretrained_vitl16_384,
|
7 |
+
_make_pretrained_vitb16_384,
|
8 |
+
forward_vit,
|
9 |
+
)
|
10 |
+
|
11 |
+
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
12 |
+
if backbone == "vitl16_384":
|
13 |
+
pretrained = _make_pretrained_vitl16_384(
|
14 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
15 |
+
)
|
16 |
+
scratch = _make_scratch(
|
17 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
18 |
+
) # ViT-L/16 - 85.0% Top1 (backbone)
|
19 |
+
elif backbone == "vitb_rn50_384":
|
20 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
21 |
+
use_pretrained,
|
22 |
+
hooks=hooks,
|
23 |
+
use_vit_only=use_vit_only,
|
24 |
+
use_readout=use_readout,
|
25 |
+
)
|
26 |
+
scratch = _make_scratch(
|
27 |
+
[256, 512, 768, 768], features, groups=groups, expand=expand
|
28 |
+
) # ViT-H/16 - 85.0% Top1 (backbone)
|
29 |
+
elif backbone == "vitb16_384":
|
30 |
+
pretrained = _make_pretrained_vitb16_384(
|
31 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
32 |
+
)
|
33 |
+
scratch = _make_scratch(
|
34 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
35 |
+
) # ViT-B/16 - 84.6% Top1 (backbone)
|
36 |
+
elif backbone == "resnext101_wsl":
|
37 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
38 |
+
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
39 |
+
elif backbone == "efficientnet_lite3":
|
40 |
+
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
41 |
+
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
42 |
+
else:
|
43 |
+
print(f"Backbone '{backbone}' not implemented")
|
44 |
+
assert False
|
45 |
+
|
46 |
+
return pretrained, scratch
|
47 |
+
|
48 |
+
|
49 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
50 |
+
scratch = nn.Module()
|
51 |
+
|
52 |
+
out_shape1 = out_shape
|
53 |
+
out_shape2 = out_shape
|
54 |
+
out_shape3 = out_shape
|
55 |
+
out_shape4 = out_shape
|
56 |
+
if expand==True:
|
57 |
+
out_shape1 = out_shape
|
58 |
+
out_shape2 = out_shape*2
|
59 |
+
out_shape3 = out_shape*4
|
60 |
+
out_shape4 = out_shape*8
|
61 |
+
|
62 |
+
scratch.layer1_rn = nn.Conv2d(
|
63 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
64 |
+
)
|
65 |
+
scratch.layer2_rn = nn.Conv2d(
|
66 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
67 |
+
)
|
68 |
+
scratch.layer3_rn = nn.Conv2d(
|
69 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
70 |
+
)
|
71 |
+
scratch.layer4_rn = nn.Conv2d(
|
72 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
73 |
+
)
|
74 |
+
|
75 |
+
return scratch
|
76 |
+
|
77 |
+
|
78 |
+
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
79 |
+
efficientnet = torch.hub.load(
|
80 |
+
"rwightman/gen-efficientnet-pytorch",
|
81 |
+
"tf_efficientnet_lite3",
|
82 |
+
pretrained=use_pretrained,
|
83 |
+
exportable=exportable
|
84 |
+
)
|
85 |
+
return _make_efficientnet_backbone(efficientnet)
|
86 |
+
|
87 |
+
|
88 |
+
def _make_efficientnet_backbone(effnet):
|
89 |
+
pretrained = nn.Module()
|
90 |
+
|
91 |
+
pretrained.layer1 = nn.Sequential(
|
92 |
+
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
93 |
+
)
|
94 |
+
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
95 |
+
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
96 |
+
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
97 |
+
|
98 |
+
return pretrained
|
99 |
+
|
100 |
+
|
101 |
+
def _make_resnet_backbone(resnet):
|
102 |
+
pretrained = nn.Module()
|
103 |
+
pretrained.layer1 = nn.Sequential(
|
104 |
+
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
105 |
+
)
|
106 |
+
|
107 |
+
pretrained.layer2 = resnet.layer2
|
108 |
+
pretrained.layer3 = resnet.layer3
|
109 |
+
pretrained.layer4 = resnet.layer4
|
110 |
+
|
111 |
+
return pretrained
|
112 |
+
|
113 |
+
|
114 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
115 |
+
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
116 |
+
return _make_resnet_backbone(resnet)
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
class Interpolate(nn.Module):
|
121 |
+
"""Interpolation module.
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
125 |
+
"""Init.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
scale_factor (float): scaling
|
129 |
+
mode (str): interpolation mode
|
130 |
+
"""
|
131 |
+
super(Interpolate, self).__init__()
|
132 |
+
|
133 |
+
self.interp = nn.functional.interpolate
|
134 |
+
self.scale_factor = scale_factor
|
135 |
+
self.mode = mode
|
136 |
+
self.align_corners = align_corners
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
"""Forward pass.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
x (tensor): input
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
tensor: interpolated data
|
146 |
+
"""
|
147 |
+
|
148 |
+
x = self.interp(
|
149 |
+
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
150 |
+
)
|
151 |
+
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class ResidualConvUnit(nn.Module):
|
156 |
+
"""Residual convolution module.
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, features):
|
160 |
+
"""Init.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
features (int): number of features
|
164 |
+
"""
|
165 |
+
super().__init__()
|
166 |
+
|
167 |
+
self.conv1 = nn.Conv2d(
|
168 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
169 |
+
)
|
170 |
+
|
171 |
+
self.conv2 = nn.Conv2d(
|
172 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
173 |
+
)
|
174 |
+
|
175 |
+
self.relu = nn.ReLU(inplace=True)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
"""Forward pass.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
x (tensor): input
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
tensor: output
|
185 |
+
"""
|
186 |
+
out = self.relu(x)
|
187 |
+
out = self.conv1(out)
|
188 |
+
out = self.relu(out)
|
189 |
+
out = self.conv2(out)
|
190 |
+
|
191 |
+
return out + x
|
192 |
+
|
193 |
+
|
194 |
+
class FeatureFusionBlock(nn.Module):
|
195 |
+
"""Feature fusion block.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self, features):
|
199 |
+
"""Init.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
features (int): number of features
|
203 |
+
"""
|
204 |
+
super(FeatureFusionBlock, self).__init__()
|
205 |
+
|
206 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
207 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
208 |
+
|
209 |
+
def forward(self, *xs):
|
210 |
+
"""Forward pass.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
tensor: output
|
214 |
+
"""
|
215 |
+
output = xs[0]
|
216 |
+
|
217 |
+
if len(xs) == 2:
|
218 |
+
output += self.resConfUnit1(xs[1])
|
219 |
+
|
220 |
+
output = self.resConfUnit2(output)
|
221 |
+
|
222 |
+
output = nn.functional.interpolate(
|
223 |
+
output, scale_factor=2, mode="bilinear", align_corners=True
|
224 |
+
)
|
225 |
+
|
226 |
+
return output
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
class ResidualConvUnit_custom(nn.Module):
|
232 |
+
"""Residual convolution module.
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(self, features, activation, bn):
|
236 |
+
"""Init.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
features (int): number of features
|
240 |
+
"""
|
241 |
+
super().__init__()
|
242 |
+
|
243 |
+
self.bn = bn
|
244 |
+
|
245 |
+
self.groups=1
|
246 |
+
|
247 |
+
self.conv1 = nn.Conv2d(
|
248 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
249 |
+
)
|
250 |
+
|
251 |
+
self.conv2 = nn.Conv2d(
|
252 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
253 |
+
)
|
254 |
+
|
255 |
+
if self.bn==True:
|
256 |
+
self.bn1 = nn.BatchNorm2d(features)
|
257 |
+
self.bn2 = nn.BatchNorm2d(features)
|
258 |
+
|
259 |
+
self.activation = activation
|
260 |
+
|
261 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
"""Forward pass.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
x (tensor): input
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
tensor: output
|
271 |
+
"""
|
272 |
+
|
273 |
+
out = self.activation(x)
|
274 |
+
out = self.conv1(out)
|
275 |
+
if self.bn==True:
|
276 |
+
out = self.bn1(out)
|
277 |
+
|
278 |
+
out = self.activation(out)
|
279 |
+
out = self.conv2(out)
|
280 |
+
if self.bn==True:
|
281 |
+
out = self.bn2(out)
|
282 |
+
|
283 |
+
if self.groups > 1:
|
284 |
+
out = self.conv_merge(out)
|
285 |
+
|
286 |
+
return self.skip_add.add(out, x)
|
287 |
+
|
288 |
+
# return out + x
|
289 |
+
|
290 |
+
|
291 |
+
class FeatureFusionBlock_custom(nn.Module):
|
292 |
+
"""Feature fusion block.
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
296 |
+
"""Init.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
features (int): number of features
|
300 |
+
"""
|
301 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
302 |
+
|
303 |
+
self.deconv = deconv
|
304 |
+
self.align_corners = align_corners
|
305 |
+
|
306 |
+
self.groups=1
|
307 |
+
|
308 |
+
self.expand = expand
|
309 |
+
out_features = features
|
310 |
+
if self.expand==True:
|
311 |
+
out_features = features//2
|
312 |
+
|
313 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
314 |
+
|
315 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
316 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
317 |
+
|
318 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
319 |
+
|
320 |
+
def forward(self, *xs):
|
321 |
+
"""Forward pass.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
tensor: output
|
325 |
+
"""
|
326 |
+
output = xs[0]
|
327 |
+
|
328 |
+
if len(xs) == 2:
|
329 |
+
res = self.resConfUnit1(xs[1])
|
330 |
+
output = self.skip_add.add(output, res)
|
331 |
+
# output += res
|
332 |
+
|
333 |
+
output = self.resConfUnit2(output)
|
334 |
+
|
335 |
+
output = nn.functional.interpolate(
|
336 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
337 |
+
)
|
338 |
+
|
339 |
+
output = self.out_conv(output)
|
340 |
+
|
341 |
+
return output
|
342 |
+
|
extensions/sd-webui-controlnet/annotator/midas/midas/dpt_depth.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .base_model import BaseModel
|
6 |
+
from .blocks import (
|
7 |
+
FeatureFusionBlock,
|
8 |
+
FeatureFusionBlock_custom,
|
9 |
+
Interpolate,
|
10 |
+
_make_encoder,
|
11 |
+
forward_vit,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def _make_fusion_block(features, use_bn):
|
16 |
+
return FeatureFusionBlock_custom(
|
17 |
+
features,
|
18 |
+
nn.ReLU(False),
|
19 |
+
deconv=False,
|
20 |
+
bn=use_bn,
|
21 |
+
expand=False,
|
22 |
+
align_corners=True,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class DPT(BaseModel):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
head,
|
30 |
+
features=256,
|
31 |
+
backbone="vitb_rn50_384",
|
32 |
+
readout="project",
|
33 |
+
channels_last=False,
|
34 |
+
use_bn=False,
|
35 |
+
):
|
36 |
+
|
37 |
+
super(DPT, self).__init__()
|
38 |
+
|
39 |
+
self.channels_last = channels_last
|
40 |
+
|
41 |
+
hooks = {
|
42 |
+
"vitb_rn50_384": [0, 1, 8, 11],
|
43 |
+
"vitb16_384": [2, 5, 8, 11],
|
44 |
+
"vitl16_384": [5, 11, 17, 23],
|
45 |
+
}
|
46 |
+
|
47 |
+
# Instantiate backbone and reassemble blocks
|
48 |
+
self.pretrained, self.scratch = _make_encoder(
|
49 |
+
backbone,
|
50 |
+
features,
|
51 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
52 |
+
groups=1,
|
53 |
+
expand=False,
|
54 |
+
exportable=False,
|
55 |
+
hooks=hooks[backbone],
|
56 |
+
use_readout=readout,
|
57 |
+
)
|
58 |
+
|
59 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
60 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
61 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
62 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
63 |
+
|
64 |
+
self.scratch.output_conv = head
|
65 |
+
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
if self.channels_last == True:
|
69 |
+
x.contiguous(memory_format=torch.channels_last)
|
70 |
+
|
71 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
72 |
+
|
73 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
74 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
75 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
76 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
77 |
+
|
78 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
79 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
80 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
81 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
82 |
+
|
83 |
+
out = self.scratch.output_conv(path_1)
|
84 |
+
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
class DPTDepthModel(DPT):
|
89 |
+
def __init__(self, path=None, non_negative=True, **kwargs):
|
90 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
91 |
+
|
92 |
+
head = nn.Sequential(
|
93 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
94 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
95 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
96 |
+
nn.ReLU(True),
|
97 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
98 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
99 |
+
nn.Identity(),
|
100 |
+
)
|
101 |
+
|
102 |
+
super().__init__(head, **kwargs)
|
103 |
+
|
104 |
+
if path is not None:
|
105 |
+
self.load(path)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return super().forward(x).squeeze(dim=1)
|
109 |
+
|
extensions/sd-webui-controlnet/annotator/midas/midas/midas_net.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
17 |
+
"""Init.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
path (str, optional): Path to saved model. Defaults to None.
|
21 |
+
features (int, optional): Number of features. Defaults to 256.
|
22 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
23 |
+
"""
|
24 |
+
print("Loading weights: ", path)
|
25 |
+
|
26 |
+
super(MidasNet, self).__init__()
|
27 |
+
|
28 |
+
use_pretrained = False if path is None else True
|
29 |
+
|
30 |
+
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
31 |
+
|
32 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
33 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
34 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
35 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
36 |
+
|
37 |
+
self.scratch.output_conv = nn.Sequential(
|
38 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
39 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
40 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
41 |
+
nn.ReLU(True),
|
42 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
43 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
44 |
+
)
|
45 |
+
|
46 |
+
if path:
|
47 |
+
self.load(path)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
"""Forward pass.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
x (tensor): input data (image)
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tensor: depth
|
57 |
+
"""
|
58 |
+
|
59 |
+
layer_1 = self.pretrained.layer1(x)
|
60 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
61 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
62 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
63 |
+
|
64 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
65 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
66 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
67 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
68 |
+
|
69 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
70 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
71 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
72 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
73 |
+
|
74 |
+
out = self.scratch.output_conv(path_1)
|
75 |
+
|
76 |
+
return torch.squeeze(out, dim=1)
|
extensions/sd-webui-controlnet/annotator/midas/midas/midas_net_custom.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet_small(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
17 |
+
blocks={'expand': True}):
|
18 |
+
"""Init.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
path (str, optional): Path to saved model. Defaults to None.
|
22 |
+
features (int, optional): Number of features. Defaults to 256.
|
23 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
24 |
+
"""
|
25 |
+
print("Loading weights: ", path)
|
26 |
+
|
27 |
+
super(MidasNet_small, self).__init__()
|
28 |
+
|
29 |
+
use_pretrained = False if path else True
|
30 |
+
|
31 |
+
self.channels_last = channels_last
|
32 |
+
self.blocks = blocks
|
33 |
+
self.backbone = backbone
|
34 |
+
|
35 |
+
self.groups = 1
|
36 |
+
|
37 |
+
features1=features
|
38 |
+
features2=features
|
39 |
+
features3=features
|
40 |
+
features4=features
|
41 |
+
self.expand = False
|
42 |
+
if "expand" in self.blocks and self.blocks['expand'] == True:
|
43 |
+
self.expand = True
|
44 |
+
features1=features
|
45 |
+
features2=features*2
|
46 |
+
features3=features*4
|
47 |
+
features4=features*8
|
48 |
+
|
49 |
+
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
50 |
+
|
51 |
+
self.scratch.activation = nn.ReLU(False)
|
52 |
+
|
53 |
+
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
54 |
+
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
55 |
+
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
56 |
+
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
57 |
+
|
58 |
+
|
59 |
+
self.scratch.output_conv = nn.Sequential(
|
60 |
+
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
61 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
62 |
+
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
63 |
+
self.scratch.activation,
|
64 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
65 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
66 |
+
nn.Identity(),
|
67 |
+
)
|
68 |
+
|
69 |
+
if path:
|
70 |
+
self.load(path)
|
71 |
+
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
"""Forward pass.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
x (tensor): input data (image)
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
tensor: depth
|
81 |
+
"""
|
82 |
+
if self.channels_last==True:
|
83 |
+
print("self.channels_last = ", self.channels_last)
|
84 |
+
x.contiguous(memory_format=torch.channels_last)
|
85 |
+
|
86 |
+
|
87 |
+
layer_1 = self.pretrained.layer1(x)
|
88 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
89 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
90 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
91 |
+
|
92 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
93 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
94 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
95 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
96 |
+
|
97 |
+
|
98 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
99 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
100 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
101 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
102 |
+
|
103 |
+
out = self.scratch.output_conv(path_1)
|
104 |
+
|
105 |
+
return torch.squeeze(out, dim=1)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def fuse_model(m):
|
110 |
+
prev_previous_type = nn.Identity()
|
111 |
+
prev_previous_name = ''
|
112 |
+
previous_type = nn.Identity()
|
113 |
+
previous_name = ''
|
114 |
+
for name, module in m.named_modules():
|
115 |
+
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
116 |
+
# print("FUSED ", prev_previous_name, previous_name, name)
|
117 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
118 |
+
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
119 |
+
# print("FUSED ", prev_previous_name, previous_name)
|
120 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
121 |
+
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
122 |
+
# print("FUSED ", previous_name, name)
|
123 |
+
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
124 |
+
|
125 |
+
prev_previous_type = previous_type
|
126 |
+
prev_previous_name = previous_name
|
127 |
+
previous_type = type(module)
|
128 |
+
previous_name = name
|
extensions/sd-webui-controlnet/annotator/midas/midas/transforms.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
7 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
sample (dict): sample
|
11 |
+
size (tuple): image size
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
tuple: new size
|
15 |
+
"""
|
16 |
+
shape = list(sample["disparity"].shape)
|
17 |
+
|
18 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
19 |
+
return sample
|
20 |
+
|
21 |
+
scale = [0, 0]
|
22 |
+
scale[0] = size[0] / shape[0]
|
23 |
+
scale[1] = size[1] / shape[1]
|
24 |
+
|
25 |
+
scale = max(scale)
|
26 |
+
|
27 |
+
shape[0] = math.ceil(scale * shape[0])
|
28 |
+
shape[1] = math.ceil(scale * shape[1])
|
29 |
+
|
30 |
+
# resize
|
31 |
+
sample["image"] = cv2.resize(
|
32 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
33 |
+
)
|
34 |
+
|
35 |
+
sample["disparity"] = cv2.resize(
|
36 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
37 |
+
)
|
38 |
+
sample["mask"] = cv2.resize(
|
39 |
+
sample["mask"].astype(np.float32),
|
40 |
+
tuple(shape[::-1]),
|
41 |
+
interpolation=cv2.INTER_NEAREST,
|
42 |
+
)
|
43 |
+
sample["mask"] = sample["mask"].astype(bool)
|
44 |
+
|
45 |
+
return tuple(shape)
|
46 |
+
|
47 |
+
|
48 |
+
class Resize(object):
|
49 |
+
"""Resize sample to given size (width, height).
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
width,
|
55 |
+
height,
|
56 |
+
resize_target=True,
|
57 |
+
keep_aspect_ratio=False,
|
58 |
+
ensure_multiple_of=1,
|
59 |
+
resize_method="lower_bound",
|
60 |
+
image_interpolation_method=cv2.INTER_AREA,
|
61 |
+
):
|
62 |
+
"""Init.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
width (int): desired output width
|
66 |
+
height (int): desired output height
|
67 |
+
resize_target (bool, optional):
|
68 |
+
True: Resize the full sample (image, mask, target).
|
69 |
+
False: Resize image only.
|
70 |
+
Defaults to True.
|
71 |
+
keep_aspect_ratio (bool, optional):
|
72 |
+
True: Keep the aspect ratio of the input sample.
|
73 |
+
Output sample might not have the given width and height, and
|
74 |
+
resize behaviour depends on the parameter 'resize_method'.
|
75 |
+
Defaults to False.
|
76 |
+
ensure_multiple_of (int, optional):
|
77 |
+
Output width and height is constrained to be multiple of this parameter.
|
78 |
+
Defaults to 1.
|
79 |
+
resize_method (str, optional):
|
80 |
+
"lower_bound": Output will be at least as large as the given size.
|
81 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
82 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
83 |
+
Defaults to "lower_bound".
|
84 |
+
"""
|
85 |
+
self.__width = width
|
86 |
+
self.__height = height
|
87 |
+
|
88 |
+
self.__resize_target = resize_target
|
89 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
90 |
+
self.__multiple_of = ensure_multiple_of
|
91 |
+
self.__resize_method = resize_method
|
92 |
+
self.__image_interpolation_method = image_interpolation_method
|
93 |
+
|
94 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
95 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
96 |
+
|
97 |
+
if max_val is not None and y > max_val:
|
98 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
99 |
+
|
100 |
+
if y < min_val:
|
101 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
102 |
+
|
103 |
+
return y
|
104 |
+
|
105 |
+
def get_size(self, width, height):
|
106 |
+
# determine new height and width
|
107 |
+
scale_height = self.__height / height
|
108 |
+
scale_width = self.__width / width
|
109 |
+
|
110 |
+
if self.__keep_aspect_ratio:
|
111 |
+
if self.__resize_method == "lower_bound":
|
112 |
+
# scale such that output size is lower bound
|
113 |
+
if scale_width > scale_height:
|
114 |
+
# fit width
|
115 |
+
scale_height = scale_width
|
116 |
+
else:
|
117 |
+
# fit height
|
118 |
+
scale_width = scale_height
|
119 |
+
elif self.__resize_method == "upper_bound":
|
120 |
+
# scale such that output size is upper bound
|
121 |
+
if scale_width < scale_height:
|
122 |
+
# fit width
|
123 |
+
scale_height = scale_width
|
124 |
+
else:
|
125 |
+
# fit height
|
126 |
+
scale_width = scale_height
|
127 |
+
elif self.__resize_method == "minimal":
|
128 |
+
# scale as least as possbile
|
129 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
130 |
+
# fit width
|
131 |
+
scale_height = scale_width
|
132 |
+
else:
|
133 |
+
# fit height
|
134 |
+
scale_width = scale_height
|
135 |
+
else:
|
136 |
+
raise ValueError(
|
137 |
+
f"resize_method {self.__resize_method} not implemented"
|
138 |
+
)
|
139 |
+
|
140 |
+
if self.__resize_method == "lower_bound":
|
141 |
+
new_height = self.constrain_to_multiple_of(
|
142 |
+
scale_height * height, min_val=self.__height
|
143 |
+
)
|
144 |
+
new_width = self.constrain_to_multiple_of(
|
145 |
+
scale_width * width, min_val=self.__width
|
146 |
+
)
|
147 |
+
elif self.__resize_method == "upper_bound":
|
148 |
+
new_height = self.constrain_to_multiple_of(
|
149 |
+
scale_height * height, max_val=self.__height
|
150 |
+
)
|
151 |
+
new_width = self.constrain_to_multiple_of(
|
152 |
+
scale_width * width, max_val=self.__width
|
153 |
+
)
|
154 |
+
elif self.__resize_method == "minimal":
|
155 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
156 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
157 |
+
else:
|
158 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
159 |
+
|
160 |
+
return (new_width, new_height)
|
161 |
+
|
162 |
+
def __call__(self, sample):
|
163 |
+
width, height = self.get_size(
|
164 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
165 |
+
)
|
166 |
+
|
167 |
+
# resize sample
|
168 |
+
sample["image"] = cv2.resize(
|
169 |
+
sample["image"],
|
170 |
+
(width, height),
|
171 |
+
interpolation=self.__image_interpolation_method,
|
172 |
+
)
|
173 |
+
|
174 |
+
if self.__resize_target:
|
175 |
+
if "disparity" in sample:
|
176 |
+
sample["disparity"] = cv2.resize(
|
177 |
+
sample["disparity"],
|
178 |
+
(width, height),
|
179 |
+
interpolation=cv2.INTER_NEAREST,
|
180 |
+
)
|
181 |
+
|
182 |
+
if "depth" in sample:
|
183 |
+
sample["depth"] = cv2.resize(
|
184 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
185 |
+
)
|
186 |
+
|
187 |
+
sample["mask"] = cv2.resize(
|
188 |
+
sample["mask"].astype(np.float32),
|
189 |
+
(width, height),
|
190 |
+
interpolation=cv2.INTER_NEAREST,
|
191 |
+
)
|
192 |
+
sample["mask"] = sample["mask"].astype(bool)
|
193 |
+
|
194 |
+
return sample
|
195 |
+
|
196 |
+
|
197 |
+
class NormalizeImage(object):
|
198 |
+
"""Normlize image by given mean and std.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, mean, std):
|
202 |
+
self.__mean = mean
|
203 |
+
self.__std = std
|
204 |
+
|
205 |
+
def __call__(self, sample):
|
206 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
207 |
+
|
208 |
+
return sample
|
209 |
+
|
210 |
+
|
211 |
+
class PrepareForNet(object):
|
212 |
+
"""Prepare sample for usage as network input.
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self):
|
216 |
+
pass
|
217 |
+
|
218 |
+
def __call__(self, sample):
|
219 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
220 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
221 |
+
|
222 |
+
if "mask" in sample:
|
223 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
224 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
225 |
+
|
226 |
+
if "disparity" in sample:
|
227 |
+
disparity = sample["disparity"].astype(np.float32)
|
228 |
+
sample["disparity"] = np.ascontiguousarray(disparity)
|
229 |
+
|
230 |
+
if "depth" in sample:
|
231 |
+
depth = sample["depth"].astype(np.float32)
|
232 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
233 |
+
|
234 |
+
return sample
|
extensions/sd-webui-controlnet/annotator/midas/midas/vit.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import timm
|
4 |
+
import types
|
5 |
+
import math
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class Slice(nn.Module):
|
10 |
+
def __init__(self, start_index=1):
|
11 |
+
super(Slice, self).__init__()
|
12 |
+
self.start_index = start_index
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
return x[:, self.start_index :]
|
16 |
+
|
17 |
+
|
18 |
+
class AddReadout(nn.Module):
|
19 |
+
def __init__(self, start_index=1):
|
20 |
+
super(AddReadout, self).__init__()
|
21 |
+
self.start_index = start_index
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
if self.start_index == 2:
|
25 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
26 |
+
else:
|
27 |
+
readout = x[:, 0]
|
28 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
29 |
+
|
30 |
+
|
31 |
+
class ProjectReadout(nn.Module):
|
32 |
+
def __init__(self, in_features, start_index=1):
|
33 |
+
super(ProjectReadout, self).__init__()
|
34 |
+
self.start_index = start_index
|
35 |
+
|
36 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
40 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
41 |
+
|
42 |
+
return self.project(features)
|
43 |
+
|
44 |
+
|
45 |
+
class Transpose(nn.Module):
|
46 |
+
def __init__(self, dim0, dim1):
|
47 |
+
super(Transpose, self).__init__()
|
48 |
+
self.dim0 = dim0
|
49 |
+
self.dim1 = dim1
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = x.transpose(self.dim0, self.dim1)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
def forward_vit(pretrained, x):
|
57 |
+
b, c, h, w = x.shape
|
58 |
+
|
59 |
+
glob = pretrained.model.forward_flex(x)
|
60 |
+
|
61 |
+
layer_1 = pretrained.activations["1"]
|
62 |
+
layer_2 = pretrained.activations["2"]
|
63 |
+
layer_3 = pretrained.activations["3"]
|
64 |
+
layer_4 = pretrained.activations["4"]
|
65 |
+
|
66 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
67 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
68 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
69 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
70 |
+
|
71 |
+
unflatten = nn.Sequential(
|
72 |
+
nn.Unflatten(
|
73 |
+
2,
|
74 |
+
torch.Size(
|
75 |
+
[
|
76 |
+
h // pretrained.model.patch_size[1],
|
77 |
+
w // pretrained.model.patch_size[0],
|
78 |
+
]
|
79 |
+
),
|
80 |
+
)
|
81 |
+
)
|
82 |
+
|
83 |
+
if layer_1.ndim == 3:
|
84 |
+
layer_1 = unflatten(layer_1)
|
85 |
+
if layer_2.ndim == 3:
|
86 |
+
layer_2 = unflatten(layer_2)
|
87 |
+
if layer_3.ndim == 3:
|
88 |
+
layer_3 = unflatten(layer_3)
|
89 |
+
if layer_4.ndim == 3:
|
90 |
+
layer_4 = unflatten(layer_4)
|
91 |
+
|
92 |
+
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
93 |
+
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
94 |
+
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
95 |
+
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
96 |
+
|
97 |
+
return layer_1, layer_2, layer_3, layer_4
|
98 |
+
|
99 |
+
|
100 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
101 |
+
posemb_tok, posemb_grid = (
|
102 |
+
posemb[:, : self.start_index],
|
103 |
+
posemb[0, self.start_index :],
|
104 |
+
)
|
105 |
+
|
106 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
107 |
+
|
108 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
109 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
110 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
111 |
+
|
112 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
113 |
+
|
114 |
+
return posemb
|
115 |
+
|
116 |
+
|
117 |
+
def forward_flex(self, x):
|
118 |
+
b, c, h, w = x.shape
|
119 |
+
|
120 |
+
pos_embed = self._resize_pos_embed(
|
121 |
+
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
122 |
+
)
|
123 |
+
|
124 |
+
B = x.shape[0]
|
125 |
+
|
126 |
+
if hasattr(self.patch_embed, "backbone"):
|
127 |
+
x = self.patch_embed.backbone(x)
|
128 |
+
if isinstance(x, (list, tuple)):
|
129 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
130 |
+
|
131 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
132 |
+
|
133 |
+
if getattr(self, "dist_token", None) is not None:
|
134 |
+
cls_tokens = self.cls_token.expand(
|
135 |
+
B, -1, -1
|
136 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
137 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
138 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
139 |
+
else:
|
140 |
+
cls_tokens = self.cls_token.expand(
|
141 |
+
B, -1, -1
|
142 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
143 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
144 |
+
|
145 |
+
x = x + pos_embed
|
146 |
+
x = self.pos_drop(x)
|
147 |
+
|
148 |
+
for blk in self.blocks:
|
149 |
+
x = blk(x)
|
150 |
+
|
151 |
+
x = self.norm(x)
|
152 |
+
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
activations = {}
|
157 |
+
|
158 |
+
|
159 |
+
def get_activation(name):
|
160 |
+
def hook(model, input, output):
|
161 |
+
activations[name] = output
|
162 |
+
|
163 |
+
return hook
|
164 |
+
|
165 |
+
|
166 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
167 |
+
if use_readout == "ignore":
|
168 |
+
readout_oper = [Slice(start_index)] * len(features)
|
169 |
+
elif use_readout == "add":
|
170 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
171 |
+
elif use_readout == "project":
|
172 |
+
readout_oper = [
|
173 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
174 |
+
]
|
175 |
+
else:
|
176 |
+
assert (
|
177 |
+
False
|
178 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
179 |
+
|
180 |
+
return readout_oper
|
181 |
+
|
182 |
+
|
183 |
+
def _make_vit_b16_backbone(
|
184 |
+
model,
|
185 |
+
features=[96, 192, 384, 768],
|
186 |
+
size=[384, 384],
|
187 |
+
hooks=[2, 5, 8, 11],
|
188 |
+
vit_features=768,
|
189 |
+
use_readout="ignore",
|
190 |
+
start_index=1,
|
191 |
+
):
|
192 |
+
pretrained = nn.Module()
|
193 |
+
|
194 |
+
pretrained.model = model
|
195 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
196 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
197 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
198 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
199 |
+
|
200 |
+
pretrained.activations = activations
|
201 |
+
|
202 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
203 |
+
|
204 |
+
# 32, 48, 136, 384
|
205 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
206 |
+
readout_oper[0],
|
207 |
+
Transpose(1, 2),
|
208 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
209 |
+
nn.Conv2d(
|
210 |
+
in_channels=vit_features,
|
211 |
+
out_channels=features[0],
|
212 |
+
kernel_size=1,
|
213 |
+
stride=1,
|
214 |
+
padding=0,
|
215 |
+
),
|
216 |
+
nn.ConvTranspose2d(
|
217 |
+
in_channels=features[0],
|
218 |
+
out_channels=features[0],
|
219 |
+
kernel_size=4,
|
220 |
+
stride=4,
|
221 |
+
padding=0,
|
222 |
+
bias=True,
|
223 |
+
dilation=1,
|
224 |
+
groups=1,
|
225 |
+
),
|
226 |
+
)
|
227 |
+
|
228 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
229 |
+
readout_oper[1],
|
230 |
+
Transpose(1, 2),
|
231 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
232 |
+
nn.Conv2d(
|
233 |
+
in_channels=vit_features,
|
234 |
+
out_channels=features[1],
|
235 |
+
kernel_size=1,
|
236 |
+
stride=1,
|
237 |
+
padding=0,
|
238 |
+
),
|
239 |
+
nn.ConvTranspose2d(
|
240 |
+
in_channels=features[1],
|
241 |
+
out_channels=features[1],
|
242 |
+
kernel_size=2,
|
243 |
+
stride=2,
|
244 |
+
padding=0,
|
245 |
+
bias=True,
|
246 |
+
dilation=1,
|
247 |
+
groups=1,
|
248 |
+
),
|
249 |
+
)
|
250 |
+
|
251 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
252 |
+
readout_oper[2],
|
253 |
+
Transpose(1, 2),
|
254 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
255 |
+
nn.Conv2d(
|
256 |
+
in_channels=vit_features,
|
257 |
+
out_channels=features[2],
|
258 |
+
kernel_size=1,
|
259 |
+
stride=1,
|
260 |
+
padding=0,
|
261 |
+
),
|
262 |
+
)
|
263 |
+
|
264 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
265 |
+
readout_oper[3],
|
266 |
+
Transpose(1, 2),
|
267 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
268 |
+
nn.Conv2d(
|
269 |
+
in_channels=vit_features,
|
270 |
+
out_channels=features[3],
|
271 |
+
kernel_size=1,
|
272 |
+
stride=1,
|
273 |
+
padding=0,
|
274 |
+
),
|
275 |
+
nn.Conv2d(
|
276 |
+
in_channels=features[3],
|
277 |
+
out_channels=features[3],
|
278 |
+
kernel_size=3,
|
279 |
+
stride=2,
|
280 |
+
padding=1,
|
281 |
+
),
|
282 |
+
)
|
283 |
+
|
284 |
+
pretrained.model.start_index = start_index
|
285 |
+
pretrained.model.patch_size = [16, 16]
|
286 |
+
|
287 |
+
# We inject this function into the VisionTransformer instances so that
|
288 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
289 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
290 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
291 |
+
_resize_pos_embed, pretrained.model
|
292 |
+
)
|
293 |
+
|
294 |
+
return pretrained
|
295 |
+
|
296 |
+
|
297 |
+
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
298 |
+
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
299 |
+
|
300 |
+
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
301 |
+
return _make_vit_b16_backbone(
|
302 |
+
model,
|
303 |
+
features=[256, 512, 1024, 1024],
|
304 |
+
hooks=hooks,
|
305 |
+
vit_features=1024,
|
306 |
+
use_readout=use_readout,
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
311 |
+
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
312 |
+
|
313 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
314 |
+
return _make_vit_b16_backbone(
|
315 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
316 |
+
)
|
317 |
+
|
318 |
+
|
319 |
+
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
320 |
+
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
321 |
+
|
322 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
323 |
+
return _make_vit_b16_backbone(
|
324 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
325 |
+
)
|
326 |
+
|
327 |
+
|
328 |
+
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
329 |
+
model = timm.create_model(
|
330 |
+
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
331 |
+
)
|
332 |
+
|
333 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
334 |
+
return _make_vit_b16_backbone(
|
335 |
+
model,
|
336 |
+
features=[96, 192, 384, 768],
|
337 |
+
hooks=hooks,
|
338 |
+
use_readout=use_readout,
|
339 |
+
start_index=2,
|
340 |
+
)
|
341 |
+
|
342 |
+
|
343 |
+
def _make_vit_b_rn50_backbone(
|
344 |
+
model,
|
345 |
+
features=[256, 512, 768, 768],
|
346 |
+
size=[384, 384],
|
347 |
+
hooks=[0, 1, 8, 11],
|
348 |
+
vit_features=768,
|
349 |
+
use_vit_only=False,
|
350 |
+
use_readout="ignore",
|
351 |
+
start_index=1,
|
352 |
+
):
|
353 |
+
pretrained = nn.Module()
|
354 |
+
|
355 |
+
pretrained.model = model
|
356 |
+
|
357 |
+
if use_vit_only == True:
|
358 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
359 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
360 |
+
else:
|
361 |
+
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
362 |
+
get_activation("1")
|
363 |
+
)
|
364 |
+
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
365 |
+
get_activation("2")
|
366 |
+
)
|
367 |
+
|
368 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
369 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
370 |
+
|
371 |
+
pretrained.activations = activations
|
372 |
+
|
373 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
374 |
+
|
375 |
+
if use_vit_only == True:
|
376 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
377 |
+
readout_oper[0],
|
378 |
+
Transpose(1, 2),
|
379 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
380 |
+
nn.Conv2d(
|
381 |
+
in_channels=vit_features,
|
382 |
+
out_channels=features[0],
|
383 |
+
kernel_size=1,
|
384 |
+
stride=1,
|
385 |
+
padding=0,
|
386 |
+
),
|
387 |
+
nn.ConvTranspose2d(
|
388 |
+
in_channels=features[0],
|
389 |
+
out_channels=features[0],
|
390 |
+
kernel_size=4,
|
391 |
+
stride=4,
|
392 |
+
padding=0,
|
393 |
+
bias=True,
|
394 |
+
dilation=1,
|
395 |
+
groups=1,
|
396 |
+
),
|
397 |
+
)
|
398 |
+
|
399 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
400 |
+
readout_oper[1],
|
401 |
+
Transpose(1, 2),
|
402 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
403 |
+
nn.Conv2d(
|
404 |
+
in_channels=vit_features,
|
405 |
+
out_channels=features[1],
|
406 |
+
kernel_size=1,
|
407 |
+
stride=1,
|
408 |
+
padding=0,
|
409 |
+
),
|
410 |
+
nn.ConvTranspose2d(
|
411 |
+
in_channels=features[1],
|
412 |
+
out_channels=features[1],
|
413 |
+
kernel_size=2,
|
414 |
+
stride=2,
|
415 |
+
padding=0,
|
416 |
+
bias=True,
|
417 |
+
dilation=1,
|
418 |
+
groups=1,
|
419 |
+
),
|
420 |
+
)
|
421 |
+
else:
|
422 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
423 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
424 |
+
)
|
425 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
426 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
427 |
+
)
|
428 |
+
|
429 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
430 |
+
readout_oper[2],
|
431 |
+
Transpose(1, 2),
|
432 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
433 |
+
nn.Conv2d(
|
434 |
+
in_channels=vit_features,
|
435 |
+
out_channels=features[2],
|
436 |
+
kernel_size=1,
|
437 |
+
stride=1,
|
438 |
+
padding=0,
|
439 |
+
),
|
440 |
+
)
|
441 |
+
|
442 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
443 |
+
readout_oper[3],
|
444 |
+
Transpose(1, 2),
|
445 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
446 |
+
nn.Conv2d(
|
447 |
+
in_channels=vit_features,
|
448 |
+
out_channels=features[3],
|
449 |
+
kernel_size=1,
|
450 |
+
stride=1,
|
451 |
+
padding=0,
|
452 |
+
),
|
453 |
+
nn.Conv2d(
|
454 |
+
in_channels=features[3],
|
455 |
+
out_channels=features[3],
|
456 |
+
kernel_size=3,
|
457 |
+
stride=2,
|
458 |
+
padding=1,
|
459 |
+
),
|
460 |
+
)
|
461 |
+
|
462 |
+
pretrained.model.start_index = start_index
|
463 |
+
pretrained.model.patch_size = [16, 16]
|
464 |
+
|
465 |
+
# We inject this function into the VisionTransformer instances so that
|
466 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
467 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
468 |
+
|
469 |
+
# We inject this function into the VisionTransformer instances so that
|
470 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
471 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
472 |
+
_resize_pos_embed, pretrained.model
|
473 |
+
)
|
474 |
+
|
475 |
+
return pretrained
|
476 |
+
|
477 |
+
|
478 |
+
def _make_pretrained_vitb_rn50_384(
|
479 |
+
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
480 |
+
):
|
481 |
+
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
482 |
+
|
483 |
+
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
484 |
+
return _make_vit_b_rn50_backbone(
|
485 |
+
model,
|
486 |
+
features=[256, 512, 768, 768],
|
487 |
+
size=[384, 384],
|
488 |
+
hooks=hooks,
|
489 |
+
use_vit_only=use_vit_only,
|
490 |
+
use_readout=use_readout,
|
491 |
+
)
|
extensions/sd-webui-controlnet/annotator/midas/utils.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utils for monoDepth."""
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def read_pfm(path):
|
10 |
+
"""Read pfm file.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
path (str): path to file
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
tuple: (data, scale)
|
17 |
+
"""
|
18 |
+
with open(path, "rb") as file:
|
19 |
+
|
20 |
+
color = None
|
21 |
+
width = None
|
22 |
+
height = None
|
23 |
+
scale = None
|
24 |
+
endian = None
|
25 |
+
|
26 |
+
header = file.readline().rstrip()
|
27 |
+
if header.decode("ascii") == "PF":
|
28 |
+
color = True
|
29 |
+
elif header.decode("ascii") == "Pf":
|
30 |
+
color = False
|
31 |
+
else:
|
32 |
+
raise Exception("Not a PFM file: " + path)
|
33 |
+
|
34 |
+
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
|
35 |
+
if dim_match:
|
36 |
+
width, height = list(map(int, dim_match.groups()))
|
37 |
+
else:
|
38 |
+
raise Exception("Malformed PFM header.")
|
39 |
+
|
40 |
+
scale = float(file.readline().decode("ascii").rstrip())
|
41 |
+
if scale < 0:
|
42 |
+
# little-endian
|
43 |
+
endian = "<"
|
44 |
+
scale = -scale
|
45 |
+
else:
|
46 |
+
# big-endian
|
47 |
+
endian = ">"
|
48 |
+
|
49 |
+
data = np.fromfile(file, endian + "f")
|
50 |
+
shape = (height, width, 3) if color else (height, width)
|
51 |
+
|
52 |
+
data = np.reshape(data, shape)
|
53 |
+
data = np.flipud(data)
|
54 |
+
|
55 |
+
return data, scale
|
56 |
+
|
57 |
+
|
58 |
+
def write_pfm(path, image, scale=1):
|
59 |
+
"""Write pfm file.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
path (str): pathto file
|
63 |
+
image (array): data
|
64 |
+
scale (int, optional): Scale. Defaults to 1.
|
65 |
+
"""
|
66 |
+
|
67 |
+
with open(path, "wb") as file:
|
68 |
+
color = None
|
69 |
+
|
70 |
+
if image.dtype.name != "float32":
|
71 |
+
raise Exception("Image dtype must be float32.")
|
72 |
+
|
73 |
+
image = np.flipud(image)
|
74 |
+
|
75 |
+
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
76 |
+
color = True
|
77 |
+
elif (
|
78 |
+
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
|
79 |
+
): # greyscale
|
80 |
+
color = False
|
81 |
+
else:
|
82 |
+
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
83 |
+
|
84 |
+
file.write("PF\n" if color else "Pf\n".encode())
|
85 |
+
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
86 |
+
|
87 |
+
endian = image.dtype.byteorder
|
88 |
+
|
89 |
+
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
90 |
+
scale = -scale
|
91 |
+
|
92 |
+
file.write("%f\n".encode() % scale)
|
93 |
+
|
94 |
+
image.tofile(file)
|
95 |
+
|
96 |
+
|
97 |
+
def read_image(path):
|
98 |
+
"""Read image and output RGB image (0-1).
|
99 |
+
|
100 |
+
Args:
|
101 |
+
path (str): path to file
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
array: RGB image (0-1)
|
105 |
+
"""
|
106 |
+
img = cv2.imread(path)
|
107 |
+
|
108 |
+
if img.ndim == 2:
|
109 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
110 |
+
|
111 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
112 |
+
|
113 |
+
return img
|
114 |
+
|
115 |
+
|
116 |
+
def resize_image(img):
|
117 |
+
"""Resize image and make it fit for network.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
img (array): image
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
tensor: data ready for network
|
124 |
+
"""
|
125 |
+
height_orig = img.shape[0]
|
126 |
+
width_orig = img.shape[1]
|
127 |
+
|
128 |
+
if width_orig > height_orig:
|
129 |
+
scale = width_orig / 384
|
130 |
+
else:
|
131 |
+
scale = height_orig / 384
|
132 |
+
|
133 |
+
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
134 |
+
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
135 |
+
|
136 |
+
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
|
137 |
+
|
138 |
+
img_resized = (
|
139 |
+
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
|
140 |
+
)
|
141 |
+
img_resized = img_resized.unsqueeze(0)
|
142 |
+
|
143 |
+
return img_resized
|
144 |
+
|
145 |
+
|
146 |
+
def resize_depth(depth, width, height):
|
147 |
+
"""Resize depth map and bring to CPU (numpy).
|
148 |
+
|
149 |
+
Args:
|
150 |
+
depth (tensor): depth
|
151 |
+
width (int): image width
|
152 |
+
height (int): image height
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
array: processed depth
|
156 |
+
"""
|
157 |
+
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
|
158 |
+
|
159 |
+
depth_resized = cv2.resize(
|
160 |
+
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
|
161 |
+
)
|
162 |
+
|
163 |
+
return depth_resized
|
164 |
+
|
165 |
+
def write_depth(path, depth, bits=1):
|
166 |
+
"""Write depth map to pfm and png file.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
path (str): filepath without extension
|
170 |
+
depth (array): depth
|
171 |
+
"""
|
172 |
+
write_pfm(path + ".pfm", depth.astype(np.float32))
|
173 |
+
|
174 |
+
depth_min = depth.min()
|
175 |
+
depth_max = depth.max()
|
176 |
+
|
177 |
+
max_val = (2**(8*bits))-1
|
178 |
+
|
179 |
+
if depth_max - depth_min > np.finfo("float").eps:
|
180 |
+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
181 |
+
else:
|
182 |
+
out = np.zeros(depth.shape, dtype=depth.type)
|
183 |
+
|
184 |
+
if bits == 1:
|
185 |
+
cv2.imwrite(path + ".png", out.astype("uint8"))
|
186 |
+
elif bits == 2:
|
187 |
+
cv2.imwrite(path + ".png", out.astype("uint16"))
|
188 |
+
|
189 |
+
return
|
extensions/sd-webui-controlnet/annotator/mlsd/__init__.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
|
6 |
+
from einops import rearrange
|
7 |
+
from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny
|
8 |
+
from .models.mbv2_mlsd_large import MobileV2_MLSD_Large
|
9 |
+
from .utils import pred_lines
|
10 |
+
from modules import devices
|
11 |
+
from modules.paths import models_path
|
12 |
+
|
13 |
+
mlsdmodel = None
|
14 |
+
remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/mlsd_large_512_fp32.pth"
|
15 |
+
old_modeldir = os.path.dirname(os.path.realpath(__file__))
|
16 |
+
modeldir = os.path.join(models_path, "mlsd")
|
17 |
+
|
18 |
+
def unload_mlsd_model():
|
19 |
+
global mlsdmodel
|
20 |
+
if mlsdmodel is not None:
|
21 |
+
mlsdmodel = mlsdmodel.cpu()
|
22 |
+
|
23 |
+
def apply_mlsd(input_image, thr_v, thr_d):
|
24 |
+
global modelpath, mlsdmodel
|
25 |
+
if mlsdmodel is None:
|
26 |
+
modelpath = os.path.join(modeldir, "mlsd_large_512_fp32.pth")
|
27 |
+
old_modelpath = os.path.join(old_modeldir, "mlsd_large_512_fp32.pth")
|
28 |
+
if os.path.exists(old_modelpath):
|
29 |
+
modelpath = old_modelpath
|
30 |
+
elif not os.path.exists(modelpath):
|
31 |
+
from basicsr.utils.download_util import load_file_from_url
|
32 |
+
load_file_from_url(remote_model_path, model_dir=modeldir)
|
33 |
+
mlsdmodel = MobileV2_MLSD_Large()
|
34 |
+
mlsdmodel.load_state_dict(torch.load(modelpath), strict=True)
|
35 |
+
mlsdmodel = mlsdmodel.to(devices.get_device_for("controlnet")).eval()
|
36 |
+
|
37 |
+
model = mlsdmodel
|
38 |
+
assert input_image.ndim == 3
|
39 |
+
img = input_image
|
40 |
+
img_output = np.zeros_like(img)
|
41 |
+
try:
|
42 |
+
with torch.no_grad():
|
43 |
+
lines = pred_lines(img, model, [img.shape[0], img.shape[1]], thr_v, thr_d)
|
44 |
+
for line in lines:
|
45 |
+
x_start, y_start, x_end, y_end = [int(val) for val in line]
|
46 |
+
cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
|
47 |
+
except Exception as e:
|
48 |
+
pass
|
49 |
+
return img_output[:, :, 0]
|