diff --git a/mono/configs/HourglassDecoder/convlarge.0.3_150.py b/mono/configs/HourglassDecoder/convlarge.0.3_150.py new file mode 100644 index 0000000000000000000000000000000000000000..37b91c80284d6db3df3017ec636f18198e42dc08 --- /dev/null +++ b/mono/configs/HourglassDecoder/convlarge.0.3_150.py @@ -0,0 +1,25 @@ +_base_=[ + '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py', + '../_base_/datasets/_data_base_.py', + '../_base_/default_runtime.py', + ] + +model = dict( + backbone=dict( + pretrained=False, + ) +) + +# configs of the canonical space +data_basic=dict( + canonical_space = dict( + img_size=(512, 960), + focal_length=1000.0, + ), + depth_range=(0, 1), + depth_normalize=(0.3, 150), + crop_size = (544, 1216), +) + +batchsize_per_gpu = 2 +thread_per_gpu = 4 diff --git a/mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py b/mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py new file mode 100644 index 0000000000000000000000000000000000000000..cdd9156b7f2f0921fb01b1adaf9a2a7447332d6e --- /dev/null +++ b/mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py @@ -0,0 +1,25 @@ +_base_=[ + '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py', + '../_base_/datasets/_data_base_.py', + '../_base_/default_runtime.py', + ] + +model = dict( + backbone=dict( + pretrained=False, + ) +) + +# configs of the canonical space +data_basic=dict( + canonical_space = dict( + img_size=(512, 960), + focal_length=1000.0, + ), + depth_range=(0, 1), + depth_normalize=(0.3, 150), + crop_size = (512, 1088), +) + +batchsize_per_gpu = 2 +thread_per_gpu = 4 diff --git a/mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py b/mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py new file mode 100644 index 0000000000000000000000000000000000000000..6601f5cdfad07c5fad8b89fbf959e67039126dfa --- /dev/null +++ b/mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py @@ -0,0 +1,25 @@ +_base_=[ + '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py', + '../_base_/datasets/_data_base_.py', + '../_base_/default_runtime.py', + ] + +model = dict( + backbone=dict( + pretrained=False, + ) +) + +# configs of the canonical space +data_basic=dict( + canonical_space = dict( + img_size=(512, 960), + focal_length=1000.0, + ), + depth_range=(0, 1), + depth_normalize=(0.3, 150), + crop_size = (480, 1216), +) + +batchsize_per_gpu = 2 +thread_per_gpu = 4 diff --git a/mono/configs/HourglassDecoder/vit.raft5.large.py b/mono/configs/HourglassDecoder/vit.raft5.large.py new file mode 100644 index 0000000000000000000000000000000000000000..4febdcb2867513008496f394ce8dc513230fb480 --- /dev/null +++ b/mono/configs/HourglassDecoder/vit.raft5.large.py @@ -0,0 +1,33 @@ +_base_=[ + '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py', + '../_base_/datasets/_data_base_.py', + '../_base_/default_runtime.py', + ] + +import numpy as np +model=dict( + decode_head=dict( + type='RAFTDepthNormalDPT5', + iters=8, + n_downsample=2, + detach=False, + ) +) + + +max_value = 200 +# configs of the canonical space +data_basic=dict( + canonical_space = dict( + # img_size=(540, 960), + focal_length=1000.0, + ), + depth_range=(0, 1), + depth_normalize=(0.1, max_value), + crop_size = (616, 1064), # %28 = 0 + clip_depth_range=(0.1, 200), + vit_size=(616,1064) +) + +batchsize_per_gpu = 1 +thread_per_gpu = 1 diff --git a/mono/configs/HourglassDecoder/vit.raft5.small.py b/mono/configs/HourglassDecoder/vit.raft5.small.py new file mode 100644 index 0000000000000000000000000000000000000000..25eb68cc151f090c7654b7ebbcaf9dfc6a478570 --- /dev/null +++ b/mono/configs/HourglassDecoder/vit.raft5.small.py @@ -0,0 +1,33 @@ +_base_=[ + '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py', + '../_base_/datasets/_data_base_.py', + '../_base_/default_runtime.py', + ] + +import numpy as np +model=dict( + decode_head=dict( + type='RAFTDepthNormalDPT5', + iters=4, + n_downsample=2, + detach=False, + ) +) + + +max_value = 200 +# configs of the canonical space +data_basic=dict( + canonical_space = dict( + # img_size=(540, 960), + focal_length=1000.0, + ), + depth_range=(0, 1), + depth_normalize=(0.1, max_value), + crop_size = (616, 1064), # %28 = 0 + clip_depth_range=(0.1, 200), + vit_size=(616,1064) +) + +batchsize_per_gpu = 1 +thread_per_gpu = 1 diff --git a/mono/configs/__init__.py b/mono/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/mono/configs/__init__.py @@ -0,0 +1 @@ + diff --git a/mono/configs/_base_/_data_base_.py b/mono/configs/_base_/_data_base_.py new file mode 100644 index 0000000000000000000000000000000000000000..35f3844f24191b6b9452e136ea3205b7622466d7 --- /dev/null +++ b/mono/configs/_base_/_data_base_.py @@ -0,0 +1,13 @@ +# canonical camera setting and basic data setting +# we set it same as the E300 camera (crop version) +# +data_basic=dict( + canonical_space = dict( + img_size=(540, 960), + focal_length=1196.0, + ), + depth_range=(0.9, 150), + depth_normalize=(0.006, 1.001), + crop_size = (512, 960), + clip_depth_range=(0.9, 150), +) diff --git a/mono/configs/_base_/datasets/_data_base_.py b/mono/configs/_base_/datasets/_data_base_.py new file mode 100644 index 0000000000000000000000000000000000000000..b554444e9b75b4519b862e726890dcf7859be0ec --- /dev/null +++ b/mono/configs/_base_/datasets/_data_base_.py @@ -0,0 +1,12 @@ +# canonical camera setting and basic data setting +# +data_basic=dict( + canonical_space = dict( + img_size=(540, 960), + focal_length=1196.0, + ), + depth_range=(0.9, 150), + depth_normalize=(0.006, 1.001), + crop_size = (512, 960), + clip_depth_range=(0.9, 150), +) diff --git a/mono/configs/_base_/default_runtime.py b/mono/configs/_base_/default_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..a690b491bf50aad5c2fd7e9ac387609123a4594a --- /dev/null +++ b/mono/configs/_base_/default_runtime.py @@ -0,0 +1,4 @@ + +load_from = None +cudnn_benchmark = True +test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3','rmse_log', 'log10', 'sq_rel'] diff --git a/mono/configs/_base_/models/backbones/convnext_large.py b/mono/configs/_base_/models/backbones/convnext_large.py new file mode 100644 index 0000000000000000000000000000000000000000..5a22f7e1b53ca154bfae1672e6ee3b52028039b9 --- /dev/null +++ b/mono/configs/_base_/models/backbones/convnext_large.py @@ -0,0 +1,16 @@ +#_base_ = ['./_model_base_.py',] + +#'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-large_3rdparty_in21k_20220301-e6e0ea0a.pth' +model = dict( + #type='EncoderDecoderAuxi', + backbone=dict( + type='convnext_large', + pretrained=True, + in_22k=True, + out_indices=[0, 1, 2, 3], + drop_path_rate=0.4, + layer_scale_init_value=1.0, + checkpoint='data/pretrained_weight_repo/convnext/convnext_large_22k_1k_384.pth', + prefix='backbones.', + out_channels=[192, 384, 768, 1536]), + ) diff --git a/mono/configs/_base_/models/backbones/dino_vit_large.py b/mono/configs/_base_/models/backbones/dino_vit_large.py new file mode 100644 index 0000000000000000000000000000000000000000..843178ed6e61d74070b971f01148f87fdf2a62cf --- /dev/null +++ b/mono/configs/_base_/models/backbones/dino_vit_large.py @@ -0,0 +1,7 @@ +model = dict( + backbone=dict( + type='vit_large', + prefix='backbones.', + out_channels=[1024, 1024, 1024, 1024], + drop_path_rate = 0.0), + ) diff --git a/mono/configs/_base_/models/backbones/dino_vit_large_reg.py b/mono/configs/_base_/models/backbones/dino_vit_large_reg.py new file mode 100644 index 0000000000000000000000000000000000000000..25e96747d459d42df299f8a6a1e14044a0e56164 --- /dev/null +++ b/mono/configs/_base_/models/backbones/dino_vit_large_reg.py @@ -0,0 +1,7 @@ +model = dict( + backbone=dict( + type='vit_large_reg', + prefix='backbones.', + out_channels=[1024, 1024, 1024, 1024], + drop_path_rate = 0.0), + ) diff --git a/mono/configs/_base_/models/backbones/dino_vit_small_reg.py b/mono/configs/_base_/models/backbones/dino_vit_small_reg.py new file mode 100644 index 0000000000000000000000000000000000000000..0c8bd97dccb9cdee7517250f40e01bb3124144e6 --- /dev/null +++ b/mono/configs/_base_/models/backbones/dino_vit_small_reg.py @@ -0,0 +1,7 @@ +model = dict( + backbone=dict( + type='vit_small_reg', + prefix='backbones.', + out_channels=[384, 384, 384, 384], + drop_path_rate = 0.0), + ) diff --git a/mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py b/mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f262288c49e7ffccb6174b09b0daf80ff79dd684 --- /dev/null +++ b/mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py @@ -0,0 +1,10 @@ +# model settings +_base_ = ['../backbones/convnext_large.py',] +model = dict( + type='DensePredModel', + decode_head=dict( + type='HourglassDecoder', + in_channels=[192, 384, 768, 1536], + decoder_channel=[128, 128, 256, 512], + prefix='decode_heads.'), +) diff --git a/mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py b/mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py new file mode 100644 index 0000000000000000000000000000000000000000..bd69efefab2c03de435996c6b7b65ff941db1e5d --- /dev/null +++ b/mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py @@ -0,0 +1,20 @@ +# model settings +_base_ = ['../backbones/dino_vit_large.py'] +model = dict( + type='DensePredModel', + decode_head=dict( + type='RAFTDepthDPT', + in_channels=[1024, 1024, 1024, 1024], + use_cls_token=True, + feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14] + decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14] + up_scale = 7, + hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536] + n_gru_layers=3, + n_downsample=2, + iters=12, + slow_fast_gru=True, + corr_radius=4, + corr_levels=4, + prefix='decode_heads.'), +) diff --git a/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py b/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py new file mode 100644 index 0000000000000000000000000000000000000000..26ab6dc090e9cdb840d84fab10587becb536dbb8 --- /dev/null +++ b/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py @@ -0,0 +1,19 @@ +# model settings +_base_ = ['../backbones/dino_vit_large_reg.py'] +model = dict( + type='DensePredModel', + decode_head=dict( + type='RAFTDepthDPT', + in_channels=[1024, 1024, 1024, 1024], + use_cls_token=True, + feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14] + decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14] + up_scale = 7, + hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536] + n_gru_layers=3, + n_downsample=2, + iters=3, + slow_fast_gru=True, + num_register_tokens=4, + prefix='decode_heads.'), +) diff --git a/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py b/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py new file mode 100644 index 0000000000000000000000000000000000000000..19466c191e9f2a83903e55ca4fc0827d9a11bcb9 --- /dev/null +++ b/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py @@ -0,0 +1,19 @@ +# model settings +_base_ = ['../backbones/dino_vit_small_reg.py'] +model = dict( + type='DensePredModel', + decode_head=dict( + type='RAFTDepthDPT', + in_channels=[384, 384, 384, 384], + use_cls_token=True, + feature_channels = [96, 192, 384, 768], # [2/7, 1/7, 1/14, 1/14] + decoder_channels = [48, 96, 192, 384, 384], # [-, 1/4, 1/7, 1/14, 1/14] + up_scale = 7, + hidden_channels=[48, 48, 48, 48], # [x_4, x_8, x_16, x_32] [1/4, 1/7, 1/14, -] + n_gru_layers=3, + n_downsample=2, + iters=3, + slow_fast_gru=True, + num_register_tokens=4, + prefix='decode_heads.'), +) diff --git a/mono/model/__init__.py b/mono/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1ea3d3e3b880e28ef880083b3c79e3b00cd119 --- /dev/null +++ b/mono/model/__init__.py @@ -0,0 +1,5 @@ +from .monodepth_model import DepthModel +# from .__base_model__ import BaseDepthModel + + +__all__ = ['DepthModel', 'BaseDepthModel'] diff --git a/mono/model/__pycache__/__init__.cpython-39.pyc b/mono/model/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c9c860c14219cf199bdb577cb7e0e6dd7e5eadb Binary files /dev/null and b/mono/model/__pycache__/__init__.cpython-39.pyc differ diff --git a/mono/model/__pycache__/monodepth_model.cpython-39.pyc b/mono/model/__pycache__/monodepth_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd965a942a758e150ac2ca0854800bce82b83f14 Binary files /dev/null and b/mono/model/__pycache__/monodepth_model.cpython-39.pyc differ diff --git a/mono/model/backbones/ConvNeXt.py b/mono/model/backbones/ConvNeXt.py new file mode 100644 index 0000000000000000000000000000000000000000..f1c4be0e6463ae2b0dda6d20fc273a300afa5ebf --- /dev/null +++ b/mono/model/backbones/ConvNeXt.py @@ -0,0 +1,271 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import trunc_normal_, DropPath +from timm.models.registry import register_model + +class Block(nn.Module): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + +class ConvNeXt(nn.Module): + r""" ConvNeXt + A PyTorch impl of : `A ConvNet for the 2020s` - + https://arxiv.org/pdf/2201.03545.pdf + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + def __init__(self, in_chans=3, num_classes=1000, + depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., + layer_scale_init_value=1e-6, head_init_scale=1., + **kwargs,): + super().__init__() + + self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first") + ) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + cur = 0 + for i in range(4): + stage = nn.Sequential( + *[Block(dim=dims[i], drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] + ) + self.stages.append(stage) + cur += depths[i] + + #self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer + #self.head = nn.Linear(dims[-1], num_classes) + + self.apply(self._init_weights) + #self.head.weight.data.mul_(head_init_scale) + #self.head.bias.data.mul_(head_init_scale) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + nn.init.constant_(m.bias, 0) + + def forward_features(self, x): + features = [] + for i in range(4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + features.append(x) + return features # global average pooling, (N, C, H, W) -> (N, C) + + def forward(self, x): + #x = self.forward_features(x) + #x = self.head(x) + features = self.forward_features(x) + return features + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +model_urls = { + "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", + "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", + "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", + "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", + "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", + "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", + "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", + "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", + "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", +} + +def convnext_tiny(pretrained=True,in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) + if pretrained: + checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu") + #url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] + #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) + model_dict = model.state_dict() + pretrained_dict = {} + unmatched_pretrained_dict = {} + for k, v in checkpoint['model'].items(): + if k in model_dict: + pretrained_dict[k] = v + else: + unmatched_pretrained_dict[k] = v + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + print( + 'Successfully loaded pretrained %d params, and %d paras are unmatched.' + %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys()))) + print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys()) + return model + +def convnext_small(pretrained=True,in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) + if pretrained: + checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu") + #url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] + #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") + model_dict = model.state_dict() + pretrained_dict = {} + unmatched_pretrained_dict = {} + for k, v in checkpoint['model'].items(): + if k in model_dict: + pretrained_dict[k] = v + else: + unmatched_pretrained_dict[k] = v + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + print( + 'Successfully loaded pretrained %d params, and %d paras are unmatched.' + %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys()))) + print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys()) + return model + +def convnext_base(pretrained=True, in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + if pretrained: + checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu") + #url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] + #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") + model_dict = model.state_dict() + pretrained_dict = {} + unmatched_pretrained_dict = {} + for k, v in checkpoint['model'].items(): + if k in model_dict: + pretrained_dict[k] = v + else: + unmatched_pretrained_dict[k] = v + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + print( + 'Successfully loaded pretrained %d params, and %d paras are unmatched.' + %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys()))) + print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys()) + return model + +def convnext_large(pretrained=True, in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) + if pretrained: + checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu") + #url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] + #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") + model_dict = model.state_dict() + pretrained_dict = {} + unmatched_pretrained_dict = {} + for k, v in checkpoint['model'].items(): + if k in model_dict: + pretrained_dict[k] = v + else: + unmatched_pretrained_dict[k] = v + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + print( + 'Successfully loaded pretrained %d params, and %d paras are unmatched.' + %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys()))) + print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys()) + return model + +def convnext_xlarge(pretrained=True, in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) + if pretrained: + assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" + checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu") + #url = model_urls['convnext_xlarge_22k'] + #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") + model_dict = model.state_dict() + pretrained_dict = {} + unmatched_pretrained_dict = {} + for k, v in checkpoint['model'].items(): + if k in model_dict: + pretrained_dict[k] = v + else: + unmatched_pretrained_dict[k] = v + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + print( + 'Successfully loaded pretrained %d params, and %d paras are unmatched.' + %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys()))) + print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys()) + return model + +if __name__ == '__main__': + import torch + model = convnext_base(True, in_22k=False).cuda() + + rgb = torch.rand((2, 3, 256, 256)).cuda() + out = model(rgb) + print(len(out)) + for i, ft in enumerate(out): + print(i, ft.shape) diff --git a/mono/model/backbones/ViT_DINO.py b/mono/model/backbones/ViT_DINO.py new file mode 100644 index 0000000000000000000000000000000000000000..5a1998f0dd5024fbe69895e244fc054245a06568 --- /dev/null +++ b/mono/model/backbones/ViT_DINO.py @@ -0,0 +1,1504 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List + +import torch +import torch.nn as nn +from torch import Tensor +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +#from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + +logger = logging.getLogger("dinov2") + +class ConvBlock(nn.Module): + def __init__(self, channels): + super(ConvBlock, self).__init__() + + self.act = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d( + channels, + channels, + kernel_size=3, + stride=1, + padding=1 + ) + self.norm1 = nn.BatchNorm2d(channels) + self.conv2 = nn.Conv2d( + channels, + channels, + kernel_size=3, + stride=1, + padding=1 + ) + self.norm2 = nn.BatchNorm2d(channels) + + def forward(self, x): + + out = self.norm1(x) + out = self.act(out) + out = self.conv1(out) + out = self.norm2(out) + out = self.act(out) + out = self.conv2(out) + return x + out + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + #import numpy.bool + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + from xformers.components.attention import ScaledDotProduct + from xformers.components import MultiHeadDispatch + #import numpy.bool + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + window_size: int = 0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + #if not self.training: + # + # self.attn = ScaledDotProduct() + #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + if attn_bias is not None: + attn = attn + attn_bias[:, :, :N] + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + #if True: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x, attn_bias) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + if attn_bias is not None: + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N]) + else: + x = memory_efficient_attention(q, k, v) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + #import numpy.bool + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values = None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + def attn_residual_func(x: Tensor, attn_bias) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + attn_bias=attn_bias + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, attn_bias)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, attn_bias) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, attn_bias=None +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset, attn_bias) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list, attn_bias=None): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list, attn_bias) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x, others=None): + for b in self: + if others == None: + x = b(x) + else: + x = b(x, others) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + #init_values=None, # for layerscale: None or 0 => no layerscale + init_values=1e-5, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=NestedTensorBlock, + ffn_layer="mlp", + block_chunks=1, + window_size=37, + **kwargs + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.window_size = window_size + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode="bicubic", + ) + + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + B, C, H, W = x.size() + pad_h = (self.patch_size - H % self.patch_size) + pad_w = (self.patch_size - W % self.patch_size) + if pad_h == self.patch_size: + pad_h = 0 + if pad_w == self.patch_size: + pad_w = 0 + #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2)) + if pad_h + pad_w > 0: + x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear') + + x = self.prepare_tokens_with_masks(x, masks) + + features = [] + for blk in self.blocks: + x = blk(x) + # for idx in range(len(self.blocks[0])): + # x = self.blocks[0][idx](x) + # if (idx + 1) % (len(self.blocks[0]) // 4) == 0: + # features.append(x) + + #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] + + x_norm = self.norm(x) + # return { + # "x_norm_clstoken": x_norm[:, 0], + # "x_norm_patchtokens": x_norm[:, 1:], + # "x_prenorm": x, + # "masks": masks, + # } + features = [] + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + return ret + # if is_training: + # return ret + # else: + # return self.head(ret["x_norm_clstoken"]) + + +class PosConv(nn.Module): + # PEG from https://arxiv.org/abs/2102.10882 + def __init__(self, in_chans, embed_dim=768, stride=1): + super(PosConv, self).__init__() + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim, 37, stride, 18, bias=True, groups=embed_dim), + ) + self.stride = stride + + def forward(self, x, size): + B, N, C = x.shape + cnn_feat_token = x.transpose(1, 2).view(B, C, *size) + x = self.proj(cnn_feat_token) + if self.stride == 1: + x += cnn_feat_token + x = x.flatten(2).transpose(1, 2) + return x + + #def no_weight_decay(self): + #return ['proj.%d.weight' % i for i in range(4)] + +class DinoWindowVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + #init_values=None, # for layerscale: None or 0 => no layerscale + init_values=1e-5, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=NestedTensorBlock, + ffn_layer="mlp", + block_chunks=1, + window_size=7, + **kwargs + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + #self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + #self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + + self.pos_conv = PosConv(self.embed_dim, self.embed_dim) + + self.window_size = window_size + #self.conv_block = nn.ModuleList([ConvBlock(embed_dim) for i in range(4)]) + #self.conv_block = nn.ModuleList([nn.Identity() for i in range(4)]) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.nh = -1 + self.nw = -1 + try: + H = cfg.data_basic['crop_size'][0] + W = cfg.data_basic['crop_size'][1] + pad_h = (self.patch_size - H % self.patch_size) + pad_w = (self.patch_size - W % self.patch_size) + if pad_h == self.patch_size: + pad_h = 0 + if pad_w == self.patch_size: + pad_w = 0 + self.nh = (H + pad_h) // self.patch_size + self.nw = (W + pad_w) // self.patch_size + self.prepare_attn_bias((self.nh, self.nw)) + except: + pass + self.init_weights() + + self.total_step = 10000 # For PE -> GPE transfer + self.start_step = 2000 + self.current_step = 20000 + + def init_weights(self): + #trunc_normal_(self.pos_embed, std=0.02) + #nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + for i in range(4): + try: + nn.init.constant_(self.conv_block[i].conv2.weight, 0.0) + except: + pass + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + #npatch = x.shape[1] - 1 + #N = self.pos_embed.shape[1] - 1 + npatch = x.shape[1] + N = self.pos_embed.shape[1] + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + #class_pos_embed = pos_embed[:, 0] + #patch_pos_embed = pos_embed[:, 1:] + patch_pos_embed = pos_embed + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode="bicubic", + ) + + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed.to(previous_dtype) + #return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def window_partition(self, x: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + if conv_feature == False: + B, N, C = x.shape + H, W = hw[0], hw[1] + + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C) + else: + B, C, H, W = x.shape + + x = x.view(B, C, H // window_size, window_size, W // window_size, window_size) + + windows = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(-1, window_size * window_size, C) + + #y = torch.cat((x_cls, windows), dim=1) + return windows #, (Hp, Wp) + + + def window_unpartition(self, + windows: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False + ) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + H, W = hw + + B = windows.shape[0] // (H * W // window_size // window_size) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + + if conv_feature == False: + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp * Wp, -1) + else: + C = windows.shape[-1] + x = x.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, C, H, W) + + # if Hp > H or Wp > W: + # x = x[:, :H, :W, :].contiguous() + return x + + def prepare_tokens_with_masks(self, x, masks=None, step=-1): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + #x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if step == -1: + step = self.current_step + else: + self.current_step = step + + if step < self.start_step: + coef = 0.0 + elif step < self.total_step: + coef = (step - self.start_step) / (self.total_step - self.start_step) + else: + coef = 1.0 + + x = x + (1 - coef) * self.interpolate_pos_encoding(x, w, h) + coef * self.pos_conv(x, (self.nh, self.nw)) + + return x + + def prepare_attn_bias(self, shape): + window_size = self.window_size + if window_size <= 0: + return + + import xformers.components.attention.attention_patterns as AP + + nh, nw = shape + radius = (window_size-1)//2 + mask_ori = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda() + + pad = (8 - (nh * nw) % 8) + if pad == 8: + pad = 0 + mask_pad = nn.functional.pad(mask_ori, (0, pad)).contiguous() + if pad > 0: + mask = mask_pad[:, :-pad].view(nh, nw, nh, nw) + else: + mask = mask_pad[:, :].view(nh, nw, nh, nw) + + # angle + mask[:radius+1, :radius+1, :window_size, :window_size] = True + mask[:radius+1, -radius-1:, :window_size, -window_size:] = True + mask[-radius-1:, :radius+1, -window_size:, :window_size] = True + mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True + + # edge + mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :] + mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :] + mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :] + mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :] + + mask = mask.view(nh*nw, nh*nw) + bias_pad = torch.log(mask_pad) + #bias = bias_pad[:, :-pad] + self.register_buffer('attn_bias', bias_pad) + + return bias_pad + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None, **kwargs): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + B, C, H, W = x.size() + pad_h = (self.patch_size - H % self.patch_size) + pad_w = (self.patch_size - W % self.patch_size) + if pad_h == self.patch_size: + pad_h = 0 + if pad_w == self.patch_size: + pad_w = 0 + #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2)) + if pad_h + pad_w > 0: + x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear') + + nh = (H+pad_h)//self.patch_size + nw = (W+pad_w)//self.patch_size + + if self.window_size > 0: + if nh == self.nh and nw == self.nw: + attn_bias = self.attn_bias + else: + attn_bias = self.prepare_attn_bias(((H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size)) + self.nh = nh + self.nw = nw + attn_bias = attn_bias.unsqueeze(0).repeat(B * self.num_heads, 1, 1) + else: + attn_bias = None + + x = self.prepare_tokens_with_masks(x, masks) + #x = self.patch_embed(x) + + features = [] + #x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size)) + for blk in self.blocks: + x = blk(x, attn_bias) + #x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size)) + + # for idx in range(len(self.blocks[0])): + # x = self.blocks[0][idx](x, attn_bias) + + # if (idx + 1) % (len(self.blocks[0]) // 4) == 0: + # x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True) + # x = self.conv_block[idx // (len(self.blocks[0]) // 4)](x) + # if idx + 1 != len(self.blocks[0]): + # x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True) + # else: + # b, c, h, w = x.size() + # x = x.permute(0, 2, 3, 1).contiguous().view(b, h, w, c) + #features.append(x) + + #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] + + x_norm = self.norm(x) + # return { + # "x_norm_clstoken": x_norm[:, 0], + # "x_norm_patchtokens": x_norm[:, 1:], + # "x_prenorm": x, + # "masks": masks, + # } + features = [] + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + return ret + # if is_training: + # return ret + # else: + # return self.head(ret["x_norm_clstoken"]) + + + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=14, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_base(patch_size=14, **kwargs): + model = DinoWindowVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_large(patch_size=14, checkpoint=None, **kwargs): + model = DinoVisionTransformer( + img_size = 518, + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + **kwargs, + ) + + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + try: + model.load_state_dict(state_dict, strict=True) + except: + new_state_dict = {} + for key, value in state_dict.items(): + if 'blocks' in key: + key_new = 'blocks.0' + key[len('blocks'):] + else: + key_new = key + new_state_dict[key_new] = value + + model.load_state_dict(new_state_dict, strict=True) + #del model.norm + del model.mask_token + return model + + # model = DinoWindowVisionTransformer( + # img_size = 518, + # patch_size=patch_size, + # embed_dim=1024, + # depth=24, + # num_heads=16, + # mlp_ratio=4, + # block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + # window_size=37, + # **kwargs, + # ) + + # if checkpoint is not None: + # with open(checkpoint, "rb") as f: + # state_dict = torch.load(f) + # try: + # model.load_state_dict(state_dict, strict=True) + # except: + # new_state_dict = {} + # for key, value in state_dict.items(): + # if 'blocks' in key: + # key_new = 'blocks.0' + key[len('blocks'):] + # else: + # key_new = key + # if 'pos_embed' in key: + # value = value[:, 1:, :] + # new_state_dict[key_new] = value + + # model.load_state_dict(new_state_dict, strict=False) + # #del model.norm + # del model.mask_token + return model + + +def vit_giant2(patch_size=16, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + +if __name__ == '__main__': + try: + from mmcv.utils import Config + except: + from mmengine import Config + + #rgb = torch.rand((2, 3, 518, 518)).cuda() + + #cfg.data_basic['crop_size']['0'] + #cfg.data_basic['crop_size']['1'] + cfg = Config.fromfile('/cpfs01/user/mu.hu/monodepth/mono/configs/HourglassDecoder/pub12.convlarge.0.3_150.py') + + #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036) + rgb = torch.zeros(1, 3, 1400, 1680).cuda() + model = vit_large(checkpoint="/cpfs02/shared/public/custom/group_local_map/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth", kwarg=cfg).cuda() + + #import timm + #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda() + #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn) + + out1 = model(rgb) + #out2 = model2(rgb) + temp = 0 + + + +# import time +# window_size = 37 +# def prepare_window_masks(shape): +# if window_size <= 0: +# return None +# import xformers.components.attention.attention_patterns as AP + +# B, nh, nw, _, _ = shape +# radius = (window_size-1)//2 +# #time0 = time.time() +# d = AP.local_nd_distance(nh, nw, distance = radius + 0.1, p=torch.inf).cuda() +# #mask = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda() +# # mask = mask.view(nh, nw, nh, nw) +# # #time1 = time.time() - time0 + +# # # angle +# # mask[:radius+1, :radius+1, :window_size, :window_size] = True +# # mask[:radius+1, -radius-1:, :window_size, -window_size:] = True +# # mask[-radius-1:, :radius+1, -window_size:, :window_size] = True +# # mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True +# # time2 = time.time() - time0 - time1 + +# # # edge +# # mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :] +# # mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :] +# # mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :] +# # mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :] +# # time3 = time.time() - time0 - time2 +# # print(time1, time2, time3) + +# # return mask.view(nw*nw, nh*nw).unsqueeze(0).repeat(B, 1) + +# shape = (1, 55, 55, None, None) +# mask = prepare_window_masks(shape) +# # temp = 1 \ No newline at end of file diff --git a/mono/model/backbones/ViT_DINO_reg.py b/mono/model/backbones/ViT_DINO_reg.py new file mode 100644 index 0000000000000000000000000000000000000000..854f96320ea93752e023c8cd845bf38353dfab17 --- /dev/null +++ b/mono/model/backbones/ViT_DINO_reg.py @@ -0,0 +1,1293 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List + +import torch +import torch.nn as nn +from torch import Tensor +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ +import torch.nn.init +import torch.nn.functional as F + +#from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + +logger = logging.getLogger("dinov2") + +# SSF finetuning originally by dongzelian +def init_ssf_scale_shift(dim): + scale = nn.Parameter(torch.ones(dim)) + shift = nn.Parameter(torch.zeros(dim)) + + nn.init.normal_(scale, mean=1, std=.02) + nn.init.normal_(shift, std=.02) + + return scale, shift + +def ssf_ada(x, scale, shift): + assert scale.shape == shift.shape + if x.shape[-1] == scale.shape[0]: + return x * scale + shift + elif x.shape[1] == scale.shape[0]: + return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1) + else: + raise ValueError('the input tensor shape does not match the shape of the scale factor.') + +# LoRA finetuning originally by edwardjhu +class LoRALayer(): + def __init__( + self, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + +class LoRALinear(nn.Linear, LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0., + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + merge_weights: bool = True, + **kwargs + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, + merge_weights=merge_weights) + + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) + self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + if fan_in_fan_out: + self.weight.data = self.weight.data.transpose(0, 1) + + def reset_parameters(self): + #nn.Linear.reset_parameters(self) + if hasattr(self, 'lora_A'): + # initialize B the same way as the default for nn.Linear and A to zero + # this is different than what is described in the paper but should not affect performance + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + # def train(self, mode: bool = True): + # def T(w): + # return w.transpose(0, 1) if self.fan_in_fan_out else w + # nn.Linear.train(self, mode) + # if mode: + # if self.merge_weights and self.merged: + # # Make sure that the weights are not merged + # if self.r > 0: + # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling + # self.merged = False + # else: + # if self.merge_weights and not self.merged: + # # Merge the weights and mark it + # if self.r > 0: + # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling + # self.merged = True + + def forward(self, x: torch.Tensor): + def T(w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + if self.r > 0 and not self.merged: + result = F.linear(x, T(self.weight), bias=self.bias) + result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling + return result + else: + return F.linear(x, T(self.weight), bias=self.bias) + + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + tuning_mode: Optional[str] = None + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + tuning_mode: Optional[int] = None + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features) + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) + + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) + + x = self.drop(x) + return x + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + tuning_mode: Optional[int] = None + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(2 * hidden_features) + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + if self.tuning_mode == 'ssf': + x12 = ssf_ada(x12, self.ssf_scale_1, self.ssf_shift_1) + + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + out = self.w3(hidden) + + if self.tuning_mode == 'ssf': + out = ssf_ada(out, self.ssf_scale_2, self.ssf_scale_2) + + return out + + +try: + from xformers.ops import SwiGLU + #import numpy.bool + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + from xformers.components.attention import ScaledDotProduct + from xformers.components import MultiHeadDispatch + #import numpy.bool + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + window_size: int = 0, + tuning_mode: Optional[int] = None + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + if tuning_mode == 'lora': + self.tuning_mode = tuning_mode + self.qkv = LoRALinear(dim, dim * 3, bias=qkv_bias, r=8) + else: + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + + if tuning_mode == 'lora': + self.tuning_mode = tuning_mode + self.proj = LoRALinear(dim, dim, bias=proj_bias, r=8) + else: + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim * 3) + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + + #if not self.training: + # + # self.attn = ScaledDotProduct() + #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + if self.tuning_mode == 'ssf': + qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + if attn_bias is not None: + attn = attn + attn_bias[:, :, :N] + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) + + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + #if True: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x, attn_bias) + + B, N, C = x.shape + if self.tuning_mode == 'ssf': + qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads) + else: + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + if attn_bias is not None: + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N]) + else: + x = memory_efficient_attention(q, k, v) + x = x.reshape([B, N, C]) + + x = self.proj(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) + + x = self.proj_drop(x) + return x + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + #import numpy.bool + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values = None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + tuning_mode: Optional[int] = None + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + tuning_mode=tuning_mode + ) + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim) + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + def attn_residual_func(x: Tensor, attn_bias) -> Tensor: + if self.tuning_mode == 'ssf': + return self.ls1(self.attn(ssf_ada(self.norm1(x), self.ssf_scale_1, self.ssf_shift_1), attn_bias)) + else: + return self.ls1(self.attn(self.norm1(x), attn_bias)) + + def ffn_residual_func(x: Tensor) -> Tensor: + if self.tuning_mode == 'ssf': + return self.ls2(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2))) + else: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + attn_bias=attn_bias + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, attn_bias)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, attn_bias) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, attn_bias=None +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset, attn_bias) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list, attn_bias=None): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list, attn_bias) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x, others=None): + for b in self: + if others == None: + x = b(x) + else: + x = b(x, others) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=518, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=1e-5, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + tuning_mode=None, + **kwargs + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + tuning_mode_list = [tuning_mode] * depth + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tuning_mode=tuning_mode) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + tuning_mode=tuning_mode_list[i] + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + mode="bicubic", + antialias=self.interpolate_antialias, + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + B, C, H, W = x.size() + pad_h = (self.patch_size - H % self.patch_size) + pad_w = (self.patch_size - W % self.patch_size) + if pad_h == self.patch_size: + pad_h = 0 + if pad_w == self.patch_size: + pad_w = 0 + #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2)) + if pad_h + pad_w > 0: + x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear') + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + if self.tuning_mode == 'ssf': + x_norm = ssf_ada(x_norm, self.ssf_scale_1, self.ssf_shift_1) + + # return { + # "x_norm_clstoken": x_norm[:, 0], + # "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + # "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + # "x_prenorm": x, + # "masks": masks, + # } + features = [] + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)] + + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + return ret + # if is_training: + # return ret + # else: + # return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def load_ckpt_dino(checkpoint, model): + if checkpoint is not None: + try: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + except: + print('NO pretrained imagenet ckpt available! Check your path!') + del model.mask_token + return + + try: + model.load_state_dict(state_dict, strict=True) + except: + new_state_dict = {} + for key, value in state_dict.items(): + if 'blocks' in key: + key_new = 'blocks.0' + key[len('blocks'):] + else: + key_new = key + new_state_dict[key_new] = value + + model.load_state_dict(new_state_dict, strict=True) + del model.mask_token + return + else: + return + + +def vit_small(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + + load_ckpt_dino(checkpoint, model) + + return model + + +def vit_base(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + try: + model.load_state_dict(state_dict, strict=True) + except: + new_state_dict = {} + for key, value in state_dict.items(): + if 'blocks' in key: + key_new = 'blocks.0' + key[len('blocks'):] + else: + key_new = key + new_state_dict[key_new] = value + + model.load_state_dict(new_state_dict, strict=True) + del model.mask_token + return model + + +def vit_giant2(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + ffn_layer='swiglu', + **kwargs, + ) + return model + + + +def vit_small_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + tuning_mode=tuning_mode, + **kwargs, + ) + + load_ckpt_dino(checkpoint, model) + + return model + + +def vit_base_reg(patch_size=14, num_register_tokens=4, checkpoint=None, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + + load_ckpt_dino(checkpoint, model) + + return model + + +def vit_large_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs): + model = DinoVisionTransformer( + img_size = 518, + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + tuning_mode=tuning_mode, + **kwargs, + ) + + load_ckpt_dino(checkpoint, model) + + return model + + +def vit_giant2_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + ffn_layer='swiglu', + tuning_mode=tuning_mode, + **kwargs, + ) + + load_ckpt_dino(checkpoint, model) + + return model + +if __name__ == '__main__': + try: + from mmcv.utils import Config + except: + from mmengine import Config + + #rgb = torch.rand((2, 3, 518, 518)).cuda() + + #cfg.data_basic['crop_size']['0'] + #cfg.data_basic['crop_size']['1'] + cfg = Config.fromfile('/opt/ml/project/mu.hu/projects/monodepth_vit/mono/configs/RAFTDecoder/vit.raft5.large.kitti.py') + + #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036) + rgb = torch.zeros(1, 3, 616, 1064).cuda() + cfg['tuning_mode'] = 'ssf' + #model = vit_large_reg(checkpoint="/cpfs02/shared/public/groups/local_map/yvan/pretrained_weight_repo/vit/dinov2_vitl14_reg4_pretrain.pth", kwarg=cfg).cuda() + model = vit_large_reg(tuning_mode='ssf').cuda() + + #import timm + #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda() + #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn) + + out1 = model(rgb) + #out2 = model2(rgb) + temp = 0 + + diff --git a/mono/model/backbones/__init__.py b/mono/model/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc3ba70ef5ef867f0518d73a189e7531466cbab --- /dev/null +++ b/mono/model/backbones/__init__.py @@ -0,0 +1,11 @@ +from .ConvNeXt import convnext_xlarge +from .ConvNeXt import convnext_small +from .ConvNeXt import convnext_base +from .ConvNeXt import convnext_large +from .ConvNeXt import convnext_tiny +from .ViT_DINO import vit_large +from .ViT_DINO_reg import vit_small_reg, vit_large_reg + +__all__ = [ + 'convnext_xlarge', 'convnext_small', 'convnext_base', 'convnext_large', 'convnext_tiny', 'vit_small_reg', 'vit_large_reg' +] diff --git a/mono/model/backbones/__pycache__/ConvNeXt.cpython-39.pyc b/mono/model/backbones/__pycache__/ConvNeXt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..126ed2ec9338fdbaf1a3d9445815a8ff3f03aea5 Binary files /dev/null and b/mono/model/backbones/__pycache__/ConvNeXt.cpython-39.pyc differ diff --git a/mono/model/backbones/__pycache__/__init__.cpython-39.pyc b/mono/model/backbones/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16cdbeb696a2cca7a544cc37a01f66e764632aeb Binary files /dev/null and b/mono/model/backbones/__pycache__/__init__.cpython-39.pyc differ diff --git a/mono/model/decode_heads/HourGlassDecoder.py b/mono/model/decode_heads/HourGlassDecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e084382601e21e6ce5144abbd6a65f563905b659 --- /dev/null +++ b/mono/model/decode_heads/HourGlassDecoder.py @@ -0,0 +1,274 @@ +import torch +import torch.nn as nn +import numpy as np +import math +import torch.nn.functional as F + +def compute_depth_expectation(prob, depth_values): + depth_values = depth_values.view(*depth_values.shape, 1, 1) + depth = torch.sum(prob * depth_values, 1) + return depth + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3): + super(ConvBlock, self).__init__() + + if kernel_size == 3: + self.conv = nn.Sequential( + nn.ReflectionPad2d(1), + nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1), + ) + elif kernel_size == 1: + self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1) + + self.nonlin = nn.ELU(inplace=True) + + def forward(self, x): + out = self.conv(x) + out = self.nonlin(out) + return out + + +class ConvBlock_double(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3): + super(ConvBlock_double, self).__init__() + + if kernel_size == 3: + self.conv = nn.Sequential( + nn.ReflectionPad2d(1), + nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1), + ) + elif kernel_size == 1: + self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1) + + self.nonlin = nn.ELU(inplace=True) + self.conv_2 = nn.Conv2d(out_channels, out_channels, 1, padding=0, stride=1) + self.nonlin_2 =nn.ELU(inplace=True) + + def forward(self, x): + out = self.conv(x) + out = self.nonlin(out) + out = self.conv_2(out) + out = self.nonlin_2(out) + return out + +class DecoderFeature(nn.Module): + def __init__(self, feat_channels, num_ch_dec=[64, 64, 128, 256]): + super(DecoderFeature, self).__init__() + self.num_ch_dec = num_ch_dec + self.feat_channels = feat_channels + + self.upconv_3_0 = ConvBlock(self.feat_channels[3], self.num_ch_dec[3], kernel_size=1) + self.upconv_3_1 = ConvBlock_double( + self.feat_channels[2] + self.num_ch_dec[3], + self.num_ch_dec[3], + kernel_size=1) + + self.upconv_2_0 = ConvBlock(self.num_ch_dec[3], self.num_ch_dec[2], kernel_size=3) + self.upconv_2_1 = ConvBlock_double( + self.feat_channels[1] + self.num_ch_dec[2], + self.num_ch_dec[2], + kernel_size=3) + + self.upconv_1_0 = ConvBlock(self.num_ch_dec[2], self.num_ch_dec[1], kernel_size=3) + self.upconv_1_1 = ConvBlock_double( + self.feat_channels[0] + self.num_ch_dec[1], + self.num_ch_dec[1], + kernel_size=3) + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + + def forward(self, ref_feature): + x = ref_feature[3] + + x = self.upconv_3_0(x) + x = torch.cat((self.upsample(x), ref_feature[2]), 1) + x = self.upconv_3_1(x) + + x = self.upconv_2_0(x) + x = torch.cat((self.upsample(x), ref_feature[1]), 1) + x = self.upconv_2_1(x) + + x = self.upconv_1_0(x) + x = torch.cat((self.upsample(x), ref_feature[0]), 1) + x = self.upconv_1_1(x) + return x + + +class UNet(nn.Module): + def __init__(self, inp_ch=32, output_chal=1, down_sample_times=3, channel_mode='v0'): + super(UNet, self).__init__() + basic_block = ConvBnReLU + num_depth = 128 + + self.conv0 = basic_block(inp_ch, num_depth) + if channel_mode == 'v0': + channels = [num_depth, num_depth//2, num_depth//4, num_depth//8, num_depth // 8] + elif channel_mode == 'v1': + channels = [num_depth, num_depth, num_depth, num_depth, num_depth, num_depth] + self.down_sample_times = down_sample_times + for i in range(down_sample_times): + setattr( + self, 'conv_%d' % i, + nn.Sequential( + basic_block(channels[i], channels[i+1], stride=2), + basic_block(channels[i+1], channels[i+1]) + ) + ) + for i in range(down_sample_times-1,-1,-1): + setattr(self, 'deconv_%d' % i, + nn.Sequential( + nn.ConvTranspose2d( + channels[i+1], + channels[i], + kernel_size=3, + padding=1, + output_padding=1, + stride=2, + bias=False), + nn.BatchNorm2d(channels[i]), + nn.ReLU(inplace=True) + ) + ) + self.prob = nn.Conv2d(num_depth, output_chal, 1, stride=1, padding=0) + + def forward(self, x): + features = {} + conv0 = self.conv0(x) + x = conv0 + features[0] = conv0 + for i in range(self.down_sample_times): + x = getattr(self, 'conv_%d' % i)(x) + features[i+1] = x + for i in range(self.down_sample_times-1,-1,-1): + x = features[i] + getattr(self, 'deconv_%d' % i)(x) + x = self.prob(x) + return x + +class ConvBnReLU(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): + super(ConvBnReLU, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=pad, + bias=False + ) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + return F.relu(self.bn(self.conv(x)), inplace=True) + + +class HourglassDecoder(nn.Module): + def __init__(self, cfg): + super(HourglassDecoder, self).__init__() + self.inchannels = cfg.model.decode_head.in_channels # [256, 512, 1024, 2048] + self.decoder_channels = cfg.model.decode_head.decoder_channel # [64, 64, 128, 256] + self.min_val = cfg.data_basic.depth_normalize[0] + self.max_val = cfg.data_basic.depth_normalize[1] + + self.num_ch_dec = self.decoder_channels # [64, 64, 128, 256] + self.num_depth_regressor_anchor = 512 + self.feat_channels = self.inchannels + unet_in_channel = self.num_ch_dec[1] + unet_out_channel = 256 + + self.decoder_mono = DecoderFeature(self.feat_channels, self.num_ch_dec) + self.conv_out_2 = UNet(inp_ch=unet_in_channel, + output_chal=unet_out_channel + 1, + down_sample_times=3, + channel_mode='v0', + ) + + self.depth_regressor_2 = nn.Sequential( + nn.Conv2d(unet_out_channel, + self.num_depth_regressor_anchor, + kernel_size=3, + padding=1, + ), + nn.BatchNorm2d(self.num_depth_regressor_anchor), + nn.ReLU(inplace=True), + nn.Conv2d( + self.num_depth_regressor_anchor, + self.num_depth_regressor_anchor, + kernel_size=1, + ) + ) + self.residual_channel = 16 + self.conv_up_2 = nn.Sequential( + nn.Conv2d(1 + 2 + unet_out_channel, self.residual_channel, 3, padding=1), + nn.BatchNorm2d(self.residual_channel), + nn.ReLU(), + nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1), + nn.Upsample(scale_factor=4), + nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1), + nn.ReLU(), + nn.Conv2d(self.residual_channel, 1, 1, padding=0), + ) + + def get_bins(self, bins_num): + depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device='cuda') + depth_bins_vec = torch.exp(depth_bins_vec) + return depth_bins_vec + + def register_depth_expectation_anchor(self, bins_num, B): + depth_bins_vec = self.get_bins(bins_num) + depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1) + self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False) + + def upsample(self, x, scale_factor=2): + return F.interpolate(x, scale_factor=scale_factor, mode='nearest') + + def regress_depth_2(self, feature_map_d): + prob = self.depth_regressor_2(feature_map_d).softmax(dim=1) + B = prob.shape[0] + if "depth_expectation_anchor" not in self._buffers: + self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B) + d = compute_depth_expectation( + prob, + self.depth_expectation_anchor[:B, ...] + ).unsqueeze(1) + return d + + def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True): + y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device), + torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij') + meshgrid = torch.stack((x, y)) + meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1) + return meshgrid + + def forward(self, features_mono, **kwargs): + ''' + trans_ref2src: list of transformation matrix from the reference view to source view. [B, 4, 4] + inv_intrinsic_pool: list of inverse intrinsic matrix. + features_mono: features of reference and source views. [[ref_f1, ref_f2, ref_f3, ref_f4],[src1_f1, src1_f2, src1_f3, src1_f4], ...]. + ''' + outputs = {} + # get encoder feature of the reference view + ref_feat = features_mono + + feature_map_mono = self.decoder_mono(ref_feat) + feature_map_mono_pred = self.conv_out_2(feature_map_mono) + confidence_map_2 = feature_map_mono_pred[:, -1:, :, :] + feature_map_d_2 = feature_map_mono_pred[:, :-1, :, :] + + depth_pred_2 = self.regress_depth_2(feature_map_d_2) + + B, _, H, W = depth_pred_2.shape + + meshgrid = self.create_mesh_grid(H, W, B) + + depth_pred_mono = self.upsample(depth_pred_2, scale_factor=4) + 1e-1 * \ + self.conv_up_2( + torch.cat((depth_pred_2, meshgrid[:B, ...], feature_map_d_2), 1) + ) + confidence_map_mono = self.upsample(confidence_map_2, scale_factor=4) + + outputs=dict( + prediction=depth_pred_mono, + confidence=confidence_map_mono, + pred_logit=None, + ) + return outputs \ No newline at end of file diff --git a/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py b/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py new file mode 100644 index 0000000000000000000000000000000000000000..9af89f9b4b1878a2e4bcfcd489075c2e97cd8d3d --- /dev/null +++ b/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py @@ -0,0 +1,1033 @@ +import torch +import torch.nn as nn +import numpy as np +import math +import torch.nn.functional as F + +# LORA finetuning originally by edwardjhu +class LoRALayer(): + def __init__( + self, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + +class LoRALinear(nn.Linear, LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0., + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + merge_weights: bool = True, + **kwargs + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, + merge_weights=merge_weights) + + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) + self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + if fan_in_fan_out: + self.weight.data = self.weight.data.transpose(0, 1) + + def reset_parameters(self): + #nn.Linear.reset_parameters(self) + if hasattr(self, 'lora_A'): + # initialize B the same way as the default for nn.Linear and A to zero + # this is different than what is described in the paper but should not affect performance + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + # def train(self, mode: bool = True): + # def T(w): + # return w.transpose(0, 1) if self.fan_in_fan_out else w + # nn.Linear.train(self, mode) + # if mode: + # if self.merge_weights and self.merged: + # # Make sure that the weights are not merged + # if self.r > 0: + # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling + # self.merged = False + # else: + # if self.merge_weights and not self.merged: + # # Merge the weights and mark it + # if self.r > 0: + # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling + # self.merged = True + + def forward(self, x: torch.Tensor): + def T(w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + if self.r > 0 and not self.merged: + result = F.linear(x, T(self.weight), bias=self.bias) + result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling + return result + else: + return F.linear(x, T(self.weight), bias=self.bias) + +class ConvLoRA(nn.Conv2d, LoRALayer): + def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): + #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs) + nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) + assert isinstance(kernel_size, int) + + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter( + self.weight.new_zeros((r * kernel_size, in_channels * kernel_size)) + ) + self.lora_B = nn.Parameter( + self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size)) + ) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + self.merged = False + + def reset_parameters(self): + #self.conv.reset_parameters() + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + # def train(self, mode=True): + # super(ConvLoRA, self).train(mode) + # if mode: + # if self.merge_weights and self.merged: + # if self.r > 0: + # # Make sure that the weights are not merged + # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling + # self.merged = False + # else: + # if self.merge_weights and not self.merged: + # if self.r > 0: + # # Merge the weights and mark it + # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling + # self.merged = True + + def forward(self, x): + if self.r > 0 and not self.merged: + # return self.conv._conv_forward( + # x, + # self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling, + # self.conv.bias + # ) + weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + bias = self.bias + + return F.conv2d(x, weight, bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + else: + return F.conv2d(x, self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + +class ConvTransposeLoRA(nn.ConvTranspose2d, LoRALayer): + def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): + #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs) + nn.ConvTranspose2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) + assert isinstance(kernel_size, int) + + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter( + self.weight.new_zeros((r * kernel_size, in_channels * kernel_size)) + ) + self.lora_B = nn.Parameter( + self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size)) + ) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + self.merged = False + + def reset_parameters(self): + #self.conv.reset_parameters() + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + # def train(self, mode=True): + # super(ConvTransposeLoRA, self).train(mode) + # if mode: + # if self.merge_weights and self.merged: + # if self.r > 0: + # # Make sure that the weights are not merged + # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling + # self.merged = False + # else: + # if self.merge_weights and not self.merged: + # if self.r > 0: + # # Merge the weights and mark it + # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling + # self.merged = True + + def forward(self, x): + if self.r > 0 and not self.merged: + weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + bias = self.bias + return F.conv_transpose2d(x, weight, + bias=bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding, + groups=self.groups, dilation=self.dilation) + else: + return F.conv_transpose2d(x, self.weight, + bias=self.bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding, + groups=self.groups, dilation=self.dilation) + #return self.conv(x) + +class Conv2dLoRA(ConvLoRA): + def __init__(self, *args, **kwargs): + super(Conv2dLoRA, self).__init__(*args, **kwargs) + +class ConvTranspose2dLoRA(ConvTransposeLoRA): + def __init__(self, *args, **kwargs): + super(ConvTranspose2dLoRA, self).__init__(*args, **kwargs) + + +def compute_depth_expectation(prob, depth_values): + depth_values = depth_values.view(*depth_values.shape, 1, 1) + depth = torch.sum(prob * depth_values, 1) + return depth + +def interpolate_float32(x, size=None, scale_factor=None, mode='nearest', align_corners=None): + with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False): + return F.interpolate(x.float(), size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners) + +# def upflow8(flow, mode='bilinear'): +# new_size = (8 * flow.shape[2], 8 * flow.shape[3]) +# return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + +def upflow4(flow, mode='bilinear'): + new_size = (4 * flow.shape[2], 4 * flow.shape[3]) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False): + return F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + +def coords_grid(batch, ht, wd): + # coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = (torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd))) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + +def norm_normalize(norm_out): + min_kappa = 0.01 + norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1) + norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10 + kappa = F.elu(kappa) + 1.0 + min_kappa + final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1) + return final_out + +# uncertainty-guided sampling (only used during training) +@torch.no_grad() +def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta): + device = init_normal.device + B, _, H, W = init_normal.shape + N = int(sampling_ratio * H * W) + beta = beta + + # uncertainty map + uncertainty_map = -1 * init_normal[:, -1, :, :] # B, H, W + + # gt_invalid_mask (B, H, W) + if gt_norm_mask is not None: + gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest') + gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5 + uncertainty_map[gt_invalid_mask] = -1e4 + + # (B, H*W) + _, idx = uncertainty_map.view(B, -1).sort(1, descending=True) + + # importance sampling + if int(beta * N) > 0: + importance = idx[:, :int(beta * N)] # B, beta*N + + # remaining + remaining = idx[:, int(beta * N):] # B, H*W - beta*N + + # coverage + num_coverage = N - int(beta * N) + + if num_coverage <= 0: + samples = importance + else: + coverage_list = [] + for i in range(B): + idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" + coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N + coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N + samples = torch.cat((importance, coverage), dim=1) # B, N + + else: + # remaining + remaining = idx[:, :] # B, H*W + + # coverage + num_coverage = N + + coverage_list = [] + for i in range(B): + idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" + coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N + coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N + samples = coverage + + # point coordinates + rows_int = samples // W # 0 for first row, H-1 for last row + rows_float = rows_int / float(H-1) # 0 to 1.0 + rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0 + + cols_int = samples % W # 0 for first column, W-1 for last column + cols_float = cols_int / float(W-1) # 0 to 1.0 + cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0 + + point_coords = torch.zeros(B, 1, N, 2) + point_coords[:, 0, :, 0] = cols_float # x coord + point_coords[:, 0, :, 1] = rows_float # y coord + point_coords = point_coords.to(device) + return point_coords, rows_int, cols_int + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256, output_dim_depth=2, output_dim_norm=4, tuning_mode=None): + super(FlowHead, self).__init__() + self.conv1d = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) + self.conv2d = Conv2dLoRA(hidden_dim // 2, output_dim_depth, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) + + self.conv1n = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) + self.conv2n = Conv2dLoRA(hidden_dim // 2, output_dim_norm, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + depth = self.conv2d(self.relu(self.conv1d(x))) + normal = self.conv2n(self.relu(self.conv1n(x))) + return torch.cat((depth, normal), dim=1) + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim, input_dim, kernel_size=3, tuning_mode=None): + super(ConvGRU, self).__init__() + self.convz = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0) + self.convr = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0) + self.convq = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0) + + def forward(self, h, cz, cr, cq, *x_list): + x = torch.cat(x_list, dim=1) + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid((self.convz(hx) + cz)) + r = torch.sigmoid((self.convr(hx) + cr)) + q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq)) + + # z = torch.sigmoid((self.convz(hx) + cz).float()) + # r = torch.sigmoid((self.convr(hx) + cr).float()) + # q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq).float()) + + h = (1-z) * h + z * q + return h + +def pool2x(x): + return F.avg_pool2d(x, 3, stride=2, padding=1) + +def pool4x(x): + return F.avg_pool2d(x, 5, stride=4, padding=1) + +def interp(x, dest): + interp_args = {'mode': 'bilinear', 'align_corners': True} + return interpolate_float32(x, dest.shape[2:], **interp_args) + +class BasicMultiUpdateBlock(nn.Module): + def __init__(self, args, hidden_dims=[], out_dims=2, tuning_mode=None): + super().__init__() + self.args = args + self.n_gru_layers = args.model.decode_head.n_gru_layers # 3 + self.n_downsample = args.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K) + + # self.encoder = BasicMotionEncoder(args) + # encoder_output_dim = 128 # if there is corr volume + encoder_output_dim = 6 # no corr volume + + self.gru08 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (self.n_gru_layers > 1), tuning_mode=tuning_mode) + self.gru16 = ConvGRU(hidden_dims[1], hidden_dims[0] * (self.n_gru_layers == 3) + hidden_dims[2], tuning_mode=tuning_mode) + self.gru32 = ConvGRU(hidden_dims[0], hidden_dims[1], tuning_mode=tuning_mode) + self.flow_head = FlowHead(hidden_dims[2], hidden_dim=2*hidden_dims[2], tuning_mode=tuning_mode) + factor = 2**self.n_downsample + + self.mask = nn.Sequential( + Conv2dLoRA(hidden_dims[2], hidden_dims[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0), + nn.ReLU(inplace=True), + Conv2dLoRA(hidden_dims[2], (factor**2)*9, 1, padding=0, r = 8 if tuning_mode == 'lora' else 0)) + + def forward(self, net, inp, corr=None, flow=None, iter08=True, iter16=True, iter32=True, update=True): + + if iter32: + net[2] = self.gru32(net[2], *(inp[2]), pool2x(net[1])) + if iter16: + if self.n_gru_layers > 2: + net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]), interp(net[2], net[1])) + else: + net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1])) + if iter08: + if corr is not None: + motion_features = self.encoder(flow, corr) + else: + motion_features = flow + if self.n_gru_layers > 1: + net[0] = self.gru08(net[0], *(inp[0]), motion_features, interp(net[1], net[0])) + else: + net[0] = self.gru08(net[0], *(inp[0]), motion_features) + + if not update: + return net + + delta_flow = self.flow_head(net[0]) + + # scale mask to balence gradients + mask = .25 * self.mask(net[0]) + return net, mask, delta_flow + +class LayerNorm2d(nn.LayerNorm): + def __init__(self, dim): + super(LayerNorm2d, self).__init__(dim) + + def forward(self, x): + x = x.permute(0, 2, 3, 1).contiguous() + x = super(LayerNorm2d, self).forward(x) + x = x.permute(0, 3, 1, 2).contiguous() + return x + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1, tuning_mode=None): + super(ResidualBlock, self).__init__() + + self.conv1 = Conv2dLoRA(in_planes, planes, kernel_size=3, padding=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0) + self.conv2 = Conv2dLoRA(planes, planes, kernel_size=3, padding=1, r = 8 if tuning_mode == 'lora' else 0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'layer': + self.norm1 = LayerNorm2d(planes) + self.norm2 = LayerNorm2d(planes) + if not (stride == 1 and in_planes == planes): + self.norm3 = LayerNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.Sequential() + + if stride == 1 and in_planes == planes: + self.downsample = None + + else: + self.downsample = nn.Sequential( + Conv2dLoRA(in_planes, planes, kernel_size=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0), self.norm3) + + def forward(self, x): + y = x + y = self.conv1(y) + y = self.norm1(y) + y = self.relu(y) + y = self.conv2(y) + y = self.norm2(y) + y = self.relu(y) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ContextFeatureEncoder(nn.Module): + ''' + Encoder features are used to: + 1. initialize the hidden state of the update operator + 2. and also injected into the GRU during each iteration of the update operator + ''' + def __init__(self, in_dim, output_dim, tuning_mode=None): + ''' + in_dim = [x4, x8, x16, x32] + output_dim = [hindden_dims, context_dims] + [[x4,x8,x16,x32],[x4,x8,x16,x32]] + ''' + super().__init__() + + output_list = [] + for dim in output_dim: + conv_out = nn.Sequential( + ResidualBlock(in_dim[0], dim[0], 'layer', stride=1, tuning_mode=tuning_mode), + Conv2dLoRA(dim[0], dim[0], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)) + output_list.append(conv_out) + + self.outputs04 = nn.ModuleList(output_list) + + output_list = [] + for dim in output_dim: + conv_out = nn.Sequential( + ResidualBlock(in_dim[1], dim[1], 'layer', stride=1, tuning_mode=tuning_mode), + Conv2dLoRA(dim[1], dim[1], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)) + output_list.append(conv_out) + + self.outputs08 = nn.ModuleList(output_list) + + output_list = [] + for dim in output_dim: + conv_out = nn.Sequential( + ResidualBlock(in_dim[2], dim[2], 'layer', stride=1, tuning_mode=tuning_mode), + Conv2dLoRA(dim[2], dim[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)) + output_list.append(conv_out) + + self.outputs16 = nn.ModuleList(output_list) + + # output_list = [] + # for dim in output_dim: + # conv_out = Conv2dLoRA(in_dim[3], dim[3], 3, padding=1) + # output_list.append(conv_out) + + # self.outputs32 = nn.ModuleList(output_list) + + def forward(self, encoder_features): + x_4, x_8, x_16, x_32 = encoder_features + + outputs04 = [f(x_4) for f in self.outputs04] + outputs08 = [f(x_8) for f in self.outputs08] + outputs16 = [f(x_16)for f in self.outputs16] + # outputs32 = [f(x_32) for f in self.outputs32] + + return (outputs04, outputs08, outputs16) + +class ConvBlock(nn.Module): + # reimplementation of DPT + def __init__(self, channels, tuning_mode=None): + super(ConvBlock, self).__init__() + + self.act = nn.ReLU(inplace=True) + self.conv1 = Conv2dLoRA( + channels, + channels, + kernel_size=3, + stride=1, + padding=1, + r = 8 if tuning_mode == 'lora' else 0 + ) + self.conv2 = Conv2dLoRA( + channels, + channels, + kernel_size=3, + stride=1, + padding=1, + r = 8 if tuning_mode == 'lora' else 0 + ) + + def forward(self, x): + out = self.act(x) + out = self.conv1(out) + out = self.act(out) + out = self.conv2(out) + return x + out + +class FuseBlock(nn.Module): + # reimplementation of DPT + def __init__(self, in_channels, out_channels, fuse=True, upsample=True, scale_factor=2, tuning_mode=None): + super(FuseBlock, self).__init__() + + self.fuse = fuse + self.scale_factor = scale_factor + self.way_trunk = ConvBlock(in_channels, tuning_mode=tuning_mode) + if self.fuse: + self.way_branch = ConvBlock(in_channels, tuning_mode=tuning_mode) + + self.out_conv = Conv2dLoRA( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + r = 8 if tuning_mode == 'lora' else 0 + ) + self.upsample = upsample + + def forward(self, x1, x2=None): + if x2 is not None: + x2 = self.way_branch(x2) + x1 = x1 + x2 + + out = self.way_trunk(x1) + + if self.upsample: + out = interpolate_float32( + out, scale_factor=self.scale_factor, mode="bilinear", align_corners=True + ) + out = self.out_conv(out) + return out + +class Readout(nn.Module): + # From DPT + def __init__(self, in_features, use_cls_token=True, num_register_tokens=0, tuning_mode=None): + super(Readout, self).__init__() + self.use_cls_token = use_cls_token + if self.use_cls_token == True: + self.project_patch = LoRALinear(in_features, in_features, r = 8 if tuning_mode == 'lora' else 0) + self.project_learn = LoRALinear((1 + num_register_tokens) * in_features, in_features, bias=False, r = 8 if tuning_mode == 'lora' else 0) + self.act = nn.GELU() + else: + self.project = nn.Identity() + + def forward(self, x): + + if self.use_cls_token == True: + x_patch = self.project_patch(x[0]) + x_learn = self.project_learn(x[1]) + x_learn = x_learn.expand_as(x_patch).contiguous() + features = x_patch + x_learn + return self.act(features) + else: + return self.project(x) + +class Token2Feature(nn.Module): + # From DPT + def __init__(self, vit_channel, feature_channel, scale_factor, use_cls_token=True, num_register_tokens=0, tuning_mode=None): + super(Token2Feature, self).__init__() + self.scale_factor = scale_factor + self.readoper = Readout(in_features=vit_channel, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) + if scale_factor > 1 and isinstance(scale_factor, int): + self.sample = ConvTranspose2dLoRA(r = 8 if tuning_mode == 'lora' else 0, + in_channels=vit_channel, + out_channels=feature_channel, + kernel_size=scale_factor, + stride=scale_factor, + padding=0, + ) + + elif scale_factor > 1: + self.sample = nn.Sequential( + # Upsample2(upscale=scale_factor), + # nn.Upsample(scale_factor=scale_factor), + Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0, + in_channels=vit_channel, + out_channels=feature_channel, + kernel_size=1, + stride=1, + padding=0, + ), + ) + + + elif scale_factor < 1: + scale_factor = int(1.0 / scale_factor) + self.sample = Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0, + in_channels=vit_channel, + out_channels=feature_channel, + kernel_size=scale_factor+1, + stride=scale_factor, + padding=1, + ) + + else: + self.sample = nn.Identity() + + def forward(self, x): + x = self.readoper(x) + #if use_cls_token == True: + x = x.permute(0, 3, 1, 2).contiguous() + if isinstance(self.scale_factor, float): + x = interpolate_float32(x.float(), scale_factor=self.scale_factor, mode='nearest') + x = self.sample(x) + return x + +class EncoderFeature(nn.Module): + def __init__(self, vit_channel, num_ch_dec=[256, 512, 1024, 1024], use_cls_token=True, num_register_tokens=0, tuning_mode=None): + super(EncoderFeature, self).__init__() + self.vit_channel = vit_channel + self.num_ch_dec = num_ch_dec + + self.read_3 = Token2Feature(self.vit_channel, self.num_ch_dec[3], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) + self.read_2 = Token2Feature(self.vit_channel, self.num_ch_dec[2], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) + self.read_1 = Token2Feature(self.vit_channel, self.num_ch_dec[1], scale_factor=2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) + self.read_0 = Token2Feature(self.vit_channel, self.num_ch_dec[0], scale_factor=7/2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) + + def forward(self, ref_feature): + x = self.read_3(ref_feature[3]) # 1/14 + x2 = self.read_2(ref_feature[2]) # 1/14 + x1 = self.read_1(ref_feature[1]) # 1/7 + x0 = self.read_0(ref_feature[0]) # 1/4 + + return x, x2, x1, x0 + +class DecoderFeature(nn.Module): + def __init__(self, vit_channel, num_ch_dec=[128, 256, 512, 1024, 1024], use_cls_token=True, tuning_mode=None): + super(DecoderFeature, self).__init__() + self.vit_channel = vit_channel + self.num_ch_dec = num_ch_dec + + self.upconv_3 = FuseBlock( + self.num_ch_dec[4], + self.num_ch_dec[3], + fuse=False, upsample=False, tuning_mode=tuning_mode) + + self.upconv_2 = FuseBlock( + self.num_ch_dec[3], + self.num_ch_dec[2], + tuning_mode=tuning_mode) + + self.upconv_1 = FuseBlock( + self.num_ch_dec[2], + self.num_ch_dec[1] + 2, + scale_factor=7/4, + tuning_mode=tuning_mode) + + # self.upconv_0 = FuseBlock( + # self.num_ch_dec[1], + # self.num_ch_dec[0] + 1, + # ) + + def forward(self, ref_feature): + x, x2, x1, x0 = ref_feature # 1/14 1/14 1/7 1/4 + + x = self.upconv_3(x) # 1/14 + x = self.upconv_2(x, x2) # 1/7 + x = self.upconv_1(x, x1) # 1/4 + # x = self.upconv_0(x, x0) # 4/7 + return x + +class RAFTDepthNormalDPT5(nn.Module): + def __init__(self, cfg): + super().__init__() + self.in_channels = cfg.model.decode_head.in_channels # [1024, 1024, 1024, 1024] + self.feature_channels = cfg.model.decode_head.feature_channels # [256, 512, 1024, 1024] [2/7, 1/7, 1/14, 1/14] + self.decoder_channels = cfg.model.decode_head.decoder_channels # [128, 256, 512, 1024, 1024] [-, 1/4, 1/7, 1/14, 1/14] + self.use_cls_token = cfg.model.decode_head.use_cls_token + self.up_scale = cfg.model.decode_head.up_scale + self.num_register_tokens = cfg.model.decode_head.num_register_tokens + self.min_val = cfg.data_basic.depth_normalize[0] + self.max_val = cfg.data_basic.depth_normalize[1] + self.regress_scale = 100.0\ + + try: + tuning_mode = cfg.model.decode_head.tuning_mode + except: + tuning_mode = None + self.tuning_mode = tuning_mode + + self.hidden_dims = self.context_dims = cfg.model.decode_head.hidden_channels # [128, 128, 128, 128] + self.n_gru_layers = cfg.model.decode_head.n_gru_layers # 3 + self.n_downsample = cfg.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K) + self.iters = cfg.model.decode_head.iters # 22 + self.slow_fast_gru = cfg.model.decode_head.slow_fast_gru # True + + self.num_depth_regressor_anchor = 256 # 512 + self.used_res_channel = self.decoder_channels[1] # now, use 2/7 res + self.token2feature = EncoderFeature(self.in_channels[0], self.feature_channels, self.use_cls_token, self.num_register_tokens, tuning_mode=tuning_mode) + self.decoder_mono = DecoderFeature(self.in_channels, self.decoder_channels, tuning_mode=tuning_mode) + self.depth_regressor = nn.Sequential( + Conv2dLoRA(self.used_res_channel, + self.num_depth_regressor_anchor, + kernel_size=3, + padding=1, r = 8 if tuning_mode == 'lora' else 0), + # nn.BatchNorm2d(self.num_depth_regressor_anchor), + nn.ReLU(inplace=True), + Conv2dLoRA(self.num_depth_regressor_anchor, + self.num_depth_regressor_anchor, + kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), + ) + self.normal_predictor = nn.Sequential( + Conv2dLoRA(self.used_res_channel, + 128, + kernel_size=3, + padding=1, r = 8 if tuning_mode == 'lora' else 0,), + # nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True), + Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True), + Conv2dLoRA(128, 3, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), + ) + + self.context_feature_encoder = ContextFeatureEncoder(self.feature_channels, [self.hidden_dims, self.context_dims], tuning_mode=tuning_mode) + self.context_zqr_convs = nn.ModuleList([Conv2dLoRA(self.context_dims[i], self.hidden_dims[i]*3, 3, padding=3//2, r = 8 if tuning_mode == 'lora' else 0) for i in range(self.n_gru_layers)]) + self.update_block = BasicMultiUpdateBlock(cfg, hidden_dims=self.hidden_dims, out_dims=6, tuning_mode=tuning_mode) + + self.relu = nn.ReLU(inplace=True) + + def get_bins(self, bins_num): + depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cuda") + depth_bins_vec = torch.exp(depth_bins_vec) + return depth_bins_vec + + def register_depth_expectation_anchor(self, bins_num, B): + depth_bins_vec = self.get_bins(bins_num) + depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1) + self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False) + + def clamp(self, x): + y = self.relu(x - self.min_val) + self.min_val + y = self.max_val - self.relu(self.max_val - y) + return y + + def regress_depth(self, feature_map_d): + prob_feature = self.depth_regressor(feature_map_d) + prob = prob_feature.softmax(dim=1) + #prob = prob_feature.float().softmax(dim=1) + + ## Error logging + if torch.isnan(prob).any(): + print('prob_feat_nan!!!') + if torch.isinf(prob).any(): + print('prob_feat_inf!!!') + + # h = prob[0,:,0,0].cpu().numpy().reshape(-1) + # import matplotlib.pyplot as plt + # plt.bar(range(len(h)), h) + B = prob.shape[0] + if "depth_expectation_anchor" not in self._buffers: + self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B) + d = compute_depth_expectation( + prob, + self.depth_expectation_anchor[:B, ...]).unsqueeze(1) + + ## Error logging + if torch.isnan(d ).any(): + print('d_nan!!!') + if torch.isinf(d ).any(): + print('d_inf!!!') + + return (self.clamp(d) - self.max_val)/ self.regress_scale, prob_feature + + def pred_normal(self, feature_map, confidence): + normal_out = self.normal_predictor(feature_map) + + ## Error logging + if torch.isnan(normal_out).any(): + print('norm_nan!!!') + if torch.isinf(normal_out).any(): + print('norm_feat_inf!!!') + + return norm_normalize(torch.cat([normal_out, confidence], dim=1)) + #return norm_normalize(torch.cat([normal_out, confidence], dim=1).float()) + + def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True): + y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device), + torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij') + meshgrid = torch.stack((x, y)) + meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1) + #self.register_buffer('meshgrid', meshgrid, persistent=False) + return meshgrid + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, D, H, W = flow.shape + factor = 2 ** self.n_downsample + mask = mask.view(N, 1, 9, factor, factor, H, W) + mask = torch.softmax(mask, dim=2) + #mask = torch.softmax(mask.float(), dim=2) + + #up_flow = F.unfold(factor * flow, [3,3], padding=1) + up_flow = F.unfold(flow, [3,3], padding=1) + up_flow = up_flow.view(N, D, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, D, factor*H, factor*W) + + def initialize_flow(self, img): + """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, _, H, W = img.shape + + coords0 = coords_grid(N, H, W).to(img.device) + coords1 = coords_grid(N, H, W).to(img.device) + + return coords0, coords1 + + def upsample(self, x, scale_factor=2): + """Upsample input tensor by a factor of 2 + """ + return interpolate_float32(x, scale_factor=scale_factor*self.up_scale/8, mode="nearest") + + def forward(self, vit_features, **kwargs): + ## read vit token to multi-scale features + B, H, W, _, _, num_register_tokens = vit_features[1] + vit_features = vit_features[0] + + ## Error logging + if torch.isnan(vit_features[0]).any(): + print('vit_feature_nan!!!') + if torch.isinf(vit_features[0]).any(): + print('vit_feature_inf!!!') + + if self.use_cls_token == True: + vit_features = [[ft[:, 1+num_register_tokens:, :].view(B, H, W, self.in_channels[0]), \ + ft[:, 0:1+num_register_tokens, :].view(B, 1, 1, self.in_channels[0] * (1+num_register_tokens))] for ft in vit_features] + else: + vit_features = [ft.view(B, H, W, self.in_channels[0]) for ft in vit_features] + encoder_features = self.token2feature(vit_features) # 1/14, 1/14, 1/7, 1/4 + + ## Error logging + for en_ft in encoder_features: + if torch.isnan(en_ft).any(): + print('decoder_feature_nan!!!') + print(en_ft.shape) + if torch.isinf(en_ft).any(): + print('decoder_feature_inf!!!') + print(en_ft.shape) + + ## decode features to init-depth (and confidence) + ref_feat= self.decoder_mono(encoder_features) # now, 1/4 for depth + + ## Error logging + if torch.isnan(ref_feat).any(): + print('ref_feat_nan!!!') + if torch.isinf(ref_feat).any(): + print('ref_feat_inf!!!') + + feature_map = ref_feat[:, :-2, :, :] # feature map share of depth and normal prediction + depth_confidence_map = ref_feat[:, -2:-1, :, :] + normal_confidence_map = ref_feat[:, -1:, :, :] + depth_pred, binmap = self.regress_depth(feature_map) # regress bin for depth + normal_pred = self.pred_normal(feature_map, normal_confidence_map) # mlp for normal + + depth_init = torch.cat((depth_pred, depth_confidence_map, normal_pred), dim=1) # (N, 1+1+4, H, W) + + ## encoder features to context-feature for init-hidden-state and contex-features + cnet_list = self.context_feature_encoder(encoder_features[::-1]) + net_list = [torch.tanh(x[0]) for x in cnet_list] # x_4, x_8, x_16 of hidden state + inp_list = [torch.relu(x[1]) for x in cnet_list] # x_4, x_8, x_16 context features + + # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning + inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)] + + coords0, coords1 = self.initialize_flow(net_list[0]) + if depth_init is not None: + coords1 = coords1 + depth_init + + if self.training: + low_resolution_init = [self.clamp(depth_init[:,:1] * self.regress_scale + self.max_val), depth_init[:,1:2], norm_normalize(depth_init[:,2:].clone())] + init_depth = upflow4(depth_init) + flow_predictions = [self.clamp(init_depth[:,:1] * self.regress_scale + self.max_val)] + conf_predictions = [init_depth[:,1:2]] + normal_outs = [norm_normalize(init_depth[:,2:].clone())] + + else: + flow_predictions = [] + conf_predictions = [] + samples_pred_list = [] + coord_list = [] + normal_outs = [] + low_resolution_init = [] + + for itr in range(self.iters): + # coords1 = coords1.detach() + flow = coords1 - coords0 + if self.n_gru_layers == 3 and self.slow_fast_gru: # Update low-res GRU + net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False) + if self.n_gru_layers >= 2 and self.slow_fast_gru:# Update low-res GRU and mid-res GRU + net_list = self.update_block(net_list, inp_list, iter32=self.n_gru_layers==3, iter16=True, iter08=False, update=False) + net_list, up_mask, delta_flow = self.update_block(net_list, inp_list, None, flow, iter32=self.n_gru_layers==3, iter16=self.n_gru_layers>=2) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # We do not need to upsample or output intermediate results in test_mode + #if (not self.training) and itr < self.iters-1: + #continue + + # upsample predictions + if up_mask is None: + flow_up = self.upsample(coords1-coords0, 4) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + # flow_up = self.upsample(coords1-coords0, 4) + + flow_predictions.append(self.clamp(flow_up[:,:1] * self.regress_scale + self.max_val)) + conf_predictions.append(flow_up[:,1:2]) + normal_outs.append(norm_normalize(flow_up[:,2:].clone())) + + outputs=dict( + prediction=flow_predictions[-1], + predictions_list=flow_predictions, + confidence=conf_predictions[-1], + confidence_list=conf_predictions, + pred_logit=None, + # samples_pred_list=samples_pred_list, + # coord_list=coord_list, + prediction_normal=normal_outs[-1], + normal_out_list=normal_outs, + low_resolution_init=low_resolution_init, + ) + + return outputs + + +if __name__ == "__main__": + try: + from mmcv.utils import Config + except: + from mmengine import Config + cfg = Config.fromfile('/cpfs01/shared/public/users/mu.hu/monodepth/mono/configs/RAFTDecoder/vit.raft.full2t.py') + cfg.model.decode_head.in_channels = [384, 384, 384, 384] + cfg.model.decode_head.feature_channels = [96, 192, 384, 768] + cfg.model.decode_head.decoder_channels = [48, 96, 192, 384, 384] + cfg.model.decode_head.hidden_channels = [48, 48, 48, 48, 48] + cfg.model.decode_head.up_scale = 7 + + # cfg.model.decode_head.use_cls_token = True + # vit_feature = [[torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \ + # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \ + # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \ + # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()]] + + cfg.model.decode_head.use_cls_token = True + cfg.model.decode_head.num_register_tokens = 4 + vit_feature = [[torch.rand((2, (74 * 74) + 5, 384)).cuda(),\ + torch.rand((2, (74 * 74) + 5, 384)).cuda(), \ + torch.rand((2, (74 * 74) + 5, 384)).cuda(), \ + torch.rand((2, (74 * 74) + 5, 384)).cuda()], (2, 74, 74, 1036, 1036, 4)] + + decoder = RAFTDepthNormalDPT5(cfg).cuda() + output = decoder(vit_feature) + temp = 1 + + + + diff --git a/mono/model/decode_heads/__init__.py b/mono/model/decode_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92381a5fc3dad0ca8009c1ab0a153ce6b107c634 --- /dev/null +++ b/mono/model/decode_heads/__init__.py @@ -0,0 +1,4 @@ +from .HourGlassDecoder import HourglassDecoder +from .RAFTDepthNormalDPTDecoder5 import RAFTDepthNormalDPT5 + +__all__=['HourglassDecoder', 'RAFTDepthNormalDPT5'] diff --git a/mono/model/decode_heads/__pycache__/HourGlassDecoder.cpython-39.pyc b/mono/model/decode_heads/__pycache__/HourGlassDecoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47c981bd3124006156222a76cf044e3d5033d77c Binary files /dev/null and b/mono/model/decode_heads/__pycache__/HourGlassDecoder.cpython-39.pyc differ diff --git a/mono/model/decode_heads/__pycache__/__init__.cpython-39.pyc b/mono/model/decode_heads/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1433a90f49acb41d64bf291bebc75e844e4bc5b Binary files /dev/null and b/mono/model/decode_heads/__pycache__/__init__.cpython-39.pyc differ diff --git a/mono/model/model_pipelines/__base_model__.py b/mono/model/model_pipelines/__base_model__.py new file mode 100644 index 0000000000000000000000000000000000000000..d599c418b3d9677a195fe87d45bb31bf1068fbce --- /dev/null +++ b/mono/model/model_pipelines/__base_model__.py @@ -0,0 +1,20 @@ +import torch +import torch.nn as nn +from mono.utils.comm import get_func + + +class BaseDepthModel(nn.Module): + def __init__(self, cfg, **kwargs) -> None: + super(BaseDepthModel, self).__init__() + model_type = cfg.model.type + self.depth_model = get_func('mono.model.model_pipelines.' + model_type)(cfg) + + def forward(self, data): + output = self.depth_model(**data) + + return output['prediction'], output['confidence'], output + + def inference(self, data): + with torch.no_grad(): + pred_depth, confidence, _ = self.forward(data) + return pred_depth, confidence \ No newline at end of file diff --git a/mono/model/model_pipelines/__init__.py b/mono/model/model_pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b962a3f858573466e429219c4ad70951b545b637 --- /dev/null +++ b/mono/model/model_pipelines/__init__.py @@ -0,0 +1,6 @@ + +from .dense_pipeline import DensePredModel +from .__base_model__ import BaseDepthModel +__all__ = [ + 'DensePredModel', 'BaseDepthModel', +] \ No newline at end of file diff --git a/mono/model/model_pipelines/__pycache__/__base_model__.cpython-39.pyc b/mono/model/model_pipelines/__pycache__/__base_model__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09f4eaf42b12b9e4820d30f7d2e0a651bef48ad1 Binary files /dev/null and b/mono/model/model_pipelines/__pycache__/__base_model__.cpython-39.pyc differ diff --git a/mono/model/model_pipelines/__pycache__/__init__.cpython-39.pyc b/mono/model/model_pipelines/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efc35f49a0d2964b808e7900728e404c68ba5435 Binary files /dev/null and b/mono/model/model_pipelines/__pycache__/__init__.cpython-39.pyc differ diff --git a/mono/model/model_pipelines/__pycache__/dense_pipeline.cpython-39.pyc b/mono/model/model_pipelines/__pycache__/dense_pipeline.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee8194cf14465d26eed75d612ba77ff7c19699f9 Binary files /dev/null and b/mono/model/model_pipelines/__pycache__/dense_pipeline.cpython-39.pyc differ diff --git a/mono/model/model_pipelines/dense_pipeline.py b/mono/model/model_pipelines/dense_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..1362a11b6b9d45e50795dd705906aa3f79ec4a9a --- /dev/null +++ b/mono/model/model_pipelines/dense_pipeline.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn +from mono.utils.comm import get_func + +class DensePredModel(nn.Module): + def __init__(self, cfg) -> None: + super(DensePredModel, self).__init__() + + self.encoder = get_func('mono.model.' + cfg.model.backbone.prefix + cfg.model.backbone.type)(**cfg.model.backbone) + self.decoder = get_func('mono.model.' + cfg.model.decode_head.prefix + cfg.model.decode_head.type)(cfg) + + def forward(self, input, **kwargs): + # [f_32, f_16, f_8, f_4] + features = self.encoder(input) + out = self.decoder(features, **kwargs) + return out \ No newline at end of file diff --git a/mono/model/monodepth_model.py b/mono/model/monodepth_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0b58b7643ee43f84fd4e621e5b3b61b1f3f85564 --- /dev/null +++ b/mono/model/monodepth_model.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn +from .model_pipelines.__base_model__ import BaseDepthModel + +class DepthModel(BaseDepthModel): + def __init__(self, cfg, **kwards): + super(DepthModel, self).__init__(cfg) + model_type = cfg.model.type + + def inference(self, data): + with torch.no_grad(): + pred_depth, confidence, output_dict = self.forward(data) + return pred_depth, confidence, output_dict + +def get_monodepth_model( + cfg : dict, + **kwargs + ) -> nn.Module: + # config depth model + model = DepthModel(cfg, **kwargs) + #model.init_weights(load_imagenet_model, imagenet_ckpt_fpath) + assert isinstance(model, nn.Module) + return model + +def get_configured_monodepth_model( + cfg: dict, + ) -> nn.Module: + """ + Args: + @ configs: configures for the network. + @ load_imagenet_model: whether to initialize from ImageNet-pretrained model. + @ imagenet_ckpt_fpath: string representing path to file with weights to initialize model with. + Returns: + # model: depth model. + """ + model = get_monodepth_model(cfg) + return model diff --git a/mono/tools/test_scale_cano.py b/mono/tools/test_scale_cano.py new file mode 100644 index 0000000000000000000000000000000000000000..684fb841a004833e27edd52192ad0821bf2d43af --- /dev/null +++ b/mono/tools/test_scale_cano.py @@ -0,0 +1,158 @@ +import os +import os.path as osp +import cv2 +import time +import sys +CODE_SPACE=os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(CODE_SPACE) +import argparse +import mmcv +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +try: + from mmcv.utils import Config, DictAction +except: + from mmengine import Config, DictAction +from datetime import timedelta +import random +import numpy as np +from mono.utils.logger import setup_logger +import glob +from mono.utils.comm import init_env +from mono.model.monodepth_model import get_configured_monodepth_model +from mono.utils.running import load_ckpt +from mono.utils.do_test import do_scalecano_test_with_custom_data +from mono.utils.mldb import load_data_info, reset_ckpt_path +from mono.utils.custom_data import load_from_annos, load_data + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a segmentor') + parser.add_argument('config', help='train config file path') + parser.add_argument('--show-dir', help='the dir to save logs and visualization results') + parser.add_argument('--load-from', help='the checkpoint file to load weights from') + parser.add_argument('--node_rank', type=int, default=0) + parser.add_argument('--nnodes', type=int, default=1, help='number of nodes') + parser.add_argument('--options', nargs='+', action=DictAction, help='custom options') + parser.add_argument('--launcher', choices=['None', 'pytorch', 'slurm', 'mpi', 'ror'], default='slurm', help='job launcher') + parser.add_argument('--test_data_path', default='None', type=str, help='the path of test data') + args = parser.parse_args() + return args + +def main(args): + os.chdir(CODE_SPACE) + cfg = Config.fromfile(args.config) + + if args.options is not None: + cfg.merge_from_dict(args.options) + + # show_dir is determined in this priority: CLI > segment in file > filename + if args.show_dir is not None: + # update configs according to CLI args if args.show_dir is not None + cfg.show_dir = args.show_dir + else: + # use condig filename + timestamp as default show_dir if args.show_dir is None + cfg.show_dir = osp.join('./show_dirs', + osp.splitext(osp.basename(args.config))[0], + args.timestamp) + + # ckpt path + if args.load_from is None: + raise RuntimeError('Please set model path!') + cfg.load_from = args.load_from + + # load data info + data_info = {} + load_data_info('data_info', data_info=data_info) + cfg.mldb_info = data_info + # update check point info + reset_ckpt_path(cfg.model, data_info) + + # create show dir + os.makedirs(osp.abspath(cfg.show_dir), exist_ok=True) + + # init the logger before other steps + cfg.log_file = osp.join(cfg.show_dir, f'{args.timestamp}.log') + logger = setup_logger(cfg.log_file) + + # log some basic info + logger.info(f'Config:\n{cfg.pretty_text}') + + # init distributed env dirst, since logger depends on the dist info + if args.launcher == 'None': + cfg.distributed = False + else: + cfg.distributed = True + init_env(args.launcher, cfg) + logger.info(f'Distributed training: {cfg.distributed}') + + # dump config + cfg.dump(osp.join(cfg.show_dir, osp.basename(args.config))) + test_data_path = args.test_data_path + if not os.path.isabs(test_data_path): + test_data_path = osp.join(CODE_SPACE, test_data_path) + + if 'json' in test_data_path: + test_data = load_from_annos(test_data_path) + else: + test_data = load_data(args.test_data_path) + + if not cfg.distributed: + main_worker(0, cfg, args.launcher, test_data) + else: + # distributed training + if args.launcher == 'ror': + local_rank = cfg.dist_params.local_rank + main_worker(local_rank, cfg, args.launcher, test_data) + else: + mp.spawn(main_worker, nprocs=cfg.dist_params.num_gpus_per_node, args=(cfg, args.launcher, test_data)) + +def main_worker(local_rank: int, cfg: dict, launcher: str, test_data: list): + if cfg.distributed: + cfg.dist_params.global_rank = cfg.dist_params.node_rank * cfg.dist_params.num_gpus_per_node + local_rank + cfg.dist_params.local_rank = local_rank + + if launcher == 'ror': + init_torch_process_group(use_hvd=False) + else: + torch.cuda.set_device(local_rank) + default_timeout = timedelta(minutes=30) + dist.init_process_group( + backend=cfg.dist_params.backend, + init_method=cfg.dist_params.dist_url, + world_size=cfg.dist_params.world_size, + rank=cfg.dist_params.global_rank, + timeout=default_timeout) + + logger = setup_logger(cfg.log_file) + # build model + model = get_configured_monodepth_model(cfg, ) + + # config distributed training + if cfg.distributed: + model = torch.nn.parallel.DistributedDataParallel(model.cuda(), + device_ids=[local_rank], + output_device=local_rank, + find_unused_parameters=True) + else: + model = torch.nn.DataParallel(model).cuda() + + # load ckpt + model, _, _, _ = load_ckpt(cfg.load_from, model, strict_match=False) + model.eval() + + do_scalecano_test_with_custom_data( + model, + cfg, + test_data, + logger, + cfg.distributed, + local_rank + ) + +if __name__ == '__main__': + args = parse_args() + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + args.timestamp = timestamp + main(args) \ No newline at end of file diff --git a/mono/utils/__init__.py b/mono/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/mono/utils/__init__.py @@ -0,0 +1 @@ + diff --git a/mono/utils/__pycache__/__init__.cpython-39.pyc b/mono/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1124c910af2391269228568f250c6894520aab54 Binary files /dev/null and b/mono/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/mono/utils/__pycache__/avg_meter.cpython-39.pyc b/mono/utils/__pycache__/avg_meter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b12793a25471d7debe0eb2159c47beb8d732a51d Binary files /dev/null and b/mono/utils/__pycache__/avg_meter.cpython-39.pyc differ diff --git a/mono/utils/__pycache__/comm.cpython-39.pyc b/mono/utils/__pycache__/comm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bb531447618f36553e4a63ef17b8d874a97759c Binary files /dev/null and b/mono/utils/__pycache__/comm.cpython-39.pyc differ diff --git a/mono/utils/__pycache__/custom_data.cpython-39.pyc b/mono/utils/__pycache__/custom_data.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbcc9fc30c9e18379a164f674430f83afa87eb78 Binary files /dev/null and b/mono/utils/__pycache__/custom_data.cpython-39.pyc differ diff --git a/mono/utils/__pycache__/do_test.cpython-39.pyc b/mono/utils/__pycache__/do_test.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a32049271c1c59666f2f40d6e28775e96be8d64 Binary files /dev/null and b/mono/utils/__pycache__/do_test.cpython-39.pyc differ diff --git a/mono/utils/__pycache__/logger.cpython-39.pyc b/mono/utils/__pycache__/logger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e74e49882c0ff109592180b63ed208e949db92fa Binary files /dev/null and b/mono/utils/__pycache__/logger.cpython-39.pyc differ diff --git a/mono/utils/__pycache__/mldb.cpython-39.pyc b/mono/utils/__pycache__/mldb.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a19b422fb3981948973a2fd58165f160bf9e2824 Binary files /dev/null and b/mono/utils/__pycache__/mldb.cpython-39.pyc differ diff --git a/mono/utils/__pycache__/running.cpython-39.pyc b/mono/utils/__pycache__/running.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63270065ee2d57c2b20eeb400302824402cb0738 Binary files /dev/null and b/mono/utils/__pycache__/running.cpython-39.pyc differ diff --git a/mono/utils/__pycache__/transform.cpython-39.pyc b/mono/utils/__pycache__/transform.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7886dadc9311f0ae919a928fc157f377537b4ec Binary files /dev/null and b/mono/utils/__pycache__/transform.cpython-39.pyc differ diff --git a/mono/utils/__pycache__/unproj_pcd.cpython-39.pyc b/mono/utils/__pycache__/unproj_pcd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54df84153aadb93b17d666f2a058a7c53543ba2d Binary files /dev/null and b/mono/utils/__pycache__/unproj_pcd.cpython-39.pyc differ diff --git a/mono/utils/__pycache__/visualization.cpython-39.pyc b/mono/utils/__pycache__/visualization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..908f7aa5fecdfbc13df19333c0a8ac19a49d75a5 Binary files /dev/null and b/mono/utils/__pycache__/visualization.cpython-39.pyc differ diff --git a/mono/utils/avg_meter.py b/mono/utils/avg_meter.py new file mode 100644 index 0000000000000000000000000000000000000000..3f935df9760cee1d73c6cba00b954d03e659ccb3 --- /dev/null +++ b/mono/utils/avg_meter.py @@ -0,0 +1,475 @@ +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +import matplotlib.pyplot as plt + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self) -> None: + self.reset() + + def reset(self) -> None: + self.val = np.longdouble(0.0) + self.avg = np.longdouble(0.0) + self.sum = np.longdouble(0.0) + self.count = np.longdouble(0.0) + + def update(self, val, n: float = 1) -> None: + self.val = val + self.sum += val + self.count += n + self.avg = self.sum / (self.count + 1e-6) + +class MetricAverageMeter(AverageMeter): + """ + An AverageMeter designed specifically for evaluating segmentation results. + """ + def __init__(self, metrics: list) -> None: + """ Initialize object. """ + # average meters for metrics + self.abs_rel = AverageMeter() + self.rmse = AverageMeter() + self.silog = AverageMeter() + self.delta1 = AverageMeter() + self.delta2 = AverageMeter() + self.delta3 = AverageMeter() + + self.metrics = metrics + + self.consistency = AverageMeter() + self.log10 = AverageMeter() + self.rmse_log = AverageMeter() + self.sq_rel = AverageMeter() + + # normal + self.normal_mean = AverageMeter() + self.normal_rmse = AverageMeter() + self.normal_a1 = AverageMeter() + self.normal_a2 = AverageMeter() + + self.normal_median = AverageMeter() + self.normal_a3 = AverageMeter() + self.normal_a4 = AverageMeter() + self.normal_a5 = AverageMeter() + + + def update_metrics_cpu(self, + pred: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor,): + """ + Update metrics on cpu + """ + + assert pred.shape == target.shape + + if len(pred.shape) == 3: + pred = pred[:, None, :, :] + target = target[:, None, :, :] + mask = mask[:, None, :, :] + elif len(pred.shape) == 2: + pred = pred[None, None, :, :] + target = target[None, None, :, :] + mask = mask[None, None, :, :] + + + # Absolute relative error + abs_rel_sum, valid_pics = get_absrel_err(pred, target, mask) + abs_rel_sum = abs_rel_sum.numpy() + valid_pics = valid_pics.numpy() + self.abs_rel.update(abs_rel_sum, valid_pics) + + # squared relative error + sqrel_sum, _ = get_sqrel_err(pred, target, mask) + sqrel_sum = sqrel_sum.numpy() + self.sq_rel.update(sqrel_sum, valid_pics) + + # root mean squared error + rmse_sum, _ = get_rmse_err(pred, target, mask) + rmse_sum = rmse_sum.numpy() + self.rmse.update(rmse_sum, valid_pics) + + # log root mean squared error + log_rmse_sum, _ = get_rmse_log_err(pred, target, mask) + log_rmse_sum = log_rmse_sum.numpy() + self.rmse.update(log_rmse_sum, valid_pics) + + # log10 error + log10_sum, _ = get_log10_err(pred, target, mask) + log10_sum = log10_sum.numpy() + self.rmse.update(log10_sum, valid_pics) + + # scale-invariant root mean squared error in log space + silog_sum, _ = get_silog_err(pred, target, mask) + silog_sum = silog_sum.numpy() + self.silog.update(silog_sum, valid_pics) + + # ratio error, delta1, .... + delta1_sum, delta2_sum, delta3_sum, _ = get_ratio_error(pred, target, mask) + delta1_sum = delta1_sum.numpy() + delta2_sum = delta2_sum.numpy() + delta3_sum = delta3_sum.numpy() + + self.delta1.update(delta1_sum, valid_pics) + self.delta2.update(delta1_sum, valid_pics) + self.delta3.update(delta1_sum, valid_pics) + + + def update_metrics_gpu( + self, + pred: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + is_distributed: bool, + pred_next: torch.tensor = None, + pose_f1_to_f2: torch.tensor = None, + intrinsic: torch.tensor = None): + """ + Update metric on GPU. It supports distributed processing. If multiple machines are employed, please + set 'is_distributed' as True. + """ + assert pred.shape == target.shape + + if len(pred.shape) == 3: + pred = pred[:, None, :, :] + target = target[:, None, :, :] + mask = mask[:, None, :, :] + elif len(pred.shape) == 2: + pred = pred[None, None, :, :] + target = target[None, None, :, :] + mask = mask[None, None, :, :] + + + # Absolute relative error + abs_rel_sum, valid_pics = get_absrel_err(pred, target, mask) + if is_distributed: + dist.all_reduce(abs_rel_sum), dist.all_reduce(valid_pics) + abs_rel_sum = abs_rel_sum.cpu().numpy() + valid_pics = int(valid_pics) + self.abs_rel.update(abs_rel_sum, valid_pics) + + # root mean squared error + rmse_sum, _ = get_rmse_err(pred, target, mask) + if is_distributed: + dist.all_reduce(rmse_sum) + rmse_sum = rmse_sum.cpu().numpy() + self.rmse.update(rmse_sum, valid_pics) + + # log root mean squared error + log_rmse_sum, _ = get_rmse_log_err(pred, target, mask) + if is_distributed: + dist.all_reduce(log_rmse_sum) + log_rmse_sum = log_rmse_sum.cpu().numpy() + self.rmse_log.update(log_rmse_sum, valid_pics) + + # log10 error + log10_sum, _ = get_log10_err(pred, target, mask) + if is_distributed: + dist.all_reduce(log10_sum) + log10_sum = log10_sum.cpu().numpy() + self.log10.update(log10_sum, valid_pics) + + # scale-invariant root mean squared error in log space + silog_sum, _ = get_silog_err(pred, target, mask) + if is_distributed: + dist.all_reduce(silog_sum) + silog_sum = silog_sum.cpu().numpy() + self.silog.update(silog_sum, valid_pics) + + # ratio error, delta1, .... + delta1_sum, delta2_sum, delta3_sum, _ = get_ratio_error(pred, target, mask) + if is_distributed: + dist.all_reduce(delta1_sum), dist.all_reduce(delta2_sum), dist.all_reduce(delta3_sum) + delta1_sum = delta1_sum.cpu().numpy() + delta2_sum = delta2_sum.cpu().numpy() + delta3_sum = delta3_sum.cpu().numpy() + + self.delta1.update(delta1_sum, valid_pics) + self.delta2.update(delta2_sum, valid_pics) + self.delta3.update(delta3_sum, valid_pics) + + # video consistency error + consistency_rel_sum, valid_warps = get_video_consistency_err(pred, pred_next, pose_f1_to_f2, intrinsic) + if is_distributed: + dist.all_reduce(consistency_rel_sum), dist.all_reduce(valid_warps) + consistency_rel_sum = consistency_rel_sum.cpu().numpy() + valid_warps = int(valid_warps) + self.consistency.update(consistency_rel_sum, valid_warps) + + ## for surface normal + def update_normal_metrics_gpu( + self, + pred: torch.Tensor, # (B, 3, H, W) + target: torch.Tensor, # (B, 3, H, W) + mask: torch.Tensor, # (B, 1, H, W) + is_distributed: bool, + ): + """ + Update metric on GPU. It supports distributed processing. If multiple machines are employed, please + set 'is_distributed' as True. + """ + assert pred.shape == target.shape + + valid_pics = torch.sum(mask, dtype=torch.float32) + 1e-6 + + if valid_pics < 10: + return + + mean_error = rmse_error = a1_error = a2_error = dist_node_cnt = valid_pics + normal_error = torch.cosine_similarity(pred, target, dim=1) + normal_error = torch.clamp(normal_error, min=-1.0, max=1.0) + angle_error = torch.acos(normal_error) * 180.0 / torch.pi + angle_error = angle_error[:, None, :, :] + angle_error = angle_error[mask] + # Calculation error + mean_error = angle_error.sum() / valid_pics + rmse_error = torch.sqrt( torch.sum(torch.square(angle_error)) / valid_pics ) + median_error = angle_error.median() + a1_error = 100.0 * (torch.sum(angle_error < 5) / valid_pics) + a2_error = 100.0 * (torch.sum(angle_error < 7.5) / valid_pics) + + a3_error = 100.0 * (torch.sum(angle_error < 11.25) / valid_pics) + a4_error = 100.0 * (torch.sum(angle_error < 22.5) / valid_pics) + a5_error = 100.0 * (torch.sum(angle_error < 30) / valid_pics) + + # if valid_pics > 1e-5: + # If the current node gets data with valid normal + dist_node_cnt = (valid_pics - 1e-6) / valid_pics + + if is_distributed: + dist.all_reduce(dist_node_cnt) + dist.all_reduce(mean_error) + dist.all_reduce(rmse_error) + dist.all_reduce(a1_error) + dist.all_reduce(a2_error) + + dist.all_reduce(a3_error) + dist.all_reduce(a4_error) + dist.all_reduce(a5_error) + + dist_node_cnt = dist_node_cnt.cpu().numpy() + self.normal_mean.update(mean_error.cpu().numpy(), dist_node_cnt) + self.normal_rmse.update(rmse_error.cpu().numpy(), dist_node_cnt) + self.normal_a1.update(a1_error.cpu().numpy(), dist_node_cnt) + self.normal_a2.update(a2_error.cpu().numpy(), dist_node_cnt) + + self.normal_median.update(median_error.cpu().numpy(), dist_node_cnt) + self.normal_a3.update(a3_error.cpu().numpy(), dist_node_cnt) + self.normal_a4.update(a4_error.cpu().numpy(), dist_node_cnt) + self.normal_a5.update(a5_error.cpu().numpy(), dist_node_cnt) + + + def get_metrics(self,): + """ + """ + metrics_dict = {} + for metric in self.metrics: + metrics_dict[metric] = self.__getattribute__(metric).avg + return metrics_dict + + + def get_metrics(self,): + """ + """ + metrics_dict = {} + for metric in self.metrics: + metrics_dict[metric] = self.__getattribute__(metric).avg + return metrics_dict + +def get_absrel_err(pred: torch.tensor, + target: torch.tensor, + mask: torch.tensor, + ): + """ + Computes absolute relative error. + Tasks preprocessed depths (no nans, infs and non-positive values). + pred, target, and mask should be in the shape of [b, c, h, w] + """ + + assert len(pred.shape) == 4, len(target.shape) == 4 + b, c, h, w = pred.shape + mask = mask.to(torch.float) + t_m = target * mask + p_m = pred * mask + + # Mean Absolute Relative Error + rel = torch.abs(t_m - p_m) / (t_m + 1e-10) # compute errors + abs_rel_sum = torch.sum(rel.reshape((b, c, -1)), dim=2) # [b, c] + num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] + abs_err = abs_rel_sum / (num + 1e-10) + valid_pics = torch.sum(num > 0) + return torch.sum(abs_err), valid_pics + +def get_sqrel_err(pred: torch.tensor, + target: torch.tensor, + mask: torch.tensor, + ): + """ + Computes squared relative error. + Tasks preprocessed depths (no nans, infs and non-positive values). + pred, target, and mask should be in the shape of [b, c, h, w] + """ + + assert len(pred.shape) == 4, len(target.shape) == 4 + b, c, h, w = pred.shape + mask = mask.to(torch.float) + t_m = target * mask + p_m = pred * mask + + # squared Relative Error + sq_rel = torch.abs(t_m - p_m) ** 2 / (t_m + 1e-10) # compute errors + sq_rel_sum = torch.sum(sq_rel.reshape((b, c, -1)), dim=2) # [b, c] + num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] + sqrel_err = sq_rel_sum / (num + 1e-10) + valid_pics = torch.sum(num > 0) + return torch.sum(sqrel_err), valid_pics + +def get_log10_err(pred: torch.tensor, + target: torch.tensor, + mask: torch.tensor, + ): + """ + Computes log10 error. + Tasks preprocessed depths (no nans, infs and non-positive values). + pred, target, and mask should be in the shape of [b, c, h, w] + """ + + assert len(pred.shape) == 4, len(target.shape) == 4 + b, c, h, w = pred.shape + mask = mask.to(torch.float) + t_m = target * mask + p_m = pred * mask + + diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask + log10_diff = torch.abs(diff_log) + log10_sum = torch.sum(log10_diff.reshape((b, c, -1)), dim=2) # [b, c] + num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] + log10_err = log10_sum / (num + 1e-10) + valid_pics = torch.sum(num > 0) + return torch.sum(log10_err), valid_pics + +def get_rmse_err(pred: torch.tensor, + target: torch.tensor, + mask: torch.tensor, + ): + """ + Computes rmse error. + Tasks preprocessed depths (no nans, infs and non-positive values). + pred, target, and mask should be in the shape of [b, c, h, w] + """ + + assert len(pred.shape) == 4, len(target.shape) == 4 + b, c, h, w = pred.shape + mask = mask.to(torch.float) + t_m = target * mask + p_m = pred * mask + + square = (t_m - p_m) ** 2 + rmse_sum = torch.sum(square.reshape((b, c, -1)), dim=2) # [b, c] + num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] + rmse = torch.sqrt(rmse_sum / (num + 1e-10)) + valid_pics = torch.sum(num > 0) + return torch.sum(rmse), valid_pics + +def get_rmse_log_err(pred: torch.tensor, + target: torch.tensor, + mask: torch.tensor, + ): + """ + Computes log rmse error. + Tasks preprocessed depths (no nans, infs and non-positive values). + pred, target, and mask should be in the shape of [b, c, h, w] + """ + + assert len(pred.shape) == 4, len(target.shape) == 4 + b, c, h, w = pred.shape + mask = mask.to(torch.float) + t_m = target * mask + p_m = pred * mask + + diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask + square = diff_log ** 2 + rmse_log_sum = torch.sum(square.reshape((b, c, -1)), dim=2) # [b, c] + num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] + rmse_log = torch.sqrt(rmse_log_sum / (num + 1e-10)) + valid_pics = torch.sum(num > 0) + return torch.sum(rmse_log), valid_pics + +def get_silog_err(pred: torch.tensor, + target: torch.tensor, + mask: torch.tensor, + ): + """ + Computes log rmse error. + Tasks preprocessed depths (no nans, infs and non-positive values). + pred, target, and mask should be in the shape of [b, c, h, w] + """ + + assert len(pred.shape) == 4, len(target.shape) == 4 + b, c, h, w = pred.shape + mask = mask.to(torch.float) + t_m = target * mask + p_m = pred * mask + + diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask + diff_log_sum = torch.sum(diff_log.reshape((b, c, -1)), dim=2) # [b, c] + diff_log_square = diff_log ** 2 + diff_log_square_sum = torch.sum(diff_log_square.reshape((b, c, -1)), dim=2) # [b, c] + num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] + silog = torch.sqrt(diff_log_square_sum / (num + 1e-10) - (diff_log_sum / (num + 1e-10)) ** 2) + valid_pics = torch.sum(num > 0) + return torch.sum(silog), valid_pics + +def get_ratio_err(pred: torch.tensor, + target: torch.tensor, + mask: torch.tensor, + ): + """ + Computes the percentage of pixels for which the ratio of the two depth maps is less than a given threshold. + Tasks preprocessed depths (no nans, infs and non-positive values). + pred, target, and mask should be in the shape of [b, c, h, w] + """ + assert len(pred.shape) == 4, len(target.shape) == 4 + b, c, h, w = pred.shape + mask = mask.to(torch.float) + t_m = target * mask + p_m = pred + + gt_pred = t_m / (p_m + 1e-10) + pred_gt = p_m / (t_m + 1e-10) + gt_pred = gt_pred.reshape((b, c, -1)) + pred_gt = pred_gt.reshape((b, c, -1)) + gt_pred_gt = torch.cat((gt_pred, pred_gt), axis=1) + ratio_max = torch.amax(gt_pred_gt, axis=1) + + delta_1_sum = torch.sum((ratio_max < 1.25), dim=1) # [b, ] + delta_2_sum = torch.sum((ratio_max < 1.25 ** 2), dim=1) # [b, ] + delta_3_sum = torch.sum((ratio_max < 1.25 ** 3), dim=1) # [b, ] + num = torch.sum(mask.reshape((b, -1)), dim=1) # [b, ] + + delta_1 = delta_1_sum / (num + 1e-10) + delta_2 = delta_2_sum / (num + 1e-10) + delta_3 = delta_3_sum / (num + 1e-10) + valid_pics = torch.sum(num > 0) + + return torch.sum(delta_1), torch.sum(delta_2), torch.sum(delta_3), valid_pics + + +if __name__ == '__main__': + cfg = ['abs_rel', 'delta1'] + dam = MetricAverageMeter(cfg) + + pred_depth = np.random.random([2, 480, 640]) + gt_depth = np.random.random([2, 480, 640]) - 0.5 + intrinsic = [[100, 100, 200, 200], [200, 200, 300, 300]] + + pred = torch.from_numpy(pred_depth).cuda() + gt = torch.from_numpy(gt_depth).cuda() + + mask = gt > 0 + dam.update_metrics_gpu(pred, gt, mask, False) + eval_error = dam.get_metrics() + print(eval_error) + \ No newline at end of file diff --git a/mono/utils/comm.py b/mono/utils/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..939e4e175c14563d5d13e77e6b56fd1a34668ebf --- /dev/null +++ b/mono/utils/comm.py @@ -0,0 +1,322 @@ +import importlib +import torch +import torch.distributed as dist +from .avg_meter import AverageMeter +from collections import defaultdict, OrderedDict +import os +import socket +from mmcv.utils import collect_env as collect_base_env +try: + from mmcv.utils import get_git_hash +except: + from mmengine.utils import get_git_hash +#import mono.mmseg as mmseg +# import mmseg +import time +import datetime +import logging + + +def main_process() -> bool: + return get_rank() == 0 + #return not cfg.distributed or \ + # (cfg.distributed and cfg.local_rank == 0) + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + +def _find_free_port(): + # refer to https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(('', 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + +def _is_free_port(port): + ips = socket.gethostbyname_ex(socket.gethostname())[-1] + ips.append('localhost') + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return all(s.connect_ex((ip, port)) != 0 for ip in ips) + + +# def collect_env(): +# """Collect the information of the running environments.""" +# env_info = collect_base_env() +# env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' + +# return env_info + +def init_env(launcher, cfg): + """Initialize distributed training environment. + If argument ``cfg.dist_params.dist_url`` is specified as 'env://', then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + """ + if launcher == 'slurm': + _init_dist_slurm(cfg) + elif launcher == 'ror': + _init_dist_ror(cfg) + elif launcher == 'None': + _init_none_dist(cfg) + else: + raise RuntimeError(f'{cfg.launcher} has not been supported!') + +def _init_none_dist(cfg): + cfg.dist_params.num_gpus_per_node = 1 + cfg.dist_params.world_size = 1 + cfg.dist_params.nnodes = 1 + cfg.dist_params.node_rank = 0 + cfg.dist_params.global_rank = 0 + cfg.dist_params.local_rank = 0 + os.environ["WORLD_SIZE"] = str(1) + +def _init_dist_ror(cfg): + from ac2.ror.comm import get_local_rank, get_world_rank, get_local_size, get_node_rank, get_world_size + cfg.dist_params.num_gpus_per_node = get_local_size() + cfg.dist_params.world_size = get_world_size() + cfg.dist_params.nnodes = (get_world_size()) // (get_local_size()) + cfg.dist_params.node_rank = get_node_rank() + cfg.dist_params.global_rank = get_world_rank() + cfg.dist_params.local_rank = get_local_rank() + os.environ["WORLD_SIZE"] = str(get_world_size()) + + +def _init_dist_slurm(cfg): + if 'NNODES' not in os.environ: + os.environ['NNODES'] = str(cfg.dist_params.nnodes) + if 'NODE_RANK' not in os.environ: + os.environ['NODE_RANK'] = str(cfg.dist_params.node_rank) + + #cfg.dist_params. + num_gpus = torch.cuda.device_count() + world_size = int(os.environ['NNODES']) * num_gpus + os.environ['WORLD_SIZE'] = str(world_size) + + # config port + if 'MASTER_PORT' in os.environ: + master_port = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable + else: + # if torch.distributed default port(29500) is available + # then use it, else find a free port + if _is_free_port(16500): + master_port = '16500' + else: + master_port = str(_find_free_port()) + os.environ['MASTER_PORT'] = master_port + + # config addr + if 'MASTER_ADDR' in os.environ: + master_addr = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable + # elif cfg.dist_params.dist_url is not None: + # master_addr = ':'.join(str(cfg.dist_params.dist_url).split(':')[:2]) + else: + master_addr = '127.0.0.1' #'tcp://127.0.0.1' + os.environ['MASTER_ADDR'] = master_addr + + # set dist_url to 'env://' + cfg.dist_params.dist_url = 'env://' #f"{master_addr}:{master_port}" + + cfg.dist_params.num_gpus_per_node = num_gpus + cfg.dist_params.world_size = world_size + cfg.dist_params.nnodes = int(os.environ['NNODES']) + cfg.dist_params.node_rank = int(os.environ['NODE_RANK']) + + # if int(os.environ['NNODES']) > 1 and cfg.dist_params.dist_url.startswith("file://"): + # raise Warning("file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://") + + +def get_func(func_name): + """ + Helper to return a function object by name. func_name must identify + a function in this module or the path to a function relative to the base + module. + @ func_name: function name. + """ + if func_name == '': + return None + try: + parts = func_name.split('.') + # Refers to a function in this module + if len(parts) == 1: + return globals()[parts[0]] + # Otherwise, assume we're referencing a module under modeling + module_name = '.'.join(parts[:-1]) + module = importlib.import_module(module_name) + return getattr(module, parts[-1]) + except: + raise RuntimeError(f'Failed to find function: {func_name}') + +class Timer(object): + """A simple timer.""" + + def __init__(self): + self.reset() + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + self.total_time += self.diff + self.calls += 1 + self.average_time = self.total_time / self.calls + if average: + return self.average_time + else: + return self.diff + + def reset(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + +class TrainingStats(object): + """Track vital training statistics.""" + def __init__(self, log_period, tensorboard_logger=None): + self.log_period = log_period + self.tblogger = tensorboard_logger + self.tb_ignored_keys = ['iter', 'eta', 'epoch', 'time'] + self.iter_timer = Timer() + # Window size for smoothing tracked values (with median filtering) + self.filter_size = log_period + def create_smoothed_value(): + return AverageMeter() + self.smoothed_losses = defaultdict(create_smoothed_value) + #self.smoothed_metrics = defaultdict(create_smoothed_value) + #self.smoothed_total_loss = AverageMeter() + + + def IterTic(self): + self.iter_timer.tic() + + def IterToc(self): + return self.iter_timer.toc(average=False) + + def reset_iter_time(self): + self.iter_timer.reset() + + def update_iter_stats(self, losses_dict): + """Update tracked iteration statistics.""" + for k, v in losses_dict.items(): + self.smoothed_losses[k].update(float(v), 1) + + def log_iter_stats(self, cur_iter, optimizer, max_iters, val_err={}): + """Log the tracked statistics.""" + if (cur_iter % self.log_period == 0): + stats = self.get_stats(cur_iter, optimizer, max_iters, val_err) + log_stats(stats) + if self.tblogger: + self.tb_log_stats(stats, cur_iter) + for k, v in self.smoothed_losses.items(): + v.reset() + + def tb_log_stats(self, stats, cur_iter): + """Log the tracked statistics to tensorboard""" + for k in stats: + # ignore some logs + if k not in self.tb_ignored_keys: + v = stats[k] + if isinstance(v, dict): + self.tb_log_stats(v, cur_iter) + else: + self.tblogger.add_scalar(k, v, cur_iter) + + + def get_stats(self, cur_iter, optimizer, max_iters, val_err = {}): + eta_seconds = self.iter_timer.average_time * (max_iters - cur_iter) + + eta = str(datetime.timedelta(seconds=int(eta_seconds))) + stats = OrderedDict( + iter=cur_iter, # 1-indexed + time=self.iter_timer.average_time, + eta=eta, + ) + optimizer_state_dict = optimizer.state_dict() + lr = {} + for i in range(len(optimizer_state_dict['param_groups'])): + lr_name = 'group%d_lr' % i + lr[lr_name] = optimizer_state_dict['param_groups'][i]['lr'] + + stats['lr'] = OrderedDict(lr) + for k, v in self.smoothed_losses.items(): + stats[k] = v.avg + + stats['val_err'] = OrderedDict(val_err) + stats['max_iters'] = max_iters + return stats + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + Args: + @input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. + @average (bool): whether to do average or sum + Returns: + a dict with the same keys as input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +def log_stats(stats): + logger = logging.getLogger() + """Log training statistics to terminal""" + lines = "[Step %d/%d]\n" % ( + stats['iter'], stats['max_iters']) + + lines += "\t\tloss: %.3f, time: %.6f, eta: %s\n" % ( + stats['total_loss'], stats['time'], stats['eta']) + + # log loss + lines += "\t\t" + for k, v in stats.items(): + if 'loss' in k.lower() and 'total_loss' not in k.lower(): + lines += "%s: %.3f" % (k, v) + ", " + lines = lines[:-3] + lines += '\n' + + # validate criteria + lines += "\t\tlast val err:" + ", ".join("%s: %.6f" % (k, v) for k, v in stats['val_err'].items()) + ", " + lines += '\n' + + # lr in different groups + lines += "\t\t" + ", ".join("%s: %.8f" % (k, v) for k, v in stats['lr'].items()) + lines += '\n' + logger.info(lines[:-1]) # remove last new linen_pxl + diff --git a/mono/utils/custom_data.py b/mono/utils/custom_data.py new file mode 100644 index 0000000000000000000000000000000000000000..d9fab47478bc471c51b5454cc15550079ebec21b --- /dev/null +++ b/mono/utils/custom_data.py @@ -0,0 +1,34 @@ +import glob +import os +import json +import cv2 + +def load_from_annos(anno_path): + with open(anno_path, 'r') as f: + annos = json.load(f)['files'] + + datas = [] + for i, anno in enumerate(annos): + rgb = anno['rgb'] + depth = anno['depth'] if 'depth' in anno else None + depth_scale = anno['depth_scale'] if 'depth_scale' in anno else 1.0 + intrinsic = anno['cam_in'] if 'cam_in' in anno else None + normal = anno['normal'] if 'normal' in anno else None + + data_i = { + 'rgb': rgb, + 'depth': depth, + 'depth_scale': depth_scale, + 'intrinsic': intrinsic, + 'filename': os.path.basename(rgb), + 'folder': rgb.split('/')[-3], + 'normal': normal + } + datas.append(data_i) + return datas + +def load_data(path: str): + rgbs = glob.glob(path + '/*.jpg') + glob.glob(path + '/*.png') + #intrinsic = [835.8179931640625, 835.8179931640625, 961.5419921875, 566.8090209960938] #[721.53769, 721.53769, 609.5593, 172.854] + data = [{'rgb': i, 'depth': None, 'intrinsic': None, 'filename': os.path.basename(i), 'folder': i.split('/')[-3]} for i in rgbs] + return data \ No newline at end of file diff --git a/mono/utils/do_test.py b/mono/utils/do_test.py new file mode 100644 index 0000000000000000000000000000000000000000..89ee4afc9d6cd67ec491af6726c850347cafc099 --- /dev/null +++ b/mono/utils/do_test.py @@ -0,0 +1,364 @@ +import torch +import torch.nn.functional as F +import logging +import os +import os.path as osp +from mono.utils.avg_meter import MetricAverageMeter +from mono.utils.visualization import save_val_imgs, create_html, save_raw_imgs, save_normal_val_imgs +import cv2 +from tqdm import tqdm +import numpy as np +from PIL import Image +import matplotlib.pyplot as plt + +from mono.utils.unproj_pcd import reconstruct_pcd, save_point_cloud + +def to_cuda(data: dict): + for k, v in data.items(): + if isinstance(v, torch.Tensor): + data[k] = v.cuda(non_blocking=True) + if isinstance(v, list) and len(v)>=1 and isinstance(v[0], torch.Tensor): + for i, l_i in enumerate(v): + data[k][i] = l_i.cuda(non_blocking=True) + return data + +def align_scale(pred: torch.tensor, target: torch.tensor): + mask = target > 0 + if torch.sum(mask) > 10: + scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8) + else: + scale = 1 + pred_scaled = pred * scale + return pred_scaled, scale + +def align_scale_shift(pred: torch.tensor, target: torch.tensor): + mask = target > 0 + target_mask = target[mask].cpu().numpy() + pred_mask = pred[mask].cpu().numpy() + if torch.sum(mask) > 10: + scale, shift = np.polyfit(pred_mask, target_mask, deg=1) + if scale < 0: + scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8) + shift = 0 + else: + scale = 1 + shift = 0 + pred = pred * scale + shift + return pred, scale + +def align_scale_shift_numpy(pred: np.array, target: np.array): + mask = target > 0 + target_mask = target[mask] + pred_mask = pred[mask] + if np.sum(mask) > 10: + scale, shift = np.polyfit(pred_mask, target_mask, deg=1) + if scale < 0: + scale = np.median(target[mask]) / (np.median(pred[mask]) + 1e-8) + shift = 0 + else: + scale = 1 + shift = 0 + pred = pred * scale + shift + return pred, scale + + +def build_camera_model(H : int, W : int, intrinsics : list) -> np.array: + """ + Encode the camera intrinsic parameters (focal length and principle point) to a 4-channel map. + """ + fx, fy, u0, v0 = intrinsics + f = (fx + fy) / 2.0 + # principle point location + x_row = np.arange(0, W).astype(np.float32) + x_row_center_norm = (x_row - u0) / W + x_center = np.tile(x_row_center_norm, (H, 1)) # [H, W] + + y_col = np.arange(0, H).astype(np.float32) + y_col_center_norm = (y_col - v0) / H + y_center = np.tile(y_col_center_norm, (W, 1)).T # [H, W] + + # FoV + fov_x = np.arctan(x_center / (f / W)) + fov_y = np.arctan(y_center / (f / H)) + + cam_model = np.stack([x_center, y_center, fov_x, fov_y], axis=2) + return cam_model + +def resize_for_input(image, output_shape, intrinsic, canonical_shape, to_canonical_ratio): + """ + Resize the input. + Resizing consists of two processed, i.e. 1) to the canonical space (adjust the camera model); 2) resize the image while the camera model holds. Thus the + label will be scaled with the resize factor. + """ + padding = [123.675, 116.28, 103.53] + h, w, _ = image.shape + resize_ratio_h = output_shape[0] / canonical_shape[0] + resize_ratio_w = output_shape[1] / canonical_shape[1] + to_scale_ratio = min(resize_ratio_h, resize_ratio_w) + + resize_ratio = to_canonical_ratio * to_scale_ratio + + reshape_h = int(resize_ratio * h) + reshape_w = int(resize_ratio * w) + + pad_h = max(output_shape[0] - reshape_h, 0) + pad_w = max(output_shape[1] - reshape_w, 0) + pad_h_half = int(pad_h / 2) + pad_w_half = int(pad_w / 2) + + # resize + image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR) + # padding + image = cv2.copyMakeBorder( + image, + pad_h_half, + pad_h - pad_h_half, + pad_w_half, + pad_w - pad_w_half, + cv2.BORDER_CONSTANT, + value=padding) + + # Resize, adjust principle point + intrinsic[2] = intrinsic[2] * to_scale_ratio + intrinsic[3] = intrinsic[3] * to_scale_ratio + + cam_model = build_camera_model(reshape_h, reshape_w, intrinsic) + cam_model = cv2.copyMakeBorder( + cam_model, + pad_h_half, + pad_h - pad_h_half, + pad_w_half, + pad_w - pad_w_half, + cv2.BORDER_CONSTANT, + value=-1) + + pad=[pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half] + label_scale_factor=1/to_scale_ratio + return image, cam_model, pad, label_scale_factor + + +def get_prediction( + model: torch.nn.Module, + input: torch.tensor, + cam_model: torch.tensor, + pad_info: torch.tensor, + scale_info: torch.tensor, + gt_depth: torch.tensor, + normalize_scale: float, + ori_shape: list=[], +): + + data = dict( + input=input, + cam_model=cam_model, + ) + pred_depth, confidence, output_dict = model.module.inference(data) + pred_depth = pred_depth + pred_depth = pred_depth.squeeze() + pred_depth = pred_depth[pad_info[0] : pred_depth.shape[0] - pad_info[1], pad_info[2] : pred_depth.shape[1] - pad_info[3]] + if gt_depth is not None: + resize_shape = gt_depth.shape + elif ori_shape != []: + resize_shape = ori_shape + else: + resize_shape = pred_depth.shape + + pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], resize_shape, mode='bilinear').squeeze() # to original size + pred_depth = pred_depth * normalize_scale / scale_info + if gt_depth is not None: + pred_depth_scale, scale = align_scale(pred_depth, gt_depth) + else: + pred_depth_scale = None + scale = None + + return pred_depth, pred_depth_scale, scale, output_dict + +def transform_test_data_scalecano(rgb, intrinsic, data_basic): + """ + Pre-process the input for forwarding. Employ `label scale canonical transformation.' + Args: + rgb: input rgb image. [H, W, 3] + intrinsic: camera intrinsic parameter, [fx, fy, u0, v0] + data_basic: predefined canonical space in configs. + """ + canonical_space = data_basic['canonical_space'] + forward_size = data_basic.crop_size + mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None] + std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None] + + # BGR to RGB + rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) + + ori_h, ori_w, _ = rgb.shape + ori_focal = (intrinsic[0] + intrinsic[1]) / 2 + canonical_focal = canonical_space['focal_length'] + + cano_label_scale_ratio = canonical_focal / ori_focal + + canonical_intrinsic = [ + intrinsic[0] * cano_label_scale_ratio, + intrinsic[1] * cano_label_scale_ratio, + intrinsic[2], + intrinsic[3], + ] + + # resize + rgb, cam_model, pad, resize_label_scale_ratio = resize_for_input(rgb, forward_size, canonical_intrinsic, [ori_h, ori_w], 1.0) + + # label scale factor + label_scale_factor = cano_label_scale_ratio * resize_label_scale_ratio + + rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float() + rgb = torch.div((rgb - mean), std) + rgb = rgb[None, :, :, :].cuda() + + cam_model = torch.from_numpy(cam_model.transpose((2, 0, 1))).float() + cam_model = cam_model[None, :, :, :].cuda() + cam_model_stacks = [ + torch.nn.functional.interpolate(cam_model, size=(cam_model.shape[2]//i, cam_model.shape[3]//i), mode='bilinear', align_corners=False) + for i in [2, 4, 8, 16, 32] + ] + return rgb, cam_model_stacks, pad, label_scale_factor + +def do_scalecano_test_with_custom_data( + model: torch.nn.Module, + cfg: dict, + test_data: list, + logger: logging.RootLogger, + is_distributed: bool = True, + local_rank: int = 0, +): + + show_dir = cfg.show_dir + save_interval = 1 + save_imgs_dir = show_dir + '/vis' + os.makedirs(save_imgs_dir, exist_ok=True) + save_pcd_dir = show_dir + '/pcd' + os.makedirs(save_pcd_dir, exist_ok=True) + + normalize_scale = cfg.data_basic.depth_range[1] + dam = MetricAverageMeter(['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']) + dam_median = MetricAverageMeter(['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']) + dam_global = MetricAverageMeter(['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']) + + for i, an in tqdm(enumerate(test_data)): + #for i, an in enumerate(test_data): + print(an['rgb']) + rgb_origin = cv2.imread(an['rgb'])[:, :, ::-1].copy() + if an['depth'] is not None: + gt_depth = cv2.imread(an['depth'], -1) + gt_depth_scale = an['depth_scale'] + gt_depth = gt_depth / gt_depth_scale + gt_depth_flag = True + else: + gt_depth = None + gt_depth_flag = False + intrinsic = an['intrinsic'] + if intrinsic is None: + intrinsic = [1000.0, 1000.0, rgb_origin.shape[1]/2, rgb_origin.shape[0]/2] + # intrinsic = [542.0, 542.0, 963.706, 760.199] + print(intrinsic) + rgb_input, cam_models_stacks, pad, label_scale_factor = transform_test_data_scalecano(rgb_origin, intrinsic, cfg.data_basic) + + pred_depth, pred_depth_scale, scale, output = get_prediction( + model = model, + input = rgb_input, + cam_model = cam_models_stacks, + pad_info = pad, + scale_info = label_scale_factor, + gt_depth = None, + normalize_scale = normalize_scale, + ori_shape=[rgb_origin.shape[0], rgb_origin.shape[1]], + ) + + pred_depth = (pred_depth > 0) * (pred_depth < 300) * pred_depth + if gt_depth_flag: + + pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], (gt_depth.shape[0], gt_depth.shape[1]), mode='bilinear').squeeze() # to original size + + gt_depth = torch.from_numpy(gt_depth).cuda() + + pred_depth_median = pred_depth * gt_depth[gt_depth != 0].median() / pred_depth[gt_depth != 0].median() + pred_global, _ = align_scale_shift(pred_depth, gt_depth) + + mask = (gt_depth > 1e-8) + dam.update_metrics_gpu(pred_depth, gt_depth, mask, is_distributed) + dam_median.update_metrics_gpu(pred_depth_median, gt_depth, mask, is_distributed) + dam_global.update_metrics_gpu(pred_global, gt_depth, mask, is_distributed) + print(gt_depth[gt_depth != 0].median() / pred_depth[gt_depth != 0].median(), ) + + if i % save_interval == 0: + os.makedirs(osp.join(save_imgs_dir, an['folder']), exist_ok=True) + rgb_torch = torch.from_numpy(rgb_origin).to(pred_depth.device).permute(2, 0, 1) + mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None].to(rgb_torch.device) + std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None].to(rgb_torch.device) + rgb_torch = torch.div((rgb_torch - mean), std) + + save_val_imgs( + i, + pred_depth, + gt_depth if gt_depth is not None else torch.ones_like(pred_depth, device=pred_depth.device), + rgb_torch, + osp.join(an['folder'], an['filename']), + save_imgs_dir, + ) + #save_raw_imgs(pred_depth.detach().cpu().numpy(), rgb_torch, osp.join(an['folder'], an['filename']), save_imgs_dir, 1000.0) + + # pcd + pred_depth = pred_depth.detach().cpu().numpy() + #pcd = reconstruct_pcd(pred_depth, intrinsic[0], intrinsic[1], intrinsic[2], intrinsic[3]) + #os.makedirs(osp.join(save_pcd_dir, an['folder']), exist_ok=True) + #save_point_cloud(pcd.reshape((-1, 3)), rgb_origin.reshape(-1, 3), osp.join(save_pcd_dir, an['folder'], an['filename'][:-4]+'.ply')) + + if an['intrinsic'] == None: + #for r in [0.9, 1.0, 1.1]: + for r in [1.0]: + #for f in [600, 800, 1000, 1250, 1500]: + for f in [1000]: + pcd = reconstruct_pcd(pred_depth, f * r, f * (2-r), intrinsic[2], intrinsic[3]) + fstr = '_fx_' + str(int(f * r)) + '_fy_' + str(int(f * (2-r))) + os.makedirs(osp.join(save_pcd_dir, an['folder']), exist_ok=True) + save_point_cloud(pcd.reshape((-1, 3)), rgb_origin.reshape(-1, 3), osp.join(save_pcd_dir, an['folder'], an['filename'][:-4] + fstr +'.ply')) + + if "normal_out_list" in output.keys(): + + normal_out_list = output['normal_out_list'] + pred_normal = normal_out_list[0][:, :3, :, :] # (B, 3, H, W) + H, W = pred_normal.shape[2:] + pred_normal = pred_normal[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3]] + + gt_normal = None + #if gt_normal_flag: + if False: + pred_normal = torch.nn.functional.interpolate(pred_normal, size=gt_normal.shape[2:], mode='bilinear', align_corners=True) + gt_normal = cv2.imread(norm_path) + gt_normal = cv2.cvtColor(gt_normal, cv2.COLOR_BGR2RGB) + gt_normal = np.array(gt_normal).astype(np.uint8) + gt_normal = ((gt_normal.astype(np.float32) / 255.0) * 2.0) - 1.0 + norm_valid_mask = (np.linalg.norm(gt_normal, axis=2, keepdims=True) > 0.5) + gt_normal = gt_normal * norm_valid_mask + gt_normal_mask = ~torch.all(gt_normal == 0, dim=1, keepdim=True) + dam.update_normal_metrics_gpu(pred_normal, gt_normal, gt_normal_mask, cfg.distributed)# save valiad normal + + if i % save_interval == 0: + save_normal_val_imgs(iter, + pred_normal, + gt_normal if gt_normal is not None else torch.ones_like(pred_normal, device=pred_normal.device), + rgb_torch, # data['input'], + osp.join(an['folder'], 'normal_'+an['filename']), + save_imgs_dir, + ) + + + #if gt_depth_flag: + if False: + eval_error = dam.get_metrics() + print('w/o match :', eval_error) + + eval_error_median = dam_median.get_metrics() + print('median match :', eval_error_median) + + eval_error_global = dam_global.get_metrics() + print('global match :', eval_error_global) + else: + print('missing gt_depth, only save visualizations...') diff --git a/mono/utils/logger.py b/mono/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..ca48c613b2fdc5352b13ccb7d0bfdc1df5e3b531 --- /dev/null +++ b/mono/utils/logger.py @@ -0,0 +1,102 @@ +import atexit +import logging +import os +import sys +import time +import torch +from termcolor import colored + +__all__ = ["setup_logger", ] + +class _ColorfulFormatter(logging.Formatter): + def __init__(self, *args, **kwargs): + self._root_name = kwargs.pop("root_name") + "." + self._abbrev_name = kwargs.pop("abbrev_name", "") + if len(self._abbrev_name): + self._abbrev_name = self._abbrev_name + "." + super(_ColorfulFormatter, self).__init__(*args, **kwargs) + + def formatMessage(self, record): + record.name = record.name.replace(self._root_name, self._abbrev_name) + log = super(_ColorfulFormatter, self).formatMessage(record) + if record.levelno == logging.WARNING: + prefix = colored("WARNING", "red", attrs=["blink"]) + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + prefix = colored("ERROR", "red", attrs=["blink", "underline"]) + else: + return log + return prefix + " " + log + +def setup_logger( + output=None, distributed_rank=0, *, name='metricdepth', color=True, abbrev_name=None +): + """ + Initialize the detectron2 logger and set its verbosity level to "DEBUG". + Args: + output (str): a file name or a directory to save log. If None, will not save log file. + If ends with ".txt" or ".log", assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + abbrev_name (str): an abbreviation of the module, to avoid log names in logs. + Set to "" not log the root module in logs. + By default, will abbreviate "detectron2" to "d2" and leave other + modules unchanged. + Returns: + logging.Logger: a logger + """ + logger = logging.getLogger() + logger.setLevel(logging.INFO) # NOTE: if more detailed, change it to logging.DEBUG + logger.propagate = False + + if abbrev_name is None: + abbrev_name = "d2" + + plain_formatter = logging.Formatter( + "[%(asctime)s] %(name)s %(levelname)s %(message)s ", datefmt="%m/%d %H:%M:%S" + ) + # stdout logging: master only + if distributed_rank == 0: + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.INFO) # NOTE: if more detailed, change it to logging.DEBUG + if color: + formatter = _ColorfulFormatter( + colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", + datefmt="%m/%d %H:%M:%S", + root_name=name, + abbrev_name=str(abbrev_name), + ) + else: + formatter = plain_formatter + ch.setFormatter(formatter) + logger.addHandler(ch) + + # file logging: all workers + if output is not None: + if output.endswith(".txt") or output.endswith(".log"): + filename = output + else: + filename = os.path.join(output, "log.txt") + if distributed_rank > 0: + filename = filename + ".rank{}".format(distributed_rank) + os.makedirs(os.path.dirname(filename), exist_ok=True) + + fh = logging.StreamHandler(_cached_log_stream(filename)) + fh.setLevel(logging.INFO) # NOTE: if more detailed, change it to logging.DEBUG + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + + + return logger + +from iopath.common.file_io import PathManager as PathManagerBase + + +PathManager = PathManagerBase() + +# cache the opened file object, so that different calls to 'setup_logger +# with the same file name can safely write to the same file. +def _cached_log_stream(filename): + # use 1K buffer if writting to cloud storage + io = PathManager.open(filename, "a", buffering=1024 if "://" in filename else -1) + atexit.register(io.close) + return io + \ No newline at end of file diff --git a/mono/utils/mldb.py b/mono/utils/mldb.py new file mode 100644 index 0000000000000000000000000000000000000000..d74ac53fd0302e2e954105bade52e6de4c18e2f6 --- /dev/null +++ b/mono/utils/mldb.py @@ -0,0 +1,34 @@ +from types import ModuleType +import data_info + +def load_data_info(module_name, data_info={}, mldb_type='mldb_info', module=None): + if module is None: + module = globals().get(module_name, None) + if module: + for key, value in module.__dict__.items(): + if not (key.startswith('__')) and not (key.startswith('_')): + if key == 'mldb_info': + data_info.update(value) + elif isinstance(value, ModuleType): + load_data_info(module_name + '.' + key, data_info, module=value) + else: + raise RuntimeError(f'Try to access "mldb_info", but cannot find {module_name} module.') + +def reset_ckpt_path(cfg, data_info): + if isinstance(cfg, dict): + for key in cfg.keys(): + if key == 'backbone': + new_ckpt_path = data_info['checkpoint']['mldb_root'] + '/' + data_info['checkpoint'][cfg.backbone.type] + cfg.backbone.update(checkpoint=new_ckpt_path) + continue + elif isinstance(cfg.get(key), dict): + reset_ckpt_path(cfg.get(key), data_info) + else: + continue + else: + return + +if __name__ == '__main__': + mldb_info_tmp = {} + load_data_info('mldb_data_info', mldb_info_tmp) + print('results', mldb_info_tmp.keys()) \ No newline at end of file diff --git a/mono/utils/pcd_filter.py b/mono/utils/pcd_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..2d26314d806ea961f6bf09d1fb195bf5e364f181 --- /dev/null +++ b/mono/utils/pcd_filter.py @@ -0,0 +1,24 @@ +import open3d as o3d +import numpy as np + +def downsample_and_filter(pcd_file): + pcd = o3d.io.read_point_cloud(pcd_file, max_bound_div = 750, neighbor_num = 8) + point_num = len(pcd.points) + if (point_num > 10000000): + voxel_down_pcd = o3d.geometry.PointCloud.uniform_down_sample(pcd, int(point_num / 10000000)+1) + else: + voxel_down_pcd = pcd + max_bound = voxel_down_pcd.get_max_bound() + ball_radius = np.linalg.norm(max_bound) / max_bound_div + pcd_filter, _ = voxel_down_pcd.remove_radius_outlier(neighbor_num, ball_radius) + print('filtered size', len(pcd_filter.points), 'pre size:', len(pcd.points)) + o3d.io.write_point_cloud(pcd_file[:-4] + '_filtered.ply', pcd_filter) + + +if __name__ == "__main__": + import os + dir_path = './data/demo_pcd' + for pcd_file in os.listdir(dir_path): + #if 'jonathan' in pcd_file: set max_bound_div to 300 and neighbot_num to 8 + downsample_and_filter(os.path.join(dir_path, pcd_file)) + \ No newline at end of file diff --git a/mono/utils/running.py b/mono/utils/running.py new file mode 100644 index 0000000000000000000000000000000000000000..8a8b8d2c1f355717f46f784a28ac5f327c01dfc5 --- /dev/null +++ b/mono/utils/running.py @@ -0,0 +1,77 @@ +import os +import torch +import torch.nn as nn +from mono.utils.comm import main_process +import copy +import inspect +import logging +import glob + + +def load_ckpt(load_path, model, optimizer=None, scheduler=None, strict_match=True, loss_scaler=None): + """ + Load the check point for resuming training or finetuning. + """ + logger = logging.getLogger() + if os.path.isfile(load_path): + if main_process(): + logger.info(f"Loading weight '{load_path}'") + checkpoint = torch.load(load_path, map_location="cpu") + ckpt_state_dict = checkpoint['model_state_dict'] + model.module.load_state_dict(ckpt_state_dict, strict=strict_match) + + if optimizer is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + if scheduler is not None: + scheduler.load_state_dict(checkpoint['scheduler']) + if loss_scaler is not None and 'scaler' in checkpoint: + scheduler.load_state_dict(checkpoint['scaler']) + del ckpt_state_dict + del checkpoint + if main_process(): + logger.info(f"Successfully loaded weight: '{load_path}'") + if scheduler is not None and optimizer is not None: + logger.info(f"Resume training from: '{load_path}'") + else: + if main_process(): + raise RuntimeError(f"No weight found at '{load_path}'") + return model, optimizer, scheduler, loss_scaler + + +def save_ckpt(cfg, model, optimizer, scheduler, curr_iter=0, curr_epoch=None, loss_scaler=None): + """ + Save the model, optimizer, lr scheduler. + """ + logger = logging.getLogger() + + if 'IterBasedRunner' in cfg.runner.type: + max_iters = cfg.runner.max_iters + elif 'EpochBasedRunner' in cfg.runner.type: + max_iters = cfg.runner.max_epochs + else: + raise TypeError(f'{cfg.runner.type} is not supported') + + ckpt = dict( + model_state_dict=model.module.state_dict(), + optimizer=optimizer.state_dict(), + max_iter=cfg.runner.max_iters if 'max_iters' in cfg.runner \ + else cfg.runner.max_epochs, + scheduler=scheduler.state_dict(), + ) + + if loss_scaler is not None: + ckpt.update(dict(scaler=loss_scaler.state_dict())) + + ckpt_dir = os.path.join(cfg.work_dir, 'ckpt') + os.makedirs(ckpt_dir, exist_ok=True) + + save_name = os.path.join(ckpt_dir, 'step%08d.pth' %curr_iter) + saved_ckpts = glob.glob(ckpt_dir + '/step*.pth') + torch.save(ckpt, save_name) + + # keep the last 8 ckpts + if len(saved_ckpts) > 20: + saved_ckpts.sort() + os.remove(saved_ckpts.pop(0)) + + logger.info(f'Save model: {save_name}') diff --git a/mono/utils/transform.py b/mono/utils/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..2af94efe754d6f72325db6fdc170f30fbfb8c2fe --- /dev/null +++ b/mono/utils/transform.py @@ -0,0 +1,408 @@ +import collections +import cv2 +import math +import numpy as np +import numbers +import random +import torch + +import matplotlib +import matplotlib.cm + + +""" +Provides a set of Pytorch transforms that use OpenCV instead of PIL (Pytorch default) +for image manipulation. +""" + +class Compose(object): + # Composes transforms: transforms.Compose([transforms.RandScale([0.5, 2.0]), transforms.ToTensor()]) + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None): + for t in self.transforms: + images, labels, intrinsics, cam_models, other_labels, transform_paras = t(images, labels, intrinsics, cam_models, other_labels, transform_paras) + return images, labels, intrinsics, cam_models, other_labels, transform_paras + + +class ToTensor(object): + # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). + def __init__(self, **kwargs): + return + def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None): + if not isinstance(images, list) or not isinstance(labels, list) or not isinstance(intrinsics, list): + raise (RuntimeError("transform.ToTensor() only handle inputs/labels/intrinsics lists.")) + if len(images) != len(intrinsics): + raise (RuntimeError("Numbers of images and intrinsics are not matched.")) + if not isinstance(images[0], np.ndarray) or not isinstance(labels[0], np.ndarray): + raise (RuntimeError("transform.ToTensor() only handle np.ndarray for the input and label." + "[eg: data readed by cv2.imread()].\n")) + if not isinstance(intrinsics[0], list): + raise (RuntimeError("transform.ToTensor() only handle list for the camera intrinsics")) + + if len(images[0].shape) > 3 or len(images[0].shape) < 2: + raise (RuntimeError("transform.ToTensor() only handle image(np.ndarray) with 3 dims or 2 dims.\n")) + if len(labels[0].shape) > 3 or len(labels[0].shape) < 2: + raise (RuntimeError("transform.ToTensor() only handle label(np.ndarray) with 3 dims or 2 dims.\n")) + + if len(intrinsics[0]) >4 or len(intrinsics[0]) < 3: + raise (RuntimeError("transform.ToTensor() only handle intrinsic(list) with 3 sizes or 4 sizes.\n")) + + for i, img in enumerate(images): + if len(img.shape) == 2: + img = np.expand_dims(img, axis=2) + images[i] = torch.from_numpy(img.transpose((2, 0, 1))).float() + for i, lab in enumerate(labels): + if len(lab.shape) == 2: + lab = np.expand_dims(lab, axis=0) + labels[i] = torch.from_numpy(lab).float() + for i, intrinsic in enumerate(intrinsics): + if len(intrinsic) == 3: + intrinsic = [intrinsic[0],] + intrinsic + intrinsics[i] = torch.tensor(intrinsic, dtype=torch.float) + if cam_models is not None: + for i, cam_model in enumerate(cam_models): + cam_models[i] = torch.from_numpy(cam_model.transpose((2, 0, 1))).float() if cam_model is not None else None + if other_labels is not None: + for i, lab in enumerate(other_labels): + if len(lab.shape) == 2: + lab = np.expand_dims(lab, axis=0) + other_labels[i] = torch.from_numpy(lab).float() + return images, labels, intrinsics, cam_models, other_labels, transform_paras + + +class Normalize(object): + # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std + def __init__(self, mean, std=None, **kwargs): + if std is None: + assert len(mean) > 0 + else: + assert len(mean) == len(std) + self.mean = torch.tensor(mean).float()[:, None, None] + self.std = torch.tensor(std).float()[:, None, None] if std is not None \ + else torch.tensor([1.0, 1.0, 1.0]).float()[:, None, None] + + def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None): + # if self.std is None: + # # for t, m in zip(image, self.mean): + # # t.sub(m) + # image = image - self.mean + # if ref_images is not None: + # for i, ref_i in enumerate(ref_images): + # ref_images[i] = ref_i - self.mean + # else: + # # for t, m, s in zip(image, self.mean, self.std): + # # t.sub(m).div(s) + # image = (image - self.mean) / self.std + # if ref_images is not None: + # for i, ref_i in enumerate(ref_images): + # ref_images[i] = (ref_i - self.mean) / self.std + for i, img in enumerate(images): + img = torch.div((img - self.mean), self.std) + images[i] = img + return images, labels, intrinsics, cam_models, other_labels, transform_paras + + +class LableScaleCanonical(object): + """ + To solve the ambiguity observation for the mono branch, i.e. different focal length (object size) with the same depth, cameras are + mapped to a canonical space. To mimic this, we set the focal length to a canonical one and scale the depth value. NOTE: resize the image based on the ratio can also solve + Args: + images: list of RGB images. + labels: list of depth/disparity labels. + other labels: other labels, such as instance segmentations, semantic segmentations... + """ + def __init__(self, **kwargs): + self.canonical_focal = kwargs['focal_length'] + + def _get_scale_ratio(self, intrinsic): + target_focal_x = intrinsic[0] + label_scale_ratio = self.canonical_focal / target_focal_x + pose_scale_ratio = 1.0 + return label_scale_ratio, pose_scale_ratio + + def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None): + assert len(images[0].shape) == 3 and len(labels[0].shape) == 2 + assert labels[0].dtype == np.float32 + + label_scale_ratio = None + pose_scale_ratio = None + + for i in range(len(intrinsics)): + img_i = images[i] + label_i = labels[i] if i < len(labels) else None + intrinsic_i = intrinsics[i].copy() + cam_model_i = cam_models[i] if cam_models is not None and i < len(cam_models) else None + + label_scale_ratio, pose_scale_ratio = self._get_scale_ratio(intrinsic_i) + + # adjust the focal length, map the current camera to the canonical space + intrinsics[i] = [intrinsic_i[0] * label_scale_ratio, intrinsic_i[1] * label_scale_ratio, intrinsic_i[2], intrinsic_i[3]] + + # scale the label to the canonical space + if label_i is not None: + labels[i] = label_i * label_scale_ratio + + if cam_model_i is not None: + # As the focal length is adjusted (canonical focal length), the camera model should be re-built + ori_h, ori_w, _ = img_i.shape + cam_models[i] = build_camera_model(ori_h, ori_w, intrinsics[i]) + + + if transform_paras is not None: + transform_paras.update(label_scale_factor=label_scale_ratio, focal_scale_factor=label_scale_ratio) + + return images, labels, intrinsics, cam_models, other_labels, transform_paras + + +class ResizeKeepRatio(object): + """ + Resize and pad to a given size. Hold the aspect ratio. + This resizing assumes that the camera model remains unchanged. + Args: + resize_size: predefined output size. + """ + def __init__(self, resize_size, padding=None, ignore_label=-1, **kwargs): + if isinstance(resize_size, int): + self.resize_h = resize_size + self.resize_w = resize_size + elif isinstance(resize_size, collections.Iterable) and len(resize_size) == 2 \ + and isinstance(resize_size[0], int) and isinstance(resize_size[1], int) \ + and resize_size[0] > 0 and resize_size[1] > 0: + self.resize_h = resize_size[0] + self.resize_w = resize_size[1] + else: + raise (RuntimeError("crop size error.\n")) + if padding is None: + self.padding = padding + elif isinstance(padding, list): + if all(isinstance(i, numbers.Number) for i in padding): + self.padding = padding + else: + raise (RuntimeError("padding in Crop() should be a number list\n")) + if len(padding) != 3: + raise (RuntimeError("padding channel is not equal with 3\n")) + else: + raise (RuntimeError("padding in Crop() should be a number list\n")) + if isinstance(ignore_label, int): + self.ignore_label = ignore_label + else: + raise (RuntimeError("ignore_label should be an integer number\n")) + # self.crop_size = kwargs['crop_size'] + self.canonical_focal = kwargs['focal_length'] + + def main_data_transform(self, image, label, intrinsic, cam_model, resize_ratio, padding, to_scale_ratio): + """ + Resize data first and then do the padding. + 'label' will be scaled. + """ + h, w, _ = image.shape + reshape_h = int(resize_ratio * h) + reshape_w = int(resize_ratio * w) + + pad_h, pad_w, pad_h_half, pad_w_half = padding + + # resize + image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR) + # padding + image = cv2.copyMakeBorder( + image, + pad_h_half, + pad_h - pad_h_half, + pad_w_half, + pad_w - pad_w_half, + cv2.BORDER_CONSTANT, + value=self.padding) + + if label is not None: + # label = cv2.resize(label, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST) + label = resize_depth_preserve(label, (reshape_h, reshape_w)) + label = cv2.copyMakeBorder( + label, + pad_h_half, + pad_h - pad_h_half, + pad_w_half, + pad_w - pad_w_half, + cv2.BORDER_CONSTANT, + value=self.ignore_label) + # scale the label + label = label / to_scale_ratio + + # Resize, adjust principle point + if intrinsic is not None: + intrinsic[0] = intrinsic[0] * resize_ratio / to_scale_ratio + intrinsic[1] = intrinsic[1] * resize_ratio / to_scale_ratio + intrinsic[2] = intrinsic[2] * resize_ratio + intrinsic[3] = intrinsic[3] * resize_ratio + + if cam_model is not None: + #cam_model = cv2.resize(cam_model, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR) + cam_model = build_camera_model(reshape_h, reshape_w, intrinsic) + cam_model = cv2.copyMakeBorder( + cam_model, + pad_h_half, + pad_h - pad_h_half, + pad_w_half, + pad_w - pad_w_half, + cv2.BORDER_CONSTANT, + value=self.ignore_label) + + # Pad, adjust the principle point + if intrinsic is not None: + intrinsic[2] = intrinsic[2] + pad_w_half + intrinsic[3] = intrinsic[3] + pad_h_half + return image, label, intrinsic, cam_model + + def get_label_scale_factor(self, image, intrinsic, resize_ratio): + ori_h, ori_w, _ = image.shape + # crop_h, crop_w = self.crop_size + ori_focal = intrinsic[0] + + to_canonical_ratio = self.canonical_focal / ori_focal + to_scale_ratio = resize_ratio / to_canonical_ratio + return to_scale_ratio + + def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None): + target_h, target_w, _ = images[0].shape + resize_ratio_h = self.resize_h / target_h + resize_ratio_w = self.resize_w / target_w + resize_ratio = min(resize_ratio_h, resize_ratio_w) + reshape_h = int(resize_ratio * target_h) + reshape_w = int(resize_ratio * target_w) + pad_h = max(self.resize_h - reshape_h, 0) + pad_w = max(self.resize_w - reshape_w, 0) + pad_h_half = int(pad_h / 2) + pad_w_half = int(pad_w / 2) + + pad_info = [pad_h, pad_w, pad_h_half, pad_w_half] + to_scale_ratio = self.get_label_scale_factor(images[0], intrinsics[0], resize_ratio) + + for i in range(len(images)): + img = images[i] + label = labels[i] if i < len(labels) else None + intrinsic = intrinsics[i] if i < len(intrinsics) else None + cam_model = cam_models[i] if cam_models is not None and i < len(cam_models) else None + img, label, intrinsic, cam_model = self.main_data_transform( + img, label, intrinsic, cam_model, resize_ratio, pad_info, to_scale_ratio) + images[i] = img + if label is not None: + labels[i] = label + if intrinsic is not None: + intrinsics[i] = intrinsic + if cam_model is not None: + cam_models[i] = cam_model + + if other_labels is not None: + + for i, other_lab in enumerate(other_labels): + # resize + other_lab = cv2.resize(other_lab, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST) + # pad + other_labels[i] = cv2.copyMakeBorder( + other_lab, + pad_h_half, + pad_h - pad_h_half, + pad_w_half, + pad_w - pad_w_half, + cv2.BORDER_CONSTANT, + value=self.ignore_label) + + pad = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half] + if transform_paras is not None: + pad_old = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0] + new_pad = [pad_old[0] + pad[0], pad_old[1] + pad[1], pad_old[2] + pad[2], pad_old[3] + pad[3]] + transform_paras.update(dict(pad=new_pad)) + if 'label_scale_factor' in transform_paras: + transform_paras['label_scale_factor'] = transform_paras['label_scale_factor'] * 1.0 / to_scale_ratio + else: + transform_paras.update(label_scale_factor=1.0/to_scale_ratio) + return images, labels, intrinsics, cam_models, other_labels, transform_paras + + +class BGR2RGB(object): + # Converts image from BGR order to RGB order, for model initialized from Pytorch + def __init__(self, **kwargs): + return + def __call__(self, images, labels, intrinsics, cam_models=None,other_labels=None, transform_paras=None): + for i, img in enumerate(images): + images[i] = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return images, labels, intrinsics, cam_models, other_labels, transform_paras + + +def resize_depth_preserve(depth, shape): + """ + Resizes depth map preserving all valid depth pixels + Multiple downsampled points can be assigned to the same pixel. + + Parameters + ---------- + depth : np.array [h,w] + Depth map + shape : tuple (H,W) + Output shape + + Returns + ------- + depth : np.array [H,W,1] + Resized depth map + """ + # Store dimensions and reshapes to single column + depth = np.squeeze(depth) + h, w = depth.shape + x = depth.reshape(-1) + # Create coordinate grid + uv = np.mgrid[:h, :w].transpose(1, 2, 0).reshape(-1, 2) + # Filters valid points + idx = x > 0 + crd, val = uv[idx], x[idx] + # Downsamples coordinates + crd[:, 0] = (crd[:, 0] * (shape[0] / h) + 0.5).astype(np.int32) + crd[:, 1] = (crd[:, 1] * (shape[1] / w) + 0.5).astype(np.int32) + # Filters points inside image + idx = (crd[:, 0] < shape[0]) & (crd[:, 1] < shape[1]) + crd, val = crd[idx], val[idx] + # Creates downsampled depth image and assigns points + depth = np.zeros(shape) + depth[crd[:, 0], crd[:, 1]] = val + # Return resized depth map + return depth + + +def build_camera_model(H : int, W : int, intrinsics : list) -> np.array: + """ + Encode the camera intrinsic parameters (focal length and principle point) to a 4-channel map. + """ + fx, fy, u0, v0 = intrinsics + f = (fx + fy) / 2.0 + # principle point location + x_row = np.arange(0, W).astype(np.float32) + x_row_center_norm = (x_row - u0) / W + x_center = np.tile(x_row_center_norm, (H, 1)) # [H, W] + + y_col = np.arange(0, H).astype(np.float32) + y_col_center_norm = (y_col - v0) / H + y_center = np.tile(y_col_center_norm, (W, 1)).T + + # FoV + fov_x = np.arctan(x_center / (f / W)) + fov_y = np.arctan(y_center/ (f / H)) + + cam_model = np.stack([x_center, y_center, fov_x, fov_y], axis=2) + return cam_model + +def gray_to_colormap(img, cmap='rainbow'): + """ + Transfer gray map to matplotlib colormap + """ + assert img.ndim == 2 + + img[img<0] = 0 + mask_invalid = img < 1e-10 + img = img / (img.max() + 1e-8) + norm = matplotlib.colors.Normalize(vmin=0, vmax=1.1) + cmap_m = matplotlib.cm.get_cmap(cmap) + map = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap_m) + colormap = (map.to_rgba(img)[:, :, :3] * 255).astype(np.uint8) + colormap[mask_invalid] = 0 + return colormap \ No newline at end of file diff --git a/mono/utils/unproj_pcd.py b/mono/utils/unproj_pcd.py new file mode 100644 index 0000000000000000000000000000000000000000..a0986d482a2ec68be1dd65719adec662272b833c --- /dev/null +++ b/mono/utils/unproj_pcd.py @@ -0,0 +1,88 @@ +import numpy as np +import torch +from plyfile import PlyData, PlyElement +import cv2 + + +def get_pcd_base(H, W, u0, v0, fx, fy): + x_row = np.arange(0, W) + x = np.tile(x_row, (H, 1)) + x = x.astype(np.float32) + u_m_u0 = x - u0 + + y_col = np.arange(0, H) # y_col = np.arange(0, height) + y = np.tile(y_col, (W, 1)).T + y = y.astype(np.float32) + v_m_v0 = y - v0 + + x = u_m_u0 / fx + y = v_m_v0 / fy + z = np.ones_like(x) + pw = np.stack([x, y, z], axis=2) # [h, w, c] + return pw + + +def reconstruct_pcd(depth, fx, fy, u0, v0, pcd_base=None, mask=None): + if type(depth) == torch.__name__: + depth = depth.cpu().numpy().squeeze() + depth = cv2.medianBlur(depth, 5) + if pcd_base is None: + H, W = depth.shape + pcd_base = get_pcd_base(H, W, u0, v0, fx, fy) + pcd = depth[:, :, None] * pcd_base + if mask: + pcd[mask] = 0 + return pcd + + +def save_point_cloud(pcd, rgb, filename, binary=True): + """Save an RGB point cloud as a PLY file. + :paras + @pcd: Nx3 matrix, the XYZ coordinates + @rgb: Nx3 matrix, the rgb colors for each 3D point + """ + assert pcd.shape[0] == rgb.shape[0] + + if rgb is None: + gray_concat = np.tile(np.array([128], dtype=np.uint8), + (pcd.shape[0], 3)) + points_3d = np.hstack((pcd, gray_concat)) + else: + points_3d = np.hstack((pcd, rgb)) + python_types = (float, float, float, int, int, int) + npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), + ('green', 'u1'), ('blue', 'u1')] + if binary is True: + # Format into Numpy structured array + vertices = [] + for row_idx in range(points_3d.shape[0]): + cur_point = points_3d[row_idx] + vertices.append( + tuple( + dtype(point) + for dtype, point in zip(python_types, cur_point))) + vertices_array = np.array(vertices, dtype=npy_types) + el = PlyElement.describe(vertices_array, 'vertex') + + # write + PlyData([el]).write(filename) + else: + x = np.squeeze(points_3d[:, 0]) + y = np.squeeze(points_3d[:, 1]) + z = np.squeeze(points_3d[:, 2]) + r = np.squeeze(points_3d[:, 3]) + g = np.squeeze(points_3d[:, 4]) + b = np.squeeze(points_3d[:, 5]) + + ply_head = 'ply\n' \ + 'format ascii 1.0\n' \ + 'element vertex %d\n' \ + 'property float x\n' \ + 'property float y\n' \ + 'property float z\n' \ + 'property uchar red\n' \ + 'property uchar green\n' \ + 'property uchar blue\n' \ + 'end_header' % r.shape[0] + # ---- Save ply data to disk + np.savetxt(filename, np.column_stack[x, y, z, r, g, b], fmt='%f %f %f %d %d %d', header=ply_head, comments='') \ No newline at end of file diff --git a/mono/utils/visualization.py b/mono/utils/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..07275030c48aeea062c0041b11ba60d911c14a3f --- /dev/null +++ b/mono/utils/visualization.py @@ -0,0 +1,140 @@ +import matplotlib.pyplot as plt +import os, cv2 +import numpy as np +from mono.utils.transform import gray_to_colormap +import shutil +import glob +from mono.utils.running import main_process +import torch +from html4vision import Col, imagetable + +def save_raw_imgs( + pred: torch.tensor, + rgb: torch.tensor, + filename: str, + save_dir: str, + scale: float=200.0, + target: torch.tensor=None, + ): + """ + Save raw GT, predictions, RGB in the same file. + """ + cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_rgb.jpg'), rgb) + cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_d.png'), (pred*scale).astype(np.uint16)) + if target is not None: + cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_gt.png'), (target*scale).astype(np.uint16)) + + +def save_val_imgs( + iter: int, + pred: torch.tensor, + target: torch.tensor, + rgb: torch.tensor, + filename: str, + save_dir: str, + tb_logger=None + ): + """ + Save GT, predictions, RGB in the same file. + """ + rgb, pred_scale, target_scale, pred_color, target_color = get_data_for_log(pred, target, rgb) + rgb = rgb.transpose((1, 2, 0)) + cat_img = np.concatenate([rgb, pred_color, target_color], axis=0) + plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img) + + # save to tensorboard + if tb_logger is not None: + tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter) + +def save_normal_val_imgs( + iter: int, + pred: torch.tensor, + targ: torch.tensor, + rgb: torch.tensor, + filename: str, + save_dir: str, + tb_logger=None, + mask=None, + ): + """ + Save GT, predictions, RGB in the same file. + """ + mean = np.array([123.675, 116.28, 103.53])[np.newaxis, np.newaxis, :] + std= np.array([58.395, 57.12, 57.375])[np.newaxis, np.newaxis, :] + pred = pred.squeeze() + targ = targ.squeeze() + rgb = rgb.squeeze() + + if pred.size(0) == 3: + pred = pred.permute(1,2,0) + if targ.size(0) == 3: + targ = targ.permute(1,2,0) + if rgb.size(0) == 3: + rgb = rgb.permute(1,2,0) + + pred_color = vis_surface_normal(pred, mask) + targ_color = vis_surface_normal(targ, mask) + rgb_color = ((rgb.cpu().numpy() * std) + mean).astype(np.uint8) + + try: + cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0) + except: + pred_color = cv2.resize(pred_color, (rgb.shape[1], rgb.shape[0])) + targ_color = cv2.resize(targ_color, (rgb.shape[1], rgb.shape[0])) + cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0) + + plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img) + # cv2.imwrite(os.path.join(save_dir, filename[:-4]+'.jpg'), pred_color) + # save to tensorboard + if tb_logger is not None: + tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter) + +def get_data_for_log(pred: torch.tensor, target: torch.tensor, rgb: torch.tensor): + mean = np.array([123.675, 116.28, 103.53])[:, np.newaxis, np.newaxis] + std= np.array([58.395, 57.12, 57.375])[:, np.newaxis, np.newaxis] + + pred = pred.squeeze().cpu().numpy() + target = target.squeeze().cpu().numpy() + rgb = rgb.squeeze().cpu().numpy() + + pred[pred<0] = 0 + target[target<0] = 0 + max_scale = max(pred.max(), target.max()) + pred_scale = (pred/max_scale * 10000).astype(np.uint16) + target_scale = (target/max_scale * 10000).astype(np.uint16) + pred_color = gray_to_colormap(pred) + target_color = gray_to_colormap(target) + pred_color = cv2.resize(pred_color, (rgb.shape[2], rgb.shape[1])) + target_color = cv2.resize(target_color, (rgb.shape[2], rgb.shape[1])) + + rgb = ((rgb * std) + mean).astype(np.uint8) + return rgb, pred_scale, target_scale, pred_color, target_color + + +def create_html(name2path, save_path='index.html', size=(256, 384)): + # table description + cols = [] + for k, v in name2path.items(): + col_i = Col('img', k, v) # specify image content for column + cols.append(col_i) + # html table generation + imagetable(cols, out_file=save_path, imsize=size) + +def vis_surface_normal(normal: torch.tensor, mask: torch.tensor=None) -> np.array: + """ + Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255] + Aargs: + normal (torch.tensor, [h, w, 3]): surface normal + mask (torch.tensor, [h, w]): valid masks + """ + normal = normal.cpu().numpy().squeeze() + n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True)) + n_img_norm = normal / (n_img_L2 + 1e-8) + normal_vis = n_img_norm * 127 + normal_vis += 128 + normal_vis = normal_vis.astype(np.uint8) + if mask is not None: + mask = mask.cpu().numpy().squeeze() + normal_vis[~mask] = 0 + return normal_vis +