Spaces:
Running
Running
from pathlib import Path | |
from typing import Any, Optional, List, Dict, Tuple, ClassVar, Collection | |
from langchain.schema import Document | |
from langchain_community.vectorstores.chroma import Chroma, DEFAULT_K | |
from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
from langchain_core.utils import xor_args | |
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever | |
class AdvancedVectorStoreRetriever(VectorStoreRetriever): | |
allowed_search_types: ClassVar[Collection[str]] = ( | |
"similarity", | |
"similarity_score_threshold", | |
"mmr", | |
"similarity_with_embeddings" | |
) | |
def _get_relevant_documents( | |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
if self.search_type == "similarity_with_embeddings": | |
docs_scores_and_embeddings = ( | |
self.vectorstore.advanced_similarity_search( | |
query, **self.search_kwargs | |
) | |
) | |
for doc, score, embeddings in docs_scores_and_embeddings: | |
if '__embeddings' not in doc.metadata.keys(): | |
doc.metadata['__embeddings'] = embeddings | |
if '__similarity' not in doc.metadata.keys(): | |
doc.metadata['__similarity'] = score | |
docs = [doc for doc, _, _ in docs_scores_and_embeddings] | |
elif self.search_type == "similarity_score_threshold": | |
docs_and_similarities = ( | |
self.vectorstore.similarity_search_with_relevance_scores( | |
query, **self.search_kwargs | |
) | |
) | |
for doc, similarity in docs_and_similarities: | |
if '__similarity' not in doc.metadata.keys(): | |
doc.metadata['__similarity'] = similarity | |
docs = [doc for doc, _ in docs_and_similarities] | |
else: | |
docs = super()._get_relevant_documents(query, run_manager=run_manager) | |
return docs | |
class AdvancedVectorStore(VectorStore): | |
def as_retriever(self, **kwargs: Any) -> AdvancedVectorStoreRetriever: | |
tags = kwargs.pop("tags", None) or [] | |
tags.extend(self._get_retriever_tags()) | |
return AdvancedVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags) | |
class ChromaAdvancedRetrieval(Chroma, AdvancedVectorStore): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
def __query_collection( | |
self, | |
query_texts: Optional[List[str]] = None, | |
query_embeddings: Optional[List[List[float]]] = None, | |
n_results: int = 4, | |
where: Optional[Dict[str, str]] = None, | |
where_document: Optional[Dict[str, str]] = None, | |
**kwargs: Any, | |
) -> List[Document]: | |
"""Query the chroma collection.""" | |
try: | |
import chromadb # noqa: F401 | |
except ImportError: | |
raise ValueError( | |
"Could not import chromadb python package. " | |
"Please install it with `pip install chromadb`." | |
) | |
return self._collection.query( | |
query_texts=query_texts, | |
query_embeddings=query_embeddings, | |
n_results=n_results, | |
where=where, | |
where_document=where_document, | |
**kwargs, | |
) | |
def advanced_similarity_search( | |
self, | |
query: str, | |
k: int = DEFAULT_K, | |
filter: Optional[Dict[str, str]] = None, | |
**kwargs: Any, | |
) -> [List[Document], float, List[float]]: | |
docs_scores_and_embeddings = self.similarity_search_with_scores_and_embeddings(query, k, filter=filter) | |
return docs_scores_and_embeddings | |
def similarity_search_with_scores_and_embeddings( | |
self, | |
query: str, | |
k: int = DEFAULT_K, | |
filter: Optional[Dict[str, str]] = None, | |
where_document: Optional[Dict[str, str]] = None, | |
**kwargs: Any, | |
) -> List[Tuple[Document, float, List[float]]]: | |
if self._embedding_function is None: | |
results = self.__query_collection( | |
query_texts=[query], | |
n_results=k, | |
where=filter, | |
where_document=where_document, | |
include=['metadatas', 'documents', 'embeddings', 'distances'] | |
) | |
else: | |
query_embedding = self._embedding_function.embed_query(query) | |
results = self.__query_collection( | |
query_embeddings=[query_embedding], | |
n_results=k, | |
where=filter, | |
where_document=where_document, | |
include=['metadatas', 'documents', 'embeddings', 'distances'] | |
) | |
return _results_to_docs_scores_and_embeddings(results) | |
def _results_to_docs_scores_and_embeddings(results: Any) -> List[Tuple[Document, float, List[float]]]: | |
return [ | |
(Document(page_content=result[0], metadata=result[1] or {}), result[2], result[3]) | |
for result in zip( | |
results["documents"][0], | |
results["metadatas"][0], | |
results["distances"][0], | |
results["embeddings"][0], | |
) | |
] | |