vishwask commited on
Commit
b1754ef
β€’
1 Parent(s): 9ace069

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +292 -0
app.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ os.system('wget -q https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.4.1/auto_gptq-0.4.1+cu118-cp310-cp310-linux_x86_64.whl')
5
+ os.system('pip install -qqq auto_gptq-0.4.1+cu118-cp310-cp310-linux_x86_64.whl --progress-bar off')
6
+ os.system('sudo apt-get install poppler-utils')
7
+
8
+ import uuid
9
+ #import replicate
10
+ import requests
11
+ import streamlit as st
12
+ from streamlit.logger import get_logger
13
+ import torch
14
+ from auto_gptq import AutoGPTQForCausalLM
15
+ from langchain import HuggingFacePipeline, PromptTemplate
16
+ from langchain.chains import RetrievalQA
17
+ from langchain.document_loaders import PyPDFDirectoryLoader
18
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
19
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
20
+ from langchain.vectorstores import Chroma
21
+ from pdf2image import convert_from_path
22
+ from transformers import AutoTokenizer, TextStreamer, pipeline
23
+ from langchain.memory import ConversationBufferMemory
24
+ from gtts import gTTS
25
+ from io import BytesIO
26
+ from langchain.chains import ConversationalRetrievalChain
27
+ from streamlit_modal import Modal
28
+ import streamlit.components.v1 as components
29
+ #from sentence_transformers import SentenceTransformer
30
+ from langchain.document_loaders import UnstructuredMarkdownLoader
31
+ from langchain.vectorstores.utils import filter_complex_metadata
32
+ import fitz
33
+ from PIL import Image
34
+
35
+ user_session_id = uuid.uuid4()
36
+
37
+ logger = get_logger(__name__)
38
+ st.set_page_config(page_title="Document QA by Dono", page_icon="πŸ€–", )
39
+ st.session_state.disabled = False
40
+ st.title("Document QA by Dono")
41
+ st.markdown(f"""<style>
42
+ .stApp {{background-image: url("https://media.istockphoto.com/id/450481545/photo/glowing-lightbulb-against-black-background.webp?b=1&s=170667a&w=0&k=20&c=fJ91chWN1UkoKTNUvwgiQwpM80DlRpVC-WlJH_78OvE=");
43
+ background-attachment: fixed;
44
+ background-size: cover}}
45
+ </style>
46
+ """, unsafe_allow_html=True)
47
+
48
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
49
+
50
+ loader = PyPDFDirectoryLoader("/pdfs/")
51
+ docs = loader.load()
52
+ #len(docs)
53
+
54
+
55
+
56
+
57
+ @st.cache_resource
58
+ def load_model():
59
+ #embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-large",model_kwargs={"device":DEVICE})
60
+ embeddings = HuggingFaceInstructEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",model_kwargs={"device":DEVICE})
61
+
62
+
63
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=256)
64
+ texts = text_splitter.split_documents(docs)
65
+
66
+ db = Chroma.from_documents(texts, embeddings, persist_directory="db")
67
+
68
+ model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ"
69
+ model_basename = "model"
70
+
71
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
72
+
73
+ model = AutoGPTQForCausalLM.from_quantized(
74
+ model_name_or_path,
75
+ revision="gptq-4bit-128g-actorder_True",
76
+ model_basename=model_basename,
77
+ use_safetensors=True,
78
+ trust_remote_code=True,
79
+ inject_fused_attention=False,
80
+ device=DEVICE,
81
+ quantize_config=None,
82
+ )
83
+
84
+ DEFAULT_SYSTEM_PROMPT = """
85
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. Always provide the citation for the answer from the text. Try to include any section or subsection present in the text responsible for the answer. Provide reference. Provide page number, section, sub section etc from which answer is taken.
86
+
87
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
88
+ """.strip()
89
+
90
+
91
+ def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
92
+ return f"""[INST] <<SYS>>{system_prompt}<</SYS>>{prompt} [/INST]""".strip()
93
+
94
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
95
+
96
+ text_pipeline = pipeline("text-generation",model=model,tokenizer=tokenizer,max_new_tokens=1024,
97
+ temperature=0.2,top_p=0.95,repetition_penalty=1.15,streamer=streamer,)
98
+
99
+ llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 0.2})
100
+
101
+ SYSTEM_PROMPT = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
102
+
103
+ template = generate_prompt("""{context} Question: {question} """,system_prompt=SYSTEM_PROMPT,) #Enter memory here!
104
+
105
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"]) #Add history here
106
+
107
+ qa_chain = RetrievalQA.from_chain_type(
108
+ llm=llm,
109
+ chain_type="stuff",
110
+ retriever=db.as_retriever(search_kwargs={"k": 2}),
111
+ return_source_documents=True,
112
+ chain_type_kwargs={"prompt": prompt,
113
+ "verbose": False,
114
+ #"memory": ConversationBufferMemory(
115
+ #memory_key="history",
116
+ #input_key="question",
117
+ #return_messages=True)
118
+ },)
119
+ return qa_chain
120
+
121
+
122
+ uploaded_file = len(docs)
123
+ flag = 0
124
+ if uploaded_file is not None:
125
+ flag = 1
126
+
127
+ model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ"
128
+ model_basename = "model"
129
+
130
+ st.session_state["llm_model"] = model_name_or_path
131
+
132
+
133
+ if "messages" not in st.session_state:
134
+ st.session_state.messages = []
135
+
136
+
137
+
138
+ for message in st.session_state.messages:
139
+ with st.chat_message(message["role"]):
140
+ st.markdown(message["content"])
141
+
142
+
143
+ def on_select():
144
+ st.session_state.disabled = True
145
+
146
+
147
+ def get_message_history():
148
+ for message in st.session_state.messages:
149
+ role, content = message["role"], message["content"]
150
+ yield f"{role.title()}: {content}"
151
+
152
+
153
+ if prompt := st.chat_input("How can I help you today?"):
154
+ st.session_state.messages.append({"role": "user", "content": prompt})
155
+ with st.chat_message("user"):
156
+ st.markdown(prompt)
157
+ with st.chat_message("assistant"):
158
+ message_placeholder = st.empty()
159
+ full_response = ""
160
+ message_history = "\n".join(list(get_message_history())[-3:])
161
+ logger.info(f"{user_session_id} Message History: {message_history}")
162
+ qa_chain = load_model()
163
+ # question = st.text_input("Ask your question", placeholder="Try to include context in your question",
164
+ # disabled=not uploaded_file,)
165
+ result = qa_chain(prompt)
166
+ sound_file = BytesIO()
167
+ tts = gTTS(result['result'], lang='en')
168
+ tts.write_to_fp(sound_file)
169
+ output = [result['result']]
170
+
171
+ for item in output:
172
+ full_response += item
173
+ message_placeholder.markdown(full_response + "β–Œ")
174
+ message_placeholder.markdown(full_response)
175
+ #st.write(repr(result['source_documents'][0].metadata['page']))
176
+ #st.write(repr(result['source_documents'][0]))
177
+
178
+
179
+ ### READ IN PDF
180
+ page_number = int(result['source_documents'][0].metadata['page'])
181
+ doc = fitz.open(str(result['source_documents'][0].metadata['source']))
182
+
183
+ text = str(result['source_documents'][0].page_content)
184
+ if text != '':
185
+ for page in doc:
186
+ ### SEARCH
187
+ text_instances = page.search_for(text)
188
+
189
+ ### HIGHLIGHT
190
+ for inst in text_instances:
191
+ highlight = page.add_highlight_annot(inst)
192
+ highlight.update()
193
+
194
+ ### OUTPUT
195
+ doc.save("/pdf2image/output.pdf", garbage=4, deflate=True, clean=True)
196
+
197
+ # pdf_to_open = repr(result['source_documents'][0].metadata['source'])
198
+
199
+ def pdf_page_to_image(pdf_file, page_number, output_image):
200
+ # Open the PDF file
201
+ pdf_document = fitz.open(pdf_file)
202
+
203
+ # Get the specific page
204
+ page = pdf_document[page_number]
205
+
206
+ # Define the image DPI (dots per inch)
207
+ dpi = 300 # You can adjust this as needed
208
+
209
+ # Convert the page to an image
210
+ pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 100, dpi / 100))
211
+
212
+ # Save the image as a PNG file
213
+ pix.save(output_image, "png")
214
+
215
+ # Close the PDF file
216
+ pdf_document.close()
217
+
218
+
219
+ pdf_page_to_image('/pdf2image/output.pdf', page_number, '/pdf2image/output.png')
220
+
221
+ image = Image.open('/pdf2image/output.png')
222
+ st.image(image)
223
+ st.audio(sound_file)
224
+
225
+ # if 'clickedR' not in st.session_state:
226
+ # st.session_state.clickedR = False
227
+
228
+ # def click_buttonR():
229
+ # st.session_state.clickedR = True
230
+ # if st.session_state.clickedR:
231
+ # message_placeholder.markdown(full_response+repr(result['source_documents'][0]))
232
+
233
+ # ref = st.button('References', on_click = click_buttonR)
234
+
235
+
236
+ # if 'clickedA' not in st.session_state:
237
+ # st.session_state.clickedA = False
238
+
239
+ # def click_buttonA():
240
+ # st.session_state.clickedA = True
241
+ # if st.session_state.clickedA:
242
+ # sound_file = BytesIO()
243
+ # tts = gTTS(result['result'], lang='en')
244
+ # tts.write_to_fp(sound_file)
245
+ # st.audio(sound_file)
246
+
247
+
248
+ # ref = st.button(':speaker:', on_click = click_buttonA)
249
+
250
+
251
+
252
+
253
+
254
+ #st.session_state.clickedR = False
255
+
256
+ # #if ref:
257
+ # message_placeholder.markdown(full_response+repr(result['source_documents'][0]))
258
+ # #if sound:
259
+ # sound_file = BytesIO()
260
+ # tts = gTTS(result['result'], lang='en')
261
+ # tts.write_to_fp(sound_file)
262
+ # html_string = """
263
+ # <audio controls autoplay>
264
+ # <source src="/content/sound_file" type="audio/wav">
265
+ # </audio>
266
+ # """
267
+ # message_placeholder.markdown(html_string, unsafe_allow_html=True) # will display a st.audio with the sound you specified in the "src" of the html_string and autoplay it
268
+ # #time.sleep(5) # wait for 2 seconds to finish the playing of the audio
269
+ response_sentiment = st.radio(
270
+ "How was the Assistant's response?",
271
+ ["😁", "πŸ˜•", "😒"],
272
+ key="response_sentiment",
273
+ disabled=st.session_state.disabled,
274
+ horizontal=True,
275
+ index=1,
276
+ help="This helps us improve the model.",
277
+ # hide the radio button on click
278
+ on_change=on_select(),
279
+ )
280
+ logger.info(f"{user_session_id} | {full_response} | {response_sentiment}")
281
+
282
+ # # Logging to FastAPI Endpoint
283
+ # headers = {"Authorization": f"Bearer {secret_token}"}
284
+ # log_data = {"log": f"{user_session_id} | {full_response} | {response_sentiment}"}
285
+ # response = requests.post(fastapi_endpoint, json=log_data, headers=headers, timeout=10)
286
+ # if response.status_code == 200:
287
+ # logger.info("Query logged successfully")
288
+
289
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
290
+
291
+
292
+