Vincentqyw
update: features and matchers
404d2af
raw
history blame
10.1 kB
# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use
import os, pdb
import numpy as np
from PIL import Image
from .dataset import Dataset, CatDataset
from tools.transforms import instanciate_transformation
from tools.transforms_tools import persp_apply
class PairDataset (Dataset):
""" A dataset that serves image pairs with ground-truth pixel correspondences.
"""
def __init__(self):
Dataset.__init__(self)
self.npairs = 0
def get_filename(self, img_idx, root=None):
if is_pair(img_idx): # if img_idx is a pair of indices, we return a pair of filenames
return tuple(Dataset.get_filename(self, i, root) for i in img_idx)
return Dataset.get_filename(self, img_idx, root)
def get_image(self, img_idx):
if is_pair(img_idx): # if img_idx is a pair of indices, we return a pair of images
return tuple(Dataset.get_image(self, i) for i in img_idx)
return Dataset.get_image(self, img_idx)
def get_corres_filename(self, pair_idx):
raise NotImplementedError()
def get_homography_filename(self, pair_idx):
raise NotImplementedError()
def get_flow_filename(self, pair_idx):
raise NotImplementedError()
def get_mask_filename(self, pair_idx):
raise NotImplementedError()
def get_pair(self, idx, output=()):
""" returns (img1, img2, `metadata`)
`metadata` is a dict() that can contain:
flow: optical flow
aflow: absolute flow
corres: list of 2d-2d correspondences
mask: boolean image of flow validity (in the first image)
...
"""
raise NotImplementedError()
def get_paired_images(self):
fns = set()
for i in range(self.npairs):
a,b = self.image_pairs[i]
fns.add(self.get_filename(a))
fns.add(self.get_filename(b))
return fns
def __len__(self):
return self.npairs # size should correspond to the number of pairs, not images
def __repr__(self):
res = 'Dataset: %s\n' % self.__class__.__name__
res += ' %d images,' % self.nimg
res += ' %d image pairs' % self.npairs
res += '\n root: %s...\n' % self.root
return res
@staticmethod
def _flow2png(flow, path):
flow = np.clip(np.around(16*flow), -2**15, 2**15-1)
bytes = np.int16(flow).view(np.uint8)
Image.fromarray(bytes).save(path)
return flow / 16
@staticmethod
def _png2flow(path):
try:
flow = np.asarray(Image.open(path)).view(np.int16)
return np.float32(flow) / 16
except:
raise IOError("Error loading flow for %s" % path)
class StillPairDataset (PairDataset):
""" A dataset of 'still' image pairs.
By overloading a normal image dataset, it appends the get_pair(i) function
that serves trivial image pairs (img1, img2) where img1 == img2 == get_image(i).
"""
def get_pair(self, pair_idx, output=()):
if isinstance(output, str): output = output.split()
img1, img2 = map(self.get_image, self.image_pairs[pair_idx])
W,H = img1.size
sx = img2.size[0] / float(W)
sy = img2.size[1] / float(H)
meta = {}
if 'aflow' in output or 'flow' in output:
mgrid = np.mgrid[0:H, 0:W][::-1].transpose(1,2,0).astype(np.float32)
meta['aflow'] = mgrid * (sx,sy)
meta['flow'] = meta['aflow'] - mgrid
if 'mask' in output:
meta['mask'] = np.ones((H,W), np.uint8)
if 'homography' in output:
meta['homography'] = np.diag(np.float32([sx, sy, 1]))
return img1, img2, meta
class SyntheticPairDataset (PairDataset):
""" A synthetic generator of image pairs.
Given a normal image dataset, it constructs pairs using random homographies & noise.
"""
def __init__(self, dataset, scale='', distort=''):
self.attach_dataset(dataset)
self.distort = instanciate_transformation(distort)
self.scale = instanciate_transformation(scale)
def attach_dataset(self, dataset):
assert isinstance(dataset, Dataset) and not isinstance(dataset, PairDataset)
self.dataset = dataset
self.npairs = dataset.nimg
self.get_image = dataset.get_image
self.get_key = dataset.get_key
self.get_filename = dataset.get_filename
self.root = None
def make_pair(self, img):
return img, img
def get_pair(self, i, output=('aflow')):
""" Procedure:
This function applies a series of random transformations to one original image
to form a synthetic image pairs with perfect ground-truth.
"""
if isinstance(output, str):
output = output.split()
original_img = self.dataset.get_image(i)
scaled_image = self.scale(original_img)
scaled_image, scaled_image2 = self.make_pair(scaled_image)
scaled_and_distorted_image = self.distort(
dict(img=scaled_image2, persp=(1,0,0,0,1,0,0,0)))
W, H = scaled_image.size
trf = scaled_and_distorted_image['persp']
meta = dict()
if 'aflow' in output or 'flow' in output:
# compute optical flow
xy = np.mgrid[0:H,0:W][::-1].reshape(2,H*W).T
aflow = np.float32(persp_apply(trf, xy).reshape(H,W,2))
meta['flow'] = aflow - xy.reshape(H,W,2)
meta['aflow'] = aflow
if 'homography' in output:
meta['homography'] = np.float32(trf+(1,)).reshape(3,3)
return scaled_image, scaled_and_distorted_image['img'], meta
def __repr__(self):
res = 'Dataset: %s\n' % self.__class__.__name__
res += ' %d images and pairs' % self.npairs
res += '\n root: %s...' % self.dataset.root
res += '\n Scale: %s' % (repr(self.scale).replace('\n',''))
res += '\n Distort: %s' % (repr(self.distort).replace('\n',''))
return res + '\n'
class TransformedPairs (PairDataset):
""" Automatic data augmentation for pre-existing image pairs.
Given an image pair dataset, it generates synthetically jittered pairs
using random transformations (e.g. homographies & noise).
"""
def __init__(self, dataset, trf=''):
self.attach_dataset(dataset)
self.trf = instanciate_transformation(trf)
def attach_dataset(self, dataset):
assert isinstance(dataset, PairDataset)
self.dataset = dataset
self.nimg = dataset.nimg
self.npairs = dataset.npairs
self.get_image = dataset.get_image
self.get_key = dataset.get_key
self.get_filename = dataset.get_filename
self.root = None
def get_pair(self, i, output=''):
""" Procedure:
This function applies a series of random transformations to one original image
to form a synthetic image pairs with perfect ground-truth.
"""
img_a, img_b_, metadata = self.dataset.get_pair(i, output)
img_b = self.trf({'img': img_b_, 'persp':(1,0,0,0,1,0,0,0)})
trf = img_b['persp']
if 'aflow' in metadata or 'flow' in metadata:
aflow = metadata['aflow']
aflow[:] = persp_apply(trf, aflow.reshape(-1,2)).reshape(aflow.shape)
W, H = img_a.size
flow = metadata['flow']
mgrid = np.mgrid[0:H, 0:W][::-1].transpose(1,2,0).astype(np.float32)
flow[:] = aflow - mgrid
if 'corres' in metadata:
corres = metadata['corres']
corres[:,1] = persp_apply(trf, corres[:,1])
if 'homography' in metadata:
# p_b = homography * p_a
trf_ = np.float32(trf+(1,)).reshape(3,3)
metadata['homography'] = np.float32(trf_ @ metadata['homography'])
return img_a, img_b['img'], metadata
def __repr__(self):
res = 'Transformed Pairs from %s\n' % type(self.dataset).__name__
res += ' %d images and pairs' % self.npairs
res += '\n root: %s...' % self.dataset.root
res += '\n transform: %s' % (repr(self.trf).replace('\n',''))
return res + '\n'
class CatPairDataset (CatDataset):
''' Concatenation of several pair datasets.
'''
def __init__(self, *datasets):
CatDataset.__init__(self, *datasets)
pair_offsets = [0]
for db in datasets:
pair_offsets.append(db.npairs)
self.pair_offsets = np.cumsum(pair_offsets)
self.npairs = self.pair_offsets[-1]
def __len__(self):
return self.npairs
def __repr__(self):
fmt_str = "CatPairDataset("
for db in self.datasets:
fmt_str += str(db).replace("\n"," ") + ', '
return fmt_str[:-2] + ')'
def pair_which(self, i):
pos = np.searchsorted(self.pair_offsets, i, side='right')-1
assert pos < self.npairs, 'Bad pair index %d >= %d' % (i, self.npairs)
return pos, i - self.pair_offsets[pos]
def pair_call(self, func, i, *args, **kwargs):
b, j = self.pair_which(i)
return getattr(self.datasets[b], func)(j, *args, **kwargs)
def get_pair(self, i, output=()):
b, i = self.pair_which(i)
return self.datasets[b].get_pair(i, output)
def get_flow_filename(self, pair_idx, *args, **kwargs):
return self.pair_call('get_flow_filename', pair_idx, *args, **kwargs)
def get_mask_filename(self, pair_idx, *args, **kwargs):
return self.pair_call('get_mask_filename', pair_idx, *args, **kwargs)
def get_corres_filename(self, pair_idx, *args, **kwargs):
return self.pair_call('get_corres_filename', pair_idx, *args, **kwargs)
def is_pair(x):
if isinstance(x, (tuple,list)) and len(x) == 2:
return True
if isinstance(x, np.ndarray) and x.ndim == 1 and x.shape[0] == 2:
return True
return False