|
import os
|
|
from sentence_transformers import SentenceTransformer
|
|
from sentence_transformers.util import cos_sim
|
|
from modules.pdfExtractor import PdfConverter
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain.schema import Document
|
|
|
|
class EmbeddingModel:
|
|
def __init__(self, model_path=None):
|
|
if model_path is None:
|
|
self.model = SentenceTransformer(
|
|
"thenlper/gte-base",
|
|
trust_remote_code=True
|
|
)
|
|
self.model.save(os.path.join(os.getcwd(), "embeddingModel"))
|
|
else:
|
|
self.model = SentenceTransformer(model_path)
|
|
|
|
self.model.max_seq_length = 512
|
|
|
|
def encode(self, texts):
|
|
return self.model.encode(texts)
|
|
|
|
class DocumentProcessor:
|
|
def __init__(self, model, chunk_size=1000, chunk_overlap=200):
|
|
self.model = model
|
|
self.chunk_size = chunk_size
|
|
self.chunk_overlap = chunk_overlap
|
|
|
|
def context_chunks(self, document_text):
|
|
document = Document(page_content=document_text)
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=self.chunk_size,
|
|
chunk_overlap=self.chunk_overlap
|
|
)
|
|
text_chunks = text_splitter.split_documents([document])
|
|
text_content_chunks = [chunk.page_content for chunk in text_chunks]
|
|
return text_content_chunks
|
|
|
|
def context_embedding(self, text_content_chunks):
|
|
return [self.model.encode([text]) for text in text_content_chunks]
|
|
|
|
def rag_query(self, query):
|
|
return self.model.encode([query])
|
|
|
|
def similarity(self, query_embedding, text_contents_embeddings, text_content_chunks, top_k):
|
|
similarities = [
|
|
(text, cos_sim(embedding, query_embedding[0]))
|
|
for text, embedding in zip(text_content_chunks, text_contents_embeddings)
|
|
]
|
|
|
|
similarities_sorted = sorted(similarities, key=lambda x: x[1], reverse=True)
|
|
top_k_texts = [text for text, _ in similarities_sorted[:top_k]]
|
|
|
|
return top_k_texts
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model = EmbeddingModel(model_path=os.path.join(os.getcwd(), "embeddingModel"))
|
|
processor = DocumentProcessor(model=model)
|
|
|
|
pdf_file = os.path.join(os.getcwd(), "pdfs", "test2.pdf")
|
|
converter = PdfConverter(pdf_file)
|
|
document_text = converter.convert_to_markdown()
|
|
text_chunks = processor.context_chunks(document_text)
|
|
text_embeddings = processor.context_embedding(text_chunks)
|
|
|
|
query = "what metric used in this paper for performance evaluation?"
|
|
query_embedding = processor.rag_query(query)
|
|
top_results = processor.similarity(query_embedding, text_embeddings, text_chunks, top_k=5)
|
|
|
|
print(top_results)
|
|
|