|
import torch |
|
from torch.nn.functional import grid_sample |
|
|
|
|
|
def back_project_sparse_type(coords, origin, voxel_size, feats, KRcam, sizeH=None, sizeW=None, only_mask=False, |
|
with_proj_z=False): |
|
|
|
''' |
|
Unproject the image fetures to form a 3D (sparse) feature volume |
|
|
|
:param coords: coordinates of voxels, |
|
dim: (num of voxels, 4) (4 : batch ind, x, y, z) |
|
:param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0)) |
|
dim: (batch size, 3) (3: x, y, z) |
|
:param voxel_size: floats specifying the size of a voxel |
|
:param feats: image features |
|
dim: (num of views, batch size, C, H, W) |
|
:param KRcam: projection matrix |
|
dim: (num of views, batch size, 4, 4) |
|
:return: feature_volume_all: 3D feature volumes |
|
dim: (num of voxels, num_of_views, c) |
|
:return: mask_volume_all: indicate the voxel of sampled feature volume is valid or not |
|
dim: (num of voxels, num_of_views) |
|
''' |
|
n_views, bs, c, h, w = feats.shape |
|
device = feats.device |
|
|
|
if sizeH is None: |
|
sizeH, sizeW = h, w |
|
|
|
feature_volume_all = torch.zeros(coords.shape[0], n_views, c).to(device) |
|
mask_volume_all = torch.zeros([coords.shape[0], n_views], dtype=torch.int32).to(device) |
|
|
|
for batch in range(bs): |
|
|
|
batch_ind = torch.nonzero(coords[:, 0] == batch).squeeze(1) |
|
coords_batch = coords[batch_ind][:, 1:] |
|
|
|
coords_batch = coords_batch.view(-1, 3) |
|
origin_batch = origin[batch].unsqueeze(0) |
|
feats_batch = feats[:, batch] |
|
proj_batch = KRcam[:, batch] |
|
|
|
grid_batch = coords_batch * voxel_size + origin_batch.float() |
|
rs_grid = grid_batch.unsqueeze(0).expand(n_views, -1, -1) |
|
rs_grid = rs_grid.permute(0, 2, 1).contiguous() |
|
nV = rs_grid.shape[-1] |
|
rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1) |
|
|
|
|
|
im_p = proj_batch @ rs_grid |
|
im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2] |
|
|
|
im_z[im_z >= 0] = im_z[im_z >= 0].clamp(min=1e-6) |
|
|
|
im_x = im_x / im_z |
|
im_y = im_y / im_z |
|
|
|
im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1) |
|
mask = im_grid.abs() <= 1 |
|
mask = (mask.sum(dim=-1) == 2) & (im_z > 0) |
|
|
|
mask = mask.view(n_views, -1) |
|
mask = mask.permute(1, 0).contiguous() |
|
|
|
mask_volume_all[batch_ind] = mask.to(torch.int32) |
|
|
|
if only_mask: |
|
return mask_volume_all |
|
|
|
feats_batch = feats_batch.view(n_views, c, h, w) |
|
im_grid = im_grid.view(n_views, 1, -1, 2) |
|
features = grid_sample(feats_batch, im_grid, padding_mode='zeros', align_corners=True) |
|
|
|
|
|
features = features.view(n_views, c, -1) |
|
features = features.permute(2, 0, 1).contiguous() |
|
|
|
feature_volume_all[batch_ind] = features |
|
|
|
if with_proj_z: |
|
im_z = im_z.view(n_views, 1, -1).permute(2, 0, 1).contiguous() |
|
return feature_volume_all, mask_volume_all, im_z |
|
|
|
|
|
return feature_volume_all, mask_volume_all |
|
|
|
|
|
def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr, padding_mode, sizeH=None, sizeW=None, with_depth=False): |
|
"""Transform coordinates in the camera frame to the pixel frame. |
|
Args: |
|
cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W] |
|
proj_c2p_rot: rotation matrix of cameras -- [B, 3, 3] |
|
proj_c2p_tr: translation vectors of cameras -- [B, 3, 1] |
|
Returns: |
|
array of [-1,1] coordinates -- [B, H, W, 2] |
|
""" |
|
b, _, h, w = cam_coords.size() |
|
if sizeH is None: |
|
sizeH = h |
|
sizeW = w |
|
|
|
cam_coords_flat = cam_coords.view(b, 3, -1) |
|
if proj_c2p_rot is not None: |
|
pcoords = proj_c2p_rot.bmm(cam_coords_flat) |
|
else: |
|
pcoords = cam_coords_flat |
|
|
|
if proj_c2p_tr is not None: |
|
pcoords = pcoords + proj_c2p_tr |
|
X = pcoords[:, 0] |
|
Y = pcoords[:, 1] |
|
Z = pcoords[:, 2].clamp(min=1e-3) |
|
|
|
X_norm = 2 * (X / Z) / (sizeW - 1) - 1 |
|
|
|
Y_norm = 2 * (Y / Z) / (sizeH - 1) - 1 |
|
if padding_mode == 'zeros': |
|
X_mask = ((X_norm > 1) + (X_norm < -1)).detach() |
|
X_norm[X_mask] = 2 |
|
Y_mask = ((Y_norm > 1) + (Y_norm < -1)).detach() |
|
Y_norm[Y_mask] = 2 |
|
|
|
if with_depth: |
|
pixel_coords = torch.stack([X_norm, Y_norm, Z], dim=2) |
|
return pixel_coords.view(b, h, w, 3) |
|
else: |
|
pixel_coords = torch.stack([X_norm, Y_norm], dim=2) |
|
return pixel_coords.view(b, h, w, 2) |
|
|
|
|
|
|
|
def back_project_dense_type(coords, origin, voxel_size, feats, proj_matrix, sizeH=None, sizeW=None): |
|
''' |
|
Unproject the image fetures to form a 3D (dense) feature volume |
|
|
|
:param coords: coordinates of voxels, |
|
dim: (batch, nviews, 3, X,Y,Z) |
|
:param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0)) |
|
dim: (batch size, 3) (3: x, y, z) |
|
:param voxel_size: floats specifying the size of a voxel |
|
:param feats: image features |
|
dim: (batch size, num of views, C, H, W) |
|
:param proj_matrix: projection matrix |
|
dim: (batch size, num of views, 4, 4) |
|
:return: feature_volume_all: 3D feature volumes |
|
dim: (batch, nviews, C, X,Y,Z) |
|
:return: count: number of times each voxel can be seen |
|
dim: (batch, nviews, 1, X,Y,Z) |
|
''' |
|
|
|
batch, nviews, _, wX, wY, wZ = coords.shape |
|
|
|
if sizeH is None: |
|
sizeH, sizeW = feats.shape[-2:] |
|
proj_matrix = proj_matrix.view(batch * nviews, *proj_matrix.shape[2:]) |
|
|
|
coords_wrd = coords * voxel_size + origin.view(batch, 1, 3, 1, 1, 1) |
|
coords_wrd = coords_wrd.view(batch * nviews, 3, wX * wY * wZ, 1) |
|
|
|
pixel_grids = cam2pixel(coords_wrd, proj_matrix[:, :3, :3], proj_matrix[:, :3, 3:], |
|
'zeros', sizeH=sizeH, sizeW=sizeW) |
|
pixel_grids = pixel_grids.view(batch * nviews, 1, wX * wY * wZ, 2) |
|
|
|
feats = feats.view(batch * nviews, *feats.shape[2:]) |
|
|
|
ones = torch.ones((batch * nviews, 1, *feats.shape[2:])).to(feats.dtype).to(feats.device) |
|
|
|
features_volume = torch.nn.functional.grid_sample(feats, pixel_grids, padding_mode='zeros', align_corners=True) |
|
counts_volume = torch.nn.functional.grid_sample(ones, pixel_grids, padding_mode='zeros', align_corners=True) |
|
|
|
features_volume = features_volume.view(batch, nviews, -1, wX, wY, wZ) |
|
counts_volume = counts_volume.view(batch, nviews, -1, wX, wY, wZ) |
|
return features_volume, counts_volume |
|
|
|
|