cernenv-trainer / server /environment.py
anugrah55's picture
Update CERNenv Space
5f78183 verified
"""``CERNCollisionEnvironment``: orchestrates simulator + rules + rewards.
This is the OpenEnv-compatible ``Environment`` that the FastAPI app exposes.
It owns one episode at a time:
reset(seed) → builds a fresh latent state from a sampled scenario.
step(action) → validates → generates noisy output → updates state →
computes reward → builds the agent observation.
"""
from __future__ import annotations
import logging
import uuid
from typing import Any, List, Optional
from openenv.core.env_server import Environment, State
from models import (
AGENT_ENVIRONMENT_RULES,
ActionType,
CollisionObservation,
DiscoveryClaim,
ExperimentAction,
IntermediateOutput,
OutputType,
PipelineStepRecord,
ResourceUsage,
TaskSpec,
build_agent_system_prompt,
)
from server.rewards import (
RewardWeights,
compute_step_reward,
compute_terminal_reward,
)
from server.rules import RulesEngine, ViolationCode
from server.simulator import (
NoiseModel,
OutputGenerator,
TransitionEngine,
compute_action_cost,
)
from server.simulator.latent_state import FullLatentState
from server.tasks import sample_scenario, Scenario
logger = logging.getLogger(__name__)
# ── State container ──────────────────────────────────────────────────────
class CernState(State):
"""OpenEnv State subclass: includes hidden truth & runtime stats."""
scenario_name: Optional[str] = None
difficulty: Optional[str] = None
episode_done: bool = False
cumulative_reward: float = 0.0
terminal_reward: Optional[float] = None
discovered: Optional[bool] = None
correct_mass: Optional[bool] = None
correct_channel: Optional[bool] = None
correct_spin: Optional[bool] = None
truth_mass_gev: Optional[float] = None
truth_channel: Optional[str] = None
# ── Environment ──────────────────────────────────────────────────────────
class CERNCollisionEnvironment(Environment[ExperimentAction, CollisionObservation, CernState]):
"""LHC particle-discovery POMDP environment."""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(
self,
*,
max_steps: int = 40,
default_difficulty: Optional[str] = None,
default_scenario_name: Optional[str] = None,
reward_weights: Optional[RewardWeights] = None,
) -> None:
super().__init__()
self.max_steps = max_steps
self.default_difficulty = default_difficulty
self.default_scenario_name = default_scenario_name
self.reward_weights = reward_weights or RewardWeights()
self._state = CernState()
self._scenario: Optional[Scenario] = None
self._latent: Optional[FullLatentState] = None
self._task: Optional[TaskSpec] = None
self._noise: Optional[NoiseModel] = None
self._output_gen: Optional[OutputGenerator] = None
self._transition: Optional[TransitionEngine] = None
self._rules: Optional[RulesEngine] = None
self._history: List[PipelineStepRecord] = []
self._all_outputs: List[IntermediateOutput] = []
# ── Environment API ────────────────────────────────────────────────
@property
def state(self) -> CernState:
return self._state
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> CollisionObservation:
difficulty = kwargs.get("difficulty") or self.default_difficulty
scenario_name = kwargs.get("scenario") or self.default_scenario_name
scenario = sample_scenario(
difficulty=difficulty,
name=scenario_name,
seed=seed,
)
self._scenario = scenario
self._latent = scenario.fresh_latent()
self._task = scenario.task
if seed is not None:
self._latent.rng_seed = int(seed)
self._noise = NoiseModel(seed=self._latent.rng_seed)
self._output_gen = OutputGenerator(self._noise)
self._transition = TransitionEngine()
self._rules = RulesEngine(
mass_search_window_gev=tuple(self._task.mass_search_window_gev),
)
self._history = []
self._all_outputs = []
self._state = CernState(
episode_id=episode_id or f"ep-{uuid.uuid4().hex[:8]}",
step_count=0,
scenario_name=scenario.name,
difficulty=scenario.difficulty,
episode_done=False,
cumulative_reward=0.0,
truth_mass_gev=self._latent.particle.mass_gev,
truth_channel=self._latent.particle.primary_channel,
)
obs = self._build_observation(
latest_output=None,
done=False,
reward=0.0,
step_breakdown={},
rule_violations=[],
)
return obs
def step(
self,
action: ExperimentAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> CollisionObservation:
if self._latent is None:
self.reset()
if self._state.episode_done:
return self._build_terminal_observation(reason="episode already complete")
assert self._rules is not None
assert self._output_gen is not None
assert self._transition is not None
prev_state = self._latent.model_copy(deep=True)
rule_result = self._rules.validate(action, self._latent)
if not rule_result.allowed:
output = IntermediateOutput(
output_type=OutputType.FAILURE_REPORT,
step_index=self._state.step_count,
success=False,
quality_score=0.0,
summary="Action rejected: " + "; ".join(rule_result.messages),
warnings=rule_result.messages,
)
else:
output = self._output_gen.generate(
action=action,
state=self._latent,
step_index=self._state.step_count,
)
# Apply transition (state mutation + cost accounting)
if rule_result.allowed:
self._transition.step(self._latent, action, output)
else:
cost = compute_action_cost(action, output)
self._latent.resources.budget_used_musd += cost["musd"]
self._latent.resources.time_used_days += cost["days"]
self._latent.step_count += 1
self._all_outputs.append(output)
cost = compute_action_cost(action, output)
record = PipelineStepRecord(
step_index=self._state.step_count,
action_type=action.action_type,
method=action.method,
parameters=action.parameters,
output_summary=output.summary,
output_type=output.output_type,
success=output.success,
quality_score=float(output.quality_score),
cost_musd=float(cost["musd"]),
luminosity_cost_fb=float(cost["luminosity_fb"]),
time_cost_days=float(cost["days"]),
)
self._history.append(record)
step_reward = compute_step_reward(
action=action,
output=output,
state_before=prev_state,
state_after=self._latent,
rule_result=rule_result,
weights=self.reward_weights,
)
self._state.cumulative_reward += step_reward.reward
self._state.step_count += 1
terminal_now = (
action.action_type == ActionType.SUBMIT_DISCOVERY_CLAIM
and rule_result.allowed
)
time_up = (
self._state.step_count >= self.max_steps
or self._latent.resources.budget_exhausted
or self._latent.resources.time_exhausted
)
terminal_reward_value = 0.0
if terminal_now:
claim = self._claim_from_action(action)
term = compute_terminal_reward(
state=self._latent,
claim=claim,
weights=self.reward_weights,
)
terminal_reward_value = term.reward
self._state.cumulative_reward += terminal_reward_value
self._state.terminal_reward = terminal_reward_value
self._state.discovered = term.discovered
self._state.correct_mass = term.correct_mass
self._state.correct_channel = term.correct_channel
self._state.correct_spin = term.correct_spin
done = terminal_now or time_up
if done:
self._state.episode_done = True
observation = self._build_observation(
latest_output=output,
done=done,
reward=step_reward.reward + terminal_reward_value,
step_breakdown=step_reward.breakdown.components,
rule_violations=[
*(v.value for v in rule_result.violations),
*(v.value for v in rule_result.soft_violations),
],
)
return observation
# ── Helpers ────────────────────────────────────────────────────────
def _claim_from_action(self, action: ExperimentAction) -> DiscoveryClaim:
raw = action.parameters.get("claim") or {}
try:
return DiscoveryClaim(**raw)
except Exception as exc: # pragma: no cover - defensive
logger.warning("Malformed claim, defaulting: %s", exc)
return DiscoveryClaim()
def _build_terminal_observation(self, reason: str) -> CollisionObservation:
obs = self._build_observation(
latest_output=IntermediateOutput(
output_type=OutputType.FAILURE_REPORT,
step_index=self._state.step_count,
success=False,
summary=reason,
),
done=True,
reward=0.0,
step_breakdown={},
rule_violations=["episode_terminated"],
)
return obs
def _build_observation(
self,
*,
latest_output: Optional[IntermediateOutput],
done: bool,
reward: float,
step_breakdown: dict,
rule_violations: list,
) -> CollisionObservation:
assert self._latent is not None
assert self._task is not None
res = self._latent.resources
usage = ResourceUsage(
budget_used_musd=res.budget_used_musd,
budget_remaining_musd=res.budget_remaining,
luminosity_used_fb=res.luminosity_used_fb,
luminosity_remaining_fb=res.luminosity_remaining,
time_used_days=res.time_used_days,
time_remaining_days=res.time_remaining,
compute_hours_used=res.compute_hours_used,
)
obs = CollisionObservation(
done=done,
reward=float(reward),
task=self._task,
step_index=self._state.step_count,
pipeline_history=list(self._history),
available_channels=self._task.available_channels,
available_triggers=self._task.available_triggers,
available_tools=self._task.available_tools,
resource_usage=usage,
latest_output=latest_output,
all_outputs=list(self._all_outputs),
candidate_masses_gev=list(self._latent.candidate_masses_gev),
candidate_significances=list(self._latent.candidate_significances),
selected_channel=self._latent.selected_channel,
selected_beam_energy=self._latent.selected_beam_energy,
cumulative_significance=float(
self._latent.progress.best_significance_sigma or 0.0
),
uncertainty_summary={
"energy_scale_unc_gev": self._latent.detector.energy_scale_uncertainty,
"luminosity_unc": self._latent.detector.luminosity_uncertainty,
"resolution_gev": self._latent.detector.detector_resolution_gev,
},
rule_violations=rule_violations,
step_reward_breakdown=dict(step_breakdown),
)
return obs
# ── Convenience for diagnostics ────────────────────────────────────
def hidden_truth(self) -> Optional[dict]:
"""Reveal the hidden particle (debug / evaluation only)."""
if self._latent is None:
return None
return self._latent.particle.model_dump()
__all__ = [
"CernState",
"CERNCollisionEnvironment",
"AGENT_ENVIRONMENT_RULES",
"build_agent_system_prompt",
]