3v324v23's picture
init
2f85de4
raw
history blame
No virus
6.89 kB
# python3.8
"""Contains utility functions for rendering."""
import torch
def normalize_vecs(vectors):
"""
Normalize vector lengths.
"""
return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
def truncated_normal(tensor, mean=0, std=1):
"""
Samples from truncated normal distribution.
"""
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
return tensor
def get_grid_coords(points, bounds):
""" transform points from the world coordinate to the volume coordinate
pts: batch_size, num_point, 3
bounds: 2, 3
"""
# normalize the points
bounds = bounds[None]
min_xyz = bounds[:, :1]
points = points - min_xyz
# convert the voxel coordinate to [-1, 1]
size = bounds[:, 1] - bounds[:, 0]
points = (points / size[:, None]) * 2 - 1
return points
def grid_sample_3d(image, optical):
"""grid sample images by the optical in 3D format
image: batch_size, channel, D, H, W
optical: batch_size, D, H, W, 3
"""
N, C, ID, IH, IW = image.shape
_, D, H, W, _ = optical.shape
ix = optical[..., 0]
iy = optical[..., 1]
iz = optical[..., 2]
ix = ((ix + 1) / 2) * (IW - 1)
iy = ((iy + 1) / 2) * (IH - 1)
iz = ((iz + 1) / 2) * (ID - 1)
with torch.no_grad():
ix_tnw = torch.floor(ix)
iy_tnw = torch.floor(iy)
iz_tnw = torch.floor(iz)
ix_tne = ix_tnw + 1
iy_tne = iy_tnw
iz_tne = iz_tnw
ix_tsw = ix_tnw
iy_tsw = iy_tnw + 1
iz_tsw = iz_tnw
ix_tse = ix_tnw + 1
iy_tse = iy_tnw + 1
iz_tse = iz_tnw
ix_bnw = ix_tnw
iy_bnw = iy_tnw
iz_bnw = iz_tnw + 1
ix_bne = ix_tnw + 1
iy_bne = iy_tnw
iz_bne = iz_tnw + 1
ix_bsw = ix_tnw
iy_bsw = iy_tnw + 1
iz_bsw = iz_tnw + 1
ix_bse = ix_tnw + 1
iy_bse = iy_tnw + 1
iz_bse = iz_tnw + 1
tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz)
tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz)
tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz)
tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz)
bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse)
bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw)
bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne)
bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw)
with torch.no_grad():
torch.clamp(ix_tnw, 0, IW - 1, out=ix_tnw)
torch.clamp(iy_tnw, 0, IH - 1, out=iy_tnw)
torch.clamp(iz_tnw, 0, ID - 1, out=iz_tnw)
torch.clamp(ix_tne, 0, IW - 1, out=ix_tne)
torch.clamp(iy_tne, 0, IH - 1, out=iy_tne)
torch.clamp(iz_tne, 0, ID - 1, out=iz_tne)
torch.clamp(ix_tsw, 0, IW - 1, out=ix_tsw)
torch.clamp(iy_tsw, 0, IH - 1, out=iy_tsw)
torch.clamp(iz_tsw, 0, ID - 1, out=iz_tsw)
torch.clamp(ix_tse, 0, IW - 1, out=ix_tse)
torch.clamp(iy_tse, 0, IH - 1, out=iy_tse)
torch.clamp(iz_tse, 0, ID - 1, out=iz_tse)
torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw)
torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw)
torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw)
torch.clamp(ix_bne, 0, IW - 1, out=ix_bne)
torch.clamp(iy_bne, 0, IH - 1, out=iy_bne)
torch.clamp(iz_bne, 0, ID - 1, out=iz_bne)
torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw)
torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw)
torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw)
torch.clamp(ix_bse, 0, IW - 1, out=ix_bse)
torch.clamp(iy_bse, 0, IH - 1, out=iy_bse)
torch.clamp(iz_bse, 0, ID - 1, out=iz_bse)
image = image.view(N, C, ID * IH * IW)
tnw_val = torch.gather(image, 2,
(iz_tnw * IW * IH + iy_tnw * IW +
ix_tnw).long().view(N, 1,
D * H * W).repeat(1, C, 1))
tne_val = torch.gather(image, 2,
(iz_tne * IW * IH + iy_tne * IW +
ix_tne).long().view(N, 1,
D * H * W).repeat(1, C, 1))
tsw_val = torch.gather(image, 2,
(iz_tsw * IW * IH + iy_tsw * IW +
ix_tsw).long().view(N, 1,
D * H * W).repeat(1, C, 1))
tse_val = torch.gather(image, 2,
(iz_tse * IW * IH + iy_tse * IW +
ix_tse).long().view(N, 1,
D * H * W).repeat(1, C, 1))
bnw_val = torch.gather(image, 2,
(iz_bnw * IW * IH + iy_bnw * IW +
ix_bnw).long().view(N, 1,
D * H * W).repeat(1, C, 1))
bne_val = torch.gather(image, 2,
(iz_bne * IW * IH + iy_bne * IW +
ix_bne).long().view(N, 1,
D * H * W).repeat(1, C, 1))
bsw_val = torch.gather(image, 2,
(iz_bsw * IW * IH + iy_bsw * IW +
ix_bsw).long().view(N, 1,
D * H * W).repeat(1, C, 1))
bse_val = torch.gather(image, 2,
(iz_bse * IW * IH + iy_bse * IW +
ix_bse).long().view(N, 1,
D * H * W).repeat(1, C, 1))
out_val = (tnw_val.view(N, C, D, H, W) * tnw.view(N, 1, D, H, W) +
tne_val.view(N, C, D, H, W) * tne.view(N, 1, D, H, W) +
tsw_val.view(N, C, D, H, W) * tsw.view(N, 1, D, H, W) +
tse_val.view(N, C, D, H, W) * tse.view(N, 1, D, H, W) +
bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) +
bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) +
bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) +
bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W))
return out_val
def interpolate_feature(points, volume, bounds):
"""
points: batch_size, num_point, 3
volume: batch_size, num_channel, d, h, w
bounds: 2, 3
"""
grid_coords = get_grid_coords(points, bounds)
grid_coords = grid_coords[:, None, None]
# point_features = F.grid_sample(volume,
# grid_coords,
# padding_mode='zeros',
# align_corners=True)
point_features = grid_sample_3d(volume, grid_coords)
point_features = point_features[:, :, 0, 0]
return point_features