|
|
|
|
|
import json |
|
import logging |
|
from typing import List, Optional |
|
import torch |
|
from torch import nn |
|
|
|
from detectron2.utils.file_io import PathManager |
|
|
|
from densepose.structures.mesh import create_mesh |
|
|
|
|
|
class MeshAlignmentEvaluator: |
|
""" |
|
Class for evaluation of 3D mesh alignment based on the learned vertex embeddings |
|
""" |
|
|
|
def __init__(self, embedder: nn.Module, mesh_names: Optional[List[str]]): |
|
self.embedder = embedder |
|
|
|
self.mesh_names = mesh_names if mesh_names else embedder.mesh_names |
|
self.logger = logging.getLogger(__name__) |
|
with PathManager.open( |
|
"https://dl.fbaipublicfiles.com/densepose/data/cse/mesh_keyvertices_v0.json", "r" |
|
) as f: |
|
self.mesh_keyvertices = json.load(f) |
|
|
|
def evaluate(self): |
|
ge_per_mesh = {} |
|
gps_per_mesh = {} |
|
for mesh_name_1 in self.mesh_names: |
|
avg_errors = [] |
|
avg_gps = [] |
|
embeddings_1 = self.embedder(mesh_name_1) |
|
keyvertices_1 = self.mesh_keyvertices[mesh_name_1] |
|
keyvertex_names_1 = list(keyvertices_1.keys()) |
|
keyvertex_indices_1 = [keyvertices_1[name] for name in keyvertex_names_1] |
|
for mesh_name_2 in self.mesh_names: |
|
if mesh_name_1 == mesh_name_2: |
|
continue |
|
embeddings_2 = self.embedder(mesh_name_2) |
|
keyvertices_2 = self.mesh_keyvertices[mesh_name_2] |
|
sim_matrix_12 = embeddings_1[keyvertex_indices_1].mm(embeddings_2.T) |
|
vertices_2_matching_keyvertices_1 = sim_matrix_12.argmax(axis=1) |
|
mesh_2 = create_mesh(mesh_name_2, embeddings_2.device) |
|
geodists = mesh_2.geodists[ |
|
vertices_2_matching_keyvertices_1, |
|
[keyvertices_2[name] for name in keyvertex_names_1], |
|
] |
|
Current_Mean_Distances = 0.255 |
|
gps = (-(geodists**2) / (2 * (Current_Mean_Distances**2))).exp() |
|
avg_errors.append(geodists.mean().item()) |
|
avg_gps.append(gps.mean().item()) |
|
|
|
ge_mean = torch.as_tensor(avg_errors).mean().item() |
|
gps_mean = torch.as_tensor(avg_gps).mean().item() |
|
ge_per_mesh[mesh_name_1] = ge_mean |
|
gps_per_mesh[mesh_name_1] = gps_mean |
|
ge_mean_global = torch.as_tensor(list(ge_per_mesh.values())).mean().item() |
|
gps_mean_global = torch.as_tensor(list(gps_per_mesh.values())).mean().item() |
|
per_mesh_metrics = { |
|
"GE": ge_per_mesh, |
|
"GPS": gps_per_mesh, |
|
} |
|
return ge_mean_global, gps_mean_global, per_mesh_metrics |
|
|