File size: 1,477 Bytes
8390f90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
r""" Convolutional Hough Matching Networks """

import torch.nn as nn
import torch

from . import chmlearner as chmlearner
from .base import backbone


class CHMNet(nn.Module):
    def __init__(self, ktype):
        super(CHMNet, self).__init__()

        self.backbone = backbone.resnet101(pretrained=True)
        self.learner = chmlearner.CHMLearner(ktype, feat_dim=1024)

    def forward(self, src_img, trg_img):
        src_feat, trg_feat = self.extract_features(src_img, trg_img)
        correlation  = self.learner(src_feat, trg_feat)
        return correlation

    def extract_features(self, src_img, trg_img):
        feat = self.backbone.conv1.forward(torch.cat([src_img, trg_img], dim=1))
        feat = self.backbone.bn1.forward(feat)
        feat = self.backbone.relu.forward(feat)
        feat = self.backbone.maxpool.forward(feat)

        for idx in range(1, 5):
            feat = self.backbone.__getattr__('layer%d' % idx)(feat)

            if idx == 3:
                src_feat = feat.narrow(1, 0, feat.size(1) // 2).clone()
                trg_feat = feat.narrow(1, feat.size(1) // 2, feat.size(1) // 2).clone()
                return src_feat, trg_feat

    def training_objective(cls, prd_kps, trg_kps, npts):
        l2dist = (prd_kps - trg_kps).pow(2).sum(dim=1)
        loss = []
        for dist, npt in zip(l2dist, npts):
            loss.append(dist[:npt].mean())
        return torch.stack(loss).mean()