File size: 7,341 Bytes
854f0d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import torch
from torch.nn.functional import grid_sample
def back_project_sparse_type(coords, origin, voxel_size, feats, KRcam, sizeH=None, sizeW=None, only_mask=False,
with_proj_z=False):
# - modified version from NeuRecon
'''
Unproject the image fetures to form a 3D (sparse) feature volume
:param coords: coordinates of voxels,
dim: (num of voxels, 4) (4 : batch ind, x, y, z)
:param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0))
dim: (batch size, 3) (3: x, y, z)
:param voxel_size: floats specifying the size of a voxel
:param feats: image features
dim: (num of views, batch size, C, H, W)
:param KRcam: projection matrix
dim: (num of views, batch size, 4, 4)
:return: feature_volume_all: 3D feature volumes
dim: (num of voxels, num_of_views, c)
:return: mask_volume_all: indicate the voxel of sampled feature volume is valid or not
dim: (num of voxels, num_of_views)
'''
n_views, bs, c, h, w = feats.shape
device = feats.device
if sizeH is None:
sizeH, sizeW = h, w # - if the KRcam is not suitable for the current feats
feature_volume_all = torch.zeros(coords.shape[0], n_views, c).to(device)
mask_volume_all = torch.zeros([coords.shape[0], n_views], dtype=torch.int32).to(device)
# import ipdb; ipdb.set_trace()
for batch in range(bs):
# import ipdb; ipdb.set_trace()
batch_ind = torch.nonzero(coords[:, 0] == batch).squeeze(1)
coords_batch = coords[batch_ind][:, 1:]
coords_batch = coords_batch.view(-1, 3)
origin_batch = origin[batch].unsqueeze(0)
feats_batch = feats[:, batch]
proj_batch = KRcam[:, batch]
grid_batch = coords_batch * voxel_size + origin_batch.float()
rs_grid = grid_batch.unsqueeze(0).expand(n_views, -1, -1)
rs_grid = rs_grid.permute(0, 2, 1).contiguous()
nV = rs_grid.shape[-1]
rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1)
# Project grid
im_p = proj_batch @ rs_grid # - transform world pts to image UV space
im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2]
im_z[im_z >= 0] = im_z[im_z >= 0].clamp(min=1e-6)
im_x = im_x / im_z
im_y = im_y / im_z
im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1)
mask = im_grid.abs() <= 1
mask = (mask.sum(dim=-1) == 2) & (im_z > 0)
mask = mask.view(n_views, -1)
mask = mask.permute(1, 0).contiguous() # [num_pts, nviews]
mask_volume_all[batch_ind] = mask.to(torch.int32)
if only_mask:
return mask_volume_all
feats_batch = feats_batch.view(n_views, c, h, w)
im_grid = im_grid.view(n_views, 1, -1, 2)
features = grid_sample(feats_batch, im_grid, padding_mode='zeros', align_corners=True)
# if features.isnan().sum() > 0:
# import ipdb; ipdb.set_trace()
features = features.view(n_views, c, -1)
features = features.permute(2, 0, 1).contiguous() # [num_pts, nviews, c]
feature_volume_all[batch_ind] = features
if with_proj_z:
im_z = im_z.view(n_views, 1, -1).permute(2, 0, 1).contiguous() # [num_pts, nviews, 1]
return feature_volume_all, mask_volume_all, im_z
# if feature_volume_all.isnan().sum() > 0:
# import ipdb; ipdb.set_trace()
return feature_volume_all, mask_volume_all
def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr, padding_mode, sizeH=None, sizeW=None, with_depth=False):
"""Transform coordinates in the camera frame to the pixel frame.
Args:
cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W]
proj_c2p_rot: rotation matrix of cameras -- [B, 3, 3]
proj_c2p_tr: translation vectors of cameras -- [B, 3, 1]
Returns:
array of [-1,1] coordinates -- [B, H, W, 2]
"""
b, _, h, w = cam_coords.size()
if sizeH is None:
sizeH = h
sizeW = w
cam_coords_flat = cam_coords.view(b, 3, -1) # [B, 3, H*W]
if proj_c2p_rot is not None:
pcoords = proj_c2p_rot.bmm(cam_coords_flat)
else:
pcoords = cam_coords_flat
if proj_c2p_tr is not None:
pcoords = pcoords + proj_c2p_tr # [B, 3, H*W]
X = pcoords[:, 0]
Y = pcoords[:, 1]
Z = pcoords[:, 2].clamp(min=1e-3)
X_norm = 2 * (X / Z) / (sizeW - 1) - 1 # Normalized, -1 if on extreme left,
# 1 if on extreme right (x = w-1) [B, H*W]
Y_norm = 2 * (Y / Z) / (sizeH - 1) - 1 # Idem [B, H*W]
if padding_mode == 'zeros':
X_mask = ((X_norm > 1) + (X_norm < -1)).detach()
X_norm[X_mask] = 2 # make sure that no point in warped image is a combinaison of im and gray
Y_mask = ((Y_norm > 1) + (Y_norm < -1)).detach()
Y_norm[Y_mask] = 2
if with_depth:
pixel_coords = torch.stack([X_norm, Y_norm, Z], dim=2) # [B, H*W, 3]
return pixel_coords.view(b, h, w, 3)
else:
pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2]
return pixel_coords.view(b, h, w, 2)
# * have already checked, should check whether proj_matrix is for right coordinate system and resolution
def back_project_dense_type(coords, origin, voxel_size, feats, proj_matrix, sizeH=None, sizeW=None):
'''
Unproject the image fetures to form a 3D (dense) feature volume
:param coords: coordinates of voxels,
dim: (batch, nviews, 3, X,Y,Z)
:param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0))
dim: (batch size, 3) (3: x, y, z)
:param voxel_size: floats specifying the size of a voxel
:param feats: image features
dim: (batch size, num of views, C, H, W)
:param proj_matrix: projection matrix
dim: (batch size, num of views, 4, 4)
:return: feature_volume_all: 3D feature volumes
dim: (batch, nviews, C, X,Y,Z)
:return: count: number of times each voxel can be seen
dim: (batch, nviews, 1, X,Y,Z)
'''
batch, nviews, _, wX, wY, wZ = coords.shape
if sizeH is None:
sizeH, sizeW = feats.shape[-2:]
proj_matrix = proj_matrix.view(batch * nviews, *proj_matrix.shape[2:])
coords_wrd = coords * voxel_size + origin.view(batch, 1, 3, 1, 1, 1)
coords_wrd = coords_wrd.view(batch * nviews, 3, wX * wY * wZ, 1) # (b*nviews,3,wX*wY*wZ, 1)
pixel_grids = cam2pixel(coords_wrd, proj_matrix[:, :3, :3], proj_matrix[:, :3, 3:],
'zeros', sizeH=sizeH, sizeW=sizeW) # (b*nviews,wX*wY*wZ, 2)
pixel_grids = pixel_grids.view(batch * nviews, 1, wX * wY * wZ, 2)
feats = feats.view(batch * nviews, *feats.shape[2:]) # (b*nviews,c,h,w)
ones = torch.ones((batch * nviews, 1, *feats.shape[2:])).to(feats.dtype).to(feats.device)
features_volume = torch.nn.functional.grid_sample(feats, pixel_grids, padding_mode='zeros', align_corners=True)
counts_volume = torch.nn.functional.grid_sample(ones, pixel_grids, padding_mode='zeros', align_corners=True)
features_volume = features_volume.view(batch, nviews, -1, wX, wY, wZ) # (batch, nviews, C, X,Y,Z)
counts_volume = counts_volume.view(batch, nviews, -1, wX, wY, wZ)
return features_volume, counts_volume
|