from .common import * _norm = bn # _norm = pixelnorm def norm(channels): return _norm(channels) def skip( num_input_channels=2, num_output_channels=3, num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4], filter_size_down=3, filter_size_up=3, filter_skip_size=1, need_sigmoid=True, need_tanh=False, need_bias=True, pad="reflection", upsample_mode="bilinear", downsample_mode="stride", act_fun="LeakyReLU", need1x1_up=True, decorr_rgb=True, ): """Assembles encoder-decoder with skip connections. Arguments: act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) pad (string): zero|reflection (default: 'zero') upsample_mode (string): 'nearest|bilinear' (default: 'nearest') downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride') """ assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) n_scales = len(num_channels_down) if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)): upsample_mode = [upsample_mode] * n_scales if not (isinstance(downsample_mode, list) or isinstance(downsample_mode, tuple)): downsample_mode = [downsample_mode] * n_scales if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)): filter_size_down = [filter_size_down] * n_scales if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)): filter_size_up = [filter_size_up] * n_scales last_scale = n_scales - 1 cur_depth = None model = nn.Sequential() # model.add(transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) model_tmp = model input_depth = num_input_channels for i in range(len(num_channels_down)): deeper = nn.Sequential() skip = nn.Sequential() if num_channels_skip[i] != 0: model_tmp.add(Concat(1, skip, deeper)) else: model_tmp.add(deeper) model_tmp.add(norm(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i]))) if num_channels_skip[i] != 0: skip.add(conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad)) skip.add(norm(num_channels_skip[i])) skip.add(act(act_fun)) # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part)) deeper.add( conv( input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i], ) ) deeper.add(norm(num_channels_down[i])) deeper.add(act(act_fun)) deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)) deeper.add(norm(num_channels_down[i])) deeper.add(act(act_fun)) deeper_main = nn.Sequential() if i == len(num_channels_down) - 1: # The deepest k = num_channels_down[i] else: deeper.add(deeper_main) k = num_channels_up[i + 1] deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i])) model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad)) # if i > 0: # model_tmp.add(norm(num_channels_up[i])) model_tmp.add(norm(num_channels_up[i])) model_tmp.add(act(act_fun)) if need1x1_up: model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)) if i > 0: model_tmp.add(norm(num_channels_up[i])) model_tmp.add(act(act_fun)) input_depth = num_channels_down[i] model_tmp = deeper_main model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad)) if decorr_rgb: model.add(DecorrelatedColorsToRGB()) if need_sigmoid: model.add(nn.Sigmoid()) elif need_tanh: model.add(nn.Tanh()) return model