|
|
|
|
| from langchain_core.output_parsers import PydanticOutputParser |
| from typing import Callable, Dict, List, Any |
| import time |
| import json |
| from groq_api import grok_get_llm_response, API_llama_get_llm_response, open_oss_get_llm_response, openai_get_llm_response, deepseekapi_get_llm_response |
| from local_templates import llama3_get_llm_response, mistral_get_llm_response, qwen_get_llm_response, deepseek_get_llm_response, grape_get_llm_response |
| import os |
| import re |
|
|
|
|
| max_steps = 15 |
|
|
| base_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
| def select_model(model_type: str): |
| """Return the correct LLM response function for a given model_type.""" |
|
|
| mapping = { |
| "groq_api": grok_get_llm_response, |
| "llama_api": API_llama_get_llm_response, |
| "oss_api": open_oss_get_llm_response, |
| "openai_api": openai_get_llm_response, |
| "deepseek_api": deepseekapi_get_llm_response, |
| "llama3": llama3_get_llm_response, |
| "mistral": mistral_get_llm_response, |
| "qwen3": qwen_get_llm_response, |
| "deepseek": deepseek_get_llm_response, |
| "grape": grape_get_llm_response, |
| } |
|
|
| if model_type not in mapping: |
| raise ValueError(f"Unknown model_type: {model_type}") |
|
|
| return mapping[model_type] |
|
|
|
|
| def format_gaia_response(model_type, last_observation, question_out): |
|
|
| get_llm_response = select_model(model_type) |
|
|
| |
| with open(base_dir+"/system_prompt_final.txt", "r") as f: |
| final_sys_prompt = f.read() |
|
|
| gaia_prompt = ( |
| f"{final_sys_prompt}\n\n" |
| f"User Question:\n{question_out}\n\n" |
| f"Last Observation:\n{last_observation}\n\n" |
| "Please review user questions and the last obervation and respond with the correct answer, in the correct format. No extra text, just the answer." |
| ) |
|
|
| final_answer_out = get_llm_response(final_sys_prompt, gaia_prompt, reasoning_format = 'hidden') |
|
|
| return final_answer_out |
|
|
|
|
| class ImprovedAgent: |
| def __init__(self, tools: Dict[str, Callable], model_type: str): |
| self.tools = tools |
| self.history = [] |
| self.get_llm_response = select_model(model_type) |
|
|
|
|
| |
| self.system_prompt_plan = self.load_prompt(base_dir+"/system_prompt_planning.txt") |
| self.system_prompt_thought = self.load_prompt(base_dir+"/system_prompt_thought.txt") |
| self.system_prompt_action = self.load_prompt(base_dir+"/system_prompt_action.txt") |
| self.system_prompt_observe = self.load_prompt(base_dir+"/system_prompt_observe.txt") |
|
|
|
|
| def load_prompt(self, filepath: str) -> str: |
| with open(filepath, "r") as f: |
| return f.read() |
|
|
| def reset(self): |
| self.history = [] |
| def strip_markdown_code_block(self, text: str) -> str: |
| """ |
| Remove leading/trailing markdown code block markers like ```json or ``` |
| """ |
| |
| text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.IGNORECASE) |
| |
| text = re.sub(r"\s*```$", "", text) |
| return text.strip() |
| def parse_json_response(self, response_text: str) -> Dict: |
| """Attempt to parse LLM JSON response safely.""" |
| |
| try: |
|
|
| cleaned = self.strip_markdown_code_block(response_text.strip()) |
| |
| json_text = self.extract_json_string(cleaned) |
|
|
| json_text = json_text.replace("\\'", "'") |
| |
|
|
| return json.loads(json_text) |
|
|
| except json.JSONDecodeError as e: |
| print(f"[ERROR] JSON Parse Error: {e}") |
| print(f"[DEBUG] Raw response: {response_text}") |
| return {"error": f"Invalid JSON response: {str(e)}"} |
|
|
| def extract_json_string(self, text: str) -> str: |
| """Extract the first valid-looking JSON object from a string.""" |
| match = re.search(r'\{.*\}', text, re.DOTALL) |
| return match.group(0) if match else text |
|
|
| def build_prompt_from_history(self, query: str) -> str: |
| return f"""User Query: {query} |
| History: {json.dumps(self.history, indent=2)} |
| """ |
|
|
| def run(self, query: str): |
| self.reset() |
|
|
| |
| planning_input = f"User Query: {query}" |
| print("-----Stage Plan-----") |
| |
| plan_response = self.get_llm_response(self.system_prompt_plan, planning_input) |
| print("-----Plan Text-----") |
| print(plan_response) |
| print("-------------------") |
| print("-----Plan Parsed-----") |
| parsed_plan = self.parse_json_response(plan_response) |
| print(parsed_plan) |
| print("---------------------") |
| self.history.append(parsed_plan) |
|
|
| current_input = self.build_prompt_from_history(query) |
|
|
| for _ in range(max_steps): |
|
|
| print(f"-----Itterantion {_}-----") |
| |
| print("-----Stage Thought-----") |
| |
| thought_response = self.get_llm_response(self.system_prompt_thought, current_input) |
| print(thought_response) |
| parsed_thought = self.parse_json_response(thought_response) |
| print("-----Thought Parsed-----") |
| print(parsed_thought) |
| print("-----------------") |
| self.history.append(parsed_thought) |
|
|
| |
| if "thought" not in parsed_thought: |
| return "[ERROR] Thought agent did not return 'thought'. Ending.", "" |
| action_input = json.dumps({"thought": parsed_thought["thought"]}) |
| print("-----Stage Action-----") |
| |
| action_response_text = self.get_llm_response(self.system_prompt_action, action_input) |
|
|
| |
| try: |
| |
| if '<think>' in action_response_text and '</think>' in action_response_text: |
| json_part = action_response_text.split('</think>')[1].strip() |
| else: |
| json_part = action_response_text.strip() |
| |
| |
| import re |
| json_match = re.search(r'\{.*\}', json_part) |
| if json_match: |
| parsed_action = json.loads(json_match.group()) |
| else: |
| parsed_action = {'error': 'No JSON found in response'} |
| |
| except Exception as e: |
| parsed_action = {'error': f'JSON parsing failed: {str(e)}'} |
| print(parsed_action) |
| print("-----------------") |
| self.history.append(parsed_action) |
|
|
| |
| tool_name = parsed_action.get("action") |
| tool_args = parsed_action.get("action_input", {}) |
| |
| |
| |
| |
| |
| if not tool_name or tool_name not in self.tools: |
| observation = f"[ERROR] Invalid or missing tool: {tool_name}" |
| else: |
| try: |
| result = self.tools[tool_name](**tool_args) |
| observation = f"Tool `{tool_name}` executed successfully. Output: {result}" |
| print("-----Tool Observation OK-----") |
| print(observation) |
| print("-----------------") |
| |
| except Exception as e: |
| observation = f"[ERROR] Tool `{tool_name}` execution failed: {str(e)}" |
| print("-----Tool Observation Fail-----") |
| print(observation) |
| print("-----------------") |
|
|
| |
| self.history.append({ |
| "tool_name": tool_name, |
| "tool_args": tool_args, |
| |
| }) |
|
|
| |
| |
| observation_input = f"""User Query: {query} |
| Plan: {json.dumps(self.history[0], indent=2)} |
| History: {json.dumps(self.history, indent=2)} |
| Tool Output: {observation} |
| """ |
| print("-----Stage Observe-----") |
| observation_response_text = self.get_llm_response(self.system_prompt_observe, observation_input) |
|
|
| print("-----Observation Parsed-----") |
| parsed_observation = self.parse_json_response(observation_response_text) |
| print(parsed_observation) |
| print("-----------------") |
| self.history.append(parsed_observation) |
|
|
| |
| if "final_answer" in parsed_observation: |
| print(parsed_observation["final_answer"]) |
| |
| return self.history, observation_response_text, parsed_observation["final_answer"] |
|
|
| |
| current_input = self.build_prompt_from_history(query) |
|
|
| print('ERROR LOOP LIMIT REACHED') |
| return self.history, observation_response_text + "This is our last observation. Make your best estimation given the question.", parsed_observation |