SarayuSree commited on
Commit
ded5270
·
verified ·
1 Parent(s): 81917a3

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +140 -0
agent.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # smol_agent.py
2
+
3
+ import os
4
+ from dotenv import load_dotenv
5
+ from smolagents import CodeAgent, tool
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_huggingface import HuggingFaceEmbeddings
8
+ from supabase.client import create_client
9
+ from langchain_community.vectorstores import SupabaseVectorStore
10
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
11
+ from langchain_community.tools.tavily_search import TavilySearchResults
12
+
13
+ load_dotenv()
14
+
15
+ # ============== Arithmetic Tools ==============
16
+
17
+ @tool
18
+ def multiply(a: int, b: int) -> int:
19
+ """Multiply two numbers."""
20
+ return a * b
21
+
22
+ @tool
23
+ def add(a: int, b: int) -> int:
24
+ """Add two numbers."""
25
+ return a + b
26
+
27
+ @tool
28
+ def subtract(a: int, b: int) -> int:
29
+ """Subtract two numbers."""
30
+ return a - b
31
+
32
+ @tool
33
+ def divide(a: int, b: int) -> float:
34
+ """Divide two numbers."""
35
+ if b == 0:
36
+ raise ValueError("Cannot divide by zero.")
37
+ return a / b
38
+
39
+ @tool
40
+ def modulus(a: int, b: int) -> int:
41
+ """Get modulus of two numbers."""
42
+ return a % b
43
+
44
+ # ============== Wikipedia Search Tool ==============
45
+
46
+ @tool
47
+ def wiki_search(query: str) -> str:
48
+ """Search Wikipedia for a query and return max 2 results."""
49
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
50
+ formatted_search_docs = "\n\n---\n\n".join(
51
+ [
52
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
53
+ for doc in search_docs
54
+ ])
55
+ return formatted_search_docs
56
+
57
+ # ============== Tavily Search Tool ==============
58
+
59
+ @tool
60
+ def web_search(query: str) -> str:
61
+ """Search Tavily for a query and return max 3 results."""
62
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
63
+ formatted_search_docs = "\n\n---\n\n".join(
64
+ [
65
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
66
+ for doc in search_docs
67
+ ])
68
+ return formatted_search_docs
69
+
70
+ # ============== Arxiv Search Tool ==============
71
+
72
+ @tool
73
+ def arxiv_search(query: str) -> str:
74
+ """Search Arxiv for a query and return max 3 results."""
75
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
76
+ formatted_search_docs = "\n\n---\n\n".join(
77
+ [
78
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
79
+ for doc in search_docs
80
+ ])
81
+ return formatted_search_docs
82
+
83
+ # ============== Retriever Tool ==============
84
+
85
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
86
+ supabase = create_client(
87
+ os.environ.get("SUPABASE_URL"),
88
+ os.environ.get("SUPABASE_SERVICE_KEY")
89
+ )
90
+ vector_store = SupabaseVectorStore(
91
+ client=supabase,
92
+ embedding=embeddings,
93
+ table_name="documents",
94
+ query_name="match_documents_langchain",
95
+ )
96
+
97
+ @tool
98
+ def retrieve_similar_question(query: str) -> str:
99
+ """Retrieve similar question from vector store."""
100
+ similar_doc = vector_store.similarity_search(query, k=1)[0]
101
+ content = similar_doc.page_content
102
+ if "Final answer :" in content:
103
+ answer = content.split("Final answer :")[-1].strip()
104
+ else:
105
+ answer = content.strip()
106
+ return answer
107
+
108
+ # ============== Load System Prompt ==============
109
+
110
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
111
+ system_prompt = f.read()
112
+
113
+ # ============== Initialize LLM ==============
114
+
115
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
116
+
117
+ # ============== Build SmolAgent ==============
118
+
119
+ agent = CodeAgent(
120
+ llm=llm,
121
+ tools=[
122
+ multiply,
123
+ add,
124
+ subtract,
125
+ divide,
126
+ modulus,
127
+ wiki_search,
128
+ web_search,
129
+ arxiv_search,
130
+ retrieve_similar_question
131
+ ],
132
+ system_prompt=system_prompt
133
+ )
134
+
135
+ # ============== Example Run ==============
136
+
137
+ if __name__ == "__main__":
138
+ user_query = "Find recent arxiv papers on diffusion models."
139
+ response = agent.run(user_query)
140
+ print(response)