r""" 4D and 6D convolutional Hough matching layers """ from torch.nn.modules.conv import _ConvNd import torch.nn.functional as F import torch.nn as nn import torch from common.logger import Logger from . import chm_kernel def fast4d(corr, kernel, bias=None): r""" Optimized implementation of 4D convolution """ bsz, ch, srch, srcw, trgh, trgw = corr.size() out_channels, _, kernel_size, kernel_size, kernel_size, kernel_size = kernel.size() psz = kernel_size // 2 out_corr = torch.zeros((bsz, out_channels, srch, srcw, trgh, trgw)) corr = corr.transpose(1, 2).contiguous().view(bsz * srch, ch, srcw, trgh, trgw) for pidx, k3d in enumerate(kernel.permute(2, 0, 1, 3, 4, 5)): inter_corr = F.conv3d(corr, k3d, bias=None, stride=1, padding=psz) inter_corr = inter_corr.view(bsz, srch, out_channels, srcw, trgh, trgw).transpose(1, 2).contiguous() add_sid = max(psz - pidx, 0) add_fid = min(srch, srch + psz - pidx) slc_sid = max(pidx - psz, 0) slc_fid = min(srch, srch - psz + pidx) out_corr[:, :, add_sid:add_fid, :, :, :] += inter_corr[:, :, slc_sid:slc_fid, :, :, :] if bias is not None: out_corr += bias.view(1, out_channels, 1, 1, 1, 1) return out_corr def fast6d(corr, kernel, bias, diagonal_idx): r""" Optimized implementation of 6D convolutional Hough matching NOTE: this function only supports kernel size of (3, 3, 5, 5, 5, 5). r""" bsz, _, s6d, s6d, s4d, s4d, s4d, s4d = corr.size() _, _, ks6d, ks6d, ks4d, ks4d, ks4d, ks4d = kernel.size() corr = corr.permute(0, 2, 3, 1, 4, 5, 6, 7).contiguous().view(-1, 1, s4d, s4d, s4d, s4d) kernel = kernel.view(-1, ks6d ** 2, ks4d, ks4d, ks4d, ks4d).transpose(0, 1) corr = fast4d(corr, kernel).view(bsz, s6d * s6d, ks6d * ks6d, s4d, s4d, s4d, s4d) corr = corr.view(bsz, s6d, s6d, ks6d, ks6d, s4d, s4d, s4d, s4d).transpose(2, 3).\ contiguous().view(-1, s6d * ks6d, s4d, s4d, s4d, s4d) ndiag = s6d + (ks6d // 2) * 2 first_sum = [] for didx in diagonal_idx: first_sum.append(corr[:, didx, :, :, :, :].sum(dim=1)) first_sum = torch.stack(first_sum).transpose(0, 1).view(bsz, s6d * ks6d, ndiag, s4d, s4d, s4d, s4d) corr = [] for didx in diagonal_idx: corr.append(first_sum[:, didx, :, :, :, :, :].sum(dim=1)) sidx = ks6d // 2 eidx = ndiag - sidx corr = torch.stack(corr).transpose(0, 1)[:, sidx:eidx, sidx:eidx, :, :, :, :].unsqueeze(1).contiguous() corr += bias.view(1, -1, 1, 1, 1, 1, 1, 1) reverse_idx = torch.linspace(s6d * s6d - 1, 0, s6d * s6d).long() corr = corr.view(bsz, 1, s6d * s6d, s4d, s4d, s4d, s4d)[:, :, reverse_idx, :, :, :, :].\ view(bsz, 1, s6d, s6d, s4d, s4d, s4d, s4d) return corr def init_param_idx4d(param_dict): param_idx = [] for key in param_dict: curr_offset = int(key.split('_')[-1]) param_idx.append(torch.tensor(param_dict[key])) return param_idx class CHM4d(_ConvNd): r""" 4D convolutional Hough matching layer NOTE: this function only supports in_channels=1 and out_channels=1. r""" def __init__(self, in_channels, out_channels, ksz4d, ktype, bias=True): super(CHM4d, self).__init__(in_channels, out_channels, (ksz4d,) * 4, (1,) * 4, (0,) * 4, (1,) * 4, False, (0,) * 4, 1, bias, padding_mode='zeros') # Zero kernel initialization self.zero_kernel4d = torch.zeros((in_channels, out_channels, ksz4d, ksz4d, ksz4d, ksz4d)) self.nkernels = in_channels * out_channels # Initialize kernel indices param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate() param_shared = param_dict4d is not None if param_shared: # Initialize the shared parameters (multiplied by the number of times being shared) self.param_idx = init_param_idx4d(param_dict4d) weights = torch.abs(torch.randn(len(self.param_idx) * self.nkernels)) * 1e-3 for weight, param_idx in zip(weights.sort()[0], self.param_idx): weight *= len(param_idx) self.weight = nn.Parameter(weights) else: # full kernel initialziation self.param_idx = None self.weight = nn.Parameter(torch.abs(self.weight)) if bias: self.bias = nn.Parameter(torch.tensor(0.0)) Logger.info('(%s) # params in CHM 4D: %d' % (ktype, len(self.weight.view(-1)))) def forward(self, x): kernel = self.init_kernel() x = fast4d(x, kernel, self.bias) return x def init_kernel(self): # Initialize CHM kernel (divided by the number of times being shared) ksz = self.kernel_size[-1] if self.param_idx is None: kernel = self.weight else: kernel = torch.zeros_like(self.zero_kernel4d) for idx, pdx in enumerate(self.param_idx): kernel = kernel.view(-1, ksz, ksz, ksz, ksz) for jdx, kernel_single in enumerate(kernel): weight = self.weight[idx + jdx * len(self.param_idx)].repeat(len(pdx)) / len(pdx) kernel_single.view(-1)[pdx] += weight kernel = kernel.view(self.in_channels, self.out_channels, ksz, ksz, ksz, ksz) return kernel class CHM6d(_ConvNd): r""" 6D convolutional Hough matching layer with kernel (3, 3, 5, 5, 5, 5) NOTE: this function only supports in_channels=1 and out_channels=1. r""" def __init__(self, in_channels, out_channels, ksz6d, ksz4d, ktype): kernel_size = (ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d) super(CHM6d, self).__init__(in_channels, out_channels, kernel_size, (1,) * 6, (0,) * 6, (1,) * 6, False, (0,) * 6, 1, bias=True, padding_mode='zeros') # Zero kernel initialization self.zero_kernel4d = torch.zeros((ksz4d, ksz4d, ksz4d, ksz4d)) self.zero_kernel6d = torch.zeros((ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d)) self.nkernels = in_channels * out_channels # Initialize kernel indices # Indices in scale-space where 4D convolutions are performed (3 by 3 scale-space) self.diagonal_idx = [torch.tensor(x) for x in [[6], [3, 7], [0, 4, 8], [1, 5], [2]]] param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate() param_shared = param_dict4d is not None if param_shared: # psi & iso kernel initialization if ktype == 'psi': self.param_dict6d = [[4], [0, 8], [2, 6], [1, 3, 5, 7]] elif ktype == 'iso': self.param_dict6d = [[0, 4, 8], [2, 6], [1, 3, 5, 7]] self.param_dict6d = [torch.tensor(i) for i in self.param_dict6d] # Initialize the shared parameters (multiplied by the number of times being shared) self.param_idx = init_param_idx4d(param_dict4d) self.param = [] for param_dict6d in self.param_dict6d: weights = torch.abs(torch.randn(len(self.param_idx))) * 1e-3 for weight, param_idx in zip(weights, self.param_idx): weight *= (len(param_idx) * len(param_dict6d)) self.param.append(nn.Parameter(weights)) self.param = nn.ParameterList(self.param) else: # full kernel initialziation self.param_idx = None self.param = nn.Parameter(torch.abs(self.weight) * 1e-3) Logger.info('(%s) # params in CHM 6D: %d' % (ktype, sum([len(x.view(-1)) for x in self.param]))) self.weight = None def forward(self, corr): kernel = self.init_kernel() corr = fast6d(corr, kernel, self.bias, self.diagonal_idx) return corr def init_kernel(self): # Initialize CHM kernel (divided by the number of times being shared) if self.param_idx is None: return self.param kernel6d = torch.zeros_like(self.zero_kernel6d) for idx, (param, param_dict6d) in enumerate(zip(self.param, self.param_dict6d)): ksz4d = self.kernel_size[-1] kernel4d = torch.zeros_like(self.zero_kernel4d) for jdx, pdx in enumerate(self.param_idx): kernel4d.view(-1)[pdx] += ((param[jdx] / len(pdx)) / len(param_dict6d)) kernel6d.view(-1, ksz4d, ksz4d, ksz4d, ksz4d)[param_dict6d] += kernel4d.view(ksz4d, ksz4d, ksz4d, ksz4d) kernel6d = kernel6d.unsqueeze(0).unsqueeze(0) return kernel6d