Spaces:
Sleeping
Sleeping
import json | |
import openai | |
from config import PROJECT_ID, DATASET_ID | |
from utils.bigquery_utils import get_bigquery_schema_info | |
def table_selection_agent(state): | |
"""Identifies relevant tables for the natural language query based on schema.""" | |
natural_language_query = state["sql_query"] | |
client = state["client"] | |
if client is None: | |
return {"relevant_tables": [], "error": "Failed to connect to BigQuery."} | |
schema_info = get_bigquery_schema_info(client, PROJECT_ID, DATASET_ID) | |
# Format the schema for the prompt | |
schema_text = "" | |
for table_name, columns in schema_info.items(): | |
schema_text += f"- **{DATASET_ID}.{table_name}** ({', '.join(columns)})\n" | |
prompt = f""" | |
Based on the following natural language query and BigQuery schema, identify the tables that would be needed to answer the query. | |
**Query:** "{natural_language_query}" | |
**BigQuery Schema:** | |
{schema_text} | |
Analyze the query and determine which tables contain the necessary information. | |
IMPORTANT: Return ONLY a raw JSON array of table names without any markdown formatting, code blocks, or explanations. | |
Example of correct response format: | |
["{DATASET_ID}.users", "{DATASET_ID}.orders"] | |
Example of INCORRECT response format: | |
```json | |
["{DATASET_ID}.users", "{DATASET_ID}.orders"] | |
``` | |
DO NOT use code blocks, backticks, or any other formatting. Return ONLY the raw JSON array. | |
""" | |
try: | |
response = openai.chat.completions.create( | |
model="gpt-4o", | |
messages=[{"role": "user", "content": prompt}], | |
temperature=0.0 | |
) | |
# Get the content from the response | |
content = response.choices[0].message.content.strip() | |
# Remove markdown code block formatting if present | |
if content.startswith("```"): | |
# Extract content between the code block markers | |
parts = content.split("```") | |
if len(parts) >= 3: # There should be at least 3 parts if there are code blocks | |
content = parts[1] | |
# If there's a language identifier (like json), remove it | |
if content.startswith("json"): | |
content = content.replace("json", "", 1).strip() | |
# Parse the JSON | |
relevant_tables = json.loads(content) | |
print(f"Parsed relevant tables: {relevant_tables}") | |
return {"relevant_tables": relevant_tables} | |
except json.JSONDecodeError as e: | |
print(f"JSON Decode Error: {e}") | |
print(f"Response content: {response.choices[0].message.content}") | |
return {"relevant_tables": [], "error": "Invalid JSON response from OpenAI"} | |
except Exception as e: | |
print(f"Unexpected error: {e}") | |
return {"relevant_tables": [], "error": f"Error: {str(e)}"} |