File size: 5,540 Bytes
5ea5bac
2091d19
 
5ea5bac
2091d19
5ea5bac
2091d19
 
 
c737a31
 
 
 
2091d19
 
5ea5bac
 
2091d19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ea5bac
2091d19
 
 
 
5ea5bac
 
 
 
 
 
2091d19
 
 
 
 
5ea5bac
 
 
 
 
 
 
 
2091d19
 
 
 
 
 
 
 
 
5ea5bac
 
2091d19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ea5bac
 
 
 
2091d19
5ea5bac
2091d19
5ea5bac
2091d19
 
 
 
 
5ea5bac
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# import asyncio
from datetime import date

import nest_asyncio
from llama_index.core.agent.workflow import AgentWorkflow, ReActAgent
from llama_index.core.tools import FunctionTool
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec

from src.agent_hackathon.consts import PROJECT_ROOT_DIR

# from dotenv import find_dotenv, load_dotenv
from src.agent_hackathon.generate_arxiv_responses import ArxivResponseGenerator
from src.agent_hackathon.logger import get_logger

nest_asyncio.apply()

# _ = load_dotenv(dotenv_path=find_dotenv(raise_error_if_not_found=False), override=True)

logger = get_logger(log_name="multiagent", log_dir=PROJECT_ROOT_DIR / "logs")


class MultiAgentWorkflow:
    """Multi-agent workflow for retrieving research papers and related events."""

    def __init__(self) -> None:
        """Initialize the workflow with LLM, tools, and generator."""
        logger.info("Initializing MultiAgentWorkflow.")
        self.llm = HuggingFaceInferenceAPI(
            model="meta-llama/Llama-3.3-70B-Instruct",
            provider="auto",
            # provider="nebius",
            temperature=0.1,
            top_p=0.95,
            # api_key=os.getenv(key="NEBIUS_API_KEY"),
            # base_url="https://api.studio.nebius.com/v1/",
            system_prompt="Don't just plan, but execute the plan until failure.",
        )
        self._generator = ArxivResponseGenerator(
            vector_store_path=PROJECT_ROOT_DIR / "db/arxiv_docs.db"
        )
        self._arxiv_rag_tool = FunctionTool.from_defaults(
            fn=self._arxiv_rag,
            name="arxiv_rag",
            description="Retrieves arxiv research papers.",
            return_direct=False,
        )
        self._duckduckgo_search_tool = [
            tool
            for tool in DuckDuckGoSearchToolSpec().to_tool_list()
            if tool.metadata.name == "duckduckgo_full_search"
        ]
        self._arxiv_agent = ReActAgent(
            name="arxiv_agent",
            description="Retrieves information about arxiv research papers",
            system_prompt="You are arxiv research paper agent, who retrieves information "
            "about arxiv research papers.",
            tools=[self._arxiv_rag_tool],
            llm=self.llm,
        )
        self._websearch_agent = ReActAgent(
            name="web_search",
            description="Searches the web",
            system_prompt="You are search engine who searches the web using duckduckgo tool",
            tools=self._duckduckgo_search_tool,
            llm=self.llm,
        )

        self._workflow = AgentWorkflow(
            agents=[self._arxiv_agent, self._websearch_agent],
            root_agent="arxiv_agent",
            timeout=180,
        )
        # AgentWorkflow.from_tools_or_functions(
        #     tools_or_functions=self._duckduckgo_search_tool,
        #     llm=self.llm,
        #     system_prompt="You are an expert that  "
        #     "searches for any corresponding events related to the "
        #     "user query "
        #     "using the duckduckgo_search_tool and returns the final results." \
        #     "Don't return the steps but execute the necessary tools that you have " \
        #     "access to and return the results.",
        #     timeout=180,
        # )

        logger.info("MultiAgentWorkflow initialized.")

    def _arxiv_rag(self, query: str) -> str:
        """Retrieve research papers from arXiv based on the query.

        Args:
            query (str): The search query.

        Returns:
            str: Retrieved research papers as a string.
        """
        return self._generator.retrieve_arxiv_papers(query=query)

    def _clean_response(self, result: str) -> str:
        """Removes the think tags.

        Args:
            result (str): The result with the <think></think> content.

        Returns:
            str: The result without the <think></think> content.
        """
        if result.find("</think>"):
            result = result[result.find("</think>") + len("</think>") :]
        return result

    async def run(self, user_query: str) -> str:
        """Run the multi-agent workflow for a given user query.

        Args:
            user_query (str): The user's search query.

        Returns:
            str: The output string.
        """
        logger.info("Running multi-agent workflow.")
        try:
            user_msg = (
                f"First, give me arxiv research papers about: {user_query}."
                f"Then search with web search agent for any events related to : {user_query}.\n"
                f"The web search results should be relevant to the current year: {date.today().year}."
                "Return all the content from all the agents."
            )
            results = await self._workflow.run(user_msg=user_msg)
            logger.info("Workflow run completed successfully.")
            return results
        except Exception as err:
            logger.error(f"Workflow run failed: {err}")
            raise


# if __name__ == "__main__":
#     USER_QUERY = "i want to learn more about nlp"
#     workflow = MultiAgentWorkflow()
#     logger.info("Starting workflow for user query.")
#     try:
#         result = asyncio.run(workflow.run(user_query=USER_QUERY))
#         logger.info("Workflow finished. Output below:")
#         print(result)
#     except Exception as err:
#         logger.error(f"Error during workflow execution: {err}")