import torch import numpy #from transformer import Local_Attention,Transformer_1 # codes of this function are borrowed from def index_points(device, points, idx): """ Input: points: input points data, [B, N, C] idx: sample index data, [B, S] Return: new_points:, indexed points data, [B, S, C] """ B = points.shape[0] view_shape = list(idx.shape) view_shape[1:] = [1] * (len(view_shape) - 1) repeat_shape = list(idx.shape) repeat_shape[0] = 1 # batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) new_points = points[batch_indices, idx, :] return new_points def knn_l2(device, net, k, u): ''' Input: k: int32, number of k in k-nn search net: (batch_size, npoint, c) float32 array, points u: int32, block size Output: idx: (batch_size, npoint, k) int32 array, indices to input points ''' INF = 1e8 batch_size = net.size(0) npoint = net.size(1) n_channel = net.size(2) square = torch.pow(torch.norm(net, dim=2,keepdim=True),2) def u_block(batch_size, npoint, u): block = numpy.zeros([batch_size, npoint, npoint]) n = npoint // u for i in range(n): block[:, (i*u):(i*u+u), (i*u):(i*u+u)] = numpy.ones([batch_size, u, u]) * (-INF) return block # minus_distance = 2 * torch.matmul(net, net.transpose(2,1)) - square - square.transpose(2,1) + torch.Tensor(u_block(batch_size, npoint, u)).to(device) minus_distance = 2 * torch.matmul(net, net.transpose(2,1)) - square - square.transpose(2,1) + torch.Tensor(u_block(batch_size, npoint, u)).to(device) _, indices = torch.topk(minus_distance, k, largest=True, sorted=False) return indices if __name__ == '__main__': bs,gs,k=5,5,4 A=torch.rand(bs*gs,512,14,14).cuda() net=Transformer_1(512,4,4,782).cuda() Y=net(A) print(Y.shape) exit(0) feature_map_size=A.shape[-1] point = A.permute(0,2,1,3,4).reshape(A.size(0), A.size(1)*A.shape[-1]*A.shape[-2], -1) point = point.permute(0,2,1) X=point print(point.shape) idx = knn_l2(0, point, 4, 1) #print(idx) feat=idx new_point = index_points(0, point,idx) group_point = new_point.permute(0, 3, 2, 1) print(group_point.shape) _1,_2,_3,_4=group_point.shape X=X.permute(0,2,1) print(X.shape)[group_point.reshape(_1*_2,k,_4),X.reshape(_1*_2,1,_4)],dim=1).permute(0,2,1) attn_map=X.reshape(_1*_2,1,_4)[group_point.reshape(_1*_2,k,_4),X.reshape(_1*_2,1,_4)],dim=1).permute(0,2,1)[group_point.reshape(_1*_2,k,_4),X.reshape(_1*_2,1,_4)],dim=1) print(attn_map.shape) Y=attn_map@V Y=Y.reshape(_1,_2,_4) #group_point = torch.max(group_point, 2)[0] # [B, D', S] group_point=Y print(group_point.shape) intra_mask = group_point.view(bs,gs, group_point.size(2), feature_map_size, feature_map_size) print(intra_mask.shape)