sql-debug-env / server /reward.py
md896's picture
Harden strict (0,1) scoring boundaries across runtime and config.
9b71d1b
"""
Reward function for the SQL Debug Environment.
Reward is computed at every step (not just end of episode).
This provides dense, meaningful signal for RL training.
Reward components:
- correctness: 0.0–0.6 (row-level match vs expected)
- efficiency: 0.0–0.2 (bonus for solving quickly)
- syntax_progress: 0.0–0.1 (valid SQL even if wrong content)
- schema_bonus: 0.0–0.1 (correct tables/columns referenced)
- penalty: 0.0 to 0.2 (deduction for bad actions)
Total range: 0.0 to 1.0 (clamped to [0.0, 1.0])
"""
from typing import Optional, List, Dict, Any
from .models import SQLDebugReward
MIN_STRICT_SCORE = 0.001
MAX_STRICT_SCORE = 0.999
def _strict_score(value: float) -> float:
return round(min(MAX_STRICT_SCORE, max(MIN_STRICT_SCORE, value)), 4)
def compute_reward(
action_type: str,
query_result: Optional[Dict[str, Any]],
grade_score: float,
steps_taken: int,
max_steps: int,
previous_best_score: float,
schema_tables: List[str],
submitted_query: Optional[str] = None,
) -> SQLDebugReward:
"""
Compute the full reward for a step.
Args:
action_type: The action taken this step
query_result: Result dict from EpisodeDatabase.execute_query()
grade_score: strict (0, 1) score from task grader
steps_taken: How many steps have been used (1-indexed)
max_steps: Maximum steps for this task
previous_best_score: Best grade score seen so far
schema_tables: List of valid table names in this task's DB
submitted_query: The SQL query string (if action was submit_query)
"""
correctness = 0.0
efficiency = 0.0
syntax_progress = 0.0
schema_bonus = 0.0
penalty = 0.0 # deduction magnitude (non-negative)
if action_type == "submit_query":
# Correctness: primary signal
correctness = min(0.6, grade_score * 0.6)
# Syntax progress: reward for at least getting a valid query
if query_result and query_result.get("success"):
syntax_progress = 0.1
elif query_result and not query_result.get("success"):
# Partially reward if it's getting closer (fewer errors)
error = query_result.get("error_message", "")
if "no such column" in error.lower():
syntax_progress = 0.03 # Structure is right but wrong column
elif "no such table" in error.lower():
syntax_progress = 0.01
else:
syntax_progress = 0.0
# Schema bonus: correct table references
if submitted_query and schema_tables:
query_upper = submitted_query.upper()
tables_referenced = sum(
1 for t in schema_tables if t.upper() in query_upper
)
schema_bonus = min(0.1, (tables_referenced / len(schema_tables)) * 0.1)
# Efficiency bonus: reward solving with fewer steps
if grade_score >= 0.95: # Near-perfect solution
steps_fraction = steps_taken / max_steps
if steps_fraction <= 0.3:
efficiency = 0.2
elif steps_fraction <= 0.5:
efficiency = 0.15
elif steps_fraction <= 0.7:
efficiency = 0.1
else:
efficiency = 0.05
# Penalty: if score went DOWN from previous best (regressed)
if grade_score < previous_best_score - 0.1:
penalty = 0.05
elif action_type == "reset_query":
# Penalize resetting β€” agent should be making progress
penalty = 0.05
elif action_type in ("inspect_schema", "inspect_error", "inspect_sample"):
# Free information actions β€” small positive for using schema info
# (encourages agents to explore rather than blindly guess)
syntax_progress = 0.01
# Penalty: approaching step limit (urgency signal)
steps_remaining = max_steps - steps_taken
if steps_remaining <= 2 and grade_score < 0.5:
penalty += 0.03
total_raw = correctness + efficiency + syntax_progress + schema_bonus - penalty
total = _strict_score(total_raw)
breakdown = (
f"correctness={correctness:.3f} + "
f"efficiency={efficiency:.3f} + "
f"syntax={syntax_progress:.3f} + "
f"schema={schema_bonus:.3f} + "
f"penalty={penalty:.3f} = {total:.4f}"
)
return SQLDebugReward(
value=total,
correctness=correctness,
efficiency=efficiency,
syntax_progress=syntax_progress,
schema_bonus=schema_bonus,
penalty=penalty,
breakdown=breakdown
)