oguzakif's picture
init repo
d4b77ac
raw history blame
No virus
18.5 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
from .fbConsistencyCheck import image_warp
class FlowWarpingLoss(nn.Module):
def __init__(self, metric):
super(FlowWarpingLoss, self).__init__()
self.metric = metric
def warp(self, x, flow):
"""
Args:
x: torch tensor with shape [b, c, h, w], the x can be 3 (for rgb frame) or 2 (for optical flow)
flow: torch tensor with shape [b, 2, h, w]
Returns: the warped x (can be an image or an optical flow)
"""
h, w = x.shape[2:]
device = x.device
# normalize the flow to [-1~1]
flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1) / 2), flow[:, 1:2, :, :] / ((h - 1) / 2)], dim=1)
flow = flow.permute(0, 2, 3, 1) # change to [b, h, w, c]
# generate meshgrid
x_idx = np.linspace(-1, 1, w)
y_idx = np.linspace(-1, 1, h)
X_idx, Y_idx = np.meshgrid(x_idx, y_idx)
grid = torch.cat((torch.from_numpy(X_idx.astype('float32')).unsqueeze(0).unsqueeze(3),
torch.from_numpy(Y_idx.astype('float32')).unsqueeze(0).unsqueeze(3)), 3).to(device)
output = torch.nn.functional.grid_sample(x, grid + flow, mode='bilinear', padding_mode='zeros')
return output
def __call__(self, x, y, flow, mask):
"""
image/flow warping, only support the single image/flow warping
Args:
x: Can be optical flow or image with shape [b, c, h, w], c can be 2 or 3
y: The ground truth of x (can be the extracted optical flow or image)
flow: The flow used to warp x, whose shape is [b, 2, h, w]
mask: The mask which indicates the hole of x, which must be [b, 1, h, w]
Returns: the warped image/optical flow
"""
warped_x = self.warp(x, flow)
loss = self.metric(warped_x * mask, y * mask)
return loss
class TVLoss():
# shift one pixel to get difference ( for both x and y direction)
def __init__(self):
super(TVLoss, self).__init__()
def __call__(self, x):
loss = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) + torch.mean(
torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
return loss
class WarpLoss(nn.Module):
def __init__(self):
super(WarpLoss, self).__init__()
self.metric = nn.L1Loss()
def forward(self, flow, mask, img1, img2):
"""
Args:
flow: flow indicates the motion from img1 to img2
mask: mask corresponds to img1
img1: frame 1
img2: frame t+1
Returns: warp loss from img2 to img1
"""
img2_warped = image_warp(img2, flow)
loss = self.metric(img2_warped * mask, img1 * mask)
return loss
class AdversarialLoss(nn.Module):
r"""
Adversarial loss
https://arxiv.org/abs/1711.10337
"""
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
r"""
type = nsgan | lsgan | hinge
"""
super(AdversarialLoss, self).__init__()
self.type = type
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
if type == 'nsgan':
self.criterion = nn.BCELoss()
elif type == 'lsgan':
self.criterion = nn.MSELoss()
elif type == 'hinge':
self.criterion = nn.ReLU()
def __call__(self, outputs, is_real, is_disc=None):
if self.type == 'hinge':
if is_disc:
if is_real:
outputs = -outputs
return self.criterion(1 + outputs).mean()
else:
return (-outputs).mean()
else:
labels = (self.real_label if is_real else self.fake_label).expand_as(outputs)
loss = self.criterion(outputs, labels)
return loss
class StyleLoss(nn.Module):
r"""
Perceptual loss, VGG-based
https://arxiv.org/abs/1603.08155
https://github.com/dxyang/StyleTransfer/blob/master/utils.py
"""
def __init__(self):
super(StyleLoss, self).__init__()
self.add_module('vgg', VGG19())
self.criterion = torch.nn.L1Loss()
def compute_gram(self, x):
b, ch, h, w = x.size()
f = x.view(b, ch, w * h)
f_T = f.transpose(1, 2)
G = f.bmm(f_T) / (h * w * ch)
return G
def __call__(self, x, y):
# Compute features
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
# Compute loss
style_loss = 0.0
style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2']))
style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4']))
style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4']))
style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2']))
return style_loss
class PerceptualLoss(nn.Module):
r"""
Perceptual loss, VGG-based
https://arxiv.org/abs/1603.08155
https://github.com/dxyang/StyleTransfer/blob/master/utils.py
"""
def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
super(PerceptualLoss, self).__init__()
self.add_module('vgg', VGG19())
self.criterion = torch.nn.L1Loss()
self.weights = weights
def __call__(self, x, y):
# Compute features
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
content_loss = 0.0
content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
return content_loss
class VGG19(torch.nn.Module):
def __init__(self):
super(VGG19, self).__init__()
features = models.vgg19(pretrained=True).features
self.relu1_1 = torch.nn.Sequential()
self.relu1_2 = torch.nn.Sequential()
self.relu2_1 = torch.nn.Sequential()
self.relu2_2 = torch.nn.Sequential()
self.relu3_1 = torch.nn.Sequential()
self.relu3_2 = torch.nn.Sequential()
self.relu3_3 = torch.nn.Sequential()
self.relu3_4 = torch.nn.Sequential()
self.relu4_1 = torch.nn.Sequential()
self.relu4_2 = torch.nn.Sequential()
self.relu4_3 = torch.nn.Sequential()
self.relu4_4 = torch.nn.Sequential()
self.relu5_1 = torch.nn.Sequential()
self.relu5_2 = torch.nn.Sequential()
self.relu5_3 = torch.nn.Sequential()
self.relu5_4 = torch.nn.Sequential()
for x in range(2):
self.relu1_1.add_module(str(x), features[x])
for x in range(2, 4):
self.relu1_2.add_module(str(x), features[x])
for x in range(4, 7):
self.relu2_1.add_module(str(x), features[x])
for x in range(7, 9):
self.relu2_2.add_module(str(x), features[x])
for x in range(9, 12):
self.relu3_1.add_module(str(x), features[x])
for x in range(12, 14):
self.relu3_2.add_module(str(x), features[x])
for x in range(14, 16):
self.relu3_3.add_module(str(x), features[x])
for x in range(16, 18):
self.relu3_4.add_module(str(x), features[x])
for x in range(18, 21):
self.relu4_1.add_module(str(x), features[x])
for x in range(21, 23):
self.relu4_2.add_module(str(x), features[x])
for x in range(23, 25):
self.relu4_3.add_module(str(x), features[x])
for x in range(25, 27):
self.relu4_4.add_module(str(x), features[x])
for x in range(27, 30):
self.relu5_1.add_module(str(x), features[x])
for x in range(30, 32):
self.relu5_2.add_module(str(x), features[x])
for x in range(32, 34):
self.relu5_3.add_module(str(x), features[x])
for x in range(34, 36):
self.relu5_4.add_module(str(x), features[x])
# don't need the gradients, just want the features
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
relu1_1 = self.relu1_1(x)
relu1_2 = self.relu1_2(relu1_1)
relu2_1 = self.relu2_1(relu1_2)
relu2_2 = self.relu2_2(relu2_1)
relu3_1 = self.relu3_1(relu2_2)
relu3_2 = self.relu3_2(relu3_1)
relu3_3 = self.relu3_3(relu3_2)
relu3_4 = self.relu3_4(relu3_3)
relu4_1 = self.relu4_1(relu3_4)
relu4_2 = self.relu4_2(relu4_1)
relu4_3 = self.relu4_3(relu4_2)
relu4_4 = self.relu4_4(relu4_3)
relu5_1 = self.relu5_1(relu4_4)
relu5_2 = self.relu5_2(relu5_1)
relu5_3 = self.relu5_3(relu5_2)
relu5_4 = self.relu5_4(relu5_3)
out = {
'relu1_1': relu1_1,
'relu1_2': relu1_2,
'relu2_1': relu2_1,
'relu2_2': relu2_2,
'relu3_1': relu3_1,
'relu3_2': relu3_2,
'relu3_3': relu3_3,
'relu3_4': relu3_4,
'relu4_1': relu4_1,
'relu4_2': relu4_2,
'relu4_3': relu4_3,
'relu4_4': relu4_4,
'relu5_1': relu5_1,
'relu5_2': relu5_2,
'relu5_3': relu5_3,
'relu5_4': relu5_4,
}
return out
# Some losses related to optical flows
# From Unflow: https://github.com/simonmeister/UnFlow
def fbLoss(forward_flow, backward_flow, forward_gt_flow, backward_gt_flow, fb_loss_weight, image_warp_loss_weight=0,
occ_weight=0, beta=255, first_image=None, second_image=None):
"""
calculate the forward-backward consistency loss and the related image warp loss
Args:
forward_flow: torch tensor, with shape [b, c, h, w]
backward_flow: torch tensor, with shape [b, c, h, w]
forward_gt_flow: the ground truth of the forward flow (used for occlusion calculation)
backward_gt_flow: the ground truth of the backward flow (used for occlusion calculation)
fb_loss_weight: loss weight for forward-backward consistency check between two frames
image_warp_loss_weight: loss weight for image warping
occ_weight: loss weight for occlusion area (serve as a punishment for image warp loss)
beta: 255 by default, according to the original loss codes in unflow
first_image: the previous image (extraction for the optical flows)
second_image: the later image (extraction for the optical flows)
Note: forward and backward flow should be extracted from the same image pair
Returns: forward backward consistency loss between forward and backward flow
"""
mask_fw = create_outgoing_mask(forward_flow).float()
mask_bw = create_outgoing_mask(backward_flow).float()
# forward warp backward flow and backward forward flow to calculate the cycle consistency
forward_flow_warped = image_warp(forward_flow, backward_gt_flow)
forward_flow_warped_gt = image_warp(forward_gt_flow, backward_gt_flow)
backward_flow_warped = image_warp(backward_flow, forward_gt_flow)
backward_flow_warped_gt = image_warp(backward_gt_flow, forward_gt_flow)
flow_diff_fw = backward_flow_warped + forward_flow
flow_diff_fw_gt = backward_flow_warped_gt + forward_gt_flow
flow_diff_bw = backward_flow + forward_flow_warped
flow_diff_bw_gt = backward_gt_flow + forward_flow_warped_gt
# occlusion calculation
mag_sq_fw = length_sq(forward_gt_flow) + length_sq(backward_flow_warped_gt)
mag_sq_bw = length_sq(backward_gt_flow) + length_sq(forward_flow_warped_gt)
occ_thresh_fw = 0.01 * mag_sq_fw + 0.5
occ_thresh_bw = 0.01 * mag_sq_bw + 0.5
fb_occ_fw = (length_sq(flow_diff_fw_gt) > occ_thresh_fw).float()
fb_occ_bw = (length_sq(flow_diff_bw_gt) > occ_thresh_bw).float()
mask_fw *= (1 - fb_occ_fw)
mask_bw *= (1 - fb_occ_bw)
occ_fw = 1 - mask_fw
occ_bw = 1 - mask_bw
if image_warp_loss_weight != 0:
# warp images
second_image_warped = image_warp(second_image, forward_flow) # frame 2 -> 1
first_image_warped = image_warp(first_image, backward_flow) # frame 1 -> 2
im_diff_fw = first_image - second_image_warped
im_diff_bw = second_image - first_image_warped
# calculate the image warp loss based on the occlusion regions calculated by forward and backward flows (gt)
occ_loss = occ_weight * (charbonnier_loss(occ_fw) + charbonnier_loss(occ_bw))
image_warp_loss = image_warp_loss_weight * (
charbonnier_loss(im_diff_fw, mask_fw, beta=beta) + charbonnier_loss(im_diff_bw, mask_bw,
beta=beta)) + occ_loss
else:
image_warp_loss = 0
fb_loss = fb_loss_weight * (charbonnier_loss(flow_diff_fw, mask_fw) + charbonnier_loss(flow_diff_bw, mask_bw))
return fb_loss + image_warp_loss
def length_sq(x):
return torch.sum(torch.square(x), 1, keepdim=True)
def smoothness_loss(flow, cmask):
delta_u, delta_v, mask = smoothness_deltas(flow)
loss_u = charbonnier_loss(delta_u, cmask)
loss_v = charbonnier_loss(delta_v, cmask)
return loss_u + loss_v
def smoothness_deltas(flow):
"""
flow: [b, c, h, w]
"""
mask_x = create_mask(flow, [[0, 0], [0, 1]])
mask_y = create_mask(flow, [[0, 1], [0, 0]])
mask = torch.cat((mask_x, mask_y), dim=1)
mask = mask.to(flow.device)
filter_x = torch.tensor([[0, 0, 0.], [0, 1, -1], [0, 0, 0]])
filter_y = torch.tensor([[0, 0, 0.], [0, 1, 0], [0, -1, 0]])
weights = torch.ones([2, 1, 3, 3])
weights[0, 0] = filter_x
weights[1, 0] = filter_y
weights = weights.to(flow.device)
flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1)
delta_u = F.conv2d(flow_u, weights, stride=1, padding=1)
delta_v = F.conv2d(flow_v, weights, stride=1, padding=1)
return delta_u, delta_v, mask
def second_order_loss(flow, cmask):
delta_u, delta_v, mask = second_order_deltas(flow)
loss_u = charbonnier_loss(delta_u, cmask)
loss_v = charbonnier_loss(delta_v, cmask)
return loss_u + loss_v
def charbonnier_loss(x, mask=None, truncate=None, alpha=0.45, beta=1.0, epsilon=0.001):
"""
Compute the generalized charbonnier loss of the difference tensor x
All positions where mask == 0 are not taken into account
x: a tensor of shape [b, c, h, w]
mask: a mask of shape [b, mc, h, w], where mask channels must be either 1 or the same as
the number of channels of x. Entries should be 0 or 1
return: loss
"""
b, c, h, w = x.shape
norm = b * c * h * w
error = torch.pow(torch.square(x * beta) + torch.square(torch.tensor(epsilon)), alpha)
if mask is not None:
error = mask * error
if truncate is not None:
error = torch.min(error, truncate)
return torch.sum(error) / norm
def second_order_deltas(flow):
"""
consider the single flow first
flow shape: [b, c, h, w]
"""
# create mask
mask_x = create_mask(flow, [[0, 0], [1, 1]])
mask_y = create_mask(flow, [[1, 1], [0, 0]])
mask_diag = create_mask(flow, [[1, 1], [1, 1]])
mask = torch.cat((mask_x, mask_y, mask_diag, mask_diag), dim=1)
mask = mask.to(flow.device)
filter_x = torch.tensor([[0, 0, 0.], [1, -2, 1], [0, 0, 0]])
filter_y = torch.tensor([[0, 1, 0.], [0, -2, 0], [0, 1, 0]])
filter_diag1 = torch.tensor([[1, 0, 0.], [0, -2, 0], [0, 0, 1]])
filter_diag2 = torch.tensor([[0, 0, 1.], [0, -2, 0], [1, 0, 0]])
weights = torch.ones([4, 1, 3, 3])
weights[0] = filter_x
weights[1] = filter_y
weights[2] = filter_diag1
weights[3] = filter_diag2
weights = weights.to(flow.device)
# split the flow into flow_u and flow_v, conv them with the weights
flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1)
delta_u = F.conv2d(flow_u, weights, stride=1, padding=1)
delta_v = F.conv2d(flow_v, weights, stride=1, padding=1)
return delta_u, delta_v, mask
def create_mask(tensor, paddings):
"""
tensor shape: [b, c, h, w]
paddings: [2 x 2] shape list, the first row indicates up and down paddings
the second row indicates left and right paddings
| |
| x |
| x * x |
| x |
| |
"""
shape = tensor.shape
inner_height = shape[2] - (paddings[0][0] + paddings[0][1])
inner_width = shape[3] - (paddings[1][0] + paddings[1][1])
inner = torch.ones([inner_height, inner_width])
torch_paddings = [paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]] # left, right, up and down
mask2d = F.pad(inner, pad=torch_paddings)
mask3d = mask2d.unsqueeze(0).repeat(shape[0], 1, 1)
mask4d = mask3d.unsqueeze(1)
return mask4d.detach()
def create_outgoing_mask(flow):
"""
Computes a mask that is zero at all positions where the flow would carry a pixel over the image boundary
For such pixels, it's invalid to calculate the flow losses
Args:
flow: torch tensor: with shape [b, 2, h, w]
Returns: a mask, 1 indicates in-boundary pixels, with shape [b, 1, h, w]
"""
b, c, h, w = flow.shape
grid_x = torch.reshape(torch.arange(0, w), [1, 1, w])
grid_x = grid_x.repeat(b, h, 1).float()
grid_y = torch.reshape(torch.arange(0, h), [1, h, 1])
grid_y = grid_y.repeat(b, 1, w).float()
grid_x = grid_x.to(flow.device)
grid_y = grid_y.to(flow.device)
flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1) # [b, h, w]
pos_x = grid_x + flow_u
pos_y = grid_y + flow_v
inside_x = torch.logical_and(pos_x <= w - 1, pos_x >= 0)
inside_y = torch.logical_and(pos_y <= h - 1, pos_y >= 0)
inside = torch.logical_and(inside_x, inside_y)
if len(inside.shape) == 3:
inside = inside.unsqueeze(1)
return inside