Fecalisboa commited on
Commit
04a7754
1 Parent(s): 510c455

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -68
app.py CHANGED
@@ -15,7 +15,6 @@ from langchain.chains import ConversationChain
15
  from langchain.memory import ConversationBufferMemory
16
  from langchain_community.llms import HuggingFaceEndpoint
17
  import torch
18
-
19
  api_token = os.getenv("HF_TOKEN")
20
 
21
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
@@ -98,26 +97,40 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, in
98
  progress(0.9, desc="Done!")
99
  return qa_chain
100
 
101
- def initialize_llm_no_doc(llm_model, temperature, max_tokens, top_k, initial_prompt, progress=gr.Progress()):
102
- progress(0.1, desc="Initializing HF tokenizer...")
103
- progress(0.5, desc="Initializing HF Hub...")
104
- llm = HuggingFaceEndpoint(
105
- repo_id=llm_model,
106
- huggingfacehub_api_token=api_token,
107
- temperature=temperature,
108
- max_new_tokens=max_tokens,
109
- top_k=top_k,
110
- )
111
- progress(0.75, desc="Defining buffer memory...")
112
- memory = ConversationBufferMemory(
113
- memory_key="chat_history",
114
- output_key='answer',
115
- return_messages=True
116
- )
117
- conversation_chain = ConversationChain(llm=llm, memory=memory, verbose=False)
118
- conversation_chain({"question": initial_prompt})
 
 
 
 
 
 
 
 
119
  progress(0.9, desc="Done!")
120
- return conversation_chain
 
 
 
 
 
 
121
 
122
  def format_chat_history(message, chat_history):
123
  formatted_chat_history = []
@@ -143,6 +156,30 @@ def conversation(qa_chain, message, history):
143
  new_history = history + [(message, response_answer)]
144
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def conversation_no_doc(llm, message, history):
147
  formatted_chat_history = format_chat_history(message, history)
148
  response = llm({"question": message, "chat_history": formatted_chat_history})
@@ -156,33 +193,6 @@ def upload_file(file_obj):
156
  list_file_path.append(file.name)
157
  return list_file_path
158
 
159
- def initialize_database(list_file_obj, chunk_size, chunk_overlap, db_type, progress=gr.Progress()):
160
- list_file_path = [x.name for x in list_file_obj if x is not None]
161
- progress(0.1, desc="Creating collection name...")
162
- collection_name = create_collection_name(list_file_path[0])
163
- progress(0.25, desc="Loading document...")
164
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
165
- progress(0.5, desc="Generating vector database...")
166
- vector_db = create_db(doc_splits, collection_name, db_type)
167
- progress(0.9, desc="Done!")
168
- return vector_db, collection_name, "Complete!"
169
-
170
- def create_collection_name(filepath):
171
- collection_name = Path(filepath).stem
172
- collection_name = collection_name.replace(" ", "-")
173
- collection_name = unidecode(collection_name)
174
- collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
175
- collection_name = collection_name[:50]
176
- if len(collection_name) < 3:
177
- collection_name = collection_name + 'xyz'
178
- if not collection_name[0].isalnum():
179
- collection_name = 'A' + collection_name[1:]
180
- if not collection_name[-1].isalnum():
181
- collection_name = collection_name[:-1] + 'Z'
182
- print('Filepath: ', filepath)
183
- print('Collection name: ', collection_name)
184
- return collection_name
185
-
186
  def demo():
187
  with gr.Blocks(theme="base") as demo:
188
  vector_db = gr.State()
@@ -257,20 +267,6 @@ def demo():
257
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
258
 
259
  with gr.Tab("Step 6 - Chatbot without document"):
260
- with gr.Row():
261
- llm_no_doc_btn = gr.Radio(list_llm_simple,
262
- label="LLM models", value=list_llm_simple[0], type="index", info="Choose your LLM model for chatbot without document")
263
- with gr.Accordion("Advanced options - LLM model", open=False):
264
- with gr.Row():
265
- slider_temperature_no_doc = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
266
- with gr.Row():
267
- slider_maxtokens_no_doc = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
268
- with gr.Row():
269
- slider_topk_no_doc = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
270
- with gr.Row():
271
- llm_no_doc_progress = gr.Textbox(value="None", label="LLM initialization for chatbot without document")
272
- with gr.Row():
273
- llm_no_doc_init_btn = gr.Button("Initialize LLM for Chatbot without document")
274
  chatbot_no_doc = gr.Chatbot(height=300)
275
  with gr.Row():
276
  msg_no_doc = gr.Textbox(placeholder="Type message to chat with lucIAna", container=True)
@@ -282,10 +278,10 @@ def demo():
282
  db_btn.click(initialize_database,
283
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
284
  outputs=[vector_db, collection_name, db_progress])
285
- set_prompt_btn.click(lambda prompt: gr.update(value=prompt),
286
  inputs=prompt_input,
287
  outputs=initial_prompt)
288
- qachain_btn.click(initialize_llmchain,
289
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, initial_prompt],
290
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
291
  inputs=None,
@@ -306,11 +302,7 @@ def demo():
306
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
307
  queue=False)
308
 
309
- # Initialize LLM without document for conversation
310
- llm_no_doc_init_btn.click(initialize_llm_no_doc,
311
- inputs=[llm_no_doc_btn, slider_temperature_no_doc, slider_maxtokens_no_doc, slider_topk_no_doc, initial_prompt],
312
- outputs=[llm_no_doc, llm_no_doc_progress])
313
-
314
  submit_btn_no_doc.click(conversation_no_doc,
315
  inputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
316
  outputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
@@ -320,6 +312,27 @@ def demo():
320
  outputs=[chatbot_no_doc, msg_no_doc],
321
  queue=False)
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  demo.queue().launch(debug=True, share=True)
324
 
325
  if __name__ == "__main__":
 
15
  from langchain.memory import ConversationBufferMemory
16
  from langchain_community.llms import HuggingFaceEndpoint
17
  import torch
 
18
  api_token = os.getenv("HF_TOKEN")
19
 
20
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
 
97
  progress(0.9, desc="Done!")
98
  return qa_chain
99
 
100
+ # Generate collection name for vector database
101
+ def create_collection_name(filepath):
102
+ collection_name = Path(filepath).stem
103
+ collection_name = collection_name.replace(" ", "-")
104
+ collection_name = unidecode(collection_name)
105
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
106
+ collection_name = collection_name[:50]
107
+ if len(collection_name) < 3:
108
+ collection_name = collection_name + 'xyz'
109
+ if not collection_name[0].isalnum():
110
+ collection_name = 'A' + collection_name[1:]
111
+ if not collection_name[-1].isalnum():
112
+ collection_name = collection_name[:-1] + 'Z'
113
+ print('Filepath: ', filepath)
114
+ print('Collection name: ', collection_name)
115
+ return collection_name
116
+
117
+ # Initialize database
118
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap, db_type, progress=gr.Progress()):
119
+ list_file_path = [x.name for x in list_file_obj if x is not None]
120
+ progress(0.1, desc="Creating collection name...")
121
+ collection_name = create_collection_name(list_file_path[0])
122
+ progress(0.25, desc="Loading document...")
123
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
124
+ progress(0.5, desc="Generating vector database...")
125
+ vector_db = create_db(doc_splits, collection_name, db_type)
126
  progress(0.9, desc="Done!")
127
+ return vector_db, collection_name, "Complete!"
128
+
129
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, initial_prompt, progress=gr.Progress()):
130
+ llm_name = list_llm[llm_option]
131
+ print("llm_name: ", llm_name)
132
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, initial_prompt, progress)
133
+ return qa_chain, "Complete!"
134
 
135
  def format_chat_history(message, chat_history):
136
  formatted_chat_history = []
 
156
  new_history = history + [(message, response_answer)]
157
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
158
 
159
+ def initialize_llm_no_doc(llm_model, temperature, max_tokens, top_k, initial_prompt, progress=gr.Progress()):
160
+ progress(0.1, desc="Initializing HF tokenizer...")
161
+
162
+ progress(0.5, desc="Initializing HF Hub...")
163
+
164
+ llm = HuggingFaceEndpoint(
165
+ repo_id=llm_model,
166
+ huggingfacehub_api_token=api_token,
167
+ temperature=temperature,
168
+ max_new_tokens=max_tokens,
169
+ top_k=top_k,
170
+ )
171
+
172
+ progress(0.75, desc="Defining buffer memory...")
173
+ memory = ConversationBufferMemory(
174
+ memory_key="chat_history",
175
+ output_key='answer',
176
+ return_messages=True
177
+ )
178
+ conversation_chain = ConversationChain(llm=llm, memory=memory, verbose=False)
179
+ conversation_chain({"question": initial_prompt})
180
+ progress(0.9, desc="Done!")
181
+ return conversation_chain
182
+
183
  def conversation_no_doc(llm, message, history):
184
  formatted_chat_history = format_chat_history(message, history)
185
  response = llm({"question": message, "chat_history": formatted_chat_history})
 
193
  list_file_path.append(file.name)
194
  return list_file_path
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  def demo():
197
  with gr.Blocks(theme="base") as demo:
198
  vector_db = gr.State()
 
267
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
268
 
269
  with gr.Tab("Step 6 - Chatbot without document"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  chatbot_no_doc = gr.Chatbot(height=300)
271
  with gr.Row():
272
  msg_no_doc = gr.Textbox(placeholder="Type message to chat with lucIAna", container=True)
 
278
  db_btn.click(initialize_database,
279
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
280
  outputs=[vector_db, collection_name, db_progress])
281
+ set_prompt_btn.click(lambda prompt: prompt,
282
  inputs=prompt_input,
283
  outputs=initial_prompt)
284
+ qachain_btn.click(initialize_LLM,
285
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, initial_prompt],
286
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
287
  inputs=None,
 
302
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
303
  queue=False)
304
 
305
+ # Chatbot events without document
 
 
 
 
306
  submit_btn_no_doc.click(conversation_no_doc,
307
  inputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
308
  outputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
 
312
  outputs=[chatbot_no_doc, msg_no_doc],
313
  queue=False)
314
 
315
+ # Initialize LLM without document for conversation
316
+ with gr.Tab("Initialize LLM for Chatbot without document"):
317
+ with gr.Row():
318
+ llm_no_doc_btn = gr.Radio(list_llm_simple,
319
+ label="LLM models", value=list_llm_simple[0], type="index", info="Choose your LLM model for chatbot without document")
320
+ with gr.Accordion("Advanced options - LLM model", open=False):
321
+ with gr.Row():
322
+ slider_temperature_no_doc = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
323
+ with gr.Row():
324
+ slider_maxtokens_no_doc = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
325
+ with gr.Row():
326
+ slider_topk_no_doc = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
327
+ with gr.Row():
328
+ llm_no_doc_progress = gr.Textbox(value="None", label="LLM initialization for chatbot without document")
329
+ with gr.Row():
330
+ llm_no_doc_init_btn = gr.Button("Initialize LLM for Chatbot without document")
331
+
332
+ llm_no_doc_init_btn.click(initialize_llm_no_doc,
333
+ inputs=[llm_no_doc_btn, slider_temperature_no_doc, slider_maxtokens_no_doc, slider_topk_no_doc, initial_prompt],
334
+ outputs=[llm_no_doc, llm_no_doc_progress])
335
+
336
  demo.queue().launch(debug=True, share=True)
337
 
338
  if __name__ == "__main__":