|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def disp_to_depth(disp, min_depth, max_depth): |
|
"""Convert network's sigmoid output into depth prediction |
|
The formula for this conversion is given in the 'additional considerations' |
|
section of the paper. |
|
""" |
|
min_disp = 1 / max_depth |
|
max_disp = 1 / min_depth |
|
scaled_disp = min_disp + (max_disp - min_disp) * disp |
|
depth = 1 / scaled_disp |
|
return scaled_disp, depth |
|
|
|
|
|
def transformation_from_parameters(axisangle, translation, invert=False): |
|
"""Convert the network's (axisangle, translation) output into a 4x4 matrix |
|
""" |
|
R = rot_from_axisangle(axisangle) |
|
t = translation.clone() |
|
|
|
if invert: |
|
R = R.transpose(1, 2) |
|
t *= -1 |
|
|
|
T = get_translation_matrix(t) |
|
|
|
if invert: |
|
M = torch.matmul(R, T) |
|
else: |
|
M = torch.matmul(T, R) |
|
|
|
return M |
|
|
|
|
|
def get_translation_matrix(translation_vector): |
|
"""Convert a translation vector into a 4x4 transformation matrix |
|
""" |
|
T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) |
|
|
|
t = translation_vector.contiguous().view(-1, 3, 1) |
|
|
|
T[:, 0, 0] = 1 |
|
T[:, 1, 1] = 1 |
|
T[:, 2, 2] = 1 |
|
T[:, 3, 3] = 1 |
|
T[:, :3, 3, None] = t |
|
|
|
return T |
|
|
|
|
|
def rot_from_axisangle(vec): |
|
"""Convert an axisangle rotation into a 4x4 transformation matrix |
|
(adapted from https://github.com/Wallacoloo/printipi) |
|
Input 'vec' has to be Bx1x3 |
|
""" |
|
angle = torch.norm(vec, 2, 2, True) |
|
axis = vec / (angle + 1e-7) |
|
|
|
ca = torch.cos(angle) |
|
sa = torch.sin(angle) |
|
C = 1 - ca |
|
|
|
x = axis[..., 0].unsqueeze(1) |
|
y = axis[..., 1].unsqueeze(1) |
|
z = axis[..., 2].unsqueeze(1) |
|
|
|
xs = x * sa |
|
ys = y * sa |
|
zs = z * sa |
|
xC = x * C |
|
yC = y * C |
|
zC = z * C |
|
xyC = x * yC |
|
yzC = y * zC |
|
zxC = z * xC |
|
|
|
rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) |
|
|
|
rot[:, 0, 0] = torch.squeeze(x * xC + ca) |
|
rot[:, 0, 1] = torch.squeeze(xyC - zs) |
|
rot[:, 0, 2] = torch.squeeze(zxC + ys) |
|
rot[:, 1, 0] = torch.squeeze(xyC + zs) |
|
rot[:, 1, 1] = torch.squeeze(y * yC + ca) |
|
rot[:, 1, 2] = torch.squeeze(yzC - xs) |
|
rot[:, 2, 0] = torch.squeeze(zxC - ys) |
|
rot[:, 2, 1] = torch.squeeze(yzC + xs) |
|
rot[:, 2, 2] = torch.squeeze(z * zC + ca) |
|
rot[:, 3, 3] = 1 |
|
|
|
return rot |
|
|
|
|
|
class ConvBlock(nn.Module): |
|
"""Layer to perform a convolution followed by ELU |
|
""" |
|
def __init__(self, in_channels, out_channels): |
|
super(ConvBlock, self).__init__() |
|
|
|
self.conv = Conv3x3(in_channels, out_channels) |
|
self.nonlin = nn.ELU(inplace=True) |
|
|
|
def forward(self, x): |
|
out = self.conv(x) |
|
out = self.nonlin(out) |
|
return out |
|
|
|
|
|
class Conv3x3(nn.Module): |
|
"""Layer to pad and convolve input |
|
""" |
|
def __init__(self, in_channels, out_channels, use_refl=True): |
|
super(Conv3x3, self).__init__() |
|
|
|
if use_refl: |
|
self.pad = nn.ReflectionPad2d(1) |
|
else: |
|
self.pad = nn.ZeroPad2d(1) |
|
self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) |
|
|
|
def forward(self, x): |
|
out = self.pad(x) |
|
out = self.conv(out) |
|
return out |
|
|
|
|
|
class BackprojectDepth(nn.Module): |
|
"""Layer to transform a depth image into a point cloud |
|
""" |
|
def __init__(self, batch_size, height, width, shift_rays_half_pixel=0): |
|
super(BackprojectDepth, self).__init__() |
|
|
|
self.batch_size = batch_size |
|
self.height = height |
|
self.width = width |
|
|
|
meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') |
|
id_coords = np.stack(meshgrid, axis=0).astype(np.float32) |
|
id_coords = torch.from_numpy(id_coords) |
|
|
|
ones = torch.ones(self.batch_size, 1, self.height * self.width) |
|
|
|
pix_coords = torch.unsqueeze(torch.stack( |
|
[id_coords[0].view(-1), id_coords[1].view(-1)], 0), 0) |
|
pix_coords = pix_coords.repeat(batch_size, 1, 1) |
|
pix_coords = torch.cat([pix_coords + shift_rays_half_pixel, |
|
ones], 1) |
|
self.register_buffer("pix_coords", pix_coords) |
|
self.register_buffer("id_coords", id_coords) |
|
self.register_buffer("ones", ones) |
|
|
|
|
|
|
|
def forward(self, depth, inv_K): |
|
cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords.to(depth.device)) |
|
cam_points = depth.view(self.batch_size, 1, -1) * cam_points |
|
cam_points = torch.cat([cam_points, self.ones.to(depth.device)], 1) |
|
|
|
return cam_points |
|
|
|
|
|
class Project3D(nn.Module): |
|
"""Layer which projects 3D points into a camera with intrinsics K and at position T |
|
""" |
|
def __init__(self, batch_size, height, width, eps=1e-7): |
|
super(Project3D, self).__init__() |
|
|
|
self.batch_size = batch_size |
|
self.height = height |
|
self.width = width |
|
self.eps = eps |
|
|
|
def forward(self, points, K, T=None): |
|
if T is None: |
|
P = K |
|
else: |
|
P = torch.matmul(K, T) |
|
P = P[:, :3, :] |
|
|
|
cam_points = torch.matmul(P, points) |
|
|
|
pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) |
|
pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) |
|
pix_coords = pix_coords.permute(0, 2, 3, 1) |
|
pix_coords[..., 0] /= self.width - 1 |
|
pix_coords[..., 1] /= self.height - 1 |
|
pix_coords = (pix_coords - 0.5) * 2 |
|
return pix_coords |
|
|
|
|
|
class Project3DSimple(nn.Module): |
|
"""Layer which projects 3D points into a camera with intrinsics K and at position T |
|
""" |
|
def __init__(self, batch_size, height, width, eps=1e-7): |
|
super(Project3DSimple, self).__init__() |
|
|
|
self.batch_size = batch_size |
|
self.height = height |
|
self.width = width |
|
self.eps = eps |
|
|
|
def forward(self, points, K): |
|
K = K[:, :3, :] |
|
|
|
cam_points = torch.matmul(K, points) |
|
|
|
pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) |
|
pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) |
|
pix_coords = pix_coords.permute(0, 2, 3, 1) |
|
return pix_coords |
|
|
|
def upsample(x): |
|
"""Upsample input tensor by a factor of 2 |
|
""" |
|
return F.interpolate(x, scale_factor=2, mode="nearest") |
|
|
|
|
|
def get_smooth_loss(disp, img): |
|
"""Computes the smoothness loss for a disparity image |
|
The color image is used for edge-aware smoothness |
|
""" |
|
grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) |
|
grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) |
|
|
|
grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) |
|
grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) |
|
|
|
grad_disp_x *= torch.exp(-grad_img_x) |
|
grad_disp_y *= torch.exp(-grad_img_y) |
|
|
|
return grad_disp_x.mean() + grad_disp_y.mean() |
|
|
|
|
|
class SSIM(nn.Module): |
|
"""Layer to compute the SSIM loss between a pair of images |
|
""" |
|
def __init__(self): |
|
super(SSIM, self).__init__() |
|
self.mu_x_pool = nn.AvgPool2d(3, 1) |
|
self.mu_y_pool = nn.AvgPool2d(3, 1) |
|
self.sig_x_pool = nn.AvgPool2d(3, 1) |
|
self.sig_y_pool = nn.AvgPool2d(3, 1) |
|
self.sig_xy_pool = nn.AvgPool2d(3, 1) |
|
|
|
self.refl = nn.ReflectionPad2d(1) |
|
|
|
self.C1 = 0.01 ** 2 |
|
self.C2 = 0.03 ** 2 |
|
|
|
def forward(self, x, y): |
|
x = self.refl(x) |
|
y = self.refl(y) |
|
|
|
mu_x = self.mu_x_pool(x) |
|
mu_y = self.mu_y_pool(y) |
|
|
|
sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 |
|
sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 |
|
sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y |
|
|
|
SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) |
|
SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) |
|
|
|
return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) |
|
|
|
|
|
def compute_depth_errors(gt, pred): |
|
"""Computation of error metrics between predicted and ground truth depths |
|
""" |
|
thresh = torch.max((gt / pred), (pred / gt)) |
|
a1 = (thresh < 1.25 ).float().mean() |
|
a2 = (thresh < 1.25 ** 2).float().mean() |
|
a3 = (thresh < 1.25 ** 3).float().mean() |
|
|
|
rmse = (gt - pred) ** 2 |
|
rmse = torch.sqrt(rmse.mean()) |
|
|
|
rmse_log = (torch.log(gt) - torch.log(pred)) ** 2 |
|
rmse_log = torch.sqrt(rmse_log.mean()) |
|
|
|
abs_rel = torch.mean(torch.abs(gt - pred) / gt) |
|
|
|
sq_rel = torch.mean((gt - pred) ** 2 / gt) |
|
|
|
return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 |
|
|