Groove-GPT / test.py
LordFarquaad42's picture
please work
58964c1
import chromadb
from chromadb.utils import embedding_functions
from sentence_transformers import SentenceTransformer
from pypdf import PdfReader as reader
import os
# experiment with larger models
MODEL_NAME = "Salesforce/SFR-Embedding-Mistral" # ~ 1.2 gb
DISTANCE_FUNCTION = "cosine"
COLLECTION_NAME = "scheme"
EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=MODEL_NAME)
client = chromadb.PersistentClient(path="./chromadb_linux_two/")
print("Getting Collection")
schemer = client.create_collection(
name=COLLECTION_NAME,
embedding_function=EMBEDDING_FUNC,
)
print(f"Number enteries in collection: {schemer.count()}")
###########################################################################
def get_text(pdf_path: str) -> str:
doc = reader(pdf_path)
text_content = ''
for page in range(len(doc.pages)):
page = doc.pages[page]
text_content += page.extract_text()
return text_content
def clean_text(text: str)-> str:
return text.replace('\n', ' ')
files = os.listdir('./data/')
dataset = []
for file in files:
if file.endswith(".pdf"):
text_content = str(get_text(os.path.join('data', file)))
dataset.append(text_content)
print(file)
batch_size = 1024
padding_element = '.'
batch_documents = []
batch_ids = []
batch_metadata = []
for i, document in enumerate(dataset):
# entering each batch
for j in range(0, len(document), batch_size):
try:
j_end = min(j + batch_size, len(document))
batch = document[j:min(j+batch_size, len(document))]
if len(batch) < batch_size: # Extend the batch with the padding elements
padding_needed = batch_size - len(batch)
batch = batch + str(padding_element * padding_needed)
print(f"Doc {i+1}/{len(dataset)}: Batch {j}/{len(document)}")
text = clean_text(batch)
batch_documents.append(text)
batch_ids.append(f'batch{i}{j}{batch[0]}')
batch_metadata.append({"length": len(batch)})
except Exception as e:
print(f"Error processing batch {j} of document {i}: {e}")
print("Upserting into collection")
schemer.upsert(
ids=[str(id) for id in batch_ids],
metadatas=batch_metadata,
documents=batch_documents,
)