clementsan commited on
Commit
5be8df6
·
1 Parent(s): 529bde4

Add PDF chatbot application

Browse files
Files changed (1) hide show
  1. app.py +236 -4
app.py CHANGED
@@ -1,7 +1,239 @@
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
 
4
+ from langchain.document_loaders import PyPDFLoader
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.vectorstores import Chroma
7
+ from langchain.chains import ConversationalRetrievalChain
8
+ from langchain.embeddings import HuggingFaceEmbeddings
9
+ from langchain.llms import HuggingFacePipeline
10
+ from langchain.chains import ConversationChain
11
+ from langchain.memory import ConversationBufferMemory
12
+ from langchain.llms import HuggingFaceHub
13
 
14
+ from transformers import AutoTokenizer
15
+ import transformers
16
+ import torch
17
+ import tqdm
18
+ import accelerate
19
+
20
+
21
+ default_persist_directory = './chroma_HF/'
22
+ default_llm_name1 = "tiiuae/falcon-7b-instruct"
23
+ default_llm_name2 = "google/flan-t5-xxl"
24
+ default_llm_name3 = "mosaicml/mpt-7b-instruct"
25
+ default_llm_name4 = "meta-llama/Llama-2-7b-chat-hf"
26
+ default_llm_name5 = "mistralai/Mistral-7B-Instruct-v0.1"
27
+ list_llm = [default_llm_name1, default_llm_name2, default_llm_name3, default_llm_name4, default_llm_name5]
28
+
29
+
30
+ # Load PDF document and create doc splits
31
+ def load_doc(list_file_path, chunk_size, chunk_overlap):
32
+ # Processing for one document only
33
+ # loader = PyPDFLoader(file_path)
34
+ # pages = loader.load()
35
+ loaders = [PyPDFLoader(x) for x in list_file_path]
36
+ pages = []
37
+ for loader in loaders:
38
+ pages.extend(loader.load())
39
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50)
40
+ text_splitter = RecursiveCharacterTextSplitter(
41
+ chunk_size = chunk_size,
42
+ chunk_overlap = chunk_overlap)
43
+ doc_splits = text_splitter.split_documents(pages)
44
+ return doc_splits
45
+
46
+
47
+ # Create vector database
48
+ def create_db(splits):
49
+ embedding = HuggingFaceEmbeddings()
50
+ vectordb = Chroma.from_documents(
51
+ documents=splits,
52
+ embedding=embedding,
53
+ persist_directory=default_persist_directory
54
+ )
55
+ return vectordb
56
+
57
+
58
+ # Load vector database
59
+ def load_db():
60
+ embedding = HuggingFaceEmbeddings()
61
+ vectordb = Chroma(
62
+ persist_directory=default_persist_directory,
63
+ embedding_function=embedding)
64
+ return vectordb
65
+
66
+
67
+ # Initialize langchain LLM chain
68
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
69
+ progress(0.1, desc="Initializing HF tokenizer...")
70
+ # HuggingFacePipeline uses local model
71
+ # Warning: it will download model locally...
72
+ # tokenizer=AutoTokenizer.from_pretrained(llm_model)
73
+ # progress(0.5, desc="Initializing HF pipeline...")
74
+ # pipeline=transformers.pipeline(
75
+ # "text-generation",
76
+ # model=llm_model,
77
+ # tokenizer=tokenizer,
78
+ # torch_dtype=torch.bfloat16,
79
+ # trust_remote_code=True,
80
+ # device_map="auto",
81
+ # # max_length=1024,
82
+ # max_new_tokens=max_tokens,
83
+ # do_sample=True,
84
+ # top_k=top_k,
85
+ # num_return_sequences=1,
86
+ # eos_token_id=tokenizer.eos_token_id
87
+ # )
88
+ # llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': temperature})
89
+
90
+ # HuggingFaceHub uses HF inference endpoints
91
+ progress(0.5, desc="Initializing HF Hub...")
92
+ llm = HuggingFaceHub(
93
+ repo_id=llm_model,
94
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
95
+ )
96
+
97
+ progress(0.5, desc="Defining buffer memory...")
98
+ memory = ConversationBufferMemory(
99
+ memory_key="chat_history",
100
+ return_messages=True
101
+ )
102
+ # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
103
+ retriever=vector_db.as_retriever()
104
+ progress(0.8, desc="Defining retrieval chain...")
105
+ global qa_chain
106
+ qa_chain = ConversationalRetrievalChain.from_llm(
107
+ llm,
108
+ retriever=retriever,
109
+ chain_type="stuff",
110
+ memory=memory,
111
+ # combine_docs_chain_kwargs={"prompt": your_prompt})
112
+ # return_source_documents=True,
113
+ # return_generated_question=True,
114
+ # verbose=True,
115
+ )
116
+ progress(0.9, desc="Done!")
117
+ # return qa_chain
118
+
119
+
120
+ # Initialize all elements
121
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
122
+ # Create list of documents (when valid)
123
+ #file_path = file_obj.name
124
+ list_file_path = [x.name for x in list_file_obj if x is not None]
125
+ print('list_file_path', list_file_path)
126
+ progress(0.25, desc="Loading document...")
127
+ # Load document and create splits
128
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
129
+ # Create or load Vector database
130
+ progress(0.5, desc="Generating vector database...")
131
+ # global vector_db
132
+ vector_db = create_db(doc_splits)
133
+ progress(0.9, desc="Done!")
134
+ return vector_db, "Complete!"
135
+ #return qa_chain
136
+
137
+
138
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
139
+ print("llm_option",llm_option)
140
+ llm_name = list_llm[llm_option]
141
+ print("llm_name",llm_name)
142
+ initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
143
+ return "Complete!"
144
+ #return qa_chain
145
+
146
+
147
+ def format_chat_history(message, chat_history):
148
+ formatted_chat_history = []
149
+ for user_message, bot_message in chat_history:
150
+ formatted_chat_history.append(f"User: {user_message}")
151
+ formatted_chat_history.append(f"Assistant: {bot_message}")
152
+ return formatted_chat_history
153
+
154
+
155
+ def conversation(message, history):
156
+ formatted_chat_history = format_chat_history(message, history)
157
+ #print("formatted_chat_history",formatted_chat_history)
158
+
159
+ # Generate response using QA chain
160
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
161
+ # return response['answer']
162
+
163
+ # Append user message and response to chat history
164
+ new_history = history + [(message, response["answer"])]
165
+ return gr.update(value=""), new_history
166
+
167
+
168
+ def upload_file(file_obj):
169
+ list_file_path = []
170
+ for idx, file in enumerate(file_obj):
171
+ file_path = file_obj.name
172
+ list_file_path.append(file_path)
173
+ # print(file_path)
174
+ # initialize_database(file_path, progress)
175
+ return list_file_path
176
+
177
+
178
+ def demo():
179
+ with gr.Blocks(theme="base") as demo:
180
+ vector_db = gr.Variable()
181
+ # qa_chain = gr.Variable()
182
+
183
+ gr.Markdown(
184
+ """<center><h2> Document-based chatbot</center></h2>
185
+ <h3>Ask any questions about your PDF documents (single or multiple)</h3>
186
+ <i>Note: chatbot performs question-answering using Langchain and LLMs</i>
187
+ """)
188
+ with gr.Tab("Step 1 - Document pre-processing"):
189
+ with gr.Row():
190
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload PDF Documents")
191
+ # upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
192
+ with gr.Row():
193
+ db_btn = gr.Radio(["ChromaDB"], label="Vector database", value = "ChromaDB", type="index", info="Choose your vector database")
194
+ with gr.Accordion("Advanced options - Text splitter", open=False):
195
+ with gr.Row():
196
+ slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
197
+ with gr.Row():
198
+ slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=50, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
199
+ with gr.Row():
200
+ db_progress = gr.Textbox(label="Database Initialization", value="None")
201
+ with gr.Row():
202
+ db_btn = gr.Button("Generating vector database...")
203
+
204
+ with gr.Tab("Step 2 - Initializing QA chain"):
205
+ with gr.Row():
206
+ llm_btn = gr.Radio(["falcon-7b-instruct", "flan-t5-xxl", "mpt-7b-instruct", "Llama-2-7b-chat-hf", "Mistral-7B-Instruct-v0.1"], \
207
+ label="LLM", value = "falcon-7b-instruct", type="index", info="Choose your LLM model")
208
+ with gr.Accordion("Advanced options - LLM", open=False):
209
+ slider_temperature = gr.Slider(minimum = 0.0, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
210
+ slider_maxtokens = gr.Slider(minimum = 256, maximum = 4096, value=1024, step=24, label="Max Tokens", info="Model max tokens", interactive=True)
211
+ slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
212
+ with gr.Row():
213
+ llm_progress = gr.Textbox(value="None",label="QA chain Initialization")
214
+ with gr.Row():
215
+ qachain_btn = gr.Button("QA chain Initialization...")
216
+
217
+ with gr.Tab("Step 3 - Conversation"):
218
+ chatbot = gr.Chatbot(height=600)
219
+ with gr.Row():
220
+ msg = gr.Textbox(placeholder="Type message", container=True)
221
+ with gr.Row():
222
+ submit_btn = gr.Button("Submit")
223
+ clear_btn = gr.ClearButton([msg, chatbot])
224
+
225
+ # Preprocessing events
226
+ #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
227
+ db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap], outputs=[vector_db, db_progress])
228
+ qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[llm_progress]).then(lambda: None, None, chatbot, queue=False)
229
+
230
+ # Chatbot events
231
+ msg.submit(conversation, [msg, chatbot], [msg, chatbot], queue=False)
232
+ submit_btn.click(conversation, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False)
233
+ clear_btn.click(lambda: None, None, chatbot, queue=False)
234
+ demo.queue(concurrency_count=20).launch(debug=True)
235
+
236
+
237
+ if __name__ == "__main__":
238
+ demo()
239
+