sunana commited on
Commit
cb0ba4a
1 Parent(s): 342d9ff

Update MT.py

Browse files
Files changed (1) hide show
  1. MT.py +6 -15
MT.py CHANGED
@@ -131,7 +131,7 @@ class PositionEmbeddingSine(nn.Module):
131
  return pos
132
 
133
 
134
- def feature_add_position(feature0, feature_channels, scale=1.0):
135
  temp = torch.mean(abs(feature0))
136
  pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
137
  # position = PositionalEncodingPermute2D(feature_channels)(feature0)
@@ -223,8 +223,6 @@ class TransformerLayer(nn.Module):
223
  att = feature_add_position(att.transpose(-1, -2).view(
224
  B, C, shape[0], shape[1]), C).reshape(B, C, -1).transpose(-1, -2)
225
 
226
- # att = feature_add_position(att.transpose(-1, -2).view(
227
- # B, C, shape[0], shape[1]), C).reshape(B, C, -1).transpose(-1, -2)
228
  val_proj = self.v_proj(value)
229
  att_proj = self.att_proj(att) # [B, L, C]
230
  norm_fac = torch.sum(att_proj ** 2, dim=-1, keepdim=True) ** 0.5
@@ -237,7 +235,6 @@ class TransformerLayer(nn.Module):
237
  D = 1 / (torch.sqrt(D) + 1e-6) # normalized node degrees
238
  A = D * A * D.transpose(-1, -2)
239
 
240
- # A = torch.softmax(A , dim=2) # [B, L, L]
241
  message = torch.matmul(A, val_proj) # [B, L, C]
242
 
243
  message = self.merge(message) # [B, L, C]
@@ -246,9 +243,6 @@ class TransformerLayer(nn.Module):
246
  message = self.mlp(torch.cat([value, message], dim=-1))
247
  message = self.norm2(message)
248
 
249
- # if iteration > 2:
250
- # message = self.drop(message)
251
-
252
  att = self.attn_updater(att, message, shape)
253
  value = self.gru(value, message, shape)
254
  return value, att, A
@@ -290,14 +284,11 @@ class FeatureTransformer(nn.Module):
290
  att = att.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
291
  for i in range(self.num_layers):
292
  value, att, attn_viz = self.layers(att=att, value=value, shape=[h, w], iteration=i)
293
- attn_viz = attn_viz.reshape(b, h, w, h, w)
294
- attn_viz_list.append(attn_viz)
295
- value_decode = self.normalize(
296
- torch.square(self.re_proj(value))) # map to motion energy, Do use normalization here
297
- # print("value_decode",value_decode.abs().mean())
298
- attn_list.append(att.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
299
- feature_list.append(value_decode.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
300
- # reshape back
301
  return feature_list, attn_list, attn_viz_list
302
 
303
  def forward_save_mem(self, feature0, add_position_embedding=True):
 
131
  return pos
132
 
133
 
134
+ def feature_add_position(feature0, feature_channels, scale=0.5):
135
  temp = torch.mean(abs(feature0))
136
  pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
137
  # position = PositionalEncodingPermute2D(feature_channels)(feature0)
 
223
  att = feature_add_position(att.transpose(-1, -2).view(
224
  B, C, shape[0], shape[1]), C).reshape(B, C, -1).transpose(-1, -2)
225
 
 
 
226
  val_proj = self.v_proj(value)
227
  att_proj = self.att_proj(att) # [B, L, C]
228
  norm_fac = torch.sum(att_proj ** 2, dim=-1, keepdim=True) ** 0.5
 
235
  D = 1 / (torch.sqrt(D) + 1e-6) # normalized node degrees
236
  A = D * A * D.transpose(-1, -2)
237
 
 
238
  message = torch.matmul(A, val_proj) # [B, L, C]
239
 
240
  message = self.merge(message) # [B, L, C]
 
243
  message = self.mlp(torch.cat([value, message], dim=-1))
244
  message = self.norm2(message)
245
 
 
 
 
246
  att = self.attn_updater(att, message, shape)
247
  value = self.gru(value, message, shape)
248
  return value, att, A
 
284
  att = att.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
285
  for i in range(self.num_layers):
286
  value, att, attn_viz = self.layers(att=att, value=value, shape=[h, w], iteration=i)
287
+ value_decode = self.normalize(torch.square(self.re_proj(value))) # map to motion energy, Do use normalization here
288
+
289
+ attn_viz_list.append(attn_viz.reshape(b, h, w, h, w))
290
+ attn_list.append(att.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
291
+ feature_list.append(value_decode.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
 
 
 
292
  return feature_list, attn_list, attn_viz_list
293
 
294
  def forward_save_mem(self, feature0, add_position_embedding=True):