PBWR commited on
Commit
7fdaedc
·
verified ·
1 Parent(s): 0c885dd

Upload evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +333 -0
evaluate.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # _*_ coding: utf-8 _*_
3
+ # ---------------------------------------------------
4
+ # @Time : 2026-03-10 8:58 p.m.
5
+ # @Author : shangfeng
6
+ # @Organization: University of Calgary
7
+ # @File : evaluate.py.py
8
+ # @IDE : PyCharm
9
+ # -----------------Evaluation TASK---------------------
10
+ # Evaluation
11
+ # 1. Chamfer Distance (CD): Measures the geometric discrepancy between the predicted mesh and the ground-truth mesh, reflecting the overall reconstruction accuracy.
12
+ #
13
+ # 2. Edge Chamfer Distance (ECD): Evaluates the geometric similarity between the edges of the reconstructed mesh and those of the ground-truth mesh, serving as an indicator of edge sharpness and structural fidelity.
14
+ #
15
+ # 3. Normal Consistency (NC): Assesses the alignment between surface normals of the predicted mesh and the ground-truth mesh, indicating the consistency of local surface orientation.
16
+ #
17
+ # 4. V_Ratio: Defined as the ratio between the number of vertices in the predicted mesh and that of the ground-truth mesh, reflecting changes in geometric complexity.
18
+ #
19
+ # 5. F_Ratio: Defined as the ratio between the number of faces in the predicted mesh and that of the ground-truth mesh, indicating variations in mesh resolution.
20
+ # ---------------------------------------------------
21
+ import os
22
+ import trimesh
23
+ import numpy as np
24
+ from scipy.spatial import cKDTree
25
+ import faiss
26
+
27
+
28
+ # --------------------------- Load mesh using trimesh and normalization --------------------------------------
29
+ def load_mesh(p_file, gt_file):
30
+ """
31
+ :param p_file:
32
+ :param gt_file:
33
+ :return:
34
+ """
35
+ p_mesh = trimesh.load(p_file)
36
+ gt_mesh = trimesh.load(gt_file)
37
+ return p_mesh, gt_mesh
38
+
39
+ def normalization(p_mesh, gt_mesh):
40
+ gt_vertices = np.asarray(gt_mesh.vertices)
41
+ p_vertices = np.asarray(p_mesh.vertices)
42
+ vert_min = gt_vertices.min(axis=0)
43
+ vert_max = gt_vertices.max(axis=0)
44
+
45
+ vert_center = 0.5 * (vert_min + vert_max)
46
+
47
+ gt_vertices = gt_vertices - vert_center
48
+ # p_vertices = p_vertices - vert_center
49
+
50
+ vert_min = gt_vertices.min(axis=0)
51
+ vert_max = gt_vertices.max(axis=0)
52
+ extents = vert_max - vert_min
53
+ scale = np.max(extents)
54
+
55
+ gt_vertices = gt_vertices / (scale + 1e-6)
56
+ # p_vertices = p_vertices / (scale + 1e-6)
57
+ p_vertices = p_vertices * np.sqrt(np.sum(extents ** 2)) / (scale + 1e-6)
58
+
59
+ return trimesh.Trimesh(vertices=p_vertices,faces=p_mesh.faces), trimesh.Trimesh(vertices=gt_vertices,faces=gt_mesh.faces)
60
+
61
+
62
+ # --------------------------- L1 Chamfer distance --------------------------------------
63
+ def chamfer_l1_distance_kdtree(p, q):
64
+ """
65
+ p: (N,3) prediction
66
+ q: (M,3) ground truth
67
+ """
68
+
69
+ # --- Remove invalid points to ensure numerical stability
70
+ p = p[np.isfinite(p).all(axis=1)]
71
+ q = q[np.isfinite(q).all(axis=1)]
72
+
73
+ # --- KDTree
74
+ tree_p = cKDTree(p)
75
+ tree_q = cKDTree(q)
76
+
77
+ # --- Distance
78
+ dist_pq, _ = tree_q.query(p) # P → Q
79
+ dist_qp, _ = tree_p.query(q) # Q → P
80
+
81
+ # L1 Chamfer Distance
82
+ chamfer_distance = np.mean(dist_pq) + np.mean(dist_qp)
83
+
84
+ return chamfer_distance
85
+
86
+ def chamfer_l1_distance_faiss(p, q, use_gpu=False):
87
+ """
88
+ p: (N,3) prediction
89
+ q: (M,3) ground truth
90
+ """
91
+
92
+ # ---------- 1. remove invalid ----------
93
+ p = p[np.isfinite(p).all(axis=1)]
94
+ q = q[np.isfinite(q).all(axis=1)]
95
+
96
+ # FAISS
97
+ p = p.astype(np.float32)
98
+ q = q.astype(np.float32)
99
+
100
+ # ---------- 2. build index ----------
101
+ index_p = faiss.IndexFlatL2(3) # dim=3
102
+ index_q = faiss.IndexFlatL2(3)
103
+
104
+ # ---------- 3. optional GPU ----------
105
+ if use_gpu:
106
+ res = faiss.StandardGpuResources()
107
+ index_p = faiss.index_cpu_to_gpu(res, 0, index_p)
108
+ index_q = faiss.index_cpu_to_gpu(res, 0, index_q)
109
+
110
+ index_p.add(p)
111
+ index_q.add(q)
112
+
113
+ # ---------- 4. nearest neighbor ----------
114
+ # FAISS return square distance
115
+ D_pq, _ = index_q.search(p, 1) # p → q
116
+ D_qp, _ = index_p.search(q, 1) # q → p
117
+
118
+ # ---------- 5. convert to L1 ----------
119
+ dist_pq = np.sqrt(D_pq[:, 0])
120
+ dist_qp = np.sqrt(D_qp[:, 0])
121
+
122
+ chamfer_distance = dist_pq.mean() + dist_qp.mean()
123
+
124
+ return float(chamfer_distance)
125
+
126
+ # --------------------------- Mesh sampling points --------------------------------------
127
+ def mesh_sample_points(p_mesh, gt_mesh, sample_points=1000000):
128
+ """
129
+ :param p_mesh: trimesh mesh
130
+ :param gt_mesh: Trimesh mesh
131
+ :param sample_points:
132
+ :return: (sample_points, 3)
133
+ """
134
+ p_points = p_mesh.sample(sample_points)
135
+ gt_points = gt_mesh.sample(sample_points)
136
+ return p_points, gt_points
137
+
138
+ # --------------------------- Edge Chamfer L1 Distance --------------------------------------
139
+ def extract_sharp_edges(mesh, angle_threshold_deg=30.0):
140
+ """
141
+ Version-agnostic sharp edge extraction.
142
+ Works with any trimesh version.
143
+ """
144
+ faces = np.asarray(mesh.faces)
145
+ face_normals = np.asarray(mesh.face_normals)
146
+
147
+ # ------------------ normalize normals ----------------------
148
+ face_normals = face_normals / (
149
+ np.linalg.norm(face_normals, axis=1, keepdims=True) + 1e-12
150
+ )
151
+
152
+ # --- Step 1: build edge -> faces mapping ---
153
+ edge_faces = dict()
154
+
155
+ for f_idx, face in enumerate(faces):
156
+ edges = [
157
+ tuple(sorted((face[0], face[1]))),
158
+ tuple(sorted((face[1], face[2]))),
159
+ tuple(sorted((face[2], face[0]))),
160
+ ]
161
+ for e in edges:
162
+ if e not in edge_faces:
163
+ edge_faces[e] = []
164
+ edge_faces[e].append(f_idx)
165
+
166
+ # --- Step 2: detect sharp edges ---
167
+ cos_thresh = np.cos(np.deg2rad(angle_threshold_deg))
168
+ sharp_edges = []
169
+
170
+ for edge, f_list in edge_faces.items():
171
+ # boundary edge → sharp
172
+ if len(f_list) == 1:
173
+ sharp_edges.append(edge)
174
+ continue
175
+
176
+ # non-manifold (>2 faces) → treat as sharp
177
+ if len(f_list) > 2:
178
+ sharp_edges.append(edge)
179
+ continue
180
+
181
+ # exactly two adjacent faces
182
+ f1, f2 = f_list
183
+ n1 = face_normals[f1]
184
+ n2 = face_normals[f2]
185
+
186
+ dot = np.dot(n1, n2)
187
+ dot = np.clip(dot, -1.0, 1.0)
188
+ if np.abs(dot) < cos_thresh:
189
+ sharp_edges.append(edge)
190
+
191
+ if len(sharp_edges) == 0:
192
+ return np.zeros((0, 2), dtype=np.int64)
193
+
194
+ return np.asarray(sharp_edges, dtype=np.int64)
195
+
196
+
197
+ def sample_points_on_edges_global(vertices, edges, total_samples=100000):
198
+ """
199
+ Sample points uniformly along edges, proportional to edge length.
200
+
201
+ Args:
202
+ vertices (np.ndarray): (V, 3)
203
+ edges (np.ndarray): (E, 2)
204
+ total_samples (int): total number of sampled points
205
+
206
+ Returns:
207
+ np.ndarray: (total_samples, 3)
208
+ """
209
+
210
+ if edges.shape[0] == 0:
211
+ return np.zeros((0, 3), dtype=np.float32)
212
+
213
+ # --- 1. Endpoints of edges --------------
214
+ p1 = vertices[edges[:, 0]] # (E, 3)
215
+ p2 = vertices[edges[:, 1]] # (E, 3)
216
+
217
+ # --- 2. Calculate the length of edge --------------
218
+ edge_lengths = np.linalg.norm(p2 - p1, axis=1) # (E,)
219
+
220
+ # --- 3. Calculate probability --------------
221
+ probs = edge_lengths / (edge_lengths.sum() + 1e-12)
222
+
223
+ # --- 4. edge weight --------------
224
+ edge_indices = np.random.choice(len(edges), size=total_samples, p=probs)
225
+
226
+ # --- 5. random points --------------
227
+ t = np.random.rand(total_samples, 1) # (N,1)
228
+
229
+ sampled_p1 = p1[edge_indices]
230
+ sampled_p2 = p2[edge_indices]
231
+
232
+ points = (1 - t) * sampled_p1 + t * sampled_p2
233
+
234
+ return points.astype(np.float32)
235
+
236
+
237
+ def compute_edge_chamfer_distance(p_mesh, gt_mesh, angle_threshold_deg=30.0):
238
+ """
239
+ :param p_mesh:
240
+ :param gt_mesh:
241
+ :param angle_threshold_deg:
242
+ :return:
243
+ """
244
+ # ---------- Extract sharp edges ----------
245
+ sharp_edges_gt = extract_sharp_edges(gt_mesh, angle_threshold_deg)
246
+ sharp_edges_pred = extract_sharp_edges(p_mesh, angle_threshold_deg)
247
+
248
+ # ---------- Sample points on edges ----------
249
+ edge_pts_gt = sample_points_on_edges_global(
250
+ gt_mesh.vertices, sharp_edges_gt
251
+ )
252
+ edge_pts_pred = sample_points_on_edges_global(
253
+ p_mesh.vertices, sharp_edges_pred
254
+ )
255
+
256
+ # ---------- Compute ECD ----------
257
+ ecd = chamfer_l1_distance_kdtree(edge_pts_pred, edge_pts_gt)
258
+
259
+ return ecd
260
+
261
+
262
+
263
+ # --------------------------- Normal Consistency (NC) --------------------------------------
264
+ def normal_consistency(
265
+ p_mesh,
266
+ gt_mesh,
267
+ num_samples=100000
268
+ ):
269
+ """
270
+ mesh_gt, mesh_pred: trimesh.Trimesh
271
+ return: NC in [0, 1]
272
+ """
273
+
274
+ # ---------- 1. sample surface points from GT ----------
275
+ pts_gt, face_ids = trimesh.sample.sample_surface(gt_mesh, num_samples)
276
+ normals_gt = gt_mesh.face_normals[face_ids]
277
+
278
+ # ---------- 2. find closest face on pred mesh---------
279
+ closest_points, distance, face_id = p_mesh.nearest.on_surface(pts_gt)
280
+ normals_pred = p_mesh.face_normals[face_id]
281
+
282
+ # ---------- 3. normalize ----------
283
+ normals_gt = normals_gt / np.linalg.norm(normals_gt, axis=1, keepdims=True)
284
+ normals_pred = normals_pred / np.linalg.norm(normals_pred, axis=1, keepdims=True)
285
+
286
+ # ---------- 4. cosine similarity ----------
287
+ cos_sim = np.abs(np.sum(normals_gt * normals_pred, axis=1))
288
+
289
+ return float(cos_sim.mean())
290
+
291
+ # --------------------------- V_Ratio & F_Ratio --------------------------------------
292
+ def calculate_vertices_face_ratio(p_mesh, gt_mesh):
293
+ """
294
+ :param p_mesh: trimesh.Trimesh
295
+ :param gt_mesh: trimesh.Trimesh
296
+ :return: float, float
297
+ """
298
+ f_ratio = len(p_mesh.faces) / len(gt_mesh.faces)
299
+ v_ratio = len(p_mesh.vertices) / len(gt_mesh.vertices)
300
+ return v_ratio, f_ratio
301
+
302
+
303
+ # --------------------------- Mesh Evaluation For 3rd USM3D ----------------------------
304
+ def mesh_evaluation(p_file, gt_file):
305
+ """
306
+ :param p_file: the path of predicted mesh
307
+ :param gt_file: the path of ground truth mesh
308
+ :return: mesh_chamfer_distance
309
+ """
310
+ # --------------- Load Mesh using trimesh & normalization----------------
311
+ p_mesh, gt_mesh = load_mesh(p_file, gt_file)
312
+ p_mesh, gt_mesh = normalization(p_mesh, gt_mesh)
313
+
314
+ # ----------------------- Mesh Chamfer Distance --------------------------
315
+ p_points, gt_points = mesh_sample_points(p_mesh, gt_mesh)
316
+ mesh_chamfer_distance = chamfer_l1_distance_kdtree(p_points, gt_points)
317
+
318
+ # ---------------------- Edge Chamfer Distance ---------------------------
319
+ edge_chamfer_distance = compute_edge_chamfer_distance(p_mesh, gt_mesh, angle_threshold_deg=30.0)
320
+
321
+ # ---------------------- Normal Consistency --------------------------
322
+ normals_consistency = normal_consistency(p_mesh, gt_mesh)
323
+
324
+ # ---------------------- V_ratio & F_ratio ---------------------------
325
+ v_ratio, f_ratio = calculate_vertices_face_ratio(p_mesh, gt_mesh)
326
+
327
+ return mesh_chamfer_distance, edge_chamfer_distance, normals_consistency, v_ratio, f_ratio
328
+
329
+
330
+ # if __name__ == '__main__':
331
+ # p_file = r'./pred/1a_0.obj'
332
+ # gt_file = r'./gt/1a_0.obj'
333
+ # print(mesh_evaluation(p_file, gt_file))