vil-tracker / test_all.py
omar-ah's picture
Sequence training: pairs→K-frame clips, mLSTM memory carries across frames
4ba026e verified
"""
Comprehensive test suite for ViL Tracker.
16 tests covering all components:
1. mLSTM Cell (LinearHeadwiseExpand correctness + param count)
2. mLSTM Block (full block without MLP)
3. TMoE MLP
4. Backbone (standard, small depth)
5. Backbone (with TMoE + integrated FiLM, medium depth)
6. Prediction Heads
7. FiLM Temporal Modulation
8. Full Tracker (small depth for speed)
9. Loss Functions (all 6)
10. Kalman Filter (8-state, adaptive)
11. Dataset (synthetic)
12. Training Step (mini forward + backward with temporal)
13. Model Summary (FULL depth=24, constraint check)
14. Online Tracker (full inference pipeline)
15. Augmentation pipeline
16. ACL curriculum integration
"""
import sys
import time
import torch
import numpy as np
torch.manual_seed(42)
np.random.seed(42)
PASS = 0
FAIL = 0
def test(name, fn):
global PASS, FAIL
print(f"\nTest {PASS + FAIL + 1}: {name}...", flush=True)
try:
fn()
PASS += 1
print(f" ✅ PASSED")
except Exception as e:
FAIL += 1
print(f" ❌ FAILED: {e}")
import traceback
traceback.print_exc()
def count_params(model):
return sum(p.numel() for p in model.parameters())
# ============================================================
# Test 1: mLSTM Cell
# ============================================================
def test_mlstm_cell():
from vil_tracker.models.mlstm import mLSTMCell, LinearHeadwiseExpand
# Test LinearHeadwiseExpand
lhe = LinearHeadwiseExpand(768, num_heads=192, bias=False)
lhe_params = count_params(lhe)
assert lhe_params == 192 * 4 * 4, f"LHE params: {lhe_params} != {192*4*4}"
x = torch.randn(2, 10, 768)
y = lhe(x)
assert y.shape == (2, 10, 768), f"LHE output shape: {y.shape}"
# Test full mLSTM cell
cell = mLSTMCell(dim=384, proj_factor=2.0, qkv_proj_blocksize=4, num_heads=4)
cell_params = count_params(cell)
print(f" mLSTMCell params: {cell_params:,} ({cell_params/1e6:.3f}M)")
# Should be ~920K, not 2.66M
assert cell_params < 1_000_000, f"Cell has {cell_params:,} params (should be <1M)"
assert cell_params > 800_000, f"Cell has {cell_params:,} params (should be >800K)"
# Verify GroupNorm uses 192 groups (num_proj_heads), not 4 (num_heads)
assert cell.outnorm.num_groups == 192, f"GroupNorm should have 192 groups, got {cell.outnorm.num_groups}"
print(f" GroupNorm groups: {cell.outnorm.num_groups} (correct: per-projection-head)")
x = torch.randn(2, 20, 384)
y = cell(x)
assert y.shape == (2, 20, 384), f"Cell output shape: {y.shape}"
# Test reverse mode
y_rev = cell(x, reverse=True)
assert y_rev.shape == (2, 20, 384), f"Reverse output shape: {y_rev.shape}"
# Forward and reverse should produce different results
assert not torch.allclose(y, y_rev, atol=1e-3), "Forward and reverse should differ"
test("mLSTM Cell (LinearHeadwiseExpand)", test_mlstm_cell)
# ============================================================
# Test 2: mLSTM Block
# ============================================================
def test_mlstm_block():
from vil_tracker.models.mlstm import mLSTMBlock
block = mLSTMBlock(dim=384, proj_factor=2.0, qkv_proj_blocksize=4,
num_heads=4, mlp_ratio=4.0)
params = count_params(block)
print(f" mLSTMBlock params: {params:,} ({params/1e6:.3f}M)")
# No separate MLP — should be ~920K, same as cell + LayerNorm
assert params < 1_050_000, f"Block has {params:,} params (should be <1.05M without MLP)"
x = torch.randn(2, 20, 384)
y = block(x)
assert y.shape == (2, 20, 384), f"Block output shape: {y.shape}"
# Residual connection: output should be close-ish to input at init
diff = (y - x).abs().mean().item()
print(f" Residual diff from input: {diff:.4f}")
test("mLSTM Block (no separate MLP)", test_mlstm_block)
# ============================================================
# Test 3: TMoE MLP
# ============================================================
def test_tmoe():
from vil_tracker.models.backbone import TMoEMLP
tmoe = TMoEMLP(dim=384, mlp_ratio=4.0, num_experts=4)
params = count_params(tmoe)
print(f" TMoEMLP params: {params:,} ({params/1e6:.3f}M)")
x = torch.randn(2, 20, 384)
y = tmoe(x)
assert y.shape == (2, 20, 384), f"TMoE output shape: {y.shape}"
# Test freezing shared expert
tmoe.freeze_shared_expert()
frozen = sum(1 for p in tmoe.shared_expert.parameters() if not p.requires_grad)
total_shared = sum(1 for p in tmoe.shared_expert.parameters())
assert frozen == total_shared, "Shared expert should be fully frozen"
test("TMoE MLP", test_tmoe)
# ============================================================
# Test 4: Backbone (standard, small depth)
# ============================================================
def test_backbone_small():
from vil_tracker.models.backbone import ViLBackbone
backbone = ViLBackbone(dim=384, depth=4, patch_size=16, tmoe_blocks=0)
params = count_params(backbone)
print(f" Backbone (depth=4, no TMoE) params: {params:,} ({params/1e6:.3f}M)")
template = torch.randn(2, 3, 128, 128)
search = torch.randn(2, 3, 256, 256)
t_feat, s_feat = backbone(template, search)
assert t_feat.shape == (2, 64, 384), f"Template feat shape: {t_feat.shape}"
assert s_feat.shape == (2, 256, 384), f"Search feat shape: {s_feat.shape}"
test("Backbone (standard, depth=4)", test_backbone_small)
# ============================================================
# Test 5: Backbone with TMoE + integrated FiLM
# ============================================================
def test_backbone_tmoe_film():
from vil_tracker.models.backbone import ViLBackbone
from vil_tracker.models.film_temporal import TemporalModulationManager
backbone = ViLBackbone(dim=384, depth=6, patch_size=16, tmoe_blocks=2,
num_experts=4, film_interval=3)
params = count_params(backbone)
print(f" Backbone (depth=6, TMoE=2) params: {params:,} ({params/1e6:.3f}M)")
# Create temporal modulation manager
temporal_mod = TemporalModulationManager(dim=384, num_blocks=6, modulation_interval=3)
template = torch.randn(1, 3, 128, 128)
search = torch.randn(1, 3, 256, 256)
# First pass: no temporal context yet
t_feat, s_feat = backbone(template, search, temporal_mod_manager=temporal_mod)
assert t_feat.shape == (1, 64, 384), f"Template feat shape: {t_feat.shape}"
assert s_feat.shape == (1, 256, 384), f"Search feat shape: {s_feat.shape}"
# Second pass: temporal context should be active now
t_feat2, s_feat2 = backbone(template, search, temporal_mod_manager=temporal_mod)
# Output should differ when temporal modulation is active
assert t_feat2.shape == (1, 64, 384)
print(f" FiLM modulation active: features differ = {not torch.allclose(t_feat, t_feat2, atol=1e-5)}")
test("Backbone (TMoE + integrated FiLM)", test_backbone_tmoe_film)
# ============================================================
# Test 6: Prediction Heads
# ============================================================
def test_heads():
from vil_tracker.models.heads import CenterHead, UncertaintyHead, decode_predictions, create_hanning_window
center_head = CenterHead(dim=384, feat_size=16)
unc_head = UncertaintyHead(dim=384, feat_size=16)
print(f" CenterHead params: {count_params(center_head):,}")
print(f" UncertaintyHead params: {count_params(unc_head):,}")
search_feat = torch.randn(2, 256, 384)
preds = center_head(search_feat)
assert preds['heatmap'].shape == (2, 1, 16, 16), f"Heatmap shape: {preds['heatmap'].shape}"
assert preds['size'].shape == (2, 2, 16, 16), f"Size shape: {preds['size'].shape}"
assert preds['offset'].shape == (2, 2, 16, 16), f"Offset shape: {preds['offset'].shape}"
# Decode without Hanning
boxes, scores = decode_predictions(preds['heatmap'], preds['size'], preds['offset'])
assert boxes.shape == (2, 4), f"Boxes shape: {boxes.shape}"
assert scores.shape == (2,), f"Scores shape: {scores.shape}"
# Decode WITH Hanning window
hann = create_hanning_window(16)
assert hann.shape == (16, 16), f"Hanning shape: {hann.shape}"
assert abs(hann[8, 8].item() - 1.0) < 0.05, f"Hanning center should be ~1.0, got {hann[8, 8]}"
assert hann[0, 0].item() < 0.01, f"Hanning corner should be ~0, got {hann[0, 0]}"
boxes_h, scores_h = decode_predictions(preds['heatmap'], preds['size'], preds['offset'],
hanning_window=hann)
assert boxes_h.shape == (2, 4), f"Hanning boxes shape: {boxes_h.shape}"
print(f" Hanning window: center={hann[8,8]:.3f}, corner={hann[0,0]:.6f}")
print(f" Without Hanning: box={boxes[0].tolist()}, score={scores[0].item():.4f}")
print(f" With Hanning: box={boxes_h[0].tolist()}, score={scores_h[0].item():.4f}")
# Uncertainty
log_var = unc_head(search_feat)
assert log_var.shape == (2, 1, 16, 16), f"Log variance shape: {log_var.shape}"
test("Prediction Heads", test_heads)
# ============================================================
# Test 7: FiLM Temporal Modulation
# ============================================================
def test_film():
from vil_tracker.models.film_temporal import (
TemporalReliabilityCalibrator,
FiLMTemporalModulation,
TemporalModulationManager,
)
# Test individual components
calib = TemporalReliabilityCalibrator(384)
film = FiLMTemporalModulation(384)
x = torch.randn(2, 20, 384)
tc = torch.randn(2, 20, 384)
rel = calib(tc)
assert rel.shape == (2, 20, 1), f"Reliability shape: {rel.shape}"
assert (rel >= 0).all() and (rel <= 1).all(), "Reliability not in [0,1]"
modulated = film(x, tc, rel)
assert modulated.shape == (2, 20, 384), f"Modulated shape: {modulated.shape}"
# Test manager
manager = TemporalModulationManager(dim=384, num_blocks=24, modulation_interval=6)
print(f" TemporalModulationManager params: {count_params(manager):,}")
# First call: no temporal context yet, should return unchanged
y = manager.modulate(x, block_idx=5)
assert torch.allclose(y, x), "Should return unchanged without temporal context"
# Update context and try again
manager.update_temporal_context(x)
y = manager.modulate(x, block_idx=5) # block 5 → (5+1)%6==0, should modulate
assert y.shape == (2, 20, 384)
# Test reset
manager.reset()
y = manager.modulate(x, block_idx=5)
assert torch.allclose(y, x), "After reset, should return unchanged"
test("FiLM Temporal Modulation", test_film)
# ============================================================
# Test 8: Full Tracker (small depth for speed)
# ============================================================
def test_full_tracker_small():
from vil_tracker.models.tracker import ViLTracker, get_default_config
config = get_default_config()
config['depth'] = 4
config['tmoe_blocks'] = 1
config['film_interval'] = 2
tracker = ViLTracker(config)
params = count_params(tracker)
print(f" Tracker (depth=4) params: {params:,} ({params/1e6:.3f}M)")
B, K = 2, 3
template = torch.randn(B, 3, 128, 128)
# Test single-frame (backward compat)
search_single = torch.randn(B, 3, 256, 256)
output_s = tracker(template, search_single, use_temporal=False)
assert output_s['heatmap'].shape == (B, 1, 16, 16), f"Single heatmap: {output_s['heatmap'].shape}"
assert output_s['boxes'].shape == (B, 4), f"Single boxes: {output_s['boxes'].shape}"
assert output_s['scores'].shape == (B,), f"Single scores: {output_s['scores'].shape}"
print(f" Single-frame: boxes={output_s['boxes'][0].tolist()}")
# Test multi-frame sequence
searches = torch.randn(B, K, 3, 256, 256)
output_m = tracker(template, searches, use_temporal=True)
assert output_m['heatmap'].shape == (B, K, 1, 16, 16), f"Multi heatmap: {output_m['heatmap'].shape}"
assert output_m['boxes'].shape == (B, K, 4), f"Multi boxes: {output_m['boxes'].shape}"
assert output_m['scores'].shape == (B, K), f"Multi scores: {output_m['scores'].shape}"
assert output_m['search_feats'].shape == (B, K, 256, 384), f"Multi feats: {output_m['search_feats'].shape}"
print(f" Multi-frame (K={K}): frame 0 box={output_m['boxes'][0,0].tolist()}")
print(f" frame 2 box={output_m['boxes'][0,2].tolist()}")
tracker.reset_temporal()
test("Full Tracker (single + multi-frame)", test_full_tracker_small)
# ============================================================
# Test 9: Loss Functions (all 6)
# ============================================================
def test_losses():
from vil_tracker.training.losses import (
FocalLoss, GIoULoss, UncertaintyNLLLoss,
MemoryContrastiveLoss, AFKDDistillationLoss,
ADWLoss, CombinedTrackingLoss,
)
B = 4
# Focal loss
focal = FocalLoss()
pred_hm = torch.randn(B, 1, 16, 16)
gt_hm = torch.zeros(B, 1, 16, 16)
gt_hm[:, :, 8, 8] = 1.0
fl = focal(pred_hm, gt_hm)
print(f" Focal loss: {fl.item():.4f}")
assert fl.item() > 0, "Focal loss should be positive"
# GIoU loss
giou = GIoULoss()
pred_box = torch.tensor([[128.0, 128.0, 50.0, 50.0]] * B)
gt_box = torch.tensor([[130.0, 130.0, 48.0, 48.0]] * B)
gl = giou(pred_box, gt_box)
print(f" GIoU loss: {gl.item():.4f}")
assert 0 <= gl.item() <= 2, f"GIoU loss out of range: {gl.item()}"
# Uncertainty NLL loss
unc = UncertaintyNLLLoss()
pred_v = torch.randn(B, 4)
target_v = torch.randn(B, 4)
log_var = torch.zeros(B, 4) # unit variance
ul = unc(pred_v, target_v, log_var)
print(f" Uncertainty NLL loss: {ul.item():.4f}")
assert ul.item() > 0
# Contrastive loss
contrastive = MemoryContrastiveLoss()
feat_a = torch.randn(B, 384)
feat_b = feat_a + torch.randn(B, 384) * 0.1
cl = contrastive(feat_a, feat_b)
print(f" Contrastive loss: {cl.item():.4f}")
# AFKD distillation loss
afkd = AFKDDistillationLoss(student_dim=384, teacher_dim=768)
student_feat = torch.randn(B, 256, 384)
teacher_feat = torch.randn(B, 256, 768)
dl = afkd(student_feat, teacher_feat)
print(f" AFKD distillation loss: {dl.item():.4f}")
assert dl.item() > 0
# ADW loss
adw = ADWLoss(num_tasks=3)
losses = [torch.tensor(1.0), torch.tensor(0.5), torch.tensor(2.0)]
al = adw(losses)
print(f" ADW loss: {al.item():.4f}")
# Combined loss
combined = CombinedTrackingLoss()
pred = {
'heatmap': pred_hm,
'size': torch.rand(B, 2, 16, 16),
'boxes': pred_box,
'log_variance': torch.randn(B, 1, 16, 16),
}
loss_dict = combined(pred, gt_hm, torch.tensor([[0.2, 0.2]] * B), gt_box)
print(f" Combined loss: {loss_dict['total'].item():.4f}")
assert loss_dict['total'].item() > 0
test("Loss Functions (all 6)", test_losses)
# ============================================================
# Test 10: Kalman Filter
# ============================================================
def test_kalman():
from vil_tracker.inference.kalman import KalmanFilter
kf = KalmanFilter()
assert not kf.initialized
# Initialize
init_box = np.array([100.0, 100.0, 50.0, 50.0])
kf.initialize(init_box)
assert kf.initialized
# Predict + update cycle with moving target
for i in range(10):
pred = kf.predict()
assert len(pred) == 4, f"Prediction length: {len(pred)}"
# Simulate noisy measurement of linearly moving target
noise = np.random.randn(4) * 2
meas = init_box + np.array([i * 2, i * 1, 0, 0]) + noise
kf.update(meas, uncertainty=1.0)
state = kf.get_state()
print(f" Final state: cx={state[0]:.1f}, cy={state[1]:.1f}, w={state[2]:.1f}, h={state[3]:.1f}")
assert state[2] > 0 and state[3] > 0, "Width/height should be positive"
# Test outlier rejection (chi-squared gating)
kf.update(np.array([500.0, 500.0, 50.0, 50.0]), uncertainty=1.0) # Far outlier
state_after = kf.get_state()
# State should NOT have jumped to 500,500
assert state_after[0] < 200, f"Outlier should be rejected, cx={state_after[0]}"
test("Kalman Filter (8-state, adaptive)", test_kalman)
# ============================================================
# Test 11: Dataset (synthetic)
# ============================================================
def test_dataset():
from vil_tracker.data.dataset import SyntheticTrackingDataset, TrackingDataset
ds = SyntheticTrackingDataset(length=100, clip_length=3)
assert len(ds) == 100
sample = ds[0]
assert sample['template'].shape == (3, 128, 128), f"Template shape: {sample['template'].shape}"
assert sample['searches'].shape == (3, 3, 256, 256), f"Searches shape: {sample['searches'].shape}"
assert sample['heatmaps'].shape == (3, 1, 16, 16), f"Heatmaps shape: {sample['heatmaps'].shape}"
assert sample['sizes'].shape == (3, 2), f"Sizes shape: {sample['sizes'].shape}"
assert sample['boxes'].shape == (3, 4), f"Boxes shape: {sample['boxes'].shape}"
# Verify target moves across frames (not static)
cx_f0 = sample['boxes'][0, 0].item()
cx_f2 = sample['boxes'][2, 0].item()
print(f" Frame 0 cx: {cx_f0:.1f}, Frame 2 cx: {cx_f2:.1f} (moving target)")
# Check ACL difficulty changes motion magnitude
ds.set_acl_difficulty(0.0)
easy_sample = ds[42]
ds.set_acl_difficulty(1.0)
hard_sample = ds[42]
print(f" Easy frame spread: {(easy_sample['boxes'][:, 0].max() - easy_sample['boxes'][:, 0].min()).item():.1f} px")
print(f" Hard frame spread: {(hard_sample['boxes'][:, 0].max() - hard_sample['boxes'][:, 0].min()).item():.1f} px")
# Test backward-compatible alias
ds2 = TrackingDataset(synthetic=True, synthetic_length=50, clip_length=3)
assert len(ds2) == 50
sample2 = ds2[0]
assert sample2['searches'].shape[0] == 3, "Clip length should be 3"
test("Dataset (synthetic + backward compat)", test_dataset)
# ============================================================
# Test 12: Training Step (with temporal modulation)
# ============================================================
def test_training_step():
from vil_tracker.models.tracker import ViLTracker, get_default_config
from vil_tracker.training.losses import CombinedTrackingLoss, MemoryContrastiveLoss
from vil_tracker.models.heads import generate_heatmap
config = get_default_config()
config['depth'] = 2
config['tmoe_blocks'] = 0
config['film_interval'] = 2
model = ViLTracker(config)
model.train()
loss_fn = CombinedTrackingLoss()
contrastive_loss = MemoryContrastiveLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
B, K = 2, 3
template = torch.randn(B, 3, 128, 128)
searches = torch.randn(B, K, 3, 256, 256)
# GT targets for K frames
gt_heatmaps = torch.zeros(B, K, 1, 16, 16)
gt_heatmaps[:, :, :, 8, 8] = 1.0 # center
gt_sizes = torch.tensor([[[0.2, 0.3]] * K] * B)
gt_boxes = torch.tensor([[[128.0, 128.0, 51.2, 76.8]] * K] * B)
# Forward WITH temporal modulation, multi-frame
pred = model(template, searches, use_temporal=True)
assert pred['heatmap'].shape == (B, K, 1, 16, 16), f"Heatmap shape: {pred['heatmap'].shape}"
assert pred['boxes'].shape == (B, K, 4), f"Boxes shape: {pred['boxes'].shape}"
assert pred['scores'].shape == (B, K), f"Scores shape: {pred['scores'].shape}"
assert pred['search_feats'].shape == (B, K, 256, 384), f"Search feats: {pred['search_feats'].shape}"
# Accumulate loss over K frames
total_loss = torch.tensor(0.0)
for k in range(K):
pred_k = {
'heatmap': pred['heatmap'][:, k],
'size': pred['size'][:, k],
'boxes': pred['boxes'][:, k],
}
if 'log_variance' in pred:
pred_k['log_variance'] = pred['log_variance'][:, k]
loss_dict = loss_fn(pred_k, gt_heatmaps[:, k], gt_sizes[:, k], gt_boxes[:, k])
total_loss = total_loss + loss_dict['total']
total_loss = total_loss / K
# Add contrastive loss
t_pooled = pred['template_feat'].mean(dim=1)
s_pooled = pred['search_feats'][:, -1].mean(dim=1)
c_loss = contrastive_loss(t_pooled, s_pooled)
total_loss = total_loss + 0.1 * c_loss
# Backward
total_loss.backward()
has_grads = sum(1 for p in model.parameters() if p.grad is not None)
total_params_count = sum(1 for p in model.parameters())
print(f" Total loss: {total_loss.item():.4f} (K={K} frames, contr={c_loss.item():.4f})")
print(f" Params with gradients: {has_grads}/{total_params_count}")
optimizer.step()
optimizer.zero_grad()
assert total_loss.item() > 0
assert has_grads > 0
test("Training Step (K=3 sequence + contrastive)", test_training_step)
# ============================================================
# Test 13: Model Summary (FULL depth=24, constraint check)
# ============================================================
def test_model_summary():
from vil_tracker.models.tracker import ViLTracker, get_default_config
from vil_tracker.utils.helpers import print_model_summary
config = get_default_config()
model = ViLTracker(config)
summary = print_model_summary(model, config)
total_m = summary['total_params'] / 1e6
# HARD CONSTRAINTS
assert summary['param_ok'], f"FAIL: {total_m:.2f}M params exceeds 50M limit"
assert summary['size_ok'], f"FAIL: {summary['size_fp16_mb']:.1f}MB exceeds 500MB limit"
# GFLOPs is approximate, warn but don't fail if close
if not summary['flop_ok']:
print(f" ⚠️ GFLOPs estimate ({summary['gflops']:.2f}) exceeds 20, but this is approximate")
test("Model Summary (full depth=24)", test_model_summary)
# ============================================================
# Test 14: Online Tracker (inference pipeline)
# ============================================================
def test_online_tracker():
from vil_tracker.models.tracker import ViLTracker, get_default_config
from vil_tracker.inference.online_tracker import OnlineTracker
config = get_default_config()
config['depth'] = 2
config['tmoe_blocks'] = 0
config['film_interval'] = 2
model = ViLTracker(config)
model.eval()
tracker = OnlineTracker(model, device='cpu', template_size=128, search_size=256)
# Simulate a sequence: 480x640 frames with a moving rectangle
H, W = 480, 640
init_bbox = [200, 200, 60, 80] # [x, y, w, h]
# First frame
frame0 = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8)
# Draw target
x, y, w, h = init_bbox
frame0[y:y+h, x:x+w] = [255, 0, 0] # Red rectangle
tracker.initialize(frame0, init_bbox)
# Track for 5 frames
for i in range(1, 6):
frame = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8)
# Move target
nx = x + i * 5
ny = y + i * 3
frame[ny:ny+h, nx:nx+w] = [255, 0, 0]
bbox = tracker.track(frame)
assert len(bbox) == 4, f"Bbox should have 4 elements, got {len(bbox)}"
assert all(isinstance(v, (int, float, np.floating)) for v in bbox), f"Bbox values: {bbox}"
print(f" Frame {i}: predicted [{bbox[0]:.1f}, {bbox[1]:.1f}, {bbox[2]:.1f}, {bbox[3]:.1f}]")
print(f" Online tracker completed 5-frame sequence")
test("Online Tracker (inference pipeline)", test_online_tracker)
# ============================================================
# Test 15: Augmentation pipeline
# ============================================================
def test_augmentation():
from vil_tracker.data.dataset import TrackingAugmentation
aug = TrackingAugmentation(
brightness=0.2,
contrast=0.2,
horizontal_flip_prob=1.0, # Force flip to test bbox update
grayscale_prob=0.0,
blur_prob=0.0,
)
template = torch.rand(3, 128, 128)
search = torch.rand(3, 256, 256)
bbox = torch.tensor([128.0, 128.0, 50.0, 50.0]) # [cx, cy, w, h]
t_aug, s_aug, b_aug = aug(template, search, bbox)
assert t_aug.shape == (3, 128, 128), f"Aug template shape: {t_aug.shape}"
assert s_aug.shape == (3, 256, 256), f"Aug search shape: {s_aug.shape}"
assert b_aug.shape == (4,), f"Aug bbox shape: {b_aug.shape}"
# With flip_prob=1.0, cx should be flipped: new_cx = W - old_cx = 256 - 128 = 128
print(f" Original bbox: {bbox.tolist()}")
print(f" Augmented bbox: {b_aug.tolist()}")
assert abs(b_aug[0].item() - (256 - 128)) < 1.0, f"Flipped cx should be ~128, got {b_aug[0]}"
test("Augmentation pipeline", test_augmentation)
# ============================================================
# Test 16: ACL curriculum integration
# ============================================================
def test_acl_curriculum():
from vil_tracker.data.dataset import SyntheticTrackingDataset
ds = SyntheticTrackingDataset(length=100, acl_difficulty=0.0, clip_length=3)
# Easy: targets barely move
easy_spreads = []
for i in range(20):
sample = ds[i]
spread = (sample['boxes'][:, 0].max() - sample['boxes'][:, 0].min()).item()
easy_spreads.append(spread)
ds.set_acl_difficulty(1.0)
hard_spreads = []
for i in range(20):
sample = ds[i]
spread = (sample['boxes'][:, 0].max() - sample['boxes'][:, 0].min()).item()
hard_spreads.append(spread)
avg_easy = np.mean(easy_spreads)
avg_hard = np.mean(hard_spreads)
print(f" Avg cx spread (easy, d=0.0): {avg_easy:.1f} px")
print(f" Avg cx spread (hard, d=1.0): {avg_hard:.1f} px")
print(f" Hard > Easy: {avg_hard > avg_easy}")
test("ACL curriculum integration", test_acl_curriculum)
# ============================================================
# Summary
# ============================================================
print("\n" + "=" * 60)
print(f"Results: {PASS}/{PASS + FAIL} tests passed")
if FAIL > 0:
print(f" ❌ {FAIL} test(s) FAILED")
sys.exit(1)
else:
print(" ✅ All tests passed!")
sys.exit(0)