BrainFM / utils /interpol /pushpull.py
peirong26's picture
Upload 187 files
2571f24 verified
"""
Non-differentiable forward/backward components.
These components are put together in `interpol.autograd` to generate
differentiable functions.
Note
----
.. I removed @torch.jit.script from these entry-points because compiling
all possible combinations of bound+interpolation made the first call
extremely slow.
.. I am not using the dot/multi_dot helpers even though they should be
more efficient that "multiply and sum" because I haven't had the time
to test them. It would be worth doing it.
"""
import torch
from typing import List, Optional, Tuple
from .jit_utils import list_all, dot, dot_multi, pad_list_int
from .bounds import Bound
from .splines import Spline
from . import iso0, iso1, nd
Tensor = torch.Tensor
@torch.jit.script
def make_bound(bound: List[int]) -> List[Bound]:
return [Bound(b) for b in bound]
@torch.jit.script
def make_spline(spline: List[int]) -> List[Spline]:
return [Spline(s) for s in spline]
# @torch.jit.script
def grid_pull(inp, grid, bound: List[int], interpolation: List[int],
extrapolate: int):
"""
inp: (B, C, *spatial_in) tensor
grid: (B, *spatial_out, D) tensor
bound: List{D}[int] tensor
interpolation: List{D}[int]
extrapolate: int
returns: (B, C, *spatial_out) tensor
"""
dim = grid.shape[-1]
bound = pad_list_int(bound, dim)
interpolation = pad_list_int(interpolation, dim)
bound_fn = make_bound(bound)
is_iso1 = list_all([order == 1 for order in interpolation])
if is_iso1:
if dim == 3:
return iso1.pull3d(inp, grid, bound_fn, extrapolate)
elif dim == 2:
return iso1.pull2d(inp, grid, bound_fn, extrapolate)
elif dim == 1:
return iso1.pull1d(inp, grid, bound_fn, extrapolate)
is_iso0 = list_all([order == 0 for order in interpolation])
if is_iso0:
if dim == 3:
return iso0.pull3d(inp, grid, bound_fn, extrapolate)
elif dim == 2:
return iso0.pull2d(inp, grid, bound_fn, extrapolate)
elif dim == 1:
return iso0.pull1d(inp, grid, bound_fn, extrapolate)
spline_fn = make_spline(interpolation)
return nd.pull(inp, grid, bound_fn, spline_fn, extrapolate)
# @torch.jit.script
def grid_push(inp, grid, shape: Optional[List[int]], bound: List[int],
interpolation: List[int], extrapolate: int):
"""
inp: (B, C, *spatial_in) tensor
grid: (B, *spatial_in, D) tensor
shape: List{D}[int] tensor, optional, default=spatial_in
bound: List{D}[int] tensor
interpolation: List{D}[int]
extrapolate: int
returns: (B, C, *shape) tensor
"""
dim = grid.shape[-1]
bound = pad_list_int(bound, dim)
interpolation = pad_list_int(interpolation, dim)
bound_fn = make_bound(bound)
is_iso1 = list_all([order == 1 for order in interpolation])
if is_iso1:
if dim == 3:
return iso1.push3d(inp, grid, shape, bound_fn, extrapolate)
elif dim == 2:
return iso1.push2d(inp, grid, shape, bound_fn, extrapolate)
elif dim == 1:
return iso1.push1d(inp, grid, shape, bound_fn, extrapolate)
is_iso0 = list_all([order == 0 for order in interpolation])
if is_iso0:
if dim == 3:
return iso0.push3d(inp, grid, shape, bound_fn, extrapolate)
elif dim == 2:
return iso0.push2d(inp, grid, shape, bound_fn, extrapolate)
elif dim == 1:
return iso0.push1d(inp, grid, shape, bound_fn, extrapolate)
spline_fn = make_spline(interpolation)
return nd.push(inp, grid, shape, bound_fn, spline_fn, extrapolate)
# @torch.jit.script
def grid_count(grid, shape: Optional[List[int]], bound: List[int],
interpolation: List[int], extrapolate: int):
"""
grid: (B, *spatial_in, D) tensor
shape: List{D}[int] tensor, optional, default=spatial_in
bound: List{D}[int] tensor
interpolation: List{D}[int]
extrapolate: int
returns: (B, 1, *shape) tensor
"""
dim = grid.shape[-1]
bound = pad_list_int(bound, dim)
interpolation = pad_list_int(interpolation, dim)
bound_fn = make_bound(bound)
gshape = list(grid.shape[-dim-1:-1])
if shape is None:
shape = gshape
inp = torch.ones([], dtype=grid.dtype, device=grid.device)
inp = inp.expand([len(grid), 1] + gshape)
is_iso1 = list_all([order == 1 for order in interpolation])
if is_iso1:
if dim == 3:
return iso1.push3d(inp, grid, shape, bound_fn, extrapolate)
elif dim == 2:
return iso1.push2d(inp, grid, shape, bound_fn, extrapolate)
elif dim == 1:
return iso1.push1d(inp, grid, shape, bound_fn, extrapolate)
is_iso0 = list_all([order == 0 for order in interpolation])
if is_iso0:
if dim == 3:
return iso0.push3d(inp, grid, shape, bound_fn, extrapolate)
elif dim == 2:
return iso0.push2d(inp, grid, shape, bound_fn, extrapolate)
elif dim == 1:
return iso0.push1d(inp, grid, shape, bound_fn, extrapolate)
spline_fn = make_spline(interpolation)
return nd.push(inp, grid, shape, bound_fn, spline_fn, extrapolate)
# @torch.jit.script
def grid_grad(inp, grid, bound: List[int], interpolation: List[int],
extrapolate: int):
"""
inp: (B, C, *spatial_in) tensor
grid: (B, *spatial_out, D) tensor
bound: List{D}[int] tensor
interpolation: List{D}[int]
extrapolate: int
returns: (B, C, *spatial_out, D) tensor
"""
dim = grid.shape[-1]
bound = pad_list_int(bound, dim)
interpolation = pad_list_int(interpolation, dim)
bound_fn = make_bound(bound)
is_iso1 = list_all([order == 1 for order in interpolation])
if is_iso1:
if dim == 3:
return iso1.grad3d(inp, grid, bound_fn, extrapolate)
elif dim == 2:
return iso1.grad2d(inp, grid, bound_fn, extrapolate)
elif dim == 1:
return iso1.grad1d(inp, grid, bound_fn, extrapolate)
is_iso0 = list_all([order == 0 for order in interpolation])
if is_iso0:
return iso0.grad(inp, grid, bound_fn, extrapolate)
spline_fn = make_spline(interpolation)
return nd.grad(inp, grid, bound_fn, spline_fn, extrapolate)
# @torch.jit.script
def grid_pushgrad(inp, grid, shape: List[int], bound: List[int],
interpolation: List[int], extrapolate: int):
""" /!\ Used only in backward pass of grid_grad
inp: (B, C, *spatial_in, D) tensor
grid: (B, *spatial_in, D) tensor
shape: List{D}[int], optional
bound: List{D}[int] tensor
interpolation: List{D}[int]
extrapolate: int
returns: (B, C, *shape) tensor
"""
dim = grid.shape[-1]
bound = pad_list_int(bound, dim)
interpolation = pad_list_int(interpolation, dim)
bound_fn = make_bound(bound)
is_iso1 = list_all([order == 1 for order in interpolation])
if is_iso1:
if dim == 3:
return iso1.pushgrad3d(inp, grid, shape, bound_fn, extrapolate)
elif dim == 2:
return iso1.pushgrad2d(inp, grid, shape, bound_fn, extrapolate)
elif dim == 1:
return iso1.pushgrad1d(inp, grid, shape, bound_fn, extrapolate)
is_iso0 = list_all([order == 0 for order in interpolation])
if is_iso0:
return iso0.pushgrad(inp, grid, shape, bound_fn, extrapolate)
spline_fn = make_spline(interpolation)
return nd.pushgrad(inp, grid, shape, bound_fn, spline_fn, extrapolate)
# @torch.jit.script
def grid_hess(inp, grid, bound: List[int], interpolation: List[int],
extrapolate: int):
""" /!\ Used only in backward pass of grid_grad
inp: (B, C, *spatial_in) tensor
grid: (B, *spatial_out, D) tensor
bound: List{D}[int] tensor
interpolation: List{D}[int]
extrapolate: int
returns: (B, C, *spatial_out, D, D) tensor
"""
dim = grid.shape[-1]
bound = pad_list_int(bound, dim)
interpolation = pad_list_int(interpolation, dim)
bound_fn = make_bound(bound)
is_iso1 = list_all([order == 1 for order in interpolation])
if is_iso1:
if dim == 3:
return iso1.hess3d(inp, grid, bound_fn, extrapolate)
if dim == 2:
return iso1.hess2d(inp, grid, bound_fn, extrapolate)
if dim == 1:
return iso1.hess1d(inp, grid, bound_fn, extrapolate)
is_iso0 = list_all([order == 0 for order in interpolation])
if is_iso0:
return iso0.hess(inp, grid, bound_fn, extrapolate)
spline_fn = make_spline(interpolation)
return nd.hess(inp, grid, bound_fn, spline_fn, extrapolate)
# @torch.jit.script
def grid_pull_backward(grad, inp, grid, bound: List[int],
interpolation: List[int], extrapolate: int) \
-> Tuple[Optional[Tensor], Optional[Tensor], ]:
"""
grad: (B, C, *spatial_out) tensor
inp: (B, C, *spatial_in) tensor
grid: (B, *spatial_out, D) tensor
bound: List{D}[int] tensor
interpolation: List{D}[int]
extrapolate: int
returns: (B, C, *spatial_in) tensor, (B, *spatial_out, D)
"""
dim = grid.shape[-1]
grad_inp: Optional[Tensor] = None
grad_grid: Optional[Tensor] = None
if inp.requires_grad:
grad_inp = grid_push(grad, grid, inp.shape[-dim:], bound, interpolation, extrapolate)
if grid.requires_grad:
grad_grid = grid_grad(inp, grid, bound, interpolation, extrapolate)
# grad_grid = dot(grad_grid, grad.unsqueeze(-1), dim=1)
grad_grid = (grad_grid * grad.unsqueeze(-1)).sum(dim=1)
return grad_inp, grad_grid
# @torch.jit.script
def grid_push_backward(grad, inp, grid, bound: List[int],
interpolation: List[int], extrapolate: int) \
-> Tuple[Optional[Tensor], Optional[Tensor], ]:
"""
grad: (B, C, *spatial_out) tensor
inp: (B, C, *spatial_in) tensor
grid: (B, *spatial_in, D) tensor
bound: List{D}[int] tensor
interpolation: List{D}[int]
extrapolate: int
returns: (B, C, *spatial_in) tensor, (B, *spatial_in, D)
"""
grad_inp: Optional[Tensor] = None
grad_grid: Optional[Tensor] = None
if inp.requires_grad:
grad_inp = grid_pull(grad, grid, bound, interpolation, extrapolate)
if grid.requires_grad:
grad_grid = grid_grad(grad, grid, bound, interpolation, extrapolate)
# grad_grid = dot(grad_grid, inp.unsqueeze(-1), dim=1)
grad_grid = (grad_grid * inp.unsqueeze(-1)).sum(dim=1)
return grad_inp, grad_grid
# @torch.jit.script
def grid_count_backward(grad, grid, bound: List[int],
interpolation: List[int], extrapolate: int) \
-> Optional[Tensor]:
"""
grad: (B, C, *spatial_out) tensor
grid: (B, *spatial_in, D) tensor
bound: List{D}[int] tensor
interpolation: List{D}[int]
extrapolate: int
returns: (B, C, *spatial_in) tensor, (B, *spatial_in, D)
"""
if grid.requires_grad:
return grid_grad(grad, grid, bound, interpolation, extrapolate).sum(1)
return None
# @torch.jit.script
def grid_grad_backward(grad, inp, grid, bound: List[int],
interpolation: List[int], extrapolate: int) \
-> Tuple[Optional[Tensor], Optional[Tensor]]:
"""
grad: (B, C, *spatial_out, D) tensor
inp: (B, C, *spatial_in) tensor
grid: (B, *spatial_out, D) tensor
bound: List{D}[int] tensor
interpolation: List{D}[int]
extrapolate: int
returns: (B, C, *spatial_in, D) tensor, (B, *spatial_out, D)
"""
dim = grid.shape[-1]
shape = inp.shape[-dim:]
grad_inp: Optional[Tensor] = None
grad_grid: Optional[Tensor] = None
if inp.requires_grad:
grad_inp = grid_pushgrad(grad, grid, shape, bound, interpolation, extrapolate)
if grid.requires_grad:
grad_grid = grid_hess(inp, grid, bound, interpolation, extrapolate)
# grad_grid = dot_multi(grad_grid, grad.unsqueeze(-1), dim=[1, -2])
grad_grid = (grad_grid * grad.unsqueeze(-1)).sum(dim=[1, -2])
return grad_inp, grad_grid