| import importlib.util |
| import json |
| from pathlib import Path |
|
|
| import h5py |
| import pytest |
|
|
| from tests._shared.repo_paths import find_repo_root |
|
|
| pytestmark = [pytest.mark.lightweight, pytest.mark.gpu] |
|
|
|
|
| def _load_module(module_name: str, relative_path: str): |
| repo_root = find_repo_root(__file__) |
| module_path = repo_root / relative_path |
| spec = importlib.util.spec_from_file_location(module_name, module_path) |
| module = importlib.util.module_from_spec(spec) |
| assert spec.loader is not None |
| spec.loader.exec_module(module) |
| return module |
|
|
|
|
| matcher_mod = _load_module( |
| "oracle_action_matcher_under_test", |
| "src/robomme/robomme_env/utils/oracle_action_matcher.py", |
| ) |
| resolver_mod = _load_module( |
| "episode_dataset_resolver_under_test", |
| "src/robomme/env_record_wrapper/episode_dataset_resolver.py", |
| ) |
|
|
|
|
| def test_find_exact_label_option_index_matches_label_only(): |
| options = [ |
| {"label": "a", "action": "pick up the cube"}, |
| {"label": "b", "action": "put it down"}, |
| ] |
|
|
| assert matcher_mod.find_exact_label_option_index("a", options) == 0 |
| assert matcher_mod.find_exact_label_option_index("b", options) == 1 |
| assert matcher_mod.find_exact_label_option_index("pick up the cube", options) == -1 |
| assert matcher_mod.find_exact_label_option_index(1, options) == -1 |
|
|
|
|
| def test_map_action_text_to_option_label_strict_exact(): |
| options = [ |
| {"label": "a", "action": "pick up the cube"}, |
| {"label": "b", "action": "put it down"}, |
| ] |
|
|
| assert ( |
| matcher_mod.map_action_text_to_option_label("pick up the cube", options) == "a" |
| ) |
| assert matcher_mod.map_action_text_to_option_label("unknown action", options) is None |
| assert matcher_mod.map_action_text_to_option_label(None, options) is None |
|
|
|
|
| def test_episode_dataset_resolver_extracts_choice_command_and_ignores_empty_choice(tmp_path): |
| h5_path = tmp_path / "choice_oracle_commands.h5" |
|
|
| with h5py.File(h5_path, "w") as h5: |
| episode_group = h5.create_group("episode_0") |
|
|
| ts0 = episode_group.create_group("timestep_0") |
| ts0_action = ts0.create_group("action") |
| ts0_action.create_dataset( |
| "choice_action", |
| data=json.dumps( |
| { |
| "choice": "B", |
| "point": [34, 12], |
| } |
| ), |
| dtype=h5py.string_dtype(encoding="utf-8"), |
| ) |
| ts0_info = ts0.create_group("info") |
| ts0_info.create_dataset("is_video_demo", data=False) |
| ts0_info.create_dataset("is_subgoal_boundary", data=True) |
|
|
| ts1 = episode_group.create_group("timestep_1") |
| ts1_action = ts1.create_group("action") |
| ts1_action.create_dataset( |
| "choice_action", |
| data=json.dumps( |
| { |
| "choice": "", |
| "point": [30, 20], |
| } |
| ), |
| dtype=h5py.string_dtype(encoding="utf-8"), |
| ) |
| ts1_info = ts1.create_group("info") |
| ts1_info.create_dataset("is_video_demo", data=False) |
| ts1_info.create_dataset("is_subgoal_boundary", data=True) |
|
|
| resolver = resolver_mod.EpisodeDatasetResolver( |
| env_id="DummyEnv", |
| episode=0, |
| dataset_directory=h5_path, |
| ) |
| try: |
| command0 = resolver.get_step("multi_choice", 0) |
| assert command0 == {"choice": "B", "point": [34, 12]} |
| assert "position_3d" not in command0 |
|
|
| command1 = resolver.get_step("multi_choice", 1) |
| assert command1 is None |
| finally: |
| resolver.close() |
|
|