ArneBinder's picture
new demo setup with langchain retriever
2cc87ec verified
import json
import logging
import os
import shutil
from itertools import islice
from typing import Iterator, List, Optional, Sequence, Tuple
from langchain.storage import create_kv_docstore
from langchain_core.documents import Document as LCDocument
from langchain_core.stores import BaseStore, ByteStore
from pie_datasets import Dataset, DatasetDict
from .pie_document_store import PieDocumentStore
logger = logging.getLogger(__name__)
class BasicPieDocumentStore(PieDocumentStore):
"""PIE Document store that uses a client to store and retrieve documents."""
def __init__(
self,
client: Optional[BaseStore[str, LCDocument]] = None,
byte_store: Optional[ByteStore] = None,
):
if byte_store is not None:
client = create_kv_docstore(byte_store)
elif client is None:
raise Exception("You must pass a `byte_store` parameter.")
self.client = client
def mget(self, keys: Sequence[str]) -> List[LCDocument]:
return self.client.mget(keys)
def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
self.client.mset(items)
def mdelete(self, keys: Sequence[str]) -> None:
self.client.mdelete(keys)
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
return self.client.yield_keys(prefix=prefix)
def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None:
all_doc_ids = []
all_metadata = []
pie_documents_path = os.path.join(path, "pie_documents")
if os.path.exists(pie_documents_path):
# remove existing directory
logger.warning(f"Removing existing directory: {pie_documents_path}")
shutil.rmtree(pie_documents_path)
os.makedirs(pie_documents_path, exist_ok=True)
doc_ids_iter = iter(self.client.yield_keys())
while batch_doc_ids := list(islice(doc_ids_iter, batch_size or 1000)):
all_doc_ids.extend(batch_doc_ids)
docs = self.client.mget(batch_doc_ids)
pie_docs = []
for doc in docs:
pie_doc = doc.metadata[self.METADATA_KEY_PIE_DOCUMENT]
pie_docs.append(pie_doc)
all_metadata.append(
{k: v for k, v in doc.metadata.items() if k != self.METADATA_KEY_PIE_DOCUMENT}
)
pie_dataset = Dataset.from_documents(pie_docs)
DatasetDict({"train": pie_dataset}).to_json(path=pie_documents_path)
if len(all_doc_ids) > 0:
doc_ids_path = os.path.join(path, "doc_ids.json")
with open(doc_ids_path, "w") as f:
json.dump(all_doc_ids, f)
if len(all_metadata) > 0:
metadata_path = os.path.join(path, "metadata.json")
with open(metadata_path, "w") as f:
json.dump(all_metadata, f)
def _load_from_directory(self, path: str, **kwargs) -> None:
pie_documents_path = os.path.join(path, "pie_documents")
if not os.path.exists(pie_documents_path):
logger.warning(
f"Directory {pie_documents_path} does not exist, don't load any documents."
)
return None
pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path)
pie_docs = pie_dataset["train"]
metadata_path = os.path.join(path, "metadata.json")
if os.path.exists(metadata_path):
with open(metadata_path, "r") as f:
all_metadata = json.load(f)
else:
logger.warning(f"File {metadata_path} does not exist, don't load any metadata.")
all_metadata = [{} for _ in pie_docs]
docs = [
self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(pie_docs, all_metadata)
]
doc_ids_path = os.path.join(path, "doc_ids.json")
if os.path.exists(doc_ids_path):
with open(doc_ids_path, "r") as f:
all_doc_ids = json.load(f)
else:
logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.")
all_doc_ids = [doc.id for doc in pie_docs]
self.client.mset(zip(all_doc_ids, docs))
logger.info(f"Loaded {len(docs)} documents from {path} into docstore")