prozorov commited on
Commit
d49dc28
·
verified ·
1 Parent(s): 81917a3

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +114 -0
agent.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langgraph.graph import START, StateGraph, MessagesState
4
+ from langgraph.prebuilt import tools_condition
5
+ from langgraph.prebuilt import ToolNode
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_groq import ChatGroq
8
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
9
+ from langchain_community.tools.tavily_search import TavilySearchResults
10
+ from langchain_community.document_loaders import WikipediaLoader
11
+ from langchain_community.document_loaders import ArxivLoader
12
+ from langchain_community.vectorstores import SupabaseVectorStore
13
+ from langchain_core.messages import SystemMessage, HumanMessage
14
+ from langchain_community.retrievers import WikipediaRetriever
15
+ from langchain_core.tools import tool
16
+ from supabase.client import Client, create_client
17
+
18
+ load_dotenv()
19
+
20
+ @tool
21
+ def wiki_search(query: str) -> str:
22
+ """Search Wikipedia for a query and return maximum 2 results.
23
+
24
+ Args:
25
+ query: The search query."""
26
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
27
+ formatted_search_docs = "\n\n---\n\n".join(
28
+ [
29
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
30
+ for doc in search_docs
31
+ ])
32
+ return {"wiki_results": formatted_search_docs}
33
+
34
+ @tool
35
+ def web_search(query: str) -> str:
36
+ """Search Tavily for a query and return maximum 3 results.
37
+
38
+ Args:
39
+ query: The search query."""
40
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
41
+ formatted_search_docs = "\n\n---\n\n".join(
42
+ [
43
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
44
+ for doc in search_docs
45
+ ])
46
+ return {"web_results": formatted_search_docs}
47
+
48
+ @tool
49
+ def arvix_search(query: str) -> str:
50
+ """Search Arxiv for a query and return maximum 3 result.
51
+
52
+ Args:
53
+ query: The search query."""
54
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
55
+ formatted_search_docs = "\n\n---\n\n".join(
56
+ [
57
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
58
+ for doc in search_docs
59
+ ])
60
+ return {"arvix_results": formatted_search_docs}
61
+
62
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
63
+ system_prompt = f.read()
64
+
65
+ # System message
66
+ sys_msg = SystemMessage(content=system_prompt)
67
+
68
+ # build a retriever
69
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
70
+ supabase: Client = create_client(
71
+ os.environ.get("SUPABASE_URL"),
72
+ os.environ.get("SUPABASE_SERVICE_KEY"))
73
+ vector_store = SupabaseVectorStore(
74
+ client=supabase,
75
+ embedding= embeddings,
76
+ table_name="documents",
77
+ query_name="match_documents_langchain",
78
+ )
79
+
80
+ tools = [
81
+ wiki_search,
82
+ web_search,
83
+ arvix_search,
84
+ ]
85
+
86
+ def build_graph():
87
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
88
+ llm_with_tools = llm.bind_tools(tools)
89
+
90
+ def assistant(state: MessagesState):
91
+ """Assistant node"""
92
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
93
+
94
+ def retriever(state: MessagesState):
95
+ """Retriever node"""
96
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
97
+ example_msg = HumanMessage(
98
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
99
+ )
100
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
101
+
102
+ builder = StateGraph(MessagesState)
103
+ builder.add_node("retriever", retriever)
104
+ builder.add_node("assistant", assistant)
105
+ builder.add_node("tools", ToolNode(tools))
106
+ builder.add_edge(START, "retriever")
107
+ builder.add_edge("retriever", "assistant")
108
+ builder.add_conditional_edges(
109
+ "assistant",
110
+ tools_condition,
111
+ )
112
+ builder.add_edge("tools", "assistant")
113
+
114
+ return builder.compile()