| |
|
|
| import json |
| import logging |
| from typing import List, Optional |
| import torch |
| from torch import nn |
|
|
| from detectron2.utils.file_io import PathManager |
|
|
| from densepose.structures.mesh import create_mesh |
|
|
|
|
| class MeshAlignmentEvaluator: |
| """ |
| Class for evaluation of 3D mesh alignment based on the learned vertex embeddings |
| """ |
|
|
| def __init__(self, embedder: nn.Module, mesh_names: Optional[List[str]]): |
| self.embedder = embedder |
| |
| self.mesh_names = mesh_names if mesh_names else embedder.mesh_names |
| self.logger = logging.getLogger(__name__) |
| with PathManager.open( |
| "https://dl.fbaipublicfiles.com/densepose/data/cse/mesh_keyvertices_v0.json", "r" |
| ) as f: |
| self.mesh_keyvertices = json.load(f) |
|
|
| def evaluate(self): |
| ge_per_mesh = {} |
| gps_per_mesh = {} |
| for mesh_name_1 in self.mesh_names: |
| avg_errors = [] |
| avg_gps = [] |
| embeddings_1 = self.embedder(mesh_name_1) |
| keyvertices_1 = self.mesh_keyvertices[mesh_name_1] |
| keyvertex_names_1 = list(keyvertices_1.keys()) |
| keyvertex_indices_1 = [keyvertices_1[name] for name in keyvertex_names_1] |
| for mesh_name_2 in self.mesh_names: |
| if mesh_name_1 == mesh_name_2: |
| continue |
| embeddings_2 = self.embedder(mesh_name_2) |
| keyvertices_2 = self.mesh_keyvertices[mesh_name_2] |
| sim_matrix_12 = embeddings_1[keyvertex_indices_1].mm(embeddings_2.T) |
| vertices_2_matching_keyvertices_1 = sim_matrix_12.argmax(axis=1) |
| mesh_2 = create_mesh(mesh_name_2, embeddings_2.device) |
| geodists = mesh_2.geodists[ |
| vertices_2_matching_keyvertices_1, |
| [keyvertices_2[name] for name in keyvertex_names_1], |
| ] |
| Current_Mean_Distances = 0.255 |
| gps = (-(geodists**2) / (2 * (Current_Mean_Distances**2))).exp() |
| avg_errors.append(geodists.mean().item()) |
| avg_gps.append(gps.mean().item()) |
|
|
| ge_mean = torch.as_tensor(avg_errors).mean().item() |
| gps_mean = torch.as_tensor(avg_gps).mean().item() |
| ge_per_mesh[mesh_name_1] = ge_mean |
| gps_per_mesh[mesh_name_1] = gps_mean |
| ge_mean_global = torch.as_tensor(list(ge_per_mesh.values())).mean().item() |
| gps_mean_global = torch.as_tensor(list(gps_per_mesh.values())).mean().item() |
| per_mesh_metrics = { |
| "GE": ge_per_mesh, |
| "GPS": gps_per_mesh, |
| } |
| return ge_mean_global, gps_mean_global, per_mesh_metrics |
|
|