PUMP / datasets /pair_loader.py
Philippe Weinzaepfel
huggingface demo
3ef85e9
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use
from pdb import set_trace as bb
from PIL import Image
import numpy as np
from core import functional as myF
from tools.common import todevice
from .transforms import instanciate_transforms
from .utils import *
class FastPairLoader (DatasetWithRng):
""" On-the-fly generation of related image pairs
crop: random crop applied to both images
scale: random scaling applied to img2
distort: random ditorsion applied to img2
self[idx] returns: (img1, img2), dict(homography=)
(homography: 3x3 array, can be nan)
"""
def __init__(self, dataset, crop=256, transform='', p_flip=0, p_swap=0, scale_jitter=0, seed=None):
super().__init__(seed)
self.dataset = self.with_same_rng(dataset)
self.transform = instanciate_transforms( transform, rng=self.rng )
self.crop_size = crop
self.p_swap = p_swap
self.p_flip = p_flip
self.scale_jitter = abs(np.log1p(scale_jitter))
def __len__(self):
return len(self.dataset)
def __repr__(self):
fmt_str = f'FastPairLoader({self.dataset},\n'
short_repr = lambda s: repr(s).strip().replace('\n',', ')[14:-1].replace(' ',' ')
fmt_str += ' Transform:\t%s\n' % short_repr(self.transform)
fmt_str +=f' Crop={self.crop_size}, scale_jitter=x{np.exp(self.scale_jitter):g}, p_swap={self.p_swap:g}'
return fmt_str
def init_worker(self, tid):
super().init_worker(tid)
self.dataset.init_worker(tid)
def set_epoch(self, epoch):
self.dataset.set_epoch(epoch)
def __getitem__(self, idx):
self.init_worker(idx) # preserve RNG for this pair
(img1, img2), gt = self.dataset[idx]
if self.rng.random() < self.p_swap:
img1, img2 = img2, img1
if 'homography' in gt: gt['homography'] = invh(gt['homography'])
if 'corres' in gt: gt['corres'] = swap_corres(gt['corres'])
if self.rng.random() < self.p_flip:
img1, img2, gt = flip_image_pair(img1, img2, gt)
# apply transformations to the second image
img2 = self.transform(dict(img=img2))
homography, corres = spatial_relationship( img1, img2, gt )
# find a good window
img1, img2 = map(self._pad_rgb_numpy, (img1, img2['img']))
if not 'debug':
from tools.viz import show_correspondences
print(np.median(corres[:,5]))
show_correspondences(img1, img2, corres, bb=bb)
def windows_from_corres( idx, scale_jitter=1 ):
c = corres[idx]
p1, p2, scale = c[0:2], c[2:4], c[6]
scale *= scale_jitter
# make windows based on scaling
win1 = window(*p1, self.crop_size, max(1, 1/scale), img1.shape)
win2 = window(*p2, self.crop_size, max(1, scale/1), img2.shape)
return win1, win2
best = 0, None
for idx in self.rng.choice(len(corres), size=min(len(corres),5), replace=False):
# pick a correspondence at random
win1, win2 = windows_from_corres( idx )
# check how many matches are in the 2 windows
score = score_windows(is_in(corres[:,0:2],win1), is_in(corres[:,2:4],win2))
if score > best[0]: best = score, idx
others = {}
if None in best: # counldn't find a good window
img1 = img2 = np.zeros((self.crop_size,self.crop_size,3), dtype=np.uint8)
corres = np.empty((0, 6), dtype=np.float32)
else:
# jitter scales
scale_jitter = np.exp(self.rng.uniform(-self.scale_jitter, self.scale_jitter))
win1, win2 = windows_from_corres( best[1], scale_jitter )
# print(win1, win2, img1.shape, img2.shape)
img1, img2 = imresize(img1[win1], self.crop_size), imresize(img2[win2], self.crop_size)
trf1, trf2 = wintrf(win1, img1), wintrf(win2, img2)
# fix rotation if necessary
angle_scores = np.bincount(corres[:,5].astype(int) % 8)
rot90 = int((((angle_scores.argmax() + 4) % 8) - 4) / 2)
if rot90: # rectify rotation
img2, trf = myF.rotate_img_90((img2, np.eye(3)), 90*rot90)
trf2 = invh(trf) @ trf2
homography = trf2 @ homography @ invh(trf1)
corres = myF.affmul((trf1,trf2), corres)
f32c = lambda i,**kw: np.require(i, requirements='CWAE', **kw)
return (f32c(img1), f32c(img2)), dict(homography = f32c(homography, dtype=np.float32), corres=corres, **others)
def _pad_rgb_numpy(self, img):
if img.mode != 'RGB':
img = img.convert('RGB')
if min(img.size) < self.crop_size:
w, h = img.size
result = Image.new('RGB', (max(w,self.crop_size), max(h,self.crop_size)), 0)
result.paste(img, (0, 0))
img = result
return np.asarray(img)
def swap_corres( corres ): # swap img1 and img2
res = corres.copy()
res[:,[0,1,2,3]] = corres[:,[2,3,0,1]]
if corres.shape[1] > 4: # invert rotation and scale
scale, rot = myF.decode_scale_rot(corres[:,5])
res[:,5] = myF.encode_scale_rot(1/scale, -rot)
return res
def flip(img):
w, h = img.size
return img.transpose(Image.FLIP_LEFT_RIGHT), np.float32( [[-1,0,w-1],[0,1,0],[0,0,1]] )
def flip_image_pair(img1, img2, gt):
img1, F1 = flip(img1)
img2, F2 = flip(img2)
res = {}
for key, value in gt.items():
if key == 'homography':
res['homography'] = F2 @ value @ F1
elif key == 'aflow':
assert False, 'flip for aflow: todo'
elif key == 'corres':
new_corres = np.c_[applyh(F1,value[:,0:2]), applyh(F2,value[:,2:4])]
if value.shape[1] == 4: pass
elif value.shape[1] == 6:
scale, rot = myF.decode_scale_rot(value[:,5])
new_code = myF.encode_scale_rot(scale, -rot)
new_corres = np.c_[new_corres,value[:,4],new_code]
res['corres'] = new_corres
else:
raise ValueError(f"flip_image_pair: bad gt field '{key}'")
return img1, img2, res
def spatial_relationship( img1, img2, gt ):
if 'homography' in gt:
homography = gt['homography']
if 'homography' in img2:
homography = np.float32(img2['homography']) @ homography
corres = corres_from_homography(homography, *img1.size)
elif 'corres' in gt:
homography = np.full((3,3), np.nan, dtype=np.float32)
corres = gt['corres']
if 'homography' in img2:
corres[:,2:4] = applyh(img2['homography'], corres[:,2:4])
else:
img2['homography'] = np.eye(3)
scales = np.sqrt(np.abs(np.linalg.det(jacobianh(img2['homography'], corres[:,0:2]).T)))
if corres.shape[1] == 4:
scales, rots = scale_rot_from_corres(corres)
corres = np.c_[corres, np.ones_like(scales), myF.encode_scale_rot(scales,rots*180/np.pi), scales]
elif corres.shape[1] == 6:
corres = np.c_[corres, scales * myF.decode_scale_rot(corres[:,5])[0]]
else:
assert ValueError(f'bad shape for corres: {corres.shape}')
return homography, corres
def scale_rot_from_corres( corres, sub=256, nn=16 ):
# select a subset of relevant correspondences
sub = np.random.choice(len(corres), size=min(len(corres),sub), replace=False)
sub = corres[sub]
# for each corres, find the scale change w.r.t. its NNs
from scipy.spatial.distance import cdist
nns = cdist(corres, sub, metric='sqeuclidean').argsort(axis=1)[:,:nn]
# affine transform for this set of neighboring correspondences
pts = sub[nns] # shape = npts x sub x 4
# [P1,1] @ A = P2 with A = 3x2 matrix
# A = [P1,1]^-1 @ P2
P1, P2 = pts[:,:,0:2], pts[:,:,2:4] # each row = list of correspondences
P1 = np.concatenate((P1,np.ones_like(P1[:,:,:1])),axis=-1)
A = (np.linalg.pinv(P1) @ P2).transpose(0,2,1)
scale, (angy,angx) = detect_scale_rotation(A.transpose(1,2,0)[:,1::-1])
rot = np.arctan2(angy, angx)
return scale.clip(min=0.2, max=5), rot
def window1(x, size, w):
l = x - int(0.5 + size / 2)
r = l + int(0.5 + size)
if l < 0: l,r = (0, r - l)
if r > w: l,r = (l + w - r, w)
if l < 0: l,r = 0,w # larger than width
return slice(l,r)
def window(cx, cy, win_size, scale, img_shape):
return (window1(int(cy), win_size*scale, img_shape[0]),
window1(int(cx), win_size*scale, img_shape[1]))
def is_in( pts, window ):
x, y = pts.T
sly, slx = window
return (slx.start <= x) & (x < slx.stop) & (sly.start <= y) & (y < sly.stop)
def score_windows( valid1, valid2 ):
inter = (valid1 & valid2).sum()
iou1 = inter / (valid1.sum() + 1e-8)
iou2 = inter / (valid2.sum() + 1e-8)
return inter * min(iou1, iou2)
def imresize( img, max_size, resample=Image.ANTIALIAS):
if max(img.shape[:2]) > max_size:
if img.shape[-1] == 2:
img = np.stack([np.float32(Image.fromarray(img[...,i]).resize((max_size,max_size), resample=resample)) for i in range(2)], axis=-1)
else:
img = np.asarray(Image.fromarray(img).resize((max_size,max_size), resample=resample))
assert img.shape[0] == img.shape[1] == max_size, bb()
return img
def wintrf( window, final_img ):
wy, wx = window
H, W = final_img.shape[:2]
T = np.float32((((wx.stop-wx.start)/W, 0, wx.start),
(0, (wy.stop-wy.start)/H, wy.start),
(0, 0, 1)) )
return invh(T)
def collate_ordered(batch, _use_shared_memory=True):
pairs, gt = zip(*batch)
imgs1, imgs2 = zip(*pairs)
assert len(imgs1) == len(imgs2) == len(gt) and isinstance(gt[0], dict)
# reorder samples (supervised ones first, unsupervised ones last)
supervised = [i for i,b in enumerate(gt) if np.isfinite(b['homography']).all()]
unsupervsd = [i for i,b in enumerate(gt) if np.isnan(b['homography']).any()]
order = supervised + unsupervsd
def collate( tensors, key=None ):
import torch
batch = todevice([tensors[i] for i in order], 'cpu')
if key == 'corres': return batch # cannot concat
if _use_shared_memory: # shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
return torch.stack(batch, dim=0, out=out)
return (collate(imgs1), collate(imgs2)), {k:collate([b[k] for b in gt],k) for k in gt[0]}
if __name__ == '__main__':
from datasets import *
from tools.viz import show_random_pairs
db = BalancedCatImagePairs(
3125, SyntheticImagePairs(RandomWebImages(0,52),distort='RandomTilting(0.5)'),
4875, SyntheticImagePairs(SfM120k_Images(),distort='RandomTilting(0.5)'),
8000, SfM120k_Pairs())
db = FastPairLoader(db,
crop=256, transform='RandomRotation(20), RandomScale(256,1536,ar=1.3,can_upscale=True), PixelNoise()',
p_swap=0.5, p_flip=0.5, scale_jitter=0, seed=777)
show_random_pairs(db)