Spaces:
Sleeping
Sleeping
File size: 2,898 Bytes
05e3517 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import json
import openai
from config import PROJECT_ID, DATASET_ID
from utils.bigquery_utils import get_bigquery_schema_info
def table_selection_agent(state):
"""Identifies relevant tables for the natural language query based on schema."""
natural_language_query = state["sql_query"]
client = state["client"]
if client is None:
return {"relevant_tables": [], "error": "Failed to connect to BigQuery."}
schema_info = get_bigquery_schema_info(client, PROJECT_ID, DATASET_ID)
# Format the schema for the prompt
schema_text = ""
for table_name, columns in schema_info.items():
schema_text += f"- **{DATASET_ID}.{table_name}** ({', '.join(columns)})\n"
prompt = f"""
Based on the following natural language query and BigQuery schema, identify the tables that would be needed to answer the query.
**Query:** "{natural_language_query}"
**BigQuery Schema:**
{schema_text}
Analyze the query and determine which tables contain the necessary information.
IMPORTANT: Return ONLY a raw JSON array of table names without any markdown formatting, code blocks, or explanations.
Example of correct response format:
["{DATASET_ID}.users", "{DATASET_ID}.orders"]
Example of INCORRECT response format:
```json
["{DATASET_ID}.users", "{DATASET_ID}.orders"]
```
DO NOT use code blocks, backticks, or any other formatting. Return ONLY the raw JSON array.
"""
try:
response = openai.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": prompt}],
temperature=0.0
)
# Get the content from the response
content = response.choices[0].message.content.strip()
# Remove markdown code block formatting if present
if content.startswith("```"):
# Extract content between the code block markers
parts = content.split("```")
if len(parts) >= 3: # There should be at least 3 parts if there are code blocks
content = parts[1]
# If there's a language identifier (like json), remove it
if content.startswith("json"):
content = content.replace("json", "", 1).strip()
# Parse the JSON
relevant_tables = json.loads(content)
print(f"Parsed relevant tables: {relevant_tables}")
return {"relevant_tables": relevant_tables}
except json.JSONDecodeError as e:
print(f"JSON Decode Error: {e}")
print(f"Response content: {response.choices[0].message.content}")
return {"relevant_tables": [], "error": "Invalid JSON response from OpenAI"}
except Exception as e:
print(f"Unexpected error: {e}")
return {"relevant_tables": [], "error": f"Error: {str(e)}"} |