Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,879 Bytes
f5e5ccb c726440 f5e5ccb 458b338 f5e5ccb 458b338 f5e5ccb 458b338 f5e5ccb 458b338 f5e5ccb 458b338 f5e5ccb 337280a f5e5ccb ee5bae7 f5e5ccb ee5bae7 f5e5ccb 458b338 f5e5ccb 13cf722 f5e5ccb 13cf722 f5e5ccb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import glob
import json
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
from transformers import AutoTokenizer
from torch import cuda
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from auditqa.reports import files, report_list
from langchain.docstore.document import Document
import configparser
# read all the necessary variables
device = 'cuda' if cuda.is_available() else 'cpu'
path_to_data = "./reports/"
##---------------------fucntions -------------------------------------------##
def getconfig(configfile_path:str):
"""
configfile_path: file path of .cfg file
"""
config = configparser.ConfigParser()
try:
config.read_file(open(configfile_path))
return config
except:
logging.warning("config file not found")
def open_file(filepath):
with open(filepath) as file:
simple_json = json.load(file)
return simple_json
def load_chunks():
"""
this method reads through the files and report_list to create the vector database
"""
# we iterate through the files which contain information about its
# 'source'=='category', 'subtype', these are used in UI for document selection
# which will be used later for filtering database
config = getconfig("./model_params.cfg")
all_documents = {}
categories = list(files.keys())
# iterate through 'source'
for category in categories:
print("documents splitting in source:",category)
all_documents[category] = []
subtypes = list(files[category].keys())
# iterate through 'subtype' within the source
# example source/category == 'District', has subtypes which is district names
for subtype in subtypes:
print("document splitting for subtype:",subtype)
for file in files[category][subtype]:
# load the chunks
try:
doc_processed = open_file(path_to_data + file + "/"+ file+ ".chunks.json" )
except Exception as e:
print("Exception: ", e)
print("chunks in subtype:",subtype, "are:",len(doc_processed))
# add metadata information
chunks_list = []
for doc in doc_processed:
chunks_list.append(Document(page_content= doc['content'],
metadata={"source": category,
"subtype":subtype,
"year":file[-4:],
"filename":file,
"page":doc['metadata']['page'],
"headings":doc['metadata']['headings']}))
all_documents[category].append(chunks_list)
# convert list of list to flat list
for key, docs_processed in all_documents.items():
docs_processed = [item for sublist in docs_processed for item in sublist]
print("length of chunks in source:",key, "are:",len(docs_processed))
all_documents[key] = docs_processed
all_documents['allreports'] = [sublist for key,sublist in all_documents.items()]
all_documents['allreports'] = [item for sublist in all_documents['allreports'] for item in sublist]
# define embedding model
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': bool(int(config.get('retriever','NORMALIZE')))},
model_name=config.get('retriever','MODEL')
)
# placeholder for collection
qdrant_collections = {}
for file,value in all_documents.items():
if file == "allreports":
print("emebddings for:",file)
qdrant_collections[file] = Qdrant.from_documents(
value,
embeddings,
path="/data/local_qdrant",
collection_name=file,
)
print(qdrant_collections)
print("vector embeddings done")
return qdrant_collections
def get_local_qdrant():
config = getconfig("./model_params.cfg")
qdrant_collections = {}
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': True},
model_name=config.get('retriever','MODEL'))
#list_ = ['Consolidated','District','Ministry','allreports']
#for val in list_:
client = QdrantClient(path="/data/local_qdrant")
print(client.get_collections())
qdrant_collections['allreports'] = Qdrant(client=client, collection_name='allreports', embeddings=embeddings, )
return qdrant_collections |