Spaces:
No application file
No application file
from langchain.vectorstores import FAISS | |
from typing import Any, Callable, List, Optional, Tuple, Dict | |
from langchain.docstore.document import Document | |
from langchain.docstore.base import Docstore | |
from langchain.vectorstores.utils import maximal_marginal_relevance | |
from langchain.embeddings.base import Embeddings | |
import uuid | |
from langchain.docstore.in_memory import InMemoryDocstore | |
import numpy as np | |
def dependable_faiss_import() -> Any: | |
"""Import faiss if available, otherwise raise error.""" | |
try: | |
import faiss | |
except ImportError: | |
raise ValueError( | |
"Could not import faiss python package. " | |
"Please install it with `pip install faiss` " | |
"or `pip install faiss-cpu` (depending on Python version)." | |
) | |
return faiss | |
class FAISSVS(FAISS): | |
def __init__(self, | |
embedding_function: Callable[..., Any], | |
index: Any, | |
docstore: Docstore, | |
index_to_docstore_id: Dict[int, str]): | |
super().__init__(embedding_function, index, docstore, index_to_docstore_id) | |
def max_marginal_relevance_search_by_vector( | |
self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any | |
) -> List[Tuple[Document, float]]: | |
"""Return docs selected using the maximal marginal relevance. | |
Maximal marginal relevance optimizes for similarity to query AND diversity | |
among selected documents. | |
Args: | |
embedding: Embedding to look up documents similar to. | |
k: Number of Documents to return. Defaults to 4. | |
fetch_k: Number of Documents to fetch to pass to MMR algorithm. | |
Returns: | |
List of Documents with scores selected by maximal marginal relevance. | |
""" | |
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k) | |
# -1 happens when not enough docs are returned. | |
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1] | |
mmr_selected = maximal_marginal_relevance( | |
np.array([embedding], dtype=np.float32), embeddings, k=k | |
) | |
selected_indices = [indices[0][i] for i in mmr_selected] | |
selected_scores = [scores[0][i] for i in mmr_selected] | |
docs = [] | |
for i, score in zip(selected_indices, selected_scores): | |
if i == -1: | |
# This happens when not enough docs are returned. | |
continue | |
_id = self.index_to_docstore_id[i] | |
doc = self.docstore.search(_id) | |
if not isinstance(doc, Document): | |
raise ValueError(f"Could not find document for id {_id}, got {doc}") | |
docs.append((doc, score)) | |
return docs | |
def max_marginal_relevance_search( | |
self, | |
query: str, | |
k: int = 4, | |
fetch_k: int = 20, | |
**kwargs: Any, | |
) -> List[Tuple[Document, float]]: | |
"""Return docs selected using the maximal marginal relevance. | |
Maximal marginal relevance optimizes for similarity to query AND diversity | |
among selected documents. | |
Args: | |
query: Text to look up documents similar to. | |
k: Number of Documents to return. Defaults to 4. | |
fetch_k: Number of Documents to fetch to pass to MMR algorithm. | |
Returns: | |
List of Documents with scores selected by maximal marginal relevance. | |
""" | |
embedding = self.embedding_function(query) | |
docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k) | |
return docs | |
def __from( | |
cls, | |
texts: List[str], | |
embeddings: List[List[float]], | |
embedding: Embeddings, | |
metadatas: Optional[List[dict]] = None, | |
**kwargs: Any, | |
) -> FAISS: | |
faiss = dependable_faiss_import() | |
index = faiss.IndexFlatIP(len(embeddings[0])) | |
index.add(np.array(embeddings, dtype=np.float32)) | |
# # my code, for speeding up search | |
# quantizer = faiss.IndexFlatL2(len(embeddings[0])) | |
# index = faiss.IndexIVFFlat(quantizer, len(embeddings[0]), 100) | |
# index.train(np.array(embeddings, dtype=np.float32)) | |
# index.add(np.array(embeddings, dtype=np.float32)) | |
documents = [] | |
for i, text in enumerate(texts): | |
metadata = metadatas[i] if metadatas else {} | |
documents.append(Document(page_content=text, metadata=metadata)) | |
index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))} | |
docstore = InMemoryDocstore( | |
{index_to_id[i]: doc for i, doc in enumerate(documents)} | |
) | |
return cls(embedding.embed_query, index, docstore, index_to_id) | |