from datasets import load_dataset from transformers import DPRContextEncoderTokenizer, DPRContextEncoder from general_utils import embed_passages, embed_passages_haystack import faiss import argparse import os from haystack.nodes import DensePassageRetriever from haystack.document_stores import InMemoryDocumentStore os.environ["OMP_NUM_THREADS"] = "8" def create_faiss_index(args): minchars = 200 dims = 128 dpr = DensePassageRetriever( document_store=InMemoryDocumentStore(), query_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base", passage_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base", max_seq_len_query=64, max_seq_len_passage=256, batch_size=512, ) dataset = load_dataset( "IIC/spanish_biomedical_crawled_corpus", split="train" ) dataset = dataset.filter(lambda example: len(example["text"]) > minchars) def embed_passages_retrieval(examples): return embed_passages_haystack(dpr, examples) dataset = dataset.map(embed_passages_retrieval, batched=True, batch_size=8192) dataset.add_faiss_index( column="embeddings", string_factory="OPQ64_128,IVF4898,PQ64x4fsr", train_size=len(dataset), ) dataset.save_faiss_index("embeddings", args.index_file_name) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Creates Faiss Wikipedia index file") parser.add_argument( "--ctx_encoder_name", default="IIC/dpr-spanish-passage_encoder-squades-base", help="Encoding model to use for passage encoding", ) parser.add_argument( "--index_file_name", default="dpr_index_bio_splitted.faiss", help="Faiss index file with passage embeddings", ) parser.add_argument( "--device", default="cuda:0", help="The device to index data on." ) main_args, _ = parser.parse_known_args() create_faiss_index(main_args)