|
import json |
|
import logging |
|
import os |
|
import shutil |
|
from itertools import islice |
|
from typing import Iterator, List, Optional, Sequence, Tuple |
|
|
|
from langchain.storage import create_kv_docstore |
|
from langchain_core.documents import Document as LCDocument |
|
from langchain_core.stores import BaseStore, ByteStore |
|
from pie_datasets import Dataset, DatasetDict |
|
|
|
from .pie_document_store import PieDocumentStore |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class BasicPieDocumentStore(PieDocumentStore): |
|
"""PIE Document store that uses a client to store and retrieve documents.""" |
|
|
|
def __init__( |
|
self, |
|
client: Optional[BaseStore[str, LCDocument]] = None, |
|
byte_store: Optional[ByteStore] = None, |
|
): |
|
if byte_store is not None: |
|
client = create_kv_docstore(byte_store) |
|
elif client is None: |
|
raise Exception("You must pass a `byte_store` parameter.") |
|
|
|
self.client = client |
|
|
|
def mget(self, keys: Sequence[str]) -> List[LCDocument]: |
|
return self.client.mget(keys) |
|
|
|
def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None: |
|
self.client.mset(items) |
|
|
|
def mdelete(self, keys: Sequence[str]) -> None: |
|
self.client.mdelete(keys) |
|
|
|
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: |
|
return self.client.yield_keys(prefix=prefix) |
|
|
|
def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None: |
|
all_doc_ids = [] |
|
all_metadata = [] |
|
pie_documents_path = os.path.join(path, "pie_documents") |
|
if os.path.exists(pie_documents_path): |
|
|
|
logger.warning(f"Removing existing directory: {pie_documents_path}") |
|
shutil.rmtree(pie_documents_path) |
|
os.makedirs(pie_documents_path, exist_ok=True) |
|
doc_ids_iter = iter(self.client.yield_keys()) |
|
while batch_doc_ids := list(islice(doc_ids_iter, batch_size or 1000)): |
|
all_doc_ids.extend(batch_doc_ids) |
|
docs = self.client.mget(batch_doc_ids) |
|
pie_docs = [] |
|
for doc in docs: |
|
pie_doc = doc.metadata[self.METADATA_KEY_PIE_DOCUMENT] |
|
pie_docs.append(pie_doc) |
|
all_metadata.append( |
|
{k: v for k, v in doc.metadata.items() if k != self.METADATA_KEY_PIE_DOCUMENT} |
|
) |
|
pie_dataset = Dataset.from_documents(pie_docs) |
|
DatasetDict({"train": pie_dataset}).to_json(path=pie_documents_path) |
|
if len(all_doc_ids) > 0: |
|
doc_ids_path = os.path.join(path, "doc_ids.json") |
|
with open(doc_ids_path, "w") as f: |
|
json.dump(all_doc_ids, f) |
|
if len(all_metadata) > 0: |
|
metadata_path = os.path.join(path, "metadata.json") |
|
with open(metadata_path, "w") as f: |
|
json.dump(all_metadata, f) |
|
|
|
def _load_from_directory(self, path: str, **kwargs) -> None: |
|
pie_documents_path = os.path.join(path, "pie_documents") |
|
if not os.path.exists(pie_documents_path): |
|
logger.warning( |
|
f"Directory {pie_documents_path} does not exist, don't load any documents." |
|
) |
|
return None |
|
pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path) |
|
pie_docs = pie_dataset["train"] |
|
metadata_path = os.path.join(path, "metadata.json") |
|
if os.path.exists(metadata_path): |
|
with open(metadata_path, "r") as f: |
|
all_metadata = json.load(f) |
|
else: |
|
logger.warning(f"File {metadata_path} does not exist, don't load any metadata.") |
|
all_metadata = [{} for _ in pie_docs] |
|
docs = [ |
|
self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(pie_docs, all_metadata) |
|
] |
|
doc_ids_path = os.path.join(path, "doc_ids.json") |
|
if os.path.exists(doc_ids_path): |
|
with open(doc_ids_path, "r") as f: |
|
all_doc_ids = json.load(f) |
|
else: |
|
logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.") |
|
all_doc_ids = [doc.id for doc in pie_docs] |
|
self.client.mset(zip(all_doc_ids, docs)) |
|
logger.info(f"Loaded {len(docs)} documents from {path} into docstore") |
|
|