ScientificArgumentRecommender / src /langchain_modules /datasets_pie_document_store.py
ArneBinder's picture
new demo setup with langchain retriever
2cc87ec verified
import json
import logging
import os
import shutil
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple
from datasets import Dataset as HFDataset
from langchain_core.documents import Document as LCDocument
from pie_datasets import Dataset, DatasetDict, concatenate_datasets
from pytorch_ie.documents import TextBasedDocument
from .pie_document_store import PieDocumentStore
logger = logging.getLogger(__name__)
class DatasetsPieDocumentStore(PieDocumentStore):
"""PIE Document store that uses Huggingface Datasets as the backend."""
def __init__(self) -> None:
self._data: Optional[Dataset] = None
# keys map to indices in the dataset
self._keys: Dict[str, int] = {}
self._metadata: Dict[str, Any] = {}
def __len__(self):
return len(self._keys)
def _get_pie_docs_by_indices(self, indices: Iterable[int]) -> Sequence[TextBasedDocument]:
if self._data is None:
return []
return self._data.apply_hf_func(func=HFDataset.select, indices=indices)
def mget(self, keys: Sequence[str]) -> List[LCDocument]:
if self._data is None or len(keys) == 0:
return []
keys_in_data = [key for key in keys if key in self._keys]
indices = [self._keys[key] for key in keys_in_data]
dataset = self._get_pie_docs_by_indices(indices)
metadatas = [self._metadata.get(key, {}) for key in keys_in_data]
return [self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(dataset, metadatas)]
def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
if len(items) == 0:
return
keys, new_docs = zip(*items)
pie_docs, metadatas = zip(*[self.unwrap_with_metadata(doc) for doc in new_docs])
if self._data is None:
idx_start = 0
self._data = Dataset.from_documents(pie_docs)
else:
# we pass the features to the new dataset to mitigate issues caused by
# slightly different inferred features
dataset = Dataset.from_documents(pie_docs, features=self._data.features)
idx_start = len(self._data)
self._data = concatenate_datasets([self._data, dataset], clear_metadata=False)
keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)}
self._keys.update(keys_dict)
self._metadata.update(
{key: metadata for key, metadata in zip(keys, metadatas) if metadata}
)
def add_pie_dataset(
self,
dataset: Dataset,
keys: Optional[List[str]] = None,
metadatas: Optional[List[Dict[str, Any]]] = None,
) -> None:
if len(dataset) == 0:
return
if keys is None:
keys = [doc.id for doc in dataset]
if len(keys) != len(set(keys)):
raise ValueError("Keys must be unique.")
if None in keys:
raise ValueError("Keys must not be None.")
if metadatas is None:
metadatas = [{} for _ in range(len(dataset))]
if len(keys) != len(dataset) or len(keys) != len(metadatas):
raise ValueError("Keys, dataset and metadatas must have the same length.")
if self._data is None:
idx_start = 0
self._data = dataset
else:
idx_start = len(self._data)
self._data = concatenate_datasets([self._data, dataset], clear_metadata=False)
keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)}
self._keys.update(keys_dict)
metadatas_dict = {key: metadata for key, metadata in zip(keys, metadatas) if metadata}
self._metadata.update(metadatas_dict)
def mdelete(self, keys: Sequence[str]) -> None:
for key in keys:
idx = self._keys.pop(key, None)
if idx is not None:
self._metadata.pop(key, None)
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
return (key for key in self._keys if prefix is None or key.startswith(prefix))
def _purge_invalid_entries(self):
if self._data is None or len(self._keys) == len(self._data):
return
self._data = self._get_pie_docs_by_indices(self._keys.values())
def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None:
self._purge_invalid_entries()
if len(self) == 0:
logger.warning("No documents to save.")
return
all_doc_ids = list(self._keys)
all_metadatas: List[Dict[str, Any]] = [self._metadata.get(key, {}) for key in all_doc_ids]
pie_documents_path = os.path.join(path, "pie_documents")
if os.path.exists(pie_documents_path):
# remove existing directory
logger.warning(f"Removing existing directory: {pie_documents_path}")
shutil.rmtree(pie_documents_path)
os.makedirs(pie_documents_path, exist_ok=True)
DatasetDict({"train": self._data}).to_json(pie_documents_path)
doc_ids_path = os.path.join(path, "doc_ids.json")
with open(doc_ids_path, "w") as f:
json.dump(all_doc_ids, f)
metadata_path = os.path.join(path, "metadata.json")
with open(metadata_path, "w") as f:
json.dump(all_metadatas, f)
def _load_from_directory(self, path: str, **kwargs) -> None:
doc_ids_path = os.path.join(path, "doc_ids.json")
if os.path.exists(doc_ids_path):
with open(doc_ids_path, "r") as f:
all_doc_ids = json.load(f)
else:
logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.")
all_doc_ids = None
metadata_path = os.path.join(path, "metadata.json")
if os.path.exists(metadata_path):
with open(metadata_path, "r") as f:
all_metadata = json.load(f)
else:
logger.warning(f"File {metadata_path} does not exist, don't load any metadata.")
all_metadata = None
pie_documents_path = os.path.join(path, "pie_documents")
if not os.path.exists(pie_documents_path):
logger.warning(
f"Directory {pie_documents_path} does not exist, don't load any documents."
)
return None
# If we have a dataset already loaded, we use its features to load the new dataset
# This is to mitigate issues caused by slightly different inferred features.
features = self._data.features if self._data is not None else None
pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path, features=features)
pie_docs = pie_dataset["train"]
self.add_pie_dataset(pie_docs, keys=all_doc_ids, metadatas=all_metadata)
logger.info(f"Loaded {len(pie_docs)} documents from {path} into docstore")