Spaces:
Runtime error
Runtime error
# MIT License | |
# Copyright (c) 2022 Intelligent Systems Lab Org | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
# File author: Zhenyu Li | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# from zoedepth.models.layers.swin_layers import G2LFusion | |
from estimator.models.blocks.swin_layers import G2LFusion | |
from torchvision.ops import roi_align as torch_roi_align | |
from estimator.registry import MODELS | |
class DoubleConvWOBN(nn.Module): | |
"""(convolution => [BN] => ReLU) * 2""" | |
def __init__(self, in_channels, out_channels, mid_channels=None): | |
super().__init__() | |
if not mid_channels: | |
mid_channels = out_channels | |
self.double_conv = nn.Sequential( | |
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=True), | |
# nn.BatchNorm2d(mid_channels), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=True), | |
# nn.BatchNorm2d(mid_channels), | |
nn.ReLU(inplace=True)) | |
def forward(self, x): | |
return self.double_conv(x) | |
class DoubleConv(nn.Module): | |
"""(convolution => [BN] => ReLU) * 2""" | |
def __init__(self, in_channels, out_channels, mid_channels=None): | |
super().__init__() | |
if not mid_channels: | |
mid_channels = out_channels | |
self.double_conv = nn.Sequential( | |
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), | |
nn.BatchNorm2d(mid_channels), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x): | |
return self.double_conv(x) | |
class Down(nn.Module): | |
"""Downscaling with maxpool then double conv""" | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.maxpool_conv = nn.Sequential( | |
nn.MaxPool2d(2), | |
DoubleConv(in_channels, out_channels) | |
) | |
def forward(self, x): | |
return self.maxpool_conv(x) | |
class Upv1(nn.Module): | |
"""Upscaling then double conv""" | |
def __init__(self, in_channels, out_channels, mid_channels=None): | |
super().__init__() | |
# self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
if mid_channels is not None: | |
self.conv = DoubleConvWOBN(in_channels, out_channels, mid_channels) | |
else: | |
self.conv = DoubleConvWOBN(in_channels, out_channels, in_channels) | |
def forward(self, x1, x2): | |
# x1 = self.up(x1) | |
x1 = F.interpolate(x1, size=x2.shape[-2:], mode='bilinear', align_corners=True) | |
x = torch.cat([x2, x1], dim=1) | |
return self.conv(x) | |
class GuidedFusionPatchFusion(nn.Module): | |
def __init__( | |
self, | |
n_channels, | |
g2l, | |
in_channels=[32, 256, 256, 256, 256, 256], | |
depth=[2, 2, 3, 3, 4, 4], | |
num_heads=[8, 8, 16, 16, 32, 32], | |
# num_patches=[12*16, 24*32, 48*64, 96*128, 192*256, 384*512], | |
num_patches=[384*512, 192*256, 96*128, 48*64, 24*32, 12*16], | |
patch_process_shape=[384, 512]): | |
super(GuidedFusionPatchFusion, self).__init__() | |
self.n_channels = n_channels | |
self.inc = DoubleConv(n_channels, in_channels[0]) | |
self.down_conv_list = nn.ModuleList() | |
for idx in range(len(in_channels) - 1): | |
lay = Down(in_channels[idx], in_channels[idx+1]) | |
self.down_conv_list.append(lay) | |
in_channels_inv = in_channels[::-1] | |
self.up_conv_list = nn.ModuleList() | |
for idx in range(1, len(in_channels)): | |
lay = Upv1(in_channels_inv[idx] + in_channels_inv[idx-1] + in_channels_inv[idx-1], in_channels_inv[idx]) | |
self.up_conv_list.append(lay) | |
self.g2l = g2l | |
if self.g2l: | |
self.g2l_att = nn.ModuleList() | |
win = 12 | |
self.patch_process_shape = patch_process_shape | |
num_heads_inv = num_heads[::-1] | |
depth_inv = depth[::-1] | |
num_patches_inv = num_patches[::-1] | |
self.g2l_list = nn.ModuleList() | |
self.convs = nn.ModuleList() | |
for idx in range(len(in_channels_inv)): | |
g2l_layer = G2LFusion(input_dim=in_channels_inv[idx], embed_dim=in_channels_inv[idx], window_size=win, num_heads=num_heads_inv[idx], depth=depth_inv[idx], num_patches=num_patches_inv[idx]) | |
self.g2l_list.append(g2l_layer) | |
layer = DoubleConvWOBN(in_channels_inv[idx] * 2, in_channels_inv[idx], in_channels_inv[idx]) | |
self.convs.append(layer) | |
# self.g2l5 = G2LFusion(input_dim=in_channels[5], embed_dim=crf_dims[5], window_size=win, num_heads=32, depth=4, num_patches=num_patches[0]) | |
# self.g2l4 = G2LFusion(input_dim=in_channels[4], embed_dim=crf_dims[4], window_size=win, num_heads=32, depth=4, num_patches=num_patches[1]) | |
# self.g2l3 = G2LFusion(input_dim=in_channels[3], embed_dim=crf_dims[3], window_size=win, num_heads=16, depth=3, num_patches=num_patches[2]) | |
# self.g2l2 = G2LFusion(input_dim=in_channels[2], embed_dim=crf_dims[2], window_size=win, num_heads=16, depth=3, num_patches=num_patches[3]) | |
# self.g2l1 = G2LFusion(input_dim=in_channels[1], embed_dim=crf_dims[1], window_size=win, num_heads=8, depth=2, num_patches=num_patches[4]) | |
# self.g2l0 = G2LFusion(input_dim=in_channels[0], embed_dim=crf_dims[0], window_size=win, num_heads=8, depth=2, num_patches=num_patches[5]) | |
# self.conv5 = DoubleConvWOBN(in_channels[5] * 2, in_channels[5], in_channels[5]) | |
# self.conv4 = DoubleConvWOBN(in_channels[4] * 2, in_channels[4], in_channels[4]) | |
# self.conv3 = DoubleConvWOBN(in_channels[3] * 2, in_channels[3], in_channels[3]) | |
# self.conv2 = DoubleConvWOBN(in_channels[2] * 2, in_channels[2], in_channels[2]) | |
# self.conv1 = DoubleConvWOBN(in_channels[1] * 2, in_channels[1], in_channels[1]) | |
# self.conv0 = DoubleConvWOBN(in_channels[0] * 2, in_channels[0], in_channels[0]) | |
def forward(self, | |
input_tensor, | |
guide_plus, | |
guide_cat, | |
bbox=None, | |
fine_feat_crop=None, | |
coarse_feat_whole=None, | |
coarse_feat_whole_hack=None, | |
coarse_feat_crop=None): | |
# apply unscaled feat to swin | |
if coarse_feat_whole_hack is not None: | |
coarse_feat_whole = coarse_feat_whole_hack | |
feat_list = [] | |
x = self.inc(input_tensor) | |
feat_list.append(x) | |
for layer in self.down_conv_list: | |
x = layer(x) | |
feat_list.append(x) | |
output = [] | |
feat_inv_list = feat_list[::-1] | |
for idx, (feat_enc, feat_c) in enumerate(zip(feat_inv_list, coarse_feat_whole)): | |
# in case for depth-anything | |
_, _, h, w = feat_enc.shape | |
if h != feat_c.shape[-2] or w != feat_c.shape[-1]: | |
feat_enc = F.interpolate(feat_enc, size=feat_c.shape[-2:], mode='bilinear', align_corners=True) | |
if idx == 0: | |
pass | |
else: | |
feat_enc = self.up_conv_list[idx-1](torch.cat([temp_feat, guide_cat[idx-1]], dim=1), feat_enc) | |
_, _, h, w = feat_c.shape | |
feat_c = self.g2l_list[idx](feat_c, None) | |
feat_c = torch_roi_align(feat_c, bbox, (h, w), h/self.patch_process_shape[0], aligned=True) | |
x = self.convs[idx](torch.cat([feat_enc, feat_c], dim=1)) | |
temp_feat = x | |
output.append(x) | |
return output[::-1] |