Final_Assignment / tools /encyclopedia.py
Alfred828's picture
Create tools/encyclopedia.py
422ae9b verified
import os
import datasets
from langchain.docstore.document import Document
from langchain_community.retrievers import BM25Retriever
from pydantic import BaseModel, Field
class EncyclopediaDataSets:
@staticmethod
def gaia(base_path: str) -> list[Document]:
datasets_dir = os.path.join(base_path, "tools/.datasets/gaia")
try:
gaia_dataset: (
datasets.DatasetDict
| datasets.Dataset
| datasets.IterableDatasetDict
| datasets.IterableDataset
) = datasets.load_from_disk(datasets_dir)
# print("load local")
except Exception as e:
# print(f"{e}load online")
gaia_dataset: (
datasets.DatasetDict
| datasets.Dataset
| datasets.IterableDatasetDict
| datasets.IterableDataset
) = datasets.load_dataset(
"gaia-benchmark/GAIA",
"2023_all",
)
gaia_dataset.save_to_disk(datasets_dir)
# dict_keys(['task_id', 'Question', 'Level', 'Final answer', 'file_name', 'file_path', 'Annotator Metadata'])
gaia_dataset_list = (
gaia_dataset["test"].to_list() + gaia_dataset["validation"].to_list()
)
# Convert dataset entries into Document objects
docs: list[Document] = [
Document(
page_content="\n".join(
[
f"task_id: {gdl['task_id']}",
f"Question: {gdl['Question']}",
f"Final answer: {gdl['Final answer']}",
]
),
metadata={"Question": gdl["Question"]},
)
for gdl in gaia_dataset_list
]
return docs
class EncyclopediaRetrieveInput(BaseModel):
question: str = Field(description="使用者欲搜尋的完整問題。")
class EncyclopediaRetriever:
def __init__(self, needed_doc_names: list[str], base_path: str):
self.bm25_retriever = BM25Retriever.from_documents(
self.prepare_docs(needed_doc_names, base_path)
)
def prepare_docs(self, needed_doc_names: list[str], base_path: str):
"""
準備所需的 Document 文件列表。
Args:
needed_doc_names (list[str]): 需要載入的百科資料集合名稱列表。
base_path (str): 存放本地資料集的基礎路徑。
Returns:
list[Document]: 經由所有指定來源整合而成的 Document 物件列表。
說明:
根據傳入的資料集名稱逐一載入相關文件,支援多來源文檔的彙整。
目前僅支援 "gaia" 資料集,其它來源可根據需求擴充。
"""
docs = []
for ndn in needed_doc_names:
if ndn == "gaia":
docs.extend(EncyclopediaDataSets.gaia(base_path))
return docs
def get_related_question(self, question: str) -> str:
"""
依據輸入問題檢索相關百科內容。
Args:
question (str): 使用者欲搜尋的完整問題。
Returns:
str: 與問題最相關的百科內容(文本格式),如無符合則傳回提示訊息。
說明:
本方法會使用 BM25 向量檢索器對 Document 集合進行檢索,回傳結果內容合併為字串輸出。
"""
results: list[Document] = self.bm25_retriever.invoke(question)
results_in_str: str = "No matching guest information found."
if results:
results_in_str: str = "\n\n".join([doc.page_content for doc in results])
return results_in_str