SiraH commited on
Commit
6e78b52
1 Parent(s): 21f34d2

streamlit app

Browse files
Files changed (1) hide show
  1. app.py.py +275 -0
app.py.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ import re
5
+ import pathlib
6
+
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
+ from langchain.llms import HuggingFacePipeline
9
+ from langchain.llms import LlamaCpp
10
+ from langchain import PromptTemplate, LLMChain
11
+ from langchain.callbacks.manager import CallbackManager
12
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
13
+ from langchain.embeddings import HuggingFaceEmbeddings
14
+ from langchain.chains import RetrievalQA
15
+ from langchain.vectorstores import FAISS
16
+ from PyPDF2 import PdfReader
17
+ import os
18
+ import time
19
+ from langchain.chains.question_answering import load_qa_chain
20
+ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
21
+
22
+ from langchain.document_loaders import TextLoader
23
+ from langchain.document_loaders import PyPDFLoader
24
+ from langchain.document_loaders import Docx2txtLoader
25
+ from langchain.document_loaders.image import UnstructuredImageLoader
26
+ from langchain.document_loaders import UnstructuredHTMLLoader
27
+ from langchain.document_loaders import UnstructuredPowerPointLoader
28
+ from langchain.document_loaders import TextLoader
29
+ from langchain.memory import ConversationBufferWindowMemory
30
+
31
+ from langchain.memory import ConversationBufferMemory
32
+ from langchain.chains import ConversationalRetrievalChain
33
+ from langchain.memory.chat_message_histories.streamlit import StreamlitChatMessageHistory
34
+
35
+ class UploadDoc:
36
+ def __init__(self, path_data):
37
+ self.path_data = path_data
38
+
39
+ def prepare_filetype(self):
40
+ extension_lists = {
41
+ ".docx": [],
42
+ ".pdf": [],
43
+ ".html": [],
44
+ ".png": [],
45
+ ".pptx": [],
46
+ ".txt": [],
47
+ }
48
+
49
+ path_list = []
50
+ for path, subdirs, files in os.walk(self.path_data):
51
+ for name in files:
52
+ path_list.append(os.path.join(path, name))
53
+ #print(os.path.join(path, name))
54
+
55
+ # Loop through the path_list and categorize files
56
+ for filename in path_list:
57
+ file_extension = pathlib.Path(filename).suffix
58
+ #print("File Extension:", file_extension)
59
+
60
+ if file_extension in extension_lists:
61
+ extension_lists[file_extension].append(filename)
62
+ return extension_lists
63
+
64
+ def upload_docx(self, extension_lists):
65
+ #word
66
+ data_docxs = []
67
+ for doc in extension_lists[".docx"]:
68
+ loader = Docx2txtLoader(doc)
69
+ data = loader.load()
70
+ data_docxs.extend(data)
71
+ return data_docxs
72
+
73
+ def upload_pdf(self, extension_lists):
74
+ #pdf
75
+ data_pdf = []
76
+ for doc in extension_lists[".pdf"]:
77
+ loader = PyPDFLoader(doc)
78
+ data = loader.load_and_split()
79
+ data_pdf.extend(data)
80
+ return data_pdf
81
+
82
+ def upload_html(self, extension_lists):
83
+ #html
84
+ data_html = []
85
+ for doc in extension_lists[".html"]:
86
+ loader = UnstructuredHTMLLoader(doc)
87
+ data = loader.load()
88
+ data_html.extend(data)
89
+ return data_html
90
+
91
+ def upload_png_ocr(self, extension_lists):
92
+ #png ocr
93
+ data_png = []
94
+ for doc in extension_lists[".png"]:
95
+ loader = UnstructuredImageLoader(doc)
96
+ data = loader.load()
97
+ data_png.extend(data)
98
+ return data_png
99
+
100
+ def upload_pptx(self, extension_lists):
101
+ #power point
102
+ data_pptx = []
103
+ for doc in extension_lists[".pptx"]:
104
+ loader = UnstructuredPowerPointLoader(doc)
105
+ data = loader.load()
106
+ data_pptx.extend(data)
107
+ return data_pptx
108
+
109
+ def upload_txt(self, extension_lists):
110
+ #txt
111
+ data_txt = []
112
+ for doc in extension_lists[".txt"]:
113
+ loader = TextLoader(doc)
114
+ data = loader.load()
115
+ data_txt.extend(data)
116
+ return data_txt
117
+
118
+ def count_files(self, extension_lists):
119
+ file_extension_counts = {}
120
+ # Count the quantity of each item
121
+ for ext, file_list in extension_lists.items():
122
+ file_extension_counts[ext] = len(file_list)
123
+ return print(f"number of file:{file_extension_counts}")
124
+ # Print the counts
125
+ # for ext, count in file_extension_counts.items():
126
+ # return print(f"{ext}: {count} file")
127
+
128
+ def create_document(self, dataframe=True):
129
+ documents = []
130
+ extension_lists = self.prepare_filetype()
131
+ self.count_files(extension_lists)
132
+
133
+ upload_functions = {
134
+ ".docx": self.upload_docx,
135
+ ".pdf": self.upload_pdf,
136
+ ".html": self.upload_html,
137
+ ".png": self.upload_png_ocr,
138
+ ".pptx": self.upload_pptx,
139
+ ".txt": self.upload_txt,
140
+ }
141
+
142
+ for extension, upload_function in upload_functions.items():
143
+ if len(extension_lists[extension]) > 0:
144
+ if extension == ".xlsx" or extension == ".csv":
145
+ data = upload_function(extension_lists, dataframe)
146
+ else:
147
+ data = upload_function(extension_lists)
148
+ documents.extend(data)
149
+
150
+ return documents
151
+
152
+ def split_docs(documents,chunk_size=500):
153
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=50)
154
+ sp_docs = text_splitter.split_documents(documents)
155
+ return sp_docs
156
+
157
+ @st.cache_resource
158
+ def load_llama2_llamaCpp():
159
+ core_model_name = "llama-2-7b-chat.ggmlv3.q4_0.bin"
160
+ n_gpu_layers = 32
161
+ n_batch = 512
162
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
163
+ llm = LlamaCpp(
164
+ model_path=core_model_name,
165
+ n_gpu_layers=n_gpu_layers,
166
+ n_batch=n_batch,
167
+ callback_manager=callback_manager,
168
+ verbose=True,n_ctx = 4096, temperature = 0.1, max_tokens = 256
169
+ )
170
+ return llm
171
+
172
+ def set_custom_prompt():
173
+ custom_prompt_template = """ Use the following pieces of information to answer the user's question.
174
+ If you don't know the answer, please just say that you don't know the answer, don't try to make up
175
+ an answer.
176
+
177
+ Context : {context}
178
+ chat_history : {chat_history}
179
+ Question : {question}
180
+
181
+ Only returns the helpful answer below and nothing else.
182
+ Helpful answer:
183
+ """
184
+ prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context',
185
+ 'question',
186
+ 'chat_history'])
187
+ return prompt
188
+
189
+ @st.cache_resource
190
+ def load_embeddings():
191
+ embeddings = HuggingFaceEmbeddings(model_name = "sentence-transformers/all-MiniLM-L6-v2",
192
+ model_kwargs = {'device': 'cpu'})
193
+ return embeddings
194
+
195
+ def main():
196
+ msgs = StreamlitChatMessageHistory(key="langchain_messages")
197
+ print(msgs)
198
+ if "messages" not in st.session_state:
199
+ st.session_state.messages = []
200
+
201
+ data = []
202
+ # DB_FAISS_UPLOAD_PATH = "vectorstores/db_faiss"
203
+ st.header("DOCUMENT QUESTION ANSWERING IS2")
204
+ directory = "data"
205
+ data_dir = UploadDoc(directory).create_document()
206
+ data.extend(data_dir)
207
+
208
+ #create vector from upload
209
+ if len(data) > 0 :
210
+ sp_docs = split_docs(documents = data)
211
+ st.write(f"This document have {len(sp_docs)} chunks")
212
+ embeddings = load_embeddings()
213
+ with st.spinner('Wait for create vector'):
214
+ db = FAISS.from_documents(sp_docs, embeddings)
215
+ # db.save_local(DB_FAISS_UPLOAD_PATH)
216
+ # st.write(f"Your model is already store in {DB_FAISS_UPLOAD_PATH}")
217
+
218
+ llm = load_llama2_llamaCpp()
219
+ qa_prompt = set_custom_prompt()
220
+ memory = ConversationBufferWindowMemory(k = 0, return_messages=True, input_key= 'question', output_key='answer', memory_key="chat_history")
221
+ #memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
222
+ doc_chain = load_qa_chain(llm, chain_type="stuff", prompt = qa_prompt)
223
+ question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
224
+ qa_chain = ConversationalRetrievalChain(
225
+ retriever =db.as_retriever(search_type="similarity_score_threshold", search_kwargs={'k':3, "score_threshold": 0.7}),
226
+ question_generator=question_generator,
227
+ #condense_question_prompt=CONDENSE_QUESTION_PROMPT,
228
+ combine_docs_chain=doc_chain,
229
+ return_source_documents=True,
230
+ memory = memory,
231
+ #get_chat_history=lambda h :h
232
+ )
233
+
234
+ for message in st.session_state.messages:
235
+ with st.chat_message(message["role"]):
236
+ st.markdown(message["content"])
237
+
238
+ # Accept user input
239
+ if query := st.chat_input("What is up?"):
240
+ # Display user message in chat message container
241
+ with st.chat_message("user"):
242
+ st.markdown(query)
243
+ # Add user message to chat history
244
+ st.session_state.messages.append({"role": "user", "content": query})
245
+
246
+ start = time.time()
247
+
248
+ response = qa_chain({'question': query})
249
+
250
+ url_list = set([i.metadata['source'] for i in response['source_documents']])
251
+ #print(f"condensed quesion : {question_generator.run({'chat_history': response['chat_history'], 'question' : query})}")
252
+
253
+ end = time.time()
254
+ st.write("Respone time:",int(end-start),"sec")
255
+
256
+ # Add assistant response to chat history
257
+ st.session_state.messages.append({"role": "assistant", "content": response['answer']})
258
+
259
+ with st.expander("See the related documents"):
260
+ for count, url in enumerate(url_list):
261
+ #url_reg = regex_source(url)
262
+ st.write(str(count+1)+":", url)
263
+
264
+ view_messages = st.expander("View the message contents in session state")
265
+ with view_messages:
266
+ view_messages.json(st.session_state.langchain_messages)
267
+
268
+ clear_button = st.button("Start new convo")
269
+ if clear_button :
270
+ st.session_state.messages = []
271
+ qa_chain.memory.chat_memory.clear()
272
+
273
+
274
+ if __name__ == '__main__':
275
+ main()