|
""" |
|
Non-differentiable forward/backward components. |
|
These components are put together in `interpol.autograd` to generate |
|
differentiable functions. |
|
|
|
Note |
|
---- |
|
.. I removed @torch.jit.script from these entry-points because compiling |
|
all possible combinations of bound+interpolation made the first call |
|
extremely slow. |
|
.. I am not using the dot/multi_dot helpers even though they should be |
|
more efficient that "multiply and sum" because I haven't had the time |
|
to test them. It would be worth doing it. |
|
""" |
|
import torch |
|
from typing import List, Optional, Tuple |
|
from .jit_utils import list_all, dot, dot_multi, pad_list_int |
|
from .bounds import Bound |
|
from .splines import Spline |
|
from . import iso0, iso1, nd |
|
Tensor = torch.Tensor |
|
|
|
|
|
@torch.jit.script |
|
def make_bound(bound: List[int]) -> List[Bound]: |
|
return [Bound(b) for b in bound] |
|
|
|
|
|
@torch.jit.script |
|
def make_spline(spline: List[int]) -> List[Spline]: |
|
return [Spline(s) for s in spline] |
|
|
|
|
|
|
|
def grid_pull(inp, grid, bound: List[int], interpolation: List[int], |
|
extrapolate: int): |
|
""" |
|
inp: (B, C, *spatial_in) tensor |
|
grid: (B, *spatial_out, D) tensor |
|
bound: List{D}[int] tensor |
|
interpolation: List{D}[int] |
|
extrapolate: int |
|
returns: (B, C, *spatial_out) tensor |
|
""" |
|
dim = grid.shape[-1] |
|
bound = pad_list_int(bound, dim) |
|
interpolation = pad_list_int(interpolation, dim) |
|
bound_fn = make_bound(bound) |
|
is_iso1 = list_all([order == 1 for order in interpolation]) |
|
if is_iso1: |
|
if dim == 3: |
|
return iso1.pull3d(inp, grid, bound_fn, extrapolate) |
|
elif dim == 2: |
|
return iso1.pull2d(inp, grid, bound_fn, extrapolate) |
|
elif dim == 1: |
|
return iso1.pull1d(inp, grid, bound_fn, extrapolate) |
|
is_iso0 = list_all([order == 0 for order in interpolation]) |
|
if is_iso0: |
|
if dim == 3: |
|
return iso0.pull3d(inp, grid, bound_fn, extrapolate) |
|
elif dim == 2: |
|
return iso0.pull2d(inp, grid, bound_fn, extrapolate) |
|
elif dim == 1: |
|
return iso0.pull1d(inp, grid, bound_fn, extrapolate) |
|
spline_fn = make_spline(interpolation) |
|
return nd.pull(inp, grid, bound_fn, spline_fn, extrapolate) |
|
|
|
|
|
|
|
def grid_push(inp, grid, shape: Optional[List[int]], bound: List[int], |
|
interpolation: List[int], extrapolate: int): |
|
""" |
|
inp: (B, C, *spatial_in) tensor |
|
grid: (B, *spatial_in, D) tensor |
|
shape: List{D}[int] tensor, optional, default=spatial_in |
|
bound: List{D}[int] tensor |
|
interpolation: List{D}[int] |
|
extrapolate: int |
|
returns: (B, C, *shape) tensor |
|
""" |
|
dim = grid.shape[-1] |
|
bound = pad_list_int(bound, dim) |
|
interpolation = pad_list_int(interpolation, dim) |
|
bound_fn = make_bound(bound) |
|
is_iso1 = list_all([order == 1 for order in interpolation]) |
|
if is_iso1: |
|
if dim == 3: |
|
return iso1.push3d(inp, grid, shape, bound_fn, extrapolate) |
|
elif dim == 2: |
|
return iso1.push2d(inp, grid, shape, bound_fn, extrapolate) |
|
elif dim == 1: |
|
return iso1.push1d(inp, grid, shape, bound_fn, extrapolate) |
|
is_iso0 = list_all([order == 0 for order in interpolation]) |
|
if is_iso0: |
|
if dim == 3: |
|
return iso0.push3d(inp, grid, shape, bound_fn, extrapolate) |
|
elif dim == 2: |
|
return iso0.push2d(inp, grid, shape, bound_fn, extrapolate) |
|
elif dim == 1: |
|
return iso0.push1d(inp, grid, shape, bound_fn, extrapolate) |
|
spline_fn = make_spline(interpolation) |
|
return nd.push(inp, grid, shape, bound_fn, spline_fn, extrapolate) |
|
|
|
|
|
|
|
def grid_count(grid, shape: Optional[List[int]], bound: List[int], |
|
interpolation: List[int], extrapolate: int): |
|
""" |
|
grid: (B, *spatial_in, D) tensor |
|
shape: List{D}[int] tensor, optional, default=spatial_in |
|
bound: List{D}[int] tensor |
|
interpolation: List{D}[int] |
|
extrapolate: int |
|
returns: (B, 1, *shape) tensor |
|
""" |
|
dim = grid.shape[-1] |
|
bound = pad_list_int(bound, dim) |
|
interpolation = pad_list_int(interpolation, dim) |
|
bound_fn = make_bound(bound) |
|
gshape = list(grid.shape[-dim-1:-1]) |
|
if shape is None: |
|
shape = gshape |
|
inp = torch.ones([], dtype=grid.dtype, device=grid.device) |
|
inp = inp.expand([len(grid), 1] + gshape) |
|
is_iso1 = list_all([order == 1 for order in interpolation]) |
|
if is_iso1: |
|
if dim == 3: |
|
return iso1.push3d(inp, grid, shape, bound_fn, extrapolate) |
|
elif dim == 2: |
|
return iso1.push2d(inp, grid, shape, bound_fn, extrapolate) |
|
elif dim == 1: |
|
return iso1.push1d(inp, grid, shape, bound_fn, extrapolate) |
|
is_iso0 = list_all([order == 0 for order in interpolation]) |
|
if is_iso0: |
|
if dim == 3: |
|
return iso0.push3d(inp, grid, shape, bound_fn, extrapolate) |
|
elif dim == 2: |
|
return iso0.push2d(inp, grid, shape, bound_fn, extrapolate) |
|
elif dim == 1: |
|
return iso0.push1d(inp, grid, shape, bound_fn, extrapolate) |
|
spline_fn = make_spline(interpolation) |
|
return nd.push(inp, grid, shape, bound_fn, spline_fn, extrapolate) |
|
|
|
|
|
|
|
def grid_grad(inp, grid, bound: List[int], interpolation: List[int], |
|
extrapolate: int): |
|
""" |
|
inp: (B, C, *spatial_in) tensor |
|
grid: (B, *spatial_out, D) tensor |
|
bound: List{D}[int] tensor |
|
interpolation: List{D}[int] |
|
extrapolate: int |
|
returns: (B, C, *spatial_out, D) tensor |
|
""" |
|
dim = grid.shape[-1] |
|
bound = pad_list_int(bound, dim) |
|
interpolation = pad_list_int(interpolation, dim) |
|
bound_fn = make_bound(bound) |
|
is_iso1 = list_all([order == 1 for order in interpolation]) |
|
if is_iso1: |
|
if dim == 3: |
|
return iso1.grad3d(inp, grid, bound_fn, extrapolate) |
|
elif dim == 2: |
|
return iso1.grad2d(inp, grid, bound_fn, extrapolate) |
|
elif dim == 1: |
|
return iso1.grad1d(inp, grid, bound_fn, extrapolate) |
|
is_iso0 = list_all([order == 0 for order in interpolation]) |
|
if is_iso0: |
|
return iso0.grad(inp, grid, bound_fn, extrapolate) |
|
spline_fn = make_spline(interpolation) |
|
return nd.grad(inp, grid, bound_fn, spline_fn, extrapolate) |
|
|
|
|
|
|
|
def grid_pushgrad(inp, grid, shape: List[int], bound: List[int], |
|
interpolation: List[int], extrapolate: int): |
|
""" /!\ Used only in backward pass of grid_grad |
|
inp: (B, C, *spatial_in, D) tensor |
|
grid: (B, *spatial_in, D) tensor |
|
shape: List{D}[int], optional |
|
bound: List{D}[int] tensor |
|
interpolation: List{D}[int] |
|
extrapolate: int |
|
returns: (B, C, *shape) tensor |
|
""" |
|
dim = grid.shape[-1] |
|
bound = pad_list_int(bound, dim) |
|
interpolation = pad_list_int(interpolation, dim) |
|
bound_fn = make_bound(bound) |
|
is_iso1 = list_all([order == 1 for order in interpolation]) |
|
if is_iso1: |
|
if dim == 3: |
|
return iso1.pushgrad3d(inp, grid, shape, bound_fn, extrapolate) |
|
elif dim == 2: |
|
return iso1.pushgrad2d(inp, grid, shape, bound_fn, extrapolate) |
|
elif dim == 1: |
|
return iso1.pushgrad1d(inp, grid, shape, bound_fn, extrapolate) |
|
is_iso0 = list_all([order == 0 for order in interpolation]) |
|
if is_iso0: |
|
return iso0.pushgrad(inp, grid, shape, bound_fn, extrapolate) |
|
spline_fn = make_spline(interpolation) |
|
return nd.pushgrad(inp, grid, shape, bound_fn, spline_fn, extrapolate) |
|
|
|
|
|
|
|
def grid_hess(inp, grid, bound: List[int], interpolation: List[int], |
|
extrapolate: int): |
|
""" /!\ Used only in backward pass of grid_grad |
|
inp: (B, C, *spatial_in) tensor |
|
grid: (B, *spatial_out, D) tensor |
|
bound: List{D}[int] tensor |
|
interpolation: List{D}[int] |
|
extrapolate: int |
|
returns: (B, C, *spatial_out, D, D) tensor |
|
""" |
|
dim = grid.shape[-1] |
|
bound = pad_list_int(bound, dim) |
|
interpolation = pad_list_int(interpolation, dim) |
|
bound_fn = make_bound(bound) |
|
is_iso1 = list_all([order == 1 for order in interpolation]) |
|
if is_iso1: |
|
if dim == 3: |
|
return iso1.hess3d(inp, grid, bound_fn, extrapolate) |
|
if dim == 2: |
|
return iso1.hess2d(inp, grid, bound_fn, extrapolate) |
|
if dim == 1: |
|
return iso1.hess1d(inp, grid, bound_fn, extrapolate) |
|
is_iso0 = list_all([order == 0 for order in interpolation]) |
|
if is_iso0: |
|
return iso0.hess(inp, grid, bound_fn, extrapolate) |
|
spline_fn = make_spline(interpolation) |
|
return nd.hess(inp, grid, bound_fn, spline_fn, extrapolate) |
|
|
|
|
|
|
|
def grid_pull_backward(grad, inp, grid, bound: List[int], |
|
interpolation: List[int], extrapolate: int) \ |
|
-> Tuple[Optional[Tensor], Optional[Tensor], ]: |
|
""" |
|
grad: (B, C, *spatial_out) tensor |
|
inp: (B, C, *spatial_in) tensor |
|
grid: (B, *spatial_out, D) tensor |
|
bound: List{D}[int] tensor |
|
interpolation: List{D}[int] |
|
extrapolate: int |
|
returns: (B, C, *spatial_in) tensor, (B, *spatial_out, D) |
|
""" |
|
dim = grid.shape[-1] |
|
grad_inp: Optional[Tensor] = None |
|
grad_grid: Optional[Tensor] = None |
|
if inp.requires_grad: |
|
grad_inp = grid_push(grad, grid, inp.shape[-dim:], bound, interpolation, extrapolate) |
|
if grid.requires_grad: |
|
grad_grid = grid_grad(inp, grid, bound, interpolation, extrapolate) |
|
|
|
grad_grid = (grad_grid * grad.unsqueeze(-1)).sum(dim=1) |
|
return grad_inp, grad_grid |
|
|
|
|
|
|
|
def grid_push_backward(grad, inp, grid, bound: List[int], |
|
interpolation: List[int], extrapolate: int) \ |
|
-> Tuple[Optional[Tensor], Optional[Tensor], ]: |
|
""" |
|
grad: (B, C, *spatial_out) tensor |
|
inp: (B, C, *spatial_in) tensor |
|
grid: (B, *spatial_in, D) tensor |
|
bound: List{D}[int] tensor |
|
interpolation: List{D}[int] |
|
extrapolate: int |
|
returns: (B, C, *spatial_in) tensor, (B, *spatial_in, D) |
|
""" |
|
grad_inp: Optional[Tensor] = None |
|
grad_grid: Optional[Tensor] = None |
|
if inp.requires_grad: |
|
grad_inp = grid_pull(grad, grid, bound, interpolation, extrapolate) |
|
if grid.requires_grad: |
|
grad_grid = grid_grad(grad, grid, bound, interpolation, extrapolate) |
|
|
|
grad_grid = (grad_grid * inp.unsqueeze(-1)).sum(dim=1) |
|
return grad_inp, grad_grid |
|
|
|
|
|
|
|
def grid_count_backward(grad, grid, bound: List[int], |
|
interpolation: List[int], extrapolate: int) \ |
|
-> Optional[Tensor]: |
|
""" |
|
grad: (B, C, *spatial_out) tensor |
|
grid: (B, *spatial_in, D) tensor |
|
bound: List{D}[int] tensor |
|
interpolation: List{D}[int] |
|
extrapolate: int |
|
returns: (B, C, *spatial_in) tensor, (B, *spatial_in, D) |
|
""" |
|
if grid.requires_grad: |
|
return grid_grad(grad, grid, bound, interpolation, extrapolate).sum(1) |
|
return None |
|
|
|
|
|
|
|
def grid_grad_backward(grad, inp, grid, bound: List[int], |
|
interpolation: List[int], extrapolate: int) \ |
|
-> Tuple[Optional[Tensor], Optional[Tensor]]: |
|
""" |
|
grad: (B, C, *spatial_out, D) tensor |
|
inp: (B, C, *spatial_in) tensor |
|
grid: (B, *spatial_out, D) tensor |
|
bound: List{D}[int] tensor |
|
interpolation: List{D}[int] |
|
extrapolate: int |
|
returns: (B, C, *spatial_in, D) tensor, (B, *spatial_out, D) |
|
""" |
|
dim = grid.shape[-1] |
|
shape = inp.shape[-dim:] |
|
grad_inp: Optional[Tensor] = None |
|
grad_grid: Optional[Tensor] = None |
|
if inp.requires_grad: |
|
grad_inp = grid_pushgrad(grad, grid, shape, bound, interpolation, extrapolate) |
|
if grid.requires_grad: |
|
grad_grid = grid_hess(inp, grid, bound, interpolation, extrapolate) |
|
|
|
grad_grid = (grad_grid * grad.unsqueeze(-1)).sum(dim=[1, -2]) |
|
return grad_inp, grad_grid |
|
|