import torch import torch.nn as nn import torch.nn.functional as F class TpsWarp(nn.Module): def __init__(self, s): super(TpsWarp, self).__init__() iy, ix = torch.meshgrid(torch.linspace(-1, 1, s), torch.linspace(-1, 1, s)) self.gs = torch.stack((ix, iy), dim=2).reshape((1, -1, 2)) self.sz = s def forward(self, src, dst): # src and dst are B.n.2 B, n, _ = src.size() # B.n.1.2 delta = src.unsqueeze(2) delta = delta - delta.permute(0, 2, 1, 3) # B.n.n K = delta.norm(dim=3) # Rsq = torch.sum(delta**2, dim=3) # Rsq += torch.eye(n) # Rsq[Rsq == 0] = 1. # K = 0.5 * Rsq * torch.log(Rsq) # c = -150 # K = torch.exp(c * Rsq) # K = torch.abs(Rsq - 0.5) - 0.5 # WARNING: TORCH.SQRT HAS NAN GRAD AT 0 # K = torch.sqrt(Rsq) # print(K) # K[torch.isnan(K)] = 0. P = torch.cat((torch.ones((B, n, 1)), src), 2) L = torch.cat((K, P), 2) t = torch.cat( (P.permute(0, 2, 1), torch.zeros((B, 3, 3))), 2) L = torch.cat((L, t), 1) # LInv = L.inverse() # # wv is B.n+3.2 # wv = torch.matmul(LInv, torch.cat((dst, torch.zeros((B, 3, 2))), 1)) # the above implementation has stability problem near the boundaries wv = torch.solve( torch.cat((dst, torch.zeros((B, 3, 2))), 1), L)[0] # get the grid sampler s = self.gs.size(1) gs = self.gs delta = gs.unsqueeze(2) delta = delta - src.unsqueeze(1) K = delta.norm(dim=3) # Rsq = torch.sum(delta**2, dim=3) # K = torch.exp(c * Rsq) # Rsq[Rsq == 0] = 1. # K = 0.5 * Rsq * torch.log(Rsq) # K = torch.abs(Rsq - 0.5) - 0.5 # K = torch.sqrt(Rsq) # K[torch.isnan(K)] = 0. gs = gs.expand(B, -1, -1) P = torch.cat((torch.ones((B, s, 1)), gs), 2) L = torch.cat((K, P), 2) gs = torch.matmul(L, wv) return gs.reshape(B, self.sz, self.sz, 2).permute(0, 3, 1, 2) class PspWarp(nn.Module): def __init__(self): super().__init__() def pspmat(self, src, dst): # B, 4, 2 B, _, _ = src.size() s = torch.cat([ torch.cat([src, torch.ones((B, 4, 1)), torch.zeros((B, 4, 3)), -dst[..., 0: 1] * src[..., 0: 1], -dst[..., 0: 1] * src[..., 1: 2]], dim=2), torch.cat([torch.zeros((B, 4, 3)), src, torch.ones((B, 4, 1)), -dst[..., 1: 2] * src[..., 0: 1], -dst[..., 1: 2] * src[..., 1: 2]], dim=2) ], dim=1) t = torch.cat([dst[..., 0: 1], dst[..., 1: 2]], dim=1) # M = s.inverse() @ t M = torch.solve(t, s)[0] # M is B 8 1 return M def forward(self, xy, M): # permute M to B 1 8 M = M.permute(0, 2, 1) t = M[..., 6] * xy[..., 0] + M[..., 7] * xy[..., 1] + 1 u = (M[..., 0] * xy[..., 0] + M[..., 1] * xy[..., 1] + M[..., 2]) / t v = (M[..., 3] * xy[..., 0] + M[..., 4] * xy[..., 1] + M[..., 5]) / t return torch.stack((u, v), dim=2) # for ii in range(4): # xy = src[:, ii : ii + 1, :] # uv = dst[:, ii : ii + 1, :] # t0 = [xy, torch.ones((B, 1, 1)), torch.zeros((B, 1, 3)), -uv[..., 0] * xy[..., 0], -uv[..., 0] * xy[..., 1]] # t0 = torch.cat(t0, dim=2) # t1 = [torch.zeros((B, 1, 3)), xy, torch.ones((B, 1, 1)), -uv[..., 1] * xy[..., 0], -uv[..., 1] * xy[..., 1]] # t1 = torch.cat(t1, dim=2) class IdwWarp(nn.Module): # inverse distance weighting def __init__(self, s): super().__init__() iy, ix = torch.meshgrid(torch.linspace(-1, 1, s), torch.linspace(-1, 1, s)) self.gs = torch.stack((ix, iy), dim=2).reshape((1, -1, 2)).to('cuda') self.s = s def forward(self, src, dst): # B n 2 B, n, _ = src.size() # B.n.1.2 delta = src.unsqueeze(2) delta = delta - self.gs.unsqueeze(0) # B.n.K p = 1 Rsq = torch.sum(delta**2, dim=3)**p w = 1 / Rsq # turn inf to [0...1...0] t = torch.isinf(w) idx = t.any(dim=1).nonzero() w[idx[:, 0], :, idx[:, 1]] = t[idx[:, 0], :, idx[:, 1]].float() wwx = w * dst[..., 0: 1] wwx = wwx.sum(dim=1) / w.sum(dim=1) wwy = w * dst[..., 1: 2] wwy = wwy.sum(dim=1) / w.sum(dim=1) # print(wwy.size()) gs = torch.stack((wwx, wwy), dim=2).reshape( B, self.s, self.s, 2).permute(0, 3, 1, 2) return gs if __name__ == "__main__": import cv2 import numpy as np from hdf5storage import loadmat from visdom import Visdom vis = Visdom(port=10086) # bm_path = '/nfs/bigdisk/sagnik/swat3d/bm/7/2_471_7-ec_Page_375-5LI0001.mat' # img_path = '/nfs/bigdisk/sagnik/swat3d/img/7/2_471_7-ec_Page_375-5LI0001.png' # bm = loadmat(bm_path)['bm'] # bm = (bm - 224) / 224. # bm = cv2.resize(bm, (64, 64), cv2.INTER_LINEAR).astype(np.float32) # im = cv2.imread(img_path) / 255. # im = im[..., ::-1].copy() # im = cv2.resize(im, (256, 256), cv2.INTER_AREA).astype(np.float32) # im = torch.from_numpy(im.transpose(2, 0, 1)).unsqueeze(0).to('cuda') # x = np.random.choice(np.arange(64), 50, False) # y = np.random.choice(np.arange(64), 50, False) # src = torch.tensor([[x, y]], dtype=torch.float32).permute(0, 2, 1) # src = (src - 32) / 32. # dst = torch.from_numpy(bm[y, x, :]).unsqueeze(0).to('cuda') # # print(src.size()) # # print(dst.size()) # tpswarp = TpsWarp(64) # import time # t = time.time() # for _ in range(100): # gs = tpswarp(src, dst) # print(f'time:{time.time() - t}') # gs = gs.view(-1, 64, 64, 2) # print(gs.size()) # bm2x2 = F.interpolate(gs.permute(0, 3, 1, 2), size=256, mode='bilinear', align_corners=True).permute(0, 2, 3, 1) # rim = F.grid_sample(im, bm2x2, align_corners=True) # vis.images(rim, win='sk3') tpswarp = TpsWarp(16) import matplotlib.pyplot as plt cn = torch.tensor([[-1, -1], [1, -1], [1, 1], [-1, 1], [-0.5, -1], [0, -1], [0.5, -1]], dtype=torch.float).unsqueeze(0) pn = torch.tensor([[-1, -0.5], [1, -1], [1, 1], [-1, 0.5], [-0.5, -1], [0, -0.5], [0.5, -1]]).unsqueeze(0) pspwarp = PspWarp() # # print(cn.dtype) M = pspwarp.pspmat(cn[..., 0: 4, :], pn[..., 0: 4, :]) invM = pspwarp.pspmat(pn[..., 0: 4, :], cn[..., 0: 4, :]) # iy, ix = torch.meshgrid(torch.linspace(-1, 1, 8), torch.linspace(-1, 1, 8)) # gs = torch.stack((ix, iy), dim=2).reshape((1, -1, 2)).to('cuda') # t = pspwarp(gs, M).reshape(8, 8, 2).detach().cpu().numpy() # print(M) t = tpswarp(cn, pn) from tsdeform import WarperUtil wu = WarperUtil(16) tgs = wu.global_post_warp(t, 16, invM, M) t = tgs.permute(0, 2, 3, 1)[0].detach().cpu().numpy() plt.clf() plt.pcolormesh(t[..., 0], t[..., 1], np.zeros_like(t[..., 0]), edgecolors='r') plt.gca().invert_yaxis() plt.gca().axis('equal') vis.matplot(plt, env='grid', win='mpl')