|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.parameter import Parameter |
|
|
|
from .score import peakiness_score |
|
|
|
|
|
class BaseNet(nn.Module): |
|
"""Helper class to construct a fully-convolutional network that |
|
extract a l2-normalized patch descriptor. |
|
""" |
|
|
|
def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False): |
|
super(BaseNet, self).__init__() |
|
self.inchan = inchan |
|
self.curchan = inchan |
|
self.dilated = dilated |
|
self.dilation = dilation |
|
self.bn = bn |
|
self.bn_affine = bn_affine |
|
|
|
def _make_bn(self, outd): |
|
return nn.BatchNorm2d(outd, affine=self.bn_affine) |
|
|
|
def _add_conv( |
|
self, |
|
outd, |
|
k=3, |
|
stride=1, |
|
dilation=1, |
|
bn=True, |
|
relu=True, |
|
k_pool=1, |
|
pool_type="max", |
|
bias=False, |
|
): |
|
|
|
d = self.dilation * dilation |
|
|
|
|
|
|
|
|
|
|
|
conv_params = dict( |
|
padding=((k - 1) * d) // 2, dilation=d, stride=stride, bias=bias |
|
) |
|
|
|
ops = nn.ModuleList([]) |
|
|
|
ops.append(nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params)) |
|
if bn and self.bn: |
|
ops.append(self._make_bn(outd)) |
|
if relu: |
|
ops.append(nn.ReLU(inplace=True)) |
|
self.curchan = outd |
|
|
|
if k_pool > 1: |
|
if pool_type == "avg": |
|
ops.append(torch.nn.AvgPool2d(kernel_size=k_pool)) |
|
elif pool_type == "max": |
|
ops.append(torch.nn.MaxPool2d(kernel_size=k_pool)) |
|
else: |
|
print(f"Error, unknown pooling type {pool_type}...") |
|
|
|
return nn.Sequential(*ops) |
|
|
|
|
|
class Quad_L2Net(BaseNet): |
|
"""Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs.""" |
|
|
|
def __init__(self, dim=128, mchan=4, relu22=False, **kw): |
|
BaseNet.__init__(self, **kw) |
|
self.conv0 = self._add_conv(8 * mchan) |
|
self.conv1 = self._add_conv(8 * mchan, bn=False) |
|
self.bn1 = self._make_bn(8 * mchan) |
|
self.conv2 = self._add_conv(16 * mchan, stride=2) |
|
self.conv3 = self._add_conv(16 * mchan, bn=False) |
|
self.bn3 = self._make_bn(16 * mchan) |
|
self.conv4 = self._add_conv(32 * mchan, stride=2) |
|
self.conv5 = self._add_conv(32 * mchan) |
|
|
|
self.conv6_0 = self._add_conv(32 * mchan) |
|
self.conv6_1 = self._add_conv(32 * mchan) |
|
self.conv6_2 = self._add_conv(dim, bn=False, relu=False) |
|
self.out_dim = dim |
|
|
|
self.moving_avg_params = nn.ParameterList( |
|
[ |
|
Parameter(torch.tensor(1.0), requires_grad=False), |
|
Parameter(torch.tensor(1.0), requires_grad=False), |
|
Parameter(torch.tensor(1.0), requires_grad=False), |
|
] |
|
) |
|
|
|
def forward(self, x): |
|
|
|
x0 = self.conv0(x) |
|
x1 = self.conv1(x0) |
|
x1_bn = self.bn1(x1) |
|
x2 = self.conv2(x1_bn) |
|
x3 = self.conv3(x2) |
|
x3_bn = self.bn3(x3) |
|
x4 = self.conv4(x3_bn) |
|
x5 = self.conv5(x4) |
|
x6_0 = self.conv6_0(x5) |
|
x6_1 = self.conv6_1(x6_0) |
|
x6_2 = self.conv6_2(x6_1) |
|
|
|
|
|
comb_weights = torch.tensor([1.0, 2.0, 3.0], device=x.device) |
|
comb_weights /= torch.sum(comb_weights) |
|
ksize = [3, 2, 1] |
|
det_score_maps = [] |
|
|
|
for idx, xx in enumerate([x1, x3, x6_2]): |
|
if self.training: |
|
instance_max = torch.max(xx) |
|
self.moving_avg_params[idx].data = ( |
|
self.moving_avg_params[idx] * 0.99 + instance_max.detach() * 0.01 |
|
) |
|
else: |
|
pass |
|
|
|
alpha, beta = peakiness_score( |
|
xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx] |
|
) |
|
|
|
score_vol = alpha * beta |
|
det_score_map = torch.max(score_vol, dim=1, keepdim=True)[0] |
|
det_score_map = F.interpolate( |
|
det_score_map, size=x.shape[2:], mode="bilinear", align_corners=True |
|
) |
|
det_score_map = comb_weights[idx] * det_score_map |
|
det_score_maps.append(det_score_map) |
|
|
|
det_score_map = torch.sum(torch.stack(det_score_maps, dim=0), dim=0) |
|
|
|
|
|
return x6_2, det_score_map, x1, x3 |
|
|