|
import torch.nn as nn |
|
import torchsparse.nn as spnn |
|
from torchsparse.point_tensor import PointTensor |
|
|
|
from lib.spvcnn_utils import * |
|
__all__ = ['SPVCNN_CLASSIFICATION'] |
|
|
|
|
|
|
|
class BasicConvolutionBlock(nn.Module): |
|
def __init__(self, inc, outc, ks=3, stride=1, dilation=1): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
spnn.Conv3d(inc, |
|
outc, |
|
kernel_size=ks, |
|
dilation=dilation, |
|
stride=stride), |
|
spnn.BatchNorm(outc), |
|
spnn.ReLU(True)) |
|
|
|
def forward(self, x): |
|
out = self.net(x) |
|
return out |
|
|
|
|
|
class BasicDeconvolutionBlock(nn.Module): |
|
def __init__(self, inc, outc, ks=3, stride=1): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
spnn.Conv3d(inc, |
|
outc, |
|
kernel_size=ks, |
|
stride=stride, |
|
transpose=True), |
|
spnn.BatchNorm(outc), |
|
spnn.ReLU(True)) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
def __init__(self, inc, outc, ks=3, stride=1, dilation=1): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
spnn.Conv3d(inc, |
|
outc, |
|
kernel_size=ks, |
|
dilation=dilation, |
|
stride=stride), spnn.BatchNorm(outc), |
|
spnn.ReLU(True), |
|
spnn.Conv3d(outc, |
|
outc, |
|
kernel_size=ks, |
|
dilation=dilation, |
|
stride=1), |
|
spnn.BatchNorm(outc) |
|
) |
|
|
|
self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \ |
|
nn.Sequential( |
|
spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride), |
|
spnn.BatchNorm(outc) |
|
) |
|
|
|
self.relu = spnn.ReLU(True) |
|
|
|
def forward(self, x): |
|
out = self.relu(self.net(x) + self.downsample(x)) |
|
return out |
|
|
|
|
|
class SPVCNN_CLASSIFICATION(nn.Module): |
|
def __init__(self, **kwargs): |
|
super().__init__() |
|
|
|
cr = kwargs.get('cr', 1.0) |
|
cs = [32, 32, 64, 128, 256, 256, 128, 96, 96] |
|
cs = [int(cr * x) for x in cs] |
|
|
|
if 'pres' in kwargs and 'vres' in kwargs: |
|
self.pres = kwargs['pres'] |
|
self.vres = kwargs['vres'] |
|
|
|
self.stem = nn.Sequential( |
|
spnn.Conv3d(kwargs['input_channel'], cs[0], kernel_size=3, stride=1), |
|
spnn.BatchNorm(cs[0]), |
|
spnn.ReLU(True), |
|
spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1), |
|
spnn.BatchNorm(cs[0]), |
|
spnn.ReLU(True)) |
|
|
|
self.stage1 = nn.Sequential( |
|
BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1), |
|
ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1), |
|
ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1), |
|
) |
|
|
|
self.stage2 = nn.Sequential( |
|
BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1), |
|
ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1), |
|
ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1), |
|
) |
|
|
|
self.stage3 = nn.Sequential( |
|
BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1), |
|
ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1), |
|
ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1), |
|
) |
|
|
|
self.stage4 = nn.Sequential( |
|
BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1), |
|
ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1), |
|
ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1), |
|
) |
|
self.avg_pool = spnn.GlobalAveragePooling() |
|
self.classifier = nn.Sequential(nn.Linear(cs[4], kwargs['num_classes'])) |
|
self.point_transforms = nn.ModuleList([ |
|
nn.Sequential( |
|
nn.Linear(cs[0], cs[4]), |
|
nn.BatchNorm1d(cs[4]), |
|
nn.ReLU(True), |
|
), |
|
]) |
|
|
|
self.weight_initialization() |
|
self.dropout = nn.Dropout(0.3, True) |
|
|
|
def weight_initialization(self): |
|
for m in self.modules(): |
|
if isinstance(m, nn.BatchNorm1d): |
|
nn.init.constant_(m.weight, 1) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x): |
|
|
|
z = PointTensor(x.F, x.C.float()) |
|
|
|
x0 = initial_voxelize(z, self.pres, self.vres) |
|
|
|
x0 = self.stem(x0) |
|
z0 = voxel_to_point(x0, z, nearest=False) |
|
z0.F = z0.F |
|
|
|
x1 = point_to_voxel(x0, z0) |
|
x1 = self.stage1(x1) |
|
x2 = self.stage2(x1) |
|
x3 = self.stage3(x2) |
|
x4 = self.stage4(x3) |
|
z1 = voxel_to_point(x4, z0) |
|
z1.F = z1.F + self.point_transforms[0](z0.F) |
|
y1 = point_to_voxel(x4, z1) |
|
pool = self.avg_pool(y1) |
|
out = self.classifier(pool) |
|
|
|
|
|
return out |
|
|
|
|
|
|