Spaces:
Runtime error
Runtime error
from llama_index.core.tools import FunctionTool | |
from llama_index.core.workflow import Context | |
import logging | |
import os | |
import re | |
from typing import List | |
from args import Args, LLMInterface | |
from llm_factory import LLMFactory | |
from llama_index.core.agent.workflow import AgentWorkflow | |
class IAgent(): | |
def __init__(self, temperature, max_tokens, sys_prompt_filename, llm_itf: LLMInterface): | |
self.name = self._format_name(sys_prompt_filename) | |
self.temperature, self.max_tokens = temperature, max_tokens | |
# Load the system prompt from a file | |
system_prompt_path = os.path.join(os.getcwd(), "system_prompts", sys_prompt_filename) | |
self.system_prompt = "" | |
with open(system_prompt_path, "r") as file: | |
self.system_prompt = file.read().strip() | |
# Initialize the tool agents | |
self.tools = self.setup_tools() | |
self.slaves: List[IAgent] = self.setup_slaves() | |
# Define the LLM and agent | |
self.llm = LLMFactory.create(llm_itf, self.system_prompt, temperature, max_tokens) | |
self.agent = self._setup_agent() | |
self.ctx = Context(self.agent) | |
def _format_name(sys_prompt_filename: str) -> str: | |
# Remove file extension | |
name_without_ext = os.path.splitext(sys_prompt_filename)[0] | |
# Remove numbers and special characters from the beginning | |
cleaned_name = re.sub(r'^[^a-zA-Z]+', '', name_without_ext) | |
return cleaned_name | |
def setup_tools(self) -> List[FunctionTool]: | |
""" | |
Set up the tools for this agent. | |
Override this method in subclasses to define custom tools. | |
By default, returns an empty list. | |
Returns: | |
List: A list of tools this agent can use | |
""" | |
return [] | |
def setup_slaves(self) -> List: | |
""" | |
Set up the slave agents for this agent. | |
Override this method in subclasses to define custom sub-agents. | |
By default, returns an empty list. | |
Returns: | |
List: A list of slave agents this agent can use | |
""" | |
return [] | |
def _setup_agent(self) -> AgentWorkflow: | |
""" | |
Initializes and returns an agent workflow based on the presence of tools and slaves. | |
If both `self.tools` and `self.slaves` are empty, it sets up a default agent using the provided language model (`self.llm`). | |
Otherwise, it creates an agent workflow using the combined list of tools and slaves with the language model. | |
Returns: | |
AgentWorkflow: An instance of the agent workflow configured with the appropriate tools and language model. | |
""" | |
# Create tools from slaves: each tool calls slave.query(question) asynchronously | |
slave_tools = [] | |
for slave in self.slaves: | |
slave_tool = FunctionTool.from_defaults( | |
name=f"call_{slave.name}", | |
description=f"Calls agent {slave.name} with a given query.", | |
fn=slave.query | |
) | |
slave_tools.append(slave_tool) | |
self.tools.extend(slave_tools) | |
return AgentWorkflow.from_tools_or_functions( | |
self.tools, | |
llm=self.llm | |
) | |
def get_system_prompt(self) -> str: | |
""" | |
Retrieves the system prompt. | |
Returns: | |
str: The system prompt string. | |
""" | |
return self.system_prompt | |
async def query(self, question: str, has_context = True) -> str: | |
""" | |
Asynchronously queries the agent with a given question and returns the response. | |
Args: | |
question (str): The question to be sent to the agent. | |
Returns: | |
str: The response from the agent as a string. | |
""" | |
if Args.LOGGER is None: | |
raise RuntimeError("LOGGER must be defined before querying the agent.") | |
separator = "==============================" | |
Args.LOGGER.log(logging.INFO, f"\n{separator}\nAgent '{self.name}' has been queried !\nINPUT:\n{question}\n") | |
if has_context: | |
response = await self.agent.run(question, ctx=self.ctx) | |
else: | |
response = await self.agent.run(question) | |
response = str(response) | |
Args.LOGGER.log(logging.INFO, f"\nAgent '{self.name}' produced OUTPUT:\n{response}\n{separator}\n") | |
return response | |
def clear_context(self): | |
""" | |
Clears the current context of the agent, resetting any conversation history. | |
This is useful when starting a new conversation or when the context needs to be refreshed. | |
""" | |
if self.ctx is not None: | |
self.ctx = Context(self.agent) | |
if not self.slaves: | |
return | |
for slave in self.slaves: | |
slave.clear_context() | |