Spaces:
Configuration error
Configuration error
from .common import * | |
_norm = bn | |
# _norm = pixelnorm | |
def norm(channels): | |
return _norm(channels) | |
def skip( | |
num_input_channels=2, | |
num_output_channels=3, | |
num_channels_down=[16, 32, 64, 128, 128], | |
num_channels_up=[16, 32, 64, 128, 128], | |
num_channels_skip=[4, 4, 4, 4, 4], | |
filter_size_down=3, | |
filter_size_up=3, | |
filter_skip_size=1, | |
need_sigmoid=True, | |
need_tanh=False, | |
need_bias=True, | |
pad="reflection", | |
upsample_mode="bilinear", | |
downsample_mode="stride", | |
act_fun="LeakyReLU", | |
need1x1_up=True, | |
decorr_rgb=True, | |
): | |
"""Assembles encoder-decoder with skip connections. | |
Arguments: | |
act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) | |
pad (string): zero|reflection (default: 'zero') | |
upsample_mode (string): 'nearest|bilinear' (default: 'nearest') | |
downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride') | |
""" | |
assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) | |
n_scales = len(num_channels_down) | |
if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)): | |
upsample_mode = [upsample_mode] * n_scales | |
if not (isinstance(downsample_mode, list) or isinstance(downsample_mode, tuple)): | |
downsample_mode = [downsample_mode] * n_scales | |
if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)): | |
filter_size_down = [filter_size_down] * n_scales | |
if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)): | |
filter_size_up = [filter_size_up] * n_scales | |
last_scale = n_scales - 1 | |
cur_depth = None | |
model = nn.Sequential() | |
# model.add(transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) | |
model_tmp = model | |
input_depth = num_input_channels | |
for i in range(len(num_channels_down)): | |
deeper = nn.Sequential() | |
skip = nn.Sequential() | |
if num_channels_skip[i] != 0: | |
model_tmp.add(Concat(1, skip, deeper)) | |
else: | |
model_tmp.add(deeper) | |
model_tmp.add(norm(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i]))) | |
if num_channels_skip[i] != 0: | |
skip.add(conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad)) | |
skip.add(norm(num_channels_skip[i])) | |
skip.add(act(act_fun)) | |
# skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part)) | |
deeper.add( | |
conv( | |
input_depth, | |
num_channels_down[i], | |
filter_size_down[i], | |
2, | |
bias=need_bias, | |
pad=pad, | |
downsample_mode=downsample_mode[i], | |
) | |
) | |
deeper.add(norm(num_channels_down[i])) | |
deeper.add(act(act_fun)) | |
deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)) | |
deeper.add(norm(num_channels_down[i])) | |
deeper.add(act(act_fun)) | |
deeper_main = nn.Sequential() | |
if i == len(num_channels_down) - 1: | |
# The deepest | |
k = num_channels_down[i] | |
else: | |
deeper.add(deeper_main) | |
k = num_channels_up[i + 1] | |
deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i])) | |
model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad)) | |
# if i > 0: | |
# model_tmp.add(norm(num_channels_up[i])) | |
model_tmp.add(norm(num_channels_up[i])) | |
model_tmp.add(act(act_fun)) | |
if need1x1_up: | |
model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)) | |
if i > 0: | |
model_tmp.add(norm(num_channels_up[i])) | |
model_tmp.add(act(act_fun)) | |
input_depth = num_channels_down[i] | |
model_tmp = deeper_main | |
model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad)) | |
if decorr_rgb: | |
model.add(DecorrelatedColorsToRGB()) | |
if need_sigmoid: | |
model.add(nn.Sigmoid()) | |
elif need_tanh: | |
model.add(nn.Tanh()) | |
return model | |