Keiser41's picture
Upload 98 files
22d8ab7
raw
history blame
7.26 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from vgg_model import vgg19
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(mid_channels),
nn.LeakyReLU(0.1, True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.1, True)
)
def forward(self, x):
x = self.double_conv(x)
return x
class ResBlock(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.bottle_conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
self.double_conv = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
)
def forward(self, x):
x = self.bottle_conv(x)
x = self.double_conv(x) + x
return x / math.sqrt(2)
class Down(nn.Module):
"""Downscaling with stride conv then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 4, 2, 1),
nn.LeakyReLU(0.1, True),
# DoubleConv(in_channels, out_channels)
ResBlock(in_channels, out_channels)
)
def forward(self, x):
x = self.main(x)
return x
class SDFT(nn.Module):
def __init__(self, color_dim, channels, kernel_size = 3):
super().__init__()
# generate global conv weights
fan_in = channels * kernel_size ** 2
self.kernel_size = kernel_size
self.padding = kernel_size // 2
self.scale = 1 / math.sqrt(fan_in)
self.modulation = nn.Conv2d(color_dim, channels, 1)
self.weight = nn.Parameter(
torch.randn(1, channels, channels, kernel_size, kernel_size)
)
def forward(self, fea, color_style):
# for global adjustation
B, C, H, W = fea.size()
# print(fea.shape, color_style.shape)
style = self.modulation(color_style).view(B, 1, C, 1, 1)
weight = self.scale * self.weight * style
# demodulation
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.view(B, C, 1, 1, 1)
weight = weight.view(
B * C, C, self.kernel_size, self.kernel_size
)
fea = fea.view(1, B * C, H, W)
fea = F.conv2d(fea, weight, padding=self.padding, groups=B)
fea = fea.view(B, C, H, W)
return fea
class UpBlock(nn.Module):
def __init__(self, color_dim, in_channels, out_channels, kernel_size = 3, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
else:
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
self.conv_cat = nn.Sequential(
nn.Conv2d(in_channels // 2 + in_channels // 8, out_channels, 1, 1, 0),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True)
)
self.conv_s = nn.Conv2d(in_channels//2, out_channels, 1, 1, 0)
# generate global conv weights
self.SDFT = SDFT(color_dim, out_channels, kernel_size)
def forward(self, x1, x2, color_style):
# print(x1.shape, x2.shape, color_style.shape)
x1 = self.up(x1)
x1_s = self.conv_s(x1)
x = torch.cat([x1, x2[:, ::4, :, :]], dim=1)
x = self.conv_cat(x)
x = self.SDFT(x, color_style)
x = x + x1_s #ResBlock
return x
class ColorEncoder(nn.Module):
def __init__(self, color_dim=512):
super(ColorEncoder, self).__init__()
# self.vgg = vgg19(pretrained_path=None)
self.vgg = vgg19()
self.feature2vector = nn.Sequential(
nn.Conv2d(color_dim, color_dim, 4, 2, 2), # 8x8
nn.LeakyReLU(0.2, True),
nn.Conv2d(color_dim, color_dim, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(color_dim, color_dim, 4, 2, 2), # 4x4
nn.LeakyReLU(0.2, True),
nn.Conv2d(color_dim, color_dim, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.AdaptiveAvgPool2d((1, 1)), # 1x1
nn.Conv2d(color_dim, color_dim//2, 1), # linear-1
nn.LeakyReLU(0.2, True),
nn.Conv2d(color_dim//2, color_dim//2, 1), # linear-2
nn.LeakyReLU(0.2, True),
nn.Conv2d(color_dim//2, color_dim, 1), # linear-3
)
self.color_dim = color_dim
def forward(self, x):
# x #[0, 1] RGB
vgg_fea = self.vgg(x, layer_name='relu5_2') # [B, 512, 16, 16]
x_color = self.feature2vector(vgg_fea[-1]) # [B, 512, 1, 1]
return x_color
class ColorUNet(nn.Module):
### this model output is ab
def __init__(self, n_channels=1, n_classes=3, bilinear=True):
super(ColorUNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = UpBlock(512, 1024, 512 // factor, 3, bilinear)
self.up2 = UpBlock(512, 512, 256 // factor, 3, bilinear)
self.up3 = UpBlock(512, 256, 128 // factor, 5, bilinear)
self.up4 = UpBlock(512, 128, 64, 5, bilinear)
self.outc = nn.Sequential(
nn.Conv2d(64, 64, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(64, 2, 3, 1, 1),
nn.Tanh() # [-1,1]
)
def forward(self, x):
# print(torch.max(x[0]), torch.min(x[0])) #[-1, 1] gray image L
# print(torch.max(x[1]), torch.min(x[1])) # color vector
x_color = x[1] # [B, 512, 1, 1]
x1 = self.inc(x[0]) # [B, 64, 256, 256]
x2 = self.down1(x1) # [B, 128, 128, 128]
x3 = self.down2(x2) # [B, 256, 64, 64]
x4 = self.down3(x3) # [B, 512, 32, 32]
x5 = self.down4(x4) # [B, 512, 16, 16]
x6 = self.up1(x5, x4, x_color) # [B, 256, 32, 32]
x7 = self.up2(x6, x3, x_color) # [B, 128, 64, 64]
x8 = self.up3(x7, x2, x_color) # [B, 64, 128, 128]
x9 = self.up4(x8, x1, x_color) # [B, 64, 256, 256]
x_ab = self.outc(x9)
return x_ab