|
import json |
|
import logging |
|
import os |
|
import shutil |
|
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple |
|
|
|
from datasets import Dataset as HFDataset |
|
from langchain_core.documents import Document as LCDocument |
|
from pie_datasets import Dataset, DatasetDict, concatenate_datasets |
|
from pytorch_ie.documents import TextBasedDocument |
|
|
|
from .pie_document_store import PieDocumentStore |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class DatasetsPieDocumentStore(PieDocumentStore): |
|
"""PIE Document store that uses Huggingface Datasets as the backend.""" |
|
|
|
def __init__(self) -> None: |
|
self._data: Optional[Dataset] = None |
|
|
|
self._keys: Dict[str, int] = {} |
|
self._metadata: Dict[str, Any] = {} |
|
|
|
def __len__(self): |
|
return len(self._keys) |
|
|
|
def _get_pie_docs_by_indices(self, indices: Iterable[int]) -> Sequence[TextBasedDocument]: |
|
if self._data is None: |
|
return [] |
|
return self._data.apply_hf_func(func=HFDataset.select, indices=indices) |
|
|
|
def mget(self, keys: Sequence[str]) -> List[LCDocument]: |
|
if self._data is None or len(keys) == 0: |
|
return [] |
|
keys_in_data = [key for key in keys if key in self._keys] |
|
indices = [self._keys[key] for key in keys_in_data] |
|
dataset = self._get_pie_docs_by_indices(indices) |
|
metadatas = [self._metadata.get(key, {}) for key in keys_in_data] |
|
return [self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(dataset, metadatas)] |
|
|
|
def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None: |
|
if len(items) == 0: |
|
return |
|
keys, new_docs = zip(*items) |
|
pie_docs, metadatas = zip(*[self.unwrap_with_metadata(doc) for doc in new_docs]) |
|
if self._data is None: |
|
idx_start = 0 |
|
self._data = Dataset.from_documents(pie_docs) |
|
else: |
|
|
|
|
|
dataset = Dataset.from_documents(pie_docs, features=self._data.features) |
|
idx_start = len(self._data) |
|
self._data = concatenate_datasets([self._data, dataset], clear_metadata=False) |
|
keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)} |
|
self._keys.update(keys_dict) |
|
self._metadata.update( |
|
{key: metadata for key, metadata in zip(keys, metadatas) if metadata} |
|
) |
|
|
|
def add_pie_dataset( |
|
self, |
|
dataset: Dataset, |
|
keys: Optional[List[str]] = None, |
|
metadatas: Optional[List[Dict[str, Any]]] = None, |
|
) -> None: |
|
if len(dataset) == 0: |
|
return |
|
if keys is None: |
|
keys = [doc.id for doc in dataset] |
|
if len(keys) != len(set(keys)): |
|
raise ValueError("Keys must be unique.") |
|
if None in keys: |
|
raise ValueError("Keys must not be None.") |
|
if metadatas is None: |
|
metadatas = [{} for _ in range(len(dataset))] |
|
if len(keys) != len(dataset) or len(keys) != len(metadatas): |
|
raise ValueError("Keys, dataset and metadatas must have the same length.") |
|
|
|
if self._data is None: |
|
idx_start = 0 |
|
self._data = dataset |
|
else: |
|
idx_start = len(self._data) |
|
self._data = concatenate_datasets([self._data, dataset], clear_metadata=False) |
|
keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)} |
|
self._keys.update(keys_dict) |
|
metadatas_dict = {key: metadata for key, metadata in zip(keys, metadatas) if metadata} |
|
self._metadata.update(metadatas_dict) |
|
|
|
def mdelete(self, keys: Sequence[str]) -> None: |
|
for key in keys: |
|
idx = self._keys.pop(key, None) |
|
if idx is not None: |
|
self._metadata.pop(key, None) |
|
|
|
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: |
|
return (key for key in self._keys if prefix is None or key.startswith(prefix)) |
|
|
|
def _purge_invalid_entries(self): |
|
if self._data is None or len(self._keys) == len(self._data): |
|
return |
|
self._data = self._get_pie_docs_by_indices(self._keys.values()) |
|
|
|
def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None: |
|
self._purge_invalid_entries() |
|
if len(self) == 0: |
|
logger.warning("No documents to save.") |
|
return |
|
|
|
all_doc_ids = list(self._keys) |
|
all_metadatas: List[Dict[str, Any]] = [self._metadata.get(key, {}) for key in all_doc_ids] |
|
pie_documents_path = os.path.join(path, "pie_documents") |
|
if os.path.exists(pie_documents_path): |
|
|
|
logger.warning(f"Removing existing directory: {pie_documents_path}") |
|
shutil.rmtree(pie_documents_path) |
|
os.makedirs(pie_documents_path, exist_ok=True) |
|
DatasetDict({"train": self._data}).to_json(pie_documents_path) |
|
doc_ids_path = os.path.join(path, "doc_ids.json") |
|
with open(doc_ids_path, "w") as f: |
|
json.dump(all_doc_ids, f) |
|
metadata_path = os.path.join(path, "metadata.json") |
|
with open(metadata_path, "w") as f: |
|
json.dump(all_metadatas, f) |
|
|
|
def _load_from_directory(self, path: str, **kwargs) -> None: |
|
doc_ids_path = os.path.join(path, "doc_ids.json") |
|
if os.path.exists(doc_ids_path): |
|
with open(doc_ids_path, "r") as f: |
|
all_doc_ids = json.load(f) |
|
else: |
|
logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.") |
|
all_doc_ids = None |
|
metadata_path = os.path.join(path, "metadata.json") |
|
if os.path.exists(metadata_path): |
|
with open(metadata_path, "r") as f: |
|
all_metadata = json.load(f) |
|
else: |
|
logger.warning(f"File {metadata_path} does not exist, don't load any metadata.") |
|
all_metadata = None |
|
pie_documents_path = os.path.join(path, "pie_documents") |
|
if not os.path.exists(pie_documents_path): |
|
logger.warning( |
|
f"Directory {pie_documents_path} does not exist, don't load any documents." |
|
) |
|
return None |
|
|
|
|
|
features = self._data.features if self._data is not None else None |
|
pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path, features=features) |
|
pie_docs = pie_dataset["train"] |
|
self.add_pie_dataset(pie_docs, keys=all_doc_ids, metadatas=all_metadata) |
|
logger.info(f"Loaded {len(pie_docs)} documents from {path} into docstore") |
|
|