Spaces:
Sleeping
Sleeping
| def compute_reward(task: dict, agent_query: str, run_result: dict) -> dict: | |
| """ | |
| task = one of TASK dicts from tasks/ | |
| agent_query = the SQL string the agent submitted | |
| run_result = output from runner.run_query() | |
| Returns a dict: { value, syntax_ok, result_match_pct, plan_score, message } | |
| """ | |
| # ββ Step 1: Did the query even run? βββββββββββββββββββββββββββββββββββββββ | |
| syntax_ok = (run_result["error"] is None) | |
| if not syntax_ok: | |
| # Give tiny credit for trying (not zero, so agent gets gradient signal) | |
| return { | |
| "value": 0.05, | |
| "syntax_ok": False, | |
| "result_match_pct": 0.0, | |
| "plan_score": 0.0, | |
| "message": f"Syntax error: {run_result['error'][:100]}", | |
| } | |
| # ββ Step 2: Did we get the right rows? ββββββββββββββββββββββββββββββββββββ | |
| result_match_pct = 0.0 | |
| if task["expected_rows"] is not None: | |
| expected = task["expected_rows"] | |
| got = run_result["rows"] | |
| # Count how many expected rows are present in the result | |
| matches = sum(1 for row in expected if row in got) | |
| result_match_pct = matches / max(len(expected), 1) | |
| # Penalize extra rows (returned too many rows = wrong query) | |
| if len(got) > len(expected) * 2: | |
| result_match_pct *= 0.7 # 30% penalty for bloated results | |
| else: | |
| # Hard task: no fixed rows β give full match credit if query runs | |
| result_match_pct = 1.0 | |
| # ββ Step 3: Is the query plan good? (hard task only) βββββββββββββββββββββ | |
| plan_score = 0.0 | |
| if task.get("check_plan"): | |
| query_upper = agent_query.upper() | |
| good_patterns = task.get("good_patterns", []) | |
| # Each good pattern found = partial credit | |
| found = sum(1 for p in good_patterns if p.upper() in query_upper) | |
| plan_score = found / max(len(good_patterns), 1) | |
| # Also penalize if they still use correlated subquery pattern | |
| if "WHERE" in query_upper and "SELECT AVG" in query_upper: | |
| plan_score *= 0.3 # Heavy penalty β they didn't really optimize | |
| # ββ Step 4: Combine into final score ββββββββββββββββββββββββββββββββββββββ | |
| # Weights: syntax 20% + correctness 60% + plan 20% | |
| base_score = 0.2 + (0.6 * result_match_pct) + (0.2 * plan_score) | |
| # Penalize absurdly long queries (e.g. agent spams SELECT *) | |
| length_penalty = max(0.0, (len(agent_query) - 800) / 2000) | |
| final = max(0.0, min(1.0, base_score - length_penalty)) | |
| status = "perfect" if final >= 0.99 else "partial" if final > 0.2 else "wrong" | |
| msg = f"{status} | rows matched: {result_match_pct:.0%} | plan: {plan_score:.0%}" | |
| return { | |
| "value": round(final, 3), | |
| "syntax_ok": True, | |
| "result_match_pct": result_match_pct, | |
| "plan_score": plan_score, | |
| "message": msg, | |
| } |