Spaces:
Sleeping
Sleeping
import os | |
import datasets | |
from langchain.docstore.document import Document | |
from langchain_community.retrievers import BM25Retriever | |
from pydantic import BaseModel, Field | |
class EncyclopediaDataSets: | |
def gaia(base_path: str) -> list[Document]: | |
datasets_dir = os.path.join(base_path, "tools/.datasets/gaia") | |
try: | |
gaia_dataset: ( | |
datasets.DatasetDict | |
| datasets.Dataset | |
| datasets.IterableDatasetDict | |
| datasets.IterableDataset | |
) = datasets.load_from_disk(datasets_dir) | |
# print("load local") | |
except Exception as e: | |
# print(f"{e}load online") | |
gaia_dataset: ( | |
datasets.DatasetDict | |
| datasets.Dataset | |
| datasets.IterableDatasetDict | |
| datasets.IterableDataset | |
) = datasets.load_dataset( | |
"gaia-benchmark/GAIA", | |
"2023_all", | |
) | |
gaia_dataset.save_to_disk(datasets_dir) | |
# dict_keys(['task_id', 'Question', 'Level', 'Final answer', 'file_name', 'file_path', 'Annotator Metadata']) | |
gaia_dataset_list = ( | |
gaia_dataset["test"].to_list() + gaia_dataset["validation"].to_list() | |
) | |
# Convert dataset entries into Document objects | |
docs: list[Document] = [ | |
Document( | |
page_content="\n".join( | |
[ | |
f"task_id: {gdl['task_id']}", | |
f"Question: {gdl['Question']}", | |
f"Final answer: {gdl['Final answer']}", | |
] | |
), | |
metadata={"Question": gdl["Question"]}, | |
) | |
for gdl in gaia_dataset_list | |
] | |
return docs | |
class EncyclopediaRetrieveInput(BaseModel): | |
question: str = Field(description="使用者欲搜尋的完整問題。") | |
class EncyclopediaRetriever: | |
def __init__(self, needed_doc_names: list[str], base_path: str): | |
self.bm25_retriever = BM25Retriever.from_documents( | |
self.prepare_docs(needed_doc_names, base_path) | |
) | |
def prepare_docs(self, needed_doc_names: list[str], base_path: str): | |
""" | |
準備所需的 Document 文件列表。 | |
Args: | |
needed_doc_names (list[str]): 需要載入的百科資料集合名稱列表。 | |
base_path (str): 存放本地資料集的基礎路徑。 | |
Returns: | |
list[Document]: 經由所有指定來源整合而成的 Document 物件列表。 | |
說明: | |
根據傳入的資料集名稱逐一載入相關文件,支援多來源文檔的彙整。 | |
目前僅支援 "gaia" 資料集,其它來源可根據需求擴充。 | |
""" | |
docs = [] | |
for ndn in needed_doc_names: | |
if ndn == "gaia": | |
docs.extend(EncyclopediaDataSets.gaia(base_path)) | |
return docs | |
def get_related_question(self, question: str) -> str: | |
""" | |
依據輸入問題檢索相關百科內容。 | |
Args: | |
question (str): 使用者欲搜尋的完整問題。 | |
Returns: | |
str: 與問題最相關的百科內容(文本格式),如無符合則傳回提示訊息。 | |
說明: | |
本方法會使用 BM25 向量檢索器對 Document 集合進行檢索,回傳結果內容合併為字串輸出。 | |
""" | |
results: list[Document] = self.bm25_retriever.invoke(question) | |
results_in_str: str = "No matching guest information found." | |
if results: | |
results_in_str: str = "\n\n".join([doc.page_content for doc in results]) | |
return results_in_str | |