|
"""Generic N-dimensional version: any combination of spline orders""" |
|
import torch |
|
from typing import List, Optional, Tuple |
|
from .bounds import Bound |
|
from .splines import Spline |
|
from .jit_utils import sub2ind_list, make_sign, list_prod_int, cartesian_prod |
|
Tensor = torch.Tensor |
|
|
|
|
|
@torch.jit.script |
|
def inbounds_mask(extrapolate: int, grid, shape: List[int])\ |
|
-> Optional[Tensor]: |
|
|
|
mask: Optional[Tensor] = None |
|
if extrapolate in (0, 2): |
|
grid = grid.unsqueeze(1) |
|
tiny = 5e-2 |
|
threshold = tiny |
|
if extrapolate == 2: |
|
threshold = 0.5 + tiny |
|
mask = torch.ones(grid.shape[:-1], |
|
dtype=torch.bool, device=grid.device) |
|
for grid1, shape1 in zip(grid.unbind(-1), shape): |
|
mask = mask & (grid1 > -threshold) |
|
mask = mask & (grid1 < shape1 - 1 + threshold) |
|
return mask |
|
return mask |
|
|
|
|
|
@torch.jit.script |
|
def get_weights(grid, bound: List[Bound], spline: List[Spline], |
|
shape: List[int], grad: bool = False, hess: bool = False) \ |
|
-> Tuple[List[List[Tensor]], |
|
List[List[Optional[Tensor]]], |
|
List[List[Optional[Tensor]]], |
|
List[List[Tensor]], |
|
List[List[Optional[Tensor]]]]: |
|
|
|
weights: List[List[Tensor]] = [] |
|
grads: List[List[Optional[Tensor]]] = [] |
|
hesss: List[List[Optional[Tensor]]] = [] |
|
coords: List[List[Tensor]] = [] |
|
signs: List[List[Optional[Tensor]]] = [] |
|
for g, b, s, n in zip(grid.unbind(-1), bound, spline, shape): |
|
grid0 = (g - (s.order-1)/2).floor() |
|
dist0 = g - grid0 |
|
grid0 = grid0.long() |
|
nb_nodes = s.order + 1 |
|
subweights: List[Tensor] = [] |
|
subcoords: List[Tensor] = [] |
|
subgrads: List[Optional[Tensor]] = [] |
|
subhesss: List[Optional[Tensor]] = [] |
|
subsigns: List[Optional[Tensor]] = [] |
|
for node in range(nb_nodes): |
|
grid1 = grid0 + node |
|
sign1: Optional[Tensor] = b.transform(grid1, n) |
|
subsigns.append(sign1) |
|
grid1 = b.index(grid1, n) |
|
subcoords.append(grid1) |
|
dist1 = dist0 - node |
|
weight1 = s.fastweight(dist1) |
|
subweights.append(weight1) |
|
grad1: Optional[Tensor] = None |
|
if grad: |
|
grad1 = s.fastgrad(dist1) |
|
subgrads.append(grad1) |
|
hess1: Optional[Tensor] = None |
|
if hess: |
|
hess1 = s.fasthess(dist1) |
|
subhesss.append(hess1) |
|
weights.append(subweights) |
|
coords.append(subcoords) |
|
signs.append(subsigns) |
|
grads.append(subgrads) |
|
hesss.append(subhesss) |
|
|
|
return weights, grads, hesss, coords, signs |
|
|
|
|
|
@torch.jit.script |
|
def pull(inp, grid, bound: List[Bound], spline: List[Spline], |
|
extrapolate: int = 1): |
|
""" |
|
inp: (B, C, *ishape) tensor |
|
g: (B, *oshape, D) tensor |
|
bound: List{D}[Bound] tensor |
|
spline: List{D}[Spline] tensor |
|
extrapolate: int |
|
returns: (B, C, *oshape) tensor |
|
""" |
|
|
|
dim = grid.shape[-1] |
|
shape = list(inp.shape[-dim:]) |
|
oshape = list(grid.shape[-dim-1:-1]) |
|
batch = max(inp.shape[0], grid.shape[0]) |
|
channel = inp.shape[1] |
|
|
|
grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) |
|
inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) |
|
mask = inbounds_mask(extrapolate, grid, shape) |
|
|
|
|
|
weights, _, _, coords, signs = get_weights(grid, bound, spline, shape, False, False) |
|
|
|
|
|
out = torch.zeros([batch, channel, grid.shape[1]], |
|
dtype=inp.dtype, device=inp.device) |
|
|
|
|
|
range_nodes = [torch.as_tensor([d for d in range(n)]) |
|
for n in [s.order + 1 for s in spline]] |
|
if dim == 1: |
|
|
|
|
|
all_nodes = range_nodes[0].unsqueeze(-1) |
|
else: |
|
all_nodes = cartesian_prod(range_nodes) |
|
for nodes in all_nodes: |
|
|
|
idx = [c[n] for c, n in zip(coords, nodes)] |
|
idx = sub2ind_list(idx, shape).unsqueeze(1) |
|
idx = idx.expand([batch, channel, idx.shape[-1]]) |
|
out1 = inp.gather(-1, idx) |
|
|
|
|
|
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] |
|
sign1: Optional[Tensor] = make_sign(sign0) |
|
if sign1 is not None: |
|
out1 = out1 * sign1.unsqueeze(1) |
|
|
|
|
|
for weight, n in zip(weights, nodes): |
|
out1 = out1 * weight[n].unsqueeze(1) |
|
|
|
|
|
out = out + out1 |
|
|
|
|
|
if mask is not None: |
|
out = out * mask |
|
|
|
out = out.reshape(list(out.shape[:2]) + oshape) |
|
return out |
|
|
|
|
|
@torch.jit.script |
|
def push(inp, grid, shape: Optional[List[int]], bound: List[Bound], |
|
spline: List[Spline], extrapolate: int = 1): |
|
""" |
|
inp: (B, C, *ishape) tensor |
|
g: (B, *ishape, D) tensor |
|
shape: List{D}[int], optional |
|
bound: List{D}[Bound] tensor |
|
spline: List{D}[Spline] tensor |
|
extrapolate: int |
|
returns: (B, C, *oshape) tensor |
|
""" |
|
|
|
dim = grid.shape[-1] |
|
ishape = list(grid.shape[-dim - 1:-1]) |
|
if shape is None: |
|
shape = ishape |
|
shape = list(shape) |
|
batch = max(inp.shape[0], grid.shape[0]) |
|
channel = inp.shape[1] |
|
|
|
grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) |
|
inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) |
|
mask = inbounds_mask(extrapolate, grid, shape) |
|
|
|
|
|
weights, _, _, coords, signs = get_weights(grid, bound, spline, shape) |
|
|
|
|
|
out = torch.zeros([batch, channel, list_prod_int(shape)], |
|
dtype=inp.dtype, device=inp.device) |
|
|
|
|
|
range_nodes = [torch.as_tensor([d for d in range(n)]) |
|
for n in [s.order + 1 for s in spline]] |
|
if dim == 1: |
|
|
|
|
|
all_nodes = range_nodes[0].unsqueeze(-1) |
|
else: |
|
all_nodes = cartesian_prod(range_nodes) |
|
for nodes in all_nodes: |
|
|
|
|
|
idx = [c[n] for c, n in zip(coords, nodes)] |
|
idx = sub2ind_list(idx, shape).unsqueeze(1) |
|
idx = idx.expand([batch, channel, idx.shape[-1]]) |
|
out1 = inp.clone() |
|
|
|
|
|
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] |
|
sign1: Optional[Tensor] = make_sign(sign0) |
|
if sign1 is not None: |
|
out1 = out1 * sign1.unsqueeze(1) |
|
|
|
|
|
if mask is not None: |
|
out1 = out1 * mask |
|
|
|
|
|
for weight, n in zip(weights, nodes): |
|
out1 = out1 * weight[n].unsqueeze(1) |
|
|
|
|
|
out.scatter_add_(-1, idx, out1) |
|
|
|
out = out.reshape(list(out.shape[:2]) + shape) |
|
return out |
|
|
|
|
|
@torch.jit.script |
|
def grad(inp, grid, bound: List[Bound], spline: List[Spline], |
|
extrapolate: int = 1): |
|
""" |
|
inp: (B, C, *ishape) tensor |
|
grid: (B, *oshape, D) tensor |
|
bound: List{D}[Bound] tensor |
|
spline: List{D}[Spline] tensor |
|
extrapolate: int |
|
returns: (B, C, *oshape, D) tensor |
|
""" |
|
|
|
dim = grid.shape[-1] |
|
shape = list(inp.shape[-dim:]) |
|
oshape = list(grid.shape[-dim-1:-1]) |
|
batch = max(inp.shape[0], grid.shape[0]) |
|
channel = inp.shape[1] |
|
|
|
grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) |
|
inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) |
|
mask = inbounds_mask(extrapolate, grid, shape) |
|
|
|
|
|
weights, grads, _, coords, signs = get_weights(grid, bound, spline, shape, |
|
grad=True) |
|
|
|
|
|
out = torch.zeros([batch, channel, grid.shape[1], dim], |
|
dtype=inp.dtype, device=inp.device) |
|
|
|
|
|
range_nodes = [torch.as_tensor([d for d in range(n)]) |
|
for n in [s.order + 1 for s in spline]] |
|
if dim == 1: |
|
|
|
|
|
all_nodes = range_nodes[0].unsqueeze(-1) |
|
else: |
|
all_nodes = cartesian_prod(range_nodes) |
|
for nodes in all_nodes: |
|
|
|
|
|
idx = [c[n] for c, n in zip(coords, nodes)] |
|
idx = sub2ind_list(idx, shape).unsqueeze(1) |
|
idx = idx.expand([batch, channel, idx.shape[-1]]) |
|
out0 = inp.gather(-1, idx) |
|
|
|
|
|
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] |
|
sign1: Optional[Tensor] = make_sign(sign0) |
|
if sign1 is not None: |
|
out0 = out0 * sign1.unsqueeze(1) |
|
|
|
for d in range(dim): |
|
out1 = out0.clone() |
|
|
|
for dd, (weight, grad1, n) in enumerate(zip(weights, grads, nodes)): |
|
if d == dd: |
|
grad11 = grad1[n] |
|
if grad11 is not None: |
|
out1 = out1 * grad11.unsqueeze(1) |
|
else: |
|
out1 = out1 * weight[n].unsqueeze(1) |
|
|
|
|
|
out.unbind(-1)[d].add_(out1) |
|
|
|
|
|
if mask is not None: |
|
out = out * mask.unsqueeze(-1) |
|
|
|
out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-1:])) |
|
return out |
|
|
|
|
|
@torch.jit.script |
|
def pushgrad(inp, grid, shape: Optional[List[int]], bound: List[Bound], |
|
spline: List[Spline], extrapolate: int = 1): |
|
""" |
|
inp: (B, C, *ishape, D) tensor |
|
g: (B, *ishape, D) tensor |
|
shape: List{D}[int], optional |
|
bound: List{D}[Bound] tensor |
|
spline: List{D}[Spline] tensor |
|
extrapolate: int |
|
returns: (B, C, *shape) tensor |
|
""" |
|
dim = grid.shape[-1] |
|
oshape = list(grid.shape[-dim-1:-1]) |
|
if shape is None: |
|
shape = oshape |
|
shape = list(shape) |
|
batch = max(inp.shape[0], grid.shape[0]) |
|
channel = inp.shape[1] |
|
|
|
grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) |
|
inp = inp.reshape([inp.shape[0], inp.shape[1], -1, dim]) |
|
mask = inbounds_mask(extrapolate, grid, shape) |
|
|
|
|
|
weights, grads, _, coords, signs = get_weights(grid, bound, spline, shape, grad=True) |
|
|
|
|
|
out = torch.zeros([batch, channel, list_prod_int(shape)], |
|
dtype=inp.dtype, device=inp.device) |
|
|
|
|
|
range_nodes = [torch.as_tensor([d for d in range(n)]) |
|
for n in [s.order + 1 for s in spline]] |
|
if dim == 1: |
|
|
|
|
|
all_nodes = range_nodes[0].unsqueeze(-1) |
|
else: |
|
all_nodes = cartesian_prod(range_nodes) |
|
for nodes in all_nodes: |
|
|
|
|
|
idx = [c[n] for c, n in zip(coords, nodes)] |
|
idx = sub2ind_list(idx, shape).unsqueeze(1) |
|
idx = idx.expand([batch, channel, idx.shape[-1]]) |
|
out0 = inp.clone() |
|
|
|
|
|
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] |
|
sign1: Optional[Tensor] = make_sign(sign0) |
|
if sign1 is not None: |
|
out0 = out0 * sign1.unsqueeze(1).unsqueeze(-1) |
|
|
|
|
|
if mask is not None: |
|
out0 = out0 * mask.unsqueeze(-1) |
|
|
|
for d in range(dim): |
|
out1 = out0.unbind(-1)[d].clone() |
|
|
|
for dd, (weight, grad1, n) in enumerate(zip(weights, grads, nodes)): |
|
if d == dd: |
|
grad11 = grad1[n] |
|
if grad11 is not None: |
|
out1 = out1 * grad11.unsqueeze(1) |
|
else: |
|
out1 = out1 * weight[n].unsqueeze(1) |
|
|
|
|
|
out.scatter_add_(-1, idx, out1) |
|
|
|
out = out.reshape(list(out.shape[:2]) + shape) |
|
return out |
|
|
|
|
|
@torch.jit.script |
|
def hess(inp, grid, bound: List[Bound], spline: List[Spline], |
|
extrapolate: int = 1): |
|
""" |
|
inp: (B, C, *ishape) tensor |
|
grid: (B, *oshape, D) tensor |
|
bound: List{D}[Bound] tensor |
|
spline: List{D}[Spline] tensor |
|
extrapolate: int |
|
returns: (B, C, *oshape, D, D) tensor |
|
""" |
|
|
|
dim = grid.shape[-1] |
|
shape = list(inp.shape[-dim:]) |
|
oshape = list(grid.shape[-dim-1:-1]) |
|
batch = max(inp.shape[0], grid.shape[0]) |
|
channel = inp.shape[1] |
|
|
|
grid = grid.reshape([grid.shape[0], -1, grid.shape[-1]]) |
|
inp = inp.reshape([inp.shape[0], inp.shape[1], -1]) |
|
mask = inbounds_mask(extrapolate, grid, shape) |
|
|
|
|
|
weights, grads, hesss, coords, signs \ |
|
= get_weights(grid, bound, spline, shape, grad=True, hess=True) |
|
|
|
|
|
out = torch.zeros([batch, channel, grid.shape[1], dim, dim], |
|
dtype=inp.dtype, device=inp.device) |
|
|
|
|
|
range_nodes = [torch.as_tensor([d for d in range(n)]) |
|
for n in [s.order + 1 for s in spline]] |
|
if dim == 1: |
|
|
|
|
|
all_nodes = range_nodes[0].unsqueeze(-1) |
|
else: |
|
all_nodes = cartesian_prod(range_nodes) |
|
for nodes in all_nodes: |
|
|
|
|
|
idx = [c[n] for c, n in zip(coords, nodes)] |
|
idx = sub2ind_list(idx, shape).unsqueeze(1) |
|
idx = idx.expand([batch, channel, idx.shape[-1]]) |
|
out0 = inp.gather(-1, idx) |
|
|
|
|
|
sign0: List[Optional[Tensor]] = [sgn[n] for sgn, n in zip(signs, nodes)] |
|
sign1: Optional[Tensor] = make_sign(sign0) |
|
if sign1 is not None: |
|
out0 = out0 * sign1.unsqueeze(1) |
|
|
|
for d in range(dim): |
|
|
|
out1 = out0.clone() |
|
|
|
|
|
for dd, (weight, hess1, n) \ |
|
in enumerate(zip(weights, hesss, nodes)): |
|
if d == dd: |
|
hess11 = hess1[n] |
|
if hess11 is not None: |
|
out1 = out1 * hess11.unsqueeze(1) |
|
else: |
|
out1 = out1 * weight[n].unsqueeze(1) |
|
|
|
|
|
out.unbind(-1)[d].unbind(-1)[d].add_(out1) |
|
|
|
|
|
for d2 in range(d+1, dim): |
|
out1 = out0.clone() |
|
|
|
|
|
for dd, (weight, grad1, n) \ |
|
in enumerate(zip(weights, grads, nodes)): |
|
if dd in (d, d2): |
|
grad11 = grad1[n] |
|
if grad11 is not None: |
|
out1 = out1 * grad11.unsqueeze(1) |
|
else: |
|
out1 = out1 * weight[n].unsqueeze(1) |
|
|
|
|
|
out.unbind(-1)[d].unbind(-1)[d2].add_(out1) |
|
|
|
|
|
if mask is not None: |
|
out = out * mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) |
|
|
|
|
|
for d in range(dim): |
|
for d2 in range(d+1, dim): |
|
out.unbind(-1)[d2].unbind(-1)[d].copy_(out.unbind(-1)[d].unbind(-1)[d2]) |
|
|
|
out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-2:])) |
|
return out |
|
|