z2p / networks.py
galmetzer's picture
app
cd438c2
import torch
import torch.nn.functional as F
from torch import nn
W_SIZE = 512
def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.size()
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(N, C, 1, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std
def adain(content_feat, style_feat):
assert (content_feat.size()[:2] == style_feat[0].size()[:2]) and (content_feat.size()[:2] == style_feat[1].size()[:2])
size = content_feat.size()
style_mean, style_std = style_feat
style_mean, style_std = style_mean.unsqueeze(-1).unsqueeze(-1), style_std.unsqueeze(-1).unsqueeze(-1)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(
size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
class FullyConnected(nn.Module):
def __init__(self, input_channels: int, output_channels: int, layers=3):
super(FullyConnected, self).__init__()
self.channels = torch.linspace(input_channels, output_channels, layers + 1).long()
self.layers = nn.Sequential(
*[nn.Linear(self.channels[i].item(), self.channels[i+1].item()) for i in range(len(self.channels) - 1)]
)
def forward(self, x):
return self.layers(x)
class Affine(nn.Module):
def __init__(self, input_channels: int, output_channels):
super(Affine, self).__init__()
self.lin = nn.Linear(input_channels, output_channels)
bias = torch.zeros(output_channels)
nn.init.normal_(bias, 0, 1)
self.bias = nn.Parameter(bias)
def forward(self, x):
return self.lin(x) + self.bias
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None, ada=False, padding='zeros'):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.ada = ada
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, padding_mode=padding)
if ada:
self.a1_mean = Affine(W_SIZE, mid_channels)
self.a1_std = Affine(W_SIZE, mid_channels)
else:
self.norm1 = nn.InstanceNorm2d(mid_channels, affine=True)
self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding)
if ada:
self.a2_mean = Affine(W_SIZE, out_channels)
self.a2_std = Affine(W_SIZE, out_channels)
else:
self.norm2 = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x, w=None):
if self.ada:
assert w is not None
x = self.conv1(x)
if self.ada:
x = adain(x, (self.a1_mean(w), self.a1_std(w)))
else:
x = self.norm1(x)
x = self.relu(x)
x = self.conv2(x)
if self.ada:
x = adain(x, (self.a2_mean(w), self.a2_std(w)))
else:
x = self.norm2(x)
x = self.relu(x)
return x
class DiluteConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, dilation, padding='zeros'):
super().__init__()
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
padding=1+dilation, dilation=dilation, padding_mode=padding)
self.norm1 = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x, y=None):
if y is not None:
x = torch.cat([x, y], dim=1)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
return x
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels, ada=False, padding='zeros'):
super().__init__()
self.max_pool = nn.MaxPool2d(2)
self.double_conv = DoubleConv(in_channels, out_channels, ada=ada, padding=padding)
def forward(self, x, w=None):
x = self.max_pool(x)
return self.double_conv(x, w)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True, ada=False, padding='zeros'):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2, ada=ada)
else:
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels, ada=ada)
def forward(self, x1, x2, w=None):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x, w)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels, padding='zeros'):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding_mode=padding)
def forward(self, x):
return self.conv(x)