Spaces:
Sleeping
Sleeping
from typing import Dict, List, cast | |
from langchain_core.messages import AIMessage | |
from langgraph.graph import StateGraph | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from src.config import Configuration | |
from src.model import GoogleModel | |
from src.state import InputState, State | |
from src.tools import TOOLS | |
class GaiaAgent: | |
def __init__(self): | |
self.graph = self._build_graph() | |
def _build_graph(self) -> StateGraph: | |
builder = StateGraph(State, input=InputState, config_schema=Configuration) | |
# Define the two nodes we will cycle between | |
builder.add_node("call_model", self._call_model) | |
builder.add_node("tools", ToolNode(TOOLS)) | |
# Set the entrypoint as `call_model` | |
# This means that this node is the first one called | |
builder.add_edge("__start__", "call_model") | |
builder.add_conditional_edges( | |
"call_model", | |
# If the latest message requires a tool, route to tools | |
# Otherwise, provide a direct response | |
tools_condition, | |
) | |
builder.add_edge("tools", "call_model") | |
graph = builder.compile(name="GAIA Agent", debug=False) | |
return graph | |
def _call_model(self, state: State) -> Dict[str, List[AIMessage]]: | |
"""Call the LLM powering our "agent". | |
This function prepares the prompt, initializes the model, and processes the response. | |
Args: | |
state (State): The current state of the conversation. | |
config (RunnableConfig): Configuration for the model run. | |
Returns: | |
dict: A dictionary containing the model's response message. | |
""" | |
configuration = Configuration.from_context() | |
# Initialize the model with tool binding. Change the model or add more tools here. | |
model = GoogleModel( | |
model=configuration.google_model, | |
temperature=configuration.temperature, | |
tools=TOOLS | |
) | |
# Format the system prompt. Customize this to change the agent's behavior. | |
system_message = configuration.system_prompt | |
if state.file_name: | |
file_prompt = ( | |
f"\n\nThe task id is {state.task_id}.\n" | |
f"Please use this to download the file." | |
) | |
system_message += file_prompt | |
# Get the model's response | |
response = cast( | |
AIMessage, | |
model.llm.invoke( | |
[ | |
{"role": "system", "content": system_message}, | |
*state.messages, | |
] | |
), | |
) | |
# Handle the case when it's the last step and the model still wants to use a tool | |
if state.is_last_step and response.tool_calls: | |
return { | |
"messages": [ | |
AIMessage( | |
id=response.id, | |
content="Sorry, I could not find an answer to your question in the specified number of steps.", | |
) | |
] | |
} | |
# Return the model's response as a list to be added to existing messages | |
return {"messages": [response]} | |