# Adapted from SPFN
import torch
import torch.nn as nn
from torch_cluster import fps
# from .point_convolution_universal import TransitionDown, TransitionUp
# from .model_util import construct_conv1d_modules, construct_conv_modules, CorrFlowPredNet, set_bn_not_training, set_grad_to_none
# from .utils import farthest_point_sampling, get_knn_idx, batched_index_select
def set_bn_not_training(module):
if isinstance(module, nn.ModuleList):
for block in module:
elif isinstance(module, nn.Sequential):
for block in module:
if isinstance(block, nn.BatchNorm1d) or isinstance(block, nn.BatchNorm2d):
block.is_training = False
raise ValueError("Not recognized module to set not training!")
def set_grad_to_none(module):
if isinstance(module, nn.ModuleList):
for block in module:
elif isinstance(module, nn.Sequential):
for block in module:
for param in block.parameters():
param.grad = None
raise ValueError("Not recognized module to set not training!")
def apply_module_with_conv2d_bn(x, module): # bsz x npts x feats -> bsz x feats x npts ->
x = x.transpose(2, 3).contiguous().transpose(1, 2).contiguous()
# print(x.size())
for layer in module:
for sublayer in layer:
x = sublayer(x.contiguous())
x = x.float()
x = torch.transpose(x, 1, 2).transpose(2, 3)
return x
def batched_index_select(values, indices, dim = 1):
value_dims = values.shape[(dim + 1):]
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
indices = indices[(..., *((None,) * len(value_dims)))]
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
value_expand_len = len(indices_shape) - (dim + 1)
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
value_expand_shape = [-1] * len(values.shape)
expand_slice = slice(dim, (dim + value_expand_len))
value_expand_shape[expand_slice] = indices.shape[expand_slice]
values = values.expand(*value_expand_shape)
dim += value_expand_len
return values.gather(dim, indices)
def init_weight(blocks):
for module in blocks:
if isinstance(module, nn.Sequential):
for subm in module:
if isinstance(subm, nn.Linear):
elif isinstance(module, nn.Linear):
def construct_conv_modules(mlp_dims, n_in, last_act=True, bn=True):
rt_module_list = nn.ModuleList()
for i, dim in enumerate(mlp_dims):
inc, ouc = n_in if i == 0 else mlp_dims[i-1], dim
if (i < len(mlp_dims) - 1 or (i == len(mlp_dims) - 1 and last_act)):
blk = nn.Sequential(
nn.Conv2d(in_channels=inc, out_channels=ouc, kernel_size=(1, 1), stride=(1, 1), bias=True),
nn.BatchNorm2d(num_features=ouc, eps=1e-5, momentum=0.1),
# nn.GroupNorm(num_groups=4, num_channels=ouc),
# elif bn and ouc % 4 == 0:
elif bn: # and ouc % 4 == 0:
blk = nn.Sequential(
nn.Conv2d(in_channels=inc, out_channels=ouc, kernel_size=(1, 1), stride=(1, 1), bias=True),
nn.BatchNorm2d(num_features=ouc, eps=1e-5, momentum=0.1),
# nn.GroupNorm(num_groups=4, num_channels=ouc),
blk = nn.Sequential(
nn.Conv2d(in_channels=inc, out_channels=ouc, kernel_size=(1, 1), stride=(1, 1), bias=True),
return rt_module_list
def farthest_point_sampling(pos: torch.FloatTensor, n_sampling: int):
bz, N = pos.size(0), pos.size(1)
feat_dim = pos.size(-1)
device = pos.device
sampling_ratio = float(n_sampling / N)
pos_float = pos.float()
batch = torch.arange(bz, dtype=torch.long).view(bz, 1).to(device)
mult_one = torch.ones((N,), dtype=torch.long).view(1, N).to(device)
batch = batch * mult_one
batch = batch.view(-1)
pos_float = pos_float.contiguous().view(-1, feat_dim).contiguous() # (bz x N, 3)
# sampling_ratio = torch.tensor([sampling_ratio for _ in range(bz)], dtype=torch.float).to(device)
# batch = torch.zeros((N, ), dtype=torch.long, device=device)
sampled_idx = fps(pos_float, batch, ratio=sampling_ratio, random_start=False)
# shape of sampled_idx?
return sampled_idx
class PointnetPP(nn.Module):
def __init__(self, in_feat_dim: int):
super(PointnetPP, self).__init__()
# if args is not None: #
# self.skip_global = args.skip_global
# else:
self.skip_global = False
# self.n_samples = [512, 128, 1] # if "motion" not in args.task else [256, 128, 1]
self.n_samples = [256, 128, 1]
# self.n_samples = [1024, 512, 1]
mlps = [[64,64,128], [128,128,256], [256,512,1024]]
mlps_in = [[in_feat_dim,64,64], [128+3,128,128], [256+3,256,512]]
# up_mlps = [[256, 256], [256, 128], [128, 128, 128]]
up_mlps = [[512, 512], [512, 512], [512, 512, 512]]
# up_mlps_in = [1024+256, 256+128, 128+3+3]
up_mlps_in = [1024 + 256, 512 + 128, 512 + in_feat_dim]
self.in_feat_dim = in_feat_dim
self.radius = [0.2, 0.4, None]
self.radius = [None, None, None]
# if args is not None: # radius? #
# n_layers = args.pnpp_n_layers
# self.n_samples = self.n_samples[:n_layers]
# mlps, mlps_in = mlps[:n_layers], mlps_in[:n_layers]
# self.radius = self.radius[:n_layers]
# up_mlps = up_mlps[-n_layers:]
# up_mlps_in = up_mlps_in[-n_layers:]
self.mlp_layers = nn.ModuleList()
for i, (dims_in, dims_out) in enumerate(zip(mlps_in, mlps)):
# if self.skip_global and i == len(mlps_in) - 1:
# break
conv_layers = construct_conv_modules(
mlp_dims=dims_out, n_in=dims_in[0],
self.up_mlp_layers = nn.ModuleList()
for i, (dim_in, dims_out) in enumerate(zip(up_mlps_in, up_mlps)):
# if self.skip_global and i == 0:
# continue
conv_layers = construct_conv_modules(
mlp_dims=dims_out, n_in=dim_in,
# last_act=False,
def eval(self):
# return super().eval()
def set_bn_no_training(self):
for sub_module in self.mlp_layers:
for sub_module in self.up_mlp_layers:
def set_grad_to_none(self):
for sub_module in self.mlp_layers:
for sub_module in self.up_mlp_layers:
def sample_and_group(self, feat, pos, n_samples, use_pos=True, k=64):
bz, N = pos.size(0), pos.size(1)
fps_idx = farthest_point_sampling(pos=pos[:, :, :3], n_sampling=n_samples)
# bz x n_samples x pos_dim
# sampled_pos = batched_index_select(values=pos, indices=fps_idx, dim=1)
sampled_pos = pos.contiguous().view(bz * N, -1)[fps_idx, :].contiguous().view(bz, n_samples, -1)
ppdist = torch.sum((sampled_pos.unsqueeze(2) - pos.unsqueeze(1)) ** 2, dim=-1)
ppdist = torch.sqrt(ppdist)
topk_dist, topk_idx = torch.topk(ppdist, k=k, dim=2, largest=False)
# if n_samples == 1:
grouped_pos = batched_index_select(values=pos, indices=topk_idx, dim=1)
grouped_pos = grouped_pos - sampled_pos.unsqueeze(2)
if feat is not None:
grouped_feat = batched_index_select(values=feat, indices=topk_idx, dim=1)
if use_pos:
grouped_feat =[grouped_pos, grouped_feat], dim=-1)
grouped_feat = grouped_pos
return grouped_feat, topk_dist, sampled_pos
def max_pooling_with_r(self, grouped_feat, ppdist, r=None):
if r is None:
res, _ = torch.max(grouped_feat, dim=2)
# bz x N x k
indicators = (ppdist <= r).float()
indicators_expand = indicators.unsqueeze(-1).repeat(1, 1, 1, grouped_feat.size(-1))
indicators_expand[indicators_expand < 0.5] = -1e8
indicators_expand[indicators_expand > 0.5] = 0.
# grouped_feat[indicators_expand < 0.5] = -1e8
# res, _ = torch.max(grouped_feat, dim=2)
res, _ = torch.max(grouped_feat + indicators_expand, dim=2)
return res
def interpolate_features(self, feat, p1, p2, ):
dist = p2[:, :, None, :] - p1[:, None, :, :]
dist = torch.norm(dist, dim=-1, p=2, keepdim=False)
topkk = min(3, dist.size(-1))
dist, idx = dist.topk(topkk, dim=-1, largest=False)
# bz x N2 x 3
# print(dist.size(), idx.size())
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
# weight.size() = bz x N2 x 3; idx.size() = bz x N2 x 3
three_nearest_features = batched_index_select(feat, idx, dim=1) # 1 is the idx dimension
interpolated_feats = torch.sum(three_nearest_features * weight[:, :, :, None], dim=2, keepdim=False)
return interpolated_feats
def forward(self, x: torch.FloatTensor, pos: torch.FloatTensor, return_global=False,
# x = x[:, :, 3:] # bsz x nnf x nnbasepts x nnbaseptsfeats #
bz = pos.size(0)
cache = []
cache.append((None if x is None else x.clone(), pos.clone()))
n_samples = self.n_samples
for i, n_samples in enumerate(n_samples): # point view ---> how to look joints from the base pts here --> and for the point convs #
if n_samples == 1:
grouped_feat = x.unsqueeze(1)
grouped_feat =
[pos.unsqueeze(1), grouped_feat], dim=-1
grouped_feat = apply_module_with_conv2d_bn(
grouped_feat, self.mlp_layers[i]
x, _ = torch.max(grouped_feat, dim=1, keepdim=True)
sampled_pos = torch.zeros((bz, 1, 3), dtype=torch.float, device=pos.device)
pos = sampled_pos
grouped_feat, topk_dist, pos = self.sample_and_group(x, pos, n_samples, use_pos=True, k=64)
# print(f"x: {x.size()}, pos: {pos.size()}, grouped_feat: {grouped_feat.size()}")
grouped_feat = apply_module_with_conv2d_bn(
grouped_feat, self.mlp_layers[i]
cur_radius = self.radius[i]
x = self.max_pooling_with_r(grouped_feat, topk_dist, r=cur_radius)
cache.append((x.clone(), pos.clone()))
up_mlp_layers = self.up_mlp_layers
# global_x = x
for i, up_conv_layers in enumerate(up_mlp_layers):
prev_x, prev_pos = cache[-i-2][0], cache[-i-2][1]
# print(prev_pos.size(), x.size(), pos.size())
# interpolate x via pos & prev_pos # interpolate features
interpolated_feats = self.interpolate_features(x, pos, prev_pos)
if prev_x is None:
prev_x = prev_pos
elif i == len(self.up_mlp_layers) - 1:
prev_x =[prev_x, prev_pos], dim=-1)
# if without previous x, we only have the interpolated feature
cur_up_feats =[interpolated_feats, prev_x], dim=-1)
x = apply_module_with_conv2d_bn(
cur_up_feats.unsqueeze(2), up_conv_layers
pos = prev_pos
# # bsz x nnf x nnbasepts x nnbaseptsfeats #
# if return_global:
# return x, global_x, pos # pos, base_pts_feats #
# else:
return x, pos