Spaces:
Build error
Build error
import os | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.util import cos_sim | |
from modules.pdfExtractor import PdfConverter | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import Document | |
class EmbeddingModel: | |
def __init__(self, model_path=None): | |
if model_path is None: | |
self.model = SentenceTransformer( | |
"thenlper/gte-base", # switch to en/zh for English or Chinese | |
trust_remote_code=True | |
) | |
self.model.save(os.path.join(os.getcwd(), "embeddingModel")) | |
else: | |
self.model = SentenceTransformer(model_path) | |
self.model.max_seq_length = 512 | |
def encode(self, texts): | |
return self.model.encode(texts) | |
class DocumentProcessor: | |
def __init__(self, model, chunk_size=1000, chunk_overlap=200): | |
self.model = model | |
self.chunk_size = chunk_size | |
self.chunk_overlap = chunk_overlap | |
def context_chunks(self, document_text): | |
document = Document(page_content=document_text) | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=self.chunk_size, | |
chunk_overlap=self.chunk_overlap | |
) | |
text_chunks = text_splitter.split_documents([document]) | |
text_content_chunks = [chunk.page_content for chunk in text_chunks] | |
return text_content_chunks | |
def context_embedding(self, text_content_chunks): | |
return [self.model.encode([text]) for text in text_content_chunks] | |
def rag_query(self, query): | |
return self.model.encode([query]) | |
def similarity(self, query_embedding, text_contents_embeddings, text_content_chunks, top_k): | |
similarities = [ | |
(text, cos_sim(embedding, query_embedding[0])) | |
for text, embedding in zip(text_content_chunks, text_contents_embeddings) | |
] | |
similarities_sorted = sorted(similarities, key=lambda x: x[1], reverse=True) | |
top_k_texts = [text for text, _ in similarities_sorted[:top_k]] | |
return top_k_texts | |
# Example usage: | |
if __name__ == "__main__": | |
model = EmbeddingModel(model_path=os.path.join(os.getcwd(), "embeddingModel")) | |
processor = DocumentProcessor(model=model) | |
pdf_file = os.path.join(os.getcwd(), "pdfs", "test2.pdf") | |
converter = PdfConverter(pdf_file) | |
document_text = converter.convert_to_markdown() | |
text_chunks = processor.context_chunks(document_text) | |
text_embeddings = processor.context_embedding(text_chunks) | |
query = "what metric used in this paper for performance evaluation?" | |
query_embedding = processor.rag_query(query) | |
top_results = processor.similarity(query_embedding, text_embeddings, text_chunks, top_k=5) | |
print(top_results) | |