FastAPI_KIG / api.py
adrienbrdne's picture
Update api.py
fcf819f verified
import logging
import time
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from typing import List, Dict, Any
# Import necessary components from your kig_core library
# Ensure kig_core is in the Python path or installed as a package
try:
from kig_core.config import settings # Loads config on import
from kig_core.schemas import PlannerState, KeyIssue as KigKeyIssue, GraphConfig
from kig_core.planner import build_graph
from kig_core.graph_client import neo4j_client # Import the initialized client instance
from langchain_core.messages import HumanMessage
except ImportError as e:
print(f"Error importing kig_core components: {e}")
print("Please ensure kig_core is in your Python path or installed.")
# You might want to exit or raise a clearer error if imports fail
raise
# Configure logging for the API
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Pydantic Models for API Request/Response ---
class KeyIssueRequest(BaseModel):
"""Request body containing the user's technical query."""
query: str
class KeyIssueResponse(BaseModel):
"""Response body containing the generated key issues."""
key_issues: List[KigKeyIssue] # Use the KeyIssue schema from kig_core
# --- Global Variables / State ---
# Keep the graph instance global for efficiency if desired,
# but consider potential concurrency issues if graph/LLMs have state.
# Rebuilding on each request is safer for statelessness.
app_graph = None # Will be initialized at startup
# --- Application Lifecycle (Startup/Shutdown) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Handles startup and shutdown events."""
global app_graph
logger.info("API starting up...")
# Initialize Neo4j client (already done on import by graph_client.py)
# Verify connection (optional, already done by graph_client on init)
try:
logger.info("Verifying Neo4j connection...")
neo4j_client._get_driver().verify_connectivity()
logger.info("Neo4j connection verified.")
except Exception as e:
logger.error(f"Neo4j connection verification failed on startup: {e}", exc_info=True)
# Decide if the app should fail to start
# raise RuntimeError("Failed to connect to Neo4j on startup.") from e
# Build the LangGraph application
logger.info("Building LangGraph application...")
try:
app_graph = build_graph()
logger.info("LangGraph application built successfully.")
except Exception as e:
logger.error(f"Failed to build LangGraph application on startup: {e}", exc_info=True)
# Decide if the app should fail to start
raise RuntimeError("Failed to build LangGraph on startup.") from e
yield # API runs here
# --- Shutdown ---
logger.info("API shutting down...")
# Close Neo4j connection (handled by atexit in graph_client.py)
# neo4j_client.close() # Usually not needed due to atexit registration
logger.info("Neo4j client closed (likely via atexit).")
logger.info("API shutdown complete.")
# --- FastAPI Application ---
app = FastAPI(
title="Key Issue Generator API",
description="API to generate Key Issues based on a technical query using LLMs and Neo4j.",
version="1.0.0",
lifespan=lifespan # Use the lifespan context manager
)
# --- API Endpoint ---
# API state check route
@app.get("/")
def read_root():
return {"status": "ok"}
@app.post("/generate-key-issues", response_model=KeyIssueResponse)
async def generate_issues(request: KeyIssueRequest):
"""
Accepts a technical query and returns a list of generated Key Issues.
"""
global app_graph
if app_graph is None:
logger.error("Graph application is not initialized.")
raise HTTPException(status_code=503, detail="Service Unavailable: Graph not initialized")
user_query = request.query
if not user_query:
raise HTTPException(status_code=400, detail="Query cannot be empty.")
logger.info(f"Received request to generate key issues for query: '{user_query[:100]}...'")
start_time = time.time()
try:
# --- Prepare Initial State for LangGraph ---
# Note: Ensure PlannerState aligns with what build_graph expects
initial_state: PlannerState = {
"user_query": user_query,
"messages": [HumanMessage(content=user_query)],
"plan": [],
"current_plan_step_index": -1, # Or as expected by your graph's entry point
"step_outputs": {},
"key_issues": [],
"error": None
}
# --- Define Configuration (e.g., Thread ID for Memory) ---
# Using a simple thread ID; adapt if using persistent memory
# import hashlib
# thread_id = hashlib.sha256(user_query.encode()).hexdigest()[:8]
# config: GraphConfig = {"configurable": {"thread_id": thread_id}}
# If not using memory, config can be simpler or empty based on LangGraph version
config: GraphConfig = {"configurable": {}} # Adjust if thread_id/memory is needed
# --- Execute the LangGraph Workflow ---
logger.info("Invoking LangGraph workflow...")
# Use invoke for a single result, or stream if you need intermediate steps
final_state = await app_graph.ainvoke(initial_state, config=config)
# If using stream:
# final_state = None
# async for step_state in app_graph.astream(initial_state, config=config):
# # Process intermediate states if needed
# node_name = list(step_state.keys())[0]
# logger.debug(f"Graph step completed: {node_name}")
# final_state = step_state[node_name] # Get the latest full state output
end_time = time.time()
logger.info(f"Workflow finished in {end_time - start_time:.2f} seconds.")
# --- Process Final Results ---
if final_state is None:
logger.error("Workflow execution did not produce a final state.")
raise HTTPException(status_code=500, detail="Workflow execution failed to produce a result.")
if final_state.get("error"):
error_msg = final_state.get("error", "Unknown error")
logger.error(f"Workflow failed with error: {error_msg}")
# Map internal errors to appropriate HTTP status codes
status_code = 500 # Internal Server Error by default
if "Neo4j" in error_msg or "connection" in error_msg.lower():
status_code = 503 # Service Unavailable (database issue)
elif "LLM error" in error_msg or "parse" in error_msg.lower():
status_code = 502 # Bad Gateway (issue with upstream LLM)
raise HTTPException(status_code=status_code, detail=f"Workflow failed: {error_msg}")
# --- Extract Key Issues ---
# Ensure the structure matches KeyIssueResponse and KigKeyIssue Pydantic model
generated_issues_data = final_state.get("key_issues", [])
# Validate and convert if necessary (Pydantic usually handles this via response_model)
try:
# Pydantic will validate against KeyIssueResponse -> List[KigKeyIssue]
response_data = {"key_issues": generated_issues_data}
logger.info(f"Successfully generated {len(generated_issues_data)} key issues.")
return response_data
except Exception as pydantic_error: # Catch potential validation errors
logger.error(f"Failed to validate final key issues against response model: {pydantic_error}", exc_info=True)
logger.error(f"Data that failed validation: {generated_issues_data}")
raise HTTPException(status_code=500, detail="Internal error: Failed to format key issues response.")
except HTTPException as http_exc:
# Re-raise HTTPExceptions directly
raise http_exc
except ConnectionError as e:
logger.error(f"Connection Error during API request: {e}", exc_info=True)
raise HTTPException(status_code=503, detail=f"Service Unavailable: {e}")
except ValueError as e:
logger.error(f"Value Error during API request: {e}", exc_info=True)
raise HTTPException(status_code=400, detail=f"Bad Request: {e}") # Often input validation issues
except Exception as e:
logger.error(f"An unexpected error occurred during API request: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal Server Error: An unexpected error occurred.")
# --- How to Run ---
if __name__ == "__main__":
# Make sure to set environment variables for config (NEO4J_URI, NEO4J_PASSWORD, GEMINI_API_KEY, etc.)
# or have a .env file in the same directory where you run this script.
print("Starting API server...")
print("Ensure required environment variables (e.g., NEO4J_URI, NEO4J_PASSWORD, GEMINI_API_KEY) are set or .env file is present.")
# Run with uvicorn: uvicorn api:app --reload --host 0.0.0.0 --port 8000
# The --reload flag is good for development. Remove it for production.
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True) # Use reload=False for production