debashis2007's picture
Upload folder using huggingface_hub
75bea1c verified
from __future__ import annotations
"""ReACT (Reasoning + Acting) handler."""
import json
from dataclasses import dataclass, field
from typing import Any
from src.llm import LLMClient, Message, MessageRole
from src.llm.prompts import format_prompt, PromptNames
from src.tools.base import ToolRegistry
from src.utils.config import settings
from src.utils.logging import get_logger
logger = get_logger(__name__)
@dataclass
class ReACTStep:
"""A single step in the ReACT loop."""
iteration: int
thought: str
action: str
action_input: dict[str, Any]
observation: str | None = None
@dataclass
class ReACTResult:
"""Result from a ReACT loop execution."""
answer: str
steps: list[ReACTStep] = field(default_factory=list)
success: bool = True
error: str | None = None
class ReACTHandler:
"""Handler for ReACT reasoning loops."""
def __init__(
self,
llm_client: LLMClient,
tool_registry: ToolRegistry,
max_iterations: int | None = None,
):
"""Initialize ReACT handler.
Args:
llm_client: LLM client for reasoning
tool_registry: Registry of available tools
max_iterations: Maximum iterations (defaults to settings)
"""
self.llm = llm_client
self.tools = tool_registry
self.max_iterations = max_iterations or settings.max_iterations
async def run(
self,
query: str,
system_prompt: str,
initial_context: dict[str, Any] | None = None,
) -> ReACTResult:
"""Run a ReACT loop to answer a query.
Args:
query: User's query
system_prompt: System prompt for the agent
initial_context: Optional initial context
Returns:
ReACTResult with answer and step history
"""
steps: list[ReACTStep] = []
working_memory = initial_context or {}
tool_schemas = self.tools.get_schemas()
for iteration in range(1, self.max_iterations + 1):
logger.info(f"ReACT iteration {iteration}")
# Build context from previous steps
context = self._format_steps(steps)
# Generate thought and action
prompt = format_prompt(
PromptNames.REACT_REASONING,
user_query=query,
iteration_number=iteration,
max_iterations=self.max_iterations,
previous_steps=context,
working_memory=json.dumps(working_memory),
)
messages = [
Message(role=MessageRole.SYSTEM, content=system_prompt),
Message(role=MessageRole.USER, content=prompt),
]
response = await self.llm.chat(messages, tools=tool_schemas, temperature=0.5)
# Parse the response
thought, action, action_input = self._parse_response(response)
logger.info(f"Thought: {thought[:100]}...")
logger.info(f"Action: {action}")
# Check for finish
if action.lower() == "finish":
answer = action_input.get("answer", response.content or "")
steps.append(
ReACTStep(
iteration=iteration,
thought=thought,
action="finish",
action_input=action_input,
observation=answer,
)
)
return ReACTResult(answer=answer, steps=steps, success=True)
# Execute action
if response.has_tool_calls:
tool_call = response.tool_calls[0]
result = await self.tools.execute(tool_call.name, **tool_call.arguments)
observation = (
json.dumps(result.data) if result.success else f"Error: {result.error}"
)
action = tool_call.name
action_input = tool_call.arguments
elif action:
result = await self.tools.execute(action, **action_input)
observation = (
json.dumps(result.data) if result.success else f"Error: {result.error}"
)
else:
observation = "No valid action specified"
# Record step
steps.append(
ReACTStep(
iteration=iteration,
thought=thought,
action=action,
action_input=action_input,
observation=observation,
)
)
# Update working memory
working_memory[f"step_{iteration}"] = {
"action": action,
"observation": observation[:500], # Truncate for memory
}
# Max iterations reached
return ReACTResult(
answer="I was unable to find a complete answer within the iteration limit.",
steps=steps,
success=False,
error="Max iterations reached",
)
def _format_steps(self, steps: list[ReACTStep]) -> str:
"""Format steps for context.
Args:
steps: List of ReACT steps
Returns:
Formatted string
"""
if not steps:
return "No previous steps."
formatted = []
for step in steps:
formatted.append(
f"**THOUGHT {step.iteration}:** {step.thought}\n"
f"**ACTION {step.iteration}:** {step.action}[{json.dumps(step.action_input)}]\n"
f"**OBSERVATION {step.iteration}:** {step.observation}"
)
return "\n\n".join(formatted)
def _parse_response(self, response: Any) -> tuple[str, str, dict[str, Any]]:
"""Parse thought and action from response.
Args:
response: LLM response
Returns:
Tuple of (thought, action, action_input)
"""
content = response.content or ""
# Handle tool calls
if response.has_tool_calls:
tool_call = response.tool_calls[0]
thought = content.split("**ACTION")[0].replace("**THOUGHT", "").strip()
thought = thought.strip("*: \n")
return thought, tool_call.name, tool_call.arguments
# Parse text format
thought = ""
action = ""
action_input: dict[str, Any] = {}
if "THOUGHT" in content:
thought_part = content.split("THOUGHT")[-1]
thought = thought_part.split("**ACTION")[0].strip("*: \n")
if "ACTION" in content:
action_part = content.split("ACTION")[-1].strip("*: \n")
if "[" in action_part and "]" in action_part:
action = action_part.split("[")[0].strip()
input_str = action_part[action_part.find("[") + 1 : action_part.rfind("]")]
try:
if input_str.startswith("{"):
action_input = json.loads(input_str)
else:
action_input = {"answer": input_str}
except json.JSONDecodeError:
action_input = {"answer": input_str}
else:
action = action_part.split("\n")[0].strip()
if "finish" in action.lower():
action = "finish"
return thought, action, action_input