|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
W_SIZE = 512 |
|
|
|
|
|
def calc_mean_std(feat, eps=1e-5): |
|
|
|
size = feat.size() |
|
assert (len(size) == 4) |
|
N, C = size[:2] |
|
feat_var = feat.view(N, C, -1).var(dim=2) + eps |
|
feat_std = feat_var.sqrt().view(N, C, 1, 1) |
|
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) |
|
return feat_mean, feat_std |
|
|
|
|
|
def adain(content_feat, style_feat): |
|
assert (content_feat.size()[:2] == style_feat[0].size()[:2]) and (content_feat.size()[:2] == style_feat[1].size()[:2]) |
|
size = content_feat.size() |
|
style_mean, style_std = style_feat |
|
style_mean, style_std = style_mean.unsqueeze(-1).unsqueeze(-1), style_std.unsqueeze(-1).unsqueeze(-1) |
|
content_mean, content_std = calc_mean_std(content_feat) |
|
|
|
normalized_feat = (content_feat - content_mean.expand( |
|
size)) / content_std.expand(size) |
|
return normalized_feat * style_std.expand(size) + style_mean.expand(size) |
|
|
|
|
|
class FullyConnected(nn.Module): |
|
def __init__(self, input_channels: int, output_channels: int, layers=3): |
|
super(FullyConnected, self).__init__() |
|
self.channels = torch.linspace(input_channels, output_channels, layers + 1).long() |
|
self.layers = nn.Sequential( |
|
*[nn.Linear(self.channels[i].item(), self.channels[i+1].item()) for i in range(len(self.channels) - 1)] |
|
) |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
|
|
|
|
class Affine(nn.Module): |
|
def __init__(self, input_channels: int, output_channels): |
|
super(Affine, self).__init__() |
|
self.lin = nn.Linear(input_channels, output_channels) |
|
bias = torch.zeros(output_channels) |
|
nn.init.normal_(bias, 0, 1) |
|
self.bias = nn.Parameter(bias) |
|
|
|
|
|
def forward(self, x): |
|
return self.lin(x) + self.bias |
|
|
|
|
|
class DoubleConv(nn.Module): |
|
"""(convolution => [BN] => ReLU) * 2""" |
|
|
|
def __init__(self, in_channels, out_channels, mid_channels=None, ada=False, padding='zeros'): |
|
super().__init__() |
|
if not mid_channels: |
|
mid_channels = out_channels |
|
self.ada = ada |
|
self.relu = nn.ReLU(inplace=True) |
|
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, padding_mode=padding) |
|
if ada: |
|
self.a1_mean = Affine(W_SIZE, mid_channels) |
|
self.a1_std = Affine(W_SIZE, mid_channels) |
|
else: |
|
self.norm1 = nn.InstanceNorm2d(mid_channels, affine=True) |
|
self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding) |
|
|
|
if ada: |
|
self.a2_mean = Affine(W_SIZE, out_channels) |
|
self.a2_std = Affine(W_SIZE, out_channels) |
|
else: |
|
self.norm2 = nn.InstanceNorm2d(out_channels, affine=True) |
|
|
|
def forward(self, x, w=None): |
|
if self.ada: |
|
assert w is not None |
|
|
|
x = self.conv1(x) |
|
|
|
if self.ada: |
|
x = adain(x, (self.a1_mean(w), self.a1_std(w))) |
|
else: |
|
x = self.norm1(x) |
|
x = self.relu(x) |
|
|
|
x = self.conv2(x) |
|
if self.ada: |
|
x = adain(x, (self.a2_mean(w), self.a2_std(w))) |
|
else: |
|
x = self.norm2(x) |
|
x = self.relu(x) |
|
|
|
return x |
|
|
|
|
|
class DiluteConv(nn.Module): |
|
"""(convolution => [BN] => ReLU) * 2""" |
|
|
|
def __init__(self, in_channels, out_channels, dilation, padding='zeros'): |
|
super().__init__() |
|
self.relu = nn.ReLU(inplace=True) |
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, |
|
padding=1+dilation, dilation=dilation, padding_mode=padding) |
|
self.norm1 = nn.InstanceNorm2d(out_channels, affine=True) |
|
|
|
def forward(self, x, y=None): |
|
if y is not None: |
|
x = torch.cat([x, y], dim=1) |
|
|
|
x = self.conv1(x) |
|
x = self.norm1(x) |
|
x = self.relu(x) |
|
return x |
|
|
|
|
|
class Down(nn.Module): |
|
"""Downscaling with maxpool then double conv""" |
|
|
|
def __init__(self, in_channels, out_channels, ada=False, padding='zeros'): |
|
super().__init__() |
|
self.max_pool = nn.MaxPool2d(2) |
|
self.double_conv = DoubleConv(in_channels, out_channels, ada=ada, padding=padding) |
|
|
|
def forward(self, x, w=None): |
|
x = self.max_pool(x) |
|
return self.double_conv(x, w) |
|
|
|
|
|
class Up(nn.Module): |
|
"""Upscaling then double conv""" |
|
|
|
def __init__(self, in_channels, out_channels, bilinear=True, ada=False, padding='zeros'): |
|
super().__init__() |
|
|
|
|
|
if bilinear: |
|
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
|
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2, ada=ada) |
|
else: |
|
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) |
|
self.conv = DoubleConv(in_channels, out_channels, ada=ada) |
|
|
|
|
|
def forward(self, x1, x2, w=None): |
|
x1 = self.up(x1) |
|
|
|
diffY = x2.size()[2] - x1.size()[2] |
|
diffX = x2.size()[3] - x1.size()[3] |
|
|
|
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, |
|
diffY // 2, diffY - diffY // 2]) |
|
|
|
|
|
|
|
x = torch.cat([x2, x1], dim=1) |
|
return self.conv(x, w) |
|
|
|
|
|
class OutConv(nn.Module): |
|
def __init__(self, in_channels, out_channels, padding='zeros'): |
|
super(OutConv, self).__init__() |
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding_mode=padding) |
|
|
|
def forward(self, x): |
|
return self.conv(x) |