Spaces:
Sleeping
Sleeping
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage | |
import logging | |
import os | |
import re | |
from typing import List | |
from args import Args, AgentPreset | |
from llm_factory import LLMFactory | |
class IAgent(): | |
def __init__(self, sys_prompt_filename, agent_preset: AgentPreset, tools: List = [], parallel_tool_calls=False): | |
self.name = self._format_name(sys_prompt_filename) | |
self.interface = agent_preset.get_interface() | |
self.mock = (agent_preset.get_model_name() == "groot") | |
self.tools = tools # <-- store tools for tool call execution | |
# Load the system prompt from a file | |
system_prompt_path = os.path.join(os.getcwd(), "system_prompts", sys_prompt_filename) | |
self.system_prompt = "" | |
with open(system_prompt_path, "r") as file: | |
self.system_prompt = file.read().strip() | |
# Define LLM | |
llm = LLMFactory.create(agent_preset) | |
# Add tools | |
if tools: | |
self.model = llm.bind_tools(tools, parallel_tool_calls=parallel_tool_calls) | |
else: | |
self.model = llm | |
def _format_name(sys_prompt_filename: str) -> str: | |
# Remove file extension | |
name_without_ext = os.path.splitext(sys_prompt_filename)[0] | |
# Remove numbers and special characters from the beginning | |
cleaned_name = re.sub(r'^[^a-zA-Z]+', '', name_without_ext) | |
return cleaned_name | |
def _bake_roles(messages: List[str]) -> List[AnyMessage]: | |
""" | |
Assigns roles to messages in reverse order: last message is HumanMessage, | |
previous is AIMessage, and so on, alternating backwards. | |
Args: | |
messages (List[str]): List of message strings. | |
Returns: | |
List[AnyMessage]: List of messages wrapped with appropriate role classes. | |
Raises: | |
ValueError: If messages is empty. | |
""" | |
if not messages: | |
raise ValueError("The list of messages cannot be empty !") | |
messages_with_roles = [] | |
total_messages = len(messages) | |
for idx, msg in enumerate(messages): | |
# Assign roles in reverse: last is Human, previous is AI, etc. | |
reverse_idx = total_messages - idx - 1 | |
if reverse_idx % 2 == 0: | |
messages_with_roles.append(HumanMessage(content=msg)) | |
else: | |
messages_with_roles.append(AIMessage(content=msg)) | |
return messages_with_roles | |
def get_system_prompt(self) -> str: | |
""" | |
Retrieves the system prompt. | |
Returns: | |
str: The system prompt string. | |
""" | |
return self.system_prompt | |
def _handle_tool_calls(self, tool_calls): | |
""" | |
Executes tool calls and returns their results as a string. | |
""" | |
tool_results = [] | |
for call in tool_calls: | |
tool_name = None | |
tool_args = {} | |
# Qwen-style: call['function']['name'], call['function']['arguments'] | |
if "function" in call: | |
tool_name = call["function"].get("name") | |
import json | |
try: | |
tool_args = json.loads(call["function"].get("arguments", "{}")) | |
except Exception: | |
tool_args = {} | |
# OpenAI-style: call['name'], call['args'] | |
elif "name" in call and "args" in call: | |
tool_name = call["name"] | |
tool_args = call["args"] | |
tool = next((t for t in self.tools if getattr(t, "name", None) == tool_name), None) | |
if tool is not None: | |
try: | |
# Handle "__arg1" as positional argument for single-argument tools | |
if isinstance(tool_args, dict) and len(tool_args) == 1 and "__arg1" in tool_args: | |
result = tool.func(tool_args["__arg1"]) | |
elif isinstance(tool_args, dict): | |
result = tool.func(**tool_args) | |
else: | |
result = tool.func(tool_args) | |
tool_results.append(f"[{tool_name}]: {result}") | |
except Exception as e: | |
tool_results.append(f"[{tool_name} ERROR]: {str(e)}") | |
else: | |
tool_results.append(f"[{tool_name} ERROR]: Tool not found") | |
return "\n".join(tool_results) | |
def query(self, messages: List[str]) -> str: | |
""" | |
Asynchronously queries the agent with a given question and returns the response. | |
Args: | |
question (str): The question to be sent to the agent. | |
Returns: | |
str: The response from the agent as a string. | |
""" | |
if Args.LOGGER is None: | |
raise RuntimeError("LOGGER must be defined before querying the agent.") | |
separator = "==============================" | |
Args.LOGGER.log(logging.INFO, f"\n{separator}\nAgent '{self.name}' has been queried !\nINPUT:\n{messages}\nLAST INPUT:{messages[-1]}\n") | |
if self.mock: | |
response = str("I am GROOT !") | |
Args.LOGGER.log(logging.INFO, f"\nAgent '{self.name}' produced OUTPUT:\n{response}\n{separator}\n") | |
return response | |
system_prompt = self.get_system_prompt() | |
# Disable thinking block for some models | |
if Args.MiscParams.NO_THINK: | |
messages[-1] += "\n/no_think" | |
messages_with_roles = self._bake_roles(messages) | |
conversation = [SystemMessage(content=system_prompt)] + messages_with_roles | |
raw_output = self.model.invoke(conversation) | |
# --- Unified output and tool call handling --- | |
response = "" | |
# 1. Handle tool calls if present | |
tool_calls = getattr(raw_output, "additional_kwargs", {}).get("tool_calls", None) | |
if tool_calls: | |
Args.LOGGER.log(logging.INFO, f"\nAgent '{self.name}' called tools !\n") | |
response = self._handle_tool_calls(tool_calls) | |
# 2. Otherwise, use standard LLM output if present | |
elif hasattr(raw_output, "content") and raw_output.content: | |
response = str(raw_output.content) | |
# 3. Fallback: string conversion | |
else: | |
response = str(raw_output) | |
Args.LOGGER.log(logging.INFO, f"\nAgent '{self.name}' produced OUTPUT:\n{response}\n{separator}\n") | |
return response | |