DepthPoseEstimation / layers.py
mkalia's picture
Upload layers.py
e015760 verified
# Copyright Niantic 2019. Patent Pending. All rights reserved.
#
# This software is licensed under the terms of the Monodepth2 licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.
from __future__ import absolute_import, division, print_function
import numpy as np
from scipy.spatial.transform import Rotation as R
import torch
import torch.nn as nn
import torch.nn.functional as F
# from torchmetrics.image.fid import FrechetInceptionDistance
# def silog(real1, fake1):
# # filter out invalid pixels
# real = real1.clone()
# fake = fake1.clone()
# N = (real>0).float().sum()
# mask1 = (real<=0)
# mask2 = (fake<=0)
# mask3 = mask1+mask2
# # mask = 1.0 - (mask3>0).float()
# mask = (mask3>0)
# fake[mask] = 1.
# real[mask] = 1.
# loss_ = torch.log(real)-torch.log(fake)
# loss = torch.sqrt((torch.sum( loss_ ** 2) / N ) - ((torch.sum(loss_)/N)**2))
# return loss
class SpatialTransformer(nn.Module):
def __init__(self, size, mode='bilinear'):
"""
Instiantiate the block
:param size: size of input to the spatial transformer block
:param mode: method of interpolation for grid_sampler
"""
super(SpatialTransformer, self).__init__()
# Create sampling grid
vectors = [torch.arange(0, s) for s in size]
grids = torch.meshgrid(vectors)
grid = torch.stack(grids) # y, x, z
grid = torch.unsqueeze(grid, 0) # add batch
grid = grid.type(torch.FloatTensor)
self.register_buffer('grid', grid)
self.mode = mode
def forward(self, src, flow):
"""
Push the src and flow through the spatial transform block
:param src: the source image
:param flow: the output from the U-Net
"""
new_locs = self.grid + flow
shape = flow.shape[2:]
# Need to normalize grid values to [-1, 1] for resampler
for i in range(len(shape)):
new_locs[:, i, ...] = 2*(new_locs[:, i, ...]/(shape[i]-1) - 0.5)
if len(shape) == 2:
new_locs = new_locs.permute(0, 2, 3, 1)
new_locs = new_locs[..., [1, 0]]
elif len(shape) == 3:
new_locs = new_locs.permute(0, 2, 3, 4, 1)
new_locs = new_locs[..., [2, 1, 0]]
return F.grid_sample(src, new_locs, mode=self.mode, padding_mode="border")
class optical_flow(nn.Module):
def __init__(self, size, batch_size, height, width, eps=1e-7):
super(optical_flow, self).__init__()
# Create sampling grid
vectors = [torch.arange(0, s) for s in size]
grids = torch.meshgrid(vectors)
grid = torch.stack(grids) # y, x, z
grid = torch.unsqueeze(grid, 0) # add batch
grid = grid.type(torch.FloatTensor)
self.register_buffer('grid', grid)
self.batch_size = batch_size
self.height = height
self.width = width
self.eps = eps
def forward(self, points, K, T):
P = torch.matmul(K, T)[:, :3, :]
cam_points = torch.matmul(P, points)
pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
optical_flow = pix_coords[:, [1,0], ...] - self.grid
return optical_flow
def get_corresponding_map(data):
"""
:param data: unnormalized coordinates Bx2xHxW
:return: Bx1xHxW
"""
B, _, H, W = data.size()
# x = data[:, 0, :, :].view(B, -1).clamp(0, W - 1) # BxN (N=H*W)
# y = data[:, 1, :, :].view(B, -1).clamp(0, H - 1)
x = data[:, 0, :, :].view(B, -1) # BxN (N=H*W)
y = data[:, 1, :, :].view(B, -1)
# invalid = (x < 0) | (x > W - 1) | (y < 0) | (y > H - 1) # BxN
# invalid = invalid.repeat([1, 4])
x1 = torch.floor(x)
x_floor = x1.clamp(0, W - 1)
y1 = torch.floor(y)
y_floor = y1.clamp(0, H - 1)
x0 = x1 + 1
x_ceil = x0.clamp(0, W - 1)
y0 = y1 + 1
y_ceil = y0.clamp(0, H - 1)
x_ceil_out = x0 != x_ceil
y_ceil_out = y0 != y_ceil
x_floor_out = x1 != x_floor
y_floor_out = y1 != y_floor
invalid = torch.cat([x_ceil_out | y_ceil_out,
x_ceil_out | y_floor_out,
x_floor_out | y_ceil_out,
x_floor_out | y_floor_out], dim=1)
# encode coordinates, since the scatter function can only index along one axis
corresponding_map = torch.zeros(B, H * W).type_as(data)
indices = torch.cat([x_ceil + y_ceil * W,
x_ceil + y_floor * W,
x_floor + y_ceil * W,
x_floor + y_floor * W], 1).long() # BxN (N=4*H*W)
values = torch.cat([(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_ceil)),
(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_floor)),
(1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_ceil)),
(1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_floor))],
1)
# values = torch.ones_like(values)
values[invalid] = 0
corresponding_map.scatter_add_(1, indices, values)
# decode coordinates
corresponding_map = corresponding_map.view(B, H, W)
return corresponding_map.unsqueeze(1)
class get_occu_mask_backward(nn.Module):
def __init__(self, size):
super(get_occu_mask_backward, self).__init__()
# Create sampling grid
vectors = [torch.arange(0, s) for s in size]
grids = torch.meshgrid(vectors)
grid = torch.stack(grids) # y, x, z
grid = torch.unsqueeze(grid, 0) # add batch
grid = grid.type(torch.FloatTensor)
self.register_buffer('grid', grid)
def forward(self, flow, th=0.95):
new_locs = self.grid + flow
new_locs = new_locs[:, [1,0], ...]
corr_map = get_corresponding_map(new_locs)
occu_map = corr_map
occu_mask = (occu_map > th).float()
return occu_mask, occu_map
class get_occu_mask_bidirection(nn.Module):
def __init__(self, size, mode='bilinear'):
super(get_occu_mask_bidirection, self).__init__()
# Create sampling grid
vectors = [torch.arange(0, s) for s in size]
grids = torch.meshgrid(vectors)
grid = torch.stack(grids) # y, x, z
grid = torch.unsqueeze(grid, 0) # add batch
grid = grid.type(torch.FloatTensor)
self.register_buffer('grid', grid)
self.mode = mode
def forward(self, flow12, flow21, scale=0.01, bias=0.5):
new_locs = self.grid + flow12
shape = flow12.shape[2:]
# Need to normalize grid values to [-1, 1] for resampler
for i in range(len(shape)):
new_locs[:, i, ...] = 2*(new_locs[:, i, ...]/(shape[i]-1) - 0.5)
if len(shape) == 2:
new_locs = new_locs.permute(0, 2, 3, 1)
new_locs = new_locs[..., [1, 0]]
elif len(shape) == 3:
new_locs = new_locs.permute(0, 2, 3, 4, 1)
new_locs = new_locs[..., [2, 1, 0]]
flow21_warped = F.grid_sample(flow21, new_locs, mode=self.mode, padding_mode="border")
flow12_diff = torch.abs(flow12 + flow21_warped)
# mag = (flow12 * flow12).sum(1, keepdim=True) + \
# (flow21_warped * flow21_warped).sum(1, keepdim=True)
# occ_thresh = scale * mag + bias
# occ_mask = (flow12_diff * flow12_diff).sum(1, keepdim=True) < occ_thresh
return flow12_diff
# functions
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
"""
Return the rotation matrices for one of the rotations about an axis
of which Euler angles describe, for each value of the angle given.
Args:
axis: Axis label "X" or "Y or "Z".
angle: any shape tensor of Euler angles in radians
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "X":
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
elif axis == "Y":
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
elif axis == "Z":
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
else:
raise ValueError("letter must be either X, Y or Z.")
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
"""
Convert rotations given as Euler angles in radians to rotation matrices.
Args:
euler_angles: Euler angles in radians as tensor of shape (..., 3).
convention: Convention string of three uppercase letters from
{"X", "Y", and "Z"}.
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
raise ValueError("Invalid input euler angles.")
if len(convention) != 3:
raise ValueError("Convention must have 3 letters.")
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = [
_axis_angle_rotation(c, e)
for c, e in zip(convention, torch.unbind(euler_angles, -1))
]
# return functools.reduce(torch.matmul, matrices)
rotation_matrices = torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
rot = torch.zeros((rotation_matrices.shape[0], 4, 4)).to(device=rotation_matrices.device)
rot[:, :3, :3] = rotation_matrices.squeeze()
rot[:, 3, 3] = 1
return rot
def _angle_from_tan(
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
) -> torch.Tensor:
"""
Extract the first or third Euler angle from the two members of
the matrix which are positive constant times its sine and cosine.
Args:
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
convention.
data: Rotation matrices as tensor of shape (..., 3, 3).
horizontal: Whether we are looking for the angle for the third axis,
which means the relevant entries are in the same row of the
rotation matrix. If not, they are in the same column.
tait_bryan: Whether the first and third axes in the convention differ.
Returns:
Euler Angles in radians for each matrix in data as a tensor
of shape (...).
"""
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
if horizontal:
i2, i1 = i1, i2
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
if horizontal == even:
return torch.atan2(data[..., i1], data[..., i2])
if tait_bryan:
return torch.atan2(-data[..., i2], data[..., i1])
return torch.atan2(data[..., i2], -data[..., i1])
def matrix_2_euler_vector(matrix, convention = 'ZYX', roll = True):
# matrix = matrix_in.copy()
euler = (matrix_to_euler_angles(matrix[:, :3,:3], convention)) # to match with scipy euler = -euler and transpose of this
if roll:
euler[0] = 0.0
t = matrix[:, :3,3]
out = torch.cat([euler, t], dim = 0)
return out
def _index_from_letter(letter: str) -> int:
if letter == "X":
return 0
if letter == "Y":
return 1
if letter == "Z":
return 2
raise ValueError("letter must be either X, Y or Z.")
def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to Euler angles in radians.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
convention: Convention string of three uppercase letters.
Returns:
Euler angles in radians as tensor of shape (..., 3).
"""
if len(convention) != 3:
raise ValueError("Convention must have 3 letters.")
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
i0 = _index_from_letter(convention[0])
i2 = _index_from_letter(convention[2])
tait_bryan = i0 != i2
if tait_bryan:
central_angle = torch.asin(
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
)
else:
central_angle = torch.acos(matrix[..., i0, i0])
o = (
_angle_from_tan(
convention[0], convention[1], matrix[..., i2], False, tait_bryan
),
central_angle,
_angle_from_tan(
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
),
)
return torch.stack(o, -1)
def computeFID(real_images, fake_images, fid_criterion):
# metric = FrechetInceptionDistance(feature)
fid_criterion.update(real_images, real=True)
fid_criterion.update(fake_images, real=False)
return fid_criterion.compute()
class SLlog(nn.Module):
def __init__(self):
super(SLlog, self).__init__()
def forward(self, fake1, real1):
if not fake1.shape == real1.shape:
_,_,H,W = real1.shape
fake = F.upsample(fake, size=(H,W), mode='bilinear')
# filter out invalid pixels
real = real1.clone()
fake = fake1.clone()
N = (real>0).float().sum()
mask1 = (real<=0)
mask2 = (fake<=0)
mask3 = mask1+mask2
# mask = 1.0 - (mask3>0).float()
mask = (mask3>0)
fake[mask] = 1.
real[mask] = 1.
loss_ = torch.log(real)-torch.log(fake)
loss = torch.sqrt((torch.sum( loss_ ** 2) / N ) - ((torch.sum(loss_)/N)**2))
# loss = 100.* torch.sum( torch.abs(torch.log(real)-torch.log(fake)) ) / N
return loss
class RMSE_log(nn.Module):
def __init__(self, use_cuda):
super(RMSE_log, self).__init__()
self.eps = 1e-8
self.use_cuda = use_cuda
def forward(self, fake, real):
mask = real<1.
n,_,h,w = real.size()
fake = F.upsample(fake, size=(h,w), mode='bilinear')
fake += self.eps
N = len(real[mask])
loss = torch.sqrt( torch.sum( torch.abs(torch.log(real[mask])-torch.log(fake[mask])) ** 2 ) / N )
return loss
def depth_to_disp(depth, min_disp=0.00001, max_disp = 1.000001):
"""Convert network's sigmoid output into depth prediction
The formula for this conversion is given in the 'additional considerations'
section of the paper.
"""
min_depth = 1 / max_disp
max_depth = 1 / min_disp
scaled_depth = min_depth + (max_depth - min_depth) * depth
disp = 1 / scaled_depth
return scaled_depth, disp
def disp_to_depth(disp, min_depth, max_depth):
"""Convert network's sigmoid output into depth prediction
The formula for this conversion is given in the 'additional considerations'
section of the paper.
"""
min_disp = 1 / max_depth
max_disp = 1 / min_depth
scaled_disp = min_disp + (max_disp - min_disp) * disp
depth = 1 / scaled_disp
return scaled_disp, depth
def disp_to_depth_no_scaling(disp):
"""Convert network's sigmoid output into depth prediction
The formula for this conversion is given in the 'additional considerations'
section of the paper.
"""
depth = 1 / (disp + 1e-7)
return depth
def transformation_from_parameters(axisangle, translation, invert=False):
"""Convert the network's (axisangle, translation) output into a 4x4 matrix
"""
R = rot_from_axisangle(axisangle)
t = translation.clone()
if invert:
R = R.transpose(1, 2) # uncomment beore running
t *= -1
T = get_translation_matrix(t)
if invert:
M = torch.matmul(R, T)
else:
M = torch.matmul(T, R)
return M
def transformation_from_parameters_euler(euler, translation, invert=False):
"""Convert the network's (axisangle, translation) output into a 4x4 matrix
"""
# R = torch.transpose(euler_angles_to_matrix(euler, 'ZYX'), 0, 1).permute(1, 0, 2) # to match with scipy euler = -euler and transpose of this
R = euler_angles_to_matrix(euler, 'ZYX') # to match with scipy euler = -euler and transpose of this
t = translation.clone()
if invert:
R = R.transpose(1, 2)
t *= -1
T = get_translation_matrix(t)
if invert:
M = torch.matmul(R, T)
else:
M = torch.matmul(T, R)
return M
def get_translation_matrix(translation_vector):
"""Convert a translation vector into a 4x4 transformation matrix
"""
T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
t = translation_vector.contiguous().view(-1, 3, 1)
T[:, 0, 0] = 1
T[:, 1, 1] = 1
T[:, 2, 2] = 1
T[:, 3, 3] = 1
T[:, :3, 3, None] = t
return T
def rot_from_euler(vec):
rot = R.from_euler('zyx', vec, degrees=True)
return
def rot_from_axisangle(vec):
"""Convert an axisangle rotation into a 4x4 transformation matrix
(adapted from https://github.com/Wallacoloo/printipi)
Input 'vec' has to be Bx1x3
"""
angle = torch.norm(vec, 2, 2, True)
axis = vec / (angle + 1e-7)
ca = torch.cos(angle)
sa = torch.sin(angle)
C = 1 - ca
x = axis[..., 0].unsqueeze(1)
y = axis[..., 1].unsqueeze(1)
z = axis[..., 2].unsqueeze(1)
xs = x * sa
ys = y * sa
zs = z * sa
xC = x * C
yC = y * C
zC = z * C
xyC = x * yC
yzC = y * zC
zxC = z * xC
rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
rot[:, 0, 0] = torch.squeeze(x * xC + ca)
rot[:, 0, 1] = torch.squeeze(xyC - zs)
rot[:, 0, 2] = torch.squeeze(zxC + ys)
rot[:, 1, 0] = torch.squeeze(xyC + zs)
rot[:, 1, 1] = torch.squeeze(y * yC + ca)
rot[:, 1, 2] = torch.squeeze(yzC - xs)
rot[:, 2, 0] = torch.squeeze(zxC - ys)
rot[:, 2, 1] = torch.squeeze(yzC + xs)
rot[:, 2, 2] = torch.squeeze(z * zC + ca)
rot[:, 3, 3] = 1
return rot
class ConvBlock(nn.Module):
"""Layer to perform a convolution followed by ELU
"""
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv = Conv3x3(in_channels, out_channels)
self.nonlin = nn.ELU(inplace=True)
def forward(self, x):
out = self.conv(x)
out = self.nonlin(out)
return out
def batchNorm(num_ch_dec):
return nn.BatchNorm2d(num_ch_dec)
class Conv3x3(nn.Module):
"""Layer to pad and convolve input
"""
def __init__(self, in_channels, out_channels, use_refl=True):
super(Conv3x3, self).__init__()
if use_refl:
self.pad = nn.ReflectionPad2d(1)
else:
self.pad = nn.ZeroPad2d(1)
self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
def forward(self, x):
out = self.pad(x)
out = self.conv(out)
return out
class BackprojectDepth(nn.Module):
"""Layer to transform a depth image into a point cloud
"""
def __init__(self, batch_size, height, width):
super(BackprojectDepth, self).__init__()
self.batch_size = batch_size
self.height = height
self.width = width
meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
requires_grad=False)
self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
requires_grad=False)
self.pix_coords = torch.unsqueeze(torch.stack(
[self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
requires_grad=False)
def forward(self, depth, inv_K):
cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
cam_points = depth.view(self.batch_size, 1, -1) * cam_points
cam_points = torch.cat([cam_points, self.ones], 1)
return cam_points
class Project3D(nn.Module):
"""Layer which projects 3D points into a camera with intrinsics K and at position T
"""
def __init__(self, batch_size, height, width, eps=1e-7):
super(Project3D, self).__init__()
self.batch_size = batch_size
self.height = height
self.width = width
self.eps = eps
def forward(self, points, K, T):
P = torch.matmul(K, T)[:, :3, :]
cam_points = torch.matmul(P, points)
pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
pix_coords = pix_coords.permute(0, 2, 3, 1)
pix_coords[..., 0] /= self.width - 1
pix_coords[..., 1] /= self.height - 1
pix_coords = (pix_coords - 0.5) * 2
return pix_coords
def upsample(x):
"""Upsample input tensor by a factor of 2
"""
return F.interpolate(x, scale_factor=2, mode="nearest")
class deconv(nn.Module):
"""Layer to perform a convolution followed by ELU
"""
def __init__(self, ch_in, ch_out):
super(deconv, self).__init__()
self.deconvlayer = nn.ConvTranspose2d(ch_in, ch_out, 3, stride=2, padding=1)
def forward(self, x):
out = self.deconvlayer(x)
return out
def get_smooth_loss_gauss_mask(disp, img, gauss_mask):
"""Computes the smoothness loss for a disparity image
The color image is used for edge-aware smoothness
"""
grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
# weighted mean
# grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:])*gauss_mask[:, :, :, :-1], 1, keepdim=True)
# grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :])*gauss_mask[:, :, :-1, :], 1, keepdim=True)
grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
grad_disp_x *= torch.exp(-grad_img_x)
grad_disp_y *= torch.exp(-grad_img_y)
# take weighted mean
grad_disp_x*=gauss_mask[:, :, :, :-1]
grad_disp_y*=gauss_mask[:, :, :-1, :]
return grad_disp_x.mean() + grad_disp_y.mean()
def get_smooth_loss(disp, img):
"""Computes the smoothness loss for a disparity image
The color image is used for edge-aware smoothness
"""
grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
grad_disp_x *= torch.exp(-grad_img_x)
grad_disp_y *= torch.exp(-grad_img_y)
return grad_disp_x.mean() + grad_disp_y.mean()
class SSIM(nn.Module):
"""Layer to compute the SSIM loss between a pair of images
"""
def __init__(self):
super(SSIM, self).__init__()
self.mu_x_pool = nn.AvgPool2d(3, 1)
self.mu_y_pool = nn.AvgPool2d(3, 1)
self.sig_x_pool = nn.AvgPool2d(3, 1)
self.sig_y_pool = nn.AvgPool2d(3, 1)
self.sig_xy_pool = nn.AvgPool2d(3, 1)
self.refl = nn.ReflectionPad2d(1)
self.C1 = 0.01 ** 2
self.C2 = 0.03 ** 2
def forward(self, x, y):
x = self.refl(x)
y = self.refl(y)
mu_x = self.mu_x_pool(x)
mu_y = self.mu_y_pool(y)
sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
def compute_depth_errors(gt, pred):
"""Computation of error metrics between predicted and ground truth depths
"""
thresh = torch.max((gt / pred), (pred / gt))
a1 = (thresh < 1.25 ).float().mean()
a2 = (thresh < 1.25 ** 2).float().mean()
a3 = (thresh < 1.25 ** 3).float().mean()
rmse = (gt - pred) ** 2
rmse = torch.sqrt(rmse.mean())
rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
rmse_log = torch.sqrt(rmse_log.mean())
abs_rel = torch.mean(torch.abs(gt - pred) / gt)
sq_rel = torch.mean((gt - pred) ** 2 / gt)
return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
""" Parts of the U-Net model """
class InstanceNormDoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.InstanceNorm2d(mid_channels, affine = True),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class DoubleConvIN(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.InstanceNorm2d(mid_channels,affine = True).to('cuda'),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.InstanceNorm2d(out_channels,affine = True).to('cuda'),
nn.ReLU(inplace=True))
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class DownIN(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConvIN(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class UpIN(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConvIN(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConvIN(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
# def gaussian_fn(M, std):
# n = torch.arange(0, M) - (M - 1.0) / 2.0
# sig2 = 2 * std * std
# w = torch.exp(-n ** 2 / sig2)
# return w
# def gkern(kernlen=256, std=128):
# """Returns a 2D Gaussian kernel array."""
# gkern1d = gaussian_fn(kernlen, std=std)
# gkern2d = torch.outer(gkern1d, gkern1d)
# return gkern2d
# A = np.random.rand(256*256).reshape([256,256])
# A = torch.from_numpy(A)
# guassian_filter = gkern(256, std=32)
# class GaussianLayer(nn.Module):
# def __init__(self):
# super(GaussianLayer, self).__init__()
# self.seq = nn.Sequential(
# nn.ReflectionPad2d(10),
# nn.Conv2d(3, 3, 21, stride=1, padding=0, bias=None, groups=3)
# )
# self.weights_init()
# def forward(self, x):
# return self.seq(x)
# def weights_init(self):
# n= np.zeros((21,21))
# n[10,10] = 1
# k = scipy.ndimage.gaussian_filter(n,sigma=3)
# for name, f in self.named_parameters():
# f.data.copy_(torch.from_numpy(k))