Spaces:
Runtime error
Runtime error
File size: 4,985 Bytes
f5e5ccb c726440 f5e5ccb 458b338 f5e5ccb 458b338 b3ec1fd 458b338 b3ec1fd 458b338 f5e5ccb 458b338 f5e5ccb 458b338 f5e5ccb 458b338 f5e5ccb 337280a f5e5ccb ee5bae7 f5e5ccb ee5bae7 f5e5ccb 458b338 f5e5ccb 13cf722 f5e5ccb 13cf722 34798a4 13cf722 b3ec1fd 13cf722 f5e5ccb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import glob
import json
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTransformersTokenTextSplitter
from transformers import AutoTokenizer
from torch import cuda
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from auditqa.reports import files, report_list
from langchain.docstore.document import Document
import configparser
# read all the necessary variables
device = 'cuda' if cuda.is_available() else 'cpu'
path_to_data = "./reports/"
##---------------------functions -------------------------------------------##
def getconfig(configfile_path:str):
"""
Read the config file
Params
----------------
configfile_path: file path of .cfg file
"""
config = configparser.ConfigParser()
try:
config.read_file(open(configfile_path))
return config
except:
logging.warning("config file not found")
def open_file(filepath):
with open(filepath) as file:
simple_json = json.load(file)
return simple_json
def load_chunks():
"""
this method reads through the files and report_list to create the vector database
"""
# we iterate through the files which contain information about its
# 'source'=='category', 'subtype', these are used in UI for document selection
# which will be used later for filtering database
config = getconfig("./model_params.cfg")
all_documents = {}
categories = list(files.keys())
# iterate through 'source'
for category in categories:
print("documents splitting in source:",category)
all_documents[category] = []
subtypes = list(files[category].keys())
# iterate through 'subtype' within the source
# example source/category == 'District', has subtypes which is district names
for subtype in subtypes:
print("document splitting for subtype:",subtype)
for file in files[category][subtype]:
# load the chunks
try:
doc_processed = open_file(path_to_data + file + "/"+ file+ ".chunks.json" )
except Exception as e:
print("Exception: ", e)
print("chunks in subtype:",subtype, "are:",len(doc_processed))
# add metadata information
chunks_list = []
for doc in doc_processed:
chunks_list.append(Document(page_content= doc['content'],
metadata={"source": category,
"subtype":subtype,
"year":file[-4:],
"filename":file,
"page":doc['metadata']['page'],
"headings":doc['metadata']['headings']}))
all_documents[category].append(chunks_list)
# convert list of list to flat list
for key, docs_processed in all_documents.items():
docs_processed = [item for sublist in docs_processed for item in sublist]
print("length of chunks in source:",key, "are:",len(docs_processed))
all_documents[key] = docs_processed
all_documents['allreports'] = [sublist for key,sublist in all_documents.items()]
all_documents['allreports'] = [item for sublist in all_documents['allreports'] for item in sublist]
# define embedding model
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': bool(int(config.get('retriever','NORMALIZE')))},
model_name=config.get('retriever','MODEL')
)
# placeholder for collection
qdrant_collections = {}
for file,value in all_documents.items():
if file == "allreports":
print("emebddings for:",file)
qdrant_collections[file] = Qdrant.from_documents(
value,
embeddings,
path="/data/local_qdrant",
collection_name=file,
)
print(qdrant_collections)
print("vector embeddings done")
return qdrant_collections
def get_local_qdrant():
"""once the local qdrant server is created this is used to make the connection to exisitng server"""
config = getconfig("./model_params.cfg")
qdrant_collections = {}
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': True},
model_name=config.get('retriever','MODEL'))
client = QdrantClient(path="/data/local_qdrant")
print("Collections in local Qdrant:",client.get_collections())
qdrant_collections['allreports'] = Qdrant(client=client, collection_name='allreports', embeddings=embeddings, )
return qdrant_collections |