Spaces:
Running
Running
| """ | |
| Reward function for the SQL Debug Environment. | |
| Reward is computed at every step (not just end of episode). | |
| This provides dense, meaningful signal for RL training. | |
| Reward components: | |
| - correctness: 0.0β0.6 (row-level match vs expected) | |
| - efficiency: 0.0β0.2 (bonus for solving quickly) | |
| - syntax_progress: 0.0β0.1 (valid SQL even if wrong content) | |
| - schema_bonus: 0.0β0.1 (correct tables/columns referenced) | |
| - penalty: 0.0 to 0.2 (deduction for bad actions) | |
| Total range: 0.0 to 1.0 (clamped to [0.0, 1.0]) | |
| """ | |
| from typing import Optional, List, Dict, Any | |
| from .models import SQLDebugReward | |
| MIN_STRICT_SCORE = 0.001 | |
| MAX_STRICT_SCORE = 0.999 | |
| def _strict_score(value: float) -> float: | |
| return round(min(MAX_STRICT_SCORE, max(MIN_STRICT_SCORE, value)), 4) | |
| def compute_reward( | |
| action_type: str, | |
| query_result: Optional[Dict[str, Any]], | |
| grade_score: float, | |
| steps_taken: int, | |
| max_steps: int, | |
| previous_best_score: float, | |
| schema_tables: List[str], | |
| submitted_query: Optional[str] = None, | |
| ) -> SQLDebugReward: | |
| """ | |
| Compute the full reward for a step. | |
| Args: | |
| action_type: The action taken this step | |
| query_result: Result dict from EpisodeDatabase.execute_query() | |
| grade_score: strict (0, 1) score from task grader | |
| steps_taken: How many steps have been used (1-indexed) | |
| max_steps: Maximum steps for this task | |
| previous_best_score: Best grade score seen so far | |
| schema_tables: List of valid table names in this task's DB | |
| submitted_query: The SQL query string (if action was submit_query) | |
| """ | |
| correctness = 0.0 | |
| efficiency = 0.0 | |
| syntax_progress = 0.0 | |
| schema_bonus = 0.0 | |
| penalty = 0.0 # deduction magnitude (non-negative) | |
| if action_type == "submit_query": | |
| # Correctness: primary signal | |
| correctness = min(0.6, grade_score * 0.6) | |
| # Syntax progress: reward for at least getting a valid query | |
| if query_result and query_result.get("success"): | |
| syntax_progress = 0.1 | |
| elif query_result and not query_result.get("success"): | |
| # Partially reward if it's getting closer (fewer errors) | |
| error = query_result.get("error_message", "") | |
| if "no such column" in error.lower(): | |
| syntax_progress = 0.03 # Structure is right but wrong column | |
| elif "no such table" in error.lower(): | |
| syntax_progress = 0.01 | |
| else: | |
| syntax_progress = 0.0 | |
| # Schema bonus: correct table references | |
| if submitted_query and schema_tables: | |
| query_upper = submitted_query.upper() | |
| tables_referenced = sum( | |
| 1 for t in schema_tables if t.upper() in query_upper | |
| ) | |
| schema_bonus = min(0.1, (tables_referenced / len(schema_tables)) * 0.1) | |
| # Efficiency bonus: reward solving with fewer steps | |
| if grade_score >= 0.95: # Near-perfect solution | |
| steps_fraction = steps_taken / max_steps | |
| if steps_fraction <= 0.3: | |
| efficiency = 0.2 | |
| elif steps_fraction <= 0.5: | |
| efficiency = 0.15 | |
| elif steps_fraction <= 0.7: | |
| efficiency = 0.1 | |
| else: | |
| efficiency = 0.05 | |
| # Penalty: if score went DOWN from previous best (regressed) | |
| if grade_score < previous_best_score - 0.1: | |
| penalty = 0.05 | |
| elif action_type == "reset_query": | |
| # Penalize resetting β agent should be making progress | |
| penalty = 0.05 | |
| elif action_type in ("inspect_schema", "inspect_error", "inspect_sample"): | |
| # Free information actions β small positive for using schema info | |
| # (encourages agents to explore rather than blindly guess) | |
| syntax_progress = 0.01 | |
| # Penalty: approaching step limit (urgency signal) | |
| steps_remaining = max_steps - steps_taken | |
| if steps_remaining <= 2 and grade_score < 0.5: | |
| penalty += 0.03 | |
| total_raw = correctness + efficiency + syntax_progress + schema_bonus - penalty | |
| total = _strict_score(total_raw) | |
| breakdown = ( | |
| f"correctness={correctness:.3f} + " | |
| f"efficiency={efficiency:.3f} + " | |
| f"syntax={syntax_progress:.3f} + " | |
| f"schema={schema_bonus:.3f} + " | |
| f"penalty={penalty:.3f} = {total:.4f}" | |
| ) | |
| return SQLDebugReward( | |
| value=total, | |
| correctness=correctness, | |
| efficiency=efficiency, | |
| syntax_progress=syntax_progress, | |
| schema_bonus=schema_bonus, | |
| penalty=penalty, | |
| breakdown=breakdown | |
| ) | |