|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
def tps(theta, ctrl, grid): |
|
'''Evaluate the thin-plate-spline (TPS) surface at xy locations arranged in a grid. |
|
The TPS surface is a minimum bend interpolation surface defined by a set of control points. |
|
The function value for a x,y location is given by |
|
|
|
TPS(x,y) := theta[-3] + theta[-2]*x + theta[-1]*y + \sum_t=0,T theta[t] U(x,y,ctrl[t]) |
|
|
|
This method computes the TPS value for multiple batches over multiple grid locations for 2 |
|
surfaces in one go. |
|
|
|
Params |
|
------ |
|
theta: Nx(T+3)x2 tensor, or Nx(T+2)x2 tensor |
|
Batch size N, T+3 or T+2 (reduced form) model parameters for T control points in dx and dy. |
|
ctrl: NxTx2 tensor or Tx2 tensor |
|
T control points in normalized image coordinates [0..1] |
|
grid: NxHxWx3 tensor |
|
Grid locations to evaluate with homogeneous 1 in first coordinate. |
|
|
|
Returns |
|
------- |
|
z: NxHxWx2 tensor |
|
Function values at each grid location in dx and dy. |
|
''' |
|
|
|
N, H, W, _ = grid.size() |
|
|
|
if ctrl.dim() == 2: |
|
ctrl = ctrl.expand(N, *ctrl.size()) |
|
|
|
T = ctrl.shape[1] |
|
|
|
diff = grid[...,1:].unsqueeze(-2) - ctrl.unsqueeze(1).unsqueeze(1) |
|
D = torch.sqrt((diff**2).sum(-1)) |
|
U = (D**2) * torch.log(D + 1e-6) |
|
|
|
w, a = theta[:, :-3, :], theta[:, -3:, :] |
|
|
|
reduced = T + 2 == theta.shape[1] |
|
if reduced: |
|
w = torch.cat((-w.sum(dim=1, keepdim=True), w), dim=1) |
|
|
|
|
|
b = torch.bmm(U.view(N, -1, T), w).view(N,H,W,2) |
|
|
|
z = torch.bmm(grid.view(N,-1,3), a).view(N,H,W,2) + b |
|
|
|
return z |
|
|
|
def tps_grid(theta, ctrl, size): |
|
'''Compute a thin-plate-spline grid from parameters for sampling. |
|
|
|
Params |
|
------ |
|
theta: Nx(T+3)x2 tensor |
|
Batch size N, T+3 model parameters for T control points in dx and dy. |
|
ctrl: NxTx2 tensor, or Tx2 tensor |
|
T control points in normalized image coordinates [0..1] |
|
size: tuple |
|
Output grid size as NxCxHxW. C unused. This defines the output image |
|
size when sampling. |
|
|
|
Returns |
|
------- |
|
grid : NxHxWx2 tensor |
|
Grid suitable for sampling in pytorch containing source image |
|
locations for each output pixel. |
|
''' |
|
N, _, H, W = size |
|
|
|
grid = theta.new(N, H, W, 3) |
|
grid[:, :, :, 0] = 1. |
|
grid[:, :, :, 1] = torch.linspace(0, 1, W) |
|
grid[:, :, :, 2] = torch.linspace(0, 1, H).unsqueeze(-1) |
|
|
|
z = tps(theta, ctrl, grid) |
|
return (grid[...,1:] + z)*2-1 |
|
|
|
def tps_sparse(theta, ctrl, xy): |
|
if xy.dim() == 2: |
|
xy = xy.expand(theta.shape[0], *xy.size()) |
|
|
|
N, M = xy.shape[:2] |
|
grid = xy.new(N, M, 3) |
|
grid[..., 0] = 1. |
|
grid[..., 1:] = xy |
|
|
|
z = tps(theta, ctrl, grid.view(N,M,1,3)) |
|
return xy + z.view(N, M, 2) |
|
|
|
def uniform_grid(shape): |
|
'''Uniformly places control points aranged in grid accross normalized image coordinates. |
|
|
|
Params |
|
------ |
|
shape : tuple |
|
HxW defining the number of control points in height and width dimension |
|
|
|
Returns |
|
------- |
|
points: HxWx2 tensor |
|
Control points over [0,1] normalized image range. |
|
''' |
|
H,W = shape[:2] |
|
c = torch.zeros(H, W, 2) |
|
c[..., 0] = torch.linspace(0, 1, W) |
|
c[..., 1] = torch.linspace(0, 1, H).unsqueeze(-1) |
|
return c |
|
|
|
if __name__ == '__main__': |
|
c = torch.tensor([ |
|
[0., 0], |
|
[1., 0], |
|
[1., 1], |
|
[0, 1], |
|
]).unsqueeze(0) |
|
theta = torch.zeros(1, 4+3, 2) |
|
size= (1,1,6,3) |
|
print(tps_grid(theta, c, size).shape) |