Haoxin Chen commited on
Commit
58ef683
1 Parent(s): 2eaa3f6

add adapter code

Browse files
extralibs/midas/__init__.py ADDED
File without changes
extralibs/midas/api.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/isl-org/MiDaS
2
+
3
+ import cv2
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision.transforms import Compose
7
+
8
+ from extralibs.midas.midas.dpt_depth import DPTDepthModel
9
+ from extralibs.midas.midas.midas_net import MidasNet
10
+ from extralibs.midas.midas.midas_net_custom import MidasNet_small
11
+ from extralibs.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
12
+
13
+
14
+ ISL_PATHS = {
15
+ "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
16
+ "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
17
+ "midas_v21": "",
18
+ "midas_v21_small": "",
19
+ }
20
+
21
+
22
+ def disabled_train(self, mode=True):
23
+ """Overwrite model.train with this function to make sure train/eval mode
24
+ does not change anymore."""
25
+ return self
26
+
27
+
28
+ def load_midas_transform(model_type):
29
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
30
+ # load transform only
31
+ if model_type == "dpt_large": # DPT-Large
32
+ net_w, net_h = 384, 384
33
+ resize_mode = "minimal"
34
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
35
+
36
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
37
+ net_w, net_h = 384, 384
38
+ resize_mode = "minimal"
39
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
40
+
41
+ elif model_type == "midas_v21":
42
+ net_w, net_h = 384, 384
43
+ resize_mode = "upper_bound"
44
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
45
+
46
+ elif model_type == "midas_v21_small":
47
+ net_w, net_h = 256, 256
48
+ resize_mode = "upper_bound"
49
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50
+
51
+ else:
52
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
53
+
54
+ transform = Compose(
55
+ [
56
+ Resize(
57
+ net_w,
58
+ net_h,
59
+ resize_target=None,
60
+ keep_aspect_ratio=True,
61
+ ensure_multiple_of=32,
62
+ resize_method=resize_mode,
63
+ image_interpolation_method=cv2.INTER_CUBIC,
64
+ ),
65
+ normalization,
66
+ PrepareForNet(),
67
+ ]
68
+ )
69
+
70
+ return transform
71
+
72
+
73
+ def load_model(model_type, model_path=None):
74
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
75
+ # load network
76
+ if model_path is None:
77
+ model_path = ISL_PATHS[model_type]
78
+ if model_type == "dpt_large": # DPT-Large
79
+ model = DPTDepthModel(
80
+ path=model_path,
81
+ backbone="vitl16_384",
82
+ non_negative=True,
83
+ )
84
+ net_w, net_h = 384, 384
85
+ resize_mode = "minimal"
86
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
87
+
88
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
89
+ model = DPTDepthModel(
90
+ path=model_path,
91
+ backbone="vitb_rn50_384",
92
+ non_negative=True,
93
+ )
94
+ net_w, net_h = 384, 384
95
+ resize_mode = "minimal"
96
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
97
+
98
+ elif model_type == "midas_v21":
99
+ model = MidasNet(model_path, non_negative=True)
100
+ net_w, net_h = 384, 384
101
+ resize_mode = "upper_bound"
102
+ normalization = NormalizeImage(
103
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
104
+ )
105
+
106
+ elif model_type == "midas_v21_small":
107
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
108
+ non_negative=True, blocks={'expand': True})
109
+ net_w, net_h = 256, 256
110
+ resize_mode = "upper_bound"
111
+ normalization = NormalizeImage(
112
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
113
+ )
114
+
115
+ else:
116
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
117
+ assert False
118
+
119
+ transform = Compose(
120
+ [
121
+ Resize(
122
+ net_w,
123
+ net_h,
124
+ resize_target=None,
125
+ keep_aspect_ratio=True,
126
+ ensure_multiple_of=32,
127
+ resize_method=resize_mode,
128
+ image_interpolation_method=cv2.INTER_CUBIC,
129
+ ),
130
+ normalization,
131
+ PrepareForNet(),
132
+ ]
133
+ )
134
+
135
+ return model.eval(), transform
136
+
137
+
138
+ class MiDaSInference(nn.Module):
139
+ MODEL_TYPES_TORCH_HUB = [
140
+ "DPT_Large",
141
+ "DPT_Hybrid",
142
+ "MiDaS_small"
143
+ ]
144
+ MODEL_TYPES_ISL = [
145
+ "dpt_large",
146
+ "dpt_hybrid",
147
+ "midas_v21",
148
+ "midas_v21_small",
149
+ ]
150
+
151
+ def __init__(self, model_type, model_path):
152
+ super().__init__()
153
+ assert (model_type in self.MODEL_TYPES_ISL)
154
+ model, _ = load_model(model_type, model_path)
155
+ self.model = model
156
+ self.model.train = disabled_train
157
+
158
+ def forward(self, x):
159
+ # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
160
+ # NOTE: we expect that the correct transform has been called during dataloading.
161
+ with torch.no_grad():
162
+ prediction = self.model(x)
163
+ prediction = torch.nn.functional.interpolate(
164
+ prediction.unsqueeze(1),
165
+ size=x.shape[2:],
166
+ mode="bicubic",
167
+ align_corners=False,
168
+ )
169
+ assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
170
+ return prediction
171
+
extralibs/midas/midas/__init__.py ADDED
File without changes
extralibs/midas/midas/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device('cpu'))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
extralibs/midas/midas/blocks.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vit import (
5
+ _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384,
7
+ _make_pretrained_vitb16_384,
8
+ forward_vit,
9
+ )
10
+
11
+ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12
+ if backbone == "vitl16_384":
13
+ pretrained = _make_pretrained_vitl16_384(
14
+ use_pretrained, hooks=hooks, use_readout=use_readout
15
+ )
16
+ scratch = _make_scratch(
17
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
18
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
19
+ elif backbone == "vitb_rn50_384":
20
+ pretrained = _make_pretrained_vitb_rn50_384(
21
+ use_pretrained,
22
+ hooks=hooks,
23
+ use_vit_only=use_vit_only,
24
+ use_readout=use_readout,
25
+ )
26
+ scratch = _make_scratch(
27
+ [256, 512, 768, 768], features, groups=groups, expand=expand
28
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
29
+ elif backbone == "vitb16_384":
30
+ pretrained = _make_pretrained_vitb16_384(
31
+ use_pretrained, hooks=hooks, use_readout=use_readout
32
+ )
33
+ scratch = _make_scratch(
34
+ [96, 192, 384, 768], features, groups=groups, expand=expand
35
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
36
+ elif backbone == "resnext101_wsl":
37
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39
+ elif backbone == "efficientnet_lite3":
40
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42
+ else:
43
+ print(f"Backbone '{backbone}' not implemented")
44
+ assert False
45
+
46
+ return pretrained, scratch
47
+
48
+
49
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50
+ scratch = nn.Module()
51
+
52
+ out_shape1 = out_shape
53
+ out_shape2 = out_shape
54
+ out_shape3 = out_shape
55
+ out_shape4 = out_shape
56
+ if expand==True:
57
+ out_shape1 = out_shape
58
+ out_shape2 = out_shape*2
59
+ out_shape3 = out_shape*4
60
+ out_shape4 = out_shape*8
61
+
62
+ scratch.layer1_rn = nn.Conv2d(
63
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64
+ )
65
+ scratch.layer2_rn = nn.Conv2d(
66
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67
+ )
68
+ scratch.layer3_rn = nn.Conv2d(
69
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70
+ )
71
+ scratch.layer4_rn = nn.Conv2d(
72
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73
+ )
74
+
75
+ return scratch
76
+
77
+
78
+ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79
+ efficientnet = torch.hub.load(
80
+ "rwightman/gen-efficientnet-pytorch",
81
+ "tf_efficientnet_lite3",
82
+ pretrained=use_pretrained,
83
+ exportable=exportable
84
+ )
85
+ return _make_efficientnet_backbone(efficientnet)
86
+
87
+
88
+ def _make_efficientnet_backbone(effnet):
89
+ pretrained = nn.Module()
90
+
91
+ pretrained.layer1 = nn.Sequential(
92
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93
+ )
94
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97
+
98
+ return pretrained
99
+
100
+
101
+ def _make_resnet_backbone(resnet):
102
+ pretrained = nn.Module()
103
+ pretrained.layer1 = nn.Sequential(
104
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105
+ )
106
+
107
+ pretrained.layer2 = resnet.layer2
108
+ pretrained.layer3 = resnet.layer3
109
+ pretrained.layer4 = resnet.layer4
110
+
111
+ return pretrained
112
+
113
+
114
+ def _make_pretrained_resnext101_wsl(use_pretrained):
115
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116
+ return _make_resnet_backbone(resnet)
117
+
118
+
119
+
120
+ class Interpolate(nn.Module):
121
+ """Interpolation module.
122
+ """
123
+
124
+ def __init__(self, scale_factor, mode, align_corners=False):
125
+ """Init.
126
+
127
+ Args:
128
+ scale_factor (float): scaling
129
+ mode (str): interpolation mode
130
+ """
131
+ super(Interpolate, self).__init__()
132
+
133
+ self.interp = nn.functional.interpolate
134
+ self.scale_factor = scale_factor
135
+ self.mode = mode
136
+ self.align_corners = align_corners
137
+
138
+ def forward(self, x):
139
+ """Forward pass.
140
+
141
+ Args:
142
+ x (tensor): input
143
+
144
+ Returns:
145
+ tensor: interpolated data
146
+ """
147
+
148
+ x = self.interp(
149
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150
+ )
151
+
152
+ return x
153
+
154
+
155
+ class ResidualConvUnit(nn.Module):
156
+ """Residual convolution module.
157
+ """
158
+
159
+ def __init__(self, features):
160
+ """Init.
161
+
162
+ Args:
163
+ features (int): number of features
164
+ """
165
+ super().__init__()
166
+
167
+ self.conv1 = nn.Conv2d(
168
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
169
+ )
170
+
171
+ self.conv2 = nn.Conv2d(
172
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
173
+ )
174
+
175
+ self.relu = nn.ReLU(inplace=True)
176
+
177
+ def forward(self, x):
178
+ """Forward pass.
179
+
180
+ Args:
181
+ x (tensor): input
182
+
183
+ Returns:
184
+ tensor: output
185
+ """
186
+ out = self.relu(x)
187
+ out = self.conv1(out)
188
+ out = self.relu(out)
189
+ out = self.conv2(out)
190
+
191
+ return out + x
192
+
193
+
194
+ class FeatureFusionBlock(nn.Module):
195
+ """Feature fusion block.
196
+ """
197
+
198
+ def __init__(self, features):
199
+ """Init.
200
+
201
+ Args:
202
+ features (int): number of features
203
+ """
204
+ super(FeatureFusionBlock, self).__init__()
205
+
206
+ self.resConfUnit1 = ResidualConvUnit(features)
207
+ self.resConfUnit2 = ResidualConvUnit(features)
208
+
209
+ def forward(self, *xs):
210
+ """Forward pass.
211
+
212
+ Returns:
213
+ tensor: output
214
+ """
215
+ output = xs[0]
216
+
217
+ if len(xs) == 2:
218
+ output += self.resConfUnit1(xs[1])
219
+
220
+ output = self.resConfUnit2(output)
221
+
222
+ output = nn.functional.interpolate(
223
+ output, scale_factor=2, mode="bilinear", align_corners=True
224
+ )
225
+
226
+ return output
227
+
228
+
229
+
230
+
231
+ class ResidualConvUnit_custom(nn.Module):
232
+ """Residual convolution module.
233
+ """
234
+
235
+ def __init__(self, features, activation, bn):
236
+ """Init.
237
+
238
+ Args:
239
+ features (int): number of features
240
+ """
241
+ super().__init__()
242
+
243
+ self.bn = bn
244
+
245
+ self.groups=1
246
+
247
+ self.conv1 = nn.Conv2d(
248
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249
+ )
250
+
251
+ self.conv2 = nn.Conv2d(
252
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253
+ )
254
+
255
+ if self.bn==True:
256
+ self.bn1 = nn.BatchNorm2d(features)
257
+ self.bn2 = nn.BatchNorm2d(features)
258
+
259
+ self.activation = activation
260
+
261
+ self.skip_add = nn.quantized.FloatFunctional()
262
+
263
+ def forward(self, x):
264
+ """Forward pass.
265
+
266
+ Args:
267
+ x (tensor): input
268
+
269
+ Returns:
270
+ tensor: output
271
+ """
272
+
273
+ out = self.activation(x)
274
+ out = self.conv1(out)
275
+ if self.bn==True:
276
+ out = self.bn1(out)
277
+
278
+ out = self.activation(out)
279
+ out = self.conv2(out)
280
+ if self.bn==True:
281
+ out = self.bn2(out)
282
+
283
+ if self.groups > 1:
284
+ out = self.conv_merge(out)
285
+
286
+ return self.skip_add.add(out, x)
287
+
288
+ # return out + x
289
+
290
+
291
+ class FeatureFusionBlock_custom(nn.Module):
292
+ """Feature fusion block.
293
+ """
294
+
295
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296
+ """Init.
297
+
298
+ Args:
299
+ features (int): number of features
300
+ """
301
+ super(FeatureFusionBlock_custom, self).__init__()
302
+
303
+ self.deconv = deconv
304
+ self.align_corners = align_corners
305
+
306
+ self.groups=1
307
+
308
+ self.expand = expand
309
+ out_features = features
310
+ if self.expand==True:
311
+ out_features = features//2
312
+
313
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314
+
315
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317
+
318
+ self.skip_add = nn.quantized.FloatFunctional()
319
+
320
+ def forward(self, *xs):
321
+ """Forward pass.
322
+
323
+ Returns:
324
+ tensor: output
325
+ """
326
+ output = xs[0]
327
+
328
+ if len(xs) == 2:
329
+ res = self.resConfUnit1(xs[1])
330
+ output = self.skip_add.add(output, res)
331
+ # output += res
332
+
333
+ output = self.resConfUnit2(output)
334
+
335
+ output = nn.functional.interpolate(
336
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337
+ )
338
+
339
+ output = self.out_conv(output)
340
+
341
+ return output
342
+
extralibs/midas/midas/dpt_depth.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .base_model import BaseModel
6
+ from .blocks import (
7
+ FeatureFusionBlock,
8
+ FeatureFusionBlock_custom,
9
+ Interpolate,
10
+ _make_encoder,
11
+ forward_vit,
12
+ )
13
+
14
+
15
+ def _make_fusion_block(features, use_bn):
16
+ return FeatureFusionBlock_custom(
17
+ features,
18
+ nn.ReLU(False),
19
+ deconv=False,
20
+ bn=use_bn,
21
+ expand=False,
22
+ align_corners=True,
23
+ )
24
+
25
+
26
+ class DPT(BaseModel):
27
+ def __init__(
28
+ self,
29
+ head,
30
+ features=256,
31
+ backbone="vitb_rn50_384",
32
+ readout="project",
33
+ channels_last=False,
34
+ use_bn=False,
35
+ ):
36
+
37
+ super(DPT, self).__init__()
38
+
39
+ self.channels_last = channels_last
40
+
41
+ hooks = {
42
+ "vitb_rn50_384": [0, 1, 8, 11],
43
+ "vitb16_384": [2, 5, 8, 11],
44
+ "vitl16_384": [5, 11, 17, 23],
45
+ }
46
+
47
+ # Instantiate backbone and reassemble blocks
48
+ self.pretrained, self.scratch = _make_encoder(
49
+ backbone,
50
+ features,
51
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
52
+ groups=1,
53
+ expand=False,
54
+ exportable=False,
55
+ hooks=hooks[backbone],
56
+ use_readout=readout,
57
+ )
58
+
59
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63
+
64
+ self.scratch.output_conv = head
65
+
66
+
67
+ def forward(self, x):
68
+ if self.channels_last == True:
69
+ x.contiguous(memory_format=torch.channels_last)
70
+
71
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72
+
73
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
74
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
75
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
76
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
77
+
78
+ path_4 = self.scratch.refinenet4(layer_4_rn)
79
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82
+
83
+ out = self.scratch.output_conv(path_1)
84
+
85
+ return out
86
+
87
+
88
+ class DPTDepthModel(DPT):
89
+ def __init__(self, path=None, non_negative=True, **kwargs):
90
+ features = kwargs["features"] if "features" in kwargs else 256
91
+
92
+ head = nn.Sequential(
93
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96
+ nn.ReLU(True),
97
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98
+ nn.ReLU(True) if non_negative else nn.Identity(),
99
+ nn.Identity(),
100
+ )
101
+
102
+ super().__init__(head, **kwargs)
103
+
104
+ if path is not None:
105
+ self.load(path)
106
+ print("Midas depth estimation model loaded.")
107
+
108
+ def forward(self, x):
109
+ return super().forward(x).squeeze(dim=1)
110
+
extralibs/midas/midas/midas_net.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=256, non_negative=True):
17
+ """Init.
18
+
19
+ Args:
20
+ path (str, optional): Path to saved model. Defaults to None.
21
+ features (int, optional): Number of features. Defaults to 256.
22
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23
+ """
24
+ print("Loading weights: ", path)
25
+
26
+ super(MidasNet, self).__init__()
27
+
28
+ use_pretrained = False if path is None else True
29
+
30
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31
+
32
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
33
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
34
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
35
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
36
+
37
+ self.scratch.output_conv = nn.Sequential(
38
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39
+ Interpolate(scale_factor=2, mode="bilinear"),
40
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41
+ nn.ReLU(True),
42
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43
+ nn.ReLU(True) if non_negative else nn.Identity(),
44
+ )
45
+
46
+ if path:
47
+ self.load(path)
48
+
49
+ def forward(self, x):
50
+ """Forward pass.
51
+
52
+ Args:
53
+ x (tensor): input data (image)
54
+
55
+ Returns:
56
+ tensor: depth
57
+ """
58
+
59
+ layer_1 = self.pretrained.layer1(x)
60
+ layer_2 = self.pretrained.layer2(layer_1)
61
+ layer_3 = self.pretrained.layer3(layer_2)
62
+ layer_4 = self.pretrained.layer4(layer_3)
63
+
64
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
65
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
66
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
67
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
68
+
69
+ path_4 = self.scratch.refinenet4(layer_4_rn)
70
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73
+
74
+ out = self.scratch.output_conv(path_1)
75
+
76
+ return torch.squeeze(out, dim=1)
extralibs/midas/midas/midas_net_custom.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet_small(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17
+ blocks={'expand': True}):
18
+ """Init.
19
+
20
+ Args:
21
+ path (str, optional): Path to saved model. Defaults to None.
22
+ features (int, optional): Number of features. Defaults to 256.
23
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24
+ """
25
+ print("Loading weights: ", path)
26
+
27
+ super(MidasNet_small, self).__init__()
28
+
29
+ use_pretrained = False if path else True
30
+
31
+ self.channels_last = channels_last
32
+ self.blocks = blocks
33
+ self.backbone = backbone
34
+
35
+ self.groups = 1
36
+
37
+ features1=features
38
+ features2=features
39
+ features3=features
40
+ features4=features
41
+ self.expand = False
42
+ if "expand" in self.blocks and self.blocks['expand'] == True:
43
+ self.expand = True
44
+ features1=features
45
+ features2=features*2
46
+ features3=features*4
47
+ features4=features*8
48
+
49
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50
+
51
+ self.scratch.activation = nn.ReLU(False)
52
+
53
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57
+
58
+
59
+ self.scratch.output_conv = nn.Sequential(
60
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61
+ Interpolate(scale_factor=2, mode="bilinear"),
62
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63
+ self.scratch.activation,
64
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65
+ nn.ReLU(True) if non_negative else nn.Identity(),
66
+ nn.Identity(),
67
+ )
68
+
69
+ if path:
70
+ self.load(path)
71
+
72
+
73
+ def forward(self, x):
74
+ """Forward pass.
75
+
76
+ Args:
77
+ x (tensor): input data (image)
78
+
79
+ Returns:
80
+ tensor: depth
81
+ """
82
+ if self.channels_last==True:
83
+ print("self.channels_last = ", self.channels_last)
84
+ x.contiguous(memory_format=torch.channels_last)
85
+
86
+
87
+ layer_1 = self.pretrained.layer1(x)
88
+ layer_2 = self.pretrained.layer2(layer_1)
89
+ layer_3 = self.pretrained.layer3(layer_2)
90
+ layer_4 = self.pretrained.layer4(layer_3)
91
+
92
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
93
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
94
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
95
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
96
+
97
+
98
+ path_4 = self.scratch.refinenet4(layer_4_rn)
99
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102
+
103
+ out = self.scratch.output_conv(path_1)
104
+
105
+ return torch.squeeze(out, dim=1)
106
+
107
+
108
+
109
+ def fuse_model(m):
110
+ prev_previous_type = nn.Identity()
111
+ prev_previous_name = ''
112
+ previous_type = nn.Identity()
113
+ previous_name = ''
114
+ for name, module in m.named_modules():
115
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116
+ # print("FUSED ", prev_previous_name, previous_name, name)
117
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119
+ # print("FUSED ", prev_previous_name, previous_name)
120
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122
+ # print("FUSED ", previous_name, name)
123
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124
+
125
+ prev_previous_type = previous_type
126
+ prev_previous_name = previous_name
127
+ previous_type = type(module)
128
+ previous_name = name
extralibs/midas/midas/transforms.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+
6
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
+
9
+ Args:
10
+ sample (dict): sample
11
+ size (tuple): image size
12
+
13
+ Returns:
14
+ tuple: new size
15
+ """
16
+ shape = list(sample["disparity"].shape)
17
+
18
+ if shape[0] >= size[0] and shape[1] >= size[1]:
19
+ return sample
20
+
21
+ scale = [0, 0]
22
+ scale[0] = size[0] / shape[0]
23
+ scale[1] = size[1] / shape[1]
24
+
25
+ scale = max(scale)
26
+
27
+ shape[0] = math.ceil(scale * shape[0])
28
+ shape[1] = math.ceil(scale * shape[1])
29
+
30
+ # resize
31
+ sample["image"] = cv2.resize(
32
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
+ )
34
+
35
+ sample["disparity"] = cv2.resize(
36
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
+ )
38
+ sample["mask"] = cv2.resize(
39
+ sample["mask"].astype(np.float32),
40
+ tuple(shape[::-1]),
41
+ interpolation=cv2.INTER_NEAREST,
42
+ )
43
+ sample["mask"] = sample["mask"].astype(bool)
44
+
45
+ return tuple(shape)
46
+
47
+
48
+ class Resize(object):
49
+ """Resize sample to given size (width, height).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ width,
55
+ height,
56
+ resize_target=True,
57
+ keep_aspect_ratio=False,
58
+ ensure_multiple_of=1,
59
+ resize_method="lower_bound",
60
+ image_interpolation_method=cv2.INTER_AREA,
61
+ ):
62
+ """Init.
63
+
64
+ Args:
65
+ width (int): desired output width
66
+ height (int): desired output height
67
+ resize_target (bool, optional):
68
+ True: Resize the full sample (image, mask, target).
69
+ False: Resize image only.
70
+ Defaults to True.
71
+ keep_aspect_ratio (bool, optional):
72
+ True: Keep the aspect ratio of the input sample.
73
+ Output sample might not have the given width and height, and
74
+ resize behaviour depends on the parameter 'resize_method'.
75
+ Defaults to False.
76
+ ensure_multiple_of (int, optional):
77
+ Output width and height is constrained to be multiple of this parameter.
78
+ Defaults to 1.
79
+ resize_method (str, optional):
80
+ "lower_bound": Output will be at least as large as the given size.
81
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83
+ Defaults to "lower_bound".
84
+ """
85
+ self.__width = width
86
+ self.__height = height
87
+
88
+ self.__resize_target = resize_target
89
+ self.__keep_aspect_ratio = keep_aspect_ratio
90
+ self.__multiple_of = ensure_multiple_of
91
+ self.__resize_method = resize_method
92
+ self.__image_interpolation_method = image_interpolation_method
93
+
94
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96
+
97
+ if max_val is not None and y > max_val:
98
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
+
100
+ if y < min_val:
101
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
+
103
+ return y
104
+
105
+ def get_size(self, width, height):
106
+ # determine new height and width
107
+ scale_height = self.__height / height
108
+ scale_width = self.__width / width
109
+
110
+ if self.__keep_aspect_ratio:
111
+ if self.__resize_method == "lower_bound":
112
+ # scale such that output size is lower bound
113
+ if scale_width > scale_height:
114
+ # fit width
115
+ scale_height = scale_width
116
+ else:
117
+ # fit height
118
+ scale_width = scale_height
119
+ elif self.__resize_method == "upper_bound":
120
+ # scale such that output size is upper bound
121
+ if scale_width < scale_height:
122
+ # fit width
123
+ scale_height = scale_width
124
+ else:
125
+ # fit height
126
+ scale_width = scale_height
127
+ elif self.__resize_method == "minimal":
128
+ # scale as least as possbile
129
+ if abs(1 - scale_width) < abs(1 - scale_height):
130
+ # fit width
131
+ scale_height = scale_width
132
+ else:
133
+ # fit height
134
+ scale_width = scale_height
135
+ else:
136
+ raise ValueError(
137
+ f"resize_method {self.__resize_method} not implemented"
138
+ )
139
+
140
+ if self.__resize_method == "lower_bound":
141
+ new_height = self.constrain_to_multiple_of(
142
+ scale_height * height, min_val=self.__height
143
+ )
144
+ new_width = self.constrain_to_multiple_of(
145
+ scale_width * width, min_val=self.__width
146
+ )
147
+ elif self.__resize_method == "upper_bound":
148
+ new_height = self.constrain_to_multiple_of(
149
+ scale_height * height, max_val=self.__height
150
+ )
151
+ new_width = self.constrain_to_multiple_of(
152
+ scale_width * width, max_val=self.__width
153
+ )
154
+ elif self.__resize_method == "minimal":
155
+ new_height = self.constrain_to_multiple_of(scale_height * height)
156
+ new_width = self.constrain_to_multiple_of(scale_width * width)
157
+ else:
158
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
159
+
160
+ return (new_width, new_height)
161
+
162
+ def __call__(self, sample):
163
+ width, height = self.get_size(
164
+ sample["image"].shape[1], sample["image"].shape[0]
165
+ )
166
+
167
+ # resize sample
168
+ sample["image"] = cv2.resize(
169
+ sample["image"],
170
+ (width, height),
171
+ interpolation=self.__image_interpolation_method,
172
+ )
173
+
174
+ if self.__resize_target:
175
+ if "disparity" in sample:
176
+ sample["disparity"] = cv2.resize(
177
+ sample["disparity"],
178
+ (width, height),
179
+ interpolation=cv2.INTER_NEAREST,
180
+ )
181
+
182
+ if "depth" in sample:
183
+ sample["depth"] = cv2.resize(
184
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185
+ )
186
+
187
+ sample["mask"] = cv2.resize(
188
+ sample["mask"].astype(np.float32),
189
+ (width, height),
190
+ interpolation=cv2.INTER_NEAREST,
191
+ )
192
+ sample["mask"] = sample["mask"].astype(bool)
193
+
194
+ return sample
195
+
196
+
197
+ class NormalizeImage(object):
198
+ """Normlize image by given mean and std.
199
+ """
200
+
201
+ def __init__(self, mean, std):
202
+ self.__mean = mean
203
+ self.__std = std
204
+
205
+ def __call__(self, sample):
206
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
207
+
208
+ return sample
209
+
210
+
211
+ class PrepareForNet(object):
212
+ """Prepare sample for usage as network input.
213
+ """
214
+
215
+ def __init__(self):
216
+ pass
217
+
218
+ def __call__(self, sample):
219
+ image = np.transpose(sample["image"], (2, 0, 1))
220
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221
+
222
+ if "mask" in sample:
223
+ sample["mask"] = sample["mask"].astype(np.float32)
224
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
225
+
226
+ if "disparity" in sample:
227
+ disparity = sample["disparity"].astype(np.float32)
228
+ sample["disparity"] = np.ascontiguousarray(disparity)
229
+
230
+ if "depth" in sample:
231
+ depth = sample["depth"].astype(np.float32)
232
+ sample["depth"] = np.ascontiguousarray(depth)
233
+
234
+ return sample
extralibs/midas/midas/vit.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Slice(nn.Module):
10
+ def __init__(self, start_index=1):
11
+ super(Slice, self).__init__()
12
+ self.start_index = start_index
13
+
14
+ def forward(self, x):
15
+ return x[:, self.start_index :]
16
+
17
+
18
+ class AddReadout(nn.Module):
19
+ def __init__(self, start_index=1):
20
+ super(AddReadout, self).__init__()
21
+ self.start_index = start_index
22
+
23
+ def forward(self, x):
24
+ if self.start_index == 2:
25
+ readout = (x[:, 0] + x[:, 1]) / 2
26
+ else:
27
+ readout = x[:, 0]
28
+ return x[:, self.start_index :] + readout.unsqueeze(1)
29
+
30
+
31
+ class ProjectReadout(nn.Module):
32
+ def __init__(self, in_features, start_index=1):
33
+ super(ProjectReadout, self).__init__()
34
+ self.start_index = start_index
35
+
36
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37
+
38
+ def forward(self, x):
39
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40
+ features = torch.cat((x[:, self.start_index :], readout), -1)
41
+
42
+ return self.project(features)
43
+
44
+
45
+ class Transpose(nn.Module):
46
+ def __init__(self, dim0, dim1):
47
+ super(Transpose, self).__init__()
48
+ self.dim0 = dim0
49
+ self.dim1 = dim1
50
+
51
+ def forward(self, x):
52
+ x = x.transpose(self.dim0, self.dim1)
53
+ return x
54
+
55
+
56
+ activations = {}
57
+ def forward_vit(pretrained, x):
58
+ b, c, h, w = x.shape
59
+
60
+ glob = pretrained.model.forward_flex(x)
61
+ pretrained.activations = activations
62
+
63
+ layer_1 = pretrained.activations["1"]
64
+ layer_2 = pretrained.activations["2"]
65
+ layer_3 = pretrained.activations["3"]
66
+ layer_4 = pretrained.activations["4"]
67
+
68
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
69
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
70
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
71
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
72
+
73
+ unflatten = nn.Sequential(
74
+ nn.Unflatten(
75
+ 2,
76
+ torch.Size(
77
+ [
78
+ h // pretrained.model.patch_size[1],
79
+ w // pretrained.model.patch_size[0],
80
+ ]
81
+ ),
82
+ )
83
+ )
84
+
85
+ if layer_1.ndim == 3:
86
+ layer_1 = unflatten(layer_1)
87
+ if layer_2.ndim == 3:
88
+ layer_2 = unflatten(layer_2)
89
+ if layer_3.ndim == 3:
90
+ layer_3 = unflatten(layer_3)
91
+ if layer_4.ndim == 3:
92
+ layer_4 = unflatten(layer_4)
93
+
94
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
95
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
96
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
97
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
98
+
99
+ return layer_1, layer_2, layer_3, layer_4
100
+
101
+
102
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
103
+ posemb_tok, posemb_grid = (
104
+ posemb[:, : self.start_index],
105
+ posemb[0, self.start_index :],
106
+ )
107
+
108
+ gs_old = int(math.sqrt(len(posemb_grid)))
109
+
110
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
111
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
112
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
113
+
114
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
115
+
116
+ return posemb
117
+
118
+
119
+ def forward_flex(self, x):
120
+ b, c, h, w = x.shape
121
+
122
+ pos_embed = self._resize_pos_embed(
123
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
124
+ )
125
+
126
+ B = x.shape[0]
127
+
128
+ if hasattr(self.patch_embed, "backbone"):
129
+ x = self.patch_embed.backbone(x)
130
+ if isinstance(x, (list, tuple)):
131
+ x = x[-1] # last feature if backbone outputs list/tuple of features
132
+
133
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
134
+
135
+ if getattr(self, "dist_token", None) is not None:
136
+ cls_tokens = self.cls_token.expand(
137
+ B, -1, -1
138
+ ) # stole cls_tokens impl from Phil Wang, thanks
139
+ dist_token = self.dist_token.expand(B, -1, -1)
140
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
141
+ else:
142
+ cls_tokens = self.cls_token.expand(
143
+ B, -1, -1
144
+ ) # stole cls_tokens impl from Phil Wang, thanks
145
+ x = torch.cat((cls_tokens, x), dim=1)
146
+
147
+ x = x + pos_embed
148
+ x = self.pos_drop(x)
149
+
150
+ for blk in self.blocks:
151
+ x = blk(x)
152
+
153
+ x = self.norm(x)
154
+
155
+ return x
156
+
157
+
158
+ def get_activation(name):
159
+ def hook(model, input, output):
160
+ activations[name] = output
161
+ return hook
162
+
163
+
164
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
165
+ if use_readout == "ignore":
166
+ readout_oper = [Slice(start_index)] * len(features)
167
+ elif use_readout == "add":
168
+ readout_oper = [AddReadout(start_index)] * len(features)
169
+ elif use_readout == "project":
170
+ readout_oper = [
171
+ ProjectReadout(vit_features, start_index) for out_feat in features
172
+ ]
173
+ else:
174
+ assert (
175
+ False
176
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
177
+
178
+ return readout_oper
179
+
180
+
181
+ def _make_vit_b16_backbone(
182
+ model,
183
+ features=[96, 192, 384, 768],
184
+ size=[384, 384],
185
+ hooks=[2, 5, 8, 11],
186
+ vit_features=768,
187
+ use_readout="ignore",
188
+ start_index=1,
189
+ ):
190
+ pretrained = nn.Module()
191
+
192
+ pretrained.model = model
193
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
194
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
195
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
196
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
197
+
198
+ pretrained.activations = activations
199
+
200
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
201
+
202
+ # 32, 48, 136, 384
203
+ pretrained.act_postprocess1 = nn.Sequential(
204
+ readout_oper[0],
205
+ Transpose(1, 2),
206
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
207
+ nn.Conv2d(
208
+ in_channels=vit_features,
209
+ out_channels=features[0],
210
+ kernel_size=1,
211
+ stride=1,
212
+ padding=0,
213
+ ),
214
+ nn.ConvTranspose2d(
215
+ in_channels=features[0],
216
+ out_channels=features[0],
217
+ kernel_size=4,
218
+ stride=4,
219
+ padding=0,
220
+ bias=True,
221
+ dilation=1,
222
+ groups=1,
223
+ ),
224
+ )
225
+
226
+ pretrained.act_postprocess2 = nn.Sequential(
227
+ readout_oper[1],
228
+ Transpose(1, 2),
229
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
230
+ nn.Conv2d(
231
+ in_channels=vit_features,
232
+ out_channels=features[1],
233
+ kernel_size=1,
234
+ stride=1,
235
+ padding=0,
236
+ ),
237
+ nn.ConvTranspose2d(
238
+ in_channels=features[1],
239
+ out_channels=features[1],
240
+ kernel_size=2,
241
+ stride=2,
242
+ padding=0,
243
+ bias=True,
244
+ dilation=1,
245
+ groups=1,
246
+ ),
247
+ )
248
+
249
+ pretrained.act_postprocess3 = nn.Sequential(
250
+ readout_oper[2],
251
+ Transpose(1, 2),
252
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
253
+ nn.Conv2d(
254
+ in_channels=vit_features,
255
+ out_channels=features[2],
256
+ kernel_size=1,
257
+ stride=1,
258
+ padding=0,
259
+ ),
260
+ )
261
+
262
+ pretrained.act_postprocess4 = nn.Sequential(
263
+ readout_oper[3],
264
+ Transpose(1, 2),
265
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
266
+ nn.Conv2d(
267
+ in_channels=vit_features,
268
+ out_channels=features[3],
269
+ kernel_size=1,
270
+ stride=1,
271
+ padding=0,
272
+ ),
273
+ nn.Conv2d(
274
+ in_channels=features[3],
275
+ out_channels=features[3],
276
+ kernel_size=3,
277
+ stride=2,
278
+ padding=1,
279
+ ),
280
+ )
281
+
282
+ pretrained.model.start_index = start_index
283
+ pretrained.model.patch_size = [16, 16]
284
+
285
+ # We inject this function into the VisionTransformer instances so that
286
+ # we can use it with interpolated position embeddings without modifying the library source.
287
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
288
+ pretrained.model._resize_pos_embed = types.MethodType(
289
+ _resize_pos_embed, pretrained.model
290
+ )
291
+
292
+ return pretrained
293
+
294
+
295
+ def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
296
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
297
+
298
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
299
+ return _make_vit_b16_backbone(
300
+ model,
301
+ features=[256, 512, 1024, 1024],
302
+ hooks=hooks,
303
+ vit_features=1024,
304
+ use_readout=use_readout,
305
+ )
306
+
307
+
308
+ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
309
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
310
+
311
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
312
+ return _make_vit_b16_backbone(
313
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
314
+ )
315
+
316
+
317
+ def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
318
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
319
+
320
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
321
+ return _make_vit_b16_backbone(
322
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
323
+ )
324
+
325
+
326
+ def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
327
+ model = timm.create_model(
328
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
329
+ )
330
+
331
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
332
+ return _make_vit_b16_backbone(
333
+ model,
334
+ features=[96, 192, 384, 768],
335
+ hooks=hooks,
336
+ use_readout=use_readout,
337
+ start_index=2,
338
+ )
339
+
340
+
341
+ def _make_vit_b_rn50_backbone(
342
+ model,
343
+ features=[256, 512, 768, 768],
344
+ size=[384, 384],
345
+ hooks=[0, 1, 8, 11],
346
+ vit_features=768,
347
+ use_vit_only=False,
348
+ use_readout="ignore",
349
+ start_index=1,
350
+ ):
351
+ pretrained = nn.Module()
352
+
353
+ pretrained.model = model
354
+
355
+ if use_vit_only == True:
356
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
357
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
358
+ else:
359
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
360
+ get_activation("1")
361
+ )
362
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
363
+ get_activation("2")
364
+ )
365
+
366
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
367
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
368
+
369
+ pretrained.activations = activations
370
+
371
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
372
+
373
+ if use_vit_only == True:
374
+ pretrained.act_postprocess1 = nn.Sequential(
375
+ readout_oper[0],
376
+ Transpose(1, 2),
377
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
378
+ nn.Conv2d(
379
+ in_channels=vit_features,
380
+ out_channels=features[0],
381
+ kernel_size=1,
382
+ stride=1,
383
+ padding=0,
384
+ ),
385
+ nn.ConvTranspose2d(
386
+ in_channels=features[0],
387
+ out_channels=features[0],
388
+ kernel_size=4,
389
+ stride=4,
390
+ padding=0,
391
+ bias=True,
392
+ dilation=1,
393
+ groups=1,
394
+ ),
395
+ )
396
+
397
+ pretrained.act_postprocess2 = nn.Sequential(
398
+ readout_oper[1],
399
+ Transpose(1, 2),
400
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
401
+ nn.Conv2d(
402
+ in_channels=vit_features,
403
+ out_channels=features[1],
404
+ kernel_size=1,
405
+ stride=1,
406
+ padding=0,
407
+ ),
408
+ nn.ConvTranspose2d(
409
+ in_channels=features[1],
410
+ out_channels=features[1],
411
+ kernel_size=2,
412
+ stride=2,
413
+ padding=0,
414
+ bias=True,
415
+ dilation=1,
416
+ groups=1,
417
+ ),
418
+ )
419
+ else:
420
+ pretrained.act_postprocess1 = nn.Sequential(
421
+ nn.Identity(), nn.Identity(), nn.Identity()
422
+ )
423
+ pretrained.act_postprocess2 = nn.Sequential(
424
+ nn.Identity(), nn.Identity(), nn.Identity()
425
+ )
426
+
427
+ pretrained.act_postprocess3 = nn.Sequential(
428
+ readout_oper[2],
429
+ Transpose(1, 2),
430
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
431
+ nn.Conv2d(
432
+ in_channels=vit_features,
433
+ out_channels=features[2],
434
+ kernel_size=1,
435
+ stride=1,
436
+ padding=0,
437
+ ),
438
+ )
439
+
440
+ pretrained.act_postprocess4 = nn.Sequential(
441
+ readout_oper[3],
442
+ Transpose(1, 2),
443
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
444
+ nn.Conv2d(
445
+ in_channels=vit_features,
446
+ out_channels=features[3],
447
+ kernel_size=1,
448
+ stride=1,
449
+ padding=0,
450
+ ),
451
+ nn.Conv2d(
452
+ in_channels=features[3],
453
+ out_channels=features[3],
454
+ kernel_size=3,
455
+ stride=2,
456
+ padding=1,
457
+ ),
458
+ )
459
+
460
+ pretrained.model.start_index = start_index
461
+ pretrained.model.patch_size = [16, 16]
462
+
463
+ # We inject this function into the VisionTransformer instances so that
464
+ # we can use it with interpolated position embeddings without modifying the library source.
465
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
466
+
467
+ # We inject this function into the VisionTransformer instances so that
468
+ # we can use it with interpolated position embeddings without modifying the library source.
469
+ pretrained.model._resize_pos_embed = types.MethodType(
470
+ _resize_pos_embed, pretrained.model
471
+ )
472
+
473
+ return pretrained
474
+
475
+
476
+ def _make_pretrained_vitb_rn50_384(
477
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
478
+ ):
479
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
480
+
481
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
482
+ return _make_vit_b_rn50_backbone(
483
+ model,
484
+ features=[256, 512, 768, 768],
485
+ size=[384, 384],
486
+ hooks=hooks,
487
+ use_vit_only=use_vit_only,
488
+ use_readout=use_readout,
489
+ )
extralibs/midas/utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for monoDepth."""
2
+ import sys
3
+ import re
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+
8
+
9
+ def read_pfm(path):
10
+ """Read pfm file.
11
+
12
+ Args:
13
+ path (str): path to file
14
+
15
+ Returns:
16
+ tuple: (data, scale)
17
+ """
18
+ with open(path, "rb") as file:
19
+
20
+ color = None
21
+ width = None
22
+ height = None
23
+ scale = None
24
+ endian = None
25
+
26
+ header = file.readline().rstrip()
27
+ if header.decode("ascii") == "PF":
28
+ color = True
29
+ elif header.decode("ascii") == "Pf":
30
+ color = False
31
+ else:
32
+ raise Exception("Not a PFM file: " + path)
33
+
34
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35
+ if dim_match:
36
+ width, height = list(map(int, dim_match.groups()))
37
+ else:
38
+ raise Exception("Malformed PFM header.")
39
+
40
+ scale = float(file.readline().decode("ascii").rstrip())
41
+ if scale < 0:
42
+ # little-endian
43
+ endian = "<"
44
+ scale = -scale
45
+ else:
46
+ # big-endian
47
+ endian = ">"
48
+
49
+ data = np.fromfile(file, endian + "f")
50
+ shape = (height, width, 3) if color else (height, width)
51
+
52
+ data = np.reshape(data, shape)
53
+ data = np.flipud(data)
54
+
55
+ return data, scale
56
+
57
+
58
+ def write_pfm(path, image, scale=1):
59
+ """Write pfm file.
60
+
61
+ Args:
62
+ path (str): pathto file
63
+ image (array): data
64
+ scale (int, optional): Scale. Defaults to 1.
65
+ """
66
+
67
+ with open(path, "wb") as file:
68
+ color = None
69
+
70
+ if image.dtype.name != "float32":
71
+ raise Exception("Image dtype must be float32.")
72
+
73
+ image = np.flipud(image)
74
+
75
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
76
+ color = True
77
+ elif (
78
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79
+ ): # greyscale
80
+ color = False
81
+ else:
82
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83
+
84
+ file.write("PF\n" if color else "Pf\n".encode())
85
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86
+
87
+ endian = image.dtype.byteorder
88
+
89
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
90
+ scale = -scale
91
+
92
+ file.write("%f\n".encode() % scale)
93
+
94
+ image.tofile(file)
95
+
96
+
97
+ def read_image(path):
98
+ """Read image and output RGB image (0-1).
99
+
100
+ Args:
101
+ path (str): path to file
102
+
103
+ Returns:
104
+ array: RGB image (0-1)
105
+ """
106
+ img = cv2.imread(path)
107
+
108
+ if img.ndim == 2:
109
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110
+
111
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112
+
113
+ return img
114
+
115
+
116
+ def resize_image(img):
117
+ """Resize image and make it fit for network.
118
+
119
+ Args:
120
+ img (array): image
121
+
122
+ Returns:
123
+ tensor: data ready for network
124
+ """
125
+ height_orig = img.shape[0]
126
+ width_orig = img.shape[1]
127
+
128
+ if width_orig > height_orig:
129
+ scale = width_orig / 384
130
+ else:
131
+ scale = height_orig / 384
132
+
133
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135
+
136
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137
+
138
+ img_resized = (
139
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140
+ )
141
+ img_resized = img_resized.unsqueeze(0)
142
+
143
+ return img_resized
144
+
145
+
146
+ def resize_depth(depth, width, height):
147
+ """Resize depth map and bring to CPU (numpy).
148
+
149
+ Args:
150
+ depth (tensor): depth
151
+ width (int): image width
152
+ height (int): image height
153
+
154
+ Returns:
155
+ array: processed depth
156
+ """
157
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158
+
159
+ depth_resized = cv2.resize(
160
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161
+ )
162
+
163
+ return depth_resized
164
+
165
+ def write_depth(path, depth, bits=1):
166
+ """Write depth map to pfm and png file.
167
+
168
+ Args:
169
+ path (str): filepath without extension
170
+ depth (array): depth
171
+ """
172
+ write_pfm(path + ".pfm", depth.astype(np.float32))
173
+
174
+ depth_min = depth.min()
175
+ depth_max = depth.max()
176
+
177
+ max_val = (2**(8*bits))-1
178
+
179
+ if depth_max - depth_min > np.finfo("float").eps:
180
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
181
+ else:
182
+ out = np.zeros(depth.shape, dtype=depth.type)
183
+
184
+ if bits == 1:
185
+ cv2.imwrite(path + ".png", out.astype("uint8"))
186
+ elif bits == 2:
187
+ cv2.imwrite(path + ".png", out.astype("uint16"))
188
+
189
+ return
lvdm/models/ddpm3d.py CHANGED
@@ -1432,3 +1432,53 @@ class DiffusionWrapper(pl.LightningModule):
1432
 
1433
  return out
1434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1432
 
1433
  return out
1434
 
1435
+
1436
+ class T2VAdapterDepth(LatentDiffusion):
1437
+ def __init__(self, depth_stage_config, adapter_config, *args, **kwargs):
1438
+ super(T2VAdapterDepth, self).__init__(*args, **kwargs)
1439
+ self.adapter = instantiate_from_config(adapter_config)
1440
+ self.condtype = adapter_config.cond_name
1441
+ self.depth_stage_model = instantiate_from_config(depth_stage_config)
1442
+
1443
+ def prepare_midas_input(self, batch_x):
1444
+ # input: b,c,h,w
1445
+ x_midas = torch.nn.functional.interpolate(batch_x, size=(384, 384), mode='bicubic')
1446
+ return x_midas
1447
+
1448
+ @torch.no_grad()
1449
+ def get_batch_depth(self, batch_x, target_size, encode_bs=1):
1450
+ b, c, t, h, w = batch_x.shape
1451
+ merge_x = rearrange(batch_x, 'b c t h w -> (b t) c h w')
1452
+ split_x = torch.split(merge_x, encode_bs, dim=0)
1453
+ cond_depth_list = []
1454
+ for x in split_x:
1455
+ x_midas = self.prepare_midas_input(x)
1456
+ cond_depth = self.depth_stage_model(x_midas)
1457
+ cond_depth = torch.nn.functional.interpolate(
1458
+ cond_depth,
1459
+ size=target_size,
1460
+ mode="bicubic",
1461
+ align_corners=False,
1462
+ )
1463
+ depth_min, depth_max = torch.amin(cond_depth, dim=[1, 2, 3], keepdim=True), torch.amax(cond_depth, dim=[1, 2, 3], keepdim=True)
1464
+ cond_depth = 2. * (cond_depth - depth_min) / (depth_max - depth_min + 1e-7) - 1.
1465
+ cond_depth_list.append(cond_depth)
1466
+ batch_cond_depth=torch.cat(cond_depth_list, dim=0)
1467
+ batch_cond_depth = rearrange(batch_cond_depth, '(b t) c h w -> b c t h w', b=b, t=t)
1468
+ return batch_cond_depth
1469
+
1470
+ def get_adapter_features(self, extra_cond, encode_bs=1):
1471
+ b, c, t, h, w = extra_cond.shape
1472
+ ## process in 2D manner
1473
+ merge_extra_cond = rearrange(extra_cond, 'b c t h w -> (b t) c h w')
1474
+ split_extra_cond = torch.split(merge_extra_cond, encode_bs, dim=0)
1475
+ features_adapter_list = []
1476
+ for extra_cond in split_extra_cond:
1477
+ features_adapter = self.adapter(extra_cond)
1478
+ features_adapter_list.append(features_adapter)
1479
+ merge_features_adapter_list = []
1480
+ for i in range(len(features_adapter_list[0])):
1481
+ merge_features_adapter = torch.cat([features_adapter_list[num][i] for num in range(len(features_adapter_list))], dim=0)
1482
+ merge_features_adapter_list.append(merge_features_adapter)
1483
+ merge_features_adapter_list = [rearrange(feature, '(b t) c h w -> b c t h w', b=b, t=t) for feature in merge_features_adapter_list]
1484
+ return merge_features_adapter_list
lvdm/models/modules/adapter.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from collections import OrderedDict
4
+ from lvdm.models.modules.util import (
5
+ zero_module,
6
+ conv_nd,
7
+ avg_pool_nd
8
+ )
9
+
10
+ class Downsample(nn.Module):
11
+ """
12
+ A downsampling layer with an optional convolution.
13
+ :param channels: channels in the inputs and outputs.
14
+ :param use_conv: a bool determining if a convolution is applied.
15
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
16
+ downsampling occurs in the inner-two dimensions.
17
+ """
18
+
19
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
20
+ super().__init__()
21
+ self.channels = channels
22
+ self.out_channels = out_channels or channels
23
+ self.use_conv = use_conv
24
+ self.dims = dims
25
+ stride = 2 if dims != 3 else (1, 2, 2)
26
+ if use_conv:
27
+ self.op = conv_nd(
28
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
29
+ )
30
+ else:
31
+ assert self.channels == self.out_channels
32
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
33
+
34
+ def forward(self, x):
35
+ assert x.shape[1] == self.channels
36
+ return self.op(x)
37
+
38
+
39
+ class ResnetBlock(nn.Module):
40
+ def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
41
+ super().__init__()
42
+ ps = ksize // 2
43
+ if in_c != out_c or sk == False:
44
+ self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
45
+ else:
46
+ # print('n_in')
47
+ self.in_conv = None
48
+ self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
49
+ self.act = nn.ReLU()
50
+ self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
51
+ if sk == False:
52
+ self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
53
+ else:
54
+ self.skep = None
55
+
56
+ self.down = down
57
+ if self.down == True:
58
+ self.down_opt = Downsample(in_c, use_conv=use_conv)
59
+
60
+ def forward(self, x):
61
+ if self.down == True:
62
+ x = self.down_opt(x)
63
+ if self.in_conv is not None: # edit
64
+ x = self.in_conv(x)
65
+
66
+ h = self.block1(x)
67
+ h = self.act(h)
68
+ h = self.block2(h)
69
+ if self.skep is not None:
70
+ return h + self.skep(x)
71
+ else:
72
+ return h + x
73
+
74
+
75
+ class Adapter(nn.Module):
76
+ def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True):
77
+ super(Adapter, self).__init__()
78
+ self.unshuffle = nn.PixelUnshuffle(8)
79
+ self.channels = channels
80
+ self.nums_rb = nums_rb
81
+ self.body = []
82
+ for i in range(len(channels)):
83
+ for j in range(nums_rb):
84
+ if (i != 0) and (j == 0):
85
+ self.body.append(
86
+ ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
87
+ else:
88
+ self.body.append(
89
+ ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
90
+ self.body = nn.ModuleList(self.body)
91
+ self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
92
+
93
+ def forward(self, x):
94
+ # unshuffle
95
+ x = self.unshuffle(x)
96
+ # extract features
97
+ features = []
98
+ x = self.conv_in(x)
99
+ for i in range(len(self.channels)):
100
+ for j in range(self.nums_rb):
101
+ idx = i * self.nums_rb + j
102
+ x = self.body[idx](x)
103
+ features.append(x)
104
+
105
+ return features
lvdm/models/modules/lora.py CHANGED
@@ -622,7 +622,7 @@ def net_load_lora(net, checkpoint_path, alpha=1.0, remove=False):
622
  state_dict = torch.load(checkpoint_path)
623
  for k, v in state_dict.items():
624
  state_dict[k] = v.to(net.device)
625
- # import pdb;pdb.set_trace()
626
  for key in state_dict:
627
  if ".alpha" in key or key in visited:
628
  continue
@@ -685,7 +685,7 @@ def net_load_lora_v2(net, checkpoint_path, alpha=1.0, remove=False, origin_weigh
685
  state_dict = torch.load(checkpoint_path)
686
  for k, v in state_dict.items():
687
  state_dict[k] = v.to(net.device)
688
- # import pdb;pdb.set_trace()
689
  for key in state_dict:
690
  if ".alpha" in key or key in visited:
691
  continue
 
622
  state_dict = torch.load(checkpoint_path)
623
  for k, v in state_dict.items():
624
  state_dict[k] = v.to(net.device)
625
+
626
  for key in state_dict:
627
  if ".alpha" in key or key in visited:
628
  continue
 
685
  state_dict = torch.load(checkpoint_path)
686
  for k, v in state_dict.items():
687
  state_dict[k] = v.to(net.device)
688
+
689
  for key in state_dict:
690
  if ".alpha" in key or key in visited:
691
  continue
lvdm/models/modules/openaimodel3d.py CHANGED
@@ -629,7 +629,7 @@ class UNetModel(nn.Module):
629
  self.middle_block.apply(convert_module_to_f32)
630
  self.output_blocks.apply(convert_module_to_f32)
631
 
632
- def forward(self, x, timesteps=None, time_emb_replace=None, context=None, y=None, **kwargs):
633
  """
634
  Apply the model to an input batch.
635
  :param x: an [N x C x ...] Tensor of inputs.
@@ -651,9 +651,17 @@ class UNetModel(nn.Module):
651
  emb = emb + self.label_emb(y)
652
 
653
  h = x.type(self.dtype)
654
- for module in self.input_blocks:
 
655
  h = module(h, emb, context, **kwargs)
 
 
 
 
656
  hs.append(h)
 
 
 
657
  h = self.middle_block(h, emb, context, **kwargs)
658
  for module in self.output_blocks:
659
  h = th.cat([h, hs.pop()], dim=1)
 
629
  self.middle_block.apply(convert_module_to_f32)
630
  self.output_blocks.apply(convert_module_to_f32)
631
 
632
+ def forward(self, x, timesteps=None, time_emb_replace=None, context=None, features_adapter=None, y=None, **kwargs):
633
  """
634
  Apply the model to an input batch.
635
  :param x: an [N x C x ...] Tensor of inputs.
 
651
  emb = emb + self.label_emb(y)
652
 
653
  h = x.type(self.dtype)
654
+ adapter_idx = 0
655
+ for id, module in enumerate(self.input_blocks):
656
  h = module(h, emb, context, **kwargs)
657
+ ## plug-in adapter features
658
+ if ((id+1)%3 == 0) and features_adapter is not None:
659
+ h = h + features_adapter[adapter_idx]
660
+ adapter_idx += 1
661
  hs.append(h)
662
+ if features_adapter is not None:
663
+ assert len(features_adapter)==adapter_idx, 'Mismatch features adapter'
664
+
665
  h = self.middle_block(h, emb, context, **kwargs)
666
  for module in self.output_blocks:
667
  h = th.cat([h, hs.pop()], dim=1)
lvdm/samplers/ddim.py CHANGED
@@ -197,7 +197,7 @@ class DDIMSampler(object):
197
  def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
198
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
199
  unconditional_guidance_scale=1., unconditional_conditioning=None, sample_noise=None,
200
- cond_fn=None,uc_type=None, model_kwargs={},
201
  **kwargs,
202
  ):
203
  b, *_, device = *x.shape, x.device
@@ -206,15 +206,15 @@ class DDIMSampler(object):
206
  else:
207
  is_video = False
208
  if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
209
- e_t = self.model.apply_model(x, t, c, **model_kwargs) # unet denoiser
210
  else:
211
  # with unconditional condition
212
  if isinstance(c, torch.Tensor):
213
- e_t = self.model.apply_model(x, t, c, **model_kwargs)
214
- e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **model_kwargs)
215
  elif isinstance(c, dict):
216
- e_t = self.model.apply_model(x, t, c, **model_kwargs)
217
- e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **model_kwargs)
218
  else:
219
  raise NotImplementedError
220
  # text cfg
 
197
  def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
198
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
199
  unconditional_guidance_scale=1., unconditional_conditioning=None, sample_noise=None,
200
+ cond_fn=None, uc_type=None,
201
  **kwargs,
202
  ):
203
  b, *_, device = *x.shape, x.device
 
206
  else:
207
  is_video = False
208
  if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
209
+ e_t = self.model.apply_model(x, t, c, **kwargs) # unet denoiser
210
  else:
211
  # with unconditional condition
212
  if isinstance(c, torch.Tensor):
213
+ e_t = self.model.apply_model(x, t, c, **kwargs)
214
+ e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs)
215
  elif isinstance(c, dict):
216
+ e_t = self.model.apply_model(x, t, c, **kwargs)
217
+ e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs)
218
  else:
219
  raise NotImplementedError
220
  # text cfg
lvdm/utils/saving_utils.py CHANGED
@@ -14,6 +14,24 @@ from torchvision.utils import make_grid
14
  from torch import Tensor
15
  from torchvision.transforms.functional import to_tensor
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # ----------------------------------------------------------------------------------------------
18
  def savenp2sheet(imgs, savepath, nrow=None):
19
  """ save multiple imgs (in numpy array type) to a img sheet.
 
14
  from torch import Tensor
15
  from torchvision.transforms.functional import to_tensor
16
 
17
+
18
+ def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
19
+ """
20
+ video: torch.Tensor, b,c,t,h,w, 0-1
21
+ if -1~1, enable rescale=True
22
+ """
23
+ n = video.shape[0]
24
+ video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
25
+ nrow = int(np.sqrt(n)) if nrow is None else nrow
26
+ frame_grids = [torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video] # [3, grid_h, grid_w]
27
+ grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [T, 3, grid_h, grid_w]
28
+ grid = torch.clamp(grid.float(), -1., 1.)
29
+ if rescale:
30
+ grid = (grid + 1.0) / 2.0
31
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3]
32
+ #print(f'Save video to {savepath}')
33
+ torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
34
+
35
  # ----------------------------------------------------------------------------------------------
36
  def savenp2sheet(imgs, savepath, nrow=None):
37
  """ save multiple imgs (in numpy array type) to a img sheet.
models/adapter_t2v_depth/model_config.yaml ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: lvdm.models.ddpm3d.T2VAdapterDepth
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.012
6
+ num_timesteps_cond: 1
7
+ log_every_t: 200
8
+ timesteps: 1000
9
+ first_stage_key: video
10
+ cond_stage_key: caption
11
+ image_size:
12
+ - 32
13
+ - 32
14
+ video_length: 16
15
+ channels: 4
16
+ cond_stage_trainable: false
17
+ conditioning_key: crossattn
18
+ scale_by_std: false
19
+ scale_factor: 0.18215
20
+
21
+ unet_config:
22
+ target: lvdm.models.modules.openaimodel3d.UNetModel
23
+ params:
24
+ image_size: 32
25
+ in_channels: 4
26
+ out_channels: 4
27
+ model_channels: 320
28
+ attention_resolutions:
29
+ - 4
30
+ - 2
31
+ - 1
32
+ num_res_blocks: 2
33
+ channel_mult:
34
+ - 1
35
+ - 2
36
+ - 4
37
+ - 4
38
+ num_heads: 8
39
+ transformer_depth: 1
40
+ context_dim: 768
41
+ use_checkpoint: true
42
+ legacy: false
43
+ kernel_size_t: 1
44
+ padding_t: 0
45
+ temporal_length: 16
46
+ use_relative_position: true
47
+
48
+ first_stage_config:
49
+ target: lvdm.models.autoencoder.AutoencoderKL
50
+ params:
51
+ embed_dim: 4
52
+ monitor: val/rec_loss
53
+ ddconfig:
54
+ double_z: true
55
+ z_channels: 4
56
+ resolution: 256
57
+ in_channels: 3
58
+ out_ch: 3
59
+ ch: 128
60
+ ch_mult:
61
+ - 1
62
+ - 2
63
+ - 4
64
+ - 4
65
+ num_res_blocks: 2
66
+ attn_resolutions: []
67
+ dropout: 0.0
68
+ lossconfig:
69
+ target: torch.nn.Identity
70
+
71
+ cond_stage_config:
72
+ target: lvdm.models.modules.condition_modules.FrozenCLIPEmbedder
73
+
74
+ depth_stage_config:
75
+ target: extralibs.midas.api.MiDaSInference
76
+ params:
77
+ model_type: "dpt_hybrid"
78
+ model_path: models/adapter_t2v_depth/dpt_hybrid-midas.pt
79
+
80
+ adapter_config:
81
+ target: lvdm.models.modules.adapter.Adapter
82
+ cond_name: depth
83
+ params:
84
+ cin: 64
85
+ channels: [320, 640, 1280, 1280]
86
+ nums_rb: 2
87
+ ksize: 1
88
+ sk: True
89
+ use_conv: False
sample_adapter.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROMPT="An ostrich walking in the desert, photorealistic, 4k"
2
+ VIDEO="input/flamingo.mp4"
3
+ OUTDIR="results/"
4
+
5
+ NAME="video_adapter"
6
+ CONFIG_PATH="models/adapter_t2v_depth/model_config.yaml"
7
+ BASE_PATH="models/base_t2v/model.ckpt"
8
+ ADAPTER_PATH="models/adapter_t2v_depth/adapter.pth"
9
+
10
+ python scripts/sample_text2video_adapter.py \
11
+ --seed 123 \
12
+ --ckpt_path $BASE_PATH \
13
+ --adapter_ckpt $ADAPTER_PATH \
14
+ --base $CONFIG_PATH \
15
+ --savedir $OUTDIR/$NAME \
16
+ --bs 1 --height 256 --width 256 \
17
+ --frame_stride -1 \
18
+ --unconditional_guidance_scale 15.0 \
19
+ --ddim_steps 50 \
20
+ --ddim_eta 1.0 \
21
+ --prompt "$PROMPT" \
22
+ --video $VIDEO
scripts/ddp_wrapper.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import argparse, importlib
3
+ from pytorch_lightning import seed_everything
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+
9
+ def setup_dist(local_rank):
10
+ if dist.is_initialized():
11
+ return
12
+ torch.cuda.set_device(local_rank)
13
+ torch.distributed.init_process_group('nccl', init_method='env://')
14
+
15
+
16
+ def get_dist_info():
17
+ if dist.is_available():
18
+ initialized = dist.is_initialized()
19
+ else:
20
+ initialized = False
21
+ if initialized:
22
+ rank = dist.get_rank()
23
+ world_size = dist.get_world_size()
24
+ else:
25
+ rank = 0
26
+ world_size = 1
27
+ return rank, world_size
28
+
29
+
30
+ if __name__ == '__main__':
31
+ now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument("--module", type=str, help="module name", default="inference")
34
+ parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0)
35
+ args, unknown = parser.parse_known_args()
36
+ inference_api = importlib.import_module(args.module, package=None)
37
+
38
+ inference_parser = inference_api.get_parser()
39
+ inference_args, unknown = inference_parser.parse_known_args()
40
+
41
+ seed_everything(inference_args.seed)
42
+ setup_dist(args.local_rank)
43
+ torch.backends.cudnn.benchmark = True
44
+ rank, gpu_num = get_dist_info()
45
+
46
+ print("@CoVideoGen Inference [rank%d]: %s"%(rank, now))
47
+ inference_api.run_inference(inference_args, rank)
scripts/sample_text2video_adapter.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob
2
+ import datetime, time
3
+ from omegaconf import OmegaConf
4
+
5
+ import torch
6
+ from decord import VideoReader, cpu
7
+ import torchvision
8
+ from pytorch_lightning import seed_everything
9
+
10
+ from lvdm.samplers.ddim import DDIMSampler
11
+ from lvdm.utils.common_utils import instantiate_from_config
12
+ from lvdm.utils.saving_utils import tensor_to_mp4
13
+
14
+
15
+ def get_filelist(data_dir, ext='*'):
16
+ file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext))
17
+ file_list.sort()
18
+ return file_list
19
+
20
+ def load_model_checkpoint(model, ckpt, adapter_ckpt=None):
21
+ print('>>> Loading checkpoints ...')
22
+ if adapter_ckpt:
23
+ ## main model
24
+ state_dict = torch.load(ckpt, map_location="cpu")
25
+ if "state_dict" in list(state_dict.keys()):
26
+ state_dict = state_dict["state_dict"]
27
+ model.load_state_dict(state_dict, strict=False)
28
+ print('@model checkpoint loaded.')
29
+ ## adapter
30
+ state_dict = torch.load(adapter_ckpt, map_location="cpu")
31
+ if "state_dict" in list(state_dict.keys()):
32
+ state_dict = state_dict["state_dict"]
33
+ model.adapter.load_state_dict(state_dict, strict=True)
34
+ print('@adapter checkpoint loaded.')
35
+ else:
36
+ state_dict = torch.load(ckpt, map_location="cpu")
37
+ if "state_dict" in list(state_dict.keys()):
38
+ state_dict = state_dict["state_dict"]
39
+ model.load_state_dict(state_dict, strict=True)
40
+ print('@model checkpoint loaded.')
41
+ return model
42
+
43
+ def load_prompts(prompt_file):
44
+ f = open(prompt_file, 'r')
45
+ prompt_list = []
46
+ for idx, line in enumerate(f.readlines()):
47
+ l = line.strip()
48
+ if len(l) != 0:
49
+ prompt_list.append(l)
50
+ f.close()
51
+ return prompt_list
52
+
53
+ def load_video(filepath, frame_stride, video_size=(256,256), video_frames=16):
54
+ vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
55
+ max_frames = len(vidreader)
56
+ temp_stride = max_frames // video_frames if frame_stride == -1 else frame_stride
57
+ if temp_stride * (video_frames-1) >= max_frames:
58
+ print(f'Warning: default frame stride is used because the input video clip {max_frames} is not long enough.')
59
+ temp_stride = max_frames // video_frames
60
+ frame_indices = [temp_stride*i for i in range(video_frames)]
61
+ frames = vidreader.get_batch(frame_indices)
62
+
63
+ ## [t,h,w,c] -> [c,t,h,w]
64
+ frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
65
+ frame_tensor = (frame_tensor / 255. - 0.5) * 2
66
+ return frame_tensor
67
+
68
+
69
+ def save_results(prompt, samples, inputs, filename, realdir, fakedir, fps=10):
70
+ ## save prompt
71
+ prompt = prompt[0] if isinstance(prompt, list) else prompt
72
+ path = os.path.join(realdir, "%s.txt"%filename)
73
+ with open(path, 'w') as f:
74
+ f.write(f'{prompt}')
75
+ f.close()
76
+
77
+ ## save video
78
+ videos = [inputs, samples]
79
+ savedirs = [realdir, fakedir]
80
+ for idx, video in enumerate(videos):
81
+ if video is None:
82
+ continue
83
+ # b,c,t,h,w
84
+ video = video.detach().cpu()
85
+ video = torch.clamp(video.float(), -1., 1.)
86
+ n = video.shape[0]
87
+ video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
88
+ frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n)) for framesheet in video] #[3, 1*h, n*w]
89
+ grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
90
+ grid = (grid + 1.0) / 2.0
91
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
92
+ path = os.path.join(savedirs[idx], "%s.mp4"%filename)
93
+ torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'})
94
+
95
+
96
+ def adapter_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \
97
+ unconditional_guidance_scale=1.0, unconditional_guidance_scale_temporal=None, **kwargs):
98
+ ddim_sampler = DDIMSampler(model)
99
+
100
+ batch_size = noise_shape[0]
101
+ ## get condition embeddings (support single prompt only)
102
+ if isinstance(prompts, str):
103
+ prompts = [prompts]
104
+ cond = model.get_learned_conditioning(prompts)
105
+ if unconditional_guidance_scale != 1.0:
106
+ prompts = batch_size * [""]
107
+ uc = model.get_learned_conditioning(prompts)
108
+ else:
109
+ uc = None
110
+
111
+ ## adapter features: process in 2D manner
112
+ b, c, t, h, w = videos.shape
113
+ extra_cond = model.get_batch_depth(videos, (h,w))
114
+ features_adapter = model.get_adapter_features(extra_cond)
115
+
116
+ batch_variants = []
117
+ for _ in range(n_samples):
118
+ if ddim_sampler is not None:
119
+ samples, _ = ddim_sampler.sample(S=ddim_steps,
120
+ conditioning=cond,
121
+ batch_size=noise_shape[0],
122
+ shape=noise_shape[1:],
123
+ verbose=False,
124
+ unconditional_guidance_scale=unconditional_guidance_scale,
125
+ unconditional_conditioning=uc,
126
+ eta=ddim_eta,
127
+ temporal_length=noise_shape[2],
128
+ conditional_guidance_scale_temporal=unconditional_guidance_scale_temporal,
129
+ features_adapter=features_adapter,
130
+ **kwargs
131
+ )
132
+ ## reconstruct from latent to pixel space
133
+ batch_images = model.decode_first_stage(samples, decode_bs=1, return_cpu=False)
134
+ batch_variants.append(batch_images)
135
+ ## variants, batch, c, t, h, w
136
+ batch_variants = torch.stack(batch_variants)
137
+ return batch_variants.permute(1, 0, 2, 3, 4, 5), extra_cond
138
+
139
+
140
+ def run_inference(args, gpu_idx):
141
+ ## model config
142
+ config = OmegaConf.load(args.base)
143
+ model_config = config.pop("model", OmegaConf.create())
144
+ model = instantiate_from_config(model_config)
145
+ model = model.cuda(gpu_idx)
146
+ assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
147
+ model = load_model_checkpoint(model, args.ckpt_path, args.adapter_ckpt)
148
+ model.eval()
149
+
150
+ ## run over data
151
+ assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!"
152
+ ## latent noise shape
153
+ h, w = args.height // 8, args.width // 8
154
+ channels = model.channels
155
+ frames = model.temporal_length
156
+ noise_shape = [args.bs, channels, frames, h, w]
157
+
158
+ ## inference
159
+ start = time.time()
160
+ prompt = args.prompt
161
+ video = load_video(args.video, args.frame_stride, video_size=(args.height, args.width), video_frames=16)
162
+ video = video.unsqueeze(0).to("cuda")
163
+ with torch.no_grad():
164
+ batch_samples, batch_conds = adapter_guided_synthesis(model, prompt, video, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, \
165
+ args.unconditional_guidance_scale, args.unconditional_guidance_scale_temporal)
166
+ batch_samples = batch_samples[0]
167
+ os.makedirs(args.savedir, exist_ok=True)
168
+ filename = f"{args.prompt}_seed{args.seed}"
169
+ filename = filename.replace("/", "_slash_") if "/" in filename else filename
170
+ filename = filename.replace(" ", "_") if " " in filename else filename
171
+ tensor_to_mp4(video=batch_conds.detach().cpu(), savepath=os.path.join(args.savedir, f'{filename}_depth.mp4'), fps=10)
172
+ tensor_to_mp4(video=batch_samples.detach().cpu(), savepath=os.path.join(args.savedir, f'{filename}_sample.mp4'), fps=10)
173
+
174
+ print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds")
175
+
176
+
177
+ def get_parser():
178
+ parser = argparse.ArgumentParser()
179
+ parser.add_argument("--savedir", type=str, default=None, help="results saving path")
180
+ parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path")
181
+ parser.add_argument("--adapter_ckpt", type=str, default=None, help="adapter checkpoint path")
182
+ parser.add_argument("--base", type=str, help="config (yaml) path")
183
+ parser.add_argument("--prompt", type=str, default=None, help="prompt string")
184
+ parser.add_argument("--video", type=str, default=None, help="video path")
185
+ parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",)
186
+ parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",)
187
+ parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",)
188
+ parser.add_argument("--bs", type=int, default=1, help="batch size for inference")
189
+ parser.add_argument("--height", type=int, default=512, help="image height, in pixel space")
190
+ parser.add_argument("--width", type=int, default=512, help="image width, in pixel space")
191
+ parser.add_argument("--frame_stride", type=int, default=-1, help="frame extracting from input video")
192
+ parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance")
193
+ parser.add_argument("--unconditional_guidance_scale_temporal", type=float, default=None, help="temporal consistency guidance")
194
+ parser.add_argument("--seed", type=int, default=2023, help="seed for seed_everything")
195
+ return parser
196
+
197
+
198
+ if __name__ == '__main__':
199
+ now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
200
+ print("@CoVideoGen cond-Inference: %s"%now)
201
+ parser = get_parser()
202
+ args = parser.parse_args()
203
+
204
+ seed_everything(args.seed)
205
+ rank = 0
206
+ run_inference(args, rank)