Spaces:
Running
Running
import os | |
from textwrap import dedent | |
from typing import TypedDict, Annotated, Optional, Any, Callable, Sequence, Union | |
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage | |
from langchain_core.tools import BaseTool | |
from langchain_openai import ChatOpenAI | |
from langchain_tavily import TavilySearch | |
from langgraph.constants import START | |
from langgraph.errors import GraphRecursionError | |
from langgraph.graph import add_messages, StateGraph | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from langgraph.pregel import PregelProtocol | |
from loguru import logger | |
from pydantic import SecretStr | |
from tools.excel_to_text import excel_to_text | |
from tools.execute_python_code_from_file import execute_python_code_from_file | |
from tools.maths import add_integers | |
from tools.produce_classifier import produce_classifier | |
from tools.sort_words_alphabetically import sort_words_alphabetically | |
from tools.transcribe_audio import transcribe_audio | |
from tools.web_page_information_extractor import web_page_information_extractor | |
from tools.wikipedia_search import wikipedia_search | |
from tools.youtube_transcript import youtube_transcript | |
class AgentState(TypedDict): | |
messages: Annotated[list[AnyMessage], add_messages] | |
class ShrewdAgent: | |
message_system = dedent(""" | |
You are a general AI assistant equipped with a suite of external tools. Your task is to | |
answer the following question as accurately and helpfully as possible by using the tools | |
provided. Do not write or execute code yourself. For any operation requiring computation, | |
data retrieval, or external access, explicitly invoke the appropriate tool. | |
Follow these guidelines: | |
- Clearly explain your reasoning step by step. | |
- Justify your choice of tool(s) at each step. | |
- If multiple interpretations are possible, outline them and explain your reasoning for selecting one. | |
- If the answer requires external data or inference, retrieve or deduce it via the available tools. | |
Important: Your final output MUST be only a number or a word with no additional text or explanation, | |
unless the response format is explicitly specified in the question. Do not include reasoning, | |
commentary, or any other content beyond the requested answer.""") | |
def __init__(self): | |
self.tools = [ | |
TavilySearch(), | |
wikipedia_search, | |
web_page_information_extractor, | |
youtube_transcript, | |
produce_classifier, | |
sort_words_alphabetically, | |
excel_to_text, | |
execute_python_code_from_file, | |
add_integers, | |
transcribe_audio, | |
] | |
self.llm = ChatOpenAI( | |
model="gpt-4.1", | |
temperature=0, | |
api_key=SecretStr(os.environ['OPENAI_API_KEY']) | |
).bind_tools(self.tools) | |
def assistant_node(state: AgentState): | |
return { | |
"messages": [self.llm.invoke(state["messages"])], | |
} | |
self.agent = _build_state_graph(AgentState, assistant_node, self.tools) | |
logger.info(f"Agent initialized with tools: {[tool.name for tool in self.tools]}") | |
logger.debug(f"system message:\n{self.message_system}") | |
def __call__(self, question: str) -> str: | |
logger.info(f"Agent received question:\n{question}") | |
accumulated_response = [] | |
try: | |
for chunk in self.agent.stream( | |
{"messages": [ | |
SystemMessage(self.message_system), | |
HumanMessage(question, ) | |
]}, | |
{"recursion_limit": 18}, | |
): | |
assistant = chunk.get("assistant") | |
if assistant: | |
logger.debug(f"\n{assistant.get('messages')[0].pretty_repr()}") | |
tools = chunk.get("tools") | |
if tools: | |
logger.debug(f"\n{tools.get('messages')[0].pretty_repr()}") | |
accumulated_response.append(chunk) | |
except GraphRecursionError as e: | |
logger.error(f"GraphRecursionError: {e}") | |
final_answer = "I couldn't find the answer" | |
if accumulated_response[-1].get("assistant"): | |
final_answer = accumulated_response[-1]["assistant"]['messages'][-1].content | |
logger.info(f"Agent returning answer: {final_answer}") | |
return final_answer | |
def _build_state_graph( | |
state_schema: Optional[type[Any]], | |
assistant: Callable, | |
tools: Sequence[Union[BaseTool, Callable]]) -> PregelProtocol: # CompiledStateGraph: | |
return (StateGraph(state_schema) | |
.add_node("assistant", assistant) | |
.add_node("tools", ToolNode(tools)) | |
.add_edge(START, "assistant") | |
.add_conditional_edges("assistant", tools_condition) | |
.add_edge("tools", "assistant") | |
.compile() | |
) |