|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from vgg_model import vgg19 |
|
|
|
class DoubleConv(nn.Module): |
|
"""(convolution => [BN] => ReLU) * 2""" |
|
|
|
def __init__(self, in_channels, out_channels, mid_channels=None): |
|
super().__init__() |
|
if not mid_channels: |
|
mid_channels = out_channels |
|
self.double_conv = nn.Sequential( |
|
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(mid_channels), |
|
nn.LeakyReLU(0.1, True), |
|
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(out_channels), |
|
nn.LeakyReLU(0.1, True) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.double_conv(x) |
|
return x |
|
|
|
class ResBlock(nn.Module): |
|
"""(convolution => [BN] => ReLU) * 2""" |
|
|
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.bottle_conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0) |
|
self.double_conv = nn.Sequential( |
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(out_channels), |
|
nn.LeakyReLU(0.2, True), |
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.bottle_conv(x) |
|
x = self.double_conv(x) + x |
|
return x / math.sqrt(2) |
|
|
|
|
|
class Down(nn.Module): |
|
"""Downscaling with stride conv then double conv""" |
|
|
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.main = nn.Sequential( |
|
nn.Conv2d(in_channels, in_channels, 4, 2, 1), |
|
nn.LeakyReLU(0.1, True), |
|
|
|
ResBlock(in_channels, out_channels) |
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
|
x = self.main(x) |
|
|
|
return x |
|
|
|
class SDFT(nn.Module): |
|
|
|
def __init__(self, color_dim, channels, kernel_size = 3): |
|
super().__init__() |
|
|
|
|
|
fan_in = channels * kernel_size ** 2 |
|
self.kernel_size = kernel_size |
|
self.padding = kernel_size // 2 |
|
|
|
self.scale = 1 / math.sqrt(fan_in) |
|
self.modulation = nn.Conv2d(color_dim, channels, 1) |
|
self.weight = nn.Parameter( |
|
torch.randn(1, channels, channels, kernel_size, kernel_size) |
|
) |
|
|
|
def forward(self, fea, color_style): |
|
|
|
B, C, H, W = fea.size() |
|
|
|
style = self.modulation(color_style).view(B, 1, C, 1, 1) |
|
weight = self.scale * self.weight * style |
|
|
|
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) |
|
weight = weight * demod.view(B, C, 1, 1, 1) |
|
|
|
weight = weight.view( |
|
B * C, C, self.kernel_size, self.kernel_size |
|
) |
|
|
|
fea = fea.view(1, B * C, H, W) |
|
fea = F.conv2d(fea, weight, padding=self.padding, groups=B) |
|
fea = fea.view(B, C, H, W) |
|
|
|
return fea |
|
|
|
|
|
class UpBlock(nn.Module): |
|
|
|
|
|
def __init__(self, color_dim, in_channels, out_channels, kernel_size = 3, bilinear=True): |
|
super().__init__() |
|
|
|
|
|
if bilinear: |
|
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) |
|
|
|
else: |
|
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) |
|
|
|
self.conv_cat = nn.Sequential( |
|
nn.Conv2d(in_channels // 2 + in_channels // 8, out_channels, 1, 1, 0), |
|
nn.LeakyReLU(0.2, True), |
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
|
nn.LeakyReLU(0.2, True) |
|
) |
|
|
|
self.conv_s = nn.Conv2d(in_channels//2, out_channels, 1, 1, 0) |
|
|
|
|
|
self.SDFT = SDFT(color_dim, out_channels, kernel_size) |
|
|
|
|
|
def forward(self, x1, x2, color_style): |
|
|
|
x1 = self.up(x1) |
|
x1_s = self.conv_s(x1) |
|
|
|
x = torch.cat([x1, x2[:, ::4, :, :]], dim=1) |
|
x = self.conv_cat(x) |
|
x = self.SDFT(x, color_style) |
|
|
|
x = x + x1_s |
|
|
|
return x |
|
|
|
|
|
class ColorEncoder(nn.Module): |
|
def __init__(self, color_dim=512): |
|
super(ColorEncoder, self).__init__() |
|
|
|
|
|
self.vgg = vgg19() |
|
|
|
self.feature2vector = nn.Sequential( |
|
nn.Conv2d(color_dim, color_dim, 4, 2, 2), |
|
nn.LeakyReLU(0.2, True), |
|
nn.Conv2d(color_dim, color_dim, 3, 1, 1), |
|
nn.LeakyReLU(0.2, True), |
|
nn.Conv2d(color_dim, color_dim, 4, 2, 2), |
|
nn.LeakyReLU(0.2, True), |
|
nn.Conv2d(color_dim, color_dim, 3, 1, 1), |
|
nn.LeakyReLU(0.2, True), |
|
nn.AdaptiveAvgPool2d((1, 1)), |
|
nn.Conv2d(color_dim, color_dim//2, 1), |
|
nn.LeakyReLU(0.2, True), |
|
nn.Conv2d(color_dim//2, color_dim//2, 1), |
|
nn.LeakyReLU(0.2, True), |
|
nn.Conv2d(color_dim//2, color_dim, 1), |
|
) |
|
|
|
self.color_dim = color_dim |
|
|
|
def forward(self, x): |
|
|
|
vgg_fea = self.vgg(x, layer_name='relu5_2') |
|
|
|
x_color = self.feature2vector(vgg_fea[-1]) |
|
|
|
return x_color |
|
|
|
|
|
class ColorUNet(nn.Module): |
|
|
|
def __init__(self, n_channels=1, n_classes=3, bilinear=True): |
|
super(ColorUNet, self).__init__() |
|
self.n_channels = n_channels |
|
self.n_classes = n_classes |
|
self.bilinear = bilinear |
|
|
|
self.inc = DoubleConv(n_channels, 64) |
|
self.down1 = Down(64, 128) |
|
self.down2 = Down(128, 256) |
|
self.down3 = Down(256, 512) |
|
factor = 2 if bilinear else 1 |
|
self.down4 = Down(512, 1024 // factor) |
|
|
|
self.up1 = UpBlock(512, 1024, 512 // factor, 3, bilinear) |
|
self.up2 = UpBlock(512, 512, 256 // factor, 3, bilinear) |
|
self.up3 = UpBlock(512, 256, 128 // factor, 5, bilinear) |
|
self.up4 = UpBlock(512, 128, 64, 5, bilinear) |
|
self.outc = nn.Sequential( |
|
nn.Conv2d(64, 64, 3, 1, 1), |
|
nn.LeakyReLU(0.2, True), |
|
nn.Conv2d(64, 2, 3, 1, 1), |
|
nn.Tanh() |
|
) |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
x_color = x[1] |
|
|
|
x1 = self.inc(x[0]) |
|
x2 = self.down1(x1) |
|
x3 = self.down2(x2) |
|
x4 = self.down3(x3) |
|
x5 = self.down4(x4) |
|
|
|
x6 = self.up1(x5, x4, x_color) |
|
x7 = self.up2(x6, x3, x_color) |
|
x8 = self.up3(x7, x2, x_color) |
|
x9 = self.up4(x8, x1, x_color) |
|
x_ab = self.outc(x9) |
|
|
|
return x_ab |
|
|