import pickle # Used to create the dense document vectors. import torch from sentence_transformers import SentenceTransformer import datasets # Used to create and store the Faiss index. import faiss import numpy as np class WitIndex: """ WitIndex is a class to search the wiki snippets from the given text. It can also return link to the wiki page or the image. """ wit_dataset = None def __init__(self, wit_index_path: str, model_name: str, wit_dataset_path: str, gpu=True): self.index = faiss.read_index(wit_index_path) self.model = SentenceTransformer(model_name) if WitIndex.wit_dataset is None: WitIndex.wit_dataset = pickle.load(open(wit_dataset_path, "rb")) print(f"Gpu: {gpu}") if gpu and torch.cuda.is_available(): print("Cuda is available") self.model = self.model.to(torch.device("cuda")) res = faiss.StandardGpuResources() self.index = faiss.index_cpu_to_gpu(res, 0, self.index) def search(self, text, top_k=6): print(f"> Search: {text}") embedding = self.model.encode(text, convert_to_numpy=True, show_progress_bar=False) # Retrieve the k nearest neighbours distance, index = self.index.search(np.array([embedding]), k=top_k) distance, index = distance.flatten().tolist(), index.flatten().tolist() index_url = [WitIndex.wit_dataset['desc2image_map'][i] for i in index] image_info = [WitIndex.wit_dataset['image_info'][i] for i in index_url] print(f"> URL: {image_info[0]}") return distance, index, image_info