wchai commited on
Commit
4c9c42b
1 Parent(s): bd087dc
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +160 -0
  2. README.md +1 -1
  3. annotator/canny/__init__.py +6 -0
  4. annotator/ckpts/dpt_hybrid-midas-501f0c75.pt +3 -0
  5. annotator/midas/__init__.py +38 -0
  6. annotator/midas/api.py +169 -0
  7. annotator/midas/midas/__init__.py +0 -0
  8. annotator/midas/midas/base_model.py +16 -0
  9. annotator/midas/midas/blocks.py +342 -0
  10. annotator/midas/midas/dpt_depth.py +109 -0
  11. annotator/midas/midas/midas_net.py +76 -0
  12. annotator/midas/midas/midas_net_custom.py +128 -0
  13. annotator/midas/midas/transforms.py +234 -0
  14. annotator/midas/midas/vit.py +491 -0
  15. annotator/midas/utils.py +189 -0
  16. annotator/util.py +38 -0
  17. app.py +447 -0
  18. ckpt/cldm_v15.yaml +79 -0
  19. ckpt/control_sd15_canny.pth +3 -0
  20. ckpt/control_sd15_depth.pth +3 -0
  21. ckpt/dpt_hybrid-midas-501f0c75.pt +3 -0
  22. cldm/cldm.py +429 -0
  23. cldm/hack.py +111 -0
  24. cldm/logger.py +76 -0
  25. cldm/model.py +28 -0
  26. data/bear/bear.mp4 +0 -0
  27. data/bear/bear/00000.jpg +0 -0
  28. data/bear/bear/00001.jpg +0 -0
  29. data/bear/bear/00002.jpg +0 -0
  30. data/bear/bear/00003.jpg +0 -0
  31. data/bear/bear/00004.jpg +0 -0
  32. data/bear/bear/00005.jpg +0 -0
  33. data/bear/bear/00006.jpg +0 -0
  34. data/bear/bear/00007.jpg +0 -0
  35. data/bear/bear/00008.jpg +0 -0
  36. data/bear/bear/00009.jpg +0 -0
  37. data/bear/bear/00010.jpg +0 -0
  38. data/bear/bear/00011.jpg +0 -0
  39. data/bear/bear/00012.jpg +0 -0
  40. data/bear/bear/00013.jpg +0 -0
  41. data/bear/bear/00014.jpg +0 -0
  42. data/bear/bear/00015.jpg +0 -0
  43. data/bear/bear/00016.jpg +0 -0
  44. data/bear/bear/00017.jpg +0 -0
  45. data/bear/bear/00018.jpg +0 -0
  46. data/bear/bear/00019.jpg +0 -0
  47. data/bear/bear/00020.jpg +0 -0
  48. data/bear/bear/00021.jpg +0 -0
  49. data/bear/bear/00022.jpg +0 -0
  50. data/bear/bear/00023.jpg +0 -0
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
README.md CHANGED
@@ -6,7 +6,7 @@ colorTo: green
6
  sdk: gradio
7
  sdk_version: 3.41.2
8
  app_file: app.py
9
- pinned: false
10
  license: apache-2.0
11
  ---
12
 
 
6
  sdk: gradio
7
  sdk_version: 3.41.2
8
  app_file: app.py
9
+ pinned: true
10
  license: apache-2.0
11
  ---
12
 
annotator/canny/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+
4
+ class CannyDetector:
5
+ def __call__(self, img, low_threshold, high_threshold):
6
+ return cv2.Canny(img, low_threshold, high_threshold)
annotator/ckpts/dpt_hybrid-midas-501f0c75.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:501f0c75b3bca7daec6b3682c5054c09b366765aef6fa3a09d03a5cb4b230853
3
+ size 492757791
annotator/midas/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+
5
+ from einops import rearrange
6
+ from .api import MiDaSInference
7
+
8
+
9
+ class MidasDetector:
10
+ def __init__(self):
11
+ self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
12
+
13
+ def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
14
+ assert input_image.ndim == 3
15
+ image_depth = input_image
16
+ with torch.no_grad():
17
+ image_depth = torch.from_numpy(image_depth).float().cuda()
18
+ image_depth = image_depth / 127.5 - 1.0
19
+ image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
20
+ depth = self.model(image_depth)[0]
21
+
22
+ depth_pt = depth.clone()
23
+ depth_pt -= torch.min(depth_pt)
24
+ depth_pt /= torch.max(depth_pt)
25
+ depth_pt = depth_pt.cpu().numpy()
26
+ depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
27
+
28
+ depth_np = depth.cpu().numpy()
29
+ x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
30
+ y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
31
+ z = np.ones_like(x) * a
32
+ x[depth_pt < bg_th] = 0
33
+ y[depth_pt < bg_th] = 0
34
+ normal = np.stack([x, y, z], axis=2)
35
+ normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
36
+ normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
37
+
38
+ return depth_image, normal_image
annotator/midas/api.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/isl-org/MiDaS
2
+
3
+ import cv2
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision.transforms import Compose
8
+
9
+ from .midas.dpt_depth import DPTDepthModel
10
+ from .midas.midas_net import MidasNet
11
+ from .midas.midas_net_custom import MidasNet_small
12
+ from .midas.transforms import Resize, NormalizeImage, PrepareForNet
13
+ from annotator.util import annotator_ckpts_path
14
+
15
+
16
+ ISL_PATHS = {
17
+ "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
18
+ "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
19
+ "midas_v21": "",
20
+ "midas_v21_small": "",
21
+ }
22
+
23
+ remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
24
+
25
+
26
+ def disabled_train(self, mode=True):
27
+ """Overwrite model.train with this function to make sure train/eval mode
28
+ does not change anymore."""
29
+ return self
30
+
31
+
32
+ def load_midas_transform(model_type):
33
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
34
+ # load transform only
35
+ if model_type == "dpt_large": # DPT-Large
36
+ net_w, net_h = 384, 384
37
+ resize_mode = "minimal"
38
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
39
+
40
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
41
+ net_w, net_h = 384, 384
42
+ resize_mode = "minimal"
43
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
44
+
45
+ elif model_type == "midas_v21":
46
+ net_w, net_h = 384, 384
47
+ resize_mode = "upper_bound"
48
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
49
+
50
+ elif model_type == "midas_v21_small":
51
+ net_w, net_h = 256, 256
52
+ resize_mode = "upper_bound"
53
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
54
+
55
+ else:
56
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
57
+
58
+ transform = Compose(
59
+ [
60
+ Resize(
61
+ net_w,
62
+ net_h,
63
+ resize_target=None,
64
+ keep_aspect_ratio=True,
65
+ ensure_multiple_of=32,
66
+ resize_method=resize_mode,
67
+ image_interpolation_method=cv2.INTER_CUBIC,
68
+ ),
69
+ normalization,
70
+ PrepareForNet(),
71
+ ]
72
+ )
73
+
74
+ return transform
75
+
76
+
77
+ def load_model(model_type):
78
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
79
+ # load network
80
+ model_path = ISL_PATHS[model_type]
81
+ if model_type == "dpt_large": # DPT-Large
82
+ model = DPTDepthModel(
83
+ path=model_path,
84
+ backbone="vitl16_384",
85
+ non_negative=True,
86
+ )
87
+ net_w, net_h = 384, 384
88
+ resize_mode = "minimal"
89
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
90
+
91
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
92
+ if not os.path.exists(model_path):
93
+ from basicsr.utils.download_util import load_file_from_url
94
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
95
+
96
+ model = DPTDepthModel(
97
+ path=model_path,
98
+ backbone="vitb_rn50_384",
99
+ non_negative=True,
100
+ )
101
+ net_w, net_h = 384, 384
102
+ resize_mode = "minimal"
103
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
104
+
105
+ elif model_type == "midas_v21":
106
+ model = MidasNet(model_path, non_negative=True)
107
+ net_w, net_h = 384, 384
108
+ resize_mode = "upper_bound"
109
+ normalization = NormalizeImage(
110
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
111
+ )
112
+
113
+ elif model_type == "midas_v21_small":
114
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
115
+ non_negative=True, blocks={'expand': True})
116
+ net_w, net_h = 256, 256
117
+ resize_mode = "upper_bound"
118
+ normalization = NormalizeImage(
119
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
120
+ )
121
+
122
+ else:
123
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
124
+ assert False
125
+
126
+ transform = Compose(
127
+ [
128
+ Resize(
129
+ net_w,
130
+ net_h,
131
+ resize_target=None,
132
+ keep_aspect_ratio=True,
133
+ ensure_multiple_of=32,
134
+ resize_method=resize_mode,
135
+ image_interpolation_method=cv2.INTER_CUBIC,
136
+ ),
137
+ normalization,
138
+ PrepareForNet(),
139
+ ]
140
+ )
141
+
142
+ return model.eval(), transform
143
+
144
+
145
+ class MiDaSInference(nn.Module):
146
+ MODEL_TYPES_TORCH_HUB = [
147
+ "DPT_Large",
148
+ "DPT_Hybrid",
149
+ "MiDaS_small"
150
+ ]
151
+ MODEL_TYPES_ISL = [
152
+ "dpt_large",
153
+ "dpt_hybrid",
154
+ "midas_v21",
155
+ "midas_v21_small",
156
+ ]
157
+
158
+ def __init__(self, model_type):
159
+ super().__init__()
160
+ assert (model_type in self.MODEL_TYPES_ISL)
161
+ model, _ = load_model(model_type)
162
+ self.model = model
163
+ self.model.train = disabled_train
164
+
165
+ def forward(self, x):
166
+ with torch.no_grad():
167
+ prediction = self.model(x)
168
+ return prediction
169
+
annotator/midas/midas/__init__.py ADDED
File without changes
annotator/midas/midas/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device('cpu'))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
annotator/midas/midas/blocks.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vit import (
5
+ _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384,
7
+ _make_pretrained_vitb16_384,
8
+ forward_vit,
9
+ )
10
+
11
+ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12
+ if backbone == "vitl16_384":
13
+ pretrained = _make_pretrained_vitl16_384(
14
+ use_pretrained, hooks=hooks, use_readout=use_readout
15
+ )
16
+ scratch = _make_scratch(
17
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
18
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
19
+ elif backbone == "vitb_rn50_384":
20
+ pretrained = _make_pretrained_vitb_rn50_384(
21
+ use_pretrained,
22
+ hooks=hooks,
23
+ use_vit_only=use_vit_only,
24
+ use_readout=use_readout,
25
+ )
26
+ scratch = _make_scratch(
27
+ [256, 512, 768, 768], features, groups=groups, expand=expand
28
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
29
+ elif backbone == "vitb16_384":
30
+ pretrained = _make_pretrained_vitb16_384(
31
+ use_pretrained, hooks=hooks, use_readout=use_readout
32
+ )
33
+ scratch = _make_scratch(
34
+ [96, 192, 384, 768], features, groups=groups, expand=expand
35
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
36
+ elif backbone == "resnext101_wsl":
37
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39
+ elif backbone == "efficientnet_lite3":
40
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42
+ else:
43
+ print(f"Backbone '{backbone}' not implemented")
44
+ assert False
45
+
46
+ return pretrained, scratch
47
+
48
+
49
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50
+ scratch = nn.Module()
51
+
52
+ out_shape1 = out_shape
53
+ out_shape2 = out_shape
54
+ out_shape3 = out_shape
55
+ out_shape4 = out_shape
56
+ if expand==True:
57
+ out_shape1 = out_shape
58
+ out_shape2 = out_shape*2
59
+ out_shape3 = out_shape*4
60
+ out_shape4 = out_shape*8
61
+
62
+ scratch.layer1_rn = nn.Conv2d(
63
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64
+ )
65
+ scratch.layer2_rn = nn.Conv2d(
66
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67
+ )
68
+ scratch.layer3_rn = nn.Conv2d(
69
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70
+ )
71
+ scratch.layer4_rn = nn.Conv2d(
72
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73
+ )
74
+
75
+ return scratch
76
+
77
+
78
+ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79
+ efficientnet = torch.hub.load(
80
+ "rwightman/gen-efficientnet-pytorch",
81
+ "tf_efficientnet_lite3",
82
+ pretrained=use_pretrained,
83
+ exportable=exportable
84
+ )
85
+ return _make_efficientnet_backbone(efficientnet)
86
+
87
+
88
+ def _make_efficientnet_backbone(effnet):
89
+ pretrained = nn.Module()
90
+
91
+ pretrained.layer1 = nn.Sequential(
92
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93
+ )
94
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97
+
98
+ return pretrained
99
+
100
+
101
+ def _make_resnet_backbone(resnet):
102
+ pretrained = nn.Module()
103
+ pretrained.layer1 = nn.Sequential(
104
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105
+ )
106
+
107
+ pretrained.layer2 = resnet.layer2
108
+ pretrained.layer3 = resnet.layer3
109
+ pretrained.layer4 = resnet.layer4
110
+
111
+ return pretrained
112
+
113
+
114
+ def _make_pretrained_resnext101_wsl(use_pretrained):
115
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116
+ return _make_resnet_backbone(resnet)
117
+
118
+
119
+
120
+ class Interpolate(nn.Module):
121
+ """Interpolation module.
122
+ """
123
+
124
+ def __init__(self, scale_factor, mode, align_corners=False):
125
+ """Init.
126
+
127
+ Args:
128
+ scale_factor (float): scaling
129
+ mode (str): interpolation mode
130
+ """
131
+ super(Interpolate, self).__init__()
132
+
133
+ self.interp = nn.functional.interpolate
134
+ self.scale_factor = scale_factor
135
+ self.mode = mode
136
+ self.align_corners = align_corners
137
+
138
+ def forward(self, x):
139
+ """Forward pass.
140
+
141
+ Args:
142
+ x (tensor): input
143
+
144
+ Returns:
145
+ tensor: interpolated data
146
+ """
147
+
148
+ x = self.interp(
149
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150
+ )
151
+
152
+ return x
153
+
154
+
155
+ class ResidualConvUnit(nn.Module):
156
+ """Residual convolution module.
157
+ """
158
+
159
+ def __init__(self, features):
160
+ """Init.
161
+
162
+ Args:
163
+ features (int): number of features
164
+ """
165
+ super().__init__()
166
+
167
+ self.conv1 = nn.Conv2d(
168
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
169
+ )
170
+
171
+ self.conv2 = nn.Conv2d(
172
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
173
+ )
174
+
175
+ self.relu = nn.ReLU(inplace=True)
176
+
177
+ def forward(self, x):
178
+ """Forward pass.
179
+
180
+ Args:
181
+ x (tensor): input
182
+
183
+ Returns:
184
+ tensor: output
185
+ """
186
+ out = self.relu(x)
187
+ out = self.conv1(out)
188
+ out = self.relu(out)
189
+ out = self.conv2(out)
190
+
191
+ return out + x
192
+
193
+
194
+ class FeatureFusionBlock(nn.Module):
195
+ """Feature fusion block.
196
+ """
197
+
198
+ def __init__(self, features):
199
+ """Init.
200
+
201
+ Args:
202
+ features (int): number of features
203
+ """
204
+ super(FeatureFusionBlock, self).__init__()
205
+
206
+ self.resConfUnit1 = ResidualConvUnit(features)
207
+ self.resConfUnit2 = ResidualConvUnit(features)
208
+
209
+ def forward(self, *xs):
210
+ """Forward pass.
211
+
212
+ Returns:
213
+ tensor: output
214
+ """
215
+ output = xs[0]
216
+
217
+ if len(xs) == 2:
218
+ output += self.resConfUnit1(xs[1])
219
+
220
+ output = self.resConfUnit2(output)
221
+
222
+ output = nn.functional.interpolate(
223
+ output, scale_factor=2, mode="bilinear", align_corners=True
224
+ )
225
+
226
+ return output
227
+
228
+
229
+
230
+
231
+ class ResidualConvUnit_custom(nn.Module):
232
+ """Residual convolution module.
233
+ """
234
+
235
+ def __init__(self, features, activation, bn):
236
+ """Init.
237
+
238
+ Args:
239
+ features (int): number of features
240
+ """
241
+ super().__init__()
242
+
243
+ self.bn = bn
244
+
245
+ self.groups=1
246
+
247
+ self.conv1 = nn.Conv2d(
248
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249
+ )
250
+
251
+ self.conv2 = nn.Conv2d(
252
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253
+ )
254
+
255
+ if self.bn==True:
256
+ self.bn1 = nn.BatchNorm2d(features)
257
+ self.bn2 = nn.BatchNorm2d(features)
258
+
259
+ self.activation = activation
260
+
261
+ self.skip_add = nn.quantized.FloatFunctional()
262
+
263
+ def forward(self, x):
264
+ """Forward pass.
265
+
266
+ Args:
267
+ x (tensor): input
268
+
269
+ Returns:
270
+ tensor: output
271
+ """
272
+
273
+ out = self.activation(x)
274
+ out = self.conv1(out)
275
+ if self.bn==True:
276
+ out = self.bn1(out)
277
+
278
+ out = self.activation(out)
279
+ out = self.conv2(out)
280
+ if self.bn==True:
281
+ out = self.bn2(out)
282
+
283
+ if self.groups > 1:
284
+ out = self.conv_merge(out)
285
+
286
+ return self.skip_add.add(out, x)
287
+
288
+ # return out + x
289
+
290
+
291
+ class FeatureFusionBlock_custom(nn.Module):
292
+ """Feature fusion block.
293
+ """
294
+
295
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296
+ """Init.
297
+
298
+ Args:
299
+ features (int): number of features
300
+ """
301
+ super(FeatureFusionBlock_custom, self).__init__()
302
+
303
+ self.deconv = deconv
304
+ self.align_corners = align_corners
305
+
306
+ self.groups=1
307
+
308
+ self.expand = expand
309
+ out_features = features
310
+ if self.expand==True:
311
+ out_features = features//2
312
+
313
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314
+
315
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317
+
318
+ self.skip_add = nn.quantized.FloatFunctional()
319
+
320
+ def forward(self, *xs):
321
+ """Forward pass.
322
+
323
+ Returns:
324
+ tensor: output
325
+ """
326
+ output = xs[0]
327
+
328
+ if len(xs) == 2:
329
+ res = self.resConfUnit1(xs[1])
330
+ output = self.skip_add.add(output, res)
331
+ # output += res
332
+
333
+ output = self.resConfUnit2(output)
334
+
335
+ output = nn.functional.interpolate(
336
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337
+ )
338
+
339
+ output = self.out_conv(output)
340
+
341
+ return output
342
+
annotator/midas/midas/dpt_depth.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .base_model import BaseModel
6
+ from .blocks import (
7
+ FeatureFusionBlock,
8
+ FeatureFusionBlock_custom,
9
+ Interpolate,
10
+ _make_encoder,
11
+ forward_vit,
12
+ )
13
+
14
+
15
+ def _make_fusion_block(features, use_bn):
16
+ return FeatureFusionBlock_custom(
17
+ features,
18
+ nn.ReLU(False),
19
+ deconv=False,
20
+ bn=use_bn,
21
+ expand=False,
22
+ align_corners=True,
23
+ )
24
+
25
+
26
+ class DPT(BaseModel):
27
+ def __init__(
28
+ self,
29
+ head,
30
+ features=256,
31
+ backbone="vitb_rn50_384",
32
+ readout="project",
33
+ channels_last=False,
34
+ use_bn=False,
35
+ ):
36
+
37
+ super(DPT, self).__init__()
38
+
39
+ self.channels_last = channels_last
40
+
41
+ hooks = {
42
+ "vitb_rn50_384": [0, 1, 8, 11],
43
+ "vitb16_384": [2, 5, 8, 11],
44
+ "vitl16_384": [5, 11, 17, 23],
45
+ }
46
+
47
+ # Instantiate backbone and reassemble blocks
48
+ self.pretrained, self.scratch = _make_encoder(
49
+ backbone,
50
+ features,
51
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
52
+ groups=1,
53
+ expand=False,
54
+ exportable=False,
55
+ hooks=hooks[backbone],
56
+ use_readout=readout,
57
+ )
58
+
59
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63
+
64
+ self.scratch.output_conv = head
65
+
66
+
67
+ def forward(self, x):
68
+ if self.channels_last == True:
69
+ x.contiguous(memory_format=torch.channels_last)
70
+
71
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72
+
73
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
74
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
75
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
76
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
77
+
78
+ path_4 = self.scratch.refinenet4(layer_4_rn)
79
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82
+
83
+ out = self.scratch.output_conv(path_1)
84
+
85
+ return out
86
+
87
+
88
+ class DPTDepthModel(DPT):
89
+ def __init__(self, path=None, non_negative=True, **kwargs):
90
+ features = kwargs["features"] if "features" in kwargs else 256
91
+
92
+ head = nn.Sequential(
93
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96
+ nn.ReLU(True),
97
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98
+ nn.ReLU(True) if non_negative else nn.Identity(),
99
+ nn.Identity(),
100
+ )
101
+
102
+ super().__init__(head, **kwargs)
103
+
104
+ if path is not None:
105
+ self.load(path)
106
+
107
+ def forward(self, x):
108
+ return super().forward(x).squeeze(dim=1)
109
+
annotator/midas/midas/midas_net.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=256, non_negative=True):
17
+ """Init.
18
+
19
+ Args:
20
+ path (str, optional): Path to saved model. Defaults to None.
21
+ features (int, optional): Number of features. Defaults to 256.
22
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23
+ """
24
+ print("Loading weights: ", path)
25
+
26
+ super(MidasNet, self).__init__()
27
+
28
+ use_pretrained = False if path is None else True
29
+
30
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31
+
32
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
33
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
34
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
35
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
36
+
37
+ self.scratch.output_conv = nn.Sequential(
38
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39
+ Interpolate(scale_factor=2, mode="bilinear"),
40
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41
+ nn.ReLU(True),
42
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43
+ nn.ReLU(True) if non_negative else nn.Identity(),
44
+ )
45
+
46
+ if path:
47
+ self.load(path)
48
+
49
+ def forward(self, x):
50
+ """Forward pass.
51
+
52
+ Args:
53
+ x (tensor): input data (image)
54
+
55
+ Returns:
56
+ tensor: depth
57
+ """
58
+
59
+ layer_1 = self.pretrained.layer1(x)
60
+ layer_2 = self.pretrained.layer2(layer_1)
61
+ layer_3 = self.pretrained.layer3(layer_2)
62
+ layer_4 = self.pretrained.layer4(layer_3)
63
+
64
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
65
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
66
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
67
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
68
+
69
+ path_4 = self.scratch.refinenet4(layer_4_rn)
70
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73
+
74
+ out = self.scratch.output_conv(path_1)
75
+
76
+ return torch.squeeze(out, dim=1)
annotator/midas/midas/midas_net_custom.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet_small(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17
+ blocks={'expand': True}):
18
+ """Init.
19
+
20
+ Args:
21
+ path (str, optional): Path to saved model. Defaults to None.
22
+ features (int, optional): Number of features. Defaults to 256.
23
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24
+ """
25
+ print("Loading weights: ", path)
26
+
27
+ super(MidasNet_small, self).__init__()
28
+
29
+ use_pretrained = False if path else True
30
+
31
+ self.channels_last = channels_last
32
+ self.blocks = blocks
33
+ self.backbone = backbone
34
+
35
+ self.groups = 1
36
+
37
+ features1=features
38
+ features2=features
39
+ features3=features
40
+ features4=features
41
+ self.expand = False
42
+ if "expand" in self.blocks and self.blocks['expand'] == True:
43
+ self.expand = True
44
+ features1=features
45
+ features2=features*2
46
+ features3=features*4
47
+ features4=features*8
48
+
49
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50
+
51
+ self.scratch.activation = nn.ReLU(False)
52
+
53
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57
+
58
+
59
+ self.scratch.output_conv = nn.Sequential(
60
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61
+ Interpolate(scale_factor=2, mode="bilinear"),
62
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63
+ self.scratch.activation,
64
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65
+ nn.ReLU(True) if non_negative else nn.Identity(),
66
+ nn.Identity(),
67
+ )
68
+
69
+ if path:
70
+ self.load(path)
71
+
72
+
73
+ def forward(self, x):
74
+ """Forward pass.
75
+
76
+ Args:
77
+ x (tensor): input data (image)
78
+
79
+ Returns:
80
+ tensor: depth
81
+ """
82
+ if self.channels_last==True:
83
+ print("self.channels_last = ", self.channels_last)
84
+ x.contiguous(memory_format=torch.channels_last)
85
+
86
+
87
+ layer_1 = self.pretrained.layer1(x)
88
+ layer_2 = self.pretrained.layer2(layer_1)
89
+ layer_3 = self.pretrained.layer3(layer_2)
90
+ layer_4 = self.pretrained.layer4(layer_3)
91
+
92
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
93
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
94
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
95
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
96
+
97
+
98
+ path_4 = self.scratch.refinenet4(layer_4_rn)
99
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102
+
103
+ out = self.scratch.output_conv(path_1)
104
+
105
+ return torch.squeeze(out, dim=1)
106
+
107
+
108
+
109
+ def fuse_model(m):
110
+ prev_previous_type = nn.Identity()
111
+ prev_previous_name = ''
112
+ previous_type = nn.Identity()
113
+ previous_name = ''
114
+ for name, module in m.named_modules():
115
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116
+ # print("FUSED ", prev_previous_name, previous_name, name)
117
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119
+ # print("FUSED ", prev_previous_name, previous_name)
120
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122
+ # print("FUSED ", previous_name, name)
123
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124
+
125
+ prev_previous_type = previous_type
126
+ prev_previous_name = previous_name
127
+ previous_type = type(module)
128
+ previous_name = name
annotator/midas/midas/transforms.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+
6
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
+
9
+ Args:
10
+ sample (dict): sample
11
+ size (tuple): image size
12
+
13
+ Returns:
14
+ tuple: new size
15
+ """
16
+ shape = list(sample["disparity"].shape)
17
+
18
+ if shape[0] >= size[0] and shape[1] >= size[1]:
19
+ return sample
20
+
21
+ scale = [0, 0]
22
+ scale[0] = size[0] / shape[0]
23
+ scale[1] = size[1] / shape[1]
24
+
25
+ scale = max(scale)
26
+
27
+ shape[0] = math.ceil(scale * shape[0])
28
+ shape[1] = math.ceil(scale * shape[1])
29
+
30
+ # resize
31
+ sample["image"] = cv2.resize(
32
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
+ )
34
+
35
+ sample["disparity"] = cv2.resize(
36
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
+ )
38
+ sample["mask"] = cv2.resize(
39
+ sample["mask"].astype(np.float32),
40
+ tuple(shape[::-1]),
41
+ interpolation=cv2.INTER_NEAREST,
42
+ )
43
+ sample["mask"] = sample["mask"].astype(bool)
44
+
45
+ return tuple(shape)
46
+
47
+
48
+ class Resize(object):
49
+ """Resize sample to given size (width, height).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ width,
55
+ height,
56
+ resize_target=True,
57
+ keep_aspect_ratio=False,
58
+ ensure_multiple_of=1,
59
+ resize_method="lower_bound",
60
+ image_interpolation_method=cv2.INTER_AREA,
61
+ ):
62
+ """Init.
63
+
64
+ Args:
65
+ width (int): desired output width
66
+ height (int): desired output height
67
+ resize_target (bool, optional):
68
+ True: Resize the full sample (image, mask, target).
69
+ False: Resize image only.
70
+ Defaults to True.
71
+ keep_aspect_ratio (bool, optional):
72
+ True: Keep the aspect ratio of the input sample.
73
+ Output sample might not have the given width and height, and
74
+ resize behaviour depends on the parameter 'resize_method'.
75
+ Defaults to False.
76
+ ensure_multiple_of (int, optional):
77
+ Output width and height is constrained to be multiple of this parameter.
78
+ Defaults to 1.
79
+ resize_method (str, optional):
80
+ "lower_bound": Output will be at least as large as the given size.
81
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83
+ Defaults to "lower_bound".
84
+ """
85
+ self.__width = width
86
+ self.__height = height
87
+
88
+ self.__resize_target = resize_target
89
+ self.__keep_aspect_ratio = keep_aspect_ratio
90
+ self.__multiple_of = ensure_multiple_of
91
+ self.__resize_method = resize_method
92
+ self.__image_interpolation_method = image_interpolation_method
93
+
94
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96
+
97
+ if max_val is not None and y > max_val:
98
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
+
100
+ if y < min_val:
101
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
+
103
+ return y
104
+
105
+ def get_size(self, width, height):
106
+ # determine new height and width
107
+ scale_height = self.__height / height
108
+ scale_width = self.__width / width
109
+
110
+ if self.__keep_aspect_ratio:
111
+ if self.__resize_method == "lower_bound":
112
+ # scale such that output size is lower bound
113
+ if scale_width > scale_height:
114
+ # fit width
115
+ scale_height = scale_width
116
+ else:
117
+ # fit height
118
+ scale_width = scale_height
119
+ elif self.__resize_method == "upper_bound":
120
+ # scale such that output size is upper bound
121
+ if scale_width < scale_height:
122
+ # fit width
123
+ scale_height = scale_width
124
+ else:
125
+ # fit height
126
+ scale_width = scale_height
127
+ elif self.__resize_method == "minimal":
128
+ # scale as least as possbile
129
+ if abs(1 - scale_width) < abs(1 - scale_height):
130
+ # fit width
131
+ scale_height = scale_width
132
+ else:
133
+ # fit height
134
+ scale_width = scale_height
135
+ else:
136
+ raise ValueError(
137
+ f"resize_method {self.__resize_method} not implemented"
138
+ )
139
+
140
+ if self.__resize_method == "lower_bound":
141
+ new_height = self.constrain_to_multiple_of(
142
+ scale_height * height, min_val=self.__height
143
+ )
144
+ new_width = self.constrain_to_multiple_of(
145
+ scale_width * width, min_val=self.__width
146
+ )
147
+ elif self.__resize_method == "upper_bound":
148
+ new_height = self.constrain_to_multiple_of(
149
+ scale_height * height, max_val=self.__height
150
+ )
151
+ new_width = self.constrain_to_multiple_of(
152
+ scale_width * width, max_val=self.__width
153
+ )
154
+ elif self.__resize_method == "minimal":
155
+ new_height = self.constrain_to_multiple_of(scale_height * height)
156
+ new_width = self.constrain_to_multiple_of(scale_width * width)
157
+ else:
158
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
159
+
160
+ return (new_width, new_height)
161
+
162
+ def __call__(self, sample):
163
+ width, height = self.get_size(
164
+ sample["image"].shape[1], sample["image"].shape[0]
165
+ )
166
+
167
+ # resize sample
168
+ sample["image"] = cv2.resize(
169
+ sample["image"],
170
+ (width, height),
171
+ interpolation=self.__image_interpolation_method,
172
+ )
173
+
174
+ if self.__resize_target:
175
+ if "disparity" in sample:
176
+ sample["disparity"] = cv2.resize(
177
+ sample["disparity"],
178
+ (width, height),
179
+ interpolation=cv2.INTER_NEAREST,
180
+ )
181
+
182
+ if "depth" in sample:
183
+ sample["depth"] = cv2.resize(
184
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185
+ )
186
+
187
+ sample["mask"] = cv2.resize(
188
+ sample["mask"].astype(np.float32),
189
+ (width, height),
190
+ interpolation=cv2.INTER_NEAREST,
191
+ )
192
+ sample["mask"] = sample["mask"].astype(bool)
193
+
194
+ return sample
195
+
196
+
197
+ class NormalizeImage(object):
198
+ """Normlize image by given mean and std.
199
+ """
200
+
201
+ def __init__(self, mean, std):
202
+ self.__mean = mean
203
+ self.__std = std
204
+
205
+ def __call__(self, sample):
206
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
207
+
208
+ return sample
209
+
210
+
211
+ class PrepareForNet(object):
212
+ """Prepare sample for usage as network input.
213
+ """
214
+
215
+ def __init__(self):
216
+ pass
217
+
218
+ def __call__(self, sample):
219
+ image = np.transpose(sample["image"], (2, 0, 1))
220
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221
+
222
+ if "mask" in sample:
223
+ sample["mask"] = sample["mask"].astype(np.float32)
224
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
225
+
226
+ if "disparity" in sample:
227
+ disparity = sample["disparity"].astype(np.float32)
228
+ sample["disparity"] = np.ascontiguousarray(disparity)
229
+
230
+ if "depth" in sample:
231
+ depth = sample["depth"].astype(np.float32)
232
+ sample["depth"] = np.ascontiguousarray(depth)
233
+
234
+ return sample
annotator/midas/midas/vit.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Slice(nn.Module):
10
+ def __init__(self, start_index=1):
11
+ super(Slice, self).__init__()
12
+ self.start_index = start_index
13
+
14
+ def forward(self, x):
15
+ return x[:, self.start_index :]
16
+
17
+
18
+ class AddReadout(nn.Module):
19
+ def __init__(self, start_index=1):
20
+ super(AddReadout, self).__init__()
21
+ self.start_index = start_index
22
+
23
+ def forward(self, x):
24
+ if self.start_index == 2:
25
+ readout = (x[:, 0] + x[:, 1]) / 2
26
+ else:
27
+ readout = x[:, 0]
28
+ return x[:, self.start_index :] + readout.unsqueeze(1)
29
+
30
+
31
+ class ProjectReadout(nn.Module):
32
+ def __init__(self, in_features, start_index=1):
33
+ super(ProjectReadout, self).__init__()
34
+ self.start_index = start_index
35
+
36
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37
+
38
+ def forward(self, x):
39
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40
+ features = torch.cat((x[:, self.start_index :], readout), -1)
41
+
42
+ return self.project(features)
43
+
44
+
45
+ class Transpose(nn.Module):
46
+ def __init__(self, dim0, dim1):
47
+ super(Transpose, self).__init__()
48
+ self.dim0 = dim0
49
+ self.dim1 = dim1
50
+
51
+ def forward(self, x):
52
+ x = x.transpose(self.dim0, self.dim1)
53
+ return x
54
+
55
+
56
+ def forward_vit(pretrained, x):
57
+ b, c, h, w = x.shape
58
+
59
+ glob = pretrained.model.forward_flex(x)
60
+
61
+ layer_1 = pretrained.activations["1"]
62
+ layer_2 = pretrained.activations["2"]
63
+ layer_3 = pretrained.activations["3"]
64
+ layer_4 = pretrained.activations["4"]
65
+
66
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
67
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
68
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
69
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
70
+
71
+ unflatten = nn.Sequential(
72
+ nn.Unflatten(
73
+ 2,
74
+ torch.Size(
75
+ [
76
+ h // pretrained.model.patch_size[1],
77
+ w // pretrained.model.patch_size[0],
78
+ ]
79
+ ),
80
+ )
81
+ )
82
+
83
+ if layer_1.ndim == 3:
84
+ layer_1 = unflatten(layer_1)
85
+ if layer_2.ndim == 3:
86
+ layer_2 = unflatten(layer_2)
87
+ if layer_3.ndim == 3:
88
+ layer_3 = unflatten(layer_3)
89
+ if layer_4.ndim == 3:
90
+ layer_4 = unflatten(layer_4)
91
+
92
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
93
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
94
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
95
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
96
+
97
+ return layer_1, layer_2, layer_3, layer_4
98
+
99
+
100
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
101
+ posemb_tok, posemb_grid = (
102
+ posemb[:, : self.start_index],
103
+ posemb[0, self.start_index :],
104
+ )
105
+
106
+ gs_old = int(math.sqrt(len(posemb_grid)))
107
+
108
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
109
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
110
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
111
+
112
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
113
+
114
+ return posemb
115
+
116
+
117
+ def forward_flex(self, x):
118
+ b, c, h, w = x.shape
119
+
120
+ pos_embed = self._resize_pos_embed(
121
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
122
+ )
123
+
124
+ B = x.shape[0]
125
+
126
+ if hasattr(self.patch_embed, "backbone"):
127
+ x = self.patch_embed.backbone(x)
128
+ if isinstance(x, (list, tuple)):
129
+ x = x[-1] # last feature if backbone outputs list/tuple of features
130
+
131
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
132
+
133
+ if getattr(self, "dist_token", None) is not None:
134
+ cls_tokens = self.cls_token.expand(
135
+ B, -1, -1
136
+ ) # stole cls_tokens impl from Phil Wang, thanks
137
+ dist_token = self.dist_token.expand(B, -1, -1)
138
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
139
+ else:
140
+ cls_tokens = self.cls_token.expand(
141
+ B, -1, -1
142
+ ) # stole cls_tokens impl from Phil Wang, thanks
143
+ x = torch.cat((cls_tokens, x), dim=1)
144
+
145
+ x = x + pos_embed
146
+ x = self.pos_drop(x)
147
+
148
+ for blk in self.blocks:
149
+ x = blk(x)
150
+
151
+ x = self.norm(x)
152
+
153
+ return x
154
+
155
+
156
+ activations = {}
157
+
158
+
159
+ def get_activation(name):
160
+ def hook(model, input, output):
161
+ activations[name] = output
162
+
163
+ return hook
164
+
165
+
166
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
167
+ if use_readout == "ignore":
168
+ readout_oper = [Slice(start_index)] * len(features)
169
+ elif use_readout == "add":
170
+ readout_oper = [AddReadout(start_index)] * len(features)
171
+ elif use_readout == "project":
172
+ readout_oper = [
173
+ ProjectReadout(vit_features, start_index) for out_feat in features
174
+ ]
175
+ else:
176
+ assert (
177
+ False
178
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
179
+
180
+ return readout_oper
181
+
182
+
183
+ def _make_vit_b16_backbone(
184
+ model,
185
+ features=[96, 192, 384, 768],
186
+ size=[384, 384],
187
+ hooks=[2, 5, 8, 11],
188
+ vit_features=768,
189
+ use_readout="ignore",
190
+ start_index=1,
191
+ ):
192
+ pretrained = nn.Module()
193
+
194
+ pretrained.model = model
195
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
196
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
197
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
198
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
199
+
200
+ pretrained.activations = activations
201
+
202
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
203
+
204
+ # 32, 48, 136, 384
205
+ pretrained.act_postprocess1 = nn.Sequential(
206
+ readout_oper[0],
207
+ Transpose(1, 2),
208
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
209
+ nn.Conv2d(
210
+ in_channels=vit_features,
211
+ out_channels=features[0],
212
+ kernel_size=1,
213
+ stride=1,
214
+ padding=0,
215
+ ),
216
+ nn.ConvTranspose2d(
217
+ in_channels=features[0],
218
+ out_channels=features[0],
219
+ kernel_size=4,
220
+ stride=4,
221
+ padding=0,
222
+ bias=True,
223
+ dilation=1,
224
+ groups=1,
225
+ ),
226
+ )
227
+
228
+ pretrained.act_postprocess2 = nn.Sequential(
229
+ readout_oper[1],
230
+ Transpose(1, 2),
231
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
232
+ nn.Conv2d(
233
+ in_channels=vit_features,
234
+ out_channels=features[1],
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0,
238
+ ),
239
+ nn.ConvTranspose2d(
240
+ in_channels=features[1],
241
+ out_channels=features[1],
242
+ kernel_size=2,
243
+ stride=2,
244
+ padding=0,
245
+ bias=True,
246
+ dilation=1,
247
+ groups=1,
248
+ ),
249
+ )
250
+
251
+ pretrained.act_postprocess3 = nn.Sequential(
252
+ readout_oper[2],
253
+ Transpose(1, 2),
254
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
255
+ nn.Conv2d(
256
+ in_channels=vit_features,
257
+ out_channels=features[2],
258
+ kernel_size=1,
259
+ stride=1,
260
+ padding=0,
261
+ ),
262
+ )
263
+
264
+ pretrained.act_postprocess4 = nn.Sequential(
265
+ readout_oper[3],
266
+ Transpose(1, 2),
267
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
268
+ nn.Conv2d(
269
+ in_channels=vit_features,
270
+ out_channels=features[3],
271
+ kernel_size=1,
272
+ stride=1,
273
+ padding=0,
274
+ ),
275
+ nn.Conv2d(
276
+ in_channels=features[3],
277
+ out_channels=features[3],
278
+ kernel_size=3,
279
+ stride=2,
280
+ padding=1,
281
+ ),
282
+ )
283
+
284
+ pretrained.model.start_index = start_index
285
+ pretrained.model.patch_size = [16, 16]
286
+
287
+ # We inject this function into the VisionTransformer instances so that
288
+ # we can use it with interpolated position embeddings without modifying the library source.
289
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
290
+ pretrained.model._resize_pos_embed = types.MethodType(
291
+ _resize_pos_embed, pretrained.model
292
+ )
293
+
294
+ return pretrained
295
+
296
+
297
+ def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
298
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
299
+
300
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
301
+ return _make_vit_b16_backbone(
302
+ model,
303
+ features=[256, 512, 1024, 1024],
304
+ hooks=hooks,
305
+ vit_features=1024,
306
+ use_readout=use_readout,
307
+ )
308
+
309
+
310
+ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
311
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
312
+
313
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
314
+ return _make_vit_b16_backbone(
315
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
316
+ )
317
+
318
+
319
+ def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
320
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
321
+
322
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
323
+ return _make_vit_b16_backbone(
324
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
325
+ )
326
+
327
+
328
+ def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
329
+ model = timm.create_model(
330
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
331
+ )
332
+
333
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
334
+ return _make_vit_b16_backbone(
335
+ model,
336
+ features=[96, 192, 384, 768],
337
+ hooks=hooks,
338
+ use_readout=use_readout,
339
+ start_index=2,
340
+ )
341
+
342
+
343
+ def _make_vit_b_rn50_backbone(
344
+ model,
345
+ features=[256, 512, 768, 768],
346
+ size=[384, 384],
347
+ hooks=[0, 1, 8, 11],
348
+ vit_features=768,
349
+ use_vit_only=False,
350
+ use_readout="ignore",
351
+ start_index=1,
352
+ ):
353
+ pretrained = nn.Module()
354
+
355
+ pretrained.model = model
356
+
357
+ if use_vit_only == True:
358
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
359
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
360
+ else:
361
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
362
+ get_activation("1")
363
+ )
364
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
365
+ get_activation("2")
366
+ )
367
+
368
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
369
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
370
+
371
+ pretrained.activations = activations
372
+
373
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
374
+
375
+ if use_vit_only == True:
376
+ pretrained.act_postprocess1 = nn.Sequential(
377
+ readout_oper[0],
378
+ Transpose(1, 2),
379
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
380
+ nn.Conv2d(
381
+ in_channels=vit_features,
382
+ out_channels=features[0],
383
+ kernel_size=1,
384
+ stride=1,
385
+ padding=0,
386
+ ),
387
+ nn.ConvTranspose2d(
388
+ in_channels=features[0],
389
+ out_channels=features[0],
390
+ kernel_size=4,
391
+ stride=4,
392
+ padding=0,
393
+ bias=True,
394
+ dilation=1,
395
+ groups=1,
396
+ ),
397
+ )
398
+
399
+ pretrained.act_postprocess2 = nn.Sequential(
400
+ readout_oper[1],
401
+ Transpose(1, 2),
402
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
403
+ nn.Conv2d(
404
+ in_channels=vit_features,
405
+ out_channels=features[1],
406
+ kernel_size=1,
407
+ stride=1,
408
+ padding=0,
409
+ ),
410
+ nn.ConvTranspose2d(
411
+ in_channels=features[1],
412
+ out_channels=features[1],
413
+ kernel_size=2,
414
+ stride=2,
415
+ padding=0,
416
+ bias=True,
417
+ dilation=1,
418
+ groups=1,
419
+ ),
420
+ )
421
+ else:
422
+ pretrained.act_postprocess1 = nn.Sequential(
423
+ nn.Identity(), nn.Identity(), nn.Identity()
424
+ )
425
+ pretrained.act_postprocess2 = nn.Sequential(
426
+ nn.Identity(), nn.Identity(), nn.Identity()
427
+ )
428
+
429
+ pretrained.act_postprocess3 = nn.Sequential(
430
+ readout_oper[2],
431
+ Transpose(1, 2),
432
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
433
+ nn.Conv2d(
434
+ in_channels=vit_features,
435
+ out_channels=features[2],
436
+ kernel_size=1,
437
+ stride=1,
438
+ padding=0,
439
+ ),
440
+ )
441
+
442
+ pretrained.act_postprocess4 = nn.Sequential(
443
+ readout_oper[3],
444
+ Transpose(1, 2),
445
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
446
+ nn.Conv2d(
447
+ in_channels=vit_features,
448
+ out_channels=features[3],
449
+ kernel_size=1,
450
+ stride=1,
451
+ padding=0,
452
+ ),
453
+ nn.Conv2d(
454
+ in_channels=features[3],
455
+ out_channels=features[3],
456
+ kernel_size=3,
457
+ stride=2,
458
+ padding=1,
459
+ ),
460
+ )
461
+
462
+ pretrained.model.start_index = start_index
463
+ pretrained.model.patch_size = [16, 16]
464
+
465
+ # We inject this function into the VisionTransformer instances so that
466
+ # we can use it with interpolated position embeddings without modifying the library source.
467
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
468
+
469
+ # We inject this function into the VisionTransformer instances so that
470
+ # we can use it with interpolated position embeddings without modifying the library source.
471
+ pretrained.model._resize_pos_embed = types.MethodType(
472
+ _resize_pos_embed, pretrained.model
473
+ )
474
+
475
+ return pretrained
476
+
477
+
478
+ def _make_pretrained_vitb_rn50_384(
479
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
480
+ ):
481
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
482
+
483
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
484
+ return _make_vit_b_rn50_backbone(
485
+ model,
486
+ features=[256, 512, 768, 768],
487
+ size=[384, 384],
488
+ hooks=hooks,
489
+ use_vit_only=use_vit_only,
490
+ use_readout=use_readout,
491
+ )
annotator/midas/utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for monoDepth."""
2
+ import sys
3
+ import re
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+
8
+
9
+ def read_pfm(path):
10
+ """Read pfm file.
11
+
12
+ Args:
13
+ path (str): path to file
14
+
15
+ Returns:
16
+ tuple: (data, scale)
17
+ """
18
+ with open(path, "rb") as file:
19
+
20
+ color = None
21
+ width = None
22
+ height = None
23
+ scale = None
24
+ endian = None
25
+
26
+ header = file.readline().rstrip()
27
+ if header.decode("ascii") == "PF":
28
+ color = True
29
+ elif header.decode("ascii") == "Pf":
30
+ color = False
31
+ else:
32
+ raise Exception("Not a PFM file: " + path)
33
+
34
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35
+ if dim_match:
36
+ width, height = list(map(int, dim_match.groups()))
37
+ else:
38
+ raise Exception("Malformed PFM header.")
39
+
40
+ scale = float(file.readline().decode("ascii").rstrip())
41
+ if scale < 0:
42
+ # little-endian
43
+ endian = "<"
44
+ scale = -scale
45
+ else:
46
+ # big-endian
47
+ endian = ">"
48
+
49
+ data = np.fromfile(file, endian + "f")
50
+ shape = (height, width, 3) if color else (height, width)
51
+
52
+ data = np.reshape(data, shape)
53
+ data = np.flipud(data)
54
+
55
+ return data, scale
56
+
57
+
58
+ def write_pfm(path, image, scale=1):
59
+ """Write pfm file.
60
+
61
+ Args:
62
+ path (str): pathto file
63
+ image (array): data
64
+ scale (int, optional): Scale. Defaults to 1.
65
+ """
66
+
67
+ with open(path, "wb") as file:
68
+ color = None
69
+
70
+ if image.dtype.name != "float32":
71
+ raise Exception("Image dtype must be float32.")
72
+
73
+ image = np.flipud(image)
74
+
75
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
76
+ color = True
77
+ elif (
78
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79
+ ): # greyscale
80
+ color = False
81
+ else:
82
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83
+
84
+ file.write("PF\n" if color else "Pf\n".encode())
85
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86
+
87
+ endian = image.dtype.byteorder
88
+
89
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
90
+ scale = -scale
91
+
92
+ file.write("%f\n".encode() % scale)
93
+
94
+ image.tofile(file)
95
+
96
+
97
+ def read_image(path):
98
+ """Read image and output RGB image (0-1).
99
+
100
+ Args:
101
+ path (str): path to file
102
+
103
+ Returns:
104
+ array: RGB image (0-1)
105
+ """
106
+ img = cv2.imread(path)
107
+
108
+ if img.ndim == 2:
109
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110
+
111
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112
+
113
+ return img
114
+
115
+
116
+ def resize_image(img):
117
+ """Resize image and make it fit for network.
118
+
119
+ Args:
120
+ img (array): image
121
+
122
+ Returns:
123
+ tensor: data ready for network
124
+ """
125
+ height_orig = img.shape[0]
126
+ width_orig = img.shape[1]
127
+
128
+ if width_orig > height_orig:
129
+ scale = width_orig / 384
130
+ else:
131
+ scale = height_orig / 384
132
+
133
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135
+
136
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137
+
138
+ img_resized = (
139
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140
+ )
141
+ img_resized = img_resized.unsqueeze(0)
142
+
143
+ return img_resized
144
+
145
+
146
+ def resize_depth(depth, width, height):
147
+ """Resize depth map and bring to CPU (numpy).
148
+
149
+ Args:
150
+ depth (tensor): depth
151
+ width (int): image width
152
+ height (int): image height
153
+
154
+ Returns:
155
+ array: processed depth
156
+ """
157
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158
+
159
+ depth_resized = cv2.resize(
160
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161
+ )
162
+
163
+ return depth_resized
164
+
165
+ def write_depth(path, depth, bits=1):
166
+ """Write depth map to pfm and png file.
167
+
168
+ Args:
169
+ path (str): filepath without extension
170
+ depth (array): depth
171
+ """
172
+ write_pfm(path + ".pfm", depth.astype(np.float32))
173
+
174
+ depth_min = depth.min()
175
+ depth_max = depth.max()
176
+
177
+ max_val = (2**(8*bits))-1
178
+
179
+ if depth_max - depth_min > np.finfo("float").eps:
180
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
181
+ else:
182
+ out = np.zeros(depth.shape, dtype=depth.type)
183
+
184
+ if bits == 1:
185
+ cv2.imwrite(path + ".png", out.astype("uint8"))
186
+ elif bits == 2:
187
+ cv2.imwrite(path + ".png", out.astype("uint16"))
188
+
189
+ return
annotator/util.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import os
4
+
5
+
6
+ annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
7
+
8
+
9
+ def HWC3(x):
10
+ assert x.dtype == np.uint8
11
+ if x.ndim == 2:
12
+ x = x[:, :, None]
13
+ assert x.ndim == 3
14
+ H, W, C = x.shape
15
+ assert C == 1 or C == 3 or C == 4
16
+ if C == 3:
17
+ return x
18
+ if C == 1:
19
+ return np.concatenate([x, x, x], axis=2)
20
+ if C == 4:
21
+ color = x[:, :, 0:3].astype(np.float32)
22
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
23
+ y = color * alpha + 255.0 * (1.0 - alpha)
24
+ y = y.clip(0, 255).astype(np.uint8)
25
+ return y
26
+
27
+
28
+ def resize_image(input_image, resolution):
29
+ H, W, C = input_image.shape
30
+ H = float(H)
31
+ W = float(W)
32
+ k = float(resolution) / min(H, W)
33
+ H *= k
34
+ W *= k
35
+ H = int(np.round(H / 64.0)) * 64
36
+ W = int(np.round(W / 64.0)) * 64
37
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
38
+ return img
app.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import cv2
4
+ import einops
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ import torch.optim as optim
9
+ import random
10
+ import imageio
11
+ from torchvision import transforms
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from PIL import Image
15
+ import time
16
+ import scipy.interpolate
17
+ from tqdm import tqdm
18
+
19
+ from pytorch_lightning import seed_everything
20
+ from annotator.util import resize_image, HWC3
21
+ from annotator.canny import CannyDetector
22
+ from annotator.midas import MidasDetector
23
+ from cldm.model import create_model, load_state_dict
24
+ from ldm.models.diffusion.ddim import DDIMSampler
25
+ from stablevideo.atlas_data import AtlasData
26
+ from stablevideo.atlas_utils import get_grid_indices, get_atlas_bounding_box
27
+ from stablevideo.aggnet import AGGNet
28
+
29
+
30
+ class StableVideo:
31
+ def __init__(self, base_cfg, canny_model_cfg, depth_model_cfg, save_memory=False):
32
+ self.base_cfg = base_cfg
33
+ self.canny_model_cfg = canny_model_cfg
34
+ self.depth_model_cfg = depth_model_cfg
35
+ self.img2img_model = None
36
+ self.canny_model = None
37
+ self.depth_model = None
38
+ self.b_atlas = None
39
+ self.f_atlas = None
40
+ self.data = None
41
+ self.crops = None
42
+ self.save_memory = save_memory
43
+
44
+ def load_canny_model(
45
+ self,
46
+ base_cfg='ckpt/cldm_v15.yaml',
47
+ canny_model_cfg='ckpt/control_sd15_canny.pth',
48
+ ):
49
+ self.apply_canny = CannyDetector()
50
+ canny_model = create_model(base_cfg).cpu()
51
+ canny_model.load_state_dict(load_state_dict(canny_model_cfg, location='cuda'), strict=False)
52
+ self.canny_ddim_sampler = DDIMSampler(canny_model)
53
+ self.canny_model = canny_model
54
+
55
+ def load_depth_model(
56
+ self,
57
+ base_cfg='ckpt/cldm_v15.yaml',
58
+ depth_model_cfg='ckpt/control_sd15_depth.pth',
59
+ ):
60
+ self.apply_midas = MidasDetector()
61
+ depth_model = create_model(base_cfg).cpu()
62
+ depth_model.load_state_dict(load_state_dict(depth_model_cfg, location='cuda'), strict=False)
63
+ self.depth_ddim_sampler = DDIMSampler(depth_model)
64
+ self.depth_model = depth_model
65
+
66
+ def load_video(self, video_name):
67
+ self.data = AtlasData(video_name)
68
+ save_name = f"data/{video_name}/{video_name}.mp4"
69
+ if not os.path.exists(save_name):
70
+ imageio.mimwrite(save_name, self.data.original_video.cpu().permute(0, 2, 3, 1))
71
+ print("original video saved.")
72
+ toIMG = transforms.ToPILImage()
73
+ self.f_atlas_origin = toIMG(self.data.cropped_foreground_atlas[0])
74
+ self.b_atlas_origin = toIMG(self.data.background_grid_atlas[0])
75
+ return save_name, self.f_atlas_origin, self.b_atlas_origin
76
+
77
+ @torch.no_grad()
78
+ def depth_edit(self, input_image=None,
79
+ prompt="",
80
+ a_prompt="best quality, extremely detailed",
81
+ n_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
82
+ image_resolution=512,
83
+ detect_resolution=384,
84
+ ddim_steps=20,
85
+ scale=9,
86
+ seed=-1,
87
+ eta=0,
88
+ num_samples=1):
89
+
90
+ size = input_image.size
91
+ model = self.depth_model
92
+ ddim_sampler = self.depth_ddim_sampler
93
+ apply_midas = self.apply_midas
94
+
95
+ input_image = np.array(input_image)
96
+ input_image = HWC3(input_image)
97
+ detected_map, _ = apply_midas(resize_image(input_image, detect_resolution))
98
+ detected_map = HWC3(detected_map)
99
+ img = resize_image(input_image, image_resolution)
100
+ H, W, C = img.shape
101
+
102
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
103
+
104
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
105
+ control = torch.stack([control for _ in range(1)], dim=0)
106
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
107
+
108
+ if seed == -1:
109
+ seed = random.randint(0, 65535)
110
+ seed_everything(seed)
111
+
112
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
113
+ un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
114
+ shape = (4, H // 8, W // 8)
115
+
116
+
117
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
118
+ shape, cond, verbose=False, eta=eta,
119
+ unconditional_guidance_scale=scale,
120
+ unconditional_conditioning=un_cond)
121
+
122
+ x_samples = model.decode_first_stage(samples)
123
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
124
+
125
+ results = [x_samples[i] for i in range(num_samples)]
126
+ self.b_atlas = Image.fromarray(results[0]).resize(size)
127
+ return self.b_atlas
128
+
129
+ @torch.no_grad()
130
+ def edit_background(self, *args, **kwargs):
131
+ self.depth_model = self.depth_model.cuda()
132
+
133
+ input_image = self.b_atlas_origin
134
+ self.depth_edit(input_image, *args, **kwargs)
135
+
136
+ if self.save_memory:
137
+ self.depth_model = self.depth_model.cpu()
138
+ return self.b_atlas
139
+
140
+ @torch.no_grad()
141
+ def advanced_edit_foreground(self,
142
+ keyframes="0",
143
+ res=2000,
144
+ prompt="",
145
+ a_prompt="best quality, extremely detailed",
146
+ n_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
147
+ image_resolution=512,
148
+ low_threshold=100,
149
+ high_threshold=200,
150
+ ddim_steps=20,
151
+ s=0.9,
152
+ scale=9,
153
+ seed=-1,
154
+ eta=0,
155
+ if_net=False,
156
+ num_samples=1):
157
+
158
+ self.canny_model = self.canny_model.cuda()
159
+
160
+ keyframes = [int(x) for x in keyframes.split(",")]
161
+ if self.data is None:
162
+ raise ValueError("Please load video first")
163
+ self.crops = self.data.get_global_crops_multi(keyframes, res)
164
+ n_keyframes = len(keyframes)
165
+ indices = get_grid_indices(0, 0, res, res)
166
+ f_atlas = torch.zeros(size=(n_keyframes, res, res, 3,)).to("cuda")
167
+
168
+ img_list = [transforms.ToPILImage()(i[0]) for i in self.crops['original_foreground_crops']]
169
+ result_list = []
170
+
171
+ # initial setting
172
+ if seed == -1:
173
+ seed = random.randint(0, 65535)
174
+ seed_everything(seed)
175
+
176
+ self.canny_ddim_sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=eta, verbose=False)
177
+ c_crossattn = [self.canny_model.get_learned_conditioning([prompt + ', ' + a_prompt])]
178
+ uc_crossattn = [self.canny_model.get_learned_conditioning([n_prompt])]
179
+
180
+ for i in range(n_keyframes):
181
+ # get current keyframe
182
+ current_img = img_list[i]
183
+ img = resize_image(HWC3(np.array(current_img)), image_resolution)
184
+ H, W, C = img.shape
185
+ shape = (4, H // 8, W // 8)
186
+ # get canny control
187
+ detected_map = self.apply_canny(img, low_threshold, high_threshold)
188
+ detected_map = HWC3(detected_map)
189
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
190
+ control = einops.rearrange(control.unsqueeze(0), 'b h w c -> b c h w').clone()
191
+
192
+ cond = {"c_concat": [control], "c_crossattn": c_crossattn}
193
+ un_cond = {"c_concat": [control], "c_crossattn": uc_crossattn}
194
+
195
+
196
+ # if not the key frame, calculate the mapping from last atlas
197
+ if i == 0:
198
+ latent = torch.randn((1, 4, H // 8, W // 8)).cuda()
199
+ samples, _ = self.canny_ddim_sampler.sample(ddim_steps, num_samples,
200
+ shape, cond, verbose=False, eta=eta,
201
+ unconditional_guidance_scale=scale,
202
+ unconditional_conditioning=un_cond,
203
+ x_T=latent)
204
+ else:
205
+ last_atlas = f_atlas[i-1:i].permute(0, 3, 2, 1)
206
+ mapped_img = F.grid_sample(last_atlas, self.crops['foreground_uvs'][i].reshape(1, -1, 1, 2), mode="bilinear", align_corners=self.data.config["align_corners"]).clamp(min=0.0, max=1.0).reshape((3, current_img.size[1], current_img.size[0]))
207
+ mapped_img = transforms.ToPILImage()(mapped_img)
208
+
209
+ mapped_img = mapped_img.resize((W, H))
210
+ mapped_img = np.array(mapped_img).astype(np.float32) / 255.0
211
+ mapped_img = mapped_img[None].transpose(0, 3, 1, 2)
212
+ mapped_img = torch.from_numpy(mapped_img).cuda()
213
+ mapped_img = 2. * mapped_img - 1.
214
+ latent = self.canny_model.get_first_stage_encoding(self.canny_model.encode_first_stage(mapped_img))
215
+
216
+ t_enc = int(ddim_steps * s)
217
+ latent = self.canny_ddim_sampler.stochastic_encode(latent, torch.tensor([t_enc]).to("cuda"))
218
+ samples = self.canny_ddim_sampler.decode(x_latent=latent,
219
+ cond=cond,
220
+ t_start=t_enc,
221
+ unconditional_guidance_scale=scale,
222
+ unconditional_conditioning=un_cond)
223
+
224
+ x_samples = self.canny_model.decode_first_stage(samples)
225
+ result = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
226
+ result = Image.fromarray(result[0])
227
+
228
+ result = result.resize(current_img.size)
229
+ result = transforms.ToTensor()(result)
230
+ # times alpha
231
+ alpha = self.crops['foreground_alpha'][i][0].cpu()
232
+ result = alpha * result
233
+
234
+ # buffer for training
235
+ result_copy = result.clone().cuda()
236
+ result_copy.requires_grad = True
237
+ result_list.append(result_copy)
238
+
239
+ # map to atlas
240
+ uv = (self.crops['foreground_uvs'][i].reshape(-1, 2) * 0.5 + 0.5) * res
241
+ for c in range(3):
242
+ interpolated = scipy.interpolate.griddata(
243
+ points=uv.cpu().numpy(),
244
+ values=result[c].reshape(-1, 1).cpu().numpy(),
245
+ xi=indices.reshape(-1, 2).cpu().numpy(),
246
+ method="linear",
247
+ ).reshape(res, res)
248
+ interpolated = torch.from_numpy(interpolated).float()
249
+ interpolated[interpolated.isnan()] = 0.0
250
+ f_atlas[i, :, :, c] = interpolated
251
+
252
+ f_atlas = f_atlas.permute(0, 3, 2, 1)
253
+
254
+ # aggregate via simple median as begining
255
+ agg_atlas, _ = torch.median(f_atlas, dim=0)
256
+
257
+ if if_net == True:
258
+ #####################################
259
+ # aggregate net #
260
+ #####################################
261
+ lr, n_epoch = 1e-3, 500
262
+ agg_net = AGGNet().cuda()
263
+ loss_fn = nn.L1Loss()
264
+ optimizer = optim.SGD(agg_net.parameters(), lr=lr, momentum=0.9)
265
+ for _ in range(n_epoch):
266
+ loss = 0.
267
+ for i in range(n_keyframes):
268
+ e_img = result_list[i]
269
+ temp_agg_atlas = agg_net(agg_atlas)
270
+ rec_img = F.grid_sample(temp_agg_atlas[None],
271
+ self.crops['foreground_uvs'][i].reshape(1, -1, 1, 2),
272
+ mode="bilinear",
273
+ align_corners=self.data.config["align_corners"])
274
+ rec_img = rec_img.clamp(min=0.0, max=1.0).reshape(e_img.shape)
275
+ loss += loss_fn(rec_img, e_img)
276
+ optimizer.zero_grad()
277
+ loss.backward()
278
+ optimizer.step()
279
+ agg_atlas = agg_net(agg_atlas)
280
+ #####################################
281
+
282
+ agg_atlas, _ = get_atlas_bounding_box(self.data.mask_boundaries, agg_atlas, self.data.foreground_all_uvs)
283
+ self.f_atlas = transforms.ToPILImage()(agg_atlas)
284
+
285
+ if self.save_memory:
286
+ self.canny_model = self.canny_model.cpu()
287
+
288
+ return self.f_atlas
289
+
290
+ @torch.no_grad()
291
+ def render(self, f_atlas, b_atlas):
292
+ # foreground
293
+ if f_atlas == None:
294
+ f_atlas = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0).cuda()
295
+ else:
296
+ f_atlas, mask = f_atlas["image"], f_atlas["mask"]
297
+ f_atlas_origin = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0).cuda()
298
+ f_atlas = transforms.ToTensor()(f_atlas).unsqueeze(0).cuda()
299
+ mask = transforms.ToTensor()(mask).unsqueeze(0).cuda()
300
+ if f_atlas.shape != mask.shape:
301
+ print("Warning: truncating mask to atlas shape {}".format(f_atlas.shape))
302
+ mask = mask[:f_atlas.shape[0], :f_atlas.shape[1], :f_atlas.shape[2], :f_atlas.shape[3]]
303
+ f_atlas = f_atlas * (1 - mask) + f_atlas_origin * mask
304
+
305
+ f_atlas = torch.nn.functional.pad(
306
+ f_atlas,
307
+ pad=(
308
+ self.data.foreground_atlas_bbox[1],
309
+ self.data.foreground_grid_atlas.shape[-1] - (self.data.foreground_atlas_bbox[1] + self.data.foreground_atlas_bbox[3]),
310
+ self.data.foreground_atlas_bbox[0],
311
+ self.data.foreground_grid_atlas.shape[-2] - (self.data.foreground_atlas_bbox[0] + self.data.foreground_atlas_bbox[2]),
312
+ ),
313
+ mode="replicate",
314
+ )
315
+ foreground_edit = F.grid_sample(
316
+ f_atlas, self.data.scaled_foreground_uvs, mode="bilinear", align_corners=self.data.config["align_corners"]
317
+ ).clamp(min=0.0, max=1.0)
318
+
319
+ foreground_edit = foreground_edit.squeeze().t() # shape (batch, 3)
320
+ foreground_edit = (
321
+ foreground_edit.reshape(self.data.config["maximum_number_of_frames"], self.data.config["resy"], self.data.config["resx"], 3)
322
+ .permute(0, 3, 1, 2)
323
+ .clamp(min=0.0, max=1.0)
324
+ )
325
+ # background
326
+ if b_atlas == None:
327
+ b_atlas = self.b_atlas_origin
328
+
329
+ b_atlas = transforms.ToTensor()(b_atlas).unsqueeze(0).cuda()
330
+ background_edit = F.grid_sample(
331
+ b_atlas, self.data.scaled_background_uvs, mode="bilinear", align_corners=self.data.config["align_corners"]
332
+ ).clamp(min=0.0, max=1.0)
333
+ background_edit = background_edit.squeeze().t() # shape (batch, 3)
334
+ background_edit = (
335
+ background_edit.reshape(self.data.config["maximum_number_of_frames"], self.data.config["resy"], self.data.config["resx"], 3)
336
+ .permute(0, 3, 1, 2)
337
+ .clamp(min=0.0, max=1.0)
338
+ )
339
+
340
+ output_video = (
341
+ self.data.all_alpha * foreground_edit
342
+ + (1 - self.data.all_alpha) * background_edit
343
+ )
344
+ id = time.time()
345
+ os.mkdir(f"log/{id}")
346
+ save_name = f"log/{id}/video.mp4"
347
+ imageio.mimwrite(save_name, (255 * output_video.detach().cpu()).to(torch.uint8).permute(0, 2, 3, 1))
348
+
349
+ return save_name
350
+
351
+ if __name__ == '__main__':
352
+ with torch.cuda.amp.autocast():
353
+ stablevideo = StableVideo(base_cfg="ckpt/cldm_v15.yaml",
354
+ canny_model_cfg="ckpt/control_sd15_canny.pth",
355
+ depth_model_cfg="ckpt/control_sd15_depth.pth",
356
+ save_memory=True)
357
+ stablevideo.load_canny_model()
358
+ stablevideo.load_depth_model()
359
+
360
+ block = gr.Blocks().queue()
361
+ with block:
362
+ with gr.Row():
363
+ gr.Markdown("## StableVideo")
364
+ with gr.Row():
365
+ with gr.Column():
366
+ original_video = gr.Video(label="Original Video", interactive=False)
367
+ with gr.Row():
368
+ foreground_atlas = gr.Image(label="Foreground Atlas", type="pil")
369
+ background_atlas = gr.Image(label="Background Atlas", type="pil")
370
+ gr.Markdown("### Step 1. select one example video and click **Load Video** buttom and wait for 10 sec.")
371
+ avail_video = [f.name for f in os.scandir("data") if f.is_dir()]
372
+ video_name = gr.Radio(choices=avail_video,
373
+ label="Select Example Videos",
374
+ value="car-turn")
375
+ load_video_button = gr.Button("Load Video")
376
+ gr.Markdown("### Step 2. write text prompt and advanced options for background and foreground.")
377
+ with gr.Row():
378
+ f_prompt = gr.Textbox(label="Foreground Prompt", value="a picture of an orange suv")
379
+ b_prompt = gr.Textbox(label="Background Prompt", value="winter scene, snowy scene, beautiful snow")
380
+ with gr.Row():
381
+ with gr.Accordion("Advanced Foreground Options", open=False):
382
+ adv_keyframes = gr.Textbox(label="keyframe", value="20, 40, 60")
383
+ adv_atlas_resolution = gr.Slider(label="Atlas Resolution", minimum=1000, maximum=3000, value=2000, step=100)
384
+ adv_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
385
+ adv_low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
386
+ adv_high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
387
+ adv_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
388
+ adv_s = gr.Slider(label="Noise Scale", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
389
+ adv_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=15.0, value=9.0, step=0.1)
390
+ adv_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
391
+ adv_eta = gr.Number(label="eta (DDIM)", value=0.0)
392
+ adv_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, no background')
393
+ adv_n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
394
+ adv_if_net = gr.gradio.Checkbox(label="if use agg net", value=False)
395
+
396
+ with gr.Accordion("Background Options", open=False):
397
+ b_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
398
+ b_detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=512, step=1)
399
+ b_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
400
+ b_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
401
+ b_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
402
+ b_eta = gr.Number(label="eta (DDIM)", value=0.0)
403
+ b_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
404
+ b_n_prompt = gr.Textbox(label="Negative Prompt",
405
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
406
+ gr.Markdown("### Step 3. edit each one and render.")
407
+ with gr.Row():
408
+ f_advance_run_button = gr.Button("Advanced Edit Foreground (slower, better)")
409
+ b_run_button = gr.Button("Edit Background")
410
+ run_button = gr.Button("Render")
411
+ with gr.Column():
412
+ output_video = gr.Video(label="Output Video", interactive=False)
413
+ # output_foreground_atlas = gr.Image(label="Output Foreground Atlas", type="pil", interactive=False)
414
+ output_foreground_atlas = gr.ImageMask(label="Editable Output Foreground Atlas", type="pil", tool="sketch", interactive=True)
415
+ output_background_atlas = gr.Image(label="Output Background Atlas", type="pil", interactive=False)
416
+
417
+ # edit param
418
+ f_adv_edit_param = [adv_keyframes,
419
+ adv_atlas_resolution,
420
+ f_prompt,
421
+ adv_a_prompt,
422
+ adv_n_prompt,
423
+ adv_image_resolution,
424
+ adv_low_threshold,
425
+ adv_high_threshold,
426
+ adv_ddim_steps,
427
+ adv_s,
428
+ adv_scale,
429
+ adv_seed,
430
+ adv_eta,
431
+ adv_if_net]
432
+ b_edit_param = [b_prompt,
433
+ b_a_prompt,
434
+ b_n_prompt,
435
+ b_image_resolution,
436
+ b_detect_resolution,
437
+ b_ddim_steps,
438
+ b_scale,
439
+ b_seed,
440
+ b_eta]
441
+ # action
442
+ load_video_button.click(fn=stablevideo.load_video, inputs=video_name, outputs=[original_video, foreground_atlas, background_atlas])
443
+ f_advance_run_button.click(fn=stablevideo.advanced_edit_foreground, inputs=f_adv_edit_param, outputs=[output_foreground_atlas])
444
+ b_run_button.click(fn=stablevideo.edit_background, inputs=b_edit_param, outputs=[output_background_atlas])
445
+ run_button.click(fn=stablevideo.render, inputs=[output_foreground_atlas, output_background_atlas], outputs=[output_video])
446
+
447
+ block.launch(share=True)
ckpt/cldm_v15.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: cldm.cldm.ControlLDM
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ num_timesteps_cond: 1
7
+ log_every_t: 200
8
+ timesteps: 1000
9
+ first_stage_key: "jpg"
10
+ cond_stage_key: "txt"
11
+ control_key: "hint"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+ only_mid_control: False
20
+
21
+ control_stage_config:
22
+ target: cldm.cldm.ControlNet
23
+ params:
24
+ image_size: 32 # unused
25
+ in_channels: 4
26
+ hint_channels: 3
27
+ model_channels: 320
28
+ attention_resolutions: [ 4, 2, 1 ]
29
+ num_res_blocks: 2
30
+ channel_mult: [ 1, 2, 4, 4 ]
31
+ num_heads: 8
32
+ use_spatial_transformer: True
33
+ transformer_depth: 1
34
+ context_dim: 768
35
+ use_checkpoint: True
36
+ legacy: False
37
+
38
+ unet_config:
39
+ target: cldm.cldm.ControlledUnetModel
40
+ params:
41
+ image_size: 32 # unused
42
+ in_channels: 4
43
+ out_channels: 4
44
+ model_channels: 320
45
+ attention_resolutions: [ 4, 2, 1 ]
46
+ num_res_blocks: 2
47
+ channel_mult: [ 1, 2, 4, 4 ]
48
+ num_heads: 8
49
+ use_spatial_transformer: True
50
+ transformer_depth: 1
51
+ context_dim: 768
52
+ use_checkpoint: True
53
+ legacy: False
54
+
55
+ first_stage_config:
56
+ target: ldm.models.autoencoder.AutoencoderKL
57
+ params:
58
+ embed_dim: 4
59
+ monitor: val/rec_loss
60
+ ddconfig:
61
+ double_z: true
62
+ z_channels: 4
63
+ resolution: 256
64
+ in_channels: 3
65
+ out_ch: 3
66
+ ch: 128
67
+ ch_mult:
68
+ - 1
69
+ - 2
70
+ - 4
71
+ - 4
72
+ num_res_blocks: 2
73
+ attn_resolutions: []
74
+ dropout: 0.0
75
+ lossconfig:
76
+ target: torch.nn.Identity
77
+
78
+ cond_stage_config:
79
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
ckpt/control_sd15_canny.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4de384b16bc2d7a1fb258ca0cbd941d7dd0a721ae996aff89f905299d6923f45
3
+ size 5710753329
ckpt/control_sd15_depth.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:726cd0b472c4b5c0341b01afcb7fdc4a7b4ab7c37fe797fd394c9805cbef60bf
3
+ size 5710753329
ckpt/dpt_hybrid-midas-501f0c75.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:501f0c75b3bca7daec6b3682c5054c09b366765aef6fa3a09d03a5cb4b230853
3
+ size 492757791
cldm/cldm.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import torch
3
+ import torch as th
4
+ import torch.nn as nn
5
+
6
+ from ldm.modules.diffusionmodules.util import (
7
+ conv_nd,
8
+ linear,
9
+ zero_module,
10
+ timestep_embedding,
11
+ )
12
+
13
+ from einops import rearrange, repeat
14
+ from torchvision.utils import make_grid
15
+ from ldm.modules.attention import SpatialTransformer
16
+ from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
17
+ from ldm.models.diffusion.ddpm import LatentDiffusion
18
+ from ldm.util import log_txt_as_img, exists, instantiate_from_config
19
+ from ldm.models.diffusion.ddim import DDIMSampler
20
+
21
+
22
+ class ControlledUnetModel(UNetModel):
23
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
24
+ hs = []
25
+ with torch.no_grad():
26
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
27
+ emb = self.time_embed(t_emb)
28
+ h = x.type(self.dtype)
29
+ for module in self.input_blocks:
30
+ h = module(h, emb, context)
31
+ hs.append(h)
32
+ h = self.middle_block(h, emb, context)
33
+
34
+ h += control.pop()
35
+
36
+ for i, module in enumerate(self.output_blocks):
37
+ if only_mid_control:
38
+ h = torch.cat([h, hs.pop()], dim=1)
39
+ else:
40
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
41
+ h = module(h, emb, context)
42
+
43
+ h = h.type(x.dtype)
44
+ return self.out(h)
45
+
46
+
47
+ class ControlNet(nn.Module):
48
+ def __init__(
49
+ self,
50
+ image_size,
51
+ in_channels,
52
+ model_channels,
53
+ hint_channels,
54
+ num_res_blocks,
55
+ attention_resolutions,
56
+ dropout=0,
57
+ channel_mult=(1, 2, 4, 8),
58
+ conv_resample=True,
59
+ dims=2,
60
+ use_checkpoint=False,
61
+ use_fp16=False,
62
+ num_heads=-1,
63
+ num_head_channels=-1,
64
+ num_heads_upsample=-1,
65
+ use_scale_shift_norm=False,
66
+ resblock_updown=False,
67
+ use_new_attention_order=False,
68
+ use_spatial_transformer=False, # custom transformer support
69
+ transformer_depth=1, # custom transformer support
70
+ context_dim=None, # custom transformer support
71
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
72
+ legacy=True,
73
+ disable_self_attentions=None,
74
+ num_attention_blocks=None,
75
+ disable_middle_self_attn=False,
76
+ use_linear_in_transformer=False,
77
+ ):
78
+ super().__init__()
79
+ if use_spatial_transformer:
80
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
81
+
82
+ if context_dim is not None:
83
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
84
+ from omegaconf.listconfig import ListConfig
85
+ if type(context_dim) == ListConfig:
86
+ context_dim = list(context_dim)
87
+
88
+ if num_heads_upsample == -1:
89
+ num_heads_upsample = num_heads
90
+
91
+ if num_heads == -1:
92
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
93
+
94
+ if num_head_channels == -1:
95
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
96
+
97
+ self.dims = dims
98
+ self.image_size = image_size
99
+ self.in_channels = in_channels
100
+ self.model_channels = model_channels
101
+ if isinstance(num_res_blocks, int):
102
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
103
+ else:
104
+ if len(num_res_blocks) != len(channel_mult):
105
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
106
+ "as a list/tuple (per-level) with the same length as channel_mult")
107
+ self.num_res_blocks = num_res_blocks
108
+ if disable_self_attentions is not None:
109
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
110
+ assert len(disable_self_attentions) == len(channel_mult)
111
+ if num_attention_blocks is not None:
112
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
113
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
114
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
115
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
116
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
117
+ f"attention will still not be set.")
118
+
119
+ self.attention_resolutions = attention_resolutions
120
+ self.dropout = dropout
121
+ self.channel_mult = channel_mult
122
+ self.conv_resample = conv_resample
123
+ self.use_checkpoint = use_checkpoint
124
+ self.dtype = th.float16 if use_fp16 else th.float32
125
+ self.num_heads = num_heads
126
+ self.num_head_channels = num_head_channels
127
+ self.num_heads_upsample = num_heads_upsample
128
+ self.predict_codebook_ids = n_embed is not None
129
+
130
+ time_embed_dim = model_channels * 4
131
+ self.time_embed = nn.Sequential(
132
+ linear(model_channels, time_embed_dim),
133
+ nn.SiLU(),
134
+ linear(time_embed_dim, time_embed_dim),
135
+ )
136
+
137
+ self.input_blocks = nn.ModuleList(
138
+ [
139
+ TimestepEmbedSequential(
140
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
141
+ )
142
+ ]
143
+ )
144
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
145
+
146
+ self.input_hint_block = TimestepEmbedSequential(
147
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
148
+ nn.SiLU(),
149
+ conv_nd(dims, 16, 16, 3, padding=1),
150
+ nn.SiLU(),
151
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
152
+ nn.SiLU(),
153
+ conv_nd(dims, 32, 32, 3, padding=1),
154
+ nn.SiLU(),
155
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
156
+ nn.SiLU(),
157
+ conv_nd(dims, 96, 96, 3, padding=1),
158
+ nn.SiLU(),
159
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
160
+ nn.SiLU(),
161
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
162
+ )
163
+
164
+ self._feature_size = model_channels
165
+ input_block_chans = [model_channels]
166
+ ch = model_channels
167
+ ds = 1
168
+ for level, mult in enumerate(channel_mult):
169
+ for nr in range(self.num_res_blocks[level]):
170
+ layers = [
171
+ ResBlock(
172
+ ch,
173
+ time_embed_dim,
174
+ dropout,
175
+ out_channels=mult * model_channels,
176
+ dims=dims,
177
+ use_checkpoint=use_checkpoint,
178
+ use_scale_shift_norm=use_scale_shift_norm,
179
+ )
180
+ ]
181
+ ch = mult * model_channels
182
+ if ds in attention_resolutions:
183
+ if num_head_channels == -1:
184
+ dim_head = ch // num_heads
185
+ else:
186
+ num_heads = ch // num_head_channels
187
+ dim_head = num_head_channels
188
+ if legacy:
189
+ #num_heads = 1
190
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
191
+ if exists(disable_self_attentions):
192
+ disabled_sa = disable_self_attentions[level]
193
+ else:
194
+ disabled_sa = False
195
+
196
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
197
+ layers.append(
198
+ AttentionBlock(
199
+ ch,
200
+ use_checkpoint=use_checkpoint,
201
+ num_heads=num_heads,
202
+ num_head_channels=dim_head,
203
+ use_new_attention_order=use_new_attention_order,
204
+ ) if not use_spatial_transformer else SpatialTransformer(
205
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
206
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
207
+ use_checkpoint=use_checkpoint
208
+ )
209
+ )
210
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
211
+ self.zero_convs.append(self.make_zero_conv(ch))
212
+ self._feature_size += ch
213
+ input_block_chans.append(ch)
214
+ if level != len(channel_mult) - 1:
215
+ out_ch = ch
216
+ self.input_blocks.append(
217
+ TimestepEmbedSequential(
218
+ ResBlock(
219
+ ch,
220
+ time_embed_dim,
221
+ dropout,
222
+ out_channels=out_ch,
223
+ dims=dims,
224
+ use_checkpoint=use_checkpoint,
225
+ use_scale_shift_norm=use_scale_shift_norm,
226
+ down=True,
227
+ )
228
+ if resblock_updown
229
+ else Downsample(
230
+ ch, conv_resample, dims=dims, out_channels=out_ch
231
+ )
232
+ )
233
+ )
234
+ ch = out_ch
235
+ input_block_chans.append(ch)
236
+ self.zero_convs.append(self.make_zero_conv(ch))
237
+ ds *= 2
238
+ self._feature_size += ch
239
+
240
+ if num_head_channels == -1:
241
+ dim_head = ch // num_heads
242
+ else:
243
+ num_heads = ch // num_head_channels
244
+ dim_head = num_head_channels
245
+ if legacy:
246
+ #num_heads = 1
247
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
248
+ self.middle_block = TimestepEmbedSequential(
249
+ ResBlock(
250
+ ch,
251
+ time_embed_dim,
252
+ dropout,
253
+ dims=dims,
254
+ use_checkpoint=use_checkpoint,
255
+ use_scale_shift_norm=use_scale_shift_norm,
256
+ ),
257
+ AttentionBlock(
258
+ ch,
259
+ use_checkpoint=use_checkpoint,
260
+ num_heads=num_heads,
261
+ num_head_channels=dim_head,
262
+ use_new_attention_order=use_new_attention_order,
263
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
264
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
265
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
266
+ use_checkpoint=use_checkpoint
267
+ ),
268
+ ResBlock(
269
+ ch,
270
+ time_embed_dim,
271
+ dropout,
272
+ dims=dims,
273
+ use_checkpoint=use_checkpoint,
274
+ use_scale_shift_norm=use_scale_shift_norm,
275
+ ),
276
+ )
277
+ self.middle_block_out = self.make_zero_conv(ch)
278
+ self._feature_size += ch
279
+
280
+ def make_zero_conv(self, channels):
281
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
282
+
283
+ def forward(self, x, hint, timesteps, context, **kwargs):
284
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
285
+ emb = self.time_embed(t_emb)
286
+
287
+ guided_hint = self.input_hint_block(hint, emb, context)
288
+
289
+ outs = []
290
+
291
+ h = x.type(self.dtype)
292
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
293
+ if guided_hint is not None:
294
+ h = module(h, emb, context)
295
+ h += guided_hint
296
+ guided_hint = None
297
+ else:
298
+ h = module(h, emb, context)
299
+ outs.append(zero_conv(h, emb, context))
300
+
301
+ h = self.middle_block(h, emb, context)
302
+ outs.append(self.middle_block_out(h, emb, context))
303
+
304
+ return outs
305
+
306
+
307
+ class ControlLDM(LatentDiffusion):
308
+
309
+ def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
310
+ super().__init__(*args, **kwargs)
311
+ self.control_model = instantiate_from_config(control_stage_config)
312
+ self.control_key = control_key
313
+ self.only_mid_control = only_mid_control
314
+
315
+ @torch.no_grad()
316
+ def get_input(self, batch, k, bs=None, *args, **kwargs):
317
+ x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
318
+ control = batch[self.control_key]
319
+ if bs is not None:
320
+ control = control[:bs]
321
+ control = control.to(self.device)
322
+ control = einops.rearrange(control, 'b h w c -> b c h w')
323
+ control = control.to(memory_format=torch.contiguous_format).float()
324
+ return x, dict(c_crossattn=[c], c_concat=[control])
325
+
326
+ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
327
+ assert isinstance(cond, dict)
328
+ diffusion_model = self.model.diffusion_model
329
+ cond_txt = torch.cat(cond['c_crossattn'], 1)
330
+ cond_hint = torch.cat(cond['c_concat'], 1)
331
+
332
+ control = self.control_model(x=x_noisy, hint=cond_hint, timesteps=t, context=cond_txt)
333
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
334
+
335
+ return eps
336
+
337
+ @torch.no_grad()
338
+ def get_unconditional_conditioning(self, N):
339
+ return self.get_learned_conditioning([""] * N)
340
+
341
+ @torch.no_grad()
342
+ def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
343
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
344
+ plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
345
+ use_ema_scope=True,
346
+ **kwargs):
347
+ use_ddim = ddim_steps is not None
348
+
349
+ log = dict()
350
+ z, c = self.get_input(batch, self.first_stage_key, bs=N)
351
+ c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
352
+ N = min(z.shape[0], N)
353
+ n_row = min(z.shape[0], n_row)
354
+ log["reconstruction"] = self.decode_first_stage(z)
355
+ log["control"] = c_cat * 2.0 - 1.0
356
+ log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
357
+
358
+ if plot_diffusion_rows:
359
+ # get diffusion row
360
+ diffusion_row = list()
361
+ z_start = z[:n_row]
362
+ for t in range(self.num_timesteps):
363
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
364
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
365
+ t = t.to(self.device).long()
366
+ noise = torch.randn_like(z_start)
367
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
368
+ diffusion_row.append(self.decode_first_stage(z_noisy))
369
+
370
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
371
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
372
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
373
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
374
+ log["diffusion_row"] = diffusion_grid
375
+
376
+ if sample:
377
+ # get denoise row
378
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
379
+ batch_size=N, ddim=use_ddim,
380
+ ddim_steps=ddim_steps, eta=ddim_eta)
381
+ x_samples = self.decode_first_stage(samples)
382
+ log["samples"] = x_samples
383
+ if plot_denoise_rows:
384
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
385
+ log["denoise_row"] = denoise_grid
386
+
387
+ if unconditional_guidance_scale > 1.0:
388
+ uc_cross = self.get_unconditional_conditioning(N)
389
+ uc_cat = c_cat # torch.zeros_like(c_cat)
390
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
391
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
392
+ batch_size=N, ddim=use_ddim,
393
+ ddim_steps=ddim_steps, eta=ddim_eta,
394
+ unconditional_guidance_scale=unconditional_guidance_scale,
395
+ unconditional_conditioning=uc_full,
396
+ )
397
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
398
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
399
+
400
+ return log
401
+
402
+ @torch.no_grad()
403
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
404
+ ddim_sampler = DDIMSampler(self)
405
+ b, c, h, w = cond["c_concat"][0].shape
406
+ shape = (self.channels, h // 8, w // 8)
407
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
408
+ return samples, intermediates
409
+
410
+ def configure_optimizers(self):
411
+ lr = self.learning_rate
412
+ params = list(self.control_model.parameters())
413
+ if not self.sd_locked:
414
+ params += list(self.model.diffusion_model.output_blocks.parameters())
415
+ params += list(self.model.diffusion_model.out.parameters())
416
+ opt = torch.optim.AdamW(params, lr=lr)
417
+ return opt
418
+
419
+ def low_vram_shift(self, is_diffusing):
420
+ if is_diffusing:
421
+ self.model = self.model.cuda()
422
+ self.control_model = self.control_model.cuda()
423
+ self.first_stage_model = self.first_stage_model.cpu()
424
+ self.cond_stage_model = self.cond_stage_model.cpu()
425
+ else:
426
+ self.model = self.model.cpu()
427
+ self.control_model = self.control_model.cpu()
428
+ self.first_stage_model = self.first_stage_model.cuda()
429
+ self.cond_stage_model = self.cond_stage_model.cuda()
cldm/hack.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import einops
3
+
4
+ import ldm.modules.encoders.modules
5
+ import ldm.modules.attention
6
+
7
+ from transformers import logging
8
+ from ldm.modules.attention import default
9
+
10
+
11
+ def disable_verbosity():
12
+ logging.set_verbosity_error()
13
+ print('logging improved.')
14
+ return
15
+
16
+
17
+ def enable_sliced_attention():
18
+ ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
19
+ print('Enabled sliced_attention.')
20
+ return
21
+
22
+
23
+ def hack_everything(clip_skip=0):
24
+ disable_verbosity()
25
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
26
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
27
+ print('Enabled clip hacks.')
28
+ return
29
+
30
+
31
+ # Written by Lvmin
32
+ def _hacked_clip_forward(self, text):
33
+ PAD = self.tokenizer.pad_token_id
34
+ EOS = self.tokenizer.eos_token_id
35
+ BOS = self.tokenizer.bos_token_id
36
+
37
+ def tokenize(t):
38
+ return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
39
+
40
+ def transformer_encode(t):
41
+ if self.clip_skip > 1:
42
+ rt = self.transformer(input_ids=t, output_hidden_states=True)
43
+ return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
44
+ else:
45
+ return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
46
+
47
+ def split(x):
48
+ return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
49
+
50
+ def pad(x, p, i):
51
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
52
+
53
+ raw_tokens_list = tokenize(text)
54
+ tokens_list = []
55
+
56
+ for raw_tokens in raw_tokens_list:
57
+ raw_tokens_123 = split(raw_tokens)
58
+ raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
59
+ raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
60
+ tokens_list.append(raw_tokens_123)
61
+
62
+ tokens_list = torch.IntTensor(tokens_list).to(self.device)
63
+
64
+ feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
65
+ y = transformer_encode(feed)
66
+ z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
67
+
68
+ return z
69
+
70
+
71
+ # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
72
+ def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
73
+ h = self.heads
74
+
75
+ q = self.to_q(x)
76
+ context = default(context, x)
77
+ k = self.to_k(context)
78
+ v = self.to_v(context)
79
+ del context, x
80
+
81
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
82
+
83
+ limit = k.shape[0]
84
+ att_step = 1
85
+ q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
86
+ k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
87
+ v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
88
+
89
+ q_chunks.reverse()
90
+ k_chunks.reverse()
91
+ v_chunks.reverse()
92
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
93
+ del k, q, v
94
+ for i in range(0, limit, att_step):
95
+ q_buffer = q_chunks.pop()
96
+ k_buffer = k_chunks.pop()
97
+ v_buffer = v_chunks.pop()
98
+ sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
99
+
100
+ del k_buffer, q_buffer
101
+ # attention, what we cannot get enough of, by chunks
102
+
103
+ sim_buffer = sim_buffer.softmax(dim=-1)
104
+
105
+ sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
106
+ del v_buffer
107
+ sim[i:i + att_step, :, :] = sim_buffer
108
+
109
+ del sim_buffer
110
+ sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
111
+ return self.to_out(sim)
cldm/logger.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchvision
6
+ from PIL import Image
7
+ from pytorch_lightning.callbacks import Callback
8
+ from pytorch_lightning.utilities.distributed import rank_zero_only
9
+
10
+
11
+ class ImageLogger(Callback):
12
+ def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
13
+ rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
14
+ log_images_kwargs=None):
15
+ super().__init__()
16
+ self.rescale = rescale
17
+ self.batch_freq = batch_frequency
18
+ self.max_images = max_images
19
+ if not increase_log_steps:
20
+ self.log_steps = [self.batch_freq]
21
+ self.clamp = clamp
22
+ self.disabled = disabled
23
+ self.log_on_batch_idx = log_on_batch_idx
24
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
25
+ self.log_first_step = log_first_step
26
+
27
+ @rank_zero_only
28
+ def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
29
+ root = os.path.join(save_dir, "image_log", split)
30
+ for k in images:
31
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
32
+ if self.rescale:
33
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
34
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
35
+ grid = grid.numpy()
36
+ grid = (grid * 255).astype(np.uint8)
37
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
38
+ path = os.path.join(root, filename)
39
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
40
+ Image.fromarray(grid).save(path)
41
+
42
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
43
+ check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
44
+ if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
45
+ hasattr(pl_module, "log_images") and
46
+ callable(pl_module.log_images) and
47
+ self.max_images > 0):
48
+ logger = type(pl_module.logger)
49
+
50
+ is_train = pl_module.training
51
+ if is_train:
52
+ pl_module.eval()
53
+
54
+ with torch.no_grad():
55
+ images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
56
+
57
+ for k in images:
58
+ N = min(images[k].shape[0], self.max_images)
59
+ images[k] = images[k][:N]
60
+ if isinstance(images[k], torch.Tensor):
61
+ images[k] = images[k].detach().cpu()
62
+ if self.clamp:
63
+ images[k] = torch.clamp(images[k], -1., 1.)
64
+
65
+ self.log_local(pl_module.logger.save_dir, split, images,
66
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
67
+
68
+ if is_train:
69
+ pl_module.train()
70
+
71
+ def check_frequency(self, check_idx):
72
+ return check_idx % self.batch_freq == 0
73
+
74
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
75
+ if not self.disabled:
76
+ self.log_img(pl_module, batch, batch_idx, split="train")
cldm/model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from omegaconf import OmegaConf
5
+ from ldm.util import instantiate_from_config
6
+
7
+
8
+ def get_state_dict(d):
9
+ return d.get('state_dict', d)
10
+
11
+
12
+ def load_state_dict(ckpt_path, location='cpu'):
13
+ _, extension = os.path.splitext(ckpt_path)
14
+ if extension.lower() == ".safetensors":
15
+ import safetensors.torch
16
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
17
+ else:
18
+ state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
19
+ state_dict = get_state_dict(state_dict)
20
+ print(f'Loaded state_dict from [{ckpt_path}]')
21
+ return state_dict
22
+
23
+
24
+ def create_model(config_path):
25
+ config = OmegaConf.load(config_path)
26
+ model = instantiate_from_config(config.model).cpu()
27
+ print(f'Loaded model config from [{config_path}]')
28
+ return model
data/bear/bear.mp4 ADDED
Binary file (771 kB). View file
 
data/bear/bear/00000.jpg ADDED
data/bear/bear/00001.jpg ADDED
data/bear/bear/00002.jpg ADDED
data/bear/bear/00003.jpg ADDED
data/bear/bear/00004.jpg ADDED
data/bear/bear/00005.jpg ADDED
data/bear/bear/00006.jpg ADDED
data/bear/bear/00007.jpg ADDED
data/bear/bear/00008.jpg ADDED
data/bear/bear/00009.jpg ADDED
data/bear/bear/00010.jpg ADDED
data/bear/bear/00011.jpg ADDED
data/bear/bear/00012.jpg ADDED
data/bear/bear/00013.jpg ADDED
data/bear/bear/00014.jpg ADDED
data/bear/bear/00015.jpg ADDED
data/bear/bear/00016.jpg ADDED
data/bear/bear/00017.jpg ADDED
data/bear/bear/00018.jpg ADDED
data/bear/bear/00019.jpg ADDED
data/bear/bear/00020.jpg ADDED
data/bear/bear/00021.jpg ADDED
data/bear/bear/00022.jpg ADDED
data/bear/bear/00023.jpg ADDED