KaiserShultz's picture
Update src/utils/utils.py
f247400 verified
from typing import Iterable, Optional
from langchain_openai import ChatOpenAI
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from src.schemas import ComplexityLevel, ExecutionReport, PlannerPlan
from src.prompts.prompts import COMPLEXITY_ASSESSOR_PROMPT
from src.state import AgentState
def log_stage(title: str, subtitle: Optional[str] = None, icon: str = "🚀") -> None:
"""Render a banner for the current execution stage."""
title_line = f" {title.strip()} "
border = icon + " " + "═" * max(len(title_line), 20)
print(f"\n{border}\n{icon} {title_line}\n{icon} " + "═" * max(len(title_line), 20))
if subtitle:
print(f"{icon} {subtitle}")
def log_key_values(pairs: Iterable[tuple[str, str]]) -> None:
"""Pretty-print simple key/value diagnostics."""
for key, value in pairs:
print(f" • {key}: {value}")
def format_plan_overview(plan: PlannerPlan) -> str:
"""Create a human-readable summary of plan steps."""
if not plan or not plan.steps:
return "(no steps – direct response)"
lines = []
for step in plan.steps:
tool_hint = step.tool if step.tool else "no tool"
lines.append(f"{step.id}: {step.goal} [{tool_hint}]")
return "\n".join(lines)
def display_plan(plan: PlannerPlan) -> None:
"""Print plan contents in a compact, readable form."""
log_stage("PLANNER OUTPUT", icon="🧭")
print(f"Task type: {plan.task_type}")
print(f"Summary: {plan.summary}")
if plan.assumptions:
print("Assumptions:")
for item in plan.assumptions:
print(f" - {item}")
print("Steps:")
for step in plan.steps:
print(f" {step.id}{step.goal}")
if step.tool:
print(f" tool: {step.tool}")
else:
print(" tool: (none)")
if step.inputs:
print(f" inputs: {step.inputs}")
print(f" expected: {step.expected_result}")
if step.on_fail:
print(f" on_fail: {step.on_fail}")
if plan.answer_guidelines:
print(f"Answer guidelines: {plan.answer_guidelines}")
def clean_message_history(messages):
"""
Очищает историю сообщений от неполных циклов tool_calls/responses.
Удаляет AIMessage с tool_calls, если нет соответствующих ToolMessage.
"""
cleaned_messages = []
i = 0
while i < len(messages):
msg = messages[i]
# Если это AIMessage с tool_calls
if hasattr(msg, 'tool_calls') and msg.tool_calls:
# Ищем соответствующие ToolMessage
tool_call_ids = {tc['id'] for tc in msg.tool_calls}
found_responses = set()
# Проверяем следующие сообщения на наличие ответов
j = i + 1
while j < len(messages) and isinstance(messages[j], ToolMessage):
if messages[j].tool_call_id in tool_call_ids:
found_responses.add(messages[j].tool_call_id)
j += 1
# Если все tool_calls имеют ответы, добавляем весь блок
if found_responses == tool_call_ids:
# Добавляем AIMessage и все соответствующие ToolMessage
cleaned_messages.append(msg)
for k in range(i + 1, j):
cleaned_messages.append(messages[k])
i = j
else:
# Пропускаем неполный блок
print(f"Removing incomplete tool call block: {tool_call_ids - found_responses}")
i = j
else:
# Обычное сообщение - добавляем
cleaned_messages.append(msg)
i += 1
return cleaned_messages
def format_final_answer(report: ExecutionReport, complexity: dict) -> str:
"""Format the final answer based on complexity and report content."""
if complexity.level == 'simple':
# For simple queries, just return the answer
return f"FINAL ANSWER: {report.final_answer}"
# For complex queries, provide more detailed response
formatted = f"""FINAL ANSWER: {report.final_answer}
SUMMARY:
{report.query_summary}
KEY FINDINGS:
{chr(10).join(f"• {finding}" for finding in report.key_findings)}"""
if report.data_sources:
formatted += f"""
SOURCES:
{chr(10).join(f"• {source}" for source in report.data_sources[:5])}""" # Limit to 5 sources
if report.limitations:
formatted += f"""
LIMITATIONS:
{chr(10).join(f"• {limitation}" for limitation in report.limitations)}"""
return formatted
def complexity_assessor(state: AgentState) -> AgentState:
"""Assess query complexity and determine if planning is needed."""
print("=== COMPLEXITY ASSESSMENT ===")
complexity_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.25).with_structured_output(ComplexityLevel)
assessment_message = [
SystemMessage(content=COMPLEXITY_ASSESSOR_PROMPT.strip()),
HumanMessage(content=f"Query: {state['query']}")
]
assessment = complexity_llm.invoke(assessment_message)
print(f"Complexity: {assessment.level}")
print(f"Needs planning: {assessment.needs_planning}")
print(f"Reasoning: {assessment.reasoning}")
return {
"complexity_assessment": assessment,
"messages": state["messages"] + assessment_message
}
def trim(s: str, max_len: int = 10_000) -> str:
if s and len(s) > max_len:
return s[:max_len] + "... [truncated]"
return s