Pavan178 commited on
Commit
ee43a37
·
verified ·
1 Parent(s): 03c98cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -103
app.py CHANGED
@@ -1,111 +1,162 @@
1
  import gradio as gr
2
  import os
 
3
  from langchain_community.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import Chroma
6
  from langchain.chains import ConversationalRetrievalChain
7
- #from langchain_community.embeddings import HuggingFaceEmbeddings
8
- from langchain_huggingface import HuggingFaceEmbeddings
9
  from langchain_community.llms import HuggingFacePipeline
10
  from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
 
12
  import spaces
13
  from pathlib import Path
14
  import chromadb
15
  from unidecode import unidecode
16
- import os
17
- from huggingface_hub import login
18
- import torch
19
- from transformers import AutoTokenizer, AutoModelForCausalLM
20
-
21
 
 
22
  import transformers
23
  import torch
 
 
24
  import re
25
 
26
- # List of models
27
- list_llm = [
28
- "mistralai/Mistral-7B-Instruct-v0.2",
29
- "HuggingFaceH4/zephyr-7b-beta",
30
- "microsoft/phi-2",
31
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
32
- # Add more GPU-compatible models here
 
 
33
  ]
34
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
35
 
36
  @spaces.GPU
37
- @spaces.GPU
38
  def load_doc(list_file_path, chunk_size, chunk_overlap):
 
 
 
39
  loaders = [PyPDFLoader(x) for x in list_file_path]
40
  pages = []
41
  for loader in loaders:
42
  pages.extend(loader.load())
 
43
  text_splitter = RecursiveCharacterTextSplitter(
44
  chunk_size = chunk_size,
45
  chunk_overlap = chunk_overlap)
46
  doc_splits = text_splitter.split_documents(pages)
47
  return doc_splits
48
 
 
49
  # Create vector database
50
  def create_db(splits, collection_name):
51
- # Set CUDA_VISIBLE_DEVICES if GPU is available
52
- if torch.cuda.is_available():
53
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
54
-
55
- embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
56
  new_client = chromadb.EphemeralClient()
57
  vectordb = Chroma.from_documents(
58
  documents=splits,
59
  embedding=embedding,
60
  client=new_client,
61
  collection_name=collection_name,
 
62
  )
63
  return vectordb
64
 
65
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
66
- progress(0.1, desc="Initializing HF tokenizer...")
67
-
68
- # Retrieve the Hugging Face token from environment variables
69
- hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
70
- if not hf_token:
71
- raise ValueError("Hugging Face token not found. Please set the HF_TOKEN environment variable.")
72
 
73
- # Log in to Hugging Face
74
- login(token=hf_token)
 
 
 
 
 
 
75
 
76
- # Initialize tokenizer and model with the token
77
- tokenizer = AutoTokenizer.from_pretrained(llm_model, use_auth_token=hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- progress(0.3, desc="Loading model...")
80
- try:
81
- model = AutoModelForCausalLM.from_pretrained(
82
- llm_model,
83
- use_auth_token=hf_token,
84
- torch_dtype=torch.float16,
85
- device_map="auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  )
87
- except RuntimeError as e:
88
- if "CUDA out of memory" in str(e):
89
- raise gr.Error("GPU memory exceeded. Try a smaller model or reduce batch size.")
90
- else:
91
- raise e
92
-
93
- progress(0.5, desc="Initializing HF pipeline...")
94
- pipeline = transformers.pipeline(
95
- "text-generation",
96
- model=model,
97
- tokenizer=tokenizer,
98
- torch_dtype=torch.float16,
99
- device_map="auto",
100
- max_new_tokens=max_tokens,
101
- do_sample=True,
102
- top_k=top_k,
103
- num_return_sequences=1,
104
- eos_token_id=tokenizer.eos_token_id
105
- )
106
- llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': temperature})
107
-
108
-
109
 
110
  progress(0.75, desc="Defining buffer memory...")
111
  memory = ConversationBufferMemory(
@@ -113,58 +164,90 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
113
  output_key='answer',
114
  return_messages=True
115
  )
116
- retriever = vector_db.as_retriever()
 
117
  progress(0.8, desc="Defining retrieval chain...")
118
  qa_chain = ConversationalRetrievalChain.from_llm(
119
  llm,
120
  retriever=retriever,
121
  chain_type="stuff",
122
  memory=memory,
 
123
  return_source_documents=True,
 
124
  verbose=False,
125
  )
126
  progress(0.9, desc="Done!")
127
  return qa_chain
128
 
 
 
 
129
  def create_collection_name(filepath):
 
130
  collection_name = Path(filepath).stem
 
 
131
  collection_name = collection_name.replace(" ","-")
 
132
  collection_name = unidecode(collection_name)
 
 
133
  collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
 
134
  collection_name = collection_name[:50]
 
135
  if len(collection_name) < 3:
136
  collection_name = collection_name + 'xyz'
 
137
  if not collection_name[0].isalnum():
138
  collection_name = 'A' + collection_name[1:]
139
  if not collection_name[-1].isalnum():
140
  collection_name = collection_name[:-1] + 'Z'
 
 
141
  return collection_name
142
 
 
 
143
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
 
144
  list_file_path = [x.name for x in list_file_obj if x is not None]
 
145
  progress(0.1, desc="Creating collection name...")
146
  collection_name = create_collection_name(list_file_path[0])
147
  progress(0.25, desc="Loading document...")
 
148
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
 
149
  progress(0.5, desc="Generating vector database...")
 
150
  vector_db = create_db(doc_splits, collection_name)
151
  progress(0.9, desc="Done!")
152
  return vector_db, collection_name, "Complete!"
153
 
 
154
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
155
  llm_name = list_llm[llm_option]
 
156
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
157
  return qa_chain, "Complete!"
158
 
 
159
  def format_chat_history(message, chat_history):
160
  formatted_chat_history = []
161
  for user_message, bot_message in chat_history:
162
  formatted_chat_history.append(f"User: {user_message}")
163
  formatted_chat_history.append(f"Assistant: {bot_message}")
164
  return formatted_chat_history
 
165
 
166
  def conversation(qa_chain, message, history):
167
  formatted_chat_history = format_chat_history(message, history)
 
 
 
168
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
169
  response_answer = response["answer"]
170
  if response_answer.find("Helpful Answer:") != -1:
@@ -173,12 +256,28 @@ def conversation(qa_chain, message, history):
173
  response_source1 = response_sources[0].page_content.strip()
174
  response_source2 = response_sources[1].page_content.strip()
175
  response_source3 = response_sources[2].page_content.strip()
 
176
  response_source1_page = response_sources[0].metadata["page"] + 1
177
  response_source2_page = response_sources[1].metadata["page"] + 1
178
  response_source3_page = response_sources[2].metadata["page"] + 1
 
 
179
 
 
180
  new_history = history + [(message, response_answer)]
 
181
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  def demo():
184
  with gr.Blocks(theme="base") as demo:
@@ -187,71 +286,94 @@ def demo():
187
  collection_name = gr.State()
188
 
189
  gr.Markdown(
190
- """<center><h2>GPU-Accelerated PDF-based Chatbot</center></h2>
191
  <h3>Ask any questions about your PDF documents</h3>""")
192
  gr.Markdown(
193
- """<b>Note:</b> This AI assistant uses GPU acceleration for faster processing.
194
- It performs retrieval-augmented generation (RAG) from your PDF documents using Langchain and open-source LLMs.
195
- The chatbot takes past questions into account and includes document references.""")
 
 
196
 
197
  with gr.Tab("Step 1 - Upload PDF"):
198
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
 
 
199
 
200
  with gr.Tab("Step 2 - Process document"):
201
- db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
 
202
  with gr.Accordion("Advanced options - Document text splitter", open=False):
203
- slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
204
- slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
205
- db_progress = gr.Textbox(label="Vector database initialization", value="None")
206
- db_btn = gr.Button("Generate vector database")
 
 
 
 
207
 
208
  with gr.Tab("Step 3 - Initialize QA chain"):
209
- llm_btn = gr.Radio(list_llm_simple, label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
 
 
210
  with gr.Accordion("Advanced options - LLM model", open=False):
211
- slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
212
- slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
213
- slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
214
- llm_progress = gr.Textbox(value="None",label="QA chain initialization")
215
- qachain_btn = gr.Button("Initialize Question Answering chain")
 
 
 
 
 
216
 
217
  with gr.Tab("Step 4 - Chatbot"):
218
  chatbot = gr.Chatbot(height=300)
219
  with gr.Accordion("Advanced - Document references", open=False):
220
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
221
- source1_page = gr.Number(label="Page", scale=1)
222
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
223
- source2_page = gr.Number(label="Page", scale=1)
224
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
225
- source3_page = gr.Number(label="Page", scale=1)
226
- msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
227
- submit_btn = gr.Button("Submit message")
228
- clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
 
 
 
 
 
229
 
230
  # Preprocessing events
231
- db_btn.click(initialize_database,
232
- inputs=[document, slider_chunk_size, slider_chunk_overlap],
 
233
  outputs=[vector_db, collection_name, db_progress])
234
- qachain_btn.click(initialize_LLM,
235
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
236
- outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
237
- inputs=None,
238
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
239
  queue=False)
240
 
241
  # Chatbot events
242
- msg.submit(conversation,
243
- inputs=[qa_chain, msg, chatbot],
244
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
245
  queue=False)
246
- submit_btn.click(conversation,
247
- inputs=[qa_chain, msg, chatbot],
248
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
249
  queue=False)
250
- clear_btn.click(lambda:[None,"",0,"",0,"",0],
251
- inputs=None,
252
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
253
  queue=False)
254
  demo.queue().launch(debug=True)
255
 
 
256
  if __name__ == "__main__":
257
- demo()
 
1
  import gradio as gr
2
  import os
3
+
4
  from langchain_community.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain_community.vectorstores import Chroma
7
  from langchain.chains import ConversationalRetrievalChain
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
 
9
  from langchain_community.llms import HuggingFacePipeline
10
  from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
12
+ from langchain_community.llms import HuggingFaceEndpoint
13
  import spaces
14
  from pathlib import Path
15
  import chromadb
16
  from unidecode import unidecode
 
 
 
 
 
17
 
18
+ from transformers import AutoTokenizer
19
  import transformers
20
  import torch
21
+ import tqdm
22
+ import accelerate
23
  import re
24
 
25
+
26
+
27
+ # default_persist_directory = './chroma_HF/'
28
+ list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \
29
+ "google/gemma-7b-it","google/gemma-2b-it", \
30
+ "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1", \
31
+ "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \
32
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
33
+ "google/flan-t5-xxl"
34
  ]
35
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
36
 
37
  @spaces.GPU
38
+ # Load PDF document and create doc splits
39
  def load_doc(list_file_path, chunk_size, chunk_overlap):
40
+ # Processing for one document only
41
+ # loader = PyPDFLoader(file_path)
42
+ # pages = loader.load()
43
  loaders = [PyPDFLoader(x) for x in list_file_path]
44
  pages = []
45
  for loader in loaders:
46
  pages.extend(loader.load())
47
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50)
48
  text_splitter = RecursiveCharacterTextSplitter(
49
  chunk_size = chunk_size,
50
  chunk_overlap = chunk_overlap)
51
  doc_splits = text_splitter.split_documents(pages)
52
  return doc_splits
53
 
54
+
55
  # Create vector database
56
  def create_db(splits, collection_name):
57
+ embedding = HuggingFaceEmbeddings()
 
 
 
 
58
  new_client = chromadb.EphemeralClient()
59
  vectordb = Chroma.from_documents(
60
  documents=splits,
61
  embedding=embedding,
62
  client=new_client,
63
  collection_name=collection_name,
64
+ # persist_directory=default_persist_directory
65
  )
66
  return vectordb
67
 
 
 
 
 
 
 
 
68
 
69
+ # Load vector database
70
+ def load_db():
71
+ embedding = HuggingFaceEmbeddings()
72
+ vectordb = Chroma(
73
+ # persist_directory=default_persist_directory,
74
+ embedding_function=embedding)
75
+ return vectordb
76
+
77
 
78
+ # Initialize langchain LLM chain
79
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
80
+ progress(0.1, desc="Initializing HF tokenizer...")
81
+ # HuggingFacePipeline uses local model
82
+ # Note: it will download model locally...
83
+ # tokenizer=AutoTokenizer.from_pretrained(llm_model)
84
+ # progress(0.5, desc="Initializing HF pipeline...")
85
+ # pipeline=transformers.pipeline(
86
+ # "text-generation",
87
+ # model=llm_model,
88
+ # tokenizer=tokenizer,
89
+ # torch_dtype=torch.bfloat16,
90
+ # trust_remote_code=True,
91
+ # device_map="auto",
92
+ # # max_length=1024,
93
+ # max_new_tokens=max_tokens,
94
+ # do_sample=True,
95
+ # top_k=top_k,
96
+ # num_return_sequences=1,
97
+ # eos_token_id=tokenizer.eos_token_id
98
+ # )
99
+ # llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': temperature})
100
 
101
+ # HuggingFaceHub uses HF inference endpoints
102
+ progress(0.5, desc="Initializing HF Hub...")
103
+ # Use of trust_remote_code as model_kwargs
104
+ # Warning: langchain issue
105
+ # URL: https://github.com/langchain-ai/langchain/issues/6080
106
+ if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
107
+ llm = HuggingFaceEndpoint(
108
+ repo_id=llm_model,
109
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
110
+ temperature = temperature,
111
+ max_new_tokens = max_tokens,
112
+ top_k = top_k,
113
+ load_in_8bit = True,
114
+ )
115
+ elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1","mosaicml/mpt-7b-instruct"]:
116
+ raise gr.Error("LLM model is too large to be loaded automatically on free inference endpoint")
117
+ llm = HuggingFaceEndpoint(
118
+ repo_id=llm_model,
119
+ temperature = temperature,
120
+ max_new_tokens = max_tokens,
121
+ top_k = top_k,
122
+ )
123
+ elif llm_model == "microsoft/phi-2":
124
+ # raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
125
+ llm = HuggingFaceEndpoint(
126
+ repo_id=llm_model,
127
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
128
+ temperature = temperature,
129
+ max_new_tokens = max_tokens,
130
+ top_k = top_k,
131
+ trust_remote_code = True,
132
+ torch_dtype = "auto",
133
+ )
134
+ elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
135
+ llm = HuggingFaceEndpoint(
136
+ repo_id=llm_model,
137
+ # model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
138
+ temperature = temperature,
139
+ max_new_tokens = 250,
140
+ top_k = top_k,
141
+ )
142
+ elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
143
+ raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
144
+ llm = HuggingFaceEndpoint(
145
+ repo_id=llm_model,
146
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
147
+ temperature = temperature,
148
+ max_new_tokens = max_tokens,
149
+ top_k = top_k,
150
+ )
151
+ else:
152
+ llm = HuggingFaceEndpoint(
153
+ repo_id=llm_model,
154
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
155
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
156
+ temperature = temperature,
157
+ max_new_tokens = max_tokens,
158
+ top_k = top_k,
159
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  progress(0.75, desc="Defining buffer memory...")
162
  memory = ConversationBufferMemory(
 
164
  output_key='answer',
165
  return_messages=True
166
  )
167
+ # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
168
+ retriever=vector_db.as_retriever()
169
  progress(0.8, desc="Defining retrieval chain...")
170
  qa_chain = ConversationalRetrievalChain.from_llm(
171
  llm,
172
  retriever=retriever,
173
  chain_type="stuff",
174
  memory=memory,
175
+ # combine_docs_chain_kwargs={"prompt": your_prompt})
176
  return_source_documents=True,
177
+ #return_generated_question=False,
178
  verbose=False,
179
  )
180
  progress(0.9, desc="Done!")
181
  return qa_chain
182
 
183
+
184
+ # Generate collection name for vector database
185
+ # - Use filepath as input, ensuring unicode text
186
  def create_collection_name(filepath):
187
+ # Extract filename without extension
188
  collection_name = Path(filepath).stem
189
+ # Fix potential issues from naming convention
190
+ ## Remove space
191
  collection_name = collection_name.replace(" ","-")
192
+ ## ASCII transliterations of Unicode text
193
  collection_name = unidecode(collection_name)
194
+ ## Remove special characters
195
+ #collection_name = re.findall("[\dA-Za-z]*", collection_name)[0]
196
  collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
197
+ ## Limit length to 50 characters
198
  collection_name = collection_name[:50]
199
+ ## Minimum length of 3 characters
200
  if len(collection_name) < 3:
201
  collection_name = collection_name + 'xyz'
202
+ ## Enforce start and end as alphanumeric character
203
  if not collection_name[0].isalnum():
204
  collection_name = 'A' + collection_name[1:]
205
  if not collection_name[-1].isalnum():
206
  collection_name = collection_name[:-1] + 'Z'
207
+ print('Filepath: ', filepath)
208
+ print('Collection name: ', collection_name)
209
  return collection_name
210
 
211
+
212
+ # Initialize database
213
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
214
+ # Create list of documents (when valid)
215
  list_file_path = [x.name for x in list_file_obj if x is not None]
216
+ # Create collection_name for vector database
217
  progress(0.1, desc="Creating collection name...")
218
  collection_name = create_collection_name(list_file_path[0])
219
  progress(0.25, desc="Loading document...")
220
+ # Load document and create splits
221
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
222
+ # Create or load vector database
223
  progress(0.5, desc="Generating vector database...")
224
+ # global vector_db
225
  vector_db = create_db(doc_splits, collection_name)
226
  progress(0.9, desc="Done!")
227
  return vector_db, collection_name, "Complete!"
228
 
229
+
230
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
231
+ # print("llm_option",llm_option)
232
  llm_name = list_llm[llm_option]
233
+ print("llm_name: ",llm_name)
234
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
235
  return qa_chain, "Complete!"
236
 
237
+
238
  def format_chat_history(message, chat_history):
239
  formatted_chat_history = []
240
  for user_message, bot_message in chat_history:
241
  formatted_chat_history.append(f"User: {user_message}")
242
  formatted_chat_history.append(f"Assistant: {bot_message}")
243
  return formatted_chat_history
244
+
245
 
246
  def conversation(qa_chain, message, history):
247
  formatted_chat_history = format_chat_history(message, history)
248
+ #print("formatted_chat_history",formatted_chat_history)
249
+
250
+ # Generate response using QA chain
251
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
252
  response_answer = response["answer"]
253
  if response_answer.find("Helpful Answer:") != -1:
 
256
  response_source1 = response_sources[0].page_content.strip()
257
  response_source2 = response_sources[1].page_content.strip()
258
  response_source3 = response_sources[2].page_content.strip()
259
+ # Langchain sources are zero-based
260
  response_source1_page = response_sources[0].metadata["page"] + 1
261
  response_source2_page = response_sources[1].metadata["page"] + 1
262
  response_source3_page = response_sources[2].metadata["page"] + 1
263
+ # print ('chat response: ', response_answer)
264
+ # print('DB source', response_sources)
265
 
266
+ # Append user message and response to chat history
267
  new_history = history + [(message, response_answer)]
268
+ # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
269
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
270
+
271
+
272
+ def upload_file(file_obj):
273
+ list_file_path = []
274
+ for idx, file in enumerate(file_obj):
275
+ file_path = file_obj.name
276
+ list_file_path.append(file_path)
277
+ # print(file_path)
278
+ # initialize_database(file_path, progress)
279
+ return list_file_path
280
+
281
 
282
  def demo():
283
  with gr.Blocks(theme="base") as demo:
 
286
  collection_name = gr.State()
287
 
288
  gr.Markdown(
289
+ """<center><h2>PDF-based chatbot</center></h2>
290
  <h3>Ask any questions about your PDF documents</h3>""")
291
  gr.Markdown(
292
+ """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
293
+ The user interface explicitely shows multiple steps to help understand the RAG workflow.
294
+ This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
295
+ <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
296
+ """)
297
 
298
  with gr.Tab("Step 1 - Upload PDF"):
299
+ with gr.Row():
300
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
301
+ # upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
302
 
303
  with gr.Tab("Step 2 - Process document"):
304
+ with gr.Row():
305
+ db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
306
  with gr.Accordion("Advanced options - Document text splitter", open=False):
307
+ with gr.Row():
308
+ slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
309
+ with gr.Row():
310
+ slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
311
+ with gr.Row():
312
+ db_progress = gr.Textbox(label="Vector database initialization", value="None")
313
+ with gr.Row():
314
+ db_btn = gr.Button("Generate vector database")
315
 
316
  with gr.Tab("Step 3 - Initialize QA chain"):
317
+ with gr.Row():
318
+ llm_btn = gr.Radio(list_llm_simple, \
319
+ label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
320
  with gr.Accordion("Advanced options - LLM model", open=False):
321
+ with gr.Row():
322
+ slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
323
+ with gr.Row():
324
+ slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
325
+ with gr.Row():
326
+ slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
327
+ with gr.Row():
328
+ llm_progress = gr.Textbox(value="None",label="QA chain initialization")
329
+ with gr.Row():
330
+ qachain_btn = gr.Button("Initialize Question Answering chain")
331
 
332
  with gr.Tab("Step 4 - Chatbot"):
333
  chatbot = gr.Chatbot(height=300)
334
  with gr.Accordion("Advanced - Document references", open=False):
335
+ with gr.Row():
336
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
337
+ source1_page = gr.Number(label="Page", scale=1)
338
+ with gr.Row():
339
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
340
+ source2_page = gr.Number(label="Page", scale=1)
341
+ with gr.Row():
342
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
343
+ source3_page = gr.Number(label="Page", scale=1)
344
+ with gr.Row():
345
+ msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
346
+ with gr.Row():
347
+ submit_btn = gr.Button("Submit message")
348
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
349
 
350
  # Preprocessing events
351
+ #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
352
+ db_btn.click(initialize_database, \
353
+ inputs=[document, slider_chunk_size, slider_chunk_overlap], \
354
  outputs=[vector_db, collection_name, db_progress])
355
+ qachain_btn.click(initialize_LLM, \
356
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
357
+ outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
358
+ inputs=None, \
359
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
360
  queue=False)
361
 
362
  # Chatbot events
363
+ msg.submit(conversation, \
364
+ inputs=[qa_chain, msg, chatbot], \
365
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
366
  queue=False)
367
+ submit_btn.click(conversation, \
368
+ inputs=[qa_chain, msg, chatbot], \
369
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
370
  queue=False)
371
+ clear_btn.click(lambda:[None,"",0,"",0,"",0], \
372
+ inputs=None, \
373
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
374
  queue=False)
375
  demo.queue().launch(debug=True)
376
 
377
+
378
  if __name__ == "__main__":
379
+ demo()