Fecalisboa commited on
Commit
289ac0c
1 Parent(s): 6d330c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -160
app.py CHANGED
@@ -49,14 +49,7 @@ from llama_index.core.schema import BaseNode, TextNode
49
  api_token = os.getenv("HF_TOKEN")
50
 
51
 
52
- # default_persist_directory = './chroma_HF/'
53
- list_llm = ["mistralai/Mistral-7B-Instruct-v0.3", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1",
54
- "google/gemma-7b-it","google/gemma-2b-it",
55
- "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1",
56
- "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2",
57
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct",
58
- "google/flan-t5-xxl"
59
- ]
60
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
61
 
62
  # Load PDF document and create doc splits
@@ -94,55 +87,18 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
94
 
95
  progress(0.5, desc="Initializing HF Hub...")
96
 
97
- if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.3":
98
- llm = HuggingFaceEndpoint(
99
- repo_id=llm_model,
100
- api_key=api_token,
101
- temperature = temperature,
102
- max_new_tokens = max_tokens,
103
- top_k = top_k,
104
- load_in_8bit = True,
105
- )
106
- elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1","mosaicml/mpt-7b-instruct"]:
107
- raise gr.Error("LLM model is too large to be loaded automatically on free inference endpoint")
108
- llm = HuggingFaceEndpoint(
109
- repo_id=llm_model,
110
- api_key=api_token,
111
- temperature = temperature,
112
- max_new_tokens = max_tokens,
113
- top_k = top_k,
114
- )
115
- elif llm_model == "microsoft/phi-2":
116
- llm = HuggingFaceEndpoint(
117
- repo_id=llm_model,
118
- api_key=api_token,
119
- temperature = temperature,
120
- max_new_tokens = max_tokens,
121
- top_k = top_k,
122
- trust_remote_code = True,
123
- torch_dtype = "auto",
124
- )
125
- elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
126
  llm = HuggingFaceEndpoint(
127
  repo_id=llm_model,
128
- api_key=api_token,
129
- temperature = temperature,
130
- max_new_tokens = 250,
131
- top_k = top_k,
132
- )
133
- elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
134
- raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
135
- llm = HuggingFaceEndpoint(
136
- repo_id=llm_model,
137
- api_key=api_token,
138
  temperature = temperature,
139
  max_new_tokens = max_tokens,
140
  top_k = top_k,
141
  )
142
  else:
143
  llm = HuggingFaceEndpoint(
 
144
  repo_id=llm_model,
145
- api_key=api_token,
146
  temperature = temperature,
147
  max_new_tokens = max_tokens,
148
  top_k = top_k,
@@ -209,6 +165,7 @@ def format_chat_history(message, chat_history):
209
  formatted_chat_history.append(f"Assistant: {bot_message}")
210
  return formatted_chat_history
211
 
 
212
  def conversation(qa_chain, message, history):
213
  formatted_chat_history = format_chat_history(message, history)
214
 
@@ -234,121 +191,118 @@ def upload_file(file_obj):
234
  list_file_path.append(file_path)
235
  return list_file_path
236
 
237
- # Initialize LlamaIndex parsing
238
- def initialize_llama_index(file_obj):
239
- documents = LlamaParse(result_type="markdown", api_key=api_token).load_data(file_obj[0].name)
240
- node_parser = MarkdownElementNodeParser(llm=None, num_workers=8)
241
- nodes = node_parser.get_nodes_from_documents(documents)
242
- base_nodes, objects = node_parser.get_nodes_and_objects(nodes)
243
- index_with_obj = VectorStoreIndex(nodes=base_nodes + objects)
244
- index_ret = index_with_obj.as_retriever(top_k=15)
245
- recursive_query_engine = RetrieverQueryEngine.from_args(index_ret, node_postprocessors=[FlagEmbeddingReranker(
246
- top_n=5,
247
- model="BAAI/bge-reranker-large",
248
- )], verbose=False)
249
- return recursive_query_engine, "LlamaIndex parsing complete"
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- def demo():
252
- with gr.Blocks(theme="base") as demo:
253
- vector_db = gr.State()
254
- qa_chain = gr.State()
255
- collection_name = gr.State()
256
- llama_index_engine = gr.State()
257
-
258
- gr.Markdown(
259
- """<center><h2>PDF-based chatbot</center></h2>
260
- <h3>Ask any questions about your PDF documents</h3>""")
261
- gr.Markdown(
262
- """<b>Note:</b> Esta é a lucIAna, primeira Versão da IA para seus PDF documentos. \
263
- Este chatbot leva em consideração perguntas anteriores ao gerar respostas (por meio de memória conversacional) e inclui referências a documentos para fins de clareza.
264
- """)
265
-
266
- with gr.Tab("Step 1 - Upload PDF"):
267
- with gr.Row():
268
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
269
-
270
- with gr.Tab("Step 2 - Process document"):
271
- with gr.Row():
272
- db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value="ChromaDB", type="index", info="Choose your vector database")
273
- with gr.Accordion("Advanced options - Document text splitter", open=False):
274
- with gr.Row():
275
- slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
276
- with gr.Row():
277
- slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
278
- with gr.Row():
279
- db_progress = gr.Textbox(label="Vector database initialization", value="None")
280
- with gr.Row():
281
- db_btn = gr.Button("Generate vector database")
282
-
283
- with gr.Tab("Step 3 - Initialize QA chain"):
284
- with gr.Row():
285
- llm_btn = gr.Radio(list_llm_simple,
286
- label="LLM models", value=list_llm_simple[0], type="index", info="Choose your LLM model")
287
- with gr.Accordion("Advanced options - LLM model", open=False):
288
- with gr.Row():
289
- slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
290
- with gr.Row():
291
- slider_maxtokens = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
292
- with gr.Row():
293
- slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
294
- with gr.Row():
295
- llm_progress = gr.Textbox(value="None", label="QA chain initialization")
296
- with gr.Row():
297
- qachain_btn = gr.Button("Initialize Question Answering chain")
298
 
299
- with gr.Tab("Step 4 - LlamaIndex parsing"):
300
- with gr.Row():
301
- llama_index_btn = gr.Button("Parse with LlamaIndex")
302
- with gr.Row():
303
- llama_index_progress = gr.Textbox(label="LlamaIndex parsing status", value="None")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
- with gr.Tab("Step 5 - Chatbot"):
306
- chatbot = gr.Chatbot(height=300)
307
- with gr.Accordion("Advanced - Document references", open=False):
308
- with gr.Row():
309
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
310
- source1_page = gr.Number(label="Page", scale=1)
311
- with gr.Row():
312
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
313
- source2_page = gr.Number(label="Page", scale=1)
314
- with gr.Row():
315
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
316
- source3_page = gr.Number(label="Page", scale=1)
317
- with gr.Row():
318
- msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
319
- with gr.Row():
320
- submit_btn = gr.Button("Submit message")
321
- clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
322
-
323
- # Preprocessing events
324
- db_btn.click(initialize_database,
325
- inputs=[document, slider_chunk_size, slider_chunk_overlap],
326
- outputs=[vector_db, collection_name, db_progress])
327
- qachain_btn.click(initialize_LLM,
328
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
329
- outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
330
- inputs=None,
331
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
332
- queue=False)
333
- llama_index_btn.click(initialize_llama_index,
334
- inputs=[document],
335
- outputs=[llama_index_engine, llama_index_progress])
336
 
337
- # Chatbot events
338
- msg.submit(conversation,
339
- inputs=[qa_chain, msg, chatbot],
340
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
341
- queue=False)
342
- submit_btn.click(conversation,
343
- inputs=[qa_chain, msg, chatbot],
344
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
345
- queue=False)
346
- clear_btn.click(lambda:[None,"",0,"",0,"",0],
347
- inputs=None,
348
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
349
- queue=False)
350
- demo.queue().launch(debug=True)
351
 
 
 
 
 
 
352
 
353
- if __name__ == "__main__":
354
- demo()
 
 
 
 
 
49
  api_token = os.getenv("HF_TOKEN")
50
 
51
 
52
+ list_llm = ["mistralai/Miceli", "mistralai/Mistral-7B-Instruct-v0.3"]
 
 
 
 
 
 
 
53
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
54
 
55
  # Load PDF document and create doc splits
 
87
 
88
  progress(0.5, desc="Initializing HF Hub...")
89
 
90
+ if llm_model == "mistralai/Mistral-7B-Instruct-v0.2":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  llm = HuggingFaceEndpoint(
92
  repo_id=llm_model,
93
+ huggingfacehub_api_token = api_token,
 
 
 
 
 
 
 
 
 
94
  temperature = temperature,
95
  max_new_tokens = max_tokens,
96
  top_k = top_k,
97
  )
98
  else:
99
  llm = HuggingFaceEndpoint(
100
+ huggingfacehub_api_token = api_token,
101
  repo_id=llm_model,
 
102
  temperature = temperature,
103
  max_new_tokens = max_tokens,
104
  top_k = top_k,
 
165
  formatted_chat_history.append(f"Assistant: {bot_message}")
166
  return formatted_chat_history
167
 
168
+
169
  def conversation(qa_chain, message, history):
170
  formatted_chat_history = format_chat_history(message, history)
171
 
 
191
  list_file_path.append(file_path)
192
  return list_file_path
193
 
194
+ list_llm = ["mistralai/Miceli", "mistralai/Mistral-7B-Instruct-v0.3"]
195
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
196
+
197
+ # Load PDF document and create doc splits
198
+ def load_doc(list_file_path, chunk_size, chunk_overlap):
199
+ loaders = [PyPDFLoader(x) for x in list_file_path]
200
+ pages = []
201
+ for loader in loaders:
202
+ pages.extend(loader.load())
203
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size = chunk_size, chunk_overlap = chunk_overlap)
204
+ doc_splits = text_splitter.split_documents(pages)
205
+ return doc_splits
206
+
207
+ # Create vector database
208
+ def create_db(splits, collection_name):
209
+ embedding = HuggingFaceEmbeddings()
210
+ new_client = chromadb.EphemeralClient()
211
+ vectordb = Chroma.from_documents(
212
+ documents=splits,
213
+ embedding=embedding,
214
+ client=new_client,
215
+ collection_name=collection_name,
216
+ )
217
+ return vectordb
218
 
219
+ # Load vector database
220
+ def load_db():
221
+ embedding = HuggingFaceEmbeddings()
222
+ vectordb = Chroma(
223
+ embedding_function=embedding)
224
+ return vectordb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
+ # Initialize langchain LLM chain
227
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
228
+ progress(0.1, desc="Initializing HF tokenizer...")
229
+
230
+ progress(0.5, desc="Initializing HF Hub...")
231
+
232
+ if llm_model == "mistralai/Mistral-7B-Instruct-v0.2":
233
+ llm = HuggingFaceEndpoint(
234
+ repo_id=llm_model,
235
+ huggingfacehub_api_token = api_token,
236
+ temperature = temperature,
237
+ max_new_tokens = max_tokens,
238
+ top_k = top_k,
239
+ )
240
+ else:
241
+ llm = HuggingFaceEndpoint(
242
+ huggingfacehub_api_token = api_token,
243
+ repo_id=llm_model,
244
+ temperature = temperature,
245
+ max_new_tokens = max_tokens,
246
+ top_k = top_k,
247
+ )
248
+
249
+ progress(0.75, desc="Defining buffer memory...")
250
+ memory = ConversationBufferMemory(
251
+ memory_key="chat_history",
252
+ output_key='answer',
253
+ return_messages=True
254
+ )
255
+ retriever=vector_db.as_retriever()
256
+ progress(0.8, desc="Defining retrieval chain...")
257
+ qa_chain = ConversationalRetrievalChain.from_llm(
258
+ llm,
259
+ retriever=retriever,
260
+ chain_type="stuff",
261
+ memory=memory,
262
+ return_source_documents=True,
263
+ verbose=False,
264
+ )
265
+ progress(0.9, desc="Done!")
266
+ return qa_chain
267
 
268
+ # Generate collection name for vector database
269
+ def create_collection_name(filepath):
270
+ collection_name = Path(filepath).stem
271
+ collection_name = collection_name.replace(" ","-")
272
+ collection_name = unidecode(collection_name)
273
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
274
+ collection_name = collection_name[:50]
275
+ if len(collection_name) < 3:
276
+ collection_name = collection_name + 'xyz'
277
+ if not collection_name[0].isalnum():
278
+ collection_name = 'A' + collection_name[1:]
279
+ if not collection_name[-1].isalnum():
280
+ collection_name = collection_name[:-1] + 'Z'
281
+ print('Filepath: ', filepath)
282
+ print('Collection name: ', collection_name)
283
+ return collection_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
+ # Initialize database
286
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
287
+ list_file_path = [x.name for x in list_file_obj if x is not None]
288
+ progress(0.1, desc="Creating collection name...")
289
+ collection_name = create_collection_name(list_file_path[0])
290
+ progress(0.25, desc="Loading document...")
291
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
292
+ progress(0.5, desc="Generating vector database...")
293
+ vector_db = create_db(doc_splits, collection_name)
294
+ progress(0.9, desc="Done!")
295
+ return vector_db, collection_name, "Complete!"
 
 
 
296
 
297
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
298
+ llm_name = list_llm[llm_option]
299
+ print("llm_name: ",llm_name)
300
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
301
+ return qa_chain, "Complete!"
302
 
303
+ def format_chat_history(message, chat_history):
304
+ formatted_chat_history = []
305
+ for user_message, bot_message in chat_history:
306
+ formatted_chat_history.append(f"User: {user_message}")
307
+ formatted_chat_history.append(f"Assistant: {bot_message}")
308
+ return formatted_chat_history