Added a fix to the metric: corrected indexes mismatch, and added zeromean normalization
2633f6b
from scipy.spatial.distance import cdist | |
from scipy.optimize import linear_sum_assignment | |
import numpy as np | |
def zeromean_normalize(vertices): | |
vertices = np.array(vertices) | |
vertices = vertices - vertices.mean(axis=0) | |
vertices = vertices / (1e-6 + np.linalg.norm(vertices, axis=1)[:, None]) | |
return vertices | |
def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=1.0, ce=1.0, normalized=True, squared=False): | |
pd_vertices = np.array(pd_vertices) | |
gt_vertices = np.array(gt_vertices) | |
pd_vertices = zeromean_normalize(pd_vertices) | |
gt_vertices = zeromean_normalize(gt_vertices) | |
pd_edges = np.array(pd_edges) | |
gt_edges = np.array(gt_edges) | |
# Step 1: Bipartite Matching | |
if squared: | |
distances = cdist(pd_vertices, gt_vertices, metric='sqeuclidean') | |
else: | |
distances = cdist(pd_vertices, gt_vertices, metric='euclidean') | |
row_ind, col_ind = linear_sum_assignment(distances) | |
# Step 2: Vertex Translation | |
if squared: | |
translation_costs = cv * np.sqrt(np.sum(distances[row_ind, col_ind])) | |
else: | |
translation_costs = cv * np.sum(distances[row_ind, col_ind]) | |
# Additional: Vertex Deletion | |
unmatched_pd_indices = set(range(len(pd_vertices))) - set(row_ind) | |
deletion_costs = cv * len(unmatched_pd_indices) # Assuming a fixed cost for vertex deletion | |
# Step 3: Vertex Insertion | |
unmatched_gt_indices = set(range(len(gt_vertices))) - set(col_ind) | |
insertion_costs = cv * len(unmatched_gt_indices) # Assuming a fixed cost for vertex insertion | |
# Step 4: Edge Deletion and Insertion | |
updated_pd_edges = [(col_ind[np.where(row_ind == edge[0])[0][0]], col_ind[np.where(row_ind == edge[1])[0][0]]) for edge in pd_edges if edge[0] in row_ind and edge[1] in row_ind] | |
pd_edges_set = set(map(tuple, [set(edge) for edge in updated_pd_edges])) | |
gt_edges_set = set(map(tuple, [set(edge) for edge in gt_edges])) | |
# Delete edges not in ground truth | |
edges_to_delete = pd_edges_set - gt_edges_set | |
#deletion_edge_costs = ce * sum(np.linalg.norm(pd_vertices[edge[0]] - pd_vertices[edge[1]]) for edge in edges_to_delete) | |
vert_tf = [np.where(col_ind == v)[0][0] if v in col_ind else 0 for v in range(len(gt_vertices))] | |
deletion_edge_costs = ce * sum(np.linalg.norm(pd_vertices[vert_tf[edge[0]]] - pd_vertices[vert_tf[edge[1]]]) for edge in edges_to_delete) | |
# Insert missing edges from ground truth | |
edges_to_insert = gt_edges_set - pd_edges_set | |
insertion_edge_costs = ce * sum(np.linalg.norm(gt_vertices[edge[0]] - gt_vertices[edge[1]]) for edge in edges_to_insert) | |
# Step 5: Calculation of WED | |
WED = translation_costs + deletion_costs + insertion_costs + deletion_edge_costs + insertion_edge_costs | |
print ("translation_costs, deletion_costs, insertion_costs, deletion_edge_costs, insertion_edge_costs") | |
print (translation_costs, deletion_costs, insertion_costs, deletion_edge_costs, insertion_edge_costs) | |
if normalized: | |
total_length_of_gt_edges = np.linalg.norm((gt_vertices[gt_edges[:, 0]] - gt_vertices[gt_edges[:, 1]]), axis=1).sum() | |
WED = WED / total_length_of_gt_edges | |
print ("Total length", total_length_of_gt_edges) | |
return WED |