|
|
|
from datetime import date |
|
|
|
import nest_asyncio |
|
from llama_index.core.agent.workflow import AgentWorkflow, ReActAgent |
|
from llama_index.core.tools import FunctionTool |
|
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI |
|
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec |
|
|
|
from src.agent_hackathon.consts import PROJECT_ROOT_DIR |
|
|
|
|
|
from src.agent_hackathon.generate_arxiv_responses import ArxivResponseGenerator |
|
from src.agent_hackathon.logger import get_logger |
|
|
|
nest_asyncio.apply() |
|
|
|
|
|
|
|
logger = get_logger(log_name="multiagent", log_dir=PROJECT_ROOT_DIR / "logs") |
|
|
|
|
|
class MultiAgentWorkflow: |
|
"""Multi-agent workflow for retrieving research papers and related events.""" |
|
|
|
def __init__(self) -> None: |
|
"""Initialize the workflow with LLM, tools, and generator.""" |
|
logger.info("Initializing MultiAgentWorkflow.") |
|
self.llm = HuggingFaceInferenceAPI( |
|
model="meta-llama/Llama-3.3-70B-Instruct", |
|
provider="auto", |
|
|
|
temperature=0.1, |
|
top_p=0.95, |
|
|
|
|
|
system_prompt="Don't just plan, but execute the plan until failure.", |
|
) |
|
self._generator = ArxivResponseGenerator( |
|
vector_store_path=PROJECT_ROOT_DIR / "db/arxiv_docs.db" |
|
) |
|
self._arxiv_rag_tool = FunctionTool.from_defaults( |
|
fn=self._arxiv_rag, |
|
name="arxiv_rag", |
|
description="Retrieves arxiv research papers.", |
|
return_direct=False, |
|
) |
|
self._duckduckgo_search_tool = [ |
|
tool |
|
for tool in DuckDuckGoSearchToolSpec().to_tool_list() |
|
if tool.metadata.name == "duckduckgo_full_search" |
|
] |
|
self._arxiv_agent = ReActAgent( |
|
name="arxiv_agent", |
|
description="Retrieves information about arxiv research papers", |
|
system_prompt="You are arxiv research paper agent, who retrieves information " |
|
"about arxiv research papers.", |
|
tools=[self._arxiv_rag_tool], |
|
llm=self.llm, |
|
) |
|
self._websearch_agent = ReActAgent( |
|
name="web_search", |
|
description="Searches the web", |
|
system_prompt="You are search engine who searches the web using duckduckgo tool", |
|
tools=self._duckduckgo_search_tool, |
|
llm=self.llm, |
|
) |
|
|
|
self._workflow = AgentWorkflow( |
|
agents=[self._arxiv_agent, self._websearch_agent], |
|
root_agent="arxiv_agent", |
|
timeout=180, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("MultiAgentWorkflow initialized.") |
|
|
|
def _arxiv_rag(self, query: str) -> str: |
|
"""Retrieve research papers from arXiv based on the query. |
|
|
|
Args: |
|
query (str): The search query. |
|
|
|
Returns: |
|
str: Retrieved research papers as a string. |
|
""" |
|
return self._generator.retrieve_arxiv_papers(query=query) |
|
|
|
def _clean_response(self, result: str) -> str: |
|
"""Removes the think tags. |
|
|
|
Args: |
|
result (str): The result with the <think></think> content. |
|
|
|
Returns: |
|
str: The result without the <think></think> content. |
|
""" |
|
if result.find("</think>"): |
|
result = result[result.find("</think>") + len("</think>") :] |
|
return result |
|
|
|
async def run(self, user_query: str) -> str: |
|
"""Run the multi-agent workflow for a given user query. |
|
|
|
Args: |
|
user_query (str): The user's search query. |
|
|
|
Returns: |
|
str: The output string. |
|
""" |
|
logger.info("Running multi-agent workflow.") |
|
try: |
|
user_msg = ( |
|
f"First, give me arxiv research papers about: {user_query}." |
|
f"Then search with web search agent for any events related to : {user_query}.\n" |
|
f"The web search results should be relevant to the current year: {date.today().year}." |
|
"Return all the content from all the agents." |
|
) |
|
results = await self._workflow.run(user_msg=user_msg) |
|
logger.info("Workflow run completed successfully.") |
|
return results |
|
except Exception as err: |
|
logger.error(f"Workflow run failed: {err}") |
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|