lfoppiano commited on
Commit
137e5e2
1 Parent(s): 320f843
document_qa/document_qa_engine.py CHANGED
@@ -23,7 +23,13 @@ class DocumentQAEngine:
23
  embeddings_map_from_md5 = {}
24
  embeddings_map_to_md5 = {}
25
 
26
- def __init__(self, llm, embedding_function, qa_chain_type="stuff", embeddings_root_path=None, grobid_url=None):
 
 
 
 
 
 
27
  self.embedding_function = embedding_function
28
  self.llm = llm
29
  self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
@@ -81,14 +87,14 @@ class DocumentQAEngine:
81
  return self.embeddings_map_from_md5[md5]
82
 
83
  def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
84
- verbose=False) -> (
85
  Any, str):
86
  # self.load_embeddings(self.embeddings_root_path)
87
 
88
  if verbose:
89
  print(query)
90
 
91
- response = self._run_query(doc_id, query, context_size=context_size)
92
  response = response['output_text'] if 'output_text' in response else response
93
 
94
  if verbose:
@@ -138,9 +144,15 @@ class DocumentQAEngine:
138
 
139
  return parsed_output
140
 
141
- def _run_query(self, doc_id, query, context_size=4):
142
  relevant_documents = self._get_context(doc_id, query, context_size)
143
- return self.chain.run(input_documents=relevant_documents, question=query)
 
 
 
 
 
 
144
  # return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
145
 
146
  def _get_context(self, doc_id, query, context_size=4):
@@ -150,6 +162,7 @@ class DocumentQAEngine:
150
  return relevant_documents
151
 
152
  def get_all_context_by_document(self, doc_id):
 
153
  db = self.embeddings_dict[doc_id]
154
  docs = db.get()
155
  return docs['documents']
@@ -161,6 +174,7 @@ class DocumentQAEngine:
161
  return relevant_documents
162
 
163
  def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
 
164
  if verbose:
165
  print("File", pdf_file_path)
166
  filename = Path(pdf_file_path).stem
@@ -215,12 +229,11 @@ class DocumentQAEngine:
215
  self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
216
  collection_name=hash)
217
 
218
-
219
  self.embeddings_root_path = None
220
 
221
  return hash
222
 
223
- def create_embeddings(self, pdfs_dir_path: Path):
224
  input_files = []
225
  for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
226
  for file_ in files:
@@ -238,7 +251,8 @@ class DocumentQAEngine:
238
  print(data_path, "exists. Skipping it ")
239
  continue
240
 
241
- texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=500, perc_overlap=0.1)
 
242
  filename = metadata[0]['filename']
243
 
244
  vector_db_document = Chroma.from_texts(texts,
 
23
  embeddings_map_from_md5 = {}
24
  embeddings_map_to_md5 = {}
25
 
26
+ def __init__(self,
27
+ llm,
28
+ embedding_function,
29
+ qa_chain_type="stuff",
30
+ embeddings_root_path=None,
31
+ grobid_url=None,
32
+ ):
33
  self.embedding_function = embedding_function
34
  self.llm = llm
35
  self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
 
87
  return self.embeddings_map_from_md5[md5]
88
 
89
  def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
90
+ verbose=False, memory=None) -> (
91
  Any, str):
92
  # self.load_embeddings(self.embeddings_root_path)
93
 
94
  if verbose:
95
  print(query)
96
 
97
+ response = self._run_query(doc_id, query, context_size=context_size, memory=memory)
98
  response = response['output_text'] if 'output_text' in response else response
99
 
100
  if verbose:
 
144
 
145
  return parsed_output
146
 
147
+ def _run_query(self, doc_id, query, memory=None, context_size=4):
148
  relevant_documents = self._get_context(doc_id, query, context_size)
149
+ if memory:
150
+ return self.chain.run(input_documents=relevant_documents,
151
+ question=query)
152
+ else:
153
+ return self.chain.run(input_documents=relevant_documents,
154
+ question=query,
155
+ memory=memory)
156
  # return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
157
 
158
  def _get_context(self, doc_id, query, context_size=4):
 
162
  return relevant_documents
163
 
164
  def get_all_context_by_document(self, doc_id):
165
+ """Return the full context from the document"""
166
  db = self.embeddings_dict[doc_id]
167
  docs = db.get()
168
  return docs['documents']
 
174
  return relevant_documents
175
 
176
  def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
177
+ """Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately"""
178
  if verbose:
179
  print("File", pdf_file_path)
180
  filename = Path(pdf_file_path).stem
 
229
  self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
230
  collection_name=hash)
231
 
 
232
  self.embeddings_root_path = None
233
 
234
  return hash
235
 
236
+ def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1):
237
  input_files = []
238
  for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
239
  for file_ in files:
 
251
  print(data_path, "exists. Skipping it ")
252
  continue
253
 
254
+ texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size,
255
+ perc_overlap=perc_overlap)
256
  filename = metadata[0]['filename']
257
 
258
  vector_db_document = Chroma.from_texts(texts,
streamlit_app.py CHANGED
@@ -97,6 +97,7 @@ def init_qa(model, api_key=None):
97
  else:
98
  st.error("The model was not loaded properly. Try reloading. ")
99
  st.stop()
 
100
 
101
  return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
102
 
 
97
  else:
98
  st.error("The model was not loaded properly. Try reloading. ")
99
  st.stop()
100
+ return
101
 
102
  return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
103