# Copyright 2022-present NAVER Corp. # CC BY-NC-SA 4.0 # Available only for non-commercial use from pdb import set_trace as bb import os, os.path as osp from tqdm import tqdm from PIL import Image import numpy as np import torch from .image_set import ImageSet from .transforms import instanciate_transforms from .utils import DatasetWithRng invh = np.linalg.inv class ImagePairs (DatasetWithRng): """ Base class for a dataset that serves image pairs. """ imgs = None # regular image dataset pairs = [] # list of (idx1, idx2), ... def __init__(self, image_set, pairs, trf=None, **rng): assert image_set and pairs, 'empty images or pairs' super().__init__(**rng) self.imgs = image_set self.pairs = pairs self.trf = instanciate_transforms(trf, rng=self.rng) def __len__(self): return len(self.pairs) def __getitem__(self, idx): transform = self.trf or (lambda x:x) pair = tuple(map(transform, self._load_pair(idx))) return pair, {} def _load_pair(self, idx): i,j = self.pairs[idx] img1 = self.imgs.get_image(i) return (img1, img1) if i == j else (img1, self.imgs.get_image(j)) def __repr__(self): return f'{self.__class__.__name__}({len(self)} pairs from {self.imgs})' class StillImagePairs (ImagePairs): """ A dataset of 'still' image pairs used for debugging purposes. """ def __init__(self, image_set, pairs=None, **rng): if isinstance(image_set, ImagePairs): super().__init__(image_set.imgs, pairs or image_set.pairs, **rng) else: super().__init__(image_set, pairs or [(i,i) for i in range(len(image_set))], **rng) def __getitem__(self, idx): img1, img2 = self._load_pair(idx) sx, sy = img2.size / np.float32(img1.size) return (img1, img2), dict(homography=np.diag(np.float32([sx, sy, 1]))) class SyntheticImagePairs (StillImagePairs): """ A synthetic generator of image pairs. Given a normal image dataset, it constructs pairs using random homographies & noise. scale: prior image scaling. distort: distortion applied independently to (img1,img2) if sym=True else just img2 sym: (bool) see above. """ def __init__(self, image_set, scale='', distort='', sym=False, **rng): super().__init__(image_set, **rng) self.symmetric = sym self.scale = instanciate_transforms(scale, rng=self.rng) self.distort = instanciate_transforms(distort, rng=self.rng) def __getitem__(self, idx): (img1, img2), gt = super().__getitem__(idx) img1 = dict(img=img1, homography=np.eye(3,dtype=np.float32)) if img1['img'] is img2: img1 = self.scale(img1) img2 = self.distort(dict(img1)) if self.symmetric: img1 = self.distort(img1) else: if self.symmetric: img1 = self.distort(self.scale(img1)) img2 = self.distort(self.scale(dict(img=img2, **gt))) return (img1['img'], img2['img']), dict(homography=img2['homography'] @ invh(img1['homography'])) def __repr__(self): format = lambda s: ','.join(l.strip() for l in repr(s).splitlines() if l).replace(',','',1) return f"{self.__class__.__name__}({len(self)} images, scale={format(self.scale)}, distort={format(self.distort)})" class CatImagePairs (DatasetWithRng): """ Concatenation of several ImagePairs datasets """ def __init__(self, *pair_datasets, seed=torch.initial_seed()): assert all(isinstance(db, ImagePairs) for db in pair_datasets) self.pair_datasets = pair_datasets DatasetWithRng.__init__(self, seed=seed) # init last self._init() def _init(self): self._pair_offsets = np.cumsum([0] + [len(db) for db in self.pair_datasets]) self.npairs = self._pair_offsets[-1] def __len__(self): return self.npairs def __repr__(self): fmt_str = f"{type(self).__name__}({len(self)} pairs," for i,db in enumerate(self.pair_datasets): npairs = self._pair_offsets[i+1] - self._pair_offsets[i] fmt_str += f'\n\t{npairs} from '+str(db).replace("\n"," ") + ',' return fmt_str[:-1] + ')' def __getitem__(self, idx): b, i = self._which(idx) return self.pair_datasets[b].__getitem__(i) def _which(self, i): pos = np.searchsorted(self._pair_offsets, i, side='right')-1 assert pos < self.npairs, 'Bad pair index %d >= %d' % (i, self.npairs) return pos, i - self._pair_offsets[pos] def _call(self, func, i, *args, **kwargs): b, j = self._which(i) return getattr(self.pair_datasets[b], func)(j, *args, **kwargs) def init_worker(self, tid): for db in self.pair_datasets: db.init_worker(tid) class BalancedCatImagePairs (CatImagePairs): """ Balanced concatenation of several ImagePairs datasets """ def __init__(self, npairs=0, *pair_datasets, **kw): assert isinstance(npairs, int) and npairs >= 0, 'BalancedCatImagePairs(npairs != int)' assert len(pair_datasets) > 0, 'no dataset provided' if len(pair_datasets) >= 3 and isinstance(pair_datasets[1], int): assert len(pair_datasets) % 2 == 1 pair_datasets = [npairs] + list(pair_datasets) npairs, pair_datasets = pair_datasets[0::2], pair_datasets[1::2] assert all(isinstance(n, int) for n in npairs) self._pair_offsets = np.cumsum([0]+npairs) self.npairs = self._pair_offsets[-1] else: self.npairs = npairs or max(len(db) for db in pair_datasets) self._pair_offsets = np.linspace(0, self.npairs, len(pair_datasets)+1).astype(int) CatImagePairs.__init__(self, *pair_datasets, **kw) def set_epoch(self, epoch): DatasetWithRng.init_worker(self, epoch) # random seed only depends on the epoch self._init() # reset permutations for this epoch def init_worker(self, tid): CatImagePairs.init_worker(self, tid) def _init(self): self._perms = [] for i,db in enumerate(self.pair_datasets): assert len(db), 'cannot balance if there is an empty dataset' avail = self._pair_offsets[i+1] - self._pair_offsets[i] idxs = np.arange(len(db)) while len(idxs) < avail: idxs = np.r_[idxs,idxs] if self.seed: # if not seed, then no shuffle self.rng.shuffle(idxs[(avail//len(db))*len(db):]) self._perms.append( idxs[:avail] ) # print(self._perms) def _which(self, i): pos, idx = super()._which(i) return pos, self._perms[pos][idx] class UnsupervisedPairs (ImagePairs): """ Unsupervised image pairs obtained from SfM """ def __init__(self, img_set, pair_file_path): assert isinstance(img_set, ImageSet), bb() self.pair_list = self._parse_pair_list(pair_file_path) self.corres_dir = osp.join(osp.split(pair_file_path)[0], 'corres') tag_to_idx = {n:i for i,n in enumerate(img_set.imgs)} img_indices = lambda pair: tuple([tag_to_idx[n] for n in pair]) super().__init__(img_set, [img_indices(pair) for pair in self.pair_list]) def __repr__(self): return f"{type(self).__name__}({len(self)} pairs from {self.imgs})" def _parse_pair_list(self, pair_file_path): res = [] for row in open(pair_file_path).read().splitlines(): row = row.split() if len(row) != 2: raise IOError() res.append((row[0], row[1])) return res def get_corres_path(self, pair_idx): img1, img2 = [osp.basename(self.imgs.imgs[i]) for i in self.pairs[pair_idx]] return osp.join(self.corres_dir, f'{img1}_{img2}.npy') def get_corres(self, pair_idx): return np.load(self.get_corres_path(pair_idx)) def __getitem__(self, idx): img1, img2 = self._load_pair(idx) return (img1, img2), dict(corres=self.get_corres(idx)) if __name__ == '__main__': from datasets import * from tools.viz import show_random_pairs db = BalancedCatImagePairs( 3125, SyntheticImagePairs(RandomWebImages(0,52),distort='RandomTilting(0.5)'), 4875, SyntheticImagePairs(SfM120k_Images(),distort='RandomTilting(0.5)'), 8000, SfM120k_Pairs()) show_random_pairs(db)