File size: 3,765 Bytes
422ae9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os

import datasets
from langchain.docstore.document import Document
from langchain_community.retrievers import BM25Retriever
from pydantic import BaseModel, Field


class EncyclopediaDataSets:
    @staticmethod
    def gaia(base_path: str) -> list[Document]:
        datasets_dir = os.path.join(base_path, "tools/.datasets/gaia")

        try:
            gaia_dataset: (
                datasets.DatasetDict
                | datasets.Dataset
                | datasets.IterableDatasetDict
                | datasets.IterableDataset
            ) = datasets.load_from_disk(datasets_dir)
            # print("load local")
        except Exception as e:
            # print(f"{e}load online")
            gaia_dataset: (
                datasets.DatasetDict
                | datasets.Dataset
                | datasets.IterableDatasetDict
                | datasets.IterableDataset
            ) = datasets.load_dataset(
                "gaia-benchmark/GAIA",
                "2023_all",
            )

            gaia_dataset.save_to_disk(datasets_dir)

        # dict_keys(['task_id', 'Question', 'Level', 'Final answer', 'file_name', 'file_path', 'Annotator Metadata'])
        gaia_dataset_list = (
            gaia_dataset["test"].to_list() + gaia_dataset["validation"].to_list()
        )

        # Convert dataset entries into Document objects
        docs: list[Document] = [
            Document(
                page_content="\n".join(
                    [
                        f"task_id: {gdl['task_id']}",
                        f"Question: {gdl['Question']}",
                        f"Final answer: {gdl['Final answer']}",
                    ]
                ),
                metadata={"Question": gdl["Question"]},
            )
            for gdl in gaia_dataset_list
        ]

        return docs


class EncyclopediaRetrieveInput(BaseModel):
    question: str = Field(description="使用者欲搜尋的完整問題。")


class EncyclopediaRetriever:
    def __init__(self, needed_doc_names: list[str], base_path: str):
        self.bm25_retriever = BM25Retriever.from_documents(
            self.prepare_docs(needed_doc_names, base_path)
        )

    def prepare_docs(self, needed_doc_names: list[str], base_path: str):
        """
        準備所需的 Document 文件列表。

        Args:
            needed_doc_names (list[str]): 需要載入的百科資料集合名稱列表。
            base_path (str): 存放本地資料集的基礎路徑。

        Returns:
            list[Document]: 經由所有指定來源整合而成的 Document 物件列表。

        說明:
            根據傳入的資料集名稱逐一載入相關文件,支援多來源文檔的彙整。
            目前僅支援 "gaia" 資料集,其它來源可根據需求擴充。
        """

        docs = []

        for ndn in needed_doc_names:
            if ndn == "gaia":
                docs.extend(EncyclopediaDataSets.gaia(base_path))

        return docs

    def get_related_question(self, question: str) -> str:
        """
        依據輸入問題檢索相關百科內容。

        Args:
            question (str): 使用者欲搜尋的完整問題。

        Returns:
            str: 與問題最相關的百科內容(文本格式),如無符合則傳回提示訊息。

        說明:
            本方法會使用 BM25 向量檢索器對 Document 集合進行檢索,回傳結果內容合併為字串輸出。
        """

        results: list[Document] = self.bm25_retriever.invoke(question)

        results_in_str: str = "No matching guest information found."
        if results:
            results_in_str: str = "\n\n".join([doc.page_content for doc in results])

        return results_in_str