File size: 4,311 Bytes
2cc87ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import json
import logging
import os
import shutil
from itertools import islice
from typing import Iterator, List, Optional, Sequence, Tuple

from langchain.storage import create_kv_docstore
from langchain_core.documents import Document as LCDocument
from langchain_core.stores import BaseStore, ByteStore
from pie_datasets import Dataset, DatasetDict

from .pie_document_store import PieDocumentStore

logger = logging.getLogger(__name__)


class BasicPieDocumentStore(PieDocumentStore):
    """PIE Document store that uses a client to store and retrieve documents."""

    def __init__(
        self,
        client: Optional[BaseStore[str, LCDocument]] = None,
        byte_store: Optional[ByteStore] = None,
    ):
        if byte_store is not None:
            client = create_kv_docstore(byte_store)
        elif client is None:
            raise Exception("You must pass a `byte_store` parameter.")

        self.client = client

    def mget(self, keys: Sequence[str]) -> List[LCDocument]:
        return self.client.mget(keys)

    def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
        self.client.mset(items)

    def mdelete(self, keys: Sequence[str]) -> None:
        self.client.mdelete(keys)

    def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
        return self.client.yield_keys(prefix=prefix)

    def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None:
        all_doc_ids = []
        all_metadata = []
        pie_documents_path = os.path.join(path, "pie_documents")
        if os.path.exists(pie_documents_path):
            # remove existing directory
            logger.warning(f"Removing existing directory: {pie_documents_path}")
            shutil.rmtree(pie_documents_path)
        os.makedirs(pie_documents_path, exist_ok=True)
        doc_ids_iter = iter(self.client.yield_keys())
        while batch_doc_ids := list(islice(doc_ids_iter, batch_size or 1000)):
            all_doc_ids.extend(batch_doc_ids)
            docs = self.client.mget(batch_doc_ids)
            pie_docs = []
            for doc in docs:
                pie_doc = doc.metadata[self.METADATA_KEY_PIE_DOCUMENT]
                pie_docs.append(pie_doc)
                all_metadata.append(
                    {k: v for k, v in doc.metadata.items() if k != self.METADATA_KEY_PIE_DOCUMENT}
                )
            pie_dataset = Dataset.from_documents(pie_docs)
            DatasetDict({"train": pie_dataset}).to_json(path=pie_documents_path)
        if len(all_doc_ids) > 0:
            doc_ids_path = os.path.join(path, "doc_ids.json")
            with open(doc_ids_path, "w") as f:
                json.dump(all_doc_ids, f)
        if len(all_metadata) > 0:
            metadata_path = os.path.join(path, "metadata.json")
            with open(metadata_path, "w") as f:
                json.dump(all_metadata, f)

    def _load_from_directory(self, path: str, **kwargs) -> None:
        pie_documents_path = os.path.join(path, "pie_documents")
        if not os.path.exists(pie_documents_path):
            logger.warning(
                f"Directory {pie_documents_path} does not exist, don't load any documents."
            )
            return None
        pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path)
        pie_docs = pie_dataset["train"]
        metadata_path = os.path.join(path, "metadata.json")
        if os.path.exists(metadata_path):
            with open(metadata_path, "r") as f:
                all_metadata = json.load(f)
        else:
            logger.warning(f"File {metadata_path} does not exist, don't load any metadata.")
            all_metadata = [{} for _ in pie_docs]
        docs = [
            self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(pie_docs, all_metadata)
        ]
        doc_ids_path = os.path.join(path, "doc_ids.json")
        if os.path.exists(doc_ids_path):
            with open(doc_ids_path, "r") as f:
                all_doc_ids = json.load(f)
        else:
            logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.")
            all_doc_ids = [doc.id for doc in pie_docs]
        self.client.mset(zip(all_doc_ids, docs))
        logger.info(f"Loaded {len(docs)} documents from {path} into docstore")