PatchFusion / estimator /models /blocks /guided_fusion_model.py
Zhyever
refactor
1f418ff
raw
history blame
No virus
8.96 kB
# MIT License
# Copyright (c) 2022 Intelligent Systems Lab Org
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# File author: Zhenyu Li
import torch
import torch.nn as nn
import torch.nn.functional as F
# from zoedepth.models.layers.swin_layers import G2LFusion
from estimator.models.blocks.swin_layers import G2LFusion
from torchvision.ops import roi_align as torch_roi_align
from estimator.registry import MODELS
class DoubleConvWOBN(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=True),
# nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=True),
# nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True))
def forward(self, x):
return self.double_conv(x)
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Upv1(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
# self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
if mid_channels is not None:
self.conv = DoubleConvWOBN(in_channels, out_channels, mid_channels)
else:
self.conv = DoubleConvWOBN(in_channels, out_channels, in_channels)
def forward(self, x1, x2):
# x1 = self.up(x1)
x1 = F.interpolate(x1, size=x2.shape[-2:], mode='bilinear', align_corners=True)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
@MODELS.register_module()
class GuidedFusionPatchFusion(nn.Module):
def __init__(
self,
n_channels,
g2l,
in_channels=[32, 256, 256, 256, 256, 256],
depth=[2, 2, 3, 3, 4, 4],
num_heads=[8, 8, 16, 16, 32, 32],
# num_patches=[12*16, 24*32, 48*64, 96*128, 192*256, 384*512],
num_patches=[384*512, 192*256, 96*128, 48*64, 24*32, 12*16],
patch_process_shape=[384, 512]):
super(GuidedFusionPatchFusion, self).__init__()
self.n_channels = n_channels
self.inc = DoubleConv(n_channels, in_channels[0])
self.down_conv_list = nn.ModuleList()
for idx in range(len(in_channels) - 1):
lay = Down(in_channels[idx], in_channels[idx+1])
self.down_conv_list.append(lay)
in_channels_inv = in_channels[::-1]
self.up_conv_list = nn.ModuleList()
for idx in range(1, len(in_channels)):
lay = Upv1(in_channels_inv[idx] + in_channels_inv[idx-1] + in_channels_inv[idx-1], in_channels_inv[idx])
self.up_conv_list.append(lay)
self.g2l = g2l
if self.g2l:
self.g2l_att = nn.ModuleList()
win = 12
self.patch_process_shape = patch_process_shape
num_heads_inv = num_heads[::-1]
depth_inv = depth[::-1]
num_patches_inv = num_patches[::-1]
self.g2l_list = nn.ModuleList()
self.convs = nn.ModuleList()
for idx in range(len(in_channels_inv)):
g2l_layer = G2LFusion(input_dim=in_channels_inv[idx], embed_dim=in_channels_inv[idx], window_size=win, num_heads=num_heads_inv[idx], depth=depth_inv[idx], num_patches=num_patches_inv[idx])
self.g2l_list.append(g2l_layer)
layer = DoubleConvWOBN(in_channels_inv[idx] * 2, in_channels_inv[idx], in_channels_inv[idx])
self.convs.append(layer)
# self.g2l5 = G2LFusion(input_dim=in_channels[5], embed_dim=crf_dims[5], window_size=win, num_heads=32, depth=4, num_patches=num_patches[0])
# self.g2l4 = G2LFusion(input_dim=in_channels[4], embed_dim=crf_dims[4], window_size=win, num_heads=32, depth=4, num_patches=num_patches[1])
# self.g2l3 = G2LFusion(input_dim=in_channels[3], embed_dim=crf_dims[3], window_size=win, num_heads=16, depth=3, num_patches=num_patches[2])
# self.g2l2 = G2LFusion(input_dim=in_channels[2], embed_dim=crf_dims[2], window_size=win, num_heads=16, depth=3, num_patches=num_patches[3])
# self.g2l1 = G2LFusion(input_dim=in_channels[1], embed_dim=crf_dims[1], window_size=win, num_heads=8, depth=2, num_patches=num_patches[4])
# self.g2l0 = G2LFusion(input_dim=in_channels[0], embed_dim=crf_dims[0], window_size=win, num_heads=8, depth=2, num_patches=num_patches[5])
# self.conv5 = DoubleConvWOBN(in_channels[5] * 2, in_channels[5], in_channels[5])
# self.conv4 = DoubleConvWOBN(in_channels[4] * 2, in_channels[4], in_channels[4])
# self.conv3 = DoubleConvWOBN(in_channels[3] * 2, in_channels[3], in_channels[3])
# self.conv2 = DoubleConvWOBN(in_channels[2] * 2, in_channels[2], in_channels[2])
# self.conv1 = DoubleConvWOBN(in_channels[1] * 2, in_channels[1], in_channels[1])
# self.conv0 = DoubleConvWOBN(in_channels[0] * 2, in_channels[0], in_channels[0])
def forward(self,
input_tensor,
guide_plus,
guide_cat,
bbox=None,
fine_feat_crop=None,
coarse_feat_whole=None,
coarse_feat_whole_hack=None,
coarse_feat_crop=None):
# apply unscaled feat to swin
if coarse_feat_whole_hack is not None:
coarse_feat_whole = coarse_feat_whole_hack
feat_list = []
x = self.inc(input_tensor)
feat_list.append(x)
for layer in self.down_conv_list:
x = layer(x)
feat_list.append(x)
output = []
feat_inv_list = feat_list[::-1]
for idx, (feat_enc, feat_c) in enumerate(zip(feat_inv_list, coarse_feat_whole)):
# in case for depth-anything
_, _, h, w = feat_enc.shape
if h != feat_c.shape[-2] or w != feat_c.shape[-1]:
feat_enc = F.interpolate(feat_enc, size=feat_c.shape[-2:], mode='bilinear', align_corners=True)
if idx == 0:
pass
else:
feat_enc = self.up_conv_list[idx-1](torch.cat([temp_feat, guide_cat[idx-1]], dim=1), feat_enc)
_, _, h, w = feat_c.shape
feat_c = self.g2l_list[idx](feat_c, None)
feat_c = torch_roi_align(feat_c, bbox, (h, w), h/self.patch_process_shape[0], aligned=True)
x = self.convs[idx](torch.cat([feat_enc, feat_c], dim=1))
temp_feat = x
output.append(x)
return output[::-1]