File size: 3,119 Bytes
83870cc 51dabd6 1fb8ae3 51a31d4 51dabd6 51a31d4 b7158e7 8bbe3aa 51a31d4 ab5dfc2 51a31d4 83870cc 51a31d4 83870cc 51a31d4 ab5dfc2 8bbe3aa 1fb8ae3 8bbe3aa 83870cc 8bbe3aa 83870cc 8bbe3aa 83870cc 8bbe3aa 83870cc 8bbe3aa 1fb8ae3 8bbe3aa 1fb8ae3 ab5dfc2 1fb8ae3 b7158e7 1fb8ae3 83870cc 1fb8ae3 83870cc 1fb8ae3 8bbe3aa 83870cc ab5dfc2 1fb8ae3 8bbe3aa 1fb8ae3 8bbe3aa a1746cf 8bbe3aa 83870cc 8bbe3aa 1fb8ae3 83870cc 2827202 8bbe3aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import os.path
import torch
from datasets import DatasetDict, load_dataset
from transformers import (
DPRContextEncoder,
DPRContextEncoderTokenizer,
DPRQuestionEncoder,
DPRQuestionEncoderTokenizer,
)
from src.retrievers.base_retriever import Retriever
from src.utils.log import get_logger
from src.utils.preprocessing import remove_formulas
# Hacky fix for FAISS error on macOS
# See https://stackoverflow.com/a/63374568/4545692
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
logger = get_logger()
class FaissRetriever(Retriever):
"""A class used to retrieve relevant documents based on some query.
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
"""
def __init__(self, dataset: DatasetDict, embedding_path: str = "./src/models/paragraphs_embedding.faiss") -> None:
torch.set_grad_enabled(False)
# Context encoding and tokenization
self.ctx_encoder = DPRContextEncoder.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base"
)
self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base"
)
# Question encoding and tokenization
self.q_encoder = DPRQuestionEncoder.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
self.dataset = dataset
self.embedding_path = embedding_path
self.index = self._init_index()
def _init_index(
self,
force_new_embedding: bool = False):
ds = self.dataset["train"]
ds = ds.map(remove_formulas)
if not force_new_embedding and os.path.exists(self.embedding_path):
ds.load_faiss_index(
'embeddings', self.embedding_path) # type: ignore
return ds
else:
def embed(row):
# Inline helper function to perform embedding
p = row["text"]
tok = self.ctx_tokenizer(
p, return_tensors="pt", truncation=True)
enc = self.ctx_encoder(**tok)[0][0].numpy()
return {"embeddings": enc}
# Add FAISS embeddings
index = ds.map(embed) # type: ignore
index.add_faiss_index(column="embeddings")
# save dataset w/ embeddings
os.makedirs("./src/models/", exist_ok=True)
index.save_faiss_index(
"embeddings", self.embedding_path)
return index
def retrieve(self, query: str, k: int = 50):
def embed(q):
# Inline helper function to perform embedding
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
return self.q_encoder(**tok)[0][0].numpy()
question_embedding = embed(query)
scores, results = self.index.get_nearest_examples(
"embeddings", question_embedding, k=k
)
return scores, results
|