|
|
|
|
|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
|
|
from .resnet import Resnet18
|
|
|
|
|
|
|
|
class ConvBNReLU(nn.Module):
|
|
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
|
super(ConvBNReLU, self).__init__()
|
|
self.conv = nn.Conv2d(in_chan,
|
|
out_chan,
|
|
kernel_size = ks,
|
|
stride = stride,
|
|
padding = padding,
|
|
bias = False)
|
|
self.bn = nn.BatchNorm2d(out_chan)
|
|
self.init_weight()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = F.relu(self.bn(x))
|
|
return x
|
|
|
|
def init_weight(self):
|
|
for ly in self.children():
|
|
if isinstance(ly, nn.Conv2d):
|
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
|
|
|
class BiSeNetOutput(nn.Module):
|
|
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
|
super(BiSeNetOutput, self).__init__()
|
|
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
|
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
|
self.init_weight()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.conv_out(x)
|
|
return x
|
|
|
|
def init_weight(self):
|
|
for ly in self.children():
|
|
if isinstance(ly, nn.Conv2d):
|
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
|
|
|
def get_params(self):
|
|
wd_params, nowd_params = [], []
|
|
for name, module in self.named_modules():
|
|
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
|
wd_params.append(module.weight)
|
|
if not module.bias is None:
|
|
nowd_params.append(module.bias)
|
|
elif isinstance(module, nn.BatchNorm2d):
|
|
nowd_params += list(module.parameters())
|
|
return wd_params, nowd_params
|
|
|
|
|
|
class AttentionRefinementModule(nn.Module):
|
|
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
|
super(AttentionRefinementModule, self).__init__()
|
|
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
|
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
|
self.bn_atten = nn.BatchNorm2d(out_chan)
|
|
self.sigmoid_atten = nn.Sigmoid()
|
|
self.init_weight()
|
|
|
|
def forward(self, x):
|
|
feat = self.conv(x)
|
|
atten = F.avg_pool2d(feat, feat.size()[2:])
|
|
atten = self.conv_atten(atten)
|
|
atten = self.bn_atten(atten)
|
|
atten = self.sigmoid_atten(atten)
|
|
out = torch.mul(feat, atten)
|
|
return out
|
|
|
|
def init_weight(self):
|
|
for ly in self.children():
|
|
if isinstance(ly, nn.Conv2d):
|
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
|
|
|
|
|
class ContextPath(nn.Module):
|
|
def __init__(self, *args, **kwargs):
|
|
super(ContextPath, self).__init__()
|
|
self.resnet = Resnet18()
|
|
self.arm16 = AttentionRefinementModule(256, 128)
|
|
self.arm32 = AttentionRefinementModule(512, 128)
|
|
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
|
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
|
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
|
|
|
self.init_weight()
|
|
|
|
def forward(self, x):
|
|
H0, W0 = x.size()[2:]
|
|
feat8, feat16, feat32 = self.resnet(x)
|
|
H8, W8 = feat8.size()[2:]
|
|
H16, W16 = feat16.size()[2:]
|
|
H32, W32 = feat32.size()[2:]
|
|
|
|
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
|
avg = self.conv_avg(avg)
|
|
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
|
|
|
feat32_arm = self.arm32(feat32)
|
|
feat32_sum = feat32_arm + avg_up
|
|
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
|
feat32_up = self.conv_head32(feat32_up)
|
|
|
|
feat16_arm = self.arm16(feat16)
|
|
feat16_sum = feat16_arm + feat32_up
|
|
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
|
feat16_up = self.conv_head16(feat16_up)
|
|
|
|
return feat8, feat16_up, feat32_up
|
|
|
|
def init_weight(self):
|
|
for ly in self.children():
|
|
if isinstance(ly, nn.Conv2d):
|
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
|
|
|
def get_params(self):
|
|
wd_params, nowd_params = [], []
|
|
for name, module in self.named_modules():
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
wd_params.append(module.weight)
|
|
if not module.bias is None:
|
|
nowd_params.append(module.bias)
|
|
elif isinstance(module, nn.BatchNorm2d):
|
|
nowd_params += list(module.parameters())
|
|
return wd_params, nowd_params
|
|
|
|
|
|
|
|
class SpatialPath(nn.Module):
|
|
def __init__(self, *args, **kwargs):
|
|
super(SpatialPath, self).__init__()
|
|
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
|
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
|
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
|
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
|
self.init_weight()
|
|
|
|
def forward(self, x):
|
|
feat = self.conv1(x)
|
|
feat = self.conv2(feat)
|
|
feat = self.conv3(feat)
|
|
feat = self.conv_out(feat)
|
|
return feat
|
|
|
|
def init_weight(self):
|
|
for ly in self.children():
|
|
if isinstance(ly, nn.Conv2d):
|
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
|
|
|
def get_params(self):
|
|
wd_params, nowd_params = [], []
|
|
for name, module in self.named_modules():
|
|
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
|
wd_params.append(module.weight)
|
|
if not module.bias is None:
|
|
nowd_params.append(module.bias)
|
|
elif isinstance(module, nn.BatchNorm2d):
|
|
nowd_params += list(module.parameters())
|
|
return wd_params, nowd_params
|
|
|
|
|
|
class FeatureFusionModule(nn.Module):
|
|
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
|
super(FeatureFusionModule, self).__init__()
|
|
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
|
self.conv1 = nn.Conv2d(out_chan,
|
|
out_chan//4,
|
|
kernel_size = 1,
|
|
stride = 1,
|
|
padding = 0,
|
|
bias = False)
|
|
self.conv2 = nn.Conv2d(out_chan//4,
|
|
out_chan,
|
|
kernel_size = 1,
|
|
stride = 1,
|
|
padding = 0,
|
|
bias = False)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.sigmoid = nn.Sigmoid()
|
|
self.init_weight()
|
|
|
|
def forward(self, fsp, fcp):
|
|
fcat = torch.cat([fsp, fcp], dim=1)
|
|
feat = self.convblk(fcat)
|
|
atten = F.avg_pool2d(feat, feat.size()[2:])
|
|
atten = self.conv1(atten)
|
|
atten = self.relu(atten)
|
|
atten = self.conv2(atten)
|
|
atten = self.sigmoid(atten)
|
|
feat_atten = torch.mul(feat, atten)
|
|
feat_out = feat_atten + feat
|
|
return feat_out
|
|
|
|
def init_weight(self):
|
|
for ly in self.children():
|
|
if isinstance(ly, nn.Conv2d):
|
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
|
|
|
def get_params(self):
|
|
wd_params, nowd_params = [], []
|
|
for name, module in self.named_modules():
|
|
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
|
wd_params.append(module.weight)
|
|
if not module.bias is None:
|
|
nowd_params.append(module.bias)
|
|
elif isinstance(module, nn.BatchNorm2d):
|
|
nowd_params += list(module.parameters())
|
|
return wd_params, nowd_params
|
|
|
|
|
|
class BiSeNet(nn.Module):
|
|
def __init__(self, n_classes, *args, **kwargs):
|
|
super(BiSeNet, self).__init__()
|
|
self.cp = ContextPath()
|
|
|
|
self.ffm = FeatureFusionModule(256, 256)
|
|
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
|
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
|
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
|
self.init_weight()
|
|
|
|
def forward(self, x):
|
|
H, W = x.size()[2:]
|
|
feat_res8, feat_cp8, feat_cp16 = self.cp(x)
|
|
feat_sp = feat_res8
|
|
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
|
|
|
feat_out = self.conv_out(feat_fuse)
|
|
feat_out16 = self.conv_out16(feat_cp8)
|
|
feat_out32 = self.conv_out32(feat_cp16)
|
|
|
|
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
|
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
|
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
|
return feat_out, feat_out16, feat_out32
|
|
|
|
def init_weight(self):
|
|
for ly in self.children():
|
|
if isinstance(ly, nn.Conv2d):
|
|
nn.init.kaiming_normal_(ly.weight, a=1)
|
|
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
|
|
|
def get_params(self):
|
|
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
|
for name, child in self.named_children():
|
|
child_wd_params, child_nowd_params = child.get_params()
|
|
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
|
|
lr_mul_wd_params += child_wd_params
|
|
lr_mul_nowd_params += child_nowd_params
|
|
else:
|
|
wd_params += child_wd_params
|
|
nowd_params += child_nowd_params
|
|
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
|
|
|
|
|
if __name__ == "__main__":
|
|
net = BiSeNet(19)
|
|
net.cuda()
|
|
net.eval()
|
|
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
|
out, out16, out32 = net(in_ten)
|
|
print(out.shape)
|
|
|
|
net.get_params()
|
|
|