oguzakif's picture
init repo
d4b77ac
raw history blame
No virus
17.1 kB
import os
import sys
import torch
import torch.nn as nn
import numpy as np
class AlignLoss(nn.Module):
def __init__(self, reduction='mean'):
super().__init__()
self.loss_fn = nn.L1Loss(reduction=reduction)
def forward(self, frames, masks, aligned_vs, aligned_rs):
"""
:param frames: The original frames(GT)
:param masks: Original masks
:param aligned_vs: aligned visibility map from reference frame(List: B, C, T, H, W)
:param aligned_rs: aligned reference frames(List: B, C, T, H, W)
:return:
"""
try:
B, C, T, H, W = frames.shape
except ValueError:
frames = frames.unsqueeze(2)
masks = masks.unsqueeze(2)
B, C, T, H, W = frames.shape
loss = 0
for i in range(T):
frame = frames[:, :, i]
mask = masks[:, :, i]
aligned_v = aligned_vs[i]
aligned_r = aligned_rs[i]
loss += self._singleFrameAlignLoss(frame, mask, aligned_v, aligned_r)
return loss
def _singleFrameAlignLoss(self, targetFrame, targetMask, aligned_v, aligned_r):
"""
:param targetFrame: targetFrame to be aligned-> B, C, H, W
:param targetMask: the mask of target frames
:param aligned_v: aligned visibility map from reference frame
:param aligned_r: aligned reference frame-> B, C, T, H, W
:return:
"""
targetVisibility = 1. - targetMask
targetVisibility = targetVisibility.unsqueeze(2)
targetFrame = targetFrame.unsqueeze(2)
visibility_map = targetVisibility * aligned_v
target_visibility = visibility_map * targetFrame
reference_visibility = visibility_map * aligned_r
loss = 0
for i in range(aligned_r.shape[2]):
loss += self.loss_fn(target_visibility[:, :, i], reference_visibility[:, :, i])
return loss
class HoleVisibleLoss(nn.Module):
def __init__(self, reduction='mean'):
super().__init__()
self.loss_fn = nn.L1Loss(reduction=reduction)
def forward(self, outputs, masks, GTs, c_masks):
try:
B, C, T, H, W = outputs.shape
except ValueError:
outputs = outputs.unsqueeze(2)
masks = masks.unsqueeze(2)
GTs = GTs.unsqueeze(2)
c_masks = c_masks.unsqueeze(2)
B, C, T, H, W = outputs.shape
loss = 0
for i in range(T):
loss += self._singleFrameHoleVisibleLoss(outputs[:, :, i], masks[:, :, i], c_masks[:, :, i], GTs[:, :, i])
return loss
def _singleFrameHoleVisibleLoss(self, targetFrame, targetMask, c_mask, GT):
return self.loss_fn(targetMask * c_mask * targetFrame, targetMask * c_mask * GT)
class HoleInvisibleLoss(nn.Module):
def __init__(self, reduction='mean'):
super().__init__()
self.loss_fn = nn.L1Loss(reduction=reduction)
def forward(self, outputs, masks, GTs, c_masks):
try:
B, C, T, H, W = outputs.shape
except ValueError:
outputs = outputs.unsqueeze(2)
masks = masks.unsqueeze(2)
GTs = GTs.unsqueeze(2)
c_masks = c_masks.unsqueeze(2)
B, C, T, H, W = outputs.shape
loss = 0
for i in range(T):
loss += self._singleFrameHoleInvisibleLoss(outputs[:, :, i], masks[:, :, i], c_masks[:, :, i], GTs[:, :, i])
return loss
def _singleFrameHoleInvisibleLoss(self, targetFrame, targetMask, c_mask, GT):
return self.loss_fn(targetMask * (1. - c_mask) * targetFrame, targetMask * (1. - c_mask) * GT)
class NonHoleLoss(nn.Module):
def __init__(self, reduction='mean'):
super().__init__()
self.loss_fn = nn.L1Loss(reduction=reduction)
def forward(self, outputs, masks, GTs):
try:
B, C, T, H, W = outputs.shape
except ValueError:
outputs = outputs.unsqueeze(2)
masks = masks.unsqueeze(2)
GTs = GTs.unsqueeze(2)
B, C, T, H, W = outputs.shape
loss = 0
for i in range(T):
loss += self._singleNonHoleLoss(outputs[:, :, i], masks[:, :, i], GTs[:, :, i])
return loss
def _singleNonHoleLoss(self, targetFrame, targetMask, GT):
return self.loss_fn((1. - targetMask) * targetFrame, (1. - targetMask) * GT)
class ReconLoss(nn.Module):
def __init__(self, reduction='mean', masked=False):
super().__init__()
self.loss_fn = nn.L1Loss(reduction=reduction)
self.masked = masked
def forward(self, model_output, target, mask):
outputs = model_output
targets = target
if self.masked:
masks = mask
return self.loss_fn(outputs * masks, targets * masks) # L1 loss in masked region
else:
return self.loss_fn(outputs, targets) # L1 loss in the whole region
class VGGLoss(nn.Module):
def __init__(self, vgg):
super().__init__()
self.l1_loss = nn.L1Loss()
self.vgg = vgg
def vgg_loss(self, output, target):
output_feature = self.vgg(output)
target_feature = self.vgg(target)
loss = (
self.l1_loss(output_feature.relu2_2, target_feature.relu2_2)
+ self.l1_loss(output_feature.relu3_3, target_feature.relu3_3)
+ self.l1_loss(output_feature.relu4_3, target_feature.relu4_3)
)
return loss
def forward(self, data_input, model_output):
targets = data_input
outputs = model_output
mean_image_loss = self.vgg_loss(outputs, targets)
return mean_image_loss
class StyleLoss(nn.Module):
def __init__(self, vgg, original_channel_norm=True):
super().__init__()
self.l1_loss = nn.L1Loss()
self.vgg = vgg
self.original_channel_norm = original_channel_norm
# From https://github.com/pytorch/tutorials/blob/master/advanced_source/neural_style_tutorial.py
def gram_matrix(self, input):
a, b, c, d = input.size() # a=batch size(=1)
# b=number of feature maps
# (c,d)=dimensions of a f. map (N=c*d)
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL
G = torch.mm(features, features.t()) # compute the gram product
# we 'normalize' the values of the gram matrix
# by dividing by the number of element in each feature maps.
return G.div(a * b * c * d)
# Implement "Image Inpainting for Irregular Holes Using Partial Convolutions", Liu et al., 2018
def style_loss(self, output, target):
output_features = self.vgg(output)
target_features = self.vgg(target)
layers = ['relu2_2', 'relu3_3', 'relu4_3'] # n_channel: 128 (=2 ** 7), 256 (=2 ** 8), 512 (=2 ** 9)
loss = 0
for i, layer in enumerate(layers):
output_feature = getattr(output_features, layer)
target_feature = getattr(target_features, layer)
B, C_P, H, W = output_feature.shape
output_gram_matrix = self.gram_matrix(output_feature)
target_gram_matrix = self.gram_matrix(target_feature)
if self.original_channel_norm:
C_P_square_divider = 2 ** (i + 1) # original design (avoid too small loss)
else:
C_P_square_divider = C_P ** 2
assert C_P == 128 * 2 ** i
loss += self.l1_loss(output_gram_matrix, target_gram_matrix) / C_P_square_divider
return loss
def forward(self, data_input, model_output):
targets = data_input
outputs = model_output
mean_image_loss = self.style_loss(outputs, targets)
return mean_image_loss
class L1LossMaskedMean(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.L1Loss(reduction='sum')
def forward(self, x, y, mask):
masked = 1 - mask # 默认missing region的mask值为0,原有区域为1
l1_sum = self.l1(x * masked, y * masked)
return l1_sum / torch.sum(masked)
class L2LossMaskedMean(nn.Module):
def __init__(self, reduction='sum'):
super().__init__()
self.l2 = nn.MSELoss(reduction=reduction)
def forward(self, x, y, mask):
masked = 1 - mask
l2_sum = self.l2(x * masked, y * masked)
return l2_sum / torch.sum(masked)
class ImcompleteVideoReconLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = L1LossMaskedMean()
def forward(self, data_input, model_output):
imcomplete_video = model_output['imcomplete_video']
targets = data_input['targets']
down_sampled_targets = nn.functional.interpolate(
targets.transpose(1, 2), scale_factor=[1, 0.5, 0.5])
masks = data_input['masks']
down_sampled_masks = nn.functional.interpolate(
masks.transpose(1, 2), scale_factor=[1, 0.5, 0.5])
return self.loss_fn(
imcomplete_video, down_sampled_targets,
down_sampled_masks
)
class CompleteFramesReconLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = L1LossMaskedMean()
def forward(self, data_input, model_output):
outputs = model_output['outputs']
targets = data_input['targets']
masks = data_input['masks']
return self.loss_fn(outputs, targets, masks)
class AdversarialLoss(nn.Module):
r"""
Adversarial loss
https://arxiv.org/abs/1711.10337
"""
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
r"""
type = nsgan | lsgan | hinge
"""
super(AdversarialLoss, self).__init__()
self.type = type
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
if type == 'nsgan':
self.criterion = nn.BCELoss()
elif type == 'lsgan':
self.criterion = nn.MSELoss()
elif type == 'hinge':
self.criterion = nn.ReLU()
def __call__(self, outputs, is_real, is_disc=None):
if self.type == 'hinge':
if is_disc:
if is_real:
outputs = -outputs
return self.criterion(1 + outputs).mean()
else:
return (-outputs).mean()
else:
labels = (self.real_label if is_real else self.fake_label).expand_as(
outputs)
loss = self.criterion(outputs, labels)
return loss
# # From https://github.com/phoenix104104/fast_blind_video_consistency
# class TemporalWarpingLoss(nn.Module):
# def __init__(self, opts, flownet_checkpoint_path=None, alpha=50):
# super().__init__()
# self.loss_fn = L1LossMaskedMean()
# self.alpha = alpha
# self.opts = opts
#
# assert flownet_checkpoint_path is not None, "Flownet2 pretrained models must be provided"
#
# self.flownet_checkpoint_path = flownet_checkpoint_path
# raise NotImplementedError
#
# def get_flownet_checkpoint_path(self):
# return self.flownet_checkpoint_path
#
# def _flownetwrapper(self):
# Flownet = FlowNet2(self.opts, requires_grad=False)
# Flownet2_ckpt = torch.load(self.flownet_checkpoint_path)
# Flownet.load_state_dict(Flownet2_ckpt['state_dict'])
# Flownet.to(device)
# Flownet.exal()
# return Flownet
#
# def _setup(self):
# self.flownet = self._flownetwrapper()
#
# def _get_non_occlusuib_mask(self, targets, warped_targets):
# non_occlusion_masks = torch.exp(
# -self.alpha * torch.sum(targets[:, 1:] - warped_targets, dim=2).pow(2)
# ).unsqueeze(2)
# return non_occlusion_masks
#
# def _get_loss(self, outputs, warped_outputs, non_occlusion_masks, masks):
# return self.loss_fn(
# outputs[:, 1:] * non_occlusion_masks,
# warped_outputs * non_occlusion_masks,
# masks[:, 1:]
# )
#
# def forward(self, data_input, model_output):
# if self.flownet is None:
# self._setup()
#
# targets = data_input['targets'].to(device)
# outputs = model_output['outputs'].to(device)
# flows = self.flownet.infer_video(targets).to(device)
#
# from utils.flow_utils import warp_optical_flow
# warped_targets = warp_optical_flow(targets[:, :-1], -flows).detach()
# warped_outputs = warp_optical_flow(outputs[:, :-1], -flows).detach()
# non_occlusion_masks = self._get_non_occlusion_mask(targets, warped_targets)
#
# # model_output is passed by name and dictionary is mutable
# # These values are sent to trainer for visualization
# model_output['warped_outputs'] = warped_outputs[0]
# model_output['warped_targets'] = warped_targets[0]
# model_output['non_occlusion_masks'] = non_occlusion_masks[0]
# from utils.flow_utils import flow_to_image
# flow_imgs = []
# for flow in flows[0]:
# flow_img = flow_to_image(flow.cpu().permute(1, 2, 0).detach().numpy()).transpose(2, 0, 1)
# flow_imgs.append(torch.Tensor(flow_img))
# model_output['flow_imgs'] = flow_imgs
#
# masks = data_input['masks'].to(device)
# return self._get_loss(outputs, warped_outputs, non_occlusion_masks, masks)
#
#
# class TemporalWarpingError(TemporalWarpingLoss):
# def __init__(self, flownet_checkpoint_path, alpha=50):
# super().__init__(flownet_checkpoint_path, alpha)
# self.loss_fn = L2LossMaskedMean(reduction='none')
#
# def _get_loss(self, outputs, warped_outputs, non_occlusion_masks, masks):
# # See https://arxiv.org/pdf/1808.00449.pdf 4.3
# # The sum of non_occlusion_masks is different for each video,
# # So the batch dim is kept
# loss = self.loss_fn(
# outputs[:, 1:] * non_occlusion_masks,
# warped_outputs * non_occlusion_masks,
# masks[:, 1:]
# ).sum(1).sum(1).sum(1).sum(1)
#
# loss = loss / non_occlusion_masks.sum(1).sum(1).sum(1).sum(1)
# return loss.sum()
class ValidLoss(nn.Module):
def __init__(self):
super(ValidLoss, self).__init__()
self.loss_fn = nn.L1Loss(reduction='mean')
def forward(self, model_output, target, mk):
outputs = model_output
targets = target
return self.loss_fn(outputs * (1 - mk), targets * (1 - mk)) # L1 loss in masked region
class TVLoss(nn.Module):
def __init__(self):
super(TVLoss, self).__init__()
def forward(self, mask_input, model_output):
# View 3D data as 2D
outputs = model_output
if len(mask_input.shape) == 4:
mask_input = mask_input.unsqueeze(2)
if len(outputs.shape) == 4:
outputs = outputs.unsqueeze(2)
outputs = outputs.permute((0, 2, 1, 3, 4)).contiguous()
masks = mask_input.permute((0, 2, 1, 3, 4)).contiguous()
B, L, C, H, W = outputs.shape
x = outputs.view([B * L, C, H, W])
masks = masks.view([B * L, -1])
mask_areas = masks.sum(dim=1)
h_x = x.size()[2]
w_x = x.size()[3]
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum(1).sum(1).sum(1) # 差分是为了求梯度,本质上还是梯度平方和
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum(1).sum(1).sum(1)
return ((h_tv + w_tv) / mask_areas).mean()
# for debug
def show_images(image, name):
import cv2
import numpy as np
image = np.array(image)
image[image > 0.5] = 255.
image = image.transpose((1, 2, 0))
cv2.imwrite(name, image)
if __name__ == '__main__':
# test align loss,
targetFrame = torch.ones(1, 3, 32, 32)
GT = torch.ones(1, 3, 32, 32)
GT += 1
mask = torch.zeros(1, 1, 32, 32)
mask[:, :, 8:24, 8:24] = 1.
# referenceFrames = torch.ones(1, 3, 4, 32, 32)
# referenceMasks = torch.zeros(1, 1, 4, 32, 32)
# referenceMasks[:, :, 0, 4:12, 4:12] = 1.
# referenceFrames[:, :, 0, 4:12, 4:12] = 2.
# referenceMasks[:, :, 1, 4:12, 20:28] = 1.
# referenceFrames[:, :, 1, 4:12, 20:28] = 2.
# referenceMasks[:, :, 2, 20:28, 4:12] = 1.
# referenceFrames[:, :, 2, 20:28, 4:12] = 2.
# referenceMasks[:, :, 3, 20:28, 20:28] = 1.
# referenceFrames[:, :, 3, 20:28, 20:28] = 2.
#
# aligned_v = referenceMasks
# aligned_v, referenceFrames = [aligned_v], [referenceFrames]
#
# result = AlignLoss()(targetFrame, mask, aligned_v, referenceFrames)
# print(result)
c_mask = torch.zeros(1, 1, 32, 32)
c_mask[:, :, 8:16, 16:24] = 1.
result1 = HoleVisibleLoss()(targetFrame, mask, GT, c_mask)
result2 = HoleInvisibleLoss()(targetFrame, mask, GT, c_mask)
result3 = NonHoleLoss()(targetFrame, mask, GT)
print('vis: {}, invis: {}, gt: {}'.format(result1, result2, result3))