vishwask commited on
Commit
7cff7cb
1 Parent(s): ba30854

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -178
app.py CHANGED
@@ -1,24 +1,6 @@
1
- import time
2
- print('1')
3
- print(time.time())
4
-
5
- #__import__('pysqlite3')
6
- #import sys
7
- #sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
8
-
9
  import os
10
  import torch
11
-
12
-
13
-
14
- #os.system('wget -q https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.4.2/auto_gptq-0.4.2+cu118-cp310-cp310-linux_x86_64.whl')
15
- #os.system('pip install -qqq auto_gptq-0.4.2+cu118-cp310-cp310-linux_x86_64.whl --progress-bar off')
16
-
17
- #print(f"Is CUDA available: {torch.cuda.is_available()}")
18
- os.system('nvidia-smi')
19
-
20
  import uuid
21
- #import replicate
22
  import requests
23
  import streamlit as st
24
  from streamlit.logger import get_logger
@@ -28,7 +10,6 @@ from langchain.chains import RetrievalQA
28
  from langchain.document_loaders import PyPDFDirectoryLoader
29
  from langchain.embeddings import HuggingFaceInstructEmbeddings
30
  from langchain.text_splitter import RecursiveCharacterTextSplitter
31
- from langchain.vectorstores import Chroma
32
  from pdf2image import convert_from_path
33
  from transformers import AutoTokenizer, TextStreamer, pipeline
34
  from langchain.memory import ConversationBufferMemory
@@ -36,7 +17,6 @@ from gtts import gTTS
36
  from io import BytesIO
37
  from langchain.chains import ConversationalRetrievalChain
38
  import streamlit.components.v1 as components
39
- #from sentence_transformers import SentenceTransformer
40
  from langchain.document_loaders import UnstructuredMarkdownLoader
41
  from langchain.vectorstores.utils import filter_complex_metadata
42
  import fitz
@@ -50,13 +30,6 @@ logger = get_logger(__name__)
50
  st.set_page_config(page_title="Document QA by Dono", page_icon="🤖", )
51
  st.session_state.disabled = False
52
  st.title("Document QA by Dono")
53
- #st.markdown(f"""<style>
54
- # .stApp {{background-image: url("https://media.istockphoto.com/id/450481545/photo/glowing-lightbulb-against-black-background.webp?b=1&s=170667a&w=0&k=20&c=fJ91chWN1UkoKTNUvwgiQwpM80DlRpVC-WlJH_78OvE=");
55
- # background-attachment: fixed;
56
- # background-size: cover}}
57
- # </style>
58
- # """, unsafe_allow_html=True)
59
-
60
  DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
61
 
62
 
@@ -64,30 +37,14 @@ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
64
  def load_data():
65
  loader = PyPDFDirectoryLoader("/home/user/app/pdfs/")
66
  docs = loader.load()
67
- print(len(docs))
68
  return docs
69
 
70
-
71
-
72
  @st.cache_resource
73
  def load_model(_docs):
74
- #embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-large",model_kwargs={"device":DEVICE})
75
- #embeddings = HuggingFaceInstructEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",model_kwargs={"device":DEVICE})
76
  embeddings = HuggingFaceInstructEmbeddings(model_name="/home/user/app/all-MiniLM-L6-v2/",model_kwargs={"device":DEVICE})
77
- print(DEVICE)
78
-
79
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=256)
80
  texts = text_splitter.split_documents(docs)
81
-
82
- print('embedding done')
83
-
84
- #db = Chroma.from_documents(texts, embeddings, persist_directory="/home/user/app/db")
85
  db = FAISS.from_documents(texts, embeddings)
86
-
87
- print('db done')
88
-
89
-
90
- #model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ"
91
  model_name_or_path = "/home/user/app/Llama-2-13B-chat-GPTQ/"
92
  model_basename = "model"
93
 
@@ -104,20 +61,18 @@ def load_model(_docs):
104
  quantize_config=None,
105
  )
106
 
107
- print('model done')
108
-
109
  DEFAULT_SYSTEM_PROMPT = """
110
  You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.
111
  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
112
  Please ensure that your responses are socially unbiased and positive in nature.
113
  Always provide the citation for the answer from the text.
114
  Try to include any section or subsection present in the text responsible for the answer.
115
- Provide reference. Provide page number, section, sub section etc from which answer is taken.
116
  If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
117
  Given a government document that outlines rules and regulations for a specific industry or sector, use your language model to answer questions about the rules and their applicability over time.
118
  The document may include provisions that take effect at different times, such as immediately upon publication, after a grace period, or on a specific date in the future.
119
  Your task is to identify the relevant rules and determine when they go into effect, taking into account any dependencies or exceptions that may apply.
120
- The current date is 14 September, 2023. Try to extract information which is closer to this date and not in very past.
121
  Take a deep breath and work on this problem step-by-step.
122
  """.strip()
123
 
@@ -126,52 +81,45 @@ def load_model(_docs):
126
  return f"""[INST] <<SYS>>{system_prompt}<</SYS>>{prompt} [/INST]""".strip()
127
 
128
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
129
-
130
- text_pipeline = pipeline("text-generation",model=model,tokenizer=tokenizer,max_new_tokens=1024,
131
- temperature=0.2,top_p=0.95,repetition_penalty=1.15,streamer=streamer,)
132
-
 
 
 
 
133
  llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 0.2})
134
 
135
- print('llm done')
136
-
137
- SYSTEM_PROMPT = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
138
 
139
  template = generate_prompt("""{context} Question: {question} """,system_prompt=SYSTEM_PROMPT,) #Enter memory here!
140
-
141
  prompt = PromptTemplate(template=template, input_variables=["context", "question"]) #Add history here
142
-
143
  qa_chain = RetrievalQA.from_chain_type(
144
  llm=llm,
145
  chain_type="stuff",
146
  retriever=db.as_retriever(search_kwargs={"k": 5}),
147
  return_source_documents=True,
148
  chain_type_kwargs={"prompt": prompt,
149
- "verbose": False,
150
- #"memory": ConversationBufferMemory(
151
- #memory_key="history",
152
- #input_key="question",
153
- #return_messages=True)
154
- },)
155
 
156
  print('load done')
157
  return qa_chain
158
 
159
 
160
- #uploaded_file = len(docs)
161
- #flag = 0
162
- #if uploaded_file is not None:
163
- # flag = 1
164
-
165
- model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ"
166
  model_basename = "model"
167
 
168
  st.session_state["llm_model"] = model_name_or_path
169
 
170
-
171
  if "messages" not in st.session_state:
172
  st.session_state.messages = []
173
-
174
-
 
 
175
 
176
  for message in st.session_state.messages:
177
  with st.chat_message(message["role"]):
@@ -181,7 +129,7 @@ for message in st.session_state.messages:
181
  def on_select():
182
  st.session_state.disabled = True
183
 
184
-
185
  def get_message_history():
186
  for message in st.session_state.messages:
187
  role, content = message["role"], message["content"]
@@ -191,11 +139,6 @@ def get_message_history():
191
  docs = load_data()
192
  qa_chain = load_model(docs)
193
 
194
-
195
-
196
-
197
- print('2')
198
- print(time.time())
199
  if prompt := st.chat_input("How can I help you today?"):
200
  st.session_state.messages.append({"role": "user", "content": prompt})
201
  with st.chat_message("user"):
@@ -204,144 +147,51 @@ if prompt := st.chat_input("How can I help you today?"):
204
  message_placeholder = st.empty()
205
  full_response = ""
206
  message_history = "\n".join(list(get_message_history())[-3:])
207
- logger.info(f"{user_session_id} Message History: {message_history}")
208
- # question = st.text_input("Ask your question", placeholder="Try to include context in your question",
209
- # disabled=not uploaded_file,)
210
- print('3')
211
- print(time.time())
212
- result = qa_chain(prompt)
213
- print('4')
214
- print(time.time())
215
-
216
  output = [result['result']]
217
 
218
- # for item in output:
219
- # full_response += item
220
- # message_placeholder.markdown(full_response + "▌")
221
- # message_placeholder.markdown(full_response)
222
- #st.write(repr(result['source_documents'][0].metadata['page']))
223
- #st.write(repr(result['source_documents'][0]))
224
-
225
- print('5')
226
- print(time.time())
227
-
228
  def generate_pdf():
 
229
  page_number = int(result['source_documents'][0].metadata['page'])
230
  doc = fitz.open(str(result['source_documents'][0].metadata['source']))
231
-
232
  text = str(result['source_documents'][0].page_content)
233
  if text != '':
234
  for page in doc:
235
- ### SEARCH
236
  text_instances = page.search_for(text)
237
-
238
- ### HIGHLIGHT
239
  for inst in text_instances:
240
  highlight = page.add_highlight_annot(inst)
241
  highlight.update()
242
-
243
- ### OUTPUT
244
  doc.save("/home/user/app/pdf2image/output.pdf", garbage=4, deflate=True, clean=True)
245
-
246
- # pdf_to_open = repr(result['source_documents'][0].metadata['source'])
247
-
248
  def pdf_page_to_image(pdf_file, page_number, output_image):
249
- # Open the PDF file
250
  pdf_document = fitz.open(pdf_file)
251
-
252
- # Get the specific page
253
  page = pdf_document[page_number]
254
-
255
- # Define the image DPI (dots per inch)
256
  dpi = 300 # You can adjust this as needed
257
-
258
- # Convert the page to an image
259
  pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 100, dpi / 100))
260
-
261
- # Save the image as a PNG file
262
  pix.save(output_image, "png")
263
-
264
- # Close the PDF file
265
  pdf_document.close()
266
-
267
-
268
  pdf_page_to_image('/home/user/app/pdf2image/output.pdf', page_number, '/home/user/app/pdf2image/output.png')
269
-
270
  image = Image.open('/home/user/app/pdf2image/output.png')
271
- st.sidebar.image(image)
272
  st.session_state.image_displayed = True
273
 
274
  def generate_audio():
275
  sound_file = BytesIO()
276
  tts = gTTS(result['result'], lang='en')
277
  tts.write_to_fp(sound_file)
278
- st.sidebar.audio(sound_file)
279
  st.session_state.sound_played = True
280
 
281
 
282
- #st.button(':speaker:', type='primary',on_click=generate_audio)
283
- #st.button('Reference',type='primary',on_click=generate_pdf)
284
-
285
- # Create placeholders for output
286
- image_output = st.empty()
287
- sound_output = st.empty()
288
-
289
- # Create a button to display the image
290
- # if st.button("Reference"):
291
- # image_output.clear()
292
- # generate_pdf()
293
-
294
- # # Create a button to play the sound
295
- # if st.button(":speaker:"):
296
- # sound_output.clear()
297
- # generate_audio()
298
-
299
-
300
- # on_audio = st.checkbox(':speaker:', key="speaker")
301
- # on_ref = st.checkbox('Reference', key="reference")
302
- # if on_audio:
303
- # generate_audio()
304
-
305
- # if on_ref:
306
- # generate_pdf()
307
-
308
- # Initialize session state variables
309
- if "image_displayed" not in st.session_state:
310
- st.session_state.image_displayed = False
311
- if "sound_played" not in st.session_state:
312
- st.session_state.sound_played = False
313
-
314
-
315
-
316
- # Create the two buttons
317
- #st.button("Display Image", on_click=generate_pdf)
318
- #st.button("Play Sound", on_click=generate_audio)
319
-
320
-
321
-
322
- # # Check if the image has been displayed and display it if it has not
323
- # if not st.session_state.image_displayed:
324
- # generate_pdf()
325
-
326
- # # Check if the sound has been played and play it if it has not
327
- # if not st.session_state.sound_played:
328
- # generate_audio()
329
-
330
-
331
  for item in output:
332
  full_response += item
333
  message_placeholder.markdown(full_response + "▌")
334
  message_placeholder.markdown(full_response)
335
-
336
- st.session_state.messages.append({"role": "assistant", "content": full_response})
337
 
338
- if st.button("Display Image"):
339
  generate_pdf()
340
 
341
-
342
- if st.button("Play Sound"):
343
- generate_audio()
344
-
345
-
346
 
347
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
 
 
 
 
 
 
 
 
 
3
  import uuid
 
4
  import requests
5
  import streamlit as st
6
  from streamlit.logger import get_logger
 
10
  from langchain.document_loaders import PyPDFDirectoryLoader
11
  from langchain.embeddings import HuggingFaceInstructEmbeddings
12
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
13
  from pdf2image import convert_from_path
14
  from transformers import AutoTokenizer, TextStreamer, pipeline
15
  from langchain.memory import ConversationBufferMemory
 
17
  from io import BytesIO
18
  from langchain.chains import ConversationalRetrievalChain
19
  import streamlit.components.v1 as components
 
20
  from langchain.document_loaders import UnstructuredMarkdownLoader
21
  from langchain.vectorstores.utils import filter_complex_metadata
22
  import fitz
 
30
  st.set_page_config(page_title="Document QA by Dono", page_icon="🤖", )
31
  st.session_state.disabled = False
32
  st.title("Document QA by Dono")
 
 
 
 
 
 
 
33
  DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
34
 
35
 
 
37
  def load_data():
38
  loader = PyPDFDirectoryLoader("/home/user/app/pdfs/")
39
  docs = loader.load()
 
40
  return docs
41
 
 
 
42
  @st.cache_resource
43
  def load_model(_docs):
 
 
44
  embeddings = HuggingFaceInstructEmbeddings(model_name="/home/user/app/all-MiniLM-L6-v2/",model_kwargs={"device":DEVICE})
 
 
45
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=256)
46
  texts = text_splitter.split_documents(docs)
 
 
 
 
47
  db = FAISS.from_documents(texts, embeddings)
 
 
 
 
 
48
  model_name_or_path = "/home/user/app/Llama-2-13B-chat-GPTQ/"
49
  model_basename = "model"
50
 
 
61
  quantize_config=None,
62
  )
63
 
 
 
64
  DEFAULT_SYSTEM_PROMPT = """
65
  You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.
66
  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
67
  Please ensure that your responses are socially unbiased and positive in nature.
68
  Always provide the citation for the answer from the text.
69
  Try to include any section or subsection present in the text responsible for the answer.
70
+ Provide reference. Provide page number, section, sub section etc.
71
  If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
72
  Given a government document that outlines rules and regulations for a specific industry or sector, use your language model to answer questions about the rules and their applicability over time.
73
  The document may include provisions that take effect at different times, such as immediately upon publication, after a grace period, or on a specific date in the future.
74
  Your task is to identify the relevant rules and determine when they go into effect, taking into account any dependencies or exceptions that may apply.
75
+ The current date is 14 September, 2023. Try to extract information which is closer to this date.
76
  Take a deep breath and work on this problem step-by-step.
77
  """.strip()
78
 
 
81
  return f"""[INST] <<SYS>>{system_prompt}<</SYS>>{prompt} [/INST]""".strip()
82
 
83
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
84
+ text_pipeline = pipeline("text-generation",
85
+ model=model,
86
+ tokenizer=tokenizer,
87
+ max_new_tokens=1024,
88
+ temperature=0.2,
89
+ top_p=0.95,
90
+ repetition_penalty=1.15,
91
+ streamer=streamer,)
92
  llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 0.2})
93
 
94
+ SYSTEM_PROMPT = ("Use the following pieces of context to answer the question at the end. "
95
+ "If you don't know the answer, just say that you don't know, "
96
+ "don't try to make up an answer.")
97
 
98
  template = generate_prompt("""{context} Question: {question} """,system_prompt=SYSTEM_PROMPT,) #Enter memory here!
 
99
  prompt = PromptTemplate(template=template, input_variables=["context", "question"]) #Add history here
 
100
  qa_chain = RetrievalQA.from_chain_type(
101
  llm=llm,
102
  chain_type="stuff",
103
  retriever=db.as_retriever(search_kwargs={"k": 5}),
104
  return_source_documents=True,
105
  chain_type_kwargs={"prompt": prompt,
106
+ "verbose": False})
 
 
 
 
 
107
 
108
  print('load done')
109
  return qa_chain
110
 
111
 
112
+ model_name_or_path = "Llama-2-13B-chat-GPTQ"
 
 
 
 
 
113
  model_basename = "model"
114
 
115
  st.session_state["llm_model"] = model_name_or_path
116
 
 
117
  if "messages" not in st.session_state:
118
  st.session_state.messages = []
119
+ if "image_displayed" not in st.session_state:
120
+ st.session_state.image_displayed = False
121
+ if "sound_played" not in st.session_state:
122
+ st.session_state.sound_played = False
123
 
124
  for message in st.session_state.messages:
125
  with st.chat_message(message["role"]):
 
129
  def on_select():
130
  st.session_state.disabled = True
131
 
132
+
133
  def get_message_history():
134
  for message in st.session_state.messages:
135
  role, content = message["role"], message["content"]
 
139
  docs = load_data()
140
  qa_chain = load_model(docs)
141
 
 
 
 
 
 
142
  if prompt := st.chat_input("How can I help you today?"):
143
  st.session_state.messages.append({"role": "user", "content": prompt})
144
  with st.chat_message("user"):
 
147
  message_placeholder = st.empty()
148
  full_response = ""
149
  message_history = "\n".join(list(get_message_history())[-3:])
150
+ question = st.text_input("Ask your question", placeholder="Try to include context in your question")
151
+ result = qa_chain(question)
 
 
 
 
 
 
 
152
  output = [result['result']]
153
 
 
 
 
 
 
 
 
 
 
 
154
  def generate_pdf():
155
+ generate_audio()
156
  page_number = int(result['source_documents'][0].metadata['page'])
157
  doc = fitz.open(str(result['source_documents'][0].metadata['source']))
 
158
  text = str(result['source_documents'][0].page_content)
159
  if text != '':
160
  for page in doc:
 
161
  text_instances = page.search_for(text)
 
 
162
  for inst in text_instances:
163
  highlight = page.add_highlight_annot(inst)
164
  highlight.update()
 
 
165
  doc.save("/home/user/app/pdf2image/output.pdf", garbage=4, deflate=True, clean=True)
166
+
 
 
167
  def pdf_page_to_image(pdf_file, page_number, output_image):
 
168
  pdf_document = fitz.open(pdf_file)
 
 
169
  page = pdf_document[page_number]
 
 
170
  dpi = 300 # You can adjust this as needed
 
 
171
  pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 100, dpi / 100))
 
 
172
  pix.save(output_image, "png")
 
 
173
  pdf_document.close()
 
 
174
  pdf_page_to_image('/home/user/app/pdf2image/output.pdf', page_number, '/home/user/app/pdf2image/output.png')
 
175
  image = Image.open('/home/user/app/pdf2image/output.png')
176
+ st.image(image)
177
  st.session_state.image_displayed = True
178
 
179
  def generate_audio():
180
  sound_file = BytesIO()
181
  tts = gTTS(result['result'], lang='en')
182
  tts.write_to_fp(sound_file)
183
+ st.audio(sound_file)
184
  st.session_state.sound_played = True
185
 
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  for item in output:
188
  full_response += item
189
  message_placeholder.markdown(full_response + "▌")
190
  message_placeholder.markdown(full_response)
 
 
191
 
192
+ if st.toggle("Reference and Sound"):
193
  generate_pdf()
194
 
195
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
 
 
 
 
196
 
197