hackathon / tests /test_rewards.py
Ev3Dev's picture
Upload folder using huggingface_hub
db03c40 verified
"""Tests for the decomposable reward function."""
from models import ActionType, ConclusionClaim, ExperimentAction, IntermediateOutput, OutputType
from server.rewards.reward import RewardComputer
from server.simulator.latent_state import (
ExperimentProgress,
FullLatentState,
LatentBiologicalState,
ResourceState,
)
def _states(
prev_flags: dict | None = None,
next_flags: dict | None = None,
budget_used: float = 0.0,
):
prev = FullLatentState(
progress=ExperimentProgress(**(prev_flags or {})),
resources=ResourceState(budget_total=100_000, budget_used=budget_used),
)
nf = dict(prev_flags or {})
nf.update(next_flags or {})
nxt = FullLatentState(
progress=ExperimentProgress(**nf),
resources=ResourceState(budget_total=100_000, budget_used=budget_used + 5000),
)
return prev, nxt
class TestStepReward:
def test_valid_step_positive(self):
rc = RewardComputer()
prev, nxt = _states(
prev_flags={"samples_collected": True, "library_prepared": True},
next_flags={"cells_sequenced": True},
)
output = IntermediateOutput(
output_type=OutputType.SEQUENCING_RESULT,
step_index=1,
quality_score=0.85,
uncertainty=0.15,
)
rb = rc.step_reward(
ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
prev, nxt, output, [], [],
)
assert rb.total > 0
def test_hard_violation_negative(self):
rc = RewardComputer()
prev, nxt = _states()
output = IntermediateOutput(
output_type=OutputType.FAILURE_REPORT,
step_index=1,
success=False,
)
rb = rc.step_reward(
ExperimentAction(action_type=ActionType.SEQUENCE_CELLS),
prev, nxt, output, ["blocked"], [],
)
assert rb.total < 0
def test_premature_meta_action_gets_penalized(self):
rc = RewardComputer()
prev, nxt = _states(
prev_flags={"data_normalized": True},
next_flags={"followup_designed": True},
budget_used=2_000,
)
output = IntermediateOutput(
output_type=OutputType.FOLLOWUP_DESIGN,
step_index=2,
quality_score=1.0,
uncertainty=0.0,
)
rb = rc.step_reward(
ExperimentAction(action_type=ActionType.DESIGN_FOLLOWUP),
prev,
nxt,
output,
[],
[],
)
assert rb.components.get("premature_meta_action_penalty", 0.0) < 0.0
class TestTerminalReward:
def test_correct_conclusion_rewarded(self):
rc = RewardComputer()
state = FullLatentState(
biology=LatentBiologicalState(
causal_mechanisms=["TGF-beta-driven fibrosis"],
true_markers=["NPPA"],
),
progress=ExperimentProgress(
samples_collected=True, cells_sequenced=True,
qc_performed=True, data_filtered=True,
data_normalized=True, de_performed=True,
conclusion_reached=True,
),
resources=ResourceState(budget_total=100_000, budget_used=40_000),
)
claims = [
ConclusionClaim(
claim="TGF-beta-driven fibrosis observed",
confidence=0.9,
claim_type="causal",
),
]
rb = rc.terminal_reward(
state,
claims,
[],
discovered_markers=["NPPA"],
candidate_mechanisms=["TGF-beta-driven fibrosis"],
)
assert rb.terminal > 0
def test_overconfident_wrong_claim_penalised(self):
rc = RewardComputer()
state = FullLatentState(
biology=LatentBiologicalState(causal_mechanisms=["real_mechanism"]),
progress=ExperimentProgress(conclusion_reached=True),
)
claims = [
ConclusionClaim(
claim="completely_wrong_mechanism",
confidence=0.95,
claim_type="causal",
),
]
rb = rc.terminal_reward(state, claims, [])
assert rb.components.get("overconfidence_penalty", 0) < 0
def test_discovery_error_penalizes_wrong_markers_and_mechanisms(self):
rc = RewardComputer()
state = FullLatentState(
biology=LatentBiologicalState(
true_markers=["NPPA", "NPPB"],
causal_mechanisms=["TGF-beta-driven fibrosis"],
),
progress=ExperimentProgress(
samples_collected=True,
cells_sequenced=True,
qc_performed=True,
data_filtered=True,
data_normalized=True,
de_performed=True,
markers_discovered=True,
conclusion_reached=True,
),
resources=ResourceState(budget_total=100_000, budget_used=40_000),
)
aligned = rc.terminal_reward(
state,
[],
[],
discovered_markers=["NPPA", "NPPB"],
candidate_mechanisms=["TGF-beta-driven fibrosis"],
)
misaligned = rc.terminal_reward(
state,
[],
[],
discovered_markers=["WRONG1", "WRONG2"],
candidate_mechanisms=["unrelated inflammatory process"],
)
assert aligned.components["discovery_alignment"] > misaligned.components["discovery_alignment"]
assert aligned.components["discovery_error_penalty"] > misaligned.components["discovery_error_penalty"]
assert aligned.terminal > misaligned.terminal
def test_conclusion_error_penalizes_wrong_structured_claims(self):
rc = RewardComputer()
state = FullLatentState(
biology=LatentBiologicalState(
true_markers=["NPPA", "NPPB"],
causal_mechanisms=["TGF-beta-driven fibrosis"],
),
progress=ExperimentProgress(
data_normalized=True,
de_performed=True,
markers_discovered=True,
pathways_analyzed=True,
conclusion_reached=True,
),
resources=ResourceState(budget_total=100_000, budget_used=40_000),
)
aligned = rc.terminal_reward(
state,
[
ConclusionClaim(
top_markers=["NPPA", "NPPB"],
causal_mechanisms=["TGF-beta-driven fibrosis"],
confidence=0.8,
),
],
[],
)
misaligned = rc.terminal_reward(
state,
[
ConclusionClaim(
top_markers=["WRONG1"],
causal_mechanisms=["unrelated process"],
confidence=0.8,
),
],
[],
)
assert aligned.components["conclusion_alignment"] > misaligned.components["conclusion_alignment"]
assert aligned.components["conclusion_error_penalty"] > misaligned.components["conclusion_error_penalty"]
assert aligned.terminal > misaligned.terminal