# python3.8 """Contains utility functions for rendering.""" import torch def normalize_vecs(vectors): """ Normalize vector lengths. """ return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) def truncated_normal(tensor, mean=0, std=1): """ Samples from truncated normal distribution. """ size = tensor.shape tmp = tensor.new_empty(size + (4,)).normal_() valid = (tmp < 2) & (tmp > -2) ind = valid.max(-1, keepdim=True)[1] tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) tensor.data.mul_(std).add_(mean) return tensor def get_grid_coords(points, bounds): """ transform points from the world coordinate to the volume coordinate pts: batch_size, num_point, 3 bounds: 2, 3 """ # normalize the points bounds = bounds[None] min_xyz = bounds[:, :1] points = points - min_xyz # convert the voxel coordinate to [-1, 1] size = bounds[:, 1] - bounds[:, 0] points = (points / size[:, None]) * 2 - 1 return points def grid_sample_3d(image, optical): """grid sample images by the optical in 3D format image: batch_size, channel, D, H, W optical: batch_size, D, H, W, 3 """ N, C, ID, IH, IW = image.shape _, D, H, W, _ = optical.shape ix = optical[..., 0] iy = optical[..., 1] iz = optical[..., 2] ix = ((ix + 1) / 2) * (IW - 1) iy = ((iy + 1) / 2) * (IH - 1) iz = ((iz + 1) / 2) * (ID - 1) with torch.no_grad(): ix_tnw = torch.floor(ix) iy_tnw = torch.floor(iy) iz_tnw = torch.floor(iz) ix_tne = ix_tnw + 1 iy_tne = iy_tnw iz_tne = iz_tnw ix_tsw = ix_tnw iy_tsw = iy_tnw + 1 iz_tsw = iz_tnw ix_tse = ix_tnw + 1 iy_tse = iy_tnw + 1 iz_tse = iz_tnw ix_bnw = ix_tnw iy_bnw = iy_tnw iz_bnw = iz_tnw + 1 ix_bne = ix_tnw + 1 iy_bne = iy_tnw iz_bne = iz_tnw + 1 ix_bsw = ix_tnw iy_bsw = iy_tnw + 1 iz_bsw = iz_tnw + 1 ix_bse = ix_tnw + 1 iy_bse = iy_tnw + 1 iz_bse = iz_tnw + 1 tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz) tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz) tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz) tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz) bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse) bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw) bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne) bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw) with torch.no_grad(): torch.clamp(ix_tnw, 0, IW - 1, out=ix_tnw) torch.clamp(iy_tnw, 0, IH - 1, out=iy_tnw) torch.clamp(iz_tnw, 0, ID - 1, out=iz_tnw) torch.clamp(ix_tne, 0, IW - 1, out=ix_tne) torch.clamp(iy_tne, 0, IH - 1, out=iy_tne) torch.clamp(iz_tne, 0, ID - 1, out=iz_tne) torch.clamp(ix_tsw, 0, IW - 1, out=ix_tsw) torch.clamp(iy_tsw, 0, IH - 1, out=iy_tsw) torch.clamp(iz_tsw, 0, ID - 1, out=iz_tsw) torch.clamp(ix_tse, 0, IW - 1, out=ix_tse) torch.clamp(iy_tse, 0, IH - 1, out=iy_tse) torch.clamp(iz_tse, 0, ID - 1, out=iz_tse) torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw) torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw) torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw) torch.clamp(ix_bne, 0, IW - 1, out=ix_bne) torch.clamp(iy_bne, 0, IH - 1, out=iy_bne) torch.clamp(iz_bne, 0, ID - 1, out=iz_bne) torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw) torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw) torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw) torch.clamp(ix_bse, 0, IW - 1, out=ix_bse) torch.clamp(iy_bse, 0, IH - 1, out=iy_bse) torch.clamp(iz_bse, 0, ID - 1, out=iz_bse) image = image.view(N, C, ID * IH * IW) tnw_val = torch.gather(image, 2, (iz_tnw * IW * IH + iy_tnw * IW + ix_tnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) tne_val = torch.gather(image, 2, (iz_tne * IW * IH + iy_tne * IW + ix_tne).long().view(N, 1, D * H * W).repeat(1, C, 1)) tsw_val = torch.gather(image, 2, (iz_tsw * IW * IH + iy_tsw * IW + ix_tsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) tse_val = torch.gather(image, 2, (iz_tse * IW * IH + iy_tse * IW + ix_tse).long().view(N, 1, D * H * W).repeat(1, C, 1)) bnw_val = torch.gather(image, 2, (iz_bnw * IW * IH + iy_bnw * IW + ix_bnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) bne_val = torch.gather(image, 2, (iz_bne * IW * IH + iy_bne * IW + ix_bne).long().view(N, 1, D * H * W).repeat(1, C, 1)) bsw_val = torch.gather(image, 2, (iz_bsw * IW * IH + iy_bsw * IW + ix_bsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) bse_val = torch.gather(image, 2, (iz_bse * IW * IH + iy_bse * IW + ix_bse).long().view(N, 1, D * H * W).repeat(1, C, 1)) out_val = (tnw_val.view(N, C, D, H, W) * tnw.view(N, 1, D, H, W) + tne_val.view(N, C, D, H, W) * tne.view(N, 1, D, H, W) + tsw_val.view(N, C, D, H, W) * tsw.view(N, 1, D, H, W) + tse_val.view(N, C, D, H, W) * tse.view(N, 1, D, H, W) + bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) + bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) + bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) + bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W)) return out_val def interpolate_feature(points, volume, bounds): """ points: batch_size, num_point, 3 volume: batch_size, num_channel, d, h, w bounds: 2, 3 """ grid_coords = get_grid_coords(points, bounds) grid_coords = grid_coords[:, None, None] # point_features = F.grid_sample(volume, # grid_coords, # padding_mode='zeros', # align_corners=True) point_features = grid_sample_3d(volume, grid_coords) point_features = point_features[:, :, 0, 0] return point_features