import os from sentence_transformers import SentenceTransformer from sentence_transformers.util import cos_sim from modules.pdfExtractor import PdfConverter from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.schema import Document class EmbeddingModel: def __init__(self, model_path=None): if model_path is None: self.model = SentenceTransformer( "thenlper/gte-base", # switch to en/zh for English or Chinese trust_remote_code=True ) self.model.save(os.path.join(os.getcwd(), "embeddingModel")) else: self.model = SentenceTransformer(model_path) self.model.max_seq_length = 512 def encode(self, texts): return self.model.encode(texts) class DocumentProcessor: def __init__(self, model, chunk_size=1000, chunk_overlap=200): self.model = model self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap def context_chunks(self, document_text): document = Document(page_content=document_text) text_splitter = RecursiveCharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap ) text_chunks = text_splitter.split_documents([document]) text_content_chunks = [chunk.page_content for chunk in text_chunks] return text_content_chunks def context_embedding(self, text_content_chunks): return [self.model.encode([text]) for text in text_content_chunks] def rag_query(self, query): return self.model.encode([query]) def similarity(self, query_embedding, text_contents_embeddings, text_content_chunks, top_k): similarities = [ (text, cos_sim(embedding, query_embedding[0])) for text, embedding in zip(text_content_chunks, text_contents_embeddings) ] similarities_sorted = sorted(similarities, key=lambda x: x[1], reverse=True) top_k_texts = [text for text, _ in similarities_sorted[:top_k]] return top_k_texts # Example usage: if __name__ == "__main__": model = EmbeddingModel(model_path=os.path.join(os.getcwd(), "embeddingModel")) processor = DocumentProcessor(model=model) pdf_file = os.path.join(os.getcwd(), "pdfs", "test2.pdf") converter = PdfConverter(pdf_file) document_text = converter.convert_to_markdown() text_chunks = processor.context_chunks(document_text) text_embeddings = processor.context_embedding(text_chunks) query = "what metric used in this paper for performance evaluation?" query_embedding = processor.rag_query(query) top_results = processor.similarity(query_embedding, text_embeddings, text_chunks, top_k=5) print(top_results)