audit_assistant / auditqa /doc_process.py
ppsingh's picture
Update auditqa/doc_process.py
3b6db6d verified
raw
history blame
4.58 kB
import glob
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
from transformers import AutoTokenizer
from torch import cuda
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from auditqa.reports import files, report_list
device = 'cuda' if cuda.is_available() else 'cpu'
# path to the pdf files
path_to_data = "./data/pdf/"
def process_pdf():
"""
this method reads through the files and report_list to create the vector database
"""
# load all the files using PyMuPDFfLoader
docs = {}
for file in report_list:
try:
docs[file] = PyMuPDFLoader(path_to_data + file + '.pdf').load()
except Exception as e:
print("Exception: ", e)
# text splitter based on the tokenizer of a model of your choosing
# to make texts fit exactly a transformer's context window size
# langchain text splitters: https://python.langchain.com/docs/modules/data_connection/document_transformers/
chunk_size = 256
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5"),
chunk_size=chunk_size,
chunk_overlap=10,
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n"],
)
# we iterate through the files which contain information about its
# 'source'=='category', 'subtype', these are used in UI for document selection
# which will be used later for filtering database
all_documents = {}
categories = list(files.keys())
# iterate through 'source'
for category in categories:
print("documents splitting in source:",category)
all_documents[category] = []
subtypes = list(files[category].keys())
# iterate through 'subtype' within the source
# example source/category == 'District', has subtypes which is district names
for subtype in subtypes:
print("document splitting for subtype:",subtype)
for file in files[category][subtype]:
# create the chunks
doc_processed = text_splitter.split_documents(docs[file])
print("chunks in subtype:",subtype, "are:",len(doc_processed))
# add metadata information
for doc in doc_processed:
doc.metadata["source"] = category
doc.metadata["subtype"] = subtype
doc.metadata["year"] = file[-4:]
doc.metadata["filename"] = file
all_documents[category].append(doc_processed)
# convert list of list to flat list
for key, docs_processed in all_documents.items():
docs_processed = [item for sublist in docs_processed for item in sublist]
print("length of chunks in source:",key, "are:",len(docs_processed))
all_documents[key] = docs_processed
all_documents['allreports'] = [sublist for key,sublist in all_documents.items()]
all_documents['allreports'] = [item for sublist in all_documents['allreports'] for item in sublist]
# define embedding model
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': True},
model_name="BAAI/bge-small-en-v1.5"
)
# placeholder for collection
qdrant_collections = {}
for file,value in all_documents.items():
print("emebddings for:",file)
qdrant_collections[file] = Qdrant.from_documents(
value,
embeddings,
location=":memory:",
collection_name=file,
)
print(qdrant_collections)
print("vector embeddings done")
return qdrant_collections
def get_local_qdrant():
qdrant_collections = {}
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': True},
model_name="BAAI/bge-small-en-v1.5")
list_ = ['Consolidated','District','Ministry','allreports']
for val in list_:
client = QdrantClient(path=f"./data/{val}")
print(client.get_collections())
qdrant_collections[val] = Qdrant(client=client, collection_name=val, embeddings=embeddings, )
return qdrant_collections