corrosivelogic commited on
Commit
7e94bda
1 Parent(s): 177513b

Add submission scripts

Browse files
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. gitattributes +38 -0
  3. handcrafted_solution.py +245 -0
  4. script.py +120 -33
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *.whl filter=lfs diff=lfs merge=lfs -text
36
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
37
+ packages/** filter=lfs diff=lfs merge=lfs -text
38
+ *.ipynb filter=lfs diff=lfs merge=lfs -text
handcrafted_solution.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description: This file contains the handcrafted solution for the task of wireframe reconstruction
2
+
3
+ import io
4
+ from PIL import Image as PImage
5
+ import numpy as np
6
+ from collections import defaultdict
7
+ import cv2
8
+ from typing import Tuple, List
9
+ from scipy.spatial.distance import cdist
10
+
11
+ from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
12
+ from hoho.color_mappings import gestalt_color_mapping, ade20k_color_mapping
13
+
14
+
15
+ def empty_solution():
16
+ '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
17
+ return np.zeros((2,3)), [(0, 1)]
18
+
19
+
20
+ def convert_entry_to_human_readable(entry):
21
+ out = {}
22
+ already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't']
23
+ for k, v in entry.items():
24
+ if k in already_good:
25
+ out[k] = v
26
+ continue
27
+ if k == 'points3d':
28
+ out[k] = read_points3D_binary(fid=io.BytesIO(v))
29
+ if k == 'cameras':
30
+ out[k] = read_cameras_binary(fid=io.BytesIO(v))
31
+ if k == 'images':
32
+ out[k] = read_images_binary(fid=io.BytesIO(v))
33
+ if k in ['ade20k', 'gestalt']:
34
+ out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v]
35
+ if k == 'depthcm':
36
+ out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']]
37
+ return out
38
+
39
+
40
+ def get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th = 50.0):
41
+ '''Get the vertices and edges from the gestalt segmentation mask of the house'''
42
+ vertices = []
43
+ connections = []
44
+ # Apex
45
+ apex_color = np.array(gestalt_color_mapping['apex'])
46
+ apex_mask = cv2.inRange(gest_seg_np, apex_color-0.5, apex_color+0.5)
47
+ if apex_mask.sum() > 0:
48
+ output = cv2.connectedComponentsWithStats(apex_mask, 8, cv2.CV_32S)
49
+ (numLabels, labels, stats, centroids) = output
50
+ stats, centroids = stats[1:], centroids[1:]
51
+
52
+ for i in range(numLabels-1):
53
+ vert = {"xy": centroids[i], "type": "apex"}
54
+ vertices.append(vert)
55
+
56
+ eave_end_color = np.array(gestalt_color_mapping['eave_end_point'])
57
+ eave_end_mask = cv2.inRange(gest_seg_np, eave_end_color-0.5, eave_end_color+0.5)
58
+ if eave_end_mask.sum() > 0:
59
+ output = cv2.connectedComponentsWithStats(eave_end_mask, 8, cv2.CV_32S)
60
+ (numLabels, labels, stats, centroids) = output
61
+ stats, centroids = stats[1:], centroids[1:]
62
+
63
+ for i in range(numLabels-1):
64
+ vert = {"xy": centroids[i], "type": "eave_end_point"}
65
+ vertices.append(vert)
66
+ # Connectivity
67
+ apex_pts = []
68
+ apex_pts_idxs = []
69
+ for j, v in enumerate(vertices):
70
+ apex_pts.append(v['xy'])
71
+ apex_pts_idxs.append(j)
72
+ apex_pts = np.array(apex_pts)
73
+
74
+ # Ridge connects two apex points
75
+ for edge_class in ['eave', 'ridge', 'rake', 'valley']:
76
+ edge_color = np.array(gestalt_color_mapping[edge_class])
77
+ mask = cv2.morphologyEx(cv2.inRange(gest_seg_np,
78
+ edge_color-0.5,
79
+ edge_color+0.5),
80
+ cv2.MORPH_DILATE, np.ones((11, 11)))
81
+ line_img = np.copy(gest_seg_np) * 0
82
+ if mask.sum() > 0:
83
+ output = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S)
84
+ (numLabels, labels, stats, centroids) = output
85
+ stats, centroids = stats[1:], centroids[1:]
86
+ edges = []
87
+ for i in range(1, numLabels):
88
+ y,x = np.where(labels == i)
89
+ xleft_idx = np.argmin(x)
90
+ x_left = x[xleft_idx]
91
+ y_left = y[xleft_idx]
92
+ xright_idx = np.argmax(x)
93
+ x_right = x[xright_idx]
94
+ y_right = y[xright_idx]
95
+ edges.append((x_left, y_left, x_right, y_right))
96
+ cv2.line(line_img, (x_left, y_left), (x_right, y_right), (255, 255, 255), 2)
97
+ edges = np.array(edges)
98
+ if (len(apex_pts) < 2) or len(edges) <1:
99
+ continue
100
+ pts_to_edges_dist = np.minimum(cdist(apex_pts, edges[:,:2]), cdist(apex_pts, edges[:,2:]))
101
+ connectivity_mask = pts_to_edges_dist <= edge_th
102
+ edge_connects = connectivity_mask.sum(axis=0)
103
+ for edge_idx, edgesum in enumerate(edge_connects):
104
+ if edgesum>=2:
105
+ connected_verts = np.where(connectivity_mask[:,edge_idx])[0]
106
+ for a_i, a in enumerate(connected_verts):
107
+ for b in connected_verts[a_i+1:]:
108
+ connections.append((a, b))
109
+ return vertices, connections
110
+
111
+ def get_uv_depth(vertices, depth):
112
+ '''Get the depth of the vertices from the depth image'''
113
+ uv = []
114
+ for v in vertices:
115
+ uv.append(v['xy'])
116
+ uv = np.array(uv)
117
+ uv_int = uv.astype(np.int32)
118
+ H, W = depth.shape[:2]
119
+ uv_int[:, 0] = np.clip( uv_int[:, 0], 0, W-1)
120
+ uv_int[:, 1] = np.clip( uv_int[:, 1], 0, H-1)
121
+ vertex_depth = depth[(uv_int[:, 1] , uv_int[:, 0])]
122
+ return uv, vertex_depth
123
+
124
+
125
+ def merge_vertices_3d(vert_edge_per_image, th=0.1):
126
+ '''Merge vertices that are close to each other in 3D space and are of same types'''
127
+ all_3d_vertices = []
128
+ connections_3d = []
129
+ all_indexes = []
130
+ cur_start = 0
131
+ types = []
132
+ for cimg_idx, (vertices, connections, vertices_3d) in vert_edge_per_image.items():
133
+ types += [int(v['type']=='apex') for v in vertices]
134
+ all_3d_vertices.append(vertices_3d)
135
+ connections_3d+=[(x+cur_start,y+cur_start) for (x,y) in connections]
136
+ cur_start+=len(vertices_3d)
137
+ all_3d_vertices = np.concatenate(all_3d_vertices, axis=0)
138
+ #print (connections_3d)
139
+ distmat = cdist(all_3d_vertices, all_3d_vertices)
140
+ types = np.array(types).reshape(-1,1)
141
+ same_types = cdist(types, types)
142
+ mask_to_merge = (distmat <= th) & (same_types==0)
143
+ new_vertices = []
144
+ new_connections = []
145
+ to_merge = sorted(list(set([tuple(a.nonzero()[0].tolist()) for a in mask_to_merge])))
146
+ to_merge_final = defaultdict(list)
147
+ for i in range(len(all_3d_vertices)):
148
+ for j in to_merge:
149
+ if i in j:
150
+ to_merge_final[i]+=j
151
+ for k, v in to_merge_final.items():
152
+ to_merge_final[k] = list(set(v))
153
+ already_there = set()
154
+ merged = []
155
+ for k, v in to_merge_final.items():
156
+ if k in already_there:
157
+ continue
158
+ merged.append(v)
159
+ for vv in v:
160
+ already_there.add(vv)
161
+ old_idx_to_new = {}
162
+ count=0
163
+ for idxs in merged:
164
+ new_vertices.append(all_3d_vertices[idxs].mean(axis=0))
165
+ for idx in idxs:
166
+ old_idx_to_new[idx] = count
167
+ count +=1
168
+ #print (connections_3d)
169
+ new_vertices=np.array(new_vertices)
170
+ #print (connections_3d)
171
+ for conn in connections_3d:
172
+ new_con = sorted((old_idx_to_new[conn[0]], old_idx_to_new[conn[1]]))
173
+ if new_con[0] == new_con[1]:
174
+ continue
175
+ if new_con not in new_connections:
176
+ new_connections.append(new_con)
177
+ #print (f'{len(new_vertices)} left after merging {len(all_3d_vertices)} with {th=}')
178
+ return new_vertices, new_connections
179
+
180
+ def prune_not_connected(all_3d_vertices, connections_3d):
181
+ '''Prune vertices that are not connected to any other vertex'''
182
+ connected = defaultdict(list)
183
+ for c in connections_3d:
184
+ connected[c[0]].append(c)
185
+ connected[c[1]].append(c)
186
+ new_indexes = {}
187
+ new_verts = []
188
+ connected_out = []
189
+ for k,v in connected.items():
190
+ vert = all_3d_vertices[k]
191
+ if tuple(vert) not in new_verts:
192
+ new_verts.append(tuple(vert))
193
+ new_indexes[k]=len(new_verts) -1
194
+ for k,v in connected.items():
195
+ for vv in v:
196
+ connected_out.append((new_indexes[vv[0]],new_indexes[vv[1]]))
197
+ connected_out=list(set(connected_out))
198
+
199
+ return np.array(new_verts), connected_out
200
+
201
+
202
+ def predict(entry, visualize=False) -> Tuple[np.ndarray, List[int]]:
203
+ good_entry = convert_entry_to_human_readable(entry)
204
+ vert_edge_per_image = {}
205
+ for i, (gest, depth, K, R, t) in enumerate(zip(good_entry['gestalt'],
206
+ good_entry['depthcm'],
207
+ good_entry['K'],
208
+ good_entry['R'],
209
+ good_entry['t']
210
+ )):
211
+ gest_seg = gest.resize(depth.size)
212
+ gest_seg_np = np.array(gest_seg).astype(np.uint8)
213
+ # Metric3D
214
+ depth_np = np.array(depth) / 2.5 # 2.5 is the scale estimation coefficient
215
+ vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th = 20.)
216
+ if (len(vertices) < 2) or (len(connections) < 1):
217
+ print (f'Not enough vertices or connections in image {i}')
218
+ vert_edge_per_image[i] = np.empty((0, 2)), [], np.empty((0, 3))
219
+ continue
220
+ uv, depth_vert = get_uv_depth(vertices, depth_np)
221
+ # Normalize the uv to the camera intrinsics
222
+ xy_local = np.ones((len(uv), 3))
223
+ xy_local[:, 0] = (uv[:, 0] - K[0,2]) / K[0,0]
224
+ xy_local[:, 1] = (uv[:, 1] - K[1,2]) / K[1,1]
225
+ # Get the 3D vertices
226
+ vertices_3d_local = depth_vert[...,None] * (xy_local/np.linalg.norm(xy_local, axis=1)[...,None])
227
+ world_to_cam = np.eye(4)
228
+ world_to_cam[:3, :3] = R
229
+ world_to_cam[:3, 3] = t.reshape(-1)
230
+ cam_to_world = np.linalg.inv(world_to_cam)
231
+ vertices_3d = cv2.transform(cv2.convertPointsToHomogeneous(vertices_3d_local), cam_to_world)
232
+ vertices_3d = cv2.convertPointsFromHomogeneous(vertices_3d).reshape(-1, 3)
233
+ vert_edge_per_image[i] = vertices, connections, vertices_3d
234
+ all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 3.0)
235
+ all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d)
236
+ if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
237
+ print (f'Not enough vertices or connections in the 3D vertices')
238
+ return (good_entry['__key__'], *empty_solution())
239
+ if visualize:
240
+ from hoho.viz3d import plot_estimate_and_gt
241
+ plot_estimate_and_gt( all_3d_vertices_clean,
242
+ connections_3d_clean,
243
+ good_entry['wf_vertices'],
244
+ good_entry['wf_edges'])
245
+ return good_entry['__key__'], all_3d_vertices_clean, connections_3d_clean
script.py CHANGED
@@ -4,55 +4,142 @@
4
  ### You can change the rest of the code to define and test your solution.
5
  ### However, you should not change the signature of the provided function.
6
  ### The script would save "submission.parquet" file in the current directory.
 
 
 
7
  ### You can use any additional files and subdirectories to organize your code.
8
 
9
  '''---compulsory---'''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import hoho; hoho.setup() # YOU MUST CALL hoho.setup() BEFORE ANYTHING ELSE
11
- '''---compulsory---'''
 
 
 
12
 
13
- from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from tqdm import tqdm
 
15
  import pandas as pd
 
 
 
 
 
16
  import numpy as np
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- def empty_solution(sample):
20
- '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
21
- return np.zeros((2,3)), [(0, 1)]
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  if __name__ == "__main__":
 
25
  print ("------------ Loading dataset------------ ")
26
  params = hoho.get_params()
27
-
28
- # by default it is usually better to use `get_dataset()` like this
29
- #
30
- # dataset = hoho.get_dataset(split='all')
31
- #
32
- # but in this case (because we don't do anything with the sample
33
- # anyway) we set `decode=None`. We can set the `split` argument
34
- # to 'train' or 'val' ('all' defaults back to 'train') if we are
35
- # testing ourselves locally.
36
- #
37
- # dataset = hoho.get_dataset(split='val', decode=None)
38
- #
39
- # On the test server *`split` must be set to 'all'*
40
- # to compute both the public and private leaderboards.
41
- #
42
- dataset = hoho.get_dataset(split='all', decode=None)
43
-
44
  print('------------ Now you can do your solution ---------------')
45
  solution = []
46
- for i, sample in enumerate(tqdm(dataset)):
47
- # replace this with your solution
48
- pred_vertices, pred_edges = empty_solution(sample)
 
 
49
 
50
- solution.append({
51
- '__key__': sample['__key__'],
52
- 'wf_vertices': pred_vertices.tolist(),
53
- 'wf_edges': pred_edges
54
- })
 
 
 
 
 
 
55
  print('------------ Saving results ---------------')
56
- sub = pd.DataFrame(solution, columns=["__key__", "wf_vertices", "wf_edges"])
57
- sub.to_parquet(Path(params['output_path']) / "submission.parquet")
58
- print("------------ Done ------------ ")
 
4
  ### You can change the rest of the code to define and test your solution.
5
  ### However, you should not change the signature of the provided function.
6
  ### The script would save "submission.parquet" file in the current directory.
7
+ ### The actual logic of the solution is implemented in the `handcrafted_solution.py` file.
8
+ ### The `handcrafted_solution.py` file is a placeholder for your solution.
9
+ ### You should implement the logic of your solution in that file.
10
  ### You can use any additional files and subdirectories to organize your code.
11
 
12
  '''---compulsory---'''
13
+ # import subprocess
14
+ # from pathlib import Path
15
+ # def install_package_from_local_file(package_name, folder='packages'):
16
+ # """
17
+ # Installs a package from a local .whl file or a directory containing .whl files using pip.
18
+
19
+ # Parameters:
20
+ # path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
21
+ # """
22
+ # try:
23
+ # pth = str(Path(folder) / package_name)
24
+ # subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
25
+ # "--no-index", # Do not use package index
26
+ # "--find-links", pth, # Look for packages in the specified directory or at the file
27
+ # package_name]) # Specify the package to install
28
+ # print(f"Package installed successfully from {pth}")
29
+ # except subprocess.CalledProcessError as e:
30
+ # print(f"Failed to install package from {pth}. Error: {e}")
31
+
32
+ # install_package_from_local_file('hoho')
33
+
34
  import hoho; hoho.setup() # YOU MUST CALL hoho.setup() BEFORE ANYTHING ELSE
35
+ # import subprocess
36
+ # import importlib
37
+ # from pathlib import Path
38
+ # import subprocess
39
 
40
+
41
+ # ### The function below is useful for installing additional python wheels.
42
+ # def install_package_from_local_file(package_name, folder='packages'):
43
+ # """
44
+ # Installs a package from a local .whl file or a directory containing .whl files using pip.
45
+
46
+ # Parameters:
47
+ # path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
48
+ # """
49
+ # try:
50
+ # pth = str(Path(folder) / package_name)
51
+ # subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
52
+ # "--no-index", # Do not use package index
53
+ # "--find-links", pth, # Look for packages in the specified directory or at the file
54
+ # package_name]) # Specify the package to install
55
+ # print(f"Package installed successfully from {pth}")
56
+ # except subprocess.CalledProcessError as e:
57
+ # print(f"Failed to install package from {pth}. Error: {e}")
58
+
59
+
60
+ # pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
61
+ # install_package_from_local_file('webdataset')
62
+ # install_package_from_local_file('tqdm')
63
+
64
+ ### Here you can import any library or module you want.
65
+ ### The code below is used to read and parse the input dataset.
66
+ ### Please, do not modify it.
67
+
68
+ import webdataset as wds
69
  from tqdm import tqdm
70
+ from typing import Dict
71
  import pandas as pd
72
+ from transformers import AutoTokenizer
73
+ import os
74
+ import time
75
+ import io
76
+ from PIL import Image as PImage
77
  import numpy as np
78
 
79
+ from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
80
+ from hoho import proc, Sample
81
+
82
+ def convert_entry_to_human_readable(entry):
83
+ out = {}
84
+ already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't']
85
+ for k, v in entry.items():
86
+ if k in already_good:
87
+ out[k] = v
88
+ continue
89
+ if k == 'points3d':
90
+ out[k] = read_points3D_binary(fid=io.BytesIO(v))
91
+ if k == 'cameras':
92
+ out[k] = read_cameras_binary(fid=io.BytesIO(v))
93
+ if k == 'images':
94
+ out[k] = read_images_binary(fid=io.BytesIO(v))
95
+ if k in ['ade20k', 'gestalt']:
96
+ out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v]
97
+ if k == 'depthcm':
98
+ out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']]
99
+ return out
100
 
101
+ '''---end of compulsory---'''
 
 
102
 
103
+ ### The part below is used to define and test your solution.
104
+
105
+ from pathlib import Path
106
+ def save_submission(submission, path):
107
+ """
108
+ Saves the submission to a specified path.
109
+
110
+ Parameters:
111
+ submission (List[Dict[]]): The submission to save.
112
+ path (str): The path to save the submission to.
113
+ """
114
+ sub = pd.DataFrame(submission, columns=["__key__", "wf_vertices", "wf_edges"])
115
+ sub.to_parquet(path)
116
+ print(f"Submission saved to {path}")
117
 
118
  if __name__ == "__main__":
119
+ from handcrafted_solution import predict
120
  print ("------------ Loading dataset------------ ")
121
  params = hoho.get_params()
122
+ dataset = hoho.get_dataset(decode=None, split='all', dataset_type='webdataset')
123
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  print('------------ Now you can do your solution ---------------')
125
  solution = []
126
+ from concurrent.futures import ProcessPoolExecutor
127
+ with ProcessPoolExecutor(max_workers=8) as pool:
128
+ results = []
129
+ for i, sample in enumerate(tqdm(dataset)):
130
+ results.append(pool.submit(predict, sample, visualize=False))
131
 
132
+ for i, result in enumerate(tqdm(results)):
133
+ key, pred_vertices, pred_edges = result.result()
134
+ solution.append({
135
+ '__key__': key,
136
+ 'wf_vertices': pred_vertices.tolist(),
137
+ 'wf_edges': pred_edges
138
+ })
139
+ if i % 100 == 0:
140
+ # incrementally save the results in case we run out of time
141
+ print(f"Processed {i} samples")
142
+ # save_submission(solution, Path(params['output_path']) / "submission.parquet")
143
  print('------------ Saving results ---------------')
144
+ save_submission(solution, Path(params['output_path']) / "submission.parquet")
145
+ print("------------ Done ------------ ")