NLP_RAG / model /retriever.py
droushb's picture
no message
a424cae
from datasets import load_dataset
from config import CONFIG
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, util
class Retriever:
def __init__(self):
self.corpus = None
self.bm25 = None
self.model = None
self.chunk_embeddings = None
def load_and_prepare_dataset(self):
dataset = load_dataset(CONFIG['DATASET'])
dataset = dataset['train'].select(range(CONFIG['MAX_NUM_OF_RECORDS']))
dataset = dataset.map(lambda x: {'chunks': self.chunk_text(x['abstract'])})
self.corpus = [chunk for chunks in dataset["chunks"] for chunk in chunks]
def prepare_bm25(self):
tokenized_corpus = [doc.split(" ") for doc in self.corpus]
self.bm25 = BM25Okapi(tokenized_corpus)
def compute_embeddings(self):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
tokenizer = self.model._first_module().tokenizer
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
self.chunk_embeddings = self.model.encode(self.corpus, convert_to_tensor=True)
def chunk_text(self, text, chunk_size=CONFIG['CHUNK_SIZE']):
words = text.split()
return [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
def retrieve_documents_bm25(self, query):
tokenized_query = query.split(" ")
scores = self.bm25.get_scores(tokenized_query)
top_docs = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:CONFIG['TOP_DOCS']]
return [self.corpus[i] for i in top_docs]
def retrieve_documents_semantic(self, query):
query_embedding = self.model.encode(query, convert_to_tensor=True)
scores = util.pytorch_cos_sim(query_embedding, self.chunk_embeddings)[0]
top_chunks = scores.topk(CONFIG['TOP_DOCS']).indices
return [self.corpus[i] for i in top_chunks]