"""Retriever that generates and executes structured queries over its own data source. NOTE: This code is adapted from the original implementation in the LangChain repo, but has been modified to work with the KTH QA system. """ from langchain.vectorstores import Pinecone, VectorStore from langchain.schema import BaseRetriever, Document from langchain.retrievers.self_query.pinecone import PineconeTranslator from langchain.chains.query_constructor.schema import AttributeInfo from langchain.chains.query_constructor.ir import StructuredQuery, Visitor from langchain.chains.query_constructor.base import load_query_constructor_chain from langchain.base_language import BaseLanguageModel from langchain import LLMChain from pydantic import BaseModel, Field, root_validator import re from typing import Any, Dict, List, Optional, Type, cast import logging logger = logging.getLogger() COURSE_PATTERN = r"[a-zA-Z]{2,3}\d{3,4}\w?" # e.g. DD1315 def make_uppercase(matchobj): return matchobj.group(0).upper() def _get_builtin_translator(vectorstore_cls: Type[VectorStore]) -> Visitor: """Get the translator class corresponding to the vector store class.""" BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = { Pinecone: PineconeTranslator } if vectorstore_cls not in BUILTIN_TRANSLATORS: raise ValueError( f"Self query retriever with Vector Store type {vectorstore_cls}" f" not supported." ) return BUILTIN_TRANSLATORS[vectorstore_cls]() class SelfQueryRetriever(BaseRetriever, BaseModel): """Retriever that wraps around a vector store and uses an LLM to generate the vector store queries.""" vectorstore: VectorStore """The underlying vector store from which documents will be retrieved.""" llm_chain: LLMChain """The LLMChain for generating the vector store queries.""" search_type: str = "similarity" """The search type to perform on the vector store.""" search_kwargs: dict = Field(default_factory=dict) """Keyword arguments to pass in to the vector store search.""" structured_query_translator: Visitor """Translator for turning internal query language into vectorstore search params.""" verbose: bool = False class Config: """Configuration for this pydantic object.""" arbitrary_types_allowed = True @root_validator(pre=True) def validate_translator(cls, values: Dict) -> Dict: """Validate translator.""" if "structured_query_translator" not in values: vectorstore_cls = values["vectorstore"].__class__ values["structured_query_translator"] = _get_builtin_translator( vectorstore_cls ) return values def get_relevant_documents(self, query: str) -> List[Document]: """Get documents relevant for a query. Args: query: string to find relevant documents for Returns: List of relevant documents """ if re.findall(COURSE_PATTERN, query): query = re.sub(COURSE_PATTERN, make_uppercase, query) inputs = self.llm_chain.prep_inputs(query) structured_query = cast( StructuredQuery, self.llm_chain.predict_and_parse( callbacks=None, **inputs) ) if self.verbose: logger.info( "Found course pattern in query, using structured query:") logger.info(structured_query) new_query, new_kwargs = self.structured_query_translator.visit_structured_query( structured_query ) search_kwargs = {**self.search_kwargs, **new_kwargs} else: search_kwargs = self.search_kwargs docs = self.vectorstore.search( query, self.search_type, **search_kwargs) return docs async def aget_relevant_documents(self, query: str) -> List[Document]: raise NotImplementedError @classmethod def from_llm( cls, llm: BaseLanguageModel, vectorstore: VectorStore, document_contents: str, metadata_field_info: List[AttributeInfo], structured_query_translator: Optional[Visitor] = None, chain_kwargs: Optional[Dict] = None, **kwargs: Any, ) -> "SelfQueryRetriever": if structured_query_translator is None: structured_query_translator = _get_builtin_translator( vectorstore.__class__) chain_kwargs = chain_kwargs or {} if "allowed_comparators" not in chain_kwargs: chain_kwargs[ "allowed_comparators" ] = structured_query_translator.allowed_comparators if "allowed_operators" not in chain_kwargs: chain_kwargs[ "allowed_operators" ] = structured_query_translator.allowed_operators llm_chain = load_query_constructor_chain( llm, document_contents, metadata_field_info, **chain_kwargs ) return cls( llm_chain=llm_chain, vectorstore=vectorstore, structured_query_translator=structured_query_translator, **kwargs, )