File size: 2,871 Bytes
60fc5e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from modules.pdfExtractor import PdfConverter
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document

class EmbeddingModel:
    def __init__(self, model_path=None):
        if model_path is None:
            self.model = SentenceTransformer(
                "thenlper/gte-base",  # switch to en/zh for English or Chinese
                trust_remote_code=True
            )
            self.model.save(os.path.join(os.getcwd(), "embeddingModel"))
        else:
            self.model = SentenceTransformer(model_path)
        
        self.model.max_seq_length = 512

    def encode(self, texts):
        return self.model.encode(texts)

class DocumentProcessor:
    def __init__(self, model, chunk_size=1000, chunk_overlap=200):
        self.model = model
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap

    def context_chunks(self, document_text):
        document = Document(page_content=document_text)
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.chunk_size, 
            chunk_overlap=self.chunk_overlap
        )
        text_chunks = text_splitter.split_documents([document])
        text_content_chunks = [chunk.page_content for chunk in text_chunks]
        return text_content_chunks

    def context_embedding(self, text_content_chunks):
        return [self.model.encode([text]) for text in text_content_chunks]

    def rag_query(self, query):
        return self.model.encode([query])

    def similarity(self, query_embedding, text_contents_embeddings, text_content_chunks, top_k):
        similarities = [
            (text, cos_sim(embedding, query_embedding[0])) 
            for text, embedding in zip(text_content_chunks, text_contents_embeddings)
        ]
        
        similarities_sorted = sorted(similarities, key=lambda x: x[1], reverse=True)
        top_k_texts = [text for text, _ in similarities_sorted[:top_k]]

        return top_k_texts


# Example usage:
if __name__ == "__main__":
    model = EmbeddingModel(model_path=os.path.join(os.getcwd(), "embeddingModel"))
    processor = DocumentProcessor(model=model)

    pdf_file = os.path.join(os.getcwd(), "pdfs", "test2.pdf")
    converter = PdfConverter(pdf_file)
    document_text = converter.convert_to_markdown()
    text_chunks = processor.context_chunks(document_text)
    text_embeddings = processor.context_embedding(text_chunks)

    query = "what metric used in this paper for performance evaluation?"
    query_embedding = processor.rag_query(query)
    top_results = processor.similarity(query_embedding, text_embeddings, text_chunks, top_k=5)

    print(top_results)