Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| from replicalab.agents.lab_manager_agent import LabManagerAgent | |
| from replicalab.env import ReplicaLabEnv | |
| from replicalab.models import ScientistAction | |
| from replicalab.oracle import Oracle | |
| from replicalab.oracle_models import ( | |
| AdjudicatorRoundScore, | |
| EnvironmentEvent, | |
| OracleLabManagerObservation, | |
| PostMortem, | |
| Scenario, | |
| ) | |
| def _scenario_payload() -> dict: | |
| return { | |
| "paper": { | |
| "title": "Reproducing a Small Vision Benchmark", | |
| "domain": "ml_benchmark", | |
| "claim": "A compact model can recover >90% of reference accuracy under budget.", | |
| "method_summary": "Train a compact CNN with fixed augmentations and evaluate on a held-out split.", | |
| "original_sample_size": 1200, | |
| "original_duration_days": 3, | |
| "original_technique": "compact_cnn", | |
| "required_controls": ["seed_control", "baseline_model"], | |
| "required_equipment": ["GPU cluster", "validation server"], | |
| "required_reagents": ["dataset snapshot"], | |
| "statistical_test": "accuracy_gap", | |
| }, | |
| "lab_constraints": { | |
| "budget_total": 2400.0, | |
| "budget_remaining": 2400.0, | |
| "equipment": [ | |
| { | |
| "name": "GPU cluster", | |
| "available": True, | |
| "condition": "shared_booking", | |
| "booking_conflicts": ["Monday"], | |
| "cost_per_use": 250.0, | |
| }, | |
| { | |
| "name": "Validation server", | |
| "available": True, | |
| "condition": "operational", | |
| "booking_conflicts": [], | |
| "cost_per_use": 20.0, | |
| }, | |
| ], | |
| "reagents": [ | |
| { | |
| "name": "dataset snapshot", | |
| "in_stock": True, | |
| "quantity_available": 1.0, | |
| "unit": "copy", | |
| "lead_time_days": 0, | |
| "cost": 0.0, | |
| } | |
| ], | |
| "staff": [ | |
| { | |
| "name": "Alex", | |
| "role": "engineer", | |
| "available_days": ["Monday", "Tuesday"], | |
| "skills": ["training", "evaluation"], | |
| } | |
| ], | |
| "max_duration_days": 5, | |
| "safety_rules": ["No external internet during training."], | |
| "valid_substitutions": [ | |
| { | |
| "original": "GPU cluster", | |
| "substitute": "single high-memory GPU", | |
| "validity": "acceptable_with_caveats", | |
| "caveats": "Lower throughput is acceptable if evaluation fidelity is preserved.", | |
| } | |
| ], | |
| }, | |
| "minimum_viable_spec": { | |
| "min_sample_size": 800, | |
| "must_keep_controls": ["seed_control", "baseline_model"], | |
| "acceptable_techniques": ["compact_cnn", "distilled_cnn"], | |
| "min_duration_days": 2, | |
| "critical_equipment": ["Validation server"], | |
| "flexible_equipment": ["GPU cluster"], | |
| "critical_reagents": ["dataset snapshot"], | |
| "flexible_reagents": [], | |
| "power_threshold": 0.8, | |
| }, | |
| "difficulty": "medium", | |
| "narrative_hook": "The compute team just reduced your preferred GPU window.", | |
| } | |
| def _round_score_payload() -> dict: | |
| return { | |
| "rigor_flags": ["kept baseline_model"], | |
| "feasibility_flags": ["GPU window narrowed"], | |
| "info_gain": 0.6, | |
| "protocol_delta": 0.4, | |
| "momentum": 0.7, | |
| "contradiction_detected": False, | |
| "stalling_detected": False, | |
| "step_reward": 0.55, | |
| "notes": "Scientist asked a useful scheduling question and preserved controls.", | |
| } | |
| def _post_mortem_payload() -> dict: | |
| return { | |
| "overall_summary": "The Scientist converged on a feasible compact CNN plan.", | |
| "rigor_explanation": "Controls and the validation server were preserved.", | |
| "feasibility_explanation": "The final plan fit the available compute and duration window.", | |
| "fidelity_explanation": "The protocol stayed close to the benchmark setup.", | |
| "key_decisions": ["Kept seed control", "Accepted lower-throughput compute"], | |
| "missed_opportunities": ["Could have asked about booking conflicts earlier"], | |
| "comparison_note": "An optimal Scientist would have requested the alternate GPU window one round sooner.", | |
| } | |
| class _FakeMessagesAPI: | |
| def __init__(self, payloads: list[dict]) -> None: | |
| self._payloads = payloads | |
| self.calls = 0 | |
| def create(self, **_: object): | |
| payload = self._payloads[self.calls] | |
| self.calls += 1 | |
| class _Chunk: | |
| def __init__(self, text: str) -> None: | |
| self.text = text | |
| class _Response: | |
| def __init__(self, text: str) -> None: | |
| self.content = [_Chunk(text)] | |
| return _Response(json.dumps(payload)) | |
| class _FakeClient: | |
| def __init__(self, payloads: list[dict]) -> None: | |
| self.messages = _FakeMessagesAPI(payloads) | |
| def test_oracle_generate_scenario_parses_json() -> None: | |
| oracle = Oracle(_FakeClient([_scenario_payload()])) | |
| scenario = oracle.generate_scenario(seed=7, difficulty="medium", domain="ml_benchmark") | |
| assert isinstance(scenario, Scenario) | |
| assert scenario.paper.domain == "ml_benchmark" | |
| assert scenario.lab_constraints.equipment[0].name == "GPU cluster" | |
| def test_oracle_score_round_parses_structured_payload() -> None: | |
| oracle = Oracle(_FakeClient([_round_score_payload()])) | |
| scenario = Scenario.model_validate(_scenario_payload()) | |
| action = ScientistAction( | |
| action_type="request_info", | |
| sample_size=0, | |
| controls=[], | |
| technique="", | |
| duration_days=0, | |
| required_equipment=[], | |
| required_reagents=[], | |
| questions=["When is the GPU cluster available?"], | |
| rationale="", | |
| ) | |
| lab_manager = LabManagerAgent(_FakeClient([{ | |
| "response_type": "feasibility_report", | |
| "feasible": False, | |
| "issues": ["GPU cluster is shared-booked on Monday"], | |
| "suggestions": ["Use the single high-memory GPU instead"], | |
| "cost_estimate": 250.0, | |
| "time_estimate_days": 3, | |
| "message": "The GPU cluster is shared-booked Monday; the single high-memory GPU is acceptable with caveats.", | |
| }])) | |
| response = lab_manager.respond( | |
| OracleLabManagerObservation( | |
| lab_constraints=scenario.lab_constraints, | |
| current_protocol=None, | |
| scientist_action=action, | |
| round_number=1, | |
| ) | |
| ) | |
| score = oracle.score_round( | |
| scenario=scenario, | |
| round_number=1, | |
| scientist_action=action, | |
| lab_manager_response=response, | |
| conversation_history=[], | |
| current_protocol=None, | |
| previous_scores=[], | |
| ) | |
| assert isinstance(score, AdjudicatorRoundScore) | |
| assert score.step_reward == 0.55 | |
| def test_oracle_maybe_inject_event_returns_optional_event() -> None: | |
| oracle = Oracle(_FakeClient([{"inject": True, "event": { | |
| "event_type": "budget_cut", | |
| "description": "Finance reduced the remaining budget.", | |
| "state_changes": {"lab_constraints.budget_remaining": 1800.0}, | |
| "severity": "moderate", | |
| }}])) | |
| event = oracle.maybe_inject_event( | |
| scenario=Scenario.model_validate(_scenario_payload()), | |
| round_number=3, | |
| current_protocol=None, | |
| conversation_history=[], | |
| inject_enabled=True, | |
| ) | |
| assert isinstance(event, EnvironmentEvent) | |
| assert event.event_type == "budget_cut" | |
| def test_oracle_generate_post_mortem_parses_json() -> None: | |
| oracle = Oracle(_FakeClient([_post_mortem_payload()])) | |
| from replicalab.oracle_models import AdjudicatorTerminalScore | |
| post_mortem = oracle.generate_post_mortem( | |
| scenario=Scenario.model_validate(_scenario_payload()), | |
| final_protocol={"technique": "compact_cnn"}, | |
| conversation_history=[], | |
| terminal_score=AdjudicatorTerminalScore( | |
| rigor=0.9, | |
| feasibility=0.8, | |
| fidelity=0.85, | |
| parsimony=0.9, | |
| robustness=0.8, | |
| power_preservation=0.8, | |
| efficiency_bonus=0.2, | |
| communication_bonus=0.1, | |
| penalties={}, | |
| terminal_reward=5.0, | |
| total_reward=5.6, | |
| ), | |
| ) | |
| assert isinstance(post_mortem, PostMortem) | |
| assert "feasible compact CNN plan" in post_mortem.overall_summary | |
| def test_env_can_reset_from_oracle_scenario_without_changing_outer_contract() -> None: | |
| class _FakeOracle: | |
| def __init__(self) -> None: | |
| self.scenario = Scenario.model_validate(_scenario_payload()) | |
| def generate_scenario(self, seed: int, difficulty: str, domain: str) -> Scenario: | |
| assert seed == 11 | |
| assert difficulty == "medium" | |
| assert domain == "ml_benchmark" | |
| return self.scenario | |
| def score_round(self, **_: object): | |
| return AdjudicatorRoundScore.model_validate(_round_score_payload()) | |
| def maybe_inject_event(self, **_: object): | |
| return None | |
| def generate_post_mortem(self, **_: object): | |
| return PostMortem.model_validate(_post_mortem_payload()) | |
| env = ReplicaLabEnv( | |
| oracle=_FakeOracle(), | |
| enable_oracle_post_mortem=True, | |
| ) | |
| observation = env.reset(seed=11, scenario="ml_benchmark", difficulty="medium") | |
| assert observation.scientist is not None | |
| assert observation.scientist.paper_title == "Reproducing a Small Vision Benchmark" | |
| assert observation.lab_manager is not None | |
| assert "Validation server" in observation.lab_manager.equipment_available | |