daidedou
first_try
458efe2
raw
history blame
8.68 kB
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
# feature extractor
from .layers import DiffusionNet
from omegaconf import OmegaConf
def get_mask(evals1, evals2, gamma=0.5, device="cpu"):
scaling_factor = max(torch.max(evals1), torch.max(evals2))
evals1, evals2 = evals1.to(device) / scaling_factor, evals2.to(device) / scaling_factor
evals_gamma1, evals_gamma2 = (evals1 ** gamma)[None, :], (evals2 ** gamma)[:, None]
M_re = evals_gamma2 / (evals_gamma2.square() + 1) - evals_gamma1 / (evals_gamma1.square() + 1)
M_im = 1 / (evals_gamma2.square() + 1) - 1 / (evals_gamma1.square() + 1)
return M_re.square() + M_im.square()
def get_CXX(feat_x, feat_y, evecs_trans_x, evecs_trans_y):
# compute linear operator matrix representation C1 and C2
F_hat = torch.bmm(evecs_trans_x, feat_x)
G_hat = torch.bmm(evecs_trans_y, feat_y)
A, B = F_hat, G_hat
A_t, B_t = A.transpose(1, 2), B.transpose(1, 2)
A_A_t, B_B_t = torch.bmm(A, A_t), torch.bmm(B, B_t)
B_A_t, A_B_t = torch.bmm(B, A_t), torch.bmm(A, B_t)
C12 = torch.bmm(B_A_t, torch.inverse(A_A_t))
C21 = torch.bmm(A_B_t, torch.inverse(B_B_t))
return [C12, C21]
def get_mask_noise(C, diff_model, scale=1, N_est=200, device="cuda", normalize=True, absolute=False):
with torch.no_grad():
sig = torch.ones([1, 1, 1, 1], device=device) * scale
noise = torch.randn((N_est, 1, 30, 30), device=device)
if absolute:
#noisy_new = torch.abs(C[None, :, :] + noise *scale)
noisy_new = torch.abs(torch.abs(C[None, :, :]) + noise * scale)
else:
noisy_new = C[None, :, :] + noise *scale
denoised = diff_model.net(noisy_new, sig)
mask_squared = torch.mean(torch.abs(noisy_new - denoised)/(2*scale), dim=0)/torch.mean(torch.abs(noisy_new), dim=0)
mask_median = torch.median(mask_squared).item()
if normalize:
mask_median = torch.median(mask_squared).item()
M_denoised = torch.sqrt(torch.clamp(mask_squared-mask_median/2, 0, mask_median*2))
else:
M_denoised = torch.sqrt(mask_squared)
return M_denoised-M_denoised.min()#, mask_median*2)
class RegularizedFMNet(nn.Module):
"""Compute the functional map matrix representation."""
def __init__(self, lambda_=1e-3, resolvant_gamma=0.5, use_resolvent=False):
super().__init__()
self.lambda_ = lambda_
self.use_resolvent = use_resolvent
self.resolvant_gamma = resolvant_gamma
def forward(self, feat_x, feat_y, evals_x, evals_y, evecs_trans_x, evecs_trans_y, diff_conf=None):
# compute linear operator matrix representation C1 and C2
evecs_trans_x, evecs_trans_y = evecs_trans_x.unsqueeze(0), evecs_trans_y.unsqueeze(0)
evals_x, evals_y = evals_x.unsqueeze(0), evals_y.unsqueeze(0)
F_hat = torch.bmm(evecs_trans_x, feat_x)
G_hat = torch.bmm(evecs_trans_y, feat_y)
A, B = F_hat, G_hat
if diff_conf is not None:
diff_model, scale, normalize, absolute, N_est = diff_conf
C12_raw, C21_raw = get_CXX(feat_x, feat_y, evecs_trans_x, evecs_trans_y)
D12 = get_mask_noise(C12_raw, diff_model, scale=scale, N_est=N_est, normalize=normalize, absolute=absolute)
D21 = get_mask_noise(C21_raw, diff_model, scale=scale, N_est=N_est, normalize=normalize, absolute=absolute)
elif self.use_resolvent:
D12 = get_mask(evals_x.flatten(), evals_y.flatten(), self.resolvant_gamma, feat_x.device).unsqueeze(0)
D21 = get_mask(evals_y.flatten(), evals_x.flatten(), self.resolvant_gamma, feat_x.device).unsqueeze(0)
else:
D12 = (torch.unsqueeze(evals_y, 2) - torch.unsqueeze(evals_x, 1))**2
D21 = (torch.unsqueeze(evals_x, 2) - torch.unsqueeze(evals_y, 1))**2
A_t, B_t = A.transpose(1, 2), B.transpose(1, 2)
A_A_t, B_B_t = torch.bmm(A, A_t), torch.bmm(B, B_t)
B_A_t, A_B_t = torch.bmm(B, A_t), torch.bmm(A, B_t)
C12_i = []
for i in range(evals_x.size(1)):
D12_i = torch.cat([torch.diag(D12[bs, i, :].flatten()).unsqueeze(0) for bs in range(evals_x.size(0))], dim=0)
C12 = torch.bmm(torch.inverse(A_A_t + self.lambda_ * D12_i), B_A_t[:, i, :].unsqueeze(1).transpose(1, 2))
C12_i.append(C12.transpose(1, 2))
C12 = torch.cat(C12_i, dim=1)
C21_i = []
for i in range(evals_y.size(1)):
D21_i = torch.cat([torch.diag(D21[bs, i, :].flatten()).unsqueeze(0) for bs in range(evals_y.size(0))], dim=0)
C21 = torch.bmm(torch.inverse(B_B_t + self.lambda_ * D21_i), A_B_t[:, i, :].unsqueeze(1).transpose(1, 2))
C21_i.append(C21.transpose(1, 2))
C21 = torch.cat(C21_i, dim=1)
if diff_conf is not None:
return [C12_raw, C12, D12], [C21_raw, C21, D21]
else:
return [C12, C21]
class DFMNet(nn.Module):
"""
Compilation of the global model :
- diffusion net as feature extractor
- fmap + q-fmap
- unsupervised loss
"""
def __init__(self, cfg):
super().__init__()
# feature extractor #
with_grad=True
self.feat = cfg["feat"]
self.feature_extractor = DiffusionNet(
C_in=cfg["C_in"],
C_out=cfg["n_feat"],
C_width=128,
N_block=4,
dropout=True,
with_gradient_features=with_grad,
with_gradient_rotations=with_grad,
)
# regularized fmap
self.fmreg_net = RegularizedFMNet(lambda_=cfg["lambda_"],
resolvant_gamma=cfg.get("resolvent_gamma", 0.5), use_resolvent=cfg.get("use_resolvent", False))
# parameters
self.n_fmap = cfg["n_fmap"]
if cfg.get("diffusion", None) is not None:
self.normalize = cfg["diffusion"]["normalize"]
self.abs = cfg["diffusion"]["abs"]
self.N_est = cfg.diffusion.get("batch_mask", 100)
def forward(self, batch, diff_model=None, scale=1):
if self.feat == "xyz":
feat_1, feat_2 = batch["shape1"]["vertices"], batch["shape2"]["vertices"]
elif self.feat == "wks":
feat_1, feat_2 = batch["shape1"]["wks"], batch["shape2"]["wks"]
elif self.feat == "hks":
feat_1, feat_2 = batch["shape1"]["hks"], batch["shape2"]["hks"]
else:
raise Exception("Unknow Feature")
verts1, faces1, mass1, L1, evals1, evecs1, gradX1, gradY1 = (feat_1, batch["shape1"]["faces"],
batch["shape1"]["mass"], batch["shape1"]["L"],
batch["shape1"]["evals"], batch["shape1"]["evecs"],
batch["shape1"]["gradX"], batch["shape1"]["gradY"])
verts2, faces2, mass2, L2, evals2, evecs2, gradX2, gradY2 = (feat_2, batch["shape2"]["faces"],
batch["shape2"]["mass"], batch["shape2"]["L"],
batch["shape2"]["evals"], batch["shape2"]["evecs"],
batch["shape2"]["gradX"], batch["shape2"]["gradY"])
# set features to vertices
features1, features2 = verts1, verts2
# print(features1.shape, features2.shape)
feat1 = self.feature_extractor(features1, mass1, L=L1, evals=evals1, evecs=evecs1,
gradX=gradX1, gradY=gradY1, faces=faces1).unsqueeze(0)
feat2 = self.feature_extractor(features2, mass2, L=L2, evals=evals2, evecs=evecs2,
gradX=gradX2, gradY=gradY2, faces=faces2).unsqueeze(0)
evecs_trans1, evecs_trans2 = evecs1.t()[:self.n_fmap] @ torch.diag(mass1.squeeze()), evecs2.squeeze().t()[:self.n_fmap].squeeze() @ torch.diag(mass2.squeeze())
evals1, evals2 = evals1[:self.n_fmap], evals2[:self.n_fmap]
#
if diff_model is not None:
C12_pred, C21_pred = self.fmreg_net(feat1, feat2, evals1, evals2, evecs_trans1, evecs_trans2,
diff_conf=[diff_model, scale, self.normalize, self.abs, self.N_est])
else:
C12_pred, C21_pred = self.fmreg_net(feat1, feat2, evals1, evals2, evecs_trans1, evecs_trans2)
return C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2