Spaces:
Runtime error
Runtime error
File size: 7,233 Bytes
f5e5ccb aaa979a c726440 f5e5ccb 458b338 f5e5ccb 458b338 b3ec1fd 458b338 b3ec1fd 458b338 f5e5ccb 458b338 f5e5ccb 458b338 f5e5ccb 458b338 f5e5ccb 337280a f5e5ccb ee5bae7 f5e5ccb ee5bae7 f5e5ccb 458b338 f5e5ccb 13cf722 f5e5ccb 13cf722 dc8b4b0 a9598e5 dc8b4b0 07a01c3 dc8b4b0 a8235cb dc8b4b0 13cf722 34798a4 13cf722 b3ec1fd 13cf722 f5e5ccb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import glob
import pandas as pd
import json
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
from transformers import AutoTokenizer
from torch import cuda
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from auditqa.reports import files, report_list
from langchain.docstore.document import Document
import configparser
# read all the necessary variables
device = 'cuda' if cuda.is_available() else 'cpu'
path_to_data = "./reports/"
##---------------------functions -------------------------------------------##
def getconfig(configfile_path:str):
"""
Read the config file
Params
----------------
configfile_path: file path of .cfg file
"""
config = configparser.ConfigParser()
try:
config.read_file(open(configfile_path))
return config
except:
logging.warning("config file not found")
def open_file(filepath):
with open(filepath) as file:
simple_json = json.load(file)
return simple_json
def load_chunks():
"""
this method reads through the files and report_list to create the vector database
"""
# we iterate through the files which contain information about its
# 'source'=='category', 'subtype', these are used in UI for document selection
# which will be used later for filtering database
config = getconfig("./model_params.cfg")
all_documents = {}
categories = list(files.keys())
# iterate through 'source'
for category in categories:
print("documents splitting in source:",category)
all_documents[category] = []
subtypes = list(files[category].keys())
# iterate through 'subtype' within the source
# example source/category == 'District', has subtypes which is district names
for subtype in subtypes:
print("document splitting for subtype:",subtype)
for file in files[category][subtype]:
# load the chunks
try:
doc_processed = open_file(path_to_data + file + "/"+ file+ ".chunks.json" )
except Exception as e:
print("Exception: ", e)
print("chunks in subtype:",subtype, "are:",len(doc_processed))
# add metadata information
chunks_list = []
for doc in doc_processed:
chunks_list.append(Document(page_content= doc['content'],
metadata={"source": category,
"subtype":subtype,
"year":file[-4:],
"filename":file,
"page":doc['metadata']['page'],
"headings":doc['metadata']['headings']}))
all_documents[category].append(chunks_list)
# convert list of list to flat list
for key, docs_processed in all_documents.items():
docs_processed = [item for sublist in docs_processed for item in sublist]
print("length of chunks in source:",key, "are:",len(docs_processed))
all_documents[key] = docs_processed
all_documents['allreports'] = [sublist for key,sublist in all_documents.items()]
all_documents['allreports'] = [item for sublist in all_documents['allreports'] for item in sublist]
# define embedding model
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': bool(int(config.get('retriever','NORMALIZE')))},
model_name=config.get('retriever','MODEL')
)
# placeholder for collection
qdrant_collections = {}
for file,value in all_documents.items():
if file == "allreports":
print("emebddings for:",file)
qdrant_collections[file] = Qdrant.from_documents(
value,
embeddings,
path="/data/local_qdrant",
collection_name=file,
)
print(qdrant_collections)
print("vector embeddings done")
return qdrant_collections
def load_new_chunks():
"""
this method reads through the files and report_list to create the vector database
"""
# we iterate through the files which contain information about its
# 'source'=='category', 'subtype', these are used in UI for document selection
# which will be used later for filtering database
config = getconfig("./model_params.cfg")
files = pd.read_json("./axa_processed_chunks_update.json")
all_documents= []
# iterate through 'source'
for i in range(len(files)):
# load the chunks
try:
doc_processed = open_file(path_to_data + "/chunks/"+ os.path.basename(files.loc[i,'chunks_filepath']))
doc_processed = doc_processed['paragraphs']
except Exception as e:
print("Exception: ", e)
print("chunks in subtype:", files.loc[i,'filename'], "are:",len(doc_processed))
# add metadata information
for doc in doc_processed:
all_documents.append(Document(page_content= str(doc['content']),
metadata={"source": files.loc[i,'category'],
"subtype":os.path.splitext(files.loc[i,'filename'])[0],
"year":str(files.loc[i,'year']),
"filename":files.loc[0,'filename'],
"page":doc['metadata']['page'],
"headings":doc['metadata']['headings']}))
# convert list of list to flat list
print("length of chunks:",len(all_documents))
# define embedding model
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': bool(int(config.get('retriever','NORMALIZE')))},
model_name=config.get('retriever','MODEL')
)
# placeholder for collection
qdrant_collections = {}
qdrant_collections['allreports'] = Qdrant.from_documents(
all_documents,
embeddings,
path="/data/local_qdrant",
collection_name='allreports',
)
print(qdrant_collections)
print("vector embeddings done")
return qdrant_collections
def get_local_qdrant():
"""once the local qdrant server is created this is used to make the connection to exisitng server"""
config = getconfig("./model_params.cfg")
qdrant_collections = {}
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': True},
model_name=config.get('retriever','MODEL'))
client = QdrantClient(path="/data/local_qdrant")
print("Collections in local Qdrant:",client.get_collections())
qdrant_collections['allreports'] = Qdrant(client=client, collection_name='allreports', embeddings=embeddings, )
return qdrant_collections |