Nexa_Labs / agent /controller.py
Allanatrix's picture
Upload 57 files
d8328bf verified
"""Agent controller orchestrating the LLM ↔ tool server interaction loop."""
from __future__ import annotations
import json
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Sequence
from tools.schemas import ToolCall, ToolResult
from .client_llm import Message, NexaSciModelClient
from .client_llm_remote import RemoteNexaSciClient
from .tool_client import ToolClient
TOOLCALL_REGEX = re.compile(r"~~~toolcall(.*?)~~~", re.DOTALL)
FINAL_REGEX = re.compile(r"~~~final(.*?)~~~", re.DOTALL)
@dataclass
class AgentRunResult:
"""Container describing the outcome of an agent run."""
final_response: Dict[str, Any]
messages: Sequence[Message]
tool_results: Sequence[ToolResult] = field(default_factory=list)
def pretty(self) -> str:
"""Return a readable JSON representation of the final response."""
return json.dumps(self.final_response, indent=2)
class AgentController:
"""Core agent loop handling tool invocation and final response parsing."""
def __init__(
self,
llm_client: NexaSciModelClient | RemoteNexaSciClient | None = None,
tool_client: ToolClient | None = None,
*,
max_turns: int = 8,
use_remote_model: bool = False,
model_server_url: str = "http://127.0.0.1:8001",
) -> None:
"""Initialize the agent controller.
Parameters
----------
llm_client:
Optional LLM client. If None, will create one based on use_remote_model.
tool_client:
Optional tool client. If None, will create from config.
max_turns:
Maximum number of agent turns.
use_remote_model:
If True, connect to remote model server instead of loading locally.
model_server_url:
URL of the model server (if use_remote_model is True).
"""
if llm_client is None:
if use_remote_model:
llm_client = RemoteNexaSciClient(base_url=model_server_url)
else:
llm_client = NexaSciModelClient(lazy_load=True)
self.llm_client = llm_client
self.tool_client = tool_client or ToolClient.from_config()
self.max_turns = max_turns
def run(self, user_prompt: str) -> AgentRunResult:
"""Execute the agent loop until a final response is produced."""
messages: List[Message] = [Message(role="user", content=user_prompt)]
tool_results: List[ToolResult] = []
for _ in range(self.max_turns):
response_text = self.llm_client.generate(messages)
messages.append(Message(role="assistant", content=response_text))
tool_calls = _extract_tool_calls(response_text)
if tool_calls:
for call in tool_calls:
result = self._dispatch_tool(call)
tool_results.append(result)
messages.append(
Message(
role="tool",
content=json.dumps(result.output, ensure_ascii=False),
)
)
continue
final_payload = _extract_final_response(response_text)
if final_payload is not None:
return AgentRunResult(final_response=final_payload, messages=messages, tool_results=tool_results)
raise RuntimeError("Agent did not produce a final response within the maximum number of turns.")
def _dispatch_tool(self, call: ToolCall) -> ToolResult:
"""Invoke the requested tool via the ToolClient."""
return self.tool_client.call_tool(call)
def _extract_tool_calls(text: str) -> List[ToolCall]:
"""Parse tool call JSON payloads embedded in the assistant response."""
tool_calls: List[ToolCall] = []
for match in TOOLCALL_REGEX.findall(text):
snippet = match.strip()
if not snippet:
continue
try:
payload = json.loads(snippet)
tool_calls.append(ToolCall(**payload))
except json.JSONDecodeError:
continue
return tool_calls
def _extract_final_response(text: str) -> Dict[str, Any] | None:
"""Parse the final response JSON block from the assistant output."""
match = FINAL_REGEX.search(text)
if not match:
return None
snippet = match.group(1).strip()
if not snippet:
return {}
return json.loads(snippet)
__all__ = ["AgentController", "AgentRunResult"]