File size: 3,132 Bytes
32408ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""Data manipulation helpers"""

import os.path
import pickle

from cirtorch.datasets.datahelpers import cid2filename
from cirtorch.datasets.testdataset import configdataset


def load_dataset(dataset, data_root=''):
    """Return tuple (image list, query list, bounding boxes, gnd dictionary)"""

    if isinstance(dataset, dict):
        root = os.path.join(data_root, dataset['image_root'])
        images, qimages = None, None
        if dataset['database_list'] is not None:
            images = [path_join(root, x.strip("\n")) for x in open(dataset['database_list']).readlines()]
        if dataset['query_list'] is not None:
            qimages = [path_join(root, x.strip("\n")) for x in open(dataset['query_list']).readlines()]
        bbxs = None
        gnd = None

    elif dataset == 'train':
        training_set = 'retrieval-SfM-120k'
        db_root = os.path.join(data_root, 'train', training_set)
        ims_root = os.path.join(db_root, 'ims')
        db_fn = os.path.join(db_root, '{}.pkl'.format(training_set))
        with open(db_fn, 'rb') as f:
            db = pickle.load(f)['train']
        images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))]
        qimages = []
        bbxs = None
        gnd = None

    elif dataset == 'val_eccv20':
        db_root = os.path.join(data_root, 'train', 'retrieval-SfM-120k')
        fn_val_proper = db_root+'/retrieval-SfM-120k-val-eccv2020.pkl' # pos are all with #inl >=3 & <= 10
        with open(fn_val_proper, 'rb') as f:
            db = pickle.load(f)
        ims_root = os.path.join(db_root, 'ims')
        images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))]
        gnd = db['gnd']
        qidx = db['qidx']
        qimages = [images[x] for x in qidx]
        bbxs = None

    elif "/" in dataset:
        with open(dataset, 'rb') as handle:
            db = pickle.load(handle)
        images, qimages, bbxs, gnd = db['imlist'], db['qimlist'], None, db['gnd']

    else:
        cfg = configdataset(dataset, os.path.join(data_root, 'test'))
        images = [cfg['im_fname'](cfg, i) for i in range(cfg['n'])]
        qimages = [cfg['qim_fname'](cfg, i) for i in range(cfg['nq'])]
        if 'bbx' in cfg['gnd'][0].keys():
            bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])]
        else:
            bbxs = None
        gnd = cfg['gnd']

    return images, qimages, bbxs, gnd


def path_join(root, name):
    """Perform os.path.join by default; if asterisk is present in root, substitute with the name.

    >>> path_join('/data/img_*.jpg', '001')
    '/data/img_001.jpg'
    """
    if "*" in root.rsplit("/", 1)[-1]:
        return root.replace("*", name)
    return os.path.join(root, name)


class AverageMeter:
    """Compute and store the average and last value"""

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        """Update the counter by a new value"""
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count