Sagar Sanghani
start building csv cache
8c269dc
raw
history blame
4.81 kB
from dotenv import load_dotenv, find_dotenv
import os
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_tavily import TavilySearch
from langchain_community.document_loaders import AsyncHtmlLoader
from langchain.tools import tool
from langchain.prompts import ChatPromptTemplate
from langchain.agents import AgentExecutor, create_tool_calling_agent
from prompt import get_prompt
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
import re
# --- Define Tools ---
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two integers."""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two integers."""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract b from a."""
return a - b
@tool
def divide(a: int, b: int) -> float:
"""Divide a by b, error on zero."""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Compute a mod b."""
return a % b
@tool
def wiki_search(query: str) -> dict:
"""Search Wikipedia and return up to 2 documents."""
docs = WikipediaLoader(query=query, load_max_docs=2).load()
results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content}" for d in docs]
return {"wiki_results": "\n---\n".join(results)}
@tool
def web_search(query: str) -> dict:
"""Do a web search with Tavily and return up to 4 results."""
docs = TavilySearchResults(max_results=4).invoke(query=query)
results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content}" for d in docs]
return {"web_results": "\n---\n".join(results)}
@tool
def arxiv_search(query: str) -> dict:
"""Search Arxiv and return up to 3 docs."""
docs = ArxivLoader(query=query, load_max_docs=3).load()
results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content[:1000]}" for d in docs]
return {"arxiv_results": "\n---\n".join(results)}
class Model:
def __init__(self):
#load_dotenv(find_dotenv())
self.token = os.getenv("HF_TOKEN")
self.system_prompt = get_prompt()
print(f"system_prompt: {self.system_prompt}")
self.agent_executor = self.setup_model()
def get_answer(self, question: str) -> str:
try:
result = self.agent_executor.invoke({"input": question})
except BaseException as e:
print(f"An error occurred: {e}")
result = {"FINAL_ANSWER":"ERROR"}
# The final answer is typically in the 'output' key of the result dictionary
final_answer = result['output']
pattern = r'FINAL_ANSWER:"(.*?)"'
match = re.search(pattern, final_answer, re.DOTALL)
if match:
final_answer_value = match.group(1)
print(f"The extracted FINAL_ANSWER is: {final_answer_value}")
else:
print("ERROR: Pattern not found.: {r}")
final_answer_value = "ERROR"
return final_answer_value
def setup_model(self):
search = DuckDuckGoSearchRun()
tavily_search_tool = TavilySearch(
api_key=os.getenv("TAVILY_API_KEY"),
max_results=5,
topic="general",
)
# # Define a tool for the agent to use
tools = [
multiply,
add,
subtract,
divide,
modulus,
wiki_search,
tavily_search_tool,
arxiv_search,
]
llm = HuggingFaceEndpoint(
repo_id="Qwen/Qwen3-Next-80B-A3B-Thinking",
huggingfacehub_api_token=self.token,
temperature=0
)
chat = ChatHuggingFace(llm=llm).bind_tools(tools)
# Create the ReAct prompt template
prompt = ChatPromptTemplate.from_messages(
[
("system", self.system_prompt), # Use the new, detailed ReAct prompt
("placeholder", "{agent_scratchpad}"),
("human", "{input}"),
]
)
# Create the agent
agent = create_tool_calling_agent(chat, tools, prompt)
# Create the agent executor
return AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
def main():
load_dotenv(find_dotenv())
model = Model()
response = model.get_answer("Where were the Vietnamese specimens described by Kuznetzov in Nedoshivina's 2010 paper eventually deposited? Just give me the city name without abbreviations.")
print(f"the output is: {response}")
if __name__ == "__main__":
main()