omar-ah commited on
Commit
a4d3af5
·
verified ·
1 Parent(s): f871a5c

Sequence training: pairs→K-frame clips, mLSTM memory carries across frames

Browse files
Files changed (1) hide show
  1. vil_tracker/models/backbone.py +54 -28
vil_tracker/models/backbone.py CHANGED
@@ -134,19 +134,27 @@ class mLSTMBlockWithTMoE(nn.Module):
134
 
135
 
136
  class ViLBackbone(nn.Module):
137
- """Vision-LSTM backbone for tracking with integrated FiLM temporal modulation.
138
 
139
- Concatenates template + search patches into a single sequence,
140
- processes through bidirectional mLSTM blocks with FiLM modulation
141
- injected between blocks at regular intervals, then separates outputs.
142
 
143
- Template: 128x128 8x8 = 64 tokens
144
- Search: 256x256 16x16 = 256 tokens
145
- Total sequence: 320 tokens
 
 
 
 
 
 
 
 
 
146
 
147
  Bidirectional scanning: even blocks L→R, odd blocks R→L.
148
- Last `tmoe_blocks` blocks use TMoE MLP for temporal specialization.
149
- FiLM modulation: applied after every `film_interval`-th block.
150
  """
151
  def __init__(
152
  self,
@@ -213,52 +221,70 @@ class ViLBackbone(nn.Module):
213
  def forward(
214
  self,
215
  template: torch.Tensor,
216
- search: torch.Tensor,
217
  temporal_mod_manager=None,
218
  ) -> tuple:
219
  """
 
 
220
  Args:
221
  template: (B, 3, 128, 128) template image
222
- search: (B, 3, 256, 256) search region image
 
223
  temporal_mod_manager: optional TemporalModulationManager for FiLM
224
  Returns:
225
  template_feat: (B, 64, D) template features
226
- search_feat: (B, 256, D) search features
 
227
  """
228
  B = template.shape[0]
 
229
 
230
- # Patch embed
231
- t_tokens = self.patch_embed(template) # (B, 64, D)
232
- s_tokens = self.patch_embed(search) # (B, 256, D)
233
 
234
- # Add positional + type embeddings
 
 
 
235
  t_tokens = t_tokens + self.template_pos + self.template_type
236
- s_tokens = s_tokens + self.search_pos + self.search_type
237
 
238
- # Concatenate: [template | search]
239
- tokens = torch.cat([t_tokens, s_tokens], dim=1) # (B, 320, D)
240
- n_template = t_tokens.shape[1]
 
 
 
 
241
 
242
- # Process through bidirectional mLSTM blocks with optional FiLM
 
 
 
 
 
243
  for i, block in enumerate(self.blocks):
244
- reverse = (i % 2 == 1) # odd blocks: R→L
245
  tokens = block(tokens, reverse=reverse)
246
 
247
- # Apply FiLM temporal modulation between blocks
248
  if temporal_mod_manager is not None:
249
  tokens = temporal_mod_manager.modulate(tokens, i)
250
 
251
  tokens = self.norm(tokens)
252
 
253
- # Update temporal context after full forward pass
254
  if temporal_mod_manager is not None:
255
  temporal_mod_manager.update_temporal_context(tokens)
256
 
257
- # Split back
258
- template_feat = tokens[:, :n_template]
259
- search_feat = tokens[:, n_template:]
 
 
 
 
260
 
261
- return template_feat, search_feat
262
 
263
  def freeze_shared_experts(self):
264
  """Freeze shared experts in TMoE blocks for Phase 2 training."""
 
134
 
135
 
136
  class ViLBackbone(nn.Module):
137
+ """Vision-LSTM backbone for tracking with sequential multi-frame processing.
138
 
139
+ Processes template + K search frames as one long mLSTM sequence:
140
+ [template_tokens | search_1_tokens | search_2_tokens | ... | search_K_tokens]
 
141
 
142
+ The mLSTM memory state C carries information across frames:
143
+ - Template tokens establish the target appearance in memory
144
+ - Search_1 tokens are processed with template context in memory
145
+ - Search_2 tokens are processed with template + search_1 context, etc.
146
+
147
+ This is the core advantage over ViT: temporal information accumulates
148
+ in the recurrent memory state, not through attention over all tokens.
149
+
150
+ Token counts:
151
+ Template: 128x128 → 8x8 = 64 tokens
152
+ Each search: 256x256 → 16x16 = 256 tokens
153
+ K=3 sequence: 64 + 3×256 = 832 tokens
154
 
155
  Bidirectional scanning: even blocks L→R, odd blocks R→L.
156
+ FiLM modulation: applied between blocks at interval=6.
157
+ TMoE: last `tmoe_blocks` blocks.
158
  """
159
  def __init__(
160
  self,
 
221
  def forward(
222
  self,
223
  template: torch.Tensor,
224
+ searches: torch.Tensor,
225
  temporal_mod_manager=None,
226
  ) -> tuple:
227
  """
228
+ Process template + K search frames as one mLSTM sequence.
229
+
230
  Args:
231
  template: (B, 3, 128, 128) template image
232
+ searches: (B, K, 3, 256, 256) K consecutive search frames
233
+ OR (B, 3, 256, 256) single search frame (backward compat)
234
  temporal_mod_manager: optional TemporalModulationManager for FiLM
235
  Returns:
236
  template_feat: (B, 64, D) template features
237
+ search_feats: (B, K, 256, D) per-frame search features
238
+ OR (B, 256, D) if single search frame input
239
  """
240
  B = template.shape[0]
241
+ single_frame = (searches.ndim == 4) # (B, 3, H, W) vs (B, K, 3, H, W)
242
 
243
+ if single_frame:
244
+ searches = searches.unsqueeze(1) # (B, 1, 3, H, W)
 
245
 
246
+ K = searches.shape[1]
247
+
248
+ # Patch embed template
249
+ t_tokens = self.patch_embed(template) # (B, 64, D)
250
  t_tokens = t_tokens + self.template_pos + self.template_type
251
+ n_template = t_tokens.shape[1] # 64
252
 
253
+ # Patch embed all search frames
254
+ # Reshape (B, K, 3, H, W) (B*K, 3, H, W) for batch patch embedding
255
+ s_flat = searches.reshape(B * K, *searches.shape[2:])
256
+ s_tokens_flat = self.patch_embed(s_flat) # (B*K, 256, D)
257
+ s_tokens = s_tokens_flat.reshape(B, K, -1, self.dim) # (B, K, 256, D)
258
+ s_tokens = s_tokens + self.search_pos.unsqueeze(1) + self.search_type
259
+ n_search = s_tokens.shape[2] # 256
260
 
261
+ # Build full sequence: [template | search_1 | search_2 | ... | search_K]
262
+ # The mLSTM memory carries information across this entire sequence
263
+ s_tokens_concat = s_tokens.reshape(B, K * n_search, self.dim) # (B, K*256, D)
264
+ tokens = torch.cat([t_tokens, s_tokens_concat], dim=1) # (B, 64 + K*256, D)
265
+
266
+ # Process through bidirectional mLSTM blocks
267
  for i, block in enumerate(self.blocks):
268
+ reverse = (i % 2 == 1)
269
  tokens = block(tokens, reverse=reverse)
270
 
 
271
  if temporal_mod_manager is not None:
272
  tokens = temporal_mod_manager.modulate(tokens, i)
273
 
274
  tokens = self.norm(tokens)
275
 
 
276
  if temporal_mod_manager is not None:
277
  temporal_mod_manager.update_temporal_context(tokens)
278
 
279
+ # Split: template features + per-frame search features
280
+ template_feat = tokens[:, :n_template] # (B, 64, D)
281
+ search_tokens = tokens[:, n_template:] # (B, K*256, D)
282
+ search_feats = search_tokens.reshape(B, K, n_search, self.dim) # (B, K, 256, D)
283
+
284
+ if single_frame:
285
+ return template_feat, search_feats.squeeze(1) # (B, 256, D)
286
 
287
+ return template_feat, search_feats
288
 
289
  def freeze_shared_experts(self):
290
  """Freeze shared experts in TMoE blocks for Phase 2 training."""