dmytromishkin commited on
Commit
420d591
1 Parent(s): b513ce3

accept w/o semantics

Browse files
Files changed (2) hide show
  1. hoho/hoho.py +11 -18
  2. hoho/vis.py +8 -1
hoho/hoho.py CHANGED
@@ -4,6 +4,14 @@ import shutil
4
  from pathlib import Path
5
  from typing import Dict
6
  import warnings
 
 
 
 
 
 
 
 
7
 
8
  from PIL import ImageFile
9
 
@@ -40,13 +48,7 @@ def setup(local_dir='./data/usm-training-data/data'):
40
  LOCAL_DATADIR.mkdir(parents=True)
41
 
42
  return LOCAL_DATADIR
43
-
44
-
45
-
46
-
47
- import importlib
48
- from pathlib import Path
49
- import subprocess
50
 
51
  def download_package(package_name, path_to_save='packages'):
52
  """
@@ -139,9 +141,7 @@ def Rt_to_eye_target(im, K, R, t):
139
 
140
 
141
  ########## general utilities ##########
142
- import contextlib
143
- import tempfile
144
- from pathlib import Path
145
 
146
  @contextlib.contextmanager
147
  def working_directory(path):
@@ -184,10 +184,6 @@ def proc(row, split='train'):
184
  return Sample(out)
185
 
186
 
187
-
188
-
189
-
190
-
191
  from . import read_write_colmap
192
  def decode_colmap(s):
193
  with temp_working_directory():
@@ -209,8 +205,7 @@ def decode_colmap(s):
209
  )
210
  return cameras, images, points3D
211
 
212
- from PIL import Image
213
- import io
214
  def decode(row):
215
  cameras, images, points3D = decode_colmap(row)
216
 
@@ -288,8 +283,6 @@ def get_params():
288
 
289
 
290
 
291
- import webdataset as wds
292
- import numpy as np
293
 
294
 
295
  SHARD_IDS = {'train': (0, 25), 'val': (25, 26), 'public': (26, 27), 'private': (27, 32)}
 
4
  from pathlib import Path
5
  from typing import Dict
6
  import warnings
7
+ import contextlib
8
+ import tempfile
9
+ from PIL import Image
10
+ import io
11
+ import webdataset as wds
12
+ import numpy as np
13
+ import importlib
14
+ import subprocess
15
 
16
  from PIL import ImageFile
17
 
 
48
  LOCAL_DATADIR.mkdir(parents=True)
49
 
50
  return LOCAL_DATADIR
51
+
 
 
 
 
 
 
52
 
53
  def download_package(package_name, path_to_save='packages'):
54
  """
 
141
 
142
 
143
  ########## general utilities ##########
144
+
 
 
145
 
146
  @contextlib.contextmanager
147
  def working_directory(path):
 
184
  return Sample(out)
185
 
186
 
 
 
 
 
187
  from . import read_write_colmap
188
  def decode_colmap(s):
189
  with temp_working_directory():
 
205
  )
206
  return cameras, images, points3D
207
 
208
+
 
209
  def decode(row):
210
  cameras, images, points3D = decode_colmap(row)
211
 
 
283
 
284
 
285
 
 
 
286
 
287
 
288
  SHARD_IDS = {'train': (0, 25), 'val': (25, 26), 'public': (26, 27), 'private': (27, 32)}
hoho/vis.py CHANGED
@@ -51,7 +51,14 @@ def show_wf(row, radius=10):
51
  'valley',
52
  'hip',
53
  'transition_line']
54
- return [line(a,b, radius=radius, c=color_mappings.gestalt_color_mapping[EDGE_CLASSES[cls_id]]) for (a,b), cls_id in zip(np.stack([*row['wf_vertices']])[np.stack(row['wf_edges'])], row['edge_semantics'])]
 
 
 
 
 
 
 
55
  # return [line(a,b, radius=radius, c=color_mappings.edge_colors[cls_id]) for (a,b), cls_id in zip(np.stack([*row['wf_vertices']])[np.stack(row['wf_edges'])], row['edge_semantics'])]
56
 
57
 
 
51
  'valley',
52
  'hip',
53
  'transition_line']
54
+ if 'edge_semantics' not in row:
55
+ print ("Warning: edge semantics is not here, skipping")
56
+ return [line(a,b, radius=radius, c=(214, 251, 248)) for a,b in np.stack([*row['wf_vertices']])[np.stack(row['wf_edges'])]]
57
+ elif len(np.stack(row['wf_edges'])) == len(row['edge_semantics']):
58
+ return [line(a,b, radius=radius, c=color_mappings.gestalt_color_mapping[EDGE_CLASSES[cls_id]]) for (a,b), cls_id in zip(np.stack([*row['wf_vertices']])[np.stack(row['wf_edges'])], row['edge_semantics'])]
59
+ else:
60
+ print ("Warning: edge semantics has different length compared to edges, skipping semantics")
61
+ return [line(a,b, radius=radius, c=(214, 251, 248)) for a,b in np.stack([*row['wf_vertices']])[np.stack(row['wf_edges'])]]
62
  # return [line(a,b, radius=radius, c=color_mappings.edge_colors[cls_id]) for (a,b), cls_id in zip(np.stack([*row['wf_vertices']])[np.stack(row['wf_edges'])], row['edge_semantics'])]
63
 
64