jacklangerman commited on
Commit
c77b687
1 Parent(s): 7c466d7

update solution

Browse files
Files changed (2) hide show
  1. hoho.py +0 -247
  2. script.py +22 -4
hoho.py DELETED
@@ -1,247 +0,0 @@
1
- import os
2
- import json
3
- import shutil
4
- from pathlib import Path
5
- from typing import Dict
6
-
7
- from PIL import ImageFile
8
- ImageFile.LOAD_TRUNCATED_IMAGES = True
9
-
10
- LOCAL_DATADIR = None
11
-
12
- def setup(local_dir='./data/usm-training-data/data'):
13
-
14
- # If we are in the test environment, we need to link the data directory to the correct location
15
- tmp_datadir = Path('/tmp/data/data')
16
- local_test_datadir = Path('./data/usm-test-data-x/data')
17
- local_val_datadir = Path(local_dir)
18
-
19
- os.system('pwd')
20
- os.system('ls -lahtr .')
21
-
22
- if tmp_datadir.exists() and not local_test_datadir.exists():
23
- global LOCAL_DATADIR
24
- LOCAL_DATADIR = local_test_datadir
25
- # shutil.move(datadir, './usm-test-data-x/data')
26
- print(f"Linking {tmp_datadir} to {LOCAL_DATADIR} (we are in the test environment)")
27
- LOCAL_DATADIR.parent.mkdir(parents=True, exist_ok=True)
28
- LOCAL_DATADIR.symlink_to(tmp_datadir)
29
- else:
30
- LOCAL_DATADIR = local_val_datadir
31
- print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)")
32
-
33
- # os.system("ls -lahtr")
34
-
35
- assert LOCAL_DATADIR.exists(), f"Data directory {LOCAL_DATADIR} does not exist"
36
- return LOCAL_DATADIR
37
-
38
-
39
-
40
-
41
- import importlib
42
- from pathlib import Path
43
- import subprocess
44
-
45
- def download_package(package_name, path_to_save='packages'):
46
- """
47
- Downloads a package using pip and saves it to a specified directory.
48
-
49
- Parameters:
50
- package_name (str): The name of the package to download.
51
- path_to_save (str): The path to the directory where the package will be saved.
52
- """
53
- try:
54
- # pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
55
- subprocess.check_call([subprocess.sys.executable, "-m", "pip", "download", package_name,
56
- "-d", str(Path(path_to_save)/package_name), # Download the package to the specified directory
57
- "--platform", "manylinux1_x86_64", # Specify the platform
58
- "--python-version", "38", # Specify the Python version
59
- "--only-binary=:all:"]) # Download only binary packages
60
- print(f'Package "{package_name}" downloaded successfully')
61
- except subprocess.CalledProcessError as e:
62
- print(f'Failed to downloaded package "{package_name}". Error: {e}')
63
-
64
-
65
- def install_package_from_local_file(package_name, folder='packages'):
66
- """
67
- Installs a package from a local .whl file or a directory containing .whl files using pip.
68
-
69
- Parameters:
70
- path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
71
- """
72
- try:
73
- pth = str(Path(folder) / package_name)
74
- subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
75
- "--no-index", # Do not use package index
76
- "--find-links", pth, # Look for packages in the specified directory or at the file
77
- package_name]) # Specify the package to install
78
- print(f"Package installed successfully from {pth}")
79
- except subprocess.CalledProcessError as e:
80
- print(f"Failed to install package from {pth}. Error: {e}")
81
-
82
-
83
- def importt(module_name, as_name=None):
84
- """
85
- Imports a module and returns it.
86
-
87
- Parameters:
88
- module_name (str): The name of the module to import.
89
- as_name (str): The name to use for the imported module. If None, the original module name will be used.
90
-
91
- Returns:
92
- The imported module.
93
- """
94
- for _ in range(2):
95
- try:
96
- if as_name is None:
97
- print(f'imported {module_name}')
98
- return importlib.import_module(module_name)
99
- else:
100
- print(f'imported {module_name} as {as_name}')
101
- return importlib.import_module(module_name, as_name)
102
- except ModuleNotFoundError as e:
103
- install_package_from_local_file(module_name)
104
- print(f"Failed to import module {module_name}. Error: {e}")
105
-
106
-
107
- def prepare_submission():
108
- # Download packages from requirements.txt
109
- if Path('requirements.txt').exists():
110
- print('downloading packages from requirements.txt')
111
- Path('packages').mkdir(exist_ok=True)
112
- with open('requirements.txt') as f:
113
- packages = f.readlines()
114
- for p in packages:
115
- download_package(p.strip())
116
-
117
-
118
- print('all packages downloaded. Don\'t foget to include the packages in the submission by adding them with git lfs.')
119
-
120
-
121
-
122
- ########## general utilities ##########
123
- import contextlib
124
- import tempfile
125
- from pathlib import Path
126
-
127
- @contextlib.contextmanager
128
- def working_directory(path):
129
- """Changes working directory and returns to previous on exit."""
130
- prev_cwd = Path.cwd()
131
- os.chdir(path)
132
- try:
133
- yield
134
- finally:
135
- os.chdir(prev_cwd)
136
-
137
- @contextlib.contextmanager
138
- def temp_working_directory():
139
- with tempfile.TemporaryDirectory(dir='.') as D:
140
- with working_directory(D):
141
- yield
142
-
143
-
144
- ############# Dataset #############
145
- def proc(row, split='train'):
146
- # column_names_train = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'mesh', 'wireframe']
147
- # column_names_test = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'wireframe']
148
- # cols = column_names_train if split == 'train' else column_names_test
149
- out = {}
150
- for k, v in row.items():
151
- colname = k.split('.')[0]
152
- if colname in {'ade20k', 'depthcm', 'gestalt'}:
153
- if colname in out:
154
- out[colname].append(v)
155
- else:
156
- out[colname] = [v]
157
- elif colname in {'wireframe', 'mesh'}:
158
- # out.update({a: b.tolist() for a,b in v.items()})
159
- out.update({a: b for a,b in v.items()})
160
- elif colname in 'kr':
161
- out[colname.upper()] = v
162
- else:
163
- out[colname] = v
164
-
165
- return Sample(out)
166
-
167
-
168
- class Sample(Dict):
169
- def __repr__(self):
170
- return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
171
-
172
-
173
-
174
- def get_params():
175
- exmaple_param_dict = {
176
- "competition_id": "usm3d/S23DR",
177
- "competition_type": "script",
178
- "metric": "custom",
179
- "token": "hf_**********************************",
180
- "team_id": "local-test-team_id",
181
- "submission_id": "local-test-submission_id",
182
- "submission_id_col": "__key__",
183
- "submission_cols": [
184
- "__key__",
185
- "wf_edges",
186
- "wf_vertices",
187
- "edge_semantics"
188
- ],
189
- "submission_rows": 180,
190
- "output_path": ".",
191
- "submission_repo": "<THE HF MODEL ID of THIS REPO",
192
- "time_limit": 7200,
193
- "dataset": "usm3d/usm-test-data-x",
194
- "submission_filenames": [
195
- "submission.parquet"
196
- ]
197
- }
198
-
199
- param_path = Path('params.json')
200
-
201
- if not param_path.exists():
202
- print('params.json not found (this means we probably aren\'t in the test env). Using example params.')
203
- params = exmaple_param_dict
204
- else:
205
- print('found params.json (this means we are probably in the test env). Using params from file.')
206
- with param_path.open() as f:
207
- params = json.load(f)
208
- print(params)
209
- return params
210
-
211
-
212
-
213
- import webdataset as wds
214
- import numpy as np
215
-
216
- def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset'):
217
- if LOCAL_DATADIR is None:
218
- raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.')
219
-
220
- local_dir = Path(LOCAL_DATADIR)
221
- if split != 'all':
222
- local_dir = local_dir / split
223
-
224
- paths = [str(p) for p in local_dir.rglob('*.tar.gz')]
225
-
226
- dataset = wds.WebDataset(paths)
227
- if decode is not None:
228
- dataset = dataset.decode(decode)
229
- else:
230
- dataset = dataset.decode()
231
-
232
- dataset = dataset.map(proc)
233
-
234
- if dataset_type == 'webdataset':
235
- return dataset
236
-
237
- if dataset_type == 'hf':
238
- import datasets
239
- from datasets import Features, Value, Sequence, Image, Array2D
240
-
241
- if split == 'train':
242
- return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
243
- elif split == 'val':
244
- return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
245
-
246
-
247
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
script.py CHANGED
@@ -8,13 +8,15 @@
8
 
9
  '''---compulsory---'''
10
  import hoho; hoho.setup() # YOU MUST CALL hoho.setup() BEFORE ANYTHING ELSE
 
 
11
  from pathlib import Path
12
  from tqdm import tqdm
13
  import pandas as pd
14
  import numpy as np
15
 
16
 
17
- def empty_solution():
18
  '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
19
  return np.zeros((2,3)), [(0, 1)]
20
 
@@ -22,11 +24,27 @@ def empty_solution():
22
  if __name__ == "__main__":
23
  print ("------------ Loading dataset------------ ")
24
  params = hoho.get_params()
25
- dataset = hoho.get_dataset(decode=None, split='all', dataset_type='webdataset')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  print('------------ Now you can do your solution ---------------')
27
  solution = []
28
  for i, sample in enumerate(tqdm(dataset)):
29
- pred_vertices, pred_edges = empty_solution()
 
 
30
  solution.append({
31
  '__key__': sample['__key__'],
32
  'wf_vertices': pred_vertices.tolist(),
@@ -35,4 +53,4 @@ if __name__ == "__main__":
35
  print('------------ Saving results ---------------')
36
  sub = pd.DataFrame(solution, columns=["__key__", "wf_vertices", "wf_edges"])
37
  sub.to_parquet(Path(params['output_path']) / "submission.parquet")
38
- print("------------ Done ------------ ")
 
8
 
9
  '''---compulsory---'''
10
  import hoho; hoho.setup() # YOU MUST CALL hoho.setup() BEFORE ANYTHING ELSE
11
+ '''---compulsory---'''
12
+
13
  from pathlib import Path
14
  from tqdm import tqdm
15
  import pandas as pd
16
  import numpy as np
17
 
18
 
19
+ def empty_solution(sample):
20
  '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
21
  return np.zeros((2,3)), [(0, 1)]
22
 
 
24
  if __name__ == "__main__":
25
  print ("------------ Loading dataset------------ ")
26
  params = hoho.get_params()
27
+
28
+ # by default it is usually better to use `get_dataset()` like this
29
+ #
30
+ # dataset = hoho.get_dataset(split='all')
31
+ #
32
+ # but in this case (because we don't do anything with the sample
33
+ # anyway) we set `decode=None`. We can set the `split` argument
34
+ # to 'train' or 'val' ('all' defaults back to 'train') if we are
35
+ # testing ourselves locally.
36
+ #
37
+ # On the test server *`split` must be set to 'all'*
38
+ # to compute both the public and private leaderboards.
39
+ #
40
+ dataset = hoho.get_dataset(split='all', decode=None)
41
+
42
  print('------------ Now you can do your solution ---------------')
43
  solution = []
44
  for i, sample in enumerate(tqdm(dataset)):
45
+ # replace this with your solution
46
+ pred_vertices, pred_edges = empty_solution(sample)
47
+
48
  solution.append({
49
  '__key__': sample['__key__'],
50
  'wf_vertices': pred_vertices.tolist(),
 
53
  print('------------ Saving results ---------------')
54
  sub = pd.DataFrame(solution, columns=["__key__", "wf_vertices", "wf_edges"])
55
  sub.to_parquet(Path(params['output_path']) / "submission.parquet")
56
+ print("------------ Done ------------ ")