Upload 13 files
Browse files- .env +34 -0
- Dockerfile +48 -0
- api.py +202 -0
- kig_core/config.py +89 -0
- kig_core/graph_client.py +91 -0
- kig_core/graph_operations.py +212 -0
- kig_core/llm_interface.py +69 -0
- kig_core/planner.py +254 -0
- kig_core/processing.py +127 -0
- kig_core/prompts.py +117 -0
- kig_core/schemas.py +55 -0
- kig_core/utils.py +41 -0
- requirements.txt +25 -0
.env
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Neo4j Credentials
|
2 |
+
NEO4J_URI="neo4j+s://4985272f.databases.neo4j.io"
|
3 |
+
NEO4J_USERNAME="neo4j"
|
4 |
+
NEO4J_PASSWORD="xxx" # Replace with your actual password
|
5 |
+
|
6 |
+
# API Keys
|
7 |
+
OPENAI_API_KEY="sk-xxx" # Replace if using OpenAI models
|
8 |
+
GEMINI_API_KEY="xxx" # Replace with your actual key
|
9 |
+
LANGSMITH_API_KEY="lsv2_pt_xxx" # Replace with your actual key (optional but recommended for tracing)
|
10 |
+
LANGCHAIN_PROJECT="KIG_Refactored" # Optional: For LangSmith tracing
|
11 |
+
|
12 |
+
# LLM Configuration
|
13 |
+
MAIN_LLM_MODEL="gemini-2.0-flash" # Or another preferred model
|
14 |
+
EVAL_LLM_MODEL="gemini-2.0-flash"
|
15 |
+
SUMMARIZE_LLM_MODEL="gemini-2.0-flash"
|
16 |
+
|
17 |
+
# Planner Configuration
|
18 |
+
PLAN_METHOD="generation" # or "modification"
|
19 |
+
USE_DETAILED_QUERY="false" # or "true"
|
20 |
+
|
21 |
+
# Graph Operations Configuration
|
22 |
+
CYPHER_GEN_METHOD="guided" # or "auto"
|
23 |
+
VALIDATE_CYPHER="false" # or "true"
|
24 |
+
EVAL_METHOD="binary" # or "score"
|
25 |
+
EVAL_THRESHOLD="0.7"
|
26 |
+
MAX_DOCS="10"
|
27 |
+
|
28 |
+
# Processing Configuration
|
29 |
+
# Define processing steps as a JSON string or handle differently if complex needed
|
30 |
+
PROCESS_STEPS='["summarize"]' # Example: Just summarize
|
31 |
+
COMPRESSION_METHOD="llm_lingua" # if used
|
32 |
+
COMPRESS_RATE="0.5" # if used
|
33 |
+
|
34 |
+
# Add other parameters as needed
|
Dockerfile
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use a more recent, slim Python base image
|
2 |
+
FROM python:3.10-slim
|
3 |
+
|
4 |
+
# Set the working directory in the container
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Prevent Python from writing pyc files to disc (optional)
|
8 |
+
ENV PYTHONDONTWRITEBYTECODE 1
|
9 |
+
# Ensure Python output is sent straight to terminal (useful for logs)
|
10 |
+
ENV PYTHONUNBUFFERED 1
|
11 |
+
|
12 |
+
# Upgrade pip
|
13 |
+
RUN python -m pip install --upgrade pip
|
14 |
+
|
15 |
+
# Copy the requirements file into the container
|
16 |
+
COPY requirements.txt .
|
17 |
+
|
18 |
+
# Install dependencies
|
19 |
+
# --no-cache-dir reduces image size
|
20 |
+
# --default-timeout=100 increases timeout for pip install
|
21 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
22 |
+
|
23 |
+
|
24 |
+
COPY .env .
|
25 |
+
|
26 |
+
# Copy the application code into the container
|
27 |
+
# This includes the API file and the core logic directory
|
28 |
+
COPY api.py .
|
29 |
+
COPY ./kig_core ./kig_core
|
30 |
+
|
31 |
+
|
32 |
+
# Command to run the Uvicorn server
|
33 |
+
# It will look for an object named 'app' in the 'api.py' file
|
34 |
+
# Runs on port 8000 and listens on all interfaces (0.0.0.0)
|
35 |
+
# Note: For production, consider removing --reload
|
36 |
+
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
|
37 |
+
|
38 |
+
# --- Notes ---
|
39 |
+
# Environment Variables:
|
40 |
+
# This Dockerfile assumes you will provide necessary environment variables
|
41 |
+
# (NEO4J_URI, NEO4J_PASSWORD, GEMINI_API_KEY, OPENAI_API_KEY, etc.)
|
42 |
+
# when running the container, for example using 'docker run -e VAR=value ...'
|
43 |
+
# or a docker-compose.yml file.
|
44 |
+
# DO NOT hardcode secrets directly in the Dockerfile.
|
45 |
+
#
|
46 |
+
# Cache Folders:
|
47 |
+
# Removed HF_HOME/TORCH_HOME as this app primarily uses external APIs (Gemini/OpenAI)
|
48 |
+
# and Neo4j, not local Hugging Face/PyTorch models needing specific cache dirs.
|
api.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
import uvicorn
|
4 |
+
from fastapi import FastAPI, HTTPException
|
5 |
+
from pydantic import BaseModel
|
6 |
+
from contextlib import asynccontextmanager
|
7 |
+
from typing import List, Dict, Any
|
8 |
+
|
9 |
+
# Import necessary components from your kig_core library
|
10 |
+
# Ensure kig_core is in the Python path or installed as a package
|
11 |
+
try:
|
12 |
+
from kig_core.config import settings # Loads config on import
|
13 |
+
from kig_core.schemas import PlannerState, KeyIssue as KigKeyIssue, GraphConfig
|
14 |
+
from kig_core.planner import build_graph
|
15 |
+
from kig_core.graph_client import neo4j_client # Import the initialized client instance
|
16 |
+
from langchain_core.messages import HumanMessage
|
17 |
+
except ImportError as e:
|
18 |
+
print(f"Error importing kig_core components: {e}")
|
19 |
+
print("Please ensure kig_core is in your Python path or installed.")
|
20 |
+
# You might want to exit or raise a clearer error if imports fail
|
21 |
+
raise
|
22 |
+
|
23 |
+
# Configure logging for the API
|
24 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
# --- Pydantic Models for API Request/Response ---
|
28 |
+
|
29 |
+
class KeyIssueRequest(BaseModel):
|
30 |
+
"""Request body containing the user's technical query."""
|
31 |
+
query: str
|
32 |
+
|
33 |
+
class KeyIssueResponse(BaseModel):
|
34 |
+
"""Response body containing the generated key issues."""
|
35 |
+
key_issues: List[KigKeyIssue] # Use the KeyIssue schema from kig_core
|
36 |
+
|
37 |
+
# --- Global Variables / State ---
|
38 |
+
# Keep the graph instance global for efficiency if desired,
|
39 |
+
# but consider potential concurrency issues if graph/LLMs have state.
|
40 |
+
# Rebuilding on each request is safer for statelessness.
|
41 |
+
app_graph = None # Will be initialized at startup
|
42 |
+
|
43 |
+
# --- Application Lifecycle (Startup/Shutdown) ---
|
44 |
+
@asynccontextmanager
|
45 |
+
async def lifespan(app: FastAPI):
|
46 |
+
"""Handles startup and shutdown events."""
|
47 |
+
global app_graph
|
48 |
+
logger.info("API starting up...")
|
49 |
+
# Initialize Neo4j client (already done on import by graph_client.py)
|
50 |
+
# Verify connection (optional, already done by graph_client on init)
|
51 |
+
try:
|
52 |
+
logger.info("Verifying Neo4j connection...")
|
53 |
+
neo4j_client._get_driver().verify_connectivity()
|
54 |
+
logger.info("Neo4j connection verified.")
|
55 |
+
except Exception as e:
|
56 |
+
logger.error(f"Neo4j connection verification failed on startup: {e}", exc_info=True)
|
57 |
+
# Decide if the app should fail to start
|
58 |
+
# raise RuntimeError("Failed to connect to Neo4j on startup.") from e
|
59 |
+
|
60 |
+
# Build the LangGraph application
|
61 |
+
logger.info("Building LangGraph application...")
|
62 |
+
try:
|
63 |
+
app_graph = build_graph()
|
64 |
+
logger.info("LangGraph application built successfully.")
|
65 |
+
except Exception as e:
|
66 |
+
logger.error(f"Failed to build LangGraph application on startup: {e}", exc_info=True)
|
67 |
+
# Decide if the app should fail to start
|
68 |
+
raise RuntimeError("Failed to build LangGraph on startup.") from e
|
69 |
+
|
70 |
+
yield # API runs here
|
71 |
+
|
72 |
+
# --- Shutdown ---
|
73 |
+
logger.info("API shutting down...")
|
74 |
+
# Close Neo4j connection (handled by atexit in graph_client.py)
|
75 |
+
# neo4j_client.close() # Usually not needed due to atexit registration
|
76 |
+
logger.info("Neo4j client closed (likely via atexit).")
|
77 |
+
logger.info("API shutdown complete.")
|
78 |
+
|
79 |
+
|
80 |
+
# --- FastAPI Application ---
|
81 |
+
app = FastAPI(
|
82 |
+
title="Key Issue Generator API",
|
83 |
+
description="API to generate Key Issues based on a technical query using LLMs and Neo4j.",
|
84 |
+
version="1.0.0",
|
85 |
+
lifespan=lifespan # Use the lifespan context manager
|
86 |
+
)
|
87 |
+
|
88 |
+
# --- API Endpoint ---
|
89 |
+
# API state check route
|
90 |
+
@app.get("/")
|
91 |
+
def read_root():
|
92 |
+
return {"status": "ok"}
|
93 |
+
|
94 |
+
@app.post("/generate-key-issues", response_model=KeyIssueResponse)
|
95 |
+
async def generate_issues(request: KeyIssueRequest):
|
96 |
+
"""
|
97 |
+
Accepts a technical query and returns a list of generated Key Issues.
|
98 |
+
"""
|
99 |
+
global app_graph
|
100 |
+
if app_graph is None:
|
101 |
+
logger.error("Graph application is not initialized.")
|
102 |
+
raise HTTPException(status_code=503, detail="Service Unavailable: Graph not initialized")
|
103 |
+
|
104 |
+
user_query = request.query
|
105 |
+
if not user_query:
|
106 |
+
raise HTTPException(status_code=400, detail="Query cannot be empty.")
|
107 |
+
|
108 |
+
logger.info(f"Received request to generate key issues for query: '{user_query[:100]}...'")
|
109 |
+
start_time = time.time()
|
110 |
+
|
111 |
+
try:
|
112 |
+
# --- Prepare Initial State for LangGraph ---
|
113 |
+
# Note: Ensure PlannerState aligns with what build_graph expects
|
114 |
+
initial_state: PlannerState = {
|
115 |
+
"user_query": user_query,
|
116 |
+
"messages": [HumanMessage(content=user_query)],
|
117 |
+
"plan": [],
|
118 |
+
"current_plan_step_index": -1, # Or as expected by your graph's entry point
|
119 |
+
"step_outputs": {},
|
120 |
+
"key_issues": [],
|
121 |
+
"error": None
|
122 |
+
}
|
123 |
+
|
124 |
+
# --- Define Configuration (e.g., Thread ID for Memory) ---
|
125 |
+
# Using a simple thread ID; adapt if using persistent memory
|
126 |
+
# import hashlib
|
127 |
+
# thread_id = hashlib.sha256(user_query.encode()).hexdigest()[:8]
|
128 |
+
# config: GraphConfig = {"configurable": {"thread_id": thread_id}}
|
129 |
+
# If not using memory, config can be simpler or empty based on LangGraph version
|
130 |
+
config: GraphConfig = {"configurable": {}} # Adjust if thread_id/memory is needed
|
131 |
+
|
132 |
+
# --- Execute the LangGraph Workflow ---
|
133 |
+
logger.info("Invoking LangGraph workflow...")
|
134 |
+
# Use invoke for a single result, or stream if you need intermediate steps
|
135 |
+
final_state = await app_graph.ainvoke(initial_state, config=config)
|
136 |
+
# If using stream:
|
137 |
+
# final_state = None
|
138 |
+
# async for step_state in app_graph.astream(initial_state, config=config):
|
139 |
+
# # Process intermediate states if needed
|
140 |
+
# node_name = list(step_state.keys())[0]
|
141 |
+
# logger.debug(f"Graph step completed: {node_name}")
|
142 |
+
# final_state = step_state[node_name] # Get the latest full state output
|
143 |
+
|
144 |
+
end_time = time.time()
|
145 |
+
logger.info(f"Workflow finished in {end_time - start_time:.2f} seconds.")
|
146 |
+
|
147 |
+
# --- Process Final Results ---
|
148 |
+
if final_state is None:
|
149 |
+
logger.error("Workflow execution did not produce a final state.")
|
150 |
+
raise HTTPException(status_code=500, detail="Workflow execution failed to produce a result.")
|
151 |
+
|
152 |
+
if final_state.get("error"):
|
153 |
+
error_msg = final_state.get("error", "Unknown error")
|
154 |
+
logger.error(f"Workflow failed with error: {error_msg}")
|
155 |
+
# Map internal errors to appropriate HTTP status codes
|
156 |
+
status_code = 500 # Internal Server Error by default
|
157 |
+
if "Neo4j" in error_msg or "connection" in error_msg.lower():
|
158 |
+
status_code = 503 # Service Unavailable (database issue)
|
159 |
+
elif "LLM error" in error_msg or "parse" in error_msg.lower():
|
160 |
+
status_code = 502 # Bad Gateway (issue with upstream LLM)
|
161 |
+
|
162 |
+
raise HTTPException(status_code=status_code, detail=f"Workflow failed: {error_msg}")
|
163 |
+
|
164 |
+
# --- Extract Key Issues ---
|
165 |
+
# Ensure the structure matches KeyIssueResponse and KigKeyIssue Pydantic model
|
166 |
+
generated_issues_data = final_state.get("key_issues", [])
|
167 |
+
|
168 |
+
# Validate and convert if necessary (Pydantic usually handles this via response_model)
|
169 |
+
try:
|
170 |
+
# Pydantic will validate against KeyIssueResponse -> List[KigKeyIssue]
|
171 |
+
response_data = {"key_issues": generated_issues_data}
|
172 |
+
logger.info(f"Successfully generated {len(generated_issues_data)} key issues.")
|
173 |
+
return response_data
|
174 |
+
except Exception as pydantic_error: # Catch potential validation errors
|
175 |
+
logger.error(f"Failed to validate final key issues against response model: {pydantic_error}", exc_info=True)
|
176 |
+
logger.error(f"Data that failed validation: {generated_issues_data}")
|
177 |
+
raise HTTPException(status_code=500, detail="Internal error: Failed to format key issues response.")
|
178 |
+
|
179 |
+
|
180 |
+
except HTTPException as http_exc:
|
181 |
+
# Re-raise HTTPExceptions directly
|
182 |
+
raise http_exc
|
183 |
+
except ConnectionError as e:
|
184 |
+
logger.error(f"Connection Error during API request: {e}", exc_info=True)
|
185 |
+
raise HTTPException(status_code=503, detail=f"Service Unavailable: {e}")
|
186 |
+
except ValueError as e:
|
187 |
+
logger.error(f"Value Error during API request: {e}", exc_info=True)
|
188 |
+
raise HTTPException(status_code=400, detail=f"Bad Request: {e}") # Often input validation issues
|
189 |
+
except Exception as e:
|
190 |
+
logger.error(f"An unexpected error occurred during API request: {e}", exc_info=True)
|
191 |
+
raise HTTPException(status_code=500, detail=f"Internal Server Error: An unexpected error occurred.")
|
192 |
+
|
193 |
+
|
194 |
+
# --- How to Run ---
|
195 |
+
if __name__ == "__main__":
|
196 |
+
# Make sure to set environment variables for config (NEO4J_URI, NEO4J_PASSWORD, GEMINI_API_KEY, etc.)
|
197 |
+
# or have a .env file in the same directory where you run this script.
|
198 |
+
print("Starting API server...")
|
199 |
+
print("Ensure required environment variables (e.g., NEO4J_URI, NEO4J_PASSWORD, GEMINI_API_KEY) are set or .env file is present.")
|
200 |
+
# Run with uvicorn: uvicorn api:app --reload --host 0.0.0.0 --port 8000
|
201 |
+
# The --reload flag is good for development. Remove it for production.
|
202 |
+
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True) # Use reload=False for production
|
kig_core/config.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
3 |
+
from pydantic import Field, SecretStr, HttpUrl, validator, Json
|
4 |
+
from typing import List, Optional, Literal, Union
|
5 |
+
|
6 |
+
# Helper function to load .env file if it exists
|
7 |
+
# Ensure python-dotenv is installed: pip install python-dotenv
|
8 |
+
try:
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
print("Attempting to load .env file...")
|
11 |
+
if load_dotenv():
|
12 |
+
print(".env file loaded successfully.")
|
13 |
+
else:
|
14 |
+
print(".env file not found or empty.")
|
15 |
+
except ImportError:
|
16 |
+
print("python-dotenv not installed, skipping .env file loading.")
|
17 |
+
pass # Optional: Handle missing dotenv library
|
18 |
+
|
19 |
+
|
20 |
+
class Settings(BaseSettings):
|
21 |
+
# Load from .env file
|
22 |
+
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', extra='ignore')
|
23 |
+
|
24 |
+
# Neo4j Credentials
|
25 |
+
neo4j_uri: str = Field(..., validation_alias='NEO4J_URI')
|
26 |
+
neo4j_username: str = Field("neo4j", validation_alias='NEO4J_USERNAME')
|
27 |
+
neo4j_password: SecretStr = os.getenv("NEO4J_PASSWORD")
|
28 |
+
|
29 |
+
# API Keys
|
30 |
+
openai_api_key: Optional[SecretStr] = os.getenv("OPENAI_API_KEY")
|
31 |
+
gemini_api_key: Optional[SecretStr] = os.getenv("GEMINI_API_KEY")
|
32 |
+
langsmith_api_key: Optional[SecretStr] = os.getenv("LANGSMITH_API_KEY")
|
33 |
+
langchain_project: Optional[str] = Field("KIG_Refactored", validation_alias='LANGCHAIN_PROJECT')
|
34 |
+
|
35 |
+
# LLM Configuration
|
36 |
+
main_llm_model: str = Field("gemini-1.5-flash", validation_alias='MAIN_LLM_MODEL')
|
37 |
+
eval_llm_model: str = Field("gemini-1.5-flash", validation_alias='EVAL_LLM_MODEL')
|
38 |
+
summarize_llm_model: str = Field("gemini-1.5-flash", validation_alias='SUMMARIZE_LLM_MODEL')
|
39 |
+
# Add other models if needed (e.g., cypher gen, concept selection)
|
40 |
+
|
41 |
+
# Planner Configuration
|
42 |
+
plan_method: Literal["generation", "modification"] = Field("generation", validation_alias='PLAN_METHOD')
|
43 |
+
use_detailed_query: bool = Field(False, validation_alias='USE_DETAILED_QUERY')
|
44 |
+
|
45 |
+
# Graph Operations Configuration
|
46 |
+
cypher_gen_method: Literal["guided", "auto"] = Field("guided", validation_alias='CYPHER_GEN_METHOD')
|
47 |
+
validate_cypher: bool = Field(False, validation_alias='VALIDATE_CYPHER')
|
48 |
+
eval_method: Literal["binary", "score"] = Field("binary", validation_alias='EVAL_METHOD')
|
49 |
+
eval_threshold: float = Field(0.7, validation_alias='EVAL_THRESHOLD')
|
50 |
+
max_docs: int = Field(10, validation_alias='MAX_DOCS')
|
51 |
+
|
52 |
+
# Processing Configuration
|
53 |
+
# Load processing steps from JSON string in .env
|
54 |
+
process_steps: Json[List[Union[str, dict]]] = Field('["summarize"]', validation_alias='PROCESS_STEPS')
|
55 |
+
compression_method: Optional[str] = Field(None, validation_alias='COMPRESSION_METHOD')
|
56 |
+
compress_rate: Optional[float] = Field(0.5, validation_alias='COMPRESS_RATE')
|
57 |
+
|
58 |
+
# Langsmith Tracing (set automatically based on key)
|
59 |
+
langsmith_tracing_v2: str = "false"
|
60 |
+
|
61 |
+
@validator('langsmith_tracing_v2', pre=True, always=True)
|
62 |
+
def set_langsmith_tracing(cls, v, values):
|
63 |
+
return "true" if values.get('langsmith_api_key') else "false"
|
64 |
+
|
65 |
+
def configure_langsmith(self):
|
66 |
+
"""Sets Langsmith environment variables if API key is provided."""
|
67 |
+
if self.langsmith_api_key:
|
68 |
+
os.environ["LANGCHAIN_TRACING_V2"] = self.langsmith_tracing_v2
|
69 |
+
os.environ["LANGCHAIN_API_KEY"] = self.langsmith_api_key.get_secret_value()
|
70 |
+
if self.langchain_project:
|
71 |
+
os.environ["LANGCHAIN_PROJECT"] = self.langchain_project
|
72 |
+
print("Langsmith configured.")
|
73 |
+
else:
|
74 |
+
# Ensure tracing is disabled if no key
|
75 |
+
os.environ["LANGCHAIN_TRACING_V2"] = "false"
|
76 |
+
print("Langsmith key not found, tracing disabled.")
|
77 |
+
|
78 |
+
# Create a single instance to be imported elsewhere
|
79 |
+
settings = Settings()
|
80 |
+
# Automatically configure Langsmith on import
|
81 |
+
settings.configure_langsmith()
|
82 |
+
|
83 |
+
# Optionally set Gemini key in environment if needed by library implicitly
|
84 |
+
if settings.gemini_api_key:
|
85 |
+
os.environ["GOOGLE_API_KEY"] = settings.gemini_api_key.get_secret_value()
|
86 |
+
print("Set GOOGLE_API_KEY environment variable.")
|
87 |
+
if settings.openai_api_key:
|
88 |
+
os.environ["OPENAI_API_KEY"] = settings.openai_api_key.get_secret_value()
|
89 |
+
print("Set OPENAI_API_KEY environment variable.")
|
kig_core/graph_client.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from neo4j import GraphDatabase, Driver, exceptions
|
2 |
+
from .config import settings
|
3 |
+
import logging
|
4 |
+
from typing import List, Dict, Any, Optional
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
class Neo4jClient:
|
9 |
+
_driver: Optional[Driver] = None
|
10 |
+
|
11 |
+
def _get_driver(self) -> Driver:
|
12 |
+
"""Initializes and returns the Neo4j driver instance."""
|
13 |
+
if self._driver is None or self._driver.close():
|
14 |
+
logger.info(f"Initializing Neo4j Driver for URI: {settings.neo4j_uri}")
|
15 |
+
try:
|
16 |
+
self._driver = GraphDatabase.driver(
|
17 |
+
settings.neo4j_uri,
|
18 |
+
auth=(settings.neo4j_username, settings.neo4j_password.get_secret_value())
|
19 |
+
)
|
20 |
+
# Verify connectivity during initialization
|
21 |
+
self._driver.verify_connectivity()
|
22 |
+
logger.info("Neo4j Driver initialized and connection verified.")
|
23 |
+
except exceptions.AuthError as e:
|
24 |
+
logger.error(f"Neo4j Authentication Error: {e}", exc_info=True)
|
25 |
+
raise ConnectionError("Neo4j Authentication Failed. Check credentials.") from e
|
26 |
+
except exceptions.ServiceUnavailable as e:
|
27 |
+
logger.error(f"Neo4j Service Unavailable: {e}", exc_info=True)
|
28 |
+
raise ConnectionError(f"Could not connect to Neo4j at {settings.neo4j_uri}. Ensure DB is running and reachable.") from e
|
29 |
+
except Exception as e:
|
30 |
+
logger.error(f"Unexpected error initializing Neo4j Driver: {e}", exc_info=True)
|
31 |
+
raise ConnectionError("An unexpected error occurred connecting to Neo4j.") from e
|
32 |
+
return self._driver
|
33 |
+
|
34 |
+
def close(self):
|
35 |
+
"""Closes the Neo4j driver connection."""
|
36 |
+
if self._driver and not self._driver.close():
|
37 |
+
logger.info("Closing Neo4j Driver.")
|
38 |
+
self._driver.close()
|
39 |
+
self._driver = None
|
40 |
+
|
41 |
+
def query(self, cypher_query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
42 |
+
"""Executes a Cypher query and returns the results."""
|
43 |
+
driver = self._get_driver()
|
44 |
+
logger.debug(f"Executing Cypher: {cypher_query} with params: {params}")
|
45 |
+
try:
|
46 |
+
# Use session/transaction for robust execution
|
47 |
+
with driver.session() as session:
|
48 |
+
result = session.run(cypher_query, params or {})
|
49 |
+
# Convert Neo4j Records to dictionaries
|
50 |
+
data = [record.data() for record in result]
|
51 |
+
logger.debug(f"Query returned {len(data)} records.")
|
52 |
+
return data
|
53 |
+
except (exceptions.ServiceUnavailable, exceptions.SessionExpired) as e:
|
54 |
+
logger.error(f"Neo4j connection error during query: {e}", exc_info=True)
|
55 |
+
# Attempt to close the potentially broken driver so it reconnects next time
|
56 |
+
self.close()
|
57 |
+
raise ConnectionError("Neo4j connection error during query execution.") from e
|
58 |
+
except exceptions.CypherSyntaxError as e:
|
59 |
+
logger.error(f"Neo4j Cypher Syntax Error: {e}\nQuery: {cypher_query}", exc_info=True)
|
60 |
+
raise ValueError("Invalid Cypher query syntax.") from e
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"Unexpected error during Neo4j query: {e}", exc_info=True)
|
63 |
+
raise RuntimeError("An unexpected error occurred during the Neo4j query.") from e
|
64 |
+
|
65 |
+
def get_schema(self, force_refresh: bool = False) -> Dict[str, Any]:
|
66 |
+
""" Fetches the graph schema. Placeholder - Langchain community graph has better schema fetching."""
|
67 |
+
# For simplicity, returning empty. Implement actual schema fetching if needed.
|
68 |
+
# Consider using langchain_community.graphs.Neo4jGraph for schema handling if complex interactions are needed.
|
69 |
+
logger.warning("Neo4jClient.get_schema() is a placeholder. Implement if schema needed.")
|
70 |
+
return {} # Placeholder
|
71 |
+
|
72 |
+
def get_concepts(self) -> List[str]:
|
73 |
+
"""Fetches all Concept names from the graph."""
|
74 |
+
cypher = "MATCH (c:Concept) RETURN c.name AS name ORDER BY name"
|
75 |
+
results = self.query(cypher)
|
76 |
+
return [record['name'] for record in results if 'name' in record]
|
77 |
+
|
78 |
+
def get_concept_description(self, concept_name: str) -> Optional[str]:
|
79 |
+
"""Fetches the description for a specific concept."""
|
80 |
+
cypher = "MATCH (c:Concept {name: $name}) RETURN c.description AS description LIMIT 1"
|
81 |
+
params = {"name": concept_name}
|
82 |
+
results = self.query(cypher, params)
|
83 |
+
return results[0]['description'] if results and 'description' in results[0] else None
|
84 |
+
|
85 |
+
|
86 |
+
# Create a single instance for the application to use
|
87 |
+
neo4j_client = Neo4jClient()
|
88 |
+
|
89 |
+
# Ensure the client is closed gracefully when the application exits
|
90 |
+
import atexit
|
91 |
+
atexit.register(neo4j_client.close)
|
kig_core/graph_operations.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
from typing import List, Dict, Any, Optional, Tuple
|
5 |
+
from random import sample, shuffle
|
6 |
+
|
7 |
+
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
|
8 |
+
from langchain_core.runnables import Runnable, RunnablePassthrough
|
9 |
+
from langchain_core.pydantic_v1 import Field, BaseModel as V1BaseModel # For grader models if needed
|
10 |
+
|
11 |
+
from .config import settings
|
12 |
+
from .graph_client import neo4j_client # Use the central client
|
13 |
+
from .llm_interface import get_llm, invoke_llm
|
14 |
+
from .prompts import (
|
15 |
+
CYPHER_GENERATION_PROMPT, CONCEPT_SELECTION_PROMPT,
|
16 |
+
BINARY_GRADER_PROMPT, SCORE_GRADER_PROMPT
|
17 |
+
)
|
18 |
+
from .schemas import KeyIssue # Import if needed here, maybe not
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
# --- Helper Functions ---
|
23 |
+
def extract_cypher(text: str) -> str:
|
24 |
+
"""Extracts the first Cypher code block or returns the text itself."""
|
25 |
+
pattern = r"```(?:cypher)?\s*(.*?)\s*```"
|
26 |
+
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
|
27 |
+
return match.group(1).strip() if match else text.strip()
|
28 |
+
|
29 |
+
def format_doc_for_llm(doc: Dict[str, Any]) -> str:
|
30 |
+
"""Formats a document dictionary into a string for LLM context."""
|
31 |
+
return "\n".join(f"**{key}**: {value}" for key, value in doc.items() if value)
|
32 |
+
|
33 |
+
|
34 |
+
# --- Cypher Generation ---
|
35 |
+
def generate_cypher_auto(question: str) -> str:
|
36 |
+
"""Generates Cypher using the 'auto' method."""
|
37 |
+
logger.info("Generating Cypher using 'auto' method.")
|
38 |
+
# Schema fetching needs implementation if required by the prompt/LLM
|
39 |
+
# schema_info = neo4j_client.get_schema() # Placeholder
|
40 |
+
schema_info = "Schema not available." # Default if not implemented
|
41 |
+
|
42 |
+
cypher_llm = get_llm(settings.main_llm_model) # Or a specific cypher model
|
43 |
+
chain = (
|
44 |
+
{"question": RunnablePassthrough(), "schema": lambda x: schema_info}
|
45 |
+
| CYPHER_GENERATION_PROMPT
|
46 |
+
| cypher_llm
|
47 |
+
| StrOutputParser()
|
48 |
+
| extract_cypher
|
49 |
+
)
|
50 |
+
return invoke_llm(chain,question)
|
51 |
+
|
52 |
+
def generate_cypher_guided(question: str, plan_step: int) -> str:
|
53 |
+
"""Generates Cypher using the 'guided' method based on concepts."""
|
54 |
+
logger.info(f"Generating Cypher using 'guided' method for plan step {plan_step}.")
|
55 |
+
try:
|
56 |
+
concepts = neo4j_client.get_concepts()
|
57 |
+
if not concepts:
|
58 |
+
logger.warning("No concepts found in Neo4j for guided cypher generation.")
|
59 |
+
return "" # Or raise error
|
60 |
+
|
61 |
+
concept_llm = get_llm(settings.main_llm_model) # Or a specific concept model
|
62 |
+
concept_chain = (
|
63 |
+
CONCEPT_SELECTION_PROMPT
|
64 |
+
| concept_llm
|
65 |
+
| StrOutputParser()
|
66 |
+
)
|
67 |
+
selected_concept = invoke_llm(concept_chain,{
|
68 |
+
"question": question,
|
69 |
+
"concepts": "\n".join(concepts)
|
70 |
+
}).strip()
|
71 |
+
|
72 |
+
logger.info(f"Concept selected by LLM: {selected_concept}")
|
73 |
+
|
74 |
+
# Basic check if the selected concept is valid
|
75 |
+
if selected_concept not in concepts:
|
76 |
+
logger.warning(f"LLM selected concept '{selected_concept}' not in the known list. Attempting fallback or ignoring.")
|
77 |
+
# Optional: Add fuzzy matching or similarity search here
|
78 |
+
# For now, we might default or return empty
|
79 |
+
# Let's try a simple substring check as a fallback
|
80 |
+
found_match = None
|
81 |
+
for c in concepts:
|
82 |
+
if selected_concept.lower() in c.lower():
|
83 |
+
found_match = c
|
84 |
+
logger.info(f"Found potential match: '{found_match}'")
|
85 |
+
break
|
86 |
+
if not found_match:
|
87 |
+
logger.error(f"Could not validate selected concept: {selected_concept}")
|
88 |
+
return "" # Return empty query if concept is invalid
|
89 |
+
selected_concept = found_match
|
90 |
+
|
91 |
+
|
92 |
+
# Determine the target node type based on plan step (example logic)
|
93 |
+
# This mapping might need adjustment based on the actual plan structure
|
94 |
+
if plan_step <= 1: # Steps 0 and 1: Context gathering
|
95 |
+
target = "(ts:TechnicalSpecification)"
|
96 |
+
fields = "ts.title, ts.scope, ts.description"
|
97 |
+
elif plan_step == 2: # Step 2: Research papers?
|
98 |
+
target = "(rp:ResearchPaper)"
|
99 |
+
fields = "rp.title, rp.abstract"
|
100 |
+
else: # Later steps might involve KeyIssues themselves or other types
|
101 |
+
target = "(n)" # Generic fallback
|
102 |
+
fields = "n.title, n.description" # Assuming common fields
|
103 |
+
|
104 |
+
# Construct Cypher query
|
105 |
+
# Ensure selected_concept is properly escaped if needed, though parameters are safer
|
106 |
+
cypher = f"MATCH (c:Concept {{name: $conceptName}})-[:RELATED_TO]-{target} RETURN {fields}"
|
107 |
+
# We return the query and the parameters separately for safe execution
|
108 |
+
# However, the planner currently expects just the string. Let's construct it directly for now.
|
109 |
+
# Be cautious about injection if concept names can contain special chars. Binding is preferred.
|
110 |
+
escaped_concept = selected_concept.replace("'", "\\'") # Basic escaping
|
111 |
+
cypher = f"MATCH (c:Concept {{name: '{escaped_concept}'}})-[:RELATED_TO]-{target} RETURN {fields}"
|
112 |
+
|
113 |
+
logger.info(f"Generated guided Cypher: {cypher}")
|
114 |
+
return cypher
|
115 |
+
|
116 |
+
except Exception as e:
|
117 |
+
logger.error(f"Error during guided cypher generation: {e}", exc_info=True)
|
118 |
+
time.sleep(60)
|
119 |
+
return "" # Return empty on error
|
120 |
+
|
121 |
+
|
122 |
+
# --- Document Retrieval ---
|
123 |
+
def retrieve_documents(cypher_query: str) -> List[Dict[str, Any]]:
|
124 |
+
"""Retrieves documents from Neo4j using a Cypher query."""
|
125 |
+
if not cypher_query:
|
126 |
+
logger.warning("Received empty Cypher query, skipping retrieval.")
|
127 |
+
return []
|
128 |
+
logger.info(f"Retrieving documents with Cypher: {cypher_query} limit 10")
|
129 |
+
try:
|
130 |
+
# Use the centralized client's query method
|
131 |
+
raw_results = neo4j_client.query(cypher_query + " limit 10")
|
132 |
+
# Basic cleaning/deduplication (can be enhanced)
|
133 |
+
processed_results = []
|
134 |
+
seen = set()
|
135 |
+
for doc in raw_results:
|
136 |
+
# Create a frozenset of items for hashable representation to detect duplicates
|
137 |
+
doc_items = frozenset(doc.items())
|
138 |
+
if doc_items not in seen:
|
139 |
+
processed_results.append(doc)
|
140 |
+
seen.add(doc_items)
|
141 |
+
logger.info(f"Retrieved {len(processed_results)} unique documents.")
|
142 |
+
return processed_results
|
143 |
+
except (ConnectionError, ValueError, RuntimeError) as e:
|
144 |
+
# Errors already logged in neo4j_client
|
145 |
+
logger.error(f"Document retrieval failed: {e}")
|
146 |
+
return [] # Return empty list on failure
|
147 |
+
|
148 |
+
|
149 |
+
# --- Document Evaluation ---
|
150 |
+
# Define Pydantic models for structured LLM grader output (if not using built-in LCEL structured output)
|
151 |
+
class GradeDocumentsBinary(V1BaseModel):
|
152 |
+
"""Binary score for relevance check."""
|
153 |
+
binary_score: str = Field(description="Relevant? 'yes' or 'no'")
|
154 |
+
|
155 |
+
class GradeDocumentsScore(V1BaseModel):
|
156 |
+
"""Score for relevance check."""
|
157 |
+
rationale: str = Field(description="Rationale for the score.")
|
158 |
+
score: float = Field(description="Relevance score (0.0 to 1.0)")
|
159 |
+
|
160 |
+
def evaluate_documents(
|
161 |
+
docs: List[Dict[str, Any]],
|
162 |
+
query: str
|
163 |
+
) -> List[Dict[str, Any]]:
|
164 |
+
"""Evaluates document relevance to a query using configured method."""
|
165 |
+
if not docs:
|
166 |
+
return []
|
167 |
+
|
168 |
+
logger.info(f"Evaluating {len(docs)} documents for relevance to query: '{query}' using method: {settings.eval_method}")
|
169 |
+
eval_llm = get_llm(settings.eval_llm_model)
|
170 |
+
valid_docs_with_scores: List[Tuple[Dict[str, Any], float]] = []
|
171 |
+
|
172 |
+
# Consider using LCEL's structured output capabilities directly if the model supports it well
|
173 |
+
# This simplifies parsing. Example for binary:
|
174 |
+
# binary_grader = BINARY_GRADER_PROMPT | eval_llm.with_structured_output(GradeDocumentsBinary)
|
175 |
+
|
176 |
+
if settings.eval_method == "binary":
|
177 |
+
binary_grader = BINARY_GRADER_PROMPT | eval_llm | StrOutputParser() # Fallback to string parsing
|
178 |
+
for doc in docs:
|
179 |
+
formatted_doc = format_doc_for_llm(doc)
|
180 |
+
if not formatted_doc.strip(): continue
|
181 |
+
try:
|
182 |
+
result = invoke_llm(binary_grader,{"question": query, "document": formatted_doc})
|
183 |
+
logger.debug(f"Binary grader result for doc '{doc.get('title', 'N/A')}': {result}")
|
184 |
+
if result and 'yes' in result.lower():
|
185 |
+
valid_docs_with_scores.append((doc, 1.0)) # Score 1.0 for relevant
|
186 |
+
except Exception as e:
|
187 |
+
logger.warning(f"Binary grading failed for a document: {e}", exc_info=True)
|
188 |
+
|
189 |
+
elif settings.eval_method == "score":
|
190 |
+
# Using JSON parser as a robust fallback for score extraction
|
191 |
+
score_grader = SCORE_GRADER_PROMPT | eval_llm | JsonOutputParser(pydantic_object=GradeDocumentsScore)
|
192 |
+
for doc in docs:
|
193 |
+
formatted_doc = format_doc_for_llm(doc)
|
194 |
+
if not formatted_doc.strip(): continue
|
195 |
+
try:
|
196 |
+
result: GradeDocumentsScore = invoke_llm(score_grader,{"query": query, "document": formatted_doc})
|
197 |
+
logger.debug(f"Score grader result for doc '{doc.get('title', 'N/A')}': Score={result.score}, Rationale={result.rationale}")
|
198 |
+
if result.score >= settings.eval_threshold:
|
199 |
+
valid_docs_with_scores.append((doc, result.score))
|
200 |
+
except Exception as e:
|
201 |
+
logger.warning(f"Score grading failed for a document: {e}", exc_info=True)
|
202 |
+
# Optionally treat as relevant on failure? Or skip? Skipping for now.
|
203 |
+
|
204 |
+
# Sort by score if applicable, then limit
|
205 |
+
if settings.eval_method == 'score':
|
206 |
+
valid_docs_with_scores.sort(key=lambda item: item[1], reverse=True)
|
207 |
+
|
208 |
+
# Limit to max_docs
|
209 |
+
final_docs = [doc for doc, score in valid_docs_with_scores[:settings.max_docs]]
|
210 |
+
logger.info(f"Found {len(final_docs)} relevant documents after evaluation and filtering.")
|
211 |
+
|
212 |
+
return final_docs
|
kig_core/llm_interface.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
4 |
+
from langchain_openai import ChatOpenAI
|
5 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
6 |
+
from .config import settings
|
7 |
+
import logging
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
# Store initialized models to avoid re-creating them repeatedly
|
12 |
+
_llm_cache = {}
|
13 |
+
|
14 |
+
def get_llm(model_name: str) -> BaseChatModel:
|
15 |
+
"""
|
16 |
+
Returns an initialized LangChain chat model based on the provided name.
|
17 |
+
Caches initialized models.
|
18 |
+
"""
|
19 |
+
global _llm_cache
|
20 |
+
if model_name in _llm_cache:
|
21 |
+
return _llm_cache[model_name]
|
22 |
+
|
23 |
+
logger.info(f"Initializing LLM: {model_name}")
|
24 |
+
|
25 |
+
if model_name.startswith("gemini"):
|
26 |
+
if not settings.gemini_api_key:
|
27 |
+
raise ValueError("GEMINI_API_KEY is not configured.")
|
28 |
+
try:
|
29 |
+
# Use GOOGLE_API_KEY environment variable set in config.py
|
30 |
+
llm = ChatGoogleGenerativeAI(model=model_name)
|
31 |
+
_llm_cache[model_name] = llm
|
32 |
+
logger.info(f"Initialized Google Generative AI model: {model_name}")
|
33 |
+
return llm
|
34 |
+
except Exception as e:
|
35 |
+
logger.error(f"Failed to initialize Gemini model '{model_name}': {e}", exc_info=True)
|
36 |
+
raise RuntimeError(f"Could not initialize Gemini model: {e}") from e
|
37 |
+
|
38 |
+
elif model_name.startswith("gpt"):
|
39 |
+
if not settings.openai_api_key:
|
40 |
+
raise ValueError("OPENAI_API_KEY is not configured.")
|
41 |
+
try:
|
42 |
+
# Base URL can be added here if using a proxy
|
43 |
+
# base_url = "https://your-proxy-if-needed/"
|
44 |
+
llm = ChatOpenAI(model=model_name, api_key=settings.openai_api_key) # Base URL optional
|
45 |
+
_llm_cache[model_name] = llm
|
46 |
+
logger.info(f"Initialized OpenAI model: {model_name}")
|
47 |
+
return llm
|
48 |
+
except Exception as e:
|
49 |
+
logger.error(f"Failed to initialize OpenAI model '{model_name}': {e}", exc_info=True)
|
50 |
+
raise RuntimeError(f"Could not initialize OpenAI model: {e}") from e
|
51 |
+
|
52 |
+
# Add other model providers (Anthropic, Groq, etc.) here if needed
|
53 |
+
|
54 |
+
else:
|
55 |
+
logger.error(f"Unsupported model provider for model name: {model_name}")
|
56 |
+
raise ValueError(f"Model '{model_name}' is not supported or configuration is missing.")
|
57 |
+
|
58 |
+
def invoke_llm(var,parameters):
|
59 |
+
try:
|
60 |
+
return var.invoke(parameters)
|
61 |
+
except Exception as e:
|
62 |
+
print(f"Error during .invoke : {e} \nwaiting 60 seconds")
|
63 |
+
time.sleep(60)
|
64 |
+
print("Waiting is finished")
|
65 |
+
return var.invoke(parameters)
|
66 |
+
|
67 |
+
# Example usage (could be called from other modules)
|
68 |
+
# main_llm = get_llm(settings.main_llm_model)
|
69 |
+
# eval_llm = get_llm(settings.eval_llm_model)
|
kig_core/planner.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import re
|
3 |
+
import time
|
4 |
+
from typing import List, Dict, Any, Optional
|
5 |
+
from langgraph.graph import StateGraph, END
|
6 |
+
from langgraph.checkpoint.memory import MemorySaver # Or SqliteSaver etc.
|
7 |
+
|
8 |
+
from pydantic import BaseModel, Field
|
9 |
+
|
10 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
11 |
+
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
|
12 |
+
|
13 |
+
from .config import settings
|
14 |
+
from .schemas import PlannerState, KeyIssue, GraphConfig # Import schemas
|
15 |
+
from .prompts import get_initial_planner_prompt, KEY_ISSUE_STRUCTURING_PROMPT
|
16 |
+
from .llm_interface import get_llm, invoke_llm
|
17 |
+
from .graph_operations import (
|
18 |
+
generate_cypher_auto, generate_cypher_guided,
|
19 |
+
retrieve_documents, evaluate_documents
|
20 |
+
)
|
21 |
+
from .processing import process_documents
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
# --- Graph Nodes ---
|
26 |
+
|
27 |
+
def start_planning(state: PlannerState) -> Dict[str, Any]:
|
28 |
+
"""Generates the initial plan based on the user query."""
|
29 |
+
logger.info("Node: start_planning")
|
30 |
+
user_query = state['user_query']
|
31 |
+
if not user_query:
|
32 |
+
return {"error": "User query is empty."}
|
33 |
+
|
34 |
+
initial_prompt = get_initial_planner_prompt(settings.plan_method, user_query)
|
35 |
+
llm = get_llm(settings.main_llm_model)
|
36 |
+
chain = initial_prompt | llm | StrOutputParser()
|
37 |
+
|
38 |
+
try:
|
39 |
+
plan_text = invoke_llm(chain,{}) # Prompt already includes query
|
40 |
+
logger.debug(f"Raw plan text: {plan_text}")
|
41 |
+
|
42 |
+
# Extract plan steps (simple regex, might need refinement)
|
43 |
+
plan_match = re.search(r"Plan:(.*?)<END_OF_PLAN>", plan_text, re.DOTALL | re.IGNORECASE)
|
44 |
+
if plan_match:
|
45 |
+
plan_steps = [step.strip() for step in re.split(r"\n\s*\d+\.\s*", plan_match.group(1)) if step.strip()]
|
46 |
+
logger.info(f"Extracted plan: {plan_steps}")
|
47 |
+
return {
|
48 |
+
"plan": plan_steps,
|
49 |
+
"current_plan_step_index": 0,
|
50 |
+
"messages": [AIMessage(content=plan_text)],
|
51 |
+
"step_outputs": {} # Initialize step outputs
|
52 |
+
}
|
53 |
+
else:
|
54 |
+
logger.error("Could not parse plan from LLM response.")
|
55 |
+
return {"error": "Failed to parse plan from LLM response.", "messages": [AIMessage(content=plan_text)]}
|
56 |
+
except Exception as e:
|
57 |
+
logger.error(f"Error during plan generation: {e}", exc_info=True)
|
58 |
+
return {"error": f"LLM error during plan generation: {e}"}
|
59 |
+
|
60 |
+
|
61 |
+
def execute_plan_step(state: PlannerState) -> Dict[str, Any]:
|
62 |
+
"""Executes the current step of the plan (retrieval, processing)."""
|
63 |
+
current_index = state['current_plan_step_index']
|
64 |
+
plan = state['plan']
|
65 |
+
user_query = state['user_query'] # Use original query for context
|
66 |
+
|
67 |
+
if current_index >= len(plan):
|
68 |
+
logger.warning("Plan step index out of bounds, attempting to finalize.")
|
69 |
+
# This should ideally be handled by the conditional edge, but as a fallback
|
70 |
+
return {"error": "Plan execution finished unexpectedly."}
|
71 |
+
|
72 |
+
step_description = plan[current_index]
|
73 |
+
logger.info(f"Node: execute_plan_step - Step {current_index + 1}/{len(plan)}: {step_description}")
|
74 |
+
|
75 |
+
# --- Determine Query for Retrieval ---
|
76 |
+
# Simple approach: Use step description or original query?
|
77 |
+
# Let's use the step description combined with the original query for context.
|
78 |
+
query_for_retrieval = f"Regarding the query '{user_query}', focus on: {step_description}"
|
79 |
+
logger.info(f"Query for retrieval: {query_for_retrieval}")
|
80 |
+
|
81 |
+
# --- Generate Cypher ---
|
82 |
+
cypher_query = ""
|
83 |
+
if settings.cypher_gen_method == 'auto':
|
84 |
+
cypher_query = generate_cypher_auto(query_for_retrieval)
|
85 |
+
elif settings.cypher_gen_method == 'guided':
|
86 |
+
cypher_query = generate_cypher_guided(query_for_retrieval, current_index)
|
87 |
+
# TODO: Add cypher validation if settings.validate_cypher is True
|
88 |
+
|
89 |
+
# --- Retrieve Documents ---
|
90 |
+
retrieved_docs = retrieve_documents(cypher_query)
|
91 |
+
|
92 |
+
# --- Evaluate Documents ---
|
93 |
+
evaluated_docs = evaluate_documents(retrieved_docs, query_for_retrieval)
|
94 |
+
|
95 |
+
# --- Process Documents ---
|
96 |
+
# Using configured processing steps
|
97 |
+
processed_docs_content = process_documents(evaluated_docs, settings.process_steps)
|
98 |
+
|
99 |
+
# --- Store Step Output ---
|
100 |
+
# Store the processed content relevant to this step
|
101 |
+
step_output = "\n\n".join(processed_docs_content) if processed_docs_content else "No relevant information found for this step."
|
102 |
+
current_step_outputs = state.get('step_outputs', {})
|
103 |
+
current_step_outputs[current_index] = step_output
|
104 |
+
|
105 |
+
logger.info(f"Finished executing plan step {current_index + 1}. Stored output.")
|
106 |
+
|
107 |
+
return {
|
108 |
+
"current_plan_step_index": current_index + 1,
|
109 |
+
"messages": [SystemMessage(content=f"Completed plan step {current_index + 1}. Context gathered:\n{step_output[:500]}...")], # Add summary message
|
110 |
+
"step_outputs": current_step_outputs
|
111 |
+
}
|
112 |
+
|
113 |
+
class KeyIssue(BaseModel):
|
114 |
+
# define your fields here
|
115 |
+
id: int
|
116 |
+
description: str
|
117 |
+
|
118 |
+
class KeyIssueList(BaseModel):
|
119 |
+
key_issues: List[KeyIssue] = Field(description="List of key issues")
|
120 |
+
|
121 |
+
class KeyIssueInvoke(BaseModel):
|
122 |
+
id: int
|
123 |
+
title: str
|
124 |
+
description: str
|
125 |
+
challenges: List[str]
|
126 |
+
potential_impact: Optional[str] = None
|
127 |
+
|
128 |
+
def generate_structured_issues(state: PlannerState) -> Dict[str, Any]:
|
129 |
+
"""Generates the final structured Key Issues based on all gathered context."""
|
130 |
+
logger.info("Node: generate_structured_issues")
|
131 |
+
|
132 |
+
user_query = state['user_query']
|
133 |
+
step_outputs = state.get('step_outputs', {})
|
134 |
+
|
135 |
+
# --- Combine Context from All Steps ---
|
136 |
+
full_context = f"Original User Query: {user_query}\n\n"
|
137 |
+
full_context += "Context gathered during planning:\n"
|
138 |
+
for i, output in sorted(step_outputs.items()):
|
139 |
+
full_context += f"--- Context from Step {i+1} ---\n{output}\n\n"
|
140 |
+
|
141 |
+
if not step_outputs:
|
142 |
+
full_context += "No context was gathered during the planning steps.\n"
|
143 |
+
|
144 |
+
logger.info(f"Generating key issues using combined context (length: {len(full_context)} chars).")
|
145 |
+
# logger.debug(f"Full Context for Key Issue Generation:\n{full_context}") # Optional: log full context
|
146 |
+
|
147 |
+
# --- Call LLM for Structured Output ---
|
148 |
+
issue_llm = get_llm(settings.main_llm_model)
|
149 |
+
# Use PydanticOutputParser for robust parsing
|
150 |
+
output_parser = JsonOutputParser(pydantic_object=KeyIssueList)
|
151 |
+
|
152 |
+
|
153 |
+
prompt = KEY_ISSUE_STRUCTURING_PROMPT.partial(
|
154 |
+
schema=output_parser.get_format_instructions(), # Inject schema instructions if needed by prompt
|
155 |
+
)
|
156 |
+
|
157 |
+
chain = prompt | issue_llm | output_parser
|
158 |
+
|
159 |
+
try:
|
160 |
+
structured_issues_obj = invoke_llm(chain, {
|
161 |
+
"user_query": user_query,
|
162 |
+
"context": full_context
|
163 |
+
})
|
164 |
+
print(f"structured_issues_obj => type : {type(structured_issues_obj)}, value : {structured_issues_obj}")
|
165 |
+
|
166 |
+
# If the output is a dict with a key 'key_issues', extract it
|
167 |
+
if isinstance(structured_issues_obj, dict) and 'key_issues' in structured_issues_obj:
|
168 |
+
issues_data = structured_issues_obj['key_issues']
|
169 |
+
else:
|
170 |
+
issues_data = structured_issues_obj # Assume it's already a list of dicts
|
171 |
+
|
172 |
+
# Always convert to KeyIssueInvoke objects
|
173 |
+
key_issues_list = [KeyIssueInvoke(**issue_dict) for issue_dict in issues_data]
|
174 |
+
|
175 |
+
# Ensure IDs are sequential if the LLM didn't assign them correctly
|
176 |
+
for i, issue in enumerate(key_issues_list):
|
177 |
+
issue.id = i + 1
|
178 |
+
|
179 |
+
logger.info(f"Successfully generated {len(key_issues_list)} structured key issues.")
|
180 |
+
final_message = f"Generated {len(key_issues_list)} Key Issues based on the query '{user_query}'."
|
181 |
+
return {
|
182 |
+
"key_issues": key_issues_list,
|
183 |
+
"messages": [AIMessage(content=final_message)],
|
184 |
+
"error": None
|
185 |
+
}
|
186 |
+
except Exception as e:
|
187 |
+
logger.error(f"Failed to generate or parse structured key issues: {e}", exc_info=True)
|
188 |
+
# Attempt to get raw output for debugging if possible
|
189 |
+
raw_output = "Could not retrieve raw output."
|
190 |
+
try:
|
191 |
+
raw_chain = prompt | issue_llm | StrOutputParser()
|
192 |
+
raw_output = invoke_llm(raw_chain, {"user_query": user_query, "context": full_context})
|
193 |
+
logger.debug(f"Raw output from failed JSON parsing:\n{raw_output}")
|
194 |
+
except Exception as raw_e:
|
195 |
+
logger.error(f"Could not even get raw output: {raw_e}")
|
196 |
+
|
197 |
+
return {"error": f"Failed to generate structured key issues: {e}. Raw output hint: {raw_output[:500]}..."}
|
198 |
+
|
199 |
+
|
200 |
+
# --- Conditional Edges ---
|
201 |
+
|
202 |
+
def should_continue_planning(state: PlannerState) -> str:
|
203 |
+
"""Determines if there are more plan steps to execute."""
|
204 |
+
logger.debug("Edge: should_continue_planning")
|
205 |
+
if state.get("error"):
|
206 |
+
logger.error(f"Error state detected: {state['error']}. Ending execution.")
|
207 |
+
return "error_state" # Go to a potential error handling end node
|
208 |
+
|
209 |
+
current_index = state['current_plan_step_index']
|
210 |
+
plan_length = len(state.get('plan', []))
|
211 |
+
|
212 |
+
if current_index < plan_length:
|
213 |
+
logger.debug(f"Continuing plan execution. Next step index: {current_index}")
|
214 |
+
return "continue_execution"
|
215 |
+
else:
|
216 |
+
logger.debug("Plan finished. Proceeding to final generation.")
|
217 |
+
return "finalize"
|
218 |
+
|
219 |
+
|
220 |
+
# --- Build Graph ---
|
221 |
+
def build_graph():
|
222 |
+
"""Builds the LangGraph workflow."""
|
223 |
+
workflow = StateGraph(PlannerState)
|
224 |
+
|
225 |
+
# Add nodes
|
226 |
+
workflow.add_node("start_planning", start_planning)
|
227 |
+
workflow.add_node("execute_plan_step", execute_plan_step)
|
228 |
+
workflow.add_node("generate_issues", generate_structured_issues)
|
229 |
+
# Optional: Add an error handling node
|
230 |
+
workflow.add_node("error_node", lambda state: {"messages": [SystemMessage(content=f"Execution failed: {state.get('error', 'Unknown error')}") ]})
|
231 |
+
|
232 |
+
|
233 |
+
# Define edges
|
234 |
+
workflow.set_entry_point("start_planning")
|
235 |
+
workflow.add_edge("start_planning", "execute_plan_step") # Assume plan is always generated
|
236 |
+
|
237 |
+
workflow.add_conditional_edges(
|
238 |
+
"execute_plan_step",
|
239 |
+
should_continue_planning,
|
240 |
+
{
|
241 |
+
"continue_execution": "execute_plan_step", # Loop back to execute next step
|
242 |
+
"finalize": "generate_issues", # Move to final generation
|
243 |
+
"error_state": "error_node" # Go to error node
|
244 |
+
}
|
245 |
+
)
|
246 |
+
|
247 |
+
workflow.add_edge("generate_issues", END)
|
248 |
+
workflow.add_edge("error_node", END) # End after error
|
249 |
+
|
250 |
+
# Compile the graph with memory (optional)
|
251 |
+
# memory = MemorySaver() # Use if state needs persistence between runs
|
252 |
+
# app_graph = workflow.compile(checkpointer=memory)
|
253 |
+
app_graph = workflow.compile()
|
254 |
+
return app_graph
|
kig_core/processing.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import List, Dict, Any, Union, Optional
|
3 |
+
from langchain_core.output_parsers import StrOutputParser
|
4 |
+
|
5 |
+
from .config import settings
|
6 |
+
from .llm_interface import get_llm, invoke_llm
|
7 |
+
from .prompts import SUMMARIZER_PROMPT
|
8 |
+
from .graph_operations import format_doc_for_llm # Reuse formatting
|
9 |
+
|
10 |
+
# Import llmlingua if compression is used
|
11 |
+
try:
|
12 |
+
from llmlingua import PromptCompressor
|
13 |
+
LLMLINGUA_AVAILABLE = True
|
14 |
+
except ImportError:
|
15 |
+
LLMLINGUA_AVAILABLE = False
|
16 |
+
PromptCompressor = None # Define as None if not available
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
_compressor_cache = {}
|
21 |
+
|
22 |
+
def get_compressor(method: str) -> Optional['PromptCompressor']:
|
23 |
+
"""Initializes and caches llmlingua compressor."""
|
24 |
+
if not LLMLINGUA_AVAILABLE:
|
25 |
+
logger.warning("LLMLingua not installed, compression unavailable.")
|
26 |
+
return None
|
27 |
+
if method not in _compressor_cache:
|
28 |
+
logger.info(f"Initializing LLMLingua compressor: {method}")
|
29 |
+
try:
|
30 |
+
# Adjust model names and params as needed
|
31 |
+
if method == "llm_lingua2":
|
32 |
+
model_name = "microsoft/llmlingua-2-xlm-roberta-large-meetingbank"
|
33 |
+
use_llmlingua2 = True
|
34 |
+
elif method == "llm_lingua":
|
35 |
+
model_name = "microsoft/phi-2" # Requires ~8GB RAM
|
36 |
+
use_llmlingua2 = False
|
37 |
+
else:
|
38 |
+
logger.error(f"Unsupported compression method: {method}")
|
39 |
+
return None
|
40 |
+
|
41 |
+
_compressor_cache[method] = PromptCompressor(
|
42 |
+
model_name=model_name,
|
43 |
+
use_llmlingua2=use_llmlingua2,
|
44 |
+
device_map="cpu" # Or "cuda" if GPU available
|
45 |
+
)
|
46 |
+
except Exception as e:
|
47 |
+
logger.error(f"Failed to initialize LLMLingua compressor '{method}': {e}", exc_info=True)
|
48 |
+
return None
|
49 |
+
return _compressor_cache[method]
|
50 |
+
|
51 |
+
|
52 |
+
def summarize_document(doc_content: str) -> str:
|
53 |
+
"""Summarizes a single document using the configured LLM."""
|
54 |
+
logger.debug("Summarizing document...")
|
55 |
+
try:
|
56 |
+
summarize_llm = get_llm(settings.summarize_llm_model)
|
57 |
+
summarize_chain = SUMMARIZER_PROMPT | summarize_llm | StrOutputParser()
|
58 |
+
summary = invoke_llm(summarize_chain, {"document": doc_content})
|
59 |
+
logger.debug("Summarization complete.")
|
60 |
+
return summary
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"Summarization failed: {e}", exc_info=True)
|
63 |
+
return f"Error during summarization: {e}" # Return error message instead of failing
|
64 |
+
|
65 |
+
|
66 |
+
def compress_document(doc_content: str) -> str:
|
67 |
+
"""Compresses a single document using LLMLingua."""
|
68 |
+
logger.debug(f"Compressing document using method: {settings.compression_method}...")
|
69 |
+
if not settings.compression_method:
|
70 |
+
logger.warning("Compression method not configured, skipping.")
|
71 |
+
return doc_content
|
72 |
+
|
73 |
+
compressor = get_compressor(settings.compression_method)
|
74 |
+
if not compressor:
|
75 |
+
logger.warning("Compressor not available, skipping compression.")
|
76 |
+
return doc_content
|
77 |
+
|
78 |
+
try:
|
79 |
+
# Adjust compression parameters as needed
|
80 |
+
# rate = settings.compress_rate or 0.5
|
81 |
+
# force_tokens = ['\n', '.', ',', '?', '!'] # Example tokens
|
82 |
+
# context? instructions? question?
|
83 |
+
|
84 |
+
# Simple compression for now:
|
85 |
+
result = compressor.compress_prompt(doc_content, rate=settings.compress_rate or 0.5)
|
86 |
+
compressed_text = result.get("compressed_prompt", doc_content)
|
87 |
+
|
88 |
+
original_len = len(doc_content.split())
|
89 |
+
compressed_len = len(compressed_text.split())
|
90 |
+
logger.debug(f"Compression complete. Original words: {original_len}, Compressed words: {compressed_len}")
|
91 |
+
return compressed_text
|
92 |
+
except Exception as e:
|
93 |
+
logger.error(f"Compression failed: {e}", exc_info=True)
|
94 |
+
return f"Error during compression: {e}" # Return error message
|
95 |
+
|
96 |
+
|
97 |
+
def process_documents(
|
98 |
+
docs: List[Dict[str, Any]],
|
99 |
+
processing_steps: List[Union[str, dict]]
|
100 |
+
) -> List[str]:
|
101 |
+
"""Processes a list of documents according to the specified steps."""
|
102 |
+
logger.info(f"Processing {len(docs)} documents with steps: {processing_steps}")
|
103 |
+
if not docs:
|
104 |
+
return []
|
105 |
+
|
106 |
+
processed_outputs = []
|
107 |
+
for i, doc in enumerate(docs):
|
108 |
+
logger.info(f"Processing document {i+1}/{len(docs)}...")
|
109 |
+
current_content = format_doc_for_llm(doc) # Start with formatted original doc
|
110 |
+
|
111 |
+
for step in processing_steps:
|
112 |
+
if step == "summarize":
|
113 |
+
current_content = summarize_document(current_content)
|
114 |
+
elif step == "compress":
|
115 |
+
current_content = compress_document(current_content)
|
116 |
+
elif isinstance(step, dict):
|
117 |
+
# Placeholder for custom processing steps defined by dicts
|
118 |
+
logger.warning(f"Custom processing step not implemented: {step}")
|
119 |
+
# Add logic here if needed: extract params, call specific LLM/function
|
120 |
+
pass
|
121 |
+
else:
|
122 |
+
logger.warning(f"Unknown processing step type: {step}")
|
123 |
+
|
124 |
+
processed_outputs.append(current_content) # Add the final processed content for this doc
|
125 |
+
|
126 |
+
logger.info("Document processing finished.")
|
127 |
+
return processed_outputs
|
kig_core/prompts.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
2 |
+
from .schemas import KeyIssue # Import the Pydantic model
|
3 |
+
|
4 |
+
# --- Cypher Generation ---
|
5 |
+
CYPHER_GENERATION_TEMPLATE = """Task: Generate Cypher statement to query a graph database.
|
6 |
+
Instructions:
|
7 |
+
Use only the provided relationship types and properties in the schema.
|
8 |
+
Do not use any other relationship types or properties that are not provided.
|
9 |
+
Limit to 10 the number of element retrieved.
|
10 |
+
Schema:
|
11 |
+
{schema}
|
12 |
+
|
13 |
+
Note: Do not include explanations or apologies. Respond only with the Cypher statement.
|
14 |
+
Do not respond to questions unrelated to Cypher generation.
|
15 |
+
|
16 |
+
The question is:
|
17 |
+
{question}"""
|
18 |
+
CYPHER_GENERATION_PROMPT = PromptTemplate.from_template(CYPHER_GENERATION_TEMPLATE)
|
19 |
+
|
20 |
+
|
21 |
+
# --- Concept Selection (for 'guided' cypher gen) ---
|
22 |
+
CONCEPT_SELECTION_TEMPLATE = """Task: Select the most relevant Concept from the list below for the user's question.
|
23 |
+
Instructions:
|
24 |
+
Output ONLY the name of the single most relevant concept. No explanations.
|
25 |
+
|
26 |
+
Concepts:
|
27 |
+
{concepts}
|
28 |
+
|
29 |
+
User Question:
|
30 |
+
{question}"""
|
31 |
+
CONCEPT_SELECTION_PROMPT = PromptTemplate.from_template(CONCEPT_SELECTION_TEMPLATE)
|
32 |
+
|
33 |
+
|
34 |
+
# --- Document Relevance Grading ---
|
35 |
+
BINARY_GRADER_TEMPLATE = """Assess the relevance of the retrieved document to the user question.
|
36 |
+
Goal is to filter out clearly erroneous retrievals.
|
37 |
+
If the document contains keywords or semantic meaning related to the question, grade as relevant.
|
38 |
+
Output 'yes' or 'no'."""
|
39 |
+
BINARY_GRADER_PROMPT = ChatPromptTemplate.from_messages([
|
40 |
+
("system", BINARY_GRADER_TEMPLATE),
|
41 |
+
("human", "Retrieved document:\n\n{document}\n\nUser question: {question}"),
|
42 |
+
])
|
43 |
+
|
44 |
+
SCORE_GRADER_TEMPLATE = """Analyze the query and the document. Quantify the relevance.
|
45 |
+
Provide rationale before the score.
|
46 |
+
Output a score between 0 (irrelevant) and 1 (completely relevant)."""
|
47 |
+
SCORE_GRADER_PROMPT = ChatPromptTemplate.from_messages([
|
48 |
+
("system", SCORE_GRADER_TEMPLATE),
|
49 |
+
("human", "Passage:\n\n{document}\n\nUser query: {query}"),
|
50 |
+
])
|
51 |
+
|
52 |
+
|
53 |
+
# --- Planning ---
|
54 |
+
PLAN_GENERATION_TEMPLATE = """You are a standardization expert planning to identify NEW and INNOVATIVE Key Issues related to a technical requirement.
|
55 |
+
Devise a concise, step-by-step plan to achieve this.
|
56 |
+
Consider steps like: Understanding the core problem, Researching existing standards/innovations, Identifying potential gaps/challenges, Formulating Key Issues, and Refining/Detailing them.
|
57 |
+
Output the plan starting with 'Plan:' and numbering each step. End the plan with '<END_OF_PLAN>'."""
|
58 |
+
|
59 |
+
PLAN_MODIFICATION_TEMPLATE = """You are a standardization expert planning to identify NEW and INNOVATIVE Key Issues related to a technical requirement.
|
60 |
+
Adapt the following generic plan template to the specific requirement. Keep it concise.
|
61 |
+
|
62 |
+
### PLAN TEMPLATE ###
|
63 |
+
Plan:
|
64 |
+
1. **Understand Core Requirement**: Analyze the user query to define the scope.
|
65 |
+
2. **Gather Context**: Retrieve relevant specifications, standards, and recent research papers.
|
66 |
+
3. **Identify Gaps & Challenges**: Based on context, brainstorm potential new issues and challenges.
|
67 |
+
4. **Formulate Key Issues**: Structure the findings into distinct Key Issues.
|
68 |
+
5. **Refine & Detail**: Elaborate on each Key Issue, outlining specific challenges.
|
69 |
+
<END_OF_PLAN>
|
70 |
+
### END OF PLAN TEMPLATE ###
|
71 |
+
|
72 |
+
Output the adapted plan starting with 'Plan:' and numbering each step. End with '<END_OF_PLAN>'."""
|
73 |
+
|
74 |
+
|
75 |
+
# --- Document Processing ---
|
76 |
+
SUMMARIZER_TEMPLATE = """You are a 3GPP standardization expert.
|
77 |
+
Summarize the key information in the provided document in simple technical English relevant to identifying potential Key Issues. Focus on challenges, gaps, or novel aspects.
|
78 |
+
|
79 |
+
Document:
|
80 |
+
{document}"""
|
81 |
+
SUMMARIZER_PROMPT = ChatPromptTemplate.from_template(SUMMARIZER_TEMPLATE)
|
82 |
+
|
83 |
+
|
84 |
+
# --- Key Issue Structuring (New) ---
|
85 |
+
# This prompt guides the LLM to output structured Key Issues based on gathered context.
|
86 |
+
# It references the Pydantic model 'KeyIssue' for the desired format.
|
87 |
+
KEY_ISSUE_STRUCTURING_TEMPLATE = f"""Based on the provided context (summaries of relevant documents, research findings, etc.), identify and formulate distinct Key Issues related to the original user query.
|
88 |
+
For each Key Issue identified, provide the following information in the exact JSON format described below. Output a JSON list containing multiple KeyIssue objects.
|
89 |
+
JSON Schema for each Key Issue object:
|
90 |
+
[{{{{
|
91 |
+
"id": "Sequential integer ID starting from 1",
|
92 |
+
"title": "Concise title for the key issue (max 15 words)",
|
93 |
+
"description": "Detailed description of the key issue (2-4 sentences)",
|
94 |
+
"challenges": ["List of specific challenges related to this issue (strings)", "Each challenge as a separate string"],
|
95 |
+
"potential_impact": "Brief description of the potential impact if not addressed (optional, max 30 words)"
|
96 |
+
}}}}]
|
97 |
+
|
98 |
+
User Query: {{user_query}}
|
99 |
+
Context: {{context}}
|
100 |
+
Generate the JSON list of Key Issues based *only* on the provided context and user query. Ensure the output is a valid JSON list.
|
101 |
+
"""
|
102 |
+
KEY_ISSUE_STRUCTURING_PROMPT = ChatPromptTemplate.from_template(KEY_ISSUE_STRUCTURING_TEMPLATE)
|
103 |
+
|
104 |
+
# --- Initial Prompt Selection ---
|
105 |
+
def get_initial_planner_prompt(plan_method: str, user_query: str) -> ChatPromptTemplate:
|
106 |
+
if plan_method == "generation":
|
107 |
+
template = PLAN_GENERATION_TEMPLATE
|
108 |
+
elif plan_method == "modification":
|
109 |
+
template = PLAN_MODIFICATION_TEMPLATE
|
110 |
+
else:
|
111 |
+
raise ValueError("Invalid plan_method")
|
112 |
+
|
113 |
+
# Return as ChatPromptTemplate for consistency
|
114 |
+
return ChatPromptTemplate.from_messages([
|
115 |
+
("system", template),
|
116 |
+
("human", user_query)
|
117 |
+
])
|
kig_core/schemas.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Any, Optional, Union, Annotated
|
2 |
+
from typing_extensions import TypedDict
|
3 |
+
from langchain_core.messages import BaseMessage
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
from langgraph.graph.message import add_messages
|
6 |
+
|
7 |
+
# --- Pydantic Models for Structured Output ---
|
8 |
+
|
9 |
+
class KeyIssue(BaseModel):
|
10 |
+
"""Represents a single generated Key Issue."""
|
11 |
+
id: int = Field(..., description="Sequential ID for the key issue")
|
12 |
+
title: str = Field(..., description="A concise title for the key issue")
|
13 |
+
description: str = Field(..., description="Detailed description of the key issue")
|
14 |
+
challenges: List[str] = Field(default_factory=list, description="Specific challenges associated with this issue")
|
15 |
+
potential_impact: Optional[str] = Field(None, description="Potential impact if the issue is not addressed")
|
16 |
+
# Add source tracking if possible/needed from the processed docs
|
17 |
+
# sources: List[str] = Field(default_factory=list, description="Source documents relevant to this issue")
|
18 |
+
|
19 |
+
|
20 |
+
# --- TypedDicts for LangGraph State ---
|
21 |
+
|
22 |
+
class GraphConfig(TypedDict):
|
23 |
+
"""Configuration passed to the graph execution."""
|
24 |
+
thread_id: str
|
25 |
+
# Add other config items needed at runtime if not globally available via settings
|
26 |
+
|
27 |
+
class BaseState(TypedDict):
|
28 |
+
"""Base state common across potentially multiple graphs."""
|
29 |
+
messages: Annotated[List[BaseMessage], add_messages]
|
30 |
+
error: Optional[str] # To store potential errors during execution
|
31 |
+
|
32 |
+
class PlannerState(BaseState):
|
33 |
+
"""State specific to the main planner graph."""
|
34 |
+
user_query: str
|
35 |
+
plan: List[str] # The high-level plan steps
|
36 |
+
current_plan_step_index: int # Index of the current step being executed
|
37 |
+
# Stored data from previous steps (e.g., summaries)
|
38 |
+
# Use a dictionary to store context relevant to each plan step
|
39 |
+
step_outputs: Dict[int, Any] # Stores output (e.g., processed docs) from each step
|
40 |
+
# Final structured output
|
41 |
+
key_issues: List[KeyIssue]
|
42 |
+
|
43 |
+
|
44 |
+
class DataRetrievalState(TypedDict):
|
45 |
+
"""State for a potential data retrieval sub-graph."""
|
46 |
+
query_for_retrieval: str # The specific query for this retrieval step
|
47 |
+
retrieved_docs: List[Dict[str, Any]] # Raw docs from Neo4j
|
48 |
+
evaluated_docs: List[Dict[str, Any]] # Docs after relevance grading
|
49 |
+
cypher_queries: List[str] # Generated Cypher queries
|
50 |
+
|
51 |
+
class ProcessingState(TypedDict):
|
52 |
+
"""State for a potential document processing sub-graph."""
|
53 |
+
docs_to_process: List[Dict[str, Any]] # Documents passed for processing
|
54 |
+
processed_docs: List[Union[str, Dict[str, Any]]] # Processed/summarized docs
|
55 |
+
processing_steps_config: List[Union[str, dict]] # Configuration for processing
|
kig_core/utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import io
|
3 |
+
import logging
|
4 |
+
from typing import List
|
5 |
+
from .schemas import KeyIssue # Import the Pydantic model
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
def key_issues_to_dataframe(key_issues: List[KeyIssue]) -> pd.DataFrame:
|
10 |
+
"""Converts a list of KeyIssue objects into a Pandas DataFrame."""
|
11 |
+
if not key_issues:
|
12 |
+
return pd.DataFrame()
|
13 |
+
# Use Pydantic's .model_dump() for robust serialization
|
14 |
+
data = [ki.model_dump() for ki in key_issues]
|
15 |
+
df = pd.DataFrame(data)
|
16 |
+
# Optional: Reorder or rename columns if needed
|
17 |
+
# df = df[['id', 'title', 'description', 'challenges', 'potential_impact']] # Example reordering
|
18 |
+
return df
|
19 |
+
|
20 |
+
def dataframe_to_excel_bytes(df: pd.DataFrame) -> bytes:
|
21 |
+
"""Converts a Pandas DataFrame to Excel format in memory (bytes)."""
|
22 |
+
logger.info("Generating Excel file from DataFrame...")
|
23 |
+
output = io.BytesIO()
|
24 |
+
try:
|
25 |
+
# Use BytesIO object as the target file
|
26 |
+
with pd.ExcelWriter(output, engine='openpyxl') as writer:
|
27 |
+
df.to_excel(writer, index=False, sheet_name='Key Issues')
|
28 |
+
excel_data = output.getvalue()
|
29 |
+
logger.info("Excel file generated successfully.")
|
30 |
+
return excel_data
|
31 |
+
except Exception as e:
|
32 |
+
logger.error(f"Failed to generate Excel file: {e}", exc_info=True)
|
33 |
+
raise RuntimeError("Failed to create Excel output.") from e
|
34 |
+
|
35 |
+
# Removed: format_df (HTML specific, less relevant for Excel output)
|
36 |
+
# Removed: init_app (handled by config.py)
|
37 |
+
# Removed: get_model (handled by llm_interface.py)
|
38 |
+
# Removed: clear_memory (handle state/memory management within LangGraph setup if needed)
|
39 |
+
# Removed: _set_env (handled by config.py and dotenv)
|
40 |
+
# Kept: format_doc (renamed to format_doc_for_llm in graph_operations.py)
|
41 |
+
# Removed: update_doc_history (reducer logic should be handled in LangGraph state definition/nodes)
|
requirements.txt
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Langchain Core & Ecosystem
|
2 |
+
langchain-core>=0.2.29
|
3 |
+
langchain-google-genai>=1.0.9 # For Gemini
|
4 |
+
langchain-openai>=0.1.21 # If using OpenAI
|
5 |
+
langgraph>=0.1.10
|
6 |
+
langchain-community>=0.2.10 # For Neo4jGraph if needed, other community integrations
|
7 |
+
|
8 |
+
# LLM & Processing Libraries
|
9 |
+
# llmlingua==0.2.2 # Uncomment if using compression (Ensure it's compatible)
|
10 |
+
google-generativeai>=0.7.2 # Underlying Gemini library
|
11 |
+
|
12 |
+
# Neo4j
|
13 |
+
neo4j>=5.24.0
|
14 |
+
|
15 |
+
# API Framework
|
16 |
+
fastapi>=0.110.0 # Added for FastAPI
|
17 |
+
uvicorn[standard]>=0.29.0 # Added for running FastAPI server
|
18 |
+
|
19 |
+
# Configuration & Utilities
|
20 |
+
pydantic>=2.9.0
|
21 |
+
pydantic-settings>=2.4.0 # For BaseSettings
|
22 |
+
python-dotenv>=1.0.1 # For loading .env files
|
23 |
+
|
24 |
+
# Optional: For LangSmith Tracing
|
25 |
+
# langsmith>=0.1.100
|