import torch import torch.nn.functional as F pixel_coords = None def set_id_grid(depth): global pixel_coords b, h, w = depth.size() i_range = torch.arange(0, h).view(1, h, 1).expand( 1, h, w).type_as(depth) # [1, H, W] j_range = torch.arange(0, w).view(1, 1, w).expand( 1, h, w).type_as(depth) # [1, H, W] ones = torch.ones(1, h, w).type_as(depth) pixel_coords = torch.stack((j_range, i_range, ones), dim=1) # [1, 3, H, W] def check_sizes(input, input_name, expected): condition = [input.ndimension() == len(expected)] for i, size in enumerate(expected): if size.isdigit(): condition.append(input.size(i) == int(size)) assert(all(condition)), "wrong size for {}, expected {}, got {}".format( input_name, 'x'.join(expected), list(input.size())) def pixel2cam(depth, intrinsics_inv): global pixel_coords """Transform coordinates in the pixel frame to the camera frame. Args: depth: depth maps -- [B, H, W] intrinsics_inv: intrinsics_inv matrix for each element of batch -- [B, 3, 3] Returns: array of (u,v,1) cam coordinates -- [B, 3, H, W] """ b, h, w = depth.size() if (pixel_coords is None) or pixel_coords.size(2) < h: set_id_grid(depth) current_pixel_coords = pixel_coords[:, :, :h, :w].expand( b, 3, h, w).reshape(b, 3, -1) # [B, 3, H*W] cam_coords = (intrinsics_inv @ current_pixel_coords).reshape(b, 3, h, w) out = depth.unsqueeze(1) * cam_coords return out def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr, padding_mode): """Transform coordinates in the camera frame to the pixel frame. Args: cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 4, H, W] proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4] proj_c2p_tr: translation vectors of cameras -- [B, 3, 1] Returns: array of [-1,1] coordinates -- [B, 2, H, W] """ b, _, h, w = cam_coords.size() cam_coords_flat = cam_coords.reshape(b, 3, -1) # [B, 3, H*W] if proj_c2p_rot is not None: pcoords = proj_c2p_rot @ cam_coords_flat else: pcoords = cam_coords_flat if proj_c2p_tr is not None: pcoords = pcoords + proj_c2p_tr # [B, 3, H*W] X = pcoords[:, 0] Y = pcoords[:, 1] Z = pcoords[:, 2].clamp(min=1e-3) # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W] X_norm = 2*(X / Z)/(w-1) - 1 Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W] pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2] return pixel_coords.reshape(b, h, w, 2) def euler2mat(angle): """Convert euler angles to rotation matrix. Reference: https://github.com/pulkitag/pycaffe-utils/blob/master/rot_utils.py#L174 Args: angle: rotation angle along 3 axis (in radians) -- size = [B, 3] Returns: Rotation matrix corresponding to the euler angles -- size = [B, 3, 3] """ B = angle.size(0) x, y, z = angle[:, 0], angle[:, 1], angle[:, 2] cosz = torch.cos(z) sinz = torch.sin(z) zeros = z.detach()*0 ones = zeros.detach()+1 zmat = torch.stack([cosz, -sinz, zeros, sinz, cosz, zeros, zeros, zeros, ones], dim=1).reshape(B, 3, 3) cosy = torch.cos(y) siny = torch.sin(y) ymat = torch.stack([cosy, zeros, siny, zeros, ones, zeros, -siny, zeros, cosy], dim=1).reshape(B, 3, 3) cosx = torch.cos(x) sinx = torch.sin(x) xmat = torch.stack([ones, zeros, zeros, zeros, cosx, -sinx, zeros, sinx, cosx], dim=1).reshape(B, 3, 3) rotMat = xmat @ ymat @ zmat return rotMat def quat2mat(quat): """Convert quaternion coefficients to rotation matrix. Args: quat: first three coeff of quaternion of rotation. fourht is then computed to have a norm of 1 -- size = [B, 3] Returns: Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] """ norm_quat = torch.cat([quat[:, :1].detach()*0 + 1, quat], dim=1) norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] B = quat.size(0) w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) wx, wy, wz = w*x, w*y, w*z xy, xz, yz = x*y, x*z, y*z rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).reshape(B, 3, 3) return rotMat def pose_vec2mat(vec, rotation_mode='euler'): """ Convert 6DoF parameters to transformation matrix. Args:s vec: 6DoF parameters in the order of tx, ty, tz, rx, ry, rz -- [B, 6] Returns: A transformation matrix -- [B, 3, 4] """ translation = vec[:, :3].unsqueeze(-1) # [B, 3, 1] rot = vec[:, 3:] if rotation_mode == 'euler': rot_mat = euler2mat(rot) # [B, 3, 3] elif rotation_mode == 'quat': rot_mat = quat2mat(rot) # [B, 3, 3] transform_mat = torch.cat([rot_mat, translation], dim=2) # [B, 3, 4] return transform_mat def inverse_warp(img, depth, pose, intrinsics, rotation_mode='euler', padding_mode='zeros'): """ Inverse warp a source image to the target image plane. Args: img: the source image (where to sample pixels) -- [B, 3, H, W] depth: depth map of the target image -- [B, H, W] pose: 6DoF pose parameters from target to source -- [B, 6] intrinsics: camera intrinsic matrix -- [B, 3, 3] Returns: projected_img: Source image warped to the target image plane valid_points: Boolean array indicating point validity """ check_sizes(img, 'img', 'B3HW') check_sizes(depth, 'depth', 'BHW') check_sizes(pose, 'pose', 'B6') check_sizes(intrinsics, 'intrinsics', 'B33') batch_size, _, img_height, img_width = img.size() cam_coords = pixel2cam(depth, intrinsics.inverse()) # [B,3,H,W] pose_mat = pose_vec2mat(pose, rotation_mode) # [B,3,4] # Get projection matrix for tgt camera frame to source pixel frame proj_cam_to_src_pixel = intrinsics @ pose_mat # [B, 3, 4] rot, tr = proj_cam_to_src_pixel[:, :, :3], proj_cam_to_src_pixel[:, :, -1:] src_pixel_coords = cam2pixel( cam_coords, rot, tr, padding_mode) # [B,H,W,2] projected_img = F.grid_sample( img, src_pixel_coords, padding_mode=padding_mode) valid_points = src_pixel_coords.abs().max(dim=-1)[0] <= 1 return projected_img, valid_points def cam2pixel2(cam_coords, proj_c2p_rot, proj_c2p_tr, padding_mode): """Transform coordinates in the camera frame to the pixel frame. Args: cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 4, H, W] proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4] proj_c2p_tr: translation vectors of cameras -- [B, 3, 1] Returns: array of [-1,1] coordinates -- [B, 2, H, W] """ b, _, h, w = cam_coords.size() cam_coords_flat = cam_coords.reshape(b, 3, -1) # [B, 3, H*W] if proj_c2p_rot is not None: pcoords = proj_c2p_rot @ cam_coords_flat else: pcoords = cam_coords_flat if proj_c2p_tr is not None: pcoords = pcoords + proj_c2p_tr # [B, 3, H*W] X = pcoords[:, 0] Y = pcoords[:, 1] Z = pcoords[:, 2].clamp(min=1e-3) # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W] X_norm = 2*(X / Z)/(w-1) - 1 Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W] if padding_mode == 'zeros': X_mask = ((X_norm > 1)+(X_norm < -1)).detach() # make sure that no point in warped image is a combinaison of im and gray X_norm[X_mask] = 2 Y_mask = ((Y_norm > 1)+(Y_norm < -1)).detach() Y_norm[Y_mask] = 2 pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2] return pixel_coords.reshape(b, h, w, 2), Z.reshape(b, 1, h, w) def inverse_warp2(img, depth, ref_depth, pose, intrinsics, padding_mode='zeros'): """ Inverse warp a source image to the target image plane. Args: img: the source image (where to sample pixels) -- [B, 3, H, W] depth: depth map of the target image -- [B, 1, H, W] ref_depth: the source depth map (where to sample depth) -- [B, 1, H, W] pose: 6DoF pose parameters from target to source -- [B, 6] intrinsics: camera intrinsic matrix -- [B, 3, 3] Returns: projected_img: Source image warped to the target image plane valid_mask: Float array indicating point validity projected_depth: sampled depth from source image computed_depth: computed depth of source image using the target depth """ check_sizes(img, 'img', 'B3HW') check_sizes(depth, 'depth', 'B1HW') check_sizes(ref_depth, 'ref_depth', 'B1HW') check_sizes(pose, 'pose', 'B6') check_sizes(intrinsics, 'intrinsics', 'B33') batch_size, _, img_height, img_width = img.size() cam_coords = pixel2cam(depth.squeeze(1), intrinsics.inverse()) # [B,3,H,W] pose_mat = pose_vec2mat(pose) # [B,3,4] # Get projection matrix for tgt camera frame to source pixel frame proj_cam_to_src_pixel = intrinsics @ pose_mat # [B, 3, 4] rot, tr = proj_cam_to_src_pixel[:, :, :3], proj_cam_to_src_pixel[:, :, -1:] src_pixel_coords, computed_depth = cam2pixel2(cam_coords, rot, tr, padding_mode) # [B,H,W,2] projected_img = F.grid_sample(img, src_pixel_coords, padding_mode=padding_mode, align_corners=False) projected_depth = F.grid_sample(ref_depth, src_pixel_coords, padding_mode=padding_mode, align_corners=False) return projected_img, projected_depth, computed_depth def inverse_rotation_warp(img, rot, intrinsics, padding_mode='zeros'): b, _, h, w = img.size() cam_coords = pixel2cam(torch.ones(b, h, w).type_as(img), intrinsics.inverse()) # [B,3,H,W] rot_mat = euler2mat(rot) # [B, 3, 3] # Get projection matrix for tgt camera frame to source pixel frame proj_cam_to_src_pixel = intrinsics @ rot_mat # [B, 3, 3] src_pixel_coords, computed_depth = cam2pixel2(cam_coords, proj_cam_to_src_pixel, None, padding_mode) # [B,H,W,2] projected_img = F.grid_sample(img, src_pixel_coords, padding_mode=padding_mode, align_corners=True) return projected_img def grid_to_flow(grid): b, h, w, _ = grid.size() i_range = torch.arange(0, h).view(1, h, 1).expand(1, h, w).type_as(grid) # [1, H, W] j_range = torch.arange(0, w).view(1, 1, w).expand(1, h, w).type_as(grid) # [1, H, W] image_coords = torch.stack((j_range, i_range), dim=1) # [1, 2, H, W] flow = torch.zeros_like(grid).type_as(grid) flow[:, :, :, 0] = (grid[:, :, :, 0]+1) / 2 * (w-1) flow[:, :, :, 1] = (grid[:, :, :, 1]+1) / 2 * (h-1) flow = flow.permute([0, 3, 1, 2]) flow -= image_coords return flow def compute_translation_flow(depth, pose, intrinsics): cam_coords = pixel2cam(depth.squeeze(1), intrinsics.inverse()) # [B,3,H,W] pose_mat = pose_vec2mat(pose) # [B,3,4] # Get projection matrix for tgt camera frame to source pixel frame proj_cam_to_src_pixel = intrinsics @ pose_mat # [B, 3, 4] rot, tr = proj_cam_to_src_pixel[:, :, :3], proj_cam_to_src_pixel[:, :, -1:] grid_all, _ = cam2pixel2(cam_coords, rot, tr, padding_mode='zeros') # [B,H,W,2] grid_rot, _ = cam2pixel2(cam_coords, rot, None, padding_mode='zeros') # [B,H,W,2] flow_all = grid_to_flow(grid_all) flow_rot = grid_to_flow(grid_rot) flow_tr = (flow_all - flow_rot) return flow_tr