Spaces:
Runtime error
Runtime error
| """Chain for interacting with Elasticsearch Database.""" | |
| from __future__ import annotations | |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional | |
| from langchain_core.language_models import BaseLanguageModel | |
| from langchain_core.output_parsers import BaseLLMOutputParser | |
| from langchain_core.prompts import BasePromptTemplate | |
| from langchain_core.pydantic_v1 import Extra, root_validator | |
| from langchain.callbacks.manager import CallbackManagerForChainRun | |
| from langchain.chains.base import Chain | |
| from langchain.chains.elasticsearch_database.prompts import ANSWER_PROMPT, DSL_PROMPT | |
| from langchain.chains.llm import LLMChain | |
| from langchain.output_parsers.json import SimpleJsonOutputParser | |
| if TYPE_CHECKING: | |
| from elasticsearch import Elasticsearch | |
| INTERMEDIATE_STEPS_KEY = "intermediate_steps" | |
| class ElasticsearchDatabaseChain(Chain): | |
| """Chain for interacting with Elasticsearch Database. | |
| Example: | |
| .. code-block:: python | |
| from langchain.chains import ElasticsearchDatabaseChain | |
| from langchain.llms import OpenAI | |
| from elasticsearch import Elasticsearch | |
| database = Elasticsearch("http://localhost:9200") | |
| db_chain = ElasticsearchDatabaseChain.from_llm(OpenAI(), database) | |
| """ | |
| query_chain: LLMChain | |
| """Chain for creating the ES query.""" | |
| answer_chain: LLMChain | |
| """Chain for answering the user question.""" | |
| database: Any | |
| """Elasticsearch database to connect to of type elasticsearch.Elasticsearch.""" | |
| top_k: int = 10 | |
| """Number of results to return from the query""" | |
| ignore_indices: Optional[List[str]] = None | |
| include_indices: Optional[List[str]] = None | |
| input_key: str = "question" #: :meta private: | |
| output_key: str = "result" #: :meta private: | |
| sample_documents_in_index_info: int = 3 | |
| return_intermediate_steps: bool = False | |
| """Whether or not to return the intermediate steps along with the final answer.""" | |
| class Config: | |
| """Configuration for this pydantic object.""" | |
| extra = Extra.forbid | |
| arbitrary_types_allowed = True | |
| def validate_indices(cls, values: dict) -> dict: | |
| if values["include_indices"] and values["ignore_indices"]: | |
| raise ValueError( | |
| "Cannot specify both 'include_indices' and 'ignore_indices'." | |
| ) | |
| return values | |
| def input_keys(self) -> List[str]: | |
| """Return the singular input key. | |
| :meta private: | |
| """ | |
| return [self.input_key] | |
| def output_keys(self) -> List[str]: | |
| """Return the singular output key. | |
| :meta private: | |
| """ | |
| if not self.return_intermediate_steps: | |
| return [self.output_key] | |
| else: | |
| return [self.output_key, INTERMEDIATE_STEPS_KEY] | |
| def _list_indices(self) -> List[str]: | |
| all_indices = [ | |
| index["index"] for index in self.database.cat.indices(format="json") | |
| ] | |
| if self.include_indices: | |
| all_indices = [i for i in all_indices if i in self.include_indices] | |
| if self.ignore_indices: | |
| all_indices = [i for i in all_indices if i not in self.ignore_indices] | |
| return all_indices | |
| def _get_indices_infos(self, indices: List[str]) -> str: | |
| mappings = self.database.indices.get_mapping(index=",".join(indices)) | |
| if self.sample_documents_in_index_info > 0: | |
| for k, v in mappings.items(): | |
| hits = self.database.search( | |
| index=k, | |
| query={"match_all": {}}, | |
| size=self.sample_documents_in_index_info, | |
| )["hits"]["hits"] | |
| hits = [str(hit["_source"]) for hit in hits] | |
| mappings[k]["mappings"] = str(v) + "\n\n/*\n" + "\n".join(hits) + "\n*/" | |
| return "\n\n".join( | |
| [ | |
| "Mapping for index {}:\n{}".format(index, mappings[index]["mappings"]) | |
| for index in mappings | |
| ] | |
| ) | |
| def _search(self, indices: List[str], query: str) -> str: | |
| result = self.database.search(index=",".join(indices), body=query) | |
| return str(result) | |
| def _call( | |
| self, | |
| inputs: Dict[str, Any], | |
| run_manager: Optional[CallbackManagerForChainRun] = None, | |
| ) -> Dict[str, Any]: | |
| _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
| input_text = f"{inputs[self.input_key]}\nESQuery:" | |
| _run_manager.on_text(input_text, verbose=self.verbose) | |
| indices = self._list_indices() | |
| indices_info = self._get_indices_infos(indices) | |
| query_inputs: dict = { | |
| "input": input_text, | |
| "top_k": str(self.top_k), | |
| "indices_info": indices_info, | |
| "stop": ["\nESResult:"], | |
| } | |
| intermediate_steps: List = [] | |
| try: | |
| intermediate_steps.append(query_inputs) # input: es generation | |
| es_cmd = self.query_chain.run( | |
| callbacks=_run_manager.get_child(), | |
| **query_inputs, | |
| ) | |
| _run_manager.on_text(es_cmd, color="green", verbose=self.verbose) | |
| intermediate_steps.append( | |
| es_cmd | |
| ) # output: elasticsearch dsl generation (no checker) | |
| intermediate_steps.append({"es_cmd": es_cmd}) # input: ES search | |
| result = self._search(indices=indices, query=es_cmd) | |
| intermediate_steps.append(str(result)) # output: ES search | |
| _run_manager.on_text("\nESResult: ", verbose=self.verbose) | |
| _run_manager.on_text(result, color="yellow", verbose=self.verbose) | |
| _run_manager.on_text("\nAnswer:", verbose=self.verbose) | |
| answer_inputs: dict = {"data": result, "input": input_text} | |
| intermediate_steps.append(answer_inputs) # input: final answer | |
| final_result = self.answer_chain.run( | |
| callbacks=_run_manager.get_child(), | |
| **answer_inputs, | |
| ) | |
| intermediate_steps.append(final_result) # output: final answer | |
| _run_manager.on_text(final_result, color="green", verbose=self.verbose) | |
| chain_result: Dict[str, Any] = {self.output_key: final_result} | |
| if self.return_intermediate_steps: | |
| chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps | |
| return chain_result | |
| except Exception as exc: | |
| # Append intermediate steps to exception, to aid in logging and later | |
| # improvement of few shot prompt seeds | |
| exc.intermediate_steps = intermediate_steps # type: ignore | |
| raise exc | |
| def _chain_type(self) -> str: | |
| return "elasticsearch_database_chain" | |
| def from_llm( | |
| cls, | |
| llm: BaseLanguageModel, | |
| database: Elasticsearch, | |
| *, | |
| query_prompt: Optional[BasePromptTemplate] = None, | |
| answer_prompt: Optional[BasePromptTemplate] = None, | |
| query_output_parser: Optional[BaseLLMOutputParser] = None, | |
| **kwargs: Any, | |
| ) -> ElasticsearchDatabaseChain: | |
| """Convenience method to construct ElasticsearchDatabaseChain from an LLM. | |
| Args: | |
| llm: The language model to use. | |
| database: The Elasticsearch db. | |
| query_prompt: The prompt to use for query construction. | |
| answer_prompt: The prompt to use for answering user question given data. | |
| query_output_parser: The output parser to use for parsing model-generated | |
| ES query. Defaults to SimpleJsonOutputParser. | |
| **kwargs: Additional arguments to pass to the constructor. | |
| """ | |
| query_prompt = query_prompt or DSL_PROMPT | |
| query_output_parser = query_output_parser or SimpleJsonOutputParser() | |
| query_chain = LLMChain( | |
| llm=llm, prompt=query_prompt, output_parser=query_output_parser | |
| ) | |
| answer_prompt = answer_prompt or ANSWER_PROMPT | |
| answer_chain = LLMChain(llm=llm, prompt=answer_prompt) | |
| return cls( | |
| query_chain=query_chain, | |
| answer_chain=answer_chain, | |
| database=database, | |
| **kwargs, | |
| ) | |