Spaces:
Runtime error
Runtime error
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +160 -0
- README.md +1 -1
- annotator/canny/__init__.py +6 -0
- annotator/ckpts/dpt_hybrid-midas-501f0c75.pt +3 -0
- annotator/midas/__init__.py +38 -0
- annotator/midas/api.py +169 -0
- annotator/midas/midas/__init__.py +0 -0
- annotator/midas/midas/base_model.py +16 -0
- annotator/midas/midas/blocks.py +342 -0
- annotator/midas/midas/dpt_depth.py +109 -0
- annotator/midas/midas/midas_net.py +76 -0
- annotator/midas/midas/midas_net_custom.py +128 -0
- annotator/midas/midas/transforms.py +234 -0
- annotator/midas/midas/vit.py +491 -0
- annotator/midas/utils.py +189 -0
- annotator/util.py +38 -0
- app.py +447 -0
- ckpt/cldm_v15.yaml +79 -0
- ckpt/control_sd15_canny.pth +3 -0
- ckpt/control_sd15_depth.pth +3 -0
- ckpt/dpt_hybrid-midas-501f0c75.pt +3 -0
- cldm/cldm.py +429 -0
- cldm/hack.py +111 -0
- cldm/logger.py +76 -0
- cldm/model.py +28 -0
- data/bear/bear.mp4 +0 -0
- data/bear/bear/00000.jpg +0 -0
- data/bear/bear/00001.jpg +0 -0
- data/bear/bear/00002.jpg +0 -0
- data/bear/bear/00003.jpg +0 -0
- data/bear/bear/00004.jpg +0 -0
- data/bear/bear/00005.jpg +0 -0
- data/bear/bear/00006.jpg +0 -0
- data/bear/bear/00007.jpg +0 -0
- data/bear/bear/00008.jpg +0 -0
- data/bear/bear/00009.jpg +0 -0
- data/bear/bear/00010.jpg +0 -0
- data/bear/bear/00011.jpg +0 -0
- data/bear/bear/00012.jpg +0 -0
- data/bear/bear/00013.jpg +0 -0
- data/bear/bear/00014.jpg +0 -0
- data/bear/bear/00015.jpg +0 -0
- data/bear/bear/00016.jpg +0 -0
- data/bear/bear/00017.jpg +0 -0
- data/bear/bear/00018.jpg +0 -0
- data/bear/bear/00019.jpg +0 -0
- data/bear/bear/00020.jpg +0 -0
- data/bear/bear/00021.jpg +0 -0
- data/bear/bear/00022.jpg +0 -0
- data/bear/bear/00023.jpg +0 -0
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
README.md
CHANGED
@@ -6,7 +6,7 @@ colorTo: green
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.41.2
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.41.2
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
annotator/canny/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
|
4 |
+
class CannyDetector:
|
5 |
+
def __call__(self, img, low_threshold, high_threshold):
|
6 |
+
return cv2.Canny(img, low_threshold, high_threshold)
|
annotator/ckpts/dpt_hybrid-midas-501f0c75.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:501f0c75b3bca7daec6b3682c5054c09b366765aef6fa3a09d03a5cb4b230853
|
3 |
+
size 492757791
|
annotator/midas/__init__.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from einops import rearrange
|
6 |
+
from .api import MiDaSInference
|
7 |
+
|
8 |
+
|
9 |
+
class MidasDetector:
|
10 |
+
def __init__(self):
|
11 |
+
self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
|
12 |
+
|
13 |
+
def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
|
14 |
+
assert input_image.ndim == 3
|
15 |
+
image_depth = input_image
|
16 |
+
with torch.no_grad():
|
17 |
+
image_depth = torch.from_numpy(image_depth).float().cuda()
|
18 |
+
image_depth = image_depth / 127.5 - 1.0
|
19 |
+
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
20 |
+
depth = self.model(image_depth)[0]
|
21 |
+
|
22 |
+
depth_pt = depth.clone()
|
23 |
+
depth_pt -= torch.min(depth_pt)
|
24 |
+
depth_pt /= torch.max(depth_pt)
|
25 |
+
depth_pt = depth_pt.cpu().numpy()
|
26 |
+
depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
|
27 |
+
|
28 |
+
depth_np = depth.cpu().numpy()
|
29 |
+
x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
|
30 |
+
y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
|
31 |
+
z = np.ones_like(x) * a
|
32 |
+
x[depth_pt < bg_th] = 0
|
33 |
+
y[depth_pt < bg_th] = 0
|
34 |
+
normal = np.stack([x, y, z], axis=2)
|
35 |
+
normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
|
36 |
+
normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
37 |
+
|
38 |
+
return depth_image, normal_image
|
annotator/midas/api.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://github.com/isl-org/MiDaS
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision.transforms import Compose
|
8 |
+
|
9 |
+
from .midas.dpt_depth import DPTDepthModel
|
10 |
+
from .midas.midas_net import MidasNet
|
11 |
+
from .midas.midas_net_custom import MidasNet_small
|
12 |
+
from .midas.transforms import Resize, NormalizeImage, PrepareForNet
|
13 |
+
from annotator.util import annotator_ckpts_path
|
14 |
+
|
15 |
+
|
16 |
+
ISL_PATHS = {
|
17 |
+
"dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
|
18 |
+
"dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
|
19 |
+
"midas_v21": "",
|
20 |
+
"midas_v21_small": "",
|
21 |
+
}
|
22 |
+
|
23 |
+
remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
|
24 |
+
|
25 |
+
|
26 |
+
def disabled_train(self, mode=True):
|
27 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
28 |
+
does not change anymore."""
|
29 |
+
return self
|
30 |
+
|
31 |
+
|
32 |
+
def load_midas_transform(model_type):
|
33 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
34 |
+
# load transform only
|
35 |
+
if model_type == "dpt_large": # DPT-Large
|
36 |
+
net_w, net_h = 384, 384
|
37 |
+
resize_mode = "minimal"
|
38 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
39 |
+
|
40 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
41 |
+
net_w, net_h = 384, 384
|
42 |
+
resize_mode = "minimal"
|
43 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
44 |
+
|
45 |
+
elif model_type == "midas_v21":
|
46 |
+
net_w, net_h = 384, 384
|
47 |
+
resize_mode = "upper_bound"
|
48 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
49 |
+
|
50 |
+
elif model_type == "midas_v21_small":
|
51 |
+
net_w, net_h = 256, 256
|
52 |
+
resize_mode = "upper_bound"
|
53 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
54 |
+
|
55 |
+
else:
|
56 |
+
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
57 |
+
|
58 |
+
transform = Compose(
|
59 |
+
[
|
60 |
+
Resize(
|
61 |
+
net_w,
|
62 |
+
net_h,
|
63 |
+
resize_target=None,
|
64 |
+
keep_aspect_ratio=True,
|
65 |
+
ensure_multiple_of=32,
|
66 |
+
resize_method=resize_mode,
|
67 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
68 |
+
),
|
69 |
+
normalization,
|
70 |
+
PrepareForNet(),
|
71 |
+
]
|
72 |
+
)
|
73 |
+
|
74 |
+
return transform
|
75 |
+
|
76 |
+
|
77 |
+
def load_model(model_type):
|
78 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
79 |
+
# load network
|
80 |
+
model_path = ISL_PATHS[model_type]
|
81 |
+
if model_type == "dpt_large": # DPT-Large
|
82 |
+
model = DPTDepthModel(
|
83 |
+
path=model_path,
|
84 |
+
backbone="vitl16_384",
|
85 |
+
non_negative=True,
|
86 |
+
)
|
87 |
+
net_w, net_h = 384, 384
|
88 |
+
resize_mode = "minimal"
|
89 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
90 |
+
|
91 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
92 |
+
if not os.path.exists(model_path):
|
93 |
+
from basicsr.utils.download_util import load_file_from_url
|
94 |
+
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
95 |
+
|
96 |
+
model = DPTDepthModel(
|
97 |
+
path=model_path,
|
98 |
+
backbone="vitb_rn50_384",
|
99 |
+
non_negative=True,
|
100 |
+
)
|
101 |
+
net_w, net_h = 384, 384
|
102 |
+
resize_mode = "minimal"
|
103 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
104 |
+
|
105 |
+
elif model_type == "midas_v21":
|
106 |
+
model = MidasNet(model_path, non_negative=True)
|
107 |
+
net_w, net_h = 384, 384
|
108 |
+
resize_mode = "upper_bound"
|
109 |
+
normalization = NormalizeImage(
|
110 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
111 |
+
)
|
112 |
+
|
113 |
+
elif model_type == "midas_v21_small":
|
114 |
+
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
115 |
+
non_negative=True, blocks={'expand': True})
|
116 |
+
net_w, net_h = 256, 256
|
117 |
+
resize_mode = "upper_bound"
|
118 |
+
normalization = NormalizeImage(
|
119 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
120 |
+
)
|
121 |
+
|
122 |
+
else:
|
123 |
+
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
124 |
+
assert False
|
125 |
+
|
126 |
+
transform = Compose(
|
127 |
+
[
|
128 |
+
Resize(
|
129 |
+
net_w,
|
130 |
+
net_h,
|
131 |
+
resize_target=None,
|
132 |
+
keep_aspect_ratio=True,
|
133 |
+
ensure_multiple_of=32,
|
134 |
+
resize_method=resize_mode,
|
135 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
136 |
+
),
|
137 |
+
normalization,
|
138 |
+
PrepareForNet(),
|
139 |
+
]
|
140 |
+
)
|
141 |
+
|
142 |
+
return model.eval(), transform
|
143 |
+
|
144 |
+
|
145 |
+
class MiDaSInference(nn.Module):
|
146 |
+
MODEL_TYPES_TORCH_HUB = [
|
147 |
+
"DPT_Large",
|
148 |
+
"DPT_Hybrid",
|
149 |
+
"MiDaS_small"
|
150 |
+
]
|
151 |
+
MODEL_TYPES_ISL = [
|
152 |
+
"dpt_large",
|
153 |
+
"dpt_hybrid",
|
154 |
+
"midas_v21",
|
155 |
+
"midas_v21_small",
|
156 |
+
]
|
157 |
+
|
158 |
+
def __init__(self, model_type):
|
159 |
+
super().__init__()
|
160 |
+
assert (model_type in self.MODEL_TYPES_ISL)
|
161 |
+
model, _ = load_model(model_type)
|
162 |
+
self.model = model
|
163 |
+
self.model.train = disabled_train
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
with torch.no_grad():
|
167 |
+
prediction = self.model(x)
|
168 |
+
return prediction
|
169 |
+
|
annotator/midas/midas/__init__.py
ADDED
File without changes
|
annotator/midas/midas/base_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class BaseModel(torch.nn.Module):
|
5 |
+
def load(self, path):
|
6 |
+
"""Load model from file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
path (str): file path
|
10 |
+
"""
|
11 |
+
parameters = torch.load(path, map_location=torch.device('cpu'))
|
12 |
+
|
13 |
+
if "optimizer" in parameters:
|
14 |
+
parameters = parameters["model"]
|
15 |
+
|
16 |
+
self.load_state_dict(parameters)
|
annotator/midas/midas/blocks.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .vit import (
|
5 |
+
_make_pretrained_vitb_rn50_384,
|
6 |
+
_make_pretrained_vitl16_384,
|
7 |
+
_make_pretrained_vitb16_384,
|
8 |
+
forward_vit,
|
9 |
+
)
|
10 |
+
|
11 |
+
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
12 |
+
if backbone == "vitl16_384":
|
13 |
+
pretrained = _make_pretrained_vitl16_384(
|
14 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
15 |
+
)
|
16 |
+
scratch = _make_scratch(
|
17 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
18 |
+
) # ViT-L/16 - 85.0% Top1 (backbone)
|
19 |
+
elif backbone == "vitb_rn50_384":
|
20 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
21 |
+
use_pretrained,
|
22 |
+
hooks=hooks,
|
23 |
+
use_vit_only=use_vit_only,
|
24 |
+
use_readout=use_readout,
|
25 |
+
)
|
26 |
+
scratch = _make_scratch(
|
27 |
+
[256, 512, 768, 768], features, groups=groups, expand=expand
|
28 |
+
) # ViT-H/16 - 85.0% Top1 (backbone)
|
29 |
+
elif backbone == "vitb16_384":
|
30 |
+
pretrained = _make_pretrained_vitb16_384(
|
31 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
32 |
+
)
|
33 |
+
scratch = _make_scratch(
|
34 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
35 |
+
) # ViT-B/16 - 84.6% Top1 (backbone)
|
36 |
+
elif backbone == "resnext101_wsl":
|
37 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
38 |
+
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
39 |
+
elif backbone == "efficientnet_lite3":
|
40 |
+
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
41 |
+
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
42 |
+
else:
|
43 |
+
print(f"Backbone '{backbone}' not implemented")
|
44 |
+
assert False
|
45 |
+
|
46 |
+
return pretrained, scratch
|
47 |
+
|
48 |
+
|
49 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
50 |
+
scratch = nn.Module()
|
51 |
+
|
52 |
+
out_shape1 = out_shape
|
53 |
+
out_shape2 = out_shape
|
54 |
+
out_shape3 = out_shape
|
55 |
+
out_shape4 = out_shape
|
56 |
+
if expand==True:
|
57 |
+
out_shape1 = out_shape
|
58 |
+
out_shape2 = out_shape*2
|
59 |
+
out_shape3 = out_shape*4
|
60 |
+
out_shape4 = out_shape*8
|
61 |
+
|
62 |
+
scratch.layer1_rn = nn.Conv2d(
|
63 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
64 |
+
)
|
65 |
+
scratch.layer2_rn = nn.Conv2d(
|
66 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
67 |
+
)
|
68 |
+
scratch.layer3_rn = nn.Conv2d(
|
69 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
70 |
+
)
|
71 |
+
scratch.layer4_rn = nn.Conv2d(
|
72 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
73 |
+
)
|
74 |
+
|
75 |
+
return scratch
|
76 |
+
|
77 |
+
|
78 |
+
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
79 |
+
efficientnet = torch.hub.load(
|
80 |
+
"rwightman/gen-efficientnet-pytorch",
|
81 |
+
"tf_efficientnet_lite3",
|
82 |
+
pretrained=use_pretrained,
|
83 |
+
exportable=exportable
|
84 |
+
)
|
85 |
+
return _make_efficientnet_backbone(efficientnet)
|
86 |
+
|
87 |
+
|
88 |
+
def _make_efficientnet_backbone(effnet):
|
89 |
+
pretrained = nn.Module()
|
90 |
+
|
91 |
+
pretrained.layer1 = nn.Sequential(
|
92 |
+
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
93 |
+
)
|
94 |
+
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
95 |
+
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
96 |
+
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
97 |
+
|
98 |
+
return pretrained
|
99 |
+
|
100 |
+
|
101 |
+
def _make_resnet_backbone(resnet):
|
102 |
+
pretrained = nn.Module()
|
103 |
+
pretrained.layer1 = nn.Sequential(
|
104 |
+
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
105 |
+
)
|
106 |
+
|
107 |
+
pretrained.layer2 = resnet.layer2
|
108 |
+
pretrained.layer3 = resnet.layer3
|
109 |
+
pretrained.layer4 = resnet.layer4
|
110 |
+
|
111 |
+
return pretrained
|
112 |
+
|
113 |
+
|
114 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
115 |
+
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
116 |
+
return _make_resnet_backbone(resnet)
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
class Interpolate(nn.Module):
|
121 |
+
"""Interpolation module.
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
125 |
+
"""Init.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
scale_factor (float): scaling
|
129 |
+
mode (str): interpolation mode
|
130 |
+
"""
|
131 |
+
super(Interpolate, self).__init__()
|
132 |
+
|
133 |
+
self.interp = nn.functional.interpolate
|
134 |
+
self.scale_factor = scale_factor
|
135 |
+
self.mode = mode
|
136 |
+
self.align_corners = align_corners
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
"""Forward pass.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
x (tensor): input
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
tensor: interpolated data
|
146 |
+
"""
|
147 |
+
|
148 |
+
x = self.interp(
|
149 |
+
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
150 |
+
)
|
151 |
+
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class ResidualConvUnit(nn.Module):
|
156 |
+
"""Residual convolution module.
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, features):
|
160 |
+
"""Init.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
features (int): number of features
|
164 |
+
"""
|
165 |
+
super().__init__()
|
166 |
+
|
167 |
+
self.conv1 = nn.Conv2d(
|
168 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
169 |
+
)
|
170 |
+
|
171 |
+
self.conv2 = nn.Conv2d(
|
172 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
173 |
+
)
|
174 |
+
|
175 |
+
self.relu = nn.ReLU(inplace=True)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
"""Forward pass.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
x (tensor): input
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
tensor: output
|
185 |
+
"""
|
186 |
+
out = self.relu(x)
|
187 |
+
out = self.conv1(out)
|
188 |
+
out = self.relu(out)
|
189 |
+
out = self.conv2(out)
|
190 |
+
|
191 |
+
return out + x
|
192 |
+
|
193 |
+
|
194 |
+
class FeatureFusionBlock(nn.Module):
|
195 |
+
"""Feature fusion block.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self, features):
|
199 |
+
"""Init.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
features (int): number of features
|
203 |
+
"""
|
204 |
+
super(FeatureFusionBlock, self).__init__()
|
205 |
+
|
206 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
207 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
208 |
+
|
209 |
+
def forward(self, *xs):
|
210 |
+
"""Forward pass.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
tensor: output
|
214 |
+
"""
|
215 |
+
output = xs[0]
|
216 |
+
|
217 |
+
if len(xs) == 2:
|
218 |
+
output += self.resConfUnit1(xs[1])
|
219 |
+
|
220 |
+
output = self.resConfUnit2(output)
|
221 |
+
|
222 |
+
output = nn.functional.interpolate(
|
223 |
+
output, scale_factor=2, mode="bilinear", align_corners=True
|
224 |
+
)
|
225 |
+
|
226 |
+
return output
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
class ResidualConvUnit_custom(nn.Module):
|
232 |
+
"""Residual convolution module.
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(self, features, activation, bn):
|
236 |
+
"""Init.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
features (int): number of features
|
240 |
+
"""
|
241 |
+
super().__init__()
|
242 |
+
|
243 |
+
self.bn = bn
|
244 |
+
|
245 |
+
self.groups=1
|
246 |
+
|
247 |
+
self.conv1 = nn.Conv2d(
|
248 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
249 |
+
)
|
250 |
+
|
251 |
+
self.conv2 = nn.Conv2d(
|
252 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
253 |
+
)
|
254 |
+
|
255 |
+
if self.bn==True:
|
256 |
+
self.bn1 = nn.BatchNorm2d(features)
|
257 |
+
self.bn2 = nn.BatchNorm2d(features)
|
258 |
+
|
259 |
+
self.activation = activation
|
260 |
+
|
261 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
"""Forward pass.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
x (tensor): input
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
tensor: output
|
271 |
+
"""
|
272 |
+
|
273 |
+
out = self.activation(x)
|
274 |
+
out = self.conv1(out)
|
275 |
+
if self.bn==True:
|
276 |
+
out = self.bn1(out)
|
277 |
+
|
278 |
+
out = self.activation(out)
|
279 |
+
out = self.conv2(out)
|
280 |
+
if self.bn==True:
|
281 |
+
out = self.bn2(out)
|
282 |
+
|
283 |
+
if self.groups > 1:
|
284 |
+
out = self.conv_merge(out)
|
285 |
+
|
286 |
+
return self.skip_add.add(out, x)
|
287 |
+
|
288 |
+
# return out + x
|
289 |
+
|
290 |
+
|
291 |
+
class FeatureFusionBlock_custom(nn.Module):
|
292 |
+
"""Feature fusion block.
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
296 |
+
"""Init.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
features (int): number of features
|
300 |
+
"""
|
301 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
302 |
+
|
303 |
+
self.deconv = deconv
|
304 |
+
self.align_corners = align_corners
|
305 |
+
|
306 |
+
self.groups=1
|
307 |
+
|
308 |
+
self.expand = expand
|
309 |
+
out_features = features
|
310 |
+
if self.expand==True:
|
311 |
+
out_features = features//2
|
312 |
+
|
313 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
314 |
+
|
315 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
316 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
317 |
+
|
318 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
319 |
+
|
320 |
+
def forward(self, *xs):
|
321 |
+
"""Forward pass.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
tensor: output
|
325 |
+
"""
|
326 |
+
output = xs[0]
|
327 |
+
|
328 |
+
if len(xs) == 2:
|
329 |
+
res = self.resConfUnit1(xs[1])
|
330 |
+
output = self.skip_add.add(output, res)
|
331 |
+
# output += res
|
332 |
+
|
333 |
+
output = self.resConfUnit2(output)
|
334 |
+
|
335 |
+
output = nn.functional.interpolate(
|
336 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
337 |
+
)
|
338 |
+
|
339 |
+
output = self.out_conv(output)
|
340 |
+
|
341 |
+
return output
|
342 |
+
|
annotator/midas/midas/dpt_depth.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .base_model import BaseModel
|
6 |
+
from .blocks import (
|
7 |
+
FeatureFusionBlock,
|
8 |
+
FeatureFusionBlock_custom,
|
9 |
+
Interpolate,
|
10 |
+
_make_encoder,
|
11 |
+
forward_vit,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def _make_fusion_block(features, use_bn):
|
16 |
+
return FeatureFusionBlock_custom(
|
17 |
+
features,
|
18 |
+
nn.ReLU(False),
|
19 |
+
deconv=False,
|
20 |
+
bn=use_bn,
|
21 |
+
expand=False,
|
22 |
+
align_corners=True,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class DPT(BaseModel):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
head,
|
30 |
+
features=256,
|
31 |
+
backbone="vitb_rn50_384",
|
32 |
+
readout="project",
|
33 |
+
channels_last=False,
|
34 |
+
use_bn=False,
|
35 |
+
):
|
36 |
+
|
37 |
+
super(DPT, self).__init__()
|
38 |
+
|
39 |
+
self.channels_last = channels_last
|
40 |
+
|
41 |
+
hooks = {
|
42 |
+
"vitb_rn50_384": [0, 1, 8, 11],
|
43 |
+
"vitb16_384": [2, 5, 8, 11],
|
44 |
+
"vitl16_384": [5, 11, 17, 23],
|
45 |
+
}
|
46 |
+
|
47 |
+
# Instantiate backbone and reassemble blocks
|
48 |
+
self.pretrained, self.scratch = _make_encoder(
|
49 |
+
backbone,
|
50 |
+
features,
|
51 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
52 |
+
groups=1,
|
53 |
+
expand=False,
|
54 |
+
exportable=False,
|
55 |
+
hooks=hooks[backbone],
|
56 |
+
use_readout=readout,
|
57 |
+
)
|
58 |
+
|
59 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
60 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
61 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
62 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
63 |
+
|
64 |
+
self.scratch.output_conv = head
|
65 |
+
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
if self.channels_last == True:
|
69 |
+
x.contiguous(memory_format=torch.channels_last)
|
70 |
+
|
71 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
72 |
+
|
73 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
74 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
75 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
76 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
77 |
+
|
78 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
79 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
80 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
81 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
82 |
+
|
83 |
+
out = self.scratch.output_conv(path_1)
|
84 |
+
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
class DPTDepthModel(DPT):
|
89 |
+
def __init__(self, path=None, non_negative=True, **kwargs):
|
90 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
91 |
+
|
92 |
+
head = nn.Sequential(
|
93 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
94 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
95 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
96 |
+
nn.ReLU(True),
|
97 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
98 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
99 |
+
nn.Identity(),
|
100 |
+
)
|
101 |
+
|
102 |
+
super().__init__(head, **kwargs)
|
103 |
+
|
104 |
+
if path is not None:
|
105 |
+
self.load(path)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return super().forward(x).squeeze(dim=1)
|
109 |
+
|
annotator/midas/midas/midas_net.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
17 |
+
"""Init.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
path (str, optional): Path to saved model. Defaults to None.
|
21 |
+
features (int, optional): Number of features. Defaults to 256.
|
22 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
23 |
+
"""
|
24 |
+
print("Loading weights: ", path)
|
25 |
+
|
26 |
+
super(MidasNet, self).__init__()
|
27 |
+
|
28 |
+
use_pretrained = False if path is None else True
|
29 |
+
|
30 |
+
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
31 |
+
|
32 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
33 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
34 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
35 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
36 |
+
|
37 |
+
self.scratch.output_conv = nn.Sequential(
|
38 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
39 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
40 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
41 |
+
nn.ReLU(True),
|
42 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
43 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
44 |
+
)
|
45 |
+
|
46 |
+
if path:
|
47 |
+
self.load(path)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
"""Forward pass.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
x (tensor): input data (image)
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tensor: depth
|
57 |
+
"""
|
58 |
+
|
59 |
+
layer_1 = self.pretrained.layer1(x)
|
60 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
61 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
62 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
63 |
+
|
64 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
65 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
66 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
67 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
68 |
+
|
69 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
70 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
71 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
72 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
73 |
+
|
74 |
+
out = self.scratch.output_conv(path_1)
|
75 |
+
|
76 |
+
return torch.squeeze(out, dim=1)
|
annotator/midas/midas/midas_net_custom.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet_small(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
17 |
+
blocks={'expand': True}):
|
18 |
+
"""Init.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
path (str, optional): Path to saved model. Defaults to None.
|
22 |
+
features (int, optional): Number of features. Defaults to 256.
|
23 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
24 |
+
"""
|
25 |
+
print("Loading weights: ", path)
|
26 |
+
|
27 |
+
super(MidasNet_small, self).__init__()
|
28 |
+
|
29 |
+
use_pretrained = False if path else True
|
30 |
+
|
31 |
+
self.channels_last = channels_last
|
32 |
+
self.blocks = blocks
|
33 |
+
self.backbone = backbone
|
34 |
+
|
35 |
+
self.groups = 1
|
36 |
+
|
37 |
+
features1=features
|
38 |
+
features2=features
|
39 |
+
features3=features
|
40 |
+
features4=features
|
41 |
+
self.expand = False
|
42 |
+
if "expand" in self.blocks and self.blocks['expand'] == True:
|
43 |
+
self.expand = True
|
44 |
+
features1=features
|
45 |
+
features2=features*2
|
46 |
+
features3=features*4
|
47 |
+
features4=features*8
|
48 |
+
|
49 |
+
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
50 |
+
|
51 |
+
self.scratch.activation = nn.ReLU(False)
|
52 |
+
|
53 |
+
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
54 |
+
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
55 |
+
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
56 |
+
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
57 |
+
|
58 |
+
|
59 |
+
self.scratch.output_conv = nn.Sequential(
|
60 |
+
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
61 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
62 |
+
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
63 |
+
self.scratch.activation,
|
64 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
65 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
66 |
+
nn.Identity(),
|
67 |
+
)
|
68 |
+
|
69 |
+
if path:
|
70 |
+
self.load(path)
|
71 |
+
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
"""Forward pass.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
x (tensor): input data (image)
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
tensor: depth
|
81 |
+
"""
|
82 |
+
if self.channels_last==True:
|
83 |
+
print("self.channels_last = ", self.channels_last)
|
84 |
+
x.contiguous(memory_format=torch.channels_last)
|
85 |
+
|
86 |
+
|
87 |
+
layer_1 = self.pretrained.layer1(x)
|
88 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
89 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
90 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
91 |
+
|
92 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
93 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
94 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
95 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
96 |
+
|
97 |
+
|
98 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
99 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
100 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
101 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
102 |
+
|
103 |
+
out = self.scratch.output_conv(path_1)
|
104 |
+
|
105 |
+
return torch.squeeze(out, dim=1)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def fuse_model(m):
|
110 |
+
prev_previous_type = nn.Identity()
|
111 |
+
prev_previous_name = ''
|
112 |
+
previous_type = nn.Identity()
|
113 |
+
previous_name = ''
|
114 |
+
for name, module in m.named_modules():
|
115 |
+
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
116 |
+
# print("FUSED ", prev_previous_name, previous_name, name)
|
117 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
118 |
+
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
119 |
+
# print("FUSED ", prev_previous_name, previous_name)
|
120 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
121 |
+
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
122 |
+
# print("FUSED ", previous_name, name)
|
123 |
+
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
124 |
+
|
125 |
+
prev_previous_type = previous_type
|
126 |
+
prev_previous_name = previous_name
|
127 |
+
previous_type = type(module)
|
128 |
+
previous_name = name
|
annotator/midas/midas/transforms.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
7 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
sample (dict): sample
|
11 |
+
size (tuple): image size
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
tuple: new size
|
15 |
+
"""
|
16 |
+
shape = list(sample["disparity"].shape)
|
17 |
+
|
18 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
19 |
+
return sample
|
20 |
+
|
21 |
+
scale = [0, 0]
|
22 |
+
scale[0] = size[0] / shape[0]
|
23 |
+
scale[1] = size[1] / shape[1]
|
24 |
+
|
25 |
+
scale = max(scale)
|
26 |
+
|
27 |
+
shape[0] = math.ceil(scale * shape[0])
|
28 |
+
shape[1] = math.ceil(scale * shape[1])
|
29 |
+
|
30 |
+
# resize
|
31 |
+
sample["image"] = cv2.resize(
|
32 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
33 |
+
)
|
34 |
+
|
35 |
+
sample["disparity"] = cv2.resize(
|
36 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
37 |
+
)
|
38 |
+
sample["mask"] = cv2.resize(
|
39 |
+
sample["mask"].astype(np.float32),
|
40 |
+
tuple(shape[::-1]),
|
41 |
+
interpolation=cv2.INTER_NEAREST,
|
42 |
+
)
|
43 |
+
sample["mask"] = sample["mask"].astype(bool)
|
44 |
+
|
45 |
+
return tuple(shape)
|
46 |
+
|
47 |
+
|
48 |
+
class Resize(object):
|
49 |
+
"""Resize sample to given size (width, height).
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
width,
|
55 |
+
height,
|
56 |
+
resize_target=True,
|
57 |
+
keep_aspect_ratio=False,
|
58 |
+
ensure_multiple_of=1,
|
59 |
+
resize_method="lower_bound",
|
60 |
+
image_interpolation_method=cv2.INTER_AREA,
|
61 |
+
):
|
62 |
+
"""Init.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
width (int): desired output width
|
66 |
+
height (int): desired output height
|
67 |
+
resize_target (bool, optional):
|
68 |
+
True: Resize the full sample (image, mask, target).
|
69 |
+
False: Resize image only.
|
70 |
+
Defaults to True.
|
71 |
+
keep_aspect_ratio (bool, optional):
|
72 |
+
True: Keep the aspect ratio of the input sample.
|
73 |
+
Output sample might not have the given width and height, and
|
74 |
+
resize behaviour depends on the parameter 'resize_method'.
|
75 |
+
Defaults to False.
|
76 |
+
ensure_multiple_of (int, optional):
|
77 |
+
Output width and height is constrained to be multiple of this parameter.
|
78 |
+
Defaults to 1.
|
79 |
+
resize_method (str, optional):
|
80 |
+
"lower_bound": Output will be at least as large as the given size.
|
81 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
82 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
83 |
+
Defaults to "lower_bound".
|
84 |
+
"""
|
85 |
+
self.__width = width
|
86 |
+
self.__height = height
|
87 |
+
|
88 |
+
self.__resize_target = resize_target
|
89 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
90 |
+
self.__multiple_of = ensure_multiple_of
|
91 |
+
self.__resize_method = resize_method
|
92 |
+
self.__image_interpolation_method = image_interpolation_method
|
93 |
+
|
94 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
95 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
96 |
+
|
97 |
+
if max_val is not None and y > max_val:
|
98 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
99 |
+
|
100 |
+
if y < min_val:
|
101 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
102 |
+
|
103 |
+
return y
|
104 |
+
|
105 |
+
def get_size(self, width, height):
|
106 |
+
# determine new height and width
|
107 |
+
scale_height = self.__height / height
|
108 |
+
scale_width = self.__width / width
|
109 |
+
|
110 |
+
if self.__keep_aspect_ratio:
|
111 |
+
if self.__resize_method == "lower_bound":
|
112 |
+
# scale such that output size is lower bound
|
113 |
+
if scale_width > scale_height:
|
114 |
+
# fit width
|
115 |
+
scale_height = scale_width
|
116 |
+
else:
|
117 |
+
# fit height
|
118 |
+
scale_width = scale_height
|
119 |
+
elif self.__resize_method == "upper_bound":
|
120 |
+
# scale such that output size is upper bound
|
121 |
+
if scale_width < scale_height:
|
122 |
+
# fit width
|
123 |
+
scale_height = scale_width
|
124 |
+
else:
|
125 |
+
# fit height
|
126 |
+
scale_width = scale_height
|
127 |
+
elif self.__resize_method == "minimal":
|
128 |
+
# scale as least as possbile
|
129 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
130 |
+
# fit width
|
131 |
+
scale_height = scale_width
|
132 |
+
else:
|
133 |
+
# fit height
|
134 |
+
scale_width = scale_height
|
135 |
+
else:
|
136 |
+
raise ValueError(
|
137 |
+
f"resize_method {self.__resize_method} not implemented"
|
138 |
+
)
|
139 |
+
|
140 |
+
if self.__resize_method == "lower_bound":
|
141 |
+
new_height = self.constrain_to_multiple_of(
|
142 |
+
scale_height * height, min_val=self.__height
|
143 |
+
)
|
144 |
+
new_width = self.constrain_to_multiple_of(
|
145 |
+
scale_width * width, min_val=self.__width
|
146 |
+
)
|
147 |
+
elif self.__resize_method == "upper_bound":
|
148 |
+
new_height = self.constrain_to_multiple_of(
|
149 |
+
scale_height * height, max_val=self.__height
|
150 |
+
)
|
151 |
+
new_width = self.constrain_to_multiple_of(
|
152 |
+
scale_width * width, max_val=self.__width
|
153 |
+
)
|
154 |
+
elif self.__resize_method == "minimal":
|
155 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
156 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
157 |
+
else:
|
158 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
159 |
+
|
160 |
+
return (new_width, new_height)
|
161 |
+
|
162 |
+
def __call__(self, sample):
|
163 |
+
width, height = self.get_size(
|
164 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
165 |
+
)
|
166 |
+
|
167 |
+
# resize sample
|
168 |
+
sample["image"] = cv2.resize(
|
169 |
+
sample["image"],
|
170 |
+
(width, height),
|
171 |
+
interpolation=self.__image_interpolation_method,
|
172 |
+
)
|
173 |
+
|
174 |
+
if self.__resize_target:
|
175 |
+
if "disparity" in sample:
|
176 |
+
sample["disparity"] = cv2.resize(
|
177 |
+
sample["disparity"],
|
178 |
+
(width, height),
|
179 |
+
interpolation=cv2.INTER_NEAREST,
|
180 |
+
)
|
181 |
+
|
182 |
+
if "depth" in sample:
|
183 |
+
sample["depth"] = cv2.resize(
|
184 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
185 |
+
)
|
186 |
+
|
187 |
+
sample["mask"] = cv2.resize(
|
188 |
+
sample["mask"].astype(np.float32),
|
189 |
+
(width, height),
|
190 |
+
interpolation=cv2.INTER_NEAREST,
|
191 |
+
)
|
192 |
+
sample["mask"] = sample["mask"].astype(bool)
|
193 |
+
|
194 |
+
return sample
|
195 |
+
|
196 |
+
|
197 |
+
class NormalizeImage(object):
|
198 |
+
"""Normlize image by given mean and std.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, mean, std):
|
202 |
+
self.__mean = mean
|
203 |
+
self.__std = std
|
204 |
+
|
205 |
+
def __call__(self, sample):
|
206 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
207 |
+
|
208 |
+
return sample
|
209 |
+
|
210 |
+
|
211 |
+
class PrepareForNet(object):
|
212 |
+
"""Prepare sample for usage as network input.
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self):
|
216 |
+
pass
|
217 |
+
|
218 |
+
def __call__(self, sample):
|
219 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
220 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
221 |
+
|
222 |
+
if "mask" in sample:
|
223 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
224 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
225 |
+
|
226 |
+
if "disparity" in sample:
|
227 |
+
disparity = sample["disparity"].astype(np.float32)
|
228 |
+
sample["disparity"] = np.ascontiguousarray(disparity)
|
229 |
+
|
230 |
+
if "depth" in sample:
|
231 |
+
depth = sample["depth"].astype(np.float32)
|
232 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
233 |
+
|
234 |
+
return sample
|
annotator/midas/midas/vit.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import timm
|
4 |
+
import types
|
5 |
+
import math
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class Slice(nn.Module):
|
10 |
+
def __init__(self, start_index=1):
|
11 |
+
super(Slice, self).__init__()
|
12 |
+
self.start_index = start_index
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
return x[:, self.start_index :]
|
16 |
+
|
17 |
+
|
18 |
+
class AddReadout(nn.Module):
|
19 |
+
def __init__(self, start_index=1):
|
20 |
+
super(AddReadout, self).__init__()
|
21 |
+
self.start_index = start_index
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
if self.start_index == 2:
|
25 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
26 |
+
else:
|
27 |
+
readout = x[:, 0]
|
28 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
29 |
+
|
30 |
+
|
31 |
+
class ProjectReadout(nn.Module):
|
32 |
+
def __init__(self, in_features, start_index=1):
|
33 |
+
super(ProjectReadout, self).__init__()
|
34 |
+
self.start_index = start_index
|
35 |
+
|
36 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
40 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
41 |
+
|
42 |
+
return self.project(features)
|
43 |
+
|
44 |
+
|
45 |
+
class Transpose(nn.Module):
|
46 |
+
def __init__(self, dim0, dim1):
|
47 |
+
super(Transpose, self).__init__()
|
48 |
+
self.dim0 = dim0
|
49 |
+
self.dim1 = dim1
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = x.transpose(self.dim0, self.dim1)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
def forward_vit(pretrained, x):
|
57 |
+
b, c, h, w = x.shape
|
58 |
+
|
59 |
+
glob = pretrained.model.forward_flex(x)
|
60 |
+
|
61 |
+
layer_1 = pretrained.activations["1"]
|
62 |
+
layer_2 = pretrained.activations["2"]
|
63 |
+
layer_3 = pretrained.activations["3"]
|
64 |
+
layer_4 = pretrained.activations["4"]
|
65 |
+
|
66 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
67 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
68 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
69 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
70 |
+
|
71 |
+
unflatten = nn.Sequential(
|
72 |
+
nn.Unflatten(
|
73 |
+
2,
|
74 |
+
torch.Size(
|
75 |
+
[
|
76 |
+
h // pretrained.model.patch_size[1],
|
77 |
+
w // pretrained.model.patch_size[0],
|
78 |
+
]
|
79 |
+
),
|
80 |
+
)
|
81 |
+
)
|
82 |
+
|
83 |
+
if layer_1.ndim == 3:
|
84 |
+
layer_1 = unflatten(layer_1)
|
85 |
+
if layer_2.ndim == 3:
|
86 |
+
layer_2 = unflatten(layer_2)
|
87 |
+
if layer_3.ndim == 3:
|
88 |
+
layer_3 = unflatten(layer_3)
|
89 |
+
if layer_4.ndim == 3:
|
90 |
+
layer_4 = unflatten(layer_4)
|
91 |
+
|
92 |
+
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
93 |
+
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
94 |
+
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
95 |
+
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
96 |
+
|
97 |
+
return layer_1, layer_2, layer_3, layer_4
|
98 |
+
|
99 |
+
|
100 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
101 |
+
posemb_tok, posemb_grid = (
|
102 |
+
posemb[:, : self.start_index],
|
103 |
+
posemb[0, self.start_index :],
|
104 |
+
)
|
105 |
+
|
106 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
107 |
+
|
108 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
109 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
110 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
111 |
+
|
112 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
113 |
+
|
114 |
+
return posemb
|
115 |
+
|
116 |
+
|
117 |
+
def forward_flex(self, x):
|
118 |
+
b, c, h, w = x.shape
|
119 |
+
|
120 |
+
pos_embed = self._resize_pos_embed(
|
121 |
+
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
122 |
+
)
|
123 |
+
|
124 |
+
B = x.shape[0]
|
125 |
+
|
126 |
+
if hasattr(self.patch_embed, "backbone"):
|
127 |
+
x = self.patch_embed.backbone(x)
|
128 |
+
if isinstance(x, (list, tuple)):
|
129 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
130 |
+
|
131 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
132 |
+
|
133 |
+
if getattr(self, "dist_token", None) is not None:
|
134 |
+
cls_tokens = self.cls_token.expand(
|
135 |
+
B, -1, -1
|
136 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
137 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
138 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
139 |
+
else:
|
140 |
+
cls_tokens = self.cls_token.expand(
|
141 |
+
B, -1, -1
|
142 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
143 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
144 |
+
|
145 |
+
x = x + pos_embed
|
146 |
+
x = self.pos_drop(x)
|
147 |
+
|
148 |
+
for blk in self.blocks:
|
149 |
+
x = blk(x)
|
150 |
+
|
151 |
+
x = self.norm(x)
|
152 |
+
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
activations = {}
|
157 |
+
|
158 |
+
|
159 |
+
def get_activation(name):
|
160 |
+
def hook(model, input, output):
|
161 |
+
activations[name] = output
|
162 |
+
|
163 |
+
return hook
|
164 |
+
|
165 |
+
|
166 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
167 |
+
if use_readout == "ignore":
|
168 |
+
readout_oper = [Slice(start_index)] * len(features)
|
169 |
+
elif use_readout == "add":
|
170 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
171 |
+
elif use_readout == "project":
|
172 |
+
readout_oper = [
|
173 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
174 |
+
]
|
175 |
+
else:
|
176 |
+
assert (
|
177 |
+
False
|
178 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
179 |
+
|
180 |
+
return readout_oper
|
181 |
+
|
182 |
+
|
183 |
+
def _make_vit_b16_backbone(
|
184 |
+
model,
|
185 |
+
features=[96, 192, 384, 768],
|
186 |
+
size=[384, 384],
|
187 |
+
hooks=[2, 5, 8, 11],
|
188 |
+
vit_features=768,
|
189 |
+
use_readout="ignore",
|
190 |
+
start_index=1,
|
191 |
+
):
|
192 |
+
pretrained = nn.Module()
|
193 |
+
|
194 |
+
pretrained.model = model
|
195 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
196 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
197 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
198 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
199 |
+
|
200 |
+
pretrained.activations = activations
|
201 |
+
|
202 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
203 |
+
|
204 |
+
# 32, 48, 136, 384
|
205 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
206 |
+
readout_oper[0],
|
207 |
+
Transpose(1, 2),
|
208 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
209 |
+
nn.Conv2d(
|
210 |
+
in_channels=vit_features,
|
211 |
+
out_channels=features[0],
|
212 |
+
kernel_size=1,
|
213 |
+
stride=1,
|
214 |
+
padding=0,
|
215 |
+
),
|
216 |
+
nn.ConvTranspose2d(
|
217 |
+
in_channels=features[0],
|
218 |
+
out_channels=features[0],
|
219 |
+
kernel_size=4,
|
220 |
+
stride=4,
|
221 |
+
padding=0,
|
222 |
+
bias=True,
|
223 |
+
dilation=1,
|
224 |
+
groups=1,
|
225 |
+
),
|
226 |
+
)
|
227 |
+
|
228 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
229 |
+
readout_oper[1],
|
230 |
+
Transpose(1, 2),
|
231 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
232 |
+
nn.Conv2d(
|
233 |
+
in_channels=vit_features,
|
234 |
+
out_channels=features[1],
|
235 |
+
kernel_size=1,
|
236 |
+
stride=1,
|
237 |
+
padding=0,
|
238 |
+
),
|
239 |
+
nn.ConvTranspose2d(
|
240 |
+
in_channels=features[1],
|
241 |
+
out_channels=features[1],
|
242 |
+
kernel_size=2,
|
243 |
+
stride=2,
|
244 |
+
padding=0,
|
245 |
+
bias=True,
|
246 |
+
dilation=1,
|
247 |
+
groups=1,
|
248 |
+
),
|
249 |
+
)
|
250 |
+
|
251 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
252 |
+
readout_oper[2],
|
253 |
+
Transpose(1, 2),
|
254 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
255 |
+
nn.Conv2d(
|
256 |
+
in_channels=vit_features,
|
257 |
+
out_channels=features[2],
|
258 |
+
kernel_size=1,
|
259 |
+
stride=1,
|
260 |
+
padding=0,
|
261 |
+
),
|
262 |
+
)
|
263 |
+
|
264 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
265 |
+
readout_oper[3],
|
266 |
+
Transpose(1, 2),
|
267 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
268 |
+
nn.Conv2d(
|
269 |
+
in_channels=vit_features,
|
270 |
+
out_channels=features[3],
|
271 |
+
kernel_size=1,
|
272 |
+
stride=1,
|
273 |
+
padding=0,
|
274 |
+
),
|
275 |
+
nn.Conv2d(
|
276 |
+
in_channels=features[3],
|
277 |
+
out_channels=features[3],
|
278 |
+
kernel_size=3,
|
279 |
+
stride=2,
|
280 |
+
padding=1,
|
281 |
+
),
|
282 |
+
)
|
283 |
+
|
284 |
+
pretrained.model.start_index = start_index
|
285 |
+
pretrained.model.patch_size = [16, 16]
|
286 |
+
|
287 |
+
# We inject this function into the VisionTransformer instances so that
|
288 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
289 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
290 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
291 |
+
_resize_pos_embed, pretrained.model
|
292 |
+
)
|
293 |
+
|
294 |
+
return pretrained
|
295 |
+
|
296 |
+
|
297 |
+
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
298 |
+
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
299 |
+
|
300 |
+
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
301 |
+
return _make_vit_b16_backbone(
|
302 |
+
model,
|
303 |
+
features=[256, 512, 1024, 1024],
|
304 |
+
hooks=hooks,
|
305 |
+
vit_features=1024,
|
306 |
+
use_readout=use_readout,
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
311 |
+
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
312 |
+
|
313 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
314 |
+
return _make_vit_b16_backbone(
|
315 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
316 |
+
)
|
317 |
+
|
318 |
+
|
319 |
+
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
320 |
+
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
321 |
+
|
322 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
323 |
+
return _make_vit_b16_backbone(
|
324 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
325 |
+
)
|
326 |
+
|
327 |
+
|
328 |
+
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
329 |
+
model = timm.create_model(
|
330 |
+
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
331 |
+
)
|
332 |
+
|
333 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
334 |
+
return _make_vit_b16_backbone(
|
335 |
+
model,
|
336 |
+
features=[96, 192, 384, 768],
|
337 |
+
hooks=hooks,
|
338 |
+
use_readout=use_readout,
|
339 |
+
start_index=2,
|
340 |
+
)
|
341 |
+
|
342 |
+
|
343 |
+
def _make_vit_b_rn50_backbone(
|
344 |
+
model,
|
345 |
+
features=[256, 512, 768, 768],
|
346 |
+
size=[384, 384],
|
347 |
+
hooks=[0, 1, 8, 11],
|
348 |
+
vit_features=768,
|
349 |
+
use_vit_only=False,
|
350 |
+
use_readout="ignore",
|
351 |
+
start_index=1,
|
352 |
+
):
|
353 |
+
pretrained = nn.Module()
|
354 |
+
|
355 |
+
pretrained.model = model
|
356 |
+
|
357 |
+
if use_vit_only == True:
|
358 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
359 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
360 |
+
else:
|
361 |
+
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
362 |
+
get_activation("1")
|
363 |
+
)
|
364 |
+
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
365 |
+
get_activation("2")
|
366 |
+
)
|
367 |
+
|
368 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
369 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
370 |
+
|
371 |
+
pretrained.activations = activations
|
372 |
+
|
373 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
374 |
+
|
375 |
+
if use_vit_only == True:
|
376 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
377 |
+
readout_oper[0],
|
378 |
+
Transpose(1, 2),
|
379 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
380 |
+
nn.Conv2d(
|
381 |
+
in_channels=vit_features,
|
382 |
+
out_channels=features[0],
|
383 |
+
kernel_size=1,
|
384 |
+
stride=1,
|
385 |
+
padding=0,
|
386 |
+
),
|
387 |
+
nn.ConvTranspose2d(
|
388 |
+
in_channels=features[0],
|
389 |
+
out_channels=features[0],
|
390 |
+
kernel_size=4,
|
391 |
+
stride=4,
|
392 |
+
padding=0,
|
393 |
+
bias=True,
|
394 |
+
dilation=1,
|
395 |
+
groups=1,
|
396 |
+
),
|
397 |
+
)
|
398 |
+
|
399 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
400 |
+
readout_oper[1],
|
401 |
+
Transpose(1, 2),
|
402 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
403 |
+
nn.Conv2d(
|
404 |
+
in_channels=vit_features,
|
405 |
+
out_channels=features[1],
|
406 |
+
kernel_size=1,
|
407 |
+
stride=1,
|
408 |
+
padding=0,
|
409 |
+
),
|
410 |
+
nn.ConvTranspose2d(
|
411 |
+
in_channels=features[1],
|
412 |
+
out_channels=features[1],
|
413 |
+
kernel_size=2,
|
414 |
+
stride=2,
|
415 |
+
padding=0,
|
416 |
+
bias=True,
|
417 |
+
dilation=1,
|
418 |
+
groups=1,
|
419 |
+
),
|
420 |
+
)
|
421 |
+
else:
|
422 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
423 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
424 |
+
)
|
425 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
426 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
427 |
+
)
|
428 |
+
|
429 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
430 |
+
readout_oper[2],
|
431 |
+
Transpose(1, 2),
|
432 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
433 |
+
nn.Conv2d(
|
434 |
+
in_channels=vit_features,
|
435 |
+
out_channels=features[2],
|
436 |
+
kernel_size=1,
|
437 |
+
stride=1,
|
438 |
+
padding=0,
|
439 |
+
),
|
440 |
+
)
|
441 |
+
|
442 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
443 |
+
readout_oper[3],
|
444 |
+
Transpose(1, 2),
|
445 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
446 |
+
nn.Conv2d(
|
447 |
+
in_channels=vit_features,
|
448 |
+
out_channels=features[3],
|
449 |
+
kernel_size=1,
|
450 |
+
stride=1,
|
451 |
+
padding=0,
|
452 |
+
),
|
453 |
+
nn.Conv2d(
|
454 |
+
in_channels=features[3],
|
455 |
+
out_channels=features[3],
|
456 |
+
kernel_size=3,
|
457 |
+
stride=2,
|
458 |
+
padding=1,
|
459 |
+
),
|
460 |
+
)
|
461 |
+
|
462 |
+
pretrained.model.start_index = start_index
|
463 |
+
pretrained.model.patch_size = [16, 16]
|
464 |
+
|
465 |
+
# We inject this function into the VisionTransformer instances so that
|
466 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
467 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
468 |
+
|
469 |
+
# We inject this function into the VisionTransformer instances so that
|
470 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
471 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
472 |
+
_resize_pos_embed, pretrained.model
|
473 |
+
)
|
474 |
+
|
475 |
+
return pretrained
|
476 |
+
|
477 |
+
|
478 |
+
def _make_pretrained_vitb_rn50_384(
|
479 |
+
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
480 |
+
):
|
481 |
+
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
482 |
+
|
483 |
+
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
484 |
+
return _make_vit_b_rn50_backbone(
|
485 |
+
model,
|
486 |
+
features=[256, 512, 768, 768],
|
487 |
+
size=[384, 384],
|
488 |
+
hooks=hooks,
|
489 |
+
use_vit_only=use_vit_only,
|
490 |
+
use_readout=use_readout,
|
491 |
+
)
|
annotator/midas/utils.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utils for monoDepth."""
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def read_pfm(path):
|
10 |
+
"""Read pfm file.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
path (str): path to file
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
tuple: (data, scale)
|
17 |
+
"""
|
18 |
+
with open(path, "rb") as file:
|
19 |
+
|
20 |
+
color = None
|
21 |
+
width = None
|
22 |
+
height = None
|
23 |
+
scale = None
|
24 |
+
endian = None
|
25 |
+
|
26 |
+
header = file.readline().rstrip()
|
27 |
+
if header.decode("ascii") == "PF":
|
28 |
+
color = True
|
29 |
+
elif header.decode("ascii") == "Pf":
|
30 |
+
color = False
|
31 |
+
else:
|
32 |
+
raise Exception("Not a PFM file: " + path)
|
33 |
+
|
34 |
+
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
|
35 |
+
if dim_match:
|
36 |
+
width, height = list(map(int, dim_match.groups()))
|
37 |
+
else:
|
38 |
+
raise Exception("Malformed PFM header.")
|
39 |
+
|
40 |
+
scale = float(file.readline().decode("ascii").rstrip())
|
41 |
+
if scale < 0:
|
42 |
+
# little-endian
|
43 |
+
endian = "<"
|
44 |
+
scale = -scale
|
45 |
+
else:
|
46 |
+
# big-endian
|
47 |
+
endian = ">"
|
48 |
+
|
49 |
+
data = np.fromfile(file, endian + "f")
|
50 |
+
shape = (height, width, 3) if color else (height, width)
|
51 |
+
|
52 |
+
data = np.reshape(data, shape)
|
53 |
+
data = np.flipud(data)
|
54 |
+
|
55 |
+
return data, scale
|
56 |
+
|
57 |
+
|
58 |
+
def write_pfm(path, image, scale=1):
|
59 |
+
"""Write pfm file.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
path (str): pathto file
|
63 |
+
image (array): data
|
64 |
+
scale (int, optional): Scale. Defaults to 1.
|
65 |
+
"""
|
66 |
+
|
67 |
+
with open(path, "wb") as file:
|
68 |
+
color = None
|
69 |
+
|
70 |
+
if image.dtype.name != "float32":
|
71 |
+
raise Exception("Image dtype must be float32.")
|
72 |
+
|
73 |
+
image = np.flipud(image)
|
74 |
+
|
75 |
+
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
76 |
+
color = True
|
77 |
+
elif (
|
78 |
+
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
|
79 |
+
): # greyscale
|
80 |
+
color = False
|
81 |
+
else:
|
82 |
+
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
83 |
+
|
84 |
+
file.write("PF\n" if color else "Pf\n".encode())
|
85 |
+
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
86 |
+
|
87 |
+
endian = image.dtype.byteorder
|
88 |
+
|
89 |
+
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
90 |
+
scale = -scale
|
91 |
+
|
92 |
+
file.write("%f\n".encode() % scale)
|
93 |
+
|
94 |
+
image.tofile(file)
|
95 |
+
|
96 |
+
|
97 |
+
def read_image(path):
|
98 |
+
"""Read image and output RGB image (0-1).
|
99 |
+
|
100 |
+
Args:
|
101 |
+
path (str): path to file
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
array: RGB image (0-1)
|
105 |
+
"""
|
106 |
+
img = cv2.imread(path)
|
107 |
+
|
108 |
+
if img.ndim == 2:
|
109 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
110 |
+
|
111 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
112 |
+
|
113 |
+
return img
|
114 |
+
|
115 |
+
|
116 |
+
def resize_image(img):
|
117 |
+
"""Resize image and make it fit for network.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
img (array): image
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
tensor: data ready for network
|
124 |
+
"""
|
125 |
+
height_orig = img.shape[0]
|
126 |
+
width_orig = img.shape[1]
|
127 |
+
|
128 |
+
if width_orig > height_orig:
|
129 |
+
scale = width_orig / 384
|
130 |
+
else:
|
131 |
+
scale = height_orig / 384
|
132 |
+
|
133 |
+
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
134 |
+
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
135 |
+
|
136 |
+
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
|
137 |
+
|
138 |
+
img_resized = (
|
139 |
+
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
|
140 |
+
)
|
141 |
+
img_resized = img_resized.unsqueeze(0)
|
142 |
+
|
143 |
+
return img_resized
|
144 |
+
|
145 |
+
|
146 |
+
def resize_depth(depth, width, height):
|
147 |
+
"""Resize depth map and bring to CPU (numpy).
|
148 |
+
|
149 |
+
Args:
|
150 |
+
depth (tensor): depth
|
151 |
+
width (int): image width
|
152 |
+
height (int): image height
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
array: processed depth
|
156 |
+
"""
|
157 |
+
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
|
158 |
+
|
159 |
+
depth_resized = cv2.resize(
|
160 |
+
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
|
161 |
+
)
|
162 |
+
|
163 |
+
return depth_resized
|
164 |
+
|
165 |
+
def write_depth(path, depth, bits=1):
|
166 |
+
"""Write depth map to pfm and png file.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
path (str): filepath without extension
|
170 |
+
depth (array): depth
|
171 |
+
"""
|
172 |
+
write_pfm(path + ".pfm", depth.astype(np.float32))
|
173 |
+
|
174 |
+
depth_min = depth.min()
|
175 |
+
depth_max = depth.max()
|
176 |
+
|
177 |
+
max_val = (2**(8*bits))-1
|
178 |
+
|
179 |
+
if depth_max - depth_min > np.finfo("float").eps:
|
180 |
+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
181 |
+
else:
|
182 |
+
out = np.zeros(depth.shape, dtype=depth.type)
|
183 |
+
|
184 |
+
if bits == 1:
|
185 |
+
cv2.imwrite(path + ".png", out.astype("uint8"))
|
186 |
+
elif bits == 2:
|
187 |
+
cv2.imwrite(path + ".png", out.astype("uint16"))
|
188 |
+
|
189 |
+
return
|
annotator/util.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
|
7 |
+
|
8 |
+
|
9 |
+
def HWC3(x):
|
10 |
+
assert x.dtype == np.uint8
|
11 |
+
if x.ndim == 2:
|
12 |
+
x = x[:, :, None]
|
13 |
+
assert x.ndim == 3
|
14 |
+
H, W, C = x.shape
|
15 |
+
assert C == 1 or C == 3 or C == 4
|
16 |
+
if C == 3:
|
17 |
+
return x
|
18 |
+
if C == 1:
|
19 |
+
return np.concatenate([x, x, x], axis=2)
|
20 |
+
if C == 4:
|
21 |
+
color = x[:, :, 0:3].astype(np.float32)
|
22 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
23 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
24 |
+
y = y.clip(0, 255).astype(np.uint8)
|
25 |
+
return y
|
26 |
+
|
27 |
+
|
28 |
+
def resize_image(input_image, resolution):
|
29 |
+
H, W, C = input_image.shape
|
30 |
+
H = float(H)
|
31 |
+
W = float(W)
|
32 |
+
k = float(resolution) / min(H, W)
|
33 |
+
H *= k
|
34 |
+
W *= k
|
35 |
+
H = int(np.round(H / 64.0)) * 64
|
36 |
+
W = int(np.round(W / 64.0)) * 64
|
37 |
+
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
38 |
+
return img
|
app.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import cv2
|
4 |
+
import einops
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.optim as optim
|
9 |
+
import random
|
10 |
+
import imageio
|
11 |
+
from torchvision import transforms
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from PIL import Image
|
15 |
+
import time
|
16 |
+
import scipy.interpolate
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
from pytorch_lightning import seed_everything
|
20 |
+
from annotator.util import resize_image, HWC3
|
21 |
+
from annotator.canny import CannyDetector
|
22 |
+
from annotator.midas import MidasDetector
|
23 |
+
from cldm.model import create_model, load_state_dict
|
24 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
25 |
+
from stablevideo.atlas_data import AtlasData
|
26 |
+
from stablevideo.atlas_utils import get_grid_indices, get_atlas_bounding_box
|
27 |
+
from stablevideo.aggnet import AGGNet
|
28 |
+
|
29 |
+
|
30 |
+
class StableVideo:
|
31 |
+
def __init__(self, base_cfg, canny_model_cfg, depth_model_cfg, save_memory=False):
|
32 |
+
self.base_cfg = base_cfg
|
33 |
+
self.canny_model_cfg = canny_model_cfg
|
34 |
+
self.depth_model_cfg = depth_model_cfg
|
35 |
+
self.img2img_model = None
|
36 |
+
self.canny_model = None
|
37 |
+
self.depth_model = None
|
38 |
+
self.b_atlas = None
|
39 |
+
self.f_atlas = None
|
40 |
+
self.data = None
|
41 |
+
self.crops = None
|
42 |
+
self.save_memory = save_memory
|
43 |
+
|
44 |
+
def load_canny_model(
|
45 |
+
self,
|
46 |
+
base_cfg='ckpt/cldm_v15.yaml',
|
47 |
+
canny_model_cfg='ckpt/control_sd15_canny.pth',
|
48 |
+
):
|
49 |
+
self.apply_canny = CannyDetector()
|
50 |
+
canny_model = create_model(base_cfg).cpu()
|
51 |
+
canny_model.load_state_dict(load_state_dict(canny_model_cfg, location='cuda'), strict=False)
|
52 |
+
self.canny_ddim_sampler = DDIMSampler(canny_model)
|
53 |
+
self.canny_model = canny_model
|
54 |
+
|
55 |
+
def load_depth_model(
|
56 |
+
self,
|
57 |
+
base_cfg='ckpt/cldm_v15.yaml',
|
58 |
+
depth_model_cfg='ckpt/control_sd15_depth.pth',
|
59 |
+
):
|
60 |
+
self.apply_midas = MidasDetector()
|
61 |
+
depth_model = create_model(base_cfg).cpu()
|
62 |
+
depth_model.load_state_dict(load_state_dict(depth_model_cfg, location='cuda'), strict=False)
|
63 |
+
self.depth_ddim_sampler = DDIMSampler(depth_model)
|
64 |
+
self.depth_model = depth_model
|
65 |
+
|
66 |
+
def load_video(self, video_name):
|
67 |
+
self.data = AtlasData(video_name)
|
68 |
+
save_name = f"data/{video_name}/{video_name}.mp4"
|
69 |
+
if not os.path.exists(save_name):
|
70 |
+
imageio.mimwrite(save_name, self.data.original_video.cpu().permute(0, 2, 3, 1))
|
71 |
+
print("original video saved.")
|
72 |
+
toIMG = transforms.ToPILImage()
|
73 |
+
self.f_atlas_origin = toIMG(self.data.cropped_foreground_atlas[0])
|
74 |
+
self.b_atlas_origin = toIMG(self.data.background_grid_atlas[0])
|
75 |
+
return save_name, self.f_atlas_origin, self.b_atlas_origin
|
76 |
+
|
77 |
+
@torch.no_grad()
|
78 |
+
def depth_edit(self, input_image=None,
|
79 |
+
prompt="",
|
80 |
+
a_prompt="best quality, extremely detailed",
|
81 |
+
n_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
|
82 |
+
image_resolution=512,
|
83 |
+
detect_resolution=384,
|
84 |
+
ddim_steps=20,
|
85 |
+
scale=9,
|
86 |
+
seed=-1,
|
87 |
+
eta=0,
|
88 |
+
num_samples=1):
|
89 |
+
|
90 |
+
size = input_image.size
|
91 |
+
model = self.depth_model
|
92 |
+
ddim_sampler = self.depth_ddim_sampler
|
93 |
+
apply_midas = self.apply_midas
|
94 |
+
|
95 |
+
input_image = np.array(input_image)
|
96 |
+
input_image = HWC3(input_image)
|
97 |
+
detected_map, _ = apply_midas(resize_image(input_image, detect_resolution))
|
98 |
+
detected_map = HWC3(detected_map)
|
99 |
+
img = resize_image(input_image, image_resolution)
|
100 |
+
H, W, C = img.shape
|
101 |
+
|
102 |
+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
103 |
+
|
104 |
+
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
|
105 |
+
control = torch.stack([control for _ in range(1)], dim=0)
|
106 |
+
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
107 |
+
|
108 |
+
if seed == -1:
|
109 |
+
seed = random.randint(0, 65535)
|
110 |
+
seed_everything(seed)
|
111 |
+
|
112 |
+
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
113 |
+
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
|
114 |
+
shape = (4, H // 8, W // 8)
|
115 |
+
|
116 |
+
|
117 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
|
118 |
+
shape, cond, verbose=False, eta=eta,
|
119 |
+
unconditional_guidance_scale=scale,
|
120 |
+
unconditional_conditioning=un_cond)
|
121 |
+
|
122 |
+
x_samples = model.decode_first_stage(samples)
|
123 |
+
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
124 |
+
|
125 |
+
results = [x_samples[i] for i in range(num_samples)]
|
126 |
+
self.b_atlas = Image.fromarray(results[0]).resize(size)
|
127 |
+
return self.b_atlas
|
128 |
+
|
129 |
+
@torch.no_grad()
|
130 |
+
def edit_background(self, *args, **kwargs):
|
131 |
+
self.depth_model = self.depth_model.cuda()
|
132 |
+
|
133 |
+
input_image = self.b_atlas_origin
|
134 |
+
self.depth_edit(input_image, *args, **kwargs)
|
135 |
+
|
136 |
+
if self.save_memory:
|
137 |
+
self.depth_model = self.depth_model.cpu()
|
138 |
+
return self.b_atlas
|
139 |
+
|
140 |
+
@torch.no_grad()
|
141 |
+
def advanced_edit_foreground(self,
|
142 |
+
keyframes="0",
|
143 |
+
res=2000,
|
144 |
+
prompt="",
|
145 |
+
a_prompt="best quality, extremely detailed",
|
146 |
+
n_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
|
147 |
+
image_resolution=512,
|
148 |
+
low_threshold=100,
|
149 |
+
high_threshold=200,
|
150 |
+
ddim_steps=20,
|
151 |
+
s=0.9,
|
152 |
+
scale=9,
|
153 |
+
seed=-1,
|
154 |
+
eta=0,
|
155 |
+
if_net=False,
|
156 |
+
num_samples=1):
|
157 |
+
|
158 |
+
self.canny_model = self.canny_model.cuda()
|
159 |
+
|
160 |
+
keyframes = [int(x) for x in keyframes.split(",")]
|
161 |
+
if self.data is None:
|
162 |
+
raise ValueError("Please load video first")
|
163 |
+
self.crops = self.data.get_global_crops_multi(keyframes, res)
|
164 |
+
n_keyframes = len(keyframes)
|
165 |
+
indices = get_grid_indices(0, 0, res, res)
|
166 |
+
f_atlas = torch.zeros(size=(n_keyframes, res, res, 3,)).to("cuda")
|
167 |
+
|
168 |
+
img_list = [transforms.ToPILImage()(i[0]) for i in self.crops['original_foreground_crops']]
|
169 |
+
result_list = []
|
170 |
+
|
171 |
+
# initial setting
|
172 |
+
if seed == -1:
|
173 |
+
seed = random.randint(0, 65535)
|
174 |
+
seed_everything(seed)
|
175 |
+
|
176 |
+
self.canny_ddim_sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=eta, verbose=False)
|
177 |
+
c_crossattn = [self.canny_model.get_learned_conditioning([prompt + ', ' + a_prompt])]
|
178 |
+
uc_crossattn = [self.canny_model.get_learned_conditioning([n_prompt])]
|
179 |
+
|
180 |
+
for i in range(n_keyframes):
|
181 |
+
# get current keyframe
|
182 |
+
current_img = img_list[i]
|
183 |
+
img = resize_image(HWC3(np.array(current_img)), image_resolution)
|
184 |
+
H, W, C = img.shape
|
185 |
+
shape = (4, H // 8, W // 8)
|
186 |
+
# get canny control
|
187 |
+
detected_map = self.apply_canny(img, low_threshold, high_threshold)
|
188 |
+
detected_map = HWC3(detected_map)
|
189 |
+
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
|
190 |
+
control = einops.rearrange(control.unsqueeze(0), 'b h w c -> b c h w').clone()
|
191 |
+
|
192 |
+
cond = {"c_concat": [control], "c_crossattn": c_crossattn}
|
193 |
+
un_cond = {"c_concat": [control], "c_crossattn": uc_crossattn}
|
194 |
+
|
195 |
+
|
196 |
+
# if not the key frame, calculate the mapping from last atlas
|
197 |
+
if i == 0:
|
198 |
+
latent = torch.randn((1, 4, H // 8, W // 8)).cuda()
|
199 |
+
samples, _ = self.canny_ddim_sampler.sample(ddim_steps, num_samples,
|
200 |
+
shape, cond, verbose=False, eta=eta,
|
201 |
+
unconditional_guidance_scale=scale,
|
202 |
+
unconditional_conditioning=un_cond,
|
203 |
+
x_T=latent)
|
204 |
+
else:
|
205 |
+
last_atlas = f_atlas[i-1:i].permute(0, 3, 2, 1)
|
206 |
+
mapped_img = F.grid_sample(last_atlas, self.crops['foreground_uvs'][i].reshape(1, -1, 1, 2), mode="bilinear", align_corners=self.data.config["align_corners"]).clamp(min=0.0, max=1.0).reshape((3, current_img.size[1], current_img.size[0]))
|
207 |
+
mapped_img = transforms.ToPILImage()(mapped_img)
|
208 |
+
|
209 |
+
mapped_img = mapped_img.resize((W, H))
|
210 |
+
mapped_img = np.array(mapped_img).astype(np.float32) / 255.0
|
211 |
+
mapped_img = mapped_img[None].transpose(0, 3, 1, 2)
|
212 |
+
mapped_img = torch.from_numpy(mapped_img).cuda()
|
213 |
+
mapped_img = 2. * mapped_img - 1.
|
214 |
+
latent = self.canny_model.get_first_stage_encoding(self.canny_model.encode_first_stage(mapped_img))
|
215 |
+
|
216 |
+
t_enc = int(ddim_steps * s)
|
217 |
+
latent = self.canny_ddim_sampler.stochastic_encode(latent, torch.tensor([t_enc]).to("cuda"))
|
218 |
+
samples = self.canny_ddim_sampler.decode(x_latent=latent,
|
219 |
+
cond=cond,
|
220 |
+
t_start=t_enc,
|
221 |
+
unconditional_guidance_scale=scale,
|
222 |
+
unconditional_conditioning=un_cond)
|
223 |
+
|
224 |
+
x_samples = self.canny_model.decode_first_stage(samples)
|
225 |
+
result = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
226 |
+
result = Image.fromarray(result[0])
|
227 |
+
|
228 |
+
result = result.resize(current_img.size)
|
229 |
+
result = transforms.ToTensor()(result)
|
230 |
+
# times alpha
|
231 |
+
alpha = self.crops['foreground_alpha'][i][0].cpu()
|
232 |
+
result = alpha * result
|
233 |
+
|
234 |
+
# buffer for training
|
235 |
+
result_copy = result.clone().cuda()
|
236 |
+
result_copy.requires_grad = True
|
237 |
+
result_list.append(result_copy)
|
238 |
+
|
239 |
+
# map to atlas
|
240 |
+
uv = (self.crops['foreground_uvs'][i].reshape(-1, 2) * 0.5 + 0.5) * res
|
241 |
+
for c in range(3):
|
242 |
+
interpolated = scipy.interpolate.griddata(
|
243 |
+
points=uv.cpu().numpy(),
|
244 |
+
values=result[c].reshape(-1, 1).cpu().numpy(),
|
245 |
+
xi=indices.reshape(-1, 2).cpu().numpy(),
|
246 |
+
method="linear",
|
247 |
+
).reshape(res, res)
|
248 |
+
interpolated = torch.from_numpy(interpolated).float()
|
249 |
+
interpolated[interpolated.isnan()] = 0.0
|
250 |
+
f_atlas[i, :, :, c] = interpolated
|
251 |
+
|
252 |
+
f_atlas = f_atlas.permute(0, 3, 2, 1)
|
253 |
+
|
254 |
+
# aggregate via simple median as begining
|
255 |
+
agg_atlas, _ = torch.median(f_atlas, dim=0)
|
256 |
+
|
257 |
+
if if_net == True:
|
258 |
+
#####################################
|
259 |
+
# aggregate net #
|
260 |
+
#####################################
|
261 |
+
lr, n_epoch = 1e-3, 500
|
262 |
+
agg_net = AGGNet().cuda()
|
263 |
+
loss_fn = nn.L1Loss()
|
264 |
+
optimizer = optim.SGD(agg_net.parameters(), lr=lr, momentum=0.9)
|
265 |
+
for _ in range(n_epoch):
|
266 |
+
loss = 0.
|
267 |
+
for i in range(n_keyframes):
|
268 |
+
e_img = result_list[i]
|
269 |
+
temp_agg_atlas = agg_net(agg_atlas)
|
270 |
+
rec_img = F.grid_sample(temp_agg_atlas[None],
|
271 |
+
self.crops['foreground_uvs'][i].reshape(1, -1, 1, 2),
|
272 |
+
mode="bilinear",
|
273 |
+
align_corners=self.data.config["align_corners"])
|
274 |
+
rec_img = rec_img.clamp(min=0.0, max=1.0).reshape(e_img.shape)
|
275 |
+
loss += loss_fn(rec_img, e_img)
|
276 |
+
optimizer.zero_grad()
|
277 |
+
loss.backward()
|
278 |
+
optimizer.step()
|
279 |
+
agg_atlas = agg_net(agg_atlas)
|
280 |
+
#####################################
|
281 |
+
|
282 |
+
agg_atlas, _ = get_atlas_bounding_box(self.data.mask_boundaries, agg_atlas, self.data.foreground_all_uvs)
|
283 |
+
self.f_atlas = transforms.ToPILImage()(agg_atlas)
|
284 |
+
|
285 |
+
if self.save_memory:
|
286 |
+
self.canny_model = self.canny_model.cpu()
|
287 |
+
|
288 |
+
return self.f_atlas
|
289 |
+
|
290 |
+
@torch.no_grad()
|
291 |
+
def render(self, f_atlas, b_atlas):
|
292 |
+
# foreground
|
293 |
+
if f_atlas == None:
|
294 |
+
f_atlas = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0).cuda()
|
295 |
+
else:
|
296 |
+
f_atlas, mask = f_atlas["image"], f_atlas["mask"]
|
297 |
+
f_atlas_origin = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0).cuda()
|
298 |
+
f_atlas = transforms.ToTensor()(f_atlas).unsqueeze(0).cuda()
|
299 |
+
mask = transforms.ToTensor()(mask).unsqueeze(0).cuda()
|
300 |
+
if f_atlas.shape != mask.shape:
|
301 |
+
print("Warning: truncating mask to atlas shape {}".format(f_atlas.shape))
|
302 |
+
mask = mask[:f_atlas.shape[0], :f_atlas.shape[1], :f_atlas.shape[2], :f_atlas.shape[3]]
|
303 |
+
f_atlas = f_atlas * (1 - mask) + f_atlas_origin * mask
|
304 |
+
|
305 |
+
f_atlas = torch.nn.functional.pad(
|
306 |
+
f_atlas,
|
307 |
+
pad=(
|
308 |
+
self.data.foreground_atlas_bbox[1],
|
309 |
+
self.data.foreground_grid_atlas.shape[-1] - (self.data.foreground_atlas_bbox[1] + self.data.foreground_atlas_bbox[3]),
|
310 |
+
self.data.foreground_atlas_bbox[0],
|
311 |
+
self.data.foreground_grid_atlas.shape[-2] - (self.data.foreground_atlas_bbox[0] + self.data.foreground_atlas_bbox[2]),
|
312 |
+
),
|
313 |
+
mode="replicate",
|
314 |
+
)
|
315 |
+
foreground_edit = F.grid_sample(
|
316 |
+
f_atlas, self.data.scaled_foreground_uvs, mode="bilinear", align_corners=self.data.config["align_corners"]
|
317 |
+
).clamp(min=0.0, max=1.0)
|
318 |
+
|
319 |
+
foreground_edit = foreground_edit.squeeze().t() # shape (batch, 3)
|
320 |
+
foreground_edit = (
|
321 |
+
foreground_edit.reshape(self.data.config["maximum_number_of_frames"], self.data.config["resy"], self.data.config["resx"], 3)
|
322 |
+
.permute(0, 3, 1, 2)
|
323 |
+
.clamp(min=0.0, max=1.0)
|
324 |
+
)
|
325 |
+
# background
|
326 |
+
if b_atlas == None:
|
327 |
+
b_atlas = self.b_atlas_origin
|
328 |
+
|
329 |
+
b_atlas = transforms.ToTensor()(b_atlas).unsqueeze(0).cuda()
|
330 |
+
background_edit = F.grid_sample(
|
331 |
+
b_atlas, self.data.scaled_background_uvs, mode="bilinear", align_corners=self.data.config["align_corners"]
|
332 |
+
).clamp(min=0.0, max=1.0)
|
333 |
+
background_edit = background_edit.squeeze().t() # shape (batch, 3)
|
334 |
+
background_edit = (
|
335 |
+
background_edit.reshape(self.data.config["maximum_number_of_frames"], self.data.config["resy"], self.data.config["resx"], 3)
|
336 |
+
.permute(0, 3, 1, 2)
|
337 |
+
.clamp(min=0.0, max=1.0)
|
338 |
+
)
|
339 |
+
|
340 |
+
output_video = (
|
341 |
+
self.data.all_alpha * foreground_edit
|
342 |
+
+ (1 - self.data.all_alpha) * background_edit
|
343 |
+
)
|
344 |
+
id = time.time()
|
345 |
+
os.mkdir(f"log/{id}")
|
346 |
+
save_name = f"log/{id}/video.mp4"
|
347 |
+
imageio.mimwrite(save_name, (255 * output_video.detach().cpu()).to(torch.uint8).permute(0, 2, 3, 1))
|
348 |
+
|
349 |
+
return save_name
|
350 |
+
|
351 |
+
if __name__ == '__main__':
|
352 |
+
with torch.cuda.amp.autocast():
|
353 |
+
stablevideo = StableVideo(base_cfg="ckpt/cldm_v15.yaml",
|
354 |
+
canny_model_cfg="ckpt/control_sd15_canny.pth",
|
355 |
+
depth_model_cfg="ckpt/control_sd15_depth.pth",
|
356 |
+
save_memory=True)
|
357 |
+
stablevideo.load_canny_model()
|
358 |
+
stablevideo.load_depth_model()
|
359 |
+
|
360 |
+
block = gr.Blocks().queue()
|
361 |
+
with block:
|
362 |
+
with gr.Row():
|
363 |
+
gr.Markdown("## StableVideo")
|
364 |
+
with gr.Row():
|
365 |
+
with gr.Column():
|
366 |
+
original_video = gr.Video(label="Original Video", interactive=False)
|
367 |
+
with gr.Row():
|
368 |
+
foreground_atlas = gr.Image(label="Foreground Atlas", type="pil")
|
369 |
+
background_atlas = gr.Image(label="Background Atlas", type="pil")
|
370 |
+
gr.Markdown("### Step 1. select one example video and click **Load Video** buttom and wait for 10 sec.")
|
371 |
+
avail_video = [f.name for f in os.scandir("data") if f.is_dir()]
|
372 |
+
video_name = gr.Radio(choices=avail_video,
|
373 |
+
label="Select Example Videos",
|
374 |
+
value="car-turn")
|
375 |
+
load_video_button = gr.Button("Load Video")
|
376 |
+
gr.Markdown("### Step 2. write text prompt and advanced options for background and foreground.")
|
377 |
+
with gr.Row():
|
378 |
+
f_prompt = gr.Textbox(label="Foreground Prompt", value="a picture of an orange suv")
|
379 |
+
b_prompt = gr.Textbox(label="Background Prompt", value="winter scene, snowy scene, beautiful snow")
|
380 |
+
with gr.Row():
|
381 |
+
with gr.Accordion("Advanced Foreground Options", open=False):
|
382 |
+
adv_keyframes = gr.Textbox(label="keyframe", value="20, 40, 60")
|
383 |
+
adv_atlas_resolution = gr.Slider(label="Atlas Resolution", minimum=1000, maximum=3000, value=2000, step=100)
|
384 |
+
adv_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
|
385 |
+
adv_low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
|
386 |
+
adv_high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
|
387 |
+
adv_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
388 |
+
adv_s = gr.Slider(label="Noise Scale", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
|
389 |
+
adv_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=15.0, value=9.0, step=0.1)
|
390 |
+
adv_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
391 |
+
adv_eta = gr.Number(label="eta (DDIM)", value=0.0)
|
392 |
+
adv_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, no background')
|
393 |
+
adv_n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
394 |
+
adv_if_net = gr.gradio.Checkbox(label="if use agg net", value=False)
|
395 |
+
|
396 |
+
with gr.Accordion("Background Options", open=False):
|
397 |
+
b_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
|
398 |
+
b_detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=512, step=1)
|
399 |
+
b_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
400 |
+
b_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
401 |
+
b_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
402 |
+
b_eta = gr.Number(label="eta (DDIM)", value=0.0)
|
403 |
+
b_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
404 |
+
b_n_prompt = gr.Textbox(label="Negative Prompt",
|
405 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
406 |
+
gr.Markdown("### Step 3. edit each one and render.")
|
407 |
+
with gr.Row():
|
408 |
+
f_advance_run_button = gr.Button("Advanced Edit Foreground (slower, better)")
|
409 |
+
b_run_button = gr.Button("Edit Background")
|
410 |
+
run_button = gr.Button("Render")
|
411 |
+
with gr.Column():
|
412 |
+
output_video = gr.Video(label="Output Video", interactive=False)
|
413 |
+
# output_foreground_atlas = gr.Image(label="Output Foreground Atlas", type="pil", interactive=False)
|
414 |
+
output_foreground_atlas = gr.ImageMask(label="Editable Output Foreground Atlas", type="pil", tool="sketch", interactive=True)
|
415 |
+
output_background_atlas = gr.Image(label="Output Background Atlas", type="pil", interactive=False)
|
416 |
+
|
417 |
+
# edit param
|
418 |
+
f_adv_edit_param = [adv_keyframes,
|
419 |
+
adv_atlas_resolution,
|
420 |
+
f_prompt,
|
421 |
+
adv_a_prompt,
|
422 |
+
adv_n_prompt,
|
423 |
+
adv_image_resolution,
|
424 |
+
adv_low_threshold,
|
425 |
+
adv_high_threshold,
|
426 |
+
adv_ddim_steps,
|
427 |
+
adv_s,
|
428 |
+
adv_scale,
|
429 |
+
adv_seed,
|
430 |
+
adv_eta,
|
431 |
+
adv_if_net]
|
432 |
+
b_edit_param = [b_prompt,
|
433 |
+
b_a_prompt,
|
434 |
+
b_n_prompt,
|
435 |
+
b_image_resolution,
|
436 |
+
b_detect_resolution,
|
437 |
+
b_ddim_steps,
|
438 |
+
b_scale,
|
439 |
+
b_seed,
|
440 |
+
b_eta]
|
441 |
+
# action
|
442 |
+
load_video_button.click(fn=stablevideo.load_video, inputs=video_name, outputs=[original_video, foreground_atlas, background_atlas])
|
443 |
+
f_advance_run_button.click(fn=stablevideo.advanced_edit_foreground, inputs=f_adv_edit_param, outputs=[output_foreground_atlas])
|
444 |
+
b_run_button.click(fn=stablevideo.edit_background, inputs=b_edit_param, outputs=[output_background_atlas])
|
445 |
+
run_button.click(fn=stablevideo.render, inputs=[output_foreground_atlas, output_background_atlas], outputs=[output_video])
|
446 |
+
|
447 |
+
block.launch(share=True)
|
ckpt/cldm_v15.yaml
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: cldm.cldm.ControlLDM
|
3 |
+
params:
|
4 |
+
linear_start: 0.00085
|
5 |
+
linear_end: 0.0120
|
6 |
+
num_timesteps_cond: 1
|
7 |
+
log_every_t: 200
|
8 |
+
timesteps: 1000
|
9 |
+
first_stage_key: "jpg"
|
10 |
+
cond_stage_key: "txt"
|
11 |
+
control_key: "hint"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
only_mid_control: False
|
20 |
+
|
21 |
+
control_stage_config:
|
22 |
+
target: cldm.cldm.ControlNet
|
23 |
+
params:
|
24 |
+
image_size: 32 # unused
|
25 |
+
in_channels: 4
|
26 |
+
hint_channels: 3
|
27 |
+
model_channels: 320
|
28 |
+
attention_resolutions: [ 4, 2, 1 ]
|
29 |
+
num_res_blocks: 2
|
30 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
31 |
+
num_heads: 8
|
32 |
+
use_spatial_transformer: True
|
33 |
+
transformer_depth: 1
|
34 |
+
context_dim: 768
|
35 |
+
use_checkpoint: True
|
36 |
+
legacy: False
|
37 |
+
|
38 |
+
unet_config:
|
39 |
+
target: cldm.cldm.ControlledUnetModel
|
40 |
+
params:
|
41 |
+
image_size: 32 # unused
|
42 |
+
in_channels: 4
|
43 |
+
out_channels: 4
|
44 |
+
model_channels: 320
|
45 |
+
attention_resolutions: [ 4, 2, 1 ]
|
46 |
+
num_res_blocks: 2
|
47 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
48 |
+
num_heads: 8
|
49 |
+
use_spatial_transformer: True
|
50 |
+
transformer_depth: 1
|
51 |
+
context_dim: 768
|
52 |
+
use_checkpoint: True
|
53 |
+
legacy: False
|
54 |
+
|
55 |
+
first_stage_config:
|
56 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
57 |
+
params:
|
58 |
+
embed_dim: 4
|
59 |
+
monitor: val/rec_loss
|
60 |
+
ddconfig:
|
61 |
+
double_z: true
|
62 |
+
z_channels: 4
|
63 |
+
resolution: 256
|
64 |
+
in_channels: 3
|
65 |
+
out_ch: 3
|
66 |
+
ch: 128
|
67 |
+
ch_mult:
|
68 |
+
- 1
|
69 |
+
- 2
|
70 |
+
- 4
|
71 |
+
- 4
|
72 |
+
num_res_blocks: 2
|
73 |
+
attn_resolutions: []
|
74 |
+
dropout: 0.0
|
75 |
+
lossconfig:
|
76 |
+
target: torch.nn.Identity
|
77 |
+
|
78 |
+
cond_stage_config:
|
79 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
ckpt/control_sd15_canny.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4de384b16bc2d7a1fb258ca0cbd941d7dd0a721ae996aff89f905299d6923f45
|
3 |
+
size 5710753329
|
ckpt/control_sd15_depth.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:726cd0b472c4b5c0341b01afcb7fdc4a7b4ab7c37fe797fd394c9805cbef60bf
|
3 |
+
size 5710753329
|
ckpt/dpt_hybrid-midas-501f0c75.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:501f0c75b3bca7daec6b3682c5054c09b366765aef6fa3a09d03a5cb4b230853
|
3 |
+
size 492757791
|
cldm/cldm.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import einops
|
2 |
+
import torch
|
3 |
+
import torch as th
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from ldm.modules.diffusionmodules.util import (
|
7 |
+
conv_nd,
|
8 |
+
linear,
|
9 |
+
zero_module,
|
10 |
+
timestep_embedding,
|
11 |
+
)
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
from torchvision.utils import make_grid
|
15 |
+
from ldm.modules.attention import SpatialTransformer
|
16 |
+
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
17 |
+
from ldm.models.diffusion.ddpm import LatentDiffusion
|
18 |
+
from ldm.util import log_txt_as_img, exists, instantiate_from_config
|
19 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
20 |
+
|
21 |
+
|
22 |
+
class ControlledUnetModel(UNetModel):
|
23 |
+
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
|
24 |
+
hs = []
|
25 |
+
with torch.no_grad():
|
26 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
27 |
+
emb = self.time_embed(t_emb)
|
28 |
+
h = x.type(self.dtype)
|
29 |
+
for module in self.input_blocks:
|
30 |
+
h = module(h, emb, context)
|
31 |
+
hs.append(h)
|
32 |
+
h = self.middle_block(h, emb, context)
|
33 |
+
|
34 |
+
h += control.pop()
|
35 |
+
|
36 |
+
for i, module in enumerate(self.output_blocks):
|
37 |
+
if only_mid_control:
|
38 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
39 |
+
else:
|
40 |
+
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
|
41 |
+
h = module(h, emb, context)
|
42 |
+
|
43 |
+
h = h.type(x.dtype)
|
44 |
+
return self.out(h)
|
45 |
+
|
46 |
+
|
47 |
+
class ControlNet(nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
image_size,
|
51 |
+
in_channels,
|
52 |
+
model_channels,
|
53 |
+
hint_channels,
|
54 |
+
num_res_blocks,
|
55 |
+
attention_resolutions,
|
56 |
+
dropout=0,
|
57 |
+
channel_mult=(1, 2, 4, 8),
|
58 |
+
conv_resample=True,
|
59 |
+
dims=2,
|
60 |
+
use_checkpoint=False,
|
61 |
+
use_fp16=False,
|
62 |
+
num_heads=-1,
|
63 |
+
num_head_channels=-1,
|
64 |
+
num_heads_upsample=-1,
|
65 |
+
use_scale_shift_norm=False,
|
66 |
+
resblock_updown=False,
|
67 |
+
use_new_attention_order=False,
|
68 |
+
use_spatial_transformer=False, # custom transformer support
|
69 |
+
transformer_depth=1, # custom transformer support
|
70 |
+
context_dim=None, # custom transformer support
|
71 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
72 |
+
legacy=True,
|
73 |
+
disable_self_attentions=None,
|
74 |
+
num_attention_blocks=None,
|
75 |
+
disable_middle_self_attn=False,
|
76 |
+
use_linear_in_transformer=False,
|
77 |
+
):
|
78 |
+
super().__init__()
|
79 |
+
if use_spatial_transformer:
|
80 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
81 |
+
|
82 |
+
if context_dim is not None:
|
83 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
84 |
+
from omegaconf.listconfig import ListConfig
|
85 |
+
if type(context_dim) == ListConfig:
|
86 |
+
context_dim = list(context_dim)
|
87 |
+
|
88 |
+
if num_heads_upsample == -1:
|
89 |
+
num_heads_upsample = num_heads
|
90 |
+
|
91 |
+
if num_heads == -1:
|
92 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
93 |
+
|
94 |
+
if num_head_channels == -1:
|
95 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
96 |
+
|
97 |
+
self.dims = dims
|
98 |
+
self.image_size = image_size
|
99 |
+
self.in_channels = in_channels
|
100 |
+
self.model_channels = model_channels
|
101 |
+
if isinstance(num_res_blocks, int):
|
102 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
103 |
+
else:
|
104 |
+
if len(num_res_blocks) != len(channel_mult):
|
105 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
106 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
107 |
+
self.num_res_blocks = num_res_blocks
|
108 |
+
if disable_self_attentions is not None:
|
109 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
110 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
111 |
+
if num_attention_blocks is not None:
|
112 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
113 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
114 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
115 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
116 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
117 |
+
f"attention will still not be set.")
|
118 |
+
|
119 |
+
self.attention_resolutions = attention_resolutions
|
120 |
+
self.dropout = dropout
|
121 |
+
self.channel_mult = channel_mult
|
122 |
+
self.conv_resample = conv_resample
|
123 |
+
self.use_checkpoint = use_checkpoint
|
124 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
125 |
+
self.num_heads = num_heads
|
126 |
+
self.num_head_channels = num_head_channels
|
127 |
+
self.num_heads_upsample = num_heads_upsample
|
128 |
+
self.predict_codebook_ids = n_embed is not None
|
129 |
+
|
130 |
+
time_embed_dim = model_channels * 4
|
131 |
+
self.time_embed = nn.Sequential(
|
132 |
+
linear(model_channels, time_embed_dim),
|
133 |
+
nn.SiLU(),
|
134 |
+
linear(time_embed_dim, time_embed_dim),
|
135 |
+
)
|
136 |
+
|
137 |
+
self.input_blocks = nn.ModuleList(
|
138 |
+
[
|
139 |
+
TimestepEmbedSequential(
|
140 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
141 |
+
)
|
142 |
+
]
|
143 |
+
)
|
144 |
+
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
145 |
+
|
146 |
+
self.input_hint_block = TimestepEmbedSequential(
|
147 |
+
conv_nd(dims, hint_channels, 16, 3, padding=1),
|
148 |
+
nn.SiLU(),
|
149 |
+
conv_nd(dims, 16, 16, 3, padding=1),
|
150 |
+
nn.SiLU(),
|
151 |
+
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
152 |
+
nn.SiLU(),
|
153 |
+
conv_nd(dims, 32, 32, 3, padding=1),
|
154 |
+
nn.SiLU(),
|
155 |
+
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
156 |
+
nn.SiLU(),
|
157 |
+
conv_nd(dims, 96, 96, 3, padding=1),
|
158 |
+
nn.SiLU(),
|
159 |
+
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
160 |
+
nn.SiLU(),
|
161 |
+
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
162 |
+
)
|
163 |
+
|
164 |
+
self._feature_size = model_channels
|
165 |
+
input_block_chans = [model_channels]
|
166 |
+
ch = model_channels
|
167 |
+
ds = 1
|
168 |
+
for level, mult in enumerate(channel_mult):
|
169 |
+
for nr in range(self.num_res_blocks[level]):
|
170 |
+
layers = [
|
171 |
+
ResBlock(
|
172 |
+
ch,
|
173 |
+
time_embed_dim,
|
174 |
+
dropout,
|
175 |
+
out_channels=mult * model_channels,
|
176 |
+
dims=dims,
|
177 |
+
use_checkpoint=use_checkpoint,
|
178 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
179 |
+
)
|
180 |
+
]
|
181 |
+
ch = mult * model_channels
|
182 |
+
if ds in attention_resolutions:
|
183 |
+
if num_head_channels == -1:
|
184 |
+
dim_head = ch // num_heads
|
185 |
+
else:
|
186 |
+
num_heads = ch // num_head_channels
|
187 |
+
dim_head = num_head_channels
|
188 |
+
if legacy:
|
189 |
+
#num_heads = 1
|
190 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
191 |
+
if exists(disable_self_attentions):
|
192 |
+
disabled_sa = disable_self_attentions[level]
|
193 |
+
else:
|
194 |
+
disabled_sa = False
|
195 |
+
|
196 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
197 |
+
layers.append(
|
198 |
+
AttentionBlock(
|
199 |
+
ch,
|
200 |
+
use_checkpoint=use_checkpoint,
|
201 |
+
num_heads=num_heads,
|
202 |
+
num_head_channels=dim_head,
|
203 |
+
use_new_attention_order=use_new_attention_order,
|
204 |
+
) if not use_spatial_transformer else SpatialTransformer(
|
205 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
206 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
207 |
+
use_checkpoint=use_checkpoint
|
208 |
+
)
|
209 |
+
)
|
210 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
211 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
212 |
+
self._feature_size += ch
|
213 |
+
input_block_chans.append(ch)
|
214 |
+
if level != len(channel_mult) - 1:
|
215 |
+
out_ch = ch
|
216 |
+
self.input_blocks.append(
|
217 |
+
TimestepEmbedSequential(
|
218 |
+
ResBlock(
|
219 |
+
ch,
|
220 |
+
time_embed_dim,
|
221 |
+
dropout,
|
222 |
+
out_channels=out_ch,
|
223 |
+
dims=dims,
|
224 |
+
use_checkpoint=use_checkpoint,
|
225 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
226 |
+
down=True,
|
227 |
+
)
|
228 |
+
if resblock_updown
|
229 |
+
else Downsample(
|
230 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
231 |
+
)
|
232 |
+
)
|
233 |
+
)
|
234 |
+
ch = out_ch
|
235 |
+
input_block_chans.append(ch)
|
236 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
237 |
+
ds *= 2
|
238 |
+
self._feature_size += ch
|
239 |
+
|
240 |
+
if num_head_channels == -1:
|
241 |
+
dim_head = ch // num_heads
|
242 |
+
else:
|
243 |
+
num_heads = ch // num_head_channels
|
244 |
+
dim_head = num_head_channels
|
245 |
+
if legacy:
|
246 |
+
#num_heads = 1
|
247 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
248 |
+
self.middle_block = TimestepEmbedSequential(
|
249 |
+
ResBlock(
|
250 |
+
ch,
|
251 |
+
time_embed_dim,
|
252 |
+
dropout,
|
253 |
+
dims=dims,
|
254 |
+
use_checkpoint=use_checkpoint,
|
255 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
256 |
+
),
|
257 |
+
AttentionBlock(
|
258 |
+
ch,
|
259 |
+
use_checkpoint=use_checkpoint,
|
260 |
+
num_heads=num_heads,
|
261 |
+
num_head_channels=dim_head,
|
262 |
+
use_new_attention_order=use_new_attention_order,
|
263 |
+
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
264 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
265 |
+
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
266 |
+
use_checkpoint=use_checkpoint
|
267 |
+
),
|
268 |
+
ResBlock(
|
269 |
+
ch,
|
270 |
+
time_embed_dim,
|
271 |
+
dropout,
|
272 |
+
dims=dims,
|
273 |
+
use_checkpoint=use_checkpoint,
|
274 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
275 |
+
),
|
276 |
+
)
|
277 |
+
self.middle_block_out = self.make_zero_conv(ch)
|
278 |
+
self._feature_size += ch
|
279 |
+
|
280 |
+
def make_zero_conv(self, channels):
|
281 |
+
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
282 |
+
|
283 |
+
def forward(self, x, hint, timesteps, context, **kwargs):
|
284 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
285 |
+
emb = self.time_embed(t_emb)
|
286 |
+
|
287 |
+
guided_hint = self.input_hint_block(hint, emb, context)
|
288 |
+
|
289 |
+
outs = []
|
290 |
+
|
291 |
+
h = x.type(self.dtype)
|
292 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
293 |
+
if guided_hint is not None:
|
294 |
+
h = module(h, emb, context)
|
295 |
+
h += guided_hint
|
296 |
+
guided_hint = None
|
297 |
+
else:
|
298 |
+
h = module(h, emb, context)
|
299 |
+
outs.append(zero_conv(h, emb, context))
|
300 |
+
|
301 |
+
h = self.middle_block(h, emb, context)
|
302 |
+
outs.append(self.middle_block_out(h, emb, context))
|
303 |
+
|
304 |
+
return outs
|
305 |
+
|
306 |
+
|
307 |
+
class ControlLDM(LatentDiffusion):
|
308 |
+
|
309 |
+
def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
|
310 |
+
super().__init__(*args, **kwargs)
|
311 |
+
self.control_model = instantiate_from_config(control_stage_config)
|
312 |
+
self.control_key = control_key
|
313 |
+
self.only_mid_control = only_mid_control
|
314 |
+
|
315 |
+
@torch.no_grad()
|
316 |
+
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
317 |
+
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
|
318 |
+
control = batch[self.control_key]
|
319 |
+
if bs is not None:
|
320 |
+
control = control[:bs]
|
321 |
+
control = control.to(self.device)
|
322 |
+
control = einops.rearrange(control, 'b h w c -> b c h w')
|
323 |
+
control = control.to(memory_format=torch.contiguous_format).float()
|
324 |
+
return x, dict(c_crossattn=[c], c_concat=[control])
|
325 |
+
|
326 |
+
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
|
327 |
+
assert isinstance(cond, dict)
|
328 |
+
diffusion_model = self.model.diffusion_model
|
329 |
+
cond_txt = torch.cat(cond['c_crossattn'], 1)
|
330 |
+
cond_hint = torch.cat(cond['c_concat'], 1)
|
331 |
+
|
332 |
+
control = self.control_model(x=x_noisy, hint=cond_hint, timesteps=t, context=cond_txt)
|
333 |
+
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
|
334 |
+
|
335 |
+
return eps
|
336 |
+
|
337 |
+
@torch.no_grad()
|
338 |
+
def get_unconditional_conditioning(self, N):
|
339 |
+
return self.get_learned_conditioning([""] * N)
|
340 |
+
|
341 |
+
@torch.no_grad()
|
342 |
+
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
|
343 |
+
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
344 |
+
plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
|
345 |
+
use_ema_scope=True,
|
346 |
+
**kwargs):
|
347 |
+
use_ddim = ddim_steps is not None
|
348 |
+
|
349 |
+
log = dict()
|
350 |
+
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
351 |
+
c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
|
352 |
+
N = min(z.shape[0], N)
|
353 |
+
n_row = min(z.shape[0], n_row)
|
354 |
+
log["reconstruction"] = self.decode_first_stage(z)
|
355 |
+
log["control"] = c_cat * 2.0 - 1.0
|
356 |
+
log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
|
357 |
+
|
358 |
+
if plot_diffusion_rows:
|
359 |
+
# get diffusion row
|
360 |
+
diffusion_row = list()
|
361 |
+
z_start = z[:n_row]
|
362 |
+
for t in range(self.num_timesteps):
|
363 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
364 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
365 |
+
t = t.to(self.device).long()
|
366 |
+
noise = torch.randn_like(z_start)
|
367 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
368 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
369 |
+
|
370 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
371 |
+
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
372 |
+
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
373 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
374 |
+
log["diffusion_row"] = diffusion_grid
|
375 |
+
|
376 |
+
if sample:
|
377 |
+
# get denoise row
|
378 |
+
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
379 |
+
batch_size=N, ddim=use_ddim,
|
380 |
+
ddim_steps=ddim_steps, eta=ddim_eta)
|
381 |
+
x_samples = self.decode_first_stage(samples)
|
382 |
+
log["samples"] = x_samples
|
383 |
+
if plot_denoise_rows:
|
384 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
385 |
+
log["denoise_row"] = denoise_grid
|
386 |
+
|
387 |
+
if unconditional_guidance_scale > 1.0:
|
388 |
+
uc_cross = self.get_unconditional_conditioning(N)
|
389 |
+
uc_cat = c_cat # torch.zeros_like(c_cat)
|
390 |
+
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
391 |
+
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
392 |
+
batch_size=N, ddim=use_ddim,
|
393 |
+
ddim_steps=ddim_steps, eta=ddim_eta,
|
394 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
395 |
+
unconditional_conditioning=uc_full,
|
396 |
+
)
|
397 |
+
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
398 |
+
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
399 |
+
|
400 |
+
return log
|
401 |
+
|
402 |
+
@torch.no_grad()
|
403 |
+
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
404 |
+
ddim_sampler = DDIMSampler(self)
|
405 |
+
b, c, h, w = cond["c_concat"][0].shape
|
406 |
+
shape = (self.channels, h // 8, w // 8)
|
407 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
408 |
+
return samples, intermediates
|
409 |
+
|
410 |
+
def configure_optimizers(self):
|
411 |
+
lr = self.learning_rate
|
412 |
+
params = list(self.control_model.parameters())
|
413 |
+
if not self.sd_locked:
|
414 |
+
params += list(self.model.diffusion_model.output_blocks.parameters())
|
415 |
+
params += list(self.model.diffusion_model.out.parameters())
|
416 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
417 |
+
return opt
|
418 |
+
|
419 |
+
def low_vram_shift(self, is_diffusing):
|
420 |
+
if is_diffusing:
|
421 |
+
self.model = self.model.cuda()
|
422 |
+
self.control_model = self.control_model.cuda()
|
423 |
+
self.first_stage_model = self.first_stage_model.cpu()
|
424 |
+
self.cond_stage_model = self.cond_stage_model.cpu()
|
425 |
+
else:
|
426 |
+
self.model = self.model.cpu()
|
427 |
+
self.control_model = self.control_model.cpu()
|
428 |
+
self.first_stage_model = self.first_stage_model.cuda()
|
429 |
+
self.cond_stage_model = self.cond_stage_model.cuda()
|
cldm/hack.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import einops
|
3 |
+
|
4 |
+
import ldm.modules.encoders.modules
|
5 |
+
import ldm.modules.attention
|
6 |
+
|
7 |
+
from transformers import logging
|
8 |
+
from ldm.modules.attention import default
|
9 |
+
|
10 |
+
|
11 |
+
def disable_verbosity():
|
12 |
+
logging.set_verbosity_error()
|
13 |
+
print('logging improved.')
|
14 |
+
return
|
15 |
+
|
16 |
+
|
17 |
+
def enable_sliced_attention():
|
18 |
+
ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
|
19 |
+
print('Enabled sliced_attention.')
|
20 |
+
return
|
21 |
+
|
22 |
+
|
23 |
+
def hack_everything(clip_skip=0):
|
24 |
+
disable_verbosity()
|
25 |
+
ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
|
26 |
+
ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
|
27 |
+
print('Enabled clip hacks.')
|
28 |
+
return
|
29 |
+
|
30 |
+
|
31 |
+
# Written by Lvmin
|
32 |
+
def _hacked_clip_forward(self, text):
|
33 |
+
PAD = self.tokenizer.pad_token_id
|
34 |
+
EOS = self.tokenizer.eos_token_id
|
35 |
+
BOS = self.tokenizer.bos_token_id
|
36 |
+
|
37 |
+
def tokenize(t):
|
38 |
+
return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
|
39 |
+
|
40 |
+
def transformer_encode(t):
|
41 |
+
if self.clip_skip > 1:
|
42 |
+
rt = self.transformer(input_ids=t, output_hidden_states=True)
|
43 |
+
return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
|
44 |
+
else:
|
45 |
+
return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
|
46 |
+
|
47 |
+
def split(x):
|
48 |
+
return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
|
49 |
+
|
50 |
+
def pad(x, p, i):
|
51 |
+
return x[:i] if len(x) >= i else x + [p] * (i - len(x))
|
52 |
+
|
53 |
+
raw_tokens_list = tokenize(text)
|
54 |
+
tokens_list = []
|
55 |
+
|
56 |
+
for raw_tokens in raw_tokens_list:
|
57 |
+
raw_tokens_123 = split(raw_tokens)
|
58 |
+
raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
|
59 |
+
raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
|
60 |
+
tokens_list.append(raw_tokens_123)
|
61 |
+
|
62 |
+
tokens_list = torch.IntTensor(tokens_list).to(self.device)
|
63 |
+
|
64 |
+
feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
|
65 |
+
y = transformer_encode(feed)
|
66 |
+
z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
|
67 |
+
|
68 |
+
return z
|
69 |
+
|
70 |
+
|
71 |
+
# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
|
72 |
+
def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
|
73 |
+
h = self.heads
|
74 |
+
|
75 |
+
q = self.to_q(x)
|
76 |
+
context = default(context, x)
|
77 |
+
k = self.to_k(context)
|
78 |
+
v = self.to_v(context)
|
79 |
+
del context, x
|
80 |
+
|
81 |
+
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
82 |
+
|
83 |
+
limit = k.shape[0]
|
84 |
+
att_step = 1
|
85 |
+
q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
|
86 |
+
k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
|
87 |
+
v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
|
88 |
+
|
89 |
+
q_chunks.reverse()
|
90 |
+
k_chunks.reverse()
|
91 |
+
v_chunks.reverse()
|
92 |
+
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
93 |
+
del k, q, v
|
94 |
+
for i in range(0, limit, att_step):
|
95 |
+
q_buffer = q_chunks.pop()
|
96 |
+
k_buffer = k_chunks.pop()
|
97 |
+
v_buffer = v_chunks.pop()
|
98 |
+
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
|
99 |
+
|
100 |
+
del k_buffer, q_buffer
|
101 |
+
# attention, what we cannot get enough of, by chunks
|
102 |
+
|
103 |
+
sim_buffer = sim_buffer.softmax(dim=-1)
|
104 |
+
|
105 |
+
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
|
106 |
+
del v_buffer
|
107 |
+
sim[i:i + att_step, :, :] = sim_buffer
|
108 |
+
|
109 |
+
del sim_buffer
|
110 |
+
sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
|
111 |
+
return self.to_out(sim)
|
cldm/logger.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
from PIL import Image
|
7 |
+
from pytorch_lightning.callbacks import Callback
|
8 |
+
from pytorch_lightning.utilities.distributed import rank_zero_only
|
9 |
+
|
10 |
+
|
11 |
+
class ImageLogger(Callback):
|
12 |
+
def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
|
13 |
+
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
|
14 |
+
log_images_kwargs=None):
|
15 |
+
super().__init__()
|
16 |
+
self.rescale = rescale
|
17 |
+
self.batch_freq = batch_frequency
|
18 |
+
self.max_images = max_images
|
19 |
+
if not increase_log_steps:
|
20 |
+
self.log_steps = [self.batch_freq]
|
21 |
+
self.clamp = clamp
|
22 |
+
self.disabled = disabled
|
23 |
+
self.log_on_batch_idx = log_on_batch_idx
|
24 |
+
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
25 |
+
self.log_first_step = log_first_step
|
26 |
+
|
27 |
+
@rank_zero_only
|
28 |
+
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
|
29 |
+
root = os.path.join(save_dir, "image_log", split)
|
30 |
+
for k in images:
|
31 |
+
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
32 |
+
if self.rescale:
|
33 |
+
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
34 |
+
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
35 |
+
grid = grid.numpy()
|
36 |
+
grid = (grid * 255).astype(np.uint8)
|
37 |
+
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
|
38 |
+
path = os.path.join(root, filename)
|
39 |
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
40 |
+
Image.fromarray(grid).save(path)
|
41 |
+
|
42 |
+
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
43 |
+
check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
|
44 |
+
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
45 |
+
hasattr(pl_module, "log_images") and
|
46 |
+
callable(pl_module.log_images) and
|
47 |
+
self.max_images > 0):
|
48 |
+
logger = type(pl_module.logger)
|
49 |
+
|
50 |
+
is_train = pl_module.training
|
51 |
+
if is_train:
|
52 |
+
pl_module.eval()
|
53 |
+
|
54 |
+
with torch.no_grad():
|
55 |
+
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
|
56 |
+
|
57 |
+
for k in images:
|
58 |
+
N = min(images[k].shape[0], self.max_images)
|
59 |
+
images[k] = images[k][:N]
|
60 |
+
if isinstance(images[k], torch.Tensor):
|
61 |
+
images[k] = images[k].detach().cpu()
|
62 |
+
if self.clamp:
|
63 |
+
images[k] = torch.clamp(images[k], -1., 1.)
|
64 |
+
|
65 |
+
self.log_local(pl_module.logger.save_dir, split, images,
|
66 |
+
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
67 |
+
|
68 |
+
if is_train:
|
69 |
+
pl_module.train()
|
70 |
+
|
71 |
+
def check_frequency(self, check_idx):
|
72 |
+
return check_idx % self.batch_freq == 0
|
73 |
+
|
74 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
75 |
+
if not self.disabled:
|
76 |
+
self.log_img(pl_module, batch, batch_idx, split="train")
|
cldm/model.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from ldm.util import instantiate_from_config
|
6 |
+
|
7 |
+
|
8 |
+
def get_state_dict(d):
|
9 |
+
return d.get('state_dict', d)
|
10 |
+
|
11 |
+
|
12 |
+
def load_state_dict(ckpt_path, location='cpu'):
|
13 |
+
_, extension = os.path.splitext(ckpt_path)
|
14 |
+
if extension.lower() == ".safetensors":
|
15 |
+
import safetensors.torch
|
16 |
+
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
17 |
+
else:
|
18 |
+
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
|
19 |
+
state_dict = get_state_dict(state_dict)
|
20 |
+
print(f'Loaded state_dict from [{ckpt_path}]')
|
21 |
+
return state_dict
|
22 |
+
|
23 |
+
|
24 |
+
def create_model(config_path):
|
25 |
+
config = OmegaConf.load(config_path)
|
26 |
+
model = instantiate_from_config(config.model).cpu()
|
27 |
+
print(f'Loaded model config from [{config_path}]')
|
28 |
+
return model
|
data/bear/bear.mp4
ADDED
Binary file (771 kB). View file
|
|
data/bear/bear/00000.jpg
ADDED
data/bear/bear/00001.jpg
ADDED
data/bear/bear/00002.jpg
ADDED
data/bear/bear/00003.jpg
ADDED
data/bear/bear/00004.jpg
ADDED
data/bear/bear/00005.jpg
ADDED
data/bear/bear/00006.jpg
ADDED
data/bear/bear/00007.jpg
ADDED
data/bear/bear/00008.jpg
ADDED
data/bear/bear/00009.jpg
ADDED
data/bear/bear/00010.jpg
ADDED
data/bear/bear/00011.jpg
ADDED
data/bear/bear/00012.jpg
ADDED
data/bear/bear/00013.jpg
ADDED
data/bear/bear/00014.jpg
ADDED
data/bear/bear/00015.jpg
ADDED
data/bear/bear/00016.jpg
ADDED
data/bear/bear/00017.jpg
ADDED
data/bear/bear/00018.jpg
ADDED
data/bear/bear/00019.jpg
ADDED
data/bear/bear/00020.jpg
ADDED
data/bear/bear/00021.jpg
ADDED
data/bear/bear/00022.jpg
ADDED
data/bear/bear/00023.jpg
ADDED