venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
# flake8: noqa
import importlib
import warnings
import torch
import torch.nn as nn
from imaginaire.model_utils.fs_vid2vid import (get_face_mask, get_fg_mask,
get_part_mask, pick_image,
resample)
class MaskedL1Loss(nn.Module):
r"""Masked L1 loss constructor."""
def __init__(self, normalize_over_valid=False):
super(MaskedL1Loss, self).__init__()
self.criterion = nn.L1Loss()
self.normalize_over_valid = normalize_over_valid
def forward(self, input, target, mask):
r"""Masked L1 loss computation.
Args:
input (tensor): Input tensor.
target (tensor): Target tensor.
mask (tensor): Mask to be applied to the output loss.
Returns:
(tensor): Loss value.
"""
mask = mask.expand_as(input)
loss = self.criterion(input * mask, target * mask)
if self.normalize_over_valid:
# The loss has been averaged over all pixels.
# Only average over regions which are valid.
loss = loss * torch.numel(mask) / (torch.sum(mask) + 1e-6)
return loss
class FlowLoss(nn.Module):
r"""Flow loss constructor.
Args:
cfg (obj): Configuration.
"""
def __init__(self, cfg):
super(FlowLoss, self).__init__()
self.cfg = cfg
self.data_cfg = cfg.data
self.criterion = nn.L1Loss()
self.criterionMasked = MaskedL1Loss()
flow_module = importlib.import_module(cfg.flow_network.type)
self.flowNet = flow_module.FlowNet(pretrained=True)
self.warp_ref = getattr(cfg.gen.flow, 'warp_ref', False)
self.pose_cfg = pose_cfg = getattr(cfg.data, 'for_pose_dataset', None)
self.for_pose_dataset = pose_cfg is not None
self.has_fg = getattr(cfg.data, 'has_foreground', False)
def forward(self, data, net_G_output, current_epoch):
r"""Compute losses on the output flow and occlusion mask.
Args:
data (dict): Input data.
net_G_output (dict): Generator output.
current_epoch (int): Current training epoch number.
Returns:
(dict):
- loss_flow_L1 (tensor): L1 loss compared to ground truth flow.
- loss_flow_warp (tensor): L1 loss between the warped image and the
target image when using the flow to warp.
- loss_mask (tensor): Loss for the occlusion mask.
"""
tgt_label, tgt_image = data['label'], data['image']
fake_image = net_G_output['fake_images']
warped_images = net_G_output['warped_images']
flow = net_G_output['fake_flow_maps']
occ_mask = net_G_output['fake_occlusion_masks']
if self.warp_ref:
# Pick the most similar reference image to warp.
ref_labels, ref_images = data['ref_labels'], data['ref_images']
ref_idx = net_G_output['ref_idx']
ref_label, ref_image = pick_image([ref_labels, ref_images], ref_idx)
else:
ref_label = ref_image = None
# Compute the ground truth flows and confidence maps.
flow_gt_prev = flow_gt_ref = conf_gt_prev = conf_gt_ref = None
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if self.warp_ref:
# Compute GT for warping reference -> target.
if self.for_pose_dataset:
# Use DensePose maps to compute flows for pose dataset.
flow_gt_ref, conf_gt_ref = self.flowNet(tgt_label[:, :3],
ref_label[:, :3])
else:
# Use RGB images for other datasets.
flow_gt_ref, conf_gt_ref = self.flowNet(tgt_image,
ref_image)
if current_epoch >= self.cfg.single_frame_epoch and \
data['real_prev_image'] is not None:
# Compute GT for warping previous -> target.
tgt_image_prev = data['real_prev_image']
flow_gt_prev, conf_gt_prev = self.flowNet(tgt_image,
tgt_image_prev)
flow_gt = [flow_gt_ref, flow_gt_prev]
flow_conf_gt = [conf_gt_ref, conf_gt_prev]
# Get the foreground masks.
fg_mask, ref_fg_mask = get_fg_mask([tgt_label, ref_label], self.has_fg)
# Compute losses for flow maps and masks.
loss_flow_L1, loss_flow_warp, body_mask_diff = \
self.compute_flow_losses(flow, warped_images, tgt_image, flow_gt,
flow_conf_gt, fg_mask, tgt_label,
ref_label)
loss_mask = self.compute_mask_losses(
occ_mask, fake_image, warped_images, tgt_label, tgt_image,
fg_mask, ref_fg_mask, body_mask_diff)
return loss_flow_L1, loss_flow_warp, loss_mask
def compute_flow_losses(self, flow, warped_images, tgt_image, flow_gt,
flow_conf_gt, fg_mask, tgt_label, ref_label):
r"""Compute losses on the generated flow maps.
Args:
flow (tensor or list of tensors): Generated flow maps.
warped_images (tensor or list of tensors): Warped images using the
flow maps.
tgt_image (tensor): Target image for the warped image.
flow_gt (tensor or list of tensors): Ground truth flow maps.
flow_conf_gt (tensor or list of tensors): Confidence for the ground
truth flow maps.
fg_mask (tensor): Foreground mask for the target image.
tgt_label (tensor): Target label map.
ref_label (tensor): Reference label map.
Returns:
(dict):
- loss_flow_L1 (tensor): L1 loss compared to ground truth flow.
- loss_flow_warp (tensor): L1 loss between the warped image and the
target image when using the flow to warp.
- body_mask_diff (tensor): Difference between warped body part map
and target body part map. Used for pose dataset only.
"""
loss_flow_L1 = torch.tensor(0., device=torch.device('cuda'))
loss_flow_warp = torch.tensor(0., device=torch.device('cuda'))
if isinstance(flow, list):
# Compute flow losses for both warping reference -> target and
# previous -> target.
for i in range(len(flow)):
loss_flow_L1_i, loss_flow_warp_i = \
self.compute_flow_loss(flow[i], warped_images[i], tgt_image,
flow_gt[i], flow_conf_gt[i], fg_mask)
loss_flow_L1 += loss_flow_L1_i
loss_flow_warp += loss_flow_warp_i
else:
# Compute loss for warping either reference or previous images.
loss_flow_L1, loss_flow_warp = \
self.compute_flow_loss(flow, warped_images, tgt_image,
flow_gt[-1], flow_conf_gt[-1], fg_mask)
# For pose dataset only.
body_mask_diff = None
if self.warp_ref:
if self.for_pose_dataset:
# Warped reference body part map should be similar to target
# body part map.
body_mask = get_part_mask(tgt_label[:, 2])
ref_body_mask = get_part_mask(ref_label[:, 2])
warped_ref_body_mask = resample(ref_body_mask, flow[0])
loss_flow_warp += self.criterion(warped_ref_body_mask,
body_mask)
body_mask_diff = torch.sum(
abs(warped_ref_body_mask - body_mask), dim=1, keepdim=True)
if self.has_fg:
# Warped reference foreground map should be similar to target
# foreground map.
fg_mask, ref_fg_mask = \
get_fg_mask([tgt_label, ref_label], True)
warped_ref_fg_mask = resample(ref_fg_mask, flow[0])
loss_flow_warp += self.criterion(warped_ref_fg_mask, fg_mask)
return loss_flow_L1, loss_flow_warp, body_mask_diff
def compute_flow_loss(self, flow, warped_image, tgt_image, flow_gt,
flow_conf_gt, fg_mask):
r"""Compute losses on the generated flow map.
Args:
flow (tensor): Generated flow map.
warped_image (tensor): Warped image using the flow map.
tgt_image (tensor): Target image for the warped image.
flow_gt (tensor): Ground truth flow map.
flow_conf_gt (tensor): Confidence for the ground truth flow map.
fg_mask (tensor): Foreground mask for the target image.
Returns:
(dict):
- loss_flow_L1 (tensor): L1 loss compared to ground truth flow.
- loss_flow_warp (tensor): L1 loss between the warped image and
the target image when using the flow to warp.
"""
loss_flow_L1 = torch.tensor(0., device=torch.device('cuda'))
loss_flow_warp = torch.tensor(0., device=torch.device('cuda'))
if flow is not None and flow_gt is not None:
# L1 loss compared to flow ground truth.
loss_flow_L1 = self.criterionMasked(flow, flow_gt,
flow_conf_gt * fg_mask)
if warped_image is not None:
# L1 loss between warped image and target image.
loss_flow_warp = self.criterion(warped_image, tgt_image)
return loss_flow_L1, loss_flow_warp
def compute_mask_losses(self, occ_mask, fake_image, warped_image,
tgt_label, tgt_image, fg_mask, ref_fg_mask,
body_mask_diff):
r"""Compute losses on the generated occlusion masks.
Args:
occ_mask (tensor or list of tensors): Generated occlusion masks.
fake_image (tensor): Generated image.
warped_image (tensor or list of tensors): Warped images using the
flow maps.
tgt_label (tensor): Target label map.
tgt_image (tensor): Target image for the warped image.
fg_mask (tensor): Foreground mask for the target image.
ref_fg_mask (tensor): Foreground mask for the reference image.
body_mask_diff (tensor): Difference between warped body part map
and target body part map. Used for pose dataset only.
Returns:
(tensor): Loss for the mask.
"""
loss_mask = torch.tensor(0., device=torch.device('cuda'))
if isinstance(occ_mask, list):
# Compute occlusion mask losses for both warping reference -> target
# and previous -> target.
for i in range(len(occ_mask)):
loss_mask += self.compute_mask_loss(occ_mask[i],
warped_image[i],
tgt_image)
else:
# Compute loss for warping either reference or previous images.
loss_mask += self.compute_mask_loss(occ_mask, warped_image,
tgt_image)
if self.warp_ref:
ref_occ_mask = occ_mask[0]
dummy0 = torch.zeros_like(ref_occ_mask)
dummy1 = torch.ones_like(ref_occ_mask)
if self.for_pose_dataset:
# Enforce output to use more warped reference image for
# face region.
face_mask = get_face_mask(tgt_label[:, 2]).unsqueeze(1)
AvgPool = torch.nn.AvgPool2d(15, padding=7, stride=1)
face_mask = AvgPool(face_mask)
loss_mask += self.criterionMasked(ref_occ_mask, dummy0,
face_mask)
loss_mask += self.criterionMasked(fake_image, warped_image[0],
face_mask)
# Enforce output to use more hallucinated image for discrepancy
# regions of body part masks between warped reference and
# target image.
loss_mask += self.criterionMasked(ref_occ_mask, dummy1,
body_mask_diff)
if self.has_fg:
# Enforce output to use more hallucinated image for discrepancy
# regions of foreground masks between reference and target
# image.
fg_mask_diff = ((ref_fg_mask - fg_mask) > 0).float()
loss_mask += self.criterionMasked(ref_occ_mask, dummy1,
fg_mask_diff)
return loss_mask
def compute_mask_loss(self, occ_mask, warped_image, tgt_image):
r"""Compute losses on the generated occlusion mask.
Args:
occ_mask (tensor): Generated occlusion mask.
warped_image (tensor): Warped image using the flow map.
tgt_image (tensor): Target image for the warped image.
Returns:
(tensor): Loss for the mask.
"""
loss_mask = torch.tensor(0., device=torch.device('cuda'))
if occ_mask is not None:
dummy0 = torch.zeros_like(occ_mask)
dummy1 = torch.ones_like(occ_mask)
# Compute the confidence map based on L1 distance between warped
# and GT image.
img_diff = torch.sum(abs(warped_image - tgt_image), dim=1,
keepdim=True)
conf = torch.clamp(1 - img_diff, 0, 1)
# Force mask value to be small if warped image is similar to GT,
# and vice versa.
loss_mask = self.criterionMasked(occ_mask, dummy0, conf)
loss_mask += self.criterionMasked(occ_mask, dummy1, 1 - conf)
return loss_mask