|
import math |
|
import torch |
|
import torch.nn as nn |
|
from models.networks.drn import drn_c_26 |
|
|
|
|
|
def fill_up_weights(up): |
|
w = up.weight.data |
|
f = math.ceil(w.size(2) / 2) |
|
c = (2 * f - 1 - f % 2) / (2. * f) |
|
for i in range(w.size(2)): |
|
for j in range(w.size(3)): |
|
w[0, 0, i, j] = \ |
|
(1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) |
|
for c in range(1, w.size(0)): |
|
w[c, 0, :, :] = w[0, 0, :, :] |
|
|
|
|
|
class DRNSeg(nn.Module): |
|
def __init__(self, classes, pretrained_drn=False, |
|
pretrained_model=None, use_torch_up=False): |
|
super(DRNSeg, self).__init__() |
|
|
|
model = drn_c_26(pretrained=pretrained_drn) |
|
self.base = nn.Sequential(*list(model.children())[:-2]) |
|
if pretrained_model: |
|
self.load_pretrained(pretrained_model) |
|
|
|
self.seg = nn.Conv2d(model.out_dim, classes, |
|
kernel_size=1, bias=True) |
|
|
|
m = self.seg |
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
m.weight.data.normal_(0, math.sqrt(2. / n)) |
|
m.bias.data.zero_() |
|
if use_torch_up: |
|
self.up = nn.UpsamplingBilinear2d(scale_factor=8) |
|
else: |
|
up = nn.ConvTranspose2d(classes, classes, 16, stride=8, padding=4, |
|
output_padding=0, groups=classes, |
|
bias=False) |
|
fill_up_weights(up) |
|
up.weight.requires_grad = False |
|
self.up = up |
|
|
|
def forward(self, x): |
|
x = self.base(x) |
|
x = self.seg(x) |
|
y = self.up(x) |
|
return y |
|
|
|
def optim_parameters(self, memo=None): |
|
for param in self.base.parameters(): |
|
yield param |
|
for param in self.seg.parameters(): |
|
yield param |
|
|
|
def load_pretrained(self, pretrained_model): |
|
print("loading the pretrained drn model from %s" % pretrained_model) |
|
state_dict = torch.load(pretrained_model, map_location='cpu') |
|
if hasattr(state_dict, '_metadata'): |
|
del state_dict._metadata |
|
|
|
|
|
pretrained_dict = state_dict['model'] |
|
pretrained_dict = {k[5:]: v for k, v in pretrained_dict.items() if k.split('.')[0] == 'base'} |
|
|
|
|
|
self.base.load_state_dict(pretrained_dict) |
|
|
|
|
|
class DRNSub(nn.Module): |
|
def __init__(self, num_classes, pretrained_model=None, fix_base=False): |
|
super(DRNSub, self).__init__() |
|
|
|
drnseg = DRNSeg(2) |
|
if pretrained_model: |
|
print("loading the pretrained drn model from %s" % pretrained_model) |
|
state_dict = torch.load(pretrained_model, map_location='cpu') |
|
drnseg.load_state_dict(state_dict['model']) |
|
|
|
self.base = drnseg.base |
|
if fix_base: |
|
for param in self.base.parameters(): |
|
param.requires_grad = False |
|
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
|
self.fc = nn.Linear(512, num_classes) |
|
|
|
def forward(self, x): |
|
x = self.base(x) |
|
x = self.avgpool(x) |
|
x = x.view(x.size(0), -1) |
|
x = self.fc(x) |
|
return x |