Luca Foppiano commited on
Commit
cc3e97d
β€’
2 Parent(s): 320f843 b19b313

Merge pull request #18 from lfoppiano/add-memory

Browse files
Files changed (3) hide show
  1. README.md +7 -4
  2. document_qa/document_qa_engine.py +22 -8
  3. streamlit_app.py +24 -2
README.md CHANGED
@@ -16,11 +16,14 @@ license: apache-2.0
16
 
17
  ## Introduction
18
 
19
- Question/Answering on scientific documents using LLMs (OpenAI, Mistral, ~~LLama2,~~ etc..).
20
- This application is the frontend for testing the RAG (Retrieval Augmented Generation) on scientific documents, that we are developing at NIMS.
21
- Differently to most of the project, we focus on scientific articles. We target only the full-text using [Grobid](https://github.com/kermitt2/grobid) that provide and cleaner results than the raw PDF2Text converter (which is comparable with most of other solutions).
 
22
 
23
- **NER in LLM response**: The responses from the LLMs are post-processed to extract <span stype="color:yellow">physical quantities, measurements</span> (with [grobid-quantities](https://github.com/kermitt2/grobid-quantities)) and <span stype="color:blue">materials</span> mentions (with [grobid-superconductors](https://github.com/lfoppiano/grobid-superconductors)).
 
 
24
 
25
  **Demos**:
26
  - (on HuggingFace spaces): https://lfoppiano-document-qa.hf.space/
 
16
 
17
  ## Introduction
18
 
19
+ Question/Answering on scientific documents using LLMs: ChatGPT-3.5-turbo, Mistral-7b-instruct and Zephyr-7b-beta.
20
+ The streamlit application demonstrate the implementaiton of a RAG (Retrieval Augmented Generation) on scientific documents, that we are developing at NIMS (National Institute for Materials Science), in Tsukuba, Japan.
21
+ Differently to most of the projects, we focus on scientific articles.
22
+ We target only the full-text using [Grobid](https://github.com/kermitt2/grobid) that provide and cleaner results than the raw PDF2Text converter (which is comparable with most of other solutions).
23
 
24
+ Additionally, this frontend provides the visualisation of named entities on LLM responses to extract <span stype="color:yellow">physical quantities, measurements</span> (with [grobid-quantities](https://github.com/kermitt2/grobid-quantities)) and <span stype="color:blue">materials</span> mentions (with [grobid-superconductors](https://github.com/lfoppiano/grobid-superconductors)).
25
+
26
+ The conversation is backed up by a sliding window memory (top 4 more recent messages) that help refers to information previously discussed in the chat.
27
 
28
  **Demos**:
29
  - (on HuggingFace spaces): https://lfoppiano-document-qa.hf.space/
document_qa/document_qa_engine.py CHANGED
@@ -23,7 +23,13 @@ class DocumentQAEngine:
23
  embeddings_map_from_md5 = {}
24
  embeddings_map_to_md5 = {}
25
 
26
- def __init__(self, llm, embedding_function, qa_chain_type="stuff", embeddings_root_path=None, grobid_url=None):
 
 
 
 
 
 
27
  self.embedding_function = embedding_function
28
  self.llm = llm
29
  self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
@@ -81,14 +87,14 @@ class DocumentQAEngine:
81
  return self.embeddings_map_from_md5[md5]
82
 
83
  def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
84
- verbose=False) -> (
85
  Any, str):
86
  # self.load_embeddings(self.embeddings_root_path)
87
 
88
  if verbose:
89
  print(query)
90
 
91
- response = self._run_query(doc_id, query, context_size=context_size)
92
  response = response['output_text'] if 'output_text' in response else response
93
 
94
  if verbose:
@@ -138,9 +144,15 @@ class DocumentQAEngine:
138
 
139
  return parsed_output
140
 
141
- def _run_query(self, doc_id, query, context_size=4):
142
  relevant_documents = self._get_context(doc_id, query, context_size)
143
- return self.chain.run(input_documents=relevant_documents, question=query)
 
 
 
 
 
 
144
  # return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
145
 
146
  def _get_context(self, doc_id, query, context_size=4):
@@ -150,6 +162,7 @@ class DocumentQAEngine:
150
  return relevant_documents
151
 
152
  def get_all_context_by_document(self, doc_id):
 
153
  db = self.embeddings_dict[doc_id]
154
  docs = db.get()
155
  return docs['documents']
@@ -161,6 +174,7 @@ class DocumentQAEngine:
161
  return relevant_documents
162
 
163
  def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
 
164
  if verbose:
165
  print("File", pdf_file_path)
166
  filename = Path(pdf_file_path).stem
@@ -215,12 +229,11 @@ class DocumentQAEngine:
215
  self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
216
  collection_name=hash)
217
 
218
-
219
  self.embeddings_root_path = None
220
 
221
  return hash
222
 
223
- def create_embeddings(self, pdfs_dir_path: Path):
224
  input_files = []
225
  for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
226
  for file_ in files:
@@ -238,7 +251,8 @@ class DocumentQAEngine:
238
  print(data_path, "exists. Skipping it ")
239
  continue
240
 
241
- texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=500, perc_overlap=0.1)
 
242
  filename = metadata[0]['filename']
243
 
244
  vector_db_document = Chroma.from_texts(texts,
 
23
  embeddings_map_from_md5 = {}
24
  embeddings_map_to_md5 = {}
25
 
26
+ def __init__(self,
27
+ llm,
28
+ embedding_function,
29
+ qa_chain_type="stuff",
30
+ embeddings_root_path=None,
31
+ grobid_url=None,
32
+ ):
33
  self.embedding_function = embedding_function
34
  self.llm = llm
35
  self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
 
87
  return self.embeddings_map_from_md5[md5]
88
 
89
  def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
90
+ verbose=False, memory=None) -> (
91
  Any, str):
92
  # self.load_embeddings(self.embeddings_root_path)
93
 
94
  if verbose:
95
  print(query)
96
 
97
+ response = self._run_query(doc_id, query, context_size=context_size, memory=memory)
98
  response = response['output_text'] if 'output_text' in response else response
99
 
100
  if verbose:
 
144
 
145
  return parsed_output
146
 
147
+ def _run_query(self, doc_id, query, memory=None, context_size=4):
148
  relevant_documents = self._get_context(doc_id, query, context_size)
149
+ if memory:
150
+ return self.chain.run(input_documents=relevant_documents,
151
+ question=query)
152
+ else:
153
+ return self.chain.run(input_documents=relevant_documents,
154
+ question=query,
155
+ memory=memory)
156
  # return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
157
 
158
  def _get_context(self, doc_id, query, context_size=4):
 
162
  return relevant_documents
163
 
164
  def get_all_context_by_document(self, doc_id):
165
+ """Return the full context from the document"""
166
  db = self.embeddings_dict[doc_id]
167
  docs = db.get()
168
  return docs['documents']
 
174
  return relevant_documents
175
 
176
  def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
177
+ """Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately"""
178
  if verbose:
179
  print("File", pdf_file_path)
180
  filename = Path(pdf_file_path).stem
 
229
  self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
230
  collection_name=hash)
231
 
 
232
  self.embeddings_root_path = None
233
 
234
  return hash
235
 
236
+ def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1):
237
  input_files = []
238
  for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
239
  for file_ in files:
 
251
  print(data_path, "exists. Skipping it ")
252
  continue
253
 
254
+ texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size,
255
+ perc_overlap=perc_overlap)
256
  filename = metadata[0]['filename']
257
 
258
  vector_db_document = Chroma.from_texts(texts,
streamlit_app.py CHANGED
@@ -6,6 +6,7 @@ from tempfile import NamedTemporaryFile
6
  import dotenv
7
  from grobid_quantities.quantities import QuantitiesAPI
8
  from langchain.llms.huggingface_hub import HuggingFaceHub
 
9
 
10
  dotenv.load_dotenv(override=True)
11
 
@@ -51,6 +52,9 @@ if 'ner_processing' not in st.session_state:
51
  if 'uploaded' not in st.session_state:
52
  st.session_state['uploaded'] = False
53
 
 
 
 
54
  st.set_page_config(
55
  page_title="Scientific Document Insights Q/A",
56
  page_icon="πŸ“",
@@ -67,6 +71,11 @@ def new_file():
67
  st.session_state['loaded_embeddings'] = None
68
  st.session_state['doc_id'] = None
69
  st.session_state['uploaded'] = True
 
 
 
 
 
70
 
71
 
72
  # @st.cache_resource
@@ -97,6 +106,7 @@ def init_qa(model, api_key=None):
97
  else:
98
  st.error("The model was not loaded properly. Try reloading. ")
99
  st.stop()
 
100
 
101
  return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
102
 
@@ -168,7 +178,7 @@ with st.sidebar:
168
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])
169
 
170
  st.markdown(
171
- ":warning: Mistral and Zephyr are free to use, however requests might hit limits of the huggingface free API and fail. :warning: ")
172
 
173
  if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
174
  if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
@@ -205,6 +215,11 @@ with st.sidebar:
205
  # else:
206
  # is_api_key_provided = st.session_state['api_key']
207
 
 
 
 
 
 
208
  st.title("πŸ“ Scientific Document Insights Q/A")
209
  st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
210
 
@@ -297,7 +312,8 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
297
  elif mode == "LLM":
298
  with st.spinner("Generating response..."):
299
  _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
300
- context_size=context_size)
 
301
 
302
  if not text_response:
303
  st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
@@ -316,5 +332,11 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
316
  st.write(text_response)
317
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
318
 
 
 
 
 
 
 
319
  elif st.session_state.loaded_embeddings and st.session_state.doc_id:
320
  play_old_messages()
 
6
  import dotenv
7
  from grobid_quantities.quantities import QuantitiesAPI
8
  from langchain.llms.huggingface_hub import HuggingFaceHub
9
+ from langchain.memory import ConversationBufferWindowMemory
10
 
11
  dotenv.load_dotenv(override=True)
12
 
 
52
  if 'uploaded' not in st.session_state:
53
  st.session_state['uploaded'] = False
54
 
55
+ if 'memory' not in st.session_state:
56
+ st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
57
+
58
  st.set_page_config(
59
  page_title="Scientific Document Insights Q/A",
60
  page_icon="πŸ“",
 
71
  st.session_state['loaded_embeddings'] = None
72
  st.session_state['doc_id'] = None
73
  st.session_state['uploaded'] = True
74
+ st.session_state['memory'].clear()
75
+
76
+
77
+ def clear_memory():
78
+ st.session_state['memory'].clear()
79
 
80
 
81
  # @st.cache_resource
 
106
  else:
107
  st.error("The model was not loaded properly. Try reloading. ")
108
  st.stop()
109
+ return
110
 
111
  return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
112
 
 
178
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])
179
 
180
  st.markdown(
181
+ ":warning: Mistral and Zephyr are **FREE** to use. Requests might fail anytime. Use at your own risk. :warning: ")
182
 
183
  if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
184
  if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
 
215
  # else:
216
  # is_api_key_provided = st.session_state['api_key']
217
 
218
+ st.button(
219
+ 'Reset chat memory.',
220
+ on_click=clear_memory(),
221
+ help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.")
222
+
223
  st.title("πŸ“ Scientific Document Insights Q/A")
224
  st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
225
 
 
312
  elif mode == "LLM":
313
  with st.spinner("Generating response..."):
314
  _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
315
+ context_size=context_size,
316
+ memory=st.session_state.memory)
317
 
318
  if not text_response:
319
  st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
 
332
  st.write(text_response)
333
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
334
 
335
+ for id in range(0, len(st.session_state.messages), 2):
336
+ question = st.session_state.messages[id]['content']
337
+ if len(st.session_state.messages) > id + 1:
338
+ answer = st.session_state.messages[id + 1]['content']
339
+ st.session_state.memory.save_context({"input": question}, {"output": answer})
340
+
341
  elif st.session_state.loaded_embeddings and st.session_state.doc_id:
342
  play_old_messages()