### This is example of the script that will be run in the test environment. | |
### Some parts of the code are compulsory and you should NOT CHANGE THEM. | |
### They are between '''---compulsory---''' comments. | |
### You can change the rest of the code to define and test your solution. | |
### However, you should not change the signature of the provided function. | |
### The script would save "submission.parquet" file in the current directory. | |
### You can use any additional files and subdirectories to organize your code. | |
'''---compulsory---''' | |
import hoho; hoho.setup() # YOU MUST CALL hoho.setup() BEFORE ANYTHING ELSE | |
'''---compulsory---''' | |
from pathlib import Path | |
from tqdm import tqdm | |
import pandas as pd | |
import numpy as np | |
def empty_solution(sample): | |
'''Return a minimal valid solution, i.e. 2 vertices and 1 edge.''' | |
return np.zeros((2,3)), [(0, 1)] | |
if __name__ == "__main__": | |
print ("------------ Loading dataset------------ ") | |
params = hoho.get_params() | |
# by default it is usually better to use `get_dataset()` like this | |
# | |
# dataset = hoho.get_dataset(split='all') | |
# | |
# but in this case (because we don't do anything with the sample | |
# anyway) we set `decode=None`. We can set the `split` argument | |
# to 'train' or 'val' ('all' defaults back to 'train') if we are | |
# testing ourselves locally. | |
# | |
# dataset = hoho.get_dataset(split='val', decode=None) | |
# | |
# On the test server *`split` must be set to 'all'* | |
# to compute both the public and private leaderboards. | |
# | |
dataset = hoho.get_dataset(split='all', decode=None) | |
print('------------ Now you can do your solution ---------------') | |
solution = [] | |
for i, sample in enumerate(tqdm(dataset)): | |
# replace this with your solution | |
pred_vertices, pred_edges = empty_solution(sample) | |
solution.append({ | |
'__key__': sample['__key__'], | |
'wf_vertices': pred_vertices.tolist(), | |
'wf_edges': pred_edges | |
}) | |
print('------------ Saving results ---------------') | |
sub = pd.DataFrame(solution, columns=["__key__", "wf_vertices", "wf_edges"]) | |
sub.to_parquet(Path(params['output_path']) / "submission.parquet") | |
print("------------ Done ------------ ") |