vil-tracker / vil_tracker /models /backbone.py
omar-ah's picture
Sequence training: pairs→K-frame clips, mLSTM memory carries across frames
a4d3af5 verified
"""
ViL (Vision-LSTM) Backbone for single object tracking.
Architecture:
- Patch embedding (Conv2d) for template + search region
- Stack of mLSTM blocks with bidirectional scanning (even=L→R, odd=R→L)
- FiLM temporal modulation integrated BETWEEN blocks (at interval=6)
- Optional TMoE-MLP in last N blocks (dense routing, frozen shared expert)
- Outputs concatenated template+search features for head processing
ViL-S config: dim=384, depth=24, patch_size=16, ~23M backbone params
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .mlstm import mLSTMBlock, SwiGLUMLP, StochasticDepth
class PatchEmbed(nn.Module):
"""Convert image patches to token embeddings using Conv2d."""
def __init__(self, patch_size: int = 16, in_channels: int = 3, dim: int = 384):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C, H, W) image tensor
Returns:
(B, N, D) patch token embeddings, N = (H/P)*(W/P)
"""
x = self.proj(x) # (B, D, H/P, W/P)
x = rearrange(x, 'b d h w -> b (h w) d')
x = self.norm(x)
return x
class TMoEMLP(nn.Module):
"""Temporal Mixture-of-Experts MLP.
Uses dense routing with a shared expert (frozen after Phase 1) and
K specialized experts. Output = shared_out + sum(gate_k * expert_k_out).
For tracking: experts specialize on different temporal dynamics
(fast motion, occlusion recovery, scale change).
"""
def __init__(
self,
dim: int = 384,
mlp_ratio: float = 4.0,
num_experts: int = 4,
bias: bool = False,
):
super().__init__()
self.num_experts = num_experts
hidden_dim = int(dim * mlp_ratio)
# Shared expert (frozen after Phase 1 training)
self.shared_expert = SwiGLUMLP(dim=dim, mlp_ratio=mlp_ratio, bias=bias)
# Specialized experts (smaller: mlp_ratio/2)
small_ratio = mlp_ratio / 2
self.experts = nn.ModuleList([
SwiGLUMLP(dim=dim, mlp_ratio=small_ratio, bias=bias)
for _ in range(num_experts)
])
# Dense router: soft gating over experts
self.router = nn.Linear(dim, num_experts, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Shared expert output (always contributes)
shared_out = self.shared_expert(x)
# Router logits and softmax gates
gates = F.softmax(self.router(x), dim=-1) # (B, S, num_experts)
# Expert outputs, weighted by gates
expert_out = torch.zeros_like(shared_out)
for i, expert in enumerate(self.experts):
expert_out = expert_out + gates[..., i:i+1] * expert(x)
return shared_out + expert_out
def freeze_shared_expert(self):
"""Freeze the shared expert for Phase 2 training."""
for p in self.shared_expert.parameters():
p.requires_grad = False
class mLSTMBlockWithTMoE(nn.Module):
"""mLSTM block with TMoE MLP instead of standard SwiGLU MLP."""
def __init__(
self,
dim: int = 384,
proj_factor: float = 2.0,
qkv_proj_blocksize: int = 4,
num_heads: int = 4,
conv_kernel: int = 4,
mlp_ratio: float = 4.0,
drop_path: float = 0.0,
num_experts: int = 4,
bias: bool = False,
):
super().__init__()
from .mlstm import mLSTMCell
self.norm1 = nn.LayerNorm(dim, bias=False)
self.mlstm = mLSTMCell(
dim=dim,
proj_factor=proj_factor,
qkv_proj_blocksize=qkv_proj_blocksize,
num_heads=num_heads,
conv_kernel=conv_kernel,
bias=bias,
)
self.norm2 = nn.LayerNorm(dim, bias=False)
self.mlp = TMoEMLP(dim=dim, mlp_ratio=mlp_ratio, num_experts=num_experts, bias=bias)
self.drop_path = StochasticDepth(drop_path)
def forward(self, x: torch.Tensor, reverse: bool = False) -> torch.Tensor:
x = x + self.drop_path(self.mlstm(self.norm1(x), reverse=reverse))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def freeze_shared_expert(self):
self.mlp.freeze_shared_expert()
class ViLBackbone(nn.Module):
"""Vision-LSTM backbone for tracking with sequential multi-frame processing.
Processes template + K search frames as one long mLSTM sequence:
[template_tokens | search_1_tokens | search_2_tokens | ... | search_K_tokens]
The mLSTM memory state C carries information across frames:
- Template tokens establish the target appearance in memory
- Search_1 tokens are processed with template context in memory
- Search_2 tokens are processed with template + search_1 context, etc.
This is the core advantage over ViT: temporal information accumulates
in the recurrent memory state, not through attention over all tokens.
Token counts:
Template: 128x128 → 8x8 = 64 tokens
Each search: 256x256 → 16x16 = 256 tokens
K=3 sequence: 64 + 3×256 = 832 tokens
Bidirectional scanning: even blocks L→R, odd blocks R→L.
FiLM modulation: applied between blocks at interval=6.
TMoE: last `tmoe_blocks` blocks.
"""
def __init__(
self,
dim: int = 384,
depth: int = 24,
patch_size: int = 16,
in_channels: int = 3,
proj_factor: float = 2.0,
qkv_proj_blocksize: int = 4,
num_heads: int = 4,
conv_kernel: int = 4,
mlp_ratio: float = 4.0,
drop_path_rate: float = 0.05,
tmoe_blocks: int = 2,
num_experts: int = 4,
bias: bool = False,
film_interval: int = 6,
):
super().__init__()
self.dim = dim
self.depth = depth
self.patch_size = patch_size
self.film_interval = film_interval
# Patch embedding
self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, dim=dim)
# Positional embeddings for template and search regions
# Template: 128/16 = 8x8 = 64 tokens
# Search: 256/16 = 16x16 = 256 tokens
self.template_pos = nn.Parameter(torch.randn(1, 64, dim) * 0.02)
self.search_pos = nn.Parameter(torch.randn(1, 256, dim) * 0.02)
# Token type embeddings (template vs search)
self.template_type = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
self.search_type = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
# Stochastic depth rates (linearly increasing)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
# Build blocks: last `tmoe_blocks` use TMoE MLP
self.blocks = nn.ModuleList()
for i in range(depth):
if i >= depth - tmoe_blocks:
block = mLSTMBlockWithTMoE(
dim=dim, proj_factor=proj_factor,
qkv_proj_blocksize=qkv_proj_blocksize,
num_heads=num_heads, conv_kernel=conv_kernel,
mlp_ratio=mlp_ratio, drop_path=dpr[i],
num_experts=num_experts, bias=bias,
)
else:
block = mLSTMBlock(
dim=dim, proj_factor=proj_factor,
qkv_proj_blocksize=qkv_proj_blocksize,
num_heads=num_heads, conv_kernel=conv_kernel,
mlp_ratio=mlp_ratio, drop_path=dpr[i], bias=bias,
)
self.blocks.append(block)
# Final norm
self.norm = nn.LayerNorm(dim, bias=False)
def forward(
self,
template: torch.Tensor,
searches: torch.Tensor,
temporal_mod_manager=None,
) -> tuple:
"""
Process template + K search frames as one mLSTM sequence.
Args:
template: (B, 3, 128, 128) template image
searches: (B, K, 3, 256, 256) K consecutive search frames
OR (B, 3, 256, 256) single search frame (backward compat)
temporal_mod_manager: optional TemporalModulationManager for FiLM
Returns:
template_feat: (B, 64, D) template features
search_feats: (B, K, 256, D) per-frame search features
OR (B, 256, D) if single search frame input
"""
B = template.shape[0]
single_frame = (searches.ndim == 4) # (B, 3, H, W) vs (B, K, 3, H, W)
if single_frame:
searches = searches.unsqueeze(1) # (B, 1, 3, H, W)
K = searches.shape[1]
# Patch embed template
t_tokens = self.patch_embed(template) # (B, 64, D)
t_tokens = t_tokens + self.template_pos + self.template_type
n_template = t_tokens.shape[1] # 64
# Patch embed all search frames
# Reshape (B, K, 3, H, W) → (B*K, 3, H, W) for batch patch embedding
s_flat = searches.reshape(B * K, *searches.shape[2:])
s_tokens_flat = self.patch_embed(s_flat) # (B*K, 256, D)
s_tokens = s_tokens_flat.reshape(B, K, -1, self.dim) # (B, K, 256, D)
s_tokens = s_tokens + self.search_pos.unsqueeze(1) + self.search_type
n_search = s_tokens.shape[2] # 256
# Build full sequence: [template | search_1 | search_2 | ... | search_K]
# The mLSTM memory carries information across this entire sequence
s_tokens_concat = s_tokens.reshape(B, K * n_search, self.dim) # (B, K*256, D)
tokens = torch.cat([t_tokens, s_tokens_concat], dim=1) # (B, 64 + K*256, D)
# Process through bidirectional mLSTM blocks
for i, block in enumerate(self.blocks):
reverse = (i % 2 == 1)
tokens = block(tokens, reverse=reverse)
if temporal_mod_manager is not None:
tokens = temporal_mod_manager.modulate(tokens, i)
tokens = self.norm(tokens)
if temporal_mod_manager is not None:
temporal_mod_manager.update_temporal_context(tokens)
# Split: template features + per-frame search features
template_feat = tokens[:, :n_template] # (B, 64, D)
search_tokens = tokens[:, n_template:] # (B, K*256, D)
search_feats = search_tokens.reshape(B, K, n_search, self.dim) # (B, K, 256, D)
if single_frame:
return template_feat, search_feats.squeeze(1) # (B, 256, D)
return template_feat, search_feats
def freeze_shared_experts(self):
"""Freeze shared experts in TMoE blocks for Phase 2 training."""
for block in self.blocks:
if hasattr(block, 'freeze_shared_expert'):
block.freeze_shared_expert()