vtony commited on
Commit
db80308
·
verified ·
1 Parent(s): ec64ddf

Upload agent.py

Browse files
Files changed (1) hide show
  1. agent.py +244 -0
agent.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ from dotenv import load_dotenv
5
+ from langgraph.graph import StateGraph, END
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_community.tools import DuckDuckGoSearchRun
8
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
9
+ from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
10
+ from langchain_core.tools import tool
11
+ from tenacity import retry, stop_after_attempt, wait_exponential
12
+
13
+ # Load environment variables
14
+ load_dotenv()
15
+ google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
16
+ if not google_api_key:
17
+ raise ValueError("Missing GOOGLE_API_KEY environment variable")
18
+
19
+ # --- Math Tools ---
20
+ @tool
21
+ def multiply(a: int, b: int) -> int:
22
+ """Multiply two integers."""
23
+ return a * b
24
+
25
+ @tool
26
+ def add(a: int, b: int) -> int:
27
+ """Add two integers."""
28
+ return a + b
29
+
30
+ @tool
31
+ def subtract(a: int, b: int) -> int:
32
+ """Subtract b from a."""
33
+ return a - b
34
+
35
+ @tool
36
+ def divide(a: int, b: int) -> float:
37
+ """Divide a by b, error on zero."""
38
+ if b == 0:
39
+ raise ValueError("Cannot divide by zero.")
40
+ return a / b
41
+
42
+ @tool
43
+ def modulus(a: int, b: int) -> int:
44
+ """Compute a mod b."""
45
+ return a % b
46
+
47
+ # --- Browser Tools ---
48
+ @tool
49
+ def wiki_search(query: str) -> str:
50
+ """Search Wikipedia and return up to 3 relevant documents."""
51
+ try:
52
+ docs = WikipediaLoader(query=query, load_max_docs=3).load()
53
+ if not docs:
54
+ return "No Wikipedia results found."
55
+
56
+ results = []
57
+ for doc in docs:
58
+ title = doc.metadata.get('title', 'Unknown Title')
59
+ content = doc.page_content[:2000] # Limit content length
60
+ results.append(f"Title: {title}\nContent: {content}")
61
+
62
+ return "\n\n---\n\n".join(results)
63
+ except Exception as e:
64
+ return f"Wikipedia search error: {str(e)}"
65
+
66
+ @tool
67
+ def arxiv_search(query: str) -> str:
68
+ """Search Arxiv and return up to 3 relevant papers."""
69
+ try:
70
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
71
+ if not docs:
72
+ return "No arXiv papers found."
73
+
74
+ results = []
75
+ for doc in docs:
76
+ title = doc.metadata.get('Title', 'Unknown Title')
77
+ authors = ", ".join(doc.metadata.get('Authors', []))
78
+ content = doc.page_content[:2000] # Limit content length
79
+ results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
80
+
81
+ return "\n\n---\n\n".join(results)
82
+ except Exception as e:
83
+ return f"arXiv search error: {str(e)}"
84
+
85
+ @tool
86
+ def web_search(query: str) -> str:
87
+ """Search the web using DuckDuckGo and return top results."""
88
+ try:
89
+ search = DuckDuckGoSearchRun()
90
+ result = search.run(query)
91
+ return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length
92
+ except Exception as e:
93
+ return f"Web search error: {str(e)}"
94
+
95
+ # --- Load system prompt ---
96
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
97
+ system_prompt = f.read()
98
+
99
+ # --- System message ---
100
+ sys_msg = SystemMessage(content=system_prompt)
101
+
102
+ # --- Tool Setup ---
103
+ tools = [
104
+ multiply,
105
+ add,
106
+ subtract,
107
+ divide,
108
+ modulus,
109
+ wiki_search,
110
+ arxiv_search,
111
+ web_search,
112
+ ]
113
+
114
+ # --- Graph Builder ---
115
+ def build_graph():
116
+ # Initialize model with Gemini 2.5 Flash
117
+ llm = ChatGoogleGenerativeAI(
118
+ model="gemini-2.5-flash",
119
+ temperature=0.3,
120
+ google_api_key=google_api_key,
121
+ max_retries=3
122
+ )
123
+
124
+ # Bind tools to LLM
125
+ llm_with_tools = llm.bind_tools(tools)
126
+
127
+ # 使用 TypedDict 定义状态而不是自定义类
128
+ from typing import TypedDict, Annotated, Sequence
129
+ import operator
130
+
131
+ class AgentState(TypedDict):
132
+ messages: Annotated[Sequence[dict], operator.add]
133
+
134
+ # Node definitions with error handling
135
+ def agent_node(state: AgentState):
136
+ """Main agent node that processes messages with retry logic"""
137
+ try:
138
+ # Add rate limiting
139
+ time.sleep(1) # 1 second delay between requests
140
+
141
+ # Add retry logic for API quota issues
142
+ @retry(stop=stop_after_attempt(3),
143
+ wait=wait_exponential(multiplier=1, min=4, max=10))
144
+ def invoke_llm_with_retry():
145
+ return llm_with_tools.invoke(state["messages"])
146
+
147
+ response = invoke_llm_with_retry()
148
+ return {"messages": [response]}
149
+
150
+ except Exception as e:
151
+ # Handle specific errors
152
+ error_type = "UNKNOWN"
153
+ if "429" in str(e):
154
+ error_type = "QUOTA_EXCEEDED"
155
+ elif "400" in str(e):
156
+ error_type = "INVALID_REQUEST"
157
+
158
+ error_msg = f"AGENT ERROR ({error_type}): {str(e)[:200]}"
159
+ return {"messages": [AIMessage(content=error_msg)]}
160
+
161
+ # Tool node
162
+ def tool_node(state: AgentState):
163
+ """Execute tools based on agent's request"""
164
+ last_message = state["messages"][-1]
165
+ tool_calls = last_message.additional_kwargs.get("tool_calls", [])
166
+
167
+ tool_responses = []
168
+ for tool_call in tool_calls:
169
+ tool_name = tool_call["function"]["name"]
170
+ tool_args = tool_call["function"].get("arguments", {})
171
+
172
+ # Find the tool
173
+ tool_func = next((t for t in tools if t.name == tool_name), None)
174
+ if not tool_func:
175
+ tool_responses.append(f"Tool {tool_name} not found")
176
+ continue
177
+
178
+ try:
179
+ # Execute the tool
180
+ if isinstance(tool_args, str):
181
+ # Parse JSON if arguments are in string format
182
+ tool_args = json.loads(tool_args)
183
+
184
+ result = tool_func.invoke(tool_args)
185
+ tool_responses.append(f"Tool {tool_name} result: {result}")
186
+ except Exception as e:
187
+ tool_responses.append(f"Tool {tool_name} error: {str(e)}")
188
+
189
+ tool_response_content = "\n".join(tool_responses)
190
+ return {"messages": [AIMessage(content=tool_response_content)]}
191
+
192
+ # Custom condition function
193
+ def should_continue(state: AgentState):
194
+ last_message = state["messages"][-1]
195
+
196
+ # If there was an error, end
197
+ if "AGENT ERROR" in last_message.content:
198
+ return "end"
199
+
200
+ # Check for tool calls
201
+ if hasattr(last_message, "tool_calls") and last_message.tool_calls:
202
+ return "tools"
203
+
204
+ # Check for final answer
205
+ if "FINAL ANSWER" in last_message.content:
206
+ return "end"
207
+
208
+ # Otherwise, continue to agent
209
+ return "agent"
210
+
211
+ # Build the graph
212
+ workflow = StateGraph(AgentState)
213
+
214
+ # Add nodes
215
+ workflow.add_node("agent", agent_node)
216
+ workflow.add_node("tools", tool_node)
217
+
218
+ # Set entry point
219
+ workflow.set_entry_point("agent")
220
+
221
+ # Define edges
222
+ workflow.add_conditional_edges(
223
+ "agent",
224
+ should_continue,
225
+ {
226
+ "agent": "agent",
227
+ "tools": "tools",
228
+ "end": END
229
+ }
230
+ )
231
+
232
+ workflow.add_conditional_edges(
233
+ "tools",
234
+ # Always go back to agent after using tools
235
+ lambda state: "agent",
236
+ {
237
+ "agent": "agent"
238
+ }
239
+ )
240
+
241
+ return workflow.compile()
242
+
243
+ # Initialize the agent graph
244
+ agent_graph = build_graph()