# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch import torch.nn as nn import torch.nn.functional as F from detection_models.sync_batchnorm import DataParallelWithCallback from detection_models.antialiasing import Downsample class UNet(nn.Module): def __init__( self, in_channels=3, out_channels=3, depth=5, conv_num=2, wf=6, padding=True, batch_norm=True, up_mode="upsample", with_tanh=False, sync_bn=True, antialiasing=True, ): """ Implementation of U-Net: Convolutional Networks for Biomedical Image Segmentation (Ronneberger et al., 2015) https://arxiv.org/abs/1505.04597 Using the default arguments will yield the exact version used in the original paper Args: in_channels (int): number of input channels out_channels (int): number of output channels depth (int): depth of the network wf (int): number of filters in the first layer is 2**wf padding (bool): if True, apply padding such that the input shape is the same as the output. This may introduce artifacts batch_norm (bool): Use BatchNorm after layers with an activation function up_mode (str): one of 'upconv' or 'upsample'. 'upconv' will use transposed convolutions for learned upsampling. 'upsample' will use bilinear upsampling. """ super().__init__() assert up_mode in ("upconv", "upsample") self.padding = padding self.depth = depth - 1 prev_channels = in_channels self.first = nn.Sequential( *[nn.ReflectionPad2d(3), nn.Conv2d(in_channels, 2 ** wf, kernel_size=7), nn.LeakyReLU(0.2, True)] ) prev_channels = 2 ** wf self.down_path = nn.ModuleList() self.down_sample = nn.ModuleList() for i in range(depth): if antialiasing and depth > 0: self.down_sample.append( nn.Sequential( *[ nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, prev_channels, kernel_size=3, stride=1, padding=0), nn.BatchNorm2d(prev_channels), nn.LeakyReLU(0.2, True), Downsample(channels=prev_channels, stride=2), ] ) ) else: self.down_sample.append( nn.Sequential( *[ nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, prev_channels, kernel_size=4, stride=2, padding=0), nn.BatchNorm2d(prev_channels), nn.LeakyReLU(0.2, True), ] ) ) self.down_path.append( UNetConvBlock(conv_num, prev_channels, 2 ** (wf + i + 1), padding, batch_norm) ) prev_channels = 2 ** (wf + i + 1) self.up_path = nn.ModuleList() for i in reversed(range(depth)): self.up_path.append( UNetUpBlock(conv_num, prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm) ) prev_channels = 2 ** (wf + i) if with_tanh: self.last = nn.Sequential( *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3), nn.Tanh()] ) else: self.last = nn.Sequential( *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3)] ) if sync_bn: self = DataParallelWithCallback(self) def forward(self, x): x = self.first(x) blocks = [] for i, down_block in enumerate(self.down_path): blocks.append(x) x = self.down_sample[i](x) x = down_block(x) for i, up in enumerate(self.up_path): x = up(x, blocks[-i - 1]) return self.last(x) class UNetConvBlock(nn.Module): def __init__(self, conv_num, in_size, out_size, padding, batch_norm): super(UNetConvBlock, self).__init__() block = [] for _ in range(conv_num): block.append(nn.ReflectionPad2d(padding=int(padding))) block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=0)) if batch_norm: block.append(nn.BatchNorm2d(out_size)) block.append(nn.LeakyReLU(0.2, True)) in_size = out_size self.block = nn.Sequential(*block) def forward(self, x): out = self.block(x) return out class UNetUpBlock(nn.Module): def __init__(self, conv_num, in_size, out_size, up_mode, padding, batch_norm): super(UNetUpBlock, self).__init__() if up_mode == "upconv": self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) elif up_mode == "upsample": self.up = nn.Sequential( nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False), nn.ReflectionPad2d(1), nn.Conv2d(in_size, out_size, kernel_size=3, padding=0), ) self.conv_block = UNetConvBlock(conv_num, in_size, out_size, padding, batch_norm) def center_crop(self, layer, target_size): _, _, layer_height, layer_width = layer.size() diff_y = (layer_height - target_size[0]) // 2 diff_x = (layer_width - target_size[1]) // 2 return layer[:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])] def forward(self, x, bridge): up = self.up(x) crop1 = self.center_crop(bridge, up.shape[2:]) out = torch.cat([up, crop1], 1) out = self.conv_block(out) return out class UnetGenerator(nn.Module): """Create a Unet-based generator""" def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_type="BN", use_dropout=False): """Construct a Unet generator Parameters: input_nc (int) -- the number of channels in input images output_nc (int) -- the number of channels in output images num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 # at the bottleneck ngf (int) -- the number of filters in the last conv layer norm_layer -- normalization layer We construct the U-Net from the innermost layer to the outermost layer. It is a recursive process. """ super().__init__() if norm_type == "BN": norm_layer = nn.BatchNorm2d elif norm_type == "IN": norm_layer = nn.InstanceNorm2d else: raise NameError("Unknown norm layer") # construct unet structure unet_block = UnetSkipConnectionBlock( ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True ) # add the innermost layer for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters unet_block = UnetSkipConnectionBlock( ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout, ) # gradually reduce the number of filters from ngf * 8 to ngf unet_block = UnetSkipConnectionBlock( ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer ) unet_block = UnetSkipConnectionBlock( ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer ) unet_block = UnetSkipConnectionBlock( ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer ) self.model = UnetSkipConnectionBlock( output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer ) # add the outermost layer def forward(self, input): return self.model(input) class UnetSkipConnectionBlock(nn.Module): """Defines the Unet submodule with skip connection. -------------------identity---------------------- |-- downsampling -- |submodule| -- upsampling --| """ def __init__( self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, ): """Construct a Unet submodule with skip connections. Parameters: outer_nc (int) -- the number of filters in the outer conv layer inner_nc (int) -- the number of filters in the inner conv layer input_nc (int) -- the number of channels in input images/features submodule (UnetSkipConnectionBlock) -- previously defined submodules outermost (bool) -- if this module is the outermost module innermost (bool) -- if this module is the innermost module norm_layer -- normalization layer user_dropout (bool) -- if use dropout layers. """ super().__init__() self.outermost = outermost use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.LeakyReLU(0.2, True) upnorm = norm_layer(outer_nc) if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d( inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias ) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] if use_dropout: model = down + [submodule] + up + [nn.Dropout(0.5)] else: model = down + [submodule] + up self.model = nn.Sequential(*model) def forward(self, x): if self.outermost: return self.model(x) else: # add skip connections return torch.cat([x, self.model(x)], 1) # ============================================ # Network testing # ============================================ if __name__ == "__main__": from torchsummary import summary device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet_two_decoders( in_channels=3, out_channels1=3, out_channels2=1, depth=4, conv_num=1, wf=6, padding=True, batch_norm=True, up_mode="upsample", with_tanh=False, ) model.to(device) model_pix2pix = UnetGenerator(3, 3, 5, ngf=64, norm_type="BN", use_dropout=False) model_pix2pix.to(device) print("customized unet:") summary(model, (3, 256, 256)) print("cyclegan unet:") summary(model_pix2pix, (3, 256, 256)) x = torch.zeros(1, 3, 256, 256).requires_grad_(True).cuda() g = make_dot(model(x)) g.render("models/Digraph.gv", view=False)