diffumatch / utils /torch_fmap.py
daidedou
Remove pykeops for space ZeroGPU demo.
3ccb54b
raw
history blame
3.68 kB
import torch
import torch.nn.functional as F
def euclidean_dist(x, y):
"""
Args:
x: pytorch Variable, with shape [m, d]
y: pytorch Variable, with shape [n, d]
Returns:
dist: pytorch Variable, with shape [m, n]
"""
#bs, m, n = x.size(0), x.size(1), y.size(1)
xx = torch.pow(x.squeeze(), 2).sum(1, keepdim=True)
yy = torch.pow(y.squeeze(), 2).sum(1, keepdim=True).t()
dist = xx + yy - 2 * torch.inner(x.squeeze(), y.squeeze())
dist = dist.clamp(min=1e-12).sqrt()
return dist
def knnsearch(x, y, alpha=1./0.07, prod=False):
if prod:
prods = torch.inner(x.squeeze(), y.squeeze())#/( torch.norm(x.squeeze(), dim=-1)[:, None]*torch.norm(y.squeeze(), dim=-1)[None, :])
output = F.softmax(alpha*prods, dim=1)
else:
distance = euclidean_dist(x, y[None,:])
output = F.softmax(-alpha*distance, dim=1)
return output.squeeze()
def extract_p2p_torch(reps_shape, reps_template):
n_ev = reps_shape.shape[-1]
with torch.no_grad():
# print((evecs0_dzo @ fmap01_final.squeeze().T).shape)
# print(evecs1_dzo.shape)
reps_shape_torch = torch.from_numpy(reps_shape).float().cuda()
G_i = (reps_shape_torch[:, None, :].contiguous()) # (M**2, 1, 2)
reps_template_torch = torch.from_numpy(reps_template).float().cuda()
X_j = (reps_template_torch[None, :, :n_ev].contiguous()) # (1, N, 2)
D_ij = ((G_i - X_j) ** 2).sum(-1) # (M**2, N) symbolic matrix of squared distances
indKNN = torch.argmin(D_ij, dim=0).squeeze() # Grid <-> Samples, (M**2, K) integer tensor
# pmap10_ref = FM_to_p2p(fmap01_final.detach().squeeze().cpu().numpy(), s_dict['evecs'], template_dict['evecs'])
# print(indKNN[:10], pmap10_ref[:10])
indKNN_2 = torch.argmin(D_ij, dim=1).squeeze()
return indKNN.detach().cpu().numpy(), indKNN_2.detach().cpu().numpy()
def extract_p2p_torch_fmap(fmap_shape_template, evecs_shape, evecs_template):
n_ev = fmap_shape_template.shape[-1]
with torch.no_grad():
# print((evecs0_dzo @ fmap01_final.squeeze().T).shape)
# print(evecs1_dzo.shape)
G_i = ((evecs_shape[:, :n_ev] @ fmap_shape_template.squeeze().T)[:, None, :].contiguous()) # (M**2, 1, 2)
X_j = (evecs_template[None, :, :n_ev].contiguous()) # (1, N, 2)
D_ij = ((G_i - X_j) ** 2).sum(-1) # (M**2, N) symbolic matrix of squared distances
indKNN = torch.argmin(D_ij, dim=0).squeeze() # Grid <-> Samples, (M**2, K) integer tensor
# pmap10_ref = FM_to_p2p(fmap01_final.detach().squeeze().cpu().numpy(), s_dict['evecs'], template_dict['evecs'])
# print(indKNN[:10], pmap10_ref[:10])
indKNN_2 = torch.argmin(D_ij, dim=1).squeeze()
return indKNN.detach().cpu().numpy(), indKNN_2.detach().cpu().numpy()
def wlstsq(A, B, w):
if w is None:
return torch.linalg.lstsq(A, B).solution
else:
assert w.dim() + 1 == A.dim() and w.shape[-1] == A.shape[-2]
W = torch.diag_embed(w)
return torch.linalg.lstsq(W @ A, W @ B).solution
def torch_zoomout(evecs0, evecs1, evecs_1_trans, fmap01, target_size, step=1):
assert fmap01.shape[-2] == fmap01.shape[-1], f"square fmap needed, got {fmap01.shape[-2]} and {fmap01.shape[-1]}"
fs = fmap01.shape[0]
for i in range(fs, target_size+1, step):
indKNN, _ = extract_p2p_torch_fmap(fmap01, evecs0, evecs1)
#fmap01 = wlstsq(evecs1[..., :i], evecs0[indKNN, :i], None)
fmap01 = evecs_1_trans[:i, :] @ evecs0[indKNN, :i]
if fmap01.shape[0] < target_size:
fmap01 = evecs_1_trans[:target_size, :] @ evecs0[indKNN, :target_size]
return fmap01