import torch import torch.nn as nn import torch.nn.functional as F class SurfaceClassifier(nn.Module): def __init__(self, filter_channels, num_views=1, no_residual=True, last_op=None): super(SurfaceClassifier, self).__init__() self.filters = [] self.num_views = num_views self.no_residual = no_residual filter_channels = filter_channels self.last_op = last_op if self.no_residual: for l in range(0, len(filter_channels) - 1): self.filters.append(nn.Conv1d( filter_channels[l], filter_channels[l + 1], 1)) self.add_module("conv%d" % l, self.filters[l]) else: for l in range(0, len(filter_channels) - 1): if 0 != l: self.filters.append( nn.Conv1d( filter_channels[l] + filter_channels[0], filter_channels[l + 1], 1)) else: self.filters.append(nn.Conv1d( filter_channels[l], filter_channels[l + 1], 1)) self.add_module("conv%d" % l, self.filters[l]) def forward(self, feature): ''' :param feature: list of [BxC_inxHxW] tensors of image features :param xy: [Bx3xN] tensor of (x,y) coodinates in the image plane :return: [BxC_outxN] tensor of features extracted at the coordinates ''' y = feature tmpy = feature for i, f in enumerate(self.filters): if self.no_residual: y = self._modules['conv' + str(i)](y) else: y = self._modules['conv' + str(i)]( y if i == 0 else torch.cat([y, tmpy], 1) ) if i != len(self.filters) - 1: y = F.leaky_relu(y) if self.num_views > 1 and i == len(self.filters) // 2: y = y.view( -1, self.num_views, y.shape[1], y.shape[2] ).mean(dim=1) tmpy = feature.view( -1, self.num_views, feature.shape[1], feature.shape[2] ).mean(dim=1) if self.last_op: y = self.last_op(y) return y