|
|
|
"""Contains the implementation of discriminator described in StyleGAN. |
|
|
|
Paper: https://arxiv.org/pdf/1812.04948.pdf |
|
|
|
Official TensorFlow implementation: https://github.com/NVlabs/stylegan |
|
""" |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.cuda.amp import autocast |
|
|
|
import math |
|
|
|
__all__ = ['PiGANDiscriminator'] |
|
|
|
|
|
class PiGANDiscriminator(nn.Module): |
|
|
|
def __init__(self, |
|
resolution, |
|
latent_dim=256, |
|
label_dim=0, |
|
embedding_dim=256, |
|
normalize_embedding=True, |
|
**kwargs): |
|
super().__init__() |
|
self.label_dim = label_dim |
|
self.embedding_dim = embedding_dim |
|
self.normalize_embedding = normalize_embedding |
|
|
|
self.register_buffer('lod', torch.zeros(())) |
|
|
|
self.use_embedding = label_dim > 0 and embedding_dim > 0 |
|
if self.use_embedding > 0: |
|
self.class_embedding = EqualLinear(label_dim, |
|
embedding_dim, |
|
bias=True, |
|
bias_init=0, |
|
lr_mul=1) |
|
self.norm = PixelNormLayer(dim=1, eps=1e-8) |
|
|
|
self.layers = nn.ModuleList([ |
|
ResidualCCBlock(32, 64), |
|
ResidualCCBlock(64, 128), |
|
ResidualCCBlock(128, 256), |
|
ResidualCCBlock(256, 400), |
|
ResidualCCBlock(400, 400), |
|
ResidualCCBlock(400, 400), |
|
ResidualCCBlock(400, 400), |
|
]) |
|
|
|
self.fromRGB = nn.ModuleList([ |
|
AdapterBlock(32), |
|
AdapterBlock(64), |
|
AdapterBlock(128), |
|
AdapterBlock(256), |
|
AdapterBlock(400), |
|
AdapterBlock(400), |
|
AdapterBlock(400), |
|
AdapterBlock(400), |
|
]) |
|
self.score_conv = nn.Conv2d( |
|
400, embedding_dim if self.use_embedding else max(label_dim, 1), 2) |
|
self.latent_conv = nn.Conv2d(400, latent_dim, 2) |
|
self.camera_conv = nn.Conv2d(400, 2, 2) |
|
|
|
self.img_size_to_layer = { |
|
2: 7, |
|
4: 6, |
|
8: 5, |
|
16: 4, |
|
32: 3, |
|
64: 2, |
|
128: 1, |
|
256: 0 |
|
} |
|
self.register_buffer('lod', torch.zeros(())) |
|
|
|
def forward(self, |
|
input, |
|
label=None, |
|
options=None, |
|
alpha=None, |
|
enable_amp=False, |
|
**kwargs): |
|
|
|
if self.label_dim > 0: |
|
if label is None: |
|
raise ValueError( |
|
f'Model requires an additional label ' |
|
f'(with dimension {self.label_dim}) as input, ' |
|
f'but no label is received!') |
|
if label.ndim != 2 or label.shape != (input.shape[0], |
|
self.label_dim): |
|
raise ValueError(f'Input label should be with shape ' |
|
f'[batch_size, label_dim], where ' |
|
f'`batch_size` equals to that of ' |
|
f'images ({input.shape[0]}) and ' |
|
f'`label_dim` equals to {self.label_dim}!\n' |
|
f'But `{label.shape}` is received!') |
|
label = label.to(dtype=torch.float32) |
|
if self.use_embedding: |
|
embed = self.class_embedding(label) |
|
if self.normalize_embedding: |
|
embed = self.norm(embed) |
|
|
|
start = self.img_size_to_layer[input.shape[-1]] |
|
|
|
with autocast(enabled=enable_amp): |
|
x = self.fromRGB[start](input) |
|
|
|
if kwargs.get('instance_noise', 0) > 0: |
|
x = x + torch.randn_like(x) * kwargs['instance_noise'] |
|
|
|
for i, layer in enumerate(self.layers[start:]): |
|
if i == 1 and alpha < 1: |
|
down_image = F.interpolate(input, |
|
scale_factor=0.5, |
|
mode='nearest') |
|
x = alpha * x + (1 - alpha) * self.fromRGB[start + |
|
1](down_image) |
|
|
|
x = layer(x) |
|
|
|
|
|
score = self.score_conv(x).reshape(x.shape[0], -1) |
|
if self.use_embedding: |
|
score = (score * embed).sum(dim=1, keepdim=True) |
|
score = score / np.sqrt(self.embedding_dim) |
|
elif self.label_dim > 0: |
|
score = (score * label).sum(dim=1, keepdim=True) |
|
|
|
latent = self.latent_conv(x).reshape(x.shape[0], -1) |
|
position = self.camera_conv(x).reshape(x.shape[0], -1) |
|
|
|
results = { |
|
'score': score, |
|
'latent': latent, |
|
'camera': position, |
|
} |
|
return results |
|
|
|
|
|
class ResidualCCBlock(nn.Module): |
|
|
|
def __init__(self, inplanes, planes, kernel_size=3): |
|
super().__init__() |
|
p = kernel_size // 2 |
|
self.network = nn.Sequential( |
|
CoordConv(inplanes, planes, kernel_size=kernel_size, padding=p), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
CoordConv(planes, |
|
planes, |
|
kernel_size=kernel_size, |
|
stride=2, |
|
padding=p), nn.LeakyReLU(0.2, inplace=True)) |
|
self.proj = nn.Conv2d(inplanes, planes, 1, stride=2) |
|
|
|
def init_weights(self): |
|
for module in self.modules(): |
|
if isinstance(module, nn.Linear): |
|
nn.init.kaiming_normal_(module.weight, |
|
a=0.2, |
|
mode='fan_in', |
|
nonlinearity='leaky_relu') |
|
|
|
def forward(self, input): |
|
y = self.network(input) |
|
|
|
identity = self.proj(input) |
|
|
|
y = (y + identity) / math.sqrt(2) |
|
return y |
|
|
|
|
|
class AdapterBlock(nn.Module): |
|
|
|
def __init__(self, output_channels): |
|
super().__init__() |
|
self.model = nn.Sequential(nn.Conv2d(3, output_channels, 1, padding=0), |
|
nn.LeakyReLU(0.2, inplace=True)) |
|
|
|
def forward(self, input): |
|
return self.model(input) |
|
|
|
|
|
class AddCoords(nn.Module): |
|
""" |
|
Source: |
|
https://github.com/mkocabas/CoordConv-pytorch/blob/master/CoordConv.py |
|
""" |
|
|
|
def __init__(self, with_r=False): |
|
super().__init__() |
|
self.with_r = with_r |
|
|
|
def forward(self, input_tensor): |
|
""" |
|
Args: |
|
input_tensor: shape(batch, channel, x_dim, y_dim) |
|
""" |
|
batch_size, _, x_dim, y_dim = input_tensor.size() |
|
|
|
xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) |
|
yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) |
|
|
|
xx_channel = xx_channel.float() / (x_dim - 1) |
|
yy_channel = yy_channel.float() / (y_dim - 1) |
|
|
|
xx_channel = xx_channel * 2 - 1 |
|
yy_channel = yy_channel * 2 - 1 |
|
|
|
xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) |
|
yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) |
|
|
|
ret = torch.cat([ |
|
input_tensor, |
|
xx_channel.type_as(input_tensor), |
|
yy_channel.type_as(input_tensor) |
|
], |
|
dim=1) |
|
|
|
if self.with_r: |
|
rr = torch.sqrt( |
|
torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + |
|
torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2)) |
|
ret = torch.cat([ret, rr], dim=1) |
|
|
|
return ret |
|
|
|
|
|
class CoordConv(nn.Module): |
|
""" |
|
Source: |
|
https://github.com/mkocabas/CoordConv-pytorch/blob/master/CoordConv.py |
|
""" |
|
|
|
def __init__(self, in_channels, out_channels, with_r=False, **kwargs): |
|
super().__init__() |
|
self.addcoords = AddCoords(with_r=with_r) |
|
in_size = in_channels + 2 |
|
if with_r: |
|
in_size += 1 |
|
self.conv = nn.Conv2d(in_size, out_channels, **kwargs) |
|
|
|
def forward(self, x): |
|
ret = self.addcoords(x) |
|
ret = self.conv(ret) |
|
return ret |
|
|
|
|
|
class EqualLinear(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
in_dim, |
|
out_dim, |
|
bias=True, |
|
bias_init=0, |
|
lr_mul=1, |
|
): |
|
super().__init__() |
|
|
|
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) |
|
|
|
if bias: |
|
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) |
|
|
|
else: |
|
self.bias = None |
|
|
|
self.scale = (1 / math.sqrt(in_dim)) * lr_mul |
|
self.lr_mul = lr_mul |
|
|
|
def forward(self, input): |
|
out = F.linear(input, |
|
self.weight * self.scale, |
|
bias=self.bias * self.lr_mul) |
|
return out |
|
|
|
def __repr__(self): |
|
return ( |
|
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' |
|
) |
|
|
|
|
|
class PixelNormLayer(nn.Module): |
|
"""Implements pixel-wise feature vector normalization layer.""" |
|
|
|
def __init__(self, dim, eps): |
|
super().__init__() |
|
self.dim = dim |
|
self.eps = eps |
|
|
|
def extra_repr(self): |
|
return f'dim={self.dim}, epsilon={self.eps}' |
|
|
|
def forward(self, x): |
|
scale = (x.square().mean(dim=self.dim, keepdim=True) + |
|
self.eps).rsqrt() |
|
return x * scale |