File size: 6,202 Bytes
129cd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from __future__ import annotations

from typing import Any, Dict, List, Optional, cast
from uuid import uuid4

from langchain_core.pydantic_v1 import root_validator
from langchain_core.retrievers import BaseRetriever

from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.docstore.document import Document


class WeaviateHybridSearchRetriever(BaseRetriever):
    """`Weaviate hybrid search` retriever.

    See the documentation:
      https://weaviate.io/blog/hybrid-search-explained
    """

    client: Any
    """keyword arguments to pass to the Weaviate client."""
    index_name: str
    """The name of the index to use."""
    text_key: str
    """The name of the text key to use."""
    alpha: float = 0.5
    """The weight of the text key in the hybrid search."""
    k: int = 4
    """The number of results to return."""
    attributes: List[str]
    """The attributes to return in the results."""
    create_schema_if_missing: bool = True
    """Whether to create the schema if it doesn't exist."""

    @root_validator(pre=True)
    def validate_client(
        cls,
        values: Dict[str, Any],
    ) -> Dict[str, Any]:
        try:
            import weaviate
        except ImportError:
            raise ImportError(
                "Could not import weaviate python package. "
                "Please install it with `pip install weaviate-client`."
            )
        if not isinstance(values["client"], weaviate.Client):
            client = values["client"]
            raise ValueError(
                f"client should be an instance of weaviate.Client, got {type(client)}"
            )
        if values.get("attributes") is None:
            values["attributes"] = []

        cast(List, values["attributes"]).append(values["text_key"])

        if values.get("create_schema_if_missing", True):
            class_obj = {
                "class": values["index_name"],
                "properties": [{"name": values["text_key"], "dataType": ["text"]}],
                "vectorizer": "text2vec-openai",
            }

            if not values["client"].schema.exists(values["index_name"]):
                values["client"].schema.create_class(class_obj)

        return values

    class Config:
        """Configuration for this pydantic object."""

        arbitrary_types_allowed = True

    # added text_key
    def add_documents(self, docs: List[Document], **kwargs: Any) -> List[str]:
        """Upload documents to Weaviate."""
        from weaviate.util import get_valid_uuid

        with self.client.batch as batch:
            ids = []
            for i, doc in enumerate(docs):
                metadata = doc.metadata or {}
                data_properties = {self.text_key: doc.page_content, **metadata}

                # If the UUID of one of the objects already exists
                # then the existing objectwill be replaced by the new object.
                if "uuids" in kwargs:
                    _id = kwargs["uuids"][i]
                else:
                    _id = get_valid_uuid(uuid4())

                batch.add_data_object(data_properties, self.index_name, _id)
                ids.append(_id)
        return ids

    def _get_relevant_documents(
        self,
        query: str,
        *,
        run_manager: CallbackManagerForRetrieverRun,
        where_filter: Optional[Dict[str, object]] = None,
        score: bool = False,
        hybrid_search_kwargs: Optional[Dict[str, object]] = None,
    ) -> List[Document]:
        """Look up similar documents in Weaviate.

        query: The query to search for relevant documents
         of using weviate hybrid search.

        where_filter: A filter to apply to the query.
            https://weaviate.io/developers/weaviate/guides/querying/#filtering

        score: Whether to include the score, and score explanation
            in the returned Documents meta_data.

        hybrid_search_kwargs: Used to pass additional arguments
         to the .with_hybrid() method.
            The primary uses cases for this are:
            1)  Search specific properties only -
                specify which properties to be used during hybrid search portion.
                Note: this is not the same as the (self.attributes) to be returned.
                Example - hybrid_search_kwargs={"properties": ["question", "answer"]}
            https://weaviate.io/developers/weaviate/search/hybrid#selected-properties-only

            2) Weight boosted searched properties -
                Boost the weight of certain properties during the hybrid search portion.
                Example - hybrid_search_kwargs={"properties": ["question^2", "answer"]}
            https://weaviate.io/developers/weaviate/search/hybrid#weight-boost-searched-properties

            3) Search with a custom vector - Define a different vector
                to be used during the hybrid search portion.
                Example - hybrid_search_kwargs={"vector": [0.1, 0.2, 0.3, ...]}
            https://weaviate.io/developers/weaviate/search/hybrid#with-a-custom-vector

            4) Use Fusion ranking method
                Example - from weaviate.gql.get import HybridFusion
                hybrid_search_kwargs={"fusion": fusion_type=HybridFusion.RELATIVE_SCORE}
            https://weaviate.io/developers/weaviate/search/hybrid#fusion-ranking-method
        """
        query_obj = self.client.query.get(self.index_name, self.attributes)
        if where_filter:
            query_obj = query_obj.with_where(where_filter)

        if score:
            query_obj = query_obj.with_additional(["score", "explainScore"])

        if hybrid_search_kwargs is None:
            hybrid_search_kwargs = {}

        result = (
            query_obj.with_hybrid(query, alpha=self.alpha, **hybrid_search_kwargs)
            .with_limit(self.k)
            .do()
        )
        if "errors" in result:
            raise ValueError(f"Error during query: {result['errors']}")

        docs = []

        for res in result["data"]["Get"][self.index_name]:
            text = res.pop(self.text_key)
            docs.append(Document(page_content=text, metadata=res))
        return docs