File size: 8,275 Bytes
129cd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""Chain for interacting with Elasticsearch Database."""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, Optional

from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseLLMOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Extra, root_validator

from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.elasticsearch_database.prompts import ANSWER_PROMPT, DSL_PROMPT
from langchain.chains.llm import LLMChain
from langchain.output_parsers.json import SimpleJsonOutputParser

if TYPE_CHECKING:
    from elasticsearch import Elasticsearch

INTERMEDIATE_STEPS_KEY = "intermediate_steps"


class ElasticsearchDatabaseChain(Chain):
    """Chain for interacting with Elasticsearch Database.

    Example:
        .. code-block:: python

            from langchain.chains import ElasticsearchDatabaseChain
            from langchain.llms import OpenAI
            from elasticsearch import Elasticsearch

            database = Elasticsearch("http://localhost:9200")
            db_chain = ElasticsearchDatabaseChain.from_llm(OpenAI(), database)
    """

    query_chain: LLMChain
    """Chain for creating the ES query."""
    answer_chain: LLMChain
    """Chain for answering the user question."""
    database: Any
    """Elasticsearch database to connect to of type elasticsearch.Elasticsearch."""
    top_k: int = 10
    """Number of results to return from the query"""
    ignore_indices: Optional[List[str]] = None
    include_indices: Optional[List[str]] = None
    input_key: str = "question"  #: :meta private:
    output_key: str = "result"  #: :meta private:
    sample_documents_in_index_info: int = 3
    return_intermediate_steps: bool = False
    """Whether or not to return the intermediate steps along with the final answer."""

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid
        arbitrary_types_allowed = True

    @root_validator()
    def validate_indices(cls, values: dict) -> dict:
        if values["include_indices"] and values["ignore_indices"]:
            raise ValueError(
                "Cannot specify both 'include_indices' and 'ignore_indices'."
            )
        return values

    @property
    def input_keys(self) -> List[str]:
        """Return the singular input key.

        :meta private:
        """
        return [self.input_key]

    @property
    def output_keys(self) -> List[str]:
        """Return the singular output key.

        :meta private:
        """
        if not self.return_intermediate_steps:
            return [self.output_key]
        else:
            return [self.output_key, INTERMEDIATE_STEPS_KEY]

    def _list_indices(self) -> List[str]:
        all_indices = [
            index["index"] for index in self.database.cat.indices(format="json")
        ]

        if self.include_indices:
            all_indices = [i for i in all_indices if i in self.include_indices]
        if self.ignore_indices:
            all_indices = [i for i in all_indices if i not in self.ignore_indices]

        return all_indices

    def _get_indices_infos(self, indices: List[str]) -> str:
        mappings = self.database.indices.get_mapping(index=",".join(indices))
        if self.sample_documents_in_index_info > 0:
            for k, v in mappings.items():
                hits = self.database.search(
                    index=k,
                    query={"match_all": {}},
                    size=self.sample_documents_in_index_info,
                )["hits"]["hits"]
                hits = [str(hit["_source"]) for hit in hits]
                mappings[k]["mappings"] = str(v) + "\n\n/*\n" + "\n".join(hits) + "\n*/"
        return "\n\n".join(
            [
                "Mapping for index {}:\n{}".format(index, mappings[index]["mappings"])
                for index in mappings
            ]
        )

    def _search(self, indices: List[str], query: str) -> str:
        result = self.database.search(index=",".join(indices), body=query)
        return str(result)

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        input_text = f"{inputs[self.input_key]}\nESQuery:"
        _run_manager.on_text(input_text, verbose=self.verbose)
        indices = self._list_indices()
        indices_info = self._get_indices_infos(indices)
        query_inputs: dict = {
            "input": input_text,
            "top_k": str(self.top_k),
            "indices_info": indices_info,
            "stop": ["\nESResult:"],
        }
        intermediate_steps: List = []
        try:
            intermediate_steps.append(query_inputs)  # input: es generation
            es_cmd = self.query_chain.run(
                callbacks=_run_manager.get_child(),
                **query_inputs,
            )

            _run_manager.on_text(es_cmd, color="green", verbose=self.verbose)
            intermediate_steps.append(
                es_cmd
            )  # output: elasticsearch dsl generation (no checker)
            intermediate_steps.append({"es_cmd": es_cmd})  # input: ES search
            result = self._search(indices=indices, query=es_cmd)
            intermediate_steps.append(str(result))  # output: ES search

            _run_manager.on_text("\nESResult: ", verbose=self.verbose)
            _run_manager.on_text(result, color="yellow", verbose=self.verbose)

            _run_manager.on_text("\nAnswer:", verbose=self.verbose)
            answer_inputs: dict = {"data": result, "input": input_text}
            intermediate_steps.append(answer_inputs)  # input: final answer
            final_result = self.answer_chain.run(
                callbacks=_run_manager.get_child(),
                **answer_inputs,
            )

            intermediate_steps.append(final_result)  # output: final answer
            _run_manager.on_text(final_result, color="green", verbose=self.verbose)
            chain_result: Dict[str, Any] = {self.output_key: final_result}
            if self.return_intermediate_steps:
                chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
            return chain_result
        except Exception as exc:
            # Append intermediate steps to exception, to aid in logging and later
            # improvement of few shot prompt seeds
            exc.intermediate_steps = intermediate_steps  # type: ignore
            raise exc

    @property
    def _chain_type(self) -> str:
        return "elasticsearch_database_chain"

    @classmethod
    def from_llm(
        cls,
        llm: BaseLanguageModel,
        database: Elasticsearch,
        *,
        query_prompt: Optional[BasePromptTemplate] = None,
        answer_prompt: Optional[BasePromptTemplate] = None,
        query_output_parser: Optional[BaseLLMOutputParser] = None,
        **kwargs: Any,
    ) -> ElasticsearchDatabaseChain:
        """Convenience method to construct ElasticsearchDatabaseChain from an LLM.

        Args:
            llm: The language model to use.
            database: The Elasticsearch db.
            query_prompt: The prompt to use for query construction.
            answer_prompt: The prompt to use for answering user question given data.
            query_output_parser: The output parser to use for parsing model-generated
                ES query. Defaults to SimpleJsonOutputParser.
            **kwargs: Additional arguments to pass to the constructor.
        """
        query_prompt = query_prompt or DSL_PROMPT
        query_output_parser = query_output_parser or SimpleJsonOutputParser()
        query_chain = LLMChain(
            llm=llm, prompt=query_prompt, output_parser=query_output_parser
        )
        answer_prompt = answer_prompt or ANSWER_PROMPT
        answer_chain = LLMChain(llm=llm, prompt=answer_prompt)
        return cls(
            query_chain=query_chain,
            answer_chain=answer_chain,
            database=database,
            **kwargs,
        )