fffiloni's picture
Duplicate from fffiloni/lama
24eb05d
raw
history blame
28 kB
# original: https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py
import collections
from functools import partial
import functools
import logging
from collections import defaultdict
import numpy as np
import torch.nn as nn
from saicinpainting.training.modules.base import BaseDiscriminator, deconv_factory, get_conv_block_ctor, get_norm_layer, get_activation
from saicinpainting.training.modules.ffc import FFCResnetBlock
from saicinpainting.training.modules.multidilated_conv import MultidilatedConv
class DotDict(defaultdict):
# https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
"""dot.notation access to dictionary attributes"""
__getattr__ = defaultdict.get
__setattr__ = defaultdict.__setitem__
__delattr__ = defaultdict.__delitem__
class Identity(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
dilation=1, in_dim=None, groups=1, second_dilation=None):
super(ResnetBlock, self).__init__()
self.in_dim = in_dim
self.dim = dim
if second_dilation is None:
second_dilation = dilation
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
second_dilation=second_dilation)
if self.in_dim is not None:
self.input_conv = nn.Conv2d(in_dim, dim, 1)
self.out_channnels = dim
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
dilation=1, in_dim=None, groups=1, second_dilation=1):
conv_layer = get_conv_block_ctor(conv_kind)
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(dilation)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(dilation)]
elif padding_type == 'zero':
p = dilation
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
if in_dim is None:
in_dim = dim
conv_block += [conv_layer(in_dim, dim, kernel_size=3, padding=p, dilation=dilation),
norm_layer(dim),
activation]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(second_dilation)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(second_dilation)]
elif padding_type == 'zero':
p = second_dilation
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [conv_layer(dim, dim, kernel_size=3, padding=p, dilation=second_dilation, groups=groups),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
x_before = x
if self.in_dim is not None:
x = self.input_conv(x)
out = x + self.conv_block(x_before)
return out
class ResnetBlock5x5(nn.Module):
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
dilation=1, in_dim=None, groups=1, second_dilation=None):
super(ResnetBlock5x5, self).__init__()
self.in_dim = in_dim
self.dim = dim
if second_dilation is None:
second_dilation = dilation
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
second_dilation=second_dilation)
if self.in_dim is not None:
self.input_conv = nn.Conv2d(in_dim, dim, 1)
self.out_channnels = dim
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
dilation=1, in_dim=None, groups=1, second_dilation=1):
conv_layer = get_conv_block_ctor(conv_kind)
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(dilation * 2)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(dilation * 2)]
elif padding_type == 'zero':
p = dilation * 2
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
if in_dim is None:
in_dim = dim
conv_block += [conv_layer(in_dim, dim, kernel_size=5, padding=p, dilation=dilation),
norm_layer(dim),
activation]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(second_dilation * 2)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(second_dilation * 2)]
elif padding_type == 'zero':
p = second_dilation * 2
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [conv_layer(dim, dim, kernel_size=5, padding=p, dilation=second_dilation, groups=groups),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
x_before = x
if self.in_dim is not None:
x = self.input_conv(x)
out = x + self.conv_block(x_before)
return out
class MultidilatedResnetBlock(nn.Module):
def __init__(self, dim, padding_type, conv_layer, norm_layer, activation=nn.ReLU(True), use_dropout=False):
super().__init__()
self.conv_block = self.build_conv_block(dim, padding_type, conv_layer, norm_layer, activation, use_dropout)
def build_conv_block(self, dim, padding_type, conv_layer, norm_layer, activation, use_dropout, dilation=1):
conv_block = []
conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
norm_layer(dim),
activation]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class MultiDilatedGlobalGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
n_blocks=3, norm_layer=nn.BatchNorm2d,
padding_type='reflect', conv_kind='default',
deconv_kind='convtranspose', activation=nn.ReLU(True),
up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
add_out_act=True, max_features=1024, multidilation_kwargs={},
ffc_positions=None, ffc_kwargs={}):
assert (n_blocks >= 0)
super().__init__()
conv_layer = get_conv_block_ctor(conv_kind)
resnet_conv_layer = functools.partial(get_conv_block_ctor('multidilated'), **multidilation_kwargs)
norm_layer = get_norm_layer(norm_layer)
if affine is not None:
norm_layer = partial(norm_layer, affine=affine)
up_norm_layer = get_norm_layer(up_norm_layer)
if affine is not None:
up_norm_layer = partial(up_norm_layer, affine=affine)
model = [nn.ReflectionPad2d(3),
conv_layer(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
activation]
identity = Identity()
### downsample
for i in range(n_downsampling):
mult = 2 ** i
model += [conv_layer(min(max_features, ngf * mult),
min(max_features, ngf * mult * 2),
kernel_size=3, stride=2, padding=1),
norm_layer(min(max_features, ngf * mult * 2)),
activation]
mult = 2 ** n_downsampling
feats_num_bottleneck = min(max_features, ngf * mult)
### resnet blocks
for i in range(n_blocks):
if ffc_positions is not None and i in ffc_positions:
model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
inline=True, **ffc_kwargs)]
model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
conv_layer=resnet_conv_layer, activation=activation,
norm_layer=norm_layer)]
### upsample
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
model += [nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
if add_out_act:
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
class ConfigGlobalGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
n_blocks=3, norm_layer=nn.BatchNorm2d,
padding_type='reflect', conv_kind='default',
deconv_kind='convtranspose', activation=nn.ReLU(True),
up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
add_out_act=True, max_features=1024,
manual_block_spec=[],
resnet_block_kind='multidilatedresnetblock',
resnet_conv_kind='multidilated',
resnet_dilation=1,
multidilation_kwargs={}):
assert (n_blocks >= 0)
super().__init__()
conv_layer = get_conv_block_ctor(conv_kind)
resnet_conv_layer = functools.partial(get_conv_block_ctor(resnet_conv_kind), **multidilation_kwargs)
norm_layer = get_norm_layer(norm_layer)
if affine is not None:
norm_layer = partial(norm_layer, affine=affine)
up_norm_layer = get_norm_layer(up_norm_layer)
if affine is not None:
up_norm_layer = partial(up_norm_layer, affine=affine)
model = [nn.ReflectionPad2d(3),
conv_layer(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
activation]
identity = Identity()
### downsample
for i in range(n_downsampling):
mult = 2 ** i
model += [conv_layer(min(max_features, ngf * mult),
min(max_features, ngf * mult * 2),
kernel_size=3, stride=2, padding=1),
norm_layer(min(max_features, ngf * mult * 2)),
activation]
mult = 2 ** n_downsampling
feats_num_bottleneck = min(max_features, ngf * mult)
if len(manual_block_spec) == 0:
manual_block_spec = [
DotDict(lambda : None, {
'n_blocks': n_blocks,
'use_default': True})
]
### resnet blocks
for block_spec in manual_block_spec:
def make_and_add_blocks(model, block_spec):
block_spec = DotDict(lambda : None, block_spec)
if not block_spec.use_default:
resnet_conv_layer = functools.partial(get_conv_block_ctor(block_spec.resnet_conv_kind), **block_spec.multidilation_kwargs)
resnet_conv_kind = block_spec.resnet_conv_kind
resnet_block_kind = block_spec.resnet_block_kind
if block_spec.resnet_dilation is not None:
resnet_dilation = block_spec.resnet_dilation
for i in range(block_spec.n_blocks):
if resnet_block_kind == "multidilatedresnetblock":
model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
conv_layer=resnet_conv_layer, activation=activation,
norm_layer=norm_layer)]
if resnet_block_kind == "resnetblock":
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
conv_kind=resnet_conv_kind)]
if resnet_block_kind == "resnetblock5x5":
model += [ResnetBlock5x5(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
conv_kind=resnet_conv_kind)]
if resnet_block_kind == "resnetblockdwdil":
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
conv_kind=resnet_conv_kind, dilation=resnet_dilation, second_dilation=resnet_dilation)]
make_and_add_blocks(model, block_spec)
### upsample
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
model += [nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
if add_out_act:
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
def make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs):
blocks = []
for i in range(dilated_blocks_n):
if dilation_block_kind == 'simple':
blocks.append(ResnetBlock(**dilated_block_kwargs, dilation=2 ** (i + 1)))
elif dilation_block_kind == 'multi':
blocks.append(MultidilatedResnetBlock(**dilated_block_kwargs))
else:
raise ValueError(f'dilation_block_kind could not be "{dilation_block_kind}"')
return blocks
class GlobalGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
up_norm_layer=nn.BatchNorm2d, affine=None,
up_activation=nn.ReLU(True), dilated_blocks_n=0, dilated_blocks_n_start=0,
dilated_blocks_n_middle=0,
add_out_act=True,
max_features=1024, is_resblock_depthwise=False,
ffc_positions=None, ffc_kwargs={}, dilation=1, second_dilation=None,
dilation_block_kind='simple', multidilation_kwargs={}):
assert (n_blocks >= 0)
super().__init__()
conv_layer = get_conv_block_ctor(conv_kind)
norm_layer = get_norm_layer(norm_layer)
if affine is not None:
norm_layer = partial(norm_layer, affine=affine)
up_norm_layer = get_norm_layer(up_norm_layer)
if affine is not None:
up_norm_layer = partial(up_norm_layer, affine=affine)
if ffc_positions is not None:
ffc_positions = collections.Counter(ffc_positions)
model = [nn.ReflectionPad2d(3),
conv_layer(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
activation]
identity = Identity()
### downsample
for i in range(n_downsampling):
mult = 2 ** i
model += [conv_layer(min(max_features, ngf * mult),
min(max_features, ngf * mult * 2),
kernel_size=3, stride=2, padding=1),
norm_layer(min(max_features, ngf * mult * 2)),
activation]
mult = 2 ** n_downsampling
feats_num_bottleneck = min(max_features, ngf * mult)
dilated_block_kwargs = dict(dim=feats_num_bottleneck, padding_type=padding_type,
activation=activation, norm_layer=norm_layer)
if dilation_block_kind == 'simple':
dilated_block_kwargs['conv_kind'] = conv_kind
elif dilation_block_kind == 'multi':
dilated_block_kwargs['conv_layer'] = functools.partial(
get_conv_block_ctor('multidilated'), **multidilation_kwargs)
# dilated blocks at the start of the bottleneck sausage
if dilated_blocks_n_start is not None and dilated_blocks_n_start > 0:
model += make_dil_blocks(dilated_blocks_n_start, dilation_block_kind, dilated_block_kwargs)
# resnet blocks
for i in range(n_blocks):
# dilated blocks at the middle of the bottleneck sausage
if i == n_blocks // 2 and dilated_blocks_n_middle is not None and dilated_blocks_n_middle > 0:
model += make_dil_blocks(dilated_blocks_n_middle, dilation_block_kind, dilated_block_kwargs)
if ffc_positions is not None and i in ffc_positions:
for _ in range(ffc_positions[i]): # same position can occur more than once
model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
inline=True, **ffc_kwargs)]
if is_resblock_depthwise:
resblock_groups = feats_num_bottleneck
else:
resblock_groups = 1
model += [ResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation=activation,
norm_layer=norm_layer, conv_kind=conv_kind, groups=resblock_groups,
dilation=dilation, second_dilation=second_dilation)]
# dilated blocks at the end of the bottleneck sausage
if dilated_blocks_n is not None and dilated_blocks_n > 0:
model += make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs)
# upsample
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
min(max_features, int(ngf * mult / 2)),
kernel_size=3, stride=2, padding=1, output_padding=1),
up_norm_layer(min(max_features, int(ngf * mult / 2))),
up_activation]
model += [nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
if add_out_act:
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
class GlobalGeneratorGated(GlobalGenerator):
def __init__(self, *args, **kwargs):
real_kwargs=dict(
conv_kind='gated_bn_relu',
activation=nn.Identity(),
norm_layer=nn.Identity
)
real_kwargs.update(kwargs)
super().__init__(*args, **real_kwargs)
class GlobalGeneratorFromSuperChannels(nn.Module):
def __init__(self, input_nc, output_nc, n_downsampling, n_blocks, super_channels, norm_layer="bn", padding_type='reflect', add_out_act=True):
super().__init__()
self.n_downsampling = n_downsampling
norm_layer = get_norm_layer(norm_layer)
if type(norm_layer) == functools.partial:
use_bias = (norm_layer.func == nn.InstanceNorm2d)
else:
use_bias = (norm_layer == nn.InstanceNorm2d)
channels = self.convert_super_channels(super_channels)
self.channels = channels
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, channels[0], kernel_size=7, padding=0, bias=use_bias),
norm_layer(channels[0]),
nn.ReLU(True)]
for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
model += [nn.Conv2d(channels[0+i], channels[1+i], kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(channels[1+i]),
nn.ReLU(True)]
mult = 2 ** n_downsampling
n_blocks1 = n_blocks // 3
n_blocks2 = n_blocks1
n_blocks3 = n_blocks - n_blocks1 - n_blocks2
for i in range(n_blocks1):
c = n_downsampling
dim = channels[c]
model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer)]
for i in range(n_blocks2):
c = n_downsampling+1
dim = channels[c]
kwargs = {}
if i == 0:
kwargs = {"in_dim": channels[c-1]}
model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
for i in range(n_blocks3):
c = n_downsampling+2
dim = channels[c]
kwargs = {}
if i == 0:
kwargs = {"in_dim": channels[c-1]}
model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
for i in range(n_downsampling): # add upsampling layers
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(channels[n_downsampling+3+i],
channels[n_downsampling+3+i+1],
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(channels[n_downsampling+3+i+1]),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(channels[2*n_downsampling+3], output_nc, kernel_size=7, padding=0)]
if add_out_act:
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
self.model = nn.Sequential(*model)
def convert_super_channels(self, super_channels):
n_downsampling = self.n_downsampling
result = []
cnt = 0
if n_downsampling == 2:
N1 = 10
elif n_downsampling == 3:
N1 = 13
else:
raise NotImplementedError
for i in range(0, N1):
if i in [1,4,7,10]:
channel = super_channels[cnt] * (2 ** cnt)
config = {'channel': channel}
result.append(channel)
logging.info(f"Downsample channels {result[-1]}")
cnt += 1
for i in range(3):
for counter, j in enumerate(range(N1 + i * 3, N1 + 3 + i * 3)):
if len(super_channels) == 6:
channel = super_channels[3] * 4
else:
channel = super_channels[i + 3] * 4
config = {'channel': channel}
if counter == 0:
result.append(channel)
logging.info(f"Bottleneck channels {result[-1]}")
cnt = 2
for i in range(N1+9, N1+21):
if i in [22, 25,28]:
cnt -= 1
if len(super_channels) == 6:
channel = super_channels[5 - cnt] * (2 ** cnt)
else:
channel = super_channels[7 - cnt] * (2 ** cnt)
result.append(int(channel))
logging.info(f"Upsample channels {result[-1]}")
return result
def forward(self, input):
return self.model(input)
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(BaseDiscriminator):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,):
super().__init__()
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw-1.0)/2))
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
cur_model = []
cur_model += [
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]
sequence.append(cur_model)
nf_prev = nf
nf = min(nf * 2, 512)
cur_model = []
cur_model += [
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]
sequence.append(cur_model)
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
def get_all_activations(self, x):
res = [x]
for n in range(self.n_layers + 2):
model = getattr(self, 'model' + str(n))
res.append(model(res[-1]))
return res[1:]
def forward(self, x):
act = self.get_all_activations(x)
return act[-1], act[:-1]
class MultidilatedNLayerDiscriminator(BaseDiscriminator):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, multidilation_kwargs={}):
super().__init__()
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw-1.0)/2))
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
cur_model = []
cur_model += [
MultidilatedConv(nf_prev, nf, kernel_size=kw, stride=2, padding=[2, 3], **multidilation_kwargs),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]
sequence.append(cur_model)
nf_prev = nf
nf = min(nf * 2, 512)
cur_model = []
cur_model += [
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]
sequence.append(cur_model)
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
def get_all_activations(self, x):
res = [x]
for n in range(self.n_layers + 2):
model = getattr(self, 'model' + str(n))
res.append(model(res[-1]))
return res[1:]
def forward(self, x):
act = self.get_all_activations(x)
return act[-1], act[:-1]
class NLayerDiscriminatorAsGen(NLayerDiscriminator):
def forward(self, x):
return super().forward(x)[0]