radames's picture
initial commit
c7f097c
raw history blame
No virus
2.41 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class SurfaceClassifier(nn.Module):
def __init__(self, filter_channels, num_views=1, no_residual=True, last_op=None):
super(SurfaceClassifier, self).__init__()
self.filters = []
self.num_views = num_views
self.no_residual = no_residual
filter_channels = filter_channels
self.last_op = last_op
if self.no_residual:
for l in range(0, len(filter_channels) - 1):
self.filters.append(nn.Conv1d(
filter_channels[l],
filter_channels[l + 1],
1))
self.add_module("conv%d" % l, self.filters[l])
else:
for l in range(0, len(filter_channels) - 1):
if 0 != l:
self.filters.append(
nn.Conv1d(
filter_channels[l] + filter_channels[0],
filter_channels[l + 1],
1))
else:
self.filters.append(nn.Conv1d(
filter_channels[l],
filter_channels[l + 1],
1))
self.add_module("conv%d" % l, self.filters[l])
def forward(self, feature):
'''
:param feature: list of [BxC_inxHxW] tensors of image features
:param xy: [Bx3xN] tensor of (x,y) coodinates in the image plane
:return: [BxC_outxN] tensor of features extracted at the coordinates
'''
y = feature
tmpy = feature
for i, f in enumerate(self.filters):
if self.no_residual:
y = self._modules['conv' + str(i)](y)
else:
y = self._modules['conv' + str(i)](
y if i == 0
else torch.cat([y, tmpy], 1)
)
if i != len(self.filters) - 1:
y = F.leaky_relu(y)
if self.num_views > 1 and i == len(self.filters) // 2:
y = y.view(
-1, self.num_views, y.shape[1], y.shape[2]
).mean(dim=1)
tmpy = feature.view(
-1, self.num_views, feature.shape[1], feature.shape[2]
).mean(dim=1)
if self.last_op:
y = self.last_op(y)
return y