import torch import torch.nn as nn import numpy as np import scipy.io as sio import torch.nn.functional as F class RCF(nn.Module): def __init__(self, pretrained=None): super(RCF, self).__init__() self.conv1_1 = nn.Conv2d( 3, 64, 3, padding=1, dilation=1) self.conv1_2 = nn.Conv2d( 64, 64, 3, padding=1, dilation=1) self.conv2_1 = nn.Conv2d( 64, 128, 3, padding=1, dilation=1) self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1, dilation=1) self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1, dilation=1) self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1, dilation=1) self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1, dilation=1) self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1, dilation=1) self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1, dilation=1) self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1, dilation=1) self.conv5_1 = nn.Conv2d(512, 512, 3, padding=2, dilation=2) self.conv5_2 = nn.Conv2d(512, 512, 3, padding=2, dilation=2) self.conv5_3 = nn.Conv2d(512, 512, 3, padding=2, dilation=2) self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.pool4 = nn.MaxPool2d(2, stride=1, ceil_mode=True) self.act = nn.ReLU(inplace=True) self.conv1_1_down = nn.Conv2d( 64, 21, 1) self.conv1_2_down = nn.Conv2d( 64, 21, 1) self.conv2_1_down = nn.Conv2d(128, 21, 1) self.conv2_2_down = nn.Conv2d(128, 21, 1) self.conv3_1_down = nn.Conv2d(256, 21, 1) self.conv3_2_down = nn.Conv2d(256, 21, 1) self.conv3_3_down = nn.Conv2d(256, 21, 1) self.conv4_1_down = nn.Conv2d(512, 21, 1) self.conv4_2_down = nn.Conv2d(512, 21, 1) self.conv4_3_down = nn.Conv2d(512, 21, 1) self.conv5_1_down = nn.Conv2d(512, 21, 1) self.conv5_2_down = nn.Conv2d(512, 21, 1) self.conv5_3_down = nn.Conv2d(512, 21, 1) self.score_dsn1 = nn.Conv2d(21, 1, 1) self.score_dsn2 = nn.Conv2d(21, 1, 1) self.score_dsn3 = nn.Conv2d(21, 1, 1) self.score_dsn4 = nn.Conv2d(21, 1, 1) self.score_dsn5 = nn.Conv2d(21, 1, 1) self.score_fuse = nn.Conv2d(5, 1, 1) self.weight_deconv2 = self._make_bilinear_weights( 4, 1).cuda() self.weight_deconv3 = self._make_bilinear_weights( 8, 1).cuda() self.weight_deconv4 = self._make_bilinear_weights(16, 1).cuda() self.weight_deconv5 = self._make_bilinear_weights(16, 1).cuda() # init weights self.apply(self._init_weights) if pretrained is not None: vgg16 = sio.loadmat(pretrained) torch_params = self.state_dict() for k in vgg16.keys(): name_par = k.split('-') size = len(name_par) if size == 2: name_space = name_par[0] + '.' + name_par[1] data = np.squeeze(vgg16[k]) torch_params[name_space] = torch.from_numpy(data) self.load_state_dict(torch_params) def _init_weights(self, m): if isinstance(m, nn.Conv2d): m.weight.data.normal_(0, 0.01) if m.weight.data.shape == torch.Size([1, 5, 1, 1]): nn.init.constant_(m.weight, 0.2) if m.bias is not None: nn.init.constant_(m.bias, 0) # Based on HED implementation @ https://github.com/xwjabc/hed def _make_bilinear_weights(self, size, num_channels): factor = (size + 1) // 2 if size % 2 == 1: center = factor - 1 else: center = factor - 0.5 og = np.ogrid[:size, :size] filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) filt = torch.from_numpy(filt) w = torch.zeros(num_channels, num_channels, size, size) w.requires_grad = False for i in range(num_channels): for j in range(num_channels): if i == j: w[i, j] = filt return w # Based on BDCN implementation @ https://github.com/pkuCactus/BDCN def _crop(self, data, img_h, img_w, crop_h, crop_w): _, _, h, w = data.size() assert(img_h <= h and img_w <= w) data = data[:, :, crop_h:crop_h + img_h, crop_w:crop_w + img_w] return data def forward(self, x): img_h, img_w = x.shape[2], x.shape[3] conv1_1 = self.act(self.conv1_1(x)) conv1_2 = self.act(self.conv1_2(conv1_1)) pool1 = self.pool1(conv1_2) conv2_1 = self.act(self.conv2_1(pool1)) conv2_2 = self.act(self.conv2_2(conv2_1)) pool2 = self.pool2(conv2_2) conv3_1 = self.act(self.conv3_1(pool2)) conv3_2 = self.act(self.conv3_2(conv3_1)) conv3_3 = self.act(self.conv3_3(conv3_2)) pool3 = self.pool3(conv3_3) conv4_1 = self.act(self.conv4_1(pool3)) conv4_2 = self.act(self.conv4_2(conv4_1)) conv4_3 = self.act(self.conv4_3(conv4_2)) pool4 = self.pool4(conv4_3) conv5_1 = self.act(self.conv5_1(pool4)) conv5_2 = self.act(self.conv5_2(conv5_1)) conv5_3 = self.act(self.conv5_3(conv5_2)) conv1_1_down = self.conv1_1_down(conv1_1) conv1_2_down = self.conv1_2_down(conv1_2) conv2_1_down = self.conv2_1_down(conv2_1) conv2_2_down = self.conv2_2_down(conv2_2) conv3_1_down = self.conv3_1_down(conv3_1) conv3_2_down = self.conv3_2_down(conv3_2) conv3_3_down = self.conv3_3_down(conv3_3) conv4_1_down = self.conv4_1_down(conv4_1) conv4_2_down = self.conv4_2_down(conv4_2) conv4_3_down = self.conv4_3_down(conv4_3) conv5_1_down = self.conv5_1_down(conv5_1) conv5_2_down = self.conv5_2_down(conv5_2) conv5_3_down = self.conv5_3_down(conv5_3) out1 = self.score_dsn1(conv1_1_down + conv1_2_down) out2 = self.score_dsn2(conv2_1_down + conv2_2_down) out3 = self.score_dsn3(conv3_1_down + conv3_2_down + conv3_3_down) out4 = self.score_dsn4(conv4_1_down + conv4_2_down + conv4_3_down) out5 = self.score_dsn5(conv5_1_down + conv5_2_down + conv5_3_down) out2 = F.conv_transpose2d(out2, self.weight_deconv2, stride=2) out3 = F.conv_transpose2d(out3, self.weight_deconv3, stride=4) out4 = F.conv_transpose2d(out4, self.weight_deconv4, stride=8) out5 = F.conv_transpose2d(out5, self.weight_deconv5, stride=8) out2 = self._crop(out2, img_h, img_w, 1, 1) out3 = self._crop(out3, img_h, img_w, 2, 2) out4 = self._crop(out4, img_h, img_w, 4, 4) out5 = self._crop(out5, img_h, img_w, 0, 0) fuse = torch.cat((out1, out2, out3, out4, out5), dim=1) fuse = self.score_fuse(fuse) results = [out1, out2, out3, out4, out5, fuse] results = [torch.sigmoid(r) for r in results] return results