|
|
from typing import Iterable, Optional |
|
|
from langchain_openai import ChatOpenAI |
|
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage |
|
|
|
|
|
from src.schemas import ComplexityLevel, ExecutionReport, PlannerPlan |
|
|
from src.prompts.prompts import COMPLEXITY_ASSESSOR_PROMPT |
|
|
from src.state import AgentState |
|
|
|
|
|
def log_stage(title: str, subtitle: Optional[str] = None, icon: str = "🚀") -> None: |
|
|
"""Render a banner for the current execution stage.""" |
|
|
|
|
|
title_line = f" {title.strip()} " |
|
|
border = icon + " " + "═" * max(len(title_line), 20) |
|
|
print(f"\n{border}\n{icon} {title_line}\n{icon} " + "═" * max(len(title_line), 20)) |
|
|
if subtitle: |
|
|
print(f"{icon} {subtitle}") |
|
|
|
|
|
|
|
|
def log_key_values(pairs: Iterable[tuple[str, str]]) -> None: |
|
|
"""Pretty-print simple key/value diagnostics.""" |
|
|
|
|
|
for key, value in pairs: |
|
|
print(f" • {key}: {value}") |
|
|
|
|
|
|
|
|
def format_plan_overview(plan: PlannerPlan) -> str: |
|
|
"""Create a human-readable summary of plan steps.""" |
|
|
|
|
|
if not plan or not plan.steps: |
|
|
return "(no steps – direct response)" |
|
|
|
|
|
lines = [] |
|
|
for step in plan.steps: |
|
|
tool_hint = step.tool if step.tool else "no tool" |
|
|
lines.append(f"{step.id}: {step.goal} [{tool_hint}]") |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
def display_plan(plan: PlannerPlan) -> None: |
|
|
"""Print plan contents in a compact, readable form.""" |
|
|
|
|
|
log_stage("PLANNER OUTPUT", icon="🧭") |
|
|
print(f"Task type: {plan.task_type}") |
|
|
print(f"Summary: {plan.summary}") |
|
|
if plan.assumptions: |
|
|
print("Assumptions:") |
|
|
for item in plan.assumptions: |
|
|
print(f" - {item}") |
|
|
print("Steps:") |
|
|
for step in plan.steps: |
|
|
print(f" {step.id} → {step.goal}") |
|
|
if step.tool: |
|
|
print(f" tool: {step.tool}") |
|
|
else: |
|
|
print(" tool: (none)") |
|
|
if step.inputs: |
|
|
print(f" inputs: {step.inputs}") |
|
|
print(f" expected: {step.expected_result}") |
|
|
if step.on_fail: |
|
|
print(f" on_fail: {step.on_fail}") |
|
|
if plan.answer_guidelines: |
|
|
print(f"Answer guidelines: {plan.answer_guidelines}") |
|
|
|
|
|
|
|
|
def clean_message_history(messages): |
|
|
""" |
|
|
Очищает историю сообщений от неполных циклов tool_calls/responses. |
|
|
Удаляет AIMessage с tool_calls, если нет соответствующих ToolMessage. |
|
|
""" |
|
|
cleaned_messages = [] |
|
|
i = 0 |
|
|
|
|
|
while i < len(messages): |
|
|
msg = messages[i] |
|
|
|
|
|
|
|
|
if hasattr(msg, 'tool_calls') and msg.tool_calls: |
|
|
|
|
|
tool_call_ids = {tc['id'] for tc in msg.tool_calls} |
|
|
found_responses = set() |
|
|
|
|
|
|
|
|
j = i + 1 |
|
|
while j < len(messages) and isinstance(messages[j], ToolMessage): |
|
|
if messages[j].tool_call_id in tool_call_ids: |
|
|
found_responses.add(messages[j].tool_call_id) |
|
|
j += 1 |
|
|
|
|
|
|
|
|
if found_responses == tool_call_ids: |
|
|
|
|
|
cleaned_messages.append(msg) |
|
|
for k in range(i + 1, j): |
|
|
cleaned_messages.append(messages[k]) |
|
|
i = j |
|
|
else: |
|
|
|
|
|
print(f"Removing incomplete tool call block: {tool_call_ids - found_responses}") |
|
|
i = j |
|
|
else: |
|
|
|
|
|
cleaned_messages.append(msg) |
|
|
i += 1 |
|
|
|
|
|
return cleaned_messages |
|
|
|
|
|
def format_final_answer(report: ExecutionReport, complexity: dict) -> str: |
|
|
"""Format the final answer based on complexity and report content.""" |
|
|
|
|
|
if complexity.level == 'simple': |
|
|
|
|
|
return f"FINAL ANSWER: {report.final_answer}" |
|
|
|
|
|
|
|
|
formatted = f"""FINAL ANSWER: {report.final_answer} |
|
|
|
|
|
SUMMARY: |
|
|
{report.query_summary} |
|
|
|
|
|
KEY FINDINGS: |
|
|
{chr(10).join(f"• {finding}" for finding in report.key_findings)}""" |
|
|
|
|
|
if report.data_sources: |
|
|
formatted += f""" |
|
|
|
|
|
SOURCES: |
|
|
{chr(10).join(f"• {source}" for source in report.data_sources[:5])}""" |
|
|
|
|
|
if report.limitations: |
|
|
formatted += f""" |
|
|
|
|
|
LIMITATIONS: |
|
|
{chr(10).join(f"• {limitation}" for limitation in report.limitations)}""" |
|
|
|
|
|
return formatted |
|
|
|
|
|
|
|
|
def complexity_assessor(state: AgentState) -> AgentState: |
|
|
"""Assess query complexity and determine if planning is needed.""" |
|
|
print("=== COMPLEXITY ASSESSMENT ===") |
|
|
|
|
|
complexity_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.25).with_structured_output(ComplexityLevel) |
|
|
|
|
|
assessment_message = [ |
|
|
SystemMessage(content=COMPLEXITY_ASSESSOR_PROMPT.strip()), |
|
|
HumanMessage(content=f"Query: {state['query']}") |
|
|
] |
|
|
|
|
|
assessment = complexity_llm.invoke(assessment_message) |
|
|
|
|
|
print(f"Complexity: {assessment.level}") |
|
|
print(f"Needs planning: {assessment.needs_planning}") |
|
|
print(f"Reasoning: {assessment.reasoning}") |
|
|
|
|
|
return { |
|
|
"complexity_assessment": assessment, |
|
|
"messages": state["messages"] + assessment_message |
|
|
} |
|
|
|
|
|
|
|
|
def trim(s: str, max_len: int = 10_000) -> str: |
|
|
if s and len(s) > max_len: |
|
|
return s[:max_len] + "... [truncated]" |
|
|
return s |
|
|
|