Gokulnath2003 commited on
Commit
e9fa814
1 Parent(s): b0d4cdc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +260 -261
app.py CHANGED
@@ -2,52 +2,51 @@ import gradio as gr
2
  import os
3
  import re
4
  from pathlib import Path
5
-
6
  from langchain_community.vectorstores import Chroma
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain_community.embeddings import HuggingFaceEmbeddings
9
- from langchain_community.llms import HuggingFacePipeline
10
- from langchain.chains import ConversationChain
11
- from langchain.memory import ConversationBufferMemory
12
  from langchain_community.llms import HuggingFaceEndpoint
13
 
14
  from pathlib import Path
15
  import chromadb
16
  from unidecode import unidecode
17
 
18
- from transformers import AutoTokenizer
19
- import transformers
20
- import torch
21
- import tqdm
22
- import accelerate
23
- import re
 
 
 
 
 
 
 
24
 
25
 
26
 
27
- # default_persist_directory = './chroma_HF/'
28
- list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \
29
- "google/gemma-7b-it","google/gemma-2b-it", \
30
- "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1", \
31
- "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \
32
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
33
- "google/flan-t5-xxl"
34
  ]
35
- list_llm_simple = [os.path.basename(llm) for llm in list_llm]
36
 
37
  # Load PDF document and create doc splits
38
  def load_doc(list_file_path, chunk_size, chunk_overlap):
39
- # Processing for one document only
40
- # loader = PyPDFLoader(file_path)
41
- # pages = loader.load()
42
  loaders = [PyPDFLoader(x) for x in list_file_path]
43
  pages = []
44
  for loader in loaders:
45
  pages.extend(loader.load())
46
- # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50)
47
- text_splitter = RecursiveCharacterTextSplitter(
48
- chunk_size = chunk_size,
49
- chunk_overlap = chunk_overlap)
50
 
 
 
 
 
51
  doc_splits = text_splitter.split_documents(pages)
52
  return doc_splits
53
 
@@ -58,322 +57,322 @@ def create_db(splits, collection_name):
58
  embedding=embedding,
59
  client=new_client,
60
  collection_name=collection_name,
61
- # persist_directory=default_persist_directory
62
  )
63
  return vectordb
64
 
65
 
66
- # Load vector database
67
- def load_db():
68
- embedding = HuggingFaceEmbeddings()
69
- vectordb = Chroma(
70
- # persist_directory=default_persist_directory,
71
- embedding_function=embedding)
72
- return vectordb
73
 
74
 
75
  # Initialize langchain LLM chain
76
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
77
- progress(0.1, desc="Initializing HF tokenizer...")
78
- # HuggingFacePipeline uses local model
79
- # Note: it will download model locally...
80
- # tokenizer=AutoTokenizer.from_pretrained(llm_model)
81
- # progress(0.5, desc="Initializing HF pipeline...")
82
- # pipeline=transformers.pipeline(
83
- # "text-generation",
84
- # model=llm_model,
85
- # tokenizer=tokenizer,
86
- # torch_dtype=torch.bfloat16,
87
- # trust_remote_code=True,
88
- # device_map="auto",
89
- # # max_length=1024,
90
- # max_new_tokens=max_tokens,
91
- # do_sample=True,
92
- # top_k=top_k,
93
- # num_return_sequences=1,
94
- # eos_token_id=tokenizer.eos_token_id
95
- # )
96
- # llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': temperature})
97
-
98
- # HuggingFaceHub uses HF inference endpoints
99
- progress(0.5, desc="Initializing HF Hub...")
100
- # Use of trust_remote_code as model_kwargs
101
- # Warning: langchain issue
102
- # URL: https://github.com/langchain-ai/langchain/issues/6080
103
- if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
104
- llm = HuggingFaceEndpoint(
105
- repo_id=llm_model,
106
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
107
- temperature = temperature,
108
- max_new_tokens = max_tokens,
109
- top_k = top_k,
110
- load_in_8bit = True,
111
- )
112
- elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1","mosaicml/mpt-7b-instruct"]:
113
- raise gr.Error("LLM model is too large to be loaded automatically on free inference endpoint")
114
- llm = HuggingFaceEndpoint(
115
- repo_id=llm_model,
116
- temperature = temperature,
117
- max_new_tokens = max_tokens,
118
- top_k = top_k,
119
- )
120
- elif llm_model == "microsoft/phi-2":
121
- # raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
122
- llm = HuggingFaceEndpoint(
123
- repo_id=llm_model,
124
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
125
- temperature = temperature,
126
- max_new_tokens = max_tokens,
127
- top_k = top_k,
128
- trust_remote_code = True,
129
- torch_dtype = "auto",
130
- )
131
- elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
132
- llm = HuggingFaceEndpoint(
133
- repo_id=llm_model,
134
- # model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
135
- temperature = temperature,
136
- max_new_tokens = 250,
137
- top_k = top_k,
138
- )
139
- elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
140
- raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
141
- llm = HuggingFaceEndpoint(
142
- repo_id=llm_model,
143
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
144
- temperature = temperature,
145
- max_new_tokens = max_tokens,
146
- top_k = top_k,
147
- )
148
- else:
149
- llm = HuggingFaceEndpoint(
150
- repo_id=llm_model,
151
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
152
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
153
- temperature = temperature,
154
- max_new_tokens = max_tokens,
155
- top_k = top_k,
156
- )
157
 
158
- progress(0.75, desc="Defining buffer memory...")
159
  memory = ConversationBufferMemory(
160
  memory_key="chat_history",
161
  output_key='answer',
162
  return_messages=True
163
  )
164
- # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
165
- retriever=vector_db.as_retriever()
166
- progress(0.8, desc="Defining retrieval chain...")
167
  qa_chain = ConversationalRetrievalChain.from_llm(
168
  llm,
169
  retriever=retriever,
170
  chain_type="stuff",
171
  memory=memory,
172
- # combine_docs_chain_kwargs={"prompt": your_prompt})
173
  return_source_documents=True,
174
- #return_generated_question=False,
175
  verbose=False,
176
  )
177
- progress(0.9, desc="Done!")
178
  return qa_chain
179
 
180
 
181
  # Generate collection name for vector database
182
- # - Use filepath as input, ensuring unicode text
183
  def create_collection_name(filepath):
184
- # Extract filename without extension
185
  collection_name = Path(filepath).stem
186
- # Fix potential issues from naming convention
187
- ## Remove space
188
- collection_name = collection_name.replace(" ","-")
189
- ## ASCII transliterations of Unicode text
190
- collection_name = unidecode(collection_name)
191
- ## Remove special characters
192
- #collection_name = re.findall("[\dA-Za-z]*", collection_name)[0]
193
- collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
194
- ## Limit length to 50 characters
195
- collection_name = collection_name[:50]
196
- ## Minimum length of 3 characters
197
  if len(collection_name) < 3:
198
  collection_name = collection_name + 'xyz'
199
- ## Enforce start and end as alphanumeric character
200
  if not collection_name[0].isalnum():
201
  collection_name = 'A' + collection_name[1:]
202
  if not collection_name[-1].isalnum():
203
  collection_name = collection_name[:-1] + 'Z'
204
- print('Filepath: ', filepath)
205
- print('Collection name: ', collection_name)
206
  return collection_name
207
 
208
 
209
  # Initialize database
210
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
211
- # Create list of documents (when valid)
212
  list_file_path = [x.name for x in list_file_obj if x is not None]
213
- # Create collection_name for vector database
214
- progress(0.1, desc="Creating collection name...")
215
  collection_name = create_collection_name(list_file_path[0])
216
- progress(0.25, desc="Loading document...")
217
- # Load document and create splits
218
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
219
- # Create or load vector database
220
- progress(0.5, desc="Generating vector database...")
221
- # global vector_db
222
  vector_db = create_db(doc_splits, collection_name)
223
- progress(0.9, desc="Done!")
224
- return vector_db, collection_name, "Complete!"
225
 
 
226
 
 
227
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
228
- # print("llm_option",llm_option)
229
- llm_name = list_llm[llm_option]
230
- print("llm_name: ",llm_name)
231
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
232
  return qa_chain, "Complete!"
233
 
234
-
235
  def format_chat_history(message, chat_history):
236
  formatted_chat_history = []
237
  for user_message, bot_message in chat_history:
238
  formatted_chat_history.append(f"User: {user_message}")
239
  formatted_chat_history.append(f"Assistant: {bot_message}")
240
  return formatted_chat_history
241
-
242
 
243
 
 
244
  def conversation(qa_chain, message, history):
245
  formatted_chat_history = format_chat_history(message, history)
246
- #print("formatted_chat_history",formatted_chat_history)
247
-
248
- # Generate response using QA chain
249
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
250
- response_answer = response["answer"]
251
- if response_answer.find("Helpful Answer:") != -1:
252
- response_answer = response_answer.split("Helpful Answer:")[-1]
253
  response_sources = response["source_documents"]
254
- response_source1 = response_sources[0].page_content.strip()
255
- response_source2 = response_sources[1].page_content.strip()
256
- response_source3 = response_sources[2].page_content.strip()
257
- # Langchain sources are zero-based
258
- response_source1_page = response_sources[0].metadata["page"] + 1
259
- response_source2_page = response_sources[1].metadata["page"] + 1
260
- response_source3_page = response_sources[2].metadata["page"] + 1
261
- # print ('chat response: ', response_answer)
262
- # print('DB source', response_sources)
263
-
264
- # Append user message and response to chat history
265
  new_history = history + [(message, response_answer)]
266
- # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
267
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
268
-
 
 
 
 
 
 
 
269
 
270
- def upload_file(file_obj):
271
- list_file_path = []
272
- for idx, file in enumerate(file_obj):
273
- file_path = file_obj.name
274
- list_file_path.append(file_path)
275
- # print(file_path)
276
- # initialize_database(file_path, progress)
277
- return list_file_path
278
 
279
 
280
 
 
281
  def demo():
282
- with gr.Blocks(theme="base") as demo:
283
  vector_db = gr.State()
284
  qa_chain = gr.State()
285
  collection_name = gr.State()
286
 
287
  gr.Markdown(
288
- """<center><h2>PDF-based chatbot</center></h2>
289
  <h3>Ask any questions about your PDF documents</h3>""")
290
- gr.Markdown(
291
- """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
292
- The user interface explicitely shows multiple steps to help understand the RAG workflow.
293
- This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
294
- <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
295
- """)
296
 
297
- with gr.Tab("Step 1 - Upload PDF"):
298
- with gr.Row():
299
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
300
- # upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
301
 
302
- with gr.Tab("Step 2 - Process document"):
303
- with gr.Row():
304
- db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
305
- with gr.Accordion("Advanced options - Document text splitter", open=False):
306
- with gr.Row():
307
- slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
308
- with gr.Row():
309
- slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
310
- with gr.Row():
311
- db_progress = gr.Textbox(label="Vector database initialization", value="None")
312
- with gr.Row():
313
- db_btn = gr.Button("Generate vector database")
314
 
315
- with gr.Tab("Step 3 - Initialize QA chain"):
316
- with gr.Row():
317
- llm_btn = gr.Radio(list_llm_simple, \
318
- label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
319
- with gr.Accordion("Advanced options - LLM model", open=False):
320
- with gr.Row():
321
- slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
322
- with gr.Row():
323
- slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
324
- with gr.Row():
325
- slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
326
- with gr.Row():
327
- llm_progress = gr.Textbox(value="None",label="QA chain initialization")
328
- with gr.Row():
329
- qachain_btn = gr.Button("Initialize Question Answering chain")
330
-
331
- with gr.Tab("Step 4 - Chatbot"):
332
  chatbot = gr.Chatbot(height=300)
333
- with gr.Accordion("Advanced - Document references", open=False):
334
- with gr.Row():
335
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
336
- source1_page = gr.Number(label="Page", scale=1)
337
- with gr.Row():
338
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
339
- source2_page = gr.Number(label="Page", scale=1)
340
- with gr.Row():
341
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
342
- source3_page = gr.Number(label="Page", scale=1)
343
- with gr.Row():
344
- msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
345
- with gr.Row():
346
- submit_btn = gr.Button("Submit message")
347
- clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
348
 
349
- # Preprocessing events
350
- #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
351
- db_btn.click(initialize_database, \
352
- inputs=[document, slider_chunk_size, slider_chunk_overlap], \
353
- outputs=[vector_db, collection_name, db_progress])
354
- qachain_btn.click(initialize_LLM, \
355
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
356
- outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
357
- inputs=None, \
358
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
359
- queue=False)
360
-
361
- # Chatbot events
362
- msg.submit(conversation, \
363
- inputs=[qa_chain, msg, chatbot], \
364
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
365
- queue=False)
366
- submit_btn.click(conversation, \
367
- inputs=[qa_chain, msg, chatbot], \
368
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
369
- queue=False)
370
- clear_btn.click(lambda:[None,"",0,"",0,"",0], \
371
- inputs=None, \
372
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
373
- queue=False)
374
- demo.queue().launch(debug=True)
375
 
376
 
 
377
 
378
  if __name__ == "__main__":
379
  demo()
 
2
  import os
3
  import re
4
  from pathlib import Path
 
5
  from langchain_community.vectorstores import Chroma
6
  from langchain.chains import ConversationalRetrievalChain
7
  from langchain_community.embeddings import HuggingFaceEmbeddings
8
+
9
+
10
+
11
  from langchain_community.llms import HuggingFaceEndpoint
12
 
13
  from pathlib import Path
14
  import chromadb
15
  from unidecode import unidecode
16
 
17
+ # List of allowed models
18
+ allowed_llms = [
19
+ "mistralai/Mistral-7B-Instruct-v0.2",
20
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
21
+ "mistralai/Mistral-7B-Instruct-v0.1",
22
+ "google/gemma-7b-it",
23
+ "google/gemma-2b-it",
24
+ "HuggingFaceH4/zephyr-7b-beta",
25
+ "HuggingFaceH4/zephyr-7b-gemma-v0.1",
26
+ "meta-llama/Llama-2-7b-chat-hf"
27
+
28
+
29
+
30
 
31
 
32
 
 
 
 
 
 
 
 
33
  ]
34
+ list_llm_simple = [os.path.basename(llm) for llm in allowed_llms]
35
 
36
  # Load PDF document and create doc splits
37
  def load_doc(list_file_path, chunk_size, chunk_overlap):
38
+
39
+
40
+
41
  loaders = [PyPDFLoader(x) for x in list_file_path]
42
  pages = []
43
  for loader in loaders:
44
  pages.extend(loader.load())
 
 
 
 
45
 
46
+ text_splitter = RecursiveCharacterTextSplitter(
47
+ chunk_size=chunk_size,
48
+ chunk_overlap=chunk_overlap
49
+ )
50
  doc_splits = text_splitter.split_documents(pages)
51
  return doc_splits
52
 
 
57
  embedding=embedding,
58
  client=new_client,
59
  collection_name=collection_name,
60
+
61
  )
62
  return vectordb
63
 
64
 
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
 
73
 
74
  # Initialize langchain LLM chain
75
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
76
+ llm = HuggingFaceEndpoint(
77
+ repo_id=llm_model,
78
+ temperature=temperature,
79
+ max_new_tokens=max_tokens,
80
+ top_k=top_k,
81
+ load_in_8bit=True,
82
+ )
83
+
84
+
85
+
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
107
+
108
+
109
+
110
+
111
+
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+
153
+
154
+
155
+
156
 
157
+
158
  memory = ConversationBufferMemory(
159
  memory_key="chat_history",
160
  output_key='answer',
161
  return_messages=True
162
  )
163
+ retriever = vector_db.as_retriever()
164
+
165
+
166
  qa_chain = ConversationalRetrievalChain.from_llm(
167
  llm,
168
  retriever=retriever,
169
  chain_type="stuff",
170
  memory=memory,
171
+
172
  return_source_documents=True,
173
+
174
  verbose=False,
175
  )
176
+
177
  return qa_chain
178
 
179
 
180
  # Generate collection name for vector database
181
+
182
  def create_collection_name(filepath):
183
+
184
  collection_name = Path(filepath).stem
185
+ collection_name = unidecode(collection_name).replace(" ", "-")
186
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)[:50]
187
+
188
+
189
+
190
+
191
+
192
+
193
+
194
+
195
+
196
  if len(collection_name) < 3:
197
  collection_name = collection_name + 'xyz'
198
+
199
  if not collection_name[0].isalnum():
200
  collection_name = 'A' + collection_name[1:]
201
  if not collection_name[-1].isalnum():
202
  collection_name = collection_name[:-1] + 'Z'
203
+
204
+
205
  return collection_name
206
 
207
 
208
  # Initialize database
209
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
210
+
211
  list_file_path = [x.name for x in list_file_obj if x is not None]
212
+
213
+
214
  collection_name = create_collection_name(list_file_path[0])
215
+
216
+
217
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
218
+
219
+
220
+
221
  vector_db = create_db(doc_splits, collection_name)
 
 
222
 
223
+ return vector_db, collection_name, "Complete!"
224
 
225
+ # Initialize LLM
226
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
227
+ llm_name = allowed_llms[llm_option]
228
+
229
+
230
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
231
  return qa_chain, "Complete!"
232
 
233
+ # Format chat history
234
  def format_chat_history(message, chat_history):
235
  formatted_chat_history = []
236
  for user_message, bot_message in chat_history:
237
  formatted_chat_history.append(f"User: {user_message}")
238
  formatted_chat_history.append(f"Assistant: {bot_message}")
239
  return formatted_chat_history
 
240
 
241
 
242
+ # Conversation handling
243
  def conversation(qa_chain, message, history):
244
  formatted_chat_history = format_chat_history(message, history)
245
+
246
+
247
+
248
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
249
+ response_answer = response["answer"].split("Helpful Answer:")[-1]
250
+
251
+
252
  response_sources = response["source_documents"]
253
+
254
+
255
+
256
+
257
+
258
+
259
+
260
+
261
+
262
+
263
+
264
  new_history = history + [(message, response_answer)]
265
+ response_details = [(src.page_content.strip(), src.metadata["page"] + 1) for src in response_sources[:3]]
266
+ return qa_chain, gr.update(value=""), new_history, *sum(response_details, ())
267
+
268
+
269
+
270
+
271
+
272
+
273
+
274
+
275
 
 
 
 
 
 
 
 
 
276
 
277
 
278
 
279
+ # Gradio Interface
280
  def demo():
281
+ with gr.Blocks(theme="default") as demo:
282
  vector_db = gr.State()
283
  qa_chain = gr.State()
284
  collection_name = gr.State()
285
 
286
  gr.Markdown(
287
+ """<center><h2>PDF-based Chatbot</h2></center>
288
  <h3>Ask any questions about your PDF documents</h3>""")
289
+
290
+
291
+
292
+
293
+
294
+
295
 
296
+ with gr.Tab("Upload PDF"):
297
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload PDF Documents")
298
+
299
+
300
 
301
+ with gr.Tab("Process Document"):
302
+ db_btn = gr.Radio(["ChromaDB"], label="Vector Database", value="ChromaDB", type="index")
303
+ with gr.Accordion("Advanced Options", open=False):
304
+ slider_chunk_size = gr.Slider(100, 1000, 600, 20, label="Chunk Size", interactive=True)
305
+ slider_chunk_overlap = gr.Slider(10, 200, 40, 10, label="Chunk Overlap", interactive=True)
306
+ db_progress = gr.Textbox(label="Database Initialization Status", value="None")
307
+ db_btn = gr.Button("Generate Database")
308
+
309
+
310
+
311
+
312
+
313
 
314
+ with gr.Tab("Initialize QA Chain"):
315
+ llm_btn = gr.Radio(list_llm_simple, label="LLM Models", value=list_llm_simple[0], type="index")
316
+ with gr.Accordion("Advanced Options", open=False):
317
+ slider_temperature = gr.Slider(0.01, 1.0, 0.7, 0.1, label="Temperature", interactive=True)
318
+ slider_maxtokens = gr.Slider(224, 4096, 1024, 32, label="Max Tokens", interactive=True)
319
+ slider_topk = gr.Slider(1, 10, 3, 1, label="Top-k Samples", interactive=True)
320
+ llm_progress = gr.Textbox(value="None", label="QA Chain Initialization Status")
321
+ qachain_btn = gr.Button("Initialize QA Chain")
322
+
323
+ with gr.Tab("Chatbot"):
324
+
325
+
326
+
327
+
328
+
329
+
330
+
331
  chatbot = gr.Chatbot(height=300)
332
+ with gr.Accordion("Document References", open=False):
333
+ for i in range(1, 4):
334
+ gr.Row([gr.Textbox(label=f"Reference {i}", lines=2, container=True, scale=20), gr.Number(label="Page", scale=1)])
335
+ msg = gr.Textbox(placeholder="Type message here...", container=True)
336
+ gr.Row([gr.Button("Submit"), gr.Button("Clear Conversation")])
337
+
338
+
339
+
340
+
341
+
342
+
343
+
344
+
345
+
346
+
347
 
348
+ # Define Interactions
349
+ db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap], outputs=[vector_db, collection_name, db_progress])
350
+ qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress])
351
+ msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot] + [None] * 6)
352
+
353
+
354
+
355
+
356
+
357
+
358
+
359
+
360
+
361
+
362
+
363
+
364
+
365
+
366
+
367
+
368
+
369
+
370
+
371
+
372
+
 
373
 
374
 
375
+ demo.launch(debug=True)
376
 
377
  if __name__ == "__main__":
378
  demo()