|
import json |
|
import openai |
|
from config import PROJECT_ID, DATASET_ID |
|
from utils.bigquery_utils import get_bigquery_schema_info |
|
|
|
def table_selection_agent(state): |
|
"""Identifies relevant tables for the natural language query based on schema.""" |
|
natural_language_query = state["sql_query"] |
|
client = state["client"] |
|
|
|
if client is None: |
|
return {"relevant_tables": [], "error": "Failed to connect to BigQuery."} |
|
|
|
schema_info = get_bigquery_schema_info(client, PROJECT_ID, DATASET_ID) |
|
|
|
|
|
schema_text = "" |
|
for table_name, columns in schema_info.items(): |
|
schema_text += f"- **{DATASET_ID}.{table_name}** ({', '.join(columns)})\n" |
|
|
|
prompt = f""" |
|
Based on the following natural language query and BigQuery schema, identify the tables that would be needed to answer the query. |
|
|
|
**Query:** "{natural_language_query}" |
|
|
|
**BigQuery Schema:** |
|
{schema_text} |
|
|
|
Analyze the query and determine which tables contain the necessary information. |
|
|
|
IMPORTANT: Return ONLY a raw JSON array of table names without any markdown formatting, code blocks, or explanations. |
|
|
|
Example of correct response format: |
|
["{DATASET_ID}.users", "{DATASET_ID}.orders"] |
|
|
|
Example of INCORRECT response format: |
|
```json |
|
["{DATASET_ID}.users", "{DATASET_ID}.orders"] |
|
``` |
|
|
|
DO NOT use code blocks, backticks, or any other formatting. Return ONLY the raw JSON array. |
|
""" |
|
|
|
try: |
|
response = openai.chat.completions.create( |
|
model="gpt-4o", |
|
messages=[{"role": "user", "content": prompt}], |
|
temperature=0.0 |
|
) |
|
|
|
|
|
content = response.choices[0].message.content.strip() |
|
|
|
|
|
if content.startswith("```"): |
|
|
|
parts = content.split("```") |
|
if len(parts) >= 3: |
|
content = parts[1] |
|
|
|
if content.startswith("json"): |
|
content = content.replace("json", "", 1).strip() |
|
|
|
|
|
relevant_tables = json.loads(content) |
|
print(f"Parsed relevant tables: {relevant_tables}") |
|
return {"relevant_tables": relevant_tables} |
|
except json.JSONDecodeError as e: |
|
print(f"JSON Decode Error: {e}") |
|
print(f"Response content: {response.choices[0].message.content}") |
|
return {"relevant_tables": [], "error": "Invalid JSON response from OpenAI"} |
|
except Exception as e: |
|
print(f"Unexpected error: {e}") |
|
return {"relevant_tables": [], "error": f"Error: {str(e)}"} |