BerfScene / models /pigan_discriminator.py
3v324v23's picture
init
2f85de4
raw
history blame
9.74 kB
# python3.8
"""Contains the implementation of discriminator described in StyleGAN.
Paper: https://arxiv.org/pdf/1812.04948.pdf
Official TensorFlow implementation: https://github.com/NVlabs/stylegan
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
import math
__all__ = ['PiGANDiscriminator']
class PiGANDiscriminator(nn.Module):
def __init__(self,
resolution,
latent_dim=256,
label_dim=0,
embedding_dim=256,
normalize_embedding=True,
**kwargs): # from 4 * 2^0 to 4 * 2^7 4 -> 512
super().__init__()
self.label_dim = label_dim
self.embedding_dim = embedding_dim
self.normalize_embedding = normalize_embedding
self.register_buffer('lod', torch.zeros(()))
self.use_embedding = label_dim > 0 and embedding_dim > 0
if self.use_embedding > 0:
self.class_embedding = EqualLinear(label_dim,
embedding_dim,
bias=True,
bias_init=0,
lr_mul=1)
self.norm = PixelNormLayer(dim=1, eps=1e-8)
self.layers = nn.ModuleList([
ResidualCCBlock(32, 64), # 6 256x256 -> 128x128
ResidualCCBlock(64, 128), # 5 128x128 -> 64x64
ResidualCCBlock(128, 256), # 4 64x64 -> 32x32
ResidualCCBlock(256, 400), # 3 32x32 -> 16x16
ResidualCCBlock(400, 400), # 2 16x16 -> 8x8
ResidualCCBlock(400, 400), # 1 8x8 -> 4x4
ResidualCCBlock(400, 400), # 7 4x4 -> 2x2
])
self.fromRGB = nn.ModuleList([
AdapterBlock(32),
AdapterBlock(64),
AdapterBlock(128),
AdapterBlock(256),
AdapterBlock(400),
AdapterBlock(400),
AdapterBlock(400),
AdapterBlock(400),
])
self.score_conv = nn.Conv2d(
400, embedding_dim if self.use_embedding else max(label_dim, 1), 2)
self.latent_conv = nn.Conv2d(400, latent_dim, 2)
self.camera_conv = nn.Conv2d(400, 2, 2)
self.img_size_to_layer = {
2: 7,
4: 6,
8: 5,
16: 4,
32: 3,
64: 2,
128: 1,
256: 0
}
self.register_buffer('lod', torch.zeros(()))
def forward(self,
input,
label=None,
options=None,
alpha=None,
enable_amp=False,
**kwargs):
if self.label_dim > 0:
if label is None:
raise ValueError(
f'Model requires an additional label '
f'(with dimension {self.label_dim}) as input, '
f'but no label is received!')
if label.ndim != 2 or label.shape != (input.shape[0],
self.label_dim):
raise ValueError(f'Input label should be with shape '
f'[batch_size, label_dim], where '
f'`batch_size` equals to that of '
f'images ({input.shape[0]}) and '
f'`label_dim` equals to {self.label_dim}!\n'
f'But `{label.shape}` is received!')
label = label.to(dtype=torch.float32)
if self.use_embedding:
embed = self.class_embedding(label)
if self.normalize_embedding:
embed = self.norm(embed)
start = self.img_size_to_layer[input.shape[-1]]
with autocast(enabled=enable_amp):
x = self.fromRGB[start](input)
if kwargs.get('instance_noise', 0) > 0:
x = x + torch.randn_like(x) * kwargs['instance_noise']
for i, layer in enumerate(self.layers[start:]):
if i == 1 and alpha < 1:
down_image = F.interpolate(input,
scale_factor=0.5,
mode='nearest')
x = alpha * x + (1 - alpha) * self.fromRGB[start +
1](down_image)
x = layer(x)
# x = self.final_layer(x).reshape(x.shape[0], -1)
score = self.score_conv(x).reshape(x.shape[0], -1)
if self.use_embedding:
score = (score * embed).sum(dim=1, keepdim=True)
score = score / np.sqrt(self.embedding_dim)
elif self.label_dim > 0:
score = (score * label).sum(dim=1, keepdim=True)
latent = self.latent_conv(x).reshape(x.shape[0], -1)
position = self.camera_conv(x).reshape(x.shape[0], -1)
results = {
'score': score,
'latent': latent,
'camera': position,
}
return results
class ResidualCCBlock(nn.Module):
def __init__(self, inplanes, planes, kernel_size=3):
super().__init__()
p = kernel_size // 2
self.network = nn.Sequential(
CoordConv(inplanes, planes, kernel_size=kernel_size, padding=p),
nn.LeakyReLU(0.2, inplace=True),
CoordConv(planes,
planes,
kernel_size=kernel_size,
stride=2,
padding=p), nn.LeakyReLU(0.2, inplace=True))
self.proj = nn.Conv2d(inplanes, planes, 1, stride=2)
def init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight,
a=0.2,
mode='fan_in',
nonlinearity='leaky_relu')
def forward(self, input):
y = self.network(input)
identity = self.proj(input)
y = (y + identity) / math.sqrt(2)
return y
class AdapterBlock(nn.Module):
def __init__(self, output_channels):
super().__init__()
self.model = nn.Sequential(nn.Conv2d(3, output_channels, 1, padding=0),
nn.LeakyReLU(0.2, inplace=True))
def forward(self, input):
return self.model(input)
class AddCoords(nn.Module):
"""
Source:
https://github.com/mkocabas/CoordConv-pytorch/blob/master/CoordConv.py
"""
def __init__(self, with_r=False):
super().__init__()
self.with_r = with_r
def forward(self, input_tensor):
"""
Args:
input_tensor: shape(batch, channel, x_dim, y_dim)
"""
batch_size, _, x_dim, y_dim = input_tensor.size()
xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
xx_channel = xx_channel.float() / (x_dim - 1)
yy_channel = yy_channel.float() / (y_dim - 1)
xx_channel = xx_channel * 2 - 1
yy_channel = yy_channel * 2 - 1
xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
ret = torch.cat([
input_tensor,
xx_channel.type_as(input_tensor),
yy_channel.type_as(input_tensor)
],
dim=1)
if self.with_r:
rr = torch.sqrt(
torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) +
torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2))
ret = torch.cat([ret, rr], dim=1)
return ret
class CoordConv(nn.Module):
"""
Source:
https://github.com/mkocabas/CoordConv-pytorch/blob/master/CoordConv.py
"""
def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
super().__init__()
self.addcoords = AddCoords(with_r=with_r)
in_size = in_channels + 2
if with_r:
in_size += 1
self.conv = nn.Conv2d(in_size, out_channels, **kwargs)
def forward(self, x):
ret = self.addcoords(x)
ret = self.conv(ret)
return ret
class EqualLinear(nn.Module):
def __init__(
self,
in_dim,
out_dim,
bias=True,
bias_init=0,
lr_mul=1,
):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
out = F.linear(input,
self.weight * self.scale,
bias=self.bias * self.lr_mul)
return out
def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
)
class PixelNormLayer(nn.Module):
"""Implements pixel-wise feature vector normalization layer."""
def __init__(self, dim, eps):
super().__init__()
self.dim = dim
self.eps = eps
def extra_repr(self):
return f'dim={self.dim}, epsilon={self.eps}'
def forward(self, x):
scale = (x.square().mean(dim=self.dim, keepdim=True) +
self.eps).rsqrt()
return x * scale