Sean-Case commited on
Commit
6a76923
·
1 Parent(s): d53332d

Added temperature slider, more stringent checks for document relevance

Browse files
Files changed (2) hide show
  1. app.py +3 -2
  2. chatfuncs/chatfuncs.py +35 -57
app.py CHANGED
@@ -237,6 +237,7 @@ with block:
237
 
238
  with gr.Tab("Advanced features"):
239
  out_passages = gr.Slider(minimum=1, value = 2, maximum=10, step=1, label="Choose number of passages to retrieve from the document. Numbers greater than 2 may lead to increased hallucinations or input text being truncated.")
 
240
  with gr.Row():
241
  model_choice = gr.Radio(label="Choose a chat model", value="Flan Alpaca (small, fast)", choices = ["Flan Alpaca (small, fast)", "Mistral Open Orca (larger, slow)"])
242
  change_model_button = gr.Button(value="Load model", scale=0)
@@ -281,14 +282,14 @@ with block:
281
  # Click/enter to send message action
282
  response_click = submit.click(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages], outputs=[chat_history_state, sources, instruction_prompt_out], queue=False, api_name="retrieval").\
283
  then(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
284
- then(chatf.produce_streaming_answer_chatbot, inputs=[chatbot, instruction_prompt_out, model_type_state], outputs=chatbot)
285
  response_click.then(chatf.highlight_found_text, [chatbot, sources], [sources]).\
286
  then(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
287
  then(lambda: chatf.restore_interactivity(), None, [message], queue=False)
288
 
289
  response_enter = message.submit(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages], outputs=[chat_history_state, sources, instruction_prompt_out], queue=False).\
290
  then(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
291
- then(chatf.produce_streaming_answer_chatbot, [chatbot, instruction_prompt_out, model_type_state], chatbot)
292
  response_enter.then(chatf.highlight_found_text, [chatbot, sources], [sources]).\
293
  then(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
294
  then(lambda: chatf.restore_interactivity(), None, [message], queue=False)
 
237
 
238
  with gr.Tab("Advanced features"):
239
  out_passages = gr.Slider(minimum=1, value = 2, maximum=10, step=1, label="Choose number of passages to retrieve from the document. Numbers greater than 2 may lead to increased hallucinations or input text being truncated.")
240
+ temp_slide = gr.Slider(minimum=0.1, value = 0.1, maximum=1, step=0.1, label="Choose temperature setting for response generation.")
241
  with gr.Row():
242
  model_choice = gr.Radio(label="Choose a chat model", value="Flan Alpaca (small, fast)", choices = ["Flan Alpaca (small, fast)", "Mistral Open Orca (larger, slow)"])
243
  change_model_button = gr.Button(value="Load model", scale=0)
 
282
  # Click/enter to send message action
283
  response_click = submit.click(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages], outputs=[chat_history_state, sources, instruction_prompt_out], queue=False, api_name="retrieval").\
284
  then(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
285
+ then(chatf.produce_streaming_answer_chatbot, inputs=[chatbot, instruction_prompt_out, model_type_state, temp_slide], outputs=chatbot)
286
  response_click.then(chatf.highlight_found_text, [chatbot, sources], [sources]).\
287
  then(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
288
  then(lambda: chatf.restore_interactivity(), None, [message], queue=False)
289
 
290
  response_enter = message.submit(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages], outputs=[chat_history_state, sources, instruction_prompt_out], queue=False).\
291
  then(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
292
+ then(chatf.produce_streaming_answer_chatbot, [chatbot, instruction_prompt_out, model_type_state, temp_slide], chatbot)
293
  response_enter.then(chatf.highlight_found_text, [chatbot, sources], [sources]).\
294
  then(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
295
  then(lambda: chatf.restore_interactivity(), None, [message], queue=False)
chatfuncs/chatfuncs.py CHANGED
@@ -158,6 +158,9 @@ class CtransGenGenerationConfig:
158
  self.batch_size = batch_size
159
  self.reset = reset
160
 
 
 
 
161
  # Vectorstore funcs
162
 
163
  def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
@@ -220,23 +223,6 @@ QUESTION: {question}
220
 
221
  Response:"""
222
 
223
- instruction_prompt_template_openllama = """Answer the QUESTION using information from the following CONTENT.
224
- QUESTION - {question}
225
- CONTENT - {summaries}
226
- Answer:"""
227
-
228
- instruction_prompt_template_platypus = """### Instruction:
229
- Answer the QUESTION using information from the following CONTENT.
230
- CONTENT: {summaries}
231
- QUESTION: {question}
232
- ### Response:"""
233
-
234
- instruction_prompt_template_wizard_orca_quote = """### HUMAN:
235
- Quote text from the CONTENT to answer the QUESTION below.
236
- CONTENT - {summaries}
237
- QUESTION - {question}
238
- ### RESPONSE:
239
- """
240
 
241
  instruction_prompt_template_wizard_orca = """### HUMAN:
242
  Answer the QUESTION below based on the CONTENT. Only refer to CONTENT that directly answers the question.
@@ -266,15 +252,6 @@ CONTENT: {summaries}
266
  ### Response:
267
  """
268
 
269
- instruction_prompt_template_orca_rev = """
270
- ### System:
271
- You are an AI assistant that follows instruction extremely well. Help as much as you can.
272
- ### User:
273
- Answer the QUESTION with a short response using information from the following CONTENT.
274
- QUESTION: {question}
275
- CONTENT: {summaries}
276
-
277
- ### Response:"""
278
 
279
  instruction_prompt_mistral_orca = """<|im_start|>system\n
280
  You are an AI assistant that follows instruction extremely well. Help as much as you can.
@@ -284,23 +261,6 @@ CONTENT: {summaries}
284
  QUESTION: {question}\n
285
  Answer:<|im_end|>"""
286
 
287
- instruction_prompt_tinyllama_orca = """<|im_start|>system\n
288
- You are an AI assistant that follows instruction extremely well. Help as much as you can.
289
- <|im_start|>user\n
290
- Answer the QUESTION using information from the following CONTENT. Only quote text that directly answers the question and nothing more. If you can't find an answer to the question, respond with "Sorry, I can't find an answer to that question.".
291
- CONTENT: {summaries}
292
- QUESTION: {question}\n
293
- Answer:<|im_end|>"""
294
-
295
- instruction_prompt_marx = """
296
- ### HUMAN:
297
- Answer the QUESTION using information from the following CONTENT.
298
- CONTENT: {summaries}
299
- QUESTION: {question}
300
-
301
- ### RESPONSE:
302
- """
303
-
304
  if model_type == "Flan Alpaca (small, fast)":
305
  INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_template_alpaca, input_variables=['question', 'summaries'])
306
  elif model_type == "Mistral Open Orca (larger, slow)":
@@ -322,9 +282,16 @@ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content
322
 
323
 
324
  docs_keep_as_doc, doc_df, docs_keep_out = hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val = 25, out_passages = out_passages,
325
- vec_score_cut_off = 1, vec_weight = 1, bm25_weight = 1, svm_weight = 1)#,
326
  #vectorstore=globals()["vectorstore"], embeddings=globals()["embeddings"])
327
 
 
 
 
 
 
 
 
328
  # Expand the found passages to the neighbouring context
329
  file_type = determine_file_type(doc_df['meta_url'][0])
330
 
@@ -332,8 +299,6 @@ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content
332
  if (file_type != ".csv") & (file_type != ".xlsx"):
333
  docs_keep_as_doc, doc_df = get_expanded_passages(vectorstore, docs_keep_out, width=3)
334
 
335
- if docs_keep_as_doc == []:
336
- {"answer": "I'm sorry, I couldn't find a relevant answer to this question.", "sources":"I'm sorry, I couldn't find a relevant source for this question."}
337
 
338
 
339
  # Build up sources content to add to user display
@@ -380,11 +345,21 @@ def create_full_prompt(user_input, history, extracted_memory, vectorstore, embed
380
 
381
  print("Output history is:")
382
  print(history)
 
 
 
383
 
384
  return history, docs_content_string, instruction_prompt_out
385
 
386
  # Chat functions
387
- def produce_streaming_answer_chatbot(history, full_prompt, model_type):
 
 
 
 
 
 
 
388
  #print("Model type is: ", model_type)
389
 
390
  #if not full_prompt.strip():
@@ -410,6 +385,9 @@ def produce_streaming_answer_chatbot(history, full_prompt, model_type):
410
  temperature=temperature,
411
  top_k=top_k
412
  )
 
 
 
413
  t = Thread(target=model.generate, kwargs=generate_kwargs)
414
  t.start()
415
 
@@ -437,6 +415,7 @@ def produce_streaming_answer_chatbot(history, full_prompt, model_type):
437
  tokens = model.tokenize(full_prompt)
438
 
439
  gen_config = CtransGenGenerationConfig()
 
440
 
441
  print(vars(gen_config))
442
 
@@ -502,6 +481,8 @@ def create_doc_df(docs_keep_out):
502
  page_section=[]
503
  score=[]
504
 
 
 
505
 
506
 
507
  for item in docs_keep_out:
@@ -530,6 +511,7 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
530
 
531
  #vectorstore=globals()["vectorstore"]
532
  #embeddings=globals()["embeddings"]
 
533
 
534
 
535
  docs = vectorstore.similarity_search_with_score(new_question_kworded, k=k_val)
@@ -545,21 +527,15 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
545
  score_more_limit = pd.Series(docs_scores) < vec_score_cut_off
546
  docs_keep = list(compress(docs, score_more_limit))
547
 
548
- if docs_keep == []:
549
- docs_keep_as_doc = []
550
- docs_content = []
551
- docs_url = []
552
- return docs_keep_as_doc, docs_content, docs_url
553
 
554
  # Only keep sources that are at least 100 characters long
555
  length_more_limit = pd.Series(docs_len) >= 100
556
  docs_keep = list(compress(docs_keep, length_more_limit))
557
 
558
- if docs_keep == []:
559
- docs_keep_as_doc = []
560
- docs_content = []
561
- docs_url = []
562
- return docs_keep_as_doc, docs_content, docs_url
563
 
564
  docs_keep_as_doc = [x[0] for x in docs_keep]
565
  docs_keep_length = len(docs_keep_as_doc)
@@ -763,6 +739,8 @@ def get_expanded_passages(vectorstore, docs, width):
763
  expanded_doc = (Document(page_content=content_str[0], metadata=meta_full[0]), score)
764
  expanded_docs.append(expanded_doc)
765
 
 
 
766
  doc_df = create_doc_df(expanded_docs) # Assuming you've defined the 'create_doc_df' function elsewhere
767
 
768
  return expanded_docs, doc_df
 
158
  self.batch_size = batch_size
159
  self.reset = reset
160
 
161
+ def update_temp(self, new_value):
162
+ self.temperature = new_value
163
+
164
  # Vectorstore funcs
165
 
166
  def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
 
223
 
224
  Response:"""
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  instruction_prompt_template_wizard_orca = """### HUMAN:
228
  Answer the QUESTION below based on the CONTENT. Only refer to CONTENT that directly answers the question.
 
252
  ### Response:
253
  """
254
 
 
 
 
 
 
 
 
 
 
255
 
256
  instruction_prompt_mistral_orca = """<|im_start|>system\n
257
  You are an AI assistant that follows instruction extremely well. Help as much as you can.
 
261
  QUESTION: {question}\n
262
  Answer:<|im_end|>"""
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  if model_type == "Flan Alpaca (small, fast)":
265
  INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_template_alpaca, input_variables=['question', 'summaries'])
266
  elif model_type == "Mistral Open Orca (larger, slow)":
 
282
 
283
 
284
  docs_keep_as_doc, doc_df, docs_keep_out = hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val = 25, out_passages = out_passages,
285
+ vec_score_cut_off = 0.85, vec_weight = 1, bm25_weight = 1, svm_weight = 1)#,
286
  #vectorstore=globals()["vectorstore"], embeddings=globals()["embeddings"])
287
 
288
+ #print(docs_keep_as_doc)
289
+ #print(doc_df)
290
+ if (not docs_keep_as_doc) | (doc_df.empty):
291
+ sorry_prompt = """Say 'Sorry, there is no relevant information to answer this question.'.
292
+ RESPONSE:"""
293
+ return sorry_prompt, "No relevant sources found.", new_question_kworded
294
+
295
  # Expand the found passages to the neighbouring context
296
  file_type = determine_file_type(doc_df['meta_url'][0])
297
 
 
299
  if (file_type != ".csv") & (file_type != ".xlsx"):
300
  docs_keep_as_doc, doc_df = get_expanded_passages(vectorstore, docs_keep_out, width=3)
301
 
 
 
302
 
303
 
304
  # Build up sources content to add to user display
 
345
 
346
  print("Output history is:")
347
  print(history)
348
+
349
+ print("Final prompt to model is:")
350
+ print(instruction_prompt_out)
351
 
352
  return history, docs_content_string, instruction_prompt_out
353
 
354
  # Chat functions
355
+ def produce_streaming_answer_chatbot(history, full_prompt, model_type,
356
+ temperature=temperature,
357
+ max_new_tokens=max_new_tokens,
358
+ sample=sample,
359
+ repetition_penalty=repetition_penalty,
360
+ top_p=top_p,
361
+ top_k=top_k
362
+ ):
363
  #print("Model type is: ", model_type)
364
 
365
  #if not full_prompt.strip():
 
385
  temperature=temperature,
386
  top_k=top_k
387
  )
388
+
389
+ print(generate_kwargs)
390
+
391
  t = Thread(target=model.generate, kwargs=generate_kwargs)
392
  t.start()
393
 
 
415
  tokens = model.tokenize(full_prompt)
416
 
417
  gen_config = CtransGenGenerationConfig()
418
+ gen_config.update_temp(temperature)
419
 
420
  print(vars(gen_config))
421
 
 
481
  page_section=[]
482
  score=[]
483
 
484
+ doc_df = pd.DataFrame()
485
+
486
 
487
 
488
  for item in docs_keep_out:
 
511
 
512
  #vectorstore=globals()["vectorstore"]
513
  #embeddings=globals()["embeddings"]
514
+ doc_df = pd.DataFrame()
515
 
516
 
517
  docs = vectorstore.similarity_search_with_score(new_question_kworded, k=k_val)
 
527
  score_more_limit = pd.Series(docs_scores) < vec_score_cut_off
528
  docs_keep = list(compress(docs, score_more_limit))
529
 
530
+ if not docs_keep:
531
+ return [], pd.DataFrame(), []
 
 
 
532
 
533
  # Only keep sources that are at least 100 characters long
534
  length_more_limit = pd.Series(docs_len) >= 100
535
  docs_keep = list(compress(docs_keep, length_more_limit))
536
 
537
+ if not docs_keep:
538
+ return [], pd.DataFrame(), []
 
 
 
539
 
540
  docs_keep_as_doc = [x[0] for x in docs_keep]
541
  docs_keep_length = len(docs_keep_as_doc)
 
739
  expanded_doc = (Document(page_content=content_str[0], metadata=meta_full[0]), score)
740
  expanded_docs.append(expanded_doc)
741
 
742
+ doc_df = pd.DataFrame()
743
+
744
  doc_df = create_doc_df(expanded_docs) # Assuming you've defined the 'create_doc_df' function elsewhere
745
 
746
  return expanded_docs, doc_df