Spaces:
Build error
Build error
"""Functions for downloading files necessary for training and evaluation""" | |
import os.path | |
from cirtorch.utils.download import download_train, download_test | |
from . import io_helpers | |
def download_for_eval(evaluation, demo_eval, dataset_url, globals): | |
"""Download datasets for evaluation and network if given by url""" | |
# Datasets | |
datasets = evaluation['global_descriptor']['datasets'] \ | |
+ evaluation['local_descriptor']['datasets'] | |
download_datasets(datasets, dataset_url, globals) | |
# Network | |
if demo_eval and (demo_eval['net_path'].startswith("http://") \ | |
or demo_eval['net_path'].startswith("https://")): | |
net_name = os.path.basename(demo_eval['net_path']) | |
io_helpers.download_files([net_name], globals['root_path'] / "models", | |
os.path.dirname(demo_eval['net_path']) + "/", | |
logfunc=globals["logger"].info) | |
demo_eval['net_path'] = globals['root_path'] / "models" / net_name | |
def download_for_train(validation, dataset_url, globals): | |
"""Download datasets for training""" | |
datasets = ["train"] + validation['global_descriptor']['datasets'] \ | |
+ validation['local_descriptor']['datasets'] | |
download_datasets(datasets, dataset_url, globals) | |
def download_datasets(datasets, dataset_url, globals): | |
"""Download data associated with each required dataset""" | |
if "val_eccv20" in datasets: | |
download_train(globals['root_path']) | |
io_helpers.download_files(["retrieval-SfM-120k-val-eccv2020.pkl"], | |
globals['root_path'] / "train/retrieval-SfM-120k", | |
dataset_url, logfunc=globals["logger"].info) | |
elif "train" in datasets: | |
download_train(globals['root_path']) | |
if "roxford5k" in datasets or "rparis6k" in datasets: | |
download_test(globals['root_path']) | |