.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *.whl filter=lfs diff=lfs merge=lfs -text
36
  *tfevents* filter=lfs diff=lfs merge=lfs -text
37
  packages/** filter=lfs diff=lfs merge=lfs -text
 
 
35
  *.whl filter=lfs diff=lfs merge=lfs -text
36
  *tfevents* filter=lfs diff=lfs merge=lfs -text
37
  packages/** filter=lfs diff=lfs merge=lfs -text
38
+ *.ipynb filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .ipynb_checkpoints
2
-
 
 
1
  .ipynb_checkpoints
2
+ __pycache__/
3
+ data
color_mappings.py DELETED
@@ -1,182 +0,0 @@
1
- gestalt_color_mapping = {
2
- "unclassified": (215, 62, 138),
3
- "apex": (235, 88, 48),
4
- "eave_end_point": (248, 130, 228),
5
- "flashing_end_point": (71, 11, 161),
6
- "ridge": (214, 251, 248),
7
- "rake": (13, 94, 47),
8
- "eave": (54, 243, 63),
9
- "post": (187, 123, 236),
10
- "ground_line": (136, 206, 14),
11
- "flashing": (162, 162, 32),
12
- "step_flashing": (169, 255, 219),
13
- "hip": (8, 89, 52),
14
- "valley": (85, 27, 65),
15
- "roof": (215, 232, 179),
16
- "door": (110, 52, 23),
17
- "garage": (50, 233, 171),
18
- "window": (230, 249, 40),
19
- "shutter": (122, 4, 233),
20
- "fascia": (95, 230, 240),
21
- "soffit": (2, 102, 197),
22
- "horizontal_siding": (131, 88, 59),
23
- "vertical_siding": (110, 187, 198),
24
- "brick": (171, 252, 7),
25
- "concrete": (32, 47, 246),
26
- "other_wall": (112, 61, 240),
27
- "trim": (151, 206, 58),
28
- "unknown": (127, 127, 127),
29
- }
30
-
31
- ade20k_color_mapping = {
32
- 'wall': (120, 120, 120),
33
- 'building;edifice': (180, 120, 120),
34
- 'sky': (6, 230, 230),
35
- 'floor;flooring': (80, 50, 50),
36
- 'tree': (4, 200, 3),
37
- 'ceiling': (120, 120, 80),
38
- 'road;route': (140, 140, 140),
39
- 'bed': (204, 5, 255),
40
- 'windowpane;window': (230, 230, 230),
41
- 'grass': (4, 250, 7),
42
- 'cabinet': (224, 5, 255),
43
- 'sidewalk;pavement': (235, 255, 7),
44
- 'person;individual;someone;somebody;mortal;soul': (150, 5, 61),
45
- 'earth;ground': (120, 120, 70),
46
- 'door;double;door': (8, 255, 51),
47
- 'table': (255, 6, 82),
48
- 'mountain;mount': (143, 255, 140),
49
- 'plant;flora;plant;life': (204, 255, 4),
50
- 'curtain;drape;drapery;mantle;pall': (255, 51, 7),
51
- 'chair': (204, 70, 3),
52
- 'car;auto;automobile;machine;motorcar': (0, 102, 200),
53
- 'water': (61, 230, 250),
54
- 'painting;picture': (255, 6, 51),
55
- 'sofa;couch;lounge': (11, 102, 255),
56
- 'shelf': (255, 7, 71),
57
- 'house': (255, 9, 224),
58
- 'sea': (9, 7, 230),
59
- 'mirror': (220, 220, 220),
60
- 'rug;carpet;carpeting': (255, 9, 92),
61
- 'field': (112, 9, 255),
62
- 'armchair': (8, 255, 214),
63
- 'seat': (7, 255, 224),
64
- 'fence;fencing': (255, 184, 6),
65
- 'desk': (10, 255, 71),
66
- 'rock;stone': (255, 41, 10),
67
- 'wardrobe;closet;press': (7, 255, 255),
68
- 'lamp': (224, 255, 8),
69
- 'bathtub;bathing;tub;bath;tub': (102, 8, 255),
70
- 'railing;rail': (255, 61, 6),
71
- 'cushion': (255, 194, 7),
72
- 'base;pedestal;stand': (255, 122, 8),
73
- 'box': (0, 255, 20),
74
- 'column;pillar': (255, 8, 41),
75
- 'signboard;sign': (255, 5, 153),
76
- 'chest;of;drawers;chest;bureau;dresser': (6, 51, 255),
77
- 'counter': (235, 12, 255),
78
- 'sand': (160, 150, 20),
79
- 'sink': (0, 163, 255),
80
- 'skyscraper': (140, 140, 140),
81
- 'fireplace;hearth;open;fireplace': (250, 10, 15),
82
- 'refrigerator;icebox': (20, 255, 0),
83
- 'grandstand;covered;stand': (31, 255, 0),
84
- 'path': (255, 31, 0),
85
- 'stairs;steps': (255, 224, 0),
86
- 'runway': (153, 255, 0),
87
- 'case;display;case;showcase;vitrine': (0, 0, 255),
88
- 'pool;table;billiard;table;snooker;table': (255, 71, 0),
89
- 'pillow': (0, 235, 255),
90
- 'screen;door;screen': (0, 173, 255),
91
- 'stairway;staircase': (31, 0, 255),
92
- 'river': (11, 200, 200),
93
- 'bridge;span': (255 ,82, 0),
94
- 'bookcase': (0, 255, 245),
95
- 'blind;screen': (0, 61, 255),
96
- 'coffee;table;cocktail;table': (0, 255, 112),
97
- 'toilet;can;commode;crapper;pot;potty;stool;throne': (0, 255, 133),
98
- 'flower': (255, 0, 0),
99
- 'book': (255, 163, 0),
100
- 'hill': (255, 102, 0),
101
- 'bench': (194, 255, 0),
102
- 'countertop': (0, 143, 255),
103
- 'stove;kitchen;stove;range;kitchen;range;cooking;stove': (51, 255, 0),
104
- 'palm;palm;tree': (0, 82, 255),
105
- 'kitchen;island': (0, 255, 41),
106
- 'computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system': (0, 255, 173),
107
- 'swivel;chair': (10, 0, 255),
108
- 'boat': (173, 255, 0),
109
- 'bar': (0, 255, 153),
110
- 'arcade;machine': (255, 92, 0),
111
- 'hovel;hut;hutch;shack;shanty': (255, 0, 255),
112
- 'bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle': (255, 0, 245),
113
- 'towel': (255, 0, 102),
114
- 'light;light;source': (255, 173, 0),
115
- 'truck;motortruck': (255, 0, 20),
116
- 'tower': (255, 184, 184),
117
- 'chandelier;pendant;pendent': (0, 31, 255),
118
- 'awning;sunshade;sunblind': (0, 255, 61),
119
- 'streetlight;street;lamp': (0, 71, 255),
120
- 'booth;cubicle;stall;kiosk': (255, 0, 204),
121
- 'television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box': (0, 255, 194),
122
- 'airplane;aeroplane;plane': (0, 255, 82),
123
- 'dirt;track': (0, 10, 255),
124
- 'apparel;wearing;apparel;dress;clothes': (0, 112, 255),
125
- 'pole': (51, 0, 255),
126
- 'land;ground;soil': (0, 194, 255),
127
- 'bannister;banister;balustrade;balusters;handrail': (0, 122, 255),
128
- 'escalator;moving;staircase;moving;stairway': (0, 255, 163),
129
- 'ottoman;pouf;pouffe;puff;hassock': (255, 153, 0),
130
- 'bottle': (0, 255, 10),
131
- 'buffet;counter;sideboard': (255, 112, 0),
132
- 'poster;posting;placard;notice;bill;card': (143, 255, 0),
133
- 'stage': (82, 0, 255),
134
- 'van': (163, 255, 0),
135
- 'ship': (255, 235, 0),
136
- 'fountain': (8, 184, 170),
137
- 'conveyer;belt;conveyor;belt;conveyer;conveyor;transporter': (133, 0, 255),
138
- 'canopy': (0, 255, 92),
139
- 'washer;automatic;washer;washing;machine': (184, 0, 255),
140
- 'plaything;toy': (255, 0, 31),
141
- 'swimming;pool;swimming;bath;natatorium': (0, 184, 255),
142
- 'stool': (0, 214, 255),
143
- 'barrel;cask': (255, 0, 112),
144
- 'basket;handbasket': (92, 255, 0),
145
- 'waterfall;falls': (0, 224, 255),
146
- 'tent;collapsible;shelter': (112, 224, 255),
147
- 'bag': (70, 184, 160),
148
- 'minibike;motorbike': (163, 0, 255),
149
- 'cradle': (153, 0, 255),
150
- 'oven': (71, 255, 0),
151
- 'ball': (255, 0, 163),
152
- 'food;solid;food': (255, 204, 0),
153
- 'step;stair': (255, 0, 143),
154
- 'tank;storage;tank': (0, 255, 235),
155
- 'trade;name;brand;name;brand;marque': (133, 255, 0),
156
- 'microwave;microwave;oven': (255, 0, 235),
157
- 'pot;flowerpot': (245, 0, 255),
158
- 'animal;animate;being;beast;brute;creature;fauna': (255, 0, 122),
159
- 'bicycle;bike;wheel;cycle': (255, 245, 0),
160
- 'lake': (10, 190, 212),
161
- 'dishwasher;dish;washer;dishwashing;machine': (214, 255, 0),
162
- 'screen;silver;screen;projection;screen': (0, 204, 255),
163
- 'blanket;cover': (20, 0, 255),
164
- 'sculpture': (255, 255, 0),
165
- 'hood;exhaust;hood': (0, 153, 255),
166
- 'sconce': (0, 41, 255),
167
- 'vase': (0, 255, 204),
168
- 'traffic;light;traffic;signal;stoplight': (41, 0, 255),
169
- 'tray': (41, 255, 0),
170
- 'ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin': (173, 0, 255),
171
- 'fan': (0, 245, 255),
172
- 'pier;wharf;wharfage;dock': (71, 0, 255),
173
- 'crt;screen': (122, 0, 255),
174
- 'plate': (0, 255, 184),
175
- 'monitor;monitoring;device': (0, 92, 255),
176
- 'bulletin;board;notice;board': (184, 255, 0),
177
- 'shower': (0, 133, 255),
178
- 'radiator': (255, 214, 0),
179
- 'glass;drinking;glass': (25, 194, 194),
180
- 'clock': (102, 255, 0),
181
- 'flag': (92, 0, 255),
182
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
handcrafted_solution.py CHANGED
@@ -1,15 +1,16 @@
1
  # Description: This file contains the handcrafted solution for the task of wireframe reconstruction
2
 
3
  import io
4
- from read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
5
  from PIL import Image as PImage
6
  import numpy as np
7
- from color_mappings import gestalt_color_mapping, ade20k_color_mapping
8
  from collections import defaultdict
9
  import cv2
10
  from typing import Tuple, List
11
  from scipy.spatial.distance import cdist
12
 
 
 
 
13
 
14
  def empty_solution():
15
  '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
@@ -234,9 +235,9 @@ def predict(entry, visualize=False) -> Tuple[np.ndarray, List[int]]:
234
  all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d)
235
  if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
236
  print (f'Not enough vertices or connections in the 3D vertices')
237
- return empty_solution()
238
  if visualize:
239
- from viz3d import plot_estimate_and_gt
240
  plot_estimate_and_gt(all_3d_vertices_clean, connections_3d_clean, good_entry['wf_vertices'],
241
  good_entry['wf_edges'])
242
- return all_3d_vertices_clean, connections_3d_clean, [0 for i in range(len(connections_3d_clean))]
 
1
  # Description: This file contains the handcrafted solution for the task of wireframe reconstruction
2
 
3
  import io
 
4
  from PIL import Image as PImage
5
  import numpy as np
 
6
  from collections import defaultdict
7
  import cv2
8
  from typing import Tuple, List
9
  from scipy.spatial.distance import cdist
10
 
11
+ from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
12
+ from hoho.color_mappings import gestalt_color_mapping, ade20k_color_mapping
13
+
14
 
15
  def empty_solution():
16
  '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
 
235
  all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d)
236
  if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
237
  print (f'Not enough vertices or connections in the 3D vertices')
238
+ return (good_entry['__key__'], *empty_solution())
239
  if visualize:
240
+ from hoho.viz3d import plot_estimate_and_gt
241
  plot_estimate_and_gt(all_3d_vertices_clean, connections_3d_clean, good_entry['wf_vertices'],
242
  good_entry['wf_edges'])
243
+ return good_entry['__key__'], all_3d_vertices_clean, connections_3d_clean, [0 for i in range(len(connections_3d_clean))]
hoho.py DELETED
@@ -1,261 +0,0 @@
1
- import os
2
- import json
3
- import shutil
4
- from pathlib import Path
5
- from typing import Dict
6
-
7
- from PIL import ImageFile
8
- ImageFile.LOAD_TRUNCATED_IMAGES = True
9
-
10
- LOCAL_DATADIR = None
11
-
12
- def setup(local_dir='./data/usm-training-data/data'):
13
-
14
- # If we are in the test environment, we need to link the data directory to the correct location
15
- tmp_datadir = Path('/tmp/data/data')
16
- local_test_datadir = Path('./data/usm-test-data-x/data')
17
- local_val_datadir = Path(local_dir)
18
-
19
- os.system('pwd')
20
- os.system('ls -lahtr .')
21
-
22
- if tmp_datadir.exists() and not local_test_datadir.exists():
23
- global LOCAL_DATADIR
24
- LOCAL_DATADIR = local_test_datadir
25
- # shutil.move(datadir, './usm-test-data-x/data')
26
- print(f"Linking {tmp_datadir} to {LOCAL_DATADIR} (we are in the test environment)")
27
- LOCAL_DATADIR.parent.mkdir(parents=True, exist_ok=True)
28
- LOCAL_DATADIR.symlink_to(tmp_datadir)
29
- else:
30
- LOCAL_DATADIR = local_val_datadir
31
- print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)")
32
-
33
- # os.system("ls -lahtr")
34
-
35
- assert LOCAL_DATADIR.exists(), f"Data directory {LOCAL_DATADIR} does not exist"
36
- return LOCAL_DATADIR
37
-
38
-
39
-
40
-
41
- import importlib
42
- from pathlib import Path
43
- import subprocess
44
-
45
- def download_package(package_name, path_to_save='packages'):
46
- """
47
- Downloads a package using pip and saves it to a specified directory.
48
-
49
- Parameters:
50
- package_name (str): The name of the package to download.
51
- path_to_save (str): The path to the directory where the package will be saved.
52
- """
53
- try:
54
- # pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
55
- subprocess.check_call([subprocess.sys.executable, "-m", "pip", "download", package_name,
56
- "-d", str(Path(path_to_save)/package_name), # Download the package to the specified directory
57
- "--platform", "manylinux1_x86_64", # Specify the platform
58
- "--python-version", "38", # Specify the Python version
59
- "--only-binary=:all:"]) # Download only binary packages
60
- print(f'Package "{package_name}" downloaded successfully')
61
- except subprocess.CalledProcessError as e:
62
- print(f'Failed to downloaded package "{package_name}". Error: {e}')
63
-
64
-
65
- def install_package_from_local_file(package_name, folder='packages'):
66
- """
67
- Installs a package from a local .whl file or a directory containing .whl files using pip.
68
-
69
- Parameters:
70
- path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
71
- """
72
- try:
73
- pth = str(Path(folder) / package_name)
74
- subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
75
- "--no-index", # Do not use package index
76
- "--find-links", pth, # Look for packages in the specified directory or at the file
77
- package_name]) # Specify the package to install
78
- print(f"Package installed successfully from {pth}")
79
- except subprocess.CalledProcessError as e:
80
- print(f"Failed to install package from {pth}. Error: {e}")
81
-
82
-
83
- def importt(module_name, as_name=None):
84
- """
85
- Imports a module and returns it.
86
-
87
- Parameters:
88
- module_name (str): The name of the module to import.
89
- as_name (str): The name to use for the imported module. If None, the original module name will be used.
90
-
91
- Returns:
92
- The imported module.
93
- """
94
- for _ in range(2):
95
- try:
96
- if as_name is None:
97
- print(f'imported {module_name}')
98
- return importlib.import_module(module_name)
99
- else:
100
- print(f'imported {module_name} as {as_name}')
101
- return importlib.import_module(module_name, as_name)
102
- except ModuleNotFoundError as e:
103
- install_package_from_local_file(module_name)
104
- print(f"Failed to import module {module_name}. Error: {e}")
105
-
106
-
107
- def prepare_submission():
108
- # Download packages from requirements.txt
109
- if Path('requirements.txt').exists():
110
- print('downloading packages from requirements.txt')
111
- Path('packages').mkdir(exist_ok=True)
112
- with open('requirements.txt') as f:
113
- packages = f.readlines()
114
- for p in packages:
115
- download_package(p.strip())
116
-
117
-
118
- print('all packages downloaded. Don\'t foget to include the packages in the submission by adding them with git lfs.')
119
-
120
-
121
- def Rt_to_eye_target(im, K, R, t):
122
- height = im.height
123
- focal_length = K[0,0]
124
- fov = 2.0 * np.arctan2((0.5 * height), focal_length) / (np.pi / 180.0)
125
-
126
- x_axis, y_axis, z_axis = R
127
-
128
- eye = -(R.T @ t).squeeze()
129
- z_axis = z_axis.squeeze()
130
- target = eye + z_axis
131
- up = -y_axis
132
-
133
- return eye, target, up, fov
134
-
135
-
136
- ########## general utilities ##########
137
- import contextlib
138
- import tempfile
139
- from pathlib import Path
140
-
141
- @contextlib.contextmanager
142
- def working_directory(path):
143
- """Changes working directory and returns to previous on exit."""
144
- prev_cwd = Path.cwd()
145
- os.chdir(path)
146
- try:
147
- yield
148
- finally:
149
- os.chdir(prev_cwd)
150
-
151
- @contextlib.contextmanager
152
- def temp_working_directory():
153
- with tempfile.TemporaryDirectory(dir='.') as D:
154
- with working_directory(D):
155
- yield
156
-
157
-
158
- ############# Dataset #############
159
- def proc(row, split='train'):
160
- # column_names_train = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'mesh', 'wireframe']
161
- # column_names_test = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'wireframe']
162
- # cols = column_names_train if split == 'train' else column_names_test
163
- out = {}
164
- for k, v in row.items():
165
- colname = k.split('.')[0]
166
- if colname in {'ade20k', 'depthcm', 'gestalt'}:
167
- if colname in out:
168
- out[colname].append(v)
169
- else:
170
- out[colname] = [v]
171
- elif colname in {'wireframe', 'mesh'}:
172
- # out.update({a: b.tolist() for a,b in v.items()})
173
- out.update({a: b for a,b in v.items()})
174
- elif colname in 'kr':
175
- out[colname.upper()] = v
176
- else:
177
- out[colname] = v
178
-
179
- return Sample(out)
180
-
181
-
182
- class Sample(Dict):
183
- def __repr__(self):
184
- return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
185
-
186
-
187
-
188
- def get_params():
189
- exmaple_param_dict = {
190
- "competition_id": "usm3d/S23DR",
191
- "competition_type": "script",
192
- "metric": "custom",
193
- "token": "hf_**********************************",
194
- "team_id": "local-test-team_id",
195
- "submission_id": "local-test-submission_id",
196
- "submission_id_col": "__key__",
197
- "submission_cols": [
198
- "__key__",
199
- "wf_edges",
200
- "wf_vertices",
201
- "edge_semantics"
202
- ],
203
- "submission_rows": 180,
204
- "output_path": ".",
205
- "submission_repo": "<THE HF MODEL ID of THIS REPO",
206
- "time_limit": 7200,
207
- "dataset": "usm3d/usm-test-data-x",
208
- "submission_filenames": [
209
- "submission.parquet"
210
- ]
211
- }
212
-
213
- param_path = Path('params.json')
214
-
215
- if not param_path.exists():
216
- print('params.json not found (this means we probably aren\'t in the test env). Using example params.')
217
- params = exmaple_param_dict
218
- else:
219
- print('found params.json (this means we are probably in the test env). Using params from file.')
220
- with param_path.open() as f:
221
- params = json.load(f)
222
- print(params)
223
- return params
224
-
225
-
226
-
227
- import webdataset as wds
228
- import numpy as np
229
-
230
- def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset'):
231
- if LOCAL_DATADIR is None:
232
- raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.')
233
-
234
- local_dir = Path(LOCAL_DATADIR)
235
- if split != 'all':
236
- local_dir = local_dir / split
237
-
238
- paths = [str(p) for p in local_dir.rglob('*.tar.gz')]
239
-
240
- dataset = wds.WebDataset(paths)
241
- if decode is not None:
242
- dataset = dataset.decode(decode)
243
- else:
244
- dataset = dataset.decode()
245
-
246
- dataset = dataset.map(proc)
247
-
248
- if dataset_type == 'webdataset':
249
- return dataset
250
-
251
- if dataset_type == 'hf':
252
- import datasets
253
- from datasets import Features, Value, Sequence, Image, Array2D
254
-
255
- if split == 'train':
256
- return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
257
- elif split == 'val':
258
- return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
259
-
260
-
261
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages/webdataset/numpy-1.21.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl → notebooks/EDA.ipynb RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dde972a1e11bb7b702ed0e447953e7617723760f420decb97305e66fb4afc54f
3
- size 14092363
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a847cfb0c6458dc4edfc707f30c6156345dbbc486ad651abd8bfad4f0ee659a
3
+ size 14355368
notebooks/example_on_training.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
packages/webdataset/PyYAML-6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:277a0ef2981ca40581a47093e9e2d13b3f1fbbeffae064c1d21bfceba2030287
3
- size 701221
 
 
 
 
packages/webdataset/braceexpand-0.1.7-py2.py3-none-any.whl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:91332d53de7828103dcae5773fb43bc34950b0c8160e35e0f44c4427a3b85014
3
- size 5923
 
 
 
 
packages/webdataset/webdataset-0.2.86-py3-none-any.whl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:843a2b57c6356ebba25e811adf38a476da8e176f1b192f8bd5c8270daf1a6989
3
- size 70378
 
 
 
 
read_write_colmap.py DELETED
@@ -1,489 +0,0 @@
1
- # Modified to read from bytes-like object by Dmytro Mishkin.
2
- # The original license is below:
3
- # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
4
- # All rights reserved.
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # * Redistributions of source code must retain the above copyright
10
- # notice, this list of conditions and the following disclaimer.
11
- #
12
- # * Redistributions in binary form must reproduce the above copyright
13
- # notice, this list of conditions and the following disclaimer in the
14
- # documentation and/or other materials provided with the distribution.
15
- #
16
- # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
17
- # its contributors may be used to endorse or promote products derived
18
- # from this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23
- # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
24
- # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
25
- # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
26
- # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
27
- # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
28
- # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
29
- # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
30
- # POSSIBILITY OF SUCH DAMAGE.
31
- #
32
- # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
33
-
34
- import os
35
- import collections
36
- import numpy as np
37
- import struct
38
- import argparse
39
-
40
-
41
- CameraModel = collections.namedtuple(
42
- "CameraModel", ["model_id", "model_name", "num_params"])
43
- Camera = collections.namedtuple(
44
- "Camera", ["id", "model", "width", "height", "params"])
45
- BaseImage = collections.namedtuple(
46
- "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
47
- Point3D = collections.namedtuple(
48
- "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
49
-
50
-
51
- class Image(BaseImage):
52
- def qvec2rotmat(self):
53
- return qvec2rotmat(self.qvec)
54
-
55
-
56
- CAMERA_MODELS = {
57
- CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
58
- CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
59
- CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
60
- CameraModel(model_id=3, model_name="RADIAL", num_params=5),
61
- CameraModel(model_id=4, model_name="OPENCV", num_params=8),
62
- CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
63
- CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
64
- CameraModel(model_id=7, model_name="FOV", num_params=5),
65
- CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
66
- CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
67
- CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
68
- }
69
- CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
70
- for camera_model in CAMERA_MODELS])
71
- CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
72
- for camera_model in CAMERA_MODELS])
73
-
74
-
75
- def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
76
- """Read and unpack the next bytes from a binary file.
77
- :param fid:
78
- :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
79
- :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
80
- :param endian_character: Any of {@, =, <, >, !}
81
- :return: Tuple of read and unpacked values.
82
- """
83
- data = fid.read(num_bytes)
84
- return struct.unpack(endian_character + format_char_sequence, data)
85
-
86
-
87
- def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
88
- """pack and write to a binary file.
89
- :param fid:
90
- :param data: data to send, if multiple elements are sent at the same time,
91
- they should be encapsuled either in a list or a tuple
92
- :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
93
- should be the same length as the data list or tuple
94
- :param endian_character: Any of {@, =, <, >, !}
95
- """
96
- if isinstance(data, (list, tuple)):
97
- bytes = struct.pack(endian_character + format_char_sequence, *data)
98
- else:
99
- bytes = struct.pack(endian_character + format_char_sequence, data)
100
- fid.write(bytes)
101
-
102
-
103
- def read_cameras_text(path):
104
- """
105
- see: src/base/reconstruction.cc
106
- void Reconstruction::WriteCamerasText(const std::string& path)
107
- void Reconstruction::ReadCamerasText(const std::string& path)
108
- """
109
- cameras = {}
110
- with open(path, "r") as fid:
111
- while True:
112
- line = fid.readline()
113
- if not line:
114
- break
115
- line = line.strip()
116
- if len(line) > 0 and line[0] != "#":
117
- elems = line.split()
118
- camera_id = int(elems[0])
119
- model = elems[1]
120
- width = int(elems[2])
121
- height = int(elems[3])
122
- params = np.array(tuple(map(float, elems[4:])))
123
- cameras[camera_id] = Camera(id=camera_id, model=model,
124
- width=width, height=height,
125
- params=params)
126
- return cameras
127
-
128
-
129
- def read_cameras_binary(path_to_model_file=None, fid=None):
130
- """
131
- see: src/base/reconstruction.cc
132
- void Reconstruction::WriteCamerasBinary(const std::string& path)
133
- void Reconstruction::ReadCamerasBinary(const std::string& path)
134
- """
135
- cameras = {}
136
- if fid is None:
137
- fid = open(path_to_model_file, "rb")
138
- num_cameras = read_next_bytes(fid, 8, "Q")[0]
139
- for _ in range(num_cameras):
140
- camera_properties = read_next_bytes(
141
- fid, num_bytes=24, format_char_sequence="iiQQ")
142
- camera_id = camera_properties[0]
143
- model_id = camera_properties[1]
144
- model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
145
- width = camera_properties[2]
146
- height = camera_properties[3]
147
- num_params = CAMERA_MODEL_IDS[model_id].num_params
148
- params = read_next_bytes(fid, num_bytes=8*num_params,
149
- format_char_sequence="d"*num_params)
150
- cameras[camera_id] = Camera(id=camera_id,
151
- model=model_name,
152
- width=width,
153
- height=height,
154
- params=np.array(params))
155
- assert len(cameras) == num_cameras
156
- if path_to_model_file is not None:
157
- fid.close()
158
- return cameras
159
-
160
-
161
- def write_cameras_text(cameras, path):
162
- """
163
- see: src/base/reconstruction.cc
164
- void Reconstruction::WriteCamerasText(const std::string& path)
165
- void Reconstruction::ReadCamerasText(const std::string& path)
166
- """
167
- HEADER = "# Camera list with one line of data per camera:\n" + \
168
- "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + \
169
- "# Number of cameras: {}\n".format(len(cameras))
170
- with open(path, "w") as fid:
171
- fid.write(HEADER)
172
- for _, cam in cameras.items():
173
- to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
174
- line = " ".join([str(elem) for elem in to_write])
175
- fid.write(line + "\n")
176
-
177
-
178
- def write_cameras_binary(cameras, path_to_model_file):
179
- """
180
- see: src/base/reconstruction.cc
181
- void Reconstruction::WriteCamerasBinary(const std::string& path)
182
- void Reconstruction::ReadCamerasBinary(const std::string& path)
183
- """
184
- with open(path_to_model_file, "wb") as fid:
185
- write_next_bytes(fid, len(cameras), "Q")
186
- for _, cam in cameras.items():
187
- model_id = CAMERA_MODEL_NAMES[cam.model].model_id
188
- camera_properties = [cam.id,
189
- model_id,
190
- cam.width,
191
- cam.height]
192
- write_next_bytes(fid, camera_properties, "iiQQ")
193
- for p in cam.params:
194
- write_next_bytes(fid, float(p), "d")
195
- return cameras
196
-
197
-
198
- def read_images_text(path):
199
- """
200
- see: src/base/reconstruction.cc
201
- void Reconstruction::ReadImagesText(const std::string& path)
202
- void Reconstruction::WriteImagesText(const std::string& path)
203
- """
204
- images = {}
205
- with open(path, "r") as fid:
206
- while True:
207
- line = fid.readline()
208
- if not line:
209
- break
210
- line = line.strip()
211
- if len(line) > 0 and line[0] != "#":
212
- elems = line.split()
213
- image_id = int(elems[0])
214
- qvec = np.array(tuple(map(float, elems[1:5])))
215
- tvec = np.array(tuple(map(float, elems[5:8])))
216
- camera_id = int(elems[8])
217
- image_name = elems[9]
218
- elems = fid.readline().split()
219
- xys = np.column_stack([tuple(map(float, elems[0::3])),
220
- tuple(map(float, elems[1::3]))])
221
- point3D_ids = np.array(tuple(map(int, elems[2::3])))
222
- images[image_id] = Image(
223
- id=image_id, qvec=qvec, tvec=tvec,
224
- camera_id=camera_id, name=image_name,
225
- xys=xys, point3D_ids=point3D_ids)
226
- return images
227
-
228
-
229
- def read_images_binary(path_to_model_file=None, fid=None):
230
- """
231
- see: src/base/reconstruction.cc
232
- void Reconstruction::ReadImagesBinary(const std::string& path)
233
- void Reconstruction::WriteImagesBinary(const std::string& path)
234
- """
235
- images = {}
236
- if fid is None:
237
- fid = open(path_to_model_file, "rb")
238
- num_reg_images = read_next_bytes(fid, 8, "Q")[0]
239
- for _ in range(num_reg_images):
240
- binary_image_properties = read_next_bytes(
241
- fid, num_bytes=64, format_char_sequence="idddddddi")
242
- image_id = binary_image_properties[0]
243
- qvec = np.array(binary_image_properties[1:5])
244
- tvec = np.array(binary_image_properties[5:8])
245
- camera_id = binary_image_properties[8]
246
- image_name = ""
247
- current_char = read_next_bytes(fid, 1, "c")[0]
248
- while current_char != b"\x00": # look for the ASCII 0 entry
249
- image_name += current_char.decode("utf-8")
250
- current_char = read_next_bytes(fid, 1, "c")[0]
251
- num_points2D = read_next_bytes(fid, num_bytes=8,
252
- format_char_sequence="Q")[0]
253
- x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
254
- format_char_sequence="ddq"*num_points2D)
255
- xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
256
- tuple(map(float, x_y_id_s[1::3]))])
257
- point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
258
- images[image_id] = Image(
259
- id=image_id, qvec=qvec, tvec=tvec,
260
- camera_id=camera_id, name=image_name,
261
- xys=xys, point3D_ids=point3D_ids)
262
- if path_to_model_file is not None:
263
- fid.close()
264
- return images
265
-
266
-
267
- def write_images_text(images, path):
268
- """
269
- see: src/base/reconstruction.cc
270
- void Reconstruction::ReadImagesText(const std::string& path)
271
- void Reconstruction::WriteImagesText(const std::string& path)
272
- """
273
- if len(images) == 0:
274
- mean_observations = 0
275
- else:
276
- mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images)
277
- HEADER = "# Image list with two lines of data per image:\n" + \
278
- "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + \
279
- "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + \
280
- "# Number of images: {}, mean observations per image: {}\n".format(len(images), mean_observations)
281
-
282
- with open(path, "w") as fid:
283
- fid.write(HEADER)
284
- for _, img in images.items():
285
- image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name]
286
- first_line = " ".join(map(str, image_header))
287
- fid.write(first_line + "\n")
288
-
289
- points_strings = []
290
- for xy, point3D_id in zip(img.xys, img.point3D_ids):
291
- points_strings.append(" ".join(map(str, [*xy, point3D_id])))
292
- fid.write(" ".join(points_strings) + "\n")
293
-
294
-
295
- def write_images_binary(images, path_to_model_file):
296
- """
297
- see: src/base/reconstruction.cc
298
- void Reconstruction::ReadImagesBinary(const std::string& path)
299
- void Reconstruction::WriteImagesBinary(const std::string& path)
300
- """
301
- with open(path_to_model_file, "wb") as fid:
302
- write_next_bytes(fid, len(images), "Q")
303
- for _, img in images.items():
304
- write_next_bytes(fid, img.id, "i")
305
- write_next_bytes(fid, img.qvec.tolist(), "dddd")
306
- write_next_bytes(fid, img.tvec.tolist(), "ddd")
307
- write_next_bytes(fid, img.camera_id, "i")
308
- for char in img.name:
309
- write_next_bytes(fid, char.encode("utf-8"), "c")
310
- write_next_bytes(fid, b"\x00", "c")
311
- write_next_bytes(fid, len(img.point3D_ids), "Q")
312
- for xy, p3d_id in zip(img.xys, img.point3D_ids):
313
- write_next_bytes(fid, [*xy, p3d_id], "ddq")
314
-
315
-
316
- def read_points3D_text(path):
317
- """
318
- see: src/base/reconstruction.cc
319
- void Reconstruction::ReadPoints3DText(const std::string& path)
320
- void Reconstruction::WritePoints3DText(const std::string& path)
321
- """
322
- points3D = {}
323
- with open(path, "r") as fid:
324
- while True:
325
- line = fid.readline()
326
- if not line:
327
- break
328
- line = line.strip()
329
- if len(line) > 0 and line[0] != "#":
330
- elems = line.split()
331
- point3D_id = int(elems[0])
332
- xyz = np.array(tuple(map(float, elems[1:4])))
333
- rgb = np.array(tuple(map(int, elems[4:7])))
334
- error = float(elems[7])
335
- image_ids = np.array(tuple(map(int, elems[8::2])))
336
- point2D_idxs = np.array(tuple(map(int, elems[9::2])))
337
- points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
338
- error=error, image_ids=image_ids,
339
- point2D_idxs=point2D_idxs)
340
- return points3D
341
-
342
-
343
- def read_points3D_binary(path_to_model_file=None, fid=None):
344
- """
345
- see: src/base/reconstruction.cc
346
- void Reconstruction::ReadPoints3DBinary(const std::string& path)
347
- void Reconstruction::WritePoints3DBinary(const std::string& path)
348
- """
349
- points3D = {}
350
- if fid is None:
351
- fid = open(path_to_model_file, "rb")
352
- num_points = read_next_bytes(fid, 8, "Q")[0]
353
- for _ in range(num_points):
354
- binary_point_line_properties = read_next_bytes(
355
- fid, num_bytes=43, format_char_sequence="QdddBBBd")
356
- point3D_id = binary_point_line_properties[0]
357
- xyz = np.array(binary_point_line_properties[1:4])
358
- rgb = np.array(binary_point_line_properties[4:7])
359
- error = np.array(binary_point_line_properties[7])
360
- track_length = read_next_bytes(
361
- fid, num_bytes=8, format_char_sequence="Q")[0]
362
- track_elems = read_next_bytes(
363
- fid, num_bytes=8*track_length,
364
- format_char_sequence="ii"*track_length)
365
- image_ids = np.array(tuple(map(int, track_elems[0::2])))
366
- point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
367
- points3D[point3D_id] = Point3D(
368
- id=point3D_id, xyz=xyz, rgb=rgb,
369
- error=error, image_ids=image_ids,
370
- point2D_idxs=point2D_idxs)
371
- if path_to_model_file is not None:
372
- fid.close()
373
- return points3D
374
-
375
-
376
- def write_points3D_text(points3D, path):
377
- """
378
- see: src/base/reconstruction.cc
379
- void Reconstruction::ReadPoints3DText(const std::string& path)
380
- void Reconstruction::WritePoints3DText(const std::string& path)
381
- """
382
- if len(points3D) == 0:
383
- mean_track_length = 0
384
- else:
385
- mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D)
386
- HEADER = "# 3D point list with one line of data per point:\n" + \
387
- "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + \
388
- "# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length)
389
-
390
- with open(path, "w") as fid:
391
- fid.write(HEADER)
392
- for _, pt in points3D.items():
393
- point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
394
- fid.write(" ".join(map(str, point_header)) + " ")
395
- track_strings = []
396
- for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
397
- track_strings.append(" ".join(map(str, [image_id, point2D])))
398
- fid.write(" ".join(track_strings) + "\n")
399
-
400
-
401
- def write_points3D_binary(points3D, path_to_model_file):
402
- """
403
- see: src/base/reconstruction.cc
404
- void Reconstruction::ReadPoints3DBinary(const std::string& path)
405
- void Reconstruction::WritePoints3DBinary(const std::string& path)
406
- """
407
- with open(path_to_model_file, "wb") as fid:
408
- write_next_bytes(fid, len(points3D), "Q")
409
- for _, pt in points3D.items():
410
- write_next_bytes(fid, pt.id, "Q")
411
- write_next_bytes(fid, pt.xyz.tolist(), "ddd")
412
- write_next_bytes(fid, pt.rgb.tolist(), "BBB")
413
- write_next_bytes(fid, pt.error, "d")
414
- track_length = pt.image_ids.shape[0]
415
- write_next_bytes(fid, track_length, "Q")
416
- for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
417
- write_next_bytes(fid, [image_id, point2D_id], "ii")
418
-
419
-
420
- def detect_model_format(path, ext):
421
- if os.path.isfile(os.path.join(path, "cameras" + ext)) and \
422
- os.path.isfile(os.path.join(path, "images" + ext)) and \
423
- os.path.isfile(os.path.join(path, "points3D" + ext)):
424
- print("Detected model format: '" + ext + "'")
425
- return True
426
-
427
- return False
428
-
429
-
430
- def read_model(path, ext=""):
431
- # try to detect the extension automatically
432
- if ext == "":
433
- if detect_model_format(path, ".bin"):
434
- ext = ".bin"
435
- elif detect_model_format(path, ".txt"):
436
- ext = ".txt"
437
- else:
438
- print("Provide model format: '.bin' or '.txt'")
439
- return
440
-
441
- if ext == ".txt":
442
- cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
443
- images = read_images_text(os.path.join(path, "images" + ext))
444
- points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
445
- else:
446
- cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
447
- images = read_images_binary(os.path.join(path, "images" + ext))
448
- points3D = read_points3D_binary(os.path.join(path, "points3D") + ext)
449
- return cameras, images, points3D
450
-
451
-
452
- def write_model(cameras, images, points3D, path, ext=".bin"):
453
- if ext == ".txt":
454
- write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
455
- write_images_text(images, os.path.join(path, "images" + ext))
456
- write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
457
- else:
458
- write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
459
- write_images_binary(images, os.path.join(path, "images" + ext))
460
- write_points3D_binary(points3D, os.path.join(path, "points3D") + ext)
461
- return cameras, images, points3D
462
-
463
-
464
- def qvec2rotmat(qvec):
465
- return np.array([
466
- [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
467
- 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
468
- 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
469
- [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
470
- 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
471
- 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
472
- [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
473
- 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
474
- 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
475
-
476
-
477
- def rotmat2qvec(R):
478
- Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
479
- K = np.array([
480
- [Rxx - Ryy - Rzz, 0, 0, 0],
481
- [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
482
- [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
483
- [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
484
- eigvals, eigvecs = np.linalg.eigh(K)
485
- qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
486
- if qvec[0] < 0:
487
- qvec *= -1
488
- return qvec
489
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
script.py CHANGED
@@ -10,34 +10,55 @@
10
  ### You can use any additional files and subdirectories to organize your code.
11
 
12
  '''---compulsory---'''
13
- import hoho; hoho.setup() # YOU MUST CALL hoho.setup() BEFORE ANYTHING ELSE
14
- import subprocess
15
- import importlib
16
- from pathlib import Path
17
- import subprocess
18
-
19
-
20
- ### The function below is useful for installing additional python wheels.
21
- def install_package_from_local_file(package_name, folder='packages'):
22
- """
23
- Installs a package from a local .whl file or a directory containing .whl files using pip.
 
 
 
 
 
 
 
 
 
24
 
25
- Parameters:
26
- path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
27
- """
28
- try:
29
- pth = str(Path(folder) / package_name)
30
- subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
31
- "--no-index", # Do not use package index
32
- "--find-links", pth, # Look for packages in the specified directory or at the file
33
- package_name]) # Specify the package to install
34
- print(f"Package installed successfully from {pth}")
35
- except subprocess.CalledProcessError as e:
36
- print(f"Failed to install package from {pth}. Error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  # pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
40
- install_package_from_local_file('webdataset')
41
  # install_package_from_local_file('tqdm')
42
 
43
  ### Here you can import any library or module you want.
@@ -52,34 +73,11 @@ from transformers import AutoTokenizer
52
  import os
53
  import time
54
  import io
55
- from read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
56
  from PIL import Image as PImage
57
  import numpy as np
58
 
59
-
60
- def proc(row, split='train'):
61
- out = {}
62
- for k, v in row.items():
63
- colname = k.split('.')[0]
64
- if colname in {'ade20k', 'depthcm', 'gestalt'}:
65
- if colname in out:
66
- out[colname].append(v)
67
- else:
68
- out[colname] = [v]
69
- elif colname in {'wireframe', 'mesh'}:
70
- # out.update({a: b.tolist() for a,b in v.items()})
71
- out.update({a: b for a,b in v.items()})
72
- elif colname in 'kr':
73
- out[colname.upper()] = v
74
- else:
75
- out[colname] = v
76
-
77
- return Sample(out)
78
-
79
-
80
- class Sample(Dict):
81
- def __repr__(self):
82
- return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
83
 
84
  def convert_entry_to_human_readable(entry):
85
  out = {}
@@ -102,44 +100,47 @@ def convert_entry_to_human_readable(entry):
102
 
103
  '''---end of compulsory---'''
104
 
105
- def download_package(package_name, path_to_save='packages'):
 
 
 
106
  """
107
- Downloads a package using pip and saves it to a specified directory.
108
 
109
  Parameters:
110
- package_name (str): The name of the package to download.
111
- path_to_save (str): The path to the directory where the package will be saved.
112
  """
113
- try:
114
- # pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
115
- subprocess.check_call([subprocess.sys.executable, "-m", "pip", "download", package_name,
116
- "-d", str(Path(path_to_save)/package_name), # Download the package to the specified directory
117
- "--platform", "manylinux1_x86_64", # Specify the platform
118
- "--python-version", "38", # Specify the Python version
119
- "--only-binary=:all:"]) # Download only binary packages
120
- print(f'Package "{package_name}" downloaded successfully')
121
- except subprocess.CalledProcessError as e:
122
- print(f'Failed to downloaded package "{package_name}". Error: {e}')
123
-
124
-
125
- ### The part below is used to define and test your solution.
126
 
127
  if __name__ == "__main__":
128
  from handcrafted_solution import predict
129
  print ("------------ Loading dataset------------ ")
130
  params = hoho.get_params()
131
  dataset = hoho.get_dataset(decode=None, split='all', dataset_type='webdataset')
 
132
  print('------------ Now you can do your solution ---------------')
133
  solution = []
134
- for i, sample in enumerate(tqdm(dataset)):
135
- pred_vertices, pred_edges, semantics = predict(sample, visualize=False)
136
- solution.append({
137
- '__key__': sample['__key__'],
138
- 'wf_vertices': pred_vertices.tolist(),
139
- 'wf_edges': pred_edges,
140
- 'edge_semantics': semantics,
141
- })
 
 
 
 
 
 
 
 
 
 
142
  print('------------ Saving results ---------------')
143
- sub = pd.DataFrame(solution, columns=["__key__", "wf_vertices", "wf_edges", "edge_semantics"])
144
- sub.to_parquet(Path(params['output_path']) / "submission.parquet")
145
  print("------------ Done ------------ ")
 
10
  ### You can use any additional files and subdirectories to organize your code.
11
 
12
  '''---compulsory---'''
13
+ # import subprocess
14
+ # from pathlib import Path
15
+ # def install_package_from_local_file(package_name, folder='packages'):
16
+ # """
17
+ # Installs a package from a local .whl file or a directory containing .whl files using pip.
18
+
19
+ # Parameters:
20
+ # path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
21
+ # """
22
+ # try:
23
+ # pth = str(Path(folder) / package_name)
24
+ # subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
25
+ # "--no-index", # Do not use package index
26
+ # "--find-links", pth, # Look for packages in the specified directory or at the file
27
+ # package_name]) # Specify the package to install
28
+ # print(f"Package installed successfully from {pth}")
29
+ # except subprocess.CalledProcessError as e:
30
+ # print(f"Failed to install package from {pth}. Error: {e}")
31
+
32
+ # install_package_from_local_file('hoho')
33
 
34
+ import hoho; hoho.setup() # YOU MUST CALL hoho.setup() BEFORE ANYTHING ELSE
35
+ # import subprocess
36
+ # import importlib
37
+ # from pathlib import Path
38
+ # import subprocess
39
+
40
+
41
+ # ### The function below is useful for installing additional python wheels.
42
+ # def install_package_from_local_file(package_name, folder='packages'):
43
+ # """
44
+ # Installs a package from a local .whl file or a directory containing .whl files using pip.
45
+
46
+ # Parameters:
47
+ # path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
48
+ # """
49
+ # try:
50
+ # pth = str(Path(folder) / package_name)
51
+ # subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
52
+ # "--no-index", # Do not use package index
53
+ # "--find-links", pth, # Look for packages in the specified directory or at the file
54
+ # package_name]) # Specify the package to install
55
+ # print(f"Package installed successfully from {pth}")
56
+ # except subprocess.CalledProcessError as e:
57
+ # print(f"Failed to install package from {pth}. Error: {e}")
58
 
59
 
60
  # pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
61
+ # install_package_from_local_file('webdataset')
62
  # install_package_from_local_file('tqdm')
63
 
64
  ### Here you can import any library or module you want.
 
73
  import os
74
  import time
75
  import io
 
76
  from PIL import Image as PImage
77
  import numpy as np
78
 
79
+ from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
80
+ from hoho import proc, Sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def convert_entry_to_human_readable(entry):
83
  out = {}
 
100
 
101
  '''---end of compulsory---'''
102
 
103
+ ### The part below is used to define and test your solution.
104
+
105
+ from pathlib import Path
106
+ def save_submission(submission, path):
107
  """
108
+ Saves the submission to a specified path.
109
 
110
  Parameters:
111
+ submission (List[Dict[]]): The submission to save.
112
+ path (str): The path to save the submission to.
113
  """
114
+ sub = pd.DataFrame(submission, columns=["__key__", "wf_vertices", "wf_edges", "edge_semantics"])
115
+ sub.to_parquet(path)
116
+ print(f"Submission saved to {path}")
 
 
 
 
 
 
 
 
 
 
117
 
118
  if __name__ == "__main__":
119
  from handcrafted_solution import predict
120
  print ("------------ Loading dataset------------ ")
121
  params = hoho.get_params()
122
  dataset = hoho.get_dataset(decode=None, split='all', dataset_type='webdataset')
123
+
124
  print('------------ Now you can do your solution ---------------')
125
  solution = []
126
+ from concurrent.futures import ProcessPoolExecutor
127
+ with ProcessPoolExecutor(max_workers=8) as pool:
128
+ results = []
129
+ for i, sample in enumerate(tqdm(dataset)):
130
+ results.append(pool.submit(predict, sample, visualize=False))
131
+
132
+ for i, result in enumerate(tqdm(results)):
133
+ key, pred_vertices, pred_edges, semantics = result.result()
134
+ solution.append({
135
+ '__key__': key,
136
+ 'wf_vertices': pred_vertices.tolist(),
137
+ 'wf_edges': pred_edges,
138
+ 'edge_semantics': semantics,
139
+ })
140
+ if i % 100 == 0:
141
+ # incrementally save the results in case we run out of time
142
+ print(f"Processed {i} samples")
143
+ # save_submission(solution, Path(params['output_path']) / "submission.parquet")
144
  print('------------ Saving results ---------------')
145
+ save_submission(solution, Path(params['output_path']) / "submission.parquet")
 
146
  print("------------ Done ------------ ")
viz3d.py DELETED
@@ -1,302 +0,0 @@
1
-
2
- """
3
- Copyright [2022] [Paul-Edouard Sarlin and Philipp Lindenberger]
4
-
5
- Licensed under the Apache License, Version 2.0 (the "License");
6
- you may not use this file except in compliance with the License.
7
- You may obtain a copy of the License at
8
-
9
- http://www.apache.org/licenses/LICENSE-2.0
10
-
11
- Unless required by applicable law or agreed to in writing, software
12
- distributed under the License is distributed on an "AS IS" BASIS,
13
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- See the License for the specific language governing permissions and
15
- limitations under the License.
16
-
17
- 3D visualization based on plotly.
18
- Works for a small number of points and cameras, might be slow otherwise.
19
-
20
- 1) Initialize a figure with `init_figure`
21
- 2) Add 3D points, camera frustums, or both as a pycolmap.Reconstruction
22
-
23
- Written by Paul-Edouard Sarlin and Philipp Lindenberger.
24
- """
25
- # Slightly modified by Dmytro Mishkin
26
-
27
- from typing import Optional
28
- import numpy as np
29
- import pycolmap
30
- import plotly.graph_objects as go
31
-
32
-
33
- ### Some helper functions for geometry
34
- def qvec2rotmat(qvec):
35
- return np.array([
36
- [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
37
- 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
38
- 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
39
- [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
40
- 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
41
- 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
42
- [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
43
- 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
44
- 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
45
-
46
-
47
- def to_homogeneous(points):
48
- pad = np.ones((points.shape[:-1]+(1,)), dtype=points.dtype)
49
- return np.concatenate([points, pad], axis=-1)
50
-
51
- def t_to_proj_center(qvec, tvec):
52
- Rr = qvec2rotmat(qvec)
53
- tt = (-Rr.T) @ tvec
54
- return tt
55
-
56
- def calib(params):
57
- out = np.eye(3)
58
- if len(params) == 3:
59
- out[0,0] = params[0]
60
- out[1,1] = params[0]
61
- out[0,2] = params[1]
62
- out[1,2] = params[2]
63
- else:
64
- out[0,0] = params[0]
65
- out[1,1] = params[1]
66
- out[0,2] = params[2]
67
- out[1,2] = params[3]
68
- return out
69
-
70
-
71
- ### Plotting functions
72
-
73
- def init_figure(height: int = 800) -> go.Figure:
74
- """Initialize a 3D figure."""
75
- fig = go.Figure()
76
- axes = dict(
77
- visible=False,
78
- showbackground=False,
79
- showgrid=False,
80
- showline=False,
81
- showticklabels=True,
82
- autorange=True,
83
- )
84
- fig.update_layout(
85
- template="plotly_dark",
86
- height=height,
87
- scene_camera=dict(
88
- eye=dict(x=0., y=-.1, z=-2),
89
- up=dict(x=0, y=-1., z=0),
90
- projection=dict(type="orthographic")),
91
- scene=dict(
92
- xaxis=axes,
93
- yaxis=axes,
94
- zaxis=axes,
95
- aspectmode='data',
96
- dragmode='orbit',
97
- ),
98
- margin=dict(l=0, r=0, b=0, t=0, pad=0),
99
- legend=dict(
100
- orientation="h",
101
- yanchor="top",
102
- y=0.99,
103
- xanchor="left",
104
- x=0.1
105
- ),
106
- )
107
- return fig
108
-
109
-
110
- def plot_lines_3d(
111
- fig: go.Figure,
112
- pts: np.ndarray,
113
- color: str = 'rgba(255, 255, 255, 1)',
114
- ps: int = 2,
115
- colorscale: Optional[str] = None,
116
- name: Optional[str] = None):
117
- """Plot a set of 3D points."""
118
- x = pts[..., 0]
119
- y = pts[..., 1]
120
- z = pts[..., 2]
121
- traces = [go.Scatter3d(x=x1, y=y1, z=z1,
122
- mode='lines',
123
- line=dict(color=color, width=2)) for x1, y1, z1 in zip(x,y,z)]
124
- for t in traces:
125
- fig.add_trace(t)
126
- fig.update_traces(showlegend=False)
127
-
128
-
129
- def plot_points(
130
- fig: go.Figure,
131
- pts: np.ndarray,
132
- color: str = 'rgba(255, 0, 0, 1)',
133
- ps: int = 2,
134
- colorscale: Optional[str] = None,
135
- name: Optional[str] = None):
136
- """Plot a set of 3D points."""
137
- x, y, z = pts.T
138
- tr = go.Scatter3d(
139
- x=x, y=y, z=z, mode='markers', name=name, legendgroup=name,
140
- marker=dict(
141
- size=ps, color=color, line_width=0.0, colorscale=colorscale))
142
- fig.add_trace(tr)
143
-
144
- def plot_camera(
145
- fig: go.Figure,
146
- R: np.ndarray,
147
- t: np.ndarray,
148
- K: np.ndarray,
149
- color: str = 'rgb(0, 0, 255)',
150
- name: Optional[str] = None,
151
- legendgroup: Optional[str] = None,
152
- size: float = 1.0):
153
- """Plot a camera frustum from pose and intrinsic matrix."""
154
- W, H = K[0, 2]*2, K[1, 2]*2
155
- corners = np.array([[0, 0], [W, 0], [W, H], [0, H], [0, 0]])
156
- if size is not None:
157
- image_extent = max(size * W / 1024.0, size * H / 1024.0)
158
- world_extent = max(W, H) / (K[0, 0] + K[1, 1]) / 0.5
159
- scale = 0.5 * image_extent / world_extent
160
- else:
161
- scale = 1.0
162
- corners = to_homogeneous(corners) @ np.linalg.inv(K).T
163
- corners = (corners / 2 * scale) @ R.T + t
164
-
165
- x, y, z = corners.T
166
- rect = go.Scatter3d(
167
- x=x, y=y, z=z, line=dict(color=color), legendgroup=legendgroup,
168
- name=name, marker=dict(size=0.0001), showlegend=False)
169
- fig.add_trace(rect)
170
-
171
- x, y, z = np.concatenate(([t], corners)).T
172
- i = [0, 0, 0, 0]
173
- j = [1, 2, 3, 4]
174
- k = [2, 3, 4, 1]
175
-
176
- pyramid = go.Mesh3d(
177
- x=x, y=y, z=z, color=color, i=i, j=j, k=k,
178
- legendgroup=legendgroup, name=name, showlegend=False)
179
- fig.add_trace(pyramid)
180
- triangles = np.vstack((i, j, k)).T
181
- vertices = np.concatenate(([t], corners))
182
- tri_points = np.array([
183
- vertices[i] for i in triangles.reshape(-1)
184
- ])
185
-
186
- x, y, z = tri_points.T
187
- pyramid = go.Scatter3d(
188
- x=x, y=y, z=z, mode='lines', legendgroup=legendgroup,
189
- name=name, line=dict(color=color, width=1), showlegend=False)
190
- fig.add_trace(pyramid)
191
-
192
-
193
- def plot_camera_colmap(
194
- fig: go.Figure,
195
- image: pycolmap.Image,
196
- camera: pycolmap.Camera,
197
- name: Optional[str] = None,
198
- **kwargs):
199
- """Plot a camera frustum from PyCOLMAP objects"""
200
- intr = calib(camera.params)
201
- if intr[0][0] > 10000:
202
- print("Bad camera")
203
- return
204
- plot_camera(
205
- fig,
206
- qvec2rotmat(image.qvec).T,
207
- t_to_proj_center(image.qvec, image.tvec),
208
- intr,#calibration_matrix(),
209
- name=name or str(image.id),
210
- **kwargs)
211
-
212
-
213
- def plot_cameras(
214
- fig: go.Figure,
215
- reconstruction,#: pycolmap.Reconstruction,
216
- **kwargs):
217
- """Plot a camera as a cone with camera frustum."""
218
- for image_id, image in reconstruction["images"].items():
219
- plot_camera_colmap(
220
- fig, image, reconstruction["cameras"][image.camera_id], **kwargs)
221
-
222
-
223
- def plot_reconstruction(
224
- fig: go.Figure,
225
- rec,
226
- color: str = 'rgb(0, 0, 255)',
227
- name: Optional[str] = None,
228
- points: bool = True,
229
- cameras: bool = True,
230
- cs: float = 1.0,
231
- single_color_points=False,
232
- camera_color='rgba(0, 255, 0, 0.5)'):
233
- # rec is result of loading reconstruction from "read_write_colmap.py"
234
- # Filter outliers
235
- xyzs = []
236
- rgbs = []
237
- for k, p3D in rec['points'].items():
238
- xyzs.append(p3D.xyz)
239
- rgbs.append(p3D.rgb)
240
-
241
- if points:
242
- plot_points(fig, np.array(xyzs), color=color if single_color_points else np.array(rgbs), ps=1, name=name)
243
- if cameras:
244
- plot_cameras(fig, rec, color=camera_color, legendgroup=name, size=cs)
245
-
246
-
247
- def plot_pointcloud(
248
- fig: go.Figure,
249
- pts: np.ndarray,
250
- colors: np.ndarray,
251
- ps: int = 2,
252
- name: Optional[str] = None):
253
- """Plot a set of 3D points."""
254
- plot_points(fig, np.array(pts), color=colors, ps=ps, name=name)
255
-
256
-
257
- def plot_triangle_mesh(
258
- fig: go.Figure,
259
- vert: np.ndarray,
260
- colors: np.ndarray,
261
- triangles: np.ndarray,
262
- name: Optional[str] = None):
263
- """Plot a triangle mesh."""
264
- tr = go.Mesh3d(
265
- x=vert[:,0],
266
- y=vert[:,1],
267
- z=vert[:,2],
268
- vertexcolor = np.clip(255*colors, 0, 255),
269
- # i, j and k give the vertices of triangles
270
- # here we represent the 4 triangles of the tetrahedron surface
271
- i=triangles[:,0],
272
- j=triangles[:,1],
273
- k=triangles[:,2],
274
- name=name,
275
- showscale=False
276
- )
277
- fig.add_trace(tr)
278
-
279
- def plot_estimate_and_gt(pred_vertices, pred_connections, gt_vertices=None, gt_connections=None):
280
- fig3d = init_figure()
281
- c1 = (30, 20, 255)
282
- img_color = [c1 for _ in range(len(pred_vertices))]
283
- plot_points(fig3d, pred_vertices, color = img_color, ps = 10)
284
- lines = []
285
- for c in pred_connections:
286
- v1 = pred_vertices[c[0]]
287
- v2 = pred_vertices[c[1]]
288
- lines.append(np.stack([v1, v2], axis=0))
289
- plot_lines_3d(fig3d, np.array(lines), img_color, ps=4)
290
- if gt_vertices is not None:
291
- c2 = (30, 255, 20)
292
- img_color2 = [c2 for _ in range(len(gt_vertices))]
293
- plot_points(fig3d, gt_vertices, color = img_color2, ps = 10)
294
- if gt_connections is not None:
295
- gt_lines = []
296
- for c in gt_connections:
297
- v1 = gt_vertices[c[0]]
298
- v2 = gt_vertices[c[1]]
299
- gt_lines.append(np.stack([v1, v2], axis=0))
300
- plot_lines_3d(fig3d, np.array(gt_lines), img_color2, ps=4)
301
- fig3d.show()
302
- return fig3d