# Copyright Niantic 2019. Patent Pending. All rights reserved. # # This software is licensed under the terms of the Monodepth2 licence # which allows for non-commercial use only, the full terms of which are made # available in the LICENSE file. from __future__ import absolute_import, division, print_function import numpy as np from scipy.spatial.transform import Rotation as R import torch import torch.nn as nn import torch.nn.functional as F # from torchmetrics.image.fid import FrechetInceptionDistance # def silog(real1, fake1): # # filter out invalid pixels # real = real1.clone() # fake = fake1.clone() # N = (real>0).float().sum() # mask1 = (real<=0) # mask2 = (fake<=0) # mask3 = mask1+mask2 # # mask = 1.0 - (mask3>0).float() # mask = (mask3>0) # fake[mask] = 1. # real[mask] = 1. # loss_ = torch.log(real)-torch.log(fake) # loss = torch.sqrt((torch.sum( loss_ ** 2) / N ) - ((torch.sum(loss_)/N)**2)) # return loss class SpatialTransformer(nn.Module): def __init__(self, size, mode='bilinear'): """ Instiantiate the block :param size: size of input to the spatial transformer block :param mode: method of interpolation for grid_sampler """ super(SpatialTransformer, self).__init__() # Create sampling grid vectors = [torch.arange(0, s) for s in size] grids = torch.meshgrid(vectors) grid = torch.stack(grids) # y, x, z grid = torch.unsqueeze(grid, 0) # add batch grid = grid.type(torch.FloatTensor) self.register_buffer('grid', grid) self.mode = mode def forward(self, src, flow): """ Push the src and flow through the spatial transform block :param src: the source image :param flow: the output from the U-Net """ new_locs = self.grid + flow shape = flow.shape[2:] # Need to normalize grid values to [-1, 1] for resampler for i in range(len(shape)): new_locs[:, i, ...] = 2*(new_locs[:, i, ...]/(shape[i]-1) - 0.5) if len(shape) == 2: new_locs = new_locs.permute(0, 2, 3, 1) new_locs = new_locs[..., [1, 0]] elif len(shape) == 3: new_locs = new_locs.permute(0, 2, 3, 4, 1) new_locs = new_locs[..., [2, 1, 0]] return F.grid_sample(src, new_locs, mode=self.mode, padding_mode="border") class optical_flow(nn.Module): def __init__(self, size, batch_size, height, width, eps=1e-7): super(optical_flow, self).__init__() # Create sampling grid vectors = [torch.arange(0, s) for s in size] grids = torch.meshgrid(vectors) grid = torch.stack(grids) # y, x, z grid = torch.unsqueeze(grid, 0) # add batch grid = grid.type(torch.FloatTensor) self.register_buffer('grid', grid) self.batch_size = batch_size self.height = height self.width = width self.eps = eps def forward(self, points, K, T): P = torch.matmul(K, T)[:, :3, :] cam_points = torch.matmul(P, points) pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) optical_flow = pix_coords[:, [1,0], ...] - self.grid return optical_flow def get_corresponding_map(data): """ :param data: unnormalized coordinates Bx2xHxW :return: Bx1xHxW """ B, _, H, W = data.size() # x = data[:, 0, :, :].view(B, -1).clamp(0, W - 1) # BxN (N=H*W) # y = data[:, 1, :, :].view(B, -1).clamp(0, H - 1) x = data[:, 0, :, :].view(B, -1) # BxN (N=H*W) y = data[:, 1, :, :].view(B, -1) # invalid = (x < 0) | (x > W - 1) | (y < 0) | (y > H - 1) # BxN # invalid = invalid.repeat([1, 4]) x1 = torch.floor(x) x_floor = x1.clamp(0, W - 1) y1 = torch.floor(y) y_floor = y1.clamp(0, H - 1) x0 = x1 + 1 x_ceil = x0.clamp(0, W - 1) y0 = y1 + 1 y_ceil = y0.clamp(0, H - 1) x_ceil_out = x0 != x_ceil y_ceil_out = y0 != y_ceil x_floor_out = x1 != x_floor y_floor_out = y1 != y_floor invalid = torch.cat([x_ceil_out | y_ceil_out, x_ceil_out | y_floor_out, x_floor_out | y_ceil_out, x_floor_out | y_floor_out], dim=1) # encode coordinates, since the scatter function can only index along one axis corresponding_map = torch.zeros(B, H * W).type_as(data) indices = torch.cat([x_ceil + y_ceil * W, x_ceil + y_floor * W, x_floor + y_ceil * W, x_floor + y_floor * W], 1).long() # BxN (N=4*H*W) values = torch.cat([(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_ceil)), (1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_floor)), (1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_ceil)), (1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_floor))], 1) # values = torch.ones_like(values) values[invalid] = 0 corresponding_map.scatter_add_(1, indices, values) # decode coordinates corresponding_map = corresponding_map.view(B, H, W) return corresponding_map.unsqueeze(1) class get_occu_mask_backward(nn.Module): def __init__(self, size): super(get_occu_mask_backward, self).__init__() # Create sampling grid vectors = [torch.arange(0, s) for s in size] grids = torch.meshgrid(vectors) grid = torch.stack(grids) # y, x, z grid = torch.unsqueeze(grid, 0) # add batch grid = grid.type(torch.FloatTensor) self.register_buffer('grid', grid) def forward(self, flow, th=0.95): new_locs = self.grid + flow new_locs = new_locs[:, [1,0], ...] corr_map = get_corresponding_map(new_locs) occu_map = corr_map occu_mask = (occu_map > th).float() return occu_mask, occu_map class get_occu_mask_bidirection(nn.Module): def __init__(self, size, mode='bilinear'): super(get_occu_mask_bidirection, self).__init__() # Create sampling grid vectors = [torch.arange(0, s) for s in size] grids = torch.meshgrid(vectors) grid = torch.stack(grids) # y, x, z grid = torch.unsqueeze(grid, 0) # add batch grid = grid.type(torch.FloatTensor) self.register_buffer('grid', grid) self.mode = mode def forward(self, flow12, flow21, scale=0.01, bias=0.5): new_locs = self.grid + flow12 shape = flow12.shape[2:] # Need to normalize grid values to [-1, 1] for resampler for i in range(len(shape)): new_locs[:, i, ...] = 2*(new_locs[:, i, ...]/(shape[i]-1) - 0.5) if len(shape) == 2: new_locs = new_locs.permute(0, 2, 3, 1) new_locs = new_locs[..., [1, 0]] elif len(shape) == 3: new_locs = new_locs.permute(0, 2, 3, 4, 1) new_locs = new_locs[..., [2, 1, 0]] flow21_warped = F.grid_sample(flow21, new_locs, mode=self.mode, padding_mode="border") flow12_diff = torch.abs(flow12 + flow21_warped) # mag = (flow12 * flow12).sum(1, keepdim=True) + \ # (flow21_warped * flow21_warped).sum(1, keepdim=True) # occ_thresh = scale * mag + bias # occ_mask = (flow12_diff * flow12_diff).sum(1, keepdim=True) < occ_thresh return flow12_diff # functions def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: """ Return the rotation matrices for one of the rotations about an axis of which Euler angles describe, for each value of the angle given. Args: axis: Axis label "X" or "Y or "Z". angle: any shape tensor of Euler angles in radians Returns: Rotation matrices as tensor of shape (..., 3, 3). """ cos = torch.cos(angle) sin = torch.sin(angle) one = torch.ones_like(angle) zero = torch.zeros_like(angle) if axis == "X": R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) elif axis == "Y": R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) elif axis == "Z": R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) else: raise ValueError("letter must be either X, Y or Z.") return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: """ Convert rotations given as Euler angles in radians to rotation matrices. Args: euler_angles: Euler angles in radians as tensor of shape (..., 3). convention: Convention string of three uppercase letters from {"X", "Y", and "Z"}. Returns: Rotation matrices as tensor of shape (..., 3, 3). """ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: raise ValueError("Invalid input euler angles.") if len(convention) != 3: raise ValueError("Convention must have 3 letters.") if convention[1] in (convention[0], convention[2]): raise ValueError(f"Invalid convention {convention}.") for letter in convention: if letter not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") matrices = [ _axis_angle_rotation(c, e) for c, e in zip(convention, torch.unbind(euler_angles, -1)) ] # return functools.reduce(torch.matmul, matrices) rotation_matrices = torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) rot = torch.zeros((rotation_matrices.shape[0], 4, 4)).to(device=rotation_matrices.device) rot[:, :3, :3] = rotation_matrices.squeeze() rot[:, 3, 3] = 1 return rot def _angle_from_tan( axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool ) -> torch.Tensor: """ Extract the first or third Euler angle from the two members of the matrix which are positive constant times its sine and cosine. Args: axis: Axis label "X" or "Y or "Z" for the angle we are finding. other_axis: Axis label "X" or "Y or "Z" for the middle axis in the convention. data: Rotation matrices as tensor of shape (..., 3, 3). horizontal: Whether we are looking for the angle for the third axis, which means the relevant entries are in the same row of the rotation matrix. If not, they are in the same column. tait_bryan: Whether the first and third axes in the convention differ. Returns: Euler Angles in radians for each matrix in data as a tensor of shape (...). """ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] if horizontal: i2, i1 = i1, i2 even = (axis + other_axis) in ["XY", "YZ", "ZX"] if horizontal == even: return torch.atan2(data[..., i1], data[..., i2]) if tait_bryan: return torch.atan2(-data[..., i2], data[..., i1]) return torch.atan2(data[..., i2], -data[..., i1]) def matrix_2_euler_vector(matrix, convention = 'ZYX', roll = True): # matrix = matrix_in.copy() euler = (matrix_to_euler_angles(matrix[:, :3,:3], convention)) # to match with scipy euler = -euler and transpose of this if roll: euler[0] = 0.0 t = matrix[:, :3,3] out = torch.cat([euler, t], dim = 0) return out def _index_from_letter(letter: str) -> int: if letter == "X": return 0 if letter == "Y": return 1 if letter == "Z": return 2 raise ValueError("letter must be either X, Y or Z.") def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: """ Convert rotations given as rotation matrices to Euler angles in radians. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). convention: Convention string of three uppercase letters. Returns: Euler angles in radians as tensor of shape (..., 3). """ if len(convention) != 3: raise ValueError("Convention must have 3 letters.") if convention[1] in (convention[0], convention[2]): raise ValueError(f"Invalid convention {convention}.") for letter in convention: if letter not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") i0 = _index_from_letter(convention[0]) i2 = _index_from_letter(convention[2]) tait_bryan = i0 != i2 if tait_bryan: central_angle = torch.asin( matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) ) else: central_angle = torch.acos(matrix[..., i0, i0]) o = ( _angle_from_tan( convention[0], convention[1], matrix[..., i2], False, tait_bryan ), central_angle, _angle_from_tan( convention[2], convention[1], matrix[..., i0, :], True, tait_bryan ), ) return torch.stack(o, -1) def computeFID(real_images, fake_images, fid_criterion): # metric = FrechetInceptionDistance(feature) fid_criterion.update(real_images, real=True) fid_criterion.update(fake_images, real=False) return fid_criterion.compute() class SLlog(nn.Module): def __init__(self): super(SLlog, self).__init__() def forward(self, fake1, real1): if not fake1.shape == real1.shape: _,_,H,W = real1.shape fake = F.upsample(fake, size=(H,W), mode='bilinear') # filter out invalid pixels real = real1.clone() fake = fake1.clone() N = (real>0).float().sum() mask1 = (real<=0) mask2 = (fake<=0) mask3 = mask1+mask2 # mask = 1.0 - (mask3>0).float() mask = (mask3>0) fake[mask] = 1. real[mask] = 1. loss_ = torch.log(real)-torch.log(fake) loss = torch.sqrt((torch.sum( loss_ ** 2) / N ) - ((torch.sum(loss_)/N)**2)) # loss = 100.* torch.sum( torch.abs(torch.log(real)-torch.log(fake)) ) / N return loss class RMSE_log(nn.Module): def __init__(self, use_cuda): super(RMSE_log, self).__init__() self.eps = 1e-8 self.use_cuda = use_cuda def forward(self, fake, real): mask = real<1. n,_,h,w = real.size() fake = F.upsample(fake, size=(h,w), mode='bilinear') fake += self.eps N = len(real[mask]) loss = torch.sqrt( torch.sum( torch.abs(torch.log(real[mask])-torch.log(fake[mask])) ** 2 ) / N ) return loss def depth_to_disp(depth, min_disp=0.00001, max_disp = 1.000001): """Convert network's sigmoid output into depth prediction The formula for this conversion is given in the 'additional considerations' section of the paper. """ min_depth = 1 / max_disp max_depth = 1 / min_disp scaled_depth = min_depth + (max_depth - min_depth) * depth disp = 1 / scaled_depth return scaled_depth, disp def disp_to_depth(disp, min_depth, max_depth): """Convert network's sigmoid output into depth prediction The formula for this conversion is given in the 'additional considerations' section of the paper. """ min_disp = 1 / max_depth max_disp = 1 / min_depth scaled_disp = min_disp + (max_disp - min_disp) * disp depth = 1 / scaled_disp return scaled_disp, depth def disp_to_depth_no_scaling(disp): """Convert network's sigmoid output into depth prediction The formula for this conversion is given in the 'additional considerations' section of the paper. """ depth = 1 / (disp + 1e-7) return depth def transformation_from_parameters(axisangle, translation, invert=False): """Convert the network's (axisangle, translation) output into a 4x4 matrix """ R = rot_from_axisangle(axisangle) t = translation.clone() if invert: R = R.transpose(1, 2) # uncomment beore running t *= -1 T = get_translation_matrix(t) if invert: M = torch.matmul(R, T) else: M = torch.matmul(T, R) return M def transformation_from_parameters_euler(euler, translation, invert=False): """Convert the network's (axisangle, translation) output into a 4x4 matrix """ # R = torch.transpose(euler_angles_to_matrix(euler, 'ZYX'), 0, 1).permute(1, 0, 2) # to match with scipy euler = -euler and transpose of this R = euler_angles_to_matrix(euler, 'ZYX') # to match with scipy euler = -euler and transpose of this t = translation.clone() if invert: R = R.transpose(1, 2) t *= -1 T = get_translation_matrix(t) if invert: M = torch.matmul(R, T) else: M = torch.matmul(T, R) return M def get_translation_matrix(translation_vector): """Convert a translation vector into a 4x4 transformation matrix """ T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) t = translation_vector.contiguous().view(-1, 3, 1) T[:, 0, 0] = 1 T[:, 1, 1] = 1 T[:, 2, 2] = 1 T[:, 3, 3] = 1 T[:, :3, 3, None] = t return T def rot_from_euler(vec): rot = R.from_euler('zyx', vec, degrees=True) return def rot_from_axisangle(vec): """Convert an axisangle rotation into a 4x4 transformation matrix (adapted from https://github.com/Wallacoloo/printipi) Input 'vec' has to be Bx1x3 """ angle = torch.norm(vec, 2, 2, True) axis = vec / (angle + 1e-7) ca = torch.cos(angle) sa = torch.sin(angle) C = 1 - ca x = axis[..., 0].unsqueeze(1) y = axis[..., 1].unsqueeze(1) z = axis[..., 2].unsqueeze(1) xs = x * sa ys = y * sa zs = z * sa xC = x * C yC = y * C zC = z * C xyC = x * yC yzC = y * zC zxC = z * xC rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) rot[:, 0, 0] = torch.squeeze(x * xC + ca) rot[:, 0, 1] = torch.squeeze(xyC - zs) rot[:, 0, 2] = torch.squeeze(zxC + ys) rot[:, 1, 0] = torch.squeeze(xyC + zs) rot[:, 1, 1] = torch.squeeze(y * yC + ca) rot[:, 1, 2] = torch.squeeze(yzC - xs) rot[:, 2, 0] = torch.squeeze(zxC - ys) rot[:, 2, 1] = torch.squeeze(yzC + xs) rot[:, 2, 2] = torch.squeeze(z * zC + ca) rot[:, 3, 3] = 1 return rot class ConvBlock(nn.Module): """Layer to perform a convolution followed by ELU """ def __init__(self, in_channels, out_channels): super(ConvBlock, self).__init__() self.conv = Conv3x3(in_channels, out_channels) self.nonlin = nn.ELU(inplace=True) def forward(self, x): out = self.conv(x) out = self.nonlin(out) return out def batchNorm(num_ch_dec): return nn.BatchNorm2d(num_ch_dec) class Conv3x3(nn.Module): """Layer to pad and convolve input """ def __init__(self, in_channels, out_channels, use_refl=True): super(Conv3x3, self).__init__() if use_refl: self.pad = nn.ReflectionPad2d(1) else: self.pad = nn.ZeroPad2d(1) self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) def forward(self, x): out = self.pad(x) out = self.conv(out) return out class BackprojectDepth(nn.Module): """Layer to transform a depth image into a point cloud """ def __init__(self, batch_size, height, width): super(BackprojectDepth, self).__init__() self.batch_size = batch_size self.height = height self.width = width meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), requires_grad=False) self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), requires_grad=False) self.pix_coords = torch.unsqueeze(torch.stack( [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), requires_grad=False) def forward(self, depth, inv_K): cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) cam_points = depth.view(self.batch_size, 1, -1) * cam_points cam_points = torch.cat([cam_points, self.ones], 1) return cam_points class Project3D(nn.Module): """Layer which projects 3D points into a camera with intrinsics K and at position T """ def __init__(self, batch_size, height, width, eps=1e-7): super(Project3D, self).__init__() self.batch_size = batch_size self.height = height self.width = width self.eps = eps def forward(self, points, K, T): P = torch.matmul(K, T)[:, :3, :] cam_points = torch.matmul(P, points) pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) pix_coords = pix_coords.permute(0, 2, 3, 1) pix_coords[..., 0] /= self.width - 1 pix_coords[..., 1] /= self.height - 1 pix_coords = (pix_coords - 0.5) * 2 return pix_coords def upsample(x): """Upsample input tensor by a factor of 2 """ return F.interpolate(x, scale_factor=2, mode="nearest") class deconv(nn.Module): """Layer to perform a convolution followed by ELU """ def __init__(self, ch_in, ch_out): super(deconv, self).__init__() self.deconvlayer = nn.ConvTranspose2d(ch_in, ch_out, 3, stride=2, padding=1) def forward(self, x): out = self.deconvlayer(x) return out def get_smooth_loss_gauss_mask(disp, img, gauss_mask): """Computes the smoothness loss for a disparity image The color image is used for edge-aware smoothness """ grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) # weighted mean # grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:])*gauss_mask[:, :, :, :-1], 1, keepdim=True) # grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :])*gauss_mask[:, :, :-1, :], 1, keepdim=True) grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) grad_disp_x *= torch.exp(-grad_img_x) grad_disp_y *= torch.exp(-grad_img_y) # take weighted mean grad_disp_x*=gauss_mask[:, :, :, :-1] grad_disp_y*=gauss_mask[:, :, :-1, :] return grad_disp_x.mean() + grad_disp_y.mean() def get_smooth_loss(disp, img): """Computes the smoothness loss for a disparity image The color image is used for edge-aware smoothness """ grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) grad_disp_x *= torch.exp(-grad_img_x) grad_disp_y *= torch.exp(-grad_img_y) return grad_disp_x.mean() + grad_disp_y.mean() class SSIM(nn.Module): """Layer to compute the SSIM loss between a pair of images """ def __init__(self): super(SSIM, self).__init__() self.mu_x_pool = nn.AvgPool2d(3, 1) self.mu_y_pool = nn.AvgPool2d(3, 1) self.sig_x_pool = nn.AvgPool2d(3, 1) self.sig_y_pool = nn.AvgPool2d(3, 1) self.sig_xy_pool = nn.AvgPool2d(3, 1) self.refl = nn.ReflectionPad2d(1) self.C1 = 0.01 ** 2 self.C2 = 0.03 ** 2 def forward(self, x, y): x = self.refl(x) y = self.refl(y) mu_x = self.mu_x_pool(x) mu_y = self.mu_y_pool(y) sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) def compute_depth_errors(gt, pred): """Computation of error metrics between predicted and ground truth depths """ thresh = torch.max((gt / pred), (pred / gt)) a1 = (thresh < 1.25 ).float().mean() a2 = (thresh < 1.25 ** 2).float().mean() a3 = (thresh < 1.25 ** 3).float().mean() rmse = (gt - pred) ** 2 rmse = torch.sqrt(rmse.mean()) rmse_log = (torch.log(gt) - torch.log(pred)) ** 2 rmse_log = torch.sqrt(rmse_log.mean()) abs_rel = torch.mean(torch.abs(gt - pred) / gt) sq_rel = torch.mean((gt - pred) ** 2 / gt) return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 """ Parts of the U-Net model """ class InstanceNormDoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.InstanceNorm2d(mid_channels, affine = True), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.double_conv(x) class DoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.double_conv(x) class DoubleConvIN(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.InstanceNorm2d(mid_channels,affine = True).to('cuda'), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.InstanceNorm2d(out_channels,affine = True).to('cuda'), nn.ReLU(inplace=True)) def forward(self, x): return self.double_conv(x) class Down(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x) class DownIN(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConvIN(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x) class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return self.conv(x) class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super(OutConv, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): return self.conv(x) class UpIN(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConvIN(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConvIN(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return self.conv(x) # def gaussian_fn(M, std): # n = torch.arange(0, M) - (M - 1.0) / 2.0 # sig2 = 2 * std * std # w = torch.exp(-n ** 2 / sig2) # return w # def gkern(kernlen=256, std=128): # """Returns a 2D Gaussian kernel array.""" # gkern1d = gaussian_fn(kernlen, std=std) # gkern2d = torch.outer(gkern1d, gkern1d) # return gkern2d # A = np.random.rand(256*256).reshape([256,256]) # A = torch.from_numpy(A) # guassian_filter = gkern(256, std=32) # class GaussianLayer(nn.Module): # def __init__(self): # super(GaussianLayer, self).__init__() # self.seq = nn.Sequential( # nn.ReflectionPad2d(10), # nn.Conv2d(3, 3, 21, stride=1, padding=0, bias=None, groups=3) # ) # self.weights_init() # def forward(self, x): # return self.seq(x) # def weights_init(self): # n= np.zeros((21,21)) # n[10,10] = 1 # k = scipy.ndimage.gaussian_filter(n,sigma=3) # for name, f in self.named_parameters(): # f.data.copy_(torch.from_numpy(k))