Spaces:
Sleeping
Sleeping
# Copyright Niantic 2019. Patent Pending. All rights reserved. | |
# | |
# This software is licensed under the terms of the Monodepth2 licence | |
# which allows for non-commercial use only, the full terms of which are made | |
# available in the LICENSE file. | |
from __future__ import absolute_import, division, print_function | |
import numpy as np | |
from scipy.spatial.transform import Rotation as R | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# from torchmetrics.image.fid import FrechetInceptionDistance | |
# def silog(real1, fake1): | |
# # filter out invalid pixels | |
# real = real1.clone() | |
# fake = fake1.clone() | |
# N = (real>0).float().sum() | |
# mask1 = (real<=0) | |
# mask2 = (fake<=0) | |
# mask3 = mask1+mask2 | |
# # mask = 1.0 - (mask3>0).float() | |
# mask = (mask3>0) | |
# fake[mask] = 1. | |
# real[mask] = 1. | |
# loss_ = torch.log(real)-torch.log(fake) | |
# loss = torch.sqrt((torch.sum( loss_ ** 2) / N ) - ((torch.sum(loss_)/N)**2)) | |
# return loss | |
class SpatialTransformer(nn.Module): | |
def __init__(self, size, mode='bilinear'): | |
""" | |
Instiantiate the block | |
:param size: size of input to the spatial transformer block | |
:param mode: method of interpolation for grid_sampler | |
""" | |
super(SpatialTransformer, self).__init__() | |
# Create sampling grid | |
vectors = [torch.arange(0, s) for s in size] | |
grids = torch.meshgrid(vectors) | |
grid = torch.stack(grids) # y, x, z | |
grid = torch.unsqueeze(grid, 0) # add batch | |
grid = grid.type(torch.FloatTensor) | |
self.register_buffer('grid', grid) | |
self.mode = mode | |
def forward(self, src, flow): | |
""" | |
Push the src and flow through the spatial transform block | |
:param src: the source image | |
:param flow: the output from the U-Net | |
""" | |
new_locs = self.grid + flow | |
shape = flow.shape[2:] | |
# Need to normalize grid values to [-1, 1] for resampler | |
for i in range(len(shape)): | |
new_locs[:, i, ...] = 2*(new_locs[:, i, ...]/(shape[i]-1) - 0.5) | |
if len(shape) == 2: | |
new_locs = new_locs.permute(0, 2, 3, 1) | |
new_locs = new_locs[..., [1, 0]] | |
elif len(shape) == 3: | |
new_locs = new_locs.permute(0, 2, 3, 4, 1) | |
new_locs = new_locs[..., [2, 1, 0]] | |
return F.grid_sample(src, new_locs, mode=self.mode, padding_mode="border") | |
class optical_flow(nn.Module): | |
def __init__(self, size, batch_size, height, width, eps=1e-7): | |
super(optical_flow, self).__init__() | |
# Create sampling grid | |
vectors = [torch.arange(0, s) for s in size] | |
grids = torch.meshgrid(vectors) | |
grid = torch.stack(grids) # y, x, z | |
grid = torch.unsqueeze(grid, 0) # add batch | |
grid = grid.type(torch.FloatTensor) | |
self.register_buffer('grid', grid) | |
self.batch_size = batch_size | |
self.height = height | |
self.width = width | |
self.eps = eps | |
def forward(self, points, K, T): | |
P = torch.matmul(K, T)[:, :3, :] | |
cam_points = torch.matmul(P, points) | |
pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) | |
pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) | |
optical_flow = pix_coords[:, [1,0], ...] - self.grid | |
return optical_flow | |
def get_corresponding_map(data): | |
""" | |
:param data: unnormalized coordinates Bx2xHxW | |
:return: Bx1xHxW | |
""" | |
B, _, H, W = data.size() | |
# x = data[:, 0, :, :].view(B, -1).clamp(0, W - 1) # BxN (N=H*W) | |
# y = data[:, 1, :, :].view(B, -1).clamp(0, H - 1) | |
x = data[:, 0, :, :].view(B, -1) # BxN (N=H*W) | |
y = data[:, 1, :, :].view(B, -1) | |
# invalid = (x < 0) | (x > W - 1) | (y < 0) | (y > H - 1) # BxN | |
# invalid = invalid.repeat([1, 4]) | |
x1 = torch.floor(x) | |
x_floor = x1.clamp(0, W - 1) | |
y1 = torch.floor(y) | |
y_floor = y1.clamp(0, H - 1) | |
x0 = x1 + 1 | |
x_ceil = x0.clamp(0, W - 1) | |
y0 = y1 + 1 | |
y_ceil = y0.clamp(0, H - 1) | |
x_ceil_out = x0 != x_ceil | |
y_ceil_out = y0 != y_ceil | |
x_floor_out = x1 != x_floor | |
y_floor_out = y1 != y_floor | |
invalid = torch.cat([x_ceil_out | y_ceil_out, | |
x_ceil_out | y_floor_out, | |
x_floor_out | y_ceil_out, | |
x_floor_out | y_floor_out], dim=1) | |
# encode coordinates, since the scatter function can only index along one axis | |
corresponding_map = torch.zeros(B, H * W).type_as(data) | |
indices = torch.cat([x_ceil + y_ceil * W, | |
x_ceil + y_floor * W, | |
x_floor + y_ceil * W, | |
x_floor + y_floor * W], 1).long() # BxN (N=4*H*W) | |
values = torch.cat([(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_ceil)), | |
(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_floor)), | |
(1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_ceil)), | |
(1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_floor))], | |
1) | |
# values = torch.ones_like(values) | |
values[invalid] = 0 | |
corresponding_map.scatter_add_(1, indices, values) | |
# decode coordinates | |
corresponding_map = corresponding_map.view(B, H, W) | |
return corresponding_map.unsqueeze(1) | |
class get_occu_mask_backward(nn.Module): | |
def __init__(self, size): | |
super(get_occu_mask_backward, self).__init__() | |
# Create sampling grid | |
vectors = [torch.arange(0, s) for s in size] | |
grids = torch.meshgrid(vectors) | |
grid = torch.stack(grids) # y, x, z | |
grid = torch.unsqueeze(grid, 0) # add batch | |
grid = grid.type(torch.FloatTensor) | |
self.register_buffer('grid', grid) | |
def forward(self, flow, th=0.95): | |
new_locs = self.grid + flow | |
new_locs = new_locs[:, [1,0], ...] | |
corr_map = get_corresponding_map(new_locs) | |
occu_map = corr_map | |
occu_mask = (occu_map > th).float() | |
return occu_mask, occu_map | |
class get_occu_mask_bidirection(nn.Module): | |
def __init__(self, size, mode='bilinear'): | |
super(get_occu_mask_bidirection, self).__init__() | |
# Create sampling grid | |
vectors = [torch.arange(0, s) for s in size] | |
grids = torch.meshgrid(vectors) | |
grid = torch.stack(grids) # y, x, z | |
grid = torch.unsqueeze(grid, 0) # add batch | |
grid = grid.type(torch.FloatTensor) | |
self.register_buffer('grid', grid) | |
self.mode = mode | |
def forward(self, flow12, flow21, scale=0.01, bias=0.5): | |
new_locs = self.grid + flow12 | |
shape = flow12.shape[2:] | |
# Need to normalize grid values to [-1, 1] for resampler | |
for i in range(len(shape)): | |
new_locs[:, i, ...] = 2*(new_locs[:, i, ...]/(shape[i]-1) - 0.5) | |
if len(shape) == 2: | |
new_locs = new_locs.permute(0, 2, 3, 1) | |
new_locs = new_locs[..., [1, 0]] | |
elif len(shape) == 3: | |
new_locs = new_locs.permute(0, 2, 3, 4, 1) | |
new_locs = new_locs[..., [2, 1, 0]] | |
flow21_warped = F.grid_sample(flow21, new_locs, mode=self.mode, padding_mode="border") | |
flow12_diff = torch.abs(flow12 + flow21_warped) | |
# mag = (flow12 * flow12).sum(1, keepdim=True) + \ | |
# (flow21_warped * flow21_warped).sum(1, keepdim=True) | |
# occ_thresh = scale * mag + bias | |
# occ_mask = (flow12_diff * flow12_diff).sum(1, keepdim=True) < occ_thresh | |
return flow12_diff | |
# functions | |
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: | |
""" | |
Return the rotation matrices for one of the rotations about an axis | |
of which Euler angles describe, for each value of the angle given. | |
Args: | |
axis: Axis label "X" or "Y or "Z". | |
angle: any shape tensor of Euler angles in radians | |
Returns: | |
Rotation matrices as tensor of shape (..., 3, 3). | |
""" | |
cos = torch.cos(angle) | |
sin = torch.sin(angle) | |
one = torch.ones_like(angle) | |
zero = torch.zeros_like(angle) | |
if axis == "X": | |
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) | |
elif axis == "Y": | |
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) | |
elif axis == "Z": | |
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) | |
else: | |
raise ValueError("letter must be either X, Y or Z.") | |
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) | |
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: | |
""" | |
Convert rotations given as Euler angles in radians to rotation matrices. | |
Args: | |
euler_angles: Euler angles in radians as tensor of shape (..., 3). | |
convention: Convention string of three uppercase letters from | |
{"X", "Y", and "Z"}. | |
Returns: | |
Rotation matrices as tensor of shape (..., 3, 3). | |
""" | |
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: | |
raise ValueError("Invalid input euler angles.") | |
if len(convention) != 3: | |
raise ValueError("Convention must have 3 letters.") | |
if convention[1] in (convention[0], convention[2]): | |
raise ValueError(f"Invalid convention {convention}.") | |
for letter in convention: | |
if letter not in ("X", "Y", "Z"): | |
raise ValueError(f"Invalid letter {letter} in convention string.") | |
matrices = [ | |
_axis_angle_rotation(c, e) | |
for c, e in zip(convention, torch.unbind(euler_angles, -1)) | |
] | |
# return functools.reduce(torch.matmul, matrices) | |
rotation_matrices = torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) | |
rot = torch.zeros((rotation_matrices.shape[0], 4, 4)).to(device=rotation_matrices.device) | |
rot[:, :3, :3] = rotation_matrices.squeeze() | |
rot[:, 3, 3] = 1 | |
return rot | |
def _angle_from_tan( | |
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool | |
) -> torch.Tensor: | |
""" | |
Extract the first or third Euler angle from the two members of | |
the matrix which are positive constant times its sine and cosine. | |
Args: | |
axis: Axis label "X" or "Y or "Z" for the angle we are finding. | |
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the | |
convention. | |
data: Rotation matrices as tensor of shape (..., 3, 3). | |
horizontal: Whether we are looking for the angle for the third axis, | |
which means the relevant entries are in the same row of the | |
rotation matrix. If not, they are in the same column. | |
tait_bryan: Whether the first and third axes in the convention differ. | |
Returns: | |
Euler Angles in radians for each matrix in data as a tensor | |
of shape (...). | |
""" | |
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] | |
if horizontal: | |
i2, i1 = i1, i2 | |
even = (axis + other_axis) in ["XY", "YZ", "ZX"] | |
if horizontal == even: | |
return torch.atan2(data[..., i1], data[..., i2]) | |
if tait_bryan: | |
return torch.atan2(-data[..., i2], data[..., i1]) | |
return torch.atan2(data[..., i2], -data[..., i1]) | |
def matrix_2_euler_vector(matrix, convention = 'ZYX', roll = True): | |
# matrix = matrix_in.copy() | |
euler = (matrix_to_euler_angles(matrix[:, :3,:3], convention)) # to match with scipy euler = -euler and transpose of this | |
if roll: | |
euler[0] = 0.0 | |
t = matrix[:, :3,3] | |
out = torch.cat([euler, t], dim = 0) | |
return out | |
def _index_from_letter(letter: str) -> int: | |
if letter == "X": | |
return 0 | |
if letter == "Y": | |
return 1 | |
if letter == "Z": | |
return 2 | |
raise ValueError("letter must be either X, Y or Z.") | |
def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: | |
""" | |
Convert rotations given as rotation matrices to Euler angles in radians. | |
Args: | |
matrix: Rotation matrices as tensor of shape (..., 3, 3). | |
convention: Convention string of three uppercase letters. | |
Returns: | |
Euler angles in radians as tensor of shape (..., 3). | |
""" | |
if len(convention) != 3: | |
raise ValueError("Convention must have 3 letters.") | |
if convention[1] in (convention[0], convention[2]): | |
raise ValueError(f"Invalid convention {convention}.") | |
for letter in convention: | |
if letter not in ("X", "Y", "Z"): | |
raise ValueError(f"Invalid letter {letter} in convention string.") | |
if matrix.size(-1) != 3 or matrix.size(-2) != 3: | |
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") | |
i0 = _index_from_letter(convention[0]) | |
i2 = _index_from_letter(convention[2]) | |
tait_bryan = i0 != i2 | |
if tait_bryan: | |
central_angle = torch.asin( | |
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) | |
) | |
else: | |
central_angle = torch.acos(matrix[..., i0, i0]) | |
o = ( | |
_angle_from_tan( | |
convention[0], convention[1], matrix[..., i2], False, tait_bryan | |
), | |
central_angle, | |
_angle_from_tan( | |
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan | |
), | |
) | |
return torch.stack(o, -1) | |
def computeFID(real_images, fake_images, fid_criterion): | |
# metric = FrechetInceptionDistance(feature) | |
fid_criterion.update(real_images, real=True) | |
fid_criterion.update(fake_images, real=False) | |
return fid_criterion.compute() | |
class SLlog(nn.Module): | |
def __init__(self): | |
super(SLlog, self).__init__() | |
def forward(self, fake1, real1): | |
if not fake1.shape == real1.shape: | |
_,_,H,W = real1.shape | |
fake = F.upsample(fake, size=(H,W), mode='bilinear') | |
# filter out invalid pixels | |
real = real1.clone() | |
fake = fake1.clone() | |
N = (real>0).float().sum() | |
mask1 = (real<=0) | |
mask2 = (fake<=0) | |
mask3 = mask1+mask2 | |
# mask = 1.0 - (mask3>0).float() | |
mask = (mask3>0) | |
fake[mask] = 1. | |
real[mask] = 1. | |
loss_ = torch.log(real)-torch.log(fake) | |
loss = torch.sqrt((torch.sum( loss_ ** 2) / N ) - ((torch.sum(loss_)/N)**2)) | |
# loss = 100.* torch.sum( torch.abs(torch.log(real)-torch.log(fake)) ) / N | |
return loss | |
class RMSE_log(nn.Module): | |
def __init__(self, use_cuda): | |
super(RMSE_log, self).__init__() | |
self.eps = 1e-8 | |
self.use_cuda = use_cuda | |
def forward(self, fake, real): | |
mask = real<1. | |
n,_,h,w = real.size() | |
fake = F.upsample(fake, size=(h,w), mode='bilinear') | |
fake += self.eps | |
N = len(real[mask]) | |
loss = torch.sqrt( torch.sum( torch.abs(torch.log(real[mask])-torch.log(fake[mask])) ** 2 ) / N ) | |
return loss | |
def depth_to_disp(depth, min_disp=0.00001, max_disp = 1.000001): | |
"""Convert network's sigmoid output into depth prediction | |
The formula for this conversion is given in the 'additional considerations' | |
section of the paper. | |
""" | |
min_depth = 1 / max_disp | |
max_depth = 1 / min_disp | |
scaled_depth = min_depth + (max_depth - min_depth) * depth | |
disp = 1 / scaled_depth | |
return scaled_depth, disp | |
def disp_to_depth(disp, min_depth, max_depth): | |
"""Convert network's sigmoid output into depth prediction | |
The formula for this conversion is given in the 'additional considerations' | |
section of the paper. | |
""" | |
min_disp = 1 / max_depth | |
max_disp = 1 / min_depth | |
scaled_disp = min_disp + (max_disp - min_disp) * disp | |
depth = 1 / scaled_disp | |
return scaled_disp, depth | |
def disp_to_depth_no_scaling(disp): | |
"""Convert network's sigmoid output into depth prediction | |
The formula for this conversion is given in the 'additional considerations' | |
section of the paper. | |
""" | |
depth = 1 / (disp + 1e-7) | |
return depth | |
def transformation_from_parameters(axisangle, translation, invert=False): | |
"""Convert the network's (axisangle, translation) output into a 4x4 matrix | |
""" | |
R = rot_from_axisangle(axisangle) | |
t = translation.clone() | |
if invert: | |
R = R.transpose(1, 2) # uncomment beore running | |
t *= -1 | |
T = get_translation_matrix(t) | |
if invert: | |
M = torch.matmul(R, T) | |
else: | |
M = torch.matmul(T, R) | |
return M | |
def transformation_from_parameters_euler(euler, translation, invert=False): | |
"""Convert the network's (axisangle, translation) output into a 4x4 matrix | |
""" | |
# R = torch.transpose(euler_angles_to_matrix(euler, 'ZYX'), 0, 1).permute(1, 0, 2) # to match with scipy euler = -euler and transpose of this | |
R = euler_angles_to_matrix(euler, 'ZYX') # to match with scipy euler = -euler and transpose of this | |
t = translation.clone() | |
if invert: | |
R = R.transpose(1, 2) | |
t *= -1 | |
T = get_translation_matrix(t) | |
if invert: | |
M = torch.matmul(R, T) | |
else: | |
M = torch.matmul(T, R) | |
return M | |
def get_translation_matrix(translation_vector): | |
"""Convert a translation vector into a 4x4 transformation matrix | |
""" | |
T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) | |
t = translation_vector.contiguous().view(-1, 3, 1) | |
T[:, 0, 0] = 1 | |
T[:, 1, 1] = 1 | |
T[:, 2, 2] = 1 | |
T[:, 3, 3] = 1 | |
T[:, :3, 3, None] = t | |
return T | |
def rot_from_euler(vec): | |
rot = R.from_euler('zyx', vec, degrees=True) | |
return | |
def rot_from_axisangle(vec): | |
"""Convert an axisangle rotation into a 4x4 transformation matrix | |
(adapted from https://github.com/Wallacoloo/printipi) | |
Input 'vec' has to be Bx1x3 | |
""" | |
angle = torch.norm(vec, 2, 2, True) | |
axis = vec / (angle + 1e-7) | |
ca = torch.cos(angle) | |
sa = torch.sin(angle) | |
C = 1 - ca | |
x = axis[..., 0].unsqueeze(1) | |
y = axis[..., 1].unsqueeze(1) | |
z = axis[..., 2].unsqueeze(1) | |
xs = x * sa | |
ys = y * sa | |
zs = z * sa | |
xC = x * C | |
yC = y * C | |
zC = z * C | |
xyC = x * yC | |
yzC = y * zC | |
zxC = z * xC | |
rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) | |
rot[:, 0, 0] = torch.squeeze(x * xC + ca) | |
rot[:, 0, 1] = torch.squeeze(xyC - zs) | |
rot[:, 0, 2] = torch.squeeze(zxC + ys) | |
rot[:, 1, 0] = torch.squeeze(xyC + zs) | |
rot[:, 1, 1] = torch.squeeze(y * yC + ca) | |
rot[:, 1, 2] = torch.squeeze(yzC - xs) | |
rot[:, 2, 0] = torch.squeeze(zxC - ys) | |
rot[:, 2, 1] = torch.squeeze(yzC + xs) | |
rot[:, 2, 2] = torch.squeeze(z * zC + ca) | |
rot[:, 3, 3] = 1 | |
return rot | |
class ConvBlock(nn.Module): | |
"""Layer to perform a convolution followed by ELU | |
""" | |
def __init__(self, in_channels, out_channels): | |
super(ConvBlock, self).__init__() | |
self.conv = Conv3x3(in_channels, out_channels) | |
self.nonlin = nn.ELU(inplace=True) | |
def forward(self, x): | |
out = self.conv(x) | |
out = self.nonlin(out) | |
return out | |
def batchNorm(num_ch_dec): | |
return nn.BatchNorm2d(num_ch_dec) | |
class Conv3x3(nn.Module): | |
"""Layer to pad and convolve input | |
""" | |
def __init__(self, in_channels, out_channels, use_refl=True): | |
super(Conv3x3, self).__init__() | |
if use_refl: | |
self.pad = nn.ReflectionPad2d(1) | |
else: | |
self.pad = nn.ZeroPad2d(1) | |
self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) | |
def forward(self, x): | |
out = self.pad(x) | |
out = self.conv(out) | |
return out | |
class BackprojectDepth(nn.Module): | |
"""Layer to transform a depth image into a point cloud | |
""" | |
def __init__(self, batch_size, height, width): | |
super(BackprojectDepth, self).__init__() | |
self.batch_size = batch_size | |
self.height = height | |
self.width = width | |
meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') | |
self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) | |
self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), | |
requires_grad=False) | |
self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), | |
requires_grad=False) | |
self.pix_coords = torch.unsqueeze(torch.stack( | |
[self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) | |
self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) | |
self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), | |
requires_grad=False) | |
def forward(self, depth, inv_K): | |
cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) | |
cam_points = depth.view(self.batch_size, 1, -1) * cam_points | |
cam_points = torch.cat([cam_points, self.ones], 1) | |
return cam_points | |
class Project3D(nn.Module): | |
"""Layer which projects 3D points into a camera with intrinsics K and at position T | |
""" | |
def __init__(self, batch_size, height, width, eps=1e-7): | |
super(Project3D, self).__init__() | |
self.batch_size = batch_size | |
self.height = height | |
self.width = width | |
self.eps = eps | |
def forward(self, points, K, T): | |
P = torch.matmul(K, T)[:, :3, :] | |
cam_points = torch.matmul(P, points) | |
pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) | |
pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) | |
pix_coords = pix_coords.permute(0, 2, 3, 1) | |
pix_coords[..., 0] /= self.width - 1 | |
pix_coords[..., 1] /= self.height - 1 | |
pix_coords = (pix_coords - 0.5) * 2 | |
return pix_coords | |
def upsample(x): | |
"""Upsample input tensor by a factor of 2 | |
""" | |
return F.interpolate(x, scale_factor=2, mode="nearest") | |
class deconv(nn.Module): | |
"""Layer to perform a convolution followed by ELU | |
""" | |
def __init__(self, ch_in, ch_out): | |
super(deconv, self).__init__() | |
self.deconvlayer = nn.ConvTranspose2d(ch_in, ch_out, 3, stride=2, padding=1) | |
def forward(self, x): | |
out = self.deconvlayer(x) | |
return out | |
def get_smooth_loss_gauss_mask(disp, img, gauss_mask): | |
"""Computes the smoothness loss for a disparity image | |
The color image is used for edge-aware smoothness | |
""" | |
grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) | |
grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) | |
# weighted mean | |
# grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:])*gauss_mask[:, :, :, :-1], 1, keepdim=True) | |
# grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :])*gauss_mask[:, :, :-1, :], 1, keepdim=True) | |
grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) | |
grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) | |
grad_disp_x *= torch.exp(-grad_img_x) | |
grad_disp_y *= torch.exp(-grad_img_y) | |
# take weighted mean | |
grad_disp_x*=gauss_mask[:, :, :, :-1] | |
grad_disp_y*=gauss_mask[:, :, :-1, :] | |
return grad_disp_x.mean() + grad_disp_y.mean() | |
def get_smooth_loss(disp, img): | |
"""Computes the smoothness loss for a disparity image | |
The color image is used for edge-aware smoothness | |
""" | |
grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) | |
grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) | |
grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) | |
grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) | |
grad_disp_x *= torch.exp(-grad_img_x) | |
grad_disp_y *= torch.exp(-grad_img_y) | |
return grad_disp_x.mean() + grad_disp_y.mean() | |
class SSIM(nn.Module): | |
"""Layer to compute the SSIM loss between a pair of images | |
""" | |
def __init__(self): | |
super(SSIM, self).__init__() | |
self.mu_x_pool = nn.AvgPool2d(3, 1) | |
self.mu_y_pool = nn.AvgPool2d(3, 1) | |
self.sig_x_pool = nn.AvgPool2d(3, 1) | |
self.sig_y_pool = nn.AvgPool2d(3, 1) | |
self.sig_xy_pool = nn.AvgPool2d(3, 1) | |
self.refl = nn.ReflectionPad2d(1) | |
self.C1 = 0.01 ** 2 | |
self.C2 = 0.03 ** 2 | |
def forward(self, x, y): | |
x = self.refl(x) | |
y = self.refl(y) | |
mu_x = self.mu_x_pool(x) | |
mu_y = self.mu_y_pool(y) | |
sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 | |
sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 | |
sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y | |
SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) | |
SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) | |
return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) | |
def compute_depth_errors(gt, pred): | |
"""Computation of error metrics between predicted and ground truth depths | |
""" | |
thresh = torch.max((gt / pred), (pred / gt)) | |
a1 = (thresh < 1.25 ).float().mean() | |
a2 = (thresh < 1.25 ** 2).float().mean() | |
a3 = (thresh < 1.25 ** 3).float().mean() | |
rmse = (gt - pred) ** 2 | |
rmse = torch.sqrt(rmse.mean()) | |
rmse_log = (torch.log(gt) - torch.log(pred)) ** 2 | |
rmse_log = torch.sqrt(rmse_log.mean()) | |
abs_rel = torch.mean(torch.abs(gt - pred) / gt) | |
sq_rel = torch.mean((gt - pred) ** 2 / gt) | |
return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 | |
""" Parts of the U-Net model """ | |
class InstanceNormDoubleConv(nn.Module): | |
"""(convolution => [BN] => ReLU) * 2""" | |
def __init__(self, in_channels, out_channels, mid_channels=None): | |
super().__init__() | |
if not mid_channels: | |
mid_channels = out_channels | |
self.double_conv = nn.Sequential( | |
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), | |
nn.InstanceNorm2d(mid_channels, affine = True), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x): | |
return self.double_conv(x) | |
class DoubleConv(nn.Module): | |
"""(convolution => [BN] => ReLU) * 2""" | |
def __init__(self, in_channels, out_channels, mid_channels=None): | |
super().__init__() | |
if not mid_channels: | |
mid_channels = out_channels | |
self.double_conv = nn.Sequential( | |
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), | |
nn.BatchNorm2d(mid_channels), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x): | |
return self.double_conv(x) | |
class DoubleConvIN(nn.Module): | |
"""(convolution => [BN] => ReLU) * 2""" | |
def __init__(self, in_channels, out_channels, mid_channels=None): | |
super().__init__() | |
if not mid_channels: | |
mid_channels = out_channels | |
self.double_conv = nn.Sequential( | |
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), | |
nn.InstanceNorm2d(mid_channels,affine = True).to('cuda'), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
nn.InstanceNorm2d(out_channels,affine = True).to('cuda'), | |
nn.ReLU(inplace=True)) | |
def forward(self, x): | |
return self.double_conv(x) | |
class Down(nn.Module): | |
"""Downscaling with maxpool then double conv""" | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.maxpool_conv = nn.Sequential( | |
nn.MaxPool2d(2), | |
DoubleConv(in_channels, out_channels) | |
) | |
def forward(self, x): | |
return self.maxpool_conv(x) | |
class DownIN(nn.Module): | |
"""Downscaling with maxpool then double conv""" | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.maxpool_conv = nn.Sequential( | |
nn.MaxPool2d(2), | |
DoubleConvIN(in_channels, out_channels) | |
) | |
def forward(self, x): | |
return self.maxpool_conv(x) | |
class Up(nn.Module): | |
"""Upscaling then double conv""" | |
def __init__(self, in_channels, out_channels, bilinear=True): | |
super().__init__() | |
# if bilinear, use the normal convolutions to reduce the number of channels | |
if bilinear: | |
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) | |
else: | |
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) | |
self.conv = DoubleConv(in_channels, out_channels) | |
def forward(self, x1, x2): | |
x1 = self.up(x1) | |
# input is CHW | |
diffY = x2.size()[2] - x1.size()[2] | |
diffX = x2.size()[3] - x1.size()[3] | |
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, | |
diffY // 2, diffY - diffY // 2]) | |
# if you have padding issues, see | |
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a | |
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd | |
x = torch.cat([x2, x1], dim=1) | |
return self.conv(x) | |
class OutConv(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(OutConv, self).__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) | |
def forward(self, x): | |
return self.conv(x) | |
class UpIN(nn.Module): | |
"""Upscaling then double conv""" | |
def __init__(self, in_channels, out_channels, bilinear=True): | |
super().__init__() | |
# if bilinear, use the normal convolutions to reduce the number of channels | |
if bilinear: | |
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
self.conv = DoubleConvIN(in_channels, out_channels, in_channels // 2) | |
else: | |
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) | |
self.conv = DoubleConvIN(in_channels, out_channels) | |
def forward(self, x1, x2): | |
x1 = self.up(x1) | |
# input is CHW | |
diffY = x2.size()[2] - x1.size()[2] | |
diffX = x2.size()[3] - x1.size()[3] | |
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, | |
diffY // 2, diffY - diffY // 2]) | |
# if you have padding issues, see | |
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a | |
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd | |
x = torch.cat([x2, x1], dim=1) | |
return self.conv(x) | |
# def gaussian_fn(M, std): | |
# n = torch.arange(0, M) - (M - 1.0) / 2.0 | |
# sig2 = 2 * std * std | |
# w = torch.exp(-n ** 2 / sig2) | |
# return w | |
# def gkern(kernlen=256, std=128): | |
# """Returns a 2D Gaussian kernel array.""" | |
# gkern1d = gaussian_fn(kernlen, std=std) | |
# gkern2d = torch.outer(gkern1d, gkern1d) | |
# return gkern2d | |
# A = np.random.rand(256*256).reshape([256,256]) | |
# A = torch.from_numpy(A) | |
# guassian_filter = gkern(256, std=32) | |
# class GaussianLayer(nn.Module): | |
# def __init__(self): | |
# super(GaussianLayer, self).__init__() | |
# self.seq = nn.Sequential( | |
# nn.ReflectionPad2d(10), | |
# nn.Conv2d(3, 3, 21, stride=1, padding=0, bias=None, groups=3) | |
# ) | |
# self.weights_init() | |
# def forward(self, x): | |
# return self.seq(x) | |
# def weights_init(self): | |
# n= np.zeros((21,21)) | |
# n[10,10] = 1 | |
# k = scipy.ndimage.gaussian_filter(n,sigma=3) | |
# for name, f in self.named_parameters(): | |
# f.data.copy_(torch.from_numpy(k)) | |