| import time |
| from datetime import datetime, timedelta, timezone |
|
|
| from trenches_env.env import FogOfWarDiplomacyEnv |
| from trenches_env.models import ( |
| AgentAction, |
| BenchmarkRunRequest, |
| ExternalSignal, |
| IngestNewsRequest, |
| LiveControlRequest, |
| StepSessionRequest, |
| ) |
| from trenches_env.session_manager import SessionManager |
| from trenches_env.source_bundles import AGENT_LIVE_SOURCE_BUNDLES |
| from trenches_env.source_ingestion import SourceHarvester |
|
|
|
|
| class ShippingFeedFetcher: |
| def fetch(self, _: str) -> tuple[str, str]: |
| return ( |
| """ |
| <rss> |
| <channel> |
| <title>Maritime Watch</title> |
| <item> |
| <title>Shipping risk rises in Hormuz after drone intercept near tanker lanes</title> |
| </item> |
| </channel> |
| </rss> |
| """, |
| "application/rss+xml", |
| ) |
|
|
|
|
| def build_live_manager() -> SessionManager: |
| harvester = SourceHarvester(fetcher=ShippingFeedFetcher(), auto_start=False) |
| env = FogOfWarDiplomacyEnv(source_harvester=harvester) |
| return SessionManager(env=env) |
|
|
|
|
| def test_session_lifecycle() -> None: |
| manager = SessionManager() |
| session = manager.create_session(seed=7) |
|
|
| assert session.world.turn == 0 |
| assert session.session_id |
| assert session.episode.max_turns == 1000 |
|
|
| live_session = manager.set_live_mode( |
| session.session_id, |
| LiveControlRequest(enabled=True, auto_step=False, poll_interval_ms=15_000), |
| ) |
|
|
| assert live_session.live.enabled is True |
| assert live_session.live.source_queue_sizes["us"] == len(AGENT_LIVE_SOURCE_BUNDLES["us"]) |
|
|
| response = manager.step_session( |
| session.session_id, |
| StepSessionRequest( |
| actions={ |
| "us": AgentAction( |
| actor="us", |
| type="negotiate", |
| target="gulf", |
| summary="Offer deconfliction and shipping guarantees.", |
| ), |
| "oversight": AgentAction( |
| actor="oversight", |
| type="oversight_review", |
| summary="Monitor for escalation drift.", |
| ), |
| } |
| ), |
| ) |
|
|
| assert response.session.world.turn == 1 |
| assert "gulf" in response.session.world.coalition_graph["us"] |
| assert "us" in response.session.observations |
| assert response.session.observations["us"].training_source_bundle |
| assert response.session.observations["us"].live_source_bundle |
| assert response.session.recent_traces |
| assert response.session.recent_traces[-1].turn == 1 |
| assert response.session.model_bindings["us"].action_tools |
| assert response.session.model_bindings["us"].decision_mode == "heuristic_fallback" |
| assert response.session.action_log |
| assert {entry.actor for entry in response.session.action_log[-2:]} == {"us", "oversight"} |
|
|
|
|
| def test_stage_1_disables_live_mode() -> None: |
| manager = SessionManager() |
| session = manager.create_session(seed=7, training_stage="stage_1_dense") |
|
|
| assert session.episode.training_stage == "stage_1_dense" |
| assert session.episode.fog_of_war is False |
| assert session.observations["us"].perceived_tension == session.world.tension_level |
|
|
| try: |
| manager.set_live_mode( |
| session.session_id, |
| LiveControlRequest(enabled=True, auto_step=False, poll_interval_ms=15_000), |
| ) |
| except ValueError: |
| pass |
| else: |
| raise AssertionError("stage_1_dense sessions should reject live mode") |
|
|
|
|
| def test_session_manager_creates_named_scenarios() -> None: |
| manager = SessionManager() |
| session = manager.create_session(seed=7, scenario_id="shipping_crisis") |
|
|
| assert session.episode.scenario_id == "shipping_crisis" |
| assert session.episode.scenario_name == "Shipping Crisis" |
| assert session.world.actor_state["gulf"]["shipping_continuity"] < 78.0 |
|
|
|
|
| def test_source_monitor_report_is_available_for_sessions() -> None: |
| manager = SessionManager() |
| session = manager.create_session(seed=7) |
| report = manager.source_monitor(session.session_id) |
|
|
| assert report.session_id == session.session_id |
| assert len(report.agents) == len(AGENT_LIVE_SOURCE_BUNDLES) |
| assert report.summary.active_source_count > 0 |
|
|
|
|
| def test_live_get_session_auto_steps_once_for_new_source_packets() -> None: |
| manager = build_live_manager() |
| session = manager.create_session(seed=7) |
| manager.set_live_mode( |
| session.session_id, |
| LiveControlRequest(enabled=True, auto_step=True, poll_interval_ms=1_000), |
| ) |
|
|
| first_live_tick = manager.get_session(session.session_id) |
|
|
| assert first_live_tick.world.turn == 1 |
| assert first_live_tick.live.last_auto_step_at is not None |
| assert first_live_tick.live.reacted_packet_fetched_at |
| assert first_live_tick.recent_traces[-1].actions["gulf"].type == "defend" |
| assert first_live_tick.recent_traces[-1].actions["oversight"].type == "oversight_review" |
|
|
| for observation in first_live_tick.observations.values(): |
| for packet in observation.source_packets: |
| if packet.fetched_at is not None: |
| manager._sessions[session.session_id].live.reacted_packet_fetched_at[packet.source_id] = packet.fetched_at |
| manager._sessions[session.session_id].live.last_auto_step_at = datetime.now(timezone.utc) - timedelta(seconds=2) |
| second_live_tick = manager.get_session(session.session_id) |
|
|
| assert second_live_tick.world.turn == 1 |
|
|
|
|
| def test_oversight_replaces_escalatory_actions_with_valid_overrides() -> None: |
| manager = SessionManager() |
| session = manager.create_session(seed=7) |
|
|
| response = manager.step_session( |
| session.session_id, |
| StepSessionRequest( |
| actions={ |
| "us": AgentAction(actor="us", type="strike", target="iran", summary="Strike escalation lane."), |
| "israel": AgentAction(actor="israel", type="strike", target="hezbollah", summary="Strike escalation lane."), |
| "iran": AgentAction(actor="iran", type="mobilize", summary="Mobilize escalation lane."), |
| "hezbollah": AgentAction(actor="hezbollah", type="deceive", summary="Deception escalation lane."), |
| } |
| ), |
| ) |
|
|
| assert response.oversight.triggered is True |
| assert response.oversight.action_override |
| for agent_id, override_action in response.oversight.action_override.items(): |
| assert override_action.type in {"hold", "negotiate", "defend", "intel_query"} |
| assert response.session.recent_traces[-1].actions[agent_id].type == override_action.type |
| assert response.session.world.last_actions |
|
|
|
|
| def test_background_runner_advances_live_sessions_without_dashboard_polling() -> None: |
| manager = build_live_manager() |
| session = manager.create_session(seed=7) |
| manager.set_live_mode( |
| session.session_id, |
| LiveControlRequest(enabled=True, auto_step=True, poll_interval_ms=1_000), |
| ) |
| manager.start_background_runner(tick_interval_seconds=0.05) |
|
|
| try: |
| deadline = time.time() + 1.5 |
| while time.time() < deadline: |
| current = manager._sessions[session.session_id] |
| if current.world.turn >= 1: |
| break |
| time.sleep(0.05) |
| current = manager._sessions[session.session_id] |
| finally: |
| manager.stop_background_runner() |
|
|
| assert current.world.turn >= 1 |
| assert current.live.last_auto_step_at is not None |
| assert current.reaction_log |
|
|
|
|
| def test_ingest_news_generates_structured_reaction_log() -> None: |
| manager = SessionManager() |
| session = manager.create_session(seed=7) |
|
|
| response = manager.ingest_news( |
| session.session_id, |
| IngestNewsRequest( |
| signals=[ |
| ExternalSignal( |
| source="wire-service", |
| headline="Shipping risk rises in Hormuz after reported drone intercept.", |
| region="gulf", |
| tags=["shipping", "attack"], |
| severity=0.76, |
| ) |
| ], |
| agent_ids=["us", "gulf", "oversight"], |
| ), |
| ) |
|
|
| assert response.session.world.turn == 1 |
| assert response.reaction is not None |
| assert response.reaction.turn == 1 |
| assert response.reaction.signals[0].source == "wire-service" |
| assert response.reaction.latent_event_ids |
| assert {outcome.agent_id for outcome in response.reaction.actor_outcomes} == {"us", "gulf", "oversight"} |
| assert all(outcome.action.metadata["mode"] in {"heuristic_fallback", "provider_inference"} for outcome in response.reaction.actor_outcomes) |
| assert response.session.reaction_log[-1].event_id == response.reaction.event_id |
| assert manager.reaction_log(session.session_id)[-1].event_id == response.reaction.event_id |
| assert response.session.belief_state["gulf"].beliefs |
| assert response.session.observations["gulf"].belief_brief |
|
|
|
|
| def test_provider_diagnostics_are_available_per_session() -> None: |
| manager = SessionManager() |
| session = manager.create_session(seed=7) |
|
|
| diagnostics = manager.provider_diagnostics(session.session_id) |
| us_diagnostics = next(entry for entry in diagnostics.agents if entry.agent_id == "us") |
|
|
| assert us_diagnostics.agent_id == "us" |
| assert us_diagnostics.status in {"idle", "fallback_only"} |
| assert us_diagnostics.request_count == 0 |
|
|
|
|
| def test_session_manager_lists_scenarios_and_runs_benchmarks() -> None: |
| manager = SessionManager() |
|
|
| scenarios = manager.list_scenarios() |
| assert any(scenario.id == "shipping_crisis" for scenario in scenarios) |
|
|
| result = manager.run_benchmark( |
| BenchmarkRunRequest( |
| scenario_ids=["shipping_crisis"], |
| seed=9, |
| steps_per_scenario=3, |
| ) |
| ) |
| assert result.scenario_count == 1 |
| assert result.results[0].scenario_id == "shipping_crisis" |
|
|