Spaces:
Running
on
Zero
Running
on
Zero
Update MT.py
Browse files
MT.py
CHANGED
@@ -285,10 +285,11 @@ class FeatureTransformer(nn.Module):
|
|
285 |
for i in range(self.num_layers):
|
286 |
value, att, attn_viz = self.layers(att=att, value=value, shape=[h, w], iteration=i)
|
287 |
value_decode = self.normalize(torch.square(self.re_proj(value))) # map to motion energy, Do use normalization here
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
|
|
292 |
return feature_list, attn_list, attn_viz_list
|
293 |
|
294 |
def forward_save_mem(self, feature0, add_position_embedding=True):
|
|
|
285 |
for i in range(self.num_layers):
|
286 |
value, att, attn_viz = self.layers(att=att, value=value, shape=[h, w], iteration=i)
|
287 |
value_decode = self.normalize(torch.square(self.re_proj(value))) # map to motion energy, Do use normalization here
|
288 |
+
if i % 2 == 0:
|
289 |
+
attn_viz_list.append(attn_viz.reshape(b, h, w, h, w))
|
290 |
+
attn_list.append(att.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
|
291 |
+
feature_list.append(value_decode.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
|
292 |
+
|
293 |
return feature_list, attn_list, attn_viz_list
|
294 |
|
295 |
def forward_save_mem(self, feature0, add_position_embedding=True):
|