Spaces:
Sleeping
Sleeping
feat: upgrading the system and user prompt, upgrading the _make_env() function
Browse files- inference.py +33 -14
inference.py
CHANGED
|
@@ -71,18 +71,25 @@ SYSTEM_PROMPT = textwrap.dedent("""
|
|
| 71 |
|
| 72 |
Examples:
|
| 73 |
{"action_type": "inspect_logs"}
|
| 74 |
-
{"action_type": "submit_diagnosis", "diagnosis": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
RULES:
|
| 77 |
- submit_diagnosis MUST include all three fields: diagnosis, suggested_fix, reasoning.
|
| 78 |
- diagnosis is the short failure mode label β it is REQUIRED, never omit it.
|
| 79 |
-
- reasoning must cite specific values from the data you inspected (loss values, lr, gradient norms, etc.).
|
| 80 |
- Use exact failure mode phrasing for diagnosis: "exploding gradients", "overfitting", "underfitting",
|
| 81 |
"learning rate too high", "learning rate too low", "vanishing gradients",
|
| 82 |
"dying relu", "missing regularization", "batch size too small",
|
| 83 |
"optimizer misconfiguration", "bad weight initialization", "lr scheduler misconfiguration".
|
| 84 |
-
-
|
| 85 |
-
- If
|
|
|
|
| 86 |
- Never inspect the same source twice.
|
| 87 |
""").strip()
|
| 88 |
|
|
@@ -98,6 +105,8 @@ def _user_prompt(step: int, obs_summary: str, history: List[str]) -> str:
|
|
| 98 |
Recent history:
|
| 99 |
{history_block}
|
| 100 |
|
|
|
|
|
|
|
| 101 |
Respond with a JSON action.
|
| 102 |
""").strip()
|
| 103 |
|
|
@@ -136,9 +145,23 @@ def _get_action(client: OpenAI, step: int, obs_summary: str, history: List[str])
|
|
| 136 |
|
| 137 |
# ββ episode runner ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
|
| 139 |
-
async def
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
obs = result.observation
|
| 143 |
history: List[str] = []
|
| 144 |
rewards: List[float] = []
|
|
@@ -192,7 +215,7 @@ async def run_episode(env: WhyDidItFailEnv, client: OpenAI, scenario_key: str) -
|
|
| 192 |
print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning={judge_score:.3f} total={score:.3f}", flush=True)
|
| 193 |
|
| 194 |
success = score >= SUCCESS_THRESHOLD
|
| 195 |
-
return {"scenario_key": scenario_key, "score": score, "steps": len(rewards), "success": success}
|
| 196 |
|
| 197 |
|
| 198 |
# ββ task runners ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -206,7 +229,7 @@ async def run_task(task_name: str, scenario_keys: List[str], env: WhyDidItFailEn
|
|
| 206 |
|
| 207 |
results = []
|
| 208 |
for key in scenario_keys:
|
| 209 |
-
res = await run_episode(env, client, key)
|
| 210 |
results.append(res)
|
| 211 |
print(f"[RESULT] scenario={res['scenario_key']} score={res['score']:.3f} steps={res['steps']} success={str(res['success']).lower()}", flush=True)
|
| 212 |
|
|
@@ -219,11 +242,7 @@ async def run_task(task_name: str, scenario_keys: List[str], env: WhyDidItFailEn
|
|
| 219 |
|
| 220 |
async def main() -> None:
|
| 221 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 222 |
-
env = (
|
| 223 |
-
await WhyDidItFailEnv.from_docker_image(IMAGE_NAME)
|
| 224 |
-
if IMAGE_NAME
|
| 225 |
-
else WhyDidItFailEnv(base_url=SERVER_URL)
|
| 226 |
-
)
|
| 227 |
|
| 228 |
try:
|
| 229 |
await run_task("easy", EASY_SCENARIOS, env, client)
|
|
|
|
| 71 |
|
| 72 |
Examples:
|
| 73 |
{"action_type": "inspect_logs"}
|
| 74 |
+
{"action_type": "submit_diagnosis", "diagnosis": "overfitting", "suggested_fix": "add dropout=0.3 and weight_decay=0.01", "reasoning": "train_loss fell to 0.03 by epoch 20 while val_loss rose to 2.34; train_acc=0.99 vs val_acc=0.54 β clear generalization gap. Config shows dropout=0.0 and weight_decay=0.0."}
|
| 75 |
+
|
| 76 |
+
DIAGNOSIS PROCESS β follow this every episode:
|
| 77 |
+
1. Call inspect_logs first β always.
|
| 78 |
+
2. Read the Data field carefully. Note the exact numeric values (loss, acc, lr, gradient norms, model).
|
| 79 |
+
3. If Feedback says "Next required action: inspect_X" β call that action next, no exceptions.
|
| 80 |
+
4. When no required actions remain, form your diagnosis based ONLY on values you actually saw in Data.
|
| 81 |
+
5. Your reasoning MUST quote specific numbers from the Data you received (e.g. "val_loss=2.34 at epoch 20, train_acc=0.99"). If you cannot quote a specific number from the Data, you have not read it β do not submit yet.
|
| 82 |
|
| 83 |
RULES:
|
| 84 |
- submit_diagnosis MUST include all three fields: diagnosis, suggested_fix, reasoning.
|
| 85 |
- diagnosis is the short failure mode label β it is REQUIRED, never omit it.
|
|
|
|
| 86 |
- Use exact failure mode phrasing for diagnosis: "exploding gradients", "overfitting", "underfitting",
|
| 87 |
"learning rate too high", "learning rate too low", "vanishing gradients",
|
| 88 |
"dying relu", "missing regularization", "batch size too small",
|
| 89 |
"optimizer misconfiguration", "bad weight initialization", "lr scheduler misconfiguration".
|
| 90 |
+
- CRITICAL: If Feedback contains "Next required action: inspect_X", you MUST call that action before submitting. Do not submit while any required source is unexamined.
|
| 91 |
+
- If Feedback says "This source is not required for this failure mode." β submit your diagnosis on the very next step. Do NOT inspect other sources.
|
| 92 |
+
- If Feedback says "Relevant clue found" with no "Next required action" β all sources are covered. Submit on the next step.
|
| 93 |
- Never inspect the same source twice.
|
| 94 |
""").strip()
|
| 95 |
|
|
|
|
| 105 |
Recent history:
|
| 106 |
{history_block}
|
| 107 |
|
| 108 |
+
Before responding: read the Data above carefully. What specific numeric values do you see?
|
| 109 |
+
Quote at least one value from the Data in your reasoning before submitting a diagnosis.
|
| 110 |
Respond with a JSON action.
|
| 111 |
""").strip()
|
| 112 |
|
|
|
|
| 145 |
|
| 146 |
# ββ episode runner ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 147 |
|
| 148 |
+
async def _make_env() -> WhyDidItFailEnv:
|
| 149 |
+
return (
|
| 150 |
+
await WhyDidItFailEnv.from_docker_image(IMAGE_NAME)
|
| 151 |
+
if IMAGE_NAME
|
| 152 |
+
else WhyDidItFailEnv(base_url=SERVER_URL)
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
async def run_episode(env: WhyDidItFailEnv, client: OpenAI, scenario_key: str) -> tuple[dict, WhyDidItFailEnv]:
|
| 157 |
+
"""Run one full episode for a specific scenario. Returns (result dict, env).
|
| 158 |
+
env may be a fresh reconnected instance if the WebSocket dropped between episodes."""
|
| 159 |
+
try:
|
| 160 |
+
result = await env.reset(scenario_key=scenario_key)
|
| 161 |
+
except ConnectionClosedError:
|
| 162 |
+
print(f" [WARN] scenario={scenario_key} reconnecting WebSocket...", flush=True)
|
| 163 |
+
env = await _make_env()
|
| 164 |
+
result = await env.reset(scenario_key=scenario_key)
|
| 165 |
obs = result.observation
|
| 166 |
history: List[str] = []
|
| 167 |
rewards: List[float] = []
|
|
|
|
| 215 |
print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning={judge_score:.3f} total={score:.3f}", flush=True)
|
| 216 |
|
| 217 |
success = score >= SUCCESS_THRESHOLD
|
| 218 |
+
return {"scenario_key": scenario_key, "score": score, "steps": len(rewards), "success": success}, env
|
| 219 |
|
| 220 |
|
| 221 |
# ββ task runners ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 229 |
|
| 230 |
results = []
|
| 231 |
for key in scenario_keys:
|
| 232 |
+
res, env = await run_episode(env, client, key)
|
| 233 |
results.append(res)
|
| 234 |
print(f"[RESULT] scenario={res['scenario_key']} score={res['score']:.3f} steps={res['steps']} success={str(res['success']).lower()}", flush=True)
|
| 235 |
|
|
|
|
| 242 |
|
| 243 |
async def main() -> None:
|
| 244 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 245 |
+
env = await _make_env()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
try:
|
| 248 |
await run_task("easy", EASY_SCENARIOS, env, client)
|