Spaces:
Build error
Build error
"""Data manipulation helpers""" | |
import os.path | |
import pickle | |
from cirtorch.datasets.datahelpers import cid2filename | |
from cirtorch.datasets.testdataset import configdataset | |
def load_dataset(dataset, data_root=''): | |
"""Return tuple (image list, query list, bounding boxes, gnd dictionary)""" | |
if isinstance(dataset, dict): | |
root = os.path.join(data_root, dataset['image_root']) | |
images, qimages = None, None | |
if dataset['database_list'] is not None: | |
images = [path_join(root, x.strip("\n")) for x in open(dataset['database_list']).readlines()] | |
if dataset['query_list'] is not None: | |
qimages = [path_join(root, x.strip("\n")) for x in open(dataset['query_list']).readlines()] | |
bbxs = None | |
gnd = None | |
elif dataset == 'train': | |
training_set = 'retrieval-SfM-120k' | |
db_root = os.path.join(data_root, 'train', training_set) | |
ims_root = os.path.join(db_root, 'ims') | |
db_fn = os.path.join(db_root, '{}.pkl'.format(training_set)) | |
with open(db_fn, 'rb') as f: | |
db = pickle.load(f)['train'] | |
images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))] | |
qimages = [] | |
bbxs = None | |
gnd = None | |
elif dataset == 'val_eccv20': | |
db_root = os.path.join(data_root, 'train', 'retrieval-SfM-120k') | |
fn_val_proper = db_root+'/retrieval-SfM-120k-val-eccv2020.pkl' # pos are all with #inl >=3 & <= 10 | |
with open(fn_val_proper, 'rb') as f: | |
db = pickle.load(f) | |
ims_root = os.path.join(db_root, 'ims') | |
images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))] | |
gnd = db['gnd'] | |
qidx = db['qidx'] | |
qimages = [images[x] for x in qidx] | |
bbxs = None | |
elif "/" in dataset: | |
with open(dataset, 'rb') as handle: | |
db = pickle.load(handle) | |
images, qimages, bbxs, gnd = db['imlist'], db['qimlist'], None, db['gnd'] | |
else: | |
cfg = configdataset(dataset, os.path.join(data_root, 'test')) | |
images = [cfg['im_fname'](cfg, i) for i in range(cfg['n'])] | |
qimages = [cfg['qim_fname'](cfg, i) for i in range(cfg['nq'])] | |
if 'bbx' in cfg['gnd'][0].keys(): | |
bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])] | |
else: | |
bbxs = None | |
gnd = cfg['gnd'] | |
return images, qimages, bbxs, gnd | |
def path_join(root, name): | |
"""Perform os.path.join by default; if asterisk is present in root, substitute with the name. | |
>>> path_join('/data/img_*.jpg', '001') | |
'/data/img_001.jpg' | |
""" | |
if "*" in root.rsplit("/", 1)[-1]: | |
return root.replace("*", name) | |
return os.path.join(root, name) | |
class AverageMeter: | |
"""Compute and store the average and last value""" | |
def __init__(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
"""Update the counter by a new value""" | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |