|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import functools |
|
from . import base_function |
|
from .stylegan_ops import style_function |
|
from .transformer_ops import transformer_function |
|
|
|
|
|
|
|
|
|
|
|
def define_D(opt, img_size): |
|
"""Create a discriminator""" |
|
norm_value = base_function.get_norm_layer(opt.norm) |
|
if 'patch' in opt.netD: |
|
net = NLayerDiscriminator(opt.img_nc, opt.ndf, opt.n_layers_D, norm_value, use_attn=opt.attn_D) |
|
elif 'style' in opt.netD: |
|
net = StyleDiscriminator(img_size, ndf=opt.ndf, use_attn=opt.attn_D) |
|
else: |
|
raise NotImplementedError('Discriminator model name [%s] is not recognized' % opt.netD) |
|
|
|
return base_function.init_net(net, opt.init_type, opt.init_gain, initialize_weights=('style' not in opt.netD)) |
|
|
|
|
|
def define_G(opt): |
|
"""Create a decoder""" |
|
if 'diff' in opt.netG: |
|
net = base_function.DiffDecoder(opt.img_nc, opt.ngf, opt.kernel_G, opt.embed_dim, opt.n_layers_G, opt.num_res_blocks, |
|
word_size=opt.word_size, activation=opt.activation, norm=opt.norm, |
|
add_noise=opt.add_noise, use_attn=opt.attn_G, use_pos=opt.use_pos_G) |
|
elif 'linear' in opt.netG: |
|
net = base_function.LinearDecoder(opt.img_nc, opt.ngf, opt.kernel_G, opt.embed_dim, opt.activation, opt.norm) |
|
elif 'refine' in opt.netG: |
|
net = RefinedGenerator(opt.img_nc, opt.ngf, opt.embed_dim, opt.down_layers, opt.mid_layers, opt.num_res_blocks, |
|
activation=opt.activation, norm=opt.norm) |
|
else: |
|
raise NotImplementedError('Decoder model name [%s] is not recognized' % opt.netG) |
|
|
|
return base_function.init_net(net, opt.init_type, opt.init_gain, initialize_weights=('style' not in opt.netG)) |
|
|
|
|
|
def define_E(opt): |
|
"""Create a encoder""" |
|
if 'diff' in opt.netE: |
|
net = base_function.DiffEncoder(opt.img_nc, opt.ngf, opt.kernel_E, opt.embed_dim, opt.n_layers_G, opt.num_res_blocks, |
|
activation=opt.activation, norm=opt.norm, use_attn=opt.attn_E) |
|
elif 'linear' in opt.netE: |
|
net = base_function.LinearEncoder(opt.img_nc, opt.kernel_E, opt.embed_dim) |
|
else: |
|
raise NotImplementedError('Encoder model name [%s] is not recognized' % opt.netE) |
|
|
|
return base_function.init_net(net, opt.init_type, opt.init_gain, initialize_weights=('style' not in opt.netE)) |
|
|
|
|
|
def define_T(opt): |
|
"""Create a transformer""" |
|
if "original" in opt.netT: |
|
e_d_f = int(opt.ngf * (2 ** opt.n_layers_G)) |
|
net = transformer_function.Transformer(e_d_f, opt.embed_dim, e_d_f, kernel=opt.kernel_T, |
|
n_encoders=opt.n_encoders, n_decoders=opt.n_decoders, embed_type=opt.embed_type) |
|
else: |
|
raise NotImplementedError('Transformer model name [%s] is not recognized' % opt.netT) |
|
return net |
|
|
|
|
|
|
|
|
|
|
|
class NLayerDiscriminator(nn.Module): |
|
"""Defines a PatchGAN discriminator""" |
|
|
|
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_attn=False): |
|
"""Construct a PatchGAN discriminator |
|
|
|
Parameters: |
|
input_nc (int) -- the number of channels in input examples |
|
ndf (int) -- the number of filters in the last conv layer |
|
n_layers (int) -- the number of conv layers in the discriminator |
|
norm_layer -- normalization layer |
|
""" |
|
super(NLayerDiscriminator, self).__init__() |
|
if type(norm_layer) == functools.partial: |
|
use_bias = norm_layer.func == nn.InstanceNorm2d |
|
else: |
|
use_bias = norm_layer == nn.InstanceNorm2d |
|
|
|
kw = 4 |
|
padw = 1 |
|
|
|
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] |
|
|
|
nf_mult = 1 |
|
for n in range(1, n_layers): |
|
nf_mult_prev = nf_mult |
|
nf_mult = min(2 ** n, 8) |
|
sequence += [ |
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), |
|
norm_layer(ndf * nf_mult), |
|
nn.LeakyReLU(0.2, True)] |
|
if n == 2 and use_attn: |
|
sequence += [ |
|
nn.Conv2d(ndf * nf_mult, ndf * nf_mult, kernel_size=1, stride=1, bias=use_bias), |
|
base_function.AttnAware(ndf * nf_mult) |
|
] |
|
|
|
nf_mult_prev = nf_mult |
|
nf_mult = min(2 ** n_layers, 8) |
|
sequence += [ |
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), |
|
norm_layer(ndf * nf_mult), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
|
|
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] |
|
self.model = nn.Sequential(*sequence) |
|
|
|
def forward(self, input): |
|
"""Standard forward.""" |
|
return self.model(input) |
|
|
|
|
|
class StyleDiscriminator(nn.Module): |
|
def __init__(self, img_size, ndf=32, blur_kernel=[1, 3, 3, 1], use_attn=False): |
|
super(StyleDiscriminator, self).__init__() |
|
|
|
channel_multiplier = ndf / 64 |
|
channels = { |
|
4: 512, |
|
8: 512, |
|
16: 512, |
|
32: int(512 * channel_multiplier), |
|
64: int(256 * channel_multiplier), |
|
128: int(128 * channel_multiplier), |
|
256: int(64 * channel_multiplier), |
|
512: int(32 * channel_multiplier), |
|
1024: int(16 * channel_multiplier), |
|
} |
|
|
|
convs = [style_function.ConvLayer(3, channels[img_size], 1)] |
|
|
|
log_size = int(np.log2(img_size)) |
|
|
|
in_channel = channels[img_size] |
|
|
|
for i in range(log_size, 2, -1): |
|
out_channel = channels[2**(i-1)] |
|
if i == log_size - 3 and use_attn: |
|
convs.append(base_function.AttnAware(in_channel)) |
|
convs.append(style_function.StyleBlock(in_channel, out_channel, blur_kernel)) |
|
|
|
in_channel = out_channel |
|
|
|
self.convs = nn.Sequential(*convs) |
|
|
|
self.stddev_group = 4 |
|
self.stddev_feat = 1 |
|
|
|
self.final_conv = style_function.ConvLayer(in_channel+1, channels[4], 3) |
|
self.final_linear = nn.Sequential( |
|
style_function.EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), |
|
style_function.EqualLinear(channels[4], 1), |
|
) |
|
|
|
def forward(self, x): |
|
|
|
out = self.convs(x) |
|
|
|
b, c, h, w = out.shape |
|
group = min(b, self.stddev_group) |
|
stddev = out.view(group, -1, self.stddev_feat, c // self.stddev_feat, h, w) |
|
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) |
|
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) |
|
stddev = stddev.repeat(group, 1, h, w) |
|
out = torch.cat([out, stddev], 1) |
|
|
|
out = self.final_conv(out) |
|
out = out.view(b, -1) |
|
out = self.final_linear(out) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class RefinedGenerator(nn.Module): |
|
def __init__(self, input_nc, ngf=64, embed_dim=512, down_layers=3, mid_layers=6, num_res_blocks=1, dropout=0.0, |
|
rample_with_conv=True, activation='gelu', norm='pixel'): |
|
super(RefinedGenerator, self).__init__() |
|
|
|
activation_layer = base_function.get_nonlinearity_layer(activation) |
|
norm_layer = base_function.get_norm_layer(norm) |
|
self.down_layers = down_layers |
|
self.mid_layers = mid_layers |
|
self.num_res_blocks = num_res_blocks |
|
out_dims = [] |
|
|
|
self.encode = base_function.PartialConv2d(input_nc, ngf, kernel_size=3, stride=1, padding=1) |
|
|
|
self.down = nn.ModuleList() |
|
out_dim = ngf |
|
for i in range(self.down_layers): |
|
block = nn.ModuleList() |
|
down = nn.Module() |
|
in_dim = out_dim |
|
out_dims.append(out_dim) |
|
out_dim = min(int(in_dim * 2), embed_dim) |
|
down.downsample = base_function.DownSample(in_dim, rample_with_conv, kernel_size=3) |
|
for i_block in range(self.num_res_blocks): |
|
block.append(base_function.ResnetBlock(in_dim, out_dim, 3, dropout, activation, norm)) |
|
in_dim = out_dim |
|
down.block = block |
|
self.down.append(down) |
|
|
|
self.mid = nn.ModuleList() |
|
for i in range(self.mid_layers): |
|
self.mid.append(base_function.ResnetBlock(out_dim, out_dim, 3, dropout, activation, norm)) |
|
|
|
self.up = nn.ModuleList() |
|
for i in range(self.down_layers): |
|
block = nn.ModuleList() |
|
up = nn.Module() |
|
in_dim = out_dim |
|
out_dim = max(out_dims[-i-1], ngf) |
|
for i_block in range(self.num_res_blocks): |
|
block.append(base_function.ResnetBlock(in_dim, out_dim, 3, dropout, activation, norm)) |
|
in_dim = out_dim |
|
if i == self.down_layers - 3: |
|
up.attn = base_function.AttnAware(out_dim, activation, norm) |
|
up.block = block |
|
upsample = True if i != 0 else False |
|
up.out = base_function.ToRGB(out_dim, input_nc, upsample, activation, norm) |
|
up.upsample = base_function.UpSample(out_dim, rample_with_conv, kernel_size=3) |
|
self.up.append(up) |
|
|
|
self.decode = base_function.ToRGB(out_dim, input_nc, True, activation, norm) |
|
|
|
def forward(self, x, mask=None): |
|
|
|
x = self.encode(x) |
|
pre = None |
|
|
|
for i in range(self.down_layers): |
|
x = self.down[i].downsample(x) |
|
if i == 2: |
|
pre = x |
|
for i_block in range(self.num_res_blocks): |
|
x = self.down[i].block[i_block](x) |
|
|
|
for i in range(self.mid_layers): |
|
x = self.mid[i](x) |
|
|
|
skip = None |
|
for i in range(self.down_layers): |
|
for i_block in range(self.num_res_blocks): |
|
x = self.up[i].block[i_block](x) |
|
if i == self.down_layers - 3: |
|
mask = F.interpolate(mask, size=x.size()[2:], mode='bilinear', align_corners=True) if mask is not None else None |
|
x = self.up[i].attn(x, pre=pre, mask=mask) |
|
skip = self.up[i].out(x, skip) |
|
x = self.up[i].upsample(x) |
|
|
|
x = self.decode(x, skip) |
|
|
|
return x |