WhyDidItFail / server /WhyDidItFail_environment.py
samrat-rm's picture
fix: clamp all rewards and scores to [0.10, 0.90]
d3b224f
import random
from typing import Any, Optional
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from models import WhyDidItFailAction, WhyDidItFailObservation, WhyDidItFailState
from server.scenarios import SCENARIOS
from server.graders import grade
class WhyDidItFailEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self._state = State(episode_id=str(uuid4()), step_count=0)
self.scenario: dict | None = None
self.inspection_order: list[str] = [] # first-visit order; doubles as membership check
self.max_steps: int = 0
@property
def state(self) -> WhyDidItFailState:
return WhyDidItFailState(
episode_id=self._state.episode_id,
step_count=self._state.step_count,
scenario_key=self.scenario.get("failure_mode") if self.scenario else None,
difficulty=self.scenario.get("difficulty") if self.scenario else None,
inspection_order=list(self.inspection_order),
required_sources=list(self.scenario.get("required_sources", [])) if self.scenario else [],
max_steps=self.max_steps,
)
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any) -> WhyDidItFailObservation:
self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
self.inspection_order = []
scenario_key = kwargs.get("scenario_key")
if scenario_key and scenario_key in SCENARIOS:
self.scenario = SCENARIOS[scenario_key]
else:
if seed is not None:
random.seed(seed)
self.scenario = random.choice(list(SCENARIOS.values()))
required_sources = self.scenario.get("required_sources", ["logs"])
self.max_steps = len(required_sources) * 3 + 2
return WhyDidItFailObservation(
task_description=(
"A training run has failed. Diagnose the root cause.\n"
f"Difficulty: {self.scenario['difficulty']}. "
"Available actions: inspect_logs, inspect_config, inspect_gradients, submit_diagnosis."
),
visible_data={"hint": "Start by inspecting the training logs."},
available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
steps_taken=0,
reward=0.10,
done=False,
feedback="Investigation started.",
)
def step(self, action: WhyDidItFailAction, timeout_s: Optional[float] = None, **kwargs: Any) -> WhyDidItFailObservation:
if self.scenario is None:
raise RuntimeError("Environment must be reset before calling step.")
self._state.step_count += 1
# Hard step limit — terminate immediately, grade() will return 0.10.
if self._state.step_count > self.max_steps and action.action_type != "submit_diagnosis":
return WhyDidItFailObservation(
task_description="Step limit reached. Episode terminated.",
visible_data={},
available_actions=[],
steps_taken=self._state.step_count,
reward=0.10,
done=True,
feedback=(
f"Step limit ({self.max_steps}) reached without a diagnosis. "
f"Score: 0.10. Actual failure: '{self.scenario['correct_diagnosis']}'."
),
)
required: list[str] = self.scenario.get("required_sources", ["logs"])
if action.action_type == "inspect_logs":
step_reward = self._inspect_reward("logs", required)
if "logs" not in self.inspection_order:
self.inspection_order.append("logs")
return WhyDidItFailObservation(
task_description="Continue your investigation.",
visible_data={"training_logs": self.scenario["logs"]},
available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
steps_taken=self._state.step_count,
reward=step_reward,
done=False,
feedback=self._inspect_feedback("logs", required, step_reward),
)
elif action.action_type == "inspect_config":
step_reward = self._inspect_reward("config", required)
if "config" not in self.inspection_order:
self.inspection_order.append("config")
return WhyDidItFailObservation(
task_description="Continue your investigation.",
visible_data={"config": self.scenario["config"]},
available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
steps_taken=self._state.step_count,
reward=step_reward,
done=False,
feedback=self._inspect_feedback("config", required, step_reward),
)
elif action.action_type == "inspect_gradients":
step_reward = self._inspect_reward("gradients", required)
if "gradients" not in self.inspection_order:
self.inspection_order.append("gradients")
return WhyDidItFailObservation(
task_description="Continue your investigation.",
visible_data={"gradient_norms": self.scenario["gradient_norms"]},
available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
steps_taken=self._state.step_count,
reward=step_reward,
done=False,
feedback=self._inspect_feedback("gradients", required, step_reward),
)
elif action.action_type == "submit_diagnosis":
final_reward, feedback = self._grade(action)
return WhyDidItFailObservation(
task_description="Diagnosis submitted.",
visible_data={},
available_actions=[],
steps_taken=self._state.step_count,
reward=final_reward,
done=True,
feedback=feedback,
)
else:
return WhyDidItFailObservation(
task_description="Continue your investigation.",
visible_data={},
available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
steps_taken=self._state.step_count,
reward=0.10,
done=False,
feedback=f"Unknown action '{action.action_type}'. Minimum reward.",
)
# Rewards decay as more required sources are discovered — first clue is worth most.
# All values are in [0.10, 0.90] — no negative rewards.
_REQUIRED_STEP_REWARDS = [0.50, 0.30, 0.15]
def _inspect_reward(self, source: str, required: list[str]) -> float:
"""Return step reward for an inspect action.
Required sources: progressive — 0.50 / 0.30 / 0.15 for 1st/2nd/3rd discovery.
Irrelevant sources: 0.10 (minimum; mild penalty via contrast with required rewards).
Re-inspection: 0.10 (minimum; waste with no new information).
All values are strictly in [0.10, 0.90].
"""
if source in self.inspection_order:
return 0.10 # redundant inspection — minimum reward
if source in required:
n_found = sum(1 for s in self.inspection_order if s in required)
idx = min(n_found, len(self._REQUIRED_STEP_REWARDS) - 1)
return self._REQUIRED_STEP_REWARDS[idx]
return 0.10 # irrelevant source — minimum reward
def _inspect_feedback(self, source: str, required: list[str], reward: float) -> str:
label = {"logs": "training logs", "config": "hyperparameter config", "gradients": "gradient statistics"}[source]
if source in self.inspection_order:
return f"You already examined the {label}. No new information gained."
if source in required:
remaining_sources = [s for s in required if s not in self.inspection_order and s != source]
msg = f"You examined the {label}. Relevant clue found (+{reward:.2f})."
if remaining_sources:
next_source = f"inspect_{remaining_sources[0]}"
msg += f" {len(remaining_sources)} required source(s) still unexamined. Next required action: {next_source}."
return msg
return f"You examined the {label}. This source is not required for this failure mode."
def _grade(self, action: WhyDidItFailAction) -> tuple[float, str]:
"""Delegate to the unified grade() function and return (reward, feedback)."""
assert self.scenario is not None
diagnosis = (action.diagnosis or "").strip().lower()
suggested_fix = (action.suggested_fix or "").strip().lower() or None
difficulty = self.scenario["difficulty"]
reward = grade(
diagnosis=diagnosis,
suggested_fix=suggested_fix,
scenario=self.scenario,
steps_taken=self._state.step_count,
inspection_order=self.inspection_order,
difficulty=difficulty,
)
if reward >= 0.80:
feedback = f"Excellent diagnosis! Score: {reward:.2f}"
elif reward >= 0.50:
feedback = f"Partially correct. Score: {reward:.2f}. Actual failure: '{self.scenario['correct_diagnosis']}'."
else:
feedback = f"Incorrect diagnosis. Score: {reward:.2f}. Actual failure: '{self.scenario['correct_diagnosis']}'."
return reward, feedback