debug / features /query_optimizer.py
jayendra26's picture
Initial commit
f8b106f
Raw
History Blame Contribute Delete
2.54 kB
"""
features/query_optimizer.py
============================
Detects and fixes inefficient SQL queries.
Currently handles:
- COUNT(*) β†’ SUM(column) replacement when task requires it
- General structural improvements
This is the "query optimization" step in the RL action sequence.
"""
class QueryOptimizer:
"""
Optimizes a SQL query based on the task's known issue type.
Uses rule-based pattern matching β€” no external libraries needed.
"""
def optimize(self, query: str, task: dict) -> str:
"""
Optimize the query if the task requires it.
Args:
query (str): The current SQL query (may be already partially fixed).
task (dict): The task dict describing what the issue is.
Returns:
str: The optimized SQL query.
"""
# Only apply optimization if the task specifies it's needed
if not task.get("needs_optimization", False):
# No optimization needed β€” return the correct query directly
# This simulates the agent "understanding" the fix is just the correct one
return task.get("correct_query", query)
# --- OPTIMIZATION RULE: COUNT(*) β†’ SUM(amount) ---
if task["issue_type"] == "aggregation_bug":
optimized = self._fix_count_to_sum(query)
return optimized
# Default: return the task's known correct query
return task.get("correct_query", query)
def _fix_count_to_sum(self, query: str) -> str:
"""
Replace COUNT(*) or COUNT(*) AS alias with SUM(amount) AS alias.
Args:
query (str): The SQL query string.
Returns:
str: Query with COUNT replaced by SUM(amount).
"""
import re
# Pattern: COUNT(*) followed by optional alias AS <name>
pattern = r"COUNT\(\*\)(\s+AS\s+\w+)?"
replacement = r"SUM(amount)\1"
optimized = re.sub(pattern, replacement, query, flags=re.IGNORECASE)
return optimized
def describe_optimization(self, original: str, optimized: str) -> str:
"""
Generate a human-readable description of what was changed.
Args:
original (str): Original query.
optimized (str): Optimized query.
Returns:
str: Description of the optimization applied.
"""
if original == optimized:
return "No optimization needed."
return f"Replaced COUNT(*) with SUM(amount) to compute correct revenue totals."