File size: 4,311 Bytes
2cc87ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import json
import logging
import os
import shutil
from itertools import islice
from typing import Iterator, List, Optional, Sequence, Tuple
from langchain.storage import create_kv_docstore
from langchain_core.documents import Document as LCDocument
from langchain_core.stores import BaseStore, ByteStore
from pie_datasets import Dataset, DatasetDict
from .pie_document_store import PieDocumentStore
logger = logging.getLogger(__name__)
class BasicPieDocumentStore(PieDocumentStore):
"""PIE Document store that uses a client to store and retrieve documents."""
def __init__(
self,
client: Optional[BaseStore[str, LCDocument]] = None,
byte_store: Optional[ByteStore] = None,
):
if byte_store is not None:
client = create_kv_docstore(byte_store)
elif client is None:
raise Exception("You must pass a `byte_store` parameter.")
self.client = client
def mget(self, keys: Sequence[str]) -> List[LCDocument]:
return self.client.mget(keys)
def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
self.client.mset(items)
def mdelete(self, keys: Sequence[str]) -> None:
self.client.mdelete(keys)
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
return self.client.yield_keys(prefix=prefix)
def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None:
all_doc_ids = []
all_metadata = []
pie_documents_path = os.path.join(path, "pie_documents")
if os.path.exists(pie_documents_path):
# remove existing directory
logger.warning(f"Removing existing directory: {pie_documents_path}")
shutil.rmtree(pie_documents_path)
os.makedirs(pie_documents_path, exist_ok=True)
doc_ids_iter = iter(self.client.yield_keys())
while batch_doc_ids := list(islice(doc_ids_iter, batch_size or 1000)):
all_doc_ids.extend(batch_doc_ids)
docs = self.client.mget(batch_doc_ids)
pie_docs = []
for doc in docs:
pie_doc = doc.metadata[self.METADATA_KEY_PIE_DOCUMENT]
pie_docs.append(pie_doc)
all_metadata.append(
{k: v for k, v in doc.metadata.items() if k != self.METADATA_KEY_PIE_DOCUMENT}
)
pie_dataset = Dataset.from_documents(pie_docs)
DatasetDict({"train": pie_dataset}).to_json(path=pie_documents_path)
if len(all_doc_ids) > 0:
doc_ids_path = os.path.join(path, "doc_ids.json")
with open(doc_ids_path, "w") as f:
json.dump(all_doc_ids, f)
if len(all_metadata) > 0:
metadata_path = os.path.join(path, "metadata.json")
with open(metadata_path, "w") as f:
json.dump(all_metadata, f)
def _load_from_directory(self, path: str, **kwargs) -> None:
pie_documents_path = os.path.join(path, "pie_documents")
if not os.path.exists(pie_documents_path):
logger.warning(
f"Directory {pie_documents_path} does not exist, don't load any documents."
)
return None
pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path)
pie_docs = pie_dataset["train"]
metadata_path = os.path.join(path, "metadata.json")
if os.path.exists(metadata_path):
with open(metadata_path, "r") as f:
all_metadata = json.load(f)
else:
logger.warning(f"File {metadata_path} does not exist, don't load any metadata.")
all_metadata = [{} for _ in pie_docs]
docs = [
self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(pie_docs, all_metadata)
]
doc_ids_path = os.path.join(path, "doc_ids.json")
if os.path.exists(doc_ids_path):
with open(doc_ids_path, "r") as f:
all_doc_ids = json.load(f)
else:
logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.")
all_doc_ids = [doc.id for doc in pie_docs]
self.client.mset(zip(all_doc_ids, docs))
logger.info(f"Loaded {len(docs)} documents from {path} into docstore")
|