from typing import List from langchain_core.documents import Document from langchain_core.pydantic_v1 import Field from langchain_core.retrievers import BaseRetriever from langchain_core.stores import BaseStore from langchain_core.vectorstores import VectorStore from langchain.callbacks.manager import CallbackManagerForRetrieverRun class MultiVectorRetriever(BaseRetriever): """Retrieve from a set of multiple embeddings for the same document.""" vectorstore: VectorStore """The underlying vectorstore to use to store small chunks and their embedding vectors""" docstore: BaseStore[str, Document] """The storage layer for the parent documents""" id_key: str = "doc_id" search_kwargs: dict = Field(default_factory=dict) """Keyword arguments to pass to the search function.""" def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """Get documents relevant to a query. Args: query: String to find relevant documents for run_manager: The callbacks handler to use Returns: List of relevant documents """ sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs) # We do this to maintain the order of the ids that are returned ids = [] for d in sub_docs: if d.metadata[self.id_key] not in ids: ids.append(d.metadata[self.id_key]) docs = self.docstore.mget(ids) return [d for d in docs if d is not None]