Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import copy | |
from functools import partial | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from imaginaire.layers import (Conv2dBlock, HyperConv2dBlock, HyperRes2dBlock, | |
LinearBlock, Res2dBlock) | |
from imaginaire.model_utils.fs_vid2vid import (extract_valid_pose_labels, | |
pick_image, resample) | |
from imaginaire.utils.data import (get_paired_input_image_channel_number, | |
get_paired_input_label_channel_number) | |
from imaginaire.utils.distributed import master_only_print as print | |
from imaginaire.utils.init_weight import weights_init | |
from imaginaire.utils.misc import get_and_setattr, get_nested_attr | |
class Generator(nn.Module): | |
r"""Few-shot vid2vid generator constructor. | |
Args: | |
gen_cfg (obj): Generator definition part of the yaml config file. | |
data_cfg (obj): Data definition part of the yaml config file. | |
""" | |
def __init__(self, gen_cfg, data_cfg): | |
super().__init__() | |
self.gen_cfg = gen_cfg | |
self.data_cfg = data_cfg | |
self.num_frames_G = data_cfg.num_frames_G | |
self.flow_cfg = flow_cfg = gen_cfg.flow | |
# For pose dataset. | |
self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset') | |
if self.is_pose_data: | |
pose_cfg = data_cfg.for_pose_dataset | |
self.pose_type = getattr(pose_cfg, 'pose_type', 'both') | |
self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels', | |
False) | |
num_img_channels = get_paired_input_image_channel_number(data_cfg) | |
self.num_downsamples = num_downsamples = \ | |
get_and_setattr(gen_cfg, 'num_downsamples', 5) | |
conv_kernel_size = get_and_setattr(gen_cfg, 'kernel_size', 3) | |
num_filters = get_and_setattr(gen_cfg, 'num_filters', 32) | |
max_num_filters = getattr(gen_cfg, 'max_num_filters', 1024) | |
self.max_num_filters = gen_cfg.max_num_filters = \ | |
min(max_num_filters, num_filters * (2 ** num_downsamples)) | |
# Get number of filters at each layer in the main branch. | |
num_filters_each_layer = [min(self.max_num_filters, | |
num_filters * (2 ** i)) | |
for i in range(num_downsamples + 2)] | |
# Hyper normalization / convolution. | |
hyper_cfg = gen_cfg.hyper | |
# Use adaptive weight generation for SPADE. | |
self.use_hyper_spade = hyper_cfg.is_hyper_spade | |
# Use adaptive for convolutional layers in the main branch. | |
self.use_hyper_conv = hyper_cfg.is_hyper_conv | |
# Number of hyper layers. | |
self.num_hyper_layers = getattr(hyper_cfg, 'num_hyper_layers', 4) | |
if self.num_hyper_layers == -1: | |
self.num_hyper_layers = num_downsamples | |
gen_cfg.hyper.num_hyper_layers = self.num_hyper_layers | |
# Network weight generator. | |
self.weight_generator = WeightGenerator(gen_cfg, data_cfg) | |
# Number of layers to perform multi-spade combine. | |
self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine, | |
'num_layers', 3) | |
# Whether to generate raw output for additional losses. | |
self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output', | |
False) | |
# Main branch image generation. | |
padding = conv_kernel_size // 2 | |
activation_norm_type = get_and_setattr(gen_cfg, 'activation_norm_type', | |
'sync_batch') | |
weight_norm_type = get_and_setattr(gen_cfg, 'weight_norm_type', | |
'spectral') | |
activation_norm_params = get_and_setattr(gen_cfg, | |
'activation_norm_params', | |
None) | |
spade_in_channels = [] # Input channel size in SPADE module. | |
for i in range(num_downsamples + 1): | |
spade_in_channels += [[num_filters_each_layer[i]]] \ | |
if i >= self.num_multi_spade_layers \ | |
else [[num_filters_each_layer[i]] * 3] | |
order = getattr(gen_cfg.hyper, 'hyper_block_order', 'NAC') | |
for i in reversed(range(num_downsamples + 1)): | |
activation_norm_params.cond_dims = spade_in_channels[i] | |
is_hyper_conv = self.use_hyper_conv and i < self.num_hyper_layers | |
is_hyper_norm = self.use_hyper_spade and i < self.num_hyper_layers | |
setattr(self, 'up_%d' % i, HyperRes2dBlock( | |
num_filters_each_layer[i + 1], num_filters_each_layer[i], | |
conv_kernel_size, padding=padding, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
activation_norm_params=activation_norm_params, | |
order=order * 2, | |
is_hyper_conv=is_hyper_conv, is_hyper_norm=is_hyper_norm)) | |
self.conv_img = Conv2dBlock(num_filters, num_img_channels, | |
conv_kernel_size, padding=padding, | |
nonlinearity='leakyrelu', order='AC') | |
self.upsample = partial(F.interpolate, scale_factor=2) | |
# Flow estimation module. | |
# Whether to warp reference image and combine with the synthesized. | |
self.warp_ref = getattr(flow_cfg, 'warp_ref', True) | |
if self.warp_ref: | |
self.flow_network_ref = FlowGenerator(flow_cfg, data_cfg, 2) | |
self.ref_image_embedding = \ | |
LabelEmbedder(flow_cfg.multi_spade_combine.embed, | |
num_img_channels + 1) | |
# At beginning of training, only train an image generator. | |
self.temporal_initialized = False | |
if getattr(gen_cfg, 'init_temporal', True): | |
self.init_temporal_network() | |
def forward(self, data): | |
r"""few-shot vid2vid generator forward. | |
Args: | |
data (dict) : Dictionary of input data. | |
Returns: | |
output (dict) : Dictionary of output data. | |
""" | |
label = data['label'] | |
ref_labels, ref_images = data['ref_labels'], data['ref_images'] | |
prev_labels, prev_images = data['prev_labels'], data['prev_images'] | |
is_first_frame = prev_labels is None | |
if self.is_pose_data: | |
label, prev_labels = extract_valid_pose_labels( | |
[label, prev_labels], self.pose_type, self.remove_face_labels) | |
ref_labels = extract_valid_pose_labels( | |
ref_labels, self.pose_type, self.remove_face_labels, | |
do_remove=False) | |
# Weight generation. | |
x, encoded_label, conv_weights, norm_weights, atn, atn_vis, ref_idx = \ | |
self.weight_generator(ref_images, ref_labels, label, is_first_frame) | |
# Flow estimation. | |
flow, flow_mask, img_warp, cond_inputs = \ | |
self.flow_generation(label, ref_labels, ref_images, | |
prev_labels, prev_images, ref_idx) | |
for i in range(len(encoded_label)): | |
encoded_label[i] = [encoded_label[i]] | |
if self.generate_raw_output: | |
encoded_label_raw = [encoded_label[i] for i in | |
range(self.num_multi_spade_layers)] | |
x_raw = None | |
encoded_label = self.SPADE_combine(encoded_label, cond_inputs) | |
# Main branch image generation. | |
for i in range(self.num_downsamples, -1, -1): | |
conv_weight = norm_weight = [None] * 3 | |
if self.use_hyper_conv and i < self.num_hyper_layers: | |
conv_weight = conv_weights[i] | |
if self.use_hyper_spade and i < self.num_hyper_layers: | |
norm_weight = norm_weights[i] | |
# Main branch residual blocks. | |
x = self.one_up_conv_layer(x, encoded_label, | |
conv_weight, norm_weight, i) | |
# For raw output generation. | |
if self.generate_raw_output and i < self.num_multi_spade_layers: | |
x_raw = self.one_up_conv_layer(x_raw, encoded_label_raw, | |
conv_weight, norm_weight, i) | |
else: | |
x_raw = x | |
# Final conv layer. | |
if self.generate_raw_output: | |
img_raw = torch.tanh(self.conv_img(x_raw)) | |
else: | |
img_raw = None | |
img_final = torch.tanh(self.conv_img(x)) | |
output = dict() | |
output['fake_images'] = img_final | |
output['fake_flow_maps'] = flow | |
output['fake_occlusion_masks'] = flow_mask | |
output['fake_raw_images'] = img_raw | |
output['warped_images'] = img_warp | |
output['attention_visualization'] = atn_vis | |
output['ref_idx'] = ref_idx | |
return output | |
def one_up_conv_layer(self, x, encoded_label, conv_weight, norm_weight, i): | |
r"""One residual block layer in the main branch. | |
Args: | |
x (4D tensor) : Current feature map. | |
encoded_label (list of tensors) : Encoded input label maps. | |
conv_weight (list of tensors) : Hyper conv weights. | |
norm_weight (list of tensors) : Hyper norm weights. | |
i (int) : Layer index. | |
Returns: | |
x (4D tensor) : Output feature map. | |
""" | |
layer = getattr(self, 'up_' + str(i)) | |
x = layer(x, *encoded_label[i], conv_weights=conv_weight, | |
norm_weights=norm_weight) | |
if i != 0: | |
x = self.upsample(x) | |
return x | |
def init_temporal_network(self, cfg_init=None): | |
r"""When starting training multiple frames, initialize the flow network. | |
Args: | |
cfg_init (dict) : Weight initialization config. | |
""" | |
flow_cfg = self.flow_cfg | |
emb_cfg = flow_cfg.multi_spade_combine.embed | |
num_frames_G = self.num_frames_G | |
self.temporal_initialized = True | |
self.sep_prev_flownet = flow_cfg.sep_prev_flow or (num_frames_G != 2) \ | |
or not flow_cfg.warp_ref | |
if self.sep_prev_flownet: | |
self.flow_network_temp = FlowGenerator(flow_cfg, self.data_cfg, | |
num_frames_G) | |
if cfg_init is not None: | |
self.flow_network_temp.apply(weights_init(cfg_init.type, | |
cfg_init.gain)) | |
else: | |
self.flow_network_temp = self.flow_network_ref | |
self.sep_prev_embedding = emb_cfg.sep_warp_embed or \ | |
not flow_cfg.warp_ref | |
if self.sep_prev_embedding: | |
num_img_channels = get_paired_input_image_channel_number( | |
self.data_cfg) | |
self.prev_image_embedding = \ | |
LabelEmbedder(emb_cfg, num_img_channels + 1) | |
if cfg_init is not None: | |
self.prev_image_embedding.apply( | |
weights_init(cfg_init.type, cfg_init.gain)) | |
else: | |
self.prev_image_embedding = self.ref_image_embedding | |
if self.warp_ref: | |
if self.sep_prev_flownet: | |
self.init_network_weights(self.flow_network_ref, | |
self.flow_network_temp) | |
print('Initialized temporal flow network with the reference ' | |
'one.') | |
if self.sep_prev_embedding: | |
self.init_network_weights(self.ref_image_embedding, | |
self.prev_image_embedding) | |
print('Initialized temporal embedding network with the ' | |
'reference one.') | |
self.flow_temp_is_initalized = True | |
def init_network_weights(self, net_src, net_dst): | |
r"""Initialize weights in net_dst with those in net_src.""" | |
source_weights = net_src.state_dict() | |
target_weights = net_dst.state_dict() | |
for k, v in source_weights.items(): | |
if k in target_weights and target_weights[k].size() == v.size(): | |
target_weights[k] = v | |
net_dst.load_state_dict(target_weights) | |
def load_pretrained_network(self, pretrained_dict, prefix='module.'): | |
r"""Load the pretrained network into self network. | |
Args: | |
pretrained_dict (dict): Pretrained network weights. | |
prefix (str): Prefix to the network weights name. | |
""" | |
# print(pretrained_dict.keys()) | |
model_dict = self.state_dict() | |
print('Pretrained network has fewer layers; The following are ' | |
'not initialized:') | |
not_initialized = set() | |
for k, v in model_dict.items(): | |
kp = prefix + k | |
if kp in pretrained_dict and v.size() == pretrained_dict[kp].size(): | |
model_dict[k] = pretrained_dict[kp] | |
else: | |
not_initialized.add('.'.join(k.split('.')[:2])) | |
print(sorted(not_initialized)) | |
self.load_state_dict(model_dict) | |
def reset(self): | |
r"""Reset the network at the beginning of a sequence.""" | |
self.weight_generator.reset() | |
def flow_generation(self, label, ref_labels, ref_images, prev_labels, | |
prev_images, ref_idx): | |
r"""Generates flows and masks for warping reference / previous images. | |
Args: | |
label (NxCxHxW tensor): Target label map. | |
ref_labels (NxKxCxHxW tensor): Reference label maps. | |
ref_images (NxKx3xHxW tensor): Reference images. | |
prev_labels (NxTxCxHxW tensor): Previous label maps. | |
prev_images (NxTx3xHxW tensor): Previous images. | |
ref_idx (Nx1 tensor): Index for which image to use from the | |
reference images. | |
Returns: | |
(tuple): | |
- flow (list of Nx2xHxW tensor): Optical flows. | |
- occ_mask (list of Nx1xHxW tensor): Occlusion masks. | |
- img_warp (list of Nx3xHxW tensor): Warped reference / previous | |
images. | |
- cond_inputs (list of Nx4xHxW tensor): Conditional inputs for | |
SPADE combination. | |
""" | |
# Pick an image in the reference images using ref_idx. | |
ref_label, ref_image = pick_image([ref_labels, ref_images], ref_idx) | |
# Only start using prev frames when enough prev frames are generated. | |
has_prev = prev_labels is not None and \ | |
prev_labels.shape[1] == (self.num_frames_G - 1) | |
flow, occ_mask, img_warp, cond_inputs = [None] * 2, [None] * 2, \ | |
[None] * 2, [None] * 2 | |
if self.warp_ref: | |
# Generate flows/masks for warping the reference image. | |
flow_ref, occ_mask_ref = \ | |
self.flow_network_ref(label, ref_label, ref_image) | |
ref_image_warp = resample(ref_image, flow_ref) | |
flow[0], occ_mask[0], img_warp[0] = \ | |
flow_ref, occ_mask_ref, ref_image_warp[:, :3] | |
# Concat warped image and occlusion mask to form the conditional | |
# input. | |
cond_inputs[0] = torch.cat([img_warp[0], occ_mask[0]], dim=1) | |
if self.temporal_initialized and has_prev: | |
# Generate flows/masks for warping the previous image. | |
b, t, c, h, w = prev_labels.shape | |
prev_labels_concat = prev_labels.view(b, -1, h, w) | |
prev_images_concat = prev_images.view(b, -1, h, w) | |
flow_prev, occ_mask_prev = \ | |
self.flow_network_temp(label, prev_labels_concat, | |
prev_images_concat) | |
img_prev_warp = resample(prev_images[:, -1], flow_prev) | |
flow[1], occ_mask[1], img_warp[1] = \ | |
flow_prev, occ_mask_prev, img_prev_warp | |
cond_inputs[1] = torch.cat([img_warp[1], occ_mask[1]], dim=1) | |
return flow, occ_mask, img_warp, cond_inputs | |
def SPADE_combine(self, encoded_label, cond_inputs): | |
r"""Using Multi-SPADE to combine raw synthesized image with warped | |
images. | |
Args: | |
encoded_label (list of tensors): Original label map embeddings. | |
cond_inputs (list of tensors): New SPADE conditional inputs from the | |
warped images. | |
Returns: | |
encoded_label (list of tensors): Combined conditional inputs. | |
""" | |
# Generate the conditional embeddings from inputs. | |
embedded_img_feat = [None, None] | |
if cond_inputs[0] is not None: | |
embedded_img_feat[0] = self.ref_image_embedding(cond_inputs[0]) | |
if cond_inputs[1] is not None: | |
embedded_img_feat[1] = self.prev_image_embedding(cond_inputs[1]) | |
# Combine the original encoded label maps with new conditional | |
# embeddings. | |
for i in range(self.num_multi_spade_layers): | |
encoded_label[i] += [w[i] if w is not None else None | |
for w in embedded_img_feat] | |
return encoded_label | |
def custom_init(self): | |
r"""This function is for dealing with the numerical issue that might | |
occur when doing mixed precision training. | |
""" | |
print('Use custom initialization for the generator.') | |
for k, m in self.named_modules(): | |
if 'weight_generator.ref_label_' in k and 'norm' in k: | |
m.eps = 1e-1 | |
class WeightGenerator(nn.Module): | |
r"""Weight generator constructor. | |
Args: | |
gen_cfg (obj): Generator definition part of the yaml config file. | |
data_cfg (obj): Data definition part of the yaml config file | |
""" | |
def __init__(self, gen_cfg, data_cfg): | |
super().__init__() | |
self.data_cfg = data_cfg | |
self.embed_cfg = embed_cfg = gen_cfg.embed | |
self.embed_arch = embed_cfg.arch | |
num_filters = gen_cfg.num_filters | |
self.max_num_filters = gen_cfg.max_num_filters | |
self.num_downsamples = num_downsamples = gen_cfg.num_downsamples | |
self.num_filters_each_layer = num_filters_each_layer = \ | |
[min(self.max_num_filters, num_filters * (2 ** i)) | |
for i in range(num_downsamples + 2)] | |
if getattr(embed_cfg, 'num_filters', 32) != num_filters: | |
raise ValueError('Embedding network must have the same number of ' | |
'filters as generator.') | |
# Normalization params. | |
hyper_cfg = gen_cfg.hyper | |
kernel_size = getattr(hyper_cfg, 'kernel_size', 3) | |
activation_norm_type = getattr(hyper_cfg, 'activation_norm_type', | |
'sync_batch') | |
weight_norm_type = getattr(hyper_cfg, 'weight_norm_type', 'spectral') | |
# Conv kernel size in main branch. | |
self.conv_kernel_size = conv_kernel_size = gen_cfg.kernel_size | |
# Conv kernel size in embedding network. | |
self.embed_kernel_size = embed_kernel_size = \ | |
getattr(gen_cfg.embed, 'kernel_size', 3) | |
# Conv kernel size in SPADE. | |
self.kernel_size = kernel_size = \ | |
getattr(gen_cfg.activation_norm_params, 'kernel_size', 1) | |
# Input channel size in SPADE module. | |
self.spade_in_channels = [] | |
for i in range(num_downsamples + 1): | |
self.spade_in_channels += [num_filters_each_layer[i]] | |
# Hyper normalization / convolution. | |
# Use adaptive weight generation for SPADE. | |
self.use_hyper_spade = hyper_cfg.is_hyper_spade | |
# Use adaptive for the label embedding network. | |
self.use_hyper_embed = hyper_cfg.is_hyper_embed | |
# Use adaptive for convolutional layers in the main branch. | |
self.use_hyper_conv = hyper_cfg.is_hyper_conv | |
# Number of hyper layers. | |
self.num_hyper_layers = hyper_cfg.num_hyper_layers | |
# Order of operations in the conv block. | |
order = getattr(gen_cfg.hyper, 'hyper_block_order', 'NAC') | |
self.conv_before_norm = order.find('C') < order.find('N') | |
# For reference image encoding. | |
# How to utilize the reference label map: concat | mul. | |
self.concat_ref_label = 'concat' in hyper_cfg.method_to_use_ref_labels | |
self.mul_ref_label = 'mul' in hyper_cfg.method_to_use_ref_labels | |
# Output spatial size for adaptive pooling layer. | |
self.sh_fix = self.sw_fix = 32 | |
# Number of fc layers in weight generation. | |
self.num_fc_layers = getattr(hyper_cfg, 'num_fc_layers', 2) | |
# Reference image encoding network. | |
num_input_channels = get_paired_input_label_channel_number(data_cfg) | |
if num_input_channels == 0: | |
num_input_channels = getattr(data_cfg, 'label_channels', 1) | |
elif get_nested_attr(data_cfg, 'for_pose_dataset.pose_type', | |
'both') == 'open': | |
num_input_channels -= 3 | |
data_cfg.num_input_channels = num_input_channels | |
num_img_channels = get_paired_input_image_channel_number(data_cfg) | |
num_ref_channels = num_img_channels + (num_input_channels | |
if self.concat_ref_label else 0) | |
conv_2d_block = partial( | |
Conv2dBlock, kernel_size=kernel_size, | |
padding=(kernel_size // 2), weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
nonlinearity='leakyrelu') | |
self.ref_img_first = conv_2d_block(num_ref_channels, num_filters) | |
if self.mul_ref_label: | |
self.ref_label_first = conv_2d_block(num_input_channels, | |
num_filters) | |
for i in range(num_downsamples): | |
in_ch, out_ch = num_filters_each_layer[i], \ | |
num_filters_each_layer[i + 1] | |
setattr(self, 'ref_img_down_%d' % i, | |
conv_2d_block(in_ch, out_ch, stride=2)) | |
setattr(self, 'ref_img_up_%d' % i, conv_2d_block(out_ch, in_ch)) | |
if self.mul_ref_label: | |
setattr(self, 'ref_label_down_%d' % i, | |
conv_2d_block(in_ch, out_ch, stride=2)) | |
setattr(self, 'ref_label_up_%d' % i, | |
conv_2d_block(out_ch, in_ch)) | |
# Normalization / main branch conv weight generation. | |
if self.use_hyper_spade or self.use_hyper_conv: | |
for i in range(self.num_hyper_layers): | |
ch_in, ch_out = num_filters_each_layer[i], \ | |
num_filters_each_layer[i + 1] | |
conv_ks2 = conv_kernel_size ** 2 | |
embed_ks2 = embed_kernel_size ** 2 | |
spade_ks2 = kernel_size ** 2 | |
spade_in_ch = self.spade_in_channels[i] | |
fc_names, fc_ins, fc_outs = [], [], [] | |
if self.use_hyper_spade: | |
fc0_out = fcs_out = (spade_in_ch * spade_ks2 + 1) * ( | |
1 if self.conv_before_norm else 2) | |
fc1_out = (spade_in_ch * spade_ks2 + 1) * ( | |
1 if ch_in != ch_out else 2) | |
fc_names += ['fc_spade_0', 'fc_spade_1', 'fc_spade_s'] | |
fc_ins += [ch_out] * 3 | |
fc_outs += [fc0_out, fc1_out, fcs_out] | |
if self.use_hyper_embed: | |
fc_names += ['fc_spade_e'] | |
fc_ins += [ch_out] | |
fc_outs += [ch_in * embed_ks2 + 1] | |
if self.use_hyper_conv: | |
fc0_out = ch_out * conv_ks2 + 1 | |
fc1_out = ch_in * conv_ks2 + 1 | |
fcs_out = ch_out + 1 | |
fc_names += ['fc_conv_0', 'fc_conv_1', 'fc_conv_s'] | |
fc_ins += [ch_in] * 3 | |
fc_outs += [fc0_out, fc1_out, fcs_out] | |
linear_block = partial(LinearBlock, | |
weight_norm_type='spectral', | |
nonlinearity='leakyrelu') | |
for n, l in enumerate(fc_names): | |
fc_in = fc_ins[n] if self.mul_ref_label \ | |
else self.sh_fix * self.sw_fix | |
fc_layer = [linear_block(fc_in, ch_out)] | |
for k in range(1, self.num_fc_layers): | |
fc_layer += [linear_block(ch_out, ch_out)] | |
fc_layer += [LinearBlock(ch_out, fc_outs[n], | |
weight_norm_type='spectral')] | |
setattr(self, '%s_%d' % (l, i), nn.Sequential(*fc_layer)) | |
# Label embedding network. | |
num_hyper_layers = self.num_hyper_layers if self.use_hyper_embed else 0 | |
self.label_embedding = LabelEmbedder(self.embed_cfg, | |
num_input_channels, | |
num_hyper_layers=num_hyper_layers) | |
# For multiple reference images. | |
if hasattr(hyper_cfg, 'attention'): | |
self.num_downsample_atn = get_and_setattr(hyper_cfg.attention, | |
'num_downsamples', 2) | |
if data_cfg.initial_few_shot_K > 1: | |
self.attention_module = AttentionModule(hyper_cfg, data_cfg, | |
conv_2d_block, | |
num_filters_each_layer) | |
else: | |
self.num_downsample_atn = 0 | |
def forward(self, ref_image, ref_label, label, is_first_frame): | |
r"""Generate network weights based on the reference images. | |
Args: | |
ref_image (NxKx3xHxW tensor): Reference images. | |
ref_label (NxKxCxHxW tensor): Reference labels. | |
label (NxCxHxW tensor): Target label. | |
is_first_frame (bool): Whether the current frame is the first frame. | |
Returns: | |
(tuple): | |
- x (NxC2xH2xW2 tensor): Encoded features from reference images | |
for the main branch (as input to the decoder). | |
- encoded_label (list of tensors): Encoded target label map for | |
SPADE. | |
- conv_weights (list of tensors): Network weights for conv | |
layers in the main network. | |
- norm_weights (list of tensors): Network weights for SPADE | |
layers in the main network. | |
- attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps. | |
- atn_vis (1x1xH1xW1 tensor): Visualization for attention | |
scores. | |
- ref_idx (Nx1 tensor): Index for which image to use from the | |
reference images. | |
""" | |
b, k, c, h, w = ref_image.size() | |
ref_image = ref_image.view(b * k, -1, h, w) | |
if ref_label is not None: | |
ref_label = ref_label.view(b * k, -1, h, w) | |
# Encode the reference images to get the features. | |
x, encoded_ref, atn, atn_vis, ref_idx = \ | |
self.encode_reference(ref_image, ref_label, label, k) | |
# If the reference image has changed, recompute the network weights. | |
if self.training or is_first_frame or k > 1: | |
embedding_weights, norm_weights, conv_weights = [], [], [] | |
for i in range(self.num_hyper_layers): | |
if self.use_hyper_spade: | |
feat = encoded_ref[min(len(encoded_ref) - 1, i + 1)] | |
embedding_weight, norm_weight = \ | |
self.get_norm_weights(feat, i) | |
embedding_weights.append(embedding_weight) | |
norm_weights.append(norm_weight) | |
if self.use_hyper_conv: | |
feat = encoded_ref[min(len(encoded_ref) - 1, i)] | |
conv_weights.append(self.get_conv_weights(feat, i)) | |
if not self.training: | |
self.embedding_weights, self.conv_weights, self.norm_weights \ | |
= embedding_weights, conv_weights, norm_weights | |
else: | |
# print('Reusing network weights.') | |
embedding_weights, conv_weights, norm_weights \ | |
= self.embedding_weights, self.conv_weights, self.norm_weights | |
# Encode the target label to get the encoded features. | |
encoded_label = self.label_embedding(label, weights=( | |
embedding_weights if self.use_hyper_embed else None)) | |
return x, encoded_label, conv_weights, norm_weights, \ | |
atn, atn_vis, ref_idx | |
def encode_reference(self, ref_image, ref_label, label, k): | |
r"""Encode the reference image to get features for weight generation. | |
Args: | |
ref_image ((NxK)x3xHxW tensor): Reference images. | |
ref_label ((NxK)xCxHxW tensor): Reference labels. | |
label (NxCxHxW tensor): Target label. | |
k (int): Number of reference images. | |
Returns: | |
(tuple): | |
- x (NxC2xH2xW2 tensor): Encoded features from reference images | |
for the main branch (as input to the decoder). | |
- encoded_ref (list of tensors): Encoded features from reference | |
images for the weight generation branch. | |
- attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps. | |
- atn_vis (1x1xH1xW1 tensor): Visualization for attention scores. | |
- ref_idx (Nx1 tensor): Index for which image to use from the | |
reference images. | |
""" | |
if self.concat_ref_label: | |
# Concat reference label map and image together for encoding. | |
concat_ref = torch.cat([ref_image, ref_label], dim=1) | |
x = self.ref_img_first(concat_ref) | |
elif self.mul_ref_label: | |
# Apply conv to both reference label and image, then multiply them | |
# together for encoding. | |
x = self.ref_img_first(ref_image) | |
x_label = self.ref_label_first(ref_label) | |
else: | |
x = self.ref_img_first(ref_image) | |
# Attention map and the index of the most similar reference image. | |
atn = atn_vis = ref_idx = None | |
for i in range(self.num_downsamples): | |
x = getattr(self, 'ref_img_down_' + str(i))(x) | |
if self.mul_ref_label: | |
x_label = getattr(self, 'ref_label_down_' + str(i))(x_label) | |
# Combine different reference images at a particular layer. | |
if k > 1 and i == self.num_downsample_atn - 1: | |
x, atn, atn_vis = self.attention_module(x, label, ref_label) | |
if self.mul_ref_label: | |
x_label, _, _ = self.attention_module(x_label, None, None, | |
atn) | |
atn_sum = atn.view(label.shape[0], k, -1).sum(2) | |
ref_idx = torch.argmax(atn_sum, dim=1) | |
# Get all corresponding layers in the encoder output for generating | |
# weights in corresponding layers. | |
encoded_image_ref = [x] | |
if self.mul_ref_label: | |
encoded_ref_label = [x_label] | |
for i in reversed(range(self.num_downsamples)): | |
conv = getattr(self, 'ref_img_up_' + str(i))( | |
encoded_image_ref[-1]) | |
encoded_image_ref.append(conv) | |
if self.mul_ref_label: | |
conv_label = getattr(self, 'ref_label_up_' + str(i))( | |
encoded_ref_label[-1]) | |
encoded_ref_label.append(conv_label) | |
if self.mul_ref_label: | |
encoded_ref = [] | |
for i in range(len(encoded_image_ref)): | |
conv, conv_label \ | |
= encoded_image_ref[i], encoded_ref_label[i] | |
b, c, h, w = conv.size() | |
conv_label = nn.Softmax(dim=1)(conv_label) | |
conv_prod = (conv.view(b, c, 1, h * w) * | |
conv_label.view(b, 1, c, | |
h * w)).sum(3, keepdim=True) | |
encoded_ref.append(conv_prod) | |
else: | |
encoded_ref = encoded_image_ref | |
encoded_ref = encoded_ref[::-1] | |
return x, encoded_ref, atn, atn_vis, ref_idx | |
def get_norm_weights(self, x, i): | |
r"""Adaptively generate weights for SPADE in layer i of generator. | |
Args: | |
x (NxCxHxW tensor): Input features. | |
i (int): Layer index. | |
Returns: | |
(tuple): | |
- embedding_weights (list of tensors): Weights for the label | |
embedding network. | |
- norm_weights (list of tensors): Weights for the SPADE layers. | |
""" | |
if not self.mul_ref_label: | |
# Get fixed output size for fc layers. | |
x = nn.AdaptiveAvgPool2d((self.sh_fix, self.sw_fix))(x) | |
in_ch = self.num_filters_each_layer[i] | |
out_ch = self.num_filters_each_layer[i + 1] | |
spade_ch = self.spade_in_channels[i] | |
eks, sks = self.embed_kernel_size, self.kernel_size | |
b = x.size(0) | |
weight_reshaper = WeightReshaper() | |
x = weight_reshaper.reshape_embed_input(x) | |
# Weights for the label embedding network. | |
embedding_weights = None | |
if self.use_hyper_embed: | |
fc_e = getattr(self, 'fc_spade_e_' + str(i))(x).view(b, -1) | |
if 'decoder' in self.embed_arch: | |
weight_shape = [in_ch, out_ch, eks, eks] | |
fc_e = fc_e[:, :-in_ch] | |
else: | |
weight_shape = [out_ch, in_ch, eks, eks] | |
embedding_weights = weight_reshaper.reshape_weight(fc_e, | |
weight_shape) | |
# Weights for the 3 layers in SPADE module: conv_0, conv_1, | |
# and shortcut. | |
fc_0 = getattr(self, 'fc_spade_0_' + str(i))(x).view(b, -1) | |
fc_1 = getattr(self, 'fc_spade_1_' + str(i))(x).view(b, -1) | |
fc_s = getattr(self, 'fc_spade_s_' + str(i))(x).view(b, -1) | |
if self.conv_before_norm: | |
out_ch = in_ch | |
weight_0 = weight_reshaper.reshape_weight(fc_0, [out_ch * 2, spade_ch, | |
sks, sks]) | |
weight_1 = weight_reshaper.reshape_weight(fc_1, [in_ch * 2, spade_ch, | |
sks, sks]) | |
weight_s = weight_reshaper.reshape_weight(fc_s, [out_ch * 2, spade_ch, | |
sks, sks]) | |
norm_weights = [weight_0, weight_1, weight_s] | |
return embedding_weights, norm_weights | |
def get_conv_weights(self, x, i): | |
r"""Adaptively generate weights for layer i in main branch convolutions. | |
Args: | |
x (NxCxHxW tensor): Input features. | |
i (int): Layer index. | |
Returns: | |
(tuple): | |
- conv_weights (list of tensors): Weights for the conv layers in | |
the main branch. | |
""" | |
if not self.mul_ref_label: | |
x = nn.AdaptiveAvgPool2d((self.sh_fix, self.sw_fix))(x) | |
in_ch = self.num_filters_each_layer[i] | |
out_ch = self.num_filters_each_layer[i + 1] | |
cks = self.conv_kernel_size | |
b = x.size()[0] | |
weight_reshaper = WeightReshaper() | |
x = weight_reshaper.reshape_embed_input(x) | |
fc_0 = getattr(self, 'fc_conv_0_' + str(i))(x).view(b, -1) | |
fc_1 = getattr(self, 'fc_conv_1_' + str(i))(x).view(b, -1) | |
fc_s = getattr(self, 'fc_conv_s_' + str(i))(x).view(b, -1) | |
weight_0 = weight_reshaper.reshape_weight(fc_0, [in_ch, out_ch, | |
cks, cks]) | |
weight_1 = weight_reshaper.reshape_weight(fc_1, [in_ch, in_ch, | |
cks, cks]) | |
weight_s = weight_reshaper.reshape_weight(fc_s, [in_ch, out_ch, 1, 1]) | |
return [weight_0, weight_1, weight_s] | |
def reset(self): | |
r"""Reset the network at the beginning of a sequence.""" | |
self.embedding_weights = self.conv_weights = self.norm_weights = None | |
class WeightReshaper(): | |
r"""Handles all weight reshape related tasks.""" | |
def reshape_weight(self, x, weight_shape): | |
r"""Reshape input x to the desired weight shape. | |
Args: | |
x (tensor or list of tensors): Input features. | |
weight_shape (list of int): Desired shape of the weight. | |
Returns: | |
(tuple): | |
- weight (tensor): Network weights | |
- bias (tensor): Network bias. | |
""" | |
# If desired shape is a list, first divide x into the target list of | |
# features. | |
if type(weight_shape[0]) == list and type(x) != list: | |
x = self.split_weights(x, self.sum_mul(weight_shape)) | |
if type(x) == list: | |
return [self.reshape_weight(xi, wi) | |
for xi, wi in zip(x, weight_shape)] | |
# Get output shape, and divide x into either weight + bias or | |
# just weight. | |
weight_shape = [x.size(0)] + weight_shape | |
bias_size = weight_shape[1] | |
try: | |
weight = x[:, :-bias_size].view(weight_shape) | |
bias = x[:, -bias_size:] | |
except Exception: | |
weight = x.view(weight_shape) | |
bias = None | |
return [weight, bias] | |
def split_weights(self, weight, sizes): | |
r"""When the desired shape is a list, first divide the input to each | |
corresponding weight shape in the list. | |
Args: | |
weight (tensor): Input weight. | |
sizes (int or list of int): Target sizes. | |
Returns: | |
weight (list of tensors): Divided weights. | |
""" | |
if isinstance(sizes, list): | |
weights = [] | |
cur_size = 0 | |
for i in range(len(sizes)): | |
# For each target size in sizes, get the number of elements | |
# needed. | |
next_size = cur_size + self.sum(sizes[i]) | |
# Recursively divide the weights. | |
weights.append(self.split_weights( | |
weight[:, cur_size:next_size], sizes[i])) | |
cur_size = next_size | |
assert (next_size == weight.size(1)) | |
return weights | |
return weight | |
def reshape_embed_input(self, x): | |
r"""Reshape input to be (B x C) X H X W. | |
Args: | |
x (tensor or list of tensors): Input features. | |
Returns: | |
x (tensor or list of tensors): Reshaped features. | |
""" | |
if isinstance(x, list): | |
return [self.reshape_embed_input(xi) for xi in zip(x)] | |
b, c, _, _ = x.size() | |
x = x.view(b * c, -1) | |
return x | |
def sum(self, x): | |
r"""Sum all elements recursively in a nested list. | |
Args: | |
x (nested list of int): Input list of elements. | |
Returns: | |
out (int): Sum of all elements. | |
""" | |
if type(x) != list: | |
return x | |
return sum([self.sum(xi) for xi in x]) | |
def sum_mul(self, x): | |
r"""Given a weight shape, compute the number of elements needed for | |
weight + bias. If input is a list of shapes, sum all the elements. | |
Args: | |
x (list of int): Input list of elements. | |
Returns: | |
out (int or list of int): Summed number of elements. | |
""" | |
assert (type(x) == list) | |
if type(x[0]) != list: | |
return np.prod(x) + x[0] # x[0] accounts for bias. | |
return [self.sum_mul(xi) for xi in x] | |
class AttentionModule(nn.Module): | |
r"""Attention module constructor. | |
Args: | |
atn_cfg (obj): Generator definition part of the yaml config file. | |
data_cfg (obj): Data definition part of the yaml config file | |
conv_2d_block: Conv2DBlock constructor. | |
num_filters_each_layer (int): The number of filters in each layer. | |
""" | |
def __init__(self, atn_cfg, data_cfg, conv_2d_block, | |
num_filters_each_layer): | |
super().__init__() | |
self.initial_few_shot_K = data_cfg.initial_few_shot_K | |
num_input_channels = data_cfg.num_input_channels | |
num_filters = getattr(atn_cfg, 'num_filters', 32) | |
self.num_downsample_atn = getattr(atn_cfg, 'num_downsamples', 2) | |
self.atn_query_first = conv_2d_block(num_input_channels, num_filters) | |
self.atn_key_first = conv_2d_block(num_input_channels, num_filters) | |
for i in range(self.num_downsamples_atn): | |
f_in, f_out = num_filters_each_layer[i], \ | |
num_filters_each_layer[i + 1] | |
setattr(self, 'atn_key_%d' % i, | |
conv_2d_block(f_in, f_out, stride=2)) | |
setattr(self, 'atn_query_%d' % i, | |
conv_2d_block(f_in, f_out, stride=2)) | |
def forward(self, in_features, label, ref_label, attention=None): | |
r"""Get the attention map to combine multiple image features in the | |
case of multiple reference images. | |
Args: | |
in_features ((NxK)xC1xH1xW1 tensor): Input feaures. | |
label (NxC2xH2xW2 tensor): Target label. | |
ref_label (NxC2xH2xW2 tensor): Reference label. | |
attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps. | |
Returns: | |
(tuple): | |
- out_features (NxC1xH1xW1 tensor): Attention-combined features. | |
- attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps. | |
- atn_vis (1x1xH1xW1 tensor): Visualization for attention scores. | |
""" | |
b, c, h, w = in_features.size() | |
k = self.initial_few_shot_K | |
b = b // k | |
if attention is None: | |
# Compute the attention map by encoding ref_label and label as | |
# key and query. The map represents how much energy for the k-th | |
# map at location (h_i, w_j) can contribute to the final map at | |
# location (h_i2, w_j2). | |
atn_key = self.attention_encode(ref_label, 'atn_key') | |
atn_query = self.attention_encode(label, 'atn_query') | |
atn_key = atn_key.view(b, k, c, -1).permute( | |
0, 1, 3, 2).contiguous().view(b, -1, c) # B X KHW X C | |
atn_query = atn_query.view(b, c, -1) # B X C X HW | |
energy = torch.bmm(atn_key, atn_query) # B X KHW X HW | |
attention = nn.Softmax(dim=1)(energy) | |
# Combine the K features from different ref images into one by using | |
# the attention map. | |
in_features = in_features.view(b, k, c, h * w).permute( | |
0, 2, 1, 3).contiguous().view(b, c, -1) # B X C X KHW | |
out_features = torch.bmm(in_features, attention).view(b, c, h, w) | |
# Get a slice of the attention map for visualization. | |
atn_vis = attention.view(b, k, h * w, h * w).sum(2).view(b, k, h, w) | |
return out_features, attention, atn_vis[-1:, 0:1] | |
def attention_encode(self, img, net_name): | |
r"""Encode the input image to get the attention map. | |
Args: | |
img (NxCxHxW tensor): Input image. | |
net_name (str): Name for attention network. | |
Returns: | |
x (NxC2xH2xW2 tensor): Encoded feature. | |
""" | |
x = getattr(self, net_name + '_first')(img) | |
for i in range(self.num_downsample_atn): | |
x = getattr(self, net_name + '_' + str(i))(x) | |
return x | |
class FlowGenerator(nn.Module): | |
r"""flow generator constructor. | |
Args: | |
flow_cfg (obj): Flow definition part of the yaml config file. | |
data_cfg (obj): Data definition part of the yaml config file. | |
num_frames (int): Number of input frames. | |
""" | |
def __init__(self, flow_cfg, data_cfg, num_frames): | |
super().__init__() | |
num_input_channels = data_cfg.num_input_channels | |
if num_input_channels == 0: | |
num_input_channels = 1 | |
num_prev_img_channels = get_paired_input_image_channel_number(data_cfg) | |
num_downsamples = getattr(flow_cfg, 'num_downsamples', 3) | |
kernel_size = getattr(flow_cfg, 'kernel_size', 3) | |
padding = kernel_size // 2 | |
num_blocks = getattr(flow_cfg, 'num_blocks', 6) | |
num_filters = getattr(flow_cfg, 'num_filters', 32) | |
max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024) | |
num_filters_each_layer = [min(max_num_filters, num_filters * (2 ** i)) | |
for i in range(num_downsamples + 1)] | |
self.flow_output_multiplier = getattr(flow_cfg, | |
'flow_output_multiplier', 20) | |
self.sep_up_mask = getattr(flow_cfg, 'sep_up_mask', False) | |
activation_norm_type = getattr(flow_cfg, 'activation_norm_type', | |
'sync_batch') | |
weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral') | |
base_conv_block = partial(Conv2dBlock, kernel_size=kernel_size, | |
padding=padding, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
nonlinearity='leakyrelu') | |
num_input_channels = num_input_channels * num_frames + \ | |
num_prev_img_channels * (num_frames - 1) | |
# First layer. | |
down_flow = [base_conv_block(num_input_channels, num_filters)] | |
# Downsamples. | |
for i in range(num_downsamples): | |
down_flow += [base_conv_block(num_filters_each_layer[i], | |
num_filters_each_layer[i + 1], | |
stride=2)] | |
# Resnet blocks. | |
res_flow = [] | |
ch = num_filters_each_layer[num_downsamples] | |
for i in range(num_blocks): | |
res_flow += [ | |
Res2dBlock(ch, ch, kernel_size, padding=padding, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
order='NACNAC')] | |
# Upsamples. | |
up_flow = [] | |
for i in reversed(range(num_downsamples)): | |
up_flow += [nn.Upsample(scale_factor=2), | |
base_conv_block(num_filters_each_layer[i + 1], | |
num_filters_each_layer[i])] | |
conv_flow = [Conv2dBlock(num_filters, 2, kernel_size, padding=padding)] | |
conv_mask = [Conv2dBlock(num_filters, 1, kernel_size, padding=padding, | |
nonlinearity='sigmoid')] | |
self.down_flow = nn.Sequential(*down_flow) | |
self.res_flow = nn.Sequential(*res_flow) | |
self.up_flow = nn.Sequential(*up_flow) | |
if self.sep_up_mask: | |
self.up_mask = nn.Sequential(*copy.deepcopy(up_flow)) | |
self.conv_flow = nn.Sequential(*conv_flow) | |
self.conv_mask = nn.Sequential(*conv_mask) | |
def forward(self, label, ref_label, ref_image): | |
r"""Flow generator forward. | |
Args: | |
label (4D tensor) : Input label tensor. | |
ref_label (4D tensor) : Reference label tensors. | |
ref_image (4D tensor) : Reference image tensors. | |
Returns: | |
(tuple): | |
- flow (4D tensor) : Generated flow map. | |
- mask (4D tensor) : Generated occlusion mask. | |
""" | |
label_concat = torch.cat([label, ref_label, ref_image], dim=1) | |
downsample = self.down_flow(label_concat) | |
res = self.res_flow(downsample) | |
flow_feat = self.up_flow(res) | |
flow = self.conv_flow(flow_feat) * self.flow_output_multiplier | |
mask_feat = self.up_mask(res) if self.sep_up_mask else flow_feat | |
mask = self.conv_mask(mask_feat) | |
return flow, mask | |
class LabelEmbedder(nn.Module): | |
r"""Embed the input label map to get embedded features. | |
Args: | |
emb_cfg (obj): Embed network configuration. | |
num_input_channels (int): Number of input channels. | |
num_hyper_layers (int): Number of hyper layers. | |
""" | |
def __init__(self, emb_cfg, num_input_channels, num_hyper_layers=0): | |
super().__init__() | |
num_filters = getattr(emb_cfg, 'num_filters', 32) | |
max_num_filters = getattr(emb_cfg, 'max_num_filters', 1024) | |
self.arch = getattr(emb_cfg, 'arch', 'encoderdecoder') | |
self.num_downsamples = num_downsamples = \ | |
getattr(emb_cfg, 'num_downsamples', 5) | |
kernel_size = getattr(emb_cfg, 'kernel_size', 3) | |
weight_norm_type = getattr(emb_cfg, 'weight_norm_type', 'spectral') | |
activation_norm_type = getattr(emb_cfg, 'activation_norm_type', 'none') | |
self.unet = 'unet' in self.arch | |
self.has_decoder = 'decoder' in self.arch or self.unet | |
self.num_hyper_layers = num_hyper_layers \ | |
if num_hyper_layers != -1 else num_downsamples | |
base_conv_block = partial(HyperConv2dBlock, kernel_size=kernel_size, | |
padding=(kernel_size // 2), | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
nonlinearity='leakyrelu') | |
ch = [min(max_num_filters, num_filters * (2 ** i)) | |
for i in range(num_downsamples + 1)] | |
self.conv_first = base_conv_block(num_input_channels, num_filters, | |
activation_norm_type='none') | |
# Downsample. | |
for i in range(num_downsamples): | |
is_hyper_conv = (i < num_hyper_layers) and not self.has_decoder | |
setattr(self, 'down_%d' % i, | |
base_conv_block(ch[i], ch[i + 1], stride=2, | |
is_hyper_conv=is_hyper_conv)) | |
# Upsample. | |
if self.has_decoder: | |
self.upsample = nn.Upsample(scale_factor=2) | |
for i in reversed(range(num_downsamples)): | |
ch_i = ch[i + 1] * ( | |
2 if self.unet and i != num_downsamples - 1 else 1) | |
setattr(self, 'up_%d' % i, | |
base_conv_block(ch_i, ch[i], | |
is_hyper_conv=(i < num_hyper_layers))) | |
def forward(self, input, weights=None): | |
r"""Embedding network forward. | |
Args: | |
input (NxCxHxW tensor): Network input. | |
weights (list of tensors): Conv weights if using hyper network. | |
Returns: | |
output (list of tensors): Network outputs at different layers. | |
""" | |
if input is None: | |
return None | |
output = [self.conv_first(input)] | |
for i in range(self.num_downsamples): | |
layer = getattr(self, 'down_%d' % i) | |
# For hyper networks, the hyper layers are at the last few layers | |
# of decoder (if the network has a decoder). Otherwise, the hyper | |
# layers will be at the first few layers of the network. | |
if i >= self.num_hyper_layers or self.has_decoder: | |
conv = layer(output[-1]) | |
else: | |
conv = layer(output[-1], conv_weights=weights[i]) | |
# We will use outputs from different layers as input to different | |
# SPADE layers in the main branch. | |
output.append(conv) | |
if not self.has_decoder: | |
return output | |
# If the network has a decoder, will use outputs from the decoder | |
# layers instead of the encoding layers. | |
if not self.unet: | |
output = [output[-1]] | |
for i in reversed(range(self.num_downsamples)): | |
input_i = output[-1] | |
if self.unet and i != self.num_downsamples - 1: | |
input_i = torch.cat([input_i, output[i + 1]], dim=1) | |
input_i = self.upsample(input_i) | |
layer = getattr(self, 'up_%d' % i) | |
# The last few layers will be hyper layers if necessary. | |
if i >= self.num_hyper_layers: | |
conv = layer(input_i) | |
else: | |
conv = layer(input_i, conv_weights=weights[i]) | |
output.append(conv) | |
if self.unet: | |
output = output[self.num_downsamples:] | |
return output[::-1] | |