PBWR's picture
Upload evaluate.py
7fdaedc verified
#!/usr/bin/python3
# _*_ coding: utf-8 _*_
# ---------------------------------------------------
# @Time : 2026-03-10 8:58 p.m.
# @Author : shangfeng
# @Organization: University of Calgary
# @File : evaluate.py.py
# @IDE : PyCharm
# -----------------Evaluation TASK---------------------
# Evaluation
# 1. Chamfer Distance (CD): Measures the geometric discrepancy between the predicted mesh and the ground-truth mesh, reflecting the overall reconstruction accuracy.
#
# 2. Edge Chamfer Distance (ECD): Evaluates the geometric similarity between the edges of the reconstructed mesh and those of the ground-truth mesh, serving as an indicator of edge sharpness and structural fidelity.
#
# 3. Normal Consistency (NC): Assesses the alignment between surface normals of the predicted mesh and the ground-truth mesh, indicating the consistency of local surface orientation.
#
# 4. V_Ratio: Defined as the ratio between the number of vertices in the predicted mesh and that of the ground-truth mesh, reflecting changes in geometric complexity.
#
# 5. F_Ratio: Defined as the ratio between the number of faces in the predicted mesh and that of the ground-truth mesh, indicating variations in mesh resolution.
# ---------------------------------------------------
import os
import trimesh
import numpy as np
from scipy.spatial import cKDTree
import faiss
# --------------------------- Load mesh using trimesh and normalization --------------------------------------
def load_mesh(p_file, gt_file):
"""
:param p_file:
:param gt_file:
:return:
"""
p_mesh = trimesh.load(p_file)
gt_mesh = trimesh.load(gt_file)
return p_mesh, gt_mesh
def normalization(p_mesh, gt_mesh):
gt_vertices = np.asarray(gt_mesh.vertices)
p_vertices = np.asarray(p_mesh.vertices)
vert_min = gt_vertices.min(axis=0)
vert_max = gt_vertices.max(axis=0)
vert_center = 0.5 * (vert_min + vert_max)
gt_vertices = gt_vertices - vert_center
# p_vertices = p_vertices - vert_center
vert_min = gt_vertices.min(axis=0)
vert_max = gt_vertices.max(axis=0)
extents = vert_max - vert_min
scale = np.max(extents)
gt_vertices = gt_vertices / (scale + 1e-6)
# p_vertices = p_vertices / (scale + 1e-6)
p_vertices = p_vertices * np.sqrt(np.sum(extents ** 2)) / (scale + 1e-6)
return trimesh.Trimesh(vertices=p_vertices,faces=p_mesh.faces), trimesh.Trimesh(vertices=gt_vertices,faces=gt_mesh.faces)
# --------------------------- L1 Chamfer distance --------------------------------------
def chamfer_l1_distance_kdtree(p, q):
"""
p: (N,3) prediction
q: (M,3) ground truth
"""
# --- Remove invalid points to ensure numerical stability
p = p[np.isfinite(p).all(axis=1)]
q = q[np.isfinite(q).all(axis=1)]
# --- KDTree
tree_p = cKDTree(p)
tree_q = cKDTree(q)
# --- Distance
dist_pq, _ = tree_q.query(p) # P → Q
dist_qp, _ = tree_p.query(q) # Q → P
# L1 Chamfer Distance
chamfer_distance = np.mean(dist_pq) + np.mean(dist_qp)
return chamfer_distance
def chamfer_l1_distance_faiss(p, q, use_gpu=False):
"""
p: (N,3) prediction
q: (M,3) ground truth
"""
# ---------- 1. remove invalid ----------
p = p[np.isfinite(p).all(axis=1)]
q = q[np.isfinite(q).all(axis=1)]
# FAISS
p = p.astype(np.float32)
q = q.astype(np.float32)
# ---------- 2. build index ----------
index_p = faiss.IndexFlatL2(3) # dim=3
index_q = faiss.IndexFlatL2(3)
# ---------- 3. optional GPU ----------
if use_gpu:
res = faiss.StandardGpuResources()
index_p = faiss.index_cpu_to_gpu(res, 0, index_p)
index_q = faiss.index_cpu_to_gpu(res, 0, index_q)
index_p.add(p)
index_q.add(q)
# ---------- 4. nearest neighbor ----------
# FAISS return square distance
D_pq, _ = index_q.search(p, 1) # p → q
D_qp, _ = index_p.search(q, 1) # q → p
# ---------- 5. convert to L1 ----------
dist_pq = np.sqrt(D_pq[:, 0])
dist_qp = np.sqrt(D_qp[:, 0])
chamfer_distance = dist_pq.mean() + dist_qp.mean()
return float(chamfer_distance)
# --------------------------- Mesh sampling points --------------------------------------
def mesh_sample_points(p_mesh, gt_mesh, sample_points=1000000):
"""
:param p_mesh: trimesh mesh
:param gt_mesh: Trimesh mesh
:param sample_points:
:return: (sample_points, 3)
"""
p_points = p_mesh.sample(sample_points)
gt_points = gt_mesh.sample(sample_points)
return p_points, gt_points
# --------------------------- Edge Chamfer L1 Distance --------------------------------------
def extract_sharp_edges(mesh, angle_threshold_deg=30.0):
"""
Version-agnostic sharp edge extraction.
Works with any trimesh version.
"""
faces = np.asarray(mesh.faces)
face_normals = np.asarray(mesh.face_normals)
# ------------------ normalize normals ----------------------
face_normals = face_normals / (
np.linalg.norm(face_normals, axis=1, keepdims=True) + 1e-12
)
# --- Step 1: build edge -> faces mapping ---
edge_faces = dict()
for f_idx, face in enumerate(faces):
edges = [
tuple(sorted((face[0], face[1]))),
tuple(sorted((face[1], face[2]))),
tuple(sorted((face[2], face[0]))),
]
for e in edges:
if e not in edge_faces:
edge_faces[e] = []
edge_faces[e].append(f_idx)
# --- Step 2: detect sharp edges ---
cos_thresh = np.cos(np.deg2rad(angle_threshold_deg))
sharp_edges = []
for edge, f_list in edge_faces.items():
# boundary edge → sharp
if len(f_list) == 1:
sharp_edges.append(edge)
continue
# non-manifold (>2 faces) → treat as sharp
if len(f_list) > 2:
sharp_edges.append(edge)
continue
# exactly two adjacent faces
f1, f2 = f_list
n1 = face_normals[f1]
n2 = face_normals[f2]
dot = np.dot(n1, n2)
dot = np.clip(dot, -1.0, 1.0)
if np.abs(dot) < cos_thresh:
sharp_edges.append(edge)
if len(sharp_edges) == 0:
return np.zeros((0, 2), dtype=np.int64)
return np.asarray(sharp_edges, dtype=np.int64)
def sample_points_on_edges_global(vertices, edges, total_samples=100000):
"""
Sample points uniformly along edges, proportional to edge length.
Args:
vertices (np.ndarray): (V, 3)
edges (np.ndarray): (E, 2)
total_samples (int): total number of sampled points
Returns:
np.ndarray: (total_samples, 3)
"""
if edges.shape[0] == 0:
return np.zeros((0, 3), dtype=np.float32)
# --- 1. Endpoints of edges --------------
p1 = vertices[edges[:, 0]] # (E, 3)
p2 = vertices[edges[:, 1]] # (E, 3)
# --- 2. Calculate the length of edge --------------
edge_lengths = np.linalg.norm(p2 - p1, axis=1) # (E,)
# --- 3. Calculate probability --------------
probs = edge_lengths / (edge_lengths.sum() + 1e-12)
# --- 4. edge weight --------------
edge_indices = np.random.choice(len(edges), size=total_samples, p=probs)
# --- 5. random points --------------
t = np.random.rand(total_samples, 1) # (N,1)
sampled_p1 = p1[edge_indices]
sampled_p2 = p2[edge_indices]
points = (1 - t) * sampled_p1 + t * sampled_p2
return points.astype(np.float32)
def compute_edge_chamfer_distance(p_mesh, gt_mesh, angle_threshold_deg=30.0):
"""
:param p_mesh:
:param gt_mesh:
:param angle_threshold_deg:
:return:
"""
# ---------- Extract sharp edges ----------
sharp_edges_gt = extract_sharp_edges(gt_mesh, angle_threshold_deg)
sharp_edges_pred = extract_sharp_edges(p_mesh, angle_threshold_deg)
# ---------- Sample points on edges ----------
edge_pts_gt = sample_points_on_edges_global(
gt_mesh.vertices, sharp_edges_gt
)
edge_pts_pred = sample_points_on_edges_global(
p_mesh.vertices, sharp_edges_pred
)
# ---------- Compute ECD ----------
ecd = chamfer_l1_distance_kdtree(edge_pts_pred, edge_pts_gt)
return ecd
# --------------------------- Normal Consistency (NC) --------------------------------------
def normal_consistency(
p_mesh,
gt_mesh,
num_samples=100000
):
"""
mesh_gt, mesh_pred: trimesh.Trimesh
return: NC in [0, 1]
"""
# ---------- 1. sample surface points from GT ----------
pts_gt, face_ids = trimesh.sample.sample_surface(gt_mesh, num_samples)
normals_gt = gt_mesh.face_normals[face_ids]
# ---------- 2. find closest face on pred mesh---------
closest_points, distance, face_id = p_mesh.nearest.on_surface(pts_gt)
normals_pred = p_mesh.face_normals[face_id]
# ---------- 3. normalize ----------
normals_gt = normals_gt / np.linalg.norm(normals_gt, axis=1, keepdims=True)
normals_pred = normals_pred / np.linalg.norm(normals_pred, axis=1, keepdims=True)
# ---------- 4. cosine similarity ----------
cos_sim = np.abs(np.sum(normals_gt * normals_pred, axis=1))
return float(cos_sim.mean())
# --------------------------- V_Ratio & F_Ratio --------------------------------------
def calculate_vertices_face_ratio(p_mesh, gt_mesh):
"""
:param p_mesh: trimesh.Trimesh
:param gt_mesh: trimesh.Trimesh
:return: float, float
"""
f_ratio = len(p_mesh.faces) / len(gt_mesh.faces)
v_ratio = len(p_mesh.vertices) / len(gt_mesh.vertices)
return v_ratio, f_ratio
# --------------------------- Mesh Evaluation For 3rd USM3D ----------------------------
def mesh_evaluation(p_file, gt_file):
"""
:param p_file: the path of predicted mesh
:param gt_file: the path of ground truth mesh
:return: mesh_chamfer_distance
"""
# --------------- Load Mesh using trimesh & normalization----------------
p_mesh, gt_mesh = load_mesh(p_file, gt_file)
p_mesh, gt_mesh = normalization(p_mesh, gt_mesh)
# ----------------------- Mesh Chamfer Distance --------------------------
p_points, gt_points = mesh_sample_points(p_mesh, gt_mesh)
mesh_chamfer_distance = chamfer_l1_distance_kdtree(p_points, gt_points)
# ---------------------- Edge Chamfer Distance ---------------------------
edge_chamfer_distance = compute_edge_chamfer_distance(p_mesh, gt_mesh, angle_threshold_deg=30.0)
# ---------------------- Normal Consistency --------------------------
normals_consistency = normal_consistency(p_mesh, gt_mesh)
# ---------------------- V_ratio & F_ratio ---------------------------
v_ratio, f_ratio = calculate_vertices_face_ratio(p_mesh, gt_mesh)
return mesh_chamfer_distance, edge_chamfer_distance, normals_consistency, v_ratio, f_ratio
# if __name__ == '__main__':
# p_file = r'./pred/1a_0.obj'
# gt_file = r'./gt/1a_0.obj'
# print(mesh_evaluation(p_file, gt_file))