Spaces:
Sleeping
Sleeping
cosmetics
Browse files- document_qa/document_qa_engine.py +22 -8
- streamlit_app.py +1 -0
document_qa/document_qa_engine.py
CHANGED
@@ -23,7 +23,13 @@ class DocumentQAEngine:
|
|
23 |
embeddings_map_from_md5 = {}
|
24 |
embeddings_map_to_md5 = {}
|
25 |
|
26 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
self.embedding_function = embedding_function
|
28 |
self.llm = llm
|
29 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
@@ -81,14 +87,14 @@ class DocumentQAEngine:
|
|
81 |
return self.embeddings_map_from_md5[md5]
|
82 |
|
83 |
def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
|
84 |
-
verbose=False) -> (
|
85 |
Any, str):
|
86 |
# self.load_embeddings(self.embeddings_root_path)
|
87 |
|
88 |
if verbose:
|
89 |
print(query)
|
90 |
|
91 |
-
response = self._run_query(doc_id, query, context_size=context_size)
|
92 |
response = response['output_text'] if 'output_text' in response else response
|
93 |
|
94 |
if verbose:
|
@@ -138,9 +144,15 @@ class DocumentQAEngine:
|
|
138 |
|
139 |
return parsed_output
|
140 |
|
141 |
-
def _run_query(self, doc_id, query, context_size=4):
|
142 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
# return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
|
145 |
|
146 |
def _get_context(self, doc_id, query, context_size=4):
|
@@ -150,6 +162,7 @@ class DocumentQAEngine:
|
|
150 |
return relevant_documents
|
151 |
|
152 |
def get_all_context_by_document(self, doc_id):
|
|
|
153 |
db = self.embeddings_dict[doc_id]
|
154 |
docs = db.get()
|
155 |
return docs['documents']
|
@@ -161,6 +174,7 @@ class DocumentQAEngine:
|
|
161 |
return relevant_documents
|
162 |
|
163 |
def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
|
|
|
164 |
if verbose:
|
165 |
print("File", pdf_file_path)
|
166 |
filename = Path(pdf_file_path).stem
|
@@ -215,12 +229,11 @@ class DocumentQAEngine:
|
|
215 |
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
|
216 |
collection_name=hash)
|
217 |
|
218 |
-
|
219 |
self.embeddings_root_path = None
|
220 |
|
221 |
return hash
|
222 |
|
223 |
-
def create_embeddings(self, pdfs_dir_path: Path):
|
224 |
input_files = []
|
225 |
for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
|
226 |
for file_ in files:
|
@@ -238,7 +251,8 @@ class DocumentQAEngine:
|
|
238 |
print(data_path, "exists. Skipping it ")
|
239 |
continue
|
240 |
|
241 |
-
texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=
|
|
|
242 |
filename = metadata[0]['filename']
|
243 |
|
244 |
vector_db_document = Chroma.from_texts(texts,
|
|
|
23 |
embeddings_map_from_md5 = {}
|
24 |
embeddings_map_to_md5 = {}
|
25 |
|
26 |
+
def __init__(self,
|
27 |
+
llm,
|
28 |
+
embedding_function,
|
29 |
+
qa_chain_type="stuff",
|
30 |
+
embeddings_root_path=None,
|
31 |
+
grobid_url=None,
|
32 |
+
):
|
33 |
self.embedding_function = embedding_function
|
34 |
self.llm = llm
|
35 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
|
|
87 |
return self.embeddings_map_from_md5[md5]
|
88 |
|
89 |
def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
|
90 |
+
verbose=False, memory=None) -> (
|
91 |
Any, str):
|
92 |
# self.load_embeddings(self.embeddings_root_path)
|
93 |
|
94 |
if verbose:
|
95 |
print(query)
|
96 |
|
97 |
+
response = self._run_query(doc_id, query, context_size=context_size, memory=memory)
|
98 |
response = response['output_text'] if 'output_text' in response else response
|
99 |
|
100 |
if verbose:
|
|
|
144 |
|
145 |
return parsed_output
|
146 |
|
147 |
+
def _run_query(self, doc_id, query, memory=None, context_size=4):
|
148 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
149 |
+
if memory:
|
150 |
+
return self.chain.run(input_documents=relevant_documents,
|
151 |
+
question=query)
|
152 |
+
else:
|
153 |
+
return self.chain.run(input_documents=relevant_documents,
|
154 |
+
question=query,
|
155 |
+
memory=memory)
|
156 |
# return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
|
157 |
|
158 |
def _get_context(self, doc_id, query, context_size=4):
|
|
|
162 |
return relevant_documents
|
163 |
|
164 |
def get_all_context_by_document(self, doc_id):
|
165 |
+
"""Return the full context from the document"""
|
166 |
db = self.embeddings_dict[doc_id]
|
167 |
docs = db.get()
|
168 |
return docs['documents']
|
|
|
174 |
return relevant_documents
|
175 |
|
176 |
def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
|
177 |
+
"""Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately"""
|
178 |
if verbose:
|
179 |
print("File", pdf_file_path)
|
180 |
filename = Path(pdf_file_path).stem
|
|
|
229 |
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
|
230 |
collection_name=hash)
|
231 |
|
|
|
232 |
self.embeddings_root_path = None
|
233 |
|
234 |
return hash
|
235 |
|
236 |
+
def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1):
|
237 |
input_files = []
|
238 |
for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
|
239 |
for file_ in files:
|
|
|
251 |
print(data_path, "exists. Skipping it ")
|
252 |
continue
|
253 |
|
254 |
+
texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size,
|
255 |
+
perc_overlap=perc_overlap)
|
256 |
filename = metadata[0]['filename']
|
257 |
|
258 |
vector_db_document = Chroma.from_texts(texts,
|
streamlit_app.py
CHANGED
@@ -97,6 +97,7 @@ def init_qa(model, api_key=None):
|
|
97 |
else:
|
98 |
st.error("The model was not loaded properly. Try reloading. ")
|
99 |
st.stop()
|
|
|
100 |
|
101 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
102 |
|
|
|
97 |
else:
|
98 |
st.error("The model was not loaded properly. Try reloading. ")
|
99 |
st.stop()
|
100 |
+
return
|
101 |
|
102 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
103 |
|