GeoQuery / backend /core /query_planner.py
GerardCB's picture
Deploy to Spaces (Final Clean)
4851501
"""
Multi-Step Query Planner
Detects complex queries that require multiple datasets or operations,
decomposes them into atomic steps, and orchestrates execution.
"""
import json
import logging
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from enum import Enum
logger = logging.getLogger(__name__)
class StepType(Enum):
"""Types of query steps."""
DATA_QUERY = "data_query" # Simple data retrieval
AGGREGATION = "aggregation" # COUNT, SUM, GROUP BY
COMPARISON = "comparison" # Comparing results from previous steps
SPATIAL_JOIN = "spatial_join" # Joining datasets spatially
COMBINE = "combine" # Merge/combine step results
@dataclass
class QueryStep:
"""A single atomic step in a query plan."""
step_id: str
step_type: StepType
description: str
tables_needed: List[str]
sql_template: Optional[str] = None
depends_on: List[str] = field(default_factory=list)
result_name: str = "" # Name for intermediate result
def to_dict(self) -> Dict[str, Any]:
return {
"step_id": self.step_id,
"step_type": self.step_type.value,
"description": self.description,
"tables_needed": self.tables_needed,
"sql_template": self.sql_template,
"depends_on": self.depends_on,
"result_name": self.result_name
}
@dataclass
class QueryPlan:
"""Complete execution plan for a complex query."""
original_query: str
is_complex: bool
steps: List[QueryStep] = field(default_factory=list)
parallel_groups: List[List[str]] = field(default_factory=list) # Steps that can run in parallel
final_combination_logic: str = ""
def to_dict(self) -> Dict[str, Any]:
return {
"original_query": self.original_query,
"is_complex": self.is_complex,
"steps": [s.to_dict() for s in self.steps],
"parallel_groups": self.parallel_groups,
"final_combination_logic": self.final_combination_logic
}
class QueryPlanner:
"""
Multi-step query planning service.
Analyzes queries to determine complexity and decomposes
complex queries into executable atomic steps.
"""
_instance = None
# Keywords that often indicate multi-step queries
COMPLEXITY_INDICATORS = [
"compare", "comparison", "versus", "vs",
"more than", "less than", "higher than", "lower than",
"both", "and also", "as well as",
"ratio", "percentage", "proportion",
"correlation", "relationship between",
"combine", "merge", "together with",
"relative to", "compared to",
"difference between", "gap between"
]
# Keywords indicating multiple distinct data types
MULTI_DOMAIN_KEYWORDS = {
"health": ["hospital", "clinic", "healthcare", "health", "medical"],
"education": ["school", "university", "education", "college", "student"],
"infrastructure": ["road", "bridge", "infrastructure", "building"],
"environment": ["forest", "water", "environment", "park", "protected"],
"population": ["population", "demographic", "census", "people", "resident"]
}
def __new__(cls):
if cls._instance is None:
cls._instance = super(QueryPlanner, cls).__new__(cls)
cls._instance.initialized = False
return cls._instance
def __init__(self):
if self.initialized:
return
self.initialized = True
def detect_complexity(self, query: str) -> Dict[str, Any]:
"""
Analyze a query to determine if it requires multi-step planning.
Returns:
{
"is_complex": bool,
"reason": str,
"detected_domains": List[str],
"complexity_indicators": List[str]
}
"""
query_lower = query.lower()
# Check for complexity indicators
found_indicators = [
ind for ind in self.COMPLEXITY_INDICATORS
if ind in query_lower
]
# Check for multiple data domains
found_domains = []
for domain, keywords in self.MULTI_DOMAIN_KEYWORDS.items():
if any(kw in query_lower for kw in keywords):
found_domains.append(domain)
# Determine complexity
is_complex = (
len(found_indicators) > 0 and len(found_domains) >= 2
) or (
len(found_domains) >= 3
) or (
any(x in query_lower for x in ["compare", "ratio", "correlation", "versus", " vs "])
and len(found_domains) >= 2
)
reason = ""
if is_complex:
if len(found_domains) >= 2:
reason = f"Query involves multiple data domains: {', '.join(found_domains)}"
if found_indicators:
reason += f". Contains comparison/aggregation keywords: {', '.join(found_indicators[:3])}"
return {
"is_complex": is_complex,
"reason": reason,
"detected_domains": found_domains,
"complexity_indicators": found_indicators
}
async def plan_query(
self,
query: str,
available_tables: List[str],
llm_gateway
) -> QueryPlan:
"""
Create an execution plan for a complex query.
Uses LLM to decompose the query into atomic steps.
"""
from backend.core.prompts import QUERY_PLANNING_PROMPT
# Build table context
table_list = "\n".join(f"- {t}" for t in available_tables)
prompt = QUERY_PLANNING_PROMPT.format(
user_query=query,
available_tables=table_list
)
try:
response = await llm_gateway.generate_response(prompt, [])
# Parse JSON response
response_clean = response.strip()
if response_clean.startswith("```json"):
response_clean = response_clean[7:]
if response_clean.startswith("```"):
response_clean = response_clean[3:]
if response_clean.endswith("```"):
response_clean = response_clean[:-3]
plan_data = json.loads(response_clean.strip())
# Convert to QueryPlan
steps = []
for i, step_data in enumerate(plan_data.get("steps", [])):
step = QueryStep(
step_id=f"step_{i+1}",
step_type=StepType(step_data.get("type", "data_query")),
description=step_data.get("description", ""),
tables_needed=step_data.get("tables", []),
sql_template=step_data.get("sql_hint", None),
depends_on=step_data.get("depends_on", []),
result_name=step_data.get("result_name", f"result_{i+1}")
)
steps.append(step)
# Determine parallel groups (steps with no dependencies can run together)
parallel_groups = self._compute_parallel_groups(steps)
return QueryPlan(
original_query=query,
is_complex=True,
steps=steps,
parallel_groups=parallel_groups,
final_combination_logic=plan_data.get("combination_logic", "")
)
except Exception as e:
logger.error(f"Query planning failed: {e}")
# Return single-step fallback
return QueryPlan(
original_query=query,
is_complex=False,
steps=[],
parallel_groups=[],
final_combination_logic=""
)
def _compute_parallel_groups(self, steps: List[QueryStep]) -> List[List[str]]:
"""
Compute which steps can be executed in parallel.
Steps with no dependencies (or only completed dependencies)
can run together.
"""
if not steps:
return []
groups = []
executed = set()
remaining = {s.step_id: s for s in steps}
while remaining:
# Find steps whose dependencies are all satisfied
ready = [
step_id for step_id, step in remaining.items()
if all(dep in executed for dep in step.depends_on)
]
if not ready:
# Avoid infinite loop - add remaining as sequential
ready = list(remaining.keys())[:1]
groups.append(ready)
for step_id in ready:
executed.add(step_id)
del remaining[step_id]
return groups
def create_simple_plan(self, query: str) -> QueryPlan:
"""Create a simple single-step plan for non-complex queries."""
return QueryPlan(
original_query=query,
is_complex=False,
steps=[
QueryStep(
step_id="step_1",
step_type=StepType.DATA_QUERY,
description="Execute query directly",
tables_needed=[],
depends_on=[]
)
],
parallel_groups=[["step_1"]]
)
# Singleton accessor
_query_planner: Optional[QueryPlanner] = None
def get_query_planner() -> QueryPlanner:
"""Get the singleton query planner instance."""
global _query_planner
if _query_planner is None:
_query_planner = QueryPlanner()
return _query_planner