File size: 3,996 Bytes
129cd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import logging
from typing import Any, Callable, Dict, Iterator, List, Optional

from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, root_validator

logger = logging.getLogger(__name__)


class TensorflowDatasets(BaseModel):
    """Access to the TensorFlow Datasets.

    The Current implementation can work only with datasets that fit in a memory.

    `TensorFlow Datasets` is a collection of datasets ready to use, with TensorFlow
    or other Python ML frameworks, such as Jax. All datasets are exposed
    as `tf.data.Datasets`.
    To get started see the Guide: https://www.tensorflow.org/datasets/overview and
    the list of datasets: https://www.tensorflow.org/datasets/catalog/
                                               overview#all_datasets

    You have to provide the sample_to_document_function: a function that
       a sample from the dataset-specific format to the Document.

    Attributes:
        dataset_name: the name of the dataset to load
        split_name: the name of the split to load. Defaults to "train".
        load_max_docs: a limit to the number of loaded documents. Defaults to 100.
        sample_to_document_function: a function that converts a dataset sample
          to a Document

    Example:
        .. code-block:: python

            from langchain.utilities import TensorflowDatasets

            def mlqaen_example_to_document(example: dict) -> Document:
                return Document(
                    page_content=decode_to_str(example["context"]),
                    metadata={
                        "id": decode_to_str(example["id"]),
                        "title": decode_to_str(example["title"]),
                        "question": decode_to_str(example["question"]),
                        "answer": decode_to_str(example["answers"]["text"][0]),
                    },
                )

            tsds_client = TensorflowDatasets(
                    dataset_name="mlqa/en",
                    split_name="train",
                    load_max_docs=MAX_DOCS,
                    sample_to_document_function=mlqaen_example_to_document,
                )

    """

    dataset_name: str = ""
    split_name: str = "train"
    load_max_docs: int = 100
    sample_to_document_function: Optional[Callable[[Dict], Document]] = None
    dataset: Any  #: :meta private:

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that the python package exists in environment."""
        try:
            import tensorflow  # noqa: F401
        except ImportError:
            raise ImportError(
                "Could not import tensorflow python package. "
                "Please install it with `pip install tensorflow`."
            )
        try:
            import tensorflow_datasets
        except ImportError:
            raise ImportError(
                "Could not import tensorflow_datasets python package. "
                "Please install it with `pip install tensorflow-datasets`."
            )
        if values["sample_to_document_function"] is None:
            raise ValueError(
                "sample_to_document_function is None. "
                "Please provide a function that converts a dataset sample to"
                "  a Document."
            )
        values["dataset"] = tensorflow_datasets.load(
            values["dataset_name"], split=values["split_name"]
        )

        return values

    def lazy_load(self) -> Iterator[Document]:
        """Download a selected dataset lazily.

        Returns: an iterator of Documents.

        """
        return (
            self.sample_to_document_function(s)
            for s in self.dataset.take(self.load_max_docs)
            if self.sample_to_document_function is not None
        )

    def load(self) -> List[Document]:
        """Download a selected dataset.

        Returns: a list of Documents.

        """
        return list(self.lazy_load())