dmytromishkin
commited on
Commit
·
b4bc845
1
Parent(s):
ac0c4b0
Added minimal empty solution
Browse files
README.md
CHANGED
@@ -1,3 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
---
|
|
|
|
1 |
+
# Empty solution example for the S23DR competition
|
2 |
+
|
3 |
+
This repo provides a minimalistic example of a valid, but empty submission to S23DR competition.
|
4 |
+
We recommend to take a look at the [another example](https://huggingface.co/usm3d/handcrafted_baseline_submission),
|
5 |
+
which implement some primitive algorithm and provides useful I/O and visualization functions.
|
6 |
+
|
7 |
+
This one, though, containt the minimal code, which succeeds at reading the dataset and producing a solution, which consists of two vertices at the origin and edge of zero length connecting them.
|
8 |
+
|
9 |
+
|
10 |
+
The repo consistst of the following parts:
|
11 |
+
|
12 |
+
- `script.py` - the main file, which is run by the competition space. It should produce `submission.parquet` as the result of the run.
|
13 |
+
- `hoho.py` - the file for parsing the dataset at the inference time. Do NOT change it.
|
14 |
+
|
15 |
+
|
16 |
---
|
17 |
license: apache-2.0
|
18 |
---
|
19 |
+
|
hoho.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import shutil
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Dict
|
6 |
+
|
7 |
+
from PIL import ImageFile
|
8 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
9 |
+
|
10 |
+
LOCAL_DATADIR = None
|
11 |
+
|
12 |
+
def setup(local_dir='./data/usm-training-data/data'):
|
13 |
+
|
14 |
+
# If we are in the test environment, we need to link the data directory to the correct location
|
15 |
+
tmp_datadir = Path('/tmp/data/data')
|
16 |
+
local_test_datadir = Path('./data/usm-test-data-x/data')
|
17 |
+
local_val_datadir = Path(local_dir)
|
18 |
+
|
19 |
+
os.system('pwd')
|
20 |
+
os.system('ls -lahtr .')
|
21 |
+
|
22 |
+
if tmp_datadir.exists() and not local_test_datadir.exists():
|
23 |
+
global LOCAL_DATADIR
|
24 |
+
LOCAL_DATADIR = local_test_datadir
|
25 |
+
# shutil.move(datadir, './usm-test-data-x/data')
|
26 |
+
print(f"Linking {tmp_datadir} to {LOCAL_DATADIR} (we are in the test environment)")
|
27 |
+
LOCAL_DATADIR.parent.mkdir(parents=True, exist_ok=True)
|
28 |
+
LOCAL_DATADIR.symlink_to(tmp_datadir)
|
29 |
+
else:
|
30 |
+
LOCAL_DATADIR = local_val_datadir
|
31 |
+
print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)")
|
32 |
+
|
33 |
+
# os.system("ls -lahtr")
|
34 |
+
|
35 |
+
assert LOCAL_DATADIR.exists(), f"Data directory {LOCAL_DATADIR} does not exist"
|
36 |
+
return LOCAL_DATADIR
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
import importlib
|
42 |
+
from pathlib import Path
|
43 |
+
import subprocess
|
44 |
+
|
45 |
+
def download_package(package_name, path_to_save='packages'):
|
46 |
+
"""
|
47 |
+
Downloads a package using pip and saves it to a specified directory.
|
48 |
+
|
49 |
+
Parameters:
|
50 |
+
package_name (str): The name of the package to download.
|
51 |
+
path_to_save (str): The path to the directory where the package will be saved.
|
52 |
+
"""
|
53 |
+
try:
|
54 |
+
# pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
|
55 |
+
subprocess.check_call([subprocess.sys.executable, "-m", "pip", "download", package_name,
|
56 |
+
"-d", str(Path(path_to_save)/package_name), # Download the package to the specified directory
|
57 |
+
"--platform", "manylinux1_x86_64", # Specify the platform
|
58 |
+
"--python-version", "38", # Specify the Python version
|
59 |
+
"--only-binary=:all:"]) # Download only binary packages
|
60 |
+
print(f'Package "{package_name}" downloaded successfully')
|
61 |
+
except subprocess.CalledProcessError as e:
|
62 |
+
print(f'Failed to downloaded package "{package_name}". Error: {e}')
|
63 |
+
|
64 |
+
|
65 |
+
def install_package_from_local_file(package_name, folder='packages'):
|
66 |
+
"""
|
67 |
+
Installs a package from a local .whl file or a directory containing .whl files using pip.
|
68 |
+
|
69 |
+
Parameters:
|
70 |
+
path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
|
71 |
+
"""
|
72 |
+
try:
|
73 |
+
pth = str(Path(folder) / package_name)
|
74 |
+
subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
|
75 |
+
"--no-index", # Do not use package index
|
76 |
+
"--find-links", pth, # Look for packages in the specified directory or at the file
|
77 |
+
package_name]) # Specify the package to install
|
78 |
+
print(f"Package installed successfully from {pth}")
|
79 |
+
except subprocess.CalledProcessError as e:
|
80 |
+
print(f"Failed to install package from {pth}. Error: {e}")
|
81 |
+
|
82 |
+
|
83 |
+
def importt(module_name, as_name=None):
|
84 |
+
"""
|
85 |
+
Imports a module and returns it.
|
86 |
+
|
87 |
+
Parameters:
|
88 |
+
module_name (str): The name of the module to import.
|
89 |
+
as_name (str): The name to use for the imported module. If None, the original module name will be used.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
The imported module.
|
93 |
+
"""
|
94 |
+
for _ in range(2):
|
95 |
+
try:
|
96 |
+
if as_name is None:
|
97 |
+
print(f'imported {module_name}')
|
98 |
+
return importlib.import_module(module_name)
|
99 |
+
else:
|
100 |
+
print(f'imported {module_name} as {as_name}')
|
101 |
+
return importlib.import_module(module_name, as_name)
|
102 |
+
except ModuleNotFoundError as e:
|
103 |
+
install_package_from_local_file(module_name)
|
104 |
+
print(f"Failed to import module {module_name}. Error: {e}")
|
105 |
+
|
106 |
+
|
107 |
+
def prepare_submission():
|
108 |
+
# Download packages from requirements.txt
|
109 |
+
if Path('requirements.txt').exists():
|
110 |
+
print('downloading packages from requirements.txt')
|
111 |
+
Path('packages').mkdir(exist_ok=True)
|
112 |
+
with open('requirements.txt') as f:
|
113 |
+
packages = f.readlines()
|
114 |
+
for p in packages:
|
115 |
+
download_package(p.strip())
|
116 |
+
|
117 |
+
|
118 |
+
print('all packages downloaded. Don\'t foget to include the packages in the submission by adding them with git lfs.')
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
########## general utilities ##########
|
123 |
+
import contextlib
|
124 |
+
import tempfile
|
125 |
+
from pathlib import Path
|
126 |
+
|
127 |
+
@contextlib.contextmanager
|
128 |
+
def working_directory(path):
|
129 |
+
"""Changes working directory and returns to previous on exit."""
|
130 |
+
prev_cwd = Path.cwd()
|
131 |
+
os.chdir(path)
|
132 |
+
try:
|
133 |
+
yield
|
134 |
+
finally:
|
135 |
+
os.chdir(prev_cwd)
|
136 |
+
|
137 |
+
@contextlib.contextmanager
|
138 |
+
def temp_working_directory():
|
139 |
+
with tempfile.TemporaryDirectory(dir='.') as D:
|
140 |
+
with working_directory(D):
|
141 |
+
yield
|
142 |
+
|
143 |
+
|
144 |
+
############# Dataset #############
|
145 |
+
def proc(row, split='train'):
|
146 |
+
# column_names_train = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'mesh', 'wireframe']
|
147 |
+
# column_names_test = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'wireframe']
|
148 |
+
# cols = column_names_train if split == 'train' else column_names_test
|
149 |
+
out = {}
|
150 |
+
for k, v in row.items():
|
151 |
+
colname = k.split('.')[0]
|
152 |
+
if colname in {'ade20k', 'depthcm', 'gestalt'}:
|
153 |
+
if colname in out:
|
154 |
+
out[colname].append(v)
|
155 |
+
else:
|
156 |
+
out[colname] = [v]
|
157 |
+
elif colname in {'wireframe', 'mesh'}:
|
158 |
+
# out.update({a: b.tolist() for a,b in v.items()})
|
159 |
+
out.update({a: b for a,b in v.items()})
|
160 |
+
elif colname in 'kr':
|
161 |
+
out[colname.upper()] = v
|
162 |
+
else:
|
163 |
+
out[colname] = v
|
164 |
+
|
165 |
+
return Sample(out)
|
166 |
+
|
167 |
+
|
168 |
+
class Sample(Dict):
|
169 |
+
def __repr__(self):
|
170 |
+
return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
def get_params():
|
175 |
+
exmaple_param_dict = {
|
176 |
+
"competition_id": "usm3d/S23DR",
|
177 |
+
"competition_type": "script",
|
178 |
+
"metric": "custom",
|
179 |
+
"token": "hf_**********************************",
|
180 |
+
"team_id": "local-test-team_id",
|
181 |
+
"submission_id": "local-test-submission_id",
|
182 |
+
"submission_id_col": "__key__",
|
183 |
+
"submission_cols": [
|
184 |
+
"__key__",
|
185 |
+
"wf_edges",
|
186 |
+
"wf_vertices",
|
187 |
+
"edge_semantics"
|
188 |
+
],
|
189 |
+
"submission_rows": 180,
|
190 |
+
"output_path": ".",
|
191 |
+
"submission_repo": "<THE HF MODEL ID of THIS REPO",
|
192 |
+
"time_limit": 7200,
|
193 |
+
"dataset": "usm3d/usm-test-data-x",
|
194 |
+
"submission_filenames": [
|
195 |
+
"submission.parquet"
|
196 |
+
]
|
197 |
+
}
|
198 |
+
|
199 |
+
param_path = Path('params.json')
|
200 |
+
|
201 |
+
if not param_path.exists():
|
202 |
+
print('params.json not found (this means we probably aren\'t in the test env). Using example params.')
|
203 |
+
params = exmaple_param_dict
|
204 |
+
else:
|
205 |
+
print('found params.json (this means we are probably in the test env). Using params from file.')
|
206 |
+
with param_path.open() as f:
|
207 |
+
params = json.load(f)
|
208 |
+
print(params)
|
209 |
+
return params
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
import webdataset as wds
|
214 |
+
import numpy as np
|
215 |
+
|
216 |
+
def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset'):
|
217 |
+
if LOCAL_DATADIR is None:
|
218 |
+
raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.')
|
219 |
+
|
220 |
+
local_dir = Path(LOCAL_DATADIR)
|
221 |
+
if split != 'all':
|
222 |
+
local_dir = local_dir / split
|
223 |
+
|
224 |
+
paths = [str(p) for p in local_dir.rglob('*.tar.gz')]
|
225 |
+
|
226 |
+
dataset = wds.WebDataset(paths)
|
227 |
+
if decode is not None:
|
228 |
+
dataset = dataset.decode(decode)
|
229 |
+
else:
|
230 |
+
dataset = dataset.decode()
|
231 |
+
|
232 |
+
dataset = dataset.map(proc)
|
233 |
+
|
234 |
+
if dataset_type == 'webdataset':
|
235 |
+
return dataset
|
236 |
+
|
237 |
+
if dataset_type == 'hf':
|
238 |
+
import datasets
|
239 |
+
from datasets import Features, Value, Sequence, Image, Array2D
|
240 |
+
|
241 |
+
if split == 'train':
|
242 |
+
return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
|
243 |
+
elif split == 'val':
|
244 |
+
return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
|
245 |
+
|
246 |
+
|
247 |
+
|
script.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### This is example of the script that will be run in the test environment.
|
2 |
+
### Some parts of the code are compulsory and you should NOT CHANGE THEM.
|
3 |
+
### They are between '''---compulsory---''' comments.
|
4 |
+
### You can change the rest of the code to define and test your solution.
|
5 |
+
### However, you should not change the signature of the provided function.
|
6 |
+
### The script would save "submission.parquet" file in the current directory.
|
7 |
+
### You can use any additional files and subdirectories to organize your code.
|
8 |
+
|
9 |
+
'''---compulsory---'''
|
10 |
+
import hoho; hoho.setup() # YOU MUST CALL hoho.setup() BEFORE ANYTHING ELSE
|
11 |
+
from pathlib import Path
|
12 |
+
from tqdm import tqdm
|
13 |
+
import pandas as pd
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
|
17 |
+
def empty_solution():
|
18 |
+
'''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
|
19 |
+
return np.zeros((2,3)), [(0, 1)], [0]
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
print ("------------ Loading dataset------------ ")
|
24 |
+
params = hoho.get_params()
|
25 |
+
dataset = hoho.get_dataset(decode=None, split='all', dataset_type='webdataset')
|
26 |
+
print('------------ Now you can do your solution ---------------')
|
27 |
+
solution = []
|
28 |
+
for i, sample in enumerate(tqdm(dataset)):
|
29 |
+
pred_vertices, pred_edges, semantics = empty_solution()
|
30 |
+
solution.append({
|
31 |
+
'__key__': sample['__key__'],
|
32 |
+
'wf_vertices': pred_vertices.tolist(),
|
33 |
+
'wf_edges': pred_edges,
|
34 |
+
'edge_semantics': semantics,
|
35 |
+
})
|
36 |
+
print('------------ Saving results ---------------')
|
37 |
+
sub = pd.DataFrame(solution, columns=["__key__", "wf_vertices", "wf_edges", "edge_semantics"])
|
38 |
+
sub.to_parquet(Path(params['output_path']) / "submission.parquet")
|
39 |
+
print("------------ Done ------------ ")
|