import abc import logging from copy import copy from typing import Iterator, List, Optional, Sequence, Tuple import pandas as pd from langchain_core.documents import Document as LCDocument from langchain_core.stores import BaseStore from pytorch_ie.documents import TextBasedDocument from .serializable_store import SerializableStore logger = logging.getLogger(__name__) class PieDocumentStore(SerializableStore, BaseStore[str, LCDocument], abc.ABC): """Abstract base class for document stores specialized in storing and retrieving pie documents.""" METADATA_KEY_PIE_DOCUMENT: str = "pie_document" """Key for the pie document in the (langchain) document metadata.""" def wrap(self, pie_document: TextBasedDocument, **metadata) -> LCDocument: """Wrap the pie document in an LCDocument.""" return LCDocument( id=pie_document.id, page_content="", metadata={self.METADATA_KEY_PIE_DOCUMENT: pie_document, **metadata}, ) def unwrap(self, document: LCDocument) -> TextBasedDocument: """Get the pie document from the langchain document.""" return document.metadata[self.METADATA_KEY_PIE_DOCUMENT] def unwrap_with_metadata(self, document: LCDocument) -> Tuple[TextBasedDocument, dict]: """Get the pie document and metadata from the langchain document.""" metadata = copy(document.metadata) pie_document = metadata.pop(self.METADATA_KEY_PIE_DOCUMENT) return pie_document, metadata @abc.abstractmethod def mget(self, keys: Sequence[str]) -> List[LCDocument]: pass @abc.abstractmethod def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None: pass @abc.abstractmethod def mdelete(self, keys: Sequence[str]) -> None: pass @abc.abstractmethod def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: pass def __len__(self): return len(list(self.yield_keys())) def overview(self, layer_captions: dict, use_predictions: bool = False) -> pd.DataFrame: """Get an overview of the document store, including the number of items in each layer for each document in the store. Args: layer_captions: A dictionary mapping layer names to captions. use_predictions: Whether to use predictions instead of the actual layers. Returns: DataFrame: A pandas DataFrame containing the overview. """ rows = [] for doc_id in self.yield_keys(): document = self.mget([doc_id])[0] pie_document = self.unwrap(document) layers = { caption: pie_document[layer_name] for layer_name, caption in layer_captions.items() } layer_sizes = { f"num_{caption}": len(layer) + (len(layer.predictions) if use_predictions else 0) for caption, layer in layers.items() } rows.append({"doc_id": doc_id, **layer_sizes}) df = pd.DataFrame(rows) return df def as_dict(self, document: LCDocument) -> dict: """Convert the langchain document to a dictionary.""" pie_document, metadata = self.unwrap_with_metadata(document) return {self.METADATA_KEY_PIE_DOCUMENT: pie_document.asdict(), "metadata": metadata}