# -------------------------------------------------------- # SiamMask # Licensed under The MIT License # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) # -------------------------------------------------------- import torch.nn as nn import torch.nn.functional as F class RPN(nn.Module): def __init__(self): super(RPN, self).__init__() def forward(self, z_f, x_f): raise NotImplementedError def template(self, template): raise NotImplementedError def track(self, search): raise NotImplementedError def param_groups(self, start_lr, feature_mult=1, key=None): if key is None: params = filter(lambda x:x.requires_grad, self.parameters()) else: params = [v for k, v in self.named_parameters() if (key in k) and v.requires_grad] params = [{'params': params, 'lr': start_lr * feature_mult}] return params def conv2d_dw_group(x, kernel): batch, channel = kernel.shape[:2] x = x.view(1, batch*channel, x.size(2), x.size(3)) # 1 * (b*c) * k * k kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3)) # (b*c) * 1 * H * W out = F.conv2d(x, kernel, groups=batch*channel) out = out.view(batch, channel, out.size(2), out.size(3)) return out class DepthCorr(nn.Module): def __init__(self, in_channels, hidden, out_channels, kernel_size=3): super(DepthCorr, self).__init__() # adjust layer for asymmetrical features self.conv_kernel = nn.Sequential( nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False), nn.BatchNorm2d(hidden), nn.ReLU(inplace=True), ) self.conv_search = nn.Sequential( nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False), nn.BatchNorm2d(hidden), nn.ReLU(inplace=True), ) self.head = nn.Sequential( nn.Conv2d(hidden, hidden, kernel_size=1, bias=False), nn.BatchNorm2d(hidden), nn.ReLU(inplace=True), nn.Conv2d(hidden, out_channels, kernel_size=1) ) def forward_corr(self, kernel, input): kernel = self.conv_kernel(kernel) input = self.conv_search(input) feature = conv2d_dw_group(input, kernel) return feature def forward(self, kernel, search): feature = self.forward_corr(kernel, search) out = self.head(feature) return out