ABVM commited on
Commit
40f02cf
·
verified ·
1 Parent(s): c4a1609

Upload multi_agent.py

Browse files
Files changed (1) hide show
  1. multi_agent.py +197 -0
multi_agent.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import (
2
+ CodeAgent,
3
+ VisitWebpageTool,
4
+ WebSearchTool,
5
+ WikipediaSearchTool,
6
+ PythonInterpreterTool,
7
+ FinalAnswerTool,
8
+ )
9
+ from groq import Groq
10
+ from vision_tool import image_reasoning_tool
11
+ import os
12
+ import time
13
+
14
+
15
+ # ---- TOOLS ----
16
+
17
+
18
+ # ---- GROQ MODEL WRAPPER ----
19
+ class GroqModel:
20
+ def __init__(self, model_name=""):
21
+ self.model_name = model_name
22
+ self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
23
+
24
+ def __call__(self, prompt, max_tokens=8096):
25
+ if isinstance(prompt, str):
26
+ messages = [{"role": "user", "content": prompt}]
27
+ else:
28
+ messages = prompt
29
+
30
+ response = None
31
+ for attempt in range(3):
32
+ try:
33
+ response = self.client.chat.completions.create(
34
+ messages=messages,
35
+ model=self.model_name,
36
+ stream=False,
37
+ max_tokens=max_tokens,
38
+ )
39
+ break
40
+ except Exception as e:
41
+ msg = str(e).lower()
42
+ if "rate limit" in msg and attempt < 2:
43
+ wait = 10 * (attempt + 1)
44
+ time.sleep(wait)
45
+ continue
46
+ raise
47
+
48
+ if response is None:
49
+ response = self.client.chat.completions.create(
50
+ messages=messages,
51
+ model=self.model_name,
52
+ stream=False,
53
+ max_tokens=max_tokens,
54
+ )
55
+
56
+ choice = response.choices[0]
57
+ if hasattr(choice, "message"):
58
+ content = choice.message.content
59
+ else:
60
+ # Fallback for text-only completions
61
+ if hasattr(choice, "text"):
62
+ content = choice.text
63
+ elif isinstance(choice, str):
64
+ content = choice
65
+ else:
66
+ content = str(choice)
67
+ # token usage is calculated but currently unused
68
+ if hasattr(response, "usage") and response.usage is not None:
69
+ _ = response.usage.total_tokens
70
+
71
+ return content
72
+
73
+ def generate(self, prompt, max_tokens=8096, **kwargs):
74
+ # For compatibility with agent frameworks
75
+ return self.__call__(prompt, max_tokens=max_tokens)
76
+
77
+
78
+ # ---- MULTI-AGENT SYSTEM ----
79
+ class MultyAgentSystem:
80
+ def __init__(self):
81
+ self.primary_model_name = "deepseek-r1-distill-llama-70b"
82
+ self.fallback_model_name = "llama3-70b-8k"
83
+
84
+ self.deepseek_model = GroqModel(self.primary_model_name)
85
+ qwen_model = GroqModel("qwen-qwq-32b")
86
+ self.verification_limit = int(os.getenv("VERIFY_WORD_LIMIT", "75"))
87
+
88
+ # --- Web agent definition ---
89
+ self.web_agent = CodeAgent(
90
+ model=qwen_model,
91
+ tools=[WebSearchTool(), VisitWebpageTool(), WikipediaSearchTool()],
92
+ name="web_agent",
93
+ description=(
94
+ "You are a web browsing agent. Whenever the given {task} involves browsing "
95
+ "the web or a specific website such as Wikipedia or YouTube, you will use "
96
+ "the provided tools. For web-based factual and retrieval tasks, be as precise and source-reliable as possible."
97
+ ),
98
+ additional_authorized_imports=[
99
+ "markdownify",
100
+ "json",
101
+ "requests",
102
+ "urllib.request",
103
+ "urllib.parse",
104
+ "wikipedia-api",
105
+ ],
106
+ verbosity_level=0,
107
+ max_steps=10,
108
+ )
109
+
110
+ # --- Info agent definition ---
111
+ self.info_agent = CodeAgent(
112
+ model=qwen_model,
113
+ tools=[PythonInterpreterTool(), image_reasoning_tool],
114
+ name="info_agent",
115
+ description=(
116
+ "You are an agent tasked with cleaning, parsing, calculating information, and performing OCR if images are provided in the {task}. "
117
+ "You can also analyze images using a vision model. You handle all math, code, and data manipulation. Use numpy, math, and available libraries. "
118
+ "For image or chess tasks, use pytesseract, PIL, chess, or the image_reasoning_tool as required."
119
+ ),
120
+ additional_authorized_imports=[
121
+ "numpy",
122
+ "math",
123
+ "pytesseract",
124
+ "PIL",
125
+ "chess",
126
+ ],
127
+ )
128
+
129
+ # --- Manager agent definition ---
130
+ manager_planning_interval = int(os.getenv("MANAGER_PLANNING_INTERVAL", "3"))
131
+ manager_max_steps = int(os.getenv("MANAGER_MAX_STEPS", "8"))
132
+
133
+ self.manager_agent = CodeAgent(
134
+ model=qwen_model,
135
+ tools=[FinalAnswerTool()],
136
+ managed_agents=[self.web_agent, self.info_agent],
137
+ name="manager_agent",
138
+ description=(
139
+ "You are the manager. Given a {task}, plan which agent to use: "
140
+ "If web data is needed, delegate to web_agent. If math, parsing, image reasoning, or code is needed, use info_agent. "
141
+ "After collecting outputs, optionally cross-validate and check correctness, then finalize and submit the best answer using FinalAnswerTool. "
142
+ "For each task, explicitly explain your planning steps and reasons for choosing which agent, and always prefer the most accurate and complete answer possible."
143
+ ),
144
+ additional_authorized_imports=[
145
+ "json",
146
+ "pandas",
147
+ "numpy",
148
+ ],
149
+ planning_interval=manager_planning_interval,
150
+ verbosity_level=2,
151
+ max_steps=manager_max_steps,
152
+ )
153
+
154
+ # runtime tracking for fallback switching
155
+ self.total_runtime = 0.0
156
+ self.first_call_duration = None
157
+ self.model_switched = False
158
+
159
+ def _switch_to_fallback(self):
160
+ if self.model_switched:
161
+ return
162
+ self.manager_agent.model = GroqModel(self.fallback_model_name)
163
+ self.model_switched = True
164
+
165
+ def run(self, question, high_stakes: bool = False, **kwargs):
166
+ start_time = time.time()
167
+ print("Generating initial answer with Qwen-32B")
168
+ initial_answer = self.manager_agent(question, **kwargs)
169
+ call_duration = time.time() - start_time
170
+
171
+ answer = initial_answer
172
+ if high_stakes or len(initial_answer.split()) > self.verification_limit:
173
+ print("Verifying answer using DeepSeek-70B")
174
+ verification_prompt = (
175
+ "Review the following answer for accuracy and rewrite if needed:"
176
+ f"\n\n{initial_answer}"
177
+ )
178
+ try:
179
+ answer = self.deepseek_model(verification_prompt)
180
+ except Exception as e:
181
+ print(f"Verification failed: {e}. Using initial answer.")
182
+ answer = initial_answer
183
+
184
+ if self.first_call_duration is None:
185
+ self.first_call_duration = call_duration
186
+ if self.first_call_duration > 30:
187
+ self._switch_to_fallback()
188
+
189
+ self.total_runtime += call_duration
190
+ if self.total_runtime > 300 and not self.model_switched:
191
+ self._switch_to_fallback()
192
+
193
+ return answer
194
+
195
+ def __call__(self, question, high_stakes: bool = False, **kwargs):
196
+
197
+ return self.run(question, high_stakes=high_stakes, **kwargs)