GAIA_Agent_Space / src /gaia_agent.py
kylea's picture
added tools for downloading, files, wikipedia search
afb4047
from typing import Dict, List, cast
from langchain_core.messages import AIMessage
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from src.config import Configuration
from src.model import GoogleModel
from src.state import InputState, State
from src.tools import TOOLS
class GaiaAgent:
def __init__(self):
self.graph = self._build_graph()
def _build_graph(self) -> StateGraph:
builder = StateGraph(State, input=InputState, config_schema=Configuration)
# Define the two nodes we will cycle between
builder.add_node("call_model", self._call_model)
builder.add_node("tools", ToolNode(TOOLS))
# Set the entrypoint as `call_model`
# This means that this node is the first one called
builder.add_edge("__start__", "call_model")
builder.add_conditional_edges(
"call_model",
# If the latest message requires a tool, route to tools
# Otherwise, provide a direct response
tools_condition,
)
builder.add_edge("tools", "call_model")
graph = builder.compile(name="GAIA Agent", debug=False)
return graph
def _call_model(self, state: State) -> Dict[str, List[AIMessage]]:
"""Call the LLM powering our "agent".
This function prepares the prompt, initializes the model, and processes the response.
Args:
state (State): The current state of the conversation.
config (RunnableConfig): Configuration for the model run.
Returns:
dict: A dictionary containing the model's response message.
"""
configuration = Configuration.from_context()
# Initialize the model with tool binding. Change the model or add more tools here.
model = GoogleModel(
model=configuration.google_model,
temperature=configuration.temperature,
tools=TOOLS
)
# Format the system prompt. Customize this to change the agent's behavior.
system_message = configuration.system_prompt
if state.file_name:
file_prompt = (
f"\n\nThe task id is {state.task_id}.\n"
f"Please use this to download the file."
)
system_message += file_prompt
# Get the model's response
response = cast(
AIMessage,
model.llm.invoke(
[
{"role": "system", "content": system_message},
*state.messages,
]
),
)
# Handle the case when it's the last step and the model still wants to use a tool
if state.is_last_step and response.tool_calls:
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, I could not find an answer to your question in the specified number of steps.",
)
]
}
# Return the model's response as a list to be added to existing messages
return {"messages": [response]}