Test / search.py
ArunSamespace's picture
Upload 8 files
9921884 verified
import pickle
import uuid
from typing import Any, Callable, List, Optional
import faiss
import numpy as np
from langchain.docstore.document import Document
from langchain.docstore.in_memory import InMemoryDocstore
from langchain.embeddings.base import Embeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from tqdm import tqdm
def return_on_failure(value):
def decorate(f):
def applicator(*args, **kwargs):
try:
return f(*args,**kwargs)
except Exception as e:
print(f'Error "{e}" in {f.__name__}')
return value
return applicator
return decorate
class SimilaritySearch(FAISS):
@classmethod
@return_on_failure(None)
def load_from_disk(cls, embedding_function: Callable, data_dir: str = None):
docstore, index_to_docstore_id = pickle.load(open(f"{data_dir}/index.pkl", "rb"))
index_cpu = faiss.read_index(f"{data_dir}/index.faiss")
# index_gpu = faiss.index_cpu_to_gpu(GPU_RESOURCE, 0, index_cpu)
# vector_store = FAISS(embedding_function, index_gpu, docstore, index_to_docstore_id)
return FAISS(embedding_function, index_cpu, docstore, index_to_docstore_id)
@classmethod
def __from(
cls,
texts: List[str],
embeddings: List[List[float]],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> FAISS:
print("embeddings: ", len(embeddings), len(texts), len(metadatas))
index = faiss.IndexFlatIP(len(embeddings[0]))
index.add(np.array(embeddings, dtype=np.float32))
documents = []
for i, text in tqdm(enumerate(texts), total=len(texts)):
metadata = metadatas[i] if metadatas else {}
documents.append(Document(page_content=text, metadata=metadata))
index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
docstore = InMemoryDocstore(
{index_to_id[i]: doc for i, doc in enumerate(documents)}
)
return cls(embedding.embed_query, index, docstore, index_to_id, **kwargs)
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> FAISS:
"""Construct FAISS wrapper from raw documents.
This is a user friendly interface that:
1. Embeds documents.
2. Creates an in memory docstore
3. Initializes the FAISS database
This is intended to be a quick way to get started.
Example:
.. code-block:: python
from langchain import FAISS
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
faiss = FAISS.from_texts(texts, embeddings)
"""
# embeddings = embedding.embed_documents(texts)
final_texts, final_metadatas = [], []
embeddings = []
for i, text in tqdm(enumerate(texts), total=len(texts)):
try:
embeddings.append(embedding._embedding_func(text))
final_texts.append(text)
if len(metadatas) > 0:
final_metadatas.append(metadatas[i])
except Exception as e:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=4096, chunk_overlap=128)
splitted_texts = text_splitter.split_text(text)
embeddings.extend(embedding.embed_documents(splitted_texts))
final_texts.extend(splitted_texts)
final_metadatas.extend([metadatas[i]] * len(splitted_texts))
return cls.__from(
final_texts,
embeddings,
embedding,
metadatas=final_metadatas,
# ids=ids,
**kwargs,
)