qubvel-hf's picture
qubvel-hf HF staff
Init project
c509e76
raw
history blame
No virus
31.9 kB
from math import log
import torch
import torch.nn as nn
from torch.nn import init
import functools
from model.cbam import CBAM
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class SingleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0,stride=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
# nn.ReflectionPad2d(1),
# nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0,stride=1),
# nn.BatchNorm2d(out_channels),
# nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down_single(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
SingleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up_single(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.conv = SingleConv(in_channels, out_channels)
self.deconv = nn.ConvTranspose2d(in_channels, out_channels,kernel_size=4, stride=2,padding=1, bias=True)
def forward(self, x1, x2):
x1 = self.deconv(x1)
# input is BCHW
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0,stride=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0,stride=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.conv = DoubleConv(in_channels, out_channels)
self.deconv = nn.ConvTranspose2d(in_channels, out_channels,kernel_size=4, stride=2,padding=1, bias=True)
def forward(self, x1, x2):
x1 = self.deconv(x1)
# input is BCHW
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.tanh = nn.Tanh()
self.hardtanh = nn.Hardtanh()
self.sigmoid = nn.Sigmoid()
def forward(self, x1):
x = self.conv(x1)
# x = self.sigmoid(x)
# x = self.hardtanh(x)
# x = (x+1)/2
return x
class GiemaskGenerator(nn.Module):
"""Create a Unet-based generator"""
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(GiemaskGenerator, self).__init__()
self.init_channel =32
self.inc = DoubleConv(3,self.init_channel)
self.down1 = Down(self.init_channel, self.init_channel*2)
self.down2 = Down(self.init_channel*2, self.init_channel*4)
self.down3 = Down(self.init_channel*4, self.init_channel*8)
self.down4 = Down(self.init_channel*8, self.init_channel*16)
self.down5 = Down(self.init_channel*16, self.init_channel*32)
self.up1 = Up(self.init_channel*32, self.init_channel*16)
self.up2 = Up(self.init_channel*16, self.init_channel*8)
self.up3 = Up(self.init_channel*8, self.init_channel*4)
self.up4 = Up(self.init_channel*4,self.init_channel*2)
self.up5 = Up(self.init_channel*2, self.init_channel)
self.outc = OutConv(self.init_channel, 1)
self.up1_1 = Up_single(self.init_channel*32, self.init_channel*16)
self.up2_1 = Up_single(self.init_channel*16, self.init_channel*8)
self.up3_1 = Up_single(self.init_channel*8, self.init_channel*4)
self.up4_1 = Up_single(self.init_channel*4,self.init_channel*2)
self.up5_1 = Up_single(self.init_channel*2, self.init_channel)
self.outc_1 = OutConv(self.init_channel, 1)
# self.dropout = nn.Dropout(p=0.5)
def forward(self, input):
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x_1 = self.up1_1(x6, x5)
x_1 = self.up2_1(x_1, x4)
x_1 = self.up3_1(x_1, x3)
x_1 = self.up4_1(x_1, x2)
x_1 = self.up5_1(x_1, x1)
mask = self.outc_1(x_1)
x = self.up1(x6, x5)
# x = self.dropout(x)
x = self.up2(x, x4)
# x = self.dropout(x)
x = self.up3(x, x3)
# x = self.dropout(x)
x = self.up4(x, x2)
# x = self.dropout(x)
x = self.up5(x, x1)
# x = self.dropout(x)
depth = self.outc(x)
return depth,mask
"""Create a Unet-based generator"""
class Giemask2Generator(nn.Module):
"""Create a Unet-based generator"""
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(Giemask2Generator, self).__init__()
self.init_channel =32
self.inc = DoubleConv(3,self.init_channel)
self.down1 = Down(self.init_channel, self.init_channel*2)
self.down2 = Down(self.init_channel*2, self.init_channel*4)
self.down3 = Down(self.init_channel*4, self.init_channel*8)
self.down4 = Down(self.init_channel*8, self.init_channel*16)
self.down5 = Down(self.init_channel*16, self.init_channel*32)
self.up1 = Up(self.init_channel*32, self.init_channel*16)
self.up2 = Up(self.init_channel*16, self.init_channel*8)
self.up3 = Up(self.init_channel*8, self.init_channel*4)
self.up4 = Up(self.init_channel*4,self.init_channel*2)
self.up5 = Up(self.init_channel*2, self.init_channel)
self.outc = OutConv(self.init_channel, 1)
self.up1_1 = Up_single(self.init_channel*32, self.init_channel*16)
self.up2_1 = Up_single(self.init_channel*16, self.init_channel*8)
self.up3_1 = Up_single(self.init_channel*8, self.init_channel*4)
self.up4_1 = Up_single(self.init_channel*4,self.init_channel*2)
self.up5_1 = Up_single(self.init_channel*2, self.init_channel)
self.outc_1 = OutConv(self.init_channel, 1)
self.outc_2 = OutConv(self.init_channel, 1)
# self.dropout = nn.Dropout(p=0.5)
def forward(self, input):
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x_1 = self.up1_1(x6, x5)
x_1 = self.up2_1(x_1, x4)
x_1 = self.up3_1(x_1, x3)
x_1 = self.up4_1(x_1, x2)
x_1 = self.up5_1(x_1, x1)
mask = self.outc_1(x_1)
edge = self.outc_2(x_1)
x = self.up1(x6, x5)
# x = self.dropout(x)
x = self.up2(x, x4)
# x = self.dropout(x)
x = self.up3(x, x3)
# x = self.dropout(x)
x = self.up4(x, x2)
# x = self.dropout(x)
x = self.up5(x, x1)
# x = self.dropout(x)
depth = self.outc(x)
return depth,mask,edge
"""Create a Unet-based generator"""
class GieGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(GieGenerator, self).__init__()
self.init_channel =32
self.inc = DoubleConv(input_nc,self.init_channel)
self.down1 = Down(self.init_channel, self.init_channel*2)
self.down2 = Down(self.init_channel*2, self.init_channel*4)
self.down3 = Down(self.init_channel*4, self.init_channel*8)
self.down4 = Down(self.init_channel*8, self.init_channel*16)
self.down5 = Down(self.init_channel*16, self.init_channel*32)
self.up1 = Up(self.init_channel*32, self.init_channel*16)
self.up2 = Up(self.init_channel*16, self.init_channel*8)
self.up3 = Up(self.init_channel*8, self.init_channel*4)
self.up4 = Up(self.init_channel*4,self.init_channel*2)
self.up5 = Up(self.init_channel*2, self.init_channel)
self.outc = OutConv(self.init_channel, 2)
# self.dropout = nn.Dropout(p=0.5)
def forward(self, input):
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x = self.up1(x6, x5)
# x = self.dropout(x)
x = self.up2(x, x4)
# x = self.dropout(x)
x = self.up3(x, x3)
# x = self.dropout(x)
x = self.up4(x, x2)
# x = self.dropout(x)
x = self.up5(x, x1)
# x = self.dropout(x)
logits1 = self.outc(x)
return logits1
class GiecbamGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(GiecbamGenerator, self).__init__()
self.init_channel =32
self.inc = DoubleConv(input_nc,self.init_channel)
self.down1 = Down(self.init_channel, self.init_channel*2)
self.down2 = Down(self.init_channel*2, self.init_channel*4)
self.down3 = Down(self.init_channel*4, self.init_channel*8)
self.down4 = Down(self.init_channel*8, self.init_channel*16)
self.down5 = Down(self.init_channel*16, self.init_channel*32)
self.cbam = CBAM(gate_channels=self.init_channel*32)
self.up1 = Up(self.init_channel*32, self.init_channel*16)
self.up2 = Up(self.init_channel*16, self.init_channel*8)
self.up3 = Up(self.init_channel*8, self.init_channel*4)
self.up4 = Up(self.init_channel*4,self.init_channel*2)
self.up5 = Up(self.init_channel*2, self.init_channel)
self.outc = OutConv(self.init_channel, 2)
self.dropout = nn.Dropout(p=0.1)
def forward(self, input):
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x6 = self.cbam(x6)
x = self.up1(x6, x5)
x = self.up2(x, x4)
x = self.up3(x, x3)
x = self.up4(x, x2)
x = self.up5(x, x1)
x = self.dropout(x)
logits1 = self.outc(x)
return logits1
class Gie2headGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(Gie2headGenerator, self).__init__()
self.init_channel =32
self.inc = DoubleConv(input_nc,self.init_channel)
self.down1 = Down(self.init_channel, self.init_channel*2)
self.down2 = Down(self.init_channel*2, self.init_channel*4)
self.down3 = Down(self.init_channel*4, self.init_channel*8)
self.down4 = Down(self.init_channel*8, self.init_channel*16)
self.down5 = Down(self.init_channel*16, self.init_channel*32)
self.up1_1 = Up(self.init_channel*32, self.init_channel*16)
self.up2_1 = Up(self.init_channel*16, self.init_channel*8)
self.up3_1 = Up(self.init_channel*8, self.init_channel*4)
self.up4_1 = Up(self.init_channel*4,self.init_channel*2)
self.up5_1 = Up(self.init_channel*2, self.init_channel)
self.outc_1 = OutConv(self.init_channel, 1)
self.up1_2 = Up(self.init_channel*32, self.init_channel*16)
self.up2_2 = Up(self.init_channel*16, self.init_channel*8)
self.up3_2 = Up(self.init_channel*8, self.init_channel*4)
self.up4_2 = Up(self.init_channel*4,self.init_channel*2)
self.up5_2 = Up(self.init_channel*2, self.init_channel)
self.outc_2 = OutConv(self.init_channel, 1)
def forward(self, input):
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x_1 = self.up1_1(x6, x5)
x_1 = self.up2_1(x_1, x4)
x_1 = self.up3_1(x_1, x3)
x_1 = self.up4_1(x_1, x2)
x_1 = self.up5_1(x_1, x1)
logits_1 = self.outc_1(x_1)
x_2 = self.up1_2(x6, x5)
x_2 = self.up2_2(x_2, x4)
x_2 = self.up3_2(x_2, x3)
x_2 = self.up4_2(x_2, x2)
x_2 = self.up5_2(x_2, x1)
logits_2 = self.outc_2(x_2)
logits = torch.cat((logits_1,logits_2),1)
return logits
class BmpGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(BmpGenerator, self).__init__()
self.init_channel =32
self.output_nc = output_nc
self.inc = DoubleConv(input_nc,self.init_channel)
self.down1 = Down(self.init_channel, self.init_channel*2)
self.down2 = Down(self.init_channel*2, self.init_channel*4)
self.down3 = Down(self.init_channel*4, self.init_channel*8)
self.down4 = Down(self.init_channel*8, self.init_channel*16)
self.down5 = Down(self.init_channel*16, self.init_channel*32)
self.up1 = Up(self.init_channel*32, self.init_channel*16)
self.up2 = Up(self.init_channel*16, self.init_channel*8)
self.up3 = Up(self.init_channel*8, self.init_channel*4)
self.up4 = Up(self.init_channel*4,self.init_channel*2)
self.up5 = Up(self.init_channel*2, self.init_channel)
self.outc = OutConv(self.init_channel, self.output_nc)
# self.dropout = nn.Dropout(p=0.5)
def forward(self, input):
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x = self.up1(x6, x5)
# x = self.dropout(x)
x = self.up2(x, x4)
# x = self.dropout(x)
x = self.up3(x, x3)
# x = self.dropout(x)
x = self.up4(x, x2)
# x = self.dropout(x)
x = self.up5(x, x1)
# x = self.dropout(x)
logits1 = self.outc(x)
return logits1
class Bmp2Generator(nn.Module):
"""Create a Unet-based generator"""
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(Bmp2Generator, self).__init__()
#gienet
self.init_channel =32
self.inc = DoubleConv(3,self.init_channel)
self.down1 = Down(self.init_channel, self.init_channel*2)
self.down2 = Down(self.init_channel*2, self.init_channel*4)
self.down3 = Down(self.init_channel*4, self.init_channel*8)
self.down4 = Down(self.init_channel*8, self.init_channel*16)
self.down5 = Down(self.init_channel*16, self.init_channel*32)
self.up1 = Up(self.init_channel*32, self.init_channel*16)
self.up2 = Up(self.init_channel*16, self.init_channel*8)
self.up3 = Up(self.init_channel*8, self.init_channel*4)
self.up4 = Up(self.init_channel*4,self.init_channel*2)
self.up5 = Up(self.init_channel*2, self.init_channel)
self.outc = OutConv(self.init_channel, 1)
self.up1_1 = Up_single(self.init_channel*32, self.init_channel*16)
self.up2_1 = Up_single(self.init_channel*16, self.init_channel*8)
self.up3_1 = Up_single(self.init_channel*8, self.init_channel*4)
self.up4_1 = Up_single(self.init_channel*4,self.init_channel*2)
self.up5_1 = Up_single(self.init_channel*2, self.init_channel)
self.outc_1 = OutConv(self.init_channel, 1)
self.outc_2 = OutConv(self.init_channel, 1)
#bpm net
self.inc_b = DoubleConv(4,self.init_channel)
self.down1_b = Down(self.init_channel, self.init_channel*2)
self.down2_b = Down(self.init_channel*2, self.init_channel*4)
self.down3_b = Down(self.init_channel*4, self.init_channel*8)
self.down4_b = Down(self.init_channel*8, self.init_channel*16)
self.down5_b = Down(self.init_channel*16, self.init_channel*32)
self.up1_b = Up(self.init_channel*32, self.init_channel*16)
self.up2_b = Up(self.init_channel*16, self.init_channel*8)
self.up3_b = Up(self.init_channel*8, self.init_channel*4)
self.up4_b = Up(self.init_channel*4,self.init_channel*2)
self.up5_b = Up(self.init_channel*2, self.init_channel)
self.outc_b = OutConv(self.init_channel, 2)
# self.dropout = nn.Dropout(p=0.5)
def forward(self, input):
#gienet
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x_1 = self.up1_1(x6, x5)
x_1 = self.up2_1(x_1, x4)
x_1 = self.up3_1(x_1, x3)
x_1 = self.up4_1(x_1, x2)
x_1 = self.up5_1(x_1, x1)
mask = self.outc_1(x_1)
edge = self.outc_2(x_1)
x = self.up1(x6, x5)
x = self.up2(x, x4)
x = self.up3(x, x3)
x = self.up4(x, x2)
x = self.up5(x, x1)
depth = self.outc(x)
#bmpnet
mask[mask>0.5]=1.
mask[mask<=0.5]=0.
image_cat_depth = torch.cat((input*mask,depth*mask),dim=1)
x1_b = self.inc_b(image_cat_depth)
x2_b = self.down1_b(x1_b)
x3_b = self.down2_b(x2_b)
x4_b = self.down3_b(x3_b)
x5_b = self.down4_b(x4_b)
x6_b = self.down5_b(x5_b)
x_b = self.up1_b(x6_b, x5_b)
x_b = self.up2_b(x_b, x4_b)
x_b = self.up3_b(x_b, x3_b)
x_b = self.up4_b(x_b, x2_b)
x_b = self.up5_b(x_b, x1_b)
bm = self.outc_b(x_b)
# return depth,mask,edge,bm
return bm
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetGenerator, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
self.model = unet_block
def forward(self, input):
return self.model(input)
#class GieGenerator(nn.Module):
# def __init__(self, input_nc, output_nc, num_downs, ngf=64,
# norm_layer=nn.BatchNorm2d, use_dropout=False):
# super(GieGenerator, self).__init__()
#
# # construct unet structure
# unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
# for i in range(num_downs - 5):
# unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
# unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
# unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
# unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
# unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
#
# self.model = unet_block
#
# def forward(self, input):
# return self.model(input)
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
# resize = nn.Upsample(scale_factor=2)
# conv = nn.Conv2d(inner_nc,outer_nc,kernel_size=4,stride=2,padding=1,bias=use_bias)
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
#up = [uprelu, resize, conv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
##===================================================================================================
class DilatedDoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=4,stride=1,dilation=4),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=4,stride=1,dilation=4),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class DilatedDown(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DilatedDoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class DilatedUp(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.conv = DilatedDoubleConv(in_channels, out_channels)
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=4,stride=1,dilation=4),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
# self.deconv = nn.ConvTranspose2d(in_channels, out_channels,kernel_size=4, stride=2,padding=1, bias=True)
def forward(self, x1, x2):
x1 = self.up(x1)
x1 = self.conv1(x1)
# x1 = self.deconv(x1)
# input is BCHW
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class DilatedSingleUnet(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(DilatedSingleUnet, self).__init__()
self.init_channel = 32
self.inc = DilatedDoubleConv(input_nc,self.init_channel)
self.down1 = DilatedDown(self.init_channel, self.init_channel*2)
self.down2 = DilatedDown(self.init_channel*2, self.init_channel*4)
self.down3 = DilatedDown(self.init_channel*4, self.init_channel*8)
self.down4 = DilatedDown(self.init_channel*8, self.init_channel*16)
self.down5 = DilatedDown(self.init_channel*16, self.init_channel*32)
self.cbam = CBAM(gate_channels=self.init_channel*32)
self.up1 = DilatedUp(self.init_channel*32, self.init_channel*16)
self.up2 = DilatedUp(self.init_channel*16, self.init_channel*8)
self.up3 = DilatedUp(self.init_channel*8, self.init_channel*4)
self.up4 = DilatedUp(self.init_channel*4,self.init_channel*2)
self.up5 = DilatedUp(self.init_channel*2, self.init_channel)
self.outc = OutConv(self.init_channel, output_nc)
def forward(self, input):
x1 = self.inc(input)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x6 = self.cbam(x6)
x = self.up1(x6, x5)
x = self.up2(x, x4)
x = self.up3(x, x3)
x = self.up4(x, x2)
x = self.up5(x, x1)
logits1 = self.outc(x)
return logits1