Gokulnath2003 commited on
Commit
b0d4cdc
1 Parent(s): 85a72e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +278 -71
app.py CHANGED
@@ -3,170 +3,377 @@ import os
3
  import re
4
  from pathlib import Path
5
 
6
- from langchain_community.document_loaders import PyPDFLoader
7
- from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain_community.vectorstores import Chroma
9
  from langchain.chains import ConversationalRetrievalChain
10
- from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
 
11
  from langchain_community.llms import HuggingFaceEndpoint
 
 
12
  import chromadb
13
  from unidecode import unidecode
14
 
15
- # List of allowed models
16
- allowed_llms = [
17
- "mistralai/Mistral-7B-Instruct-v0.2",
18
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
19
- "mistralai/Mistral-7B-Instruct-v0.1",
20
- "google/gemma-7b-it",
21
- "google/gemma-2b-it",
22
- "HuggingFaceH4/zephyr-7b-beta",
23
- "HuggingFaceH4/zephyr-7b-gemma-v0.1",
24
- "meta-llama/Llama-2-7b-chat-hf"
 
 
 
 
 
 
25
  ]
26
- list_llm_simple = [os.path.basename(llm) for llm in allowed_llms]
27
 
28
  # Load PDF document and create doc splits
29
  def load_doc(list_file_path, chunk_size, chunk_overlap):
 
 
 
30
  loaders = [PyPDFLoader(x) for x in list_file_path]
31
  pages = []
32
  for loader in loaders:
33
  pages.extend(loader.load())
 
34
  text_splitter = RecursiveCharacterTextSplitter(
35
- chunk_size=chunk_size,
36
- chunk_overlap=chunk_overlap
37
- )
38
  doc_splits = text_splitter.split_documents(pages)
39
  return doc_splits
40
 
 
41
  # Create vector database
42
  def create_db(splits, collection_name):
43
  embedding = HuggingFaceEmbeddings()
44
- new_client = chromadb.EphemeralClient()
45
- vectordb = Chroma.from_documents(
46
- documents=splits,
47
  embedding=embedding,
48
  client=new_client,
49
  collection_name=collection_name,
 
50
  )
51
  return vectordb
52
 
 
 
 
 
 
 
 
 
 
 
53
  # Initialize langchain LLM chain
54
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
55
- llm = HuggingFaceEndpoint(
56
- repo_id=llm_model,
57
- temperature=temperature,
58
- max_new_tokens=max_tokens,
59
- top_k=top_k,
60
- load_in_8bit=True,
61
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
63
  memory = ConversationBufferMemory(
64
  memory_key="chat_history",
65
  output_key='answer',
66
  return_messages=True
67
  )
68
- retriever = vector_db.as_retriever()
69
-
 
70
  qa_chain = ConversationalRetrievalChain.from_llm(
71
  llm,
72
  retriever=retriever,
73
  chain_type="stuff",
74
  memory=memory,
 
75
  return_source_documents=True,
 
76
  verbose=False,
77
  )
 
78
  return qa_chain
79
 
 
80
  # Generate collection name for vector database
 
81
  def create_collection_name(filepath):
 
82
  collection_name = Path(filepath).stem
83
- collection_name = unidecode(collection_name).replace(" ", "-")
84
- collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)[:50]
 
 
 
 
 
 
 
 
 
85
  if len(collection_name) < 3:
86
  collection_name = collection_name + 'xyz'
 
87
  if not collection_name[0].isalnum():
88
  collection_name = 'A' + collection_name[1:]
89
  if not collection_name[-1].isalnum():
90
  collection_name = collection_name[:-1] + 'Z'
 
 
91
  return collection_name
92
 
 
93
  # Initialize database
94
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
 
95
  list_file_path = [x.name for x in list_file_obj if x is not None]
 
 
96
  collection_name = create_collection_name(list_file_path[0])
 
 
97
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
 
 
 
98
  vector_db = create_db(doc_splits, collection_name)
 
99
  return vector_db, collection_name, "Complete!"
100
 
101
- # Initialize LLM
102
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
103
- llm_name = allowed_llms[llm_option]
 
 
104
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
105
  return qa_chain, "Complete!"
106
 
107
- # Format chat history
108
  def format_chat_history(message, chat_history):
109
  formatted_chat_history = []
110
  for user_message, bot_message in chat_history:
111
  formatted_chat_history.append(f"User: {user_message}")
112
  formatted_chat_history.append(f"Assistant: {bot_message}")
113
  return formatted_chat_history
 
 
114
 
115
- # Conversation handling
116
  def conversation(qa_chain, message, history):
117
  formatted_chat_history = format_chat_history(message, history)
 
 
 
118
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
119
- response_answer = response["answer"].split("Helpful Answer:")[-1]
 
 
120
  response_sources = response["source_documents"]
 
 
 
 
 
 
 
 
 
 
 
121
  new_history = history + [(message, response_answer)]
122
- response_details = [(src.page_content.strip(), src.metadata["page"] + 1) for src in response_sources[:3]]
123
- return gr.update(value=""), new_history, *sum(response_details, ())
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- # Gradio Interface
126
  def demo():
127
- with gr.Blocks(theme="default") as demo:
128
  vector_db = gr.State()
129
  qa_chain = gr.State()
130
  collection_name = gr.State()
131
 
132
  gr.Markdown(
133
- """<center><h2>PDF-based Chatbot</h2></center>
134
  <h3>Ask any questions about your PDF documents</h3>""")
 
 
 
 
 
 
135
 
136
- with gr.Tab("Upload PDF"):
137
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload PDF Documents")
 
 
138
 
139
- with gr.Tab("Process Document"):
140
- db_btn = gr.Radio(["ChromaDB"], label="Vector Database", value="ChromaDB", type="index")
141
- with gr.Accordion("Advanced Options", open=False):
142
- slider_chunk_size = gr.Slider(100, 1000, 600, 20, label="Chunk Size", interactive=True)
143
- slider_chunk_overlap = gr.Slider(10, 200, 40, 10, label="Chunk Overlap", interactive=True)
144
- db_progress = gr.Textbox(label="Database Initialization Status", value="None")
145
- db_btn = gr.Button("Generate Database")
 
 
 
 
 
146
 
147
- with gr.Tab("Initialize QA Chain"):
148
- llm_btn = gr.Radio(list_llm_simple, label="LLM Models", value=list_llm_simple[0], type="index")
149
- with gr.Accordion("Advanced Options", open=False):
150
- slider_temperature = gr.Slider(0.01, 1.0, 0.7, 0.1, label="Temperature", interactive=True)
151
- slider_maxtokens = gr.Slider(224, 4096, 1024, 32, label="Max Tokens", interactive=True)
152
- slider_topk = gr.Slider(1, 10, 3, 1, label="Top-k Samples", interactive=True)
153
- llm_progress = gr.Textbox(value="None", label="QA Chain Initialization Status")
154
- qachain_btn = gr.Button("Initialize QA Chain")
155
-
156
- with gr.Tab("Chatbot"):
 
 
 
 
 
 
 
157
  chatbot = gr.Chatbot(height=300)
158
- with gr.Accordion("Document References", open=False):
159
- for i in range(1, 4):
160
- gr.Row([gr.Textbox(label=f"Reference {i}", lines=2, container=True, scale=20), gr.Number(label="Page", scale=1)])
161
- msg = gr.Textbox(placeholder="Type message here...", container=True)
162
- gr.Row([gr.Button("Submit"), gr.Button("Clear Conversation")])
 
 
 
 
 
 
 
 
 
 
163
 
164
- # Define Interactions
165
- db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap], outputs=[vector_db, collection_name, db_progress])
166
- qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress])
167
- msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[msg, chatbot] + [None] * 6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- demo.launch(debug=True)
170
 
171
  if __name__ == "__main__":
172
- demo()
 
3
  import re
4
  from pathlib import Path
5
 
 
 
6
  from langchain_community.vectorstores import Chroma
7
  from langchain.chains import ConversationalRetrievalChain
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+ from langchain_community.llms import HuggingFacePipeline
10
+ from langchain.chains import ConversationChain
11
+ from langchain.memory import ConversationBufferMemory
12
  from langchain_community.llms import HuggingFaceEndpoint
13
+
14
+ from pathlib import Path
15
  import chromadb
16
  from unidecode import unidecode
17
 
18
+ from transformers import AutoTokenizer
19
+ import transformers
20
+ import torch
21
+ import tqdm
22
+ import accelerate
23
+ import re
24
+
25
+
26
+
27
+ # default_persist_directory = './chroma_HF/'
28
+ list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \
29
+ "google/gemma-7b-it","google/gemma-2b-it", \
30
+ "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1", \
31
+ "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \
32
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
33
+ "google/flan-t5-xxl"
34
  ]
35
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
36
 
37
  # Load PDF document and create doc splits
38
  def load_doc(list_file_path, chunk_size, chunk_overlap):
39
+ # Processing for one document only
40
+ # loader = PyPDFLoader(file_path)
41
+ # pages = loader.load()
42
  loaders = [PyPDFLoader(x) for x in list_file_path]
43
  pages = []
44
  for loader in loaders:
45
  pages.extend(loader.load())
46
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50)
47
  text_splitter = RecursiveCharacterTextSplitter(
48
+ chunk_size = chunk_size,
49
+ chunk_overlap = chunk_overlap)
50
+
51
  doc_splits = text_splitter.split_documents(pages)
52
  return doc_splits
53
 
54
+
55
  # Create vector database
56
  def create_db(splits, collection_name):
57
  embedding = HuggingFaceEmbeddings()
 
 
 
58
  embedding=embedding,
59
  client=new_client,
60
  collection_name=collection_name,
61
+ # persist_directory=default_persist_directory
62
  )
63
  return vectordb
64
 
65
+
66
+ # Load vector database
67
+ def load_db():
68
+ embedding = HuggingFaceEmbeddings()
69
+ vectordb = Chroma(
70
+ # persist_directory=default_persist_directory,
71
+ embedding_function=embedding)
72
+ return vectordb
73
+
74
+
75
  # Initialize langchain LLM chain
76
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
77
+ progress(0.1, desc="Initializing HF tokenizer...")
78
+ # HuggingFacePipeline uses local model
79
+ # Note: it will download model locally...
80
+ # tokenizer=AutoTokenizer.from_pretrained(llm_model)
81
+ # progress(0.5, desc="Initializing HF pipeline...")
82
+ # pipeline=transformers.pipeline(
83
+ # "text-generation",
84
+ # model=llm_model,
85
+ # tokenizer=tokenizer,
86
+ # torch_dtype=torch.bfloat16,
87
+ # trust_remote_code=True,
88
+ # device_map="auto",
89
+ # # max_length=1024,
90
+ # max_new_tokens=max_tokens,
91
+ # do_sample=True,
92
+ # top_k=top_k,
93
+ # num_return_sequences=1,
94
+ # eos_token_id=tokenizer.eos_token_id
95
+ # )
96
+ # llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': temperature})
97
+
98
+ # HuggingFaceHub uses HF inference endpoints
99
+ progress(0.5, desc="Initializing HF Hub...")
100
+ # Use of trust_remote_code as model_kwargs
101
+ # Warning: langchain issue
102
+ # URL: https://github.com/langchain-ai/langchain/issues/6080
103
+ if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
104
+ llm = HuggingFaceEndpoint(
105
+ repo_id=llm_model,
106
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
107
+ temperature = temperature,
108
+ max_new_tokens = max_tokens,
109
+ top_k = top_k,
110
+ load_in_8bit = True,
111
+ )
112
+ elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1","mosaicml/mpt-7b-instruct"]:
113
+ raise gr.Error("LLM model is too large to be loaded automatically on free inference endpoint")
114
+ llm = HuggingFaceEndpoint(
115
+ repo_id=llm_model,
116
+ temperature = temperature,
117
+ max_new_tokens = max_tokens,
118
+ top_k = top_k,
119
+ )
120
+ elif llm_model == "microsoft/phi-2":
121
+ # raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
122
+ llm = HuggingFaceEndpoint(
123
+ repo_id=llm_model,
124
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
125
+ temperature = temperature,
126
+ max_new_tokens = max_tokens,
127
+ top_k = top_k,
128
+ trust_remote_code = True,
129
+ torch_dtype = "auto",
130
+ )
131
+ elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
132
+ llm = HuggingFaceEndpoint(
133
+ repo_id=llm_model,
134
+ # model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
135
+ temperature = temperature,
136
+ max_new_tokens = 250,
137
+ top_k = top_k,
138
+ )
139
+ elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
140
+ raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
141
+ llm = HuggingFaceEndpoint(
142
+ repo_id=llm_model,
143
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
144
+ temperature = temperature,
145
+ max_new_tokens = max_tokens,
146
+ top_k = top_k,
147
+ )
148
+ else:
149
+ llm = HuggingFaceEndpoint(
150
+ repo_id=llm_model,
151
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
152
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
153
+ temperature = temperature,
154
+ max_new_tokens = max_tokens,
155
+ top_k = top_k,
156
+ )
157
 
158
+ progress(0.75, desc="Defining buffer memory...")
159
  memory = ConversationBufferMemory(
160
  memory_key="chat_history",
161
  output_key='answer',
162
  return_messages=True
163
  )
164
+ # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
165
+ retriever=vector_db.as_retriever()
166
+ progress(0.8, desc="Defining retrieval chain...")
167
  qa_chain = ConversationalRetrievalChain.from_llm(
168
  llm,
169
  retriever=retriever,
170
  chain_type="stuff",
171
  memory=memory,
172
+ # combine_docs_chain_kwargs={"prompt": your_prompt})
173
  return_source_documents=True,
174
+ #return_generated_question=False,
175
  verbose=False,
176
  )
177
+ progress(0.9, desc="Done!")
178
  return qa_chain
179
 
180
+
181
  # Generate collection name for vector database
182
+ # - Use filepath as input, ensuring unicode text
183
  def create_collection_name(filepath):
184
+ # Extract filename without extension
185
  collection_name = Path(filepath).stem
186
+ # Fix potential issues from naming convention
187
+ ## Remove space
188
+ collection_name = collection_name.replace(" ","-")
189
+ ## ASCII transliterations of Unicode text
190
+ collection_name = unidecode(collection_name)
191
+ ## Remove special characters
192
+ #collection_name = re.findall("[\dA-Za-z]*", collection_name)[0]
193
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
194
+ ## Limit length to 50 characters
195
+ collection_name = collection_name[:50]
196
+ ## Minimum length of 3 characters
197
  if len(collection_name) < 3:
198
  collection_name = collection_name + 'xyz'
199
+ ## Enforce start and end as alphanumeric character
200
  if not collection_name[0].isalnum():
201
  collection_name = 'A' + collection_name[1:]
202
  if not collection_name[-1].isalnum():
203
  collection_name = collection_name[:-1] + 'Z'
204
+ print('Filepath: ', filepath)
205
+ print('Collection name: ', collection_name)
206
  return collection_name
207
 
208
+
209
  # Initialize database
210
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
211
+ # Create list of documents (when valid)
212
  list_file_path = [x.name for x in list_file_obj if x is not None]
213
+ # Create collection_name for vector database
214
+ progress(0.1, desc="Creating collection name...")
215
  collection_name = create_collection_name(list_file_path[0])
216
+ progress(0.25, desc="Loading document...")
217
+ # Load document and create splits
218
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
219
+ # Create or load vector database
220
+ progress(0.5, desc="Generating vector database...")
221
+ # global vector_db
222
  vector_db = create_db(doc_splits, collection_name)
223
+ progress(0.9, desc="Done!")
224
  return vector_db, collection_name, "Complete!"
225
 
226
+
227
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
228
+ # print("llm_option",llm_option)
229
+ llm_name = list_llm[llm_option]
230
+ print("llm_name: ",llm_name)
231
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
232
  return qa_chain, "Complete!"
233
 
234
+
235
  def format_chat_history(message, chat_history):
236
  formatted_chat_history = []
237
  for user_message, bot_message in chat_history:
238
  formatted_chat_history.append(f"User: {user_message}")
239
  formatted_chat_history.append(f"Assistant: {bot_message}")
240
  return formatted_chat_history
241
+
242
+
243
 
 
244
  def conversation(qa_chain, message, history):
245
  formatted_chat_history = format_chat_history(message, history)
246
+ #print("formatted_chat_history",formatted_chat_history)
247
+
248
+ # Generate response using QA chain
249
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
250
+ response_answer = response["answer"]
251
+ if response_answer.find("Helpful Answer:") != -1:
252
+ response_answer = response_answer.split("Helpful Answer:")[-1]
253
  response_sources = response["source_documents"]
254
+ response_source1 = response_sources[0].page_content.strip()
255
+ response_source2 = response_sources[1].page_content.strip()
256
+ response_source3 = response_sources[2].page_content.strip()
257
+ # Langchain sources are zero-based
258
+ response_source1_page = response_sources[0].metadata["page"] + 1
259
+ response_source2_page = response_sources[1].metadata["page"] + 1
260
+ response_source3_page = response_sources[2].metadata["page"] + 1
261
+ # print ('chat response: ', response_answer)
262
+ # print('DB source', response_sources)
263
+
264
+ # Append user message and response to chat history
265
  new_history = history + [(message, response_answer)]
266
+ # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
267
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
268
+
269
+
270
+ def upload_file(file_obj):
271
+ list_file_path = []
272
+ for idx, file in enumerate(file_obj):
273
+ file_path = file_obj.name
274
+ list_file_path.append(file_path)
275
+ # print(file_path)
276
+ # initialize_database(file_path, progress)
277
+ return list_file_path
278
+
279
+
280
 
 
281
  def demo():
282
+ with gr.Blocks(theme="base") as demo:
283
  vector_db = gr.State()
284
  qa_chain = gr.State()
285
  collection_name = gr.State()
286
 
287
  gr.Markdown(
288
+ """<center><h2>PDF-based chatbot</center></h2>
289
  <h3>Ask any questions about your PDF documents</h3>""")
290
+ gr.Markdown(
291
+ """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
292
+ The user interface explicitely shows multiple steps to help understand the RAG workflow.
293
+ This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
294
+ <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
295
+ """)
296
 
297
+ with gr.Tab("Step 1 - Upload PDF"):
298
+ with gr.Row():
299
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
300
+ # upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
301
 
302
+ with gr.Tab("Step 2 - Process document"):
303
+ with gr.Row():
304
+ db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
305
+ with gr.Accordion("Advanced options - Document text splitter", open=False):
306
+ with gr.Row():
307
+ slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
308
+ with gr.Row():
309
+ slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
310
+ with gr.Row():
311
+ db_progress = gr.Textbox(label="Vector database initialization", value="None")
312
+ with gr.Row():
313
+ db_btn = gr.Button("Generate vector database")
314
 
315
+ with gr.Tab("Step 3 - Initialize QA chain"):
316
+ with gr.Row():
317
+ llm_btn = gr.Radio(list_llm_simple, \
318
+ label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
319
+ with gr.Accordion("Advanced options - LLM model", open=False):
320
+ with gr.Row():
321
+ slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
322
+ with gr.Row():
323
+ slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
324
+ with gr.Row():
325
+ slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
326
+ with gr.Row():
327
+ llm_progress = gr.Textbox(value="None",label="QA chain initialization")
328
+ with gr.Row():
329
+ qachain_btn = gr.Button("Initialize Question Answering chain")
330
+
331
+ with gr.Tab("Step 4 - Chatbot"):
332
  chatbot = gr.Chatbot(height=300)
333
+ with gr.Accordion("Advanced - Document references", open=False):
334
+ with gr.Row():
335
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
336
+ source1_page = gr.Number(label="Page", scale=1)
337
+ with gr.Row():
338
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
339
+ source2_page = gr.Number(label="Page", scale=1)
340
+ with gr.Row():
341
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
342
+ source3_page = gr.Number(label="Page", scale=1)
343
+ with gr.Row():
344
+ msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
345
+ with gr.Row():
346
+ submit_btn = gr.Button("Submit message")
347
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
348
 
349
+ # Preprocessing events
350
+ #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
351
+ db_btn.click(initialize_database, \
352
+ inputs=[document, slider_chunk_size, slider_chunk_overlap], \
353
+ outputs=[vector_db, collection_name, db_progress])
354
+ qachain_btn.click(initialize_LLM, \
355
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
356
+ outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
357
+ inputs=None, \
358
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
359
+ queue=False)
360
+
361
+ # Chatbot events
362
+ msg.submit(conversation, \
363
+ inputs=[qa_chain, msg, chatbot], \
364
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
365
+ queue=False)
366
+ submit_btn.click(conversation, \
367
+ inputs=[qa_chain, msg, chatbot], \
368
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
369
+ queue=False)
370
+ clear_btn.click(lambda:[None,"",0,"",0,"",0], \
371
+ inputs=None, \
372
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
373
+ queue=False)
374
+ demo.queue().launch(debug=True)
375
+
376
 
 
377
 
378
  if __name__ == "__main__":
379
+ demo()