# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn import torch.nn.functional as F import importlib def class_for_name(module_name, class_name): # load the module, will raise ImportError if module cannot be loaded m = importlib.import_module(module_name) return getattr(m, class_name) def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation, padding_mode="reflect", ) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d( in_planes, out_planes, kernel_size=1, stride=stride, bias=False, padding_mode="reflect", ) class BasicBlock(nn.Module): expansion = 1 def __init__( self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None, ): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d # norm_layer = nn.InstanceNorm2d if groups != 1 or base_width != 64: raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes, track_running_stats=False, affine=True) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes, track_running_stats=False, affine=True) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) # while original implementation places the stride at the first 1x1 convolution(self.conv1) # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. expansion = 4 def __init__( self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None, ): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d # norm_layer = nn.InstanceNorm2d width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width, track_running_stats=False, affine=True) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width, track_running_stats=False, affine=True) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer( planes * self.expansion, track_running_stats=False, affine=True ) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class conv(nn.Module): def __init__(self, num_in_layers, num_out_layers, kernel_size, stride): super(conv, self).__init__() self.kernel_size = kernel_size self.conv = nn.Conv2d( num_in_layers, num_out_layers, kernel_size=kernel_size, stride=stride, padding=(self.kernel_size - 1) // 2, padding_mode="reflect", ) # self.bn = nn.InstanceNorm2d( # num_out_layers, track_running_stats=False, affine=True # ) self.bn = nn.BatchNorm2d(num_out_layers, track_running_stats=False, affine=True) # self.bn = nn.LayerNorm(num_out_layers) def forward(self, x): return F.elu(self.bn(self.conv(x)), inplace=True) class upconv(nn.Module): def __init__(self, num_in_layers, num_out_layers, kernel_size, scale): super(upconv, self).__init__() self.scale = scale self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1) def forward(self, x): x = nn.functional.interpolate( x, scale_factor=self.scale, align_corners=True, mode="bilinear" ) return self.conv(x) class ResUNet(nn.Module): def __init__( self, encoder="resnet34", coarse_out_ch=32, fine_out_ch=32, norm_layer=None, coarse_only=False, ): super(ResUNet, self).__init__() assert encoder in [ "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", ], "Incorrect encoder type" if encoder in ["resnet18", "resnet34"]: filters = [64, 128, 256, 512] else: filters = [256, 512, 1024, 2048] self.coarse_only = coarse_only if self.coarse_only: fine_out_ch = 0 self.coarse_out_ch = coarse_out_ch self.fine_out_ch = fine_out_ch out_ch = coarse_out_ch + fine_out_ch # original layers = [3, 4, 6, 3] if norm_layer is None: norm_layer = nn.BatchNorm2d # norm_layer = nn.InstanceNorm2d self._norm_layer = norm_layer self.dilation = 1 block = BasicBlock replace_stride_with_dilation = [False, False, False] self.inplanes = 64 self.groups = 1 self.base_width = 64 self.conv1 = nn.Conv2d( 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False, padding_mode="reflect", ) self.bn1 = norm_layer(self.inplanes, track_running_stats=False, affine=True) self.relu = nn.ReLU(inplace=True) self.layer1 = self._make_layer(block, 64, layers[0], stride=2) self.layer2 = self._make_layer( block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] ) self.layer3 = self._make_layer( block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] ) # decoder self.upconv3 = upconv(filters[2], 128, 3, 2) self.iconv3 = conv(filters[1] + 128, 128, 3, 1) self.upconv2 = upconv(128, 64, 3, 2) self.iconv2 = conv(filters[0] + 64, out_ch, 3, 1) # fine-level conv self.out_conv = nn.Conv2d(out_ch, out_ch, 1, 1) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer( planes * block.expansion, track_running_stats=False, affine=True ), ) layers = [] layers.append( block( self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer, ) ) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( block( self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, ) ) return nn.Sequential(*layers) def skipconnect(self, x1, x2): diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) # for padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return x def forward(self, x): x = self.relu(self.bn1(self.conv1(x))) x1 = self.layer1(x) x2 = self.layer2(x1) x3 = self.layer3(x2) x = self.upconv3(x3) x = self.skipconnect(x2, x) x = self.iconv3(x) x = self.upconv2(x) x = self.skipconnect(x1, x) x = self.iconv2(x) x_out = self.out_conv(x) return x_out # if self.coarse_only: # x_coarse = x_out # x_fine = None # else: # x_coarse = x_out[:, : self.coarse_out_ch, :] # x_fine = x_out[:, -self.fine_out_ch :, :] # return x_coarse, x_fine