File size: 6,662 Bytes
3ef85e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use

from pdb import set_trace as bb
from tqdm import tqdm
import numpy as np
import torch

import test_singlescale as tss
import core.functional as myF
from tools.viz import dbgfig, show_correspondences


def arg_parser(parser = None):
    parser = parser or tss.arg_parser()

    parser.add_argument('--rec-overlap', type=float, default=0.5, help='overlap between tiles in [0,0.5]')
    parser.add_argument('--rec-score-thr', type=float, default=1, help='corres score threshold to guide fine levels')
    parser.add_argument('--rec-fast-thr', type=float, default=0.1, help='prune block if less than `fast` corres fall in it')

    return parser


class RecursivePUMP (tss.SingleScalePUMP):
    """ Recursive PUMP: 
        1) find initial correspondences at a coarse scale, 
        2) refine them at a selection of finer scales
    """
    def __init__(self, coarse_size=512, fine_size=512, rec_overlap=0.5, rec_score_thr=1.0, 
                       rec_fast_thr = 0.1, **other_options ):
        super().__init__(**other_options)
        assert 10 < coarse_size < 1024
        assert 10 < fine_size < 1024
        assert 0 <= rec_overlap < 1
        assert 0 < rec_fast_thr < 1
        self.coarse_size = coarse_size
        self.fine_size = fine_size
        self.overlap = rec_overlap
        self.score_thr = rec_score_thr
        self.fast_thr = rec_fast_thr

    @torch.no_grad()
    def forward(self, img1, img2, ret='corres', dbg=()):
        img1, sca1 = self.demultiplex_img_trf(img1, force=True)
        img2, sca2 = self.demultiplex_img_trf(img2, force=True)
        input_trfs = (sca1, sca2)

        # coarse first level with low-res images
        corres = self.coarse_correspondences(img1, img2)

        # fine level: iterate on HQ blocks
        accu1, accu2 = (self._make_accu(img1), self._make_accu(img2))
        for block1, block2 in tqdm(list(self._enumerate_blocks(img1, img2, corres))):
            # print(f"img1[{block1[}:{}, {}:{}]"
            accus, trfs = tss.SingleScalePUMP.forward(self, block1, block2, ret='raw', dbg=dbg)
            self._update_accu( accu1, accus[0], trfs[0][:2,2] )
            self._update_accu( accu2, accus[1], trfs[1][:2,2] )

        demul = lambda accu: (accu[:,:,:4].reshape(-1,4).clone(), accu[:,:,4].clone())
        corres = demul(accu1), demul(accu2)
        if dbgfig('corres', dbg): viz_correspondences(img1, img2, *corres, fig='last')
        corres = [(myF.affmul(input_trfs,pos),score) for pos, score in corres] # rectify scaling etc.
        if ret == 'raw': return corres, input_trfs
        return self.reciprocal(*corres)

    def coarse_correspondences(self, img1, img2, **kw):
        # joint image resize, because relative size is important (multiscale)
        shape1, shape2 = img1.shape[-2:], img2.shape[-2:]
        if max(shape1 + shape2) > self.coarse_size:
            f1 = self.coarse_size / max(shape1)
            f2 = self.coarse_size / max(shape2)
            f = min(f1, f2)
            img1 = myF.imresize( img1, int(0.5+f*max(shape1)) )
            img2 = myF.imresize( img2, int(0.5+f*max(shape2)) )
        else:
            f = 1

        init_corres = tss.SingleScalePUMP.forward(self, img1, img2, **kw)
        # show_correspondences(img1, img2, init_corres, fig='last')
        corres = init_corres[init_corres[:,4] > self.score_thr]
        print(f"  keeping {len(corres)}/{len(init_corres)} corres with score > {self.score_thr} ...")
        return corres

    def _update_accu(self, accu, update, offset ):
        pos, scores = update
        H, W = scores.shape
        offx, offy = map(lambda i: int(i/4), offset)
        accu = accu[offy:offy+H, offx:offx+W]
        better = accu[:,:,4] < scores
        accu[:,:,4][better] = scores[better].float()
        accu[:,:,0:4][better] = pos.reshape(H,W,4)[better]

    def _enumerate_blocks(self, img1, img2, corres):
        H1, W1, H2, W2 = img1.shape[1:] + img2.shape[1:]
        size, step = self.fine_size, int(self.overlap * self.fine_size)
        def regular_steps(size): 
            if size <= self.fine_size: return [0]
            nb = int(np.ceil(size / step)) - 1 # garranted >= 1
            return (np.linspace(0, size-self.fine_size, nb) / 4 + 0.5).astype(int) * 4
        def translation(x,y):
            res = torch.eye(3, device=img1.device)
            res[0,2] = x
            res[1,2] = y
            return res
        def block2(x2,y2):
            return img2[:,y2:y2+size,x2:x2+size], translation(x2,y2)
        cx1, cy1 = corres[:,0:2].T

        for y1 in regular_steps(H1):
          for x1 in regular_steps(W1):
            block1 = (img1[:,y1:y1+size,x1:x1+size], translation(x1,y1))
            c2 = corres[(y1<=cy1) & (cy1<y1+size) & (x1<=cx1) & (cx1<x1+size)]
            nb_init = len(c2)
            while len(c2):
                cx2, cy2 = c2[:,2:4].T
                x2, y2 = (int(max(0,min(W2-size,cx2.median()-size//2)) / 4 + 0.5) * 4, 
                          int(max(0,min(H2-size,cy2.median()-size//2)) / 4 + 0.5) * 4)
                inside = (y2<=cy2) & (cy2<y2+size) & (x2<=cx2) & (cx2<x2+size)
                if not inside.any(): 
                    x2, y2 = c2[np.random.choice(len(c2)),2:4]
                    x2 = int(max(0,min(W2-size,x2-size//2)) / 4 + 0.5) * 4
                    y2 = int(max(0,min(H2-size,y2-size//2)) / 4 + 0.5) * 4
                    inside = (y2<=cy2) & (cy2<y2+size) & (x2<=cx2) & (cx2<x2+size)

                if inside.sum()/nb_init >= self.fast_thr:
                    yield block1, block2(x2,y2)

                c2 = c2[~inside] # remove

    def _make_accu(self, img):
        C, H, W = img.shape
        return img.new_zeros(((H+3)//4, (W+3)//4, 5), dtype=torch.float32)



class Main (tss.Main):
    @staticmethod
    def build_matcher(args, device):
        # set coarse and fine size based on now obsolete --resize argument
        if isinstance(args.resize, int): args.resize = [args.resize]
        if len(args.resize) == 1: args.resize *= 2
        args.rec_coarse_size, args.rec_fine_size = args.resize
        args.resize = 0 # disable it so that image loading does not downsize images

        options = Main.get_options( args )

        matcher = RecursivePUMP( coarse_size=args.rec_coarse_size, fine_size=args.rec_fine_size, 
            rec_overlap=args.rec_overlap, rec_score_thr=args.rec_score_thr, rec_fast_thr=args.rec_fast_thr,
            **options)

        return tss.Main.tune_matcher(matcher, **vars(args) ).to(device)


if __name__ == '__main__':
    Main().run_from_args(arg_parser().parse_args())