File size: 9,703 Bytes
3ef85e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 |
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use
import sys
from pdb import set_trace as bb
from PIL import Image
import numpy as np
import matplotlib.pyplot as pl; pl.ion()
import torch
import torch.nn.functional as F
from core import functional as myF
from .common import cpu, nparray, image, image_with_trf
def dbgfig(*args, **kwargs):
assert len(args) >= 2
dbg = args[-1]
if isinstance(dbg, str):
dbg = dbg.split()
for name in args[:-1]:
if {name,'all'} & set(dbg):
return pl.figure(name, **kwargs)
return False
def noticks(ax=None):
if ax is None: ax = pl.gca()
ax.set_xticks(())
ax.set_yticks(())
return ax
def plot_grid( corres, ax1, ax2=None, marker='+' ):
""" corres = Nx2 or Nx4 list of correspondences
"""
if marker is True: marker = '+'
corres = nparray(corres)
# make beautiful colors
center = corres[:,[1,0]].mean(axis=0)
colors = np.arctan2(*(corres[:,[1,0]] - center).T)
colors = np.int32(64*colors/np.pi) % 128
all_colors = np.unique(colors)
palette = {m:pl.cm.hsv(i/float(len(all_colors))) for i,m in enumerate(all_colors)}
for m in all_colors:
x, y = corres[colors==m,0:2].T
ax1.plot(x, y, marker, ms=10, mew=2, color=palette[m], scalex=0, scaley=0)
if not ax2: return
for m in all_colors:
x, y = corres[colors==m,2:4].T
ax2.plot(x, y, marker, ms=10, mew=2, color=palette[m], scalex=0, scaley=0)
def show_correspondences( img0, img1, corres, F=None, fig='last', show_grid=True, bb=None, clf=False):
img0, trf0 = img0 if isinstance(img0, tuple) else (img0, torch.eye(3))
img1, trf1 = img1 if isinstance(img1, tuple) else (img1, torch.eye(3))
if not bb: pl.ioff()
fig, axes = pl.subplots(2, 2, num=fig_num(fig, 'viz_corres'))
for i, ax in enumerate(axes.ravel()):
if clf: ax.cla()
noticks(ax).numaxis = i % 2
ax.imshow( [image(img0),image(img1)][i%2] )
if corres.shape == (3,3): # corres is an homography matrix
from pytools.hfuncs import applyh
H, W = axes[0,0].images[0].get_size()
pos1 = np.mgrid[:H,:W].reshape(2,-1)[::-1].T
pos2 = applyh(corres, pos1)
corres = np.concatenate((pos1,pos2), axis=-1)
inv = np.linalg.inv
corres = myF.affmul((inv(nparray(trf0)),inv(nparray(trf1))), nparray(corres)) # image are already downscaled
print(f">> Displaying {len(corres)} correspondences (move you mouse over the images)")
(ax1, ax2), (ax3, ax4) = axes
if corres.shape[-1] > 4:
corres = corres[corres[:,4]>0,:] # select non-null correspondences
if show_grid: plot_grid(corres, ax3, ax4, marker=show_grid)
def mouse_move(event):
if event.inaxes==None: return
numaxis = event.inaxes.numaxis
if numaxis<0: return
x,y = event.xdata, event.ydata
ax1.lines.clear()
ax2.lines.clear()
sl = slice(2*numaxis, 2*(numaxis+1))
n = np.sum((corres[:,sl] - [x,y])**2,axis=1).argmin() # find nearest point
print("\rdisplaying #%d (%d,%d) --> (%d,%d), score=%g, code=%g" % (n,
corres[n,0],corres[n,1],corres[n,2],corres[n,3],
corres[n,4] if corres.shape[-1] > 4 else np.nan,
corres[n,5] if corres.shape[-1] > 5 else np.nan), end=' '*7);sys.stdout.flush()
x,y = corres[n,0:2]
ax1.plot(x, y, '+', ms=10, mew=2, color='blue', scalex=False, scaley=False)
x,y = corres[n,2:4]
ax2.plot(x, y, '+', ms=10, mew=2, color='red', scalex=False, scaley=False)
if F is not None:
ax = None
if numaxis == 0:
line = corres[n,0:2] @ F[:2] + F[2]
ax = ax2
if numaxis == 1:
line = corres[n,2:4] @ F.T[:2] + F.T[2]
ax = ax1
if ax:
x = np.linspace(-10000,10000,2)
y = (line[2]+line[0]*x) / -line[1]
ax.plot(x, y, '-', scalex=0, scaley=0)
# we redraw only the concerned axes
renderer = fig.canvas.get_renderer()
ax1.draw(renderer)
ax2.draw(renderer)
fig.canvas.blit(ax1.bbox)
fig.canvas.blit(ax2.bbox)
cid_move = fig.canvas.mpl_connect('motion_notify_event',mouse_move)
pl.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0.02, hspace=0.02)
bb() if bb else pl.show()
fig.canvas.mpl_disconnect(cid_move)
def closest( grid, event ):
query = (event.xdata, event.ydata)
n = np.linalg.norm(grid.reshape(-1,2) - query, axis=1).argmin()
return np.unravel_index(n, grid.shape[:2])
def local_maxima( arr2d, top=5 ):
maxpooled = F.max_pool2d( arr2d[None, None], 3, padding=1, stride=1)[0,0]
local_maxima = (arr2d == maxpooled).nonzero()
order = arr2d[local_maxima.split(1,dim=1)].ravel().argsort()
return local_maxima[order[-5:]].T
def fig_num( fig, default, clf=False ):
if fig == 'last': num = pl.gcf().number
elif fig: num = fig.number
else: num = default
if clf: pl.figure(num).clf()
return num
def viz_correlation_maps( img1, img2, corr, level=0, fig=None, grid1=None, grid2=None, show_grid=False, bb=bb, **kw ):
fig, ((ax1, ax2), (ax4, ax3)) = pl.subplots(2, 2, num=fig_num(fig, 'viz_correlation_maps', clf=True))
img1 = image(img1)
img2 = image(img2)
noticks(ax1).imshow( img1 )
noticks(ax2).imshow( img2 )
ax4.hist(corr.ravel()[7:7777777:7].cpu().numpy(), bins=50)
if isinstance(corr, tuple):
H1, W1 = corr.grid.shape[:2]
corr = torch.from_numpy(corr.res_map).view(H1,W1,*corr.res_map.shape[-2:])
if grid1 is None:
s1 = int(0.5 + np.sqrt(img1.size / (3 * corr[...,0,0].numel()))) # scale factor between img1 and corr
grid1 = nparray(torch.ones_like(corr[:,:,0,0]).nonzero()*s1)[:,1::-1]
if level == 0: grid1 += s1//2
if show_grid: plot_grid(grid1, ax1)
grid1 = nparray(grid1).reshape(*corr[:,:,0,0].shape,2)
if grid2 is None:
s2 = int(0.5 + np.sqrt(img2.size / (3 * corr[0,0,...].numel()))) # scale factor between img2 and corr
grid2 = nparray(torch.ones_like(corr[0,0]).nonzero()*s2)[:,::-1]
grid2 = nparray(grid2).reshape(*corr.shape[2:],2)
def mouse_move(ev):
if ev.inaxes is ax1:
ax3.images.clear()
n = closest(grid1, ev)
ax3.imshow(corr[n].cpu().float(), vmin=0, **kw)
# find local maxima
lm = nparray(local_maxima(corr[n]))
for ax in (ax3, ax2):
if ax is ax2 and not show_grid:
ax1.lines.clear()
ax1.plot(*grid1[n], 'xr', ms=10, scalex=0, scaley=0)
ax.lines.clear()
x, y = grid2[y,x].T if ax is ax2 else lm[::-1]
if ax is not ax3:
ax.plot(x, y, 'xr', ms=10, scalex=0, scaley=0, label='local maxima')
print(f"\rCorr channel {n}. Min={corr[n].min():g}, Avg={corr[n].mean():g}, Max={corr[n].max():g} ", end='')
mouse_move(FakeEvent(0,0,inaxes=ax1))
cid_move = fig.canvas.mpl_connect('motion_notify_event', mouse_move)
pl.subplots_adjust(0,0,1,1,0,0)
pl.sca(ax4)
if bb: bb(); fig.canvas.mpl_disconnect(cid_move)
def viz_correspondences( img1, img2, corres1, corres2, fig=None ):
img1, img2 = map(image, (img1, img2))
fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = pl.subplots(3,2, num=fig_num(fig, 'viz_correspondences'))
for ax in fig.axes: noticks(ax)
ax1.imshow( img1 )
ax2.imshow( img2 )
ax3.imshow( img1 )
ax4.imshow( img2 )
corres1, corres2 = map(cpu, (corres1, corres2))
plot_grid( corres1[0], ax1, ax2 )
plot_grid( corres2[0], ax3, ax4 )
corres1, corres2 = corres1[1].float(), corres2[1].float()
ceiling = np.ceil(max(corres1.max(), corres2.max()).item())
ax5.imshow( corres1, vmin=0, vmax=ceiling )
ax6.imshow( corres2, vmin=0, vmax=ceiling )
bb()
class FakeEvent:
def __init__(self, xdata, ydata, **kw):
self.xdata = xdata
self.ydata = ydata
for name, val in kw.items():
setattr(self, name, val)
def show_random_pairs( db, pair_idxs=None, **kw ):
print('Showing random pairs from', db)
if pair_idxs is None:
pair_idxs = np.random.permutation(len(db))
for pair_idx in pair_idxs:
print(f'{pair_idx=}')
try:
img1_path, img2_path = map(db.imgs.get_image_path, db.pairs[pair_idx])
print(f'{img1_path=}\n{img2_path=}')
if hasattr(db, 'get_corres_path'):
print(f'corres_path = {db.get_corres_path(pair_idx)}')
except: pass
(img1, img2), gt = db[pair_idx]
if 'corres' in gt:
corres = gt['corres']
else:
# make corres from homography
from datasets.utils import corres_from_homography
corres = corres_from_homography(gt['homography'], *img1.size)
show_correspondences(img1, img2, corres, **kw)
if __name__=='__main__':
import argparse
import test_singlescale as pump
parser = argparse.ArgumentParser('Correspondence visualization')
parser.add_argument('--img1', required=True, help='path to first image')
parser.add_argument('--img2', required=True, help='path to second image')
parser.add_argument('--corres', required=True, help='path to correspondences')
args = parser.parse_args()
corres = np.load(args.corres)['corres']
args.resize = 0 # don't resize images
imgs = tuple(map(image, pump.Main.load_images(args)))
show_correspondences(*imgs, corres)
|