samrat-rm commited on
Commit
a818334
Β·
1 Parent(s): 66d62a2

feat(grade): inspected is upgraded to inspected_order. It rewards steps taken in order

Browse files
server/WhyDidItFail_environment.py CHANGED
@@ -26,7 +26,7 @@ class WhyDidItFailEnvironment(Environment):
26
  def __init__(self):
27
  self._state = State(episode_id=str(uuid4()), step_count=0)
28
  self.scenario: dict | None = None
29
- self.inspected: set[str] = set()
30
 
31
  @property
32
  def state(self) -> State:
@@ -34,7 +34,7 @@ class WhyDidItFailEnvironment(Environment):
34
 
35
  def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any) -> WhyDidItFailObservation:
36
  self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
37
- self.inspected = set()
38
 
39
  scenario_key = kwargs.get("scenario_key")
40
  if scenario_key and scenario_key in SCENARIOS:
@@ -66,7 +66,8 @@ class WhyDidItFailEnvironment(Environment):
66
 
67
  if action.action_type == "inspect_logs":
68
  step_reward = self._inspect_reward("logs", required)
69
- self.inspected.add("logs")
 
70
  return WhyDidItFailObservation(
71
  task_description="Continue your investigation.",
72
  visible_data={"training_logs": self.scenario["logs"]},
@@ -79,7 +80,8 @@ class WhyDidItFailEnvironment(Environment):
79
 
80
  elif action.action_type == "inspect_config":
81
  step_reward = self._inspect_reward("config", required)
82
- self.inspected.add("config")
 
83
  return WhyDidItFailObservation(
84
  task_description="Continue your investigation.",
85
  visible_data={"config": self.scenario["config"]},
@@ -92,7 +94,8 @@ class WhyDidItFailEnvironment(Environment):
92
 
93
  elif action.action_type == "inspect_gradients":
94
  step_reward = self._inspect_reward("gradients", required)
95
- self.inspected.add("gradients")
 
96
  return WhyDidItFailObservation(
97
  task_description="Continue your investigation.",
98
  visible_data={"gradient_norms": self.scenario["gradient_norms"]},
@@ -128,35 +131,51 @@ class WhyDidItFailEnvironment(Environment):
128
 
129
  # ── helpers ──────────────────────────────────────────────────────────────
130
 
 
 
 
131
  def _inspect_reward(self, source: str, required: list[str]) -> float:
132
- """Return step reward for an inspect action."""
133
- if source in self.inspected:
 
 
 
 
 
134
  return -0.05 # redundant inspection
 
135
  if source in required:
136
- return +0.05 # useful evidence
137
- return -0.05 # irrelevant source
 
 
 
138
 
139
  def _inspect_feedback(self, source: str, required: list[str], reward: float) -> str:
140
  label = {"logs": "training logs", "config": "hyperparameter config", "gradients": "gradient statistics"}[source]
141
- if source in self.inspected:
142
  return f"You already examined the {label}. No new information gained."
143
- if reward > 0:
144
- return f"You examined the {label}. This looks relevant."
145
- return f"You examined the {label}. This may not be relevant to the failure."
 
 
 
 
146
 
147
  def _grade(self, action: WhyDidItFailAction) -> tuple[float, str]:
148
  """Delegate to the unified grade() function and return (reward, feedback)."""
149
  assert self.scenario is not None
150
- diagnosis = (action.diagnosis or "").strip().lower()
151
  suggested_fix = (action.suggested_fix or "").strip().lower() or None
152
- difficulty = self.scenario["difficulty"]
153
 
154
  reward = grade(
155
  diagnosis=diagnosis,
156
  suggested_fix=suggested_fix,
157
  scenario=self.scenario,
158
  steps_taken=self._state.step_count,
159
- inspected=self.inspected,
160
  difficulty=difficulty,
161
  )
162
 
 
26
  def __init__(self):
27
  self._state = State(episode_id=str(uuid4()), step_count=0)
28
  self.scenario: dict | None = None
29
+ self.inspection_order: list[str] = [] # first-visit order; doubles as membership check
30
 
31
  @property
32
  def state(self) -> State:
 
34
 
35
  def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any) -> WhyDidItFailObservation:
36
  self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
37
+ self.inspection_order = []
38
 
39
  scenario_key = kwargs.get("scenario_key")
40
  if scenario_key and scenario_key in SCENARIOS:
 
66
 
67
  if action.action_type == "inspect_logs":
68
  step_reward = self._inspect_reward("logs", required)
69
+ if "logs" not in self.inspection_order:
70
+ self.inspection_order.append("logs")
71
  return WhyDidItFailObservation(
72
  task_description="Continue your investigation.",
73
  visible_data={"training_logs": self.scenario["logs"]},
 
80
 
81
  elif action.action_type == "inspect_config":
82
  step_reward = self._inspect_reward("config", required)
83
+ if "config" not in self.inspection_order:
84
+ self.inspection_order.append("config")
85
  return WhyDidItFailObservation(
86
  task_description="Continue your investigation.",
87
  visible_data={"config": self.scenario["config"]},
 
94
 
95
  elif action.action_type == "inspect_gradients":
96
  step_reward = self._inspect_reward("gradients", required)
97
+ if "gradients" not in self.inspection_order:
98
+ self.inspection_order.append("gradients")
99
  return WhyDidItFailObservation(
100
  task_description="Continue your investigation.",
101
  visible_data={"gradient_norms": self.scenario["gradient_norms"]},
 
131
 
132
  # ── helpers ──────────────────────────────────────────────────────────────
133
 
134
+ # Rewards decay as more required sources are discovered β€” first clue is worth most.
135
+ _REQUIRED_STEP_REWARDS = [0.10, 0.07, 0.05]
136
+
137
  def _inspect_reward(self, source: str, required: list[str]) -> float:
138
+ """Return step reward for an inspect action.
139
+
140
+ Required sources: progressive β€” +0.10 / +0.07 / +0.05 for 1st/2nd/3rd discovery.
141
+ Irrelevant sources: -0.03 (mild; some exploration is acceptable).
142
+ Re-inspection: -0.05 (waste).
143
+ """
144
+ if source in self.inspection_order:
145
  return -0.05 # redundant inspection
146
+
147
  if source in required:
148
+ n_found = sum(1 for s in self.inspection_order if s in required)
149
+ idx = min(n_found, len(self._REQUIRED_STEP_REWARDS) - 1)
150
+ return self._REQUIRED_STEP_REWARDS[idx]
151
+
152
+ return -0.03 # irrelevant source
153
 
154
  def _inspect_feedback(self, source: str, required: list[str], reward: float) -> str:
155
  label = {"logs": "training logs", "config": "hyperparameter config", "gradients": "gradient statistics"}[source]
156
+ if source in self.inspection_order:
157
  return f"You already examined the {label}. No new information gained."
158
+ if source in required:
159
+ remaining = len(set(required) - set(self.inspection_order) - {source})
160
+ msg = f"You examined the {label}. Relevant clue found (+{reward:.2f})."
161
+ if remaining > 0:
162
+ msg += f" {remaining} required source(s) still unexamined."
163
+ return msg
164
+ return f"You examined the {label}. This source is not required for this failure mode."
165
 
166
  def _grade(self, action: WhyDidItFailAction) -> tuple[float, str]:
167
  """Delegate to the unified grade() function and return (reward, feedback)."""
168
  assert self.scenario is not None
169
+ diagnosis = (action.diagnosis or "").strip().lower()
170
  suggested_fix = (action.suggested_fix or "").strip().lower() or None
171
+ difficulty = self.scenario["difficulty"]
172
 
173
  reward = grade(
174
  diagnosis=diagnosis,
175
  suggested_fix=suggested_fix,
176
  scenario=self.scenario,
177
  steps_taken=self._state.step_count,
178
+ inspection_order=self.inspection_order,
179
  difficulty=difficulty,
180
  )
181
 
server/graders.py CHANGED
@@ -75,21 +75,20 @@ def _diagnosis_score(diagnosis: str, scenario: dict) -> float:
75
  return max(0.0, min(0.7, score))
76
 
77
 
78
- def _evidence_score(inspected: set[str], required: set[str]) -> float:
79
  """
80
- +0.05 per required source the agent inspected (max +0.15 for 3 sources)
81
- βˆ’0.05 per irrelevant source the agent wasted a step on
82
- Clamped to [βˆ’0.10, +0.15].
 
83
  """
84
- relevant = len(inspected & required)
85
- irrelevant = len(inspected - required)
86
- score = (relevant * 0.06) - (irrelevant * 0.03)
 
87
 
88
- # small bonus if agent explored more than minimum but not excessively
89
- if len(inspected) > len(required):
90
- score += 0.02
91
-
92
- return max(-0.10, min(0.15, score))
93
 
94
 
95
  def _efficiency_score(steps_taken: int, min_steps: int) -> float:
@@ -140,7 +139,7 @@ def grade(
140
  suggested_fix: str | None = None,
141
  scenario: dict | None = None,
142
  steps_taken: int = 0,
143
- inspected: set[str] | None = None,
144
  difficulty: str = "easy", # kept for API compat β€” not used in scoring logic
145
  ) -> float:
146
  """
@@ -152,13 +151,13 @@ def grade(
152
  Max achievable without fix: 0.70 + 0.15 + 0.15 = 1.00
153
  Max achievable with fix: 0.70 + 0.15 + 0.15 + 0.15 = 1.00 (capped)
154
  """
155
- scenario = scenario or {}
156
- inspected = inspected or set()
157
- required = set(scenario.get("required_sources", ["logs"]))
158
- min_steps = len(required) + 1 # inspect all required sources + submit
159
 
160
  d_score = _diagnosis_score(diagnosis, scenario)
161
- e_score = _evidence_score(inspected, required)
162
  f_score = _efficiency_score(steps_taken, min_steps)
163
  b_score = _fix_bonus(suggested_fix, scenario)
164
 
 
75
  return max(0.0, min(0.7, score))
76
 
77
 
78
+ def _evidence_score(inspection_order: list[str], required: set[str]) -> float:
79
  """
80
+ +0.08 per required source inspected (max +0.24 for 3 sources)
81
+ βˆ’0.06 per required source NOT inspected at submit time
82
+ βˆ’0.02 per irrelevant source inspected
83
+ Clamped to [βˆ’0.15, +0.25].
84
  """
85
+ inspected_set = set(inspection_order)
86
+ relevant = inspected_set & required
87
+ missing = required - inspected_set
88
+ irrelevant = inspected_set - required
89
 
90
+ score = (len(relevant) * 0.08) - (len(missing) * 0.06) - (len(irrelevant) * 0.02)
91
+ return max(-0.15, min(0.25, score))
 
 
 
92
 
93
 
94
  def _efficiency_score(steps_taken: int, min_steps: int) -> float:
 
139
  suggested_fix: str | None = None,
140
  scenario: dict | None = None,
141
  steps_taken: int = 0,
142
+ inspection_order: list[str] | None = None,
143
  difficulty: str = "easy", # kept for API compat β€” not used in scoring logic
144
  ) -> float:
145
  """
 
151
  Max achievable without fix: 0.70 + 0.15 + 0.15 = 1.00
152
  Max achievable with fix: 0.70 + 0.15 + 0.15 + 0.15 = 1.00 (capped)
153
  """
154
+ scenario = scenario or {}
155
+ inspection_order = inspection_order or []
156
+ required = set(scenario.get("required_sources", ["logs"]))
157
+ min_steps = len(required) + 1 # inspect all required sources + submit
158
 
159
  d_score = _diagnosis_score(diagnosis, scenario)
160
+ e_score = _evidence_score(inspection_order, required)
161
  f_score = _efficiency_score(steps_taken, min_steps)
162
  b_score = _fix_bonus(suggested_fix, scenario)
163