Spaces:
Sleeping
Sleeping
import openai | |
from config import PROJECT_ID, DATASET_ID | |
from utils.bigquery_utils import get_bigquery_schema_info | |
def sql_generation_agent(state): | |
"""Generates a SQL query based on the natural language query and sample data.""" | |
natural_language_query = state["sql_query"] | |
relevant_tables = state.get("relevant_tables", []) | |
sample_data = state.get("sample_data", {}) | |
client = state["client"] | |
if client is None: | |
return {"generated_sql": "-- Error: Failed to connect to BigQuery."} | |
schema_info = get_bigquery_schema_info(client, PROJECT_ID, DATASET_ID) | |
# Format the schema for the prompt | |
schema_text = "" | |
for table_name, columns in schema_info.items(): | |
if f"{DATASET_ID}.{table_name}" in relevant_tables: | |
schema_text += f"- **{DATASET_ID}.{table_name}** ({', '.join(columns)})\n" | |
# Format sample data for the prompt | |
sample_data_text = "" | |
for table, rows in sample_data.items(): | |
if isinstance(rows, list) and rows: | |
sample_data_text += f"\n**Sample data from {table}:**\n" | |
# Get column names from the first row | |
columns = list(rows[0].keys()) | |
sample_data_text += "| " + " | ".join(columns) + " |\n" | |
sample_data_text += "| " + " | ".join(["---"] * len(columns)) + " |\n" | |
# Add row data | |
for row in rows: | |
sample_data_text += "| " + " | ".join([str(row.get(col, "")) for col in columns]) + " |\n" | |
prompt = f""" | |
Generate a BigQuery SQL query to answer the following question: | |
**Question:** "{natural_language_query}" | |
**Relevant Tables Schema:** | |
{schema_text} | |
**Sample Data:** | |
{sample_data_text} | |
**Rules:** | |
- Use only the provided tables with their full dataset.table_name format (e.g., {DATASET_ID}.users). | |
- Ensure correct column names as shown in the schema. | |
- Use appropriate joins based on the relationships visible in the sample data. | |
- Use BigQuery SQL syntax. | |
- Return ONLY the SQL query without any explanations or markdown formatting. | |
""" | |
response = openai.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=[{"role": "user", "content": prompt}], | |
temperature=0.0 | |
) | |
generated_sql = response.choices[0].message.content.strip() | |
# Remove markdown code block formatting if present | |
if generated_sql.startswith("```sql"): | |
generated_sql = generated_sql.replace("```sql", "").replace("```", "").strip() | |
return {"generated_sql": generated_sql} |