fmsfm's picture
Upload 13 files
1ff2d47
import torch
import torch.nn as nn
import numpy as np
import scipy.io as sio
import torch.nn.functional as F
class RCF(nn.Module):
def __init__(self, pretrained=None):
super(RCF, self).__init__()
self.conv1_1 = nn.Conv2d( 3, 64, 3, padding=1, dilation=1)
self.conv1_2 = nn.Conv2d( 64, 64, 3, padding=1, dilation=1)
self.conv2_1 = nn.Conv2d( 64, 128, 3, padding=1, dilation=1)
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1, dilation=1)
self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1, dilation=1)
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1, dilation=1)
self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1, dilation=1)
self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1, dilation=1)
self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1, dilation=1)
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1, dilation=1)
self.conv5_1 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.pool4 = nn.MaxPool2d(2, stride=1, ceil_mode=True)
self.act = nn.ReLU(inplace=True)
self.conv1_1_down = nn.Conv2d( 64, 21, 1)
self.conv1_2_down = nn.Conv2d( 64, 21, 1)
self.conv2_1_down = nn.Conv2d(128, 21, 1)
self.conv2_2_down = nn.Conv2d(128, 21, 1)
self.conv3_1_down = nn.Conv2d(256, 21, 1)
self.conv3_2_down = nn.Conv2d(256, 21, 1)
self.conv3_3_down = nn.Conv2d(256, 21, 1)
self.conv4_1_down = nn.Conv2d(512, 21, 1)
self.conv4_2_down = nn.Conv2d(512, 21, 1)
self.conv4_3_down = nn.Conv2d(512, 21, 1)
self.conv5_1_down = nn.Conv2d(512, 21, 1)
self.conv5_2_down = nn.Conv2d(512, 21, 1)
self.conv5_3_down = nn.Conv2d(512, 21, 1)
self.score_dsn1 = nn.Conv2d(21, 1, 1)
self.score_dsn2 = nn.Conv2d(21, 1, 1)
self.score_dsn3 = nn.Conv2d(21, 1, 1)
self.score_dsn4 = nn.Conv2d(21, 1, 1)
self.score_dsn5 = nn.Conv2d(21, 1, 1)
self.score_fuse = nn.Conv2d(5, 1, 1)
self.weight_deconv2 = self._make_bilinear_weights( 4, 1).cuda()
self.weight_deconv3 = self._make_bilinear_weights( 8, 1).cuda()
self.weight_deconv4 = self._make_bilinear_weights(16, 1).cuda()
self.weight_deconv5 = self._make_bilinear_weights(16, 1).cuda()
# init weights
self.apply(self._init_weights)
if pretrained is not None:
vgg16 = sio.loadmat(pretrained)
torch_params = self.state_dict()
for k in vgg16.keys():
name_par = k.split('-')
size = len(name_par)
if size == 2:
name_space = name_par[0] + '.' + name_par[1]
data = np.squeeze(vgg16[k])
torch_params[name_space] = torch.from_numpy(data)
self.load_state_dict(torch_params)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.01)
if m.weight.data.shape == torch.Size([1, 5, 1, 1]):
nn.init.constant_(m.weight, 0.2)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# Based on HED implementation @ https://github.com/xwjabc/hed
def _make_bilinear_weights(self, size, num_channels):
factor = (size + 1) // 2
if size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:size, :size]
filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
filt = torch.from_numpy(filt)
w = torch.zeros(num_channels, num_channels, size, size)
w.requires_grad = False
for i in range(num_channels):
for j in range(num_channels):
if i == j:
w[i, j] = filt
return w
# Based on BDCN implementation @ https://github.com/pkuCactus/BDCN
def _crop(self, data, img_h, img_w, crop_h, crop_w):
_, _, h, w = data.size()
assert(img_h <= h and img_w <= w)
data = data[:, :, crop_h:crop_h + img_h, crop_w:crop_w + img_w]
return data
def forward(self, x):
img_h, img_w = x.shape[2], x.shape[3]
conv1_1 = self.act(self.conv1_1(x))
conv1_2 = self.act(self.conv1_2(conv1_1))
pool1 = self.pool1(conv1_2)
conv2_1 = self.act(self.conv2_1(pool1))
conv2_2 = self.act(self.conv2_2(conv2_1))
pool2 = self.pool2(conv2_2)
conv3_1 = self.act(self.conv3_1(pool2))
conv3_2 = self.act(self.conv3_2(conv3_1))
conv3_3 = self.act(self.conv3_3(conv3_2))
pool3 = self.pool3(conv3_3)
conv4_1 = self.act(self.conv4_1(pool3))
conv4_2 = self.act(self.conv4_2(conv4_1))
conv4_3 = self.act(self.conv4_3(conv4_2))
pool4 = self.pool4(conv4_3)
conv5_1 = self.act(self.conv5_1(pool4))
conv5_2 = self.act(self.conv5_2(conv5_1))
conv5_3 = self.act(self.conv5_3(conv5_2))
conv1_1_down = self.conv1_1_down(conv1_1)
conv1_2_down = self.conv1_2_down(conv1_2)
conv2_1_down = self.conv2_1_down(conv2_1)
conv2_2_down = self.conv2_2_down(conv2_2)
conv3_1_down = self.conv3_1_down(conv3_1)
conv3_2_down = self.conv3_2_down(conv3_2)
conv3_3_down = self.conv3_3_down(conv3_3)
conv4_1_down = self.conv4_1_down(conv4_1)
conv4_2_down = self.conv4_2_down(conv4_2)
conv4_3_down = self.conv4_3_down(conv4_3)
conv5_1_down = self.conv5_1_down(conv5_1)
conv5_2_down = self.conv5_2_down(conv5_2)
conv5_3_down = self.conv5_3_down(conv5_3)
out1 = self.score_dsn1(conv1_1_down + conv1_2_down)
out2 = self.score_dsn2(conv2_1_down + conv2_2_down)
out3 = self.score_dsn3(conv3_1_down + conv3_2_down + conv3_3_down)
out4 = self.score_dsn4(conv4_1_down + conv4_2_down + conv4_3_down)
out5 = self.score_dsn5(conv5_1_down + conv5_2_down + conv5_3_down)
out2 = F.conv_transpose2d(out2, self.weight_deconv2, stride=2)
out3 = F.conv_transpose2d(out3, self.weight_deconv3, stride=4)
out4 = F.conv_transpose2d(out4, self.weight_deconv4, stride=8)
out5 = F.conv_transpose2d(out5, self.weight_deconv5, stride=8)
out2 = self._crop(out2, img_h, img_w, 1, 1)
out3 = self._crop(out3, img_h, img_w, 2, 2)
out4 = self._crop(out4, img_h, img_w, 4, 4)
out5 = self._crop(out5, img_h, img_w, 0, 0)
fuse = torch.cat((out1, out2, out3, out4, out5), dim=1)
fuse = self.score_fuse(fuse)
results = [out1, out2, out3, out4, out5, fuse]
results = [torch.sigmoid(r) for r in results]
return results