VLAarchtests / tests /test_memory_slot_write_gating.py
lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
import torch
from models.observation_memory import DualObservationMemory
def test_memory_slot_write_gating(tiny_policy_config):
config = tiny_policy_config(hidden_dim=16)
config.memory.scene_bank_size = 4
config.memory.belief_bank_size = 4
memory = DualObservationMemory(config.memory)
scene_tokens = torch.zeros(1, 12, config.backbone.hidden_dim)
history_scene_tokens = torch.zeros(1, 2, 12, config.backbone.hidden_dim)
history_actions = torch.zeros(1, 2, 14)
scene_tokens[:, :3] = 1.0
output = memory(scene_tokens, history_scene_tokens=history_scene_tokens, history_actions=history_actions)
active_slots = int((output["scene_write_gate"][0] > 0.2).sum().item())
assert active_slots <= 2
assert int(output["scene_write_gate"][0].argmax().item()) == 0