duplicate_llm / modules /ragoop.py
Kurian07's picture
Upload 15 files
60fc5e8 verified
raw
history blame contribute delete
2.87 kB
import os
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from modules.pdfExtractor import PdfConverter
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
class EmbeddingModel:
def __init__(self, model_path=None):
if model_path is None:
self.model = SentenceTransformer(
"thenlper/gte-base", # switch to en/zh for English or Chinese
trust_remote_code=True
)
self.model.save(os.path.join(os.getcwd(), "embeddingModel"))
else:
self.model = SentenceTransformer(model_path)
self.model.max_seq_length = 512
def encode(self, texts):
return self.model.encode(texts)
class DocumentProcessor:
def __init__(self, model, chunk_size=1000, chunk_overlap=200):
self.model = model
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def context_chunks(self, document_text):
document = Document(page_content=document_text)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap
)
text_chunks = text_splitter.split_documents([document])
text_content_chunks = [chunk.page_content for chunk in text_chunks]
return text_content_chunks
def context_embedding(self, text_content_chunks):
return [self.model.encode([text]) for text in text_content_chunks]
def rag_query(self, query):
return self.model.encode([query])
def similarity(self, query_embedding, text_contents_embeddings, text_content_chunks, top_k):
similarities = [
(text, cos_sim(embedding, query_embedding[0]))
for text, embedding in zip(text_content_chunks, text_contents_embeddings)
]
similarities_sorted = sorted(similarities, key=lambda x: x[1], reverse=True)
top_k_texts = [text for text, _ in similarities_sorted[:top_k]]
return top_k_texts
# Example usage:
if __name__ == "__main__":
model = EmbeddingModel(model_path=os.path.join(os.getcwd(), "embeddingModel"))
processor = DocumentProcessor(model=model)
pdf_file = os.path.join(os.getcwd(), "pdfs", "test2.pdf")
converter = PdfConverter(pdf_file)
document_text = converter.convert_to_markdown()
text_chunks = processor.context_chunks(document_text)
text_embeddings = processor.context_embedding(text_chunks)
query = "what metric used in this paper for performance evaluation?"
query_embedding = processor.rag_query(query)
top_results = processor.similarity(query_embedding, text_embeddings, text_chunks, top_k=5)
print(top_results)