|
from datasets import load_dataset |
|
from transformers import DPRContextEncoderTokenizer, DPRContextEncoder |
|
from general_utils import embed_passages, embed_passages_haystack |
|
import faiss |
|
import argparse |
|
import os |
|
from haystack.nodes import DensePassageRetriever |
|
from haystack.document_stores import InMemoryDocumentStore |
|
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "8" |
|
|
|
|
|
def create_faiss_index(args): |
|
minchars = 200 |
|
dims = 128 |
|
|
|
dpr = DensePassageRetriever( |
|
document_store=InMemoryDocumentStore(), |
|
query_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base", |
|
passage_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base", |
|
max_seq_len_query=64, |
|
max_seq_len_passage=256, |
|
batch_size=512, |
|
) |
|
|
|
dataset = load_dataset( |
|
"IIC/spanish_biomedical_crawled_corpus", split="train" |
|
) |
|
|
|
dataset = dataset.filter(lambda example: len(example["text"]) > minchars) |
|
|
|
def embed_passages_retrieval(examples): |
|
return embed_passages_haystack(dpr, examples) |
|
|
|
dataset = dataset.map(embed_passages_retrieval, batched=True, batch_size=8192) |
|
|
|
dataset.add_faiss_index( |
|
column="embeddings", |
|
string_factory="OPQ64_128,IVF4898,PQ64x4fsr", |
|
train_size=len(dataset), |
|
) |
|
dataset.save_faiss_index("embeddings", args.index_file_name) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Creates Faiss Wikipedia index file") |
|
|
|
parser.add_argument( |
|
"--ctx_encoder_name", |
|
default="IIC/dpr-spanish-passage_encoder-squades-base", |
|
help="Encoding model to use for passage encoding", |
|
) |
|
|
|
parser.add_argument( |
|
"--index_file_name", |
|
default="dpr_index_bio_splitted.faiss", |
|
help="Faiss index file with passage embeddings", |
|
) |
|
parser.add_argument( |
|
"--device", default="cuda:0", help="The device to index data on." |
|
) |
|
|
|
main_args, _ = parser.parse_known_args() |
|
create_faiss_index(main_args) |
|
|