Spaces:
Runtime error
Runtime error
File size: 9,264 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
"""Retriever that generates and executes structured queries over its own data source."""
import logging
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable
from langchain_core.vectorstores import VectorStore
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains.query_constructor.base import load_query_constructor_runnable
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers.self_query.chroma import ChromaTranslator
from langchain.retrievers.self_query.dashvector import DashvectorTranslator
from langchain.retrievers.self_query.deeplake import DeepLakeTranslator
from langchain.retrievers.self_query.elasticsearch import ElasticsearchTranslator
from langchain.retrievers.self_query.milvus import MilvusTranslator
from langchain.retrievers.self_query.myscale import MyScaleTranslator
from langchain.retrievers.self_query.opensearch import OpenSearchTranslator
from langchain.retrievers.self_query.pinecone import PineconeTranslator
from langchain.retrievers.self_query.qdrant import QdrantTranslator
from langchain.retrievers.self_query.redis import RedisTranslator
from langchain.retrievers.self_query.supabase import SupabaseVectorTranslator
from langchain.retrievers.self_query.timescalevector import TimescaleVectorTranslator
from langchain.retrievers.self_query.vectara import VectaraTranslator
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
from langchain.vectorstores import (
Chroma,
DashVector,
DeepLake,
ElasticsearchStore,
Milvus,
MyScale,
OpenSearchVectorSearch,
Pinecone,
Qdrant,
Redis,
SupabaseVectorStore,
TimescaleVector,
Vectara,
Weaviate,
)
logger = logging.getLogger(__name__)
def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
"""Get the translator class corresponding to the vector store class."""
BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = {
Pinecone: PineconeTranslator,
Chroma: ChromaTranslator,
DashVector: DashvectorTranslator,
Weaviate: WeaviateTranslator,
Vectara: VectaraTranslator,
Qdrant: QdrantTranslator,
MyScale: MyScaleTranslator,
DeepLake: DeepLakeTranslator,
ElasticsearchStore: ElasticsearchTranslator,
Milvus: MilvusTranslator,
SupabaseVectorStore: SupabaseVectorTranslator,
TimescaleVector: TimescaleVectorTranslator,
OpenSearchVectorSearch: OpenSearchTranslator,
}
if isinstance(vectorstore, Qdrant):
return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key)
elif isinstance(vectorstore, MyScale):
return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
elif isinstance(vectorstore, Redis):
return RedisTranslator.from_vectorstore(vectorstore)
elif vectorstore.__class__ in BUILTIN_TRANSLATORS:
return BUILTIN_TRANSLATORS[vectorstore.__class__]()
else:
raise ValueError(
f"Self query retriever with Vector Store type {vectorstore.__class__}"
f" not supported."
)
class SelfQueryRetriever(BaseRetriever, BaseModel):
"""Retriever that uses a vector store and an LLM to generate
the vector store queries."""
vectorstore: VectorStore
"""The underlying vector store from which documents will be retrieved."""
query_constructor: Runnable[dict, StructuredQuery] = Field(alias="llm_chain")
"""The query constructor chain for generating the vector store queries.
llm_chain is legacy name kept for backwards compatibility."""
search_type: str = "similarity"
"""The search type to perform on the vector store."""
search_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass in to the vector store search."""
structured_query_translator: Visitor
"""Translator for turning internal query language into vectorstore search params."""
verbose: bool = False
use_original_query: bool = False
"""Use original query instead of the revised new query from LLM"""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
allow_population_by_field_name = True
@root_validator(pre=True)
def validate_translator(cls, values: Dict) -> Dict:
"""Validate translator."""
if "structured_query_translator" not in values:
values["structured_query_translator"] = _get_builtin_translator(
values["vectorstore"]
)
return values
@property
def llm_chain(self) -> Runnable:
"""llm_chain is legacy name kept for backwards compatibility."""
return self.query_constructor
def _prepare_query(
self, query: str, structured_query: StructuredQuery
) -> Tuple[str, Dict[str, Any]]:
new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
structured_query
)
if structured_query.limit is not None:
new_kwargs["k"] = structured_query.limit
if self.use_original_query:
new_query = query
search_kwargs = {**self.search_kwargs, **new_kwargs}
return new_query, search_kwargs
def _get_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
) -> List[Document]:
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
return docs
async def _aget_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
) -> List[Document]:
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
return docs
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
structured_query = self.query_constructor.invoke(
{"query": query}, config={"callbacks": run_manager.get_child()}
)
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = self._get_docs_with_query(new_query, search_kwargs)
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
structured_query = await self.query_constructor.ainvoke(
{"query": query}, config={"callbacks": run_manager.get_child()}
)
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = await self._aget_docs_with_query(new_query, search_kwargs)
return docs
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
vectorstore: VectorStore,
document_contents: str,
metadata_field_info: Sequence[Union[AttributeInfo, dict]],
structured_query_translator: Optional[Visitor] = None,
chain_kwargs: Optional[Dict] = None,
enable_limit: bool = False,
use_original_query: bool = False,
**kwargs: Any,
) -> "SelfQueryRetriever":
if structured_query_translator is None:
structured_query_translator = _get_builtin_translator(vectorstore)
chain_kwargs = chain_kwargs or {}
if (
"allowed_comparators" not in chain_kwargs
and structured_query_translator.allowed_comparators is not None
):
chain_kwargs[
"allowed_comparators"
] = structured_query_translator.allowed_comparators
if (
"allowed_operators" not in chain_kwargs
and structured_query_translator.allowed_operators is not None
):
chain_kwargs[
"allowed_operators"
] = structured_query_translator.allowed_operators
query_constructor = load_query_constructor_runnable(
llm,
document_contents,
metadata_field_info,
enable_limit=enable_limit,
**chain_kwargs,
)
return cls(
query_constructor=query_constructor,
vectorstore=vectorstore,
use_original_query=use_original_query,
structured_query_translator=structured_query_translator,
**kwargs,
)
|