|
|
|
import collections |
|
from functools import partial |
|
import functools |
|
import logging |
|
from collections import defaultdict |
|
|
|
import numpy as np |
|
import torch.nn as nn |
|
|
|
from saicinpainting.training.modules.base import BaseDiscriminator, deconv_factory, get_conv_block_ctor, get_norm_layer, get_activation |
|
from saicinpainting.training.modules.ffc import FFCResnetBlock |
|
from saicinpainting.training.modules.multidilated_conv import MultidilatedConv |
|
|
|
class DotDict(defaultdict): |
|
|
|
"""dot.notation access to dictionary attributes""" |
|
__getattr__ = defaultdict.get |
|
__setattr__ = defaultdict.__setitem__ |
|
__delattr__ = defaultdict.__delitem__ |
|
|
|
class Identity(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x): |
|
return x |
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default', |
|
dilation=1, in_dim=None, groups=1, second_dilation=None): |
|
super(ResnetBlock, self).__init__() |
|
self.in_dim = in_dim |
|
self.dim = dim |
|
if second_dilation is None: |
|
second_dilation = dilation |
|
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout, |
|
conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups, |
|
second_dilation=second_dilation) |
|
|
|
if self.in_dim is not None: |
|
self.input_conv = nn.Conv2d(in_dim, dim, 1) |
|
|
|
self.out_channnels = dim |
|
|
|
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default', |
|
dilation=1, in_dim=None, groups=1, second_dilation=1): |
|
conv_layer = get_conv_block_ctor(conv_kind) |
|
|
|
conv_block = [] |
|
p = 0 |
|
if padding_type == 'reflect': |
|
conv_block += [nn.ReflectionPad2d(dilation)] |
|
elif padding_type == 'replicate': |
|
conv_block += [nn.ReplicationPad2d(dilation)] |
|
elif padding_type == 'zero': |
|
p = dilation |
|
else: |
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type) |
|
|
|
if in_dim is None: |
|
in_dim = dim |
|
|
|
conv_block += [conv_layer(in_dim, dim, kernel_size=3, padding=p, dilation=dilation), |
|
norm_layer(dim), |
|
activation] |
|
if use_dropout: |
|
conv_block += [nn.Dropout(0.5)] |
|
|
|
p = 0 |
|
if padding_type == 'reflect': |
|
conv_block += [nn.ReflectionPad2d(second_dilation)] |
|
elif padding_type == 'replicate': |
|
conv_block += [nn.ReplicationPad2d(second_dilation)] |
|
elif padding_type == 'zero': |
|
p = second_dilation |
|
else: |
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type) |
|
conv_block += [conv_layer(dim, dim, kernel_size=3, padding=p, dilation=second_dilation, groups=groups), |
|
norm_layer(dim)] |
|
|
|
return nn.Sequential(*conv_block) |
|
|
|
def forward(self, x): |
|
x_before = x |
|
if self.in_dim is not None: |
|
x = self.input_conv(x) |
|
out = x + self.conv_block(x_before) |
|
return out |
|
|
|
class ResnetBlock5x5(nn.Module): |
|
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default', |
|
dilation=1, in_dim=None, groups=1, second_dilation=None): |
|
super(ResnetBlock5x5, self).__init__() |
|
self.in_dim = in_dim |
|
self.dim = dim |
|
if second_dilation is None: |
|
second_dilation = dilation |
|
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout, |
|
conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups, |
|
second_dilation=second_dilation) |
|
|
|
if self.in_dim is not None: |
|
self.input_conv = nn.Conv2d(in_dim, dim, 1) |
|
|
|
self.out_channnels = dim |
|
|
|
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default', |
|
dilation=1, in_dim=None, groups=1, second_dilation=1): |
|
conv_layer = get_conv_block_ctor(conv_kind) |
|
|
|
conv_block = [] |
|
p = 0 |
|
if padding_type == 'reflect': |
|
conv_block += [nn.ReflectionPad2d(dilation * 2)] |
|
elif padding_type == 'replicate': |
|
conv_block += [nn.ReplicationPad2d(dilation * 2)] |
|
elif padding_type == 'zero': |
|
p = dilation * 2 |
|
else: |
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type) |
|
|
|
if in_dim is None: |
|
in_dim = dim |
|
|
|
conv_block += [conv_layer(in_dim, dim, kernel_size=5, padding=p, dilation=dilation), |
|
norm_layer(dim), |
|
activation] |
|
if use_dropout: |
|
conv_block += [nn.Dropout(0.5)] |
|
|
|
p = 0 |
|
if padding_type == 'reflect': |
|
conv_block += [nn.ReflectionPad2d(second_dilation * 2)] |
|
elif padding_type == 'replicate': |
|
conv_block += [nn.ReplicationPad2d(second_dilation * 2)] |
|
elif padding_type == 'zero': |
|
p = second_dilation * 2 |
|
else: |
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type) |
|
conv_block += [conv_layer(dim, dim, kernel_size=5, padding=p, dilation=second_dilation, groups=groups), |
|
norm_layer(dim)] |
|
|
|
return nn.Sequential(*conv_block) |
|
|
|
def forward(self, x): |
|
x_before = x |
|
if self.in_dim is not None: |
|
x = self.input_conv(x) |
|
out = x + self.conv_block(x_before) |
|
return out |
|
|
|
|
|
class MultidilatedResnetBlock(nn.Module): |
|
def __init__(self, dim, padding_type, conv_layer, norm_layer, activation=nn.ReLU(True), use_dropout=False): |
|
super().__init__() |
|
self.conv_block = self.build_conv_block(dim, padding_type, conv_layer, norm_layer, activation, use_dropout) |
|
|
|
def build_conv_block(self, dim, padding_type, conv_layer, norm_layer, activation, use_dropout, dilation=1): |
|
conv_block = [] |
|
conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type), |
|
norm_layer(dim), |
|
activation] |
|
if use_dropout: |
|
conv_block += [nn.Dropout(0.5)] |
|
|
|
conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type), |
|
norm_layer(dim)] |
|
|
|
return nn.Sequential(*conv_block) |
|
|
|
def forward(self, x): |
|
out = x + self.conv_block(x) |
|
return out |
|
|
|
|
|
class MultiDilatedGlobalGenerator(nn.Module): |
|
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, |
|
n_blocks=3, norm_layer=nn.BatchNorm2d, |
|
padding_type='reflect', conv_kind='default', |
|
deconv_kind='convtranspose', activation=nn.ReLU(True), |
|
up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True), |
|
add_out_act=True, max_features=1024, multidilation_kwargs={}, |
|
ffc_positions=None, ffc_kwargs={}): |
|
assert (n_blocks >= 0) |
|
super().__init__() |
|
|
|
conv_layer = get_conv_block_ctor(conv_kind) |
|
resnet_conv_layer = functools.partial(get_conv_block_ctor('multidilated'), **multidilation_kwargs) |
|
norm_layer = get_norm_layer(norm_layer) |
|
if affine is not None: |
|
norm_layer = partial(norm_layer, affine=affine) |
|
up_norm_layer = get_norm_layer(up_norm_layer) |
|
if affine is not None: |
|
up_norm_layer = partial(up_norm_layer, affine=affine) |
|
|
|
model = [nn.ReflectionPad2d(3), |
|
conv_layer(input_nc, ngf, kernel_size=7, padding=0), |
|
norm_layer(ngf), |
|
activation] |
|
|
|
identity = Identity() |
|
|
|
for i in range(n_downsampling): |
|
mult = 2 ** i |
|
|
|
model += [conv_layer(min(max_features, ngf * mult), |
|
min(max_features, ngf * mult * 2), |
|
kernel_size=3, stride=2, padding=1), |
|
norm_layer(min(max_features, ngf * mult * 2)), |
|
activation] |
|
|
|
mult = 2 ** n_downsampling |
|
feats_num_bottleneck = min(max_features, ngf * mult) |
|
|
|
|
|
for i in range(n_blocks): |
|
if ffc_positions is not None and i in ffc_positions: |
|
model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU, |
|
inline=True, **ffc_kwargs)] |
|
model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type, |
|
conv_layer=resnet_conv_layer, activation=activation, |
|
norm_layer=norm_layer)] |
|
|
|
|
|
for i in range(n_downsampling): |
|
mult = 2 ** (n_downsampling - i) |
|
model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features) |
|
model += [nn.ReflectionPad2d(3), |
|
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] |
|
if add_out_act: |
|
model.append(get_activation('tanh' if add_out_act is True else add_out_act)) |
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, input): |
|
return self.model(input) |
|
|
|
class ConfigGlobalGenerator(nn.Module): |
|
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, |
|
n_blocks=3, norm_layer=nn.BatchNorm2d, |
|
padding_type='reflect', conv_kind='default', |
|
deconv_kind='convtranspose', activation=nn.ReLU(True), |
|
up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True), |
|
add_out_act=True, max_features=1024, |
|
manual_block_spec=[], |
|
resnet_block_kind='multidilatedresnetblock', |
|
resnet_conv_kind='multidilated', |
|
resnet_dilation=1, |
|
multidilation_kwargs={}): |
|
assert (n_blocks >= 0) |
|
super().__init__() |
|
|
|
conv_layer = get_conv_block_ctor(conv_kind) |
|
resnet_conv_layer = functools.partial(get_conv_block_ctor(resnet_conv_kind), **multidilation_kwargs) |
|
norm_layer = get_norm_layer(norm_layer) |
|
if affine is not None: |
|
norm_layer = partial(norm_layer, affine=affine) |
|
up_norm_layer = get_norm_layer(up_norm_layer) |
|
if affine is not None: |
|
up_norm_layer = partial(up_norm_layer, affine=affine) |
|
|
|
model = [nn.ReflectionPad2d(3), |
|
conv_layer(input_nc, ngf, kernel_size=7, padding=0), |
|
norm_layer(ngf), |
|
activation] |
|
|
|
identity = Identity() |
|
|
|
|
|
for i in range(n_downsampling): |
|
mult = 2 ** i |
|
model += [conv_layer(min(max_features, ngf * mult), |
|
min(max_features, ngf * mult * 2), |
|
kernel_size=3, stride=2, padding=1), |
|
norm_layer(min(max_features, ngf * mult * 2)), |
|
activation] |
|
|
|
mult = 2 ** n_downsampling |
|
feats_num_bottleneck = min(max_features, ngf * mult) |
|
|
|
if len(manual_block_spec) == 0: |
|
manual_block_spec = [ |
|
DotDict(lambda : None, { |
|
'n_blocks': n_blocks, |
|
'use_default': True}) |
|
] |
|
|
|
|
|
for block_spec in manual_block_spec: |
|
def make_and_add_blocks(model, block_spec): |
|
block_spec = DotDict(lambda : None, block_spec) |
|
if not block_spec.use_default: |
|
resnet_conv_layer = functools.partial(get_conv_block_ctor(block_spec.resnet_conv_kind), **block_spec.multidilation_kwargs) |
|
resnet_conv_kind = block_spec.resnet_conv_kind |
|
resnet_block_kind = block_spec.resnet_block_kind |
|
if block_spec.resnet_dilation is not None: |
|
resnet_dilation = block_spec.resnet_dilation |
|
for i in range(block_spec.n_blocks): |
|
if resnet_block_kind == "multidilatedresnetblock": |
|
model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type, |
|
conv_layer=resnet_conv_layer, activation=activation, |
|
norm_layer=norm_layer)] |
|
if resnet_block_kind == "resnetblock": |
|
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer, |
|
conv_kind=resnet_conv_kind)] |
|
if resnet_block_kind == "resnetblock5x5": |
|
model += [ResnetBlock5x5(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer, |
|
conv_kind=resnet_conv_kind)] |
|
if resnet_block_kind == "resnetblockdwdil": |
|
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer, |
|
conv_kind=resnet_conv_kind, dilation=resnet_dilation, second_dilation=resnet_dilation)] |
|
make_and_add_blocks(model, block_spec) |
|
|
|
|
|
for i in range(n_downsampling): |
|
mult = 2 ** (n_downsampling - i) |
|
model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features) |
|
model += [nn.ReflectionPad2d(3), |
|
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] |
|
if add_out_act: |
|
model.append(get_activation('tanh' if add_out_act is True else add_out_act)) |
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, input): |
|
return self.model(input) |
|
|
|
|
|
def make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs): |
|
blocks = [] |
|
for i in range(dilated_blocks_n): |
|
if dilation_block_kind == 'simple': |
|
blocks.append(ResnetBlock(**dilated_block_kwargs, dilation=2 ** (i + 1))) |
|
elif dilation_block_kind == 'multi': |
|
blocks.append(MultidilatedResnetBlock(**dilated_block_kwargs)) |
|
else: |
|
raise ValueError(f'dilation_block_kind could not be "{dilation_block_kind}"') |
|
return blocks |
|
|
|
|
|
class GlobalGenerator(nn.Module): |
|
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, |
|
padding_type='reflect', conv_kind='default', activation=nn.ReLU(True), |
|
up_norm_layer=nn.BatchNorm2d, affine=None, |
|
up_activation=nn.ReLU(True), dilated_blocks_n=0, dilated_blocks_n_start=0, |
|
dilated_blocks_n_middle=0, |
|
add_out_act=True, |
|
max_features=1024, is_resblock_depthwise=False, |
|
ffc_positions=None, ffc_kwargs={}, dilation=1, second_dilation=None, |
|
dilation_block_kind='simple', multidilation_kwargs={}): |
|
assert (n_blocks >= 0) |
|
super().__init__() |
|
|
|
conv_layer = get_conv_block_ctor(conv_kind) |
|
norm_layer = get_norm_layer(norm_layer) |
|
if affine is not None: |
|
norm_layer = partial(norm_layer, affine=affine) |
|
up_norm_layer = get_norm_layer(up_norm_layer) |
|
if affine is not None: |
|
up_norm_layer = partial(up_norm_layer, affine=affine) |
|
|
|
if ffc_positions is not None: |
|
ffc_positions = collections.Counter(ffc_positions) |
|
|
|
model = [nn.ReflectionPad2d(3), |
|
conv_layer(input_nc, ngf, kernel_size=7, padding=0), |
|
norm_layer(ngf), |
|
activation] |
|
|
|
identity = Identity() |
|
|
|
for i in range(n_downsampling): |
|
mult = 2 ** i |
|
|
|
model += [conv_layer(min(max_features, ngf * mult), |
|
min(max_features, ngf * mult * 2), |
|
kernel_size=3, stride=2, padding=1), |
|
norm_layer(min(max_features, ngf * mult * 2)), |
|
activation] |
|
|
|
mult = 2 ** n_downsampling |
|
feats_num_bottleneck = min(max_features, ngf * mult) |
|
|
|
dilated_block_kwargs = dict(dim=feats_num_bottleneck, padding_type=padding_type, |
|
activation=activation, norm_layer=norm_layer) |
|
if dilation_block_kind == 'simple': |
|
dilated_block_kwargs['conv_kind'] = conv_kind |
|
elif dilation_block_kind == 'multi': |
|
dilated_block_kwargs['conv_layer'] = functools.partial( |
|
get_conv_block_ctor('multidilated'), **multidilation_kwargs) |
|
|
|
|
|
if dilated_blocks_n_start is not None and dilated_blocks_n_start > 0: |
|
model += make_dil_blocks(dilated_blocks_n_start, dilation_block_kind, dilated_block_kwargs) |
|
|
|
|
|
for i in range(n_blocks): |
|
|
|
if i == n_blocks // 2 and dilated_blocks_n_middle is not None and dilated_blocks_n_middle > 0: |
|
model += make_dil_blocks(dilated_blocks_n_middle, dilation_block_kind, dilated_block_kwargs) |
|
|
|
if ffc_positions is not None and i in ffc_positions: |
|
for _ in range(ffc_positions[i]): |
|
model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU, |
|
inline=True, **ffc_kwargs)] |
|
|
|
if is_resblock_depthwise: |
|
resblock_groups = feats_num_bottleneck |
|
else: |
|
resblock_groups = 1 |
|
|
|
model += [ResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation=activation, |
|
norm_layer=norm_layer, conv_kind=conv_kind, groups=resblock_groups, |
|
dilation=dilation, second_dilation=second_dilation)] |
|
|
|
|
|
|
|
if dilated_blocks_n is not None and dilated_blocks_n > 0: |
|
model += make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs) |
|
|
|
|
|
for i in range(n_downsampling): |
|
mult = 2 ** (n_downsampling - i) |
|
model += [nn.ConvTranspose2d(min(max_features, ngf * mult), |
|
min(max_features, int(ngf * mult / 2)), |
|
kernel_size=3, stride=2, padding=1, output_padding=1), |
|
up_norm_layer(min(max_features, int(ngf * mult / 2))), |
|
up_activation] |
|
model += [nn.ReflectionPad2d(3), |
|
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] |
|
if add_out_act: |
|
model.append(get_activation('tanh' if add_out_act is True else add_out_act)) |
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, input): |
|
return self.model(input) |
|
|
|
|
|
class GlobalGeneratorGated(GlobalGenerator): |
|
def __init__(self, *args, **kwargs): |
|
real_kwargs=dict( |
|
conv_kind='gated_bn_relu', |
|
activation=nn.Identity(), |
|
norm_layer=nn.Identity |
|
) |
|
real_kwargs.update(kwargs) |
|
super().__init__(*args, **real_kwargs) |
|
|
|
|
|
class GlobalGeneratorFromSuperChannels(nn.Module): |
|
def __init__(self, input_nc, output_nc, n_downsampling, n_blocks, super_channels, norm_layer="bn", padding_type='reflect', add_out_act=True): |
|
super().__init__() |
|
self.n_downsampling = n_downsampling |
|
norm_layer = get_norm_layer(norm_layer) |
|
if type(norm_layer) == functools.partial: |
|
use_bias = (norm_layer.func == nn.InstanceNorm2d) |
|
else: |
|
use_bias = (norm_layer == nn.InstanceNorm2d) |
|
|
|
channels = self.convert_super_channels(super_channels) |
|
self.channels = channels |
|
|
|
model = [nn.ReflectionPad2d(3), |
|
nn.Conv2d(input_nc, channels[0], kernel_size=7, padding=0, bias=use_bias), |
|
norm_layer(channels[0]), |
|
nn.ReLU(True)] |
|
|
|
for i in range(n_downsampling): |
|
mult = 2 ** i |
|
model += [nn.Conv2d(channels[0+i], channels[1+i], kernel_size=3, stride=2, padding=1, bias=use_bias), |
|
norm_layer(channels[1+i]), |
|
nn.ReLU(True)] |
|
|
|
mult = 2 ** n_downsampling |
|
|
|
n_blocks1 = n_blocks // 3 |
|
n_blocks2 = n_blocks1 |
|
n_blocks3 = n_blocks - n_blocks1 - n_blocks2 |
|
|
|
for i in range(n_blocks1): |
|
c = n_downsampling |
|
dim = channels[c] |
|
model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer)] |
|
|
|
for i in range(n_blocks2): |
|
c = n_downsampling+1 |
|
dim = channels[c] |
|
kwargs = {} |
|
if i == 0: |
|
kwargs = {"in_dim": channels[c-1]} |
|
model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)] |
|
|
|
for i in range(n_blocks3): |
|
c = n_downsampling+2 |
|
dim = channels[c] |
|
kwargs = {} |
|
if i == 0: |
|
kwargs = {"in_dim": channels[c-1]} |
|
model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)] |
|
|
|
for i in range(n_downsampling): |
|
mult = 2 ** (n_downsampling - i) |
|
model += [nn.ConvTranspose2d(channels[n_downsampling+3+i], |
|
channels[n_downsampling+3+i+1], |
|
kernel_size=3, stride=2, |
|
padding=1, output_padding=1, |
|
bias=use_bias), |
|
norm_layer(channels[n_downsampling+3+i+1]), |
|
nn.ReLU(True)] |
|
model += [nn.ReflectionPad2d(3)] |
|
model += [nn.Conv2d(channels[2*n_downsampling+3], output_nc, kernel_size=7, padding=0)] |
|
|
|
if add_out_act: |
|
model.append(get_activation('tanh' if add_out_act is True else add_out_act)) |
|
self.model = nn.Sequential(*model) |
|
|
|
def convert_super_channels(self, super_channels): |
|
n_downsampling = self.n_downsampling |
|
result = [] |
|
cnt = 0 |
|
|
|
if n_downsampling == 2: |
|
N1 = 10 |
|
elif n_downsampling == 3: |
|
N1 = 13 |
|
else: |
|
raise NotImplementedError |
|
|
|
for i in range(0, N1): |
|
if i in [1,4,7,10]: |
|
channel = super_channels[cnt] * (2 ** cnt) |
|
config = {'channel': channel} |
|
result.append(channel) |
|
logging.info(f"Downsample channels {result[-1]}") |
|
cnt += 1 |
|
|
|
for i in range(3): |
|
for counter, j in enumerate(range(N1 + i * 3, N1 + 3 + i * 3)): |
|
if len(super_channels) == 6: |
|
channel = super_channels[3] * 4 |
|
else: |
|
channel = super_channels[i + 3] * 4 |
|
config = {'channel': channel} |
|
if counter == 0: |
|
result.append(channel) |
|
logging.info(f"Bottleneck channels {result[-1]}") |
|
cnt = 2 |
|
|
|
for i in range(N1+9, N1+21): |
|
if i in [22, 25,28]: |
|
cnt -= 1 |
|
if len(super_channels) == 6: |
|
channel = super_channels[5 - cnt] * (2 ** cnt) |
|
else: |
|
channel = super_channels[7 - cnt] * (2 ** cnt) |
|
result.append(int(channel)) |
|
logging.info(f"Upsample channels {result[-1]}") |
|
return result |
|
|
|
def forward(self, input): |
|
return self.model(input) |
|
|
|
|
|
|
|
class NLayerDiscriminator(BaseDiscriminator): |
|
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,): |
|
super().__init__() |
|
self.n_layers = n_layers |
|
|
|
kw = 4 |
|
padw = int(np.ceil((kw-1.0)/2)) |
|
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), |
|
nn.LeakyReLU(0.2, True)]] |
|
|
|
nf = ndf |
|
for n in range(1, n_layers): |
|
nf_prev = nf |
|
nf = min(nf * 2, 512) |
|
|
|
cur_model = [] |
|
cur_model += [ |
|
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), |
|
norm_layer(nf), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
sequence.append(cur_model) |
|
|
|
nf_prev = nf |
|
nf = min(nf * 2, 512) |
|
|
|
cur_model = [] |
|
cur_model += [ |
|
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), |
|
norm_layer(nf), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
sequence.append(cur_model) |
|
|
|
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] |
|
|
|
for n in range(len(sequence)): |
|
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) |
|
|
|
def get_all_activations(self, x): |
|
res = [x] |
|
for n in range(self.n_layers + 2): |
|
model = getattr(self, 'model' + str(n)) |
|
res.append(model(res[-1])) |
|
return res[1:] |
|
|
|
def forward(self, x): |
|
act = self.get_all_activations(x) |
|
return act[-1], act[:-1] |
|
|
|
|
|
class MultidilatedNLayerDiscriminator(BaseDiscriminator): |
|
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, multidilation_kwargs={}): |
|
super().__init__() |
|
self.n_layers = n_layers |
|
|
|
kw = 4 |
|
padw = int(np.ceil((kw-1.0)/2)) |
|
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), |
|
nn.LeakyReLU(0.2, True)]] |
|
|
|
nf = ndf |
|
for n in range(1, n_layers): |
|
nf_prev = nf |
|
nf = min(nf * 2, 512) |
|
|
|
cur_model = [] |
|
cur_model += [ |
|
MultidilatedConv(nf_prev, nf, kernel_size=kw, stride=2, padding=[2, 3], **multidilation_kwargs), |
|
norm_layer(nf), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
sequence.append(cur_model) |
|
|
|
nf_prev = nf |
|
nf = min(nf * 2, 512) |
|
|
|
cur_model = [] |
|
cur_model += [ |
|
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), |
|
norm_layer(nf), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
sequence.append(cur_model) |
|
|
|
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] |
|
|
|
for n in range(len(sequence)): |
|
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) |
|
|
|
def get_all_activations(self, x): |
|
res = [x] |
|
for n in range(self.n_layers + 2): |
|
model = getattr(self, 'model' + str(n)) |
|
res.append(model(res[-1])) |
|
return res[1:] |
|
|
|
def forward(self, x): |
|
act = self.get_all_activations(x) |
|
return act[-1], act[:-1] |
|
|
|
|
|
class NLayerDiscriminatorAsGen(NLayerDiscriminator): |
|
def forward(self, x): |
|
return super().forward(x)[0] |
|
|