|
import logging |
|
import re |
|
import time |
|
from typing import List, Dict, Any, Optional |
|
from langgraph.graph import StateGraph, END |
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
|
from pydantic import BaseModel, Field |
|
|
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage |
|
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser |
|
|
|
from .config import settings |
|
from .schemas import PlannerState, KeyIssue, GraphConfig |
|
from .prompts import get_initial_planner_prompt, KEY_ISSUE_STRUCTURING_PROMPT |
|
from .llm_interface import get_llm, invoke_llm |
|
from .graph_operations import ( |
|
generate_cypher_auto, generate_cypher_guided, |
|
retrieve_documents, evaluate_documents |
|
) |
|
from .processing import process_documents |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
def start_planning(state: PlannerState) -> Dict[str, Any]: |
|
"""Generates the initial plan based on the user query.""" |
|
logger.info("Node: start_planning") |
|
user_query = state['user_query'] |
|
if not user_query: |
|
return {"error": "User query is empty."} |
|
|
|
initial_prompt = get_initial_planner_prompt(settings.plan_method, user_query) |
|
llm = get_llm(settings.main_llm_model) |
|
chain = initial_prompt | llm | StrOutputParser() |
|
|
|
try: |
|
plan_text = invoke_llm(chain,{}) |
|
logger.debug(f"Raw plan text: {plan_text}") |
|
|
|
|
|
plan_match = re.search(r"Plan:(.*?)<END_OF_PLAN>", plan_text, re.DOTALL | re.IGNORECASE) |
|
if plan_match: |
|
plan_steps = [step.strip() for step in re.split(r"\n\s*\d+\.\s*", plan_match.group(1)) if step.strip()] |
|
logger.info(f"Extracted plan: {plan_steps}") |
|
return { |
|
"plan": plan_steps, |
|
"current_plan_step_index": 0, |
|
"messages": [AIMessage(content=plan_text)], |
|
"step_outputs": {} |
|
} |
|
else: |
|
logger.error("Could not parse plan from LLM response.") |
|
return {"error": "Failed to parse plan from LLM response.", "messages": [AIMessage(content=plan_text)]} |
|
except Exception as e: |
|
logger.error(f"Error during plan generation: {e}", exc_info=True) |
|
return {"error": f"LLM error during plan generation: {e}"} |
|
|
|
|
|
def execute_plan_step(state: PlannerState) -> Dict[str, Any]: |
|
"""Executes the current step of the plan (retrieval, processing).""" |
|
current_index = state['current_plan_step_index'] |
|
plan = state['plan'] |
|
user_query = state['user_query'] |
|
|
|
if current_index >= len(plan): |
|
logger.warning("Plan step index out of bounds, attempting to finalize.") |
|
|
|
return {"error": "Plan execution finished unexpectedly."} |
|
|
|
step_description = plan[current_index] |
|
logger.info(f"Node: execute_plan_step - Step {current_index + 1}/{len(plan)}: {step_description}") |
|
|
|
|
|
|
|
|
|
query_for_retrieval = f"Regarding the query '{user_query}', focus on: {step_description}" |
|
logger.info(f"Query for retrieval: {query_for_retrieval}") |
|
|
|
|
|
cypher_query = "" |
|
if settings.cypher_gen_method == 'auto': |
|
cypher_query = generate_cypher_auto(query_for_retrieval) |
|
elif settings.cypher_gen_method == 'guided': |
|
cypher_query = generate_cypher_guided(query_for_retrieval, current_index) |
|
|
|
|
|
|
|
retrieved_docs = retrieve_documents(cypher_query) |
|
|
|
|
|
evaluated_docs = evaluate_documents(retrieved_docs, query_for_retrieval) |
|
|
|
|
|
|
|
processed_docs_content = process_documents(evaluated_docs, settings.process_steps) |
|
|
|
|
|
|
|
step_output = "\n\n".join(processed_docs_content) if processed_docs_content else "No relevant information found for this step." |
|
current_step_outputs = state.get('step_outputs', {}) |
|
current_step_outputs[current_index] = step_output |
|
|
|
logger.info(f"Finished executing plan step {current_index + 1}. Stored output.") |
|
|
|
return { |
|
"current_plan_step_index": current_index + 1, |
|
"messages": [SystemMessage(content=f"Completed plan step {current_index + 1}. Context gathered:\n{step_output[:500]}...")], |
|
"step_outputs": current_step_outputs |
|
} |
|
|
|
class KeyIssue(BaseModel): |
|
|
|
id: int |
|
description: str |
|
|
|
class KeyIssueList(BaseModel): |
|
key_issues: List[KeyIssue] = Field(description="List of key issues") |
|
|
|
class KeyIssueInvoke(BaseModel): |
|
id: int |
|
title: str |
|
description: str |
|
challenges: List[str] |
|
potential_impact: Optional[str] = None |
|
|
|
def generate_structured_issues(state: PlannerState) -> Dict[str, Any]: |
|
"""Generates the final structured Key Issues based on all gathered context.""" |
|
logger.info("Node: generate_structured_issues") |
|
|
|
user_query = state['user_query'] |
|
step_outputs = state.get('step_outputs', {}) |
|
|
|
|
|
full_context = f"Original User Query: {user_query}\n\n" |
|
full_context += "Context gathered during planning:\n" |
|
for i, output in sorted(step_outputs.items()): |
|
full_context += f"--- Context from Step {i+1} ---\n{output}\n\n" |
|
|
|
if not step_outputs: |
|
full_context += "No context was gathered during the planning steps.\n" |
|
|
|
logger.info(f"Generating key issues using combined context (length: {len(full_context)} chars).") |
|
|
|
|
|
|
|
issue_llm = get_llm(settings.main_llm_model) |
|
|
|
output_parser = JsonOutputParser(pydantic_object=KeyIssueList) |
|
|
|
|
|
prompt = KEY_ISSUE_STRUCTURING_PROMPT.partial( |
|
schema=output_parser.get_format_instructions(), |
|
) |
|
|
|
chain = prompt | issue_llm | output_parser |
|
|
|
try: |
|
structured_issues_obj = invoke_llm(chain, { |
|
"user_query": user_query, |
|
"context": full_context |
|
}) |
|
print(f"structured_issues_obj => type : {type(structured_issues_obj)}, value : {structured_issues_obj}") |
|
|
|
|
|
if isinstance(structured_issues_obj, dict) and 'key_issues' in structured_issues_obj: |
|
issues_data = structured_issues_obj['key_issues'] |
|
else: |
|
issues_data = structured_issues_obj |
|
|
|
|
|
key_issues_list = [KeyIssueInvoke(**issue_dict) for issue_dict in issues_data] |
|
|
|
|
|
for i, issue in enumerate(key_issues_list): |
|
issue.id = i + 1 |
|
|
|
logger.info(f"Successfully generated {len(key_issues_list)} structured key issues.") |
|
final_message = f"Generated {len(key_issues_list)} Key Issues based on the query '{user_query}'." |
|
return { |
|
"key_issues": key_issues_list, |
|
"messages": [AIMessage(content=final_message)], |
|
"error": None |
|
} |
|
except Exception as e: |
|
logger.error(f"Failed to generate or parse structured key issues: {e}", exc_info=True) |
|
|
|
raw_output = "Could not retrieve raw output." |
|
try: |
|
raw_chain = prompt | issue_llm | StrOutputParser() |
|
raw_output = invoke_llm(raw_chain, {"user_query": user_query, "context": full_context}) |
|
logger.debug(f"Raw output from failed JSON parsing:\n{raw_output}") |
|
except Exception as raw_e: |
|
logger.error(f"Could not even get raw output: {raw_e}") |
|
|
|
return {"error": f"Failed to generate structured key issues: {e}. Raw output hint: {raw_output[:500]}..."} |
|
|
|
|
|
|
|
|
|
def should_continue_planning(state: PlannerState) -> str: |
|
"""Determines if there are more plan steps to execute.""" |
|
logger.debug("Edge: should_continue_planning") |
|
if state.get("error"): |
|
logger.error(f"Error state detected: {state['error']}. Ending execution.") |
|
return "error_state" |
|
|
|
current_index = state['current_plan_step_index'] |
|
plan_length = len(state.get('plan', [])) |
|
|
|
if current_index < plan_length: |
|
logger.debug(f"Continuing plan execution. Next step index: {current_index}") |
|
return "continue_execution" |
|
else: |
|
logger.debug("Plan finished. Proceeding to final generation.") |
|
return "finalize" |
|
|
|
|
|
|
|
def build_graph(): |
|
"""Builds the LangGraph workflow.""" |
|
workflow = StateGraph(PlannerState) |
|
|
|
|
|
workflow.add_node("start_planning", start_planning) |
|
workflow.add_node("execute_plan_step", execute_plan_step) |
|
workflow.add_node("generate_issues", generate_structured_issues) |
|
|
|
workflow.add_node("error_node", lambda state: {"messages": [SystemMessage(content=f"Execution failed: {state.get('error', 'Unknown error')}") ]}) |
|
|
|
|
|
|
|
workflow.set_entry_point("start_planning") |
|
workflow.add_edge("start_planning", "execute_plan_step") |
|
|
|
workflow.add_conditional_edges( |
|
"execute_plan_step", |
|
should_continue_planning, |
|
{ |
|
"continue_execution": "execute_plan_step", |
|
"finalize": "generate_issues", |
|
"error_state": "error_node" |
|
} |
|
) |
|
|
|
workflow.add_edge("generate_issues", END) |
|
workflow.add_edge("error_node", END) |
|
|
|
|
|
|
|
|
|
app_graph = workflow.compile() |
|
return app_graph |