File size: 1,932 Bytes
dc75be1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a424cae
 
 
dc75be1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from datasets import load_dataset
from config import CONFIG
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, util


class Retriever:
    def __init__(self):
        self.corpus = None
        self.bm25 = None
        self.model = None
        self.chunk_embeddings = None

    def load_and_prepare_dataset(self):
        dataset = load_dataset(CONFIG['DATASET'])
        dataset = dataset['train'].select(range(CONFIG['MAX_NUM_OF_RECORDS']))
        dataset = dataset.map(lambda x: {'chunks': self.chunk_text(x['abstract'])})
        self.corpus = [chunk for chunks in dataset["chunks"] for chunk in chunks]

    def prepare_bm25(self):
        tokenized_corpus = [doc.split(" ") for doc in self.corpus]
        self.bm25 = BM25Okapi(tokenized_corpus)

    def compute_embeddings(self):
        self.model = SentenceTransformer('all-MiniLM-L6-v2')
        tokenizer = self.model._first_module().tokenizer
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        self.chunk_embeddings = self.model.encode(self.corpus, convert_to_tensor=True)

    def chunk_text(self, text, chunk_size=CONFIG['CHUNK_SIZE']):
        words = text.split()
        return [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]

    def retrieve_documents_bm25(self, query):
        tokenized_query = query.split(" ")
        scores = self.bm25.get_scores(tokenized_query)
        top_docs = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:CONFIG['TOP_DOCS']]
        return [self.corpus[i] for i in top_docs]

    def retrieve_documents_semantic(self, query):
        query_embedding = self.model.encode(query, convert_to_tensor=True)
        scores = util.pytorch_cos_sim(query_embedding, self.chunk_embeddings)[0]
        top_chunks = scores.topk(CONFIG['TOP_DOCS']).indices
        return [self.corpus[i] for i in top_chunks]