don_Simon / vectorstore.py
adhinojosa's picture
Create vectorstore.py
cc3f463 verified
raw
history blame
No virus
3.06 kB
import fitz
import re
import chromadb
from chromadb.utils import embedding_functions
import uuid
import torch
from langchain.text_splitter import SentenceTransformersTokenTextSplitter
from sentence_transformers import CrossEncoder
emb_model_name = "sentence-transformers/all-mpnet-base-v2"
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-mpnet-base-v2")
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
client = chromadb.PersistentClient(path='.vectorstore')
collection = client.get_or_create_collection(name='huerto',embedding_function=sentence_transformer_ef,metadata={"hnsw:space": "cosine"})
def parse_pdf(file) :
'''transforma un pdf en una lista'''
pdf = fitz.open(file)
output = []
for page_num in range(pdf.page_count):
page = pdf[page_num]
text = page.get_text()
# Merge hyphenated words
text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text)
# Fix newlines in the middle of sentences
text = re.sub(r"(?<!\n\s)\n(?!\s\n)", " ", text.strip())
# Remove multiple newlines
text = re.sub(r"\n\s*\n", "\n\n", text)
output.append(text)
return output
def file_to_splits(file,tokens_per_chunk,chunk_overlap ):
'''Transforma un txt o pdf en una en una lista que contiene piezas con metadata'''
text_splitter = SentenceTransformersTokenTextSplitter(
model_name=emb_model_name,
tokens_per_chunk=tokens_per_chunk,
chunk_overlap=chunk_overlap,
)
text = parse_pdf(file)
doc_chunks = []
for i in range(len(text)):
chunks = text_splitter.split_text(text[i])
for j in range(len(chunks)):
doc = [chunks[j], {"source": file.split('/')[-1] ,"page": i+1, "chunk": j+1}, str(uuid.uuid4())]
doc_chunks.append(doc)
return doc_chunks
def file_to_vs(file,tokens_per_chunk, chunk_overlap):
try:
splits=[]
splits.extend(file_to_splits(file,
tokens_per_chunk,
chunk_overlap))
splits = list(zip(*splits))
collection.add(documents=list(splits[0]), metadatas=list(splits[1]), ids= list(splits[2]))
return 'Files uploaded successfully'
except Exception as e:
return str(e)
def similarity_search(query,k):
sources = {}
ss_out= collection.query(query_texts=[query],n_results=20)
for _ in range(len(ss_out['ids'][0])):
score = float(cross_encoder.predict([query,ss_out['documents'][0][_]],activation_fct=torch.nn.Sigmoid()))
sources[str(_)]={"page_content":ss_out['documents'][0][_],"metadata":ss_out['metadatas'][0][_],"similarity":round(score*100,2)}
sorted_sources = sorted(sources.items(), key=lambda x: x[1]['similarity'], reverse=True)
sources = {}
for _ in range(k):
sources[str(_)] = sorted_sources[_][1]
return sources