jacklangerman commited on
Commit
3150ada
1 Parent(s): 91b350c

separate cv_ins and cv_del + bump to 0.4

Browse files
Files changed (2) hide show
  1. hoho/wed.py +16 -10
  2. setup.py +1 -1
hoho/wed.py CHANGED
@@ -28,7 +28,17 @@ def preregister_mean_std(verts_to_transform, target_verts, single_scale=True):
28
  return transformed_verts
29
 
30
 
31
- def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=-1/4, ce=1.0, normalized=True, preregister=True, single_scale=True):
 
 
 
 
 
 
 
 
 
 
32
  '''The function computes the Wireframe Edge Distance (WED) between two graphs.
33
  pd_vertices: list of predicted vertices
34
  pd_edges: list of predicted edges
@@ -51,13 +61,9 @@ def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=-1/4, ce=1.0, n
51
  if preregister:
52
  pd_vertices = preregister_mean_std(pd_vertices, gt_vertices, single_scale=single_scale)
53
 
54
- if cv < 0:
55
- diameter = cdist(gt_vertices, gt_vertices).max()
56
- # Cost of adding or deleting a vertex is set to -cv times the diameter of the ground truth wireframe
57
- cv = -cv * diameter
58
- elif cv == 0:
59
- # Cost of adding or deleting a vertex is set to the average distance of the ground truth vertices from their mean
60
- cv = np.linalg.norm(np.mean(gt_vertices, axis=0) - gt_vertices, axis=1).mean()
61
  # Step 0: Prenormalize / preregister
62
 
63
 
@@ -74,11 +80,11 @@ def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=-1/4, ce=1.0, n
74
 
75
  # Additional: Vertex Deletion
76
  unmatched_pd_indices = set(range(len(pd_vertices))) - set(row_ind)
77
- deletion_costs = cv * len(unmatched_pd_indices)
78
 
79
  # Step 3: Vertex Insertion
80
  unmatched_gt_indices = set(range(len(gt_vertices))) - set(col_ind)
81
- insertion_costs = cv * len(unmatched_gt_indices)
82
 
83
  # Step 4: Edge Deletion and Insertion
84
  updated_pd_edges = [(col_ind[np.where(row_ind == edge[0])[0][0]], col_ind[np.where(row_ind == edge[1])[0][0]]) for edge in pd_edges if edge[0] in row_ind and edge[1] in row_ind]
 
28
  return transformed_verts
29
 
30
 
31
+ def update_cv(cv, gt_vertices):
32
+ if cv < 0:
33
+ diameter = cdist(gt_vertices, gt_vertices).max()
34
+ # Cost of adding or deleting a vertex is set to -cv times the diameter of the ground truth wireframe
35
+ cv = -cv * diameter
36
+ elif cv == 0:
37
+ # Cost of adding or deleting a vertex is set to the average distance of the ground truth vertices from their mean
38
+ cv = np.linalg.norm(np.mean(gt_vertices, axis=0) - gt_vertices, axis=1).mean()
39
+ return cv
40
+
41
+ def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv_ins=-1/2, cv_del=-1/4, ce=1.0, normalized=True, preregister=True, single_scale=True):
42
  '''The function computes the Wireframe Edge Distance (WED) between two graphs.
43
  pd_vertices: list of predicted vertices
44
  pd_edges: list of predicted edges
 
61
  if preregister:
62
  pd_vertices = preregister_mean_std(pd_vertices, gt_vertices, single_scale=single_scale)
63
 
64
+ cv_del = update_cv(cv_del, gt_vertices)
65
+ cv_ins = update_cv(cv_ins, gt_vertices)
66
+
 
 
 
 
67
  # Step 0: Prenormalize / preregister
68
 
69
 
 
80
 
81
  # Additional: Vertex Deletion
82
  unmatched_pd_indices = set(range(len(pd_vertices))) - set(row_ind)
83
+ deletion_costs = cv_del * len(unmatched_pd_indices)
84
 
85
  # Step 3: Vertex Insertion
86
  unmatched_gt_indices = set(range(len(gt_vertices))) - set(col_ind)
87
+ insertion_costs = cv_ins * len(unmatched_gt_indices)
88
 
89
  # Step 4: Edge Deletion and Insertion
90
  updated_pd_edges = [(col_ind[np.where(row_ind == edge[0])[0][0]], col_ind[np.where(row_ind == edge[1])[0][0]]) for edge in pd_edges if edge[0] in row_ind and edge[1] in row_ind]
setup.py CHANGED
@@ -6,7 +6,7 @@ with open('requirements.txt') as f:
6
  required = f.read().splitlines()
7
 
8
  setup(name='hoho',
9
- version='0.0.3',
10
  description='Tools and utilites for the HoHo Dataset and S23DR Competition',
11
  url='usm3d.github.io',
12
  author='Jack Langerman, Dmytro Mishkin, S23DR Orgainizing Team',
 
6
  required = f.read().splitlines()
7
 
8
  setup(name='hoho',
9
+ version='0.0.4',
10
  description='Tools and utilites for the HoHo Dataset and S23DR Competition',
11
  url='usm3d.github.io',
12
  author='Jack Langerman, Dmytro Mishkin, S23DR Orgainizing Team',