Spaces:
Sleeping
Sleeping
File size: 6,432 Bytes
e4f6727 e3e865e d5ce935 e3e865e d5ce935 e3e865e f49023b e3e865e 4fb4269 d5ce935 4fb4269 e4f6727 3d648f2 4fb4269 e3e865e d5ce935 e3e865e 4fb4269 d5ce935 e3e865e e4f6727 e3e865e 3d648f2 e4f6727 e3e865e e4f6727 d5ce935 58afc3a d5ce935 e4f6727 4fb4269 3d648f2 e4f6727 3d648f2 58afc3a 3d648f2 d5ce935 e3e865e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage
import logging
import os
import re
from typing import List
from args import Args, AgentPreset
from llm_factory import LLMFactory
class IAgent():
def __init__(self, sys_prompt_filename, agent_preset: AgentPreset, tools: List = [], parallel_tool_calls=False):
self.name = self._format_name(sys_prompt_filename)
self.interface = agent_preset.get_interface()
self.mock = (agent_preset.get_model_name() == "groot")
self.tools = tools # <-- store tools for tool call execution
# Load the system prompt from a file
system_prompt_path = os.path.join(os.getcwd(), "system_prompts", sys_prompt_filename)
self.system_prompt = ""
with open(system_prompt_path, "r") as file:
self.system_prompt = file.read().strip()
# Define LLM
llm = LLMFactory.create(agent_preset)
# Add tools
if tools:
self.model = llm.bind_tools(tools, parallel_tool_calls=parallel_tool_calls)
else:
self.model = llm
@staticmethod
def _format_name(sys_prompt_filename: str) -> str:
# Remove file extension
name_without_ext = os.path.splitext(sys_prompt_filename)[0]
# Remove numbers and special characters from the beginning
cleaned_name = re.sub(r'^[^a-zA-Z]+', '', name_without_ext)
return cleaned_name
@staticmethod
def _bake_roles(messages: List[str]) -> List[AnyMessage]:
"""
Assigns roles to messages in reverse order: last message is HumanMessage,
previous is AIMessage, and so on, alternating backwards.
Args:
messages (List[str]): List of message strings.
Returns:
List[AnyMessage]: List of messages wrapped with appropriate role classes.
Raises:
ValueError: If messages is empty.
"""
if not messages:
raise ValueError("The list of messages cannot be empty !")
messages_with_roles = []
total_messages = len(messages)
for idx, msg in enumerate(messages):
# Assign roles in reverse: last is Human, previous is AI, etc.
reverse_idx = total_messages - idx - 1
if reverse_idx % 2 == 0:
messages_with_roles.append(HumanMessage(content=msg))
else:
messages_with_roles.append(AIMessage(content=msg))
return messages_with_roles
def get_system_prompt(self) -> str:
"""
Retrieves the system prompt.
Returns:
str: The system prompt string.
"""
return self.system_prompt
def _handle_tool_calls(self, tool_calls):
"""
Executes tool calls and returns their results as a string.
"""
tool_results = []
for call in tool_calls:
tool_name = None
tool_args = {}
# Qwen-style: call['function']['name'], call['function']['arguments']
if "function" in call:
tool_name = call["function"].get("name")
import json
try:
tool_args = json.loads(call["function"].get("arguments", "{}"))
except Exception:
tool_args = {}
# OpenAI-style: call['name'], call['args']
elif "name" in call and "args" in call:
tool_name = call["name"]
tool_args = call["args"]
tool = next((t for t in self.tools if getattr(t, "name", None) == tool_name), None)
if tool is not None:
try:
# Handle "__arg1" as positional argument for single-argument tools
if isinstance(tool_args, dict) and len(tool_args) == 1 and "__arg1" in tool_args:
result = tool.func(tool_args["__arg1"])
elif isinstance(tool_args, dict):
result = tool.func(**tool_args)
else:
result = tool.func(tool_args)
tool_results.append(f"[{tool_name}]: {result}")
except Exception as e:
tool_results.append(f"[{tool_name} ERROR]: {str(e)}")
else:
tool_results.append(f"[{tool_name} ERROR]: Tool not found")
return "\n".join(tool_results)
def query(self, messages: List[str]) -> str:
"""
Asynchronously queries the agent with a given question and returns the response.
Args:
question (str): The question to be sent to the agent.
Returns:
str: The response from the agent as a string.
"""
if Args.LOGGER is None:
raise RuntimeError("LOGGER must be defined before querying the agent.")
separator = "=============================="
Args.LOGGER.log(logging.INFO, f"\n{separator}\nAgent '{self.name}' has been queried !\nINPUT:\n{messages}\nLAST INPUT:{messages[-1]}\n")
if self.mock:
response = str("I am GROOT !")
Args.LOGGER.log(logging.INFO, f"\nAgent '{self.name}' produced OUTPUT:\n{response}\n{separator}\n")
return response
system_prompt = self.get_system_prompt()
# Disable thinking block for some models
if Args.MiscParams.NO_THINK:
messages[-1] += "\n/no_think"
messages_with_roles = self._bake_roles(messages)
conversation = [SystemMessage(content=system_prompt)] + messages_with_roles
raw_output = self.model.invoke(conversation)
# --- Unified output and tool call handling ---
response = ""
# 1. Handle tool calls if present
tool_calls = getattr(raw_output, "additional_kwargs", {}).get("tool_calls", None)
if tool_calls:
Args.LOGGER.log(logging.INFO, f"\nAgent '{self.name}' called tools !\n")
response = self._handle_tool_calls(tool_calls)
# 2. Otherwise, use standard LLM output if present
elif hasattr(raw_output, "content") and raw_output.content:
response = str(raw_output.content)
# 3. Fallback: string conversion
else:
response = str(raw_output)
Args.LOGGER.log(logging.INFO, f"\nAgent '{self.name}' produced OUTPUT:\n{response}\n{separator}\n")
return response
|