|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
|
import open3d as o3d
|
|
|
import numpy as np
|
|
|
import os
|
|
|
from urllib.request import urlretrieve
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
import MinkowskiEngine as ME
|
|
|
from MinkowskiEngine import SparseTensor
|
|
|
from MinkowskiEngine.utils import summary, batched_coordinates
|
|
|
|
|
|
|
|
|
class StackUNet(ME.MinkowskiNetwork):
|
|
|
def __init__(self, in_nchannel, out_nchannel, D):
|
|
|
ME.MinkowskiNetwork.__init__(self, D)
|
|
|
channels = [in_nchannel, 16, 32]
|
|
|
self.net = nn.Sequential(
|
|
|
ME.MinkowskiStackSum(
|
|
|
ME.MinkowskiConvolution(
|
|
|
channels[0],
|
|
|
channels[1],
|
|
|
kernel_size=3,
|
|
|
stride=1,
|
|
|
dimension=D,
|
|
|
),
|
|
|
nn.Sequential(
|
|
|
ME.MinkowskiConvolution(
|
|
|
channels[0],
|
|
|
channels[1],
|
|
|
kernel_size=3,
|
|
|
stride=2,
|
|
|
dimension=D,
|
|
|
),
|
|
|
ME.MinkowskiStackSum(
|
|
|
nn.Identity(),
|
|
|
nn.Sequential(
|
|
|
ME.MinkowskiConvolution(
|
|
|
channels[1],
|
|
|
channels[2],
|
|
|
kernel_size=3,
|
|
|
stride=2,
|
|
|
dimension=D,
|
|
|
),
|
|
|
ME.MinkowskiConvolutionTranspose(
|
|
|
channels[2],
|
|
|
channels[1],
|
|
|
kernel_size=3,
|
|
|
stride=1,
|
|
|
dimension=D,
|
|
|
),
|
|
|
ME.MinkowskiPoolingTranspose(
|
|
|
kernel_size=2, stride=2, dimension=D
|
|
|
),
|
|
|
),
|
|
|
),
|
|
|
ME.MinkowskiPoolingTranspose(kernel_size=2, stride=2, dimension=D),
|
|
|
),
|
|
|
),
|
|
|
ME.MinkowskiToFeature(),
|
|
|
nn.Linear(channels[1], out_nchannel, bias=True),
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.net(x)
|
|
|
|
|
|
|
|
|
|
|
|
class TestSummary(unittest.TestCase):
|
|
|
|
|
|
def setUp(self):
|
|
|
file_name, voxel_size = "1.ply", 0.02
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
self.net = StackUNet(3, 20, D=3).to(self.device)
|
|
|
if not os.path.isfile(file_name):
|
|
|
print('Downloading an example pointcloud...')
|
|
|
urlretrieve("https://bit.ly/3c2iLhg", file_name)
|
|
|
|
|
|
pcd = o3d.io.read_point_cloud(file_name)
|
|
|
coords = np.array(pcd.points)
|
|
|
colors = np.array(pcd.colors)
|
|
|
|
|
|
self.sinput = SparseTensor(
|
|
|
features=torch.from_numpy(colors).float(),
|
|
|
coordinates=batched_coordinates([coords / voxel_size], dtype=torch.float32),
|
|
|
device=self.device,
|
|
|
)
|
|
|
|
|
|
def test(self):
|
|
|
summary(self.net, self.sinput)
|
|
|
|
|
|
|