Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from imaginaire.config import Config | |
from imaginaire.generators.vid2vid import Generator as Vid2VidGenerator | |
from imaginaire.model_utils.fs_vid2vid import resample | |
from imaginaire.model_utils.wc_vid2vid.render import SplatRenderer | |
from imaginaire.utils.trainer import (get_model_optimizer_and_scheduler, | |
get_trainer) | |
from imaginaire.utils.visualization import tensor2im | |
class Generator(Vid2VidGenerator): | |
r"""world consistent vid2vid generator constructor. | |
Args: | |
gen_cfg (obj): Generator definition part of the yaml config file. | |
data_cfg (obj): Data definition part of the yaml config file | |
""" | |
def __init__(self, gen_cfg, data_cfg): | |
# Guidance options. | |
self.guidance_cfg = gen_cfg.guidance | |
self.guidance_only_with_flow = getattr( | |
self.guidance_cfg, 'only_with_flow', False) | |
self.guidance_partial_conv = getattr( | |
self.guidance_cfg, 'partial_conv', False) | |
# Splatter for guidance. | |
self.renderer = SplatRenderer() | |
self.reset_renderer() | |
# Single image model. | |
self.single_image_model = None | |
# Initialize the rest same as vid2vid. | |
super().__init__(gen_cfg, data_cfg) | |
def _init_single_image_model(self, load_weights=True): | |
r"""Load single image model, if any.""" | |
if self.single_image_model is None and \ | |
hasattr(self.gen_cfg, 'single_image_model'): | |
print('Using single image model...') | |
single_image_cfg = Config(self.gen_cfg.single_image_model.config) | |
# Init model. | |
net_G, net_D, opt_G, opt_D, sch_G, sch_D = \ | |
get_model_optimizer_and_scheduler(single_image_cfg) | |
# Init trainer and load checkpoint. | |
trainer = get_trainer(single_image_cfg, net_G, net_D, | |
opt_G, opt_D, | |
sch_G, sch_D, | |
None, None) | |
if load_weights: | |
print('Loading single image model checkpoint') | |
single_image_ckpt = self.gen_cfg.single_image_model.checkpoint | |
trainer.load_checkpoint(single_image_cfg, single_image_ckpt) | |
print('Loaded single image model checkpoint') | |
self.single_image_model = net_G.module | |
self.single_image_model_z = None | |
def reset_renderer(self, is_flipped_input=False): | |
r"""Reset the renderer. | |
Args: | |
is_flipped_input (bool): Is the input sequence left-right flipped? | |
""" | |
self.renderer.reset() | |
self.is_flipped_input = is_flipped_input | |
self.renderer_num_forwards = 0 | |
self.single_image_model_z = None | |
def renderer_update_point_cloud(self, image, point_info): | |
r"""Update the renderer's color dictionary.""" | |
if point_info is None or len(point_info) == 0: | |
return | |
# print('Updating the renderer.') | |
_, _, h, w = image.size() | |
# Renderer expects (h, w, c) [0-255] RGB image. | |
if isinstance(image, torch.Tensor): | |
image = tensor2im(image.detach())[0] | |
# Flip this image to correspond to SfM camera pose. | |
if self.is_flipped_input: | |
image = np.fliplr(image).copy() | |
self.renderer.update_point_cloud(image, point_info) | |
self.renderer_num_forwards += 1 | |
def get_guidance_images_and_masks(self, unprojection): | |
r"""Do stuff.""" | |
resolution = 'w1024xh512' | |
point_info = unprojection[resolution] | |
w, h = resolution.split('x') | |
w, h = int(w[1:]), int(h[1:]) | |
# This returns guidance image in [0-255] RGB. | |
# We will convert it into Tensor repr. below. | |
guidance_image, guidance_mask = self.renderer.render_image( | |
point_info, w, h, return_mask=True) | |
# If mask is None, there is no guidance. | |
# print(np.sum(guidance_mask), guidance_mask.size) | |
# if np.sum(guidance_mask) == 0: | |
# return None, point_info | |
# Flip guidance image and guidance mask if needed. | |
if self.is_flipped_input: | |
guidance_image = np.fliplr(guidance_image).copy() | |
guidance_mask = np.fliplr(guidance_mask).copy() | |
# Go from (h, w, c) to (1, c, h, w). | |
# Convert guidance image to Tensor. | |
guidance_image = (transforms.ToTensor()(guidance_image) - 0.5) * 2 | |
guidance_mask = transforms.ToTensor()(guidance_mask) | |
guidance = torch.cat((guidance_image, guidance_mask), dim=0) | |
guidance = guidance.unsqueeze(0).cuda() | |
# Save guidance at all resolutions. | |
guidance_images_and_masks = guidance | |
return guidance_images_and_masks, point_info | |
def forward(self, data): | |
r"""vid2vid generator forward. | |
Args: | |
data (dict) : Dictionary of input data. | |
Returns: | |
output (dict) : Dictionary of output data. | |
""" | |
self._init_single_image_model() | |
label = data['label'] | |
unprojection = data['unprojection'] | |
label_prev, img_prev = data['prev_labels'], data['prev_images'] | |
is_first_frame = img_prev is None | |
z = getattr(data, 'z', None) | |
bs, _, h, w = label.size() | |
# Whether to warp the previous frame or not. | |
flow = mask = img_warp = None | |
warp_prev = self.temporal_initialized and not is_first_frame and \ | |
label_prev.shape[1] == self.num_frames_G - 1 | |
# Get guidance images and masks. | |
guidance_images_and_masks, point_info = None, None | |
if unprojection is not None: | |
guidance_images_and_masks, point_info = \ | |
self.get_guidance_images_and_masks(unprojection) | |
# Get SPADE conditional maps by embedding current label input. | |
cond_maps_now = self.get_cond_maps(label, self.label_embedding) | |
# Use single image model, if flow features are not available. | |
# Guidance features are used whenever flow features are available. | |
if self.single_image_model is not None and not warp_prev: | |
# Get z vector for single image model. | |
if self.single_image_model_z is None: | |
bs = data['label'].size(0) | |
z = torch.randn(bs, self.single_image_model.style_dims, | |
dtype=torch.float32).cuda() | |
if data['label'].dtype == torch.float16: | |
z = z.half() | |
self.single_image_model_z = z | |
# Get output image. | |
data['z'] = self.single_image_model_z | |
self.single_image_model.eval() | |
with torch.no_grad(): | |
output = self.single_image_model.spade_generator(data) | |
img_final = output['fake_images'].detach() | |
fake_images_source = 'pretrained' | |
else: | |
# Input to the generator will either be noise/segmentation map (for | |
# first frame) or encoded previous frame (for subsequent frames). | |
if is_first_frame: | |
# First frame in the sequence, start from scratch. | |
if self.use_segmap_as_input: | |
x_img = F.interpolate(label, size=(self.sh, self.sw)) | |
x_img = self.fc(x_img) | |
else: | |
if z is None: | |
z = torch.randn(bs, self.z_dim, dtype=label.dtype, | |
device=label.get_device()).fill_(0) | |
x_img = self.fc(z).view(bs, -1, self.sh, self.sw) | |
# Upsampling layers. | |
for i in range(self.num_layers, self.num_downsamples_img, -1): | |
j = min(self.num_downsamples_embed, i) | |
x_img = getattr(self, 'up_' + str(i) | |
)(x_img, *cond_maps_now[j]) | |
x_img = self.upsample(x_img) | |
else: | |
# Not the first frame, will encode the previous frame and feed | |
# to the generator. | |
x_img = self.down_first(img_prev[:, -1]) | |
# Get label embedding for the previous frame. | |
cond_maps_prev = self.get_cond_maps(label_prev[:, -1], | |
self.label_embedding) | |
# Downsampling layers. | |
for i in range(self.num_downsamples_img + 1): | |
j = min(self.num_downsamples_embed, i) | |
x_img = getattr(self, 'down_' + str(i))(x_img, | |
*cond_maps_prev[j]) | |
if i != self.num_downsamples_img: | |
x_img = self.downsample(x_img) | |
# Resnet blocks. | |
j = min(self.num_downsamples_embed, | |
self.num_downsamples_img + 1) | |
for i in range(self.num_res_blocks): | |
cond_maps = cond_maps_prev[j] if \ | |
i < self.num_res_blocks // 2 else cond_maps_now[j] | |
x_img = getattr(self, 'res_' + str(i))(x_img, *cond_maps) | |
# Optical flow warped image features. | |
if warp_prev: | |
# Estimate flow & mask. | |
label_concat = torch.cat([label_prev.view(bs, -1, h, w), | |
label], dim=1) | |
img_prev_concat = img_prev.view(bs, -1, h, w) | |
flow, mask = self.flow_network_temp( | |
label_concat, img_prev_concat) | |
img_warp = resample(img_prev[:, -1], flow) | |
if self.spade_combine: | |
# if using SPADE combine, integrate the warped image (and | |
# occlusion mask) into conditional inputs for SPADE. | |
img_embed = torch.cat([img_warp, mask], dim=1) | |
cond_maps_img = self.get_cond_maps(img_embed, | |
self.img_prev_embedding) | |
x_raw_img = None | |
# Main image generation branch. | |
for i in range(self.num_downsamples_img, -1, -1): | |
# Get SPADE conditional inputs. | |
j = min(i, self.num_downsamples_embed) | |
cond_maps = cond_maps_now[j] | |
# For raw output generation. | |
if self.generate_raw_output: | |
if i >= self.num_multi_spade_layers - 1: | |
x_raw_img = x_img | |
if i < self.num_multi_spade_layers: | |
x_raw_img = self.one_up_conv_layer( | |
x_raw_img, cond_maps, i) | |
# Add flow and guidance features. | |
if warp_prev: | |
if i < self.num_multi_spade_layers: | |
# Add flow. | |
cond_maps += cond_maps_img[j] | |
# Add guidance. | |
if guidance_images_and_masks is not None: | |
cond_maps += [guidance_images_and_masks] | |
elif not self.guidance_only_with_flow: | |
# Add guidance if it is to be applied to every layer. | |
if guidance_images_and_masks is not None: | |
cond_maps += [guidance_images_and_masks] | |
x_img = self.one_up_conv_layer(x_img, cond_maps, i) | |
# Final conv layer. | |
img_final = torch.tanh(self.conv_img(x_img)) | |
fake_images_source = 'in_training' | |
# Update the point cloud color dict of renderer. | |
self.renderer_update_point_cloud(img_final, point_info) | |
output = dict() | |
output['fake_images'] = img_final | |
output['fake_flow_maps'] = flow | |
output['fake_occlusion_masks'] = mask | |
output['fake_raw_images'] = None | |
output['warped_images'] = img_warp | |
output['guidance_images_and_masks'] = guidance_images_and_masks | |
output['fake_images_source'] = fake_images_source | |
return output | |
def get_cond_dims(self, num_downs=0): | |
r"""Get the dimensions of conditional inputs. | |
Args: | |
num_downs (int) : How many downsamples at current layer. | |
Returns: | |
ch (list) : List of dimensions. | |
""" | |
if not self.use_embed: | |
ch = [self.num_input_channels] | |
else: | |
num_filters = getattr(self.emb_cfg, 'num_filters', 32) | |
num_downs = min(num_downs, self.num_downsamples_embed) | |
ch = [min(self.max_num_filters, num_filters * (2 ** num_downs))] | |
if (num_downs < self.num_multi_spade_layers): | |
ch = ch * 2 | |
# Also add guidance (RGB + mask = 4 channels, or 3 if partial). | |
if self.guidance_partial_conv: | |
ch.append(3) | |
else: | |
ch.append(4) | |
elif not self.guidance_only_with_flow: | |
if self.guidance_partial_conv: | |
ch.append(3) | |
else: | |
ch.append(4) | |
return ch | |
def get_partial(self, num_downs=0): | |
r"""Get if convs should be partial or not. | |
Args: | |
num_downs (int) : How many downsamples at current layer. | |
Returns: | |
partial (list) : List of boolean partial or not. | |
""" | |
partial = [False] | |
if (num_downs < self.num_multi_spade_layers): | |
partial = partial * 2 | |
# Also add guidance (RGB + mask = 4 channels, or 3 if partial). | |
if self.guidance_partial_conv: | |
partial.append(True) | |
else: | |
partial.append(False) | |
elif not self.guidance_only_with_flow: | |
if self.guidance_partial_conv: | |
partial.append(True) | |
else: | |
partial.append(False) | |
return partial | |
def get_cond_maps(self, label, embedder): | |
r"""Get the conditional inputs. | |
Args: | |
label (4D tensor) : Input label tensor. | |
embedder (obj) : Embedding network. | |
Returns: | |
cond_maps (list) : List of conditional inputs. | |
""" | |
if not self.use_embed: | |
return [label] * (self.num_layers + 1) | |
embedded_label = embedder(label) | |
cond_maps = [embedded_label] | |
cond_maps = [[m[i] for m in cond_maps] for i in | |
range(len(cond_maps[0]))] | |
return cond_maps | |