Diving-into-the-Fusion-of-Monocular-Priors-for-Generalized-Stereo-Matching
/
core
/extractor_depthany.py
| import os | |
| import sys | |
| import numpy as np | |
| sys.path.insert(0,'Depth-Anything-V2') | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms as T | |
| from core.extractor import ResidualBlock | |
| from depth_anything_v2.dpt import DepthAnythingV2 | |
| from core.utils.utils import sv_intermediate_results | |
| from huggingface_hub import hf_hub_download | |
| def resize_tensor(tensor, target_size=512, ratio=16): | |
| # 获取输入 tensor 的尺寸 (B, C, H, W) | |
| _, _, H, W = tensor.shape | |
| # 计算 H 和 W 中较长的一边 | |
| if H > W: | |
| new_H = target_size | |
| new_W = int(W * (target_size / H)) | |
| else: | |
| new_W = target_size | |
| new_H = int(H * (target_size / W)) | |
| new_W = (np.ceil(new_W / ratio) * ratio).astype(int) | |
| new_H = (np.ceil(new_H / ratio) * ratio).astype(int) | |
| # 使用 interpolate 进行缩放 | |
| resized_tensor = F.interpolate(tensor, size=(new_H, new_W), mode='bicubic', align_corners=False) | |
| return resized_tensor | |
| def resize_to_quarter(tensor, original_size, ratio): | |
| # 将尺寸缩小为原始尺寸的 1/4 | |
| quarter_H = original_size[0] // ratio | |
| quarter_W = original_size[1] // ratio | |
| # 使用 interpolate 进行缩小 | |
| resized_tensor = F.interpolate(tensor, size=(quarter_H, quarter_W), mode='bilinear', align_corners=False) | |
| return resized_tensor | |
| class DepthAnyExtractor(nn.Module): | |
| def __init__(self, model_dir, output_dim=[128], norm_fn='batch', downsample=2, args=None): | |
| super(DepthAnyExtractor, self).__init__() | |
| self.args = args | |
| self.norm_fn = norm_fn | |
| self.downsample = downsample | |
| output_list = [] | |
| for dim in output_dim: | |
| conv_out = nn.Sequential( | |
| ResidualBlock(128, 128, self.norm_fn, stride=1), | |
| nn.Conv2d(128, dim[2], 3, padding=1)) | |
| output_list.append(conv_out) | |
| self.outputs08 = nn.ModuleList(output_list) | |
| output_list = [] | |
| for dim in output_dim: | |
| conv_out = nn.Sequential( | |
| ResidualBlock(128, 128, self.norm_fn, stride=1), | |
| nn.Conv2d(128, dim[1], 3, padding=1)) | |
| output_list.append(conv_out) | |
| self.outputs16 = nn.ModuleList(output_list) | |
| output_list = [] | |
| for dim in output_dim: | |
| conv_out = nn.Conv2d(128, dim[0], 3, padding=1) | |
| output_list.append(conv_out) | |
| self.outputs32 = nn.ModuleList(output_list) | |
| self.layer1 = nn.Sequential( | |
| nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), | |
| nn.ReLU(inplace=True), | |
| ) | |
| self.in_planes = 128 | |
| self.layer2 = self._make_layer(128, stride=2) | |
| self.layer3 = self._make_layer(128, stride=2) | |
| # self._init_weights() | |
| model_configs = { | |
| 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, | |
| 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, | |
| 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, | |
| 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} | |
| } | |
| encoder = "vitl" | |
| depth_anything = DepthAnythingV2(**model_configs[encoder]) | |
| checkpoint_path = hf_hub_download( | |
| repo_id="BFZD/Diving-into-the-Fusion-of-Monocular-Priors-for-Generalized-Stereo-Matching", | |
| filename="dav2_models/depth_anything_v2_vitl.pth", | |
| ) | |
| # depth_anything.load_state_dict(torch.load(os.path.join(model_dir, f'depth_anything_v2_{encoder}.pth'), | |
| # map_location='cpu')) | |
| depth_anything.load_state_dict(torch.load(checkpoint_path,map_location='cpu')) | |
| # self.depth_anything = depth_anything.to('cuda') | |
| self.depth_anything = depth_anything | |
| mean = [0.485, 0.456, 0.406] | |
| std = [0.229, 0.224, 0.225] | |
| self.mean = torch.tensor(mean).view(1, 3, 1, 1).cuda() | |
| self.std = torch.tensor(std).view(1, 3, 1, 1).cuda() | |
| # 冻结 depth_anything 模型的所有参数 | |
| for param in self.depth_anything.parameters(): | |
| param.requires_grad = False | |
| # def _init_weights(self): | |
| # for m in self.modules(): | |
| # if isinstance(m, nn.Conv2d): | |
| # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
| # elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): | |
| # if m.weight is not None: | |
| # nn.init.constant_(m.weight, 1) | |
| # if m.bias is not None: | |
| # nn.init.constant_(m.bias, 0) | |
| def _make_layer(self, dim, stride=1): | |
| layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) | |
| layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) | |
| layers = (layer1, layer2) | |
| self.in_planes = dim | |
| return nn.Sequential(*layers) | |
| def forward(self, image, dual_inp=False, num_layers=3): | |
| # resize image | |
| B, _, H, W = image.shape | |
| img = resize_tensor(image, target_size=518, ratio=14) | |
| # normalization | |
| img = ((img+1)/2 - self.mean) / self.std | |
| # DepthAnything | |
| with torch.no_grad(): | |
| # out_depth: [1, 1, 518, 756] | |
| # out_fea: [1, 128, 296, 432] | |
| depth, depth_fea = self.depth_anything(img) | |
| # resize image | |
| # [1, 128, H//4, W//4] | |
| depth = resize_to_quarter(depth, (H,W), 2**self.downsample) | |
| x = resize_to_quarter(depth_fea, (H,W), 2**self.downsample) | |
| if self.args is not None and hasattr(self.args, "vis_inter") and self.args.vis_inter: | |
| sv_intermediate_results(x, "depthAnything_features", self.args.sv_root) | |
| x = self.layer1(x) | |
| outputs08 = [f(x) for f in self.outputs08] | |
| if num_layers == 1: | |
| return (outputs08, v) if dual_inp else (outputs08,) | |
| # [1, 128, H//8, W//8] | |
| y = self.layer2(x) | |
| outputs16 = [f(y) for f in self.outputs16] | |
| if num_layers == 2: | |
| return (outputs08, outputs16, v) if dual_inp else (outputs08, outputs16) | |
| # [1, 128, H//16, W//16] | |
| z = self.layer3(y) | |
| outputs32 = [f(z) for f in self.outputs32] | |
| return (outputs08, outputs16, outputs32), depth | |
| class DepthMatchExtractor(nn.Module): | |
| def __init__(self, model_dir, output_dim=256, norm_fn='batch', downsample=2): | |
| super(DepthMatchExtractor, self).__init__() | |
| self.norm_fn = norm_fn | |
| self.downsample = downsample | |
| self.layer1 = nn.Sequential( | |
| nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), | |
| nn.ReLU(inplace=True), | |
| ) | |
| self.in_planes = 128 | |
| self.layer2 = self._make_layer(128, stride=1) | |
| self.conv = nn.Conv2d(128, output_dim, kernel_size=1) | |
| # self._init_weights() | |
| model_configs = { | |
| 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, | |
| 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, | |
| 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, | |
| 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} | |
| } | |
| encoder = "vitl" | |
| depth_anything = DepthAnythingV2(**model_configs[encoder]) | |
| checkpoint_path = hf_hub_download( | |
| repo_id="BFZD/Diving-into-the-Fusion-of-Monocular-Priors-for-Generalized-Stereo-Matching", | |
| filename="dav2_models/depth_anything_v2_vitl.pth", | |
| ) | |
| # depth_anything.load_state_dict(torch.load(os.path.join(model_dir, f'depth_anything_v2_{encoder}.pth'), | |
| # map_location='cpu')) | |
| depth_anything.load_state_dict(torch.load(checkpoint_path,map_location='cpu')) | |
| self.depth_anything = depth_anything.to('cuda') | |
| mean = [0.485, 0.456, 0.406] | |
| std = [0.229, 0.224, 0.225] | |
| self.mean = torch.tensor(mean).view(1, 3, 1, 1).cuda() | |
| self.std = torch.tensor(std).view(1, 3, 1, 1).cuda() | |
| # 冻结 depth_anything 模型的所有参数 | |
| for param in self.depth_anything.parameters(): | |
| param.requires_grad = False | |
| # def _init_weights(self): | |
| # for m in self.modules(): | |
| # if isinstance(m, nn.Conv2d): | |
| # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
| # elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): | |
| # if m.weight is not None: | |
| # nn.init.constant_(m.weight, 1) | |
| # if m.bias is not None: | |
| # nn.init.constant_(m.bias, 0) | |
| def _make_layer(self, dim, stride=1): | |
| layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) | |
| layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) | |
| layers = (layer1, layer2) | |
| self.in_planes = dim | |
| return nn.Sequential(*layers) | |
| def forward(self, x, dual_inp=False, num_layers=3): | |
| # if input is list, combine batch dimension | |
| is_list = isinstance(x, tuple) or isinstance(x, list) | |
| if is_list: | |
| batch_dim = x[0].shape[0] | |
| x = torch.cat(x, dim=0) | |
| # resize image | |
| B, _, H, W = x.shape | |
| x = resize_tensor(x, target_size=518, ratio=14) | |
| # normalization | |
| x = ((x+1)/2 - self.mean) / self.std | |
| # DepthAnything | |
| with torch.no_grad(): | |
| # out_depth: [1, 1, 518, 756] | |
| # out_fea: [1, 128, 296, 432] | |
| depth, depth_fea = self.depth_anything(x) | |
| # resize image | |
| # [1, 128, H//4, W//4] | |
| x = resize_to_quarter(depth_fea, (H,W), 2**self.downsample) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.conv(x) | |
| if is_list: | |
| x = x.split(split_size=batch_dim, dim=0) | |
| return x |