multi_agentic_sql_generator / agents /table_selection.py
Gonalb's picture
init commit
05e3517
import json
import openai
from config import PROJECT_ID, DATASET_ID
from utils.bigquery_utils import get_bigquery_schema_info
def table_selection_agent(state):
"""Identifies relevant tables for the natural language query based on schema."""
natural_language_query = state["sql_query"]
client = state["client"]
if client is None:
return {"relevant_tables": [], "error": "Failed to connect to BigQuery."}
schema_info = get_bigquery_schema_info(client, PROJECT_ID, DATASET_ID)
# Format the schema for the prompt
schema_text = ""
for table_name, columns in schema_info.items():
schema_text += f"- **{DATASET_ID}.{table_name}** ({', '.join(columns)})\n"
prompt = f"""
Based on the following natural language query and BigQuery schema, identify the tables that would be needed to answer the query.
**Query:** "{natural_language_query}"
**BigQuery Schema:**
{schema_text}
Analyze the query and determine which tables contain the necessary information.
IMPORTANT: Return ONLY a raw JSON array of table names without any markdown formatting, code blocks, or explanations.
Example of correct response format:
["{DATASET_ID}.users", "{DATASET_ID}.orders"]
Example of INCORRECT response format:
```json
["{DATASET_ID}.users", "{DATASET_ID}.orders"]
```
DO NOT use code blocks, backticks, or any other formatting. Return ONLY the raw JSON array.
"""
try:
response = openai.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": prompt}],
temperature=0.0
)
# Get the content from the response
content = response.choices[0].message.content.strip()
# Remove markdown code block formatting if present
if content.startswith("```"):
# Extract content between the code block markers
parts = content.split("```")
if len(parts) >= 3: # There should be at least 3 parts if there are code blocks
content = parts[1]
# If there's a language identifier (like json), remove it
if content.startswith("json"):
content = content.replace("json", "", 1).strip()
# Parse the JSON
relevant_tables = json.loads(content)
print(f"Parsed relevant tables: {relevant_tables}")
return {"relevant_tables": relevant_tables}
except json.JSONDecodeError as e:
print(f"JSON Decode Error: {e}")
print(f"Response content: {response.choices[0].message.content}")
return {"relevant_tables": [], "error": "Invalid JSON response from OpenAI"}
except Exception as e:
print(f"Unexpected error: {e}")
return {"relevant_tables": [], "error": f"Error: {str(e)}"}