r""" Convolutional Hough Matching Networks """ import torch.nn as nn import torch from . import chmlearner as chmlearner from .base import backbone class CHMNet(nn.Module): def __init__(self, ktype): super(CHMNet, self).__init__() self.backbone = backbone.resnet101(pretrained=True) self.learner = chmlearner.CHMLearner(ktype, feat_dim=1024) def forward(self, src_img, trg_img): src_feat, trg_feat = self.extract_features(src_img, trg_img) correlation = self.learner(src_feat, trg_feat) return correlation def extract_features(self, src_img, trg_img): feat = self.backbone.conv1.forward(torch.cat([src_img, trg_img], dim=1)) feat = self.backbone.bn1.forward(feat) feat = self.backbone.relu.forward(feat) feat = self.backbone.maxpool.forward(feat) for idx in range(1, 5): feat = self.backbone.__getattr__('layer%d' % idx)(feat) if idx == 3: src_feat = feat.narrow(1, 0, feat.size(1) // 2).clone() trg_feat = feat.narrow(1, feat.size(1) // 2, feat.size(1) // 2).clone() return src_feat, trg_feat def training_objective(cls, prd_kps, trg_kps, npts): l2dist = (prd_kps - trg_kps).pow(2).sum(dim=1) loss = [] for dist, npt in zip(l2dist, npts): loss.append(dist[:npt].mean()) return torch.stack(loss).mean()