qubvel-hf's picture
qubvel-hf HF staff
Init project
c509e76
raw
history blame
31.9 kB
from math import log
import torch
import torch.nn as nn
from torch.nn import init
import functools
from model.cbam import CBAM
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class SingleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0,stride=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
# nn.ReflectionPad2d(1),
# nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0,stride=1),
# nn.BatchNorm2d(out_channels),
# nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down_single(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
SingleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up_single(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.conv = SingleConv(in_channels, out_channels)
self.deconv = nn.ConvTranspose2d(in_channels, out_channels,kernel_size=4, stride=2,padding=1, bias=True)
def forward(self, x1, x2):
x1 = self.deconv(x1)
# input is BCHW
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0,stride=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0,stride=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.conv = DoubleConv(in_channels, out_channels)
self.deconv = nn.ConvTranspose2d(in_channels, out_channels,kernel_size=4, stride=2,padding=1, bias=True)
def forward(self, x1, x2):
x1 = self.deconv(x1)
# input is BCHW
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.tanh = nn.Tanh()
self.hardtanh = nn.Hardtanh()
self.sigmoid = nn.Sigmoid()
def forward(self, x1):
x = self.conv(x1)
# x = self.sigmoid(x)
# x = self.hardtanh(x)
# x = (x+1)/2
return x
class GiemaskGenerator(nn.Module):
"""Create a Unet-based generator"""
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(GiemaskGenerator, self).__init__()
self.init_channel =32
self.inc = DoubleConv(3,self.init_channel)
self.down1 = Down(self.init_channel, self.init_channel*2)
self.down2 = Down(self.init_channel*2, self.init_channel*4)
self.down3 = Down(self.init_channel*4, self.init_channel*8)
self.down4 = Down(self.init_channel*8, self.init_channel*16)
self.down5 = Down(self.init_channel*16, self.init_channel*32)
self.up1 = Up(self.init_channel*32, self.init_channel*16)
self.up2 = Up(self.init_channel*16, self.init_channel*8)
self.up3 = Up(self.init_channel*8, self.init_channel*4)
self.up4 = Up(self.init_channel*4,self.init_channel*2)
self.up5 = Up(self.init_channel*2, self.init_channel)
self.outc = OutConv(self.init_channel, 1)
self.up1_1 = Up_single(self.init_channel*32, self.init_channel*16)
self.up2_1 = Up_single(self.init_channel*16, self.init_channel*8)
self.up3_1 = Up_single(self.init_channel*8, self.init_channel*4)
self.up4_1 = Up_single(self.init_channel*4,self.init_channel*2)
self.up5_1 = Up_single(self.init_channel*2, self.init_channel)
self.outc_1 = OutConv(self.init_channel, 1)
# self.dropout = nn.Dropout(p=0.5)
def forward(self, input):
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x_1 = self.up1_1(x6, x5)
x_1 = self.up2_1(x_1, x4)
x_1 = self.up3_1(x_1, x3)
x_1 = self.up4_1(x_1, x2)
x_1 = self.up5_1(x_1, x1)
mask = self.outc_1(x_1)
x = self.up1(x6, x5)
# x = self.dropout(x)
x = self.up2(x, x4)
# x = self.dropout(x)
x = self.up3(x, x3)
# x = self.dropout(x)
x = self.up4(x, x2)
# x = self.dropout(x)
x = self.up5(x, x1)
# x = self.dropout(x)
depth = self.outc(x)
return depth,mask
"""Create a Unet-based generator"""
class Giemask2Generator(nn.Module):
"""Create a Unet-based generator"""
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(Giemask2Generator, self).__init__()
self.init_channel =32
self.inc = DoubleConv(3,self.init_channel)
self.down1 = Down(self.init_channel, self.init_channel*2)
self.down2 = Down(self.init_channel*2, self.init_channel*4)
self.down3 = Down(self.init_channel*4, self.init_channel*8)
self.down4 = Down(self.init_channel*8, self.init_channel*16)
self.down5 = Down(self.init_channel*16, self.init_channel*32)
self.up1 = Up(self.init_channel*32, self.init_channel*16)
self.up2 = Up(self.init_channel*16, self.init_channel*8)
self.up3 = Up(self.init_channel*8, self.init_channel*4)
self.up4 = Up(self.init_channel*4,self.init_channel*2)
self.up5 = Up(self.init_channel*2, self.init_channel)
self.outc = OutConv(self.init_channel, 1)
self.up1_1 = Up_single(self.init_channel*32, self.init_channel*16)
self.up2_1 = Up_single(self.init_channel*16, self.init_channel*8)
self.up3_1 = Up_single(self.init_channel*8, self.init_channel*4)
self.up4_1 = Up_single(self.init_channel*4,self.init_channel*2)
self.up5_1 = Up_single(self.init_channel*2, self.init_channel)
self.outc_1 = OutConv(self.init_channel, 1)
self.outc_2 = OutConv(self.init_channel, 1)
# self.dropout = nn.Dropout(p=0.5)
def forward(self, input):
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x_1 = self.up1_1(x6, x5)
x_1 = self.up2_1(x_1, x4)
x_1 = self.up3_1(x_1, x3)
x_1 = self.up4_1(x_1, x2)
x_1 = self.up5_1(x_1, x1)
mask = self.outc_1(x_1)
edge = self.outc_2(x_1)
x = self.up1(x6, x5)
# x = self.dropout(x)
x = self.up2(x, x4)
# x = self.dropout(x)
x = self.up3(x, x3)
# x = self.dropout(x)
x = self.up4(x, x2)
# x = self.dropout(x)
x = self.up5(x, x1)
# x = self.dropout(x)
depth = self.outc(x)
return depth,mask,edge
"""Create a Unet-based generator"""
class GieGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(GieGenerator, self).__init__()
self.init_channel =32
self.inc = DoubleConv(input_nc,self.init_channel)
self.down1 = Down(self.init_channel, self.init_channel*2)
self.down2 = Down(self.init_channel*2, self.init_channel*4)
self.down3 = Down(self.init_channel*4, self.init_channel*8)
self.down4 = Down(self.init_channel*8, self.init_channel*16)
self.down5 = Down(self.init_channel*16, self.init_channel*32)
self.up1 = Up(self.init_channel*32, self.init_channel*16)
self.up2 = Up(self.init_channel*16, self.init_channel*8)
self.up3 = Up(self.init_channel*8, self.init_channel*4)
self.up4 = Up(self.init_channel*4,self.init_channel*2)
self.up5 = Up(self.init_channel*2, self.init_channel)
self.outc = OutConv(self.init_channel, 2)
# self.dropout = nn.Dropout(p=0.5)
def forward(self, input):
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x = self.up1(x6, x5)
# x = self.dropout(x)
x = self.up2(x, x4)
# x = self.dropout(x)
x = self.up3(x, x3)
# x = self.dropout(x)
x = self.up4(x, x2)
# x = self.dropout(x)
x = self.up5(x, x1)
# x = self.dropout(x)
logits1 = self.outc(x)
return logits1
class GiecbamGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(GiecbamGenerator, self).__init__()
self.init_channel =32
self.inc = DoubleConv(input_nc,self.init_channel)
self.down1 = Down(self.init_channel, self.init_channel*2)
self.down2 = Down(self.init_channel*2, self.init_channel*4)
self.down3 = Down(self.init_channel*4, self.init_channel*8)
self.down4 = Down(self.init_channel*8, self.init_channel*16)
self.down5 = Down(self.init_channel*16, self.init_channel*32)
self.cbam = CBAM(gate_channels=self.init_channel*32)
self.up1 = Up(self.init_channel*32, self.init_channel*16)
self.up2 = Up(self.init_channel*16, self.init_channel*8)
self.up3 = Up(self.init_channel*8, self.init_channel*4)
self.up4 = Up(self.init_channel*4,self.init_channel*2)
self.up5 = Up(self.init_channel*2, self.init_channel)
self.outc = OutConv(self.init_channel, 2)
self.dropout = nn.Dropout(p=0.1)
def forward(self, input):
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x6 = self.cbam(x6)
x = self.up1(x6, x5)
x = self.up2(x, x4)
x = self.up3(x, x3)
x = self.up4(x, x2)
x = self.up5(x, x1)
x = self.dropout(x)
logits1 = self.outc(x)
return logits1
class Gie2headGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(Gie2headGenerator, self).__init__()
self.init_channel =32
self.inc = DoubleConv(input_nc,self.init_channel)
self.down1 = Down(self.init_channel, self.init_channel*2)
self.down2 = Down(self.init_channel*2, self.init_channel*4)
self.down3 = Down(self.init_channel*4, self.init_channel*8)
self.down4 = Down(self.init_channel*8, self.init_channel*16)
self.down5 = Down(self.init_channel*16, self.init_channel*32)
self.up1_1 = Up(self.init_channel*32, self.init_channel*16)
self.up2_1 = Up(self.init_channel*16, self.init_channel*8)
self.up3_1 = Up(self.init_channel*8, self.init_channel*4)
self.up4_1 = Up(self.init_channel*4,self.init_channel*2)
self.up5_1 = Up(self.init_channel*2, self.init_channel)
self.outc_1 = OutConv(self.init_channel, 1)
self.up1_2 = Up(self.init_channel*32, self.init_channel*16)
self.up2_2 = Up(self.init_channel*16, self.init_channel*8)
self.up3_2 = Up(self.init_channel*8, self.init_channel*4)
self.up4_2 = Up(self.init_channel*4,self.init_channel*2)
self.up5_2 = Up(self.init_channel*2, self.init_channel)
self.outc_2 = OutConv(self.init_channel, 1)
def forward(self, input):
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x_1 = self.up1_1(x6, x5)
x_1 = self.up2_1(x_1, x4)
x_1 = self.up3_1(x_1, x3)
x_1 = self.up4_1(x_1, x2)
x_1 = self.up5_1(x_1, x1)
logits_1 = self.outc_1(x_1)
x_2 = self.up1_2(x6, x5)
x_2 = self.up2_2(x_2, x4)
x_2 = self.up3_2(x_2, x3)
x_2 = self.up4_2(x_2, x2)
x_2 = self.up5_2(x_2, x1)
logits_2 = self.outc_2(x_2)
logits = torch.cat((logits_1,logits_2),1)
return logits
class BmpGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(BmpGenerator, self).__init__()
self.init_channel =32
self.output_nc = output_nc
self.inc = DoubleConv(input_nc,self.init_channel)
self.down1 = Down(self.init_channel, self.init_channel*2)
self.down2 = Down(self.init_channel*2, self.init_channel*4)
self.down3 = Down(self.init_channel*4, self.init_channel*8)
self.down4 = Down(self.init_channel*8, self.init_channel*16)
self.down5 = Down(self.init_channel*16, self.init_channel*32)
self.up1 = Up(self.init_channel*32, self.init_channel*16)
self.up2 = Up(self.init_channel*16, self.init_channel*8)
self.up3 = Up(self.init_channel*8, self.init_channel*4)
self.up4 = Up(self.init_channel*4,self.init_channel*2)
self.up5 = Up(self.init_channel*2, self.init_channel)
self.outc = OutConv(self.init_channel, self.output_nc)
# self.dropout = nn.Dropout(p=0.5)
def forward(self, input):
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x = self.up1(x6, x5)
# x = self.dropout(x)
x = self.up2(x, x4)
# x = self.dropout(x)
x = self.up3(x, x3)
# x = self.dropout(x)
x = self.up4(x, x2)
# x = self.dropout(x)
x = self.up5(x, x1)
# x = self.dropout(x)
logits1 = self.outc(x)
return logits1
class Bmp2Generator(nn.Module):
"""Create a Unet-based generator"""
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(Bmp2Generator, self).__init__()
#gienet
self.init_channel =32
self.inc = DoubleConv(3,self.init_channel)
self.down1 = Down(self.init_channel, self.init_channel*2)
self.down2 = Down(self.init_channel*2, self.init_channel*4)
self.down3 = Down(self.init_channel*4, self.init_channel*8)
self.down4 = Down(self.init_channel*8, self.init_channel*16)
self.down5 = Down(self.init_channel*16, self.init_channel*32)
self.up1 = Up(self.init_channel*32, self.init_channel*16)
self.up2 = Up(self.init_channel*16, self.init_channel*8)
self.up3 = Up(self.init_channel*8, self.init_channel*4)
self.up4 = Up(self.init_channel*4,self.init_channel*2)
self.up5 = Up(self.init_channel*2, self.init_channel)
self.outc = OutConv(self.init_channel, 1)
self.up1_1 = Up_single(self.init_channel*32, self.init_channel*16)
self.up2_1 = Up_single(self.init_channel*16, self.init_channel*8)
self.up3_1 = Up_single(self.init_channel*8, self.init_channel*4)
self.up4_1 = Up_single(self.init_channel*4,self.init_channel*2)
self.up5_1 = Up_single(self.init_channel*2, self.init_channel)
self.outc_1 = OutConv(self.init_channel, 1)
self.outc_2 = OutConv(self.init_channel, 1)
#bpm net
self.inc_b = DoubleConv(4,self.init_channel)
self.down1_b = Down(self.init_channel, self.init_channel*2)
self.down2_b = Down(self.init_channel*2, self.init_channel*4)
self.down3_b = Down(self.init_channel*4, self.init_channel*8)
self.down4_b = Down(self.init_channel*8, self.init_channel*16)
self.down5_b = Down(self.init_channel*16, self.init_channel*32)
self.up1_b = Up(self.init_channel*32, self.init_channel*16)
self.up2_b = Up(self.init_channel*16, self.init_channel*8)
self.up3_b = Up(self.init_channel*8, self.init_channel*4)
self.up4_b = Up(self.init_channel*4,self.init_channel*2)
self.up5_b = Up(self.init_channel*2, self.init_channel)
self.outc_b = OutConv(self.init_channel, 2)
# self.dropout = nn.Dropout(p=0.5)
def forward(self, input):
#gienet
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x_1 = self.up1_1(x6, x5)
x_1 = self.up2_1(x_1, x4)
x_1 = self.up3_1(x_1, x3)
x_1 = self.up4_1(x_1, x2)
x_1 = self.up5_1(x_1, x1)
mask = self.outc_1(x_1)
edge = self.outc_2(x_1)
x = self.up1(x6, x5)
x = self.up2(x, x4)
x = self.up3(x, x3)
x = self.up4(x, x2)
x = self.up5(x, x1)
depth = self.outc(x)
#bmpnet
mask[mask>0.5]=1.
mask[mask<=0.5]=0.
image_cat_depth = torch.cat((input*mask,depth*mask),dim=1)
x1_b = self.inc_b(image_cat_depth)
x2_b = self.down1_b(x1_b)
x3_b = self.down2_b(x2_b)
x4_b = self.down3_b(x3_b)
x5_b = self.down4_b(x4_b)
x6_b = self.down5_b(x5_b)
x_b = self.up1_b(x6_b, x5_b)
x_b = self.up2_b(x_b, x4_b)
x_b = self.up3_b(x_b, x3_b)
x_b = self.up4_b(x_b, x2_b)
x_b = self.up5_b(x_b, x1_b)
bm = self.outc_b(x_b)
# return depth,mask,edge,bm
return bm
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetGenerator, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
self.model = unet_block
def forward(self, input):
return self.model(input)
#class GieGenerator(nn.Module):
# def __init__(self, input_nc, output_nc, num_downs, ngf=64,
# norm_layer=nn.BatchNorm2d, use_dropout=False):
# super(GieGenerator, self).__init__()
#
# # construct unet structure
# unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
# for i in range(num_downs - 5):
# unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
# unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
# unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
# unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
# unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
#
# self.model = unet_block
#
# def forward(self, input):
# return self.model(input)
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
# resize = nn.Upsample(scale_factor=2)
# conv = nn.Conv2d(inner_nc,outer_nc,kernel_size=4,stride=2,padding=1,bias=use_bias)
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
#up = [uprelu, resize, conv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
##===================================================================================================
class DilatedDoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=4,stride=1,dilation=4),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=4,stride=1,dilation=4),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class DilatedDown(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DilatedDoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class DilatedUp(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.conv = DilatedDoubleConv(in_channels, out_channels)
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=4,stride=1,dilation=4),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
# self.deconv = nn.ConvTranspose2d(in_channels, out_channels,kernel_size=4, stride=2,padding=1, bias=True)
def forward(self, x1, x2):
x1 = self.up(x1)
x1 = self.conv1(x1)
# x1 = self.deconv(x1)
# input is BCHW
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class DilatedSingleUnet(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(DilatedSingleUnet, self).__init__()
self.init_channel = 32
self.inc = DilatedDoubleConv(input_nc,self.init_channel)
self.down1 = DilatedDown(self.init_channel, self.init_channel*2)
self.down2 = DilatedDown(self.init_channel*2, self.init_channel*4)
self.down3 = DilatedDown(self.init_channel*4, self.init_channel*8)
self.down4 = DilatedDown(self.init_channel*8, self.init_channel*16)
self.down5 = DilatedDown(self.init_channel*16, self.init_channel*32)
self.cbam = CBAM(gate_channels=self.init_channel*32)
self.up1 = DilatedUp(self.init_channel*32, self.init_channel*16)
self.up2 = DilatedUp(self.init_channel*16, self.init_channel*8)
self.up3 = DilatedUp(self.init_channel*8, self.init_channel*4)
self.up4 = DilatedUp(self.init_channel*4,self.init_channel*2)
self.up5 = DilatedUp(self.init_channel*2, self.init_channel)
self.outc = OutConv(self.init_channel, output_nc)
def forward(self, input):
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x6 = self.cbam(x6)
x = self.up1(x6, x5)
x = self.up2(x, x4)
x = self.up3(x, x3)
x = self.up4(x, x2)
x = self.up5(x, x1)
logits1 = self.outc(x)
return logits1