replicalab / tests /test_oracle.py
maxxie114's picture
Initial HF Spaces deployment
80d8c84
from __future__ import annotations
import json
from replicalab.agents.lab_manager_agent import LabManagerAgent
from replicalab.env import ReplicaLabEnv
from replicalab.models import ScientistAction
from replicalab.oracle import Oracle
from replicalab.oracle_models import (
AdjudicatorRoundScore,
EnvironmentEvent,
OracleLabManagerObservation,
PostMortem,
Scenario,
)
def _scenario_payload() -> dict:
return {
"paper": {
"title": "Reproducing a Small Vision Benchmark",
"domain": "ml_benchmark",
"claim": "A compact model can recover >90% of reference accuracy under budget.",
"method_summary": "Train a compact CNN with fixed augmentations and evaluate on a held-out split.",
"original_sample_size": 1200,
"original_duration_days": 3,
"original_technique": "compact_cnn",
"required_controls": ["seed_control", "baseline_model"],
"required_equipment": ["GPU cluster", "validation server"],
"required_reagents": ["dataset snapshot"],
"statistical_test": "accuracy_gap",
},
"lab_constraints": {
"budget_total": 2400.0,
"budget_remaining": 2400.0,
"equipment": [
{
"name": "GPU cluster",
"available": True,
"condition": "shared_booking",
"booking_conflicts": ["Monday"],
"cost_per_use": 250.0,
},
{
"name": "Validation server",
"available": True,
"condition": "operational",
"booking_conflicts": [],
"cost_per_use": 20.0,
},
],
"reagents": [
{
"name": "dataset snapshot",
"in_stock": True,
"quantity_available": 1.0,
"unit": "copy",
"lead_time_days": 0,
"cost": 0.0,
}
],
"staff": [
{
"name": "Alex",
"role": "engineer",
"available_days": ["Monday", "Tuesday"],
"skills": ["training", "evaluation"],
}
],
"max_duration_days": 5,
"safety_rules": ["No external internet during training."],
"valid_substitutions": [
{
"original": "GPU cluster",
"substitute": "single high-memory GPU",
"validity": "acceptable_with_caveats",
"caveats": "Lower throughput is acceptable if evaluation fidelity is preserved.",
}
],
},
"minimum_viable_spec": {
"min_sample_size": 800,
"must_keep_controls": ["seed_control", "baseline_model"],
"acceptable_techniques": ["compact_cnn", "distilled_cnn"],
"min_duration_days": 2,
"critical_equipment": ["Validation server"],
"flexible_equipment": ["GPU cluster"],
"critical_reagents": ["dataset snapshot"],
"flexible_reagents": [],
"power_threshold": 0.8,
},
"difficulty": "medium",
"narrative_hook": "The compute team just reduced your preferred GPU window.",
}
def _round_score_payload() -> dict:
return {
"rigor_flags": ["kept baseline_model"],
"feasibility_flags": ["GPU window narrowed"],
"info_gain": 0.6,
"protocol_delta": 0.4,
"momentum": 0.7,
"contradiction_detected": False,
"stalling_detected": False,
"step_reward": 0.55,
"notes": "Scientist asked a useful scheduling question and preserved controls.",
}
def _post_mortem_payload() -> dict:
return {
"overall_summary": "The Scientist converged on a feasible compact CNN plan.",
"rigor_explanation": "Controls and the validation server were preserved.",
"feasibility_explanation": "The final plan fit the available compute and duration window.",
"fidelity_explanation": "The protocol stayed close to the benchmark setup.",
"key_decisions": ["Kept seed control", "Accepted lower-throughput compute"],
"missed_opportunities": ["Could have asked about booking conflicts earlier"],
"comparison_note": "An optimal Scientist would have requested the alternate GPU window one round sooner.",
}
class _FakeMessagesAPI:
def __init__(self, payloads: list[dict]) -> None:
self._payloads = payloads
self.calls = 0
def create(self, **_: object):
payload = self._payloads[self.calls]
self.calls += 1
class _Chunk:
def __init__(self, text: str) -> None:
self.text = text
class _Response:
def __init__(self, text: str) -> None:
self.content = [_Chunk(text)]
return _Response(json.dumps(payload))
class _FakeClient:
def __init__(self, payloads: list[dict]) -> None:
self.messages = _FakeMessagesAPI(payloads)
def test_oracle_generate_scenario_parses_json() -> None:
oracle = Oracle(_FakeClient([_scenario_payload()]))
scenario = oracle.generate_scenario(seed=7, difficulty="medium", domain="ml_benchmark")
assert isinstance(scenario, Scenario)
assert scenario.paper.domain == "ml_benchmark"
assert scenario.lab_constraints.equipment[0].name == "GPU cluster"
def test_oracle_score_round_parses_structured_payload() -> None:
oracle = Oracle(_FakeClient([_round_score_payload()]))
scenario = Scenario.model_validate(_scenario_payload())
action = ScientistAction(
action_type="request_info",
sample_size=0,
controls=[],
technique="",
duration_days=0,
required_equipment=[],
required_reagents=[],
questions=["When is the GPU cluster available?"],
rationale="",
)
lab_manager = LabManagerAgent(_FakeClient([{
"response_type": "feasibility_report",
"feasible": False,
"issues": ["GPU cluster is shared-booked on Monday"],
"suggestions": ["Use the single high-memory GPU instead"],
"cost_estimate": 250.0,
"time_estimate_days": 3,
"message": "The GPU cluster is shared-booked Monday; the single high-memory GPU is acceptable with caveats.",
}]))
response = lab_manager.respond(
OracleLabManagerObservation(
lab_constraints=scenario.lab_constraints,
current_protocol=None,
scientist_action=action,
round_number=1,
)
)
score = oracle.score_round(
scenario=scenario,
round_number=1,
scientist_action=action,
lab_manager_response=response,
conversation_history=[],
current_protocol=None,
previous_scores=[],
)
assert isinstance(score, AdjudicatorRoundScore)
assert score.step_reward == 0.55
def test_oracle_maybe_inject_event_returns_optional_event() -> None:
oracle = Oracle(_FakeClient([{"inject": True, "event": {
"event_type": "budget_cut",
"description": "Finance reduced the remaining budget.",
"state_changes": {"lab_constraints.budget_remaining": 1800.0},
"severity": "moderate",
}}]))
event = oracle.maybe_inject_event(
scenario=Scenario.model_validate(_scenario_payload()),
round_number=3,
current_protocol=None,
conversation_history=[],
inject_enabled=True,
)
assert isinstance(event, EnvironmentEvent)
assert event.event_type == "budget_cut"
def test_oracle_generate_post_mortem_parses_json() -> None:
oracle = Oracle(_FakeClient([_post_mortem_payload()]))
from replicalab.oracle_models import AdjudicatorTerminalScore
post_mortem = oracle.generate_post_mortem(
scenario=Scenario.model_validate(_scenario_payload()),
final_protocol={"technique": "compact_cnn"},
conversation_history=[],
terminal_score=AdjudicatorTerminalScore(
rigor=0.9,
feasibility=0.8,
fidelity=0.85,
parsimony=0.9,
robustness=0.8,
power_preservation=0.8,
efficiency_bonus=0.2,
communication_bonus=0.1,
penalties={},
terminal_reward=5.0,
total_reward=5.6,
),
)
assert isinstance(post_mortem, PostMortem)
assert "feasible compact CNN plan" in post_mortem.overall_summary
def test_env_can_reset_from_oracle_scenario_without_changing_outer_contract() -> None:
class _FakeOracle:
def __init__(self) -> None:
self.scenario = Scenario.model_validate(_scenario_payload())
def generate_scenario(self, seed: int, difficulty: str, domain: str) -> Scenario:
assert seed == 11
assert difficulty == "medium"
assert domain == "ml_benchmark"
return self.scenario
def score_round(self, **_: object):
return AdjudicatorRoundScore.model_validate(_round_score_payload())
def maybe_inject_event(self, **_: object):
return None
def generate_post_mortem(self, **_: object):
return PostMortem.model_validate(_post_mortem_payload())
env = ReplicaLabEnv(
oracle=_FakeOracle(),
enable_oracle_post_mortem=True,
)
observation = env.reset(seed=11, scenario="ml_benchmark", difficulty="medium")
assert observation.scientist is not None
assert observation.scientist.paper_title == "Reproducing a Small Vision Benchmark"
assert observation.lab_manager is not None
assert "Validation server" in observation.lab_manager.equipment_available