Spaces:
Running
Running
Merge pull request #18 from lfoppiano/add-memory
Browse files- README.md +7 -4
- document_qa/document_qa_engine.py +22 -8
- streamlit_app.py +24 -2
README.md
CHANGED
@@ -16,11 +16,14 @@ license: apache-2.0
|
|
16 |
|
17 |
## Introduction
|
18 |
|
19 |
-
Question/Answering on scientific documents using LLMs
|
20 |
-
|
21 |
-
Differently to most of the
|
|
|
22 |
|
23 |
-
|
|
|
|
|
24 |
|
25 |
**Demos**:
|
26 |
- (on HuggingFace spaces): https://lfoppiano-document-qa.hf.space/
|
|
|
16 |
|
17 |
## Introduction
|
18 |
|
19 |
+
Question/Answering on scientific documents using LLMs: ChatGPT-3.5-turbo, Mistral-7b-instruct and Zephyr-7b-beta.
|
20 |
+
The streamlit application demonstrate the implementaiton of a RAG (Retrieval Augmented Generation) on scientific documents, that we are developing at NIMS (National Institute for Materials Science), in Tsukuba, Japan.
|
21 |
+
Differently to most of the projects, we focus on scientific articles.
|
22 |
+
We target only the full-text using [Grobid](https://github.com/kermitt2/grobid) that provide and cleaner results than the raw PDF2Text converter (which is comparable with most of other solutions).
|
23 |
|
24 |
+
Additionally, this frontend provides the visualisation of named entities on LLM responses to extract <span stype="color:yellow">physical quantities, measurements</span> (with [grobid-quantities](https://github.com/kermitt2/grobid-quantities)) and <span stype="color:blue">materials</span> mentions (with [grobid-superconductors](https://github.com/lfoppiano/grobid-superconductors)).
|
25 |
+
|
26 |
+
The conversation is backed up by a sliding window memory (top 4 more recent messages) that help refers to information previously discussed in the chat.
|
27 |
|
28 |
**Demos**:
|
29 |
- (on HuggingFace spaces): https://lfoppiano-document-qa.hf.space/
|
document_qa/document_qa_engine.py
CHANGED
@@ -23,7 +23,13 @@ class DocumentQAEngine:
|
|
23 |
embeddings_map_from_md5 = {}
|
24 |
embeddings_map_to_md5 = {}
|
25 |
|
26 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
self.embedding_function = embedding_function
|
28 |
self.llm = llm
|
29 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
@@ -81,14 +87,14 @@ class DocumentQAEngine:
|
|
81 |
return self.embeddings_map_from_md5[md5]
|
82 |
|
83 |
def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
|
84 |
-
verbose=False) -> (
|
85 |
Any, str):
|
86 |
# self.load_embeddings(self.embeddings_root_path)
|
87 |
|
88 |
if verbose:
|
89 |
print(query)
|
90 |
|
91 |
-
response = self._run_query(doc_id, query, context_size=context_size)
|
92 |
response = response['output_text'] if 'output_text' in response else response
|
93 |
|
94 |
if verbose:
|
@@ -138,9 +144,15 @@ class DocumentQAEngine:
|
|
138 |
|
139 |
return parsed_output
|
140 |
|
141 |
-
def _run_query(self, doc_id, query, context_size=4):
|
142 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
# return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
|
145 |
|
146 |
def _get_context(self, doc_id, query, context_size=4):
|
@@ -150,6 +162,7 @@ class DocumentQAEngine:
|
|
150 |
return relevant_documents
|
151 |
|
152 |
def get_all_context_by_document(self, doc_id):
|
|
|
153 |
db = self.embeddings_dict[doc_id]
|
154 |
docs = db.get()
|
155 |
return docs['documents']
|
@@ -161,6 +174,7 @@ class DocumentQAEngine:
|
|
161 |
return relevant_documents
|
162 |
|
163 |
def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
|
|
|
164 |
if verbose:
|
165 |
print("File", pdf_file_path)
|
166 |
filename = Path(pdf_file_path).stem
|
@@ -215,12 +229,11 @@ class DocumentQAEngine:
|
|
215 |
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
|
216 |
collection_name=hash)
|
217 |
|
218 |
-
|
219 |
self.embeddings_root_path = None
|
220 |
|
221 |
return hash
|
222 |
|
223 |
-
def create_embeddings(self, pdfs_dir_path: Path):
|
224 |
input_files = []
|
225 |
for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
|
226 |
for file_ in files:
|
@@ -238,7 +251,8 @@ class DocumentQAEngine:
|
|
238 |
print(data_path, "exists. Skipping it ")
|
239 |
continue
|
240 |
|
241 |
-
texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=
|
|
|
242 |
filename = metadata[0]['filename']
|
243 |
|
244 |
vector_db_document = Chroma.from_texts(texts,
|
|
|
23 |
embeddings_map_from_md5 = {}
|
24 |
embeddings_map_to_md5 = {}
|
25 |
|
26 |
+
def __init__(self,
|
27 |
+
llm,
|
28 |
+
embedding_function,
|
29 |
+
qa_chain_type="stuff",
|
30 |
+
embeddings_root_path=None,
|
31 |
+
grobid_url=None,
|
32 |
+
):
|
33 |
self.embedding_function = embedding_function
|
34 |
self.llm = llm
|
35 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
|
|
87 |
return self.embeddings_map_from_md5[md5]
|
88 |
|
89 |
def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
|
90 |
+
verbose=False, memory=None) -> (
|
91 |
Any, str):
|
92 |
# self.load_embeddings(self.embeddings_root_path)
|
93 |
|
94 |
if verbose:
|
95 |
print(query)
|
96 |
|
97 |
+
response = self._run_query(doc_id, query, context_size=context_size, memory=memory)
|
98 |
response = response['output_text'] if 'output_text' in response else response
|
99 |
|
100 |
if verbose:
|
|
|
144 |
|
145 |
return parsed_output
|
146 |
|
147 |
+
def _run_query(self, doc_id, query, memory=None, context_size=4):
|
148 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
149 |
+
if memory:
|
150 |
+
return self.chain.run(input_documents=relevant_documents,
|
151 |
+
question=query)
|
152 |
+
else:
|
153 |
+
return self.chain.run(input_documents=relevant_documents,
|
154 |
+
question=query,
|
155 |
+
memory=memory)
|
156 |
# return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
|
157 |
|
158 |
def _get_context(self, doc_id, query, context_size=4):
|
|
|
162 |
return relevant_documents
|
163 |
|
164 |
def get_all_context_by_document(self, doc_id):
|
165 |
+
"""Return the full context from the document"""
|
166 |
db = self.embeddings_dict[doc_id]
|
167 |
docs = db.get()
|
168 |
return docs['documents']
|
|
|
174 |
return relevant_documents
|
175 |
|
176 |
def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
|
177 |
+
"""Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately"""
|
178 |
if verbose:
|
179 |
print("File", pdf_file_path)
|
180 |
filename = Path(pdf_file_path).stem
|
|
|
229 |
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
|
230 |
collection_name=hash)
|
231 |
|
|
|
232 |
self.embeddings_root_path = None
|
233 |
|
234 |
return hash
|
235 |
|
236 |
+
def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1):
|
237 |
input_files = []
|
238 |
for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
|
239 |
for file_ in files:
|
|
|
251 |
print(data_path, "exists. Skipping it ")
|
252 |
continue
|
253 |
|
254 |
+
texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size,
|
255 |
+
perc_overlap=perc_overlap)
|
256 |
filename = metadata[0]['filename']
|
257 |
|
258 |
vector_db_document = Chroma.from_texts(texts,
|
streamlit_app.py
CHANGED
@@ -6,6 +6,7 @@ from tempfile import NamedTemporaryFile
|
|
6 |
import dotenv
|
7 |
from grobid_quantities.quantities import QuantitiesAPI
|
8 |
from langchain.llms.huggingface_hub import HuggingFaceHub
|
|
|
9 |
|
10 |
dotenv.load_dotenv(override=True)
|
11 |
|
@@ -51,6 +52,9 @@ if 'ner_processing' not in st.session_state:
|
|
51 |
if 'uploaded' not in st.session_state:
|
52 |
st.session_state['uploaded'] = False
|
53 |
|
|
|
|
|
|
|
54 |
st.set_page_config(
|
55 |
page_title="Scientific Document Insights Q/A",
|
56 |
page_icon="π",
|
@@ -67,6 +71,11 @@ def new_file():
|
|
67 |
st.session_state['loaded_embeddings'] = None
|
68 |
st.session_state['doc_id'] = None
|
69 |
st.session_state['uploaded'] = True
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
|
72 |
# @st.cache_resource
|
@@ -97,6 +106,7 @@ def init_qa(model, api_key=None):
|
|
97 |
else:
|
98 |
st.error("The model was not loaded properly. Try reloading. ")
|
99 |
st.stop()
|
|
|
100 |
|
101 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
102 |
|
@@ -168,7 +178,7 @@ with st.sidebar:
|
|
168 |
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])
|
169 |
|
170 |
st.markdown(
|
171 |
-
":warning: Mistral and Zephyr are
|
172 |
|
173 |
if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
|
174 |
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
|
@@ -205,6 +215,11 @@ with st.sidebar:
|
|
205 |
# else:
|
206 |
# is_api_key_provided = st.session_state['api_key']
|
207 |
|
|
|
|
|
|
|
|
|
|
|
208 |
st.title("π Scientific Document Insights Q/A")
|
209 |
st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
|
210 |
|
@@ -297,7 +312,8 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
|
|
297 |
elif mode == "LLM":
|
298 |
with st.spinner("Generating response..."):
|
299 |
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
|
300 |
-
context_size=context_size
|
|
|
301 |
|
302 |
if not text_response:
|
303 |
st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
|
@@ -316,5 +332,11 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
|
|
316 |
st.write(text_response)
|
317 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|
318 |
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
elif st.session_state.loaded_embeddings and st.session_state.doc_id:
|
320 |
play_old_messages()
|
|
|
6 |
import dotenv
|
7 |
from grobid_quantities.quantities import QuantitiesAPI
|
8 |
from langchain.llms.huggingface_hub import HuggingFaceHub
|
9 |
+
from langchain.memory import ConversationBufferWindowMemory
|
10 |
|
11 |
dotenv.load_dotenv(override=True)
|
12 |
|
|
|
52 |
if 'uploaded' not in st.session_state:
|
53 |
st.session_state['uploaded'] = False
|
54 |
|
55 |
+
if 'memory' not in st.session_state:
|
56 |
+
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
|
57 |
+
|
58 |
st.set_page_config(
|
59 |
page_title="Scientific Document Insights Q/A",
|
60 |
page_icon="π",
|
|
|
71 |
st.session_state['loaded_embeddings'] = None
|
72 |
st.session_state['doc_id'] = None
|
73 |
st.session_state['uploaded'] = True
|
74 |
+
st.session_state['memory'].clear()
|
75 |
+
|
76 |
+
|
77 |
+
def clear_memory():
|
78 |
+
st.session_state['memory'].clear()
|
79 |
|
80 |
|
81 |
# @st.cache_resource
|
|
|
106 |
else:
|
107 |
st.error("The model was not loaded properly. Try reloading. ")
|
108 |
st.stop()
|
109 |
+
return
|
110 |
|
111 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
112 |
|
|
|
178 |
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])
|
179 |
|
180 |
st.markdown(
|
181 |
+
":warning: Mistral and Zephyr are **FREE** to use. Requests might fail anytime. Use at your own risk. :warning: ")
|
182 |
|
183 |
if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
|
184 |
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
|
|
|
215 |
# else:
|
216 |
# is_api_key_provided = st.session_state['api_key']
|
217 |
|
218 |
+
st.button(
|
219 |
+
'Reset chat memory.',
|
220 |
+
on_click=clear_memory(),
|
221 |
+
help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.")
|
222 |
+
|
223 |
st.title("π Scientific Document Insights Q/A")
|
224 |
st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
|
225 |
|
|
|
312 |
elif mode == "LLM":
|
313 |
with st.spinner("Generating response..."):
|
314 |
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
|
315 |
+
context_size=context_size,
|
316 |
+
memory=st.session_state.memory)
|
317 |
|
318 |
if not text_response:
|
319 |
st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
|
|
|
332 |
st.write(text_response)
|
333 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|
334 |
|
335 |
+
for id in range(0, len(st.session_state.messages), 2):
|
336 |
+
question = st.session_state.messages[id]['content']
|
337 |
+
if len(st.session_state.messages) > id + 1:
|
338 |
+
answer = st.session_state.messages[id + 1]['content']
|
339 |
+
st.session_state.memory.save_context({"input": question}, {"output": answer})
|
340 |
+
|
341 |
elif st.session_state.loaded_embeddings and st.session_state.doc_id:
|
342 |
play_old_messages()
|