clustering_eps=80
Browse files- handcrafted_solution.py +18 -23
- script.py +1 -1
handcrafted_solution.py
CHANGED
@@ -430,7 +430,7 @@ def prune_not_connected(all_3d_vertices, connections_3d):
|
|
430 |
return np.array(new_verts), connected_out
|
431 |
|
432 |
|
433 |
-
def predict(entry, visualize=False, scale_estimation_coefficient=2.5, clustering_eps
|
434 |
if 'gestalt' not in entry or 'depthcm' not in entry or 'K' not in entry or 'R' not in entry or 't' not in entry:
|
435 |
print('Missing required fields in the entry')
|
436 |
return (entry['__key__'], *empty_solution())
|
@@ -441,31 +441,29 @@ def predict(entry, visualize=False, scale_estimation_coefficient=2.5, clustering
|
|
441 |
for k, v in entry["images"].items():
|
442 |
image_dict[v.name] = v
|
443 |
points = [v.xyz for k, v in entry["points3d"].items()]
|
444 |
-
too_big = False
|
445 |
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
point_keys = np.array(point_keys)
|
450 |
|
451 |
-
|
452 |
|
453 |
-
|
454 |
-
|
455 |
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
|
460 |
-
|
461 |
|
462 |
-
|
463 |
-
|
464 |
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
|
470 |
for i, (gest, depthcm, K, R, t, imagekey) in enumerate(zip(entry['gestalt'],
|
471 |
entry['depthcm'],
|
@@ -485,10 +483,7 @@ def predict(entry, visualize=False, scale_estimation_coefficient=2.5, clustering
|
|
485 |
continue
|
486 |
belonging_points = []
|
487 |
for i in image_dict[imagekey].point3D_ids[np.where(image_dict[imagekey].point3D_ids != -1)]:
|
488 |
-
if
|
489 |
-
if i in biggest_cluster_keys:
|
490 |
-
belonging_points.append(entry["points3d"][i])
|
491 |
-
else:
|
492 |
belonging_points.append(entry["points3d"][i])
|
493 |
|
494 |
if len(belonging_points) < 1:
|
|
|
430 |
return np.array(new_verts), connected_out
|
431 |
|
432 |
|
433 |
+
def predict(entry, visualize=False, scale_estimation_coefficient=2.5, clustering_eps=100, **kwargs) -> Tuple[np.ndarray, List[int]]:
|
434 |
if 'gestalt' not in entry or 'depthcm' not in entry or 'K' not in entry or 'R' not in entry or 't' not in entry:
|
435 |
print('Missing required fields in the entry')
|
436 |
return (entry['__key__'], *empty_solution())
|
|
|
441 |
for k, v in entry["images"].items():
|
442 |
image_dict[v.name] = v
|
443 |
points = [v.xyz for k, v in entry["points3d"].items()]
|
|
|
444 |
|
445 |
+
points = np.array(points)
|
446 |
+
point_keys = [k for k, v in entry["points3d"].items()]
|
447 |
+
point_keys = np.array(point_keys)
|
|
|
448 |
|
449 |
+
# print(len(points))
|
450 |
|
451 |
+
clustered = DBSCAN(eps=clustering_eps, min_samples=10).fit(points).labels_
|
452 |
+
clustered_indices = np.argsort(clustered)
|
453 |
|
454 |
+
points = points[clustered_indices]
|
455 |
+
point_keys = point_keys[clustered_indices]
|
456 |
+
clustered = clustered[clustered_indices]
|
457 |
|
458 |
+
_, cluster_indices = np.unique(clustered, return_index=True)
|
459 |
|
460 |
+
clustered_points = np.split(points, cluster_indices[1:])
|
461 |
+
clustered_keys = np.split(point_keys, cluster_indices[1:])
|
462 |
|
463 |
+
biggest_cluster_index = np.argmax([len(i) for i in clustered_points])
|
464 |
+
# biggest_cluster = clustered_points[biggest_cluster_index]
|
465 |
+
biggest_cluster_keys = clustered_keys[biggest_cluster_index]
|
466 |
+
biggest_cluster_keys = set(biggest_cluster_keys)
|
467 |
|
468 |
for i, (gest, depthcm, K, R, t, imagekey) in enumerate(zip(entry['gestalt'],
|
469 |
entry['depthcm'],
|
|
|
483 |
continue
|
484 |
belonging_points = []
|
485 |
for i in image_dict[imagekey].point3D_ids[np.where(image_dict[imagekey].point3D_ids != -1)]:
|
486 |
+
if i in biggest_cluster_keys:
|
|
|
|
|
|
|
487 |
belonging_points.append(entry["points3d"][i])
|
488 |
|
489 |
if len(belonging_points) < 1:
|
script.py
CHANGED
@@ -137,7 +137,7 @@ if __name__ == "__main__":
|
|
137 |
merge_th=100.0,
|
138 |
min_missing_distance=30000000.0,
|
139 |
scale_estimation_coefficient=2.54,
|
140 |
-
clustering_eps=
|
141 |
))
|
142 |
|
143 |
for i, result in enumerate(tqdm(results)):
|
|
|
137 |
merge_th=100.0,
|
138 |
min_missing_distance=30000000.0,
|
139 |
scale_estimation_coefficient=2.54,
|
140 |
+
clustering_eps=80,
|
141 |
))
|
142 |
|
143 |
for i, result in enumerate(tqdm(results)):
|