Spaces:
Running
Running
import os | |
import time | |
import logging | |
import pickle | |
from typing import List, Tuple, Iterator | |
import faiss | |
import numpy as np | |
logger = logging.getLogger() | |
class DenseIndexer(object): | |
def __init__(self, buffer_size: int = 50000): | |
self.buffer_size = buffer_size | |
self.index_id_to_db_id = [] | |
self.index = None | |
def index_data(self, vector_files: List[str]): | |
start_time = time.time() | |
buffer = [] | |
for i, item in enumerate(iterate_encoded_files(vector_files)): | |
db_id, doc_vector = item | |
buffer.append((db_id, doc_vector)) | |
if 0 < self.buffer_size == len(buffer): | |
# indexing in batches is beneficial for many faiss index types | |
self._index_batch(buffer) | |
logger.info( | |
"data indexed %d, used_time: %f sec.", | |
len(self.index_id_to_db_id), | |
time.time() - start_time, | |
) | |
buffer = [] | |
self._index_batch(buffer) | |
indexed_cnt = len(self.index_id_to_db_id) | |
logger.info("Total data indexed %d", indexed_cnt) | |
logger.info("Data indexing completed.") | |
def _index_batch(self, data: List[Tuple[object, np.array]]): | |
raise NotImplementedError | |
def search_knn( | |
self, query_vectors: np.array, top_docs: int | |
) -> List[Tuple[List[object], List[float]]]: | |
raise NotImplementedError | |
def serialize(self, file: str): | |
logger.info("Serializing index to %s", file) | |
if os.path.isdir(file): | |
index_file = os.path.join(file, "index.dpr") | |
meta_file = os.path.join(file, "index_meta.dpr") | |
else: | |
index_file = file + ".index.dpr" | |
meta_file = file + ".index_meta.dpr" | |
faiss.write_index(self.index, index_file) | |
with open(meta_file, mode="wb") as f: | |
pickle.dump(self.index_id_to_db_id, f) | |
def deserialize_from(self, file: str): | |
logger.info("Loading index from %s", file) | |
if os.path.isdir(file): | |
index_file = os.path.join(file, "index.dpr") | |
meta_file = os.path.join(file, "index_meta.dpr") | |
else: | |
index_file = file + ".index.dpr" | |
meta_file = file + ".index_meta.dpr" | |
self.index = faiss.read_index(index_file) | |
logger.info( | |
"Loaded index of type %s and size %d", type(self.index), self.index.ntotal | |
) | |
with open(meta_file, "rb") as reader: | |
self.index_id_to_db_id = pickle.load(reader) | |
assert ( | |
len(self.index_id_to_db_id) == self.index.ntotal | |
), "Deserialized index_id_to_db_id should match faiss index size" | |
def _update_id_mapping(self, db_ids: List): | |
self.index_id_to_db_id.extend(db_ids) | |
class DenseFlatIndexer(DenseIndexer): | |
def __init__(self, vector_sz: int, buffer_size: int = 50000): | |
super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size) | |
#res = faiss.StandardGpuResources() | |
#cpu_index = faiss.IndexFlatIP(vector_sz) | |
#self.index = faiss.index_cpu_to_gpu(res, 0, cpu_index) | |
self.index = faiss.IndexFlatIP(vector_sz) | |
self.all_vectors = None | |
def _index_batch(self, data: List[Tuple[object, np.array]]): | |
db_ids = [t[0] for t in data] | |
vectors = [np.reshape(t[1], (1, -1)) for t in data] | |
vectors = np.concatenate(vectors, axis=0) | |
self._update_id_mapping(db_ids) | |
self.index.add(vectors) | |
#if self.all_vectors is None: | |
# self.all_vectors = vectors | |
#else: | |
# self.all_vectors = np.concatenate((self.all_vectors, vectors), axis=0) | |
def search_knn( | |
self, query_vectors: np.array, top_docs: int | |
) -> List[Tuple[List[object], List[float]]]: | |
scores, indexes = self.index.search(query_vectors, top_docs) | |
# convert to external ids | |
db_ids = [ | |
[self.index_id_to_db_id[i] for i in query_top_idxs] | |
for query_top_idxs in indexes | |
] | |
result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] | |
return result | |
class DenseHNSWFlatIndexer(DenseIndexer): | |
""" | |
Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage | |
""" | |
def __init__( | |
self, | |
vector_sz: int, | |
buffer_size: int = 50000, | |
store_n: int = 512, | |
ef_search: int = 128, | |
ef_construction: int = 200, | |
): | |
super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size) | |
# IndexHNSWFlat supports L2 similarity only | |
# so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension | |
index = faiss.IndexHNSWFlat(vector_sz + 1, store_n) | |
index.hnsw.efSearch = ef_search | |
index.hnsw.efConstruction = ef_construction | |
self.index = index | |
self.phi = None | |
def index_data(self, vector_files: List[str]): | |
self._set_phi(vector_files) | |
super(DenseHNSWFlatIndexer, self).index_data(vector_files) | |
def _set_phi(self, vector_files: List[str]): | |
""" | |
Calculates the max norm from the whole data and assign it to self.phi: necessary to transform IP -> L2 space | |
:param vector_files: file names to get passages vectors from | |
:return: | |
""" | |
phi = 0 | |
for i, item in enumerate(iterate_encoded_files(vector_files)): | |
id, doc_vector = item | |
norms = (doc_vector ** 2).sum() | |
phi = max(phi, norms) | |
logger.info("HNSWF DotProduct -> L2 space phi={}".format(phi)) | |
self.phi = phi | |
def _index_batch(self, data: List[Tuple[object, np.array]]): | |
# max norm is required before putting all vectors in the index to convert inner product similarity to L2 | |
if self.phi is None: | |
raise RuntimeError( | |
"Max norm needs to be calculated from all data at once," | |
"results will be unpredictable otherwise." | |
"Run `_set_phi()` before calling this method." | |
) | |
db_ids = [t[0] for t in data] | |
vectors = [np.reshape(t[1], (1, -1)) for t in data] | |
norms = [(doc_vector ** 2).sum() for doc_vector in vectors] | |
aux_dims = [np.sqrt(self.phi - norm) for norm in norms] | |
hnsw_vectors = [ | |
np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) | |
for i, doc_vector in enumerate(vectors) | |
] | |
hnsw_vectors = np.concatenate(hnsw_vectors, axis=0) | |
self._update_id_mapping(db_ids) | |
self.index.add(hnsw_vectors) | |
def search_knn( | |
self, query_vectors: np.array, top_docs: int | |
) -> List[Tuple[List[object], List[float]]]: | |
aux_dim = np.zeros(len(query_vectors), dtype="float32") | |
query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) | |
logger.info("query_hnsw_vectors %s", query_nhsw_vectors.shape) | |
scores, indexes = self.index.search(query_nhsw_vectors, top_docs) | |
# convert to external ids | |
db_ids = [ | |
[self.index_id_to_db_id[i] for i in query_top_idxs] | |
for query_top_idxs in indexes | |
] | |
result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] | |
return result | |
def deserialize_from(self, file: str): | |
super(DenseHNSWFlatIndexer, self).deserialize_from(file) | |
# to trigger warning on subsequent indexing | |
self.phi = None | |
def iterate_encoded_files(vector_files: str) -> Iterator[Tuple[object, np.array]]: | |
# for i, file in enumerate(vector_files): | |
logger.info("Reading file %s", vector_files) | |
with open(vector_files, "rb") as reader: | |
doc_vectors = pickle.load(reader) | |
for doc in doc_vectors: | |
db_id, doc_vector = doc | |
yield db_id, doc_vector |