Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class TpsWarp(nn.Module): | |
def __init__(self, s): | |
super(TpsWarp, self).__init__() | |
iy, ix = torch.meshgrid(torch.linspace(-1, 1, s), | |
torch.linspace(-1, 1, s)) | |
self.gs = torch.stack((ix, iy), dim=2).reshape((1, -1, 2)) | |
self.sz = s | |
def forward(self, src, dst): | |
# src and dst are B.n.2 | |
B, n, _ = src.size() | |
# B.n.1.2 | |
delta = src.unsqueeze(2) | |
delta = delta - delta.permute(0, 2, 1, 3) | |
# B.n.n | |
K = delta.norm(dim=3) | |
# Rsq = torch.sum(delta**2, dim=3) | |
# Rsq += torch.eye(n) | |
# Rsq[Rsq == 0] = 1. | |
# K = 0.5 * Rsq * torch.log(Rsq) | |
# c = -150 | |
# K = torch.exp(c * Rsq) | |
# K = torch.abs(Rsq - 0.5) - 0.5 | |
# WARNING: TORCH.SQRT HAS NAN GRAD AT 0 | |
# K = torch.sqrt(Rsq) | |
# print(K) | |
# K[torch.isnan(K)] = 0. | |
P = torch.cat((torch.ones((B, n, 1)), src), 2) | |
L = torch.cat((K, P), 2) | |
t = torch.cat( | |
(P.permute(0, 2, 1), torch.zeros((B, 3, 3))), 2) | |
L = torch.cat((L, t), 1) | |
# LInv = L.inverse() | |
# # wv is B.n+3.2 | |
# wv = torch.matmul(LInv, torch.cat((dst, torch.zeros((B, 3, 2))), 1)) | |
# the above implementation has stability problem near the boundaries | |
wv = torch.solve( | |
torch.cat((dst, torch.zeros((B, 3, 2))), 1), L)[0] | |
# get the grid sampler | |
s = self.gs.size(1) | |
gs = self.gs | |
delta = gs.unsqueeze(2) | |
delta = delta - src.unsqueeze(1) | |
K = delta.norm(dim=3) | |
# Rsq = torch.sum(delta**2, dim=3) | |
# K = torch.exp(c * Rsq) | |
# Rsq[Rsq == 0] = 1. | |
# K = 0.5 * Rsq * torch.log(Rsq) | |
# K = torch.abs(Rsq - 0.5) - 0.5 | |
# K = torch.sqrt(Rsq) | |
# K[torch.isnan(K)] = 0. | |
gs = gs.expand(B, -1, -1) | |
P = torch.cat((torch.ones((B, s, 1)), gs), 2) | |
L = torch.cat((K, P), 2) | |
gs = torch.matmul(L, wv) | |
return gs.reshape(B, self.sz, self.sz, 2).permute(0, 3, 1, 2) | |
class PspWarp(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def pspmat(self, src, dst): | |
# B, 4, 2 | |
B, _, _ = src.size() | |
s = torch.cat([ | |
torch.cat([src, | |
torch.ones((B, 4, 1)), | |
torch.zeros((B, 4, 3)), | |
-dst[..., 0: 1] * src[..., 0: 1], -dst[..., 0: 1] * src[..., 1: 2]], dim=2), | |
torch.cat([torch.zeros((B, 4, 3)), src, torch.ones((B, 4, 1)), | |
-dst[..., 1: 2] * src[..., 0: 1], -dst[..., 1: 2] * src[..., 1: 2]], dim=2) | |
], dim=1) | |
t = torch.cat([dst[..., 0: 1], dst[..., 1: 2]], dim=1) | |
# M = s.inverse() @ t | |
M = torch.solve(t, s)[0] | |
# M is B 8 1 | |
return M | |
def forward(self, xy, M): | |
# permute M to B 1 8 | |
M = M.permute(0, 2, 1) | |
t = M[..., 6] * xy[..., 0] + M[..., 7] * xy[..., 1] + 1 | |
u = (M[..., 0] * xy[..., 0] + M[..., 1] * xy[..., 1] + M[..., 2]) / t | |
v = (M[..., 3] * xy[..., 0] + M[..., 4] * xy[..., 1] + M[..., 5]) / t | |
return torch.stack((u, v), dim=2) | |
# for ii in range(4): | |
# xy = src[:, ii : ii + 1, :] | |
# uv = dst[:, ii : ii + 1, :] | |
# t0 = [xy, torch.ones((B, 1, 1)), torch.zeros((B, 1, 3)), -uv[..., 0] * xy[..., 0], -uv[..., 0] * xy[..., 1]] | |
# t0 = torch.cat(t0, dim=2) | |
# t1 = [torch.zeros((B, 1, 3)), xy, torch.ones((B, 1, 1)), -uv[..., 1] * xy[..., 0], -uv[..., 1] * xy[..., 1]] | |
# t1 = torch.cat(t1, dim=2) | |
class IdwWarp(nn.Module): | |
# inverse distance weighting | |
def __init__(self, s): | |
super().__init__() | |
iy, ix = torch.meshgrid(torch.linspace(-1, 1, s), | |
torch.linspace(-1, 1, s)) | |
self.gs = torch.stack((ix, iy), dim=2).reshape((1, -1, 2)).to('cuda') | |
self.s = s | |
def forward(self, src, dst): | |
# B n 2 | |
B, n, _ = src.size() | |
# B.n.1.2 | |
delta = src.unsqueeze(2) | |
delta = delta - self.gs.unsqueeze(0) | |
# B.n.K | |
p = 1 | |
Rsq = torch.sum(delta**2, dim=3)**p | |
w = 1 / Rsq | |
# turn inf to [0...1...0] | |
t = torch.isinf(w) | |
idx = t.any(dim=1).nonzero() | |
w[idx[:, 0], :, idx[:, 1]] = t[idx[:, 0], :, idx[:, 1]].float() | |
wwx = w * dst[..., 0: 1] | |
wwx = wwx.sum(dim=1) / w.sum(dim=1) | |
wwy = w * dst[..., 1: 2] | |
wwy = wwy.sum(dim=1) / w.sum(dim=1) | |
# print(wwy.size()) | |
gs = torch.stack((wwx, wwy), dim=2).reshape( | |
B, self.s, self.s, 2).permute(0, 3, 1, 2) | |
return gs | |
if __name__ == "__main__": | |
import cv2 | |
import numpy as np | |
from hdf5storage import loadmat | |
from visdom import Visdom | |
vis = Visdom(port=10086) | |
# bm_path = '/nfs/bigdisk/sagnik/swat3d/bm/7/2_471_7-ec_Page_375-5LI0001.mat' | |
# img_path = '/nfs/bigdisk/sagnik/swat3d/img/7/2_471_7-ec_Page_375-5LI0001.png' | |
# bm = loadmat(bm_path)['bm'] | |
# bm = (bm - 224) / 224. | |
# bm = cv2.resize(bm, (64, 64), cv2.INTER_LINEAR).astype(np.float32) | |
# im = cv2.imread(img_path) / 255. | |
# im = im[..., ::-1].copy() | |
# im = cv2.resize(im, (256, 256), cv2.INTER_AREA).astype(np.float32) | |
# im = torch.from_numpy(im.transpose(2, 0, 1)).unsqueeze(0).to('cuda') | |
# x = np.random.choice(np.arange(64), 50, False) | |
# y = np.random.choice(np.arange(64), 50, False) | |
# src = torch.tensor([[x, y]], dtype=torch.float32).permute(0, 2, 1) | |
# src = (src - 32) / 32. | |
# dst = torch.from_numpy(bm[y, x, :]).unsqueeze(0).to('cuda') | |
# # print(src.size()) | |
# # print(dst.size()) | |
# tpswarp = TpsWarp(64) | |
# import time | |
# t = time.time() | |
# for _ in range(100): | |
# gs = tpswarp(src, dst) | |
# print(f'time:{time.time() - t}') | |
# gs = gs.view(-1, 64, 64, 2) | |
# print(gs.size()) | |
# bm2x2 = F.interpolate(gs.permute(0, 3, 1, 2), size=256, mode='bilinear', align_corners=True).permute(0, 2, 3, 1) | |
# rim = F.grid_sample(im, bm2x2, align_corners=True) | |
# vis.images(rim, win='sk3') | |
tpswarp = TpsWarp(16) | |
import matplotlib.pyplot as plt | |
cn = torch.tensor([[-1, -1], [1, -1], [1, 1], [-1, 1], [-0.5, -1], | |
[0, -1], [0.5, -1]], dtype=torch.float).unsqueeze(0) | |
pn = torch.tensor([[-1, -0.5], [1, -1], [1, 1], [-1, 0.5], | |
[-0.5, -1], [0, -0.5], [0.5, -1]]).unsqueeze(0) | |
pspwarp = PspWarp() | |
# # print(cn.dtype) | |
M = pspwarp.pspmat(cn[..., 0: 4, :], pn[..., 0: 4, :]) | |
invM = pspwarp.pspmat(pn[..., 0: 4, :], cn[..., 0: 4, :]) | |
# iy, ix = torch.meshgrid(torch.linspace(-1, 1, 8), torch.linspace(-1, 1, 8)) | |
# gs = torch.stack((ix, iy), dim=2).reshape((1, -1, 2)).to('cuda') | |
# t = pspwarp(gs, M).reshape(8, 8, 2).detach().cpu().numpy() | |
# print(M) | |
t = tpswarp(cn, pn) | |
from tsdeform import WarperUtil | |
wu = WarperUtil(16) | |
tgs = wu.global_post_warp(t, 16, invM, M) | |
t = tgs.permute(0, 2, 3, 1)[0].detach().cpu().numpy() | |
plt.clf() | |
plt.pcolormesh(t[..., 0], t[..., 1], | |
np.zeros_like(t[..., 0]), edgecolors='r') | |
plt.gca().invert_yaxis() | |
plt.gca().axis('equal') | |
vis.matplot(plt, env='grid', win='mpl') | |