grshot commited on
Commit
c5a9cfd
·
1 Parent(s): 183b832

test basic tools

Browse files
Files changed (2) hide show
  1. agent.py +87 -284
  2. app.py +7 -2
agent.py CHANGED
@@ -1,342 +1,145 @@
1
- import json
2
  import os
3
- from typing import Annotated, Dict, Optional
4
 
5
- import pandas as pd
6
- from langchain_community.document_loaders import WikipediaLoader, YoutubeLoader
7
- from langchain_community.document_loaders.youtube import TranscriptFormat
8
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
9
- from langchain_core.runnables import RunnableLambda
10
  from langchain_core.tools import tool
11
- from langchain_experimental.tools.python.tool import PythonREPLTool
12
  from langchain_groq import ChatGroq
13
- from langchain_huggingface import (
14
- ChatHuggingFace,
15
- HuggingFaceEmbeddings,
16
- HuggingFaceEndpoint,
17
- )
18
-
19
- # --- Langchain / Langraph ---
20
  from langchain_tavily import TavilySearch
21
  from langgraph.graph import END, START, MessagesState, StateGraph
22
- from langgraph.graph.message import add_messages
23
  from langgraph.prebuilt import ToolNode, tools_condition
24
 
25
 
26
- # Custom exception for tool errors
27
- class ToolExecutionError(Exception):
28
- """Custom exception for tool execution errors"""
29
-
30
- pass
31
-
32
-
33
- @tool("search_web_sources")
34
- def search_web_sources(query: Annotated[str, "Search query string"]) -> Dict[str, str]:
35
- """Performs a web search and returns up to 3 formatted documents with content and source."""
36
  try:
37
- if not os.environ.get("TAVILY_API_KEY"):
38
- raise EnvironmentError(
39
- "TAVILY_API_KEY is not set in environment variables."
40
- )
41
 
42
  search_docs = TavilySearch(max_results=3).invoke({"query": query})
43
  if not search_docs:
44
- return {"web_results": "No results found for the given query."}
45
-
46
- formatted = "\n\n---\n\n".join(
47
  [
48
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}">\n{doc.page_content}\n</Document>'
49
  for doc in search_docs
50
  ]
51
  )
52
- return {"web_results": formatted}
53
  except Exception as e:
54
- return {"web_results": f"Error during web search: {str(e)}"}
55
 
56
 
57
  @tool
58
  def search_wikipedia(query: str) -> Dict[str, str]:
59
  """Search Wikipedia using LangChain's loader and return the first document summary."""
60
  try:
61
- # Input validation
62
- if not query or not isinstance(query, str):
63
- return {
64
- "wiki_results": "Invalid query provided. Please provide a valid search term."
65
- }
66
-
67
  loader = WikipediaLoader(query=query, lang="en", load_max_docs=2)
68
  docs = loader.load()
69
-
70
  if not docs:
71
- return {"wiki_results": f"No Wikipedia articles found for query: {query}"}
72
-
73
- formatted_docs = "---".join(
74
- [
75
- f'<WikipediaArticle title="{query}">{doc.page_content}</WikipediaArticle>'
76
- for doc in docs
77
- ]
78
  )
79
  return {"wiki_results": formatted_docs}
80
  except Exception as e:
81
- error_msg = str(e)
82
- if "Page id" in error_msg and "not found" in error_msg:
83
- return {"wiki_results": f"No Wikipedia article found for: {query}"}
84
- return {"wiki_results": f"Error searching Wikipedia: {error_msg}"}
85
 
86
 
87
- @tool
88
- def extract_youtube_transcript(video_url: str) -> dict:
89
- """Extract transcript from a YouTube video given its URL using LangChain's YouTubeLoader."""
90
- try:
91
- loader = YoutubeLoader.from_youtube_url(
92
- video_url,
93
- add_video_info=True,
94
- transcript_format=TranscriptFormat.CHUNKS,
95
- chunk_size_seconds=30,
96
- )
97
- docs = loader.load()
98
- if docs:
99
- formatted_docs = "\n\n---\n\n".join(
100
- [
101
- f'<YouTubeTranscript url="{video_url}">\n{doc.page_content}\n</YouTubeTranscript>'
102
- for doc in docs
103
- ]
104
- )
105
- return {"transcript_results": formatted_docs}
106
- else:
107
- return {"transcript_results": "No transcript found."}
108
- except Exception as e:
109
- return {"transcript_results": f"Error fetching YouTube transcript: {e}"}
110
-
111
-
112
- @tool
113
- def run_python_code(code: str) -> str:
114
- """Execute Python code and return the result.
115
- Args:
116
- code: Python code as a string.
117
- """
118
- repl = PythonREPLTool()
119
- return repl.run(code)
120
-
121
-
122
- # --- System Prompt ---
123
  system_prompt = SystemMessage(
124
- content="""
125
- You are a helpful and precise assistant with access to several tools. You will receive questions and use tools appropriately to find answers.
126
 
127
- When using tools:
128
- 1. Format tool calls correctly using the tool's exact name and required parameters
129
- 2. Validate inputs before making tool calls
130
- 3. Handle tool responses appropriately, checking for errors
131
- 4. If a tool fails, try an alternative approach or provide a clear error message
132
-
133
- Available tools:
134
- - search_web_sources: Search web for information (requires query parameter)
135
- - search_wikipedia: Search Wikipedia articles (requires query parameter)
136
- - extract_youtube_transcript: Get transcript from YouTube videos (requires video_url parameter)
137
- - run_python_code: Execute Python code (requires code parameter)
138
 
139
- Think step-by-step:
140
- 1. Understand the question
141
- 2. Choose appropriate tool(s)
142
- 3. Format tool calls correctly
143
- 4. Process tool responses
144
- 5. Formulate final answer
145
 
146
- Use this format strictly:
147
- FINAL ANSWER: [your concise answer here]
148
-
149
- Rules for your answer:
150
- - If the answer is a number, write only the number (no commas, units, or symbols unless asked)
151
- - If it's a string, avoid articles (a, an, the), don't abbreviate, and use plain text digits
152
- - If a list, follow the rules above for each element and separate with a comma and single space (e.g., "apple, orange, banana")
153
- - If there's an error, start with "Error:" followed by a clear explanation
154
-
155
- Your response must always begin with: FINAL ANSWER:
156
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  )
158
 
159
 
160
  def build_agent_graph(provider: str = "groq"):
 
 
 
161
 
162
- # Define toolset
163
- tools = [
164
- search_web_sources,
165
- search_wikipedia,
166
- extract_youtube_transcript,
167
- run_python_code,
168
- ]
169
-
170
- # Instantiate LLM with proper error handling
171
- groq_api_key = os.getenv("GROQ_API_KEY")
172
- if not groq_api_key:
173
- raise EnvironmentError("GROQ_API_KEY environment variable is not set")
174
-
175
  try:
176
  from pydantic import SecretStr
177
 
 
 
 
 
178
  llm = ChatGroq(
179
- model="qwen-qwq-32b", temperature=0, api_key=SecretStr(groq_api_key)
 
 
180
  )
181
  except Exception as e:
182
- raise Exception(f"Failed to initialize Groq LLM: {str(e)}")
183
-
184
- # Bind tools to the LLM
185
  llm_with_tools = llm.bind_tools(tools)
186
 
187
- # Assistant: reasoning step that plans next action
188
- def assistant_node(state: MessagesState) -> dict:
189
- try:
190
- # Validate input state
191
- if not isinstance(state, dict) or "messages" not in state:
192
- raise ValueError("Invalid state format")
193
-
194
- messages = state["messages"]
195
- if not messages:
196
- raise ValueError("Empty message list")
197
-
198
- # Invoke LLM
199
- response = llm_with_tools.invoke(messages)
200
- if response is None:
201
- raise ValueError("LLM returned None response")
202
-
203
- # Validate response format
204
- if not isinstance(response, (AIMessage, HumanMessage, SystemMessage)):
205
- raise ValueError(f"Invalid response type from LLM: {type(response)}")
206
-
207
- # Validate response content
208
- if not hasattr(response, "content") or response.content is None:
209
- raise ValueError("Response missing content")
210
-
211
- if not isinstance(response.content, str):
212
- raise ValueError(f"Invalid content type: {type(response.content)}")
213
-
214
- # Ensure response has content
215
- if not response.content.strip():
216
- raise ValueError("Empty response content")
217
-
218
- # Add FINAL ANSWER prefix if missing
219
- content = response.content
220
- if "FINAL ANSWER:" not in content:
221
- content = f"FINAL ANSWER: {content}"
222
- response = AIMessage(content=content)
223
-
224
- return {"messages": response}
225
- except Exception as e:
226
- error_msg = f"Error in assistant node: {str(e)}"
227
- print(f"Assistant node error: {error_msg}") # Log error for debugging
228
- return {
229
- "messages": AIMessage(
230
- content="FINAL ANSWER: Error occurred while processing request. Please try again."
231
- )
232
- }
233
-
234
- # Stubbed retriever node for future integration
235
- def retriever_node(state: MessagesState):
236
- """Retriever node"""
237
- # Example: use vector_store.similarity_search() in real use
238
- similar_question = [
239
- AIMessage(content="This is a mock similar document from the retriever.")
240
- ]
241
-
242
- if similar_question:
243
- example_msg = HumanMessage(
244
- content=f"Here I provide a similar question and answer for reference: {similar_question[0].content}",
245
- )
246
- return {"messages": [system_prompt] + state["messages"] + [example_msg]}
247
- else:
248
- return {"messages": [system_prompt] + state["messages"]}
249
 
250
- # Wrap tools with validation
251
- def wrap_tool_with_validation(tool):
252
- original_func = tool.__call__
253
-
254
- def validated_call(*args, **kwargs):
255
- response = original_func(*args, **kwargs)
256
-
257
- try:
258
- if not isinstance(response, dict):
259
- raise ValueError(
260
- f"Tool response must be a dict, got {type(response)}"
261
- )
262
-
263
- # Check for common response keys
264
- for key in ["web_results", "wiki_results", "transcript_results"]:
265
- if key in response:
266
- if not isinstance(response[key], str):
267
- raise ValueError(
268
- f"Tool response[{key}] must be string, got {type(response[key])}"
269
- )
270
- if not response[key].strip():
271
- raise ValueError(f"Tool response[{key}] is empty")
272
-
273
- return response
274
- except Exception as e:
275
- return {"error": f"Tool response validation failed: {str(e)}"}
276
-
277
- tool.__call__ = validated_call
278
- return tool
279
-
280
- # Apply validation wrapper to each tool
281
- validated_tools = [wrap_tool_with_validation(tool) for tool in tools]
282
- tool_node = ToolNode(validated_tools)
283
-
284
- # Define error handling node
285
- def error_handler_node(state: MessagesState) -> dict:
286
- """Handle errors in the graph execution"""
287
- error_msg = state.get("error", "Unknown error occurred")
288
- return {
289
- "messages": AIMessage(content=f"FINAL ANSWER: Error occurred: {error_msg}")
290
- }
291
-
292
- # Define the graph with ReAct loop and error handling
293
  builder = StateGraph(MessagesState)
294
- builder.add_node("assistant", RunnableLambda(assistant_node))
295
- builder.add_node("tools", tool_node)
296
- builder.add_node("retriever", RunnableLambda(retriever_node))
297
- builder.add_node("error_handler", RunnableLambda(error_handler_node))
298
 
299
  builder.set_entry_point("assistant")
300
  builder.add_conditional_edges("assistant", tools_condition)
301
  builder.add_edge("tools", "assistant")
302
  builder.add_edge("assistant", END)
303
 
304
- # Add error handling edges
305
- def route_by_error(state: MessagesState):
306
- """Route to error handler if error is present, otherwise continue normal flow"""
307
- if "error" in state:
308
- return "error_handler"
309
- return None
310
-
311
- builder.add_conditional_edges(
312
- "assistant",
313
- route_by_error,
314
- {
315
- "error_handler": "error_handler",
316
- },
317
- )
318
-
319
- builder.add_conditional_edges(
320
- "tools",
321
- route_by_error,
322
- {
323
- "error_handler": "error_handler",
324
- },
325
- )
326
-
327
- builder.add_edge("error_handler", END)
328
-
329
- graph = builder.compile()
330
-
331
- # Optional: test entrypoint to run the graph manually
332
- test_input = {
333
- "messages": [
334
- system_prompt,
335
- HumanMessage(content="What is the capital of France?"),
336
- ]
337
- }
338
-
339
- # result = graph.invoke(test_input)
340
- # print("\nFinal output:", result["messages"][-1].content)
341
-
342
- return graph
 
 
1
  import os
2
+ from typing import Dict
3
 
4
+ from langchain_community.document_loaders import WikipediaLoader
 
 
5
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
 
6
  from langchain_core.tools import tool
 
7
  from langchain_groq import ChatGroq
 
 
 
 
 
 
 
8
  from langchain_tavily import TavilySearch
9
  from langgraph.graph import END, START, MessagesState, StateGraph
 
10
  from langgraph.prebuilt import ToolNode, tools_condition
11
 
12
 
13
+ @tool
14
+ def search_web(query: str) -> Dict[str, str]:
15
+ """Search the web using Tavily and return relevant results."""
 
 
 
 
 
 
 
16
  try:
17
+ if not os.getenv("TAVILY_API_KEY"):
18
+ return {
19
+ "error": "Tavily API key not found. Please set TAVILY_API_KEY environment variable."
20
+ }
21
 
22
  search_docs = TavilySearch(max_results=3).invoke({"query": query})
23
  if not search_docs:
24
+ return {"error": "No results found"}
25
+ formatted_docs = "\n\n---\n\n".join(
 
26
  [
27
+ f'Source: {doc.metadata["source"]}\n\n{doc.page_content}'
28
  for doc in search_docs
29
  ]
30
  )
31
+ return {"web_results": formatted_docs}
32
  except Exception as e:
33
+ return {"error": f"Error searching web: {str(e)}"}
34
 
35
 
36
  @tool
37
  def search_wikipedia(query: str) -> Dict[str, str]:
38
  """Search Wikipedia using LangChain's loader and return the first document summary."""
39
  try:
 
 
 
 
 
 
40
  loader = WikipediaLoader(query=query, lang="en", load_max_docs=2)
41
  docs = loader.load()
 
42
  if not docs:
43
+ return {"error": f"No Wikipedia articles found for query: {query}"}
44
+ formatted_docs = "\n\n---\n\n".join(
45
+ [f"Wikipedia Article: {query}\n\n{doc.page_content}" for doc in docs]
 
 
 
 
46
  )
47
  return {"wiki_results": formatted_docs}
48
  except Exception as e:
49
+ return {"error": f"Error searching Wikipedia: {str(e)}"}
 
 
 
50
 
51
 
52
+ # System prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  system_prompt = SystemMessage(
54
+ content="""You are a helpful and precise assistant. When answering questions:
 
55
 
56
+ 1. First, understand what information you need to answer the question
57
+ 2. Then, use the available tools to gather information
58
+ 3. If a tool returns an error or no results, try another tool or rephrase your query
59
+ 4. Analyze all the information and formulate a clear, concise answer
 
 
 
 
 
 
 
60
 
61
+ When using tools, follow this format exactly:
62
+ Action: tool_name
63
+ Action Input: {"parameter": "value"}
 
 
 
64
 
65
+ Available tools:
66
+ - search_wikipedia: Search Wikipedia articles
67
+ Input: {"query": "your search term"}
68
+ Returns: {"wiki_results": "results"} or {"error": "error message"}
69
+ Best for: Historical facts, definitions, general knowledge
70
+ Error handling: If no results found, try rephrasing or use web search
71
+
72
+ - search_web: Search the web for information
73
+ Input: {"query": "your search term"}
74
+ Returns: {"web_results": "results"} or {"error": "error message"}
75
+ Best for: Recent events, current information, diverse sources
76
+ Error handling: If no results found, try more specific search terms
77
+
78
+ Tool usage strategy:
79
+ 1. For historical/factual queries:
80
+ - Start with Wikipedia
81
+ - If no results, try rephrasing the query
82
+ - If still no results, switch to web search
83
+
84
+ 2. For recent events/current info:
85
+ - Start with web search
86
+ - If no results, try more specific terms
87
+ - Cross-reference with Wikipedia if needed
88
+
89
+ 3. For complex queries:
90
+ - Use both tools to gather comprehensive info
91
+ - Compare and verify information
92
+ - Note any discrepancies in your answer
93
+
94
+ 4. When both tools fail:
95
+ - Try different phrasings
96
+ - Break complex queries into simpler parts
97
+ - Be transparent about limitations in your answer
98
+
99
+ Your final answer must:
100
+ 1. Begin with "FINAL ANSWER:"
101
+ 2. Be clear and concise
102
+ 3. Directly answer the question asked
103
+ 4. Include sources if relevant
104
+ 5. Admit uncertainty when information is unclear"""
105
  )
106
 
107
 
108
  def build_agent_graph(provider: str = "groq"):
109
+ """Build the graph"""
110
+ # Initialize tools
111
+ tools = [search_wikipedia, search_web]
112
 
113
+ # Initialize LLM with error handling
 
 
 
 
 
 
 
 
 
 
 
 
114
  try:
115
  from pydantic import SecretStr
116
 
117
+ groq_api_key = os.getenv("GROQ_API_KEY")
118
+ if not groq_api_key:
119
+ raise EnvironmentError("GROQ_API_KEY environment variable is not set")
120
+
121
  llm = ChatGroq(
122
+ model="qwen-qwq-32b",
123
+ temperature=0,
124
+ api_key=SecretStr(groq_api_key),
125
  )
126
  except Exception as e:
127
+ raise Exception(f"Failed to initialize LLM: {str(e)}")
 
 
128
  llm_with_tools = llm.bind_tools(tools)
129
 
130
+ # Create nodes
131
+ def assistant(state: MessagesState):
132
+ """Assistant node"""
133
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ # Build graph
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  builder = StateGraph(MessagesState)
137
+ builder.add_node("assistant", assistant)
138
+ builder.add_node("tools", ToolNode(tools))
 
 
139
 
140
  builder.set_entry_point("assistant")
141
  builder.add_conditional_edges("assistant", tools_condition)
142
  builder.add_edge("tools", "assistant")
143
  builder.add_edge("assistant", END)
144
 
145
+ return builder.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -32,11 +32,16 @@ class BasicAgent:
32
  human_msg = HumanMessage(content=question)
33
  msgs: List[AnyMessage] = [system_prompt, human_msg]
34
 
35
- # Create and cast the state
36
- input_state = cast(MessagesState, {"messages": msgs})
 
 
 
37
 
38
  # Invoke the graph with proper error handling
39
  try:
 
 
40
  result = self.graph.invoke(input_state)
41
  except Exception as e:
42
  print(f"Graph invocation error: {str(e)}")
 
32
  human_msg = HumanMessage(content=question)
33
  msgs: List[AnyMessage] = [system_prompt, human_msg]
34
 
35
+ # Create state dict that matches MessagesState structure
36
+ input_state = {"messages": msgs}
37
+
38
+ # Cast to MessagesState type
39
+ input_state = cast(MessagesState, input_state)
40
 
41
  # Invoke the graph with proper error handling
42
  try:
43
+ if not self.graph:
44
+ raise ValueError("Agent graph not initialized")
45
  result = self.graph.invoke(input_state)
46
  except Exception as e:
47
  print(f"Graph invocation error: {str(e)}")