jacklangerman commited on
Commit
22f7dc4
1 Parent(s): 718a478

multiprocess

Browse files
Files changed (2) hide show
  1. handcrafted_solution.py +1 -1
  2. script.py +18 -12
handcrafted_solution.py CHANGED
@@ -239,4 +239,4 @@ def predict(entry, visualize=False) -> Tuple[np.ndarray, List[int]]:
239
  from viz3d import plot_estimate_and_gt
240
  plot_estimate_and_gt(all_3d_vertices_clean, connections_3d_clean, good_entry['wf_vertices'],
241
  good_entry['wf_edges'])
242
- return all_3d_vertices_clean, connections_3d_clean, [0 for i in range(len(connections_3d_clean))]
 
239
  from viz3d import plot_estimate_and_gt
240
  plot_estimate_and_gt(all_3d_vertices_clean, connections_3d_clean, good_entry['wf_vertices'],
241
  good_entry['wf_edges'])
242
+ return good_entry['__key__'], all_3d_vertices_clean, connections_3d_clean, [0 for i in range(len(connections_3d_clean))]
script.py CHANGED
@@ -143,18 +143,24 @@ if __name__ == "__main__":
143
  dataset = hoho.get_dataset(decode=None, split='all', dataset_type='webdataset')
144
  print('------------ Now you can do your solution ---------------')
145
  solution = []
146
- for i, sample in enumerate(tqdm(dataset)):
147
- pred_vertices, pred_edges, semantics = predict(sample, visualize=False)
148
- solution.append({
149
- '__key__': sample['__key__'],
150
- 'wf_vertices': pred_vertices.tolist(),
151
- 'wf_edges': pred_edges,
152
- 'edge_semantics': semantics,
153
- })
154
- if i % 100 == 0:
155
- # incrementally save the results in case we run out of time
156
- print(f"Processed {i} samples")
157
- # save_submission(solution, Path(params['output_path']) / "submission.parquet")
 
 
 
 
 
 
158
  print('------------ Saving results ---------------')
159
  save_submission(solution, Path(params['output_path']) / "submission.parquet")
160
  print("------------ Done ------------ ")
 
143
  dataset = hoho.get_dataset(decode=None, split='all', dataset_type='webdataset')
144
  print('------------ Now you can do your solution ---------------')
145
  solution = []
146
+ from concurrent.futures import ProcessPoolExecutor
147
+ with ProcessPoolExecutor(max_workers=8) as pool:
148
+ results = []
149
+ for i, sample in enumerate(tqdm(dataset)):
150
+ results.append(pool.submit(predict, sample, visualize=False))
151
+
152
+ for i, result in enumerate(tqdm(results)):
153
+ key = pred_vertices, pred_edges, semantics = result.result()
154
+ solution.append({
155
+ '__key__': key,
156
+ 'wf_vertices': pred_vertices.tolist(),
157
+ 'wf_edges': pred_edges,
158
+ 'edge_semantics': semantics,
159
+ })
160
+ if i % 100 == 0:
161
+ # incrementally save the results in case we run out of time
162
+ print(f"Processed {i} samples")
163
+ # save_submission(solution, Path(params['output_path']) / "submission.parquet")
164
  print('------------ Saving results ---------------')
165
  save_submission(solution, Path(params['output_path']) / "submission.parquet")
166
  print("------------ Done ------------ ")