einsafutdinov commited on
Commit
201e424
β€’
1 Parent(s): 735bc3b

add flash3d model and unidepth code

Browse files
Files changed (49) hide show
  1. flash3d/networks/depth_decoder.py +81 -0
  2. flash3d/networks/gaussian_decoder.py +196 -0
  3. flash3d/networks/gaussian_predictor.py +293 -0
  4. flash3d/networks/layers.py +295 -0
  5. flash3d/networks/resnet_encoder.py +115 -0
  6. flash3d/networks/unidepth.py +577 -0
  7. flash3d/networks/unidepth_extension.py +205 -0
  8. flash3d/unidepth/layers/__init__.py +21 -0
  9. flash3d/unidepth/layers/activation.py +15 -0
  10. flash3d/unidepth/layers/attention.py +308 -0
  11. flash3d/unidepth/layers/convnext.py +44 -0
  12. flash3d/unidepth/layers/drop_path.py +25 -0
  13. flash3d/unidepth/layers/layer_scale.py +17 -0
  14. flash3d/unidepth/layers/mlp.py +34 -0
  15. flash3d/unidepth/layers/nystrom_attention.py +74 -0
  16. flash3d/unidepth/layers/positional_encoding.py +228 -0
  17. flash3d/unidepth/layers/upsample.py +69 -0
  18. flash3d/unidepth/models/__init__.py +5 -0
  19. flash3d/unidepth/models/backbones/__init__.py +9 -0
  20. flash3d/unidepth/models/backbones/convnext.py +590 -0
  21. flash3d/unidepth/models/backbones/convnext2.py +288 -0
  22. flash3d/unidepth/models/backbones/dinov2.py +552 -0
  23. flash3d/unidepth/models/backbones/metadinov2/__init__.py +12 -0
  24. flash3d/unidepth/models/backbones/metadinov2/attention.py +85 -0
  25. flash3d/unidepth/models/backbones/metadinov2/block.py +284 -0
  26. flash3d/unidepth/models/backbones/metadinov2/dino_head.py +68 -0
  27. flash3d/unidepth/models/backbones/metadinov2/drop_path.py +37 -0
  28. flash3d/unidepth/models/backbones/metadinov2/layer_scale.py +28 -0
  29. flash3d/unidepth/models/backbones/metadinov2/mlp.py +41 -0
  30. flash3d/unidepth/models/backbones/metadinov2/patch_embed.py +101 -0
  31. flash3d/unidepth/models/backbones/metadinov2/swiglu_ffn.py +63 -0
  32. flash3d/unidepth/models/encoder.py +184 -0
  33. flash3d/unidepth/models/unidepthv1/__init__.py +5 -0
  34. flash3d/unidepth/models/unidepthv1/decoder.py +542 -0
  35. flash3d/unidepth/models/unidepthv1/unidepthv1.py +329 -0
  36. flash3d/unidepth/ops/__init__.py +9 -0
  37. flash3d/unidepth/ops/losses.py +429 -0
  38. flash3d/unidepth/ops/scheduler.py +70 -0
  39. flash3d/unidepth/utils/__init__.py +35 -0
  40. flash3d/unidepth/utils/constants.py +21 -0
  41. flash3d/unidepth/utils/distributed.py +179 -0
  42. flash3d/unidepth/utils/ema_torch.py +342 -0
  43. flash3d/unidepth/utils/evaluation_depth.py +173 -0
  44. flash3d/unidepth/utils/geometric.py +248 -0
  45. flash3d/unidepth/utils/misc.py +403 -0
  46. flash3d/unidepth/utils/positional_embedding.py +274 -0
  47. flash3d/unidepth/utils/sht.py +1637 -0
  48. flash3d/unidepth/utils/visualization.py +201 -0
  49. flash3d/util/vis3d.py +135 -0
flash3d/networks/depth_decoder.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2019. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the Monodepth2 licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from collections import OrderedDict
12
+ from networks.layers import upsample, ConvBlock, Conv3x3
13
+
14
+ from einops import rearrange
15
+
16
+
17
+ class DepthDecoder(nn.Module):
18
+ def __init__(self, cfg, num_ch_enc, num_output_channels=1, use_skips=True):
19
+ super(DepthDecoder, self).__init__()
20
+
21
+ self.cfg = cfg
22
+ depth_num = cfg.model.gaussians_per_pixel - 1 if "unidepth" in cfg.model.name else cfg.model.gaussians_per_pixel
23
+ self.num_output_channels = num_output_channels * depth_num
24
+ self.use_skips = use_skips
25
+ self.upsample_mode = 'nearest'
26
+ self.scales = cfg.model.scales
27
+
28
+ self.num_ch_enc = num_ch_enc
29
+ self.num_ch_dec = np.array([16, 32, 64, 128, 256])
30
+
31
+ # decoder
32
+ self.convs = OrderedDict()
33
+ for i in range(4, -1, -1):
34
+ # upconv_0
35
+ num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
36
+ num_ch_out = self.num_ch_dec[i]
37
+ self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
38
+
39
+ # upconv_1
40
+ num_ch_in = self.num_ch_dec[i]
41
+ if self.use_skips and i > 0:
42
+ num_ch_in += self.num_ch_enc[i - 1]
43
+ num_ch_out = self.num_ch_dec[i]
44
+ self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
45
+
46
+ for s in self.scales:
47
+ out = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
48
+ self.convs[("dispconv", s)] = out
49
+ nn.init.xavier_uniform_(out.conv.weight, cfg.model.depth_scale)
50
+ nn.init.constant_(out.conv.bias, cfg.model.depth_bias)
51
+
52
+ self.decoder = nn.ModuleList(list(self.convs.values()))
53
+ if cfg.model.depth_type in ["disp", "disp_inc"]:
54
+ self.activate = nn.Sigmoid()
55
+ elif cfg.model.depth_type == "depth":
56
+ self.activate = nn.Softplus()
57
+ elif cfg.model.depth_type == "depth_inc":
58
+ self.activate = torch.exp
59
+
60
+ def forward(self, input_features):
61
+ outputs = {}
62
+ x = input_features[-1]
63
+ for i in range(4, -1, -1):
64
+ x = self.convs[("upconv", i, 0)](x)
65
+ x = [upsample(x)]
66
+ if self.use_skips and i > 0:
67
+ x += [input_features[i - 1]]
68
+ x = torch.cat(x, 1)
69
+ x = self.convs[("upconv", i, 1)](x)
70
+ if i in self.scales:
71
+ depth_num = self.cfg.model.gaussians_per_pixel - 1 if "unidepth" in self.cfg.model.name else self.cfg.model.gaussians_per_pixel
72
+ if self.cfg.model.depth_type == "depth_inc":
73
+ outputs[("depth", i)] = rearrange(self.activate(torch.clamp(self.convs[("dispconv", i)](x), min=-10.0, max=6.0)),
74
+ 'b (n c) ...-> (b n) c ...', n = depth_num)
75
+ elif self.cfg.model.depth_type in ["disp", "disp_inc"]:
76
+ outputs[("disp", i)] = rearrange(self.activate(self.convs[("dispconv", i)](x)),
77
+ 'b (n c) ...-> (b n) c ...', n = depth_num)
78
+ else:
79
+ outputs[(self.cfg.model.depth_type, i)] = rearrange(self.activate(self.convs[("dispconv", i)](x)),
80
+ 'b (n c) ...-> (b n) c ...', n = depth_num)
81
+ return outputs
flash3d/networks/gaussian_decoder.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+
8
+ def upsample(x):
9
+ """Upsample input tensor by a factor of 2
10
+ """
11
+ return F.interpolate(x, scale_factor=2, mode="nearest")
12
+
13
+
14
+ class Conv3x3(nn.Module):
15
+ """Layer to pad and convolve input
16
+ """
17
+ def __init__(self, in_channels, out_channels, use_refl=True):
18
+ super(Conv3x3, self).__init__()
19
+
20
+ if use_refl:
21
+ self.pad = nn.ReflectionPad2d(1)
22
+ else:
23
+ self.pad = nn.ZeroPad2d(1)
24
+ self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
25
+
26
+ def forward(self, x):
27
+ out = self.pad(x)
28
+ out = self.conv(out)
29
+ return out
30
+
31
+
32
+ class ConvBlock(nn.Module):
33
+ """Layer to perform a convolution followed by ELU
34
+ """
35
+ def __init__(self, in_channels, out_channels):
36
+ super(ConvBlock, self).__init__()
37
+
38
+ self.conv = Conv3x3(in_channels, out_channels)
39
+ self.nonlin = nn.ELU(inplace=True)
40
+
41
+ def forward(self, x):
42
+ out = self.conv(x)
43
+ out = self.nonlin(out)
44
+ return out
45
+
46
+
47
+ def get_splits_and_inits(cfg):
48
+ split_dimensions = []
49
+ scale_inits = []
50
+ bias_inits = []
51
+
52
+ for g_idx in range(cfg.model.gaussians_per_pixel):
53
+ if cfg.model.predict_offset:
54
+ split_dimensions += [3]
55
+ scale_inits += [cfg.model.xyz_scale]
56
+ bias_inits += [cfg.model.xyz_bias]
57
+
58
+ split_dimensions += [1, 3, 4, 3]
59
+ scale_inits += [cfg.model.opacity_scale,
60
+ cfg.model.scale_scale,
61
+ 1.0,
62
+ 5.0]
63
+ bias_inits += [cfg.model.opacity_bias,
64
+ np.log(cfg.model.scale_bias),
65
+ 0.0,
66
+ 0.0]
67
+
68
+ if cfg.model.max_sh_degree != 0:
69
+ sh_num = (cfg.model.max_sh_degree + 1) ** 2 - 1
70
+ sh_num_rgb = sh_num * 3
71
+ split_dimensions.append(sh_num_rgb)
72
+ scale_inits.append(cfg.model.sh_scale)
73
+ bias_inits.append(0.0)
74
+ if not cfg.model.one_gauss_decoder:
75
+ break
76
+
77
+ return split_dimensions, scale_inits, bias_inits,
78
+
79
+
80
+ class GaussianDecoder(nn.Module):
81
+ def __init__(self, cfg, num_ch_enc, use_skips=True):
82
+ super(GaussianDecoder, self).__init__()
83
+
84
+ self.cfg = cfg
85
+ self.use_skips = use_skips
86
+ self.upsample_mode = 'nearest'
87
+
88
+ self.num_ch_enc = num_ch_enc
89
+ self.num_ch_dec = np.array(cfg.model.num_ch_dec)
90
+
91
+ split_dimensions, scale, bias = get_splits_and_inits(cfg)
92
+
93
+ # [offset], opacity, scaling, rotation, feat_dc
94
+ assert not cfg.model.unified_decoder
95
+
96
+ self.split_dimensions = split_dimensions
97
+
98
+ self.num_output_channels = sum(self.split_dimensions)
99
+
100
+ # decoder
101
+ self.convs = OrderedDict()
102
+ for i in range(4, -1, -1):
103
+ # upconv_0
104
+ num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
105
+ num_ch_out = self.num_ch_dec[i]
106
+ self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
107
+
108
+ # upconv_1
109
+ num_ch_in = self.num_ch_dec[i]
110
+ if self.use_skips and i > 0:
111
+ num_ch_in += self.num_ch_enc[i - 1]
112
+ num_ch_out = self.num_ch_dec[i]
113
+ self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
114
+
115
+ self.out = nn.Conv2d(self.num_ch_dec[0], self.num_output_channels, 1)
116
+
117
+ out_channels = self.split_dimensions
118
+ start_channels = 0
119
+ for out_channel, b, s in zip(out_channels, bias, scale):
120
+ nn.init.xavier_uniform_(
121
+ self.out.weight[start_channels:start_channels+out_channel,
122
+ :, :, :], s)
123
+ nn.init.constant_(
124
+ self.out.bias[start_channels:start_channels+out_channel], b)
125
+ start_channels += out_channel
126
+
127
+ self.decoder = nn.ModuleList(list(self.convs.values()))
128
+
129
+ self.scaling_activation = torch.exp
130
+ self.opacity_activation = torch.sigmoid
131
+ self.rotation_activation = torch.nn.functional.normalize
132
+ self.scaling_lambda = cfg.model.scale_lambda
133
+ self.sigmoid = nn.Sigmoid()
134
+
135
+ def forward(self, input_features):
136
+ self.outputs = {}
137
+
138
+ # decoder
139
+ x = input_features[-1]
140
+ for i in range(4, -1, -1):
141
+ x = self.convs[("upconv", i, 0)](x)
142
+ x = [upsample(x)]
143
+ if self.use_skips and i > 0:
144
+ x += [input_features[i - 1]]
145
+ x = torch.cat(x, 1)
146
+ x = self.convs[("upconv", i, 1)](x)
147
+
148
+ x = self.out(x)
149
+
150
+ split_network_outputs = x.split(self.split_dimensions, dim=1)
151
+
152
+ offset_list = []
153
+ opacity_list = []
154
+ scaling_list = []
155
+ rotation_list = []
156
+ feat_dc_list = []
157
+ feat_rest_list = []
158
+
159
+ assert not self.cfg.model.unified_decoder
160
+
161
+ for i in range(self.cfg.model.gaussians_per_pixel):
162
+ assert self.cfg.model.max_sh_degree > 0
163
+ if self.cfg.model.predict_offset:
164
+ offset_s, opacity_s, scaling_s, \
165
+ rotation_s, feat_dc_s, features_rest_s = split_network_outputs[i*6:(i+1)*6]
166
+ offset_list.append(offset_s[:, None, ...])
167
+ else:
168
+ opacity_s, scaling_s, rotation_s, feat_dc_s, features_rest_s = split_network_outputs[i*5:(i+1)*5]
169
+ opacity_list.append(opacity_s[:, None, ...])
170
+ scaling_list.append(scaling_s[:, None, ...])
171
+ rotation_list.append(rotation_s[:, None, ...])
172
+ feat_dc_list.append(feat_dc_s[:, None, ...])
173
+ feat_rest_list.append(features_rest_s[:, None, ...])
174
+ if not self.cfg.model.one_gauss_decoder:
175
+ break
176
+
177
+ # squeezing will remove dimension if there is only one gaussian per pixel
178
+ opacity = torch.cat(opacity_list, dim=1).squeeze(1)
179
+ scaling = torch.cat(scaling_list, dim=1).squeeze(1)
180
+ rotation = torch.cat(rotation_list, dim=1).squeeze(1)
181
+ feat_dc = torch.cat(feat_dc_list, dim=1).squeeze(1)
182
+ features_rest = torch.cat(feat_rest_list, dim=1).squeeze(1)
183
+
184
+ out = {
185
+ ("gauss_opacity", 0): self.opacity_activation(opacity),
186
+ ("gauss_scaling", 0): self.scaling_activation(scaling) * self.scaling_lambda,
187
+ ("gauss_rotation", 0): self.rotation_activation(rotation),
188
+ ("gauss_features_dc", 0): feat_dc,
189
+ ("gauss_features_rest", 0): features_rest
190
+ }
191
+
192
+ if self.cfg.model.predict_offset:
193
+ offset = torch.cat(offset_list, dim=1).squeeze(1)
194
+ out[("gauss_offset", 0)] = offset
195
+ return out
196
+
flash3d/networks/gaussian_predictor.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import logging
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from networks.layers import BackprojectDepth, disp_to_depth
9
+ from networks.resnet_encoder import ResnetEncoder
10
+ from networks.depth_decoder import DepthDecoder
11
+ from networks.gaussian_decoder import GaussianDecoder
12
+
13
+
14
+ def default_param_group(model):
15
+ return [{'params': model.parameters()}]
16
+
17
+
18
+ def to_device(inputs, device):
19
+ for key, ipt in inputs.items():
20
+ if isinstance(ipt, torch.Tensor):
21
+ inputs[key] = ipt.to(device)
22
+ return inputs
23
+
24
+
25
+ class GaussianPredictor(nn.Module):
26
+ def __init__(self, cfg):
27
+ super().__init__()
28
+ self.cfg = cfg
29
+
30
+ # checking height and width are multiples of 32
31
+ # assert cfg.dataset.width % 32 == 0, "'width' must be a multiple of 32"
32
+
33
+ models = {}
34
+ self.parameters_to_train = []
35
+
36
+ self.num_scales = len(cfg.model.scales)
37
+
38
+ assert cfg.model.frame_ids[0] == 0, "frame_ids must start with 0"
39
+
40
+ if cfg.model.use_stereo:
41
+ cfg.model.frame_ids.append("s")
42
+
43
+ model_name = cfg.model.name
44
+ if model_name == "resnet":
45
+ models["encoder"] = ResnetEncoder(
46
+ cfg.model.num_layers,
47
+ cfg.model.weights_init == "pretrained",
48
+ cfg.model.resnet_bn_order
49
+ )
50
+ self.parameters_to_train += default_param_group(models["encoder"])
51
+ if not cfg.model.unified_decoder:
52
+ models["depth"] = DepthDecoder(
53
+ cfg, models["encoder"].num_ch_enc)
54
+ self.parameters_to_train += default_param_group(models["depth"])
55
+ if cfg.model.gaussian_rendering:
56
+ for i in range(cfg.model.gaussians_per_pixel):
57
+ gauss_decoder = GaussianDecoder(
58
+ cfg, models["encoder"].num_ch_enc,
59
+ )
60
+ self.parameters_to_train += default_param_group(gauss_decoder)
61
+ models["gauss_decoder_"+str(i)] = gauss_decoder
62
+ elif model_name == "unidepth":
63
+ from networks.unidepth import UniDepthSplatter
64
+ models["unidepth"] = UniDepthSplatter(cfg)
65
+ self.parameters_to_train += models["unidepth"].get_parameter_groups()
66
+ elif model_name in ["unidepth_unprojector_vit", "unidepth_unprojector_cnvnxtl"]:
67
+ from networks.unidepth import UniDepthUnprojector
68
+ models["unidepth"] = UniDepthUnprojector(cfg)
69
+ self.parameters_to_train += models["unidepth"].get_parameter_groups()
70
+ elif model_name in ["unidepth_extension_vit", "unidepth_extension_cnvnxtl"]:
71
+ from networks.unidepth_extension import UniDepthExtended
72
+ models["unidepth_extended"] = UniDepthExtended(cfg)
73
+ self.parameters_to_train += models["unidepth_extended"].get_parameter_groups()
74
+
75
+ self.models = nn.ModuleDict(models)
76
+
77
+ backproject_depth = {}
78
+ H = cfg.dataset.height
79
+ W = cfg.dataset.width
80
+ for scale in cfg.model.scales:
81
+ h = H // (2 ** scale)
82
+ w = W // (2 ** scale)
83
+ if cfg.model.shift_rays_half_pixel == "zero":
84
+ shift_rays_half_pixel = 0
85
+ elif cfg.model.shift_rays_half_pixel == "forward":
86
+ shift_rays_half_pixel = 0.5
87
+ elif cfg.model.shift_rays_half_pixel == "backward":
88
+ shift_rays_half_pixel = -0.5
89
+ else:
90
+ raise NotImplementedError
91
+ backproject_depth[str(scale)] = BackprojectDepth(
92
+ cfg.optimiser.batch_size * cfg.model.gaussians_per_pixel,
93
+ # backprojection can be different if padding was used
94
+ h + 2 * self.cfg.dataset.pad_border_aug,
95
+ w + 2 * self.cfg.dataset.pad_border_aug,
96
+ shift_rays_half_pixel=shift_rays_half_pixel
97
+ )
98
+ self.backproject_depth = nn.ModuleDict(backproject_depth)
99
+
100
+ def set_train(self):
101
+ """Convert all models to training mode
102
+ """
103
+ for m in self.models.values():
104
+ m.train()
105
+ self._is_train = True
106
+
107
+ def set_eval(self):
108
+ """Convert all models to testing/evaluation mode
109
+ """
110
+ for m in self.models.values():
111
+ m.eval()
112
+ self._is_train = False
113
+
114
+ def is_train(self):
115
+ return self._is_train
116
+
117
+ def forward(self, inputs):
118
+ cfg = self.cfg
119
+ B = cfg.optimiser.batch_size
120
+
121
+ if cfg.model.name == "resnet":
122
+ do_flip = self.is_train() and \
123
+ cfg.train.lazy_flip_augmentation and \
124
+ (torch.rand(1) > .5).item()
125
+ # Otherwise, we only feed the image with frame_id 0 through the depth encoder
126
+ input_img = inputs["color_aug", 0, 0]
127
+ if do_flip:
128
+ input_img = torch.flip(input_img, dims=(-1, ))
129
+ features = self.models["encoder"](input_img)
130
+ if not cfg.model.unified_decoder:
131
+ outputs = self.models["depth"](features)
132
+ else:
133
+ outputs = dict()
134
+
135
+ if self.cfg.model.gaussian_rendering:
136
+ # gauss_feats = self.models["gauss_encoder"](inputs["color_aug", 0, 0])
137
+ input_f_id = 0
138
+ gauss_feats = features
139
+ gauss_outs = dict()
140
+ for i in range(self.cfg.model.gaussians_per_pixel):
141
+ outs = self.models["gauss_decoder_"+str(i)](gauss_feats)
142
+ for key, v in outs.items():
143
+ gauss_outs[key] = outs[key][:,None,...] if i==0 else torch.cat([gauss_outs[key], outs[key][:,None,...]], dim=1)
144
+ for key, v in gauss_outs.items():
145
+ gauss_outs[key] = rearrange(gauss_outs[key], 'b n ... -> (b n) ...')
146
+ outputs |= gauss_outs
147
+ outputs = {(key[0], input_f_id, key[1]): v for key, v in outputs.items()}
148
+ else:
149
+ for scale in cfg.model.scales:
150
+ outputs[("disp", 0, scale)] = outputs[("disp", scale)]
151
+
152
+ # unflip all outputs
153
+ if do_flip:
154
+ for k, v in outputs.items():
155
+ outputs[k] = torch.flip(v, dims=(-1, ))
156
+ elif "unidepth" in cfg.model.name:
157
+ if cfg.model.name in ["unidepth",
158
+ "unidepth_unprojector_vit",
159
+ "unidepth_unprojector_cnvnxtl"]:
160
+ outputs = self.models["unidepth"](inputs)
161
+ elif cfg.model.name in ["unidepth_extension_vit",
162
+ "unidepth_extension_cnvnxtl"]:
163
+ outputs = self.models["unidepth_extended"](inputs)
164
+
165
+ input_f_id = 0
166
+ outputs = {(key[0], input_f_id, key[1]): v for key, v in outputs.items()}
167
+
168
+ input_f_id = 0
169
+ scale = 0
170
+ if not ("depth", input_f_id, scale) in outputs:
171
+ disp = outputs[("disp", input_f_id, scale)]
172
+ _, depth = disp_to_depth(disp, cfg.model.min_depth, cfg.model.max_depth)
173
+ outputs[("depth", input_f_id, scale)] = depth
174
+
175
+ self.compute_gauss_means(inputs, outputs)
176
+
177
+ return outputs
178
+
179
+ def target_tensor_image_dims(self, inputs):
180
+ B, _, H, W = inputs["color", 0, 0].shape
181
+ return B, H, W
182
+
183
+ def compute_gauss_means(self, inputs, outputs):
184
+ cfg = self.cfg
185
+ input_f_id = 0
186
+ scale = 0
187
+ depth = outputs[("depth", input_f_id, scale)]
188
+ B, _, H, W = depth.shape
189
+ if ("inv_K_src", scale) in inputs:
190
+ inv_K = inputs[("inv_K_src", scale)]
191
+ else:
192
+ inv_K = outputs[("inv_K_src", input_f_id, scale)]
193
+ if self.cfg.model.gaussians_per_pixel > 1:
194
+ inv_K = rearrange(inv_K[:,None,...].
195
+ repeat(1, self.cfg.model.gaussians_per_pixel, 1, 1),
196
+ 'b n ... -> (b n) ...')
197
+ xyz = self.backproject_depth[str(scale)](
198
+ depth, inv_K
199
+ )
200
+ inputs[("inv_K_src", scale)] = inv_K
201
+ if cfg.model.predict_offset:
202
+ offset = outputs[("gauss_offset", input_f_id, scale)]
203
+ if cfg.model.scaled_offset:
204
+ offset = offset * depth.detach()
205
+ offset = offset.view(B, 3, -1)
206
+ zeros = torch.zeros(B, 1, H * W, device=depth.device)
207
+ offset = torch.cat([offset, zeros], 1)
208
+ xyz = xyz + offset # [B, 4, W*H]
209
+ outputs[("gauss_means", input_f_id, scale)] = xyz
210
+
211
+ def checkpoint_dir(self):
212
+ return Path("checkpoints")
213
+
214
+ def save_model(self, optimizer, step, ema=None):
215
+ """Save model weights to disk
216
+ """
217
+ save_folder = self.checkpoint_dir()
218
+ save_folder.mkdir(exist_ok=True, parents=True)
219
+
220
+ save_path = save_folder / f"model_{step:07}.pth"
221
+ logging.info(f"saving checkpoint to {str(save_path)}")
222
+
223
+ model = ema.ema_model if ema is not None else self
224
+ save_dict = {
225
+ "model": model.state_dict(),
226
+ "version": "1.0",
227
+ "optimiser": optimizer.state_dict(),
228
+ "step": step
229
+ }
230
+ torch.save(save_dict, save_path)
231
+
232
+ num_ckpts = self.cfg.optimiser.num_keep_ckpts
233
+ ckpts = sorted(list(save_folder.glob("model_*.pth")), reverse=True)
234
+ if len(ckpts) > num_ckpts:
235
+ for ckpt in ckpts[num_ckpts:]:
236
+ ckpt.unlink()
237
+
238
+ def load_model(self, weights_path, optimizer=None):
239
+ """Load model(s) from disk
240
+ """
241
+ weights_path = Path(weights_path)
242
+
243
+ # determine if it is an old or new saving format
244
+ if weights_path.is_dir() and weights_path.joinpath("encoder.pth").exists():
245
+ self.load_model_old(weights_path, optimizer)
246
+ return
247
+
248
+ logging.info(f"Loading weights from {weights_path}...")
249
+ state_dict = torch.load(weights_path)
250
+ if "version" in state_dict and state_dict["version"] == "1.0":
251
+ new_dict = {}
252
+ for k, v in state_dict["model"].items():
253
+ if "backproject_depth" in k:
254
+ new_dict[k] = self.state_dict()[k].clone()
255
+ else:
256
+ new_dict[k] = v.clone()
257
+ # for k, v in state_dict["model"].items():
258
+ # if "backproject_depth" in k and ("pix_coords" in k or "ones" in k):
259
+ # # model has these parameters set as a function of batch size
260
+ # # when batch size changes in eval this results in a loading error
261
+ # state_dict["model"][k] = v[:1, ...]
262
+ self.load_state_dict(new_dict, strict=False)
263
+ else:
264
+ # TODO remove loading according to the old format
265
+ for name in self.cfg.train.models_to_load:
266
+ if name not in self.models:
267
+ continue
268
+ self.models[name].load_state_dict(state_dict[name])
269
+
270
+ # loading adam state
271
+ if optimizer is not None:
272
+ optimizer.load_state_dict(state_dict["optimiser"])
273
+ self.step = state_dict["step"]
274
+
275
+ def load_model_old(self, weights_folder, optimizer=None):
276
+ for n in self.cfg.train.models_to_load:
277
+ print(f"Loading {n} weights...")
278
+ path = weights_folder / f"{n}.pth"
279
+ if n not in self.models:
280
+ continue
281
+ model_dict = self.models[n].state_dict()
282
+ pretrained_dict = torch.load(path)
283
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
284
+ model_dict.update(pretrained_dict)
285
+ self.models[n].load_state_dict(model_dict)
286
+
287
+ # loading adam state
288
+ optimizer_load_path = weights_folder / "adam.pth"
289
+ if optimizer is not None and optimizer_load_path.is_file():
290
+ print("Loading Adam weights")
291
+ optimizer_state = torch.load(optimizer_load_path)
292
+ optimizer.load_state_dict(optimizer_state["adam"])
293
+ self.step = optimizer_state["step"]
flash3d/networks/layers.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2019. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the Monodepth2 licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ import numpy as np
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ def disp_to_depth(disp, min_depth, max_depth):
15
+ """Convert network's sigmoid output into depth prediction
16
+ The formula for this conversion is given in the 'additional considerations'
17
+ section of the paper.
18
+ """
19
+ min_disp = 1 / max_depth
20
+ max_disp = 1 / min_depth
21
+ scaled_disp = min_disp + (max_disp - min_disp) * disp
22
+ depth = 1 / scaled_disp
23
+ return scaled_disp, depth
24
+
25
+
26
+ def transformation_from_parameters(axisangle, translation, invert=False):
27
+ """Convert the network's (axisangle, translation) output into a 4x4 matrix
28
+ """
29
+ R = rot_from_axisangle(axisangle)
30
+ t = translation.clone()
31
+
32
+ if invert:
33
+ R = R.transpose(1, 2)
34
+ t *= -1
35
+
36
+ T = get_translation_matrix(t)
37
+
38
+ if invert:
39
+ M = torch.matmul(R, T)
40
+ else:
41
+ M = torch.matmul(T, R)
42
+
43
+ return M
44
+
45
+
46
+ def get_translation_matrix(translation_vector):
47
+ """Convert a translation vector into a 4x4 transformation matrix
48
+ """
49
+ T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
50
+
51
+ t = translation_vector.contiguous().view(-1, 3, 1)
52
+
53
+ T[:, 0, 0] = 1
54
+ T[:, 1, 1] = 1
55
+ T[:, 2, 2] = 1
56
+ T[:, 3, 3] = 1
57
+ T[:, :3, 3, None] = t
58
+
59
+ return T
60
+
61
+
62
+ def rot_from_axisangle(vec):
63
+ """Convert an axisangle rotation into a 4x4 transformation matrix
64
+ (adapted from https://github.com/Wallacoloo/printipi)
65
+ Input 'vec' has to be Bx1x3
66
+ """
67
+ angle = torch.norm(vec, 2, 2, True)
68
+ axis = vec / (angle + 1e-7)
69
+
70
+ ca = torch.cos(angle)
71
+ sa = torch.sin(angle)
72
+ C = 1 - ca
73
+
74
+ x = axis[..., 0].unsqueeze(1)
75
+ y = axis[..., 1].unsqueeze(1)
76
+ z = axis[..., 2].unsqueeze(1)
77
+
78
+ xs = x * sa
79
+ ys = y * sa
80
+ zs = z * sa
81
+ xC = x * C
82
+ yC = y * C
83
+ zC = z * C
84
+ xyC = x * yC
85
+ yzC = y * zC
86
+ zxC = z * xC
87
+
88
+ rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
89
+
90
+ rot[:, 0, 0] = torch.squeeze(x * xC + ca)
91
+ rot[:, 0, 1] = torch.squeeze(xyC - zs)
92
+ rot[:, 0, 2] = torch.squeeze(zxC + ys)
93
+ rot[:, 1, 0] = torch.squeeze(xyC + zs)
94
+ rot[:, 1, 1] = torch.squeeze(y * yC + ca)
95
+ rot[:, 1, 2] = torch.squeeze(yzC - xs)
96
+ rot[:, 2, 0] = torch.squeeze(zxC - ys)
97
+ rot[:, 2, 1] = torch.squeeze(yzC + xs)
98
+ rot[:, 2, 2] = torch.squeeze(z * zC + ca)
99
+ rot[:, 3, 3] = 1
100
+
101
+ return rot
102
+
103
+
104
+ class ConvBlock(nn.Module):
105
+ """Layer to perform a convolution followed by ELU
106
+ """
107
+ def __init__(self, in_channels, out_channels):
108
+ super(ConvBlock, self).__init__()
109
+
110
+ self.conv = Conv3x3(in_channels, out_channels)
111
+ self.nonlin = nn.ELU(inplace=True)
112
+
113
+ def forward(self, x):
114
+ out = self.conv(x)
115
+ out = self.nonlin(out)
116
+ return out
117
+
118
+
119
+ class Conv3x3(nn.Module):
120
+ """Layer to pad and convolve input
121
+ """
122
+ def __init__(self, in_channels, out_channels, use_refl=True):
123
+ super(Conv3x3, self).__init__()
124
+
125
+ if use_refl:
126
+ self.pad = nn.ReflectionPad2d(1)
127
+ else:
128
+ self.pad = nn.ZeroPad2d(1)
129
+ self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
130
+
131
+ def forward(self, x):
132
+ out = self.pad(x)
133
+ out = self.conv(out)
134
+ return out
135
+
136
+
137
+ class BackprojectDepth(nn.Module):
138
+ """Layer to transform a depth image into a point cloud
139
+ """
140
+ def __init__(self, batch_size, height, width, shift_rays_half_pixel=0):
141
+ super(BackprojectDepth, self).__init__()
142
+
143
+ self.batch_size = batch_size
144
+ self.height = height
145
+ self.width = width
146
+
147
+ meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
148
+ id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
149
+ id_coords = torch.from_numpy(id_coords)
150
+
151
+ ones = torch.ones(self.batch_size, 1, self.height * self.width)
152
+
153
+ pix_coords = torch.unsqueeze(torch.stack(
154
+ [id_coords[0].view(-1), id_coords[1].view(-1)], 0), 0)
155
+ pix_coords = pix_coords.repeat(batch_size, 1, 1)
156
+ pix_coords = torch.cat([pix_coords + shift_rays_half_pixel,
157
+ ones], 1)
158
+ self.register_buffer("pix_coords", pix_coords)
159
+ self.register_buffer("id_coords", id_coords)
160
+ self.register_buffer("ones", ones)
161
+ # self.pix_coords = pix_coords
162
+ # self.ones = ones
163
+
164
+ def forward(self, depth, inv_K):
165
+ cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords.to(depth.device))
166
+ cam_points = depth.view(self.batch_size, 1, -1) * cam_points
167
+ cam_points = torch.cat([cam_points, self.ones.to(depth.device)], 1)
168
+
169
+ return cam_points
170
+
171
+
172
+ class Project3D(nn.Module):
173
+ """Layer which projects 3D points into a camera with intrinsics K and at position T
174
+ """
175
+ def __init__(self, batch_size, height, width, eps=1e-7):
176
+ super(Project3D, self).__init__()
177
+
178
+ self.batch_size = batch_size
179
+ self.height = height
180
+ self.width = width
181
+ self.eps = eps
182
+
183
+ def forward(self, points, K, T=None):
184
+ if T is None:
185
+ P = K
186
+ else:
187
+ P = torch.matmul(K, T)
188
+ P = P[:, :3, :]
189
+
190
+ cam_points = torch.matmul(P, points)
191
+
192
+ pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
193
+ pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
194
+ pix_coords = pix_coords.permute(0, 2, 3, 1)
195
+ pix_coords[..., 0] /= self.width - 1
196
+ pix_coords[..., 1] /= self.height - 1
197
+ pix_coords = (pix_coords - 0.5) * 2
198
+ return pix_coords
199
+
200
+
201
+ class Project3DSimple(nn.Module):
202
+ """Layer which projects 3D points into a camera with intrinsics K and at position T
203
+ """
204
+ def __init__(self, batch_size, height, width, eps=1e-7):
205
+ super(Project3DSimple, self).__init__()
206
+
207
+ self.batch_size = batch_size
208
+ self.height = height
209
+ self.width = width
210
+ self.eps = eps
211
+
212
+ def forward(self, points, K):
213
+ K = K[:, :3, :]
214
+
215
+ cam_points = torch.matmul(K, points)
216
+
217
+ pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
218
+ pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
219
+ pix_coords = pix_coords.permute(0, 2, 3, 1)
220
+ return pix_coords
221
+
222
+ def upsample(x):
223
+ """Upsample input tensor by a factor of 2
224
+ """
225
+ return F.interpolate(x, scale_factor=2, mode="nearest")
226
+
227
+
228
+ def get_smooth_loss(disp, img):
229
+ """Computes the smoothness loss for a disparity image
230
+ The color image is used for edge-aware smoothness
231
+ """
232
+ grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
233
+ grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
234
+
235
+ grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
236
+ grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
237
+
238
+ grad_disp_x *= torch.exp(-grad_img_x)
239
+ grad_disp_y *= torch.exp(-grad_img_y)
240
+
241
+ return grad_disp_x.mean() + grad_disp_y.mean()
242
+
243
+
244
+ class SSIM(nn.Module):
245
+ """Layer to compute the SSIM loss between a pair of images
246
+ """
247
+ def __init__(self):
248
+ super(SSIM, self).__init__()
249
+ self.mu_x_pool = nn.AvgPool2d(3, 1)
250
+ self.mu_y_pool = nn.AvgPool2d(3, 1)
251
+ self.sig_x_pool = nn.AvgPool2d(3, 1)
252
+ self.sig_y_pool = nn.AvgPool2d(3, 1)
253
+ self.sig_xy_pool = nn.AvgPool2d(3, 1)
254
+
255
+ self.refl = nn.ReflectionPad2d(1)
256
+
257
+ self.C1 = 0.01 ** 2
258
+ self.C2 = 0.03 ** 2
259
+
260
+ def forward(self, x, y):
261
+ x = self.refl(x)
262
+ y = self.refl(y)
263
+
264
+ mu_x = self.mu_x_pool(x)
265
+ mu_y = self.mu_y_pool(y)
266
+
267
+ sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
268
+ sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
269
+ sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
270
+
271
+ SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
272
+ SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
273
+
274
+ return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
275
+
276
+
277
+ def compute_depth_errors(gt, pred):
278
+ """Computation of error metrics between predicted and ground truth depths
279
+ """
280
+ thresh = torch.max((gt / pred), (pred / gt))
281
+ a1 = (thresh < 1.25 ).float().mean()
282
+ a2 = (thresh < 1.25 ** 2).float().mean()
283
+ a3 = (thresh < 1.25 ** 3).float().mean()
284
+
285
+ rmse = (gt - pred) ** 2
286
+ rmse = torch.sqrt(rmse.mean())
287
+
288
+ rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
289
+ rmse_log = torch.sqrt(rmse_log.mean())
290
+
291
+ abs_rel = torch.mean(torch.abs(gt - pred) / gt)
292
+
293
+ sq_rel = torch.mean((gt - pred) ** 2 / gt)
294
+
295
+ return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
flash3d/networks/resnet_encoder.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2019. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the Monodepth2 licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ import numpy as np
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torchvision.models as models
12
+
13
+
14
+ RESNETS = {18: (models.resnet18, models.ResNet18_Weights.IMAGENET1K_V1),
15
+ 50: (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V2)}
16
+
17
+
18
+ class ResNetMultiImageInput(models.ResNet):
19
+ """Constructs a resnet model with varying number of input images.
20
+ Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
21
+ """
22
+ def __init__(self, block, layers, num_classes=1000, num_input_images=1):
23
+ super(ResNetMultiImageInput, self).__init__(block, layers)
24
+ self.inplanes = 64
25
+ self.conv1 = nn.Conv2d(
26
+ num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
27
+ self.bn1 = nn.BatchNorm2d(64)
28
+ self.relu = nn.ReLU(inplace=True)
29
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
30
+ self.layer1 = self._make_layer(block, 64, layers[0])
31
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
32
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
33
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
34
+
35
+ for m in self.modules():
36
+ if isinstance(m, nn.Conv2d):
37
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
38
+ elif isinstance(m, nn.BatchNorm2d):
39
+ nn.init.constant_(m.weight, 1)
40
+ nn.init.constant_(m.bias, 0)
41
+
42
+
43
+ def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
44
+ """Constructs a ResNet model.
45
+ Args:
46
+ num_layers (int): Number of resnet layers. Must be 18 or 50
47
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
48
+ num_input_images (int): Number of frames stacked as input
49
+ """
50
+ assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
51
+ blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
52
+ block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
53
+ model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)
54
+ model, weigths = RESNETS[num_layers]
55
+
56
+ if pretrained:
57
+ loaded = torch.hub.load_state_dict_from_url(weigths.url)
58
+ loaded['conv1.weight'] = torch.cat(
59
+ [loaded['conv1.weight']] * num_input_images, 1) / num_input_images
60
+ model.load_state_dict(loaded)
61
+ return model
62
+
63
+
64
+ class ResnetEncoder(nn.Module):
65
+ """Pytorch module for a resnet encoder
66
+ """
67
+ def __init__(self, num_layers, pretrained, bn_order, num_input_images=1):
68
+ super(ResnetEncoder, self).__init__()
69
+
70
+ self.num_ch_enc = np.array([64, 64, 128, 256, 512])
71
+ self.bn_order = bn_order
72
+
73
+ if num_layers not in RESNETS:
74
+ raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
75
+
76
+ if num_input_images > 1:
77
+ self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
78
+ else:
79
+ model, weights = RESNETS[num_layers]
80
+ self.encoder = model(weights=weights)
81
+
82
+ if num_layers > 34:
83
+ self.num_ch_enc[1:] *= 4
84
+
85
+ def forward(self, input_image):
86
+ encoder = self.encoder
87
+ features = []
88
+ x = (input_image - 0.45) / 0.225
89
+ x = encoder.conv1(x)
90
+
91
+ if self.bn_order == "pre_bn":
92
+ # Concatenating pre-norm features allows us to
93
+ # keep the scale and shift of RGB colours
94
+ # and recover them at output
95
+ features.append(x)
96
+ x = encoder.bn1(x)
97
+ x = encoder.relu(x)
98
+ features.append(encoder.layer1(encoder.maxpool(x)))
99
+ elif self.bn_order == "monodepth":
100
+ # Batchnorm gets rid of constants due to colour shift
101
+ # will make the network not able to recover absolute colour shift
102
+ # of the input image
103
+ # used in old models
104
+ x = encoder.bn1(x)
105
+ x = encoder.relu(x)
106
+ features.append(x)
107
+ features.append(encoder.layer1(encoder.maxpool(x)))
108
+ else:
109
+ assert False
110
+
111
+ features.append(encoder.layer2(features[-1]))
112
+ features.append(encoder.layer3(features[-1]))
113
+ features.append(encoder.layer4(features[-1]))
114
+
115
+ return features
flash3d/networks/unidepth.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import List, Tuple
4
+ from math import ceil
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision.transforms.functional as TF
9
+ from einops import rearrange
10
+
11
+ from unidepth.models.unidepthv1 import UniDepthV1
12
+ from unidepth.utils.constants import IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD
13
+ from unidepth.utils.geometric import (
14
+ generate_rays,
15
+ spherical_zbuffer_to_euclidean,
16
+ flat_interpolate,
17
+ )
18
+ from unidepth.layers import (
19
+ MLP,
20
+ AttentionBlock,
21
+ NystromBlock,
22
+ PositionEmbeddingSine,
23
+ ConvUpsample,
24
+ )
25
+ from unidepth.utils.sht import rsh_cart_8
26
+
27
+ from networks.gaussian_decoder import get_splits_and_inits
28
+
29
+
30
+ # inference helpers
31
+ def _paddings(image_shape, network_shape):
32
+ cur_h, cur_w = image_shape
33
+ h, w = network_shape
34
+ pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2
35
+ pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2
36
+ return pad_left, pad_right, pad_top, pad_bottom
37
+
38
+
39
+ def _shapes(image_shape, network_shape):
40
+ h, w = image_shape
41
+ input_ratio = w / h
42
+ output_ratio = network_shape[1] / network_shape[0]
43
+ if output_ratio > input_ratio:
44
+ ratio = network_shape[0] / h
45
+ elif output_ratio <= input_ratio:
46
+ ratio = network_shape[1] / w
47
+ return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio
48
+
49
+
50
+ def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes):
51
+ (pad_left, pad_right, pad_top, pad_bottom) = pads
52
+ rgbs = F.interpolate(
53
+ rgbs, size=shapes, mode="bilinear", align_corners=False, antialias=True
54
+ )
55
+ rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant")
56
+ if intrinsics is not None:
57
+ intrinsics = intrinsics.clone()
58
+ intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio
59
+ intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio
60
+ intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + pad_left
61
+ intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + pad_top
62
+ return rgbs, intrinsics
63
+ return rgbs, None
64
+
65
+
66
+ def _postprocess(predictions, intrinsics, shapes, pads, ratio, original_shapes):
67
+
68
+ (pad_left, pad_right, pad_top, pad_bottom) = pads
69
+ # pred mean, trim paddings, and upsample to input dim
70
+ predictions = sum(
71
+ [
72
+ F.interpolate(
73
+ x,
74
+ size=shapes,
75
+ mode="bilinear",
76
+ align_corners=False,
77
+ antialias=True,
78
+ )
79
+ for x in predictions
80
+ ]
81
+ ) / len(predictions)
82
+
83
+ shapes = predictions.shape[2:]
84
+ predictions = predictions[
85
+ ..., pad_top : shapes[0] - pad_bottom, pad_left : shapes[1] - pad_right
86
+ ]
87
+
88
+ predictions = F.interpolate(
89
+ predictions,
90
+ size=original_shapes,
91
+ mode="bilinear",
92
+ align_corners=False,
93
+ antialias=True,
94
+ )
95
+
96
+ if intrinsics is not None:
97
+ intrinsics[:, 0, 0] = intrinsics[:, 0, 0] / ratio
98
+ intrinsics[:, 1, 1] = intrinsics[:, 1, 1] / ratio
99
+ intrinsics[:, 0, 2] = (intrinsics[:, 0, 2] - pad_left) / ratio
100
+ intrinsics[:, 1, 2] = (intrinsics[:, 1, 2] - pad_top) / ratio
101
+
102
+ return predictions, intrinsics
103
+
104
+
105
+ def scale_intrinsics_xy(intrinsics, x_ratio, y_ratio):
106
+ intrinsics = intrinsics.clone()
107
+ intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * x_ratio
108
+ intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * y_ratio
109
+ intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * x_ratio
110
+ intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * y_ratio
111
+ return intrinsics
112
+
113
+
114
+ def scale_intrinsics(intrinsics, ratio):
115
+ intrinsics = intrinsics.clone()
116
+ intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio
117
+ intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio
118
+ intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio
119
+ intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio
120
+ return intrinsics
121
+
122
+
123
+ def unidepthv1_forward(model, rgbs, intrinsics, skip_camera,
124
+ return_raw_preds=False):
125
+ B, _, H, W = rgbs.shape
126
+
127
+ rgbs = TF.normalize(
128
+ rgbs,
129
+ mean=IMAGENET_DATASET_MEAN,
130
+ std=IMAGENET_DATASET_STD,
131
+ )
132
+
133
+ (h, w), ratio = _shapes((H, W), model.image_shape)
134
+ pad_left, pad_right, pad_top, pad_bottom = _paddings((h, w), model.image_shape)
135
+ rgbs, gt_intrinsics = _preprocess(
136
+ rgbs,
137
+ intrinsics,
138
+ (h, w),
139
+ (pad_left, pad_right, pad_top, pad_bottom),
140
+ ratio,
141
+ model.image_shape,
142
+ )
143
+
144
+ encoder_outputs, cls_tokens = model.pixel_encoder(rgbs)
145
+ if "dino" in model.pixel_encoder.__class__.__name__.lower():
146
+ encoder_outputs = [
147
+ (x + y.unsqueeze(1)).contiguous()
148
+ for x, y in zip(encoder_outputs, cls_tokens)
149
+ ]
150
+
151
+ # get data for decoder and adapt to given camera
152
+ inputs = {}
153
+ inputs["encoder_outputs"] = encoder_outputs
154
+ inputs["cls_tokens"] = cls_tokens
155
+ inputs["image"] = rgbs
156
+ if gt_intrinsics is not None:
157
+ rays, angles = generate_rays(
158
+ gt_intrinsics, model.image_shape, noisy=False
159
+ )
160
+ inputs["rays"] = rays
161
+ inputs["angles"] = angles
162
+ inputs["K"] = gt_intrinsics
163
+ model.pixel_decoder.test_fixed_camera = True
164
+ model.pixel_decoder.skip_camera = skip_camera
165
+
166
+ # decode all
167
+ pred_intrinsics, predictions, features, rays = model.pixel_decoder(inputs, {})
168
+
169
+ pads = (pad_left, pad_right, pad_top, pad_bottom)
170
+
171
+ # undo the reshaping and get original image size (slow)
172
+ predictions, pred_intrinsics = _postprocess(
173
+ predictions,
174
+ pred_intrinsics,
175
+ model.image_shape,
176
+ pads,
177
+ ratio,
178
+ (H, W),
179
+ )
180
+
181
+ if return_raw_preds:
182
+ return inputs, predictions
183
+
184
+ # final 3D points backprojection
185
+ intrinsics = gt_intrinsics if gt_intrinsics is not None else pred_intrinsics
186
+ angles = generate_rays(intrinsics, (H, W), noisy=False)[-1]
187
+ angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W)
188
+ points_3d = torch.cat((angles, predictions), dim=1)
189
+ points_3d = spherical_zbuffer_to_euclidean(
190
+ points_3d.permute(0, 2, 3, 1)
191
+ ).permute(0, 3, 1, 2)
192
+
193
+ # output data
194
+ outputs = {
195
+ "intrinsics": intrinsics,
196
+ "points": points_3d,
197
+ "depth": predictions[:, -1:],
198
+ "depth_feats": features,
199
+ "rays": rays,
200
+ "padding": pads
201
+ }
202
+ model.pixel_decoder.test_fixed_camera = False
203
+ model.pixel_decoder.skip_camera = False
204
+ return inputs, outputs
205
+
206
+ class UniDepthDepth(nn.Module):
207
+ def __init__(
208
+ self,
209
+ cfg,
210
+ return_raw_preds=False
211
+ ):
212
+ super().__init__()
213
+
214
+ self.cfg = cfg
215
+ self.return_raw_preds = return_raw_preds
216
+
217
+ if "cnvnxtl" in cfg.model.name:
218
+ self.depth_prediction_model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-cnvnxtl")
219
+ elif "vit" in cfg.model.name:
220
+ self.depth_prediction_model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14")
221
+
222
+ self.skip_camera = True
223
+
224
+ def get_depth(self, img, intrinsics):
225
+ depth_inputs, outputs = unidepthv1_forward(
226
+ self.depth_prediction_model,
227
+ img,
228
+ intrinsics,
229
+ self.skip_camera,
230
+ return_raw_preds=self.return_raw_preds)
231
+ return outputs
232
+
233
+ def forward(self, inputs):
234
+ input_img = inputs["color_aug", 0, 0]
235
+ # here we need the intrinsics of the source image to condition on
236
+ # the depth prediction. needs to account for padding
237
+ if ("K_src", 0) in inputs:
238
+ intrinsics = inputs[("K_src", 0)]
239
+ else:
240
+ intrinsics = None
241
+
242
+ depth_inputs, outputs = unidepthv1_forward(
243
+ self.depth_prediction_model,
244
+ input_img,
245
+ intrinsics,
246
+ self.skip_camera,
247
+ return_raw_preds=self.return_raw_preds)
248
+
249
+ return depth_inputs, outputs
250
+
251
+ class UniDepthUnprojector(nn.Module):
252
+ def __init__(
253
+ self,
254
+ cfg
255
+ ):
256
+ super().__init__()
257
+
258
+ self.cfg = cfg
259
+
260
+ if cfg.model.name == "unidepth_unprojector_cnvnxtl":
261
+ model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-cnvnxtl")
262
+ elif cfg.model.name == "unidepth_unprojector_vit":
263
+ model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14")
264
+ self.unidepth = model
265
+
266
+ self.skip_camera = True
267
+
268
+ self.register_buffer("gauss_opacity", torch.ones(1, 1, 1).float())
269
+ self.register_buffer("gauss_scaling", torch.ones(3, 1, 1).float())
270
+ self.register_buffer("gauss_rotation", torch.ones(4, 1, 1).float() * 0.5)
271
+ self.register_buffer("gauss_features_rest", torch.zeros(9, 1, 1).float())
272
+ self.register_buffer("gauss_offset", torch.zeros(3, 1, 1).float())
273
+
274
+ self.all_params = nn.ParameterDict({
275
+ "opacity_scaling": nn.Parameter(torch.tensor(cfg.model.opacity_bias).float()),
276
+ "scale_scaling": nn.Parameter(torch.tensor(cfg.model.scale_bias).float()),
277
+ "colour_scaling": nn.Parameter(torch.tensor(self.cfg.model.colour_scale).float())})
278
+
279
+
280
+ self.scaling_activation = torch.exp
281
+ self.opacity_activation = torch.sigmoid
282
+ self.relu = nn.ReLU()
283
+
284
+ def get_parameter_groups(self):
285
+ # tune scalars for size, opacity and colour modulation
286
+ return [{'params': self.all_params.parameters()}]
287
+
288
+ def forward(self, inputs):
289
+ model = self.unidepth
290
+ input_img = inputs["color_aug", 0, 0]
291
+ # here we need the intrinsics of the source image to condition on
292
+ # the depth prediction. needs to account for padding
293
+ intrinsics = inputs[("K_src", 0)]
294
+ b, c, h, w = inputs["color_aug", 0, 0].shape
295
+
296
+ with torch.no_grad():
297
+ _, depth_outs = unidepthv1_forward(model, input_img, intrinsics, self.skip_camera)
298
+
299
+ outs = {}
300
+
301
+ outs[("gauss_opacity", 0)] = self.gauss_opacity.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) \
302
+ * self.opacity_activation(self.all_params["opacity_scaling"])
303
+ if not self.cfg.model.scale_with_depth:
304
+ outs[("gauss_scaling", 0)] = self.gauss_scaling.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) \
305
+ * self.scaling_activation(self.all_params["scale_scaling"])
306
+ else:
307
+ outs[("gauss_scaling", 0)] = self.gauss_scaling.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) \
308
+ * self.scaling_activation(self.all_params["scale_scaling"]) * depth_outs["depth"] / 10.0
309
+ outs[("gauss_rotation", 0)] = self.gauss_rotation.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w)
310
+ outs[("gauss_offset", 0)] = self.gauss_offset.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w)
311
+ outs[("gauss_features_rest", 0)] = self.gauss_features_rest.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w)
312
+ # rendering adds 0.5 to go from rendered colours to output
313
+ outs[("gauss_features_dc", 0)] = (input_img - 0.5)* self.relu(self.all_params["colour_scaling"])
314
+
315
+ outs[("depth", 0)] = depth_outs["depth"]
316
+
317
+ return outs
318
+
319
+ class UniDepthSplatter(nn.Module):
320
+ def __init__(
321
+ self,
322
+ cfg
323
+ ):
324
+ super().__init__()
325
+
326
+ self.cfg = cfg
327
+
328
+ config_path = Path("/work/eldar/src/UniDepth")
329
+ with open(config_path / "configs/config_v1_cnvnxtl.json") as f:
330
+ config = json.load(f)
331
+ self.unidepth = UniDepthDepth(self.cfg)
332
+
333
+ hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"]
334
+ expansion = config["model"]["expansion"]
335
+ depth = config["model"]["pixel_decoder"]["depths"]
336
+ num_heads = config["model"]["num_heads"]
337
+ dropout = config["model"]["pixel_decoder"]["dropout"]
338
+ layer_scale = 1.0
339
+ self.splat_decoder = GaussSplatHead(
340
+ cfg,
341
+ hidden_dim=hidden_dim,
342
+ num_heads=num_heads,
343
+ expansion=expansion,
344
+ depths=depth,
345
+ camera_dim=81,
346
+ dropout=dropout,
347
+ layer_scale=layer_scale,
348
+ )
349
+
350
+ self.skip_camera = True
351
+
352
+ def get_parameter_groups(self):
353
+ base_lr = self.cfg.optimiser.learning_rate
354
+ return [
355
+ {'params': self.unidepth.parameters(), "lr": base_lr * 0.05},
356
+ {'params': self.splat_decoder.parameters()}
357
+ ]
358
+
359
+ def forward(self, inputs):
360
+ gauss_head = self.splat_decoder
361
+
362
+ depth_inputs, depth_outs = self.unidepth(inputs)
363
+ depth_feats = depth_outs["depth_feats"]
364
+ rays = depth_outs["rays"]
365
+ padding = depth_outs["padding"]
366
+
367
+ B, _, H, W = depth_inputs["image"].shape
368
+
369
+ # TODO remove hardcoded shapes
370
+ common_shape = (28, 38)
371
+ gauss_head.set_shapes(common_shape)
372
+ gauss_head.set_original_shapes((H, W))
373
+
374
+ depth_feats = rearrange(depth_feats, "b c h w -> b (h w) c")
375
+ outs = gauss_head(
376
+ latents_16=depth_feats,
377
+ rays_hr=rays,
378
+ )
379
+ for k, v in outs.items():
380
+ pred, _ = _postprocess([v], None, self.unidepth.depth_prediction_model.image_shape,
381
+ padding, None, inputs["color_aug", 0, 0].shape[2:4])
382
+ outs[k] = pred
383
+ outs[("depth", 0)] = depth_outs["depth"]
384
+
385
+ return outs
386
+
387
+
388
+ class GaussSplatHead(nn.Module):
389
+ def __init__(
390
+ self,
391
+ cfg,
392
+ hidden_dim: int,
393
+ num_heads: int = 8,
394
+ expansion: int = 4,
395
+ depths: int | list[int] = 4,
396
+ camera_dim: int = 256,
397
+ dropout: float = 0.0,
398
+ layer_scale: float = 1.0,
399
+ ) -> None:
400
+ super().__init__()
401
+
402
+ self.cfg = cfg
403
+
404
+ if isinstance(depths, int):
405
+ depths = [depths] * 3
406
+ assert len(depths) == 3
407
+
408
+ self.project_rays16 = MLP(
409
+ camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim
410
+ )
411
+ self.project_rays8 = MLP(
412
+ camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 2
413
+ )
414
+ self.project_rays4 = MLP(
415
+ camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 4
416
+ )
417
+
418
+ self.layers_8 = nn.ModuleList([])
419
+ self.layers_4 = nn.ModuleList([])
420
+ layers_16 = nn.ModuleList([])
421
+
422
+ self.up8 = ConvUpsample(
423
+ hidden_dim, expansion=expansion, layer_scale=layer_scale
424
+ )
425
+ self.up4 = ConvUpsample(
426
+ hidden_dim // 2, expansion=expansion, layer_scale=layer_scale
427
+ )
428
+ self.up2 = ConvUpsample(
429
+ hidden_dim // 4, expansion=expansion, layer_scale=layer_scale
430
+ )
431
+
432
+ split_dimensions, scale, bias = get_splits_and_inits(cfg)
433
+ start = 1
434
+ self.split_dimensions = split_dimensions[start:]
435
+ scale = scale[start:]
436
+ bias = bias[start:]
437
+
438
+ self.num_output_channels = sum(self.split_dimensions)
439
+
440
+ self.out2 = nn.Conv2d(hidden_dim // 8, self.num_output_channels, 3, padding=1)
441
+ # self.out4 = nn.Conv2d(hidden_dim // 4, self.num_output_channels, 3, padding=1)
442
+ # self.out8 = nn.Conv2d(hidden_dim // 2, self.num_output_channels, 3, padding=1)
443
+
444
+ start_channels = 0
445
+ for out_channel, b, s in zip(self.split_dimensions, bias, scale):
446
+ nn.init.xavier_uniform_(
447
+ self.out2.weight[start_channels:start_channels+out_channel,
448
+ :, :, :], s)
449
+ nn.init.constant_(
450
+ self.out2.bias[start_channels:start_channels+out_channel], b)
451
+ start_channels += out_channel
452
+
453
+ for i, (blk_lst, depth) in enumerate(
454
+ zip([layers_16, self.layers_8, self.layers_4], depths)
455
+ ):
456
+ if i == 0:
457
+ continue
458
+ attn_cls = AttentionBlock if i == 0 else NystromBlock
459
+ for _ in range(depth):
460
+ blk_lst.append(
461
+ attn_cls(
462
+ hidden_dim // (2**i),
463
+ num_heads=num_heads // (2**i),
464
+ expansion=expansion,
465
+ dropout=dropout,
466
+ layer_scale=layer_scale,
467
+ )
468
+ )
469
+
470
+ self.scaling_activation = torch.exp
471
+ self.opacity_activation = torch.sigmoid
472
+ self.rotation_activation = torch.nn.functional.normalize
473
+ self.scaling_lambda = cfg.model.scale_lambda
474
+ self.sigmoid = nn.Sigmoid()
475
+
476
+ def set_original_shapes(self, shapes: Tuple[int, int]):
477
+ self.original_shapes = shapes
478
+
479
+ def set_shapes(self, shapes: Tuple[int, int]):
480
+ self.shapes = shapes
481
+
482
+ def forward(
483
+ self, latents_16: torch.Tensor, rays_hr: torch.Tensor
484
+ ) -> torch.Tensor:
485
+ shapes = self.shapes
486
+
487
+ # camera_embedding
488
+ # torch.cuda.synchronize()
489
+ # start = time()
490
+ rays_embedding_16 = F.normalize(
491
+ flat_interpolate(rays_hr, old=self.original_shapes, new=shapes), dim=-1
492
+ )
493
+ rays_embedding_8 = F.normalize(
494
+ flat_interpolate(
495
+ rays_hr, old=self.original_shapes, new=[x * 2 for x in shapes]
496
+ ),
497
+ dim=-1,
498
+ )
499
+ rays_embedding_4 = F.normalize(
500
+ flat_interpolate(
501
+ rays_hr, old=self.original_shapes, new=[x * 4 for x in shapes]
502
+ ),
503
+ dim=-1,
504
+ )
505
+ rays_embedding_16 = self.project_rays16(rsh_cart_8(rays_embedding_16))
506
+ rays_embedding_8 = self.project_rays8(rsh_cart_8(rays_embedding_8))
507
+ rays_embedding_4 = self.project_rays4(rsh_cart_8(rays_embedding_4))
508
+
509
+ # Block 16 - Out 8
510
+ latents_8 = self.up8(
511
+ rearrange(
512
+ latents_16 + rays_embedding_16,
513
+ "b (h w) c -> b c h w",
514
+ h=shapes[0],
515
+ w=shapes[1],
516
+ ).contiguous()
517
+ )
518
+ # out8 = self.out8(
519
+ # rearrange(
520
+ # latents_8, "b (h w) c -> b c h w", h=shapes[0] * 2, w=shapes[1] * 2
521
+ # )
522
+ # )
523
+
524
+ # Block 8 - Out 4
525
+ for layer in self.layers_8:
526
+ latents_8 = layer(latents_8, pos_embed=rays_embedding_8)
527
+ latents_4 = self.up4(
528
+ rearrange(
529
+ latents_8 + rays_embedding_8,
530
+ "b (h w) c -> b c h w",
531
+ h=shapes[0] * 2,
532
+ w=shapes[1] * 2,
533
+ ).contiguous()
534
+ )
535
+ # out4 = self.out4(
536
+ # rearrange(
537
+ # latents_4, "b (h w) c -> b c h w", h=shapes[0] * 4, w=shapes[1] * 4
538
+ # )
539
+ # )
540
+
541
+ # Block 4 - Out 2
542
+ for layer in self.layers_4:
543
+ latents_4 = layer(latents_4, pos_embed=rays_embedding_4)
544
+ latents_2 = self.up2(
545
+ rearrange(
546
+ latents_4 + rays_embedding_4,
547
+ "b (h w) c -> b c h w",
548
+ h=shapes[0] * 4,
549
+ w=shapes[1] * 4,
550
+ ).contiguous()
551
+ )
552
+ out2 = self.out2(
553
+ rearrange(
554
+ latents_2, "b (h w) c -> b c h w", h=shapes[0] * 8, w=shapes[1] * 8
555
+ )
556
+ )
557
+
558
+ split_network_outputs = out2.split(self.split_dimensions, dim=1)
559
+ last = 5
560
+ offset, opacity, scaling, rotation, feat_dc = split_network_outputs[:last]
561
+
562
+ out = {
563
+ ("gauss_opacity", 0): self.opacity_activation(opacity),
564
+ ("gauss_scaling", 0): self.scaling_activation(scaling) * self.scaling_lambda,
565
+ ("gauss_rotation", 0): self.rotation_activation(rotation),
566
+ ("gauss_features_dc", 0): feat_dc
567
+ }
568
+
569
+ if self.cfg.model.max_sh_degree > 0:
570
+ features_rest = split_network_outputs[last]
571
+ out[("gauss_features_rest", 0)] = features_rest
572
+
573
+ if self.cfg.model.predict_offset:
574
+ out[("gauss_offset", 0)] = offset
575
+
576
+ return out
577
+ # return out8, out4, out2, proj_latents_16
flash3d/networks/unidepth_extension.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ from .unidepth import UniDepthDepth
7
+ from unidepth.models import UniDepthV1
8
+ from .resnet_encoder import ResnetEncoder
9
+ from .gaussian_decoder import GaussianDecoder
10
+ from .depth_decoder import DepthDecoder
11
+
12
+ from networks.layers import disp_to_depth
13
+ from networks.gaussian_decoder import get_splits_and_inits
14
+
15
+
16
+ class UniDepthExtended(nn.Module):
17
+ def __init__(self,cfg):
18
+ super().__init__()
19
+
20
+ self.cfg = cfg
21
+
22
+ self.unidepth = UniDepthDepth(cfg)
23
+ # self.unidepth = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14")
24
+
25
+ self.parameters_to_train = []
26
+ if self.cfg.model.splat_branch == "resnet":
27
+ self.encoder = ResnetEncoder(cfg.model.num_layers,
28
+ cfg.model.weights_init == "pretrained",
29
+ cfg.model.resnet_bn_order
30
+ )
31
+ # change encoder to take depth as conditioning
32
+ if self.cfg.model.depth_cond:
33
+ self.encoder.encoder.conv1 = nn.Conv2d(
34
+ 4,
35
+ self.encoder.encoder.conv1.out_channels,
36
+ kernel_size = self.encoder.encoder.conv1.kernel_size,
37
+ padding = self.encoder.encoder.conv1.padding,
38
+ stride = self.encoder.encoder.conv1.stride
39
+ )
40
+ self.parameters_to_train += [{"params": self.encoder.parameters()}]
41
+
42
+ # use depth branch only for more gaussians
43
+ if cfg.model.gaussians_per_pixel > 1:
44
+ models ={}
45
+ models["depth"] = DepthDecoder(cfg, self.encoder.num_ch_enc)
46
+ self.parameters_to_train +=[{"params": models["depth"].parameters()}]
47
+ for i in range(cfg.model.gaussians_per_pixel):
48
+ models["gauss_decoder_"+str(i)] = GaussianDecoder(cfg, self.encoder.num_ch_enc)
49
+ self.parameters_to_train += [{"params": models["gauss_decoder_"+str(i)].parameters()}]
50
+ if cfg.model.one_gauss_decoder:
51
+ break
52
+ self.models = nn.ModuleDict(models)
53
+ else:
54
+ self.gauss_decoder = GaussianDecoder(cfg, self.encoder.num_ch_enc)
55
+ self.parameters_to_train += [{"params": self.gauss_decoder.parameters()}]
56
+
57
+ elif self.cfg.model.splat_branch == "unidepth_vit" or self.cfg.model.splat_branch == "unidepth_cnvnxtl":
58
+ self.splat_branch = UniDepthDepth(cfg,
59
+ return_raw_preds=True)
60
+ # modify the head to output the channels for Gaussian parameters
61
+ self.init_ouput_head_splat_branch()
62
+ self.parameters_to_train +=[{"params": self.splat_branch.parameters()}]
63
+
64
+ self.scaling_activation = torch.exp
65
+ self.opacity_activation = torch.sigmoid
66
+ self.rotation_activation = torch.nn.functional.normalize
67
+
68
+ def init_ouput_head_splat_branch(self):
69
+ split_dimensions, scale, bias = get_splits_and_inits(self.cfg)
70
+ # the first dim in the output is for depth - we don't use that in this branch
71
+ self.split_dimensions = split_dimensions[1:]
72
+ scale = scale[1:]
73
+ bias = bias[1:]
74
+
75
+ self.num_output_channels = sum(self.split_dimensions)
76
+
77
+ self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2 = \
78
+ nn.Conv2d(self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.in_channels,
79
+ self.num_output_channels,
80
+ kernel_size = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.kernel_size,
81
+ padding = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.padding)
82
+
83
+ self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4 = \
84
+ nn.Conv2d(self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.in_channels,
85
+ self.num_output_channels,
86
+ kernel_size = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.kernel_size,
87
+ padding = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.padding)
88
+
89
+ self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8 = \
90
+ nn.Conv2d(self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.in_channels,
91
+ self.num_output_channels,
92
+ kernel_size = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.kernel_size,
93
+ padding = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.padding)
94
+
95
+ start_channels = 0
96
+ for out_channel, b, s in zip(split_dimensions, bias, scale):
97
+ nn.init.xavier_uniform_(
98
+ self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.weight[start_channels:start_channels+out_channel,
99
+ :, :, :], s)
100
+ nn.init.constant_(
101
+ self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.bias[start_channels:start_channels+out_channel], b)
102
+ start_channels += out_channel
103
+
104
+ start_channels = 0
105
+ for out_channel, b, s in zip(split_dimensions, bias, scale):
106
+ nn.init.xavier_uniform_(
107
+ self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.weight[start_channels:start_channels+out_channel,
108
+ :, :, :], s)
109
+ nn.init.constant_(
110
+ self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.bias[start_channels:start_channels+out_channel], b)
111
+ start_channels += out_channel
112
+
113
+ start_channels = 0
114
+ for out_channel, b, s in zip(split_dimensions, bias, scale):
115
+ nn.init.xavier_uniform_(
116
+ self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.weight[start_channels:start_channels+out_channel,
117
+ :, :, :], s)
118
+ nn.init.constant_(
119
+ self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.bias[start_channels:start_channels+out_channel], b)
120
+ start_channels += out_channel
121
+
122
+ def get_parameter_groups(self):
123
+ # only the resnet encoder and gaussian parameter decoder are optimisable
124
+ return self.parameters_to_train
125
+
126
+ def forward(self, inputs):
127
+ if ('unidepth', 0, 0) in inputs.keys() and inputs[('unidepth', 0, 0)] is not None:
128
+ depth_outs = dict()
129
+ depth_outs["depth"] = inputs[('unidepth', 0, 0)]
130
+ else:
131
+ with torch.no_grad():
132
+ # if self.training and self.cfg.dataset.pad_border_aug > 0:
133
+ # pad = self.cfg.dataset.pad_border_aug
134
+ # input = inputs["color_aug", 0, 0][:,:,pad:-pad, pad:-pad]
135
+ # intrincs = inputs[("K_tgt", 0)]
136
+ # else:
137
+ # input = inputs["color_aug", 0, 0]
138
+ # intrincs = inputs[("K_src", 0)]
139
+ _, depth_outs = self.unidepth(inputs)
140
+ # depth_outs = self.unidepth.infer(input, intrincs)
141
+ # if self.training and self.cfg.dataset.pad_border_aug > 0:
142
+ # depth_outs["depth"] = F.pad(depth_outs["depth"], (pad,pad,pad,pad), mode="replicate")
143
+
144
+ outputs_gauss = {}
145
+
146
+ K = depth_outs["intrinsics"]
147
+ outputs_gauss[("K_src", 0)] = K
148
+ outputs_gauss[("inv_K_src", 0)] = torch.linalg.inv(K)
149
+
150
+ if self.cfg.model.splat_branch == "resnet":
151
+ if self.cfg.model.depth_cond:
152
+ # division by 20 is to put depth in a similar range to RGB
153
+ resnet_input = torch.cat([inputs["color_aug", 0, 0],
154
+ depth_outs["depth"] / 20.0], dim=1)
155
+ else:
156
+ resnet_input = inputs["color_aug", 0, 0]
157
+ resnet_features = self.encoder(resnet_input)
158
+ if self.cfg.model.gaussians_per_pixel > 1:
159
+ pred_depth = dict()
160
+ depth = self.models["depth"](resnet_features)
161
+ if self.cfg.model.depth_type == "disp":
162
+ for key, v in depth.items():
163
+ _, pred_depth[("depth", key[1])] = disp_to_depth(v, self.cfg.model.min_depth, self.cfg.model.max_depth)
164
+ elif self.cfg.model.depth_type in ["depth", "depth_inc"]:
165
+ pred_depth = depth
166
+ pred_depth[("depth", 0)] = rearrange(pred_depth[("depth", 0)], "(b n) ... -> b n ...", n=self.cfg.model.gaussians_per_pixel - 1)
167
+ if self.cfg.model.depth_type in ["depth_inc", "disp_inc"]:
168
+ pred_depth[("depth", 0)] = torch.cumsum(torch.cat((depth_outs["depth"][:,None,...], pred_depth[("depth", 0)]), dim=1), dim=1)
169
+ else:
170
+ pred_depth[("depth", 0)] = torch.cat((depth_outs["depth"][:,None,...], pred_depth[("depth", 0)]), dim=1)
171
+ outputs_gauss[("depth", 0)] = rearrange(pred_depth[("depth", 0)], "b n c ... -> (b n) c ...", n = self.cfg.model.gaussians_per_pixel)
172
+ gauss_outs = dict()
173
+ for i in range(self.cfg.model.gaussians_per_pixel):
174
+ outs = self.models["gauss_decoder_"+str(i)](resnet_features)
175
+ if not self.cfg.model.one_gauss_decoder:
176
+ for key, v in outs.items():
177
+ gauss_outs[key] = outs[key][:,None,...] if i==0 else torch.cat([gauss_outs[key], outs[key][:,None,...]], dim=1)
178
+ else:
179
+ gauss_outs |= outs
180
+ for key, v in gauss_outs.items():
181
+ gauss_outs[key] = rearrange(gauss_outs[key], 'b n ... -> (b n) ...')
182
+ outputs_gauss |= gauss_outs
183
+ else:
184
+ outputs_gauss[("depth", 0)] = depth_outs["depth"]
185
+ outputs_gauss |= self.gauss_decoder(resnet_features)
186
+ elif self.cfg.model.splat_branch == "unidepth_vit" or self.cfg.model.splat_branch == "unidepth_cnvnxtl":
187
+ split_network_outputs = self.splat_branch(inputs)[1].split(self.split_dimensions, dim=1)
188
+ offset, opacity, scaling, rotation, feat_dc = split_network_outputs[:5]
189
+
190
+ outputs_gauss |= {
191
+ ("gauss_opacity", 0): self.opacity_activation(opacity),
192
+ ("gauss_scaling", 0): self.scaling_activation(scaling),
193
+ ("gauss_rotation", 0): self.rotation_activation(rotation),
194
+ ("gauss_features_dc", 0): feat_dc
195
+ }
196
+
197
+ if self.cfg.model.max_sh_degree > 0:
198
+ features_rest = split_network_outputs[5]
199
+ outputs_gauss[("gauss_features_rest", 0)] = features_rest
200
+
201
+ assert self.cfg.model.predict_offset
202
+ outputs_gauss[("gauss_offset", 0)] = offset
203
+
204
+ return outputs_gauss
205
+
flash3d/unidepth/layers/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .activation import SwiGLU, GEGLU
2
+ from .convnext import CvnxtBlock
3
+ from .attention import AttentionBlock, AttentionDecoderBlock
4
+ from .nystrom_attention import NystromBlock
5
+ from .positional_encoding import PositionEmbeddingSine
6
+ from .upsample import ConvUpsample, ConvUpsampleShuffle
7
+ from .mlp import MLP
8
+
9
+
10
+ __all__ = [
11
+ "SwiGLU",
12
+ "GEGLU",
13
+ "CvnxtBlock",
14
+ "AttentionBlock",
15
+ "NystromBlock",
16
+ "PositionEmbeddingSine",
17
+ "ConvUpsample",
18
+ "MLP",
19
+ "ConvUpsampleShuffle",
20
+ "AttentionDecoderBlock",
21
+ ]
flash3d/unidepth/layers/activation.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SwiGLU(nn.Module):
7
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
8
+ x, gates = x.chunk(2, dim=-1)
9
+ return x * F.silu(gates)
10
+
11
+
12
+ class GEGLU(nn.Module):
13
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
14
+ x, gates = x.chunk(2, dim=-1)
15
+ return x * F.gelu(gates)
flash3d/unidepth/layers/attention.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ from functools import partial
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+
13
+ from .layer_scale import LayerScale
14
+ from .mlp import MLP
15
+
16
+
17
+ class SimpleAttention(nn.Module):
18
+ def __init__(
19
+ self,
20
+ dim: int,
21
+ num_heads: int = 4,
22
+ dropout: float = 0.0,
23
+ cosine: bool = False,
24
+ context_dim: int | None = None,
25
+ ):
26
+ super().__init__()
27
+ self.dropout = dropout
28
+ self.num_heads = num_heads
29
+ self.hidden_dim = dim
30
+ context_dim = context_dim or dim
31
+
32
+ self.kv = nn.Linear(context_dim, dim * 2, bias=False)
33
+ self.q = nn.Linear(dim, dim, bias=False)
34
+ self.norm_attnx = nn.LayerNorm(dim)
35
+ self.norm_attnctx = nn.LayerNorm(context_dim)
36
+ self.cosine = cosine
37
+ self.out = nn.Linear(dim, dim)
38
+
39
+ def forward(
40
+ self,
41
+ x: torch.Tensor,
42
+ attn_bias: torch.Tensor | None = None,
43
+ context: torch.Tensor | None = None,
44
+ pos_embed: torch.Tensor | None = None,
45
+ pos_embed_context: torch.Tensor | None = None,
46
+ rope: nn.Module | None = None,
47
+ ) -> torch.Tensor:
48
+ context = x if context is None else context
49
+ x = self.norm_attnx(x)
50
+ context = self.norm_attnctx(context)
51
+ k, v = rearrange(
52
+ self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
53
+ ).unbind(dim=-1)
54
+ q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads)
55
+
56
+ if rope is not None:
57
+ q = rope(q)
58
+ k = rope(k)
59
+ else:
60
+ if pos_embed is not None:
61
+ pos_embed = rearrange(
62
+ pos_embed, "b n (h d) -> b h n d", h=self.num_heads
63
+ )
64
+ q = q + pos_embed
65
+ if pos_embed_context is not None:
66
+ pos_embed_context = rearrange(
67
+ pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads
68
+ )
69
+ k = k + pos_embed_context
70
+
71
+ if self.cosine:
72
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
73
+ x = F.scaled_dot_product_attention(
74
+ q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
75
+ )
76
+ x = rearrange(x, "b h n d -> b n (h d)")
77
+ x = self.out(x)
78
+ return x
79
+
80
+
81
+ class AttentionBlock(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim: int,
85
+ num_heads: int = 4,
86
+ expansion: int = 4,
87
+ dropout: float = 0.0,
88
+ cosine: bool = False,
89
+ gated: bool = False,
90
+ layer_scale: float = 1.0,
91
+ context_dim: int | None = None,
92
+ ):
93
+ super().__init__()
94
+ self.dropout = dropout
95
+ self.num_heads = num_heads
96
+ self.hidden_dim = dim
97
+ context_dim = context_dim or dim
98
+ self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
99
+ self.kv = nn.Linear(context_dim, dim * 2)
100
+ self.q = nn.Linear(dim, dim)
101
+ self.norm_attnx = nn.LayerNorm(dim)
102
+ self.norm_attnctx = nn.LayerNorm(context_dim)
103
+ self.cosine = cosine
104
+ self.out = nn.Linear(dim, dim)
105
+ self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
106
+ self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
107
+
108
+ def attn(
109
+ self,
110
+ x: torch.Tensor,
111
+ attn_bias: torch.Tensor | None = None,
112
+ context: torch.Tensor | None = None,
113
+ pos_embed: torch.Tensor | None = None,
114
+ pos_embed_context: torch.Tensor | None = None,
115
+ rope: nn.Module | None = None,
116
+ ) -> torch.Tensor:
117
+ x = self.norm_attnx(x)
118
+ context = self.norm_attnctx(context)
119
+ k, v = rearrange(
120
+ self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
121
+ ).unbind(dim=-1)
122
+ q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads)
123
+
124
+ if rope is not None:
125
+ q = rope(q)
126
+ k = rope(k)
127
+ else:
128
+ if pos_embed is not None:
129
+ pos_embed = rearrange(
130
+ pos_embed, "b n (h d) -> b h n d", h=self.num_heads
131
+ )
132
+ q = q + pos_embed
133
+ if pos_embed_context is not None:
134
+ pos_embed_context = rearrange(
135
+ pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads
136
+ )
137
+ k = k + pos_embed_context
138
+
139
+ if self.cosine:
140
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
141
+
142
+ x = F.scaled_dot_product_attention(
143
+ q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
144
+ )
145
+ x = rearrange(x, "b h n d -> b n (h d)")
146
+ x = self.out(x)
147
+ return x
148
+
149
+ def forward(
150
+ self,
151
+ x: torch.Tensor,
152
+ attn_bias: torch.Tensor | None = None,
153
+ context: torch.Tensor | None = None,
154
+ pos_embed: torch.Tensor | None = None,
155
+ pos_embed_context: torch.Tensor | None = None,
156
+ rope: nn.Module | None = None,
157
+ ) -> torch.Tensor:
158
+ context = x if context is None else context
159
+ x = (
160
+ self.ls1(
161
+ self.attn(
162
+ x,
163
+ rope=rope,
164
+ attn_bias=attn_bias,
165
+ context=context,
166
+ pos_embed=pos_embed,
167
+ pos_embed_context=pos_embed_context,
168
+ )
169
+ )
170
+ + x
171
+ )
172
+ x = self.ls2(self.mlp(x)) + x
173
+ return x
174
+
175
+
176
+ class AttentionDecoderBlock(nn.Module):
177
+ def __init__(
178
+ self,
179
+ dim: int,
180
+ num_heads: int = 4,
181
+ expansion: int = 4,
182
+ dropout: float = 0.0,
183
+ cosine: bool = False,
184
+ gated: bool = False,
185
+ layer_scale: float = 1.0,
186
+ context_dim: int | None = None,
187
+ single_head_ca: bool = True,
188
+ ):
189
+ super().__init__()
190
+ self.dropout = dropout
191
+ self.num_heads = num_heads
192
+ self.hidden_dim = dim
193
+ self.single_head_ca = single_head_ca
194
+ context_dim = context_dim or dim
195
+ self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
196
+ self.kv_ca = nn.Linear(context_dim, dim * 2)
197
+ self.q_ca = nn.Linear(dim, dim)
198
+ self.kv_sa = nn.Linear(dim, dim * 2)
199
+ self.q_sa = nn.Linear(dim, dim)
200
+ self.norm_x_sa = nn.LayerNorm(dim)
201
+ self.norm_x_ca = nn.LayerNorm(dim)
202
+ self.norm_ctx_ca = nn.LayerNorm(context_dim)
203
+ self.cosine = cosine
204
+ self.out_ca = nn.Linear(dim, dim)
205
+ self.out_sa = nn.Linear(dim, dim)
206
+ self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
207
+ self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
208
+ self.ls3 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
209
+
210
+ def cross_attn(
211
+ self,
212
+ x: torch.Tensor,
213
+ attn_bias: torch.Tensor | None = None,
214
+ context: torch.Tensor | None = None,
215
+ pos_embed: torch.Tensor | None = None,
216
+ pos_embed_context: torch.Tensor | None = None,
217
+ rope: nn.Module | None = None,
218
+ ) -> torch.Tensor:
219
+ num_heads = 1 if self.single_head_ca else self.num_heads
220
+ x = self.norm_x_ca(x)
221
+ context = self.norm_ctx_ca(context)
222
+ k, v = rearrange(
223
+ self.kv_ca(context), "b n (kv h d) -> b h n d kv", h=num_heads, kv=2
224
+ ).unbind(dim=-1)
225
+ q = rearrange(self.q_ca(x), "b n (h d) -> b h n d", h=num_heads)
226
+
227
+ if rope is not None:
228
+ q = rope(q)
229
+ k = rope(k)
230
+ else:
231
+ if pos_embed is not None:
232
+ pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=num_heads)
233
+ q = q + pos_embed
234
+ if pos_embed_context is not None:
235
+ pos_embed_context = rearrange(
236
+ pos_embed_context, "b n (h d) -> b h n d", h=num_heads
237
+ )
238
+ k = k + pos_embed_context
239
+
240
+ if self.cosine:
241
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
242
+ x = F.scaled_dot_product_attention(
243
+ q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
244
+ )
245
+ x = rearrange(x, "b h n d -> b n (h d)")
246
+ x = self.out_ca(x)
247
+ return x
248
+
249
+ def self_attn(
250
+ self,
251
+ x: torch.Tensor,
252
+ attn_bias: torch.Tensor | None = None,
253
+ pos_embed: torch.Tensor | None = None,
254
+ rope: nn.Module | None = None,
255
+ ) -> torch.Tensor:
256
+ x = self.norm_x_sa(x)
257
+ k, v = rearrange(
258
+ self.kv_sa(x), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
259
+ ).unbind(dim=-1)
260
+ q = rearrange(self.q_sa(x), "b n (h d) -> b h n d", h=self.num_heads)
261
+
262
+ if rope is not None:
263
+ q = rope(q)
264
+ k = rope(k)
265
+ elif pos_embed is not None:
266
+ pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=self.num_heads)
267
+ q = q + pos_embed
268
+
269
+ if self.cosine:
270
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
271
+ x = F.scaled_dot_product_attention(
272
+ q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
273
+ )
274
+ x = rearrange(x, "b h n d -> b n (h d)")
275
+ x = self.out_sa(x)
276
+ return x
277
+
278
+ def forward(
279
+ self,
280
+ x: torch.Tensor,
281
+ attn_bias: torch.Tensor | None = None,
282
+ context: torch.Tensor | None = None,
283
+ pos_embed: torch.Tensor | None = None,
284
+ pos_embed_context: torch.Tensor | None = None,
285
+ rope: nn.Module | None = None,
286
+ ) -> torch.Tensor:
287
+ context = x if context is None else context
288
+ x = (
289
+ self.ls1(
290
+ self.cross_attn(
291
+ x,
292
+ rope=rope,
293
+ attn_bias=attn_bias,
294
+ context=context,
295
+ pos_embed=pos_embed,
296
+ pos_embed_context=pos_embed_context,
297
+ )
298
+ )
299
+ + x
300
+ )
301
+ x = (
302
+ self.ls2(
303
+ self.self_attn(x, rope=rope, attn_bias=attn_bias, pos_embed=pos_embed)
304
+ )
305
+ + x
306
+ )
307
+ x = self.ls3(self.mlp(x)) + x
308
+ return x
flash3d/unidepth/layers/convnext.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class CvnxtBlock(nn.Module):
6
+ def __init__(
7
+ self,
8
+ dim,
9
+ kernel_size=7,
10
+ layer_scale=1.0,
11
+ expansion=4,
12
+ dilation=1,
13
+ ):
14
+ super().__init__()
15
+ self.dwconv = nn.Conv2d(
16
+ dim,
17
+ dim,
18
+ kernel_size=kernel_size,
19
+ padding="same",
20
+ groups=dim,
21
+ dilation=dilation,
22
+ ) # depthwise conv
23
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
24
+ self.pwconv1 = nn.Linear(
25
+ dim, expansion * dim
26
+ ) # pointwise/1x1 convs, implemented with linear layers
27
+ self.act = nn.GELU()
28
+ self.pwconv2 = nn.Linear(expansion * dim, dim)
29
+ self.gamma = (
30
+ nn.Parameter(layer_scale * torch.ones((dim))) if layer_scale > 0.0 else 1.0
31
+ )
32
+
33
+ def forward(self, x):
34
+ input = x
35
+ x = self.dwconv(x)
36
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
37
+ x = self.norm(x)
38
+ x = self.pwconv1(x)
39
+ x = self.act(x)
40
+ x = self.pwconv2(x)
41
+
42
+ x = self.gamma * x
43
+ x = input + x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
44
+ return x
flash3d/unidepth/layers/drop_path.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False):
6
+ if drop_prob == 0.0 or not training:
7
+ return x
8
+ keep_prob = 1 - drop_prob
9
+ shape = (x.shape[0],) + (1,) * (
10
+ x.ndim - 1
11
+ ) # work with diff dim tensors, not just 2D ConvNets
12
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
13
+ if keep_prob > 0.0:
14
+ random_tensor.div_(keep_prob)
15
+ output = x * random_tensor
16
+ return output
17
+
18
+
19
+ class DropPath(nn.Module):
20
+ def __init__(self, drop_prob=None):
21
+ super(DropPath, self).__init__()
22
+ self.drop_prob = drop_prob
23
+
24
+ def forward(self, x):
25
+ return drop_path(x, self.drop_prob, self.training)
flash3d/unidepth/layers/layer_scale.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class LayerScale(nn.Module):
6
+ def __init__(
7
+ self,
8
+ dim: int,
9
+ init_values: float | torch.Tensor = 1e-5,
10
+ inplace: bool = False,
11
+ ) -> None:
12
+ super().__init__()
13
+ self.inplace = inplace
14
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
15
+
16
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
17
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
flash3d/unidepth/layers/mlp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from unidepth.utils.misc import default
5
+ from .activation import SwiGLU
6
+
7
+
8
+ class MLP(nn.Module):
9
+ def __init__(
10
+ self,
11
+ input_dim: int,
12
+ expansion: int = 4,
13
+ dropout: float = 0.0,
14
+ gated: bool = False,
15
+ output_dim: int | None = None,
16
+ ):
17
+ super().__init__()
18
+ if gated:
19
+ expansion = int(expansion * 2 / 3)
20
+ hidden_dim = int(input_dim * expansion)
21
+ output_dim = default(output_dim, input_dim)
22
+ self.norm = nn.LayerNorm(input_dim)
23
+ self.proj1 = nn.Linear(input_dim, hidden_dim)
24
+ self.proj2 = nn.Linear(hidden_dim, output_dim)
25
+ self.act = nn.GELU() if not gated else SwiGLU()
26
+ self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
27
+
28
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
29
+ x = self.norm(x)
30
+ x = self.proj1(x)
31
+ x = self.act(x)
32
+ x = self.proj2(x)
33
+ x = self.dropout(x)
34
+ return x
flash3d/unidepth/layers/nystrom_attention.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from xformers.components.attention import NystromAttention
8
+
9
+ from .attention import AttentionBlock
10
+
11
+
12
+ class NystromBlock(AttentionBlock):
13
+ def __init__(
14
+ self,
15
+ dim: int,
16
+ num_heads: int = 4,
17
+ expansion: int = 4,
18
+ dropout: float = 0.0,
19
+ cosine: bool = False,
20
+ gated: bool = False,
21
+ layer_scale: float = 1.0,
22
+ context_dim: int | None = None,
23
+ ):
24
+ super().__init__(
25
+ dim=dim,
26
+ num_heads=num_heads,
27
+ expansion=expansion,
28
+ dropout=dropout,
29
+ cosine=cosine,
30
+ gated=gated,
31
+ layer_scale=layer_scale,
32
+ context_dim=context_dim,
33
+ )
34
+ self.attention_fn = NystromAttention(
35
+ num_landmarks=128, num_heads=num_heads, dropout=dropout
36
+ )
37
+
38
+ def attn(
39
+ self,
40
+ x: torch.Tensor,
41
+ attn_bias: torch.Tensor | None = None,
42
+ context: torch.Tensor | None = None,
43
+ pos_embed: torch.Tensor | None = None,
44
+ pos_embed_context: torch.Tensor | None = None,
45
+ rope: nn.Module | None = None,
46
+ ) -> torch.Tensor:
47
+ x = self.norm_attnx(x)
48
+ context = self.norm_attnctx(context)
49
+ k, v = rearrange(
50
+ self.kv(context), "b n (kv h d) -> b n h d kv", h=self.num_heads, kv=2
51
+ ).unbind(dim=-1)
52
+ q = rearrange(self.q(x), "b n (h d) -> b n h d", h=self.num_heads)
53
+
54
+ if rope is not None:
55
+ q = rope(q)
56
+ k = rope(k)
57
+ else:
58
+ if pos_embed is not None:
59
+ pos_embed = rearrange(
60
+ pos_embed, "b n (h d) -> b n h d", h=self.num_heads
61
+ )
62
+ q = q + pos_embed
63
+ if pos_embed_context is not None:
64
+ pos_embed_context = rearrange(
65
+ pos_embed_context, "b n (h d) -> b n h d", h=self.num_heads
66
+ )
67
+ k = k + pos_embed_context
68
+
69
+ if self.cosine:
70
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
71
+ x = self.attention_fn(q, k, v, key_padding_mask=attn_bias)
72
+ x = rearrange(x, "b n h d -> b n (h d)")
73
+ x = self.out(x)
74
+ return x
flash3d/unidepth/layers/positional_encoding.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ from math import pi
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from einops import rearrange, repeat
13
+
14
+
15
+ class PositionEmbeddingSine(nn.Module):
16
+ def __init__(
17
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
18
+ ):
19
+ super().__init__()
20
+ self.num_pos_feats = num_pos_feats
21
+ self.temperature = temperature
22
+ self.normalize = normalize
23
+ if scale is not None and normalize is False:
24
+ raise ValueError("normalize should be True if scale is passed")
25
+ if scale is None:
26
+ scale = 2 * pi
27
+ self.scale = scale
28
+
29
+ def forward(
30
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
31
+ ) -> torch.Tensor:
32
+ if mask is None:
33
+ mask = torch.zeros(
34
+ (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
35
+ )
36
+ not_mask = ~mask
37
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
38
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
39
+ if self.normalize:
40
+ eps = 1e-6
41
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
42
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
43
+
44
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
45
+ dim_t = self.temperature ** (
46
+ 2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats
47
+ )
48
+
49
+ pos_x = x_embed[:, :, :, None] / dim_t
50
+ pos_y = y_embed[:, :, :, None] / dim_t
51
+ pos_x = torch.stack(
52
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
53
+ ).flatten(3)
54
+ pos_y = torch.stack(
55
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
56
+ ).flatten(3)
57
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
58
+ return pos
59
+
60
+ def __repr__(self, _repr_indent=4):
61
+ head = "Positional encoding " + self.__class__.__name__
62
+ body = [
63
+ "num_pos_feats: {}".format(self.num_pos_feats),
64
+ "temperature: {}".format(self.temperature),
65
+ "normalize: {}".format(self.normalize),
66
+ "scale: {}".format(self.scale),
67
+ ]
68
+ # _repr_indent = 4
69
+ lines = [head] + [" " * _repr_indent + line for line in body]
70
+ return "\n".join(lines)
71
+
72
+
73
+ class LearnedSinusoidalPosEmb(nn.Module):
74
+ def __init__(self, dim):
75
+ super().__init__()
76
+ assert (dim % 2) == 0
77
+ half_dim = dim // 2
78
+ self.weights = nn.Parameter(torch.randn(half_dim))
79
+
80
+ def forward(self, x):
81
+ x = rearrange(x, "b -> b 1")
82
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
83
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
84
+ fouriered = torch.cat((x, fouriered), dim=-1)
85
+ return fouriered
86
+
87
+
88
+ def generate_fourier_features(x, max_freq=64, num_bands=16):
89
+ x = x.unsqueeze(-1)
90
+ device, dtype, orig_x = x.device, x.dtype, x
91
+
92
+ scales = torch.linspace(
93
+ -max_freq / 2, max_freq / 2, num_bands, device=device, dtype=dtype
94
+ )
95
+ scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
96
+
97
+ x = x * scales * pi
98
+ x = torch.cat([x.sin(), x.cos()], dim=-1)
99
+ x = torch.cat((x, orig_x), dim=-1)
100
+ return x.flatten(-2)
101
+
102
+
103
+ def broadcat(tensors, dim=-1):
104
+ num_tensors = len(tensors)
105
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
106
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
107
+ shape_len = list(shape_lens)[0]
108
+ dim = (dim + shape_len) if dim < 0 else dim
109
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
110
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
111
+ assert all(
112
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
113
+ ), "invalid dimensions for broadcastable concatentation"
114
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
115
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
116
+ expanded_dims.insert(dim, (dim, dims[dim]))
117
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
118
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
119
+ return torch.cat(tensors, dim=dim)
120
+
121
+
122
+ def rotate_half(x):
123
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
124
+ x1, x2 = x.unbind(dim=-1)
125
+ x = torch.stack((-x2, x1), dim=-1)
126
+ return rearrange(x, "... d r -> ... (d r)")
127
+
128
+
129
+ class VisionRotaryEmbedding(nn.Module):
130
+ def __init__(
131
+ self,
132
+ dim,
133
+ pt_seq_len,
134
+ ft_seq_len=None,
135
+ custom_freqs=None,
136
+ freqs_for="lang",
137
+ theta=10000,
138
+ max_freq=10,
139
+ num_freqs=1,
140
+ ):
141
+ super().__init__()
142
+ if custom_freqs:
143
+ freqs = custom_freqs
144
+ elif freqs_for == "lang":
145
+ freqs = 1.0 / (
146
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
147
+ )
148
+ elif freqs_for == "pixel":
149
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
150
+ elif freqs_for == "constant":
151
+ freqs = torch.ones(num_freqs).float()
152
+ else:
153
+ raise ValueError(f"unknown modality {freqs_for}")
154
+
155
+ if ft_seq_len is None:
156
+ ft_seq_len = pt_seq_len
157
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
158
+
159
+ freqs_h = torch.einsum("..., f -> ... f", t, freqs)
160
+ freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
161
+
162
+ freqs_w = torch.einsum("..., f -> ... f", t, freqs)
163
+ freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
164
+
165
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
166
+
167
+ self.register_buffer("freqs_cos", freqs.cos())
168
+ self.register_buffer("freqs_sin", freqs.sin())
169
+
170
+ print("======== shape of rope freq", self.freqs_cos.shape, "========")
171
+
172
+ def forward(self, t, start_index=0):
173
+ rot_dim = self.freqs_cos.shape[-1]
174
+ end_index = start_index + rot_dim
175
+ assert (
176
+ rot_dim <= t.shape[-1]
177
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
178
+ t_left, t, t_right = (
179
+ t[..., :start_index],
180
+ t[..., start_index:end_index],
181
+ t[..., end_index:],
182
+ )
183
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
184
+ return torch.cat((t_left, t, t_right), dim=-1)
185
+
186
+
187
+ class VisionRotaryEmbeddingFast(nn.Module):
188
+ def __init__(
189
+ self,
190
+ dim,
191
+ pt_seq_len,
192
+ ft_seq_len=None,
193
+ custom_freqs=None,
194
+ freqs_for="lang",
195
+ theta=10000,
196
+ max_freq=10,
197
+ num_freqs=1,
198
+ ):
199
+ super().__init__()
200
+ if custom_freqs:
201
+ freqs = custom_freqs
202
+ elif freqs_for == "lang":
203
+ freqs = 1.0 / (
204
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
205
+ )
206
+ elif freqs_for == "pixel":
207
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
208
+ elif freqs_for == "constant":
209
+ freqs = torch.ones(num_freqs).float()
210
+ else:
211
+ raise ValueError(f"unknown modality {freqs_for}")
212
+
213
+ if ft_seq_len is None:
214
+ ft_seq_len = pt_seq_len
215
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
216
+
217
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
218
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
219
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
220
+
221
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
222
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
223
+
224
+ self.register_buffer("freqs_cos", freqs_cos)
225
+ self.register_buffer("freqs_sin", freqs_sin)
226
+
227
+ def forward(self, t):
228
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
flash3d/unidepth/layers/upsample.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+
10
+ from .convnext import CvnxtBlock
11
+
12
+
13
+ class ConvUpsample(nn.Module):
14
+ def __init__(
15
+ self,
16
+ hidden_dim,
17
+ num_layers: int = 2,
18
+ expansion: int = 4,
19
+ layer_scale: float = 1.0,
20
+ kernel_size: int = 7,
21
+ **kwargs
22
+ ):
23
+ super().__init__()
24
+ self.convs = nn.ModuleList([])
25
+ for _ in range(num_layers):
26
+ self.convs.append(
27
+ CvnxtBlock(
28
+ hidden_dim,
29
+ kernel_size=kernel_size,
30
+ expansion=expansion,
31
+ layer_scale=layer_scale,
32
+ )
33
+ )
34
+ self.up = nn.Sequential(
35
+ nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0),
36
+ nn.UpsamplingBilinear2d(scale_factor=2),
37
+ nn.Conv2d(hidden_dim // 2, hidden_dim // 2, kernel_size=3, padding=1),
38
+ )
39
+
40
+ def forward(self, x: torch.Tensor):
41
+ for conv in self.convs:
42
+ x = conv(x)
43
+ x = self.up(x)
44
+ x = rearrange(x, "b c h w -> b (h w) c")
45
+ return x
46
+
47
+
48
+ class ConvUpsampleShuffle(nn.Module):
49
+ def __init__(
50
+ self, hidden_dim, expansion: int = 4, layer_scale: float = 1.0, **kwargs
51
+ ):
52
+ super().__init__()
53
+ self.conv1 = CvnxtBlock(
54
+ hidden_dim, expansion=expansion, layer_scale=layer_scale
55
+ )
56
+ self.conv2 = CvnxtBlock(
57
+ hidden_dim, expansion=expansion, layer_scale=layer_scale
58
+ )
59
+ self.up = nn.Sequential(
60
+ nn.PixelShuffle(2),
61
+ nn.Conv2d(hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1),
62
+ )
63
+
64
+ def forward(self, x: torch.Tensor):
65
+ x = self.conv1(x)
66
+ x = self.conv2(x)
67
+ x = self.up(x)
68
+ x = rearrange(x, "b c h w -> b (h w) c")
69
+ return x
flash3d/unidepth/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .unidepthv1 import UniDepthV1
2
+
3
+ __all__ = [
4
+ "UniDepthV1",
5
+ ]
flash3d/unidepth/models/backbones/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .convnext2 import ConvNeXtV2
2
+ from .convnext import ConvNeXt
3
+ from .dinov2 import _make_dinov2_model
4
+
5
+ __all__ = [
6
+ "ConvNeXt",
7
+ "ConvNeXtV2",
8
+ "_make_dinov2_model",
9
+ ]
flash3d/unidepth/models/backbones/convnext.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from functools import partial
3
+ from typing import Callable, Optional, Tuple, Union, Sequence
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.checkpoint import checkpoint
8
+
9
+ from timm.layers import (
10
+ trunc_normal_,
11
+ AvgPool2dSame,
12
+ DropPath,
13
+ Mlp,
14
+ GlobalResponseNormMlp,
15
+ LayerNorm2d,
16
+ LayerNorm,
17
+ create_conv2d,
18
+ get_act_layer,
19
+ make_divisible,
20
+ to_ntuple,
21
+ )
22
+
23
+
24
+ def get_num_layer_for_convnext(var_name):
25
+ """
26
+ Divide [3, 3, 27, 3] layers into 12 groups; each group is three
27
+ consecutive blocks, including possible neighboring downsample layers;
28
+ adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
29
+ """
30
+ if var_name.startswith("downsample_layers"):
31
+ stage_id = int(var_name.split(".")[1])
32
+ if stage_id == 0:
33
+ layer_id = 0
34
+ elif stage_id == 1 or stage_id == 2:
35
+ layer_id = stage_id + 1
36
+ elif stage_id == 3:
37
+ layer_id = 12
38
+
39
+ elif var_name.startswith("stages"):
40
+ stage_id = int(var_name.split(".")[1])
41
+ block_id = int(var_name.split(".")[3])
42
+ if stage_id == 0 or stage_id == 1:
43
+ layer_id = stage_id + 1
44
+ elif stage_id == 2:
45
+ layer_id = 3 + block_id // 3
46
+ elif stage_id == 3:
47
+ layer_id = 12
48
+
49
+ elif var_name.startswith("stem"):
50
+ return 0
51
+ else:
52
+ layer_id = 12
53
+ return layer_id + 1
54
+
55
+
56
+ def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=None):
57
+ parameter_group_names = {}
58
+ parameter_group_vars = {}
59
+ skip = set()
60
+ if skip_list is not None:
61
+ skip = skip_list
62
+ if hasattr(model, "no_weight_decay"):
63
+ skip.update(model.no_weight_decay())
64
+ num_layers = 12
65
+ layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2))
66
+ for name, param in model.named_parameters():
67
+ if not param.requires_grad:
68
+ continue # frozen weights
69
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip:
70
+ group_name = "no_decay"
71
+ this_wd = 0.0
72
+ else:
73
+ group_name = "decay"
74
+ this_wd = wd
75
+
76
+ layer_id = get_num_layer_for_convnext(name)
77
+ group_name = "layer_%d_%s" % (layer_id, group_name)
78
+
79
+ if group_name not in parameter_group_names:
80
+ scale = layer_scale[layer_id]
81
+ cur_lr = lr * scale
82
+
83
+ parameter_group_names[group_name] = {
84
+ "weight_decay": this_wd,
85
+ "weight_decay_init": this_wd,
86
+ "weight_decay_base": this_wd,
87
+ "params": [],
88
+ "lr_init": cur_lr,
89
+ "lr_base": lr,
90
+ "lr": cur_lr,
91
+ }
92
+ parameter_group_vars[group_name] = {
93
+ "weight_decay": this_wd,
94
+ "weight_decay_init": this_wd,
95
+ "weight_decay_base": this_wd,
96
+ "params": [],
97
+ "lr_init": cur_lr,
98
+ "lr_base": lr,
99
+ "lr": cur_lr,
100
+ }
101
+ if this_wd == 0.0:
102
+ parameter_group_names[group_name]["weight_decay_final"] = 0.0
103
+ parameter_group_vars[group_name]["weight_decay_final"] = 0.0
104
+ parameter_group_vars[group_name]["params"].append(param)
105
+ parameter_group_names[group_name]["params"].append(name)
106
+ # from unidepth.utils import is_main_process
107
+ # import json
108
+ # if is_main_process():
109
+ # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
110
+ return list(parameter_group_vars.values()), [
111
+ v["lr"] for k, v in parameter_group_vars.items()
112
+ ]
113
+
114
+
115
+ class Downsample(nn.Module):
116
+ def __init__(self, in_chs, out_chs, stride=1, dilation=1):
117
+ super().__init__()
118
+ avg_stride = stride if dilation == 1 else 1
119
+ if stride > 1 or dilation > 1:
120
+ avg_pool_fn = (
121
+ AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
122
+ )
123
+ self.pool = avg_pool_fn(
124
+ 2, avg_stride, ceil_mode=True, count_include_pad=False
125
+ )
126
+ else:
127
+ self.pool = nn.Identity()
128
+
129
+ if in_chs != out_chs:
130
+ self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
131
+ else:
132
+ self.conv = nn.Identity()
133
+
134
+ def forward(self, x):
135
+ x = self.pool(x)
136
+ x = self.conv(x)
137
+ return x
138
+
139
+
140
+ class ConvNeXtBlock(nn.Module):
141
+ """ConvNeXt Block
142
+ There are two equivalent implementations:
143
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
144
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
145
+
146
+ Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
147
+ choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
148
+ is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ in_chs: int,
154
+ out_chs: Optional[int] = None,
155
+ kernel_size: int = 7,
156
+ stride: int = 1,
157
+ dilation: Union[int, Tuple[int, int]] = (1, 1),
158
+ mlp_ratio: float = 4,
159
+ conv_mlp: bool = False,
160
+ conv_bias: bool = True,
161
+ use_grn: bool = False,
162
+ ls_init_value: Optional[float] = 1e-6,
163
+ act_layer: Union[str, Callable] = "gelu",
164
+ norm_layer: Optional[Callable] = None,
165
+ drop_path: float = 0.0,
166
+ ):
167
+ """
168
+
169
+ Args:
170
+ in_chs: Block input channels.
171
+ out_chs: Block output channels (same as in_chs if None).
172
+ kernel_size: Depthwise convolution kernel size.
173
+ stride: Stride of depthwise convolution.
174
+ dilation: Tuple specifying input and output dilation of block.
175
+ mlp_ratio: MLP expansion ratio.
176
+ conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
177
+ conv_bias: Apply bias for all convolution (linear) layers.
178
+ use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
179
+ ls_init_value: Layer-scale init values, layer-scale applied if not None.
180
+ act_layer: Activation layer.
181
+ norm_layer: Normalization layer (defaults to LN if not specified).
182
+ drop_path: Stochastic depth probability.
183
+ """
184
+ super().__init__()
185
+ out_chs = out_chs or in_chs
186
+ dilation = to_ntuple(2)(dilation)
187
+ act_layer = get_act_layer(act_layer)
188
+ if not norm_layer:
189
+ norm_layer = LayerNorm2d if conv_mlp else LayerNorm
190
+ mlp_layer = partial(
191
+ GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp
192
+ )
193
+ self.use_conv_mlp = conv_mlp
194
+ self.conv_dw = create_conv2d(
195
+ in_chs,
196
+ out_chs,
197
+ kernel_size=kernel_size,
198
+ stride=stride,
199
+ dilation=dilation[0],
200
+ depthwise=True,
201
+ bias=conv_bias,
202
+ )
203
+ self.norm = norm_layer(out_chs)
204
+ self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
205
+ self.gamma = (
206
+ nn.Parameter(ls_init_value * torch.ones(out_chs))
207
+ if ls_init_value is not None
208
+ else None
209
+ )
210
+ if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
211
+ self.shortcut = Downsample(
212
+ in_chs, out_chs, stride=stride, dilation=dilation[0]
213
+ )
214
+ else:
215
+ self.shortcut = nn.Identity()
216
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
217
+
218
+ def forward(self, x):
219
+ shortcut = x
220
+ x = self.conv_dw(x.contiguous())
221
+ if self.use_conv_mlp:
222
+ x = self.norm(x)
223
+ x = self.mlp(x)
224
+ else:
225
+ x = x.permute(0, 2, 3, 1).contiguous()
226
+ x = self.norm(x)
227
+ x = self.mlp(x)
228
+ x = x.permute(0, 3, 1, 2).contiguous()
229
+ if self.gamma is not None:
230
+ x = x.mul(self.gamma.reshape(1, -1, 1, 1))
231
+
232
+ x = self.drop_path(x) + self.shortcut(shortcut)
233
+ return x.contiguous()
234
+
235
+
236
+ class ConvNeXtStage(nn.Module):
237
+ def __init__(
238
+ self,
239
+ in_chs,
240
+ out_chs,
241
+ kernel_size=7,
242
+ stride=2,
243
+ depth=2,
244
+ dilation=(1, 1),
245
+ drop_path_rates=None,
246
+ ls_init_value=1.0,
247
+ conv_mlp=False,
248
+ conv_bias=True,
249
+ use_grn=False,
250
+ act_layer="gelu",
251
+ norm_layer=None,
252
+ norm_layer_cl=None,
253
+ ):
254
+ super().__init__()
255
+ self.grad_checkpointing = False
256
+
257
+ if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
258
+ ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
259
+ pad = (
260
+ "same" if dilation[1] > 1 else 0
261
+ ) # same padding needed if dilation used
262
+ self.downsample = nn.Sequential(
263
+ norm_layer(in_chs),
264
+ create_conv2d(
265
+ in_chs,
266
+ out_chs,
267
+ kernel_size=ds_ks,
268
+ stride=stride,
269
+ dilation=dilation[0],
270
+ padding=pad,
271
+ bias=conv_bias,
272
+ ),
273
+ )
274
+ in_chs = out_chs
275
+ else:
276
+ self.downsample = nn.Identity()
277
+
278
+ drop_path_rates = drop_path_rates or [0.0] * depth
279
+ stage_blocks = []
280
+ for i in range(depth):
281
+ stage_blocks.append(
282
+ ConvNeXtBlock(
283
+ in_chs=in_chs,
284
+ out_chs=out_chs,
285
+ kernel_size=kernel_size,
286
+ dilation=dilation[1],
287
+ drop_path=drop_path_rates[i],
288
+ ls_init_value=ls_init_value,
289
+ conv_mlp=conv_mlp,
290
+ conv_bias=conv_bias,
291
+ use_grn=use_grn,
292
+ act_layer=act_layer,
293
+ norm_layer=norm_layer if conv_mlp else norm_layer_cl,
294
+ )
295
+ )
296
+ in_chs = out_chs
297
+ self.blocks = nn.ModuleList(stage_blocks)
298
+
299
+ def forward(self, x):
300
+ xs = []
301
+ x = self.downsample(x)
302
+ for block in self.blocks:
303
+ if self.grad_checkpointing:
304
+ x = checkpoint(block, x)
305
+ else:
306
+ x = block(x)
307
+ xs.append(x)
308
+ return xs
309
+
310
+
311
+ class ConvNeXt(nn.Module):
312
+ def __init__(
313
+ self,
314
+ in_chans: int = 3,
315
+ output_stride: int = 32,
316
+ depths: Tuple[int, ...] = (3, 3, 9, 3),
317
+ dims: Tuple[int, ...] = (96, 192, 384, 768),
318
+ kernel_sizes: Union[int, Tuple[int, ...]] = 7,
319
+ ls_init_value: Optional[float] = 1e-6,
320
+ stem_type: str = "patch",
321
+ patch_size: int = 4,
322
+ conv_mlp: bool = False,
323
+ conv_bias: bool = True,
324
+ use_grn: bool = False,
325
+ act_layer: Union[str, Callable] = "gelu",
326
+ norm_layer: Optional[Union[str, Callable]] = None,
327
+ norm_eps: Optional[float] = None,
328
+ drop_path_rate: float = 0.0,
329
+ output_idx=[],
330
+ use_checkpoint=False,
331
+ ):
332
+ """
333
+ Args:
334
+ in_chans: Number of input image channels.
335
+ num_classes: Number of classes for classification head.
336
+ global_pool: Global pooling type.
337
+ output_stride: Output stride of network, one of (8, 16, 32).
338
+ depths: Number of blocks at each stage.
339
+ dims: Feature dimension at each stage.
340
+ kernel_sizes: Depthwise convolution kernel-sizes for each stage.
341
+ ls_init_value: Init value for Layer Scale, disabled if None.
342
+ stem_type: Type of stem.
343
+ patch_size: Stem patch size for patch stem.
344
+ head_init_scale: Init scaling value for classifier weights and biases.
345
+ head_norm_first: Apply normalization before global pool + head.
346
+ head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
347
+ conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
348
+ conv_bias: Use bias layers w/ all convolutions.
349
+ use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
350
+ act_layer: Activation layer type.
351
+ norm_layer: Normalization layer type.
352
+ drop_rate: Head pre-classifier dropout rate.
353
+ drop_path_rate: Stochastic depth drop rate.
354
+ """
355
+ super().__init__()
356
+ self.num_layers = len(depths)
357
+ self.depths = output_idx
358
+ self.embed_dims = [
359
+ int(dim) for i, dim in enumerate(dims) for _ in range(depths[i])
360
+ ]
361
+ self.embed_dim = dims[0]
362
+
363
+ assert output_stride in (8, 16, 32)
364
+ kernel_sizes = to_ntuple(4)(kernel_sizes)
365
+ if norm_layer is None:
366
+ norm_layer = LayerNorm2d
367
+ norm_layer_cl = norm_layer if conv_mlp else LayerNorm
368
+ if norm_eps is not None:
369
+ norm_layer = partial(norm_layer, eps=norm_eps)
370
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
371
+ else:
372
+ assert (
373
+ conv_mlp
374
+ ), "If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input"
375
+ norm_layer_cl = norm_layer
376
+ if norm_eps is not None:
377
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
378
+
379
+ self.feature_info = []
380
+
381
+ assert stem_type in ("patch", "overlap", "overlap_tiered")
382
+ if stem_type == "patch":
383
+ # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
384
+ self.stem = nn.Sequential(
385
+ nn.Conv2d(
386
+ in_chans,
387
+ dims[0],
388
+ kernel_size=patch_size,
389
+ stride=patch_size,
390
+ bias=conv_bias,
391
+ ),
392
+ norm_layer(dims[0]),
393
+ )
394
+ stem_stride = patch_size
395
+ else:
396
+ mid_chs = make_divisible(dims[0] // 2) if "tiered" in stem_type else dims[0]
397
+ self.stem = nn.Sequential(
398
+ nn.Conv2d(
399
+ in_chans,
400
+ mid_chs,
401
+ kernel_size=3,
402
+ stride=2,
403
+ padding=1,
404
+ bias=conv_bias,
405
+ ),
406
+ nn.Conv2d(
407
+ mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias
408
+ ),
409
+ norm_layer(dims[0]),
410
+ )
411
+ stem_stride = 4
412
+
413
+ self.stages = nn.Sequential()
414
+ dp_rates = [
415
+ x.tolist()
416
+ for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)
417
+ ]
418
+ stages = []
419
+ prev_chs = dims[0]
420
+ curr_stride = stem_stride
421
+ dilation = 1
422
+ # 4 feature resolution stages, each consisting of multiple residual blocks
423
+ for i in range(4):
424
+ stride = 2 if curr_stride == 2 or i > 0 else 1
425
+ if curr_stride >= output_stride and stride > 1:
426
+ dilation *= stride
427
+ stride = 1
428
+ curr_stride *= stride
429
+ first_dilation = 1 if dilation in (1, 2) else 2
430
+ out_chs = dims[i]
431
+ stages.append(
432
+ ConvNeXtStage(
433
+ prev_chs,
434
+ out_chs,
435
+ kernel_size=kernel_sizes[i],
436
+ stride=stride,
437
+ dilation=(first_dilation, dilation),
438
+ depth=depths[i],
439
+ drop_path_rates=dp_rates[i],
440
+ ls_init_value=ls_init_value,
441
+ conv_mlp=conv_mlp,
442
+ conv_bias=conv_bias,
443
+ use_grn=use_grn,
444
+ act_layer=act_layer,
445
+ norm_layer=norm_layer,
446
+ norm_layer_cl=norm_layer_cl,
447
+ )
448
+ )
449
+ prev_chs = out_chs
450
+ # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
451
+ self.feature_info += [
452
+ dict(num_chs=prev_chs, reduction=curr_stride, module=f"stages.{i}")
453
+ ]
454
+ self.stages = nn.ModuleList(stages)
455
+ self.mask_token = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
456
+ self.num_features = prev_chs
457
+ self.apply(self._init_weights)
458
+ self.set_grad_checkpointing(use_checkpoint)
459
+
460
+ def _init_weights(self, module):
461
+ if isinstance(module, nn.Conv2d):
462
+ trunc_normal_(module.weight, std=0.02)
463
+ if module.bias is not None:
464
+ nn.init.zeros_(module.bias)
465
+ elif isinstance(module, nn.Linear):
466
+ trunc_normal_(module.weight, std=0.02)
467
+ nn.init.zeros_(module.bias)
468
+
469
+ def forward(self, x, masks=None):
470
+ outs = []
471
+ x = self.stem(x)
472
+ if masks is not None:
473
+ masks = torch.nn.functional.interpolate(
474
+ masks.float(), size=x.shape[-2:], mode="nearest"
475
+ )
476
+ x = torch.where(masks.bool(), self.mask_token.to(x.dtype), x).contiguous()
477
+ for stage in self.stages:
478
+ xs = stage(x)
479
+ outs.extend([x.permute(0, 2, 3, 1).contiguous() for x in xs])
480
+ x = xs[-1]
481
+ return outs, [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs]
482
+
483
+ @torch.jit.ignore
484
+ def group_matcher(self, coarse=False):
485
+ return dict(
486
+ stem=r"^stem",
487
+ blocks=(
488
+ r"^stages\.(\d+)"
489
+ if coarse
490
+ else [
491
+ (r"^stages\.(\d+)\.downsample", (0,)), # blocks
492
+ (r"^stages\.(\d+)\.blocks\.(\d+)", None),
493
+ (r"^norm_pre", (99999,)),
494
+ ]
495
+ ),
496
+ )
497
+
498
+ @torch.jit.ignore
499
+ def set_grad_checkpointing(self, enable=True):
500
+ for s in self.stages:
501
+ s.grad_checkpointing = enable
502
+
503
+ def freeze(self) -> None:
504
+ for module in self.modules():
505
+ module.eval()
506
+ for parameters in self.parameters():
507
+ parameters.requires_grad = False
508
+
509
+ def get_params(self, lr, wd, ld, *args, **kwargs):
510
+ encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
511
+ return encoder_p, encoder_lr
512
+
513
+ def no_weight_decay(self):
514
+ return {"mask_token"}
515
+
516
+ @classmethod
517
+ def build(cls, config):
518
+ obj = globals()[config["model"]["encoder"]["name"]](config)
519
+ return obj
520
+
521
+
522
+ def checkpoint_filter_fn(state_dict, model):
523
+ """Remap FB checkpoints -> timm"""
524
+ if "head.norm.weight" in state_dict or "norm_pre.weight" in state_dict:
525
+ return state_dict # non-FB checkpoint
526
+ if "model" in state_dict:
527
+ state_dict = state_dict["model"]
528
+
529
+ out_dict = {}
530
+ if "visual.trunk.stem.0.weight" in state_dict:
531
+ out_dict = {
532
+ k.replace("visual.trunk.", ""): v
533
+ for k, v in state_dict.items()
534
+ if k.startswith("visual.trunk.")
535
+ }
536
+ if "visual.head.proj.weight" in state_dict:
537
+ out_dict["head.fc.weight"] = state_dict["visual.head.proj.weight"]
538
+ out_dict["head.fc.bias"] = torch.zeros(
539
+ state_dict["visual.head.proj.weight"].shape[0]
540
+ )
541
+ elif "visual.head.mlp.fc1.weight" in state_dict:
542
+ out_dict["head.pre_logits.fc.weight"] = state_dict[
543
+ "visual.head.mlp.fc1.weight"
544
+ ]
545
+ out_dict["head.pre_logits.fc.bias"] = state_dict["visual.head.mlp.fc1.bias"]
546
+ out_dict["head.fc.weight"] = state_dict["visual.head.mlp.fc2.weight"]
547
+ out_dict["head.fc.bias"] = torch.zeros(
548
+ state_dict["visual.head.mlp.fc2.weight"].shape[0]
549
+ )
550
+ return out_dict
551
+
552
+ import re
553
+
554
+ for k, v in state_dict.items():
555
+ k = k.replace("downsample_layers.0.", "stem.")
556
+ k = re.sub(r"stages.([0-9]+).([0-9]+)", r"stages.\1.blocks.\2", k)
557
+ k = re.sub(
558
+ r"downsample_layers.([0-9]+).([0-9]+)", r"stages.\1.downsample.\2", k
559
+ )
560
+ k = k.replace("dwconv", "conv_dw")
561
+ k = k.replace("pwconv", "mlp.fc")
562
+ if "grn" in k:
563
+ k = k.replace("grn.beta", "mlp.grn.bias")
564
+ k = k.replace("grn.gamma", "mlp.grn.weight")
565
+ v = v.reshape(v.shape[-1])
566
+ k = k.replace("head.", "head.fc.")
567
+ if k.startswith("norm."):
568
+ k = k.replace("norm", "head.norm")
569
+ if v.ndim == 2 and "head" not in k:
570
+ model_shape = model.state_dict()[k].shape
571
+ v = v.reshape(model_shape)
572
+ out_dict[k] = v
573
+
574
+ return out_dict
575
+
576
+
577
+ HF_URL = {
578
+ "convnext_xxlarge_pt": (
579
+ "laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup",
580
+ "open_clip_pytorch_model.bin",
581
+ ),
582
+ "convnext_large_pt": (
583
+ "laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup",
584
+ "open_clip_pytorch_model.bin",
585
+ ),
586
+ "convnext_large": (
587
+ "timm/convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384",
588
+ "pytorch_model.bin",
589
+ ),
590
+ }
flash3d/unidepth/models/backbones/convnext2.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from timm.models.layers import trunc_normal_, DropPath
5
+
6
+
7
+ def get_num_layer_for_convnext_single(var_name, depths):
8
+ """
9
+ Each layer is assigned distinctive layer ids
10
+ """
11
+ if var_name.startswith("downsample_layers"):
12
+ stage_id = int(var_name.split(".")[1])
13
+ layer_id = sum(depths[:stage_id]) + 1
14
+ return layer_id
15
+
16
+ elif var_name.startswith("stages"):
17
+ stage_id = int(var_name.split(".")[1])
18
+ block_id = int(var_name.split(".")[2])
19
+ layer_id = sum(depths[:stage_id]) + block_id + 1
20
+ return layer_id
21
+
22
+ else:
23
+ return sum(depths) + 1
24
+
25
+
26
+ def get_num_layer_for_convnext(var_name):
27
+ """
28
+ Divide [3, 3, 27, 3] layers into 12 groups; each group is three
29
+ consecutive blocks, including possible neighboring downsample layers;
30
+ adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
31
+ """
32
+ num_max_layer = 12
33
+ if var_name.startswith("downsample_layers"):
34
+ stage_id = int(var_name.split(".")[1])
35
+ if stage_id == 0:
36
+ layer_id = 0
37
+ elif stage_id == 1 or stage_id == 2:
38
+ layer_id = stage_id + 1
39
+ elif stage_id == 3:
40
+ layer_id = 12
41
+ return layer_id
42
+
43
+ elif var_name.startswith("stages"):
44
+ stage_id = int(var_name.split(".")[1])
45
+ block_id = int(var_name.split(".")[2])
46
+ if stage_id == 0 or stage_id == 1:
47
+ layer_id = stage_id + 1
48
+ elif stage_id == 2:
49
+ layer_id = 3 + block_id // 3
50
+ elif stage_id == 3:
51
+ layer_id = 12
52
+ return layer_id
53
+ else:
54
+ return num_max_layer + 1
55
+
56
+
57
+ def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()):
58
+ parameter_group_names = {}
59
+ parameter_group_vars = {}
60
+ skip = {}
61
+ if skip_list is not None:
62
+ skip = skip_list
63
+ elif hasattr(model, "no_weight_decay"):
64
+ skip = model.no_weight_decay()
65
+ num_layers = 12 # sum(model.depths)
66
+ layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2))
67
+ for name, param in model.named_parameters():
68
+ if not param.requires_grad:
69
+ continue # frozen weights
70
+ if (
71
+ len(param.shape) == 1
72
+ or name.endswith(".bias")
73
+ or name in skip
74
+ or name.endswith(".gamma")
75
+ or name.endswith(".beta")
76
+ ):
77
+ group_name = "no_decay"
78
+ this_weight_decay = 0.0
79
+ else:
80
+ group_name = "decay"
81
+ this_weight_decay = wd
82
+
83
+ # layer_id = get_num_layer_for_convnext_single(name, model.depths)
84
+ layer_id = get_num_layer_for_convnext(name)
85
+ group_name = "layer_%d_%s" % (layer_id, group_name)
86
+
87
+ if group_name not in parameter_group_names:
88
+ scale = layer_scale[layer_id]
89
+ cur_lr = lr * scale
90
+
91
+ parameter_group_names[group_name] = {
92
+ "weight_decay": this_weight_decay,
93
+ "params": [],
94
+ "lr_scale": scale,
95
+ "lr": cur_lr,
96
+ }
97
+ parameter_group_vars[group_name] = {
98
+ "weight_decay": this_weight_decay,
99
+ "params": [],
100
+ "lr_scale": scale,
101
+ "lr": cur_lr,
102
+ }
103
+ parameter_group_vars[group_name]["params"].append(param)
104
+ parameter_group_names[group_name]["params"].append(name)
105
+ # if is_main_process():
106
+ # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
107
+ return list(parameter_group_vars.values()), [
108
+ v["lr"] for k, v in parameter_group_vars.items()
109
+ ]
110
+
111
+
112
+ class LayerNorm(nn.Module):
113
+ """LayerNorm that supports two data formats: channels_last (default) or channels_first.
114
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
115
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
116
+ with shape (batch_size, channels, height, width).
117
+ """
118
+
119
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
120
+ super().__init__()
121
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
122
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
123
+ self.eps = eps
124
+ self.data_format = data_format
125
+ if self.data_format not in ["channels_last", "channels_first"]:
126
+ raise NotImplementedError
127
+ self.normalized_shape = (normalized_shape,)
128
+
129
+ def forward(self, x):
130
+ if self.data_format == "channels_last":
131
+ return F.layer_norm(
132
+ x, self.normalized_shape, self.weight, self.bias, self.eps
133
+ )
134
+ elif self.data_format == "channels_first":
135
+ u = x.mean(1, keepdim=True)
136
+ s = (x - u).pow(2).mean(1, keepdim=True)
137
+ x = (x - u) / torch.sqrt(s + self.eps)
138
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
139
+ return x
140
+
141
+
142
+ class GRN(nn.Module):
143
+ """GRN (Global Response Normalization) layer"""
144
+
145
+ def __init__(self, dim):
146
+ super().__init__()
147
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
148
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
149
+
150
+ def forward(self, x):
151
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
152
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
153
+ return self.gamma * (x * Nx) + self.beta + x
154
+
155
+
156
+ class Block(nn.Module):
157
+ """ConvNeXtV2 Block.
158
+
159
+ Args:
160
+ dim (int): Number of input channels.
161
+ drop_path (float): Stochastic depth rate. Default: 0.0
162
+ """
163
+
164
+ def __init__(self, dim, drop_path=0.0, mult=4, use_checkpoint=False):
165
+ super().__init__()
166
+ self.dwconv = nn.Conv2d(
167
+ dim, dim, kernel_size=7, padding=3, groups=dim
168
+ ) # depthwise conv
169
+ self.norm = LayerNorm(dim, eps=1e-6)
170
+ self.pwconv1 = nn.Linear(
171
+ dim, mult * dim
172
+ ) # pointwise/1x1 convs, implemented with linear layers
173
+ self.act = nn.GELU()
174
+ self.grn = GRN(mult * dim)
175
+ self.pwconv2 = nn.Linear(mult * dim, dim)
176
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
177
+ self.use_checkpoint = use_checkpoint
178
+
179
+ def forward(self, x):
180
+ input = x
181
+ x = self.dwconv(x)
182
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
183
+ x = self.norm(x)
184
+ x = self.pwconv1(x)
185
+ x = self.act(x)
186
+ x = self.grn(x)
187
+ x = self.pwconv2(x)
188
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
189
+
190
+ x = input + self.drop_path(x)
191
+ return x
192
+
193
+
194
+ class ConvNeXtV2(nn.Module):
195
+ """ConvNeXt V2
196
+
197
+ Args:
198
+ in_chans (int): Number of input image channels. Default: 3
199
+ num_classes (int): Number of classes for classification head. Default: 1000
200
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
201
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
202
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
203
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ in_chans=3,
209
+ depths=[3, 3, 9, 3],
210
+ dims=96,
211
+ drop_path_rate=0.0,
212
+ output_idx=[],
213
+ use_checkpoint=False,
214
+ ):
215
+ super().__init__()
216
+ self.num_layers = len(depths)
217
+ self.depths = output_idx
218
+ self.embed_dims = [
219
+ int(dim) for i, dim in enumerate(dims) for _ in range(depths[i])
220
+ ]
221
+ self.embed_dim = dims[0]
222
+
223
+ self.downsample_layers = (
224
+ nn.ModuleList()
225
+ ) # stem and 3 intermediate downsampling conv layers
226
+ stem = nn.Sequential(
227
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
228
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
229
+ )
230
+ self.downsample_layers.append(stem)
231
+ for i in range(3):
232
+ downsample_layer = nn.Sequential(
233
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
234
+ nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
235
+ )
236
+ self.downsample_layers.append(downsample_layer)
237
+
238
+ self.stages = (
239
+ nn.ModuleList()
240
+ ) # 4 feature resolution stages, each consisting of multiple residual blocks
241
+ self.out_norms = nn.ModuleList()
242
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
243
+ cur = 0
244
+ for i in range(4):
245
+ stage = nn.ModuleList(
246
+ [
247
+ Block(
248
+ dim=dims[i],
249
+ drop_path=dp_rates[cur + j],
250
+ use_checkpoint=use_checkpoint,
251
+ )
252
+ for j in range(depths[i])
253
+ ]
254
+ )
255
+ self.stages.append(stage)
256
+ cur += depths[i]
257
+
258
+ self.apply(self._init_weights)
259
+
260
+ def _init_weights(self, m):
261
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
262
+ trunc_normal_(m.weight, std=0.02)
263
+ nn.init.constant_(m.bias, 0)
264
+
265
+ def forward(self, x):
266
+ outs = []
267
+ for i in range(4):
268
+ x = self.downsample_layers[i](x)
269
+ for stage in self.stages[i]:
270
+ x = stage(x)
271
+ outs.append(x.permute(0, 2, 3, 1))
272
+ cls_tokens = [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs]
273
+ return outs, cls_tokens
274
+
275
+ def get_params(self, lr, wd, ld, *args, **kwargs):
276
+ encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
277
+ return encoder_p, encoder_lr
278
+
279
+ def freeze(self) -> None:
280
+ for module in self.modules():
281
+ module.eval()
282
+ for parameters in self.parameters():
283
+ parameters.requires_grad = False
284
+
285
+ @classmethod
286
+ def build(cls, config):
287
+ obj = globals()[config["model"]["encoder"]["name"]](config)
288
+ return obj
flash3d/unidepth/models/backbones/dinov2.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import math
3
+ import logging
4
+ from typing import Sequence, Tuple, Union, Callable
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.checkpoint import checkpoint
9
+ from torch.nn.init import trunc_normal_
10
+
11
+ from .metadinov2 import (
12
+ Mlp,
13
+ PatchEmbed,
14
+ SwiGLUFFNFused,
15
+ MemEffAttention,
16
+ NestedTensorBlock as Block,
17
+ )
18
+
19
+
20
+ logger = logging.getLogger("dinov2")
21
+
22
+
23
+ def named_apply(
24
+ fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
25
+ ) -> nn.Module:
26
+ if not depth_first and include_root:
27
+ fn(module=module, name=name)
28
+ for child_name, child_module in module.named_children():
29
+ child_name = ".".join((name, child_name)) if name else child_name
30
+ named_apply(
31
+ fn=fn,
32
+ module=child_module,
33
+ name=child_name,
34
+ depth_first=depth_first,
35
+ include_root=True,
36
+ )
37
+ if depth_first and include_root:
38
+ fn(module=module, name=name)
39
+ return module
40
+
41
+
42
+ def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()):
43
+ parameter_group_names = {}
44
+ parameter_group_vars = {}
45
+ skip = {}
46
+ if skip_list is not None:
47
+ skip = skip_list
48
+ elif hasattr(model, "no_weight_decay"):
49
+ skip = model.no_weight_decay()
50
+
51
+ num_layers = model.n_blocks
52
+ layer_scale = list(ld ** (num_layers - i) for i in range(num_layers))
53
+
54
+ for name, param in model.named_parameters():
55
+ if not param.requires_grad:
56
+ continue
57
+
58
+ if len(param.shape) == 1: # norm
59
+ group_name = "no_decay"
60
+ this_wd = 0.0
61
+ # layer scale, bias beta?
62
+ elif (
63
+ name in skip
64
+ or name.endswith(".gamma")
65
+ or name.endswith(".beta")
66
+ or name.endswith(".bias")
67
+ ):
68
+ group_name = "no_decay"
69
+ this_wd = 0.0
70
+ elif "cls_token" in name or "pos_embed" in name or "mask_token" in name:
71
+ group_name = "no_decay"
72
+ this_wd = 0.0
73
+ else:
74
+ group_name = "decay"
75
+ this_wd = wd
76
+
77
+ if name.startswith("blocks"):
78
+ layer_id = int(name.split(".")[1])
79
+ elif name.startswith("patch_embed"):
80
+ layer_id = 0
81
+ else:
82
+ layer_id = 0
83
+
84
+ group_name = f"layer_{layer_id}_{group_name}"
85
+
86
+ if group_name not in parameter_group_names:
87
+ scale = layer_scale[layer_id]
88
+ cur_lr = lr * scale
89
+
90
+ parameter_group_names[group_name] = {
91
+ "weight_decay": this_wd,
92
+ "params": [],
93
+ "lr_init": cur_lr,
94
+ "lr_base": lr,
95
+ "lr": cur_lr,
96
+ }
97
+ parameter_group_vars[group_name] = {
98
+ "weight_decay": this_wd,
99
+ "params": [],
100
+ "lr_init": cur_lr,
101
+ "lr_base": lr,
102
+ "lr": cur_lr,
103
+ }
104
+ parameter_group_vars[group_name]["params"].append(param)
105
+ parameter_group_names[group_name]["params"].append(name)
106
+ return list(parameter_group_vars.values()), [
107
+ v["lr"] for k, v in parameter_group_vars.items()
108
+ ]
109
+
110
+
111
+ class BlockChunk(nn.ModuleList):
112
+ def forward(self, x):
113
+ for b in self:
114
+ x = b(x)
115
+ return x
116
+
117
+
118
+ class DinoVisionTransformer(nn.Module):
119
+ def __init__(
120
+ self,
121
+ img_size=224,
122
+ patch_size=16,
123
+ in_chans=3,
124
+ embed_dim=768,
125
+ depth=12,
126
+ num_heads=12,
127
+ mlp_ratio=4.0,
128
+ qkv_bias=True,
129
+ ffn_bias=True,
130
+ proj_bias=True,
131
+ drop_path_rate=0.0,
132
+ drop_path_uniform=False,
133
+ init_values=None, # for layerscale: None or 0 => no layerscale
134
+ embed_layer=PatchEmbed,
135
+ act_layer=nn.GELU,
136
+ block_fn=Block,
137
+ ffn_layer="mlp",
138
+ block_chunks=1,
139
+ output_idx=[5, 12, 18, 24],
140
+ checkpoint: bool = False,
141
+ num_register_tokens=0,
142
+ interpolate_antialias=False,
143
+ interpolate_offset=0.1,
144
+ ):
145
+ """
146
+ Args:
147
+ img_size (int, tuple): input image size
148
+ patch_size (int, tuple): patch size
149
+ in_chans (int): number of input channels
150
+ embed_dim (int): embedding dimension
151
+ depth (int): depth of transformer
152
+ num_heads (int): number of attention heads
153
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
154
+ qkv_bias (bool): enable bias for qkv if True
155
+ proj_bias (bool): enable bias for proj in attn if True
156
+ ffn_bias (bool): enable bias for ffn if True
157
+ drop_path_rate (float): stochastic depth rate
158
+ drop_path_uniform (bool): apply uniform drop rate across blocks
159
+ weight_init (str): weight init scheme
160
+ init_values (float): layer-scale init values
161
+ embed_layer (nn.Module): patch embedding layer
162
+ act_layer (nn.Module): MLP activation layer
163
+ block_fn (nn.Module): transformer block class
164
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
165
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
166
+ """
167
+ super().__init__()
168
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
169
+
170
+ self.num_features = self.embed_dim = (
171
+ embed_dim # num_features for consistency with other models
172
+ )
173
+ self.embed_dims = [embed_dim] * output_idx[-1]
174
+ self.num_tokens = 1
175
+ self.n_blocks = depth
176
+ self.num_heads = num_heads
177
+ self.patch_size = patch_size
178
+ self.depths = output_idx
179
+ self.checkpoint = checkpoint
180
+ self.num_register_tokens = num_register_tokens
181
+ self.interpolate_antialias = interpolate_antialias
182
+ self.interpolate_offset = interpolate_offset
183
+
184
+ self.patch_embed = embed_layer(
185
+ img_size=img_size,
186
+ patch_size=patch_size,
187
+ in_chans=in_chans,
188
+ embed_dim=embed_dim,
189
+ )
190
+ num_patches = self.patch_embed.num_patches
191
+
192
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
193
+ self.pos_embed = nn.Parameter(
194
+ torch.zeros(1, num_patches + self.num_tokens, embed_dim)
195
+ )
196
+ assert num_register_tokens >= 0
197
+ self.register_tokens = nn.Parameter(
198
+ torch.zeros(1, max(1, num_register_tokens), embed_dim)
199
+ )
200
+
201
+ if drop_path_uniform is True:
202
+ dpr = [drop_path_rate] * depth
203
+ else:
204
+ dpr = [
205
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
206
+ ] # stochastic depth decay rule
207
+
208
+ if ffn_layer == "mlp":
209
+ logger.info("using MLP layer as FFN")
210
+ ffn_layer = Mlp
211
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
212
+ logger.info("using SwiGLU layer as FFN")
213
+ ffn_layer = SwiGLUFFNFused
214
+ elif ffn_layer == "identity":
215
+ logger.info("using Identity layer as FFN")
216
+
217
+ def f(*args, **kwargs):
218
+ return nn.Identity()
219
+
220
+ ffn_layer = f
221
+ else:
222
+ raise NotImplementedError
223
+
224
+ blocks_list = [
225
+ block_fn(
226
+ dim=embed_dim,
227
+ num_heads=num_heads,
228
+ mlp_ratio=mlp_ratio,
229
+ qkv_bias=qkv_bias,
230
+ proj_bias=proj_bias,
231
+ ffn_bias=ffn_bias,
232
+ drop_path=dpr[i],
233
+ norm_layer=norm_layer,
234
+ act_layer=act_layer,
235
+ ffn_layer=ffn_layer,
236
+ init_values=init_values,
237
+ )
238
+ for i in range(depth)
239
+ ]
240
+ if block_chunks > 0:
241
+ self.chunked_blocks = True
242
+ chunked_blocks = []
243
+ chunksize = depth // block_chunks
244
+ for i in range(0, depth, chunksize):
245
+ # this is to keep the block index consistent if we chunk the block list
246
+ chunked_blocks.append(
247
+ [nn.Identity()] * i + blocks_list[i : i + chunksize]
248
+ )
249
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
250
+ else:
251
+ self.chunked_blocks = False
252
+ self.blocks = nn.ModuleList(blocks_list)
253
+
254
+ # self.norm = norm_layer(embed_dim)
255
+ self.head = nn.Identity()
256
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
257
+ self.init_weights()
258
+
259
+ def init_weights(self):
260
+ trunc_normal_(self.pos_embed, std=0.02)
261
+ nn.init.normal_(self.cls_token, std=1e-6)
262
+ if self.num_register_tokens:
263
+ nn.init.normal_(self.register_tokens, std=1e-6)
264
+ named_apply(init_weights_vit_timm, self)
265
+
266
+ def interpolate_pos_encoding(self, x, w, h):
267
+ previous_dtype = x.dtype
268
+ npatch = x.shape[1] - 1
269
+ N = self.pos_embed.shape[1] - 1
270
+ if npatch == N and w == h:
271
+ return self.pos_embed
272
+ pos_embed = self.pos_embed.float()
273
+ class_pos_embed = pos_embed[:, 0]
274
+ patch_pos_embed = pos_embed[:, 1:]
275
+ dim = x.shape[-1]
276
+ w0 = w // self.patch_size
277
+ h0 = h // self.patch_size
278
+ # we add a small number to avoid floating point error in the interpolation
279
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
280
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
281
+
282
+ patch_pos_embed = nn.functional.interpolate(
283
+ patch_pos_embed.reshape(
284
+ 1, int(math.sqrt(N)), int(math.sqrt(N)), dim
285
+ ).permute(0, 3, 1, 2),
286
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
287
+ mode="bicubic",
288
+ antialias=self.interpolate_antialias,
289
+ )
290
+
291
+ assert (
292
+ int(w0) == patch_pos_embed.shape[-2]
293
+ and int(h0) == patch_pos_embed.shape[-1]
294
+ )
295
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
296
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
297
+ previous_dtype
298
+ )
299
+
300
+ def prepare_tokens_with_masks(self, x, masks=None):
301
+ B, nc, w, h = x.shape
302
+ x = self.patch_embed(x)
303
+ if masks is not None:
304
+ masks = masks.bool().view(B, -1, 1)
305
+ x = torch.where(masks, self.mask_token.to(x.dtype).unsqueeze(0), x)
306
+
307
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
308
+ x = x + self.interpolate_pos_encoding(x, w, h)
309
+
310
+ if self.num_register_tokens:
311
+ x = torch.cat(
312
+ (x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]),
313
+ dim=1,
314
+ )
315
+
316
+ return x
317
+
318
+ def forward_features(self, x, masks=None):
319
+ # if isinstance(x, list):
320
+ # return self.forward_features_list(x, masks)
321
+ shapes = [val // self.patch_size for val in x.shape[-2:]]
322
+ batch_size = x.shape[0]
323
+ x = self.prepare_tokens_with_masks(x, masks)
324
+ output, cls_tokens = [], []
325
+
326
+ for i, blk in enumerate(self.blocks):
327
+ x = blk(x)
328
+ cls_token = x[:, :1]
329
+
330
+ out = x[:, self.num_register_tokens + 1 :]
331
+ # was like this before, add cls to dense features
332
+ # out = out + cls_token
333
+
334
+ output.append(out.view(batch_size, *shapes, -1))
335
+ cls_tokens.append(cls_token)
336
+ return (output, cls_tokens)
337
+
338
+ def get_params(self, lr, wd, ld, *args, **kwargs):
339
+ encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
340
+ return encoder_p, encoder_lr
341
+
342
+ def freeze(self) -> None:
343
+ for module in self.modules():
344
+ module.eval()
345
+ for parameters in self.parameters():
346
+ parameters.requires_grad = False
347
+
348
+ def train(self, mode=True):
349
+ super().train(mode)
350
+ self.mask_token.requires_grad = False
351
+ self.register_tokens.requires_grad = False
352
+
353
+ def forward(self, *args, is_training=False, **kwargs):
354
+ ret = self.forward_features(*args, **kwargs)
355
+ return ret
356
+
357
+
358
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
359
+ """ViT weight initialization, original timm impl (for reproducibility)"""
360
+ if isinstance(module, nn.Linear):
361
+ trunc_normal_(module.weight, std=0.02)
362
+ if module.bias is not None:
363
+ nn.init.zeros_(module.bias)
364
+
365
+
366
+ def vit_small(patch_size=16, **kwargs):
367
+ model = DinoVisionTransformer(
368
+ patch_size=patch_size,
369
+ embed_dim=384,
370
+ depth=12,
371
+ num_heads=6,
372
+ mlp_ratio=4,
373
+ block_fn=partial(Block, attn_class=MemEffAttention),
374
+ **kwargs,
375
+ )
376
+ return model
377
+
378
+
379
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
380
+ model = DinoVisionTransformer(
381
+ patch_size=patch_size,
382
+ embed_dim=768,
383
+ depth=12,
384
+ num_heads=12,
385
+ mlp_ratio=4,
386
+ num_register_tokens=num_register_tokens,
387
+ block_fn=partial(Block, attn_class=MemEffAttention),
388
+ **kwargs,
389
+ )
390
+ return model
391
+
392
+
393
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
394
+ model = DinoVisionTransformer(
395
+ patch_size=patch_size,
396
+ embed_dim=1024,
397
+ depth=24,
398
+ num_heads=16,
399
+ mlp_ratio=4,
400
+ num_register_tokens=num_register_tokens,
401
+ block_fn=partial(Block, attn_class=MemEffAttention),
402
+ **kwargs,
403
+ )
404
+ return model
405
+
406
+
407
+ def vit_giant2(patch_size=16, **kwargs):
408
+ """
409
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
410
+ """
411
+ model = DinoVisionTransformer(
412
+ patch_size=patch_size,
413
+ embed_dim=1536,
414
+ depth=40,
415
+ num_heads=24,
416
+ mlp_ratio=4,
417
+ block_fn=partial(Block, attn_class=MemEffAttention),
418
+ **kwargs,
419
+ )
420
+ return model
421
+
422
+
423
+ import torch
424
+ import torch.nn as nn
425
+
426
+
427
+ dependencies = ["torch"]
428
+
429
+
430
+ _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
431
+
432
+
433
+ def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
434
+ compact_arch_name = arch_name.replace("_", "")[:4]
435
+ return f"dinov2_{compact_arch_name}{patch_size}"
436
+
437
+
438
+ def _make_dinov2_model(
439
+ *,
440
+ arch_name: str = "vit_large",
441
+ img_size: int = 518,
442
+ patch_size: int = 14,
443
+ init_values: float = 1.0,
444
+ ffn_layer: str = "mlp",
445
+ block_chunks: int = 0,
446
+ pretrained: str = "",
447
+ output_idx: Sequence[int] = [],
448
+ num_register_tokens: int = 0,
449
+ drop_path_rate: float = 0.0,
450
+ **kwargs,
451
+ ):
452
+ model_name = _make_dinov2_model_name(arch_name, patch_size)
453
+ print("Instantiate:", model_name)
454
+
455
+ vit_kwargs = dict(
456
+ img_size=img_size,
457
+ patch_size=patch_size,
458
+ init_values=init_values,
459
+ ffn_layer=ffn_layer,
460
+ block_chunks=block_chunks,
461
+ output_idx=output_idx,
462
+ drop_path_rate=drop_path_rate,
463
+ num_register_tokens=num_register_tokens,
464
+ )
465
+ vit_kwargs.update(**kwargs)
466
+ model = eval(arch_name)(**vit_kwargs)
467
+ if pretrained == "":
468
+ url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}"
469
+ if num_register_tokens > 0:
470
+ url += "_reg4"
471
+ url += "_pretrain.pth"
472
+ state_dict = torch.hub.load_state_dict_from_url(
473
+ url, map_location="cpu", progress=False
474
+ )
475
+ info = model.load_state_dict(state_dict, strict=False)
476
+ print(info)
477
+ elif pretrained is not None:
478
+ state_dict = torch.load(pretrained, map_location="cpu")
479
+ info = model.load_state_dict(state_dict, strict=False)
480
+ print(f"loading from {pretrained} with:", info)
481
+ return model
482
+
483
+ # def forward_features_list(self, x_list, masks_list):
484
+ # x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
485
+ # for blk in self.blocks:
486
+ # x = blk(x)
487
+
488
+ # all_x = x
489
+ # output = []
490
+ # for x, masks in zip(all_x, masks_list):
491
+ # x_norm = self.norm(x)
492
+ # output.append(
493
+ # {
494
+ # "x_norm_clstoken": x_norm[:, 0],
495
+ # "x_norm_patchtokens": x_norm[:, 1:],
496
+ # "x_prenorm": x,
497
+ # "masks": masks,
498
+ # }
499
+ # )
500
+ # return output
501
+
502
+ # def _get_intermediate_layers_not_chunked(self, x, n=1):
503
+ # x = self.prepare_tokens_with_masks(x)
504
+ # # If n is an int, take the n last blocks. If it's a list, take them
505
+ # output, total_block_len = [], len(self.blocks)
506
+ # blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
507
+ # for i, blk in enumerate(self.blocks):
508
+ # x = blk(x)
509
+ # if i in blocks_to_take:
510
+ # output.append(x)
511
+ # assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
512
+ # return output
513
+
514
+ # def _get_intermediate_layers_chunked(self, x, n=1):
515
+ # x = self.prepare_tokens_with_masks(x)
516
+ # output, i, total_block_len = [], 0, len(self.blocks[-1])
517
+ # # If n is an int, take the n last blocks. If it's a list, take them
518
+ # blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
519
+ # for block_chunk in self.blocks:
520
+ # for blk in block_chunk[i:]: # Passing the nn.Identity()
521
+ # x = blk(x)
522
+ # if i in blocks_to_take:
523
+ # output.append(x)
524
+ # i += 1
525
+ # assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
526
+ # return output
527
+
528
+ # def get_intermediate_layers(
529
+ # self,
530
+ # x: torch.Tensor,
531
+ # n: Union[int, Sequence] = 1, # Layers or n last layers to take
532
+ # reshape: bool = False,
533
+ # return_class_token: bool = False,
534
+ # norm=True,
535
+ # ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
536
+ # if self.chunked_blocks:
537
+ # outputs = self._get_intermediate_layers_chunked(x, n)
538
+ # else:
539
+ # outputs = self._get_intermediate_layers_not_chunked(x, n)
540
+ # if norm:
541
+ # outputs = [self.norm(out) for out in outputs]
542
+ # class_tokens = [out[:, 0] for out in outputs]
543
+ # outputs = [out[:, 1:] for out in outputs]
544
+ # if reshape:
545
+ # B, _, w, h = x.shape
546
+ # outputs = [
547
+ # out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
548
+ # for out in outputs
549
+ # ]
550
+ # if return_class_token:
551
+ # return tuple(zip(outputs, class_tokens))
552
+ # return tuple(outputs)
flash3d/unidepth/models/backbones/metadinov2/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .dino_head import DINOHead
8
+ from .mlp import Mlp
9
+ from .patch_embed import PatchEmbed
10
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
11
+ from .block import NestedTensorBlock
12
+ from .attention import MemEffAttention
flash3d/unidepth/models/backbones/metadinov2/attention.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+
13
+ from torch import Tensor
14
+ import torch.nn as nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ try:
21
+ from xformers.ops import memory_efficient_attention, unbind, fmha
22
+
23
+ XFORMERS_AVAILABLE = True
24
+ except ImportError:
25
+ logger.warning("xFormers not available")
26
+ XFORMERS_AVAILABLE = False
27
+
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim: int,
33
+ num_heads: int = 8,
34
+ qkv_bias: bool = False,
35
+ proj_bias: bool = True,
36
+ attn_drop: float = 0.0,
37
+ proj_drop: float = 0.0,
38
+ ) -> None:
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ head_dim = dim // num_heads
42
+ self.scale = head_dim**-0.5
43
+
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+
49
+ def forward(self, x: Tensor) -> Tensor:
50
+ B, N, C = x.shape
51
+ qkv = (
52
+ self.qkv(x)
53
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
54
+ .permute(2, 0, 3, 1, 4)
55
+ )
56
+
57
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
58
+ attn = q @ k.transpose(-2, -1)
59
+
60
+ attn = attn.softmax(dim=-1)
61
+ attn = self.attn_drop(attn)
62
+
63
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
64
+ x = self.proj(x)
65
+ x = self.proj_drop(x)
66
+ return x
67
+
68
+
69
+ class MemEffAttention(Attention):
70
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
71
+ if not XFORMERS_AVAILABLE:
72
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
73
+ return super().forward(x)
74
+
75
+ B, N, C = x.shape
76
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
77
+
78
+ q, k, v = unbind(qkv, 2)
79
+
80
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
81
+ x = x.reshape([B, N, C])
82
+
83
+ x = self.proj(x)
84
+ x = self.proj_drop(x)
85
+ return x
flash3d/unidepth/models/backbones/metadinov2/block.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ import logging
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from .attention import Attention, MemEffAttention
18
+ from .drop_path import DropPath
19
+ from .layer_scale import LayerScale
20
+ from .mlp import Mlp
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ try:
27
+ from xformers.ops import fmha
28
+ from xformers.ops import scaled_index_add, index_select_cat
29
+
30
+ XFORMERS_AVAILABLE = True
31
+ except ImportError:
32
+ logger.warning("xFormers not available")
33
+ XFORMERS_AVAILABLE = False
34
+
35
+
36
+ class Block(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int,
41
+ mlp_ratio: float = 4.0,
42
+ qkv_bias: bool = False,
43
+ proj_bias: bool = True,
44
+ ffn_bias: bool = True,
45
+ drop: float = 0.0,
46
+ attn_drop: float = 0.0,
47
+ init_values=None,
48
+ drop_path: float = 0.0,
49
+ act_layer: Callable[..., nn.Module] = nn.GELU,
50
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51
+ attn_class: Callable[..., nn.Module] = Attention,
52
+ ffn_layer: Callable[..., nn.Module] = Mlp,
53
+ ) -> None:
54
+ super().__init__()
55
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56
+ self.norm1 = norm_layer(dim)
57
+ self.attn = attn_class(
58
+ dim,
59
+ num_heads=num_heads,
60
+ qkv_bias=qkv_bias,
61
+ proj_bias=proj_bias,
62
+ attn_drop=attn_drop,
63
+ proj_drop=drop,
64
+ )
65
+ self.ls1 = (
66
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
67
+ )
68
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
69
+
70
+ self.norm2 = norm_layer(dim)
71
+ mlp_hidden_dim = int(dim * mlp_ratio)
72
+ self.mlp = ffn_layer(
73
+ in_features=dim,
74
+ hidden_features=mlp_hidden_dim,
75
+ act_layer=act_layer,
76
+ drop=drop,
77
+ bias=ffn_bias,
78
+ )
79
+ self.ls2 = (
80
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
81
+ )
82
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
83
+
84
+ self.sample_drop_ratio = drop_path
85
+
86
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
87
+ def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
88
+ return self.ls1(self.attn(self.norm1(x)))
89
+
90
+ def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
91
+ return self.ls2(self.mlp(self.norm2(x)))
92
+
93
+ if self.training and self.sample_drop_ratio > 0.1:
94
+ # the overhead is compensated only for a drop path rate larger than 0.1
95
+ x = drop_add_residual_stochastic_depth(
96
+ x,
97
+ residual_func=attn_residual_func,
98
+ sample_drop_ratio=self.sample_drop_ratio,
99
+ )
100
+ x = drop_add_residual_stochastic_depth(
101
+ x,
102
+ residual_func=ffn_residual_func,
103
+ sample_drop_ratio=self.sample_drop_ratio,
104
+ )
105
+ elif self.training and self.sample_drop_ratio > 0.0:
106
+ x = x + self.drop_path1(attn_residual_func(x))
107
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
108
+ else:
109
+ x = x + attn_residual_func(x)
110
+ x = x + ffn_residual_func(x)
111
+ return x
112
+
113
+
114
+ def drop_add_residual_stochastic_depth(
115
+ x: torch.Tensor,
116
+ residual_func: Callable[[torch.Tensor], torch.Tensor],
117
+ sample_drop_ratio: float = 0.0,
118
+ ) -> torch.Tensor:
119
+ # 1) extract subset using permutation
120
+ b, n, d = x.shape
121
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
122
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
123
+ x_subset = x[brange]
124
+
125
+ # 2) apply residual_func to get residual
126
+ residual = residual_func(x_subset)
127
+
128
+ x_flat = x.flatten(1)
129
+ residual = residual.flatten(1)
130
+
131
+ residual_scale_factor = b / sample_subset_size
132
+
133
+ # 3) add the residual
134
+ x_plus_residual = torch.index_add(
135
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
136
+ )
137
+ return x_plus_residual.view_as(x)
138
+
139
+
140
+ def get_branges_scales(x, sample_drop_ratio=0.0):
141
+ b, n, d = x.shape
142
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
143
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
144
+ residual_scale_factor = b / sample_subset_size
145
+ return brange, residual_scale_factor
146
+
147
+
148
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
149
+ if scaling_vector is None:
150
+ x_flat = x.flatten(1)
151
+ residual = residual.flatten(1)
152
+ x_plus_residual = torch.index_add(
153
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
154
+ )
155
+ else:
156
+ x_plus_residual = scaled_index_add(
157
+ x,
158
+ brange,
159
+ residual.to(dtype=x.dtype),
160
+ scaling=scaling_vector,
161
+ alpha=residual_scale_factor,
162
+ )
163
+ return x_plus_residual
164
+
165
+
166
+ attn_bias_cache: Dict[Tuple, Any] = {}
167
+
168
+
169
+ def get_attn_bias_and_cat(x_list, branges=None):
170
+ """
171
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
172
+ """
173
+ batch_sizes = (
174
+ [b.shape[0] for b in branges]
175
+ if branges is not None
176
+ else [x.shape[0] for x in x_list]
177
+ )
178
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
179
+ if all_shapes not in attn_bias_cache.keys():
180
+ seqlens = []
181
+ for b, x in zip(batch_sizes, x_list):
182
+ for _ in range(b):
183
+ seqlens.append(x.shape[1])
184
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
185
+ attn_bias._batch_sizes = batch_sizes
186
+ attn_bias_cache[all_shapes] = attn_bias
187
+
188
+ if branges is not None:
189
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
190
+ 1, -1, x_list[0].shape[-1]
191
+ )
192
+ else:
193
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
194
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
195
+
196
+ return attn_bias_cache[all_shapes], cat_tensors
197
+
198
+
199
+ def drop_add_residual_stochastic_depth_list(
200
+ x_list: List[torch.Tensor],
201
+ residual_func: Callable[[torch.Tensor, Any], torch.Tensor],
202
+ sample_drop_ratio: float = 0.0,
203
+ scaling_vector=None,
204
+ ) -> torch.Tensor:
205
+ # 1) generate random set of indices for dropping samples in the batch
206
+ branges_scales = [
207
+ get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
208
+ ]
209
+ branges = [s[0] for s in branges_scales]
210
+ residual_scale_factors = [s[1] for s in branges_scales]
211
+
212
+ # 2) get attention bias and index+concat the tensors
213
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
214
+
215
+ # 3) apply residual_func to get residual, and split the result
216
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
217
+
218
+ outputs = []
219
+ for x, brange, residual, residual_scale_factor in zip(
220
+ x_list, branges, residual_list, residual_scale_factors
221
+ ):
222
+ outputs.append(
223
+ add_residual(
224
+ x, brange, residual, residual_scale_factor, scaling_vector
225
+ ).view_as(x)
226
+ )
227
+ return outputs
228
+
229
+
230
+ class NestedTensorBlock(Block):
231
+ def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
232
+ """
233
+ x_list contains a list of tensors to nest together and run
234
+ """
235
+ assert isinstance(self.attn, MemEffAttention)
236
+
237
+ if self.training and self.sample_drop_ratio > 0.0:
238
+
239
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
240
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
241
+
242
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
243
+ return self.mlp(self.norm2(x))
244
+
245
+ x_list = drop_add_residual_stochastic_depth_list(
246
+ x_list,
247
+ residual_func=attn_residual_func,
248
+ sample_drop_ratio=self.sample_drop_ratio,
249
+ scaling_vector=(
250
+ self.ls1.gamma if isinstance(self.ls1, LayerScale) else None
251
+ ),
252
+ )
253
+ x_list = drop_add_residual_stochastic_depth_list(
254
+ x_list,
255
+ residual_func=ffn_residual_func,
256
+ sample_drop_ratio=self.sample_drop_ratio,
257
+ scaling_vector=(
258
+ self.ls2.gamma if isinstance(self.ls1, LayerScale) else None
259
+ ),
260
+ )
261
+ return x_list
262
+ else:
263
+
264
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
265
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
266
+
267
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
268
+ return self.ls2(self.mlp(self.norm2(x)))
269
+
270
+ attn_bias, x = get_attn_bias_and_cat(x_list)
271
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
272
+ x = x + ffn_residual_func(x)
273
+ return attn_bias.split(x)
274
+
275
+ def forward(self, x_or_x_list):
276
+ if isinstance(x_or_x_list, torch.Tensor):
277
+ return super().forward(x_or_x_list)
278
+ elif isinstance(x_or_x_list, list):
279
+ assert (
280
+ XFORMERS_AVAILABLE
281
+ ), "Please install xFormers for nested tensors usage"
282
+ return self.forward_nested(x_or_x_list)
283
+ else:
284
+ raise AssertionError
flash3d/unidepth/models/backbones/metadinov2/dino_head.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn.init import trunc_normal_
10
+ from torch.nn.utils import weight_norm
11
+
12
+
13
+ class DINOHead(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_dim,
17
+ out_dim,
18
+ use_bn=False,
19
+ nlayers=3,
20
+ hidden_dim=2048,
21
+ bottleneck_dim=256,
22
+ mlp_bias=True,
23
+ ):
24
+ super().__init__()
25
+ nlayers = max(nlayers, 1)
26
+ self.mlp = _build_mlp(
27
+ nlayers,
28
+ in_dim,
29
+ bottleneck_dim,
30
+ hidden_dim=hidden_dim,
31
+ use_bn=use_bn,
32
+ bias=mlp_bias,
33
+ )
34
+ self.apply(self._init_weights)
35
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
36
+ self.last_layer.weight_g.data.fill_(1)
37
+
38
+ def _init_weights(self, m):
39
+ if isinstance(m, nn.Linear):
40
+ trunc_normal_(m.weight, std=0.02)
41
+ if isinstance(m, nn.Linear) and m.bias is not None:
42
+ nn.init.constant_(m.bias, 0)
43
+
44
+ def forward(self, x):
45
+ x = self.mlp(x)
46
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
47
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
48
+ x = self.last_layer(x)
49
+ return x
50
+
51
+
52
+ def _build_mlp(
53
+ nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
54
+ ):
55
+ if nlayers == 1:
56
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
57
+ else:
58
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
59
+ if use_bn:
60
+ layers.append(nn.BatchNorm1d(hidden_dim))
61
+ layers.append(nn.GELU())
62
+ for _ in range(nlayers - 2):
63
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
64
+ if use_bn:
65
+ layers.append(nn.BatchNorm1d(hidden_dim))
66
+ layers.append(nn.GELU())
67
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
68
+ return nn.Sequential(*layers)
flash3d/unidepth/models/backbones/metadinov2/drop_path.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10
+
11
+
12
+ import torch.nn as nn
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ if drop_prob == 0.0 or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (
20
+ x.ndim - 1
21
+ ) # work with diff dim tensors, not just 2D ConvNets
22
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
23
+ if keep_prob > 0.0:
24
+ random_tensor.div_(keep_prob)
25
+ output = x * random_tensor
26
+ return output
27
+
28
+
29
+ class DropPath(nn.Module):
30
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
31
+
32
+ def __init__(self, drop_prob=None):
33
+ super(DropPath, self).__init__()
34
+ self.drop_prob = drop_prob
35
+
36
+ def forward(self, x):
37
+ return drop_path(x, self.drop_prob, self.training)
flash3d/unidepth/models/backbones/metadinov2/layer_scale.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8
+
9
+ from typing import Union
10
+
11
+ import torch
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ class LayerScale(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ init_values: Union[float, Tensor] = 1e-5,
21
+ inplace: bool = False,
22
+ ) -> None:
23
+ super().__init__()
24
+ self.inplace = inplace
25
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
flash3d/unidepth/models/backbones/metadinov2/mlp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10
+
11
+
12
+ from typing import Callable, Optional
13
+
14
+ from torch import Tensor, nn
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features: int,
21
+ hidden_features: Optional[int] = None,
22
+ out_features: Optional[int] = None,
23
+ act_layer: Callable[..., nn.Module] = nn.GELU,
24
+ drop: float = 0.0,
25
+ bias: bool = True,
26
+ ) -> None:
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
flash3d/unidepth/models/backbones/metadinov2/patch_embed.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ from typing import Callable, Optional, Tuple, Union
12
+
13
+ from torch import Tensor
14
+ import torch.nn as nn
15
+
16
+
17
+ def make_2tuple(x):
18
+ if isinstance(x, tuple):
19
+ assert len(x) == 2
20
+ return x
21
+
22
+ assert isinstance(x, int)
23
+ return (x, x)
24
+
25
+
26
+ class PatchEmbed(nn.Module):
27
+ """
28
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29
+
30
+ Args:
31
+ img_size: Image size.
32
+ patch_size: Patch token size.
33
+ in_chans: Number of input image channels.
34
+ embed_dim: Number of linear projection output channels.
35
+ norm_layer: Normalization layer.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ img_size: Union[int, Tuple[int, int]] = 224,
41
+ patch_size: Union[int, Tuple[int, int]] = 16,
42
+ in_chans: int = 3,
43
+ embed_dim: int = 768,
44
+ norm_layer: Optional[Callable] = None,
45
+ flatten_embedding: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ image_HW = make_2tuple(img_size)
50
+ patch_HW = make_2tuple(patch_size)
51
+ patch_grid_size = (
52
+ image_HW[0] // patch_HW[0],
53
+ image_HW[1] // patch_HW[1],
54
+ )
55
+
56
+ self.img_size = image_HW
57
+ self.patch_size = patch_HW
58
+ self.patches_resolution = patch_grid_size
59
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60
+
61
+ self.in_chans = in_chans
62
+ self.embed_dim = embed_dim
63
+
64
+ self.flatten_embedding = flatten_embedding
65
+
66
+ self.proj = nn.Conv2d(
67
+ in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
68
+ )
69
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
70
+
71
+ def forward(self, x: Tensor) -> Tensor:
72
+ _, _, H, W = x.shape
73
+ patch_H, patch_W = self.patch_size
74
+
75
+ assert (
76
+ H % patch_H == 0
77
+ ), f"Input image height {H} is not a multiple of patch height {patch_H}"
78
+ assert (
79
+ W % patch_W == 0
80
+ ), f"Input image width {W} is not a multiple of patch width: {patch_W}"
81
+
82
+ x = self.proj(x) # B C H W
83
+ H, W = x.size(2), x.size(3)
84
+ x = x.flatten(2).transpose(1, 2) # B HW C
85
+ x = self.norm(x)
86
+ if not self.flatten_embedding:
87
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
88
+ return x
89
+
90
+ def flops(self) -> float:
91
+ Ho, Wo = self.patches_resolution
92
+ flops = (
93
+ Ho
94
+ * Wo
95
+ * self.embed_dim
96
+ * self.in_chans
97
+ * (self.patch_size[0] * self.patch_size[1])
98
+ )
99
+ if self.norm is not None:
100
+ flops += Ho * Wo * self.embed_dim
101
+ return flops
flash3d/unidepth/models/backbones/metadinov2/swiglu_ffn.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, Optional
8
+
9
+ from torch import Tensor, nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class SwiGLUFFN(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_features: int,
17
+ hidden_features: Optional[int] = None,
18
+ out_features: Optional[int] = None,
19
+ act_layer: Callable[..., nn.Module] = None,
20
+ drop: float = 0.0,
21
+ bias: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28
+
29
+ def forward(self, x: Tensor) -> Tensor:
30
+ x12 = self.w12(x)
31
+ x1, x2 = x12.chunk(2, dim=-1)
32
+ hidden = F.silu(x1) * x2
33
+ return self.w3(hidden)
34
+
35
+
36
+ try:
37
+ from xformers.ops import SwiGLU
38
+
39
+ XFORMERS_AVAILABLE = True
40
+ except ImportError:
41
+ SwiGLU = SwiGLUFFN
42
+ XFORMERS_AVAILABLE = False
43
+
44
+
45
+ class SwiGLUFFNFused(SwiGLU):
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ hidden_features: Optional[int] = None,
50
+ out_features: Optional[int] = None,
51
+ act_layer: Callable[..., nn.Module] = None,
52
+ drop: float = 0.0,
53
+ bias: bool = True,
54
+ ) -> None:
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58
+ super().__init__(
59
+ in_features=in_features,
60
+ hidden_features=hidden_features,
61
+ out_features=out_features,
62
+ bias=bias,
63
+ )
flash3d/unidepth/models/encoder.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from unidepth.models.backbones import ConvNeXtV2, _make_dinov2_model, ConvNeXt
5
+
6
+
7
+ class ModelWrap(nn.Module):
8
+ def __init__(self, model) -> None:
9
+ super().__init__()
10
+ self.backbone = model
11
+
12
+ def forward(self, x, *args, **kwargs):
13
+ features = []
14
+ for layer in self.backbone.features:
15
+ x = layer(x)
16
+ features.append(x)
17
+ return features
18
+
19
+
20
+ def convnextv2_base(config, **kwargs):
21
+ model = ConvNeXtV2(
22
+ depths=[3, 3, 27, 3],
23
+ dims=[128, 256, 512, 1024],
24
+ output_idx=config.get("output_idx", [3, 6, 33, 36]),
25
+ use_checkpoint=config.get("use_checkpoint", False),
26
+ **kwargs,
27
+ )
28
+ url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt"
29
+ state_dict = torch.hub.load_state_dict_from_url(
30
+ url, map_location="cpu", progress=False
31
+ )["model"]
32
+ info = model.load_state_dict(state_dict, strict=False)
33
+ print(info)
34
+ return model
35
+
36
+
37
+ def convnextv2_large(config, **kwargs):
38
+ model = ConvNeXtV2(
39
+ depths=[3, 3, 27, 3],
40
+ dims=[192, 384, 768, 1536],
41
+ output_idx=config.get("output_idx", [3, 6, 33, 36]),
42
+ use_checkpoint=config.get("use_checkpoint", False),
43
+ **kwargs,
44
+ )
45
+ url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt"
46
+ state_dict = torch.hub.load_state_dict_from_url(
47
+ url, map_location="cpu", progress=False
48
+ )["model"]
49
+ info = model.load_state_dict(state_dict, strict=False)
50
+ print(info)
51
+ return model
52
+
53
+
54
+ def convnextv2_large_mae(config, **kwargs):
55
+ model = ConvNeXtV2(
56
+ depths=[3, 3, 27, 3],
57
+ dims=[192, 384, 768, 1536],
58
+ output_idx=config.get("output_idx", [3, 6, 33, 36]),
59
+ use_checkpoint=config.get("use_checkpoint", False),
60
+ **kwargs,
61
+ )
62
+ url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt"
63
+ state_dict = torch.hub.load_state_dict_from_url(
64
+ url, map_location="cpu", progress=False
65
+ )["model"]
66
+ info = model.load_state_dict(state_dict, strict=False)
67
+ print(info)
68
+ return model
69
+
70
+
71
+ def convnextv2_huge(config, **kwargs):
72
+ model = ConvNeXtV2(
73
+ depths=[3, 3, 27, 3],
74
+ dims=[352, 704, 1408, 2816],
75
+ output_idx=config.get("output_idx", [3, 6, 33, 36]),
76
+ use_checkpoint=config.get("use_checkpoint", False),
77
+ **kwargs,
78
+ )
79
+ url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt"
80
+ state_dict = torch.hub.load_state_dict_from_url(
81
+ url, map_location="cpu", progress=False
82
+ )["model"]
83
+ info = model.load_state_dict(state_dict, strict=False)
84
+ print(info)
85
+ return model
86
+
87
+
88
+ def convnextv2_huge_mae(config, **kwargs):
89
+ model = ConvNeXtV2(
90
+ depths=[3, 3, 27, 3],
91
+ dims=[352, 704, 1408, 2816],
92
+ output_idx=config.get("output_idx", [3, 6, 33, 36]),
93
+ use_checkpoint=config.get("use_checkpoint", False),
94
+ **kwargs,
95
+ )
96
+ url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt"
97
+ state_dict = torch.hub.load_state_dict_from_url(
98
+ url, map_location="cpu", progress=False
99
+ )["model"]
100
+ info = model.load_state_dict(state_dict, strict=False)
101
+ print(info)
102
+ return model
103
+
104
+
105
+ def convnext_large_pt(config, **kwargs):
106
+ model = ConvNeXt(
107
+ depths=[3, 3, 27, 3],
108
+ dims=[192, 384, 768, 1536],
109
+ output_idx=config.get("output_idx", [3, 6, 33, 36]),
110
+ use_checkpoint=config.get("use_checkpoint", False),
111
+ **kwargs,
112
+ )
113
+ from unidepth.models.backbones.convnext import HF_URL, checkpoint_filter_fn
114
+ from huggingface_hub import hf_hub_download
115
+ from huggingface_hub.utils import disable_progress_bars
116
+
117
+ disable_progress_bars()
118
+ repo_id, filename = HF_URL["convnext_large_pt"]
119
+ state_dict = torch.load(hf_hub_download(repo_id=repo_id, filename=filename))
120
+ state_dict = checkpoint_filter_fn(state_dict, model)
121
+ info = model.load_state_dict(state_dict, strict=False)
122
+ print(info)
123
+ return model
124
+
125
+
126
+ def convnext_large(config, **kwargs):
127
+ model = ConvNeXt(
128
+ depths=[3, 3, 27, 3],
129
+ dims=[192, 384, 768, 1536],
130
+ output_idx=config.get("output_idx", [3, 6, 33, 36]),
131
+ use_checkpoint=config.get("use_checkpoint", False),
132
+ drop_path_rate=config.get("drop_path", 0.0),
133
+ **kwargs,
134
+ )
135
+ return model
136
+
137
+
138
+ def dinov2_vitb14(config, pretrained: bool = True, **kwargs):
139
+ """
140
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
141
+ """
142
+ vit = _make_dinov2_model(
143
+ arch_name="vit_base",
144
+ pretrained=pretrained,
145
+ output_idx=config.get("output_idx", [3, 6, 9, 12]),
146
+ checkpoint=config.get("use_checkpoint", False),
147
+ drop_path_rate=config.get("drop_path", 0.0),
148
+ num_register_tokens=config.get("num_register_tokens", 0),
149
+ **kwargs,
150
+ )
151
+ return vit
152
+
153
+
154
+ def dinov2_vitl14(config, pretrained: str = "", **kwargs):
155
+ """
156
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
157
+ """
158
+ vit = _make_dinov2_model(
159
+ arch_name="vit_large",
160
+ pretrained=config["pretrained"],
161
+ output_idx=config.get("output_idx", [5, 12, 18, 24]),
162
+ checkpoint=config.get("use_checkpoint", False),
163
+ drop_path_rate=config.get("drop_path", 0.0),
164
+ num_register_tokens=config.get("num_register_tokens", 0),
165
+ **kwargs,
166
+ )
167
+ return vit
168
+
169
+
170
+ def dinov2_vitg14(config, pretrained: bool = True, **kwargs):
171
+ """
172
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
173
+ """
174
+ vit = _make_dinov2_model(
175
+ arch_name="vit_giant2",
176
+ ffn_layer="swiglufused",
177
+ pretrained=pretrained,
178
+ output_idx=config.get("output_idx", [10, 20, 30, 40]),
179
+ checkpoint=config.get("use_checkpoint", False),
180
+ drop_path_rate=config.get("drop_path", 0.0),
181
+ num_register_tokens=config.get("num_register_tokens", 0),
182
+ **kwargs,
183
+ )
184
+ return vit
flash3d/unidepth/models/unidepthv1/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .unidepthv1 import UniDepthV1
2
+
3
+ __all__ = [
4
+ "UniDepthV1",
5
+ ]
flash3d/unidepth/models/unidepthv1/decoder.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ from typing import List, Tuple
7
+
8
+ from einops import rearrange
9
+ from timm.models.layers import trunc_normal_
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from unidepth.layers import (
15
+ MLP,
16
+ AttentionBlock,
17
+ NystromBlock,
18
+ PositionEmbeddingSine,
19
+ ConvUpsample,
20
+ )
21
+ from unidepth.utils.sht import rsh_cart_8
22
+ from unidepth.utils.geometric import (
23
+ generate_rays,
24
+ flat_interpolate,
25
+ )
26
+ from unidepth.utils.misc import max_stack
27
+
28
+
29
+ class ListAdapter(nn.Module):
30
+ def __init__(self, input_dims: List[int], hidden_dim: int):
31
+ super().__init__()
32
+ self.input_adapters = nn.ModuleList([])
33
+ self.num_chunks = len(input_dims)
34
+ for input_dim in input_dims:
35
+ self.input_adapters.append(
36
+ nn.Sequential(
37
+ nn.LayerNorm(input_dim), nn.Linear(input_dim, hidden_dim), nn.GELU()
38
+ )
39
+ )
40
+
41
+ def forward(self, x: torch.Tensor, splits: torch.Tensor) -> torch.Tensor:
42
+ xs = torch.split(x, splits.int().tolist(), dim=-1)
43
+ xs = [adapter(x) for x, adapter in zip(xs, self.input_adapters)]
44
+ return torch.cat(xs, dim=-1)
45
+
46
+
47
+ class CameraHead(nn.Module):
48
+ def __init__(
49
+ self,
50
+ input_dim: int,
51
+ hidden_dim: int,
52
+ num_heads: int = 8,
53
+ expansion: int = 4,
54
+ depth: int = 4,
55
+ dropout: float = 0.0,
56
+ layer_scale: float = 1.0,
57
+ **kwargs,
58
+ ):
59
+ super().__init__()
60
+
61
+ self.aggregate = AttentionBlock(
62
+ hidden_dim,
63
+ num_heads=1,
64
+ expansion=expansion,
65
+ dropout=dropout,
66
+ layer_scale=layer_scale,
67
+ )
68
+ self.latents_pos = nn.Parameter(
69
+ torch.randn(1, 4, hidden_dim), requires_grad=True
70
+ )
71
+ self.layers = nn.ModuleList([])
72
+ self.in_features = MLP(hidden_dim, expansion=2, dropout=dropout)
73
+ for _ in range(depth):
74
+ blk = AttentionBlock(
75
+ hidden_dim,
76
+ num_heads=num_heads,
77
+ expansion=expansion,
78
+ dropout=dropout,
79
+ layer_scale=layer_scale,
80
+ )
81
+ self.layers.append(blk)
82
+ self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1)
83
+ self.cls_project = nn.Sequential(
84
+ nn.LayerNorm(input_dim),
85
+ nn.Linear(input_dim, hidden_dim // 2),
86
+ nn.GELU(),
87
+ nn.Linear(hidden_dim // 2, hidden_dim),
88
+ )
89
+
90
+ def forward(self, features, cls_tokens, pos_embed) -> torch.Tensor:
91
+ features = features.unbind(dim=-1)
92
+ cls_tokens = self.cls_project(cls_tokens)
93
+ features_stack = torch.cat(features, dim=1)
94
+ features_stack = features_stack + pos_embed
95
+ latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1)
96
+ features_stack = self.in_features(features_stack)
97
+ features = torch.cat((features_stack, cls_tokens), dim=1)
98
+ cls_tokens = self.aggregate(cls_tokens, context=features, pos_embed=latents_pos)
99
+ for i, layer in enumerate(self.layers):
100
+ cls_tokens = layer(cls_tokens, pos_embed=latents_pos)
101
+
102
+ # project
103
+ x = self.out(cls_tokens).squeeze(-1)
104
+ camera_intrinsics = torch.zeros(
105
+ x.shape[0], 3, 3, device=x.device, requires_grad=False
106
+ )
107
+ camera_intrinsics[:, 0, 0] = x[:, 0].exp()
108
+ camera_intrinsics[:, 1, 1] = x[:, 1].exp()
109
+ camera_intrinsics[:, 0, 2] = x[:, 2].sigmoid()
110
+ camera_intrinsics[:, 1, 2] = x[:, 3].sigmoid()
111
+ camera_intrinsics[:, 2, 2] = 1.0
112
+ return camera_intrinsics
113
+
114
+ def set_shapes(self, shapes: Tuple[int, int]):
115
+ self.shapes = shapes
116
+
117
+
118
+ class DepthHead(nn.Module):
119
+ def __init__(
120
+ self,
121
+ hidden_dim: int,
122
+ num_heads: int = 8,
123
+ expansion: int = 4,
124
+ depths: int | list[int] = 4,
125
+ camera_dim: int = 256,
126
+ num_resolutions: int = 4,
127
+ dropout: float = 0.0,
128
+ layer_scale: float = 1.0,
129
+ **kwargs,
130
+ ) -> None:
131
+ super().__init__()
132
+ if isinstance(depths, int):
133
+ depths = [depths] * 3
134
+ assert len(depths) == 3
135
+
136
+ self.project_rays16 = MLP(
137
+ camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim
138
+ )
139
+ self.project_rays8 = MLP(
140
+ camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 2
141
+ )
142
+ self.project_rays4 = MLP(
143
+ camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 4
144
+ )
145
+ self.to_latents = MLP(hidden_dim, expansion=2, dropout=dropout)
146
+
147
+ self.features_channel_cat = nn.Linear(hidden_dim * num_resolutions, hidden_dim)
148
+
149
+ self.up8 = ConvUpsample(
150
+ hidden_dim, expansion=expansion, layer_scale=layer_scale
151
+ )
152
+ self.up4 = ConvUpsample(
153
+ hidden_dim // 2, expansion=expansion, layer_scale=layer_scale
154
+ )
155
+ self.up2 = ConvUpsample(
156
+ hidden_dim // 4, expansion=expansion, layer_scale=layer_scale
157
+ )
158
+
159
+ self.layers_16 = nn.ModuleList([])
160
+ self.layers_8 = nn.ModuleList([])
161
+ self.layers_4 = nn.ModuleList([])
162
+ self.aggregate_16 = AttentionBlock(
163
+ hidden_dim,
164
+ num_heads=1,
165
+ expansion=expansion,
166
+ dropout=dropout,
167
+ layer_scale=layer_scale,
168
+ context_dim=hidden_dim,
169
+ )
170
+ self.prompt_camera = AttentionBlock(
171
+ hidden_dim,
172
+ num_heads=1,
173
+ expansion=expansion,
174
+ dropout=dropout,
175
+ layer_scale=layer_scale,
176
+ context_dim=hidden_dim,
177
+ )
178
+ for i, (blk_lst, depth) in enumerate(
179
+ zip([self.layers_16, self.layers_8, self.layers_4], depths)
180
+ ):
181
+ attn_cls = AttentionBlock if i == 0 else NystromBlock
182
+ for _ in range(depth):
183
+ blk_lst.append(
184
+ attn_cls(
185
+ hidden_dim // (2**i),
186
+ num_heads=num_heads // (2**i),
187
+ expansion=expansion,
188
+ dropout=dropout,
189
+ layer_scale=layer_scale,
190
+ )
191
+ )
192
+
193
+ self.out2 = nn.Conv2d(hidden_dim // 8, 1, 3, padding=1)
194
+ self.out4 = nn.Conv2d(hidden_dim // 4, 1, 3, padding=1)
195
+ self.out8 = nn.Conv2d(hidden_dim // 2, 1, 3, padding=1)
196
+
197
+ def set_original_shapes(self, shapes: Tuple[int, int]):
198
+ self.original_shapes = shapes
199
+
200
+ def set_shapes(self, shapes: Tuple[int, int]):
201
+ self.shapes = shapes
202
+
203
+ def forward(
204
+ self, features: torch.Tensor, rays_hr: torch.Tensor, pos_embed, level_embed
205
+ ) -> torch.Tensor:
206
+ features = features.unbind(dim=-1)
207
+ shapes = self.shapes
208
+
209
+ # camera_embedding
210
+ # torch.cuda.synchronize()
211
+ # start = time()
212
+ rays_embedding_16 = F.normalize(
213
+ flat_interpolate(rays_hr, old=self.original_shapes, new=shapes), dim=-1
214
+ )
215
+ rays_embedding_8 = F.normalize(
216
+ flat_interpolate(
217
+ rays_hr, old=self.original_shapes, new=[x * 2 for x in shapes]
218
+ ),
219
+ dim=-1,
220
+ )
221
+ rays_embedding_4 = F.normalize(
222
+ flat_interpolate(
223
+ rays_hr, old=self.original_shapes, new=[x * 4 for x in shapes]
224
+ ),
225
+ dim=-1,
226
+ )
227
+ rays_embedding_16 = self.project_rays16(rsh_cart_8(rays_embedding_16))
228
+ rays_embedding_8 = self.project_rays8(rsh_cart_8(rays_embedding_8))
229
+ rays_embedding_4 = self.project_rays4(rsh_cart_8(rays_embedding_4))
230
+ # torch.cuda.synchronize()
231
+ # print(f"camera_embedding took {time() - start} seconds")
232
+ features_tokens = torch.cat(features, dim=1)
233
+ features_tokens_pos = pos_embed + level_embed
234
+
235
+ # Generate latents with init as pooled features
236
+ features_channels = torch.cat(features, dim=-1)
237
+ features_16 = self.features_channel_cat(features_channels)
238
+ latents_16 = self.to_latents(
239
+ flat_interpolate(features_16, old=self.shapes, new=shapes, antialias=False)
240
+ )
241
+
242
+ # Aggregate features: F -> D
243
+ latents_16 = self.aggregate_16(
244
+ latents_16, context=features_tokens, pos_embed_context=features_tokens_pos
245
+ )
246
+
247
+ # Aggregate camera: D- > D|E
248
+ latents_16 = self.prompt_camera(latents_16, context=rays_embedding_16)
249
+
250
+ # Block 16 - Out 8
251
+ for layer in self.layers_16:
252
+ latents_16 = layer(latents_16, pos_embed=rays_embedding_16)
253
+ latents_8 = self.up8(
254
+ rearrange(
255
+ latents_16 + rays_embedding_16,
256
+ "b (h w) c -> b c h w",
257
+ h=shapes[0],
258
+ w=shapes[1],
259
+ ).contiguous()
260
+ )
261
+ out8 = self.out8(
262
+ rearrange(
263
+ latents_8, "b (h w) c -> b c h w", h=shapes[0] * 2, w=shapes[1] * 2
264
+ )
265
+ )
266
+
267
+ # Block 8 - Out 4
268
+ for layer in self.layers_8:
269
+ latents_8 = layer(latents_8, pos_embed=rays_embedding_8)
270
+ latents_4 = self.up4(
271
+ rearrange(
272
+ latents_8 + rays_embedding_8,
273
+ "b (h w) c -> b c h w",
274
+ h=shapes[0] * 2,
275
+ w=shapes[1] * 2,
276
+ ).contiguous()
277
+ )
278
+ out4 = self.out4(
279
+ rearrange(
280
+ latents_4, "b (h w) c -> b c h w", h=shapes[0] * 4, w=shapes[1] * 4
281
+ )
282
+ )
283
+
284
+ # Block 4 - Out 2
285
+ for layer in self.layers_4:
286
+ latents_4 = layer(latents_4, pos_embed=rays_embedding_4)
287
+ latents_2 = self.up2(
288
+ rearrange(
289
+ latents_4 + rays_embedding_4,
290
+ "b (h w) c -> b c h w",
291
+ h=shapes[0] * 4,
292
+ w=shapes[1] * 4,
293
+ ).contiguous()
294
+ )
295
+ out2 = self.out2(
296
+ rearrange(
297
+ latents_2, "b (h w) c -> b c h w", h=shapes[0] * 8, w=shapes[1] * 8
298
+ )
299
+ )
300
+
301
+ # Depth features
302
+ proj_latents_16 = rearrange(
303
+ latents_16, "b (h w) c -> b c h w", h=shapes[0], w=shapes[1]
304
+ ).contiguous()
305
+
306
+ # MS Outputs
307
+ out2 = out2.clamp(-10.0, 10.0).exp()
308
+ out4 = out4.clamp(-10.0, 10.0).exp()
309
+ out8 = out8.clamp(-10.0, 10.0).exp()
310
+
311
+ return out8, out4, out2, proj_latents_16
312
+
313
+
314
+ class Decoder(nn.Module):
315
+ def __init__(
316
+ self,
317
+ config,
318
+ *args,
319
+ **kwargs,
320
+ ):
321
+ super().__init__()
322
+ self.build(config)
323
+ self.apply(self._init_weights)
324
+ self.test_fixed_camera = False
325
+ self.skip_camera = False
326
+
327
+ def _init_weights(self, m):
328
+ if isinstance(m, nn.Linear):
329
+ trunc_normal_(m.weight, std=0.02)
330
+ if m.bias is not None:
331
+ nn.init.constant_(m.bias, 0)
332
+ elif isinstance(m, nn.Conv2d):
333
+ trunc_normal_(m.weight, std=0.02)
334
+ if m.bias is not None:
335
+ nn.init.constant_(m.bias, 0)
336
+ elif isinstance(m, nn.LayerNorm):
337
+ nn.init.constant_(m.bias, 0)
338
+ nn.init.constant_(m.weight, 1.0)
339
+
340
+ def get_adapted_features(self, features_flat, splits):
341
+ features_flat_cat = torch.cat(features_flat, dim=-1)
342
+ features_projected = self.input_adapter(
343
+ features_flat_cat, splits
344
+ ) # list [b hw c] shapes
345
+ features = torch.chunk(features_projected, len(splits), dim=-1)
346
+ return features
347
+
348
+ def run_camera(self, cls_tokens, features, pos_embed, original_shapes, rays):
349
+ # get cls tokens projections
350
+ cls_tokens_splits = torch.tensor(
351
+ [x.shape[-1] for x in cls_tokens],
352
+ device=features.device,
353
+ requires_grad=False,
354
+ dtype=features.dtype,
355
+ )
356
+ cls_tokens = torch.cat(cls_tokens, dim=-1)
357
+ cls_tokens = self.token_adapter(cls_tokens, cls_tokens_splits)
358
+ cls_tokens = torch.cat(
359
+ torch.chunk(cls_tokens, len(cls_tokens_splits), dim=-1), dim=1
360
+ )
361
+
362
+ # camera layer
363
+ intrinsics = self.camera_layer(
364
+ features=features, cls_tokens=cls_tokens, pos_embed=pos_embed
365
+ )
366
+ intrinsics[:, 0, 0] = max(original_shapes) / 2 * intrinsics[:, 0, 0]
367
+ intrinsics[:, 1, 1] = max(original_shapes) / 2 * intrinsics[:, 1, 1]
368
+ intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * original_shapes[1]
369
+ intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * original_shapes[0]
370
+ if not self.test_fixed_camera:
371
+ rays, _ = generate_rays(intrinsics, original_shapes, noisy=False)
372
+
373
+ return intrinsics, rays
374
+
375
+ def forward(self, inputs, image_metas) -> torch.Tensor:
376
+ B, _, H, W = inputs["image"].shape
377
+ device = inputs["image"].device
378
+
379
+ # make stride happy?
380
+ original_encoder_outputs = [x.contiguous() for x in inputs["encoder_outputs"]]
381
+ cls_tokens = [x.contiguous() for x in inputs["cls_tokens"]]
382
+
383
+ # collect features and tokens
384
+ original_encoder_outputs = [
385
+ max_stack(original_encoder_outputs[i:j])
386
+ for i, j in self.slices_encoder_range
387
+ ]
388
+ cls_tokens = [cls_tokens[-i - 1] for i in range(len(self.slices_encoder_range))]
389
+
390
+ # get features in b n d format
391
+ # level shapes, the shape per level, for swin like [[128, 128], [64, 64],...], for vit [[32,32]] -> mult times resolutions
392
+ resolutions = [
393
+ tuple(sorted([x.shape[1], x.shape[2]])) for x in original_encoder_outputs
394
+ ]
395
+ level_shapes = sorted(list(set(resolutions)))[::-1]
396
+
397
+ if len(level_shapes) == 1:
398
+ level_shapes = level_shapes * self.num_resolutions
399
+ input_shapes = [
400
+ level_shapes[i]
401
+ for i, (start, end) in enumerate(self.slices_encoder)
402
+ for _ in range(end - start)
403
+ ]
404
+ common_shape = level_shapes[-2]
405
+
406
+ # input shapes repeat shapes for each level, times the amount of the layers:
407
+ features_flat = [
408
+ flat_interpolate(
409
+ rearrange(x, "b h w c -> b (h w) c"), old=input_shape, new=common_shape
410
+ )
411
+ for x, input_shape in zip(original_encoder_outputs, input_shapes)
412
+ ]
413
+ features_splits = torch.tensor(
414
+ [x.shape[-1] for x in features_flat],
415
+ device=device,
416
+ requires_grad=False,
417
+ dtype=torch.float32,
418
+ )
419
+
420
+ # input adapter, then do mean of features in same blocks
421
+ features = self.get_adapted_features(features_flat, features_splits)
422
+ features = torch.stack(features, dim=-1)
423
+
424
+ # positional embeddings, spatial and level
425
+ level_embed = torch.cat(
426
+ [
427
+ self.level_embed_layer(self.level_embeds)[i : i + 1]
428
+ .unsqueeze(0)
429
+ .repeat(B, common_shape[0] * common_shape[1], 1)
430
+ for i in range(self.num_resolutions)
431
+ ],
432
+ dim=1,
433
+ )
434
+ pos_embed = self.pos_embed(
435
+ torch.zeros(
436
+ B,
437
+ 1,
438
+ common_shape[0],
439
+ common_shape[1],
440
+ device=device,
441
+ requires_grad=False,
442
+ )
443
+ )
444
+ pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat(
445
+ 1, self.num_resolutions, 1
446
+ )
447
+
448
+ self.camera_layer.set_shapes(common_shape)
449
+ intrinsics, rays = (
450
+ self.run_camera(
451
+ cls_tokens,
452
+ features=features,
453
+ pos_embed=pos_embed + level_embed,
454
+ original_shapes=(H, W),
455
+ rays=inputs.get("rays", None),
456
+ )
457
+ if not self.skip_camera
458
+ else (inputs["K"], inputs["rays"])
459
+ )
460
+
461
+ # run bulk of the model
462
+ self.depth_layer.set_shapes(common_shape)
463
+ self.depth_layer.set_original_shapes((H, W))
464
+ out8, out4, out2, depth_features = self.depth_layer(
465
+ features=features,
466
+ rays_hr=rays,
467
+ pos_embed=pos_embed,
468
+ level_embed=level_embed,
469
+ )
470
+
471
+ return intrinsics, [out8, out4, out2], depth_features, rays
472
+
473
+ @torch.jit.ignore
474
+ def no_weight_decay_keywords(self):
475
+ return {"latents_pos", "level_embeds"}
476
+
477
+ def build(self, config):
478
+ depth = config["model"]["pixel_decoder"]["depths"]
479
+ input_dims = config["model"]["pixel_encoder"]["embed_dims"]
480
+ hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"]
481
+ num_heads = config["model"]["num_heads"]
482
+ expansion = config["model"]["expansion"]
483
+ dropout = config["model"]["pixel_decoder"]["dropout"]
484
+ depths_encoder = config["model"]["pixel_encoder"]["depths"]
485
+ num_steps = config["model"].get("num_steps", 100000)
486
+ layer_scale = 1.0
487
+
488
+ self.depth = depth
489
+ self.dim = hidden_dim
490
+ self.downsample = 4
491
+ self.num_heads = num_heads
492
+ self.num_resolutions = len(depths_encoder)
493
+ self.depths_encoder = depths_encoder
494
+
495
+ self.slices_encoder_single = list(
496
+ zip([d - 1 for d in self.depths_encoder], self.depths_encoder)
497
+ )
498
+ self.slices_encoder_range = list(
499
+ zip([0, *self.depths_encoder[:-1]], self.depths_encoder)
500
+ )
501
+ cls_token_input_dims = [input_dims[-i - 1] for i in range(len(depths_encoder))]
502
+
503
+ input_dims = [input_dims[d - 1] for d in depths_encoder]
504
+ self.slices_encoder = self.slices_encoder_single
505
+
506
+ # adapt from encoder features, just project
507
+ self.input_adapter = ListAdapter(input_dims, hidden_dim)
508
+ self.token_adapter = ListAdapter(cls_token_input_dims, hidden_dim)
509
+
510
+ # camera layer
511
+ self.camera_layer = CameraHead(
512
+ input_dim=hidden_dim,
513
+ hidden_dim=hidden_dim,
514
+ num_heads=num_heads,
515
+ expansion=expansion,
516
+ depth=2,
517
+ dropout=dropout,
518
+ layer_scale=layer_scale,
519
+ )
520
+
521
+ self.depth_layer = DepthHead(
522
+ hidden_dim=hidden_dim,
523
+ num_heads=num_heads,
524
+ expansion=expansion,
525
+ depths=depth,
526
+ dropout=dropout,
527
+ camera_dim=81,
528
+ num_resolutions=self.num_resolutions,
529
+ layer_scale=layer_scale,
530
+ )
531
+
532
+ # transformer part
533
+ self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
534
+ self.level_embeds = nn.Parameter(
535
+ torch.randn(len(input_dims), hidden_dim), requires_grad=True
536
+ )
537
+ self.level_embed_layer = nn.Sequential(
538
+ nn.Linear(hidden_dim, hidden_dim),
539
+ nn.GELU(),
540
+ nn.Linear(hidden_dim, hidden_dim),
541
+ nn.LayerNorm(hidden_dim),
542
+ )
flash3d/unidepth/models/unidepthv1/unidepthv1.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ from copy import deepcopy
7
+ import importlib
8
+ from typing import Any, Dict, Tuple
9
+ from math import ceil
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torchvision.transforms.functional as TF
15
+ from einops import rearrange
16
+
17
+ from unidepth.utils.geometric import (
18
+ generate_rays,
19
+ spherical_zbuffer_to_euclidean,
20
+ )
21
+ from unidepth.utils.misc import get_params
22
+ from unidepth.utils.distributed import is_main_process
23
+ from unidepth.utils.constants import IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD
24
+ from unidepth.models.unidepthv1.decoder import Decoder
25
+
26
+ from huggingface_hub import PyTorchModelHubMixin
27
+
28
+
29
+ MAP_BACKBONES = {"ViTL14": "vitl14", "ConvNextL": "cnvnxtl"}
30
+
31
+
32
+ # inference helpers
33
+ def _paddings(image_shape, network_shape):
34
+ cur_h, cur_w = image_shape
35
+ h, w = network_shape
36
+ pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2
37
+ pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2
38
+ return pad_left, pad_right, pad_top, pad_bottom
39
+
40
+
41
+ def _shapes(image_shape, network_shape):
42
+ h, w = image_shape
43
+ input_ratio = w / h
44
+ output_ratio = network_shape[1] / network_shape[0]
45
+ if output_ratio > input_ratio:
46
+ ratio = network_shape[0] / h
47
+ elif output_ratio <= input_ratio:
48
+ ratio = network_shape[1] / w
49
+ return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio
50
+
51
+
52
+ def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes):
53
+ (pad_left, pad_right, pad_top, pad_bottom) = pads
54
+ rgbs = F.interpolate(
55
+ rgbs, size=shapes, mode="bilinear", align_corners=False, antialias=True
56
+ )
57
+ rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant")
58
+ if intrinsics is not None:
59
+ intrinsics = intrinsics.clone()
60
+ intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio
61
+ intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio
62
+ intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + pad_left
63
+ intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + pad_top
64
+ return rgbs, intrinsics
65
+ return rgbs, None
66
+
67
+
68
+ def _postprocess(predictions, intrinsics, shapes, pads, ratio, original_shapes):
69
+ (pad_left, pad_right, pad_top, pad_bottom) = pads
70
+ # pred mean, trim paddings, and upsample to input dim
71
+ predictions = sum(
72
+ [
73
+ F.interpolate(
74
+ x.clone(),
75
+ size=shapes,
76
+ mode="bilinear",
77
+ align_corners=False,
78
+ antialias=True,
79
+ )
80
+ for x in predictions
81
+ ]
82
+ ) / len(predictions)
83
+ predictions = predictions[
84
+ ..., pad_top : shapes[0] - pad_bottom, pad_left : shapes[1] - pad_right
85
+ ]
86
+ predictions = F.interpolate(
87
+ predictions,
88
+ size=original_shapes,
89
+ mode="bilinear",
90
+ align_corners=False,
91
+ antialias=True,
92
+ )
93
+ intrinsics[:, 0, 0] = intrinsics[:, 0, 0] / ratio
94
+ intrinsics[:, 1, 1] = intrinsics[:, 1, 1] / ratio
95
+ intrinsics[:, 0, 2] = (intrinsics[:, 0, 2] - pad_left) / ratio
96
+ intrinsics[:, 1, 2] = (intrinsics[:, 1, 2] - pad_top) / ratio
97
+ return predictions, intrinsics
98
+
99
+
100
+ class UniDepthV1(nn.Module,
101
+ PyTorchModelHubMixin,
102
+ library_name="UniDepth",
103
+ repo_url="https://github.com/lpiccinelli-eth/UniDepth",
104
+ tags=["monocular-metric-depth-estimation"]):
105
+ def __init__(
106
+ self,
107
+ config,
108
+ eps: float = 1e-6,
109
+ **kwargs,
110
+ ):
111
+ super().__init__()
112
+ self.build(config)
113
+ self.eps = eps
114
+
115
+ def forward(self, inputs, image_metas):
116
+ rgbs = inputs["image"]
117
+ gt_intrinsics = inputs.get("K")
118
+ H, W = rgbs.shape[-2:]
119
+
120
+ # Encode
121
+ encoder_outputs, cls_tokens = self.pixel_encoder(rgbs)
122
+ if "dino" in self.pixel_encoder.__class__.__name__.lower():
123
+ encoder_outputs = [
124
+ (x + y.unsqueeze(1)).contiguous()
125
+ for x, y in zip(encoder_outputs, cls_tokens)
126
+ ]
127
+ inputs["encoder_outputs"] = encoder_outputs
128
+ inputs["cls_tokens"] = cls_tokens
129
+
130
+ # Get camera infos, if any
131
+ if gt_intrinsics is not None:
132
+ rays, angles = generate_rays(
133
+ gt_intrinsics, self.image_shape, noisy=self.training
134
+ )
135
+ inputs["rays"] = rays
136
+ inputs["angles"] = angles
137
+ inputs["K"] = gt_intrinsics
138
+ self.pixel_decoder.test_fixed_camera = True # use GT camera in fwd
139
+
140
+ # Decode
141
+ pred_intrinsics, predictions, _, _ = self.pixel_decoder(inputs, {})
142
+ predictions = sum(
143
+ [
144
+ F.interpolate(
145
+ x.clone(),
146
+ size=self.image_shape,
147
+ mode="bilinear",
148
+ align_corners=False,
149
+ antialias=True,
150
+ )
151
+ for x in predictions
152
+ ]
153
+ ) / len(predictions)
154
+
155
+ # Final 3D points backprojection
156
+ pred_angles = generate_rays(pred_intrinsics, (H, W), noisy=False)[-1]
157
+ # You may want to use inputs["angles"] if available?
158
+ pred_angles = rearrange(pred_angles, "b (h w) c -> b c h w", h=H, w=W)
159
+ points_3d = torch.cat((pred_angles, predictions), dim=1)
160
+ points_3d = spherical_zbuffer_to_euclidean(
161
+ points_3d.permute(0, 2, 3, 1)
162
+ ).permute(0, 3, 1, 2)
163
+
164
+ # Output data, use for loss computation
165
+ outputs = {
166
+ "angles": pred_angles,
167
+ "intrinsics": pred_intrinsics,
168
+ "points": points_3d,
169
+ "depth": predictions[:, -1:],
170
+ }
171
+ self.pixel_decoder.test_fixed_camera = False
172
+ return outputs
173
+
174
+ @torch.no_grad()
175
+ def infer(self, rgbs: torch.Tensor, intrinsics=None, skip_camera=False):
176
+ if rgbs.ndim == 3:
177
+ rgbs = rgbs.unsqueeze(0)
178
+ if intrinsics is not None and intrinsics.ndim == 2:
179
+ intrinsics = intrinsics.unsqueeze(0)
180
+ B, _, H, W = rgbs.shape
181
+
182
+ rgbs = rgbs.to(self.device)
183
+ if intrinsics is not None:
184
+ intrinsics = intrinsics.to(self.device)
185
+
186
+ # process image and intrinsiscs (if any) to match network input (slow?)
187
+ if rgbs.max() > 5 or rgbs.dtype == torch.uint8:
188
+ rgbs = TF.normalize(
189
+ rgbs.to(torch.float32).div(255),
190
+ mean=IMAGENET_DATASET_MEAN,
191
+ std=IMAGENET_DATASET_STD,
192
+ )
193
+ else:
194
+ pass
195
+ # print("Image not normalized, was it already normalized?")
196
+ (h, w), ratio = _shapes((H, W), self.image_shape)
197
+ pad_left, pad_right, pad_top, pad_bottom = _paddings((h, w), self.image_shape)
198
+ rgbs, gt_intrinsics = _preprocess(
199
+ rgbs,
200
+ intrinsics,
201
+ (h, w),
202
+ (pad_left, pad_right, pad_top, pad_bottom),
203
+ ratio,
204
+ self.image_shape,
205
+ )
206
+
207
+ # run encoder
208
+ encoder_outputs, cls_tokens = self.pixel_encoder(rgbs)
209
+ if "dino" in self.pixel_encoder.__class__.__name__.lower():
210
+ encoder_outputs = [
211
+ (x + y.unsqueeze(1)).contiguous()
212
+ for x, y in zip(encoder_outputs, cls_tokens)
213
+ ]
214
+
215
+ # get data for decoder and adapt to given camera
216
+ inputs = {}
217
+ inputs["encoder_outputs"] = encoder_outputs
218
+ inputs["cls_tokens"] = cls_tokens
219
+ inputs["image"] = rgbs
220
+ if gt_intrinsics is not None:
221
+ rays, angles = generate_rays(
222
+ gt_intrinsics, self.image_shape, noisy=self.training
223
+ )
224
+ inputs["rays"] = rays
225
+ inputs["angles"] = angles
226
+ inputs["K"] = gt_intrinsics
227
+ self.pixel_decoder.test_fixed_camera = True
228
+ self.pixel_decoder.skip_camera = skip_camera
229
+
230
+ # decode all
231
+ pred_intrinsics, predictions, _, _ = self.pixel_decoder(inputs, {})
232
+
233
+ # undo the reshaping and get original image size (slow)
234
+ predictions, pred_intrinsics = _postprocess(
235
+ predictions,
236
+ pred_intrinsics,
237
+ self.image_shape,
238
+ (pad_left, pad_right, pad_top, pad_bottom),
239
+ ratio,
240
+ (H, W),
241
+ )
242
+
243
+ # final 3D points backprojection
244
+ intrinsics = gt_intrinsics if gt_intrinsics is not None else pred_intrinsics
245
+ angles = generate_rays(intrinsics, (H, W), noisy=False)[-1]
246
+ angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W)
247
+ points_3d = torch.cat((angles, predictions), dim=1)
248
+ points_3d = spherical_zbuffer_to_euclidean(
249
+ points_3d.permute(0, 2, 3, 1)
250
+ ).permute(0, 3, 1, 2)
251
+
252
+ # output data
253
+ outputs = {
254
+ "intrinsics": pred_intrinsics,
255
+ "points": points_3d,
256
+ "depth": predictions[:, -1:],
257
+ }
258
+ self.pixel_decoder.test_fixed_camera = False
259
+ self.pixel_decoder.skip_camera = False
260
+ return outputs
261
+
262
+ def load_pretrained(self, model_file):
263
+ device = (
264
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
265
+ )
266
+ dict_model = torch.load(model_file, map_location=device)
267
+
268
+ if "model" in dict_model:
269
+ dict_model = dict_model["model"]
270
+ new_state_dict = deepcopy(
271
+ {k.replace("module.", ""): v for k, v in dict_model.items()}
272
+ )
273
+
274
+ info = self.load_state_dict(new_state_dict, strict=False)
275
+ if is_main_process():
276
+ print(
277
+ f"Loaded from {model_file} for {self.__class__.__name__} results in:",
278
+ info,
279
+ )
280
+
281
+ def get_params(self, config):
282
+ if hasattr(self.pixel_encoder, "get_params"):
283
+ encoder_p, encoder_lr = self.pixel_encoder.get_params(
284
+ config["model"]["pixel_encoder"]["lr"],
285
+ config["training"]["wd"],
286
+ config["training"]["ld"],
287
+ )
288
+ else:
289
+ encoder_p, encoder_lr = get_params(
290
+ self.pixel_encoder,
291
+ config["model"]["pixel_encoder"]["lr"],
292
+ config["training"]["wd"],
293
+ )
294
+ decoder_p, decoder_lr = get_params(
295
+ self.pixel_decoder, config["training"]["lr"], config["training"]["wd"]
296
+ )
297
+ return [*encoder_p, *decoder_p], [*encoder_lr, *decoder_lr]
298
+
299
+ @property
300
+ def device(self):
301
+ return next(self.parameters()).device
302
+
303
+ def build(self, config: Dict[str, Dict[str, Any]]):
304
+ mod = importlib.import_module("unidepth.models.encoder")
305
+ pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"])
306
+ pixel_encoder_config = {
307
+ **config["training"],
308
+ **config["data"],
309
+ **config["model"]["pixel_encoder"],
310
+ }
311
+ pixel_encoder = pixel_encoder_factory(pixel_encoder_config)
312
+
313
+ config["model"]["pixel_encoder"]["patch_size"] = (
314
+ 14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16
315
+ )
316
+ pixel_encoder_embed_dims = (
317
+ pixel_encoder.embed_dims
318
+ if hasattr(pixel_encoder, "embed_dims")
319
+ else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)]
320
+ )
321
+ config["model"]["pixel_encoder"]["embed_dim"] = getattr(
322
+ pixel_encoder, "embed_dim"
323
+ )
324
+ config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims
325
+ config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths
326
+
327
+ self.pixel_encoder = pixel_encoder
328
+ self.pixel_decoder = Decoder(config)
329
+ self.image_shape = config["data"]["image_shape"]
flash3d/unidepth/ops/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .losses import SILog, MSE, SelfCons
2
+ from .scheduler import CosineScheduler
3
+
4
+ __all__ = [
5
+ "SILog",
6
+ "MSE",
7
+ "SelfCons",
8
+ "CosineScheduler",
9
+ ]
flash3d/unidepth/ops/losses.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ from typing import Any, Optional, Dict, Tuple, List
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ FNS = {
14
+ "sqrt": torch.sqrt,
15
+ "log": torch.log,
16
+ "log1": lambda x: torch.log(x + 1),
17
+ "linear": lambda x: x,
18
+ "square": torch.square,
19
+ "disp": lambda x: 1 / x,
20
+ }
21
+
22
+
23
+ FNS_INV = {
24
+ "sqrt": torch.square,
25
+ "log": torch.exp,
26
+ "log1": lambda x: torch.exp(x) - 1,
27
+ "linear": lambda x: x,
28
+ "square": torch.sqrt,
29
+ "disp": lambda x: 1 / x,
30
+ }
31
+
32
+
33
+ def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
34
+ if mask is None:
35
+ return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
36
+ mask = mask.float()
37
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
38
+ mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
39
+ mask_sum, min=1.0
40
+ )
41
+ mask_var = torch.sum(
42
+ mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
43
+ ) / torch.clamp(mask_sum, min=1.0)
44
+ return mask_mean.squeeze(dim), mask_var.squeeze(dim)
45
+
46
+
47
+ def masked_mean(data: torch.Tensor, mask: torch.Tensor | None, dim: List[int]):
48
+ if mask is None:
49
+ return data.mean(dim=dim, keepdim=True)
50
+ mask = mask.float()
51
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
52
+ mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
53
+ mask_sum, min=1.0
54
+ )
55
+ return mask_mean
56
+
57
+
58
+ def masked_mae(data: torch.Tensor, mask: torch.Tensor, dim: Tuple[int, ...]):
59
+ if mask is None:
60
+ return data.abs().mean(dim=dim, keepdim=True)
61
+ mask = mask.float()
62
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
63
+ mask_mean = torch.sum(data.abs() * mask, dim=dim, keepdim=True) / torch.clamp(
64
+ mask_sum, min=1.0
65
+ )
66
+ return mask_mean
67
+
68
+
69
+ def masked_mse(data: torch.Tensor, mask: torch.Tensor, dim: Tuple[int, ...]):
70
+ if mask is None:
71
+ return (data**2).mean(dim=dim, keepdim=True)
72
+ mask = mask.float()
73
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
74
+ mask_mean = torch.sum((data**2) * mask, dim=dim, keepdim=True) / torch.clamp(
75
+ mask_sum, min=1.0
76
+ )
77
+ return mask_mean
78
+
79
+
80
+ def masked_median(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
81
+ ndim = data.ndim
82
+ data = data.flatten(ndim - len(dim))
83
+ mask = mask.flatten(ndim - len(dim))
84
+ mask_median = torch.median(data[mask], dim=-1).values
85
+ return mask_median
86
+
87
+
88
+ def masked_median_mad(data: torch.Tensor, mask: torch.Tensor):
89
+ data = data.flatten()
90
+ mask = mask.flatten()
91
+ mask_median = torch.median(data[mask])
92
+ n_samples = torch.clamp(torch.sum(mask.float()), min=1.0)
93
+ mask_mad = torch.sum((data[mask] - mask_median).abs()) / n_samples
94
+ return mask_median, mask_mad
95
+
96
+
97
+ def masked_weighted_mean_var(
98
+ data: torch.Tensor, mask: torch.Tensor, weights: torch.Tensor, dim: Tuple[int, ...]
99
+ ):
100
+ if mask is None:
101
+ return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
102
+ mask = mask.float()
103
+ mask_mean = torch.sum(data * mask * weights, dim=dim, keepdim=True) / torch.sum(
104
+ mask * weights, dim=dim, keepdim=True
105
+ ).clamp(min=1.0)
106
+ # V1**2 - V2, V1: sum w_i, V2: sum w_i**2
107
+ denom = torch.sum(weights * mask, dim=dim, keepdim=True).square() - torch.sum(
108
+ (mask * weights).square(), dim=dim, keepdim=True
109
+ )
110
+ # correction is V1 / (V1**2 - V2), if w_i=1 => N/(N**2 - N) => 1/(N-1) (unbiased estimator of variance, cvd)
111
+ correction_factor = torch.sum(mask * weights, dim=dim, keepdim=True) / denom.clamp(
112
+ min=1.0
113
+ )
114
+ mask_var = correction_factor * torch.sum(
115
+ weights * mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
116
+ )
117
+ return mask_mean, mask_var
118
+
119
+
120
+ def masked_mean_var_q(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
121
+ if mask is None:
122
+ return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
123
+ mask = mask.float()
124
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
125
+ mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
126
+ mask_sum, min=1.0
127
+ )
128
+ mask_var = torch.sum(
129
+ mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
130
+ ) / torch.clamp(mask_sum, min=1.0)
131
+ return mask_mean, mask_var
132
+
133
+
134
+ class SILog(nn.Module):
135
+ def __init__(
136
+ self,
137
+ weight: float,
138
+ scale_pred_weight: float = 0.15,
139
+ output_fn: str = "sqrt",
140
+ input_fn: str = "log",
141
+ legacy: bool = False,
142
+ abs_rel: bool = False,
143
+ norm: bool = False,
144
+ eps: float = 1e-5,
145
+ ):
146
+ super().__init__()
147
+ assert output_fn in FNS
148
+ self.name: str = self.__class__.__name__
149
+ self.weight: float = weight
150
+
151
+ self.scale_pred_weight: float = scale_pred_weight
152
+ self.dims = (-4, -3, -2, -1) if legacy else (-2, -1)
153
+ self.output_fn = FNS[output_fn]
154
+ self.input_fn = FNS[input_fn]
155
+ self.abs_rel = abs_rel
156
+ self.norm = norm
157
+ self.eps: float = eps
158
+
159
+ @torch.cuda.amp.autocast(enabled=False)
160
+ def forward(
161
+ self,
162
+ input: torch.Tensor,
163
+ target: torch.Tensor,
164
+ mask: Optional[torch.Tensor] = None,
165
+ interpolate: bool = True,
166
+ scale_inv: torch.Tensor | None = None,
167
+ ss_inv: torch.Tensor | None = None,
168
+ **kwargs
169
+ ) -> torch.Tensor:
170
+ if interpolate:
171
+ input = F.interpolate(
172
+ input, target.shape[-2:], mode="bilinear", align_corners=False
173
+ )
174
+ if mask is not None:
175
+ mask = mask.to(torch.bool)
176
+ if ss_inv is not None:
177
+ ss_inv = ~ss_inv
178
+
179
+ if input.shape[1] > 1:
180
+ input_ = torch.cat(
181
+ [input[:, :-1], self.input_fn(input[:, -1:].clamp(min=self.eps))], dim=1
182
+ )
183
+ target_ = torch.cat(
184
+ [target[:, :-1], self.input_fn(target[:, -1:].clamp(min=self.eps))],
185
+ dim=1,
186
+ )
187
+ error = torch.norm(input_ - target_, dim=1, keepdim=True)
188
+ else:
189
+ input_ = self.input_fn(input.clamp(min=self.eps))
190
+ target_ = self.input_fn(target.clamp(min=self.eps))
191
+ error = input_ - target_
192
+
193
+ mean_error, var_error = masked_mean_var(data=error, mask=mask, dim=self.dims)
194
+
195
+ # prevoiusly was inverted!!
196
+ if self.abs_rel:
197
+ scale_error = (input - target).abs()[:, -1:] / target[:, -1:].clip(
198
+ min=self.eps
199
+ )
200
+ scale_error = masked_mean(data=scale_error, mask=mask, dim=self.dims)
201
+ else:
202
+ scale_error = mean_error**2
203
+
204
+ if var_error.ndim > 1:
205
+ var_error = var_error.sum(dim=1)
206
+ scale_error = scale_error.sum(dim=1)
207
+
208
+ # if scale inv -> mask scale error, if scale/shift, mask the full loss
209
+ if scale_inv is not None:
210
+ scale_error = (1 - scale_inv.int()) * scale_error
211
+ scale_error = self.scale_pred_weight * scale_error
212
+ loss = var_error + scale_error
213
+ out_loss = self.output_fn(loss.clamp(min=self.eps))
214
+ out_loss = masked_mean(data=out_loss, mask=ss_inv, dim=(0,))
215
+ return out_loss.mean()
216
+
217
+ @classmethod
218
+ def build(cls, config: Dict[str, Any]):
219
+ obj = cls(
220
+ weight=config["weight"],
221
+ legacy=config["legacy"],
222
+ output_fn=config["output_fn"],
223
+ input_fn=config["input_fn"],
224
+ norm=config.get("norm", False),
225
+ scale_pred_weight=config.get("gamma", 0.15),
226
+ abs_rel=config.get("abs_rel", False),
227
+ )
228
+ return obj
229
+
230
+
231
+ class MSE(nn.Module):
232
+ def __init__(
233
+ self,
234
+ weight: float = 1.0,
235
+ input_fn: str = "linear",
236
+ output_fn: str = "linear",
237
+ ):
238
+ super().__init__()
239
+ self.name: str = self.__class__.__name__
240
+ self.output_fn = FNS[output_fn]
241
+ self.input_fn = FNS[input_fn]
242
+ self.weight: float = weight
243
+ self.eps = 1e-6
244
+
245
+ @torch.cuda.amp.autocast(enabled=False)
246
+ def forward(
247
+ self,
248
+ input: torch.Tensor,
249
+ target: torch.Tensor,
250
+ mask: torch.Tensor | None = None,
251
+ batch_mask: torch.Tensor | None = None,
252
+ **kwargs
253
+ ) -> torch.Tensor:
254
+ input = input[..., : target.shape[-1]] # B N C or B H W C
255
+ error = self.input_fn(input + self.eps) - self.input_fn(target + self.eps)
256
+ abs_error = torch.square(error).sum(dim=-1)
257
+ mean_error = masked_mean(data=abs_error, mask=mask, dim=(-1,)).mean(dim=-1)
258
+ batched_error = masked_mean(
259
+ self.output_fn(mean_error.clamp(self.eps)), batch_mask, dim=(0,)
260
+ )
261
+ return batched_error.mean(), mean_error.detach()
262
+
263
+ @classmethod
264
+ def build(cls, config: Dict[str, Any]):
265
+ obj = cls(
266
+ weight=config["weight"],
267
+ output_fn=config["output_fn"],
268
+ input_fn=config["input_fn"],
269
+ )
270
+ return obj
271
+
272
+
273
+ class SelfCons(nn.Module):
274
+ def __init__(
275
+ self,
276
+ weight: float,
277
+ scale_pred_weight: float = 0.15,
278
+ output_fn: str = "sqrt",
279
+ input_fn: str = "log",
280
+ abs_rel: bool = False,
281
+ norm: bool = False,
282
+ eps: float = 1e-5,
283
+ ):
284
+ super().__init__()
285
+ assert output_fn in FNS
286
+ self.name: str = self.__class__.__name__
287
+ self.weight: float = weight
288
+
289
+ self.scale_pred_weight: float = scale_pred_weight
290
+ self.dims = (-2, -1)
291
+ self.output_fn = FNS[output_fn]
292
+ self.input_fn = FNS[input_fn]
293
+ self.abs_rel = abs_rel
294
+ self.norm = norm
295
+ self.eps: float = eps
296
+
297
+ @torch.cuda.amp.autocast(enabled=False)
298
+ def forward(
299
+ self,
300
+ input: torch.Tensor,
301
+ mask: torch.Tensor,
302
+ metas: List[Dict[str, torch.Tensor]],
303
+ ) -> torch.Tensor:
304
+ chunks = input.shape[0] // 2
305
+ device = input.device
306
+ mask = F.interpolate(mask.float(), size=input.shape[-2:], mode="nearest")
307
+
308
+ rescales = input.shape[-2] / torch.tensor(
309
+ [x["resized_shape"][0] for x in metas], device=device
310
+ )
311
+ cams = torch.cat([x["K_target"] for x in metas], dim=0).to(device)
312
+ flips = torch.tensor([x["flip"] for x in metas], device=device)
313
+
314
+ iters = zip(
315
+ input.chunk(chunks),
316
+ mask.chunk(chunks),
317
+ cams.chunk(chunks),
318
+ rescales.chunk(chunks),
319
+ flips.chunk(chunks),
320
+ )
321
+ inputs0, inputs1, masks = [], [], []
322
+ for i, (pair_input, pair_mask, pair_cam, pair_rescale, pair_flip) in enumerate(
323
+ iters
324
+ ):
325
+ mask0, mask1 = pair_mask
326
+ input0, input1 = pair_input
327
+ cam0, cam1 = pair_cam
328
+ rescale0, rescale1 = pair_rescale
329
+ flip0, flip1 = pair_flip
330
+
331
+ fx_0 = cam0[0, 0] * rescale0
332
+ fx_1 = cam1[0, 0] * rescale1
333
+ cx_0 = (cam0[0, 2] - 0.5) * rescale0 + 0.5
334
+ cx_1 = (cam1[0, 2] - 0.5) * rescale1 + 0.5
335
+ cy_0 = (cam0[1, 2] - 0.5) * rescale0 + 0.5
336
+ cy_1 = (cam1[1, 2] - 0.5) * rescale1 + 0.5
337
+
338
+ # flip image
339
+ if flip0 ^ flip1:
340
+ input0 = torch.flip(input0, dims=(2,))
341
+ mask0 = torch.flip(mask0, dims=(2,))
342
+ cx_0 = input0.shape[-1] - cx_0
343
+
344
+ # calc zoom
345
+ zoom_x = float(fx_1 / fx_0)
346
+
347
+ # apply zoom
348
+ input0 = F.interpolate(
349
+ input0.unsqueeze(0),
350
+ scale_factor=zoom_x,
351
+ mode="bilinear",
352
+ align_corners=True,
353
+ ).squeeze(0)
354
+ mask0 = F.interpolate(
355
+ mask0.unsqueeze(0), scale_factor=zoom_x, mode="nearest"
356
+ ).squeeze(0)
357
+
358
+ # calc translation
359
+ change_left = int(cx_1 - (cx_0 - 0.5) * zoom_x - 0.5)
360
+ change_top = int(cy_1 - (cy_0 - 0.5) * zoom_x - 0.5)
361
+ change_right = input1.shape[-1] - change_left - input0.shape[-1]
362
+ change_bottom = input1.shape[-2] - change_top - input0.shape[-2]
363
+
364
+ # apply translation
365
+ pad_left = max(0, change_left)
366
+ pad_right = max(0, change_right)
367
+ pad_top = max(0, change_top)
368
+ pad_bottom = max(0, change_bottom)
369
+
370
+ crop_left = max(0, -change_left)
371
+ crop_right = max(0, -change_right)
372
+ crop_top = max(0, -change_top)
373
+ crop_bottom = max(0, -change_bottom)
374
+
375
+ input0 = F.pad(
376
+ input0,
377
+ (pad_left, pad_right, pad_top, pad_bottom),
378
+ mode="constant",
379
+ value=0,
380
+ )
381
+ mask0 = F.pad(
382
+ mask0,
383
+ (pad_left, pad_right, pad_top, pad_bottom),
384
+ mode="constant",
385
+ value=0,
386
+ )
387
+ input0 = input0[
388
+ :,
389
+ crop_top : input0.shape[-2] - crop_bottom,
390
+ crop_left : input0.shape[-1] - crop_right,
391
+ ]
392
+ mask0 = mask0[
393
+ :,
394
+ crop_top : mask0.shape[-2] - crop_bottom,
395
+ crop_left : mask0.shape[-1] - crop_right,
396
+ ]
397
+
398
+ mask = torch.logical_and(mask0, mask1)
399
+
400
+ inputs0.append(input0)
401
+ inputs1.append(input1)
402
+ masks.append(mask)
403
+
404
+ inputs0 = torch.stack(inputs0, dim=0)
405
+ inputs1 = torch.stack(inputs1, dim=0)
406
+ masks = torch.stack(masks, dim=0)
407
+ loss1 = self.loss(inputs0, inputs1.detach(), masks)
408
+ loss2 = self.loss(inputs1, inputs0.detach(), masks)
409
+ return torch.cat([loss1, loss2], dim=0).mean()
410
+
411
+ def loss(
412
+ self,
413
+ input: torch.Tensor,
414
+ target: torch.Tensor,
415
+ mask: torch.Tensor,
416
+ ) -> torch.Tensor:
417
+ loss = masked_mean(
418
+ (input - target).square().mean(dim=1), mask=mask, dim=(-2, -1)
419
+ )
420
+ return self.output_fn(loss + self.eps)
421
+
422
+ @classmethod
423
+ def build(cls, config: Dict[str, Any]):
424
+ obj = cls(
425
+ weight=config["weight"],
426
+ output_fn=config["output_fn"],
427
+ input_fn=config["input_fn"],
428
+ )
429
+ return obj
flash3d/unidepth/ops/scheduler.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ import numpy as np
7
+
8
+
9
+ class CosineScheduler(object):
10
+ def __init__(
11
+ self,
12
+ optimizer,
13
+ warmup_iters,
14
+ total_iters,
15
+ key,
16
+ overwrite=False,
17
+ init_value=None,
18
+ base_value=None,
19
+ final_value=None,
20
+ step_init=-1,
21
+ ):
22
+ super().__init__()
23
+ self.iter = step_init
24
+ self.overwrite = overwrite
25
+ self.optimizer = optimizer
26
+ self.base_value = base_value
27
+ self.init_value = init_value
28
+ self.final_value = final_value
29
+ self.total_iters = total_iters
30
+ self.warmup_iters = warmup_iters
31
+ self.key = key
32
+ self.schedulers = [
33
+ self.get_schedulers(group) for group in optimizer.param_groups
34
+ ]
35
+
36
+ def get_schedulers(self, group):
37
+ init_value = group.get(self.key + "_init", self.init_value)
38
+ base_value = group.get(self.key + "_base", self.base_value)
39
+ final_value = group.get(self.key + "_final", self.final_value)
40
+ warmup_iters = self.warmup_iters
41
+ total_iters = self.total_iters
42
+ if self.overwrite:
43
+ final_value = self.final_value
44
+
45
+ # normalize in 0,1, then apply function (power) and denormalize
46
+ normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True)
47
+ normalized_schedule = np.power(normalized_schedule, 2)
48
+ warmup_schedule = (base_value - init_value) * normalized_schedule + init_value
49
+
50
+ # main scheduling
51
+ iters = np.arange(total_iters - warmup_iters)
52
+ schedule = final_value + 0.5 * (base_value - final_value) * (
53
+ 1 + np.cos(np.pi * iters / len(iters))
54
+ )
55
+ return np.concatenate((warmup_schedule, schedule))
56
+
57
+ def step(self):
58
+ self.iter = self.iter + 1
59
+ vals = self[self.iter]
60
+ for group, val in zip(self.optimizer.param_groups, vals):
61
+ if isinstance(group[self.key], (tuple, list)):
62
+ val = (val, *group[self.key][1:])
63
+ group[self.key] = val
64
+
65
+ def __getitem__(self, it):
66
+ it = min(it, self.total_iters - 1)
67
+ return [scheduler[it] for scheduler in self.schedulers]
68
+
69
+ def get(self):
70
+ return [group[self.key] for group in self.optimizer.param_groups]
flash3d/unidepth/utils/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .evaluation_depth import eval_depth, DICT_METRICS
2
+ from .visualization import colorize, image_grid, log_train_artifacts
3
+ from .misc import format_seconds, remove_padding, get_params, identity
4
+ from .distributed import (
5
+ is_main_process,
6
+ setup_multi_processes,
7
+ setup_slurm,
8
+ sync_tensor_across_gpus,
9
+ barrier,
10
+ get_rank,
11
+ get_dist_info,
12
+ )
13
+ from .geometric import unproject_points, spherical_zbuffer_to_euclidean
14
+
15
+ __all__ = [
16
+ "eval_depth",
17
+ "DICT_METRICS",
18
+ "colorize",
19
+ "image_grid",
20
+ "log_train_artifacts",
21
+ "format_seconds",
22
+ "remove_padding",
23
+ "get_params",
24
+ "identity",
25
+ "is_main_process",
26
+ "setup_multi_processes",
27
+ "setup_slurm",
28
+ "sync_tensor_across_gpus",
29
+ "barrier",
30
+ "get_rank",
31
+ "unproject_points",
32
+ "spherical_zbuffer_to_euclidean",
33
+ "validate",
34
+ "get_dist_info",
35
+ ]
flash3d/unidepth/utils/constants.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ import math
7
+ import torch
8
+
9
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
10
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
11
+ IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406)
12
+ IMAGENET_DATASET_STD = (0.229, 0.224, 0.225)
13
+ DEPTH_BINS = torch.cat(
14
+ (
15
+ torch.logspace(math.log10(0.1), math.log10(180.0), steps=512),
16
+ torch.tensor([260.0]),
17
+ ),
18
+ dim=0,
19
+ )
20
+ LOGERR_BINS = torch.linspace(-2, 2, steps=128 + 1)
21
+ LINERR_BINS = torch.linspace(-50, 50, steps=256 + 1)
flash3d/unidepth/utils/distributed.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ import os
7
+ import platform
8
+ import warnings
9
+ import subprocess
10
+
11
+ import cv2
12
+
13
+ import torch
14
+ import torch.utils.data.distributed
15
+ from torch import multiprocessing as mp
16
+ from torch import distributed as dist
17
+
18
+
19
+ def is_dist_avail_and_initialized():
20
+ if not dist.is_available():
21
+ return False
22
+ if not dist.is_initialized():
23
+ return False
24
+ return True
25
+
26
+
27
+ def get_rank():
28
+ if not is_dist_avail_and_initialized():
29
+ return 0
30
+ return dist.get_rank()
31
+
32
+
33
+ def barrier():
34
+ if not is_dist_avail_and_initialized():
35
+ return
36
+ dist.barrier()
37
+
38
+
39
+ def is_main_process():
40
+ return get_rank() == 0
41
+
42
+
43
+ def is_rank_zero(args):
44
+ return args.rank == 0
45
+
46
+
47
+ def get_dist_info():
48
+ if dist.is_available() and dist.is_initialized():
49
+ rank = dist.get_rank()
50
+ world_size = dist.get_world_size()
51
+ else:
52
+ rank = 0
53
+ world_size = 1
54
+ return rank, world_size
55
+
56
+
57
+ def setup_multi_processes(cfg):
58
+ """Setup multi-processing environment variables."""
59
+ # set multi-process start method as `fork` to speed up the training
60
+ if platform.system() != "Windows":
61
+ mp_start_method = cfg.get("mp_start_method", "fork")
62
+ current_method = mp.get_start_method(allow_none=True)
63
+ if current_method is not None and current_method != mp_start_method:
64
+ warnings.warn(
65
+ f"Multi-processing start method `{mp_start_method}` is "
66
+ f"different from the previous setting `{current_method}`."
67
+ f"It will be force set to `{mp_start_method}`. You can change "
68
+ f"this behavior by changing `mp_start_method` in your config."
69
+ )
70
+ mp.set_start_method(mp_start_method, force=True)
71
+
72
+ # disable opencv multithreading to avoid system being overloaded
73
+ opencv_num_threads = cfg.get("opencv_num_threads", 0)
74
+ cv2.setNumThreads(opencv_num_threads)
75
+
76
+ # setup OMP threads
77
+ # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
78
+ workers_per_gpu = cfg.get("workers_per_gpu", 4)
79
+
80
+ if "OMP_NUM_THREADS" not in os.environ and workers_per_gpu > 1:
81
+ omp_num_threads = 1
82
+ warnings.warn(
83
+ f"Setting OMP_NUM_THREADS environment variable for each process "
84
+ f"to be {omp_num_threads} in default, to avoid your system being "
85
+ f"overloaded, please further tune the variable for optimal "
86
+ f"performance in your application as needed."
87
+ )
88
+ os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
89
+
90
+ # setup MKL threads
91
+ if "MKL_NUM_THREADS" not in os.environ and workers_per_gpu > 1:
92
+ mkl_num_threads = os.environ.get("OMP_NUM_THREADS", 1)
93
+ warnings.warn(
94
+ f"Setting MKL_NUM_THREADS environment variable for each process "
95
+ f"to be {mkl_num_threads} in default, to avoid your system being "
96
+ f"overloaded, please further tune the variable for optimal "
97
+ f"performance in your application as needed."
98
+ )
99
+ os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads)
100
+
101
+
102
+ def setup_slurm(backend: str, port: str) -> None:
103
+ """Initialize slurm distributed training environment.
104
+ If argument ``port`` is not specified, then the master port will be system
105
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
106
+ environment variable, then a default port ``29500`` will be used.
107
+ Args:
108
+ backend (str): Backend of torch.distributed.
109
+ port (int, optional): Master port. Defaults to None.
110
+ """
111
+ proc_id = int(os.environ["SLURM_PROCID"])
112
+ ntasks = int(os.environ["SLURM_NTASKS"])
113
+ node_list = os.environ["SLURM_NODELIST"]
114
+
115
+ num_gpus = torch.cuda.device_count()
116
+
117
+ torch.cuda.set_device(proc_id % num_gpus)
118
+ addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
119
+ os.environ["MASTER_PORT"] = str(port)
120
+ os.environ["MASTER_ADDR"] = addr
121
+ os.environ["WORLD_SIZE"] = str(ntasks)
122
+ os.environ["LOCAL_RANK"] = str(proc_id % num_gpus)
123
+ os.environ["RANK"] = str(proc_id)
124
+ print(
125
+ proc_id,
126
+ ntasks,
127
+ num_gpus,
128
+ proc_id % num_gpus,
129
+ node_list,
130
+ addr,
131
+ os.environ["MASTER_PORT"],
132
+ os.system("nvidia-smi -L"),
133
+ )
134
+ dist.init_process_group(backend, rank=proc_id, world_size=ntasks)
135
+
136
+
137
+ def sync_tensor_across_gpus(t, dim=0, cat=True):
138
+ if t is None or not (dist.is_available() and dist.is_initialized()):
139
+ return t
140
+ t = torch.atleast_1d(t)
141
+ group = dist.group.WORLD
142
+ group_size = torch.distributed.get_world_size(group)
143
+
144
+ local_size = torch.tensor(t.size(dim), device=t.device)
145
+ all_sizes = [torch.zeros_like(local_size) for _ in range(group_size)]
146
+ dist.all_gather(all_sizes, local_size)
147
+ max_size = max(all_sizes)
148
+ size_diff = max_size.item() - local_size.item()
149
+ if size_diff:
150
+ padding = torch.zeros(size_diff, device=t.device, dtype=t.dtype)
151
+ t = torch.cat((t, padding))
152
+
153
+ gather_t_tensor = [torch.zeros_like(t) for _ in range(group_size)]
154
+ dist.all_gather(gather_t_tensor, t)
155
+ all_ts = []
156
+ for t, size in zip(gather_t_tensor, all_sizes):
157
+ all_ts.append(t[:size])
158
+ if cat:
159
+ return torch.cat(all_ts, dim=0)
160
+ return all_ts
161
+
162
+
163
+ import pickle
164
+
165
+
166
+ def sync_string_across_gpus(keys: list[str], device, dim=0):
167
+ keys_serialized = pickle.dumps(keys, protocol=pickle.HIGHEST_PROTOCOL)
168
+ keys_serialized_tensor = torch.frombuffer(keys_serialized, dtype=torch.uint8).to(
169
+ device
170
+ )
171
+ keys_serialized_tensor = sync_tensor_across_gpus(
172
+ keys_serialized_tensor, dim=0, cat=False
173
+ )
174
+ keys = [
175
+ key
176
+ for keys in keys_serialized_tensor
177
+ for key in pickle.loads(bytes(keys.cpu().tolist()))
178
+ ]
179
+ return keys
flash3d/unidepth/utils/ema_torch.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ from __future__ import division
7
+ from __future__ import unicode_literals
8
+
9
+ from typing import Iterable, Optional
10
+ import weakref
11
+ import copy
12
+ import contextlib
13
+ from math import tanh
14
+
15
+ import torch
16
+
17
+
18
+ class DummyExponentialMovingAverage:
19
+ def __init__(self, *args, **kwargs):
20
+ pass
21
+
22
+ def _get_parameters(self, *args, **kwargs):
23
+ pass
24
+
25
+ def get_current_decay(self, *args, **kwargs):
26
+ pass
27
+
28
+ def update(self, *args, **kwargs):
29
+ pass
30
+
31
+ def copy_to(self, *args, **kwargs):
32
+ pass
33
+
34
+ def store(self, *args, **kwargs):
35
+ return
36
+
37
+ def restore(self, *args, **kwargs):
38
+ return
39
+
40
+ @contextlib.contextmanager
41
+ def average_parameters(self, *args, **kwargs):
42
+ try:
43
+ yield
44
+ finally:
45
+ pass
46
+
47
+ def to(self, *args, **kwargs):
48
+ pass
49
+
50
+ def state_dict(self, *args, **kwargs):
51
+ pass
52
+
53
+ def load_state_dict(self, *args, **kwargs):
54
+ pass
55
+
56
+
57
+ class ExponentialMovingAverage:
58
+ """
59
+ Maintains (exponential) moving average of a set of parameters.
60
+
61
+ Args:
62
+ parameters: Iterable of `torch.nn.Parameter` (typically from
63
+ `model.parameters()`).
64
+ Note that EMA is computed on *all* provided parameters,
65
+ regardless of whether or not they have `requires_grad = True`;
66
+ this allows a single EMA object to be consistantly used even
67
+ if which parameters are trainable changes step to step.
68
+
69
+ If you want to some parameters in the EMA, do not pass them
70
+ to the object in the first place. For example:
71
+
72
+ ExponentialMovingAverage(
73
+ parameters=[p for p in model.parameters() if p.requires_grad],
74
+ decay=0.9
75
+ )
76
+
77
+ will ignore parameters that do not require grad.
78
+
79
+ decay: The exponential decay.
80
+
81
+ use_num_updates: Whether to use number of updates when computing
82
+ averages.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ parameters: Iterable[torch.nn.Parameter],
88
+ decay: float,
89
+ use_num_updates: bool = True,
90
+ update_after_step: int = 10000,
91
+ tau: int = 20000,
92
+ switch: bool = False,
93
+ ):
94
+ if decay < 0.0 or decay > 1.0:
95
+ raise ValueError("Decay must be between 0 and 1")
96
+ self.decay = decay
97
+ self.switch = switch # fi keeping EMA params in model after epochs
98
+ self.num_updates = 0 if use_num_updates else None
99
+ parameters = list(parameters)
100
+ self.shadow_params = [p.clone().detach() for p in parameters]
101
+ self.collected_params = None
102
+ # By maintaining only a weakref to each parameter,
103
+ # we maintain the old GC behaviour of ExponentialMovingAverage:
104
+ # if the model goes out of scope but the ExponentialMovingAverage
105
+ # is kept, no references to the model or its parameters will be
106
+ # maintained, and the model will be cleaned up.
107
+ self._params_refs = [weakref.ref(p) for p in parameters]
108
+ self.update_after_step = update_after_step
109
+ self.tau = tau
110
+
111
+ def _get_parameters(
112
+ self, parameters: Optional[Iterable[torch.nn.Parameter]]
113
+ ) -> Iterable[torch.nn.Parameter]:
114
+ if parameters is None:
115
+ parameters = [p() for p in self._params_refs]
116
+ if any(p is None for p in parameters):
117
+ raise ValueError(
118
+ "(One of) the parameters with which this ExponentialMovingAverage was initialized no longer exists (was garbage collected);"
119
+ " please either provide `parameters` explicitly or keep the model to which they belong from being garbage collected."
120
+ )
121
+ return parameters
122
+ else:
123
+ parameters = list(parameters)
124
+ if len(parameters) != len(self.shadow_params):
125
+ raise ValueError(
126
+ "Number of parameters passed as argument is different "
127
+ "from number of shadow parameters maintained by this "
128
+ "ExponentialMovingAverage"
129
+ )
130
+ return parameters
131
+
132
+ def get_current_decay(self):
133
+ epoch = max(self.num_updates - self.update_after_step - 1, 0.0)
134
+ if epoch <= 0:
135
+ return 0.0
136
+ value = tanh(epoch / self.tau) * self.decay
137
+ return value
138
+
139
+ def update(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None:
140
+ """
141
+ Update currently maintained parameters.
142
+
143
+ Call this every time the parameters are updated, such as the result of
144
+ the `optimizer.step()` call.
145
+
146
+ Args:
147
+ parameters: Iterable of `torch.nn.Parameter`; usually the same set of
148
+ parameters used to initialize this object. If `None`, the
149
+ parameters with which this `ExponentialMovingAverage` was
150
+ initialized will be used.
151
+ """
152
+ parameters = self._get_parameters(parameters)
153
+ decay = self.get_current_decay()
154
+ if self.num_updates is not None:
155
+ self.num_updates += 1
156
+
157
+ one_minus_decay = 1.0 - decay
158
+ with torch.no_grad():
159
+ for s_param, param in zip(self.shadow_params, parameters):
160
+ tmp = s_param - param
161
+ # tmp will be a new tensor so we can do in-place
162
+ tmp.mul_(one_minus_decay)
163
+ s_param.sub_(tmp)
164
+
165
+ def copy_to(
166
+ self, parameters: Optional[Iterable[torch.nn.Parameter]] = None
167
+ ) -> None:
168
+ """
169
+ Copy current averaged parameters into given collection of parameters.
170
+
171
+ Args:
172
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
173
+ updated with the stored moving averages. If `None`, the
174
+ parameters with which this `ExponentialMovingAverage` was
175
+ initialized will be used.
176
+ """
177
+ parameters = self._get_parameters(parameters)
178
+ for s_param, param in zip(self.shadow_params, parameters):
179
+ param.data.copy_(s_param.data)
180
+
181
+ def store(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None:
182
+ """
183
+ Save the current parameters for restoring later.
184
+
185
+ Args:
186
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
187
+ temporarily stored. If `None`, the parameters of with which this
188
+ `ExponentialMovingAverage` was initialized will be used.
189
+ """
190
+ parameters = self._get_parameters(parameters)
191
+ self.collected_params = [param.detach().clone() for param in parameters]
192
+
193
+ def restore(
194
+ self, parameters: Optional[Iterable[torch.nn.Parameter]] = None
195
+ ) -> None:
196
+ """
197
+ Restore the parameters stored with the `store` method.
198
+ Useful to validate the model with EMA parameters without affecting the
199
+ original optimization process. Store the parameters before the
200
+ `copy_to` method. After validation (or model saving), use this to
201
+ restore the former parameters.
202
+
203
+ Args:
204
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
205
+ updated with the stored parameters. If `None`, the
206
+ parameters with which this `ExponentialMovingAverage` was
207
+ initialized will be used.
208
+ """
209
+ if self.collected_params is None:
210
+ raise RuntimeError(
211
+ "This ExponentialMovingAverage has no `store()`ed weights "
212
+ "to `restore()`"
213
+ )
214
+ parameters = self._get_parameters(parameters)
215
+ for c_param, param in zip(self.collected_params, parameters):
216
+ param.data.copy_(c_param.data)
217
+
218
+ @contextlib.contextmanager
219
+ def average_parameters(
220
+ self, parameters: Optional[Iterable[torch.nn.Parameter]] = None
221
+ ):
222
+ r"""
223
+ Context manager for validation/inference with averaged parameters.
224
+
225
+ Equivalent to:
226
+
227
+ ema.store()
228
+ ema.copy_to()
229
+ try:
230
+ ...
231
+ finally:
232
+ ema.restore()
233
+
234
+ Args:
235
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
236
+ updated with the stored parameters. If `None`, the
237
+ parameters with which this `ExponentialMovingAverage` was
238
+ initialized will be used.
239
+ """
240
+ parameters = self._get_parameters(parameters)
241
+ self.store(parameters)
242
+ self.copy_to(parameters)
243
+ try:
244
+ yield
245
+ finally:
246
+ if not self.switch:
247
+ self.restore(parameters)
248
+
249
+ def to(self, device=None, dtype=None) -> None:
250
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
251
+
252
+ Args:
253
+ device: like `device` argument to `torch.Tensor.to`
254
+ """
255
+ # .to() on the tensors handles None correctly
256
+ self.shadow_params = [
257
+ (
258
+ p.to(device=device, dtype=dtype)
259
+ if p.is_floating_point()
260
+ else p.to(device=device)
261
+ )
262
+ for p in self.shadow_params
263
+ ]
264
+ if self.collected_params is not None:
265
+ self.collected_params = [
266
+ (
267
+ p.to(device=device, dtype=dtype)
268
+ if p.is_floating_point()
269
+ else p.to(device=device)
270
+ )
271
+ for p in self.collected_params
272
+ ]
273
+ return
274
+
275
+ def state_dict(self) -> dict:
276
+ r"""Returns the state of the ExponentialMovingAverage as a dict."""
277
+ # Following PyTorch conventions, references to tensors are returned:
278
+ # "returns a reference to the state and not its copy!" -
279
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
280
+ return {
281
+ "decay": self.decay,
282
+ "num_updates": self.num_updates,
283
+ "shadow_params": self.shadow_params,
284
+ "collected_params": self.collected_params,
285
+ }
286
+
287
+ def load_state_dict(self, state_dict: dict) -> None:
288
+ r"""Loads the ExponentialMovingAverage state.
289
+
290
+ Args:
291
+ state_dict (dict): EMA state. Should be an object returned
292
+ from a call to :meth:`state_dict`.
293
+ """
294
+ # deepcopy, to be consistent with module API
295
+ state_dict = copy.deepcopy(state_dict)
296
+ self.decay = state_dict["decay"]
297
+ if self.decay < 0.0 or self.decay > 1.0:
298
+ raise ValueError("Decay must be between 0 and 1")
299
+ self.num_updates = state_dict["num_updates"]
300
+ assert self.num_updates is None or isinstance(
301
+ self.num_updates, int
302
+ ), "Invalid num_updates"
303
+
304
+ self.shadow_params = state_dict["shadow_params"]
305
+ assert isinstance(self.shadow_params, list), "shadow_params must be a list"
306
+ assert all(
307
+ isinstance(p, torch.Tensor) for p in self.shadow_params
308
+ ), "shadow_params must all be Tensors"
309
+
310
+ self.collected_params = state_dict["collected_params"]
311
+ if self.collected_params is not None:
312
+ assert isinstance(
313
+ self.collected_params, list
314
+ ), "collected_params must be a list"
315
+ assert all(
316
+ isinstance(p, torch.Tensor) for p in self.collected_params
317
+ ), "collected_params must all be Tensors"
318
+ assert len(self.collected_params) == len(
319
+ self.shadow_params
320
+ ), "collected_params and shadow_params had different lengths"
321
+
322
+ if len(self.shadow_params) == len(self._params_refs):
323
+ # Consistant with torch.optim.Optimizer, cast things to consistant
324
+ # device and dtype with the parameters
325
+ params = [p() for p in self._params_refs]
326
+ # If parameters have been garbage collected, just load the state
327
+ # we were given without change.
328
+ if not any(p is None for p in params):
329
+ # ^ parameter references are still good
330
+ for i, p in enumerate(params):
331
+ self.shadow_params[i] = self.shadow_params[i].to(
332
+ device=p.device, dtype=p.dtype
333
+ )
334
+ if self.collected_params is not None:
335
+ self.collected_params[i] = self.collected_params[i].to(
336
+ device=p.device, dtype=p.dtype
337
+ )
338
+ else:
339
+ raise ValueError(
340
+ "Tried to `load_state_dict()` with the wrong number of "
341
+ "parameters in the saved state."
342
+ )
flash3d/unidepth/utils/evaluation_depth.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+ # We prefer not to install PyTorch3D in the package
6
+ # Code commented is how 3D metrics are computed
7
+
8
+ from collections import defaultdict
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ # from chamfer_distance import ChamferDistance
15
+
16
+ from unidepth.utils.constants import DEPTH_BINS
17
+
18
+
19
+ # chamfer_cls = ChamferDistance()
20
+
21
+
22
+ # def chamfer_dist(tensor1, tensor2):
23
+ # x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device)
24
+ # y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device)
25
+ # dist1, dist2, idx1, idx2 = chamfer_cls(
26
+ # tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths
27
+ # )
28
+ # return (torch.sqrt(dist1) + torch.sqrt(dist2)) / 2
29
+
30
+
31
+ # def auc(tensor1, tensor2, thresholds):
32
+ # x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device)
33
+ # y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device)
34
+ # dist1, dist2, idx1, idx2 = chamfer_cls(
35
+ # tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths
36
+ # )
37
+ # # compute precision recall
38
+ # precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds]
39
+ # recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds]
40
+ # auc_value = torch.trapz(
41
+ # torch.tensor(precisions, device=tensor1.device),
42
+ # torch.tensor(recalls, device=tensor1.device),
43
+ # )
44
+ # return auc_value
45
+
46
+
47
+ def delta(tensor1, tensor2, exponent):
48
+ inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1))
49
+ return (inlier < 1.25**exponent).to(torch.float32).mean()
50
+
51
+
52
+ def ssi(tensor1, tensor2, qtl=0.05):
53
+ stability_mat = 1e-9 * torch.eye(2, device=tensor1.device)
54
+ error = (tensor1 - tensor2).abs()
55
+ mask = error < torch.quantile(error, 1 - qtl)
56
+ tensor1_mask = tensor1[mask]
57
+ tensor2_mask = tensor2[mask]
58
+ tensor2_one = torch.stack(
59
+ [tensor2_mask.detach(), torch.ones_like(tensor2_mask).detach()], dim=1
60
+ )
61
+ scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (
62
+ tensor2_one.T @ tensor1_mask.unsqueeze(1)
63
+ )
64
+ scale, shift = scale_shift.squeeze().chunk(2, dim=0)
65
+ return tensor2 * scale + shift
66
+ # tensor2_one = torch.stack([tensor2.detach(), torch.ones_like(tensor2).detach()], dim=1)
67
+ # scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (tensor2_one.T @ tensor1.unsqueeze(1))
68
+ # scale, shift = scale_shift.squeeze().chunk(2, dim=0)
69
+ # return tensor2 * scale + shift
70
+
71
+
72
+ def d1_ssi(tensor1, tensor2):
73
+ delta_ = delta(tensor1, ssi(tensor1, tensor2), 1.0)
74
+ return delta_
75
+
76
+
77
+ def d_auc(tensor1, tensor2):
78
+ exponents = torch.linspace(0.01, 5.0, steps=100, device=tensor1.device)
79
+ deltas = [delta(tensor1, tensor2, exponent) for exponent in exponents]
80
+ return torch.trapz(torch.tensor(deltas, device=tensor1.device), exponents) / 5.0
81
+
82
+
83
+ # def f1_score(tensor1, tensor2, thresholds):
84
+ # x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device)
85
+ # y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device)
86
+ # dist1, dist2, idx1, idx2 = chamfer_cls(
87
+ # tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths
88
+ # )
89
+ # # compute precision recall
90
+ # precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds]
91
+ # recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds]
92
+ # precisions = torch.tensor(precisions, device=tensor1.device)
93
+ # recalls = torch.tensor(recalls, device=tensor1.device)
94
+ # f1_thresholds = 2 * precisions * recalls / (precisions + recalls)
95
+ # f1_thresholds = torch.where(
96
+ # torch.isnan(f1_thresholds), torch.zeros_like(f1_thresholds), f1_thresholds
97
+ # )
98
+ # f1_value = torch.trapz(f1_thresholds) / len(thresholds)
99
+ # return f1_value
100
+
101
+
102
+ DICT_METRICS = {
103
+ "d1": partial(delta, exponent=1.0),
104
+ "d2": partial(delta, exponent=2.0),
105
+ "d3": partial(delta, exponent=3.0),
106
+ "rmse": lambda gt, pred: torch.sqrt(((gt - pred) ** 2).mean()),
107
+ "rmselog": lambda gt, pred: torch.sqrt(
108
+ ((torch.log(gt) - torch.log(pred)) ** 2).mean()
109
+ ),
110
+ "arel": lambda gt, pred: (torch.abs(gt - pred) / gt).mean(),
111
+ "sqrel": lambda gt, pred: (((gt - pred) ** 2) / gt).mean(),
112
+ "log10": lambda gt, pred: torch.abs(torch.log10(pred) - torch.log10(gt)).mean(),
113
+ "silog": lambda gt, pred: 100 * torch.std(torch.log(pred) - torch.log(gt)).mean(),
114
+ "medianlog": lambda gt, pred: 100
115
+ * (torch.log(pred) - torch.log(gt)).median().abs(),
116
+ "d_auc": d_auc,
117
+ "d1_ssi": d1_ssi,
118
+ }
119
+
120
+
121
+ # DICT_METRICS_3D = {
122
+ # "chamfer": lambda gt, pred, thresholds: chamfer_dist(
123
+ # gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1)
124
+ # ),
125
+ # "F1": lambda gt, pred, thresholds: f1_score(
126
+ # gt.unsqueeze(0).permute(0, 2, 1),
127
+ # pred.unsqueeze(0).permute(0, 2, 1),
128
+ # thresholds=thresholds,
129
+ # ),
130
+ # }
131
+
132
+
133
+ DICT_METRICS_D = {
134
+ "a1": lambda gt, pred: (torch.maximum((gt / pred), (pred / gt)) > 1.25**1.0).to(
135
+ torch.float32
136
+ ),
137
+ "abs_rel": lambda gt, pred: (torch.abs(gt - pred) / gt),
138
+ }
139
+
140
+
141
+ def eval_depth(
142
+ gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, max_depth=None
143
+ ):
144
+ summary_metrics = defaultdict(list)
145
+ preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear")
146
+ for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)):
147
+ if max_depth is not None:
148
+ mask = torch.logical_and(mask, gt <= max_depth)
149
+ for name, fn in DICT_METRICS.items():
150
+ summary_metrics[name].append(fn(gt[mask], pred[mask]).mean())
151
+ return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()}
152
+
153
+
154
+ # def eval_3d(
155
+ # gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, thresholds=None
156
+ # ):
157
+ # summary_metrics = defaultdict(list)
158
+ # w_max = min(gts.shape[-1] // 4, 400)
159
+ # gts = F.interpolate(
160
+ # gts, (int(w_max * gts.shape[-2] / gts.shape[-1]), w_max), mode="nearest"
161
+ # )
162
+ # preds = F.interpolate(preds, gts.shape[-2:], mode="nearest")
163
+ # masks = F.interpolate(
164
+ # masks.to(torch.float32), gts.shape[-2:], mode="nearest"
165
+ # ).bool()
166
+ # for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)):
167
+ # if not torch.any(mask):
168
+ # continue
169
+ # for name, fn in DICT_METRICS_3D.items():
170
+ # summary_metrics[name].append(
171
+ # fn(gt[:, mask.squeeze()], pred[:, mask.squeeze()], thresholds).mean()
172
+ # )
173
+ # return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()}
flash3d/unidepth/utils/geometric.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ from typing import Tuple
7
+
8
+ import torch
9
+ from torch.nn import functional as F
10
+
11
+
12
+ def generate_rays(
13
+ camera_intrinsics: torch.Tensor, image_shape: Tuple[int, int], noisy: bool = False
14
+ ):
15
+ batch_size, device, dtype = (
16
+ camera_intrinsics.shape[0],
17
+ camera_intrinsics.device,
18
+ camera_intrinsics.dtype,
19
+ )
20
+ height, width = image_shape
21
+ # Generate grid of pixel coordinates
22
+ pixel_coords_x = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
23
+ pixel_coords_y = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
24
+ if noisy:
25
+ pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5
26
+ pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5
27
+ pixel_coords = torch.stack(
28
+ [pixel_coords_x.repeat(height, 1), pixel_coords_y.repeat(width, 1).t()], dim=2
29
+ ) # (H, W, 2)
30
+ pixel_coords = pixel_coords + 0.5
31
+
32
+ # Calculate ray directions
33
+ intrinsics_inv = torch.inverse(camera_intrinsics.float()).to(dtype) # (B, 3, 3)
34
+ homogeneous_coords = torch.cat(
35
+ [pixel_coords, torch.ones_like(pixel_coords[:, :, :1])], dim=2
36
+ ) # (H, W, 3)
37
+ ray_directions = torch.matmul(
38
+ intrinsics_inv, homogeneous_coords.permute(2, 0, 1).flatten(1)
39
+ ) # (3, H*W)
40
+ ray_directions = F.normalize(ray_directions, dim=1) # (B, 3, H*W)
41
+ ray_directions = ray_directions.permute(0, 2, 1) # (B, H*W, 3)
42
+
43
+ theta = torch.atan2(ray_directions[..., 0], ray_directions[..., -1])
44
+ phi = torch.acos(ray_directions[..., 1])
45
+ # pitch = torch.asin(ray_directions[..., 1])
46
+ # roll = torch.atan2(ray_directions[..., 0], - ray_directions[..., 1])
47
+ angles = torch.stack([theta, phi], dim=-1)
48
+ return ray_directions, angles
49
+
50
+
51
+ @torch.jit.script
52
+ def spherical_zbuffer_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
53
+ theta = spherical_tensor[..., 0] # Extract polar angle
54
+ phi = spherical_tensor[..., 1] # Extract azimuthal angle
55
+ z = spherical_tensor[..., 2] # Extract zbuffer depth
56
+
57
+ # y = r * cos(phi)
58
+ # x = r * sin(phi) * sin(theta)
59
+ # z = r * sin(phi) * cos(theta)
60
+ # =>
61
+ # r = z / sin(phi) / cos(theta)
62
+ # y = z / (sin(phi) / cos(phi)) / cos(theta)
63
+ # x = z * sin(theta) / cos(theta)
64
+ x = z * torch.tan(theta)
65
+ y = z / torch.tan(phi) / torch.cos(theta)
66
+
67
+ euclidean_tensor = torch.stack((x, y, z), dim=-1)
68
+ return euclidean_tensor
69
+
70
+
71
+ @torch.jit.script
72
+ def spherical_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
73
+ theta = spherical_tensor[..., 0] # Extract polar angle
74
+ phi = spherical_tensor[..., 1] # Extract azimuthal angle
75
+ r = spherical_tensor[..., 2] # Extract radius
76
+ # y = r * cos(phi)
77
+ # x = r * sin(phi) * sin(theta)
78
+ # z = r * sin(phi) * cos(theta)
79
+ x = r * torch.sin(phi) * torch.sin(theta)
80
+ y = r * torch.cos(phi)
81
+ z = r * torch.cos(theta) * torch.sin(phi)
82
+
83
+ euclidean_tensor = torch.stack((x, y, z), dim=-1)
84
+ return euclidean_tensor
85
+
86
+
87
+ @torch.jit.script
88
+ def euclidean_to_spherical(spherical_tensor: torch.Tensor) -> torch.Tensor:
89
+ x = spherical_tensor[..., 0] # Extract polar angle
90
+ y = spherical_tensor[..., 1] # Extract azimuthal angle
91
+ z = spherical_tensor[..., 2] # Extract radius
92
+ # y = r * cos(phi)
93
+ # x = r * sin(phi) * sin(theta)
94
+ # z = r * sin(phi) * cos(theta)
95
+ r = torch.sqrt(x**2 + y**2 + z**2)
96
+ theta = torch.atan2(x / r, z / r)
97
+ phi = torch.acos(y / r)
98
+
99
+ euclidean_tensor = torch.stack((theta, phi, r), dim=-1)
100
+ return euclidean_tensor
101
+
102
+
103
+ @torch.jit.script
104
+ def euclidean_to_spherical_zbuffer(euclidean_tensor: torch.Tensor) -> torch.Tensor:
105
+ pitch = torch.asin(euclidean_tensor[..., 1])
106
+ yaw = torch.atan2(euclidean_tensor[..., 0], euclidean_tensor[..., -1])
107
+ z = euclidean_tensor[..., 2] # Extract zbuffer depth
108
+ euclidean_tensor = torch.stack((pitch, yaw, z), dim=-1)
109
+ return euclidean_tensor
110
+
111
+
112
+ @torch.jit.script
113
+ def unproject_points(
114
+ depth: torch.Tensor, camera_intrinsics: torch.Tensor
115
+ ) -> torch.Tensor:
116
+ """
117
+ Unprojects a batch of depth maps to 3D point clouds using camera intrinsics.
118
+
119
+ Args:
120
+ depth (torch.Tensor): Batch of depth maps of shape (B, 1, H, W).
121
+ camera_intrinsics (torch.Tensor): Camera intrinsic matrix of shape (B, 3, 3).
122
+
123
+ Returns:
124
+ torch.Tensor: Batch of 3D point clouds of shape (B, 3, H, W).
125
+ """
126
+ batch_size, _, height, width = depth.shape
127
+ device = depth.device
128
+
129
+ # Create pixel grid
130
+ y_coords, x_coords = torch.meshgrid(
131
+ torch.arange(height, device=device),
132
+ torch.arange(width, device=device),
133
+ indexing="ij",
134
+ )
135
+ pixel_coords = torch.stack((x_coords, y_coords), dim=-1) # (H, W, 2)
136
+
137
+ # Get homogeneous coords (u v 1)
138
+ pixel_coords_homogeneous = torch.cat(
139
+ (pixel_coords, torch.ones((height, width, 1), device=device)), dim=-1
140
+ )
141
+ pixel_coords_homogeneous = pixel_coords_homogeneous.permute(2, 0, 1).flatten(
142
+ 1
143
+ ) # (3, H*W)
144
+ # Apply K^-1 @ (u v 1): [B, 3, 3] @ [3, H*W] -> [B, 3, H*W]
145
+ unprojected_points = torch.matmul(
146
+ torch.inverse(camera_intrinsics), pixel_coords_homogeneous
147
+ ) # (B, 3, H*W)
148
+ unprojected_points = unprojected_points.view(
149
+ batch_size, 3, height, width
150
+ ) # (B, 3, H, W)
151
+ unprojected_points = unprojected_points * depth # (B, 3, H, W)
152
+ return unprojected_points
153
+
154
+
155
+ @torch.jit.script
156
+ def project_points(
157
+ points_3d: torch.Tensor,
158
+ intrinsic_matrix: torch.Tensor,
159
+ image_shape: Tuple[int, int],
160
+ ) -> torch.Tensor:
161
+ # Project 3D points onto the image plane via intrinsics (u v w) = (x y z) @ K^T
162
+ points_2d = torch.matmul(points_3d, intrinsic_matrix.transpose(1, 2))
163
+
164
+ # Normalize projected points: (u v w) -> (u / w, v / w, 1)
165
+ points_2d = points_2d[..., :2] / points_2d[..., 2:]
166
+
167
+ # To pixels (rounding!!!), no int as it breaks gradient
168
+ points_2d = points_2d.round()
169
+
170
+ # pointa need to be inside the image (can it diverge onto all points out???)
171
+ valid_mask = (
172
+ (points_2d[..., 0] >= 0)
173
+ & (points_2d[..., 0] < image_shape[1])
174
+ & (points_2d[..., 1] >= 0)
175
+ & (points_2d[..., 1] < image_shape[0])
176
+ )
177
+
178
+ # Calculate the flat indices of the valid pixels
179
+ flat_points_2d = points_2d[..., 0] + points_2d[..., 1] * image_shape[1]
180
+ flat_indices = flat_points_2d.long()
181
+
182
+ # Create depth maps and counts using scatter_add, (B, H, W)
183
+ depth_maps = torch.zeros(
184
+ [points_3d.shape[0], *image_shape], device=points_3d.device
185
+ )
186
+ counts = torch.zeros([points_3d.shape[0], *image_shape], device=points_3d.device)
187
+
188
+ # Loop over batches to apply masks and accumulate depth/count values
189
+ for i in range(points_3d.shape[0]):
190
+ valid_indices = flat_indices[i, valid_mask[i]]
191
+ depth_maps[i].view(-1).scatter_add_(
192
+ 0, valid_indices, points_3d[i, valid_mask[i], 2]
193
+ )
194
+ counts[i].view(-1).scatter_add_(
195
+ 0, valid_indices, torch.ones_like(points_3d[i, valid_mask[i], 2])
196
+ )
197
+
198
+ # Calculate mean depth for each pixel in each batch
199
+ mean_depth_maps = depth_maps / counts.clamp(min=1.0)
200
+ return mean_depth_maps.reshape(-1, 1, *image_shape) # (B, 1, H, W)
201
+
202
+
203
+ @torch.jit.script
204
+ def downsample(data: torch.Tensor, downsample_factor: int = 2):
205
+ N, _, H, W = data.shape
206
+ data = data.view(
207
+ N,
208
+ H // downsample_factor,
209
+ downsample_factor,
210
+ W // downsample_factor,
211
+ downsample_factor,
212
+ 1,
213
+ )
214
+ data = data.permute(0, 1, 3, 5, 2, 4).contiguous()
215
+ data = data.view(-1, downsample_factor * downsample_factor)
216
+ data_tmp = torch.where(data == 0.0, 1e5 * torch.ones_like(data), data)
217
+ data = torch.min(data_tmp, dim=-1).values
218
+ data = data.view(N, 1, H // downsample_factor, W // downsample_factor)
219
+ data = torch.where(data > 1000, torch.zeros_like(data), data)
220
+ return data
221
+
222
+
223
+ @torch.jit.script
224
+ def flat_interpolate(
225
+ flat_tensor: torch.Tensor,
226
+ old: Tuple[int, int],
227
+ new: Tuple[int, int],
228
+ antialias: bool = True,
229
+ mode: str = "bilinear",
230
+ ) -> torch.Tensor:
231
+ if old[0] == new[0] and old[1] == new[1]:
232
+ return flat_tensor
233
+ tensor = flat_tensor.view(flat_tensor.shape[0], old[0], old[1], -1).permute(
234
+ 0, 3, 1, 2
235
+ ) # b c h w
236
+ tensor_interp = F.interpolate(
237
+ tensor,
238
+ size=(new[0], new[1]),
239
+ mode=mode,
240
+ align_corners=False,
241
+ antialias=antialias,
242
+ )
243
+ flat_tensor_interp = tensor_interp.view(
244
+ flat_tensor.shape[0], -1, new[0] * new[1]
245
+ ).permute(
246
+ 0, 2, 1
247
+ ) # b (h w) c
248
+ return flat_tensor_interp.contiguous()
flash3d/unidepth/utils/misc.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ from functools import wraps
7
+
8
+ import numpy as np
9
+ from scipy import interpolate
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from einops import rearrange, repeat, reduce
16
+
17
+
18
+ def max_stack(tensors):
19
+ return torch.stack(tensors, dim=-1).max(dim=-1)[0]
20
+
21
+
22
+ def softmax_stack(tensors, temperature=1.0):
23
+ return F.softmax(torch.stack(tensors, dim=-1) / temperature, dim=-1).sum(dim=-1)
24
+
25
+
26
+ def mean_stack(tensors):
27
+ if len(tensors) == 1:
28
+ return tensors[0]
29
+ return torch.stack(tensors, dim=-1).mean(dim=-1)
30
+
31
+
32
+ def sum_stack(tensors):
33
+ return torch.stack(tensors, dim=-1).sum(dim=-1)
34
+
35
+
36
+ def convert_module_to_f16(l):
37
+ """
38
+ Convert primitive modules to float16.
39
+ """
40
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
41
+ l.weight.data = l.weight.data.half()
42
+ if l.bias is not None:
43
+ l.bias.data = l.bias.data.half()
44
+
45
+
46
+ def convert_module_to_f32(l):
47
+ """
48
+ Convert primitive modules to float32, undoing convert_module_to_f16().
49
+ """
50
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
51
+ l.weight.data = l.weight.data.float()
52
+ if l.bias is not None:
53
+ l.bias.data = l.bias.data.float()
54
+
55
+
56
+ def format_seconds(seconds):
57
+ minutes, seconds = divmod(seconds, 60)
58
+ hours, minutes = divmod(minutes, 60)
59
+ return f"{hours:d}:{minutes:02d}:{seconds:02d}"
60
+
61
+
62
+ def get_params(module, lr, wd):
63
+ skip_list = {}
64
+ skip_keywords = {}
65
+ if hasattr(module, "no_weight_decay"):
66
+ skip_list = module.no_weight_decay()
67
+ if hasattr(module, "no_weight_decay_keywords"):
68
+ skip_keywords = module.no_weight_decay_keywords()
69
+ has_decay = []
70
+ no_decay = []
71
+ for name, param in module.named_parameters():
72
+ if not param.requires_grad:
73
+ continue # frozen weights
74
+ if (
75
+ (name in skip_list)
76
+ or any((kw in name for kw in skip_keywords))
77
+ or len(param.shape) == 1
78
+ ):
79
+ # if (name in skip_list) or any((kw in name for kw in skip_keywords)):
80
+ # print(name, skip_keywords)
81
+ no_decay.append(param)
82
+ else:
83
+ has_decay.append(param)
84
+
85
+ group1 = {
86
+ "params": has_decay,
87
+ "weight_decay": wd,
88
+ "lr": lr,
89
+ "weight_decay_init": wd,
90
+ "weight_decay_base": wd,
91
+ "lr_init": lr,
92
+ "lr_base": lr,
93
+ }
94
+ group2 = {
95
+ "params": no_decay,
96
+ "weight_decay": 0.0,
97
+ "lr": lr,
98
+ "weight_decay_init": 0.0,
99
+ "weight_decay_base": 0.0,
100
+ "weight_decay_final": 0.0,
101
+ "lr_init": lr,
102
+ "lr_base": lr,
103
+ }
104
+ return [group1, group2], [lr, lr]
105
+
106
+
107
+ def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage):
108
+ if var_name in ("cls_token", "mask_token", "pos_embed", "absolute_pos_embed"):
109
+ return 0
110
+ elif var_name.startswith("patch_embed"):
111
+ return 0
112
+ elif var_name.startswith("layers"):
113
+ if var_name.split(".")[2] == "blocks":
114
+ stage_id = int(var_name.split(".")[1])
115
+ layer_id = int(var_name.split(".")[3]) + sum(layers_per_stage[:stage_id])
116
+ return layer_id + 1
117
+ elif var_name.split(".")[2] == "downsample":
118
+ stage_id = int(var_name.split(".")[1])
119
+ layer_id = sum(layers_per_stage[: stage_id + 1])
120
+ return layer_id
121
+ else:
122
+ return num_max_layer - 1
123
+
124
+
125
+ def get_params_layerdecayswin(module, lr, wd, ld):
126
+ skip_list = {}
127
+ skip_keywords = {}
128
+ if hasattr(module, "no_weight_decay"):
129
+ skip_list = module.no_weight_decay()
130
+ if hasattr(module, "no_weight_decay_keywords"):
131
+ skip_keywords = module.no_weight_decay_keywords()
132
+ layers_per_stage = module.depths
133
+ num_layers = sum(layers_per_stage) + 1
134
+ lrs = []
135
+ params = []
136
+ for name, param in module.named_parameters():
137
+ if not param.requires_grad:
138
+ print(f"{name} frozen")
139
+ continue # frozen weights
140
+ layer_id = get_num_layer_for_swin(name, num_layers, layers_per_stage)
141
+ lr_cur = lr * ld ** (num_layers - layer_id - 1)
142
+ # if (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1 or name.endswith(".bias"):
143
+ if (name in skip_list) or any((kw in name for kw in skip_keywords)):
144
+ wd_cur = 0.0
145
+ else:
146
+ wd_cur = wd
147
+ params.append({"params": param, "weight_decay": wd_cur, "lr": lr_cur})
148
+ lrs.append(lr_cur)
149
+ return params, lrs
150
+
151
+
152
+ def log(t, eps: float = 1e-5):
153
+ return torch.log(t.clamp(min=eps))
154
+
155
+
156
+ def l2norm(t):
157
+ return F.normalize(t, dim=-1)
158
+
159
+
160
+ def exists(val):
161
+ return val is not None
162
+
163
+
164
+ def identity(t, *args, **kwargs):
165
+ return t
166
+
167
+
168
+ def divisible_by(numer, denom):
169
+ return (numer % denom) == 0
170
+
171
+
172
+ def first(arr, d=None):
173
+ if len(arr) == 0:
174
+ return d
175
+ return arr[0]
176
+
177
+
178
+ def default(val, d):
179
+ if exists(val):
180
+ return val
181
+ return d() if callable(d) else d
182
+
183
+
184
+ def maybe(fn):
185
+ @wraps(fn)
186
+ def inner(x):
187
+ if not exists(x):
188
+ return x
189
+ return fn(x)
190
+
191
+ return inner
192
+
193
+
194
+ def once(fn):
195
+ called = False
196
+
197
+ @wraps(fn)
198
+ def inner(x):
199
+ nonlocal called
200
+ if called:
201
+ return
202
+ called = True
203
+ return fn(x)
204
+
205
+ return inner
206
+
207
+
208
+ def _many(fn):
209
+ @wraps(fn)
210
+ def inner(tensors, pattern, **kwargs):
211
+ return (fn(tensor, pattern, **kwargs) for tensor in tensors)
212
+
213
+ return inner
214
+
215
+
216
+ rearrange_many = _many(rearrange)
217
+ repeat_many = _many(repeat)
218
+ reduce_many = _many(reduce)
219
+
220
+
221
+ def load_pretrained(state_dict, checkpoint):
222
+ checkpoint_model = checkpoint["model"]
223
+ if any([True if "encoder." in k else False for k in checkpoint_model.keys()]):
224
+ checkpoint_model = {
225
+ k.replace("encoder.", ""): v
226
+ for k, v in checkpoint_model.items()
227
+ if k.startswith("encoder.")
228
+ }
229
+ print("Detect pre-trained model, remove [encoder.] prefix.")
230
+ else:
231
+ print("Detect non-pre-trained model, pass without doing anything.")
232
+ print(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........")
233
+ checkpoint = load_checkpoint_swin(state_dict, checkpoint_model)
234
+
235
+
236
+ def load_checkpoint_swin(model, checkpoint_model):
237
+ state_dict = model.state_dict()
238
+ # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size
239
+ all_keys = list(checkpoint_model.keys())
240
+ for key in all_keys:
241
+ if "relative_position_bias_table" in key:
242
+ relative_position_bias_table_pretrained = checkpoint_model[key]
243
+ relative_position_bias_table_current = state_dict[key]
244
+ L1, nH1 = relative_position_bias_table_pretrained.size()
245
+ L2, nH2 = relative_position_bias_table_current.size()
246
+ if nH1 != nH2:
247
+ print(f"Error in loading {key}, passing......")
248
+ else:
249
+ if L1 != L2:
250
+ print(f"{key}: Interpolate relative_position_bias_table using geo.")
251
+ src_size = int(L1**0.5)
252
+ dst_size = int(L2**0.5)
253
+
254
+ def geometric_progression(a, r, n):
255
+ return a * (1.0 - r**n) / (1.0 - r)
256
+
257
+ left, right = 1.01, 1.5
258
+ while right - left > 1e-6:
259
+ q = (left + right) / 2.0
260
+ gp = geometric_progression(1, q, src_size // 2)
261
+ if gp > dst_size // 2:
262
+ right = q
263
+ else:
264
+ left = q
265
+
266
+ # if q > 1.090307:
267
+ # q = 1.090307
268
+
269
+ dis = []
270
+ cur = 1
271
+ for i in range(src_size // 2):
272
+ dis.append(cur)
273
+ cur += q ** (i + 1)
274
+
275
+ r_ids = [-_ for _ in reversed(dis)]
276
+
277
+ x = r_ids + [0] + dis
278
+ y = r_ids + [0] + dis
279
+
280
+ t = dst_size // 2.0
281
+ dx = np.arange(-t, t + 0.1, 1.0)
282
+ dy = np.arange(-t, t + 0.1, 1.0)
283
+
284
+ print("Original positions = %s" % str(x))
285
+ print("Target positions = %s" % str(dx))
286
+
287
+ all_rel_pos_bias = []
288
+
289
+ for i in range(nH1):
290
+ z = (
291
+ relative_position_bias_table_pretrained[:, i]
292
+ .view(src_size, src_size)
293
+ .float()
294
+ .numpy()
295
+ )
296
+ f_cubic = interpolate.interp2d(x, y, z, kind="cubic")
297
+ all_rel_pos_bias.append(
298
+ torch.Tensor(f_cubic(dx, dy))
299
+ .contiguous()
300
+ .view(-1, 1)
301
+ .to(relative_position_bias_table_pretrained.device)
302
+ )
303
+
304
+ new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
305
+ checkpoint_model[key] = new_rel_pos_bias
306
+
307
+ # delete relative_position_index since we always re-init it
308
+ relative_position_index_keys = [
309
+ k for k in checkpoint_model.keys() if "relative_position_index" in k
310
+ ]
311
+ for k in relative_position_index_keys:
312
+ del checkpoint_model[k]
313
+
314
+ # delete relative_coords_table since we always re-init it
315
+ relative_coords_table_keys = [
316
+ k for k in checkpoint_model.keys() if "relative_coords_table" in k
317
+ ]
318
+ for k in relative_coords_table_keys:
319
+ del checkpoint_model[k]
320
+
321
+ # # re-map keys due to name change
322
+ rpe_mlp_keys = [k for k in checkpoint_model.keys() if "cpb_mlp" in k]
323
+ for k in rpe_mlp_keys:
324
+ checkpoint_model[k.replace("cpb_mlp", "rpe_mlp")] = checkpoint_model.pop(k)
325
+
326
+ # delete attn_mask since we always re-init it
327
+ attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k]
328
+ for k in attn_mask_keys:
329
+ del checkpoint_model[k]
330
+
331
+ encoder_keys = [k for k in checkpoint_model.keys() if k.startswith("encoder.")]
332
+ for k in encoder_keys:
333
+ checkpoint_model[k.replace("encoder.", "")] = checkpoint_model.pop(k)
334
+
335
+ return checkpoint_model
336
+
337
+
338
+ def add_padding_metas(out, image_metas):
339
+ device = out.device
340
+ # left, right, top, bottom
341
+ paddings = [img_meta.get("padding_size", [0] * 4) for img_meta in image_metas]
342
+ paddings = torch.stack(paddings).to(device)
343
+ outs = [F.pad(o, padding, value=0.0) for padding, o in zip(paddings, out)]
344
+ return torch.stack(outs)
345
+
346
+
347
+ def remove_padding(out, paddings):
348
+ B, C, H, W = out.shape
349
+ device = out.device
350
+ # left, right, top, bottom
351
+ paddings = torch.stack(paddings).to(device)
352
+ outs = [
353
+ o[:, padding[1] : H - padding[3], padding[0] : W - padding[2]]
354
+ for padding, o in zip(paddings, out)
355
+ ]
356
+ return torch.stack(outs)
357
+
358
+
359
+ def remove_padding_metas(out, image_metas):
360
+ B, C, H, W = out.shape
361
+ device = out.device
362
+ # left, right, top, bottom
363
+ paddings = [
364
+ torch.tensor(img_meta.get("padding_size", [0] * 4)) for img_meta in image_metas
365
+ ]
366
+ return remove_padding(out, paddings)
367
+
368
+
369
+ def ssi_helper(tensor1, tensor2):
370
+ stability_mat = 1e-4 * torch.eye(2, device=tensor1.device)
371
+ tensor2_one = torch.stack([tensor2, torch.ones_like(tensor2)], dim=1)
372
+ scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (
373
+ tensor2_one.T @ tensor1.unsqueeze(1)
374
+ )
375
+ scale, shift = scale_shift.squeeze().chunk(2, dim=0)
376
+ return scale, shift
377
+
378
+
379
+ def calculate_mean_values(names, values):
380
+ # Create a defaultdict to store sum and count for each name
381
+ name_values = {name: {} for name in names}
382
+
383
+ # Iterate through the lists and accumulate values for each name
384
+ for name, value in zip(names, values):
385
+ name_values[name]["sum"] = name_values[name].get("sum", 0.0) + value
386
+ name_values[name]["count"] = name_values[name].get("count", 0.0) + 1
387
+
388
+ # Calculate mean values and create the output dictionary
389
+ output_dict = {
390
+ name: name_values[name]["sum"] / name_values[name]["count"]
391
+ for name in name_values
392
+ }
393
+
394
+ return output_dict
395
+
396
+
397
+ def remove_leading_dim(infos):
398
+ if isinstance(infos, dict):
399
+ return {k: remove_leading_dim(v) for k, v in infos.items()}
400
+ elif isinstance(infos, torch.Tensor):
401
+ return infos.squeeze(0)
402
+ else:
403
+ return infos
flash3d/unidepth/utils/positional_embedding.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ from math import pi
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from einops import rearrange, repeat
13
+
14
+
15
+ class PositionEmbeddingSine(nn.Module):
16
+ def __init__(
17
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
18
+ ):
19
+ super().__init__()
20
+ self.num_pos_feats = num_pos_feats
21
+ self.temperature = temperature
22
+ self.normalize = normalize
23
+ if scale is not None and normalize is False:
24
+ raise ValueError("normalize should be True if scale is passed")
25
+ if scale is None:
26
+ scale = 2 * pi
27
+ self.scale = scale
28
+
29
+ def forward(
30
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
31
+ ) -> torch.Tensor:
32
+ if mask is None:
33
+ mask = torch.zeros(
34
+ (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
35
+ )
36
+ not_mask = ~mask
37
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
38
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
39
+ if self.normalize:
40
+ eps = 1e-6
41
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
42
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
43
+
44
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
45
+ dim_t = self.temperature ** (
46
+ 2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats
47
+ )
48
+
49
+ pos_x = x_embed[:, :, :, None] / dim_t
50
+ pos_y = y_embed[:, :, :, None] / dim_t
51
+ pos_x = torch.stack(
52
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
53
+ ).flatten(3)
54
+ pos_y = torch.stack(
55
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
56
+ ).flatten(3)
57
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
58
+ return pos
59
+
60
+ def __repr__(self, _repr_indent=4):
61
+ head = "Positional encoding " + self.__class__.__name__
62
+ body = [
63
+ "num_pos_feats: {}".format(self.num_pos_feats),
64
+ "temperature: {}".format(self.temperature),
65
+ "normalize: {}".format(self.normalize),
66
+ "scale: {}".format(self.scale),
67
+ ]
68
+ # _repr_indent = 4
69
+ lines = [head] + [" " * _repr_indent + line for line in body]
70
+ return "\n".join(lines)
71
+
72
+
73
+ class LearnedSinusoidalPosEmb(nn.Module):
74
+ def __init__(self, dim):
75
+ super().__init__()
76
+ assert (dim % 2) == 0
77
+ half_dim = dim // 2
78
+ self.weights = nn.Parameter(torch.randn(half_dim))
79
+
80
+ def forward(self, x):
81
+ x = rearrange(x, "b -> b 1")
82
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
83
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
84
+ fouriered = torch.cat((x, fouriered), dim=-1)
85
+ return fouriered
86
+
87
+
88
+ def broadcat(tensors, dim=-1):
89
+ num_tensors = len(tensors)
90
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
91
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
92
+ shape_len = list(shape_lens)[0]
93
+ dim = (dim + shape_len) if dim < 0 else dim
94
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
95
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
96
+ assert all(
97
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
98
+ ), "invalid dimensions for broadcastable concatentation"
99
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
100
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
101
+ expanded_dims.insert(dim, (dim, dims[dim]))
102
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
103
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
104
+ return torch.cat(tensors, dim=dim)
105
+
106
+
107
+ def rotate_half(x):
108
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
109
+ x1, x2 = x.unbind(dim=-1)
110
+ x = torch.stack((-x2, x1), dim=-1)
111
+ return rearrange(x, "... d r -> ... (d r)")
112
+
113
+
114
+ class VisionRotaryEmbedding(nn.Module):
115
+ def __init__(
116
+ self,
117
+ dim,
118
+ pt_seq_len,
119
+ ft_seq_len=None,
120
+ custom_freqs=None,
121
+ freqs_for="lang",
122
+ theta=10000,
123
+ max_freq=10,
124
+ num_freqs=1,
125
+ ):
126
+ super().__init__()
127
+ if custom_freqs:
128
+ freqs = custom_freqs
129
+ elif freqs_for == "lang":
130
+ freqs = 1.0 / (
131
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
132
+ )
133
+ elif freqs_for == "pixel":
134
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
135
+ elif freqs_for == "constant":
136
+ freqs = torch.ones(num_freqs).float()
137
+ else:
138
+ raise ValueError(f"unknown modality {freqs_for}")
139
+
140
+ if ft_seq_len is None:
141
+ ft_seq_len = pt_seq_len
142
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
143
+
144
+ freqs_h = torch.einsum("..., f -> ... f", t, freqs)
145
+ freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
146
+
147
+ freqs_w = torch.einsum("..., f -> ... f", t, freqs)
148
+ freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
149
+
150
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
151
+
152
+ self.register_buffer("freqs_cos", freqs.cos())
153
+ self.register_buffer("freqs_sin", freqs.sin())
154
+
155
+ print("======== shape of rope freq", self.freqs_cos.shape, "========")
156
+
157
+ def forward(self, t, start_index=0):
158
+ rot_dim = self.freqs_cos.shape[-1]
159
+ end_index = start_index + rot_dim
160
+ assert (
161
+ rot_dim <= t.shape[-1]
162
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
163
+ t_left, t, t_right = (
164
+ t[..., :start_index],
165
+ t[..., start_index:end_index],
166
+ t[..., end_index:],
167
+ )
168
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
169
+ return torch.cat((t_left, t, t_right), dim=-1)
170
+
171
+
172
+ class VisionRotaryEmbeddingFast(nn.Module):
173
+ def __init__(
174
+ self,
175
+ dim,
176
+ pt_seq_len,
177
+ ft_seq_len=None,
178
+ custom_freqs=None,
179
+ freqs_for="lang",
180
+ theta=10000,
181
+ max_freq=10,
182
+ num_freqs=1,
183
+ ):
184
+ super().__init__()
185
+ if custom_freqs:
186
+ freqs = custom_freqs
187
+ elif freqs_for == "lang":
188
+ freqs = 1.0 / (
189
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
190
+ )
191
+ elif freqs_for == "pixel":
192
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
193
+ elif freqs_for == "constant":
194
+ freqs = torch.ones(num_freqs).float()
195
+ else:
196
+ raise ValueError(f"unknown modality {freqs_for}")
197
+
198
+ if ft_seq_len is None:
199
+ ft_seq_len = pt_seq_len
200
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
201
+
202
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
203
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
204
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
205
+
206
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
207
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
208
+
209
+ self.register_buffer("freqs_cos", freqs_cos)
210
+ self.register_buffer("freqs_sin", freqs_sin)
211
+
212
+ def forward(self, t):
213
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
214
+
215
+
216
+ from math import log2
217
+
218
+
219
+ def generate_fourier_features(
220
+ x: torch.Tensor,
221
+ dim: int = 512,
222
+ max_freq: int = 64,
223
+ use_cos: bool = False,
224
+ use_log: bool = False,
225
+ cat_orig: bool = False,
226
+ ):
227
+ x_orig = x
228
+ device, dtype, input_dim = x.device, x.dtype, x.shape[-1]
229
+ num_bands = dim // (2 * input_dim) if use_cos else dim // input_dim
230
+
231
+ if use_log:
232
+ scales = 2.0 ** torch.linspace(
233
+ 0.0, log2(max_freq), steps=num_bands, device=device, dtype=dtype
234
+ )
235
+ else:
236
+ scales = torch.linspace(
237
+ 1.0, max_freq / 2, num_bands, device=device, dtype=dtype
238
+ )
239
+
240
+ x = x.unsqueeze(-1)
241
+ scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
242
+
243
+ x = x * scales * pi
244
+ x = torch.cat(
245
+ (
246
+ [x.sin(), x.cos()]
247
+ if use_cos
248
+ else [
249
+ x.sin(),
250
+ ]
251
+ ),
252
+ dim=-1,
253
+ )
254
+ x = x.flatten(-2)
255
+ if cat_orig:
256
+ return torch.cat((x, x_orig), dim=-1)
257
+ return x
258
+
259
+
260
+ # from PIL import Image
261
+ # from unidepth.utils import image_grid, colorize
262
+ # if __name__ == "__main__":
263
+ # H, W = 512, 512
264
+ # resolution = 128
265
+ # mesh = torch.meshgrid(torch.linspace(-1, 1, H), torch.linspace(-1, 1, W))
266
+ # mesh = torch.stack(mesh, dim=0).unsqueeze(0)
267
+ # mesh = mesh.view(1, 2, -1).permute(0, 2, 1)
268
+
269
+ # features = generate_fourier_features(mesh, dim=32, max_freq=resolution, use_log=True)
270
+ # channels = features.shape[-1]
271
+ # print(features.shape)
272
+
273
+ # features = features[0].view(H, W, channels).permute(2, 0, 1).numpy()
274
+ # Image.fromarray(image_grid([colorize(1+x, 0.0, 2.0, "viridis") for x in features], rows=8, cols=4)).save(f"tmp_{resolution}.png")
flash3d/unidepth/utils/sht.py ADDED
@@ -0,0 +1,1637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Real spherical harmonics in Cartesian form for PyTorch.
2
+
3
+ This is an autogenerated file. See
4
+ https://github.com/cheind/torch-spherical-harmonics
5
+ for more information.
6
+ """
7
+
8
+ import torch
9
+
10
+
11
+ def rsh_cart_0(xyz: torch.Tensor):
12
+ """Computes all real spherical harmonics up to degree 0.
13
+
14
+ This is an autogenerated method. See
15
+ https://github.com/cheind/torch-spherical-harmonics
16
+ for more information.
17
+
18
+ Params:
19
+ xyz: (N,...,3) tensor of points on the unit sphere
20
+
21
+ Returns:
22
+ rsh: (N,...,1) real spherical harmonics
23
+ projections of input. Ynm is found at index
24
+ `n*(n+1) + m`, with `0 <= n <= degree` and
25
+ `-n <= m <= n`.
26
+ """
27
+
28
+ return torch.stack(
29
+ [
30
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
31
+ ],
32
+ -1,
33
+ )
34
+
35
+
36
+ def rsh_cart_1(xyz: torch.Tensor):
37
+ """Computes all real spherical harmonics up to degree 1.
38
+
39
+ This is an autogenerated method. See
40
+ https://github.com/cheind/torch-spherical-harmonics
41
+ for more information.
42
+
43
+ Params:
44
+ xyz: (N,...,3) tensor of points on the unit sphere
45
+
46
+ Returns:
47
+ rsh: (N,...,4) real spherical harmonics
48
+ projections of input. Ynm is found at index
49
+ `n*(n+1) + m`, with `0 <= n <= degree` and
50
+ `-n <= m <= n`.
51
+ """
52
+ x = xyz[..., 0]
53
+ y = xyz[..., 1]
54
+ z = xyz[..., 2]
55
+
56
+ return torch.stack(
57
+ [
58
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
59
+ -0.48860251190292 * y,
60
+ 0.48860251190292 * z,
61
+ -0.48860251190292 * x,
62
+ ],
63
+ -1,
64
+ )
65
+
66
+
67
+ def rsh_cart_2(xyz: torch.Tensor):
68
+ """Computes all real spherical harmonics up to degree 2.
69
+
70
+ This is an autogenerated method. See
71
+ https://github.com/cheind/torch-spherical-harmonics
72
+ for more information.
73
+
74
+ Params:
75
+ xyz: (N,...,3) tensor of points on the unit sphere
76
+
77
+ Returns:
78
+ rsh: (N,...,9) real spherical harmonics
79
+ projections of input. Ynm is found at index
80
+ `n*(n+1) + m`, with `0 <= n <= degree` and
81
+ `-n <= m <= n`.
82
+ """
83
+ x = xyz[..., 0]
84
+ y = xyz[..., 1]
85
+ z = xyz[..., 2]
86
+
87
+ x2 = x**2
88
+ y2 = y**2
89
+ z2 = z**2
90
+ xy = x * y
91
+ xz = x * z
92
+ yz = y * z
93
+
94
+ return torch.stack(
95
+ [
96
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
97
+ -0.48860251190292 * y,
98
+ 0.48860251190292 * z,
99
+ -0.48860251190292 * x,
100
+ 1.09254843059208 * xy,
101
+ -1.09254843059208 * yz,
102
+ 0.94617469575756 * z2 - 0.31539156525252,
103
+ -1.09254843059208 * xz,
104
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
105
+ ],
106
+ -1,
107
+ )
108
+
109
+
110
+ def rsh_cart_3(xyz: torch.Tensor):
111
+ """Computes all real spherical harmonics up to degree 3.
112
+
113
+ This is an autogenerated method. See
114
+ https://github.com/cheind/torch-spherical-harmonics
115
+ for more information.
116
+
117
+ Params:
118
+ xyz: (N,...,3) tensor of points on the unit sphere
119
+
120
+ Returns:
121
+ rsh: (N,...,16) real spherical harmonics
122
+ projections of input. Ynm is found at index
123
+ `n*(n+1) + m`, with `0 <= n <= degree` and
124
+ `-n <= m <= n`.
125
+ """
126
+ x = xyz[..., 0]
127
+ y = xyz[..., 1]
128
+ z = xyz[..., 2]
129
+
130
+ x2 = x**2
131
+ y2 = y**2
132
+ z2 = z**2
133
+ xy = x * y
134
+ xz = x * z
135
+ yz = y * z
136
+
137
+ return torch.stack(
138
+ [
139
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
140
+ -0.48860251190292 * y,
141
+ 0.48860251190292 * z,
142
+ -0.48860251190292 * x,
143
+ 1.09254843059208 * xy,
144
+ -1.09254843059208 * yz,
145
+ 0.94617469575756 * z2 - 0.31539156525252,
146
+ -1.09254843059208 * xz,
147
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
148
+ -0.590043589926644 * y * (3.0 * x2 - y2),
149
+ 2.89061144264055 * xy * z,
150
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
151
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
152
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
153
+ 1.44530572132028 * z * (x2 - y2),
154
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
155
+ ],
156
+ -1,
157
+ )
158
+
159
+
160
+ def rsh_cart_4(xyz: torch.Tensor):
161
+ """Computes all real spherical harmonics up to degree 4.
162
+
163
+ This is an autogenerated method. See
164
+ https://github.com/cheind/torch-spherical-harmonics
165
+ for more information.
166
+
167
+ Params:
168
+ xyz: (N,...,3) tensor of points on the unit sphere
169
+
170
+ Returns:
171
+ rsh: (N,...,25) real spherical harmonics
172
+ projections of input. Ynm is found at index
173
+ `n*(n+1) + m`, with `0 <= n <= degree` and
174
+ `-n <= m <= n`.
175
+ """
176
+ x = xyz[..., 0]
177
+ y = xyz[..., 1]
178
+ z = xyz[..., 2]
179
+
180
+ x2 = x**2
181
+ y2 = y**2
182
+ z2 = z**2
183
+ xy = x * y
184
+ xz = x * z
185
+ yz = y * z
186
+ x4 = x2**2
187
+ y4 = y2**2
188
+ z4 = z2**2
189
+
190
+ return torch.stack(
191
+ [
192
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
193
+ -0.48860251190292 * y,
194
+ 0.48860251190292 * z,
195
+ -0.48860251190292 * x,
196
+ 1.09254843059208 * xy,
197
+ -1.09254843059208 * yz,
198
+ 0.94617469575756 * z2 - 0.31539156525252,
199
+ -1.09254843059208 * xz,
200
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
201
+ -0.590043589926644 * y * (3.0 * x2 - y2),
202
+ 2.89061144264055 * xy * z,
203
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
204
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
205
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
206
+ 1.44530572132028 * z * (x2 - y2),
207
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
208
+ 2.5033429417967 * xy * (x2 - y2),
209
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
210
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
211
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
212
+ 1.48099765681286
213
+ * z
214
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
215
+ - 0.952069922236839 * z2
216
+ + 0.317356640745613,
217
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
218
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
219
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
220
+ -3.75501441269506 * x2 * y2
221
+ + 0.625835735449176 * x4
222
+ + 0.625835735449176 * y4,
223
+ ],
224
+ -1,
225
+ )
226
+
227
+
228
+ def rsh_cart_5(xyz: torch.Tensor):
229
+ """Computes all real spherical harmonics up to degree 5.
230
+
231
+ This is an autogenerated method. See
232
+ https://github.com/cheind/torch-spherical-harmonics
233
+ for more information.
234
+
235
+ Params:
236
+ xyz: (N,...,3) tensor of points on the unit sphere
237
+
238
+ Returns:
239
+ rsh: (N,...,36) real spherical harmonics
240
+ projections of input. Ynm is found at index
241
+ `n*(n+1) + m`, with `0 <= n <= degree` and
242
+ `-n <= m <= n`.
243
+ """
244
+ x = xyz[..., 0]
245
+ y = xyz[..., 1]
246
+ z = xyz[..., 2]
247
+
248
+ x2 = x**2
249
+ y2 = y**2
250
+ z2 = z**2
251
+ xy = x * y
252
+ xz = x * z
253
+ yz = y * z
254
+ x4 = x2**2
255
+ y4 = y2**2
256
+ z4 = z2**2
257
+
258
+ return torch.stack(
259
+ [
260
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
261
+ -0.48860251190292 * y,
262
+ 0.48860251190292 * z,
263
+ -0.48860251190292 * x,
264
+ 1.09254843059208 * xy,
265
+ -1.09254843059208 * yz,
266
+ 0.94617469575756 * z2 - 0.31539156525252,
267
+ -1.09254843059208 * xz,
268
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
269
+ -0.590043589926644 * y * (3.0 * x2 - y2),
270
+ 2.89061144264055 * xy * z,
271
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
272
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
273
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
274
+ 1.44530572132028 * z * (x2 - y2),
275
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
276
+ 2.5033429417967 * xy * (x2 - y2),
277
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
278
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
279
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
280
+ 1.48099765681286
281
+ * z
282
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
283
+ - 0.952069922236839 * z2
284
+ + 0.317356640745613,
285
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
286
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
287
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
288
+ -3.75501441269506 * x2 * y2
289
+ + 0.625835735449176 * x4
290
+ + 0.625835735449176 * y4,
291
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
292
+ 8.30264925952416 * xy * z * (x2 - y2),
293
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
294
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
295
+ 0.241571547304372
296
+ * y
297
+ * (
298
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
299
+ + 9.375 * z2
300
+ - 1.875
301
+ ),
302
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
303
+ + 1.6840846433293
304
+ * z
305
+ * (
306
+ 1.75
307
+ * z
308
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
309
+ - 1.125 * z2
310
+ + 0.375
311
+ )
312
+ + 0.498988042467941 * z,
313
+ 0.241571547304372
314
+ * x
315
+ * (
316
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
317
+ + 9.375 * z2
318
+ - 1.875
319
+ ),
320
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
321
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
322
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
323
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
324
+ ],
325
+ -1,
326
+ )
327
+
328
+
329
+ def rsh_cart_6(xyz: torch.Tensor):
330
+ """Computes all real spherical harmonics up to degree 6.
331
+
332
+ This is an autogenerated method. See
333
+ https://github.com/cheind/torch-spherical-harmonics
334
+ for more information.
335
+
336
+ Params:
337
+ xyz: (N,...,3) tensor of points on the unit sphere
338
+
339
+ Returns:
340
+ rsh: (N,...,49) real spherical harmonics
341
+ projections of input. Ynm is found at index
342
+ `n*(n+1) + m`, with `0 <= n <= degree` and
343
+ `-n <= m <= n`.
344
+ """
345
+ x = xyz[..., 0]
346
+ y = xyz[..., 1]
347
+ z = xyz[..., 2]
348
+
349
+ x2 = x**2
350
+ y2 = y**2
351
+ z2 = z**2
352
+ xy = x * y
353
+ xz = x * z
354
+ yz = y * z
355
+ x4 = x2**2
356
+ y4 = y2**2
357
+ z4 = z2**2
358
+
359
+ return torch.stack(
360
+ [
361
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
362
+ -0.48860251190292 * y,
363
+ 0.48860251190292 * z,
364
+ -0.48860251190292 * x,
365
+ 1.09254843059208 * xy,
366
+ -1.09254843059208 * yz,
367
+ 0.94617469575756 * z2 - 0.31539156525252,
368
+ -1.09254843059208 * xz,
369
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
370
+ -0.590043589926644 * y * (3.0 * x2 - y2),
371
+ 2.89061144264055 * xy * z,
372
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
373
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
374
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
375
+ 1.44530572132028 * z * (x2 - y2),
376
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
377
+ 2.5033429417967 * xy * (x2 - y2),
378
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
379
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
380
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
381
+ 1.48099765681286
382
+ * z
383
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
384
+ - 0.952069922236839 * z2
385
+ + 0.317356640745613,
386
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
387
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
388
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
389
+ -3.75501441269506 * x2 * y2
390
+ + 0.625835735449176 * x4
391
+ + 0.625835735449176 * y4,
392
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
393
+ 8.30264925952416 * xy * z * (x2 - y2),
394
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
395
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
396
+ 0.241571547304372
397
+ * y
398
+ * (
399
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
400
+ + 9.375 * z2
401
+ - 1.875
402
+ ),
403
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
404
+ + 1.6840846433293
405
+ * z
406
+ * (
407
+ 1.75
408
+ * z
409
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
410
+ - 1.125 * z2
411
+ + 0.375
412
+ )
413
+ + 0.498988042467941 * z,
414
+ 0.241571547304372
415
+ * x
416
+ * (
417
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
418
+ + 9.375 * z2
419
+ - 1.875
420
+ ),
421
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
422
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
423
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
424
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
425
+ 4.09910463115149 * x**4 * xy
426
+ - 13.6636821038383 * xy**3
427
+ + 4.09910463115149 * xy * y**4,
428
+ -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
429
+ 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
430
+ 0.00584892228263444
431
+ * y
432
+ * (3.0 * x2 - y2)
433
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
434
+ 0.0701870673916132
435
+ * xy
436
+ * (
437
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
438
+ - 91.875 * z2
439
+ + 13.125
440
+ ),
441
+ 0.221950995245231
442
+ * y
443
+ * (
444
+ -2.8 * z * (1.5 - 7.5 * z2)
445
+ + 2.2
446
+ * z
447
+ * (
448
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
449
+ + 9.375 * z2
450
+ - 1.875
451
+ )
452
+ - 4.8 * z
453
+ ),
454
+ -1.48328138624466
455
+ * z
456
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
457
+ + 1.86469659985043
458
+ * z
459
+ * (
460
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
461
+ + 1.8
462
+ * z
463
+ * (
464
+ 1.75
465
+ * z
466
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
467
+ - 1.125 * z2
468
+ + 0.375
469
+ )
470
+ + 0.533333333333333 * z
471
+ )
472
+ + 0.953538034014426 * z2
473
+ - 0.317846011338142,
474
+ 0.221950995245231
475
+ * x
476
+ * (
477
+ -2.8 * z * (1.5 - 7.5 * z2)
478
+ + 2.2
479
+ * z
480
+ * (
481
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
482
+ + 9.375 * z2
483
+ - 1.875
484
+ )
485
+ - 4.8 * z
486
+ ),
487
+ 0.0350935336958066
488
+ * (x2 - y2)
489
+ * (
490
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
491
+ - 91.875 * z2
492
+ + 13.125
493
+ ),
494
+ 0.00584892228263444
495
+ * x
496
+ * (x2 - 3.0 * y2)
497
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
498
+ 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
499
+ -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
500
+ 0.683184105191914 * x2**3
501
+ + 10.2477615778787 * x2 * y4
502
+ - 10.2477615778787 * x4 * y2
503
+ - 0.683184105191914 * y2**3,
504
+ ],
505
+ -1,
506
+ )
507
+
508
+
509
+ def rsh_cart_7(xyz: torch.Tensor):
510
+ """Computes all real spherical harmonics up to degree 7.
511
+
512
+ This is an autogenerated method. See
513
+ https://github.com/cheind/torch-spherical-harmonics
514
+ for more information.
515
+
516
+ Params:
517
+ xyz: (N,...,3) tensor of points on the unit sphere
518
+
519
+ Returns:
520
+ rsh: (N,...,64) real spherical harmonics
521
+ projections of input. Ynm is found at index
522
+ `n*(n+1) + m`, with `0 <= n <= degree` and
523
+ `-n <= m <= n`.
524
+ """
525
+ x = xyz[..., 0]
526
+ y = xyz[..., 1]
527
+ z = xyz[..., 2]
528
+
529
+ x2 = x**2
530
+ y2 = y**2
531
+ z2 = z**2
532
+ xy = x * y
533
+ xz = x * z
534
+ yz = y * z
535
+ x4 = x2**2
536
+ y4 = y2**2
537
+ z4 = z2**2
538
+
539
+ return torch.stack(
540
+ [
541
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
542
+ -0.48860251190292 * y,
543
+ 0.48860251190292 * z,
544
+ -0.48860251190292 * x,
545
+ 1.09254843059208 * xy,
546
+ -1.09254843059208 * yz,
547
+ 0.94617469575756 * z2 - 0.31539156525252,
548
+ -1.09254843059208 * xz,
549
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
550
+ -0.590043589926644 * y * (3.0 * x2 - y2),
551
+ 2.89061144264055 * xy * z,
552
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
553
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
554
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
555
+ 1.44530572132028 * z * (x2 - y2),
556
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
557
+ 2.5033429417967 * xy * (x2 - y2),
558
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
559
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
560
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
561
+ 1.48099765681286
562
+ * z
563
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
564
+ - 0.952069922236839 * z2
565
+ + 0.317356640745613,
566
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
567
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
568
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
569
+ -3.75501441269506 * x2 * y2
570
+ + 0.625835735449176 * x4
571
+ + 0.625835735449176 * y4,
572
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
573
+ 8.30264925952416 * xy * z * (x2 - y2),
574
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
575
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
576
+ 0.241571547304372
577
+ * y
578
+ * (
579
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
580
+ + 9.375 * z2
581
+ - 1.875
582
+ ),
583
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
584
+ + 1.6840846433293
585
+ * z
586
+ * (
587
+ 1.75
588
+ * z
589
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
590
+ - 1.125 * z2
591
+ + 0.375
592
+ )
593
+ + 0.498988042467941 * z,
594
+ 0.241571547304372
595
+ * x
596
+ * (
597
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
598
+ + 9.375 * z2
599
+ - 1.875
600
+ ),
601
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
602
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
603
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
604
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
605
+ 4.09910463115149 * x**4 * xy
606
+ - 13.6636821038383 * xy**3
607
+ + 4.09910463115149 * xy * y**4,
608
+ -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
609
+ 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
610
+ 0.00584892228263444
611
+ * y
612
+ * (3.0 * x2 - y2)
613
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
614
+ 0.0701870673916132
615
+ * xy
616
+ * (
617
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
618
+ - 91.875 * z2
619
+ + 13.125
620
+ ),
621
+ 0.221950995245231
622
+ * y
623
+ * (
624
+ -2.8 * z * (1.5 - 7.5 * z2)
625
+ + 2.2
626
+ * z
627
+ * (
628
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
629
+ + 9.375 * z2
630
+ - 1.875
631
+ )
632
+ - 4.8 * z
633
+ ),
634
+ -1.48328138624466
635
+ * z
636
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
637
+ + 1.86469659985043
638
+ * z
639
+ * (
640
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
641
+ + 1.8
642
+ * z
643
+ * (
644
+ 1.75
645
+ * z
646
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
647
+ - 1.125 * z2
648
+ + 0.375
649
+ )
650
+ + 0.533333333333333 * z
651
+ )
652
+ + 0.953538034014426 * z2
653
+ - 0.317846011338142,
654
+ 0.221950995245231
655
+ * x
656
+ * (
657
+ -2.8 * z * (1.5 - 7.5 * z2)
658
+ + 2.2
659
+ * z
660
+ * (
661
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
662
+ + 9.375 * z2
663
+ - 1.875
664
+ )
665
+ - 4.8 * z
666
+ ),
667
+ 0.0350935336958066
668
+ * (x2 - y2)
669
+ * (
670
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
671
+ - 91.875 * z2
672
+ + 13.125
673
+ ),
674
+ 0.00584892228263444
675
+ * x
676
+ * (x2 - 3.0 * y2)
677
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
678
+ 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
679
+ -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
680
+ 0.683184105191914 * x2**3
681
+ + 10.2477615778787 * x2 * y4
682
+ - 10.2477615778787 * x4 * y2
683
+ - 0.683184105191914 * y2**3,
684
+ -0.707162732524596
685
+ * y
686
+ * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
687
+ 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
688
+ 9.98394571852353e-5
689
+ * y
690
+ * (5197.5 - 67567.5 * z2)
691
+ * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
692
+ 0.00239614697244565
693
+ * xy
694
+ * (x2 - y2)
695
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
696
+ 0.00397356022507413
697
+ * y
698
+ * (3.0 * x2 - y2)
699
+ * (
700
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
701
+ + 1063.125 * z2
702
+ - 118.125
703
+ ),
704
+ 0.0561946276120613
705
+ * xy
706
+ * (
707
+ -4.8 * z * (52.5 * z2 - 7.5)
708
+ + 2.6
709
+ * z
710
+ * (
711
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
712
+ - 91.875 * z2
713
+ + 13.125
714
+ )
715
+ + 48.0 * z
716
+ ),
717
+ 0.206472245902897
718
+ * y
719
+ * (
720
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
721
+ + 2.16666666666667
722
+ * z
723
+ * (
724
+ -2.8 * z * (1.5 - 7.5 * z2)
725
+ + 2.2
726
+ * z
727
+ * (
728
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
729
+ + 9.375 * z2
730
+ - 1.875
731
+ )
732
+ - 4.8 * z
733
+ )
734
+ - 10.9375 * z2
735
+ + 2.1875
736
+ ),
737
+ 1.24862677781952 * z * (1.5 * z2 - 0.5)
738
+ - 1.68564615005635
739
+ * z
740
+ * (
741
+ 1.75
742
+ * z
743
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
744
+ - 1.125 * z2
745
+ + 0.375
746
+ )
747
+ + 2.02901851395672
748
+ * z
749
+ * (
750
+ -1.45833333333333
751
+ * z
752
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
753
+ + 1.83333333333333
754
+ * z
755
+ * (
756
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
757
+ + 1.8
758
+ * z
759
+ * (
760
+ 1.75
761
+ * z
762
+ * (
763
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
764
+ - 0.666666666666667 * z
765
+ )
766
+ - 1.125 * z2
767
+ + 0.375
768
+ )
769
+ + 0.533333333333333 * z
770
+ )
771
+ + 0.9375 * z2
772
+ - 0.3125
773
+ )
774
+ - 0.499450711127808 * z,
775
+ 0.206472245902897
776
+ * x
777
+ * (
778
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
779
+ + 2.16666666666667
780
+ * z
781
+ * (
782
+ -2.8 * z * (1.5 - 7.5 * z2)
783
+ + 2.2
784
+ * z
785
+ * (
786
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
787
+ + 9.375 * z2
788
+ - 1.875
789
+ )
790
+ - 4.8 * z
791
+ )
792
+ - 10.9375 * z2
793
+ + 2.1875
794
+ ),
795
+ 0.0280973138060306
796
+ * (x2 - y2)
797
+ * (
798
+ -4.8 * z * (52.5 * z2 - 7.5)
799
+ + 2.6
800
+ * z
801
+ * (
802
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
803
+ - 91.875 * z2
804
+ + 13.125
805
+ )
806
+ + 48.0 * z
807
+ ),
808
+ 0.00397356022507413
809
+ * x
810
+ * (x2 - 3.0 * y2)
811
+ * (
812
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
813
+ + 1063.125 * z2
814
+ - 118.125
815
+ ),
816
+ 0.000599036743111412
817
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
818
+ * (-6.0 * x2 * y2 + x4 + y4),
819
+ 9.98394571852353e-5
820
+ * x
821
+ * (5197.5 - 67567.5 * z2)
822
+ * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
823
+ 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
824
+ -0.707162732524596
825
+ * x
826
+ * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
827
+ ],
828
+ -1,
829
+ )
830
+
831
+
832
+ # @torch.jit.script
833
+ def rsh_cart_8(xyz: torch.Tensor):
834
+ """Computes all real spherical harmonics up to degree 8.
835
+
836
+ This is an autogenerated method. See
837
+ https://github.com/cheind/torch-spherical-harmonics
838
+ for more information.
839
+
840
+ Params:
841
+ xyz: (N,...,3) tensor of points on the unit sphere
842
+
843
+ Returns:
844
+ rsh: (N,...,81) real spherical harmonics
845
+ projections of input. Ynm is found at index
846
+ `n*(n+1) + m`, with `0 <= n <= degree` and
847
+ `-n <= m <= n`.
848
+ """
849
+ x = xyz[..., 0]
850
+ y = xyz[..., 1]
851
+ z = xyz[..., 2]
852
+
853
+ x2 = x**2
854
+ y2 = y**2
855
+ z2 = z**2
856
+ xy = x * y
857
+ xz = x * z
858
+ yz = y * z
859
+ x4 = x2**2
860
+ y4 = y2**2
861
+ # z4 = z2**2
862
+ return torch.stack(
863
+ [
864
+ 0.282094791773878 * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]),
865
+ -0.48860251190292 * y,
866
+ 0.48860251190292 * z,
867
+ -0.48860251190292 * x,
868
+ 1.09254843059208 * xy,
869
+ -1.09254843059208 * yz,
870
+ 0.94617469575756 * z2 - 0.31539156525252,
871
+ -1.09254843059208 * xz,
872
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
873
+ -0.590043589926644 * y * (3.0 * x2 - y2),
874
+ 2.89061144264055 * xy * z,
875
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
876
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
877
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
878
+ 1.44530572132028 * z * (x2 - y2),
879
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
880
+ 2.5033429417967 * xy * (x2 - y2),
881
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
882
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
883
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
884
+ 1.48099765681286
885
+ * z
886
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
887
+ - 0.952069922236839 * z2
888
+ + 0.317356640745613,
889
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
890
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
891
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
892
+ -3.75501441269506 * x2 * y2
893
+ + 0.625835735449176 * x4
894
+ + 0.625835735449176 * y4,
895
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
896
+ 8.30264925952416 * xy * z * (x2 - y2),
897
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
898
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
899
+ 0.241571547304372
900
+ * y
901
+ * (
902
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
903
+ + 9.375 * z2
904
+ - 1.875
905
+ ),
906
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
907
+ + 1.6840846433293
908
+ * z
909
+ * (
910
+ 1.75
911
+ * z
912
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
913
+ - 1.125 * z2
914
+ + 0.375
915
+ )
916
+ + 0.498988042467941 * z,
917
+ 0.241571547304372
918
+ * x
919
+ * (
920
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
921
+ + 9.375 * z2
922
+ - 1.875
923
+ ),
924
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
925
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
926
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
927
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
928
+ 4.09910463115149 * x**4 * xy
929
+ - 13.6636821038383 * xy**3
930
+ + 4.09910463115149 * xy * y**4,
931
+ -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
932
+ 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
933
+ 0.00584892228263444
934
+ * y
935
+ * (3.0 * x2 - y2)
936
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
937
+ 0.0701870673916132
938
+ * xy
939
+ * (
940
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
941
+ - 91.875 * z2
942
+ + 13.125
943
+ ),
944
+ 0.221950995245231
945
+ * y
946
+ * (
947
+ -2.8 * z * (1.5 - 7.5 * z2)
948
+ + 2.2
949
+ * z
950
+ * (
951
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
952
+ + 9.375 * z2
953
+ - 1.875
954
+ )
955
+ - 4.8 * z
956
+ ),
957
+ -1.48328138624466
958
+ * z
959
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
960
+ + 1.86469659985043
961
+ * z
962
+ * (
963
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
964
+ + 1.8
965
+ * z
966
+ * (
967
+ 1.75
968
+ * z
969
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
970
+ - 1.125 * z2
971
+ + 0.375
972
+ )
973
+ + 0.533333333333333 * z
974
+ )
975
+ + 0.953538034014426 * z2
976
+ - 0.317846011338142,
977
+ 0.221950995245231
978
+ * x
979
+ * (
980
+ -2.8 * z * (1.5 - 7.5 * z2)
981
+ + 2.2
982
+ * z
983
+ * (
984
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
985
+ + 9.375 * z2
986
+ - 1.875
987
+ )
988
+ - 4.8 * z
989
+ ),
990
+ 0.0350935336958066
991
+ * (x2 - y2)
992
+ * (
993
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
994
+ - 91.875 * z2
995
+ + 13.125
996
+ ),
997
+ 0.00584892228263444
998
+ * x
999
+ * (x2 - 3.0 * y2)
1000
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
1001
+ 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
1002
+ -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
1003
+ 0.683184105191914 * x2**3
1004
+ + 10.2477615778787 * x2 * y4
1005
+ - 10.2477615778787 * x4 * y2
1006
+ - 0.683184105191914 * y2**3,
1007
+ -0.707162732524596
1008
+ * y
1009
+ * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
1010
+ 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
1011
+ 9.98394571852353e-5
1012
+ * y
1013
+ * (5197.5 - 67567.5 * z2)
1014
+ * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
1015
+ 0.00239614697244565
1016
+ * xy
1017
+ * (x2 - y2)
1018
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
1019
+ 0.00397356022507413
1020
+ * y
1021
+ * (3.0 * x2 - y2)
1022
+ * (
1023
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
1024
+ + 1063.125 * z2
1025
+ - 118.125
1026
+ ),
1027
+ 0.0561946276120613
1028
+ * xy
1029
+ * (
1030
+ -4.8 * z * (52.5 * z2 - 7.5)
1031
+ + 2.6
1032
+ * z
1033
+ * (
1034
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
1035
+ - 91.875 * z2
1036
+ + 13.125
1037
+ )
1038
+ + 48.0 * z
1039
+ ),
1040
+ 0.206472245902897
1041
+ * y
1042
+ * (
1043
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1044
+ + 2.16666666666667
1045
+ * z
1046
+ * (
1047
+ -2.8 * z * (1.5 - 7.5 * z2)
1048
+ + 2.2
1049
+ * z
1050
+ * (
1051
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1052
+ + 9.375 * z2
1053
+ - 1.875
1054
+ )
1055
+ - 4.8 * z
1056
+ )
1057
+ - 10.9375 * z2
1058
+ + 2.1875
1059
+ ),
1060
+ 1.24862677781952 * z * (1.5 * z2 - 0.5)
1061
+ - 1.68564615005635
1062
+ * z
1063
+ * (
1064
+ 1.75
1065
+ * z
1066
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
1067
+ - 1.125 * z2
1068
+ + 0.375
1069
+ )
1070
+ + 2.02901851395672
1071
+ * z
1072
+ * (
1073
+ -1.45833333333333
1074
+ * z
1075
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
1076
+ + 1.83333333333333
1077
+ * z
1078
+ * (
1079
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
1080
+ + 1.8
1081
+ * z
1082
+ * (
1083
+ 1.75
1084
+ * z
1085
+ * (
1086
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
1087
+ - 0.666666666666667 * z
1088
+ )
1089
+ - 1.125 * z2
1090
+ + 0.375
1091
+ )
1092
+ + 0.533333333333333 * z
1093
+ )
1094
+ + 0.9375 * z2
1095
+ - 0.3125
1096
+ )
1097
+ - 0.499450711127808 * z,
1098
+ 0.206472245902897
1099
+ * x
1100
+ * (
1101
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1102
+ + 2.16666666666667
1103
+ * z
1104
+ * (
1105
+ -2.8 * z * (1.5 - 7.5 * z2)
1106
+ + 2.2
1107
+ * z
1108
+ * (
1109
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1110
+ + 9.375 * z2
1111
+ - 1.875
1112
+ )
1113
+ - 4.8 * z
1114
+ )
1115
+ - 10.9375 * z2
1116
+ + 2.1875
1117
+ ),
1118
+ 0.0280973138060306
1119
+ * (x2 - y2)
1120
+ * (
1121
+ -4.8 * z * (52.5 * z2 - 7.5)
1122
+ + 2.6
1123
+ * z
1124
+ * (
1125
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
1126
+ - 91.875 * z2
1127
+ + 13.125
1128
+ )
1129
+ + 48.0 * z
1130
+ ),
1131
+ 0.00397356022507413
1132
+ * x
1133
+ * (x2 - 3.0 * y2)
1134
+ * (
1135
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
1136
+ + 1063.125 * z2
1137
+ - 118.125
1138
+ ),
1139
+ 0.000599036743111412
1140
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
1141
+ * (-6.0 * x2 * y2 + x4 + y4),
1142
+ 9.98394571852353e-5
1143
+ * x
1144
+ * (5197.5 - 67567.5 * z2)
1145
+ * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
1146
+ 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
1147
+ -0.707162732524596
1148
+ * x
1149
+ * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
1150
+ 5.83141328139864 * xy * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3),
1151
+ -2.91570664069932
1152
+ * yz
1153
+ * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
1154
+ 7.87853281621404e-6
1155
+ * (1013512.5 * z2 - 67567.5)
1156
+ * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
1157
+ 5.10587282657803e-5
1158
+ * y
1159
+ * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
1160
+ * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
1161
+ 0.00147275890257803
1162
+ * xy
1163
+ * (x2 - y2)
1164
+ * (
1165
+ 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
1166
+ - 14293.125 * z2
1167
+ + 1299.375
1168
+ ),
1169
+ 0.0028519853513317
1170
+ * y
1171
+ * (3.0 * x2 - y2)
1172
+ * (
1173
+ -7.33333333333333 * z * (52.5 - 472.5 * z2)
1174
+ + 3.0
1175
+ * z
1176
+ * (
1177
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
1178
+ + 1063.125 * z2
1179
+ - 118.125
1180
+ )
1181
+ - 560.0 * z
1182
+ ),
1183
+ 0.0463392770473559
1184
+ * xy
1185
+ * (
1186
+ -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
1187
+ + 2.5
1188
+ * z
1189
+ * (
1190
+ -4.8 * z * (52.5 * z2 - 7.5)
1191
+ + 2.6
1192
+ * z
1193
+ * (
1194
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
1195
+ - 91.875 * z2
1196
+ + 13.125
1197
+ )
1198
+ + 48.0 * z
1199
+ )
1200
+ + 137.8125 * z2
1201
+ - 19.6875
1202
+ ),
1203
+ 0.193851103820053
1204
+ * y
1205
+ * (
1206
+ 3.2 * z * (1.5 - 7.5 * z2)
1207
+ - 2.51428571428571
1208
+ * z
1209
+ * (
1210
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1211
+ + 9.375 * z2
1212
+ - 1.875
1213
+ )
1214
+ + 2.14285714285714
1215
+ * z
1216
+ * (
1217
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1218
+ + 2.16666666666667
1219
+ * z
1220
+ * (
1221
+ -2.8 * z * (1.5 - 7.5 * z2)
1222
+ + 2.2
1223
+ * z
1224
+ * (
1225
+ 2.25
1226
+ * z
1227
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1228
+ + 9.375 * z2
1229
+ - 1.875
1230
+ )
1231
+ - 4.8 * z
1232
+ )
1233
+ - 10.9375 * z2
1234
+ + 2.1875
1235
+ )
1236
+ + 5.48571428571429 * z
1237
+ ),
1238
+ 1.48417251362228
1239
+ * z
1240
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
1241
+ - 1.86581687426801
1242
+ * z
1243
+ * (
1244
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
1245
+ + 1.8
1246
+ * z
1247
+ * (
1248
+ 1.75
1249
+ * z
1250
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
1251
+ - 1.125 * z2
1252
+ + 0.375
1253
+ )
1254
+ + 0.533333333333333 * z
1255
+ )
1256
+ + 2.1808249179756
1257
+ * z
1258
+ * (
1259
+ 1.14285714285714 * z * (1.5 * z2 - 0.5)
1260
+ - 1.54285714285714
1261
+ * z
1262
+ * (
1263
+ 1.75
1264
+ * z
1265
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
1266
+ - 1.125 * z2
1267
+ + 0.375
1268
+ )
1269
+ + 1.85714285714286
1270
+ * z
1271
+ * (
1272
+ -1.45833333333333
1273
+ * z
1274
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
1275
+ + 1.83333333333333
1276
+ * z
1277
+ * (
1278
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
1279
+ + 1.8
1280
+ * z
1281
+ * (
1282
+ 1.75
1283
+ * z
1284
+ * (
1285
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
1286
+ - 0.666666666666667 * z
1287
+ )
1288
+ - 1.125 * z2
1289
+ + 0.375
1290
+ )
1291
+ + 0.533333333333333 * z
1292
+ )
1293
+ + 0.9375 * z2
1294
+ - 0.3125
1295
+ )
1296
+ - 0.457142857142857 * z
1297
+ )
1298
+ - 0.954110901614325 * z2
1299
+ + 0.318036967204775,
1300
+ 0.193851103820053
1301
+ * x
1302
+ * (
1303
+ 3.2 * z * (1.5 - 7.5 * z2)
1304
+ - 2.51428571428571
1305
+ * z
1306
+ * (
1307
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1308
+ + 9.375 * z2
1309
+ - 1.875
1310
+ )
1311
+ + 2.14285714285714
1312
+ * z
1313
+ * (
1314
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1315
+ + 2.16666666666667
1316
+ * z
1317
+ * (
1318
+ -2.8 * z * (1.5 - 7.5 * z2)
1319
+ + 2.2
1320
+ * z
1321
+ * (
1322
+ 2.25
1323
+ * z
1324
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1325
+ + 9.375 * z2
1326
+ - 1.875
1327
+ )
1328
+ - 4.8 * z
1329
+ )
1330
+ - 10.9375 * z2
1331
+ + 2.1875
1332
+ )
1333
+ + 5.48571428571429 * z
1334
+ ),
1335
+ 0.0231696385236779
1336
+ * (x2 - y2)
1337
+ * (
1338
+ -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
1339
+ + 2.5
1340
+ * z
1341
+ * (
1342
+ -4.8 * z * (52.5 * z2 - 7.5)
1343
+ + 2.6
1344
+ * z
1345
+ * (
1346
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
1347
+ - 91.875 * z2
1348
+ + 13.125
1349
+ )
1350
+ + 48.0 * z
1351
+ )
1352
+ + 137.8125 * z2
1353
+ - 19.6875
1354
+ ),
1355
+ 0.0028519853513317
1356
+ * x
1357
+ * (x2 - 3.0 * y2)
1358
+ * (
1359
+ -7.33333333333333 * z * (52.5 - 472.5 * z2)
1360
+ + 3.0
1361
+ * z
1362
+ * (
1363
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
1364
+ + 1063.125 * z2
1365
+ - 118.125
1366
+ )
1367
+ - 560.0 * z
1368
+ ),
1369
+ 0.000368189725644507
1370
+ * (-6.0 * x2 * y2 + x4 + y4)
1371
+ * (
1372
+ 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
1373
+ - 14293.125 * z2
1374
+ + 1299.375
1375
+ ),
1376
+ 5.10587282657803e-5
1377
+ * x
1378
+ * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
1379
+ * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
1380
+ 7.87853281621404e-6
1381
+ * (1013512.5 * z2 - 67567.5)
1382
+ * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
1383
+ -2.91570664069932
1384
+ * xz
1385
+ * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
1386
+ -20.4099464848952 * x2**3 * y2
1387
+ - 20.4099464848952 * x2 * y2**3
1388
+ + 0.72892666017483 * x4**2
1389
+ + 51.0248662122381 * x4 * y4
1390
+ + 0.72892666017483 * y4**2,
1391
+ ],
1392
+ -1,
1393
+ )
1394
+
1395
+
1396
+ __all__ = [
1397
+ "rsh_cart_0",
1398
+ "rsh_cart_1",
1399
+ "rsh_cart_2",
1400
+ "rsh_cart_3",
1401
+ "rsh_cart_4",
1402
+ "rsh_cart_5",
1403
+ "rsh_cart_6",
1404
+ "rsh_cart_7",
1405
+ "rsh_cart_8",
1406
+ ]
1407
+
1408
+
1409
+ from typing import Optional
1410
+ import torch
1411
+
1412
+
1413
+ class SphHarm(torch.nn.Module):
1414
+ def __init__(self, m, n, dtype=torch.float32) -> None:
1415
+ super().__init__()
1416
+ self.dtype = dtype
1417
+ m = torch.tensor(list(range(-m + 1, m)))
1418
+ n = torch.tensor(list(range(n)))
1419
+ self.is_normalized = False
1420
+ vals = torch.cartesian_prod(m, n).T
1421
+ vals = vals[:, vals[0] <= vals[1]]
1422
+ m, n = vals.unbind(0)
1423
+
1424
+ self.register_buffer("m", tensor=m)
1425
+ self.register_buffer("n", tensor=n)
1426
+ self.register_buffer("l_max", tensor=torch.max(self.n))
1427
+
1428
+ f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d = self._init_legendre()
1429
+ self.register_buffer("f_a", tensor=f_a)
1430
+ self.register_buffer("f_b", tensor=f_b)
1431
+ self.register_buffer("d0_mask_3d", tensor=d0_mask_3d)
1432
+ self.register_buffer("d1_mask_3d", tensor=d1_mask_3d)
1433
+ self.register_buffer("initial_value", tensor=initial_value)
1434
+
1435
+ @property
1436
+ def device(self):
1437
+ return next(self.buffers()).device
1438
+
1439
+ def forward(self, points: torch.Tensor) -> torch.Tensor:
1440
+ """Computes the spherical harmonics."""
1441
+ # Y_l^m = (-1) ^ m c_l^m P_l^m(cos(theta)) exp(i m phi)
1442
+ B, N, D = points.shape
1443
+ dtype = points.dtype
1444
+ theta, phi = points.view(-1, D).to(self.dtype).unbind(-1)
1445
+ cos_colatitude = torch.cos(phi)
1446
+ legendre = self._gen_associated_legendre(cos_colatitude)
1447
+ vals = torch.stack([self.m.abs(), self.n], dim=0)
1448
+ vals = torch.cat(
1449
+ [
1450
+ vals.repeat(1, theta.shape[0]),
1451
+ torch.arange(theta.shape[0], device=theta.device)
1452
+ .unsqueeze(0)
1453
+ .repeat_interleave(vals.shape[1], dim=1),
1454
+ ],
1455
+ dim=0,
1456
+ )
1457
+ legendre_vals = legendre[vals[0], vals[1], vals[2]]
1458
+ legendre_vals = legendre_vals.reshape(-1, theta.shape[0])
1459
+ angle = torch.outer(self.m.abs(), theta)
1460
+ vandermonde = torch.complex(torch.cos(angle), torch.sin(angle))
1461
+ harmonics = torch.complex(
1462
+ legendre_vals * torch.real(vandermonde),
1463
+ legendre_vals * torch.imag(vandermonde),
1464
+ )
1465
+
1466
+ # Negative order.
1467
+ m = self.m.unsqueeze(-1)
1468
+ harmonics = torch.where(
1469
+ m < 0, (-1.0) ** m.abs() * torch.conj(harmonics), harmonics
1470
+ )
1471
+ harmonics = harmonics.permute(1, 0).reshape(B, N, -1).to(dtype)
1472
+ return harmonics
1473
+
1474
+ def _gen_recurrence_mask(self) -> tuple[torch.Tensor, torch.Tensor]:
1475
+ """Generates mask for recurrence relation on the remaining entries.
1476
+
1477
+ The remaining entries are with respect to the diagonal and offdiagonal
1478
+ entries.
1479
+
1480
+ Args:
1481
+ l_max: see `gen_normalized_legendre`.
1482
+ Returns:
1483
+ torch.Tensors representing the mask used by the recurrence relations.
1484
+ """
1485
+
1486
+ # Computes all coefficients.
1487
+ m_mat, l_mat = torch.meshgrid(
1488
+ torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
1489
+ torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
1490
+ indexing="ij",
1491
+ )
1492
+ if self.is_normalized:
1493
+ c0 = l_mat * l_mat
1494
+ c1 = m_mat * m_mat
1495
+ c2 = 2.0 * l_mat
1496
+ c3 = (l_mat - 1.0) * (l_mat - 1.0)
1497
+ d0 = torch.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
1498
+ d1 = torch.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
1499
+ else:
1500
+ d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
1501
+ d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)
1502
+
1503
+ d0_mask_indices = torch.triu_indices(self.l_max + 1, 1)
1504
+ d1_mask_indices = torch.triu_indices(self.l_max + 1, 2)
1505
+
1506
+ d_zeros = torch.zeros(
1507
+ (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
1508
+ )
1509
+ d_zeros[d0_mask_indices] = d0[d0_mask_indices]
1510
+ d0_mask = d_zeros
1511
+
1512
+ d_zeros = torch.zeros(
1513
+ (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
1514
+ )
1515
+ d_zeros[d1_mask_indices] = d1[d1_mask_indices]
1516
+ d1_mask = d_zeros
1517
+
1518
+ # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
1519
+ i = torch.arange(self.l_max + 1, device=self.device)[:, None, None]
1520
+ j = torch.arange(self.l_max + 1, device=self.device)[None, :, None]
1521
+ k = torch.arange(self.l_max + 1, device=self.device)[None, None, :]
1522
+ mask = (i + j - k == 0).to(self.dtype)
1523
+ d0_mask_3d = torch.einsum("jk,ijk->ijk", d0_mask, mask)
1524
+ d1_mask_3d = torch.einsum("jk,ijk->ijk", d1_mask, mask)
1525
+ return (d0_mask_3d, d1_mask_3d)
1526
+
1527
+ def _recursive(self, i: int, p_val: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
1528
+ coeff_0 = self.d0_mask_3d[i]
1529
+ coeff_1 = self.d1_mask_3d[i]
1530
+ h = torch.einsum(
1531
+ "ij,ijk->ijk",
1532
+ coeff_0,
1533
+ torch.einsum("ijk,k->ijk", torch.roll(p_val, shifts=1, dims=1), x),
1534
+ ) - torch.einsum("ij,ijk->ijk", coeff_1, torch.roll(p_val, shifts=2, dims=1))
1535
+ p_val = p_val + h
1536
+ return p_val
1537
+
1538
+ def _init_legendre(self):
1539
+ a_idx = torch.arange(1, self.l_max + 1, dtype=self.dtype, device=self.device)
1540
+ b_idx = torch.arange(self.l_max, dtype=self.dtype, device=self.device)
1541
+ if self.is_normalized:
1542
+ # The initial value p(0,0).
1543
+ initial_value: torch.Tensor = torch.tensor(
1544
+ 0.5 / (torch.pi**0.5), device=self.device
1545
+ )
1546
+ f_a = torch.cumprod(-1 * torch.sqrt(1.0 + 0.5 / a_idx), dim=0)
1547
+ f_b = torch.sqrt(2.0 * b_idx + 3.0)
1548
+ else:
1549
+ # The initial value p(0,0).
1550
+ initial_value = torch.tensor(1.0, device=self.device)
1551
+ f_a = torch.cumprod(1.0 - 2.0 * a_idx, dim=0)
1552
+ f_b = 2.0 * b_idx + 1.0
1553
+
1554
+ d0_mask_3d, d1_mask_3d = self._gen_recurrence_mask()
1555
+ return f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d
1556
+
1557
+ def _gen_associated_legendre(self, x: torch.Tensor) -> torch.Tensor:
1558
+ r"""Computes associated Legendre functions (ALFs) of the first kind.
1559
+
1560
+ The ALFs of the first kind are used in spherical harmonics. The spherical
1561
+ harmonic of degree `l` and order `m` can be written as
1562
+ `Y_l^m(ΞΈ, Ο†) = N_l^m * P_l^m(cos(ΞΈ)) * exp(i m Ο†)`, where `N_l^m` is the
1563
+ normalization factor and ΞΈ and Ο† are the colatitude and longitude,
1564
+ repectively. `N_l^m` is chosen in the way that the spherical harmonics form
1565
+ a set of orthonormal basis function of L^2(S^2). For the computational
1566
+ efficiency of spherical harmonics transform, the normalization factor is
1567
+ used in the computation of the ALFs. In addition, normalizing `P_l^m`
1568
+ avoids overflow/underflow and achieves better numerical stability. Three
1569
+ recurrence relations are used in the computation.
1570
+
1571
+ Args:
1572
+ l_max: The maximum degree of the associated Legendre function. Both the
1573
+ degrees and orders are `[0, 1, 2, ..., l_max]`.
1574
+ x: A vector of type `float32`, `float64` containing the sampled points in
1575
+ spherical coordinates, at which the ALFs are computed; `x` is essentially
1576
+ `cos(ΞΈ)`. For the numerical integration used by the spherical harmonics
1577
+ transforms, `x` contains the quadrature points in the interval of
1578
+ `[-1, 1]`. There are several approaches to provide the quadrature points:
1579
+ Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
1580
+ method (`scipy.special.roots_chebyu`), and Driscoll & Healy
1581
+ method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
1582
+ transforms and convolutions on the 2-sphere." Advances in applied
1583
+ mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
1584
+ points are nearly equal-spaced along ΞΈ and provide exact discrete
1585
+ orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
1586
+ operation, `W` is a diagonal matrix containing the quadrature weights,
1587
+ and `I` is the identity matrix. The Gauss-Chebyshev points are equally
1588
+ spaced, which only provide approximate discrete orthogonality. The
1589
+ Driscoll & Healy qudarture points are equally spaced and provide the
1590
+ exact discrete orthogonality. The number of sampling points is required to
1591
+ be twice as the number of frequency points (modes) in the Driscoll & Healy
1592
+ approach, which enables FFT and achieves a fast spherical harmonics
1593
+ transform.
1594
+ is_normalized: True if the associated Legendre functions are normalized.
1595
+ With normalization, `N_l^m` is applied such that the spherical harmonics
1596
+ form a set of orthonormal basis functions of L^2(S^2).
1597
+
1598
+ Returns:
1599
+ The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
1600
+ of the ALFs at `x`; the dimensions in the sequence of order, degree, and
1601
+ evalution points.
1602
+ """
1603
+ p = torch.zeros(
1604
+ (self.l_max + 1, self.l_max + 1, x.shape[0]), dtype=x.dtype, device=x.device
1605
+ )
1606
+ p[0, 0] = self.initial_value
1607
+
1608
+ # Compute the diagonal entries p(l,l) with recurrence.
1609
+ y = torch.cumprod(
1610
+ torch.broadcast_to(torch.sqrt(1.0 - x * x), (self.l_max, x.shape[0])), dim=0
1611
+ )
1612
+ p_diag = self.initial_value * torch.einsum("i,ij->ij", self.f_a, y)
1613
+ # torch.diag_indices(l_max + 1)
1614
+ diag_indices = torch.stack(
1615
+ [torch.arange(0, self.l_max + 1, device=x.device)] * 2, dim=0
1616
+ )
1617
+ p[(diag_indices[0][1:], diag_indices[1][1:])] = p_diag
1618
+
1619
+ diag_indices = torch.stack(
1620
+ [torch.arange(0, self.l_max, device=x.device)] * 2, dim=0
1621
+ )
1622
+
1623
+ # Compute the off-diagonal entries with recurrence.
1624
+ p_offdiag = torch.einsum(
1625
+ "ij,ij->ij",
1626
+ torch.einsum("i,j->ij", self.f_b, x),
1627
+ p[(diag_indices[0], diag_indices[1])],
1628
+ ) # p[torch.diag_indices(l_max)])
1629
+ p[(diag_indices[0][: self.l_max], diag_indices[1][: self.l_max] + 1)] = (
1630
+ p_offdiag
1631
+ )
1632
+
1633
+ # Compute the remaining entries with recurrence.
1634
+ if self.l_max > 1:
1635
+ for i in range(2, self.l_max + 1):
1636
+ p = self._recursive(i, p, x)
1637
+ return p
flash3d/unidepth/utils/visualization.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Luigi Piccinelli
3
+ Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
4
+ """
5
+
6
+ import os
7
+
8
+ import numpy as np
9
+ from PIL import Image
10
+ import matplotlib.cm
11
+ import wandb
12
+ import torch
13
+
14
+ from unidepth.utils.misc import ssi_helper
15
+
16
+
17
+ def colorize(
18
+ value: np.ndarray, vmin: float = None, vmax: float = None, cmap: str = "magma_r"
19
+ ):
20
+ # if already RGB, do nothing
21
+ if value.ndim > 2:
22
+ if value.shape[-1] > 1:
23
+ return value
24
+ value = value[..., 0]
25
+ invalid_mask = value < 0.0001
26
+ # normalize
27
+ vmin = value.min() if vmin is None else vmin
28
+ vmax = value.max() if vmax is None else vmax
29
+ value = (value - vmin) / (vmax - vmin) # vmin..vmax
30
+
31
+ # set color
32
+ cmapper = matplotlib.cm.get_cmap(cmap)
33
+ value = cmapper(value, bytes=True) # (nxmx4)
34
+ value[invalid_mask] = 0
35
+ img = value[..., :3]
36
+ return img
37
+
38
+
39
+ def image_grid(imgs: list[np.ndarray], rows: int, cols: int) -> np.ndarray:
40
+ if not len(imgs):
41
+ return None
42
+ assert len(imgs) == rows * cols
43
+ h, w = imgs[0].shape[:2]
44
+ grid = Image.new("RGB", size=(cols * w, rows * h))
45
+
46
+ for i, img in enumerate(imgs):
47
+ grid.paste(
48
+ Image.fromarray(img.astype(np.uint8)).resize(
49
+ (w, h), resample=Image.BILINEAR
50
+ ),
51
+ box=(i % cols * w, i // cols * h),
52
+ )
53
+
54
+ return np.array(grid)
55
+
56
+
57
+ def get_pointcloud_from_rgbd(
58
+ image: np.array,
59
+ depth: np.array,
60
+ mask: np.ndarray,
61
+ intrinsic_matrix: np.array,
62
+ extrinsic_matrix: np.array = None,
63
+ ):
64
+ depth = np.array(depth).squeeze()
65
+ mask = np.array(mask).squeeze()
66
+ # Mask the depth array
67
+ masked_depth = np.ma.masked_where(mask == False, depth)
68
+ # masked_depth = np.ma.masked_greater(masked_depth, 8000)
69
+ # Create idx array
70
+ idxs = np.indices(masked_depth.shape)
71
+ u_idxs = idxs[1]
72
+ v_idxs = idxs[0]
73
+ # Get only non-masked depth and idxs
74
+ z = masked_depth[~masked_depth.mask]
75
+ compressed_u_idxs = u_idxs[~masked_depth.mask]
76
+ compressed_v_idxs = v_idxs[~masked_depth.mask]
77
+ image = np.stack(
78
+ [image[..., i][~masked_depth.mask] for i in range(image.shape[-1])], axis=-1
79
+ )
80
+
81
+ # Calculate local position of each point
82
+ # Apply vectorized math to depth using compressed arrays
83
+ cx = intrinsic_matrix[0, 2]
84
+ fx = intrinsic_matrix[0, 0]
85
+ x = (compressed_u_idxs - cx) * z / fx
86
+ cy = intrinsic_matrix[1, 2]
87
+ fy = intrinsic_matrix[1, 1]
88
+ # Flip y as we want +y pointing up not down
89
+ y = -((compressed_v_idxs - cy) * z / fy)
90
+
91
+ # # Apply camera_matrix to pointcloud as to get the pointcloud in world coords
92
+ # if extrinsic_matrix is not None:
93
+ # # Calculate camera pose from extrinsic matrix
94
+ # camera_matrix = np.linalg.inv(extrinsic_matrix)
95
+ # # Create homogenous array of vectors by adding 4th entry of 1
96
+ # # At the same time flip z as for eye space the camera is looking down the -z axis
97
+ # w = np.ones(z.shape)
98
+ # x_y_z_eye_hom = np.vstack((x, y, -z, w))
99
+ # # Transform the points from eye space to world space
100
+ # x_y_z_world = np.dot(camera_matrix, x_y_z_eye_hom)[:3]
101
+ # return x_y_z_world.T
102
+ # else:
103
+ x_y_z_local = np.stack((x, y, z), axis=-1)
104
+ return np.concatenate([x_y_z_local, image], axis=-1)
105
+
106
+
107
+ def save_file_ply(xyz, rgb, pc_file):
108
+ if rgb.max() < 1.001:
109
+ rgb = rgb * 255.0
110
+ rgb = rgb.astype(np.uint8)
111
+ # print(rgb)
112
+ with open(pc_file, "w") as f:
113
+ # headers
114
+ f.writelines(
115
+ [
116
+ "ply\n" "format ascii 1.0\n",
117
+ "element vertex {}\n".format(xyz.shape[0]),
118
+ "property float x\n",
119
+ "property float y\n",
120
+ "property float z\n",
121
+ "property uchar red\n",
122
+ "property uchar green\n",
123
+ "property uchar blue\n",
124
+ "end_header\n",
125
+ ]
126
+ )
127
+
128
+ for i in range(xyz.shape[0]):
129
+ str_v = "{:10.6f} {:10.6f} {:10.6f} {:d} {:d} {:d}\n".format(
130
+ xyz[i, 0], xyz[i, 1], xyz[i, 2], rgb[i, 0], rgb[i, 1], rgb[i, 2]
131
+ )
132
+ f.write(str_v)
133
+
134
+
135
+ # really awful fct... FIXME
136
+ def log_train_artifacts(rgbs, gts, preds, ds_name, step, infos={}):
137
+ rgbs = [
138
+ (127.5 * (rgb + 1))
139
+ .clip(0, 255)
140
+ .to(torch.uint8)
141
+ .cpu()
142
+ .detach()
143
+ .permute(1, 2, 0)
144
+ .numpy()
145
+ for rgb in rgbs
146
+ ]
147
+
148
+ new_gts, new_preds = [], []
149
+ if len(gts) > 0:
150
+ for i, gt in enumerate(gts):
151
+ scale, shift = ssi_helper(
152
+ gts[i][gts[i] > 0].cpu().detach(), preds[i][gts[i] > 0].cpu().detach()
153
+ )
154
+ gt = gts[i].cpu().detach().squeeze().numpy()
155
+ pred = (preds[i].cpu().detach() * scale + shift).squeeze().numpy()
156
+ vmin = gt[gt > 0].min() if (gt > 0).any() else 0.0
157
+ vmax = gt.max() if (gt > 0).any() else 0.1
158
+ new_gts.append(colorize(gt, vmin=vmin, vmax=vmax))
159
+ new_preds.append(colorize(pred, vmin=vmin, vmax=vmax))
160
+ gts, preds = new_gts, new_preds
161
+ else:
162
+ preds = [
163
+ colorize(pred.cpu().detach().squeeze().numpy(), 0.0, 80.0)
164
+ for i, pred in enumerate(preds)
165
+ ]
166
+
167
+ num_additional, additionals = 0, []
168
+ for name, info in infos.items():
169
+ num_additional += 1
170
+ if info.shape[1] == 3:
171
+ additionals.extend(
172
+ [
173
+ (127.5 * (x + 1))
174
+ .clip(0, 255)
175
+ .to(torch.uint8)
176
+ .cpu()
177
+ .detach()
178
+ .permute(1, 2, 0)
179
+ .numpy()
180
+ for x in info[:4]
181
+ ]
182
+ )
183
+ else:
184
+ additionals.extend(
185
+ [
186
+ colorize(x.cpu().detach().squeeze().numpy())
187
+ for i, x in enumerate(info[:4])
188
+ ]
189
+ )
190
+
191
+ num_rows = 2 + int(len(gts) > 0) + num_additional
192
+ artifacts_grid = image_grid(
193
+ [*rgbs, *gts, *preds, *additionals], num_rows, len(rgbs)
194
+ )
195
+ try:
196
+ wandb.log({f"{ds_name}_training": [wandb.Image(artifacts_grid)]}, step=step)
197
+ except:
198
+ Image.fromarray(artifacts_grid).save(
199
+ os.path.join(os.environ["HOME"], "Workspace", f"art_grid{step}.png")
200
+ )
201
+ print("Logging training images failed")
flash3d/util/vis3d.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from jaxtyping import Float
3
+ import numpy as np
4
+ from scipy.spatial.transform import Rotation as R
5
+ from plyfile import PlyData, PlyElement
6
+ import torch
7
+ from torch import Tensor
8
+ from einops import rearrange, einsum
9
+
10
+
11
+ def construct_list_of_attributes(num_rest: int) -> list[str]:
12
+ attributes = ["x", "y", "z", "nx", "ny", "nz"]
13
+ for i in range(3):
14
+ attributes.append(f"f_dc_{i}")
15
+ for i in range(num_rest):
16
+ attributes.append(f"f_rest_{i}")
17
+ attributes.append("opacity")
18
+ for i in range(3):
19
+ attributes.append(f"scale_{i}")
20
+ for i in range(4):
21
+ attributes.append(f"rot_{i}")
22
+ return attributes
23
+
24
+
25
+ def export_ply(
26
+ means: Float[Tensor, "gaussian 3"],
27
+ scales: Float[Tensor, "gaussian 3"],
28
+ rotations: Float[Tensor, "gaussian 4"],
29
+ harmonics: Float[Tensor, "gaussian 3 d_sh"],
30
+ opacities: Float[Tensor, "gaussian"],
31
+ path: Path,
32
+ ):
33
+ path = Path(path)
34
+ # Shift the scene so that the median Gaussian is at the origin.
35
+ means = means - means.median(dim=0).values
36
+
37
+ # Rescale the scene so that most Gaussians are within range [-1, 1].
38
+ scale_factor = means.abs().quantile(0.95, dim=0).max()
39
+ means = means / scale_factor
40
+ scales = scales / scale_factor
41
+ scales = scales * 4.0
42
+ scales = torch.clamp(scales, 0, 0.0075)
43
+
44
+ # Define a rotation that makes +Z be the world up vector.
45
+ # rotation = [
46
+ # [0, 0, 1],
47
+ # [-1, 0, 0],
48
+ # [0, -1, 0],
49
+ # ]
50
+ rotation = [
51
+ [1, 0, 0],
52
+ [0, 1, 0],
53
+ [0, 0, 1],
54
+ ]
55
+ rotation = torch.tensor(rotation, dtype=torch.float32, device=means.device)
56
+
57
+ # The Polycam viewer seems to start at a 45 degree angle. Since we want to be
58
+ # looking directly at the object, we compose a 45 degree rotation onto the above
59
+ # rotation.
60
+ # adjustment = torch.tensor(
61
+ # R.from_rotvec([0, 0, -45], True).as_matrix(),
62
+ # dtype=torch.float32,
63
+ # device=means.device,
64
+ # )
65
+ # rotation = adjustment @ rotation
66
+
67
+ # We also want to see the scene in camera space (as the default view). We therefore
68
+ # compose the w2c rotation onto the above rotation.
69
+ # rotation = rotation @ extrinsics[:3, :3].inverse()
70
+
71
+ # Apply the rotation to the means (Gaussian positions).
72
+ means = einsum(rotation, means, "i j, ... j -> ... i")
73
+
74
+ # Apply the rotation to the Gaussian rotations.
75
+ rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix()
76
+ rotations = rotation.detach().cpu().numpy() @ rotations
77
+ rotations = R.from_matrix(rotations).as_quat()
78
+ x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g")
79
+ rotations = np.stack((w, x, y, z), axis=-1)
80
+
81
+ # Since our axes are swizzled for the spherical harmonics, we only export the DC
82
+ # band.
83
+ harmonics_view_invariant = harmonics
84
+
85
+ dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)]
86
+ elements = np.empty(means.shape[0], dtype=dtype_full)
87
+ attributes = (
88
+ means.detach().cpu().numpy(),
89
+ torch.zeros_like(means).detach().cpu().numpy(),
90
+ harmonics_view_invariant.detach().cpu().contiguous().numpy(),
91
+ opacities.detach().cpu().numpy(),
92
+ scales.log().detach().cpu().numpy(),
93
+ rotations,
94
+ )
95
+ attributes = np.concatenate(attributes, axis=1)
96
+ elements[:] = list(map(tuple, attributes))
97
+ path.parent.mkdir(exist_ok=True, parents=True)
98
+ PlyData([PlyElement.describe(elements, "vertex")]).write(path)
99
+
100
+
101
+ def save_ply(outputs, path, num_gauss=3):
102
+ pad = 32
103
+
104
+ def crop_r(t):
105
+ h, w = 256, 384
106
+ H = h + pad * 2
107
+ W = w + pad * 2
108
+ t = rearrange(t, "b c (h w) -> b c h w", h=H, w=W)
109
+ t = t[..., pad:H-pad, pad:W-pad]
110
+ t = rearrange(t, "b c h w -> b c (h w)")
111
+ return t
112
+
113
+ def crop(t):
114
+ h, w = 256, 384
115
+ H = h + pad * 2
116
+ W = w + pad * 2
117
+ t = t[..., pad:H-pad, pad:W-pad]
118
+ return t
119
+
120
+ # import pdb
121
+ # pdb.set_trace()
122
+ means = rearrange(crop_r(outputs[('gauss_means', 0, 0)]), "(b v) c n -> b (v n) c", v=num_gauss)[0, :, :3]
123
+ scales = rearrange(crop(outputs[('gauss_scaling', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
124
+ rotations = rearrange(crop(outputs[('gauss_rotation', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
125
+ opacities = rearrange(crop(outputs[('gauss_opacity', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
126
+ harmonics = rearrange(crop(outputs[('gauss_features_dc', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
127
+
128
+ export_ply(
129
+ means,
130
+ scales,
131
+ rotations,
132
+ harmonics,
133
+ opacities,
134
+ path
135
+ )