Spaces:
Running
Running
| """Tests for PatchHawkEnv (OpenEnv compliance + reward logic).""" | |
| import pytest | |
| from patchhawk.agent.environment import PatchHawkEnv | |
| from patchhawk.env_models import PatchHawkAction, PatchHawkObservation, PatchHawkState | |
| def env(): | |
| """Create a PatchHawkEnv with the default scenarios file.""" | |
| e = PatchHawkEnv(use_docker=False) | |
| yield e | |
| e.close() | |
| # ββ Basic API βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_reset_returns_observation(env): | |
| """reset() returns a PatchHawkObservation (OpenEnv API).""" | |
| obs = env.reset() | |
| assert isinstance(obs, PatchHawkObservation) | |
| assert hasattr(obs, "code_snippet") | |
| assert hasattr(obs, "static_flags") | |
| assert hasattr(obs, "risk_score") | |
| assert hasattr(obs, "done") | |
| assert hasattr(obs, "reward") | |
| assert hasattr(obs, "metadata") | |
| def test_observation_fields(env): | |
| """Verify observation field types.""" | |
| obs = env.reset() | |
| assert isinstance(obs.code_snippet, str) | |
| assert isinstance(obs.static_flags, list) | |
| assert isinstance(obs.risk_score, float) | |
| assert isinstance(obs.done, bool) | |
| assert isinstance(obs.metadata, dict) | |
| def test_step_returns_observation(env): | |
| """step() returns a PatchHawkObservation (OpenEnv API).""" | |
| env.reset() | |
| action = PatchHawkAction(action_type=env.ACTION_ANALYZE) | |
| obs = env.step(action) | |
| assert isinstance(obs, PatchHawkObservation) | |
| assert isinstance(obs.reward, (int, float)) | |
| assert isinstance(obs.done, bool) | |
| assert isinstance(obs.metadata, dict) | |
| def test_state_property(env): | |
| """state property returns a PatchHawkState.""" | |
| env.reset() | |
| state = env.state | |
| assert isinstance(state, PatchHawkState) | |
| assert hasattr(state, "episode_id") | |
| assert hasattr(state, "step_count") | |
| assert hasattr(state, "scenario_id") | |
| def test_all_action_types_accepted(env): | |
| """All five action types (0-4) are accepted.""" | |
| for action_type in range(5): | |
| obs = env.reset() | |
| action = PatchHawkAction(action_type=action_type) | |
| result = env.step(action) | |
| assert isinstance(result, PatchHawkObservation) | |
| # ββ Reward logic ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_block_malicious_positive_reward(env): | |
| malicious = [s for s in env.scenarios if s.get("label") == "malicious"] | |
| if not malicious: | |
| pytest.skip("No malicious scenarios available") | |
| env.reset(scenario=malicious[0]) | |
| action = PatchHawkAction(action_type=env.ACTION_BLOCK_PR) | |
| obs = env.step(action) | |
| assert obs.reward == 2.0 | |
| assert obs.done is True | |
| def test_block_benign_negative_reward(env): | |
| benign = [s for s in env.scenarios if s.get("label") == "benign"] | |
| if not benign: | |
| pytest.skip("No benign scenarios available") | |
| env.reset(scenario=benign[0]) | |
| action = PatchHawkAction(action_type=env.ACTION_BLOCK_PR) | |
| obs = env.step(action) | |
| assert obs.reward == -1.0 | |
| assert obs.done is True | |
| def test_execute_sandbox_reward(env): | |
| env.reset() | |
| action = PatchHawkAction(action_type=env.ACTION_EXECUTE_SANDBOX) | |
| obs = env.step(action) | |
| assert obs.reward == 0.1 | |
| assert obs.done is False | |
| def test_analyze_no_reward(env): | |
| env.reset() | |
| action = PatchHawkAction(action_type=env.ACTION_ANALYZE) | |
| obs = env.step(action) | |
| assert obs.reward == 0.0 | |
| assert obs.done is False | |
| def test_request_review_terminates(env): | |
| env.reset() | |
| action = PatchHawkAction(action_type=env.ACTION_REQUEST_REVIEW) | |
| obs = env.step(action) | |
| assert obs.reward == 0.0 | |
| assert obs.done is True | |
| def test_max_steps_penalty(env): | |
| malicious = [s for s in env.scenarios if s.get("label") == "malicious"] | |
| if not malicious: | |
| pytest.skip("No malicious scenarios available") | |
| env.reset(scenario=malicious[0]) | |
| action = PatchHawkAction(action_type=env.ACTION_ANALYZE) | |
| obs = None | |
| for _ in range(env.max_steps): | |
| obs = env.step(action) | |
| if obs.done: | |
| break | |
| # Last step on malicious without block/patch β -5.0 | |
| assert obs.reward == -5.0 | |
| assert obs.done is True | |
| def test_episode_with_scenario_kwarg(env): | |
| """Verify that passing a scenario via kwargs works.""" | |
| scenario = { | |
| "id": "test_custom", | |
| "type": "functional", | |
| "label": "benign", | |
| "code_snippet": "x = 42", | |
| "patch": None, | |
| "unit_test_code": None, | |
| "attack_type": None, | |
| } | |
| obs = env.reset(scenario=scenario) | |
| assert obs.code_snippet == "x = 42" | |
| assert obs.metadata["scenario_id"] == "test_custom" | |
| def test_step_counter_increments(env): | |
| """Verify step counter tracks correctly.""" | |
| env.reset() | |
| for i in range(3): | |
| action = PatchHawkAction(action_type=env.ACTION_ANALYZE) | |
| env.step(action) | |
| assert env.state.step_count == 3 | |
| def test_close_resets_scenario(env): | |
| """close() clears episode state.""" | |
| env.reset() | |
| env.close() | |
| assert env.current_scenario is None | |