Gonalb's picture
init commit
05e3517
raw
history blame
2.2 kB
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Optional
from agents.table_selection import table_selection_agent
from agents.data_retrieval import sample_data_retrieval_agent
from agents.sql_generation import sql_generation_agent
from agents.validation import query_validation_and_optimization
from agents.execution import execution_agent
from utils.bigquery_utils import init_bigquery_connection
# Define the state schema
class SQLExecutionState(TypedDict):
sql_query: str # Natural language query
client: Optional[object] # BigQuery client
relevant_tables: Optional[list] # Tables identified as relevant
sample_data: Optional[dict] # Sample data from relevant tables
generated_sql: Optional[str] # The actual SQL query (not JSON)
validation_result: Optional[dict]
optimized_sql: Optional[str]
execution_result: Optional[dict]
def initialize_client(state: SQLExecutionState) -> SQLExecutionState:
"""Initialize the BigQuery client and add it to the state."""
client = init_bigquery_connection()
return {"client": client}
def create_workflow():
"""Create and return the workflow graph."""
# Initialize the LangGraph Workflow
graph = StateGraph(state_schema=SQLExecutionState)
# Add nodes
graph.add_node("Initialize Client", initialize_client)
graph.add_node("Table Selection", table_selection_agent)
graph.add_node("Sample Data Retrieval", sample_data_retrieval_agent)
graph.add_node("SQL Generation", sql_generation_agent)
graph.add_node("Query Validation & Optimization", query_validation_and_optimization)
graph.add_node("SQL Execution", execution_agent)
# Define execution flow
graph.add_edge(START, "Initialize Client")
graph.add_edge("Initialize Client", "Table Selection")
graph.add_edge("Table Selection", "Sample Data Retrieval")
graph.add_edge("Sample Data Retrieval", "SQL Generation")
graph.add_edge("SQL Generation", "Query Validation & Optimization")
graph.add_edge("Query Validation & Optimization", "SQL Execution")
graph.add_edge("SQL Execution", END)
# Compile the graph
return graph.compile()