import openai from config import PROJECT_ID, DATASET_ID from utils.bigquery_utils import get_bigquery_schema_info def sql_generation_agent(state): """Generates a SQL query based on the natural language query and sample data.""" natural_language_query = state["sql_query"] relevant_tables = state.get("relevant_tables", []) sample_data = state.get("sample_data", {}) client = state["client"] if client is None: return {"generated_sql": "-- Error: Failed to connect to BigQuery."} schema_info = get_bigquery_schema_info(client, PROJECT_ID, DATASET_ID) # Format the schema for the prompt schema_text = "" for table_name, columns in schema_info.items(): if f"{DATASET_ID}.{table_name}" in relevant_tables: schema_text += f"- **{DATASET_ID}.{table_name}** ({', '.join(columns)})\n" # Format sample data for the prompt sample_data_text = "" for table, rows in sample_data.items(): if isinstance(rows, list) and rows: sample_data_text += f"\n**Sample data from {table}:**\n" # Get column names from the first row columns = list(rows[0].keys()) sample_data_text += "| " + " | ".join(columns) + " |\n" sample_data_text += "| " + " | ".join(["---"] * len(columns)) + " |\n" # Add row data for row in rows: sample_data_text += "| " + " | ".join([str(row.get(col, "")) for col in columns]) + " |\n" prompt = f""" Generate a BigQuery SQL query to answer the following question: **Question:** "{natural_language_query}" **Relevant Tables Schema:** {schema_text} **Sample Data:** {sample_data_text} **Rules:** - Use only the provided tables with their full dataset.table_name format (e.g., {DATASET_ID}.users). - Ensure correct column names as shown in the schema. - Use appropriate joins based on the relationships visible in the sample data. - Use BigQuery SQL syntax. - Return ONLY the SQL query without any explanations or markdown formatting. """ response = openai.chat.completions.create( model="gpt-4o-mini", messages=[{"role": "user", "content": prompt}], temperature=0.0 ) generated_sql = response.choices[0].message.content.strip() # Remove markdown code block formatting if present if generated_sql.startswith("```sql"): generated_sql = generated_sql.replace("```sql", "").replace("```", "").strip() return {"generated_sql": generated_sql}