File size: 2,212 Bytes
3ef85e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use

from pdb import set_trace as bb

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as tvf

from core.conv_mixer import ConvMixer

norm_RGB = tvf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


class PixelDesc (nn.Module):
    def __init__(self, path='models/PUMP_st.pt'):
        super().__init__()
        state_dict = torch.load( path, 'cpu' )
        self.pixel_desc = ConvMixer(output_dim=128, hidden_dim=512, depth=7, patch_size=4, kernel_size=9).eval()
        self.pixel_desc.load_state_dict(state_dict)

    def configure(self, pipeline):
        # hot-update of the default HOG-based pipeline
        pipeline.__class__ = type(type(pipeline).__name__+'_Trained', (DescPipeline, type(pipeline)), {})
        return self

    def get_atomic_patch_size(self):
        return 4

    def forward(self, img, stride=1, offset=0):
        if img.ndim == 3: img = img[None]
        trf = torch.eye(3, device=img.device)

        desc = self.pixel_desc( img )
        desc = desc[..., offset::stride, offset::stride].contiguous() # free memory
        return desc, trf


class DescPipeline:
    def extract_descs(self, img1, img2, dtype=None):
        # this will rotate the image if needed
        img1, sca1 = self.demultiplex_img_trf(img1)
        img2, sca2 = self.demultiplex_img_trf(img2)

        # convert to float and normalize std
        fimg1, fimg2 = [norm_RGB(img.type(dtype)/255) for img in (img1, img2)]

        self.pixel_desc.type(fimg1.dtype)
        desc1, trf1 = self.pixel_desc(fimg1, stride=4, offset=2)
        desc2, trf2 = self.pixel_desc(fimg2)
        return (img1, img2), (desc1.type(dtype), desc2.type(dtype)), (sca1@trf1, sca2@trf2)

    def first_level(self, desc1, desc2, **kw):
        B, C, H, W = desc1.shape
        weights = desc1.permute(0, 2, 3, 1).view(H*W, C, 1, 1) # rearrange(desc1, '1 C H W -> (H W) C 1 1')
        corr = F.conv2d(desc2, weights, padding=0, bias=None)[0]
        norms = torch.ones(desc1.shape[-2:], device=corr.device)
        return corr.view(desc1.shape[-2:]+desc2.shape[-2:]), norms