Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from . import functional as F | |
from .ball_query import BallQuery | |
from .shared_mlp import SharedMLP | |
__all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule'] | |
class PointNetAModule(nn.Module): | |
def __init__(self, in_channels, out_channels, include_coordinates=True): | |
super().__init__() | |
if not isinstance(out_channels, (list, tuple)): | |
out_channels = [[out_channels]] | |
elif not isinstance(out_channels[0], (list, tuple)): | |
out_channels = [out_channels] | |
mlps = [] | |
total_out_channels = 0 | |
for _out_channels in out_channels: | |
mlps.append( | |
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0), | |
out_channels=_out_channels, dim=1) | |
) | |
total_out_channels += _out_channels[-1] | |
self.include_coordinates = include_coordinates | |
self.out_channels = total_out_channels | |
self.mlps = nn.ModuleList(mlps) | |
def forward(self, inputs): | |
features, coords = inputs | |
if self.include_coordinates: | |
features = torch.cat([features, coords], dim=1) | |
coords = torch.zeros((coords.size(0), 3, 1), device=coords.device) | |
if len(self.mlps) > 1: | |
features_list = [] | |
for mlp in self.mlps: | |
features_list.append(mlp(features).max(dim=-1, keepdim=True).values) | |
return torch.cat(features_list, dim=1), coords | |
else: | |
return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords | |
def extra_repr(self): | |
return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}' | |
class PointNetSAModule(nn.Module): | |
def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True): | |
super().__init__() | |
# print(f"PointNet module, in={in_channels}, out={out_channels}") | |
if not isinstance(radius, (list, tuple)): | |
radius = [radius] | |
if not isinstance(num_neighbors, (list, tuple)): | |
num_neighbors = [num_neighbors] * len(radius) | |
assert len(radius) == len(num_neighbors) | |
if not isinstance(out_channels, (list, tuple)): | |
out_channels = [[out_channels]] * len(radius) | |
elif not isinstance(out_channels[0], (list, tuple)): | |
out_channels = [out_channels] * len(radius) | |
assert len(radius) == len(out_channels) | |
groupers, mlps = [], [] | |
total_out_channels = 0 | |
for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors): | |
groupers.append( | |
BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates) | |
) | |
mlps.append( | |
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0), | |
out_channels=_out_channels, dim=2) | |
) | |
total_out_channels += _out_channels[-1] | |
self.num_centers = num_centers | |
self.out_channels = total_out_channels | |
self.groupers = nn.ModuleList(groupers) | |
self.mlps = nn.ModuleList(mlps) | |
def forward(self, inputs): | |
features, coords, temb = inputs | |
centers_coords = F.furthest_point_sample(coords, self.num_centers) # use this to reduce the number of points to next layer | |
features_list = [] | |
# print("Pointnet input shape:", features.shape) | |
for grouper, mlp in zip(self.groupers, self.mlps): | |
features, temb = mlp(grouper(coords, centers_coords, temb, features)) | |
features_list.append(features.max(dim=-1).values) | |
# print("Point net output shape:", features.shape) | |
if len(features_list) > 1: | |
return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb | |
else: | |
return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb | |
def extra_repr(self): | |
return f'num_centers={self.num_centers}, out_channels={self.out_channels}' | |
class PointNetFPModule(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
# print(f"IN channels={in_channels}, out channels={out_channels}") | |
super().__init__() | |
self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1) | |
def forward(self, inputs): | |
# print(inputs.shape) | |
if len(inputs) == 3: | |
points_coords, centers_coords, centers_features, temb = inputs | |
points_features = None | |
else: | |
points_coords, centers_coords, centers_features, points_features, temb = inputs | |
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features) | |
interpolated_temb = F.nearest_neighbor_interpolate(points_coords, centers_coords, temb) | |
if points_features is not None: | |
interpolated_features = torch.cat( | |
[interpolated_features, points_features], dim=1 | |
) # concate interpolated, with original point features (394, N) | |
return self.mlp(interpolated_features), points_coords, interpolated_temb | |