Spaces:
Sleeping
Sleeping
File size: 2,615 Bytes
05e3517 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import openai
from config import PROJECT_ID, DATASET_ID
from utils.bigquery_utils import get_bigquery_schema_info
def sql_generation_agent(state):
"""Generates a SQL query based on the natural language query and sample data."""
natural_language_query = state["sql_query"]
relevant_tables = state.get("relevant_tables", [])
sample_data = state.get("sample_data", {})
client = state["client"]
if client is None:
return {"generated_sql": "-- Error: Failed to connect to BigQuery."}
schema_info = get_bigquery_schema_info(client, PROJECT_ID, DATASET_ID)
# Format the schema for the prompt
schema_text = ""
for table_name, columns in schema_info.items():
if f"{DATASET_ID}.{table_name}" in relevant_tables:
schema_text += f"- **{DATASET_ID}.{table_name}** ({', '.join(columns)})\n"
# Format sample data for the prompt
sample_data_text = ""
for table, rows in sample_data.items():
if isinstance(rows, list) and rows:
sample_data_text += f"\n**Sample data from {table}:**\n"
# Get column names from the first row
columns = list(rows[0].keys())
sample_data_text += "| " + " | ".join(columns) + " |\n"
sample_data_text += "| " + " | ".join(["---"] * len(columns)) + " |\n"
# Add row data
for row in rows:
sample_data_text += "| " + " | ".join([str(row.get(col, "")) for col in columns]) + " |\n"
prompt = f"""
Generate a BigQuery SQL query to answer the following question:
**Question:** "{natural_language_query}"
**Relevant Tables Schema:**
{schema_text}
**Sample Data:**
{sample_data_text}
**Rules:**
- Use only the provided tables with their full dataset.table_name format (e.g., {DATASET_ID}.users).
- Ensure correct column names as shown in the schema.
- Use appropriate joins based on the relationships visible in the sample data.
- Use BigQuery SQL syntax.
- Return ONLY the SQL query without any explanations or markdown formatting.
"""
response = openai.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
temperature=0.0
)
generated_sql = response.choices[0].message.content.strip()
# Remove markdown code block formatting if present
if generated_sql.startswith("```sql"):
generated_sql = generated_sql.replace("```sql", "").replace("```", "").strip()
return {"generated_sql": generated_sql} |