open-range / tests /test_models.py
Aaron Brown
Full OpenRange implementation: environment, builder, validator, agents, tests
8c486a8
"""Tests for OpenEnv-compatible model serialization."""
import json
from open_range.server.models import RangeAction, RangeObservation, RangeState
class TestRangeAction:
"""RangeAction serialization."""
def test_create(self):
a = RangeAction(command="nmap -sV web", mode="red")
assert a.command == "nmap -sV web"
assert a.mode == "red"
def test_serialize_deserialize(self):
a = RangeAction(command="curl http://web/", mode="red")
data = a.model_dump()
assert data["command"] == "curl http://web/"
assert data["mode"] == "red"
a2 = RangeAction(**data)
assert a2.command == a.command
assert a2.mode == a.mode
def test_mode_literal(self):
"""mode must be 'red' or 'blue'."""
a = RangeAction(command="x", mode="blue")
assert a.mode == "blue"
def test_json_roundtrip(self):
a = RangeAction(command="hydra -l admin web ssh", mode="red")
js = a.model_dump_json()
a2 = RangeAction.model_validate_json(js)
assert a2.command == a.command
assert a2.mode == a.mode
class TestRangeObservation:
"""RangeObservation inherits done/reward from Observation base."""
def test_defaults(self):
obs = RangeObservation()
assert obs.done is False
assert obs.reward is None
assert obs.stdout == ""
assert obs.stderr == ""
assert obs.flags_captured == []
assert obs.alerts == []
def test_done_inherited(self):
obs = RangeObservation(done=True, reward=1.0)
assert obs.done is True
assert obs.reward == 1.0
def test_custom_fields(self):
obs = RangeObservation(
stdout="flag found",
flags_captured=["FLAG{test}"],
alerts=["IDS alert"],
)
assert obs.flags_captured == ["FLAG{test}"]
assert obs.alerts == ["IDS alert"]
def test_json_roundtrip(self):
obs = RangeObservation(
stdout="output",
stderr="err",
done=True,
reward=0.5,
flags_captured=["FLAG{a}"],
alerts=["alert1"],
)
js = obs.model_dump_json()
obs2 = RangeObservation.model_validate_json(js)
assert obs2.stdout == obs.stdout
assert obs2.done is True
assert obs2.reward == 0.5
assert obs2.flags_captured == ["FLAG{a}"]
class TestRangeState:
"""RangeState inherits episode_id/step_count from State base."""
def test_defaults(self):
s = RangeState()
assert s.episode_id is None
assert s.step_count == 0
assert s.mode == ""
assert s.tier == 1
assert s.flags_found == []
def test_inherited_fields(self):
s = RangeState(episode_id="ep_42", step_count=5)
assert s.episode_id == "ep_42"
assert s.step_count == 5
def test_custom_fields(self):
s = RangeState(
episode_id="ep1",
mode="red",
tier=3,
flags_found=["FLAG{a}"],
services_status={"web": "running"},
)
assert s.tier == 3
assert s.flags_found == ["FLAG{a}"]
assert s.services_status["web"] == "running"
def test_json_roundtrip(self):
s = RangeState(
episode_id="ep99",
step_count=10,
mode="blue",
flags_found=["FLAG{x}"],
tier=2,
)
js = s.model_dump_json()
s2 = RangeState.model_validate_json(js)
assert s2.episode_id == "ep99"
assert s2.step_count == 10
assert s2.mode == "blue"
assert s2.tier == 2
def test_model_dump_and_back(self):
s = RangeState(episode_id="e1", step_count=3, mode="red", tier=1)
d = s.model_dump()
s2 = RangeState(**d)
assert s2.episode_id == s.episode_id
assert s2.step_count == s.step_count