# # Authors: Wouter Van Gansbeke & Simon Vandenhende # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) import torch.nn as nn class ResnetDilated(nn.Module): def __init__(self, orig_resnet, dilate_scale=8): super(ResnetDilated, self).__init__() from functools import partial if dilate_scale == 8: orig_resnet.layer3.apply( partial(self._nostride_dilate, dilate=2)) orig_resnet.layer4.apply( partial(self._nostride_dilate, dilate=4)) elif dilate_scale == 16: orig_resnet.layer4.apply( partial(self._nostride_dilate, dilate=2)) self.conv1 = orig_resnet.conv1 self.bn1 = orig_resnet.bn1 self.relu = orig_resnet.relu self.maxpool = orig_resnet.maxpool self.layer1 = orig_resnet.layer1 self.layer2 = orig_resnet.layer2 self.layer3 = orig_resnet.layer3 self.layer4 = orig_resnet.layer4 def _nostride_dilate(self, m, dilate): classname = m.__class__.__name__ if classname.find('Conv') != -1: # the convolution with stride if m.stride == (2, 2): m.stride = (1, 1) if m.kernel_size == (3, 3): m.dilation = (dilate//2, dilate//2) m.padding = (dilate//2, dilate//2) # other convoluions else: if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) def forward(self, x): x = self.relu(self.bn1(self.conv1(x))) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x