djl234 commited on
Commit
df4f158
1 Parent(s): 6056c5f

Create new file

Browse files
Files changed (1) hide show
  1. Intra_MLP.py +90 -0
Intra_MLP.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy
3
+ #from transformer import Local_Attention,Transformer_1
4
+ # codes of this function are borrowed from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/models/pointnet2_utils.py
5
+ def index_points(device, points, idx):
6
+ """
7
+
8
+ Input:
9
+ points: input points data, [B, N, C]
10
+ idx: sample index data, [B, S]
11
+ Return:
12
+ new_points:, indexed points data, [B, S, C]
13
+ """
14
+ B = points.shape[0]
15
+ view_shape = list(idx.shape)
16
+ view_shape[1:] = [1] * (len(view_shape) - 1)
17
+ repeat_shape = list(idx.shape)
18
+ repeat_shape[0] = 1
19
+ # batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
20
+ batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
21
+ new_points = points[batch_indices, idx, :]
22
+ return new_points
23
+
24
+ def knn_l2(device, net, k, u):
25
+ '''
26
+ Input:
27
+ k: int32, number of k in k-nn search
28
+ net: (batch_size, npoint, c) float32 array, points
29
+ u: int32, block size
30
+ Output:
31
+ idx: (batch_size, npoint, k) int32 array, indices to input points
32
+ '''
33
+ INF = 1e8
34
+ batch_size = net.size(0)
35
+ npoint = net.size(1)
36
+ n_channel = net.size(2)
37
+
38
+ square = torch.pow(torch.norm(net, dim=2,keepdim=True),2)
39
+
40
+ def u_block(batch_size, npoint, u):
41
+ block = numpy.zeros([batch_size, npoint, npoint])
42
+ n = npoint // u
43
+ for i in range(n):
44
+ block[:, (i*u):(i*u+u), (i*u):(i*u+u)] = numpy.ones([batch_size, u, u]) * (-INF)
45
+ return block
46
+
47
+ # minus_distance = 2 * torch.matmul(net, net.transpose(2,1)) - square - square.transpose(2,1) + torch.Tensor(u_block(batch_size, npoint, u)).to(device)
48
+ minus_distance = 2 * torch.matmul(net, net.transpose(2,1)) - square - square.transpose(2,1) + torch.Tensor(u_block(batch_size, npoint, u)).to(device)
49
+ _, indices = torch.topk(minus_distance, k, largest=True, sorted=False)
50
+
51
+ return indices
52
+
53
+ if __name__ == '__main__':
54
+
55
+ bs,gs,k=5,5,4
56
+
57
+ A=torch.rand(bs*gs,512,14,14).cuda()
58
+ net=Transformer_1(512,4,4,782).cuda()
59
+ Y=net(A)
60
+ print(Y.shape)
61
+ exit(0)
62
+ feature_map_size=A.shape[-1]
63
+ point = A.permute(0,2,1,3,4).reshape(A.size(0), A.size(1)*A.shape[-1]*A.shape[-2], -1)
64
+ point = point.permute(0,2,1)
65
+ X=point
66
+ print(point.shape)
67
+ idx = knn_l2(0, point, 4, 1)
68
+ #print(idx)
69
+
70
+ feat=idx
71
+ new_point = index_points(0, point,idx)
72
+
73
+ group_point = new_point.permute(0, 3, 2, 1)
74
+ print(group_point.shape)
75
+ _1,_2,_3,_4=group_point.shape
76
+ X=X.permute(0,2,1)
77
+ print(X.shape)
78
+ #torch.cat([group_point.reshape(_1*_2,k,_4),X.reshape(_1*_2,1,_4)],dim=1).permute(0,2,1)
79
+ attn_map=X.reshape(_1*_2,1,_4)@torch.cat([group_point.reshape(_1*_2,k,_4),X.reshape(_1*_2,1,_4)],dim=1).permute(0,2,1)
80
+ V=torch.cat([group_point.reshape(_1*_2,k,_4),X.reshape(_1*_2,1,_4)],dim=1)
81
+ print(attn_map.shape)
82
+ Y=attn_map@V
83
+ Y=Y.reshape(_1,_2,_4)
84
+
85
+ #group_point = torch.max(group_point, 2)[0] # [B, D', S]
86
+ group_point=Y
87
+ print(group_point.shape)
88
+
89
+ intra_mask = group_point.view(bs,gs, group_point.size(2), feature_map_size, feature_map_size)
90
+ print(intra_mask.shape)