SuperFeatures / how /utils /data_helpers.py
YannisK's picture
temp state
32408ed
"""Data manipulation helpers"""
import os.path
import pickle
from cirtorch.datasets.datahelpers import cid2filename
from cirtorch.datasets.testdataset import configdataset
def load_dataset(dataset, data_root=''):
"""Return tuple (image list, query list, bounding boxes, gnd dictionary)"""
if isinstance(dataset, dict):
root = os.path.join(data_root, dataset['image_root'])
images, qimages = None, None
if dataset['database_list'] is not None:
images = [path_join(root, x.strip("\n")) for x in open(dataset['database_list']).readlines()]
if dataset['query_list'] is not None:
qimages = [path_join(root, x.strip("\n")) for x in open(dataset['query_list']).readlines()]
bbxs = None
gnd = None
elif dataset == 'train':
training_set = 'retrieval-SfM-120k'
db_root = os.path.join(data_root, 'train', training_set)
ims_root = os.path.join(db_root, 'ims')
db_fn = os.path.join(db_root, '{}.pkl'.format(training_set))
with open(db_fn, 'rb') as f:
db = pickle.load(f)['train']
images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))]
qimages = []
bbxs = None
gnd = None
elif dataset == 'val_eccv20':
db_root = os.path.join(data_root, 'train', 'retrieval-SfM-120k')
fn_val_proper = db_root+'/retrieval-SfM-120k-val-eccv2020.pkl' # pos are all with #inl >=3 & <= 10
with open(fn_val_proper, 'rb') as f:
db = pickle.load(f)
ims_root = os.path.join(db_root, 'ims')
images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))]
gnd = db['gnd']
qidx = db['qidx']
qimages = [images[x] for x in qidx]
bbxs = None
elif "/" in dataset:
with open(dataset, 'rb') as handle:
db = pickle.load(handle)
images, qimages, bbxs, gnd = db['imlist'], db['qimlist'], None, db['gnd']
else:
cfg = configdataset(dataset, os.path.join(data_root, 'test'))
images = [cfg['im_fname'](cfg, i) for i in range(cfg['n'])]
qimages = [cfg['qim_fname'](cfg, i) for i in range(cfg['nq'])]
if 'bbx' in cfg['gnd'][0].keys():
bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])]
else:
bbxs = None
gnd = cfg['gnd']
return images, qimages, bbxs, gnd
def path_join(root, name):
"""Perform os.path.join by default; if asterisk is present in root, substitute with the name.
>>> path_join('/data/img_*.jpg', '001')
'/data/img_001.jpg'
"""
if "*" in root.rsplit("/", 1)[-1]:
return root.replace("*", name)
return os.path.join(root, name)
class AverageMeter:
"""Compute and store the average and last value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
"""Update the counter by a new value"""
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count