Spaces:
Running
on
T4
Running
on
T4
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
import numpy as np | |
from .fbConsistencyCheck import image_warp | |
class FlowWarpingLoss(nn.Module): | |
def __init__(self, metric): | |
super(FlowWarpingLoss, self).__init__() | |
self.metric = metric | |
def warp(self, x, flow): | |
""" | |
Args: | |
x: torch tensor with shape [b, c, h, w], the x can be 3 (for rgb frame) or 2 (for optical flow) | |
flow: torch tensor with shape [b, 2, h, w] | |
Returns: the warped x (can be an image or an optical flow) | |
""" | |
h, w = x.shape[2:] | |
device = x.device | |
# normalize the flow to [-1~1] | |
flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1) / 2), flow[:, 1:2, :, :] / ((h - 1) / 2)], dim=1) | |
flow = flow.permute(0, 2, 3, 1) # change to [b, h, w, c] | |
# generate meshgrid | |
x_idx = np.linspace(-1, 1, w) | |
y_idx = np.linspace(-1, 1, h) | |
X_idx, Y_idx = np.meshgrid(x_idx, y_idx) | |
grid = torch.cat((torch.from_numpy(X_idx.astype('float32')).unsqueeze(0).unsqueeze(3), | |
torch.from_numpy(Y_idx.astype('float32')).unsqueeze(0).unsqueeze(3)), 3).to(device) | |
output = torch.nn.functional.grid_sample(x, grid + flow, mode='bilinear', padding_mode='zeros') | |
return output | |
def __call__(self, x, y, flow, mask): | |
""" | |
image/flow warping, only support the single image/flow warping | |
Args: | |
x: Can be optical flow or image with shape [b, c, h, w], c can be 2 or 3 | |
y: The ground truth of x (can be the extracted optical flow or image) | |
flow: The flow used to warp x, whose shape is [b, 2, h, w] | |
mask: The mask which indicates the hole of x, which must be [b, 1, h, w] | |
Returns: the warped image/optical flow | |
""" | |
warped_x = self.warp(x, flow) | |
loss = self.metric(warped_x * mask, y * mask) | |
return loss | |
class TVLoss(): | |
# shift one pixel to get difference ( for both x and y direction) | |
def __init__(self): | |
super(TVLoss, self).__init__() | |
def __call__(self, x): | |
loss = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) + torch.mean( | |
torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) | |
return loss | |
class WarpLoss(nn.Module): | |
def __init__(self): | |
super(WarpLoss, self).__init__() | |
self.metric = nn.L1Loss() | |
def forward(self, flow, mask, img1, img2): | |
""" | |
Args: | |
flow: flow indicates the motion from img1 to img2 | |
mask: mask corresponds to img1 | |
img1: frame 1 | |
img2: frame t+1 | |
Returns: warp loss from img2 to img1 | |
""" | |
img2_warped = image_warp(img2, flow) | |
loss = self.metric(img2_warped * mask, img1 * mask) | |
return loss | |
class AdversarialLoss(nn.Module): | |
r""" | |
Adversarial loss | |
https://arxiv.org/abs/1711.10337 | |
""" | |
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0): | |
r""" | |
type = nsgan | lsgan | hinge | |
""" | |
super(AdversarialLoss, self).__init__() | |
self.type = type | |
self.register_buffer('real_label', torch.tensor(target_real_label)) | |
self.register_buffer('fake_label', torch.tensor(target_fake_label)) | |
if type == 'nsgan': | |
self.criterion = nn.BCELoss() | |
elif type == 'lsgan': | |
self.criterion = nn.MSELoss() | |
elif type == 'hinge': | |
self.criterion = nn.ReLU() | |
def __call__(self, outputs, is_real, is_disc=None): | |
if self.type == 'hinge': | |
if is_disc: | |
if is_real: | |
outputs = -outputs | |
return self.criterion(1 + outputs).mean() | |
else: | |
return (-outputs).mean() | |
else: | |
labels = (self.real_label if is_real else self.fake_label).expand_as(outputs) | |
loss = self.criterion(outputs, labels) | |
return loss | |
class StyleLoss(nn.Module): | |
r""" | |
Perceptual loss, VGG-based | |
https://arxiv.org/abs/1603.08155 | |
https://github.com/dxyang/StyleTransfer/blob/master/utils.py | |
""" | |
def __init__(self): | |
super(StyleLoss, self).__init__() | |
self.add_module('vgg', VGG19()) | |
self.criterion = torch.nn.L1Loss() | |
def compute_gram(self, x): | |
b, ch, h, w = x.size() | |
f = x.view(b, ch, w * h) | |
f_T = f.transpose(1, 2) | |
G = f.bmm(f_T) / (h * w * ch) | |
return G | |
def __call__(self, x, y): | |
# Compute features | |
x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
# Compute loss | |
style_loss = 0.0 | |
style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2'])) | |
style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4'])) | |
style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4'])) | |
style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2'])) | |
return style_loss | |
class PerceptualLoss(nn.Module): | |
r""" | |
Perceptual loss, VGG-based | |
https://arxiv.org/abs/1603.08155 | |
https://github.com/dxyang/StyleTransfer/blob/master/utils.py | |
""" | |
def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): | |
super(PerceptualLoss, self).__init__() | |
self.add_module('vgg', VGG19()) | |
self.criterion = torch.nn.L1Loss() | |
self.weights = weights | |
def __call__(self, x, y): | |
# Compute features | |
x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
content_loss = 0.0 | |
content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1']) | |
content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1']) | |
content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1']) | |
content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1']) | |
content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1']) | |
return content_loss | |
class VGG19(torch.nn.Module): | |
def __init__(self): | |
super(VGG19, self).__init__() | |
features = models.vgg19(pretrained=True).features | |
self.relu1_1 = torch.nn.Sequential() | |
self.relu1_2 = torch.nn.Sequential() | |
self.relu2_1 = torch.nn.Sequential() | |
self.relu2_2 = torch.nn.Sequential() | |
self.relu3_1 = torch.nn.Sequential() | |
self.relu3_2 = torch.nn.Sequential() | |
self.relu3_3 = torch.nn.Sequential() | |
self.relu3_4 = torch.nn.Sequential() | |
self.relu4_1 = torch.nn.Sequential() | |
self.relu4_2 = torch.nn.Sequential() | |
self.relu4_3 = torch.nn.Sequential() | |
self.relu4_4 = torch.nn.Sequential() | |
self.relu5_1 = torch.nn.Sequential() | |
self.relu5_2 = torch.nn.Sequential() | |
self.relu5_3 = torch.nn.Sequential() | |
self.relu5_4 = torch.nn.Sequential() | |
for x in range(2): | |
self.relu1_1.add_module(str(x), features[x]) | |
for x in range(2, 4): | |
self.relu1_2.add_module(str(x), features[x]) | |
for x in range(4, 7): | |
self.relu2_1.add_module(str(x), features[x]) | |
for x in range(7, 9): | |
self.relu2_2.add_module(str(x), features[x]) | |
for x in range(9, 12): | |
self.relu3_1.add_module(str(x), features[x]) | |
for x in range(12, 14): | |
self.relu3_2.add_module(str(x), features[x]) | |
for x in range(14, 16): | |
self.relu3_3.add_module(str(x), features[x]) | |
for x in range(16, 18): | |
self.relu3_4.add_module(str(x), features[x]) | |
for x in range(18, 21): | |
self.relu4_1.add_module(str(x), features[x]) | |
for x in range(21, 23): | |
self.relu4_2.add_module(str(x), features[x]) | |
for x in range(23, 25): | |
self.relu4_3.add_module(str(x), features[x]) | |
for x in range(25, 27): | |
self.relu4_4.add_module(str(x), features[x]) | |
for x in range(27, 30): | |
self.relu5_1.add_module(str(x), features[x]) | |
for x in range(30, 32): | |
self.relu5_2.add_module(str(x), features[x]) | |
for x in range(32, 34): | |
self.relu5_3.add_module(str(x), features[x]) | |
for x in range(34, 36): | |
self.relu5_4.add_module(str(x), features[x]) | |
# don't need the gradients, just want the features | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, x): | |
relu1_1 = self.relu1_1(x) | |
relu1_2 = self.relu1_2(relu1_1) | |
relu2_1 = self.relu2_1(relu1_2) | |
relu2_2 = self.relu2_2(relu2_1) | |
relu3_1 = self.relu3_1(relu2_2) | |
relu3_2 = self.relu3_2(relu3_1) | |
relu3_3 = self.relu3_3(relu3_2) | |
relu3_4 = self.relu3_4(relu3_3) | |
relu4_1 = self.relu4_1(relu3_4) | |
relu4_2 = self.relu4_2(relu4_1) | |
relu4_3 = self.relu4_3(relu4_2) | |
relu4_4 = self.relu4_4(relu4_3) | |
relu5_1 = self.relu5_1(relu4_4) | |
relu5_2 = self.relu5_2(relu5_1) | |
relu5_3 = self.relu5_3(relu5_2) | |
relu5_4 = self.relu5_4(relu5_3) | |
out = { | |
'relu1_1': relu1_1, | |
'relu1_2': relu1_2, | |
'relu2_1': relu2_1, | |
'relu2_2': relu2_2, | |
'relu3_1': relu3_1, | |
'relu3_2': relu3_2, | |
'relu3_3': relu3_3, | |
'relu3_4': relu3_4, | |
'relu4_1': relu4_1, | |
'relu4_2': relu4_2, | |
'relu4_3': relu4_3, | |
'relu4_4': relu4_4, | |
'relu5_1': relu5_1, | |
'relu5_2': relu5_2, | |
'relu5_3': relu5_3, | |
'relu5_4': relu5_4, | |
} | |
return out | |
# Some losses related to optical flows | |
# From Unflow: https://github.com/simonmeister/UnFlow | |
def fbLoss(forward_flow, backward_flow, forward_gt_flow, backward_gt_flow, fb_loss_weight, image_warp_loss_weight=0, | |
occ_weight=0, beta=255, first_image=None, second_image=None): | |
""" | |
calculate the forward-backward consistency loss and the related image warp loss | |
Args: | |
forward_flow: torch tensor, with shape [b, c, h, w] | |
backward_flow: torch tensor, with shape [b, c, h, w] | |
forward_gt_flow: the ground truth of the forward flow (used for occlusion calculation) | |
backward_gt_flow: the ground truth of the backward flow (used for occlusion calculation) | |
fb_loss_weight: loss weight for forward-backward consistency check between two frames | |
image_warp_loss_weight: loss weight for image warping | |
occ_weight: loss weight for occlusion area (serve as a punishment for image warp loss) | |
beta: 255 by default, according to the original loss codes in unflow | |
first_image: the previous image (extraction for the optical flows) | |
second_image: the later image (extraction for the optical flows) | |
Note: forward and backward flow should be extracted from the same image pair | |
Returns: forward backward consistency loss between forward and backward flow | |
""" | |
mask_fw = create_outgoing_mask(forward_flow).float() | |
mask_bw = create_outgoing_mask(backward_flow).float() | |
# forward warp backward flow and backward forward flow to calculate the cycle consistency | |
forward_flow_warped = image_warp(forward_flow, backward_gt_flow) | |
forward_flow_warped_gt = image_warp(forward_gt_flow, backward_gt_flow) | |
backward_flow_warped = image_warp(backward_flow, forward_gt_flow) | |
backward_flow_warped_gt = image_warp(backward_gt_flow, forward_gt_flow) | |
flow_diff_fw = backward_flow_warped + forward_flow | |
flow_diff_fw_gt = backward_flow_warped_gt + forward_gt_flow | |
flow_diff_bw = backward_flow + forward_flow_warped | |
flow_diff_bw_gt = backward_gt_flow + forward_flow_warped_gt | |
# occlusion calculation | |
mag_sq_fw = length_sq(forward_gt_flow) + length_sq(backward_flow_warped_gt) | |
mag_sq_bw = length_sq(backward_gt_flow) + length_sq(forward_flow_warped_gt) | |
occ_thresh_fw = 0.01 * mag_sq_fw + 0.5 | |
occ_thresh_bw = 0.01 * mag_sq_bw + 0.5 | |
fb_occ_fw = (length_sq(flow_diff_fw_gt) > occ_thresh_fw).float() | |
fb_occ_bw = (length_sq(flow_diff_bw_gt) > occ_thresh_bw).float() | |
mask_fw *= (1 - fb_occ_fw) | |
mask_bw *= (1 - fb_occ_bw) | |
occ_fw = 1 - mask_fw | |
occ_bw = 1 - mask_bw | |
if image_warp_loss_weight != 0: | |
# warp images | |
second_image_warped = image_warp(second_image, forward_flow) # frame 2 -> 1 | |
first_image_warped = image_warp(first_image, backward_flow) # frame 1 -> 2 | |
im_diff_fw = first_image - second_image_warped | |
im_diff_bw = second_image - first_image_warped | |
# calculate the image warp loss based on the occlusion regions calculated by forward and backward flows (gt) | |
occ_loss = occ_weight * (charbonnier_loss(occ_fw) + charbonnier_loss(occ_bw)) | |
image_warp_loss = image_warp_loss_weight * ( | |
charbonnier_loss(im_diff_fw, mask_fw, beta=beta) + charbonnier_loss(im_diff_bw, mask_bw, | |
beta=beta)) + occ_loss | |
else: | |
image_warp_loss = 0 | |
fb_loss = fb_loss_weight * (charbonnier_loss(flow_diff_fw, mask_fw) + charbonnier_loss(flow_diff_bw, mask_bw)) | |
return fb_loss + image_warp_loss | |
def length_sq(x): | |
return torch.sum(torch.square(x), 1, keepdim=True) | |
def smoothness_loss(flow, cmask): | |
delta_u, delta_v, mask = smoothness_deltas(flow) | |
loss_u = charbonnier_loss(delta_u, cmask) | |
loss_v = charbonnier_loss(delta_v, cmask) | |
return loss_u + loss_v | |
def smoothness_deltas(flow): | |
""" | |
flow: [b, c, h, w] | |
""" | |
mask_x = create_mask(flow, [[0, 0], [0, 1]]) | |
mask_y = create_mask(flow, [[0, 1], [0, 0]]) | |
mask = torch.cat((mask_x, mask_y), dim=1) | |
mask = mask.to(flow.device) | |
filter_x = torch.tensor([[0, 0, 0.], [0, 1, -1], [0, 0, 0]]) | |
filter_y = torch.tensor([[0, 0, 0.], [0, 1, 0], [0, -1, 0]]) | |
weights = torch.ones([2, 1, 3, 3]) | |
weights[0, 0] = filter_x | |
weights[1, 0] = filter_y | |
weights = weights.to(flow.device) | |
flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1) | |
delta_u = F.conv2d(flow_u, weights, stride=1, padding=1) | |
delta_v = F.conv2d(flow_v, weights, stride=1, padding=1) | |
return delta_u, delta_v, mask | |
def second_order_loss(flow, cmask): | |
delta_u, delta_v, mask = second_order_deltas(flow) | |
loss_u = charbonnier_loss(delta_u, cmask) | |
loss_v = charbonnier_loss(delta_v, cmask) | |
return loss_u + loss_v | |
def charbonnier_loss(x, mask=None, truncate=None, alpha=0.45, beta=1.0, epsilon=0.001): | |
""" | |
Compute the generalized charbonnier loss of the difference tensor x | |
All positions where mask == 0 are not taken into account | |
x: a tensor of shape [b, c, h, w] | |
mask: a mask of shape [b, mc, h, w], where mask channels must be either 1 or the same as | |
the number of channels of x. Entries should be 0 or 1 | |
return: loss | |
""" | |
b, c, h, w = x.shape | |
norm = b * c * h * w | |
error = torch.pow(torch.square(x * beta) + torch.square(torch.tensor(epsilon)), alpha) | |
if mask is not None: | |
error = mask * error | |
if truncate is not None: | |
error = torch.min(error, truncate) | |
return torch.sum(error) / norm | |
def second_order_deltas(flow): | |
""" | |
consider the single flow first | |
flow shape: [b, c, h, w] | |
""" | |
# create mask | |
mask_x = create_mask(flow, [[0, 0], [1, 1]]) | |
mask_y = create_mask(flow, [[1, 1], [0, 0]]) | |
mask_diag = create_mask(flow, [[1, 1], [1, 1]]) | |
mask = torch.cat((mask_x, mask_y, mask_diag, mask_diag), dim=1) | |
mask = mask.to(flow.device) | |
filter_x = torch.tensor([[0, 0, 0.], [1, -2, 1], [0, 0, 0]]) | |
filter_y = torch.tensor([[0, 1, 0.], [0, -2, 0], [0, 1, 0]]) | |
filter_diag1 = torch.tensor([[1, 0, 0.], [0, -2, 0], [0, 0, 1]]) | |
filter_diag2 = torch.tensor([[0, 0, 1.], [0, -2, 0], [1, 0, 0]]) | |
weights = torch.ones([4, 1, 3, 3]) | |
weights[0] = filter_x | |
weights[1] = filter_y | |
weights[2] = filter_diag1 | |
weights[3] = filter_diag2 | |
weights = weights.to(flow.device) | |
# split the flow into flow_u and flow_v, conv them with the weights | |
flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1) | |
delta_u = F.conv2d(flow_u, weights, stride=1, padding=1) | |
delta_v = F.conv2d(flow_v, weights, stride=1, padding=1) | |
return delta_u, delta_v, mask | |
def create_mask(tensor, paddings): | |
""" | |
tensor shape: [b, c, h, w] | |
paddings: [2 x 2] shape list, the first row indicates up and down paddings | |
the second row indicates left and right paddings | |
| | | |
| x | | |
| x * x | | |
| x | | |
| | | |
""" | |
shape = tensor.shape | |
inner_height = shape[2] - (paddings[0][0] + paddings[0][1]) | |
inner_width = shape[3] - (paddings[1][0] + paddings[1][1]) | |
inner = torch.ones([inner_height, inner_width]) | |
torch_paddings = [paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]] # left, right, up and down | |
mask2d = F.pad(inner, pad=torch_paddings) | |
mask3d = mask2d.unsqueeze(0).repeat(shape[0], 1, 1) | |
mask4d = mask3d.unsqueeze(1) | |
return mask4d.detach() | |
def create_outgoing_mask(flow): | |
""" | |
Computes a mask that is zero at all positions where the flow would carry a pixel over the image boundary | |
For such pixels, it's invalid to calculate the flow losses | |
Args: | |
flow: torch tensor: with shape [b, 2, h, w] | |
Returns: a mask, 1 indicates in-boundary pixels, with shape [b, 1, h, w] | |
""" | |
b, c, h, w = flow.shape | |
grid_x = torch.reshape(torch.arange(0, w), [1, 1, w]) | |
grid_x = grid_x.repeat(b, h, 1).float() | |
grid_y = torch.reshape(torch.arange(0, h), [1, h, 1]) | |
grid_y = grid_y.repeat(b, 1, w).float() | |
grid_x = grid_x.to(flow.device) | |
grid_y = grid_y.to(flow.device) | |
flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1) # [b, h, w] | |
pos_x = grid_x + flow_u | |
pos_y = grid_y + flow_v | |
inside_x = torch.logical_and(pos_x <= w - 1, pos_x >= 0) | |
inside_y = torch.logical_and(pos_y <= h - 1, pos_y >= 0) | |
inside = torch.logical_and(inside_x, inside_y) | |
if len(inside.shape) == 3: | |
inside = inside.unsqueeze(1) | |
return inside | |