Spaces:
Sleeping
Sleeping
from langgraph.graph import StateGraph, START, END | |
from typing import TypedDict, Optional | |
from agents.table_selection import table_selection_agent | |
from agents.data_retrieval import sample_data_retrieval_agent | |
from agents.sql_generation import sql_generation_agent | |
from agents.validation import query_validation_and_optimization | |
from agents.execution import execution_agent | |
from utils.bigquery_utils import init_bigquery_connection | |
# Define the state schema | |
class SQLExecutionState(TypedDict): | |
sql_query: str # Natural language query | |
client: Optional[object] # BigQuery client | |
relevant_tables: Optional[list] # Tables identified as relevant | |
sample_data: Optional[dict] # Sample data from relevant tables | |
generated_sql: Optional[str] # The actual SQL query (not JSON) | |
validation_result: Optional[dict] | |
optimized_sql: Optional[str] | |
execution_result: Optional[dict] | |
def initialize_client(state: SQLExecutionState) -> SQLExecutionState: | |
"""Initialize the BigQuery client and add it to the state.""" | |
client = init_bigquery_connection() | |
return {"client": client} | |
def create_workflow(): | |
"""Create and return the workflow graph.""" | |
# Initialize the LangGraph Workflow | |
graph = StateGraph(state_schema=SQLExecutionState) | |
# Add nodes | |
graph.add_node("Initialize Client", initialize_client) | |
graph.add_node("Table Selection", table_selection_agent) | |
graph.add_node("Sample Data Retrieval", sample_data_retrieval_agent) | |
graph.add_node("SQL Generation", sql_generation_agent) | |
graph.add_node("Query Validation & Optimization", query_validation_and_optimization) | |
graph.add_node("SQL Execution", execution_agent) | |
# Define execution flow | |
graph.add_edge(START, "Initialize Client") | |
graph.add_edge("Initialize Client", "Table Selection") | |
graph.add_edge("Table Selection", "Sample Data Retrieval") | |
graph.add_edge("Sample Data Retrieval", "SQL Generation") | |
graph.add_edge("SQL Generation", "Query Validation & Optimization") | |
graph.add_edge("Query Validation & Optimization", "SQL Execution") | |
graph.add_edge("SQL Execution", END) | |
# Compile the graph | |
return graph.compile() |