|
import torch |
|
|
|
|
|
def generate_grid(n_vox, interval): |
|
""" |
|
generate grid |
|
if 3D volume, grid[:,:,x,y,z] = (x,y,z) |
|
:param n_vox: |
|
:param interval: |
|
:return: |
|
""" |
|
with torch.no_grad(): |
|
|
|
grid_range = [torch.arange(0, n_vox[axis], interval) for axis in range(3)] |
|
grid = torch.stack(torch.meshgrid(grid_range[0], grid_range[1], grid_range[2], indexing="ij")) |
|
|
|
grid = grid.unsqueeze(0).type(torch.float32) |
|
|
|
return grid |
|
|
|
|
|
if __name__ == "__main__": |
|
import torch.nn.functional as F |
|
grid = generate_grid([5, 6, 8], 1) |
|
|
|
pts = 2 * torch.tensor([1, 2, 3]) / (torch.tensor([5, 6, 8]) - 1) - 1 |
|
pts = pts.view(1, 1, 1, 1, 3) |
|
|
|
pts = torch.flip(pts, dims=[-1]) |
|
|
|
sampled = F.grid_sample(grid, pts, mode='nearest') |
|
|
|
print(sampled) |
|
|