File size: 3,884 Bytes
72db888
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from typing import Annotated, Any, Sequence, TypedDict

from langchain.tools import StructuredTool
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
from langchain_core.messages.base import BaseMessage
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables.base import Runnable
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from pydantic import BaseModel, Field

from config import settings
from tools.tool_collection_wiki import ToolsCollection as WikiTool


class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


class WikiAgent:
    def __init__(self):
        chat = ChatOpenAI(model="gpt-4o", verbose=True)
        self.tools: list[StructuredTool] = WikiTool.get_tools(
            [
                "wikipedia_opensearch",
                "get_page_title_excerpt_sections",
                "get_page_section_content",
            ]
        )
        self.chat_with_tools: Runnable[
            PromptValue
            | str
            | Sequence[
                BaseMessage | list[str] | tuple[str, str] | str | dict[str, Any]
            ],
            BaseMessage,
        ] = chat.bind_tools(self.tools)

        self.agent = self.build_agent()

    async def assistant(self, state: AgentState):
        result_message: BaseMessage = await self.chat_with_tools.ainvoke(
            state["messages"]
        )

        return {
            "messages": [result_message],
        }

    def build_agent(self) -> CompiledStateGraph:
        builder = StateGraph(AgentState)

        builder.add_node("assistant", self.assistant)
        builder.add_node("tools", ToolNode(self.tools))

        # Define edges: these determine how the control flow moves
        builder.add_edge(START, "assistant")

        builder.add_conditional_edges(source="assistant", path=tools_condition)

        builder.add_edge("tools", "assistant")

        agent: CompiledStateGraph = builder.compile()

        return agent

    async def ainvoke(self, message: str) -> dict[list[BaseMessage], str, Any]:
        response = await self.agent.ainvoke(
            {
                "messages": [
                    SystemMessage(
                        content="""
                        你是一個專門搜尋wikipedia的AI Agent,
                        步驟一:使用 wikipedia_opensearch 工具找出與問題相關的頁面
                        步驟二:使用 get_page_title_excerpt_sections 工具找出頁面的 excerpt 和 sections
                        步驟三:根據步驟二的 excerpt 和 sections 結合用戶問題,判斷哪些 section 會有需要的答案,呼叫 get_page_section_content 工具取得這些 section 的所有內容。
                        步驟四:總和前述步驟找出答案。
                        """
                    ),
                    HumanMessage(content=message),
                ]
            },
            config={"callbacks": [settings.LANGFUSE_HANDLER]},
        )

        # print("🎩 Agent's Response:")
        # print(response["messages"][-1].content)
        return response["messages"][-1].content


class WikipediaEnToolAgentInput(BaseModel):
    question: str = Field(description="The user question in natural language.")


def wikipedia_en_tool_agent(question: str) -> str:
    """
    Invokes the WikiAgent asynchronously to answer a user-provided question using Wikipedia.

    Args:
        question (str): The user question in natural language.

    Returns:
        str: The answer or result generated by the WikiAgent.
    """

    import asyncio

    return asyncio.run(WikiAgent().ainvoke(question))