Spaces:
Sleeping
Sleeping
| # env variable needed: OPENAI_API_KEY, BRAVE_SEARCH_API_KEY | |
| import os | |
| import json | |
| from dotenv import load_dotenv | |
| from typing import Literal | |
| from langchain_openai import ChatOpenAI | |
| from langgraph.graph import MessagesState | |
| from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage | |
| from langgraph.graph import StateGraph, START, END | |
| from langchain_community.tools import BraveSearch, WikipediaQueryRun | |
| # from langchain_community.utilities import WikipediaAPIWrapper | |
| from .prompt import system_prompt | |
| from .custom_tools import (calculator_tool, web_search, query_image, python_repl, | |
| get_webdoc_content, get_website_content, extract_answer_from_content, | |
| transcribe_audio, get_youtube_transcript, generate_table_from_data, check_commutative) | |
| load_dotenv() | |
| # get API key from openai, and then secure the OpenAI API key in env | |
| openai_api_key = os.environ['OPENAI_API_KEY'] | |
| class LangGraphAgent: | |
| def __init__(self, | |
| model_name="gpt-4.1-mini", | |
| show_tools_desc=True, | |
| show_prompt=True): | |
| # =========== LLM definition =========== | |
| llm = ChatOpenAI(model=model_name, temperature=0, openai_api_key=openai_api_key) | |
| print(f"LangGraphAgent initialized with model \"{model_name}\"") | |
| # =========== Augment the LLM with tools =========== | |
| community_tools = [ | |
| BraveSearch.from_api_key( # Web search (more performant than DuckDuckGo) | |
| api_key=os.getenv("BRAVE_SEARCH_API_KEY"), # needs BRAVE_SEARCH_API_KEY in env | |
| search_kwargs={"count": 5}), | |
| ] | |
| # wikipedia_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()) | |
| custom_tools = [ | |
| calculator_tool, # Basic math operations | |
| web_search, # Web search using Tavily | |
| query_image, # Ask anything about an image using a VLM | |
| python_repl, # Python code interpreter | |
| get_webdoc_content, # Load a web document | |
| get_website_content, # Load a web page | |
| extract_answer_from_content, # Extract an answer from a given content (e.g. PDF, web page) | |
| transcribe_audio, # Transcribe an audio file to text | |
| get_youtube_transcript, # Get the transcript of a YouTube video | |
| generate_table_from_data, # Generate a table from a given data | |
| check_commutative, # Analyzes a binary operation table for commutativity | |
| ] | |
| tools = community_tools + custom_tools | |
| tools_by_name = {tool.name: tool for tool in tools} | |
| llm_with_tools = llm.bind_tools(tools) | |
| # =========== Agent definition =========== | |
| # Nodes | |
| def llm_call(state: MessagesState): | |
| """LLM decides whether to call a tool or not""" | |
| return { | |
| "messages": [ | |
| llm_with_tools.invoke( | |
| [ | |
| SystemMessage( | |
| content=system_prompt | |
| ) | |
| ] | |
| + state["messages"] | |
| ) | |
| ] | |
| } | |
| def tool_node(state: dict): | |
| """Performs the tool call""" | |
| result = [] | |
| for tool_call in state["messages"][-1].tool_calls: | |
| tool = tools_by_name[tool_call["name"]] | |
| observation = tool.invoke(tool_call["args"]) | |
| result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"])) | |
| return {"messages": result} | |
| # Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call | |
| def should_continue(state: MessagesState) -> Literal["environment", END]: | |
| """Decide if we should continue the loop or stop based upon whether the LLM made a tool call""" | |
| messages = state["messages"] | |
| last_message = messages[-1] | |
| # If the LLM makes a tool call, then perform an action | |
| if last_message.tool_calls: | |
| return "Action" | |
| # Otherwise, we stop (reply to the user) | |
| return END | |
| # Build workflow | |
| agent_builder = StateGraph(MessagesState) | |
| # Add nodes | |
| agent_builder.add_node("llm_call", llm_call) | |
| agent_builder.add_node("environment", tool_node) | |
| # Add edges to connect nodes | |
| agent_builder.add_edge(START, "llm_call") | |
| agent_builder.add_conditional_edges( | |
| "llm_call", | |
| should_continue, | |
| { | |
| # Name returned by should_continue : Name of next node to visit | |
| "Action": "environment", | |
| END: END, | |
| }, | |
| ) | |
| agent_builder.add_edge("environment", "llm_call") | |
| # Compile the agent | |
| self.agent = agent_builder.compile() | |
| if show_tools_desc: | |
| for i, tool in enumerate(llm_with_tools.kwargs['tools']): | |
| print("\n" + "="*30 + f" Tool {i+1} " + "="*30) | |
| print(json.dumps(tool[tool['type']], indent=4)) | |
| if show_prompt: | |
| print("\n" + "="*30 + f" System prompt " + "="*30) | |
| print(system_prompt) | |
| def __call__(self, question: str) -> str: | |
| print("\n\n"+"*"*20) | |
| print(f"Agent received question: {question}") | |
| print("*"*20) | |
| # Invoke | |
| messages = [HumanMessage(content=question)] | |
| messages = self.agent.invoke({"messages": messages}, | |
| {"recursion_limit": 30}) # maximum number of steps before hitting a stop condition | |
| for m in messages["messages"]: | |
| m.pretty_print() | |
| # post-process the response (keep only what's after "FINAL ANSWER:" for the exact match) | |
| response = str(messages["messages"][-1].content) | |
| try: | |
| response = response.split("FINAL ANSWER:")[-1].strip() | |
| except: | |
| print('Could not split response on "FINAL ANSWER:"') | |
| print("\n\n"+"-"*50) | |
| print(f"Agent returning with answer: {response}") | |
| return response |