Spaces:
Running
on
Zero
Running
on
Zero
| from copy import deepcopy | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # feature extractor | |
| from .layers import DiffusionNet | |
| from omegaconf import OmegaConf | |
| def get_mask(evals1, evals2, gamma=0.5, device="cpu"): | |
| scaling_factor = max(torch.max(evals1), torch.max(evals2)) | |
| evals1, evals2 = evals1.to(device) / scaling_factor, evals2.to(device) / scaling_factor | |
| evals_gamma1, evals_gamma2 = (evals1 ** gamma)[None, :], (evals2 ** gamma)[:, None] | |
| M_re = evals_gamma2 / (evals_gamma2.square() + 1) - evals_gamma1 / (evals_gamma1.square() + 1) | |
| M_im = 1 / (evals_gamma2.square() + 1) - 1 / (evals_gamma1.square() + 1) | |
| return M_re.square() + M_im.square() | |
| def get_CXX(feat_x, feat_y, evecs_trans_x, evecs_trans_y): | |
| # compute linear operator matrix representation C1 and C2 | |
| F_hat = torch.bmm(evecs_trans_x, feat_x) | |
| G_hat = torch.bmm(evecs_trans_y, feat_y) | |
| A, B = F_hat, G_hat | |
| A_t, B_t = A.transpose(1, 2), B.transpose(1, 2) | |
| A_A_t, B_B_t = torch.bmm(A, A_t), torch.bmm(B, B_t) | |
| B_A_t, A_B_t = torch.bmm(B, A_t), torch.bmm(A, B_t) | |
| C12 = torch.bmm(B_A_t, torch.inverse(A_A_t)) | |
| C21 = torch.bmm(A_B_t, torch.inverse(B_B_t)) | |
| return [C12, C21] | |
| def get_mask_noise(C, diff_model, scale=1, N_est=200, device="cuda", normalize=True, absolute=False): | |
| with torch.no_grad(): | |
| sig = torch.ones([1, 1, 1, 1], device=device) * scale | |
| noise = torch.randn((N_est, 1, 30, 30), device=device) | |
| if absolute: | |
| #noisy_new = torch.abs(C[None, :, :] + noise *scale) | |
| noisy_new = torch.abs(torch.abs(C[None, :, :]) + noise * scale) | |
| else: | |
| noisy_new = C[None, :, :] + noise *scale | |
| denoised = diff_model.net(noisy_new, sig) | |
| mask_squared = torch.mean(torch.abs(noisy_new - denoised)/(2*scale), dim=0)/torch.mean(torch.abs(noisy_new), dim=0) | |
| mask_median = torch.median(mask_squared).item() | |
| if normalize: | |
| mask_median = torch.median(mask_squared).item() | |
| M_denoised = torch.sqrt(torch.clamp(mask_squared-mask_median/2, 0, mask_median*2)) | |
| else: | |
| M_denoised = torch.sqrt(mask_squared) | |
| return M_denoised-M_denoised.min()#, mask_median*2) | |
| class RegularizedFMNet(nn.Module): | |
| """Compute the functional map matrix representation.""" | |
| def __init__(self, lambda_=1e-3, resolvant_gamma=0.5, use_resolvent=False): | |
| super().__init__() | |
| self.lambda_ = lambda_ | |
| self.use_resolvent = use_resolvent | |
| self.resolvant_gamma = resolvant_gamma | |
| def forward(self, feat_x, feat_y, evals_x, evals_y, evecs_trans_x, evecs_trans_y, diff_conf=None): | |
| # compute linear operator matrix representation C1 and C2 | |
| evecs_trans_x, evecs_trans_y = evecs_trans_x.unsqueeze(0), evecs_trans_y.unsqueeze(0) | |
| evals_x, evals_y = evals_x.unsqueeze(0), evals_y.unsqueeze(0) | |
| F_hat = torch.bmm(evecs_trans_x, feat_x) | |
| G_hat = torch.bmm(evecs_trans_y, feat_y) | |
| A, B = F_hat, G_hat | |
| if diff_conf is not None: | |
| diff_model, scale, normalize, absolute, N_est = diff_conf | |
| C12_raw, C21_raw = get_CXX(feat_x, feat_y, evecs_trans_x, evecs_trans_y) | |
| D12 = get_mask_noise(C12_raw, diff_model, scale=scale, N_est=N_est, normalize=normalize, absolute=absolute) | |
| D21 = get_mask_noise(C21_raw, diff_model, scale=scale, N_est=N_est, normalize=normalize, absolute=absolute) | |
| elif self.use_resolvent: | |
| D12 = get_mask(evals_x.flatten(), evals_y.flatten(), self.resolvant_gamma, feat_x.device).unsqueeze(0) | |
| D21 = get_mask(evals_y.flatten(), evals_x.flatten(), self.resolvant_gamma, feat_x.device).unsqueeze(0) | |
| else: | |
| D12 = (torch.unsqueeze(evals_y, 2) - torch.unsqueeze(evals_x, 1))**2 | |
| D21 = (torch.unsqueeze(evals_x, 2) - torch.unsqueeze(evals_y, 1))**2 | |
| A_t, B_t = A.transpose(1, 2), B.transpose(1, 2) | |
| A_A_t, B_B_t = torch.bmm(A, A_t), torch.bmm(B, B_t) | |
| B_A_t, A_B_t = torch.bmm(B, A_t), torch.bmm(A, B_t) | |
| C12_i = [] | |
| for i in range(evals_x.size(1)): | |
| D12_i = torch.cat([torch.diag(D12[bs, i, :].flatten()).unsqueeze(0) for bs in range(evals_x.size(0))], dim=0) | |
| C12 = torch.bmm(torch.inverse(A_A_t + self.lambda_ * D12_i), B_A_t[:, i, :].unsqueeze(1).transpose(1, 2)) | |
| C12_i.append(C12.transpose(1, 2)) | |
| C12 = torch.cat(C12_i, dim=1) | |
| C21_i = [] | |
| for i in range(evals_y.size(1)): | |
| D21_i = torch.cat([torch.diag(D21[bs, i, :].flatten()).unsqueeze(0) for bs in range(evals_y.size(0))], dim=0) | |
| C21 = torch.bmm(torch.inverse(B_B_t + self.lambda_ * D21_i), A_B_t[:, i, :].unsqueeze(1).transpose(1, 2)) | |
| C21_i.append(C21.transpose(1, 2)) | |
| C21 = torch.cat(C21_i, dim=1) | |
| if diff_conf is not None: | |
| return [C12_raw, C12, D12], [C21_raw, C21, D21] | |
| else: | |
| return [C12, C21] | |
| class DFMNet(nn.Module): | |
| """ | |
| Compilation of the global model : | |
| - diffusion net as feature extractor | |
| - fmap + q-fmap | |
| - unsupervised loss | |
| """ | |
| def __init__(self, cfg): | |
| super().__init__() | |
| # feature extractor # | |
| with_grad=True | |
| self.feat = cfg["feat"] | |
| self.feature_extractor = DiffusionNet( | |
| C_in=cfg["C_in"], | |
| C_out=cfg["n_feat"], | |
| C_width=128, | |
| N_block=4, | |
| dropout=True, | |
| with_gradient_features=with_grad, | |
| with_gradient_rotations=with_grad, | |
| ) | |
| # regularized fmap | |
| self.fmreg_net = RegularizedFMNet(lambda_=cfg["lambda_"], | |
| resolvant_gamma=cfg.get("resolvent_gamma", 0.5), use_resolvent=cfg.get("use_resolvent", False)) | |
| # parameters | |
| self.n_fmap = cfg["n_fmap"] | |
| if cfg.get("diffusion", None) is not None: | |
| self.normalize = cfg["diffusion"]["normalize"] | |
| self.abs = cfg["diffusion"]["abs"] | |
| self.N_est = cfg.diffusion.get("batch_mask", 100) | |
| def forward(self, batch, diff_model=None, scale=1): | |
| if self.feat == "xyz": | |
| feat_1, feat_2 = batch["shape1"]["vertices"], batch["shape2"]["vertices"] | |
| elif self.feat == "wks": | |
| feat_1, feat_2 = batch["shape1"]["wks"], batch["shape2"]["wks"] | |
| elif self.feat == "hks": | |
| feat_1, feat_2 = batch["shape1"]["hks"], batch["shape2"]["hks"] | |
| else: | |
| raise Exception("Unknow Feature") | |
| verts1, faces1, mass1, L1, evals1, evecs1, gradX1, gradY1 = (feat_1, batch["shape1"]["faces"], | |
| batch["shape1"]["mass"], batch["shape1"]["L"], | |
| batch["shape1"]["evals"], batch["shape1"]["evecs"], | |
| batch["shape1"]["gradX"], batch["shape1"]["gradY"]) | |
| verts2, faces2, mass2, L2, evals2, evecs2, gradX2, gradY2 = (feat_2, batch["shape2"]["faces"], | |
| batch["shape2"]["mass"], batch["shape2"]["L"], | |
| batch["shape2"]["evals"], batch["shape2"]["evecs"], | |
| batch["shape2"]["gradX"], batch["shape2"]["gradY"]) | |
| # set features to vertices | |
| features1, features2 = verts1, verts2 | |
| # print(features1.shape, features2.shape) | |
| feat1 = self.feature_extractor(features1, mass1, L=L1, evals=evals1, evecs=evecs1, | |
| gradX=gradX1, gradY=gradY1, faces=faces1).unsqueeze(0) | |
| feat2 = self.feature_extractor(features2, mass2, L=L2, evals=evals2, evecs=evecs2, | |
| gradX=gradX2, gradY=gradY2, faces=faces2).unsqueeze(0) | |
| evecs_trans1, evecs_trans2 = evecs1.t()[:self.n_fmap] @ torch.diag(mass1.squeeze()), evecs2.squeeze().t()[:self.n_fmap].squeeze() @ torch.diag(mass2.squeeze()) | |
| evals1, evals2 = evals1[:self.n_fmap], evals2[:self.n_fmap] | |
| # | |
| if diff_model is not None: | |
| C12_pred, C21_pred = self.fmreg_net(feat1, feat2, evals1, evals2, evecs_trans1, evecs_trans2, | |
| diff_conf=[diff_model, scale, self.normalize, self.abs, self.N_est]) | |
| else: | |
| C12_pred, C21_pred = self.fmreg_net(feat1, feat2, evals1, evals2, evecs_trans1, evecs_trans2) | |
| return C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 |