Sequence training: pairs→K-frame clips, mLSTM memory carries across frames
Browse files- vil_tracker/models/backbone.py +54 -28
vil_tracker/models/backbone.py
CHANGED
|
@@ -134,19 +134,27 @@ class mLSTMBlockWithTMoE(nn.Module):
|
|
| 134 |
|
| 135 |
|
| 136 |
class ViLBackbone(nn.Module):
|
| 137 |
-
"""Vision-LSTM backbone for tracking with
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
injected between blocks at regular intervals, then separates outputs.
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
Bidirectional scanning: even blocks L→R, odd blocks R→L.
|
| 148 |
-
|
| 149 |
-
|
| 150 |
"""
|
| 151 |
def __init__(
|
| 152 |
self,
|
|
@@ -213,52 +221,70 @@ class ViLBackbone(nn.Module):
|
|
| 213 |
def forward(
|
| 214 |
self,
|
| 215 |
template: torch.Tensor,
|
| 216 |
-
|
| 217 |
temporal_mod_manager=None,
|
| 218 |
) -> tuple:
|
| 219 |
"""
|
|
|
|
|
|
|
| 220 |
Args:
|
| 221 |
template: (B, 3, 128, 128) template image
|
| 222 |
-
|
|
|
|
| 223 |
temporal_mod_manager: optional TemporalModulationManager for FiLM
|
| 224 |
Returns:
|
| 225 |
template_feat: (B, 64, D) template features
|
| 226 |
-
|
|
|
|
| 227 |
"""
|
| 228 |
B = template.shape[0]
|
|
|
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
s_tokens = self.patch_embed(search) # (B, 256, D)
|
| 233 |
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
| 235 |
t_tokens = t_tokens + self.template_pos + self.template_type
|
| 236 |
-
|
| 237 |
|
| 238 |
-
#
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
for i, block in enumerate(self.blocks):
|
| 244 |
-
reverse = (i % 2 == 1)
|
| 245 |
tokens = block(tokens, reverse=reverse)
|
| 246 |
|
| 247 |
-
# Apply FiLM temporal modulation between blocks
|
| 248 |
if temporal_mod_manager is not None:
|
| 249 |
tokens = temporal_mod_manager.modulate(tokens, i)
|
| 250 |
|
| 251 |
tokens = self.norm(tokens)
|
| 252 |
|
| 253 |
-
# Update temporal context after full forward pass
|
| 254 |
if temporal_mod_manager is not None:
|
| 255 |
temporal_mod_manager.update_temporal_context(tokens)
|
| 256 |
|
| 257 |
-
# Split
|
| 258 |
-
template_feat = tokens[:, :n_template]
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
-
return template_feat,
|
| 262 |
|
| 263 |
def freeze_shared_experts(self):
|
| 264 |
"""Freeze shared experts in TMoE blocks for Phase 2 training."""
|
|
|
|
| 134 |
|
| 135 |
|
| 136 |
class ViLBackbone(nn.Module):
|
| 137 |
+
"""Vision-LSTM backbone for tracking with sequential multi-frame processing.
|
| 138 |
|
| 139 |
+
Processes template + K search frames as one long mLSTM sequence:
|
| 140 |
+
[template_tokens | search_1_tokens | search_2_tokens | ... | search_K_tokens]
|
|
|
|
| 141 |
|
| 142 |
+
The mLSTM memory state C carries information across frames:
|
| 143 |
+
- Template tokens establish the target appearance in memory
|
| 144 |
+
- Search_1 tokens are processed with template context in memory
|
| 145 |
+
- Search_2 tokens are processed with template + search_1 context, etc.
|
| 146 |
+
|
| 147 |
+
This is the core advantage over ViT: temporal information accumulates
|
| 148 |
+
in the recurrent memory state, not through attention over all tokens.
|
| 149 |
+
|
| 150 |
+
Token counts:
|
| 151 |
+
Template: 128x128 → 8x8 = 64 tokens
|
| 152 |
+
Each search: 256x256 → 16x16 = 256 tokens
|
| 153 |
+
K=3 sequence: 64 + 3×256 = 832 tokens
|
| 154 |
|
| 155 |
Bidirectional scanning: even blocks L→R, odd blocks R→L.
|
| 156 |
+
FiLM modulation: applied between blocks at interval=6.
|
| 157 |
+
TMoE: last `tmoe_blocks` blocks.
|
| 158 |
"""
|
| 159 |
def __init__(
|
| 160 |
self,
|
|
|
|
| 221 |
def forward(
|
| 222 |
self,
|
| 223 |
template: torch.Tensor,
|
| 224 |
+
searches: torch.Tensor,
|
| 225 |
temporal_mod_manager=None,
|
| 226 |
) -> tuple:
|
| 227 |
"""
|
| 228 |
+
Process template + K search frames as one mLSTM sequence.
|
| 229 |
+
|
| 230 |
Args:
|
| 231 |
template: (B, 3, 128, 128) template image
|
| 232 |
+
searches: (B, K, 3, 256, 256) K consecutive search frames
|
| 233 |
+
OR (B, 3, 256, 256) single search frame (backward compat)
|
| 234 |
temporal_mod_manager: optional TemporalModulationManager for FiLM
|
| 235 |
Returns:
|
| 236 |
template_feat: (B, 64, D) template features
|
| 237 |
+
search_feats: (B, K, 256, D) per-frame search features
|
| 238 |
+
OR (B, 256, D) if single search frame input
|
| 239 |
"""
|
| 240 |
B = template.shape[0]
|
| 241 |
+
single_frame = (searches.ndim == 4) # (B, 3, H, W) vs (B, K, 3, H, W)
|
| 242 |
|
| 243 |
+
if single_frame:
|
| 244 |
+
searches = searches.unsqueeze(1) # (B, 1, 3, H, W)
|
|
|
|
| 245 |
|
| 246 |
+
K = searches.shape[1]
|
| 247 |
+
|
| 248 |
+
# Patch embed template
|
| 249 |
+
t_tokens = self.patch_embed(template) # (B, 64, D)
|
| 250 |
t_tokens = t_tokens + self.template_pos + self.template_type
|
| 251 |
+
n_template = t_tokens.shape[1] # 64
|
| 252 |
|
| 253 |
+
# Patch embed all search frames
|
| 254 |
+
# Reshape (B, K, 3, H, W) → (B*K, 3, H, W) for batch patch embedding
|
| 255 |
+
s_flat = searches.reshape(B * K, *searches.shape[2:])
|
| 256 |
+
s_tokens_flat = self.patch_embed(s_flat) # (B*K, 256, D)
|
| 257 |
+
s_tokens = s_tokens_flat.reshape(B, K, -1, self.dim) # (B, K, 256, D)
|
| 258 |
+
s_tokens = s_tokens + self.search_pos.unsqueeze(1) + self.search_type
|
| 259 |
+
n_search = s_tokens.shape[2] # 256
|
| 260 |
|
| 261 |
+
# Build full sequence: [template | search_1 | search_2 | ... | search_K]
|
| 262 |
+
# The mLSTM memory carries information across this entire sequence
|
| 263 |
+
s_tokens_concat = s_tokens.reshape(B, K * n_search, self.dim) # (B, K*256, D)
|
| 264 |
+
tokens = torch.cat([t_tokens, s_tokens_concat], dim=1) # (B, 64 + K*256, D)
|
| 265 |
+
|
| 266 |
+
# Process through bidirectional mLSTM blocks
|
| 267 |
for i, block in enumerate(self.blocks):
|
| 268 |
+
reverse = (i % 2 == 1)
|
| 269 |
tokens = block(tokens, reverse=reverse)
|
| 270 |
|
|
|
|
| 271 |
if temporal_mod_manager is not None:
|
| 272 |
tokens = temporal_mod_manager.modulate(tokens, i)
|
| 273 |
|
| 274 |
tokens = self.norm(tokens)
|
| 275 |
|
|
|
|
| 276 |
if temporal_mod_manager is not None:
|
| 277 |
temporal_mod_manager.update_temporal_context(tokens)
|
| 278 |
|
| 279 |
+
# Split: template features + per-frame search features
|
| 280 |
+
template_feat = tokens[:, :n_template] # (B, 64, D)
|
| 281 |
+
search_tokens = tokens[:, n_template:] # (B, K*256, D)
|
| 282 |
+
search_feats = search_tokens.reshape(B, K, n_search, self.dim) # (B, K, 256, D)
|
| 283 |
+
|
| 284 |
+
if single_frame:
|
| 285 |
+
return template_feat, search_feats.squeeze(1) # (B, 256, D)
|
| 286 |
|
| 287 |
+
return template_feat, search_feats
|
| 288 |
|
| 289 |
def freeze_shared_experts(self):
|
| 290 |
"""Freeze shared experts in TMoE blocks for Phase 2 training."""
|