jacklangerman commited on
Commit
4a7e4e0
1 Parent(s): 2633f6b

update metric

Browse files
Files changed (3) hide show
  1. hoho/vis.py +3 -2
  2. hoho/wed.py +76 -19
  3. requirements.txt +3 -1
hoho/vis.py CHANGED
@@ -133,7 +133,8 @@ def create_image_grid(images, target_length=312, num_per_row=2):
133
  return grid_img
134
 
135
 
136
- import matplotlib
 
137
  def visualize_depth(depth, min_depth=None, max_depth=None, cmap='rainbow'):
138
  depth = np.array(depth)
139
 
@@ -148,7 +149,7 @@ def visualize_depth(depth, min_depth=None, max_depth=None, cmap='rainbow'):
148
  depth = np.clip(depth, 0, 1)
149
 
150
  # Use the matplotlib colormap to convert the depth to an RGB image
151
- cmap = matplotlib.cm.get_cmap(cmap)
152
  depth_image = (cmap(depth) * 255).astype(np.uint8)
153
 
154
  # Convert the depth image to a PIL image
 
133
  return grid_img
134
 
135
 
136
+ import matplotlib.pyplot as plt
137
+
138
  def visualize_depth(depth, min_depth=None, max_depth=None, cmap='rainbow'):
139
  depth = np.array(depth)
140
 
 
149
  depth = np.clip(depth, 0, 1)
150
 
151
  # Use the matplotlib colormap to convert the depth to an RGB image
152
+ cmap = plt.get_cmap(cmap)
153
  depth_image = (cmap(depth) * 255).astype(np.uint8)
154
 
155
  # Convert the depth image to a PIL image
hoho/wed.py CHANGED
@@ -6,40 +6,96 @@ import numpy as np
6
  def zeromean_normalize(vertices):
7
  vertices = np.array(vertices)
8
  vertices = vertices - vertices.mean(axis=0)
9
- vertices = vertices / (1e-6 + np.linalg.norm(vertices, axis=1)[:, None])
10
  return vertices
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=1.0, ce=1.0, normalized=True, squared=False):
 
14
  pd_vertices = np.array(pd_vertices)
15
  gt_vertices = np.array(gt_vertices)
16
- pd_vertices = zeromean_normalize(pd_vertices)
17
- gt_vertices = zeromean_normalize(gt_vertices)
 
 
 
 
 
 
 
18
 
19
  pd_edges = np.array(pd_edges)
20
  gt_edges = np.array(gt_edges)
21
 
22
- # Step 1: Bipartite Matching
23
- if squared:
24
- distances = cdist(pd_vertices, gt_vertices, metric='sqeuclidean')
25
- else:
26
- distances = cdist(pd_vertices, gt_vertices, metric='euclidean')
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  row_ind, col_ind = linear_sum_assignment(distances)
29
- # Step 2: Vertex Translation
30
 
31
- if squared:
32
- translation_costs = cv * np.sqrt(np.sum(distances[row_ind, col_ind]))
33
- else:
34
- translation_costs = cv * np.sum(distances[row_ind, col_ind])
35
 
36
  # Additional: Vertex Deletion
37
  unmatched_pd_indices = set(range(len(pd_vertices))) - set(row_ind)
38
- deletion_costs = cv * len(unmatched_pd_indices) # Assuming a fixed cost for vertex deletion
39
 
40
  # Step 3: Vertex Insertion
41
  unmatched_gt_indices = set(range(len(gt_vertices))) - set(col_ind)
42
- insertion_costs = cv * len(unmatched_gt_indices) # Assuming a fixed cost for vertex insertion
43
 
44
  # Step 4: Edge Deletion and Insertion
45
  updated_pd_edges = [(col_ind[np.where(row_ind == edge[0])[0][0]], col_ind[np.where(row_ind == edge[1])[0][0]]) for edge in pd_edges if edge[0] in row_ind and edge[1] in row_ind]
@@ -61,11 +117,12 @@ def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=1.0, ce=1.0, no
61
 
62
  # Step 5: Calculation of WED
63
  WED = translation_costs + deletion_costs + insertion_costs + deletion_edge_costs + insertion_edge_costs
64
- print ("translation_costs, deletion_costs, insertion_costs, deletion_edge_costs, insertion_edge_costs")
65
- print (translation_costs, deletion_costs, insertion_costs, deletion_edge_costs, insertion_edge_costs)
66
 
67
  if normalized:
68
  total_length_of_gt_edges = np.linalg.norm((gt_vertices[gt_edges[:, 0]] - gt_vertices[gt_edges[:, 1]]), axis=1).sum()
69
  WED = WED / total_length_of_gt_edges
70
- print ("Total length", total_length_of_gt_edges)
 
71
  return WED
 
6
  def zeromean_normalize(vertices):
7
  vertices = np.array(vertices)
8
  vertices = vertices - vertices.mean(axis=0)
9
+ vertices = vertices / (1e-6 + np.linalg.norm(vertices, axis=1)[:, None]) # project all verts to sphere (not what we meant)
10
  return vertices
11
 
12
+ def preregister_mean_std(verts_to_transform, target_verts, single_scale=True):
13
+ mu_target = target_verts.mean(axis=0)
14
+ mu_in = verts_to_transform.mean(axis=0)
15
+ std_target = np.std(target_verts, axis=0)
16
+ std_in = np.std(verts_to_transform, axis=0)
17
+
18
+ if np.any(std_in == 0):
19
+ std_in[std_in == 0] = 1
20
+ if np.any(std_target == 0):
21
+ std_target[std_target == 0] = 1
22
+ if np.any(np.isnan(std_in)):
23
+ std_in[np.isnan(std_in)] = 1
24
+ if np.any(np.isnan(std_target)):
25
+ std_target[np.isnan(std_target)] = 1
26
+
27
+ if single_scale:
28
+ std_target = np.linalg.norm(std_target)
29
+ std_in = np.linalg.norm(std_in)
30
+
31
+ transformed_verts = (verts_to_transform - mu_in) / std_in
32
+ transformed_verts = transformed_verts * std_target + mu_target
33
+
34
+ return transformed_verts
35
 
36
+
37
+ def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=100.0, ce=1.0, normalized=True, prenorm=False, preregister=True, register=True, single_scale=True):
38
  pd_vertices = np.array(pd_vertices)
39
  gt_vertices = np.array(gt_vertices)
40
+
41
+ # Step 0: Prenormalize / preregister
42
+ if prenorm:
43
+ pd_vertices = zeromean_normalize(pd_vertices)
44
+ gt_vertices = zeromean_normalize(gt_vertices)
45
+
46
+ if preregister:
47
+ pd_vertices = preregister_mean_std(pd_vertices, gt_vertices, single_scale=single_scale)
48
+
49
 
50
  pd_edges = np.array(pd_edges)
51
  gt_edges = np.array(gt_edges)
52
 
 
 
 
 
 
53
 
54
+ # Step 0.5: Register
55
+ if register:
56
+ # find the optimal rotation, translation, and scale
57
+ from scipy.spatial.transform import Rotation as R
58
+ from scipy.optimize import minimize
59
+
60
+ def transform(x, pd_vertices):
61
+ # x is a 7-element vector, first 3 elements are the rotation vector, next 3 elements are the translation vector, finally scale
62
+ rotation = R.from_rotvec(x[:3])
63
+ translation = x[3:6]
64
+ scale = x[6]
65
+ return scale * rotation.apply(pd_vertices) + translation
66
+
67
+ def cost_function(x, pd_vertices, gt_vertices):
68
+ pd_vertices_transformed = transform(x, pd_vertices)
69
+ distances = cdist(pd_vertices_transformed, gt_vertices, metric='euclidean')
70
+ row_ind, col_ind = linear_sum_assignment(distances)
71
+ translation_costs = np.sum(distances[row_ind, col_ind])
72
+
73
+ return translation_costs
74
+
75
+ x0 = np.array([0, 0, 0, 0, 0, 0, 1])
76
+ # minimize subject to scale > 1e-6
77
+ # res = minimize(cost_function, x0, args=(pd_vertices, gt_vertices), constraints={'type': 'ineq', 'fun': lambda x: x[6] - 1e-6})
78
+ res = minimize(cost_function, x0, args=(pd_vertices, gt_vertices), bounds=[(-np.pi, np.pi), (-np.pi, np.pi), (-np.pi, np.pi), (-500, 500), (-500, 500), (-500, 500), (0.1, 3)])
79
+ # print("scale:", res.x)
80
+
81
+ pd_vertices = transform(res.x, pd_vertices)
82
+
83
+
84
+ # Step 1: Bipartite Matching
85
+ distances = cdist(pd_vertices, gt_vertices, metric='euclidean')
86
  row_ind, col_ind = linear_sum_assignment(distances)
87
+
88
 
89
+ # Step 2: Vertex Translation
90
+ translation_costs = np.sum(distances[row_ind, col_ind])
 
 
91
 
92
  # Additional: Vertex Deletion
93
  unmatched_pd_indices = set(range(len(pd_vertices))) - set(row_ind)
94
+ deletion_costs = cv * len(unmatched_pd_indices)
95
 
96
  # Step 3: Vertex Insertion
97
  unmatched_gt_indices = set(range(len(gt_vertices))) - set(col_ind)
98
+ insertion_costs = cv * len(unmatched_gt_indices)
99
 
100
  # Step 4: Edge Deletion and Insertion
101
  updated_pd_edges = [(col_ind[np.where(row_ind == edge[0])[0][0]], col_ind[np.where(row_ind == edge[1])[0][0]]) for edge in pd_edges if edge[0] in row_ind and edge[1] in row_ind]
 
117
 
118
  # Step 5: Calculation of WED
119
  WED = translation_costs + deletion_costs + insertion_costs + deletion_edge_costs + insertion_edge_costs
120
+ # print("translation_costs, deletion_costs, insertion_costs, deletion_edge_costs, insertion_edge_costs")
121
+ # print(translation_costs, deletion_costs, insertion_costs, deletion_edge_costs, insertion_edge_costs)
122
 
123
  if normalized:
124
  total_length_of_gt_edges = np.linalg.norm((gt_vertices[gt_edges[:, 0]] - gt_vertices[gt_edges[:, 1]]), axis=1).sum()
125
  WED = WED / total_length_of_gt_edges
126
+
127
+ # print ("Total length", total_length_of_gt_edges)
128
  return WED
requirements.txt CHANGED
@@ -3,4 +3,6 @@ pillow
3
  webdataset
4
  trimesh
5
  scipy
6
- datasets
 
 
 
3
  webdataset
4
  trimesh
5
  scipy
6
+ datasets
7
+ ipywidgets
8
+ matplotlib