File size: 1,935 Bytes
32408ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""Functions for downloading files necessary for training and evaluation"""

import os.path
from cirtorch.utils.download import download_train, download_test
from . import io_helpers


def download_for_eval(evaluation, demo_eval, dataset_url, globals):
    """Download datasets for evaluation and network if given by url"""
    # Datasets
    datasets = evaluation['global_descriptor']['datasets'] \
                + evaluation['local_descriptor']['datasets']
    download_datasets(datasets, dataset_url, globals)
    # Network
    if demo_eval and (demo_eval['net_path'].startswith("http://") \
                        or demo_eval['net_path'].startswith("https://")):
        net_name = os.path.basename(demo_eval['net_path'])
        io_helpers.download_files([net_name], globals['root_path'] / "models",
                                  os.path.dirname(demo_eval['net_path']) + "/",
                                  logfunc=globals["logger"].info)
        demo_eval['net_path'] = globals['root_path'] / "models" / net_name


def download_for_train(validation, dataset_url, globals):
    """Download datasets for training"""

    datasets = ["train"] + validation['global_descriptor']['datasets'] \
                + validation['local_descriptor']['datasets']
    download_datasets(datasets, dataset_url, globals)


def download_datasets(datasets, dataset_url, globals):
    """Download data associated with each required dataset"""

    if "val_eccv20" in datasets:
        download_train(globals['root_path'])
        io_helpers.download_files(["retrieval-SfM-120k-val-eccv2020.pkl"],
                                  globals['root_path'] / "train/retrieval-SfM-120k",
                                  dataset_url, logfunc=globals["logger"].info)
    elif "train" in datasets:
        download_train(globals['root_path'])

    if "roxford5k" in datasets or "rparis6k" in datasets:
        download_test(globals['root_path'])