zdou0830's picture
desco
749745d
raw history blame
No virus
5.07 kB
import torch
import torch.nn.functional as F
from torch import nn
from .deform_conv import ModulatedDeformConv
from .dyrelu import h_sigmoid, DYReLU
class Conv3x3Norm(torch.nn.Module):
def __init__(self, in_channels, out_channels, stride, deformable=False, use_gn=False):
super(Conv3x3Norm, self).__init__()
if deformable:
self.conv = ModulatedDeformConv(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
else:
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
if use_gn:
self.bn = nn.GroupNorm(num_groups=16, num_channels=out_channels)
else:
self.bn = None
def forward(self, input, **kwargs):
x = self.conv(input, **kwargs)
if self.bn:
x = self.bn(x)
return x
class DyConv(nn.Module):
def __init__(
self,
in_channels=256,
out_channels=256,
conv_func=Conv3x3Norm,
use_dyfuse=True,
use_dyrelu=False,
use_deform=False,
):
super(DyConv, self).__init__()
self.DyConv = nn.ModuleList()
self.DyConv.append(conv_func(in_channels, out_channels, 1))
self.DyConv.append(conv_func(in_channels, out_channels, 1))
self.DyConv.append(conv_func(in_channels, out_channels, 2))
if use_dyfuse:
self.AttnConv = nn.Sequential(
nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, 1, kernel_size=1), nn.ReLU(inplace=True)
)
self.h_sigmoid = h_sigmoid()
else:
self.AttnConv = None
if use_dyrelu:
self.relu = DYReLU(in_channels, out_channels)
else:
self.relu = nn.ReLU()
if use_deform:
self.offset = nn.Conv2d(in_channels, 27, kernel_size=3, stride=1, padding=1)
else:
self.offset = None
self.init_weights()
def init_weights(self):
for m in self.DyConv.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, 0, 0.01)
if m.bias is not None:
m.bias.data.zero_()
if self.AttnConv is not None:
for m in self.AttnConv.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, 0, 0.01)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
next_x = []
for level, feature in enumerate(x):
conv_args = dict()
if self.offset is not None:
offset_mask = self.offset(feature)
offset = offset_mask[:, :18, :, :]
mask = offset_mask[:, 18:, :, :].sigmoid()
conv_args = dict(offset=offset, mask=mask)
temp_fea = [self.DyConv[1](feature, **conv_args)]
if level > 0:
temp_fea.append(self.DyConv[2](x[level - 1], **conv_args))
if level < len(x) - 1:
temp_fea.append(
F.upsample_bilinear(
self.DyConv[0](x[level + 1], **conv_args), size=[feature.size(2), feature.size(3)]
)
)
mean_fea = torch.mean(torch.stack(temp_fea), dim=0, keepdim=False)
if self.AttnConv is not None:
attn_fea = []
res_fea = []
for fea in temp_fea:
res_fea.append(fea)
attn_fea.append(self.AttnConv(fea))
res_fea = torch.stack(res_fea)
spa_pyr_attn = self.h_sigmoid(torch.stack(attn_fea))
mean_fea = torch.mean(res_fea * spa_pyr_attn, dim=0, keepdim=False)
next_x.append(self.relu(mean_fea))
return next_x
class DyHead(nn.Module):
def __init__(self, cfg, in_channels):
super(DyHead, self).__init__()
self.cfg = cfg
channels = cfg.MODEL.DYHEAD.CHANNELS
use_gn = cfg.MODEL.DYHEAD.USE_GN
use_dyrelu = cfg.MODEL.DYHEAD.USE_DYRELU
use_dyfuse = cfg.MODEL.DYHEAD.USE_DYFUSE
use_deform = cfg.MODEL.DYHEAD.USE_DFCONV
conv_func = lambda i, o, s: Conv3x3Norm(i, o, s, deformable=use_deform, use_gn=use_gn)
dyhead_tower = []
for i in range(cfg.MODEL.DYHEAD.NUM_CONVS):
dyhead_tower.append(
DyConv(
in_channels if i == 0 else channels,
channels,
conv_func=conv_func,
use_dyrelu=use_dyrelu,
use_dyfuse=use_dyfuse,
use_deform=use_deform,
)
)
self.add_module("dyhead_tower", nn.Sequential(*dyhead_tower))
def forward(self, x):
dyhead_tower = self.dyhead_tower(x)
return dyhead_tower