import torch from torch import nn def knn(x, k): inner = -2*torch.matmul(x.transpose(2, 1), x) xx = torch.sum(x**2, dim=1, keepdim=True) pairwise_distance = -xx - inner - xx.transpose(2, 1) idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) return idx, pairwise_distance def local_operator(x, k): batch_size = x.size(0) num_points = x.size(2) x = x.view(batch_size, -1, num_points) idx, _ = knn(x, k=k) device = torch.device('cpu') idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points idx = idx + idx_base idx = idx.view(-1) _, num_dims, _ = x.size() x = x.transpose(2, 1).contiguous() neighbor = x.view(batch_size * num_points, -1)[idx, :] neighbor = neighbor.view(batch_size, num_points, k, num_dims) x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) feature = torch.cat((neighbor-x, neighbor), dim=3).permute(0, 3, 1, 2) # local and global all in return feature def local_operator_withnorm(x, norm_plt, k): batch_size = x.size(0) num_points = x.size(2) x = x.view(batch_size, -1, num_points) norm_plt = norm_plt.view(batch_size, -1, num_points) idx, _ = knn(x, k=k) # (batch_size, num_points, k) device = torch.device('cpu') idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points idx = idx + idx_base idx = idx.view(-1) _, num_dims, _ = x.size() x = x.transpose(2, 1).contiguous() norm_plt = norm_plt.transpose(2, 1).contiguous() neighbor = x.view(batch_size * num_points, -1)[idx, :] neighbor_norm = norm_plt.view(batch_size * num_points, -1)[idx, :] neighbor = neighbor.view(batch_size, num_points, k, num_dims) neighbor_norm = neighbor_norm.view(batch_size, num_points, k, num_dims) x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) feature = torch.cat((neighbor-x, neighbor, neighbor_norm), dim=3).permute(0, 3, 1, 2) # 3c return feature def GDM(x, M): """ Geometry-Disentangle Module M: number of disentangled points in both sharp and gentle variation components """ k = 64 # number of neighbors to decide the range of j in Eq.(5) tau = 0.2 # threshold in Eq.(2) sigma = 2 # parameters of f (Gaussian function in Eq.(2)) ############### """Graph Construction:""" device = torch.device('cpu') batch_size = x.size(0) num_points = x.size(2) x = x.view(batch_size, -1, num_points) idx, p = knn(x, k=k) # p: -[(x1-x2)^2+...] # here we add a tau p1 = torch.abs(p) p1 = torch.sqrt(p1) mask = p1 < tau # here we add a sigma p = p / (sigma * sigma) w = torch.exp(p) # b,n,n w = torch.mul(mask.float(), w) b = 1/torch.sum(w, dim=1) b = b.reshape(batch_size, num_points, 1).repeat(1, 1, num_points) c = torch.eye(num_points, num_points, device=device) c = c.expand(batch_size, num_points, num_points) D = b * c # b,n,n A = torch.matmul(D, w) # normalized adjacency matrix A_hat # Get Aij in a local area: idx2 = idx.view(batch_size * num_points, -1) idx_base2 = torch.arange(0, batch_size * num_points, device=device).view(-1, 1) * num_points idx2 = idx2 + idx_base2 idx2 = idx2.reshape(batch_size * num_points, k)[:, 1:k] idx2 = idx2.reshape(batch_size * num_points * (k - 1)) idx2 = idx2.view(-1) A = A.view(-1) A = A[idx2].reshape(batch_size, num_points, k - 1) # Aij: b,n,k ############### """Disentangling Point Clouds into Sharp(xs) and Gentle(xg) Variation Components:""" idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points idx = idx + idx_base idx = idx.reshape(batch_size * num_points, k)[:, 1:k] idx = idx.reshape(batch_size * num_points * (k - 1)) _, num_dims, _ = x.size() x = x.transpose(2, 1).contiguous() # b,n,c neighbor = x.view(batch_size * num_points, -1)[idx, :] neighbor = neighbor.view(batch_size, num_points, k - 1, num_dims) # b,n,k,c A = A.reshape(batch_size, num_points, k - 1, 1) # b,n,k,1 n = A.mul(neighbor) # b,n,k,c n = torch.sum(n, dim=2) # b,n,c pai = torch.norm(x - n, dim=-1).pow(2) # Eq.(5) pais = pai.topk(k=M, dim=-1)[1] # first M points as the sharp variation component paig = (-pai).topk(k=M, dim=-1)[1] # last M points as the gentle variation component pai_base = torch.arange(0, batch_size, device=device).view(-1, 1) * num_points indices = (pais + pai_base).view(-1) indiceg = (paig + pai_base).view(-1) xs = x.view(batch_size * num_points, -1)[indices, :] xg = x.view(batch_size * num_points, -1)[indiceg, :] xs = xs.view(batch_size, M, -1) # b,M,c xg = xg.view(batch_size, M, -1) # b,M,c return xs, xg class SGCAM(nn.Module): """Sharp-Gentle Complementary Attention Module:""" def __init__(self, in_channels, inter_channels=None, bn_layer=True): super(SGCAM, self).__init__() self.in_channels = in_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 conv_nd = nn.Conv1d bn = nn.BatchNorm1d self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) nn.init.constant(self.W[1].weight, 0) nn.init.constant(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant(self.W.weight, 0) nn.init.constant(self.W.bias, 0) self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) def forward(self, x, x_2): batch_size = x.size(0) g_x = self.g(x_2).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) theta_x = theta_x.permute(0, 2, 1) phi_x = self.phi(x_2).view(batch_size, self.inter_channels, -1) W = torch.matmul(theta_x, phi_x) # Attention Matrix N = W.size(-1) W_div_C = W / N y = torch.matmul(W_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) y = W_y + x return y