peirong26's picture
Upload 187 files
2571f24 verified
"""Generic N-dimensional version: any combination of spline orders"""
import torch
from typing import List, Optional, Tuple
from .bounds import Bound
from .splines import Spline
from .jit_utils import sub2ind_list, make_sign, list_prod_int, cartesian_prod
Tensor = torch.Tensor
@torch.jit.script
def inbounds_mask(extrapolate: int, grid, shape: List[int])\
-> Optional[Tensor]:
# mask of inbounds voxels
mask: Optional[Tensor] = None
if extrapolate in (0, 2): # no / hist
grid = grid.unsqueeze(1)
tiny = 5e-2
threshold = tiny
if extrapolate == 2:
threshold = 0.5 + tiny
mask = torch.ones(grid.shape[:-1],
dtype=torch.bool, device=grid.device)
for grid1, shape1 in zip(grid.unbind(-1), shape):
mask = mask & (grid1 > -threshold)
mask = mask & (grid1 < shape1 - 1 + threshold)
return mask
return mask
@torch.jit.script
def get_weights(grid, bound: List[Bound], spline: List[Spline],
shape: List[int], grad: bool = False, hess: bool = False) \
-> Tuple[List[List[Tensor]],
List[List[Optional[Tensor]]],
List[List[Optional[Tensor]]],
List[List[Tensor]],
List[List[Optional[Tensor]]]]:
weights: List[List[Tensor]] = []
grads: List[List[Optional[Tensor]]] = []
hesss: List[List[Optional[Tensor]]] = []
coords: List[List[Tensor]] = []
signs: List[List[Optional[Tensor]]] = []
for g, b, s, n in zip(grid.unbind(-1), bound, spline, shape):
grid0 = (g - (s.order-1)/2).floor()
dist0 = g - grid0
grid0 = grid0.long()
nb_nodes = s.order + 1
subweights: List[Tensor] = []
subcoords: List[Tensor] = []
subgrads: List[Optional[Tensor]] = []
subhesss: List[Optional[Tensor]] = []
subsigns: List[Optional[Tensor]] = []
for node in range(nb_nodes):
grid1 = grid0 + node
sign1: Optional[Tensor] = b.transform(grid1, n)
subsigns.append(sign1)
grid1 = b.index(grid1, n)
subcoords.append(grid1)
dist1 = dist0 - node
weight1 = s.fastweight(dist1)
subweights.append(weight1)
grad1: Optional[Tensor] = None
if grad:
grad1 = s.fastgrad(dist1)
subgrads.append(grad1)
hess1: Optional[Tensor] = None
if hess:
hess1 = s.fasthess(dist1)
subhesss.append(hess1)
weights.append(subweights)
coords.append(subcoords)
signs.append(subsigns)
grads.append(subgrads)
hesss.append(subhesss)
return weights, grads, hesss, coords, signs
@torch.jit.script
def pull(inp, grid, bound: List[Bound], spline: List[Spline],
extrapolate: int = 1):
"""
inp: (B, C, *ishape) tensor
g: (B, *oshape, D) tensor
bound: List{D}[Bound] tensor
spline: List{D}[Spline] tensor
extrapolate: int
returns: (B, C, *oshape) tensor
"""
dim = grid.shape[-1]
shape = list(inp.shape[-dim:])
oshape = list(grid.shape[-dim-1:-1])
batch = max(inp.shape[0], grid.shape[0])
channel = inp.shape[1]
grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]])
inp = inp.reshape([inp.shape[0], inp.shape[1], -1])
mask = inbounds_mask(extrapolate, grid, shape)
# precompute weights along each dimension
weights, _, _, coords, signs = get_weights(grid, bound, spline, shape, False, False)
# initialize
out = torch.zeros([batch, channel, grid.shape[1]],
dtype=inp.dtype, device=inp.device)
# iterate across nodes/corners
range_nodes = [torch.as_tensor([d for d in range(n)])
for n in [s.order + 1 for s in spline]]
if dim == 1:
# cartesian_prod does not work as expected when only one
# element is provided
all_nodes = range_nodes[0].unsqueeze(-1)
else:
all_nodes = cartesian_prod(range_nodes)
for nodes in all_nodes:
# gather
idx = [c[n] for c, n in zip(coords, nodes)]
idx = sub2ind_list(idx, shape).unsqueeze(1)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.gather(-1, idx)
# apply sign
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
sign1: Optional[Tensor] = make_sign(sign0)
if sign1 is not None:
out1 = out1 * sign1.unsqueeze(1)
# apply weights
for weight, n in zip(weights, nodes):
out1 = out1 * weight[n].unsqueeze(1)
# accumulate
out = out + out1
# out-of-bounds mask
if mask is not None:
out = out * mask
out = out.reshape(list(out.shape[:2]) + oshape)
return out
@torch.jit.script
def push(inp, grid, shape: Optional[List[int]], bound: List[Bound],
spline: List[Spline], extrapolate: int = 1):
"""
inp: (B, C, *ishape) tensor
g: (B, *ishape, D) tensor
shape: List{D}[int], optional
bound: List{D}[Bound] tensor
spline: List{D}[Spline] tensor
extrapolate: int
returns: (B, C, *oshape) tensor
"""
dim = grid.shape[-1]
ishape = list(grid.shape[-dim - 1:-1])
if shape is None:
shape = ishape
shape = list(shape)
batch = max(inp.shape[0], grid.shape[0])
channel = inp.shape[1]
grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]])
inp = inp.reshape([inp.shape[0], inp.shape[1], -1])
mask = inbounds_mask(extrapolate, grid, shape)
# precompute weights along each dimension
weights, _, _, coords, signs = get_weights(grid, bound, spline, shape)
# initialize
out = torch.zeros([batch, channel, list_prod_int(shape)],
dtype=inp.dtype, device=inp.device)
# iterate across nodes/corners
range_nodes = [torch.as_tensor([d for d in range(n)])
for n in [s.order + 1 for s in spline]]
if dim == 1:
# cartesian_prod does not work as expected when only one
# element is provided
all_nodes = range_nodes[0].unsqueeze(-1)
else:
all_nodes = cartesian_prod(range_nodes)
for nodes in all_nodes:
# gather
idx = [c[n] for c, n in zip(coords, nodes)]
idx = sub2ind_list(idx, shape).unsqueeze(1)
idx = idx.expand([batch, channel, idx.shape[-1]])
out1 = inp.clone()
# apply sign
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
sign1: Optional[Tensor] = make_sign(sign0)
if sign1 is not None:
out1 = out1 * sign1.unsqueeze(1)
# out-of-bounds mask
if mask is not None:
out1 = out1 * mask
# apply weights
for weight, n in zip(weights, nodes):
out1 = out1 * weight[n].unsqueeze(1)
# accumulate
out.scatter_add_(-1, idx, out1)
out = out.reshape(list(out.shape[:2]) + shape)
return out
@torch.jit.script
def grad(inp, grid, bound: List[Bound], spline: List[Spline],
extrapolate: int = 1):
"""
inp: (B, C, *ishape) tensor
grid: (B, *oshape, D) tensor
bound: List{D}[Bound] tensor
spline: List{D}[Spline] tensor
extrapolate: int
returns: (B, C, *oshape, D) tensor
"""
dim = grid.shape[-1]
shape = list(inp.shape[-dim:])
oshape = list(grid.shape[-dim-1:-1])
batch = max(inp.shape[0], grid.shape[0])
channel = inp.shape[1]
grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]])
inp = inp.reshape([inp.shape[0], inp.shape[1], -1])
mask = inbounds_mask(extrapolate, grid, shape)
# precompute weights along each dimension
weights, grads, _, coords, signs = get_weights(grid, bound, spline, shape,
grad=True)
# initialize
out = torch.zeros([batch, channel, grid.shape[1], dim],
dtype=inp.dtype, device=inp.device)
# iterate across nodes/corners
range_nodes = [torch.as_tensor([d for d in range(n)])
for n in [s.order + 1 for s in spline]]
if dim == 1:
# cartesian_prod does not work as expected when only one
# element is provided
all_nodes = range_nodes[0].unsqueeze(-1)
else:
all_nodes = cartesian_prod(range_nodes)
for nodes in all_nodes:
# gather
idx = [c[n] for c, n in zip(coords, nodes)]
idx = sub2ind_list(idx, shape).unsqueeze(1)
idx = idx.expand([batch, channel, idx.shape[-1]])
out0 = inp.gather(-1, idx)
# apply sign
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
sign1: Optional[Tensor] = make_sign(sign0)
if sign1 is not None:
out0 = out0 * sign1.unsqueeze(1)
for d in range(dim):
out1 = out0.clone()
# apply weights
for dd, (weight, grad1, n) in enumerate(zip(weights, grads, nodes)):
if d == dd:
grad11 = grad1[n]
if grad11 is not None:
out1 = out1 * grad11.unsqueeze(1)
else:
out1 = out1 * weight[n].unsqueeze(1)
# accumulate
out.unbind(-1)[d].add_(out1)
# out-of-bounds mask
if mask is not None:
out = out * mask.unsqueeze(-1)
out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-1:]))
return out
@torch.jit.script
def pushgrad(inp, grid, shape: Optional[List[int]], bound: List[Bound],
spline: List[Spline], extrapolate: int = 1):
"""
inp: (B, C, *ishape, D) tensor
g: (B, *ishape, D) tensor
shape: List{D}[int], optional
bound: List{D}[Bound] tensor
spline: List{D}[Spline] tensor
extrapolate: int
returns: (B, C, *shape) tensor
"""
dim = grid.shape[-1]
oshape = list(grid.shape[-dim-1:-1])
if shape is None:
shape = oshape
shape = list(shape)
batch = max(inp.shape[0], grid.shape[0])
channel = inp.shape[1]
grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]])
inp = inp.reshape([inp.shape[0], inp.shape[1], -1, dim])
mask = inbounds_mask(extrapolate, grid, shape)
# precompute weights along each dimension
weights, grads, _, coords, signs = get_weights(grid, bound, spline, shape, grad=True)
# initialize
out = torch.zeros([batch, channel, list_prod_int(shape)],
dtype=inp.dtype, device=inp.device)
# iterate across nodes/corners
range_nodes = [torch.as_tensor([d for d in range(n)])
for n in [s.order + 1 for s in spline]]
if dim == 1:
# cartesian_prod does not work as expected when only one
# element is provided
all_nodes = range_nodes[0].unsqueeze(-1)
else:
all_nodes = cartesian_prod(range_nodes)
for nodes in all_nodes:
# gather
idx = [c[n] for c, n in zip(coords, nodes)]
idx = sub2ind_list(idx, shape).unsqueeze(1)
idx = idx.expand([batch, channel, idx.shape[-1]])
out0 = inp.clone()
# apply sign
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
sign1: Optional[Tensor] = make_sign(sign0)
if sign1 is not None:
out0 = out0 * sign1.unsqueeze(1).unsqueeze(-1)
# out-of-bounds mask
if mask is not None:
out0 = out0 * mask.unsqueeze(-1)
for d in range(dim):
out1 = out0.unbind(-1)[d].clone()
# apply weights
for dd, (weight, grad1, n) in enumerate(zip(weights, grads, nodes)):
if d == dd:
grad11 = grad1[n]
if grad11 is not None:
out1 = out1 * grad11.unsqueeze(1)
else:
out1 = out1 * weight[n].unsqueeze(1)
# accumulate
out.scatter_add_(-1, idx, out1)
out = out.reshape(list(out.shape[:2]) + shape)
return out
@torch.jit.script
def hess(inp, grid, bound: List[Bound], spline: List[Spline],
extrapolate: int = 1):
"""
inp: (B, C, *ishape) tensor
grid: (B, *oshape, D) tensor
bound: List{D}[Bound] tensor
spline: List{D}[Spline] tensor
extrapolate: int
returns: (B, C, *oshape, D, D) tensor
"""
dim = grid.shape[-1]
shape = list(inp.shape[-dim:])
oshape = list(grid.shape[-dim-1:-1])
batch = max(inp.shape[0], grid.shape[0])
channel = inp.shape[1]
grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]])
inp = inp.reshape([inp.shape[0], inp.shape[1], -1])
mask = inbounds_mask(extrapolate, grid, shape)
# precompute weights along each dimension
weights, grads, hesss, coords, signs \
= get_weights(grid, bound, spline, shape, grad=True, hess=True)
# initialize
out = torch.zeros([batch, channel, grid.shape[1], dim, dim],
dtype=inp.dtype, device=inp.device)
# iterate across nodes/corners
range_nodes = [torch.as_tensor([d for d in range(n)])
for n in [s.order + 1 for s in spline]]
if dim == 1:
# cartesian_prod does not work as expected when only one
# element is provided
all_nodes = range_nodes[0].unsqueeze(-1)
else:
all_nodes = cartesian_prod(range_nodes)
for nodes in all_nodes:
# gather
idx = [c[n] for c, n in zip(coords, nodes)]
idx = sub2ind_list(idx, shape).unsqueeze(1)
idx = idx.expand([batch, channel, idx.shape[-1]])
out0 = inp.gather(-1, idx)
# apply sign
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)]
sign1: Optional[Tensor] = make_sign(sign0)
if sign1 is not None:
out0 = out0 * sign1.unsqueeze(1)
for d in range(dim):
# -- diagonal --
out1 = out0.clone()
# apply weights
for dd, (weight, hess1, n) \
in enumerate(zip(weights, hesss, nodes)):
if d == dd:
hess11 = hess1[n]
if hess11 is not None:
out1 = out1 * hess11.unsqueeze(1)
else:
out1 = out1 * weight[n].unsqueeze(1)
# accumulate
out.unbind(-1)[d].unbind(-1)[d].add_(out1)
# -- off diagonal --
for d2 in range(d+1, dim):
out1 = out0.clone()
# apply weights
for dd, (weight, grad1, n) \
in enumerate(zip(weights, grads, nodes)):
if dd in (d, d2):
grad11 = grad1[n]
if grad11 is not None:
out1 = out1 * grad11.unsqueeze(1)
else:
out1 = out1 * weight[n].unsqueeze(1)
# accumulate
out.unbind(-1)[d].unbind(-1)[d2].add_(out1)
# out-of-bounds mask
if mask is not None:
out = out * mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
# fill lower triangle
for d in range(dim):
for d2 in range(d+1, dim):
out.unbind(-1)[d2].unbind(-1)[d].copy_(out.unbind(-1)[d].unbind(-1)[d2])
out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-2:]))
return out