diffumatch / snk /loss.py
daidedou
forgot a few things lol
e321b92
from trimesh.graph import face_adjacency
import torch
import torch.nn as nn
class PrismRegularizationLoss(nn.Module):
"""
Calculate the loss based on the PriMo energy, as described in the paper:
PriMo: Coupled Prisms for Intuitive Surface Modeling
"""
def __init__(self, primo_h):
super().__init__()
self.h = primo_h
# compute coefficient for the energy
indices = torch.tensor([(i, j) for i in range(2) for j in range(2)])
indices_A = indices.repeat_interleave(4, dim=0)
indices_B = indices.repeat(4, 1)
self.coeff = (torch.ones(1) * 2).pow(((indices_A - indices_B).abs() * -1).sum(dim=1))[None, :]
def forward(self, transformed_prism, rotations, verts, faces, normals):
# transformed_prism is (n_faces, 3, 3)
# verts and faces are from the template (shape 2)
# * for now assumes there is only one batch
# todo add batch support
bs = 1
verts = verts.reshape(-1, 3)
normals = normals.reshape(-1, 3)
faces = faces
# get the area of each face
face_areas = self.get_face_areas(verts, faces) # (n_faces,)
# get list of edges and the faces that share each edge
face_ids, edges = face_adjacency(faces.cpu().numpy(), return_edges=True) # (n_edges, 2), (n_edges, 2)
face_ids, edges = torch.from_numpy(face_ids).to(verts.device), torch.from_numpy(edges).to(verts.device)
# normals and rotations of the faces that share each edge
normals1, normals2 = normals[edges[:, 0]], normals[edges[:, 1]] # (n_edges, 3), normals are per vertex
rotations1, rotations2 = rotations[face_ids[:, 0]], rotations[face_ids[:, 1]] # (n_edges, 3, 3), rotations are per face
# computed normals from the transformed prism
# normals = self.compute_normals(transformed_prism)
# compute the loss
face_id1, face_id2 = face_ids[:, 0], face_ids[:, 1] # (n_edges,)
faces_to_verts = self.get_verts_id_face(faces, edges, face_ids) # (n_edges, 4)
verts1_p1, verts2_p1 = transformed_prism[face_id1, faces_to_verts[:, 0]], transformed_prism[face_id1, faces_to_verts[:, 1]] # (n_edges, 3)
verts1_p2, verts2_p2 = transformed_prism[face_id2, faces_to_verts[:, 2]], transformed_prism[face_id2, faces_to_verts[:, 3]] # (n_edges, 3)
# get the normals per vertex
# normals1, normals2 = normals[face_id1], normals[face_id2] # (n_edges, 3) # normals per face (NOT USED)
prism1_n1, prism1_n2 = (normals1[:, None] @ rotations1).squeeze(1), (normals2[:, None] @ rotations1).squeeze(1) # todo check if this is correct
prism2_n1, prism2_n2 = (normals1[:, None] @ rotations2).squeeze(1), (normals2[:, None] @ rotations2).squeeze(1)
# get the coordinates of the face of the prism
# prism1 (1 -> 2)
f_p1_00, f_p1_01 = verts1_p1 + prism1_n1 * self.h, verts2_p1 + prism1_n2 * self.h # (n_edges, 3)
f_p1_10, f_p1_11 = verts1_p1 - prism1_n1 * self.h, verts2_p1 - prism1_n2 * self.h # (n_edges, 3)
# prism2 (2 -> 1)
f_p2_00, f_p2_01 = verts1_p2 + prism2_n1 * self.h, verts2_p2 + prism2_n2 * self.h # (n_edges, 3)
f_p2_10, f_p2_11 = verts1_p2 - prism2_n1 * self.h, verts2_p2 - prism2_n2 * self.h # (n_edges, 3)
# compute the energy
A, B = torch.stack((f_p1_00, f_p1_01, f_p1_10, f_p1_11), dim=1), torch.stack((f_p2_00, f_p2_01, f_p2_10, f_p2_11), dim=1) # (n_edges, 4, 3)
energy = self.compute_energy(A - B, A - B) # (n_edges,)
# compute weight
area1, area2 = face_areas[face_id1], face_areas[face_id2] # (n_edges,)
weight = torch.norm(verts[edges[:, 0]] - verts[edges[:, 1]], dim=1).square() / (area1 + area2) # (n_edges,)
# weight = torch.ones_like(weight).to(weight.device) # todo remove
energy = energy * weight # (n_edges,)
loss = energy.sum() / bs # todo when batch enabled, need to divide by batch size
return loss
def compute_energy(self, A, B):
"""
Computes the formula sum_{i,j,k,l=0}^{1} a_{ij}b_{kl} 2^{-|i - k| - |j - l|}.
Assumes that A and B are tensors of size bs x 4 x 3, where bs is the batch size.
"""
self.coeff = self.coeff.to(A.device)
A_repeated = A.repeat_interleave(4, dim=1)
B_repeated = B.repeat(1, 4, 1)
energy = (A_repeated * B_repeated).sum(dim=-1)
energy = (energy * self.coeff).sum(dim=1)
energy = energy / 9
return energy
def get_face_areas(self, verts, faces):
# get the area of each face
v1, v2, v3 = verts[faces[:, 0]], verts[faces[:, 1]], verts[faces[:, 2]]
area = 0.5 * torch.cross(v2 - v1, v3 - v1, dim=-1).norm(dim=1)
return area
def get_verts_id_face(self, F, E, Q):
e = E.shape[0]
Z = torch.zeros((e, 4), dtype=torch.long)
v1 = F[:, 0][Q[:, 0]]
v2 = F[:, 1][Q[:, 0]]
v3 = F[:, 2][Q[:, 0]]
v4 = F[:, 0][Q[:, 1]]
v5 = F[:, 1][Q[:, 1]]
v6 = F[:, 2][Q[:, 1]]
idx1 = torch.where(v1 == E[:, 0], 0, torch.where(v2 == E[:, 0], 1, torch.where(v3 == E[:, 0], 2, -1)))
idx2 = torch.where(v1 == E[:, 1], 0, torch.where(v2 == E[:, 1], 1, torch.where(v3 == E[:, 1], 2, -1)))
idx3 = torch.where(v4 == E[:, 0], 0, torch.where(v5 == E[:, 0], 1, torch.where(v6 == E[:, 0], 2, -1)))
idx4 = torch.where(v4 == E[:, 1], 0, torch.where(v5 == E[:, 1], 1, torch.where(v6 == E[:, 1], 2, -1)))
Z[:, 0:2] = torch.stack((idx1, idx2), dim=1)
Z[:, 2:4] = torch.stack((idx3, idx4), dim=1)
Z = Z.to(F.device)
return Z