EZ-Crossword / Faiss_Indexers_inf.py
Ujjwal123's picture
copied the whole api code from django and updated the dockerfile
a04b340
import os
import time
import logging
import pickle
from typing import List, Tuple, Iterator
import faiss
import numpy as np
logger = logging.getLogger()
class DenseIndexer(object):
def __init__(self, buffer_size: int = 50000):
self.buffer_size = buffer_size
self.index_id_to_db_id = []
self.index = None
def index_data(self, vector_files: List[str]):
start_time = time.time()
buffer = []
for i, item in enumerate(iterate_encoded_files(vector_files)):
db_id, doc_vector = item
buffer.append((db_id, doc_vector))
if 0 < self.buffer_size == len(buffer):
# indexing in batches is beneficial for many faiss index types
self._index_batch(buffer)
logger.info(
"data indexed %d, used_time: %f sec.",
len(self.index_id_to_db_id),
time.time() - start_time,
)
buffer = []
self._index_batch(buffer)
indexed_cnt = len(self.index_id_to_db_id)
logger.info("Total data indexed %d", indexed_cnt)
logger.info("Data indexing completed.")
def _index_batch(self, data: List[Tuple[object, np.array]]):
raise NotImplementedError
def search_knn(
self, query_vectors: np.array, top_docs: int
) -> List[Tuple[List[object], List[float]]]:
raise NotImplementedError
def serialize(self, file: str):
logger.info("Serializing index to %s", file)
if os.path.isdir(file):
index_file = os.path.join(file, "index.dpr")
meta_file = os.path.join(file, "index_meta.dpr")
else:
index_file = file + ".index.dpr"
meta_file = file + ".index_meta.dpr"
faiss.write_index(self.index, index_file)
with open(meta_file, mode="wb") as f:
pickle.dump(self.index_id_to_db_id, f)
def deserialize_from(self, file: str):
logger.info("Loading index from %s", file)
if os.path.isdir(file):
index_file = os.path.join(file, "index.dpr")
meta_file = os.path.join(file, "index_meta.dpr")
else:
index_file = file + ".index.dpr"
meta_file = file + ".index_meta.dpr"
self.index = faiss.read_index(index_file)
logger.info(
"Loaded index of type %s and size %d", type(self.index), self.index.ntotal
)
with open(meta_file, "rb") as reader:
self.index_id_to_db_id = pickle.load(reader)
assert (
len(self.index_id_to_db_id) == self.index.ntotal
), "Deserialized index_id_to_db_id should match faiss index size"
def _update_id_mapping(self, db_ids: List):
self.index_id_to_db_id.extend(db_ids)
class DenseFlatIndexer(DenseIndexer):
def __init__(self, vector_sz: int, buffer_size: int = 50000):
super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size)
#res = faiss.StandardGpuResources()
#cpu_index = faiss.IndexFlatIP(vector_sz)
#self.index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
self.index = faiss.IndexFlatIP(vector_sz)
self.all_vectors = None
def _index_batch(self, data: List[Tuple[object, np.array]]):
db_ids = [t[0] for t in data]
vectors = [np.reshape(t[1], (1, -1)) for t in data]
vectors = np.concatenate(vectors, axis=0)
self._update_id_mapping(db_ids)
self.index.add(vectors)
#if self.all_vectors is None:
# self.all_vectors = vectors
#else:
# self.all_vectors = np.concatenate((self.all_vectors, vectors), axis=0)
def search_knn(
self, query_vectors: np.array, top_docs: int
) -> List[Tuple[List[object], List[float]]]:
scores, indexes = self.index.search(query_vectors, top_docs)
# convert to external ids
db_ids = [
[self.index_id_to_db_id[i] for i in query_top_idxs]
for query_top_idxs in indexes
]
result = [(db_ids[i], scores[i]) for i in range(len(db_ids))]
return result
class DenseHNSWFlatIndexer(DenseIndexer):
"""
Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage
"""
def __init__(
self,
vector_sz: int,
buffer_size: int = 50000,
store_n: int = 512,
ef_search: int = 128,
ef_construction: int = 200,
):
super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size)
# IndexHNSWFlat supports L2 similarity only
# so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension
index = faiss.IndexHNSWFlat(vector_sz + 1, store_n)
index.hnsw.efSearch = ef_search
index.hnsw.efConstruction = ef_construction
self.index = index
self.phi = None
def index_data(self, vector_files: List[str]):
self._set_phi(vector_files)
super(DenseHNSWFlatIndexer, self).index_data(vector_files)
def _set_phi(self, vector_files: List[str]):
"""
Calculates the max norm from the whole data and assign it to self.phi: necessary to transform IP -> L2 space
:param vector_files: file names to get passages vectors from
:return:
"""
phi = 0
for i, item in enumerate(iterate_encoded_files(vector_files)):
id, doc_vector = item
norms = (doc_vector ** 2).sum()
phi = max(phi, norms)
logger.info("HNSWF DotProduct -> L2 space phi={}".format(phi))
self.phi = phi
def _index_batch(self, data: List[Tuple[object, np.array]]):
# max norm is required before putting all vectors in the index to convert inner product similarity to L2
if self.phi is None:
raise RuntimeError(
"Max norm needs to be calculated from all data at once,"
"results will be unpredictable otherwise."
"Run `_set_phi()` before calling this method."
)
db_ids = [t[0] for t in data]
vectors = [np.reshape(t[1], (1, -1)) for t in data]
norms = [(doc_vector ** 2).sum() for doc_vector in vectors]
aux_dims = [np.sqrt(self.phi - norm) for norm in norms]
hnsw_vectors = [
np.hstack((doc_vector, aux_dims[i].reshape(-1, 1)))
for i, doc_vector in enumerate(vectors)
]
hnsw_vectors = np.concatenate(hnsw_vectors, axis=0)
self._update_id_mapping(db_ids)
self.index.add(hnsw_vectors)
def search_knn(
self, query_vectors: np.array, top_docs: int
) -> List[Tuple[List[object], List[float]]]:
aux_dim = np.zeros(len(query_vectors), dtype="float32")
query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1)))
logger.info("query_hnsw_vectors %s", query_nhsw_vectors.shape)
scores, indexes = self.index.search(query_nhsw_vectors, top_docs)
# convert to external ids
db_ids = [
[self.index_id_to_db_id[i] for i in query_top_idxs]
for query_top_idxs in indexes
]
result = [(db_ids[i], scores[i]) for i in range(len(db_ids))]
return result
def deserialize_from(self, file: str):
super(DenseHNSWFlatIndexer, self).deserialize_from(file)
# to trigger warning on subsequent indexing
self.phi = None
def iterate_encoded_files(vector_files: str) -> Iterator[Tuple[object, np.array]]:
# for i, file in enumerate(vector_files):
logger.info("Reading file %s", vector_files)
with open(vector_files, "rb") as reader:
doc_vectors = pickle.load(reader)
for doc in doc_vectors:
db_id, doc_vector = doc
yield db_id, doc_vector