manhkhanhUIT's picture
Init
e78c13e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from detection_models.sync_batchnorm import DataParallelWithCallback
from detection_models.antialiasing import Downsample
class UNet(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
depth=5,
conv_num=2,
wf=6,
padding=True,
batch_norm=True,
up_mode="upsample",
with_tanh=False,
sync_bn=True,
antialiasing=True,
):
"""
Implementation of
U-Net: Convolutional Networks for Biomedical Image Segmentation
(Ronneberger et al., 2015)
https://arxiv.org/abs/1505.04597
Using the default arguments will yield the exact version used
in the original paper
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
depth (int): depth of the network
wf (int): number of filters in the first layer is 2**wf
padding (bool): if True, apply padding such that the input shape
is the same as the output.
This may introduce artifacts
batch_norm (bool): Use BatchNorm after layers with an
activation function
up_mode (str): one of 'upconv' or 'upsample'.
'upconv' will use transposed convolutions for
learned upsampling.
'upsample' will use bilinear upsampling.
"""
super().__init__()
assert up_mode in ("upconv", "upsample")
self.padding = padding
self.depth = depth - 1
prev_channels = in_channels
self.first = nn.Sequential(
*[nn.ReflectionPad2d(3), nn.Conv2d(in_channels, 2 ** wf, kernel_size=7), nn.LeakyReLU(0.2, True)]
)
prev_channels = 2 ** wf
self.down_path = nn.ModuleList()
self.down_sample = nn.ModuleList()
for i in range(depth):
if antialiasing and depth > 0:
self.down_sample.append(
nn.Sequential(
*[
nn.ReflectionPad2d(1),
nn.Conv2d(prev_channels, prev_channels, kernel_size=3, stride=1, padding=0),
nn.BatchNorm2d(prev_channels),
nn.LeakyReLU(0.2, True),
Downsample(channels=prev_channels, stride=2),
]
)
)
else:
self.down_sample.append(
nn.Sequential(
*[
nn.ReflectionPad2d(1),
nn.Conv2d(prev_channels, prev_channels, kernel_size=4, stride=2, padding=0),
nn.BatchNorm2d(prev_channels),
nn.LeakyReLU(0.2, True),
]
)
)
self.down_path.append(
UNetConvBlock(conv_num, prev_channels, 2 ** (wf + i + 1), padding, batch_norm)
)
prev_channels = 2 ** (wf + i + 1)
self.up_path = nn.ModuleList()
for i in reversed(range(depth)):
self.up_path.append(
UNetUpBlock(conv_num, prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
)
prev_channels = 2 ** (wf + i)
if with_tanh:
self.last = nn.Sequential(
*[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3), nn.Tanh()]
)
else:
self.last = nn.Sequential(
*[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3)]
)
if sync_bn:
self = DataParallelWithCallback(self)
def forward(self, x):
x = self.first(x)
blocks = []
for i, down_block in enumerate(self.down_path):
blocks.append(x)
x = self.down_sample[i](x)
x = down_block(x)
for i, up in enumerate(self.up_path):
x = up(x, blocks[-i - 1])
return self.last(x)
class UNetConvBlock(nn.Module):
def __init__(self, conv_num, in_size, out_size, padding, batch_norm):
super(UNetConvBlock, self).__init__()
block = []
for _ in range(conv_num):
block.append(nn.ReflectionPad2d(padding=int(padding)))
block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=0))
if batch_norm:
block.append(nn.BatchNorm2d(out_size))
block.append(nn.LeakyReLU(0.2, True))
in_size = out_size
self.block = nn.Sequential(*block)
def forward(self, x):
out = self.block(x)
return out
class UNetUpBlock(nn.Module):
def __init__(self, conv_num, in_size, out_size, up_mode, padding, batch_norm):
super(UNetUpBlock, self).__init__()
if up_mode == "upconv":
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
elif up_mode == "upsample":
self.up = nn.Sequential(
nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False),
nn.ReflectionPad2d(1),
nn.Conv2d(in_size, out_size, kernel_size=3, padding=0),
)
self.conv_block = UNetConvBlock(conv_num, in_size, out_size, padding, batch_norm)
def center_crop(self, layer, target_size):
_, _, layer_height, layer_width = layer.size()
diff_y = (layer_height - target_size[0]) // 2
diff_x = (layer_width - target_size[1]) // 2
return layer[:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])]
def forward(self, x, bridge):
up = self.up(x)
crop1 = self.center_crop(bridge, up.shape[2:])
out = torch.cat([up, crop1], 1)
out = self.conv_block(out)
return out
class UnetGenerator(nn.Module):
"""Create a Unet-based generator"""
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_type="BN", use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super().__init__()
if norm_type == "BN":
norm_layer = nn.BatchNorm2d
elif norm_type == "IN":
norm_layer = nn.InstanceNorm2d
else:
raise NameError("Unknown norm layer")
# construct unet structure
unet_block = UnetSkipConnectionBlock(
ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True
) # add the innermost layer
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
unet_block = UnetSkipConnectionBlock(
ngf * 8,
ngf * 8,
input_nc=None,
submodule=unet_block,
norm_layer=norm_layer,
use_dropout=use_dropout,
)
# gradually reduce the number of filters from ngf * 8 to ngf
unet_block = UnetSkipConnectionBlock(
ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer
)
unet_block = UnetSkipConnectionBlock(
ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer
)
unet_block = UnetSkipConnectionBlock(
ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer
)
self.model = UnetSkipConnectionBlock(
output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer
) # add the outermost layer
def forward(self, input):
return self.model(input)
class UnetSkipConnectionBlock(nn.Module):
"""Defines the Unet submodule with skip connection.
-------------------identity----------------------
|-- downsampling -- |submodule| -- upsampling --|
"""
def __init__(
self,
outer_nc,
inner_nc,
input_nc=None,
submodule=None,
outermost=False,
innermost=False,
norm_layer=nn.BatchNorm2d,
use_dropout=False,
):
"""Construct a Unet submodule with skip connections.
Parameters:
outer_nc (int) -- the number of filters in the outer conv layer
inner_nc (int) -- the number of filters in the inner conv layer
input_nc (int) -- the number of channels in input images/features
submodule (UnetSkipConnectionBlock) -- previously defined submodules
outermost (bool) -- if this module is the outermost module
innermost (bool) -- if this module is the innermost module
norm_layer -- normalization layer
user_dropout (bool) -- if use dropout layers.
"""
super().__init__()
self.outermost = outermost
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.LeakyReLU(0.2, True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(
inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else: # add skip connections
return torch.cat([x, self.model(x)], 1)
# ============================================
# Network testing
# ============================================
if __name__ == "__main__":
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet_two_decoders(
in_channels=3,
out_channels1=3,
out_channels2=1,
depth=4,
conv_num=1,
wf=6,
padding=True,
batch_norm=True,
up_mode="upsample",
with_tanh=False,
)
model.to(device)
model_pix2pix = UnetGenerator(3, 3, 5, ngf=64, norm_type="BN", use_dropout=False)
model_pix2pix.to(device)
print("customized unet:")
summary(model, (3, 256, 256))
print("cyclegan unet:")
summary(model_pix2pix, (3, 256, 256))
x = torch.zeros(1, 3, 256, 256).requires_grad_(True).cuda()
g = make_dot(model(x))
g.render("models/Digraph.gv", view=False)