grshot commited on
Commit
b1dbad7
·
1 Parent(s): 7528f0e

Add Retriver

Browse files
Files changed (2) hide show
  1. agent.py +197 -121
  2. requirements.txt +3 -1
agent.py CHANGED
@@ -1,16 +1,28 @@
 
1
  import os
2
  from typing import Dict, List, Sequence, TypedDict, cast
3
 
4
  from dotenv import load_dotenv
 
5
  from langchain_community.document_loaders import ArxivLoader, WikipediaLoader
 
6
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
7
  from langchain_core.tools import tool
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
 
 
 
 
 
10
  from langchain_tavily import TavilySearch
11
  from langgraph.graph import END, START, MessagesState, StateGraph
12
  from langgraph.prebuilt import ToolNode, tools_condition
13
  from pydantic import BaseModel
 
 
 
 
14
 
15
 
16
  class WebSearchInput(BaseModel):
@@ -25,31 +37,26 @@ class ArxivSearchInput(BaseModel):
25
  query: str
26
 
27
 
28
- @tool(args_schema=WebSearchInput)
29
- def search_web(query: str) -> Dict[str, str]:
30
  """Search the web using Tavily and return relevant results."""
31
- try:
32
- if not os.getenv("TAVILY_API_KEY"):
33
- return {
34
- "error": "Tavily API key not found. Please set TAVILY_API_KEY environment variable."
35
- }
36
 
37
- search_docs = TavilySearch(max_results=3).invoke({"query": query})
38
- if not search_docs:
39
- return {"error": "No results found"}
40
- formatted_docs = "\n\n---\n\n".join(
41
- [
42
- f'Source: {doc.metadata["source"]}\n\n{doc.page_content}'
43
- for doc in search_docs
44
- ]
45
- )
46
- return {"web_results": formatted_docs}
47
- except Exception as e:
48
- return {"error": f"Error searching web: {str(e)}"}
49
 
50
 
51
- @tool(args_schema=WikipediaSearchInput)
52
- def search_wikipedia(query: str) -> Dict[str, str]:
53
  """Search Wikipedia using LangChain's loader and return the first document summary."""
54
  try:
55
  loader = WikipediaLoader(query=query, lang="en", load_max_docs=2)
@@ -64,8 +71,8 @@ def search_wikipedia(query: str) -> Dict[str, str]:
64
  return {"error": f"Error searching Wikipedia: {str(e)}"}
65
 
66
 
67
- @tool(args_schema=ArxivSearchInput)
68
- def arxiv_search(query: str) -> Dict[str, str]:
69
  """Search Arxiv for a query and return maximum 3 result.
70
  Args:
71
  query: The search query."""
@@ -79,84 +86,140 @@ def arxiv_search(query: str) -> Dict[str, str]:
79
  return {"arxiv_results": formatted_search_docs}
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # System prompt
83
  system_prompt = SystemMessage(
84
- content="""You are a helpful and precise assistant. When answering questions:
85
-
86
- 1. First, understand what information you need to answer the question
87
- 2. Then, use the available tools to gather information
88
- 3. If a tool returns an error or no results, try another tool or rephrase your query
89
- 4. Analyze all the information and formulate a clear, concise answer
90
-
91
- When using tools, follow this format exactly:
92
- Action: tool_name
93
- Action Input: {"parameter": "value"}
94
-
95
- Available tools:
96
- - search_wikipedia: Search Wikipedia articles
97
- Input: {"query": "your search term"}
98
- Returns: {"wiki_results": "results"} or {"error": "error message"}
99
- Best for: Historical facts, definitions, general knowledge
100
- Error handling: If no results found, try rephrasing or use web search
101
-
102
- - search_web: Search the web for information
103
- Input: {"query": "your search term"}
104
- Returns: {"web_results": "results"} or {"error": "error message"}
105
- Best for: Recent events, current information, diverse sources
106
- Error handling: If no results found, try more specific search terms
107
-
108
- - arxiv_search: Search scholarly papers on arXiv
109
- Input: {"query": "topic or keywords"}
110
- Returns: {"arxiv_results": "paper summaries with title, authors, abstract"} or {"error": "error message"}
111
- Best for: Academic research, recent papers in science and technology
112
- Error handling: If no results, simplify keywords or broaden the topic
113
-
114
- Tool usage strategy:
115
- 1. For historical/factual queries:
116
- - Start with Wikipedia
117
- - If no results, try rephrasing the query
118
- - If still no results, switch to web search
119
-
120
- 2. For recent events/current info:
121
- - Start with web search
122
- - If no results, try more specific terms
123
- - Cross-reference with Wikipedia if needed
124
-
125
- 3. For academic/scientific questions:
126
- - Use arxiv_search to find recent papers
127
- - Summarize key findings, topics, or citations
128
- - Cross-check with web or Wikipedia if needed
129
-
130
- 4. For complex queries:
131
- - Use all tools to gather comprehensive info
132
- - Compare and verify information
133
- - Note any discrepancies in your answer
134
-
135
- 5. Whenall tools fail:
136
- - Try different phrasings
137
- - Break complex queries into simpler parts
138
- - Be transparent about limitations in your answer
139
-
140
- Your final answer must:
141
- 1. Begin with "FINAL ANSWER:"
142
- 2. Be clear and concise
143
- 3. Directly answer the question asked
144
- 4. Include sources if relevant
145
- 5. Admit uncertainty when information is unclear"""
146
  )
147
 
148
- # Initialize tools
149
- tools = [search_wikipedia, search_web, arxiv_search]
150
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- def build_agent_graph(provider: str = "gemini"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  """Build the graph"""
154
 
155
  # Initialize LLM class
156
  try:
157
  gemini_api_key = os.getenv("GEMINI_API_KEY")
158
-
159
- if provider == "gemini":
 
 
 
 
160
  chat_model = ChatGoogleGenerativeAI(
161
  model="gemini-2.5-pro",
162
  temperature=1.0,
@@ -165,11 +228,7 @@ def build_agent_graph(provider: str = "gemini"):
165
  )
166
  elif provider == "huggingface":
167
  llm = HuggingFaceEndpoint(
168
- repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
169
- task="text-generation",
170
- max_new_tokens=1024,
171
- do_sample=False,
172
- repetition_penalty=1.03,
173
  temperature=0,
174
  )
175
  chat_model = ChatHuggingFace(llm=llm, verbose=True)
@@ -181,32 +240,47 @@ def build_agent_graph(provider: str = "gemini"):
181
  llm_with_tools = chat_model.bind_tools(tools)
182
 
183
  # Create nodes
184
- def assistant(state: MessagesState) -> Dict[str, List[AIMessage]]:
185
  """Assistant node"""
186
- try:
187
- # Get last message
188
- messages = state.get("messages", [])
189
- if not messages:
190
- return {"messages": [AIMessage(content="Error: No messages found")]}
191
 
192
- # Run LLM and ensure AIMessage response
193
- response = llm_with_tools.invoke(messages)
194
- if isinstance(response, AIMessage):
195
- return {"messages": [response]}
196
- return {"messages": [AIMessage(content=str(response))]}
 
 
 
 
 
 
197
 
198
- except Exception as e:
199
- return {"messages": [AIMessage(content=f"Error: {str(e)}")]}
 
 
 
 
 
 
 
200
 
201
  # Build graph
202
  builder = StateGraph(MessagesState)
203
- builder.add_node("assistant", assistant)
204
- builder.add_node("tools", ToolNode(tools))
205
-
206
- builder.set_entry_point("assistant")
207
- builder.add_conditional_edges("assistant", tools_condition)
208
- builder.add_edge("tools", "assistant")
209
- builder.add_edge("assistant", END)
 
 
 
 
 
 
210
 
211
  return builder.compile()
212
 
@@ -214,8 +288,6 @@ def build_agent_graph(provider: str = "gemini"):
214
  # Manual test function
215
  def test_agent():
216
  """Run a manual test of the agent"""
217
- # Load environment variables from .env file
218
- load_dotenv()
219
  print("\n" + "=" * 50)
220
  print("Starting Agent Test")
221
  print("=" * 50)
@@ -230,16 +302,19 @@ def test_agent():
230
  if not os.getenv("TAVILY_API_KEY"):
231
  print("\nWarning: TAVILY_API_KEY not set - web search will be unavailable")
232
 
 
 
 
233
  print("\nInitializing agent...")
234
  try:
235
- graph = build_agent_graph(provider="gemini")
236
  print("Agent initialized successfully")
237
  except Exception as e:
238
  print(f"Failed to initialize agent: {str(e)}")
239
  return
240
 
241
  # Test a single question
242
- question = "What is the surname of the equine veterinarian mentioned in 1.E Exercises from the chemistry materials licensed by Marisa Alviar-Agnew & Henry Agnew under the CK-12 license in LibreText's Introductory Chemistry materials as compiled 08/21/2023?"
243
  print("\nTesting question:", question)
244
  print("-" * 50)
245
 
@@ -253,6 +328,7 @@ def test_agent():
253
 
254
  # Get answer
255
  if result and "messages" in result and result["messages"]:
 
256
  answer = result["messages"][-1].content
257
  print("\nResponse received:")
258
  print("-" * 20)
 
1
+ import cmath
2
  import os
3
  from typing import Dict, List, Sequence, TypedDict, cast
4
 
5
  from dotenv import load_dotenv
6
+ from langchain.tools.retriever import create_retriever_tool
7
  from langchain_community.document_loaders import ArxivLoader, WikipediaLoader
8
+ from langchain_community.vectorstores import SupabaseVectorStore
9
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
10
  from langchain_core.tools import tool
11
  from langchain_google_genai import ChatGoogleGenerativeAI
12
+ from langchain_groq import ChatGroq
13
+ from langchain_huggingface import (
14
+ ChatHuggingFace,
15
+ HuggingFaceEmbeddings,
16
+ HuggingFaceEndpoint,
17
+ )
18
  from langchain_tavily import TavilySearch
19
  from langgraph.graph import END, START, MessagesState, StateGraph
20
  from langgraph.prebuilt import ToolNode, tools_condition
21
  from pydantic import BaseModel
22
+ from supabase.client import Client, create_client
23
+
24
+ # Load environment variables from .env file
25
+ load_dotenv()
26
 
27
 
28
  class WebSearchInput(BaseModel):
 
37
  query: str
38
 
39
 
40
+ @tool
41
+ def search_web(query: str) -> str:
42
  """Search the web using Tavily and return relevant results."""
 
 
 
 
 
43
 
44
+ """Search Tavily for a query and return maximum 3 results.
45
+
46
+ Args:
47
+ query: The search query."""
48
+ search_docs = TavilySearch(max_results=3).invoke({"query": query})
49
+ formatted_search_docs = "\n\n---\n\n".join(
50
+ [
51
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
52
+ for doc in search_docs
53
+ ]
54
+ )
55
+ return {"web_results": formatted_search_docs}
56
 
57
 
58
+ @tool
59
+ def search_wikipedia(query: str) -> str:
60
  """Search Wikipedia using LangChain's loader and return the first document summary."""
61
  try:
62
  loader = WikipediaLoader(query=query, lang="en", load_max_docs=2)
 
71
  return {"error": f"Error searching Wikipedia: {str(e)}"}
72
 
73
 
74
+ @tool
75
+ def arxiv_search(query: str) -> str:
76
  """Search Arxiv for a query and return maximum 3 result.
77
  Args:
78
  query: The search query."""
 
86
  return {"arxiv_results": formatted_search_docs}
87
 
88
 
89
+ @tool
90
+ def power(a: float, b: float) -> float:
91
+ """
92
+ Get the power of two numbers.
93
+ Args:
94
+ a (float): the first number
95
+ b (float): the second number
96
+ """
97
+ return a**b
98
+
99
+
100
+ @tool
101
+ def square_root(a: float) -> float | complex:
102
+ """
103
+ Get the square root of a number.
104
+ Args:
105
+ a (float): the number to get the square root of
106
+ """
107
+ if a >= 0:
108
+ return a**0.5
109
+ return cmath.sqrt(a)
110
+
111
+
112
+ @tool
113
+ def multiply(a: int, b: int) -> int:
114
+ """Multiply two numbers.
115
+ Args:
116
+ a: first int
117
+ b: second int
118
+ """
119
+ return a * b
120
+
121
+
122
+ @tool
123
+ def add(a: int, b: int) -> int:
124
+ """Add two numbers.
125
+
126
+ Args:
127
+ a: first int
128
+ b: second int
129
+ """
130
+ return a + b
131
+
132
+
133
+ @tool
134
+ def subtract(a: int, b: int) -> int:
135
+ """Subtract two numbers.
136
+
137
+ Args:
138
+ a: first int
139
+ b: second int
140
+ """
141
+ return a - b
142
+
143
+
144
+ @tool
145
+ def divide(a: float, b: float) -> float:
146
+ """
147
+ Divides two numbers.
148
+ Args:
149
+ a (float): the first float number
150
+ b (float): the second float number
151
+ """
152
+ if b == 0:
153
+ raise ValueError("Cannot divided by zero.")
154
+ return a / b
155
+
156
+
157
+ @tool
158
+ def modulus(a: int, b: int) -> int:
159
+ """Get the modulus of two numbers.
160
+
161
+ Args:
162
+ a: first int
163
+ b: second int
164
+ """
165
+ return a % b
166
+
167
+
168
  # System prompt
169
  system_prompt = SystemMessage(
170
+ content="""You are a helpful assistant tasked with answering questions using a set of tools.
171
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
172
+ FINAL ANSWER: [YOUR FINAL ANSWER].
173
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, Apply the rules above for each element (number or string), ensure there is exactly one space after each comma.
174
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer. """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  )
176
 
177
+ supabase_url = os.environ.get("SUPABASE_URL")
178
+ supabase_service_key = os.environ.get("SUPABASE_SERVICE_KEY")
179
+ # build a retriever
180
+ embeddings = HuggingFaceEmbeddings(
181
+ model_name="sentence-transformers/all-mpnet-base-v2"
182
+ ) # dim=768
183
+ supabase: Client = create_client(supabase_url, supabase_service_key)
184
+ vector_store = SupabaseVectorStore(
185
+ client=supabase,
186
+ embedding=embeddings,
187
+ table_name="documents",
188
+ query_name="match_documents_langchain",
189
+ )
190
+ create_retriever_tool = create_retriever_tool(
191
+ retriever=vector_store.as_retriever(),
192
+ name="Question Search",
193
+ description="A tool to retrieve similar questions from a vector store.",
194
+ )
195
 
196
+ # Initialize tools
197
+ tools = [
198
+ search_wikipedia,
199
+ search_web,
200
+ arxiv_search,
201
+ power,
202
+ square_root,
203
+ multiply,
204
+ divide,
205
+ subtract,
206
+ add,
207
+ modulus,
208
+ ]
209
+
210
+
211
+ def build_agent_graph(provider: str = "groq"):
212
  """Build the graph"""
213
 
214
  # Initialize LLM class
215
  try:
216
  gemini_api_key = os.getenv("GEMINI_API_KEY")
217
+ if provider == "groq":
218
+ # Groq https://console.groq.com/docs/models
219
+ chat_model = ChatGroq(
220
+ model="qwen-qwq-32b", temperature=0
221
+ ) # optional : qwen-qwq-32b gemma2-9b-it
222
+ elif provider == "gemini":
223
  chat_model = ChatGoogleGenerativeAI(
224
  model="gemini-2.5-pro",
225
  temperature=1.0,
 
228
  )
229
  elif provider == "huggingface":
230
  llm = HuggingFaceEndpoint(
231
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
 
 
 
 
232
  temperature=0,
233
  )
234
  chat_model = ChatHuggingFace(llm=llm, verbose=True)
 
240
  llm_with_tools = chat_model.bind_tools(tools)
241
 
242
  # Create nodes
243
+ def assistant(state: MessagesState):
244
  """Assistant node"""
245
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
246
 
247
+ def retriever(state: MessagesState):
248
+ query = state["messages"][-1].content
249
+ results = vector_store.similarity_search(query, k=1)
250
+
251
+ if not results:
252
+ print(f"[retriever] No similar documents found for query: {query}")
253
+ return {
254
+ "messages": [
255
+ AIMessage(content="I couldn't find any similar content in memory.")
256
+ ]
257
+ }
258
 
259
+ similar_doc = results[0]
260
+ content = similar_doc.page_content
261
+
262
+ if "Final answer :" in content:
263
+ answer = content.split("Final answer :")[-1].strip()
264
+ else:
265
+ answer = content.strip()
266
+
267
+ return {"messages": [AIMessage(content=answer)]}
268
 
269
  # Build graph
270
  builder = StateGraph(MessagesState)
271
+ builder.add_node("retriever", retriever)
272
+ # builder.add_node("assistant", assistant)
273
+ # builder.add_node("tools", ToolNode(tools))
274
+ # builder.add_edge(START, "retriever")
275
+ # builder.add_edge("retriever", "assistant")
276
+ # builder.add_conditional_edges(
277
+ # "assistant",
278
+ # tools_condition,
279
+ # )
280
+ # builder.add_edge("tools", "assistant")
281
+
282
+ builder.set_entry_point("retriever")
283
+ builder.set_finish_point("retriever")
284
 
285
  return builder.compile()
286
 
 
288
  # Manual test function
289
  def test_agent():
290
  """Run a manual test of the agent"""
 
 
291
  print("\n" + "=" * 50)
292
  print("Starting Agent Test")
293
  print("=" * 50)
 
302
  if not os.getenv("TAVILY_API_KEY"):
303
  print("\nWarning: TAVILY_API_KEY not set - web search will be unavailable")
304
 
305
+ if not os.getenv("SUPABASE_URL"):
306
+ print("\nWarning: SUPABASE_URL not set - web search will be unavailable")
307
+
308
  print("\nInitializing agent...")
309
  try:
310
+ graph = build_agent_graph(provider="groq")
311
  print("Agent initialized successfully")
312
  except Exception as e:
313
  print(f"Failed to initialize agent: {str(e)}")
314
  return
315
 
316
  # Test a single question
317
+ question = "Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\""
318
  print("\nTesting question:", question)
319
  print("-" * 50)
320
 
 
328
 
329
  # Get answer
330
  if result and "messages" in result and result["messages"]:
331
+
332
  answer = result["messages"][-1].content
333
  print("\nResponse received:")
334
  print("-" * 20)
requirements.txt CHANGED
@@ -14,4 +14,6 @@ pytube>=15.0.0
14
  langchain_huggingface
15
  langchain-google-genai
16
  pymupdf
17
- arxiv
 
 
 
14
  langchain_huggingface
15
  langchain-google-genai
16
  pymupdf
17
+ arxiv
18
+ supabase
19
+ pgvector