File size: 4,574 Bytes
129cd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""Retriever wrapper for Google Cloud Document AI Warehouse."""
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from langchain_core.pydantic_v1 import root_validator
from langchain_core.retrievers import BaseRetriever

from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.docstore.document import Document
from langchain.utilities.vertexai import get_client_info
from langchain.utils import get_from_dict_or_env

if TYPE_CHECKING:
    from google.cloud.contentwarehouse_v1 import (
        DocumentServiceClient,
        RequestMetadata,
        SearchDocumentsRequest,
    )
    from google.cloud.contentwarehouse_v1.services.document_service.pagers import (
        SearchDocumentsPager,
    )


class GoogleDocumentAIWarehouseRetriever(BaseRetriever):
    """A retriever based on Document AI Warehouse.

    Documents should be created and documents should be uploaded
        in a separate flow, and this retriever uses only Document AI
        schema_id provided to search for revelant documents.

    More info: https://cloud.google.com/document-ai-warehouse.
    """

    location: str = "us"
    """Google Cloud location where Document AI Warehouse is placed."""
    project_number: str
    """Google Cloud project number, should contain digits only."""
    schema_id: Optional[str] = None
    """Document AI Warehouse schema to query against.
    If nothing is provided, all documents in the project will be searched."""
    qa_size_limit: int = 5
    """The limit on the number of documents returned."""
    client: "DocumentServiceClient" = None  #: :meta private:

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validates the environment."""
        try:  # noqa: F401
            from google.cloud.contentwarehouse_v1 import DocumentServiceClient
        except ImportError as exc:
            raise ImportError(
                "google.cloud.contentwarehouse is not installed."
                "Please install it with pip install google-cloud-contentwarehouse"
            ) from exc

        values["project_number"] = get_from_dict_or_env(
            values, "project_number", "PROJECT_NUMBER"
        )
        values["client"] = DocumentServiceClient(
            client_info=get_client_info(module="document-ai-warehouse")
        )
        return values

    def _prepare_request_metadata(self, user_ldap: str) -> "RequestMetadata":
        from google.cloud.contentwarehouse_v1 import RequestMetadata, UserInfo

        user_info = UserInfo(id=f"user:{user_ldap}")
        return RequestMetadata(user_info=user_info)

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
    ) -> List[Document]:
        request = self._prepare_search_request(query, **kwargs)
        response = self.client.search_documents(request=request)
        return self._parse_search_response(response=response)

    def _prepare_search_request(
        self, query: str, **kwargs: Any
    ) -> "SearchDocumentsRequest":
        from google.cloud.contentwarehouse_v1 import (
            DocumentQuery,
            SearchDocumentsRequest,
        )

        try:
            user_ldap = kwargs["user_ldap"]
        except KeyError:
            raise ValueError("Argument user_ldap should be provided!")

        request_metadata = self._prepare_request_metadata(user_ldap=user_ldap)
        schemas = []
        if self.schema_id:
            schemas.append(
                self.client.document_schema_path(
                    project=self.project_number,
                    location=self.location,
                    document_schema=self.schema_id,
                )
            )
        return SearchDocumentsRequest(
            parent=self.client.common_location_path(self.project_number, self.location),
            request_metadata=request_metadata,
            document_query=DocumentQuery(
                query=query, is_nl_query=True, document_schema_names=schemas
            ),
            qa_size_limit=self.qa_size_limit,
        )

    def _parse_search_response(
        self, response: "SearchDocumentsPager"
    ) -> List[Document]:
        documents = []
        for doc in response.matching_documents:
            metadata = {
                "title": doc.document.title,
                "source": doc.document.raw_document_path,
            }
            documents.append(
                Document(page_content=doc.search_text_snippet, metadata=metadata)
            )
        return documents