djl234 commited on
Commit
edc384d
1 Parent(s): df4f158

Create new file

Browse files
Files changed (1) hide show
  1. transformer.py +207 -0
transformer.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from einops import rearrange
4
+ import numpy
5
+ def index_points(device, points, idx):
6
+ """
7
+
8
+ Input:
9
+ points: input points data, [B, N, C]
10
+ idx: sample index data, [B, S]
11
+ Return:
12
+ new_points:, indexed points data, [B, S, C]
13
+ """
14
+ B = points.shape[0]
15
+ view_shape = list(idx.shape)
16
+ view_shape[1:] = [1] * (len(view_shape) - 1)
17
+ repeat_shape = list(idx.shape)
18
+ repeat_shape[0] = 1
19
+ # batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
20
+ batch_indices = torch.arange(B, dtype=torch.long).cuda().view(view_shape).repeat(repeat_shape)
21
+ new_points = points[batch_indices, idx, :]
22
+ return new_points
23
+
24
+ def knn_l2(device, net, k, u):
25
+ '''
26
+ Input:
27
+ k: int32, number of k in k-nn search
28
+ net: (batch_size, npoint, c) float32 array, points
29
+ u: int32, block size
30
+ Output:
31
+ idx: (batch_size, npoint, k) int32 array, indices to input points
32
+ '''
33
+ INF = 1e8
34
+ batch_size = net.size(0)
35
+ npoint = net.size(1)
36
+ n_channel = net.size(2)
37
+
38
+ square = torch.pow(torch.norm(net, dim=2,keepdim=True),2)
39
+
40
+ def u_block(batch_size, npoint, u):
41
+ block = numpy.zeros([batch_size, npoint, npoint])
42
+ n = npoint // u
43
+ for i in range(n):
44
+ block[:, (i*u):(i*u+u), (i*u):(i*u+u)] = numpy.ones([batch_size, u, u]) * (-INF)
45
+ return block
46
+
47
+ # minus_distance = 2 * torch.matmul(net, net.transpose(2,1)) - square - square.transpose(2,1) + torch.Tensor(u_block(batch_size, npoint, u)).to(device)
48
+ minus_distance = 2 * torch.matmul(net, net.transpose(2,1)) - square - square.transpose(2,1) + torch.Tensor(u_block(batch_size, npoint, u)).cuda()
49
+ _, indices = torch.topk(minus_distance, k, largest=True, sorted=False)
50
+
51
+ return indices
52
+
53
+ class Residual(nn.Module):
54
+ def __init__(self, fn):
55
+ super().__init__()
56
+ self.fn = fn
57
+ def forward(self, x, **kwargs):
58
+ return self.fn(x, **kwargs) + x
59
+
60
+ class PreNorm(nn.Module):
61
+ def __init__(self, dim, fn):
62
+ super().__init__()
63
+ self.norm = nn.LayerNorm(dim)
64
+ self.fn = fn
65
+ def forward(self, x, **kwargs):
66
+ return self.fn(self.norm(x), **kwargs)
67
+
68
+ class FeedForward(nn.Module):
69
+ def __init__(self, dim, hidden_dim, dropout = 0.):
70
+ super().__init__()
71
+ self.net = nn.Sequential(
72
+ nn.Linear(dim, hidden_dim),
73
+ nn.GELU(),
74
+ nn.Dropout(dropout),
75
+ nn.Linear(hidden_dim, dim),
76
+ nn.Dropout(dropout)
77
+ )
78
+ def forward(self, x):
79
+ return self.net(x)
80
+
81
+ class Attention(nn.Module):
82
+ def __init__(self, dim, heads = 4, dropout = 0.):
83
+ super().__init__()
84
+ self.heads = heads
85
+ self.scale = dim ** -0.5
86
+
87
+ self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
88
+ self.to_out = nn.Sequential(
89
+ nn.Linear(dim, dim),
90
+ nn.Dropout(dropout)
91
+ )
92
+
93
+ def forward(self, x, mask = None):
94
+ b, n, _, h = *x.shape, self.heads
95
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
96
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
97
+
98
+ dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
99
+
100
+ if mask is not None:
101
+ mask = F.pad(mask.flatten(1), (1, 0), value = True)
102
+ assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
103
+ mask = mask[:, None, :] * mask[:, :, None]
104
+ dots.masked_fill_(~mask, float('-inf'))
105
+ del mask
106
+
107
+ attn = dots.softmax(dim=-1)
108
+
109
+ out = torch.einsum('bhij,bhjd->bhid', attn, v)
110
+ out = rearrange(out, 'b h n d -> b n (h d)')
111
+ out = self.to_out(out)
112
+ return out
113
+
114
+ class Local_Attention(nn.Module):
115
+ def __init__(self, dim, heads = 4,knn=4, dropout = 0.):
116
+ super().__init__()
117
+ self.heads = heads
118
+ self.scale = dim ** -0.5
119
+
120
+ #self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
121
+ self.q=nn.Linear(dim,dim,bias=False)
122
+ self.k=nn.Linear(dim,dim,bias=False)
123
+ self.v=nn.Linear(dim,dim,bias=False)
124
+
125
+ self.to_out = nn.Sequential(
126
+ nn.Linear(dim, dim),
127
+ nn.Dropout(dropout)
128
+ )
129
+ self.knn=knn
130
+
131
+ def forward(self, x, mask = None):
132
+ b, n, _, h = *x.shape, self.heads
133
+
134
+ point=x*1
135
+ X=x*1
136
+
137
+ idx = knn_l2(0, point.permute(0,2,1), 4, 1)
138
+ feat=idx
139
+ new_point = index_points(0, point.permute(0,2,1),idx)
140
+
141
+ group_point = new_point.permute(0, 3, 2, 1)
142
+
143
+ _1,_2,_3,_4=group_point.shape
144
+
145
+ q=self.q(X.reshape(_1*_2,1,_4))
146
+ k=self.k(torch.cat([group_point.reshape(_1*_2,self.knn,_4),X.reshape(_1*_2,1,_4)],dim=1))
147
+ v=self.v(torch.cat([group_point.reshape(_1*_2,self.knn,_4),X.reshape(_1*_2,1,_4)],dim=1))
148
+ q, k, v = rearrange(q, 'b n (h d) -> b h n d', h = h),rearrange(k, 'b n (h d) -> b h n d', h = h),rearrange(v, 'b n (h d) -> b h n d', h = h)
149
+
150
+ attn_map=q@k.permute(0,1,3,2)*self.scale
151
+ attn_map=attn_map.softmax(dim=-1)
152
+
153
+ out=attn_map@v
154
+ out=out.view(b,out.shape[0]//b,out.shape[1],out.shape[3]).permute(0,2,1,3)
155
+
156
+ out = rearrange(out, 'b h n d -> b n (h d)')
157
+ out = self.to_out(out)
158
+ return out
159
+
160
+
161
+ class Transformer(nn.Module):
162
+ def __init__(self, dim, depth, heads, mlp_dim, group=5, dropout=0.):
163
+ super().__init__()
164
+ self.layers = nn.ModuleList([])
165
+ for _ in range(depth):
166
+ self.layers.append(nn.ModuleList([
167
+ Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
168
+ Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
169
+ ]))
170
+ self.group=group
171
+ def forward(self, x, mask = None):
172
+ bs_gp,dim,wid,hei=x.shape[0],x.shape[1],x.shape[2],x.shape[3]
173
+ bs=bs_gp//self.group
174
+ gp=self.group
175
+ x=x.reshape(bs,gp,dim,wid,hei)
176
+ x=x.permute(0,1,3,4,2).reshape(bs,gp*wid*hei,dim)
177
+ for attn, ff in self.layers:
178
+ x = attn(x, mask = mask)
179
+ x = ff(x)
180
+
181
+ x=x.reshape(bs,gp,wid,hei,dim).permute(0,1,4,2,3).reshape(bs_gp,dim,wid,hei)
182
+
183
+ return x
184
+
185
+ class Transformer__local(nn.Module):
186
+ def __init__(self, dim, depth, heads, mlp_dim,knn_k=4, group=5, dropout=0.):
187
+ super().__init__()
188
+ self.layers = nn.ModuleList([])
189
+ for _ in range(depth):
190
+ self.layers.append(nn.ModuleList([
191
+ Residual(PreNorm(dim, Local_Attention(dim, heads = heads,knn=knn_k, dropout = dropout))),
192
+ Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
193
+ ]))
194
+ self.group=group
195
+ def forward(self, x, mask = None):
196
+ bs_gp,dim,wid,hei=x.shape[0],x.shape[1],x.shape[2],x.shape[3]
197
+ bs=bs_gp//self.group
198
+ gp=self.group
199
+ x=x.reshape(bs,gp,dim,wid,hei)
200
+ x=x.permute(0,1,3,4,2).reshape(bs,gp*wid*hei,dim)
201
+ for attn, ff in self.layers:
202
+ x = attn(x, mask = mask)
203
+ x = ff(x)
204
+
205
+ x=x.reshape(bs,gp,wid,hei,dim).permute(0,1,4,2,3).reshape(bs_gp,dim,wid,hei)
206
+
207
+ return x