Gonalb's picture
init commit
05e3517
import openai
from config import PROJECT_ID, DATASET_ID
from utils.bigquery_utils import get_bigquery_schema_info
def sql_generation_agent(state):
"""Generates a SQL query based on the natural language query and sample data."""
natural_language_query = state["sql_query"]
relevant_tables = state.get("relevant_tables", [])
sample_data = state.get("sample_data", {})
client = state["client"]
if client is None:
return {"generated_sql": "-- Error: Failed to connect to BigQuery."}
schema_info = get_bigquery_schema_info(client, PROJECT_ID, DATASET_ID)
# Format the schema for the prompt
schema_text = ""
for table_name, columns in schema_info.items():
if f"{DATASET_ID}.{table_name}" in relevant_tables:
schema_text += f"- **{DATASET_ID}.{table_name}** ({', '.join(columns)})\n"
# Format sample data for the prompt
sample_data_text = ""
for table, rows in sample_data.items():
if isinstance(rows, list) and rows:
sample_data_text += f"\n**Sample data from {table}:**\n"
# Get column names from the first row
columns = list(rows[0].keys())
sample_data_text += "| " + " | ".join(columns) + " |\n"
sample_data_text += "| " + " | ".join(["---"] * len(columns)) + " |\n"
# Add row data
for row in rows:
sample_data_text += "| " + " | ".join([str(row.get(col, "")) for col in columns]) + " |\n"
prompt = f"""
Generate a BigQuery SQL query to answer the following question:
**Question:** "{natural_language_query}"
**Relevant Tables Schema:**
{schema_text}
**Sample Data:**
{sample_data_text}
**Rules:**
- Use only the provided tables with their full dataset.table_name format (e.g., {DATASET_ID}.users).
- Ensure correct column names as shown in the schema.
- Use appropriate joins based on the relationships visible in the sample data.
- Use BigQuery SQL syntax.
- Return ONLY the SQL query without any explanations or markdown formatting.
"""
response = openai.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
temperature=0.0
)
generated_sql = response.choices[0].message.content.strip()
# Remove markdown code block formatting if present
if generated_sql.startswith("```sql"):
generated_sql = generated_sql.replace("```sql", "").replace("```", "").strip()
return {"generated_sql": generated_sql}