| import torch | |
| from models.observation_memory import DualObservationMemory | |
| def _slot_scene(hidden_dim: int, slot_idx: int, slot_size: int = 3) -> torch.Tensor: | |
| scene = torch.zeros(1, slot_size * 4, hidden_dim) | |
| start = slot_idx * slot_size | |
| scene[:, start : start + slot_size] = 1.0 | |
| return scene | |
| def test_spatial_memory_occlusion_persistence(tiny_policy_config): | |
| config = tiny_policy_config(hidden_dim=16) | |
| config.memory.scene_bank_size = 4 | |
| config.memory.belief_bank_size = 4 | |
| memory = DualObservationMemory(config.memory) | |
| visible = _slot_scene(config.backbone.hidden_dim, 0) | |
| occluded = torch.zeros_like(visible) | |
| history = torch.stack([visible[0], occluded[0]], dim=0).unsqueeze(0) | |
| history_actions = torch.zeros(1, 2, 14) | |
| during_occlusion = memory(occluded, history_scene_tokens=history, history_actions=history_actions) | |
| no_history = memory( | |
| occluded, | |
| history_scene_tokens=torch.zeros_like(history), | |
| history_actions=history_actions, | |
| ) | |
| on_reappearance = memory(visible, history_scene_tokens=history, history_actions=history_actions) | |
| occluded_delta = (during_occlusion["belief_memory_tokens"] - no_history["belief_memory_tokens"]).norm() | |
| reappeared_delta = (on_reappearance["belief_memory_tokens"] - during_occlusion["belief_memory_tokens"]).norm() | |
| assert occluded_delta > 1e-3 | |
| assert reappeared_delta > 1e-3 | |