seanpedrickcase commited on
Commit
c61bb70
·
1 Parent(s): 49faa78

Added GPT-OSS 20b support. Moved to Llama cpp python chat_completion function

Browse files
app.py CHANGED
@@ -109,7 +109,7 @@ with app:
109
  master_unique_topics_df_revised_summaries_state = gr.Dataframe(value=pd.DataFrame(), headers=None, col_count=0, row_count = (0, "dynamic"), label="master_unique_topics_df_revised_summaries_state", visible=False, type="pandas")
110
  summarised_output_df = gr.Dataframe(value=pd.DataFrame(), headers=None, col_count=0, row_count = (0, "dynamic"), label="summarised_output_df", visible=False, type="pandas")
111
  summarised_references_markdown = gr.Markdown("", visible=False)
112
- summarised_outputs_list = gr.Dropdown(value=[], choices=[], visible=False, label="List of summarised outputs", allow_custom_value=True)
113
  latest_summary_completed_num = gr.Number(0, visible=False)
114
 
115
  original_data_file_name_textbox = gr.Textbox(label = "Reference data file name", value="", visible=False)
@@ -147,7 +147,7 @@ with app:
147
  gr.Markdown("""### Choose a tabular data file (xlsx or csv) of open text to extract topics from.""")
148
  with gr.Row():
149
  model_choice = gr.Dropdown(value = default_model_choice, choices = model_full_names, label="LLM model", multiselect=False)
150
- in_api_key = gr.Textbox(value = GEMINI_API_KEY, label="Enter Gemini API key (only if using Google API models)", lines=1, type="password")
151
 
152
  with gr.Accordion("Upload xlsx or csv file", open = True):
153
  in_data_files = gr.File(height=FILE_INPUT_HEIGHT, label="Choose Excel or csv files", file_count= "multiple", file_types=['.xlsx', '.xls', '.csv', '.parquet', '.csv.gz'])
@@ -308,6 +308,9 @@ with app:
308
  aws_access_key_textbox = gr.Textbox(label="AWS access key", interactive=False, lines=1, type="password")
309
  aws_secret_key_textbox = gr.Textbox(label="AWS secret key", interactive=False, lines=1, type="password")
310
 
 
 
 
311
  # Invisible text box to hold the session hash/username just for logging purposes
312
  session_hash_textbox = gr.Textbox(label = "Session hash", value="", visible=False)
313
 
@@ -315,19 +318,6 @@ with app:
315
  total_number_of_batches = gr.Number(label = "Current batch number", value = 1, precision=0, visible=False)
316
 
317
  text_output_logs = gr.Textbox(label = "Output summary logs", visible=False)
318
-
319
- # AWS options - not yet implemented
320
- # with gr.Tab(label="Advanced options"):
321
- # with gr.Accordion(label = "AWS data access", open = True):
322
- # aws_password_box = gr.Textbox(label="Password for AWS data access (ask the Data team if you don't have this)")
323
- # with gr.Row():
324
- # in_aws_file = gr.Dropdown(label="Choose file to load from AWS (only valid for API Gateway app)", choices=["None", "Lambeth borough plan"])
325
- # load_aws_data_button = gr.Button(value="Load data from AWS", variant="secondary")
326
-
327
- # aws_log_box = gr.Textbox(label="AWS data load status")
328
-
329
- # ### Loading AWS data ###
330
- # load_aws_data_button.click(fn=load_data_from_aws, inputs=[in_aws_file, aws_password_box], outputs=[in_file, aws_log_box])
331
 
332
  ###
333
  # INTERACTIVE ELEMENT FUNCTIONS
@@ -364,7 +354,7 @@ with app:
364
  display_topic_table_markdown,
365
  original_data_file_name_textbox,
366
  total_number_of_batches,
367
- in_api_key,
368
  temperature_slide,
369
  in_colnames,
370
  model_choice,
@@ -411,7 +401,7 @@ with app:
411
  output_tokens_num,
412
  number_of_calls_num],
413
  api_name="extract_topics")
414
-
415
  ###
416
  # DEDUPLICATION AND SUMMARISATION FUNCTIONS
417
  ###
@@ -430,14 +420,14 @@ with app:
430
  success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
431
  success(load_in_previous_data_files, inputs=[summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
432
  success(sample_reference_table_summaries, inputs=[master_reference_df_state, random_seed], outputs=[summary_reference_table_sample_state, summarised_references_markdown], api_name="sample_summaries").\
433
- success(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, in_api_key, temperature_slide, working_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio, output_folder_state, context_textbox, aws_access_key_textbox, aws_secret_key_textbox], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output, overall_summarisation_input_files, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number], api_name="summarise_topics")
434
 
435
- # latest_summary_completed_num.change(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, in_api_key, temperature_slide, working_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio, output_folder_state, context_textbox], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output, overall_summarisation_input_files, input_tokens_num, output_tokens_num, number_of_calls_num], scroll_to_output=True)
436
 
437
  # SUMMARISE WHOLE TABLE PAGE
438
  overall_summarise_previous_data_btn.click(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
439
  success(load_in_previous_data_files, inputs=[overall_summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
440
- success(overall_summary, inputs=[master_unique_topics_df_state, model_choice, in_api_key, temperature_slide, unique_topics_table_file_name_textbox, output_folder_state, in_colnames, context_textbox, aws_access_key_textbox, aws_secret_key_textbox], outputs=[overall_summary_output_files, overall_summarised_output_markdown, summarised_output_df, conversation_metadata_textbox, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number], scroll_to_output=True, api_name="overall_summary")
441
 
442
  ###
443
  # CONTINUE PREVIOUS TOPIC EXTRACTION PAGE
@@ -512,7 +502,13 @@ with app:
512
  usage_callback.setup([session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox, input_tokens_num,
513
  output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], USAGE_LOGS_FOLDER)
514
 
515
- conversation_metadata_textbox.change(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False, api_name="usage_logs").\
 
 
 
 
 
 
516
  success(fn = upload_file_to_s3, inputs=[usage_logs_state, usage_s3_logs_loc_state, s3_log_bucket_name, aws_access_key_textbox, aws_secret_key_textbox], outputs=[s3_logs_output_textbox])
517
 
518
  # User submitted feedback
 
109
  master_unique_topics_df_revised_summaries_state = gr.Dataframe(value=pd.DataFrame(), headers=None, col_count=0, row_count = (0, "dynamic"), label="master_unique_topics_df_revised_summaries_state", visible=False, type="pandas")
110
  summarised_output_df = gr.Dataframe(value=pd.DataFrame(), headers=None, col_count=0, row_count = (0, "dynamic"), label="summarised_output_df", visible=False, type="pandas")
111
  summarised_references_markdown = gr.Markdown("", visible=False)
112
+ summarised_outputs_list = gr.Dropdown(value= list(), choices= list(), visible=False, label="List of summarised outputs", allow_custom_value=True)
113
  latest_summary_completed_num = gr.Number(0, visible=False)
114
 
115
  original_data_file_name_textbox = gr.Textbox(label = "Reference data file name", value="", visible=False)
 
147
  gr.Markdown("""### Choose a tabular data file (xlsx or csv) of open text to extract topics from.""")
148
  with gr.Row():
149
  model_choice = gr.Dropdown(value = default_model_choice, choices = model_full_names, label="LLM model", multiselect=False)
150
+
151
 
152
  with gr.Accordion("Upload xlsx or csv file", open = True):
153
  in_data_files = gr.File(height=FILE_INPUT_HEIGHT, label="Choose Excel or csv files", file_count= "multiple", file_types=['.xlsx', '.xls', '.csv', '.parquet', '.csv.gz'])
 
308
  aws_access_key_textbox = gr.Textbox(label="AWS access key", interactive=False, lines=1, type="password")
309
  aws_secret_key_textbox = gr.Textbox(label="AWS secret key", interactive=False, lines=1, type="password")
310
 
311
+ with gr.Accordion("Enter Gemini API keys", open = False):
312
+ google_api_key_textbox = gr.Textbox(value = GEMINI_API_KEY, label="Enter Gemini API key (only if using Google API models)", lines=1, type="password")
313
+
314
  # Invisible text box to hold the session hash/username just for logging purposes
315
  session_hash_textbox = gr.Textbox(label = "Session hash", value="", visible=False)
316
 
 
318
  total_number_of_batches = gr.Number(label = "Current batch number", value = 1, precision=0, visible=False)
319
 
320
  text_output_logs = gr.Textbox(label = "Output summary logs", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  ###
323
  # INTERACTIVE ELEMENT FUNCTIONS
 
354
  display_topic_table_markdown,
355
  original_data_file_name_textbox,
356
  total_number_of_batches,
357
+ google_api_key_textbox,
358
  temperature_slide,
359
  in_colnames,
360
  model_choice,
 
401
  output_tokens_num,
402
  number_of_calls_num],
403
  api_name="extract_topics")
404
+
405
  ###
406
  # DEDUPLICATION AND SUMMARISATION FUNCTIONS
407
  ###
 
420
  success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
421
  success(load_in_previous_data_files, inputs=[summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
422
  success(sample_reference_table_summaries, inputs=[master_reference_df_state, random_seed], outputs=[summary_reference_table_sample_state, summarised_references_markdown], api_name="sample_summaries").\
423
+ success(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio, output_folder_state, context_textbox, aws_access_key_textbox, aws_secret_key_textbox], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output, overall_summarisation_input_files, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number], api_name="summarise_topics")
424
 
425
+ # latest_summary_completed_num.change(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio, output_folder_state, context_textbox], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output, overall_summarisation_input_files, input_tokens_num, output_tokens_num, number_of_calls_num], scroll_to_output=True)
426
 
427
  # SUMMARISE WHOLE TABLE PAGE
428
  overall_summarise_previous_data_btn.click(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
429
  success(load_in_previous_data_files, inputs=[overall_summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
430
+ success(overall_summary, inputs=[master_unique_topics_df_state, model_choice, google_api_key_textbox, temperature_slide, unique_topics_table_file_name_textbox, output_folder_state, in_colnames, context_textbox, aws_access_key_textbox, aws_secret_key_textbox], outputs=[overall_summary_output_files, overall_summarised_output_markdown, summarised_output_df, conversation_metadata_textbox, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number], scroll_to_output=True, api_name="overall_summary")
431
 
432
  ###
433
  # CONTINUE PREVIOUS TOPIC EXTRACTION PAGE
 
502
  usage_callback.setup([session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox, input_tokens_num,
503
  output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], USAGE_LOGS_FOLDER)
504
 
505
+ def conversation_metadata_textbox_change(textbox_value):
506
+ print("conversation_metadata_textbox_change:", textbox_value)
507
+ return textbox_value
508
+
509
+ number_of_calls_num.change(conversation_metadata_textbox_change, inputs=[conversation_metadata_textbox], outputs=[conversation_metadata_textbox])
510
+
511
+ number_of_calls_num.change(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False, api_name="usage_logs").\
512
  success(fn = upload_file_to_s3, inputs=[usage_logs_state, usage_s3_logs_loc_state, s3_log_bucket_name, aws_access_key_textbox, aws_secret_key_textbox], outputs=[s3_logs_output_textbox])
513
 
514
  # User submitted feedback
requirements_gpu.txt CHANGED
@@ -1,5 +1,5 @@
1
  pandas==2.3.1
2
- gradio==5.42.0
3
  transformers==4.55.2
4
  spaces==0.40.0
5
  boto3==1.40.11
@@ -17,7 +17,7 @@ python-dotenv==1.1.0
17
  # Torch and Llama CPP Python
18
  torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124 # Latest compatible with CUDA 12.4
19
  # For Linux:
20
- #https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.16-cu124/llama_cpp_python-0.3.16-cp311-cp311-linux_x86_64.whl
21
  # For Windows:
22
  #llama-cpp-python==0.3.16 -C cmake.args="-DGGML_CUDA=on" --verbose
23
  # If above doesn't work for Windows, try looking at'windows_install_llama-cpp-python.txt'
 
1
  pandas==2.3.1
2
+ gradio==5.44.0
3
  transformers==4.55.2
4
  spaces==0.40.0
5
  boto3==1.40.11
 
17
  # Torch and Llama CPP Python
18
  torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124 # Latest compatible with CUDA 12.4
19
  # For Linux:
20
+ https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.16-cu124/llama_cpp_python-0.3.16-cp311-cp311-linux_x86_64.whl
21
  # For Windows:
22
  #llama-cpp-python==0.3.16 -C cmake.args="-DGGML_CUDA=on" --verbose
23
  # If above doesn't work for Windows, try looking at'windows_install_llama-cpp-python.txt'
tools/config.py CHANGED
@@ -222,7 +222,7 @@ model_full_names = []
222
  model_short_names = []
223
  model_source = []
224
 
225
- CHOSEN_LOCAL_MODEL_TYPE = get_or_create_env_var("CHOSEN_LOCAL_MODEL_TYPE", "Gemma 3 4B") # Gemma 3 1B # "Gemma 2b"
226
 
227
  if RUN_LOCAL_MODEL == "1" and CHOSEN_LOCAL_MODEL_TYPE:
228
  model_full_names.append(CHOSEN_LOCAL_MODEL_TYPE)
@@ -252,18 +252,28 @@ model_name_map = {
252
  # HF token may or may not be needed for downloading models from Hugging Face
253
  HF_TOKEN = get_or_create_env_var('HF_TOKEN', '')
254
 
255
- GEMMA2_REPO_ID = get_or_create_env_var("GEMMA2_2B_REPO_ID", "lmstudio-community/gemma-2-2b-it-GGUF")# "bartowski/Llama-3.2-3B-Instruct-GGUF") # "lmstudio-community/gemma-2-2b-it-GGUF")#"QuantFactory/Phi-3-mini-128k-instruct-GGUF")
256
- GEMMA2_MODEL_FILE = get_or_create_env_var("GEMMA2_2B_MODEL_FILE", "gemma-2-2b-it-Q8_0.gguf") # )"Llama-3.2-3B-Instruct-Q5_K_M.gguf") #"gemma-2-2b-it-Q8_0.gguf") #"Phi-3-mini-128k-instruct.Q4_K_M.gguf")
257
- GEMMA2_MODEL_FOLDER = get_or_create_env_var("GEMMA2_2B_MODEL_FOLDER", "model/gemma") #"model/phi" # Assuming this is your intended directory
258
 
259
- GEMMA3_REPO_ID = get_or_create_env_var("GEMMA3_REPO_ID", "ggml-org/gemma-3-1b-it-GGUF")# "bartowski/Llama-3.2-3B-Instruct-GGUF") # "lmstudio-community/gemma-2-2b-it-GGUF")#"QuantFactory/Phi-3-mini-128k-instruct-GGUF")
260
- GEMMA3_MODEL_FILE = get_or_create_env_var("GEMMA3_MODEL_FILE", "gemma-3-1b-it-Q8_0.gguf") # )"Llama-3.2-3B-Instruct-Q5_K_M.gguf") #"gemma-2-2b-it-Q8_0.gguf") #"Phi-3-mini-128k-instruct.Q4_K_M.gguf")
261
  GEMMA3_MODEL_FOLDER = get_or_create_env_var("GEMMA3_MODEL_FOLDER", "model/gemma")
262
 
263
- GEMMA3_4B_REPO_ID = get_or_create_env_var("GEMMA3_4B_REPO_ID", "ggml-org/gemma-3-4b-it-GGUF")# "bartowski/Llama-3.2-3B-Instruct-GGUF") # "lmstudio-community/gemma-2-2b-it-GGUF")#"QuantFactory/Phi-3-mini-128k-instruct-GGUF")
264
- GEMMA3_4B_MODEL_FILE = get_or_create_env_var("GEMMA3_4B_MODEL_FILE", "gemma-3-4b-it-Q4_K_M.gguf") # )"Llama-3.2-3B-Instruct-Q5_K_M.gguf") #"gemma-2-2b-it-Q8_0.gguf") #"Phi-3-mini-128k-instruct.Q4_K_M.gguf")
265
  GEMMA3_4B_MODEL_FOLDER = get_or_create_env_var("GEMMA3_4B_MODEL_FOLDER", "model/gemma3_4b")
266
 
 
 
 
 
 
 
 
 
 
 
267
  if CHOSEN_LOCAL_MODEL_TYPE == "Gemma 2b":
268
  LOCAL_REPO_ID = GEMMA2_REPO_ID
269
  LOCAL_MODEL_FILE = GEMMA2_MODEL_FILE
@@ -280,25 +290,34 @@ elif CHOSEN_LOCAL_MODEL_TYPE == "Gemma 3 4B":
280
  LOCAL_MODEL_FILE = GEMMA3_4B_MODEL_FILE
281
  LOCAL_MODEL_FOLDER = GEMMA3_4B_MODEL_FOLDER
282
 
 
 
 
 
 
283
  print("CHOSEN_LOCAL_MODEL_TYPE:", CHOSEN_LOCAL_MODEL_TYPE)
284
  print("LOCAL_REPO_ID:", LOCAL_REPO_ID)
285
  print("LOCAL_MODEL_FILE:", LOCAL_MODEL_FILE)
286
  print("LOCAL_MODEL_FOLDER:", LOCAL_MODEL_FOLDER)
287
 
288
- LLM_MAX_GPU_LAYERS = int(get_or_create_env_var('LLM_MAX_GPU_LAYERS','-1'))
289
  LLM_TEMPERATURE = float(get_or_create_env_var('LLM_TEMPERATURE', '0.1'))
290
- LLM_TOP_K = int(get_or_create_env_var('LLM_TOP_K','3'))
291
- LLM_TOP_P = float(get_or_create_env_var('LLM_TOP_P', '1'))
292
- LLM_REPETITION_PENALTY = float(get_or_create_env_var('LLM_REPETITION_PENALTY', '1.2')) # Mild repetition penalty to prevent repeating table rows
 
293
  LLM_LAST_N_TOKENS = int(get_or_create_env_var('LLM_LAST_N_TOKENS', '512'))
294
  LLM_MAX_NEW_TOKENS = int(get_or_create_env_var('LLM_MAX_NEW_TOKENS', '4096'))
295
  LLM_SEED = int(get_or_create_env_var('LLM_SEED', '42'))
296
  LLM_RESET = get_or_create_env_var('LLM_RESET', 'True')
297
  LLM_STREAM = get_or_create_env_var('LLM_STREAM', 'False')
298
  LLM_THREADS = int(get_or_create_env_var('LLM_THREADS', '4'))
299
- LLM_BATCH_SIZE = int(get_or_create_env_var('LLM_BATCH_SIZE', '256'))
300
  LLM_CONTEXT_LENGTH = int(get_or_create_env_var('LLM_CONTEXT_LENGTH', '16384'))
301
  LLM_SAMPLE = get_or_create_env_var('LLM_SAMPLE', 'True')
 
 
 
302
 
303
  MAX_GROUPS = int(get_or_create_env_var('MAX_GROUPS', '99'))
304
 
 
222
  model_short_names = []
223
  model_source = []
224
 
225
+ CHOSEN_LOCAL_MODEL_TYPE = get_or_create_env_var("CHOSEN_LOCAL_MODEL_TYPE", "gpt-oss-20b") # Gemma 3 1B # "Gemma 2b" # "Gemma 3 4B"
226
 
227
  if RUN_LOCAL_MODEL == "1" and CHOSEN_LOCAL_MODEL_TYPE:
228
  model_full_names.append(CHOSEN_LOCAL_MODEL_TYPE)
 
252
  # HF token may or may not be needed for downloading models from Hugging Face
253
  HF_TOKEN = get_or_create_env_var('HF_TOKEN', '')
254
 
255
+ GEMMA2_REPO_ID = get_or_create_env_var("GEMMA2_2B_REPO_ID", "lmstudio-community/gemma-2-2b-it-GGUF")
256
+ GEMMA2_MODEL_FILE = get_or_create_env_var("GEMMA2_2B_MODEL_FILE", "gemma-2-2b-it-Q8_0.gguf")
257
+ GEMMA2_MODEL_FOLDER = get_or_create_env_var("GEMMA2_2B_MODEL_FOLDER", "model/gemma")
258
 
259
+ GEMMA3_REPO_ID = get_or_create_env_var("GEMMA3_REPO_ID", "unsloth/gemma-3-270m-it-qat-GGUF")
260
+ GEMMA3_MODEL_FILE = get_or_create_env_var("GEMMA3_MODEL_FILE", "gemma-3-270m-it-qat-F16.gguf")
261
  GEMMA3_MODEL_FOLDER = get_or_create_env_var("GEMMA3_MODEL_FOLDER", "model/gemma")
262
 
263
+ GEMMA3_4B_REPO_ID = get_or_create_env_var("GEMMA3_4B_REPO_ID", "unsloth/gemma-3-4b-it-qat-GGUF")
264
+ GEMMA3_4B_MODEL_FILE = get_or_create_env_var("GEMMA3_4B_MODEL_FILE", "gemma-3-4b-it-qat-Q4_K_M.gguf")
265
  GEMMA3_4B_MODEL_FOLDER = get_or_create_env_var("GEMMA3_4B_MODEL_FOLDER", "model/gemma3_4b")
266
 
267
+ GPT_OSS_REPO_ID = get_or_create_env_var("GPT_OSS_REPO_ID", "unsloth/gpt-oss-20b-GGUF")
268
+ GPT_OSS_MODEL_FILE = get_or_create_env_var("GPT_OSS_MODEL_FILE", "gpt-oss-20b-F16.gguf")
269
+ GPT_OSS_MODEL_FOLDER = get_or_create_env_var("GPT_OSS_MODEL_FOLDER", "model/gpt_oss")
270
+
271
+ USE_SPECULATIVE_DECODING = get_or_create_env_var("USE_SPECULATIVE_DECODING", "False")
272
+
273
+ GEMMA3_DRAFT_MODEL_LOC = get_or_create_env_var("GEMMA3_DRAFT_MODEL_LOC", ".cache/llama.cpp/unsloth_gemma-3-270m-it-qat-GGUF_gemma-3-270m-it-qat-F16.gguf")
274
+
275
+ GEMMA3_4B_DRAFT_MODEL_LOC = get_or_create_env_var("GEMMA3_4B_DRAFT_MODEL_LOC", ".cache/llama.cpp/unsloth_gemma-3-4b-it-qat-GGUF_gemma-3-4b-it-qat-Q4_K_M.gguf")
276
+
277
  if CHOSEN_LOCAL_MODEL_TYPE == "Gemma 2b":
278
  LOCAL_REPO_ID = GEMMA2_REPO_ID
279
  LOCAL_MODEL_FILE = GEMMA2_MODEL_FILE
 
290
  LOCAL_MODEL_FILE = GEMMA3_4B_MODEL_FILE
291
  LOCAL_MODEL_FOLDER = GEMMA3_4B_MODEL_FOLDER
292
 
293
+ elif CHOSEN_LOCAL_MODEL_TYPE == "gpt-oss-20b":
294
+ LOCAL_REPO_ID = GPT_OSS_REPO_ID
295
+ LOCAL_MODEL_FILE = GPT_OSS_MODEL_FILE
296
+ LOCAL_MODEL_FOLDER = GPT_OSS_MODEL_FOLDER
297
+
298
  print("CHOSEN_LOCAL_MODEL_TYPE:", CHOSEN_LOCAL_MODEL_TYPE)
299
  print("LOCAL_REPO_ID:", LOCAL_REPO_ID)
300
  print("LOCAL_MODEL_FILE:", LOCAL_MODEL_FILE)
301
  print("LOCAL_MODEL_FOLDER:", LOCAL_MODEL_FOLDER)
302
 
303
+ LLM_MAX_GPU_LAYERS = int(get_or_create_env_var('LLM_MAX_GPU_LAYERS','-1')) # Maximum possible
304
  LLM_TEMPERATURE = float(get_or_create_env_var('LLM_TEMPERATURE', '0.1'))
305
+ LLM_TOP_K = int(get_or_create_env_var('LLM_TOP_K','96'))
306
+ LLM_MIN_P = float(get_or_create_env_var('LLM_MIN_P', '0'))
307
+ LLM_TOP_P = float(get_or_create_env_var('LLM_TOP_P', '0.95'))
308
+ LLM_REPETITION_PENALTY = float(get_or_create_env_var('LLM_REPETITION_PENALTY', '1.0'))
309
  LLM_LAST_N_TOKENS = int(get_or_create_env_var('LLM_LAST_N_TOKENS', '512'))
310
  LLM_MAX_NEW_TOKENS = int(get_or_create_env_var('LLM_MAX_NEW_TOKENS', '4096'))
311
  LLM_SEED = int(get_or_create_env_var('LLM_SEED', '42'))
312
  LLM_RESET = get_or_create_env_var('LLM_RESET', 'True')
313
  LLM_STREAM = get_or_create_env_var('LLM_STREAM', 'False')
314
  LLM_THREADS = int(get_or_create_env_var('LLM_THREADS', '4'))
315
+ LLM_BATCH_SIZE = int(get_or_create_env_var('LLM_BATCH_SIZE', '128'))
316
  LLM_CONTEXT_LENGTH = int(get_or_create_env_var('LLM_CONTEXT_LENGTH', '16384'))
317
  LLM_SAMPLE = get_or_create_env_var('LLM_SAMPLE', 'True')
318
+ SPECULATIVE_DECODING = get_or_create_env_var('SPECULATIVE_DECODING', 'False')
319
+ NUM_PRED_TOKENS = int(get_or_create_env_var('NUM_PRED_TOKENS', '2'))
320
+ REASONING_SUFFIX = get_or_create_env_var('REASONING_SUFFIX', 'Reasoning: low')
321
 
322
  MAX_GROUPS = int(get_or_create_env_var('MAX_GROUPS', '99'))
323
 
tools/custom_csvlogger.py CHANGED
@@ -106,7 +106,7 @@ class CSVLogger_custom(FlaggingCallback):
106
  )
107
  latest_num = int(re.findall(r"\d+", latest_file.stem)[0])
108
 
109
- with open(latest_file, newline="", encoding="utf-8") as csvfile:
110
  reader = csv.reader(csvfile)
111
  existing_headers = next(reader, None)
112
 
@@ -122,7 +122,7 @@ class CSVLogger_custom(FlaggingCallback):
122
 
123
  if not Path(self.dataset_filepath).exists():
124
  with open(
125
- self.dataset_filepath, "w", newline="", encoding="utf-8"
126
  ) as csvfile:
127
  writer = csv.writer(csvfile)
128
  writer.writerow(utils.sanitize_list_for_csv(headers))
@@ -202,15 +202,12 @@ class CSVLogger_custom(FlaggingCallback):
202
 
203
  if save_to_csv:
204
  with self.lock:
205
- with open(self.dataset_filepath, "a", newline="", encoding="utf-8") as csvfile:
206
  writer = csv.writer(csvfile)
207
  writer.writerow(utils.sanitize_list_for_csv(csv_data))
208
- with open(self.dataset_filepath, encoding="utf-8") as csvfile:
209
  line_count = len(list(csv.reader(csvfile))) - 1
210
 
211
-
212
- print("save_to_dynamodb:", save_to_dynamodb)
213
- print("save_to_dynamodb == True:", save_to_dynamodb == True)
214
  if save_to_dynamodb == True:
215
  print("Saving to DynamoDB")
216
 
 
106
  )
107
  latest_num = int(re.findall(r"\d+", latest_file.stem)[0])
108
 
109
+ with open(latest_file, newline="", encoding="utf-8-sig") as csvfile:
110
  reader = csv.reader(csvfile)
111
  existing_headers = next(reader, None)
112
 
 
122
 
123
  if not Path(self.dataset_filepath).exists():
124
  with open(
125
+ self.dataset_filepath, "w", newline="", encoding="utf-8-sig"
126
  ) as csvfile:
127
  writer = csv.writer(csvfile)
128
  writer.writerow(utils.sanitize_list_for_csv(headers))
 
202
 
203
  if save_to_csv:
204
  with self.lock:
205
+ with open(self.dataset_filepath, "a", newline="", encoding="utf-8-sig") as csvfile:
206
  writer = csv.writer(csvfile)
207
  writer.writerow(utils.sanitize_list_for_csv(csv_data))
208
+ with open(self.dataset_filepath, encoding="utf-8-sig") as csvfile:
209
  line_count = len(list(csv.reader(csvfile))) - 1
210
 
 
 
 
211
  if save_to_dynamodb == True:
212
  print("Saving to DynamoDB")
213
 
tools/dedup_summaries.py CHANGED
@@ -436,7 +436,7 @@ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:
436
  if isinstance(responses[-1], ResponseObject):
437
  response_texts = [resp.text for resp in responses]
438
  elif "choices" in responses[-1]:
439
- response_texts = [resp["choices"][0]['text'] for resp in responses]
440
  else:
441
  response_texts = [resp.text for resp in responses]
442
 
 
436
  if isinstance(responses[-1], ResponseObject):
437
  response_texts = [resp.text for resp in responses]
438
  elif "choices" in responses[-1]:
439
+ response_texts = [resp["choices"][0]['message']['content'] for resp in responses] #resp["choices"][0]['text'] for resp in responses]
440
  else:
441
  response_texts = [resp.text for resp in responses]
442
 
tools/llm_api_call.py CHANGED
@@ -17,7 +17,7 @@ GradioFileData = gr.FileData
17
  from tools.prompts import initial_table_prompt, prompt2, prompt3, initial_table_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, force_existing_topics_prompt, allow_new_topics_prompt, force_single_topic_prompt, add_existing_topics_assistant_prefill, initial_table_assistant_prefill, structured_summary_prompt
18
  from tools.helper_functions import read_file, put_columns_in_df, wrap_text, initial_clean, load_in_data_file, load_in_file, create_topic_summary_df_from_reference_table, convert_reference_table_to_pivot_table, get_basic_response_data, clean_column_name, load_in_previous_data_files
19
  from tools.llm_funcs import ResponseObject, construct_gemini_generative_model, call_llm_with_markdown_table_checks, create_missing_references_df, calculate_tokens_from_metadata
20
- from tools.config import RUN_LOCAL_MODEL, AWS_REGION, MAX_COMMENT_CHARS, MAX_OUTPUT_VALIDATION_ATTEMPTS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, RUN_AWS_FUNCTIONS, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_GROUPS
21
  from tools.aws_functions import connect_to_bedrock_runtime
22
 
23
  if RUN_LOCAL_MODEL == "1":
@@ -31,11 +31,12 @@ batch_size_default = BATCH_SIZE_DEFAULT
31
  deduplication_threshold = DEDUPLICATION_THRESHOLD
32
  max_comment_character_length = MAX_COMMENT_CHARS
33
  random_seed = LLM_SEED
 
34
 
35
  # if RUN_AWS_FUNCTIONS == '1':
36
  # bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
37
  # else:
38
- # bedrock_runtime = []
39
 
40
 
41
 
@@ -130,7 +131,7 @@ def clean_markdown_table(text: str):
130
  lines = text.splitlines()
131
 
132
  # Step 1: Identify table structure and process line continuations
133
- table_rows = []
134
  current_row = None
135
 
136
  for line in lines:
@@ -174,7 +175,7 @@ def clean_markdown_table(text: str):
174
  max_columns = max(max_columns, len(cells))
175
 
176
  # Now format each row
177
- formatted_rows = []
178
  for row in table_rows:
179
  # Ensure the row starts and ends with pipes
180
  if not row.startswith('|'):
@@ -354,7 +355,7 @@ def write_llm_output_and_logs(responses: List[ResponseObject],
354
  - first_run (bool): A boolean indicating if this is the first run through this function in this process. Defaults to False.
355
  - output_folder (str): The name of the folder where output files are saved.
356
  """
357
- topic_summary_df_out_path = []
358
  topic_table_out_path = "topic_table_error.csv"
359
  reference_table_out_path = "reference_table_error.csv"
360
  topic_summary_df_out_path = "unique_topic_table_error.csv"
@@ -390,7 +391,7 @@ def write_llm_output_and_logs(responses: List[ResponseObject],
390
  log_files_output_paths.append(whole_conversation_path_meta)
391
 
392
  if isinstance(responses[-1], ResponseObject): response_text = responses[-1].text
393
- elif "choices" in responses[-1]: response_text = responses[-1]["choices"][0]['text']
394
  else: response_text = responses[-1].text
395
 
396
  # Convert response text to a markdown table
@@ -426,7 +427,7 @@ def write_llm_output_and_logs(responses: List[ResponseObject],
426
  topic_table_out_path = output_folder + batch_file_path_details + "_topic_table_" + model_choice_clean + ".csv"
427
 
428
  # Table to map references to topics
429
- reference_data = []
430
 
431
  batch_basic_response_df["Reference"] = batch_basic_response_df["Reference"].astype(str)
432
 
@@ -614,13 +615,7 @@ def generate_zero_shot_topics_df(zero_shot_topics:pd.DataFrame,
614
  if force_zero_shot_radio == "Yes":
615
  zero_shot_topics_gen_topics_list.append("")
616
  zero_shot_topics_subtopics_list.append("No relevant topic")
617
- zero_shot_topics_description_list.append("")
618
-
619
- # This process was abandoned (revising the general topics) as it didn't seem to work
620
- # if create_revised_general_topics == True:
621
- # pass
622
- # else:
623
- # pass
624
 
625
  # Add description or not
626
  zero_shot_topics_df = pd.DataFrame(data={
@@ -646,9 +641,9 @@ def extract_topics(in_data_file: GradioFileData,
646
  model_choice:str,
647
  candidate_topics: GradioFileData = None,
648
  latest_batch_completed:int=0,
649
- out_message:List=[],
650
- out_file_paths:List = [],
651
- log_files_output_paths:List = [],
652
  first_loop_state:bool=False,
653
  whole_conversation_metadata_str:str="",
654
  initial_table_prompt:str=initial_table_prompt,
@@ -663,7 +658,7 @@ def extract_topics(in_data_file: GradioFileData,
663
  time_taken:float = 0,
664
  sentiment_checkbox:str = "Negative, Neutral, or Positive",
665
  force_zero_shot_radio:str = "No",
666
- in_excel_sheets:List[str] = [],
667
  force_single_topic_radio:str = "No",
668
  output_folder:str=OUTPUT_FOLDER,
669
  force_single_topic_prompt:str=force_single_topic_prompt,
@@ -675,6 +670,7 @@ def extract_topics(in_data_file: GradioFileData,
675
  model_name_map:dict=model_name_map,
676
  max_time_for_loop:int=max_time_for_loop,
677
  CHOSEN_LOCAL_MODEL_TYPE:str=CHOSEN_LOCAL_MODEL_TYPE,
 
678
  progress=Progress(track_tqdm=True)):
679
 
680
  '''
@@ -723,31 +719,36 @@ def extract_topics(in_data_file: GradioFileData,
723
  - model_name_map (dict, optional): A dictionary mapping full model name to shortened.
724
  - max_time_for_loop (int, optional): The number of seconds maximum that the function should run for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there).
725
  - CHOSEN_LOCAL_MODEL_TYPE (str, optional): The name of the chosen local model.
 
726
  - progress (Progress): A progress tracker.
727
  '''
728
 
729
  tic = time.perf_counter()
730
- google_client = []
731
  google_config = {}
732
  final_time = 0.0
733
- whole_conversation_metadata = []
734
  is_error = False
735
  create_revised_general_topics = False
736
- local_model = []
737
- tokenizer = []
738
  zero_shot_topics_df = pd.DataFrame()
739
  missing_df = pd.DataFrame()
740
  new_reference_df = pd.DataFrame(columns=["Response References", "General topic", "Subtopic", "Sentiment", "Start row of group", "Group" ,"Topic_number", "Summary"])
741
  new_topic_summary_df = pd.DataFrame(columns=["General topic","Subtopic","Sentiment","Group","Number of responses","Summary"])
742
  new_topic_df = pd.DataFrame()
743
- #llama_system_prefix = "<|start_header_id|>system<|end_header_id|>\n" #"<start_of_turn>user\n"
744
- #llama_system_suffix = "<|eot_id|>" #"<end_of_turn>\n<start_of_turn>model\n"
745
- #llama_cpp_prefix = "<|start_header_id|>system<|end_header_id|>\nYou are an AI assistant that follows instruction extremely well. Help as much as you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n" #"<start_of_turn>user\n"
746
- #llama_cpp_suffix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" #"<end_of_turn>\n<start_of_turn>model\n"
747
- #llama_cpp_prefix = "<|user|>\n" # This is for phi 3.5
748
- #llama_cpp_suffix = "<|end|>\n<|assistant|>|" # This is for phi 3.5
749
- llama_cpp_prefix = "<start_of_turn>user\n"
750
- llama_cpp_suffix = "<end_of_turn>\n<start_of_turn>model\n"
 
 
 
 
751
 
752
  #print("output_folder:", output_folder)
753
 
@@ -773,8 +774,8 @@ def extract_topics(in_data_file: GradioFileData,
773
  print("This is the first time through the loop, resetting latest_batch_completed to 0")
774
  if (latest_batch_completed == 999) | (latest_batch_completed == 0):
775
  latest_batch_completed = 0
776
- out_message = []
777
- out_file_paths = []
778
  final_time = 0
779
 
780
  if (model_choice == CHOSEN_LOCAL_MODEL_TYPE) & (RUN_LOCAL_MODEL == "1"):
@@ -795,7 +796,7 @@ def extract_topics(in_data_file: GradioFileData,
795
  out_message = [out_message]
796
 
797
  if not out_file_paths:
798
- out_file_paths = []
799
 
800
 
801
  if "anthropic.claude-3-sonnet" in model_choice and file_data.shape[1] > 300:
@@ -819,7 +820,7 @@ def extract_topics(in_data_file: GradioFileData,
819
  simplified_csv_table_path, normalised_simple_markdown_table, start_row, end_row, batch_basic_response_df = data_file_to_markdown_table(file_data, file_name, chosen_cols, latest_batch_completed, batch_size)
820
 
821
  # Conversation history
822
- conversation_history = []
823
 
824
  # If the latest batch of responses contains at least one instance of text
825
  if not batch_basic_response_df.empty:
@@ -933,13 +934,15 @@ def extract_topics(in_data_file: GradioFileData,
933
  except Exception as e:
934
  print(f"Error writing prompt to file {formatted_prompt_output_path}: {e}")
935
 
936
- if "gemma" in model_choice:
937
- summary_prompt_list = [full_prompt] # Includes system prompt
938
- else:
939
- summary_prompt_list = [formatted_summary_prompt]
 
 
940
 
941
- conversation_history = []
942
- whole_conversation = []
943
 
944
  # Process requests to large language model
945
  responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=add_existing_topics_assistant_prefill, master = True)
@@ -953,13 +956,16 @@ def extract_topics(in_data_file: GradioFileData,
953
 
954
  if isinstance(responses[-1], ResponseObject):
955
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
956
- f.write(responses[-1].text)
 
957
  elif "choices" in responses[-1]:
958
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
959
- f.write(responses[-1]["choices"][0]['text'])
 
960
  else:
961
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
962
- f.write(responses[-1].text)
 
963
 
964
  except Exception as e:
965
  print("Error in returning model response:", e)
@@ -1020,26 +1026,23 @@ def extract_topics(in_data_file: GradioFileData,
1020
  if prompt3: formatted_prompt3 = prompt3.format(response_table=normalised_simple_markdown_table, sentiment_choices=sentiment_prompt)
1021
  else: formatted_prompt3 = prompt3
1022
 
1023
- if "gemma" in model_choice:
1024
- formatted_initial_table_prompt = llama_cpp_prefix + formatted_initial_table_system_prompt + "\n" + formatted_initial_table_prompt + llama_cpp_suffix
1025
- formatted_prompt2 = llama_cpp_prefix + formatted_initial_table_system_prompt + "\n" + formatted_prompt2 + llama_cpp_suffix
1026
- formatted_prompt3 = llama_cpp_prefix + formatted_initial_table_system_prompt + "\n" + formatted_prompt3 + llama_cpp_suffix
 
 
1027
 
1028
- batch_prompts = [formatted_initial_table_prompt, formatted_prompt2, formatted_prompt3][:number_of_prompts_used] # Adjust this list to send fewer requests
1029
 
1030
- whole_conversation = []
1031
 
1032
  responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, formatted_initial_table_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=initial_table_assistant_prefill)
1033
 
1034
  topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, topic_table_df, reference_df, new_topic_summary_df, batch_file_path_details, is_error = write_llm_output_and_logs(responses, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, group_name, produce_structures_summary_radio, first_run=True, output_folder=output_folder)
1035
 
1036
  # If error in table parsing, leave function
1037
- if is_error == True:
1038
- raise Exception("Error in output table parsing")
1039
- # unique_table_df_display_table_markdown, new_topic_df, new_topic_summary_df, new_reference_df, out_file_paths, out_file_paths, latest_batch_completed, log_files_output_paths, log_files_output_paths, whole_conversation_metadata_str, final_time, out_file_paths#, final_message_out
1040
-
1041
-
1042
- #all_topic_tables_df.append(topic_table_df)
1043
 
1044
  topic_table_df.to_csv(topic_table_out_path, index=None)
1045
  out_file_paths.append(topic_table_out_path)
@@ -1056,28 +1059,28 @@ def extract_topics(in_data_file: GradioFileData,
1056
  new_topic_summary_df.to_csv(topic_summary_df_out_path, index=None)
1057
  out_file_paths.append(topic_summary_df_out_path)
1058
 
1059
- #all_markdown_topic_tables.append(markdown_table)
1060
 
1061
  whole_conversation_metadata.append(whole_conversation_metadata_str)
1062
  whole_conversation_metadata_str = '. '.join(whole_conversation_metadata)
1063
 
1064
- # Write final output to text file also
1065
  # Write final output to text file for logging purposes
1066
  try:
1067
  final_table_output_path = output_folder + batch_file_path_details + "_full_response_" + model_choice_clean + ".txt"
1068
 
1069
  if isinstance(responses[-1], ResponseObject):
1070
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
1071
- f.write(responses[-1].text)
 
1072
  elif "choices" in responses[-1]:
1073
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
1074
- f.write(responses[-1]["choices"][0]['text'])
 
1075
  else:
1076
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
1077
- f.write(responses[-1].text)
 
1078
 
1079
- except Exception as e:
1080
- print("Error in returning model response:", e)
1081
 
1082
  new_topic_df = topic_table_df
1083
  new_reference_df = reference_df
@@ -1122,7 +1125,7 @@ def extract_topics(in_data_file: GradioFileData,
1122
  # Set to a very high number so as not to mess with subsequent file processing by the user
1123
  #latest_batch_completed = 999
1124
 
1125
- join_file_paths = []
1126
 
1127
  toc = time.perf_counter()
1128
  final_time = (toc - tic) + time_taken
@@ -1198,8 +1201,12 @@ def extract_topics(in_data_file: GradioFileData,
1198
 
1199
  print("latest_batch_completed at end of batch iterations to return is", latest_batch_completed)
1200
 
 
 
1201
  return unique_table_df_display_table_markdown, existing_topics_table, final_out_topic_summary_df, existing_reference_df, final_out_file_paths, final_out_file_paths, latest_batch_completed, log_files_output_paths, log_files_output_paths, whole_conversation_metadata_str, final_time, final_out_file_paths, final_out_file_paths, modifiable_topic_summary_df, final_out_file_paths, join_file_paths, existing_reference_df_pivot, missing_df
1202
 
 
 
1203
  return unique_table_df_display_table_markdown, existing_topics_table, existing_topic_summary_df, existing_reference_df, out_file_paths, out_file_paths, latest_batch_completed, log_files_output_paths, log_files_output_paths, whole_conversation_metadata_str, final_time, out_file_paths, out_file_paths, modifiable_topic_summary_df, out_file_paths, join_file_paths, existing_reference_df_pivot, missing_df # gr.Dataframe(value=modifiable_topic_summary_df, headers=None, col_count=(modifiable_topic_summary_df.shape[1], "fixed"), row_count = (modifiable_topic_summary_df.shape[0], "fixed"), visible=True, type="pandas"),
1204
 
1205
  def wrapper_extract_topics_per_column_value(
@@ -1237,7 +1244,7 @@ def wrapper_extract_topics_per_column_value(
1237
  context_textbox: str = "",
1238
  sentiment_checkbox: str = "Negative, Neutral, or Positive",
1239
  force_zero_shot_radio: str = "No",
1240
- in_excel_sheets: List[str] = [],
1241
  force_single_topic_radio: str = "No",
1242
  produce_structures_summary_radio: str = "No",
1243
  aws_access_key_textbox:str="",
@@ -1279,9 +1286,9 @@ def wrapper_extract_topics_per_column_value(
1279
  acc_missing_df = pd.DataFrame()
1280
 
1281
  # Lists are extended
1282
- acc_out_file_paths = []
1283
- acc_log_files_output_paths = []
1284
- acc_join_file_paths = [] # join_file_paths seems to be overwritten, so maybe last one or extend? Let's extend.
1285
 
1286
  # Single value outputs - typically the last one is most relevant, or sum for time
1287
  acc_markdown_output = initial_unique_table_df_display_table_markdown
@@ -1353,9 +1360,9 @@ def wrapper_extract_topics_per_column_value(
1353
  num_batches=current_num_batches,
1354
  latest_batch_completed=current_latest_batch_completed, # Reset for each new segment's internal batching
1355
  first_loop_state=current_first_loop_state, # True only for the very first iteration of wrapper
1356
- out_message=[], # Fresh for each call
1357
- out_file_paths=[],# Fresh for each call
1358
- log_files_output_paths=[],# Fresh for each call
1359
  whole_conversation_metadata_str="", # Fresh for each call
1360
  time_taken=0, # Time taken for this specific call, wrapper sums it.
1361
  # Pass through other parameters
@@ -1450,8 +1457,12 @@ def wrapper_extract_topics_per_column_value(
1450
  unique_table_df_display_table = acc_topic_summary_df.apply(lambda col: col.map(lambda x: wrap_text(x, max_text_length=500)))
1451
  acc_markdown_output = unique_table_df_display_table[["General topic", "Subtopic", "Sentiment", "Number of responses", "Summary", "Group"]].to_markdown(index=False)
1452
 
 
 
1453
  acc_input_tokens, acc_output_tokens, acc_number_of_calls = calculate_tokens_from_metadata(acc_whole_conversation_metadata, model_choice, model_name_map)
1454
 
 
 
1455
  print(f"\nWrapper finished processing all segments. Total time: {acc_total_time_taken:.2f}s")
1456
 
1457
  # The return signature should match extract_topics.
@@ -1537,7 +1548,7 @@ def modify_existing_output_tables(original_topic_summary_df:pd.DataFrame, modifi
1537
  reference_file_path = os.path.basename(reference_files[0]) if reference_files else None
1538
  unique_table_file_path = os.path.basename(unique_files[0]) if unique_files else None
1539
 
1540
- output_file_list = []
1541
 
1542
  if reference_file_path and unique_table_file_path:
1543
 
 
17
  from tools.prompts import initial_table_prompt, prompt2, prompt3, initial_table_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, force_existing_topics_prompt, allow_new_topics_prompt, force_single_topic_prompt, add_existing_topics_assistant_prefill, initial_table_assistant_prefill, structured_summary_prompt
18
  from tools.helper_functions import read_file, put_columns_in_df, wrap_text, initial_clean, load_in_data_file, load_in_file, create_topic_summary_df_from_reference_table, convert_reference_table_to_pivot_table, get_basic_response_data, clean_column_name, load_in_previous_data_files
19
  from tools.llm_funcs import ResponseObject, construct_gemini_generative_model, call_llm_with_markdown_table_checks, create_missing_references_df, calculate_tokens_from_metadata
20
+ from tools.config import RUN_LOCAL_MODEL, AWS_REGION, MAX_COMMENT_CHARS, MAX_OUTPUT_VALIDATION_ATTEMPTS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, RUN_AWS_FUNCTIONS, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_GROUPS, REASONING_SUFFIX
21
  from tools.aws_functions import connect_to_bedrock_runtime
22
 
23
  if RUN_LOCAL_MODEL == "1":
 
31
  deduplication_threshold = DEDUPLICATION_THRESHOLD
32
  max_comment_character_length = MAX_COMMENT_CHARS
33
  random_seed = LLM_SEED
34
+ reasoning_suffix = REASONING_SUFFIX
35
 
36
  # if RUN_AWS_FUNCTIONS == '1':
37
  # bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
38
  # else:
39
+ # bedrock_runtime = list()
40
 
41
 
42
 
 
131
  lines = text.splitlines()
132
 
133
  # Step 1: Identify table structure and process line continuations
134
+ table_rows = list()
135
  current_row = None
136
 
137
  for line in lines:
 
175
  max_columns = max(max_columns, len(cells))
176
 
177
  # Now format each row
178
+ formatted_rows = list()
179
  for row in table_rows:
180
  # Ensure the row starts and ends with pipes
181
  if not row.startswith('|'):
 
355
  - first_run (bool): A boolean indicating if this is the first run through this function in this process. Defaults to False.
356
  - output_folder (str): The name of the folder where output files are saved.
357
  """
358
+ topic_summary_df_out_path = list()
359
  topic_table_out_path = "topic_table_error.csv"
360
  reference_table_out_path = "reference_table_error.csv"
361
  topic_summary_df_out_path = "unique_topic_table_error.csv"
 
391
  log_files_output_paths.append(whole_conversation_path_meta)
392
 
393
  if isinstance(responses[-1], ResponseObject): response_text = responses[-1].text
394
+ elif "choices" in responses[-1]: response_text = responses[-1]['choices'][0]['message']['content'] #responses[-1]["choices"][0]['text']
395
  else: response_text = responses[-1].text
396
 
397
  # Convert response text to a markdown table
 
427
  topic_table_out_path = output_folder + batch_file_path_details + "_topic_table_" + model_choice_clean + ".csv"
428
 
429
  # Table to map references to topics
430
+ reference_data = list()
431
 
432
  batch_basic_response_df["Reference"] = batch_basic_response_df["Reference"].astype(str)
433
 
 
615
  if force_zero_shot_radio == "Yes":
616
  zero_shot_topics_gen_topics_list.append("")
617
  zero_shot_topics_subtopics_list.append("No relevant topic")
618
+ zero_shot_topics_description_list.append("")
 
 
 
 
 
 
619
 
620
  # Add description or not
621
  zero_shot_topics_df = pd.DataFrame(data={
 
641
  model_choice:str,
642
  candidate_topics: GradioFileData = None,
643
  latest_batch_completed:int=0,
644
+ out_message:List= list(),
645
+ out_file_paths:List = list(),
646
+ log_files_output_paths:List = list(),
647
  first_loop_state:bool=False,
648
  whole_conversation_metadata_str:str="",
649
  initial_table_prompt:str=initial_table_prompt,
 
658
  time_taken:float = 0,
659
  sentiment_checkbox:str = "Negative, Neutral, or Positive",
660
  force_zero_shot_radio:str = "No",
661
+ in_excel_sheets:List[str] = list(),
662
  force_single_topic_radio:str = "No",
663
  output_folder:str=OUTPUT_FOLDER,
664
  force_single_topic_prompt:str=force_single_topic_prompt,
 
670
  model_name_map:dict=model_name_map,
671
  max_time_for_loop:int=max_time_for_loop,
672
  CHOSEN_LOCAL_MODEL_TYPE:str=CHOSEN_LOCAL_MODEL_TYPE,
673
+ reasoning_suffix:str=reasoning_suffix,
674
  progress=Progress(track_tqdm=True)):
675
 
676
  '''
 
719
  - model_name_map (dict, optional): A dictionary mapping full model name to shortened.
720
  - max_time_for_loop (int, optional): The number of seconds maximum that the function should run for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there).
721
  - CHOSEN_LOCAL_MODEL_TYPE (str, optional): The name of the chosen local model.
722
+ - reasoning_suffix (str, optional): The suffix for the reasoning system prompt.
723
  - progress (Progress): A progress tracker.
724
  '''
725
 
726
  tic = time.perf_counter()
727
+ google_client = list()
728
  google_config = {}
729
  final_time = 0.0
730
+ whole_conversation_metadata = list()
731
  is_error = False
732
  create_revised_general_topics = False
733
+ local_model = list()
734
+ tokenizer = list()
735
  zero_shot_topics_df = pd.DataFrame()
736
  missing_df = pd.DataFrame()
737
  new_reference_df = pd.DataFrame(columns=["Response References", "General topic", "Subtopic", "Sentiment", "Start row of group", "Group" ,"Topic_number", "Summary"])
738
  new_topic_summary_df = pd.DataFrame(columns=["General topic","Subtopic","Sentiment","Group","Number of responses","Summary"])
739
  new_topic_df = pd.DataFrame()
740
+
741
+ # For Gemma models
742
+ #llama_cpp_prefix = "<start_of_turn>user\n"
743
+ #llama_cpp_suffix = "<end_of_turn>\n<start_of_turn>model\n"
744
+
745
+ # For GPT OSS
746
+ #llama_cpp_prefix = "<|start|>assistant<|channel|>analysis<|message|>\n"
747
+ #llama_cpp_suffix = "<|start|>assistant<|channel|>final<|message|>"
748
+
749
+ # Blank
750
+ llama_cpp_prefix = ""
751
+ llama_cpp_suffix = ""
752
 
753
  #print("output_folder:", output_folder)
754
 
 
774
  print("This is the first time through the loop, resetting latest_batch_completed to 0")
775
  if (latest_batch_completed == 999) | (latest_batch_completed == 0):
776
  latest_batch_completed = 0
777
+ out_message = list()
778
+ out_file_paths = list()
779
  final_time = 0
780
 
781
  if (model_choice == CHOSEN_LOCAL_MODEL_TYPE) & (RUN_LOCAL_MODEL == "1"):
 
796
  out_message = [out_message]
797
 
798
  if not out_file_paths:
799
+ out_file_paths = list()
800
 
801
 
802
  if "anthropic.claude-3-sonnet" in model_choice and file_data.shape[1] > 300:
 
820
  simplified_csv_table_path, normalised_simple_markdown_table, start_row, end_row, batch_basic_response_df = data_file_to_markdown_table(file_data, file_name, chosen_cols, latest_batch_completed, batch_size)
821
 
822
  # Conversation history
823
+ conversation_history = list()
824
 
825
  # If the latest batch of responses contains at least one instance of text
826
  if not batch_basic_response_df.empty:
 
934
  except Exception as e:
935
  print(f"Error writing prompt to file {formatted_prompt_output_path}: {e}")
936
 
937
+ #if "Local" in model_source:
938
+ # summary_prompt_list = [full_prompt] # Includes system prompt
939
+ #else:
940
+ summary_prompt_list = [formatted_summary_prompt]
941
+
942
+ if "Local" in model_source and reasoning_suffix: formatted_system_prompt = formatted_system_prompt + "\n" + reasoning_suffix
943
 
944
+ conversation_history = list()
945
+ whole_conversation = list()
946
 
947
  # Process requests to large language model
948
  responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=add_existing_topics_assistant_prefill, master = True)
 
956
 
957
  if isinstance(responses[-1], ResponseObject):
958
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
959
+ #f.write(responses[-1].text)
960
+ f.write(response_text)
961
  elif "choices" in responses[-1]:
962
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
963
+ #f.write(responses[-1]["choices"][0]['text'])
964
+ f.write(response_text)
965
  else:
966
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
967
+ #f.write(responses[-1].text)
968
+ f.write(response_text)
969
 
970
  except Exception as e:
971
  print("Error in returning model response:", e)
 
1026
  if prompt3: formatted_prompt3 = prompt3.format(response_table=normalised_simple_markdown_table, sentiment_choices=sentiment_prompt)
1027
  else: formatted_prompt3 = prompt3
1028
 
1029
+ #if "Local" in model_source:
1030
+ #formatted_initial_table_prompt = llama_cpp_prefix + formatted_initial_table_system_prompt + "\n" + formatted_initial_table_prompt + llama_cpp_suffix
1031
+ #formatted_prompt2 = llama_cpp_prefix + formatted_initial_table_system_prompt + "\n" + formatted_prompt2 + llama_cpp_suffix
1032
+ #formatted_prompt3 = llama_cpp_prefix + formatted_initial_table_system_prompt + "\n" + formatted_prompt3 + llama_cpp_suffix
1033
+
1034
+ batch_prompts = [formatted_initial_table_prompt, formatted_prompt2, formatted_prompt3][:number_of_prompts_used] # Adjust this list to send fewer requests
1035
 
1036
+ if "Local" in model_source and reasoning_suffix: formatted_initial_table_system_prompt = formatted_initial_table_system_prompt + "\n" + reasoning_suffix
1037
 
1038
+ whole_conversation = list()
1039
 
1040
  responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, formatted_initial_table_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=initial_table_assistant_prefill)
1041
 
1042
  topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, topic_table_df, reference_df, new_topic_summary_df, batch_file_path_details, is_error = write_llm_output_and_logs(responses, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, group_name, produce_structures_summary_radio, first_run=True, output_folder=output_folder)
1043
 
1044
  # If error in table parsing, leave function
1045
+ if is_error == True: raise Exception("Error in output table parsing")
 
 
 
 
 
1046
 
1047
  topic_table_df.to_csv(topic_table_out_path, index=None)
1048
  out_file_paths.append(topic_table_out_path)
 
1059
  new_topic_summary_df.to_csv(topic_summary_df_out_path, index=None)
1060
  out_file_paths.append(topic_summary_df_out_path)
1061
 
 
1062
 
1063
  whole_conversation_metadata.append(whole_conversation_metadata_str)
1064
  whole_conversation_metadata_str = '. '.join(whole_conversation_metadata)
1065
 
 
1066
  # Write final output to text file for logging purposes
1067
  try:
1068
  final_table_output_path = output_folder + batch_file_path_details + "_full_response_" + model_choice_clean + ".txt"
1069
 
1070
  if isinstance(responses[-1], ResponseObject):
1071
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
1072
+ #f.write(responses[-1].text)
1073
+ f.write(response_text)
1074
  elif "choices" in responses[-1]:
1075
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
1076
+ #f.write(responses[-1]["choices"][0]['text'])
1077
+ f.write(response_text)
1078
  else:
1079
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
1080
+ #f.write(responses[-1].text)
1081
+ f.write(response_text)
1082
 
1083
+ except Exception as e: print("Error in returning model response:", e)
 
1084
 
1085
  new_topic_df = topic_table_df
1086
  new_reference_df = reference_df
 
1125
  # Set to a very high number so as not to mess with subsequent file processing by the user
1126
  #latest_batch_completed = 999
1127
 
1128
+ join_file_paths = list()
1129
 
1130
  toc = time.perf_counter()
1131
  final_time = (toc - tic) + time_taken
 
1201
 
1202
  print("latest_batch_completed at end of batch iterations to return is", latest_batch_completed)
1203
 
1204
+ print("whole_conversation_metadata_str at end of batch iterations to return is", whole_conversation_metadata_str)
1205
+
1206
  return unique_table_df_display_table_markdown, existing_topics_table, final_out_topic_summary_df, existing_reference_df, final_out_file_paths, final_out_file_paths, latest_batch_completed, log_files_output_paths, log_files_output_paths, whole_conversation_metadata_str, final_time, final_out_file_paths, final_out_file_paths, modifiable_topic_summary_df, final_out_file_paths, join_file_paths, existing_reference_df_pivot, missing_df
1207
 
1208
+ print("whole_conversation_metadata_str at end of batch iterations to return is", whole_conversation_metadata_str)
1209
+
1210
  return unique_table_df_display_table_markdown, existing_topics_table, existing_topic_summary_df, existing_reference_df, out_file_paths, out_file_paths, latest_batch_completed, log_files_output_paths, log_files_output_paths, whole_conversation_metadata_str, final_time, out_file_paths, out_file_paths, modifiable_topic_summary_df, out_file_paths, join_file_paths, existing_reference_df_pivot, missing_df # gr.Dataframe(value=modifiable_topic_summary_df, headers=None, col_count=(modifiable_topic_summary_df.shape[1], "fixed"), row_count = (modifiable_topic_summary_df.shape[0], "fixed"), visible=True, type="pandas"),
1211
 
1212
  def wrapper_extract_topics_per_column_value(
 
1244
  context_textbox: str = "",
1245
  sentiment_checkbox: str = "Negative, Neutral, or Positive",
1246
  force_zero_shot_radio: str = "No",
1247
+ in_excel_sheets: List[str] = list(),
1248
  force_single_topic_radio: str = "No",
1249
  produce_structures_summary_radio: str = "No",
1250
  aws_access_key_textbox:str="",
 
1286
  acc_missing_df = pd.DataFrame()
1287
 
1288
  # Lists are extended
1289
+ acc_out_file_paths = list()
1290
+ acc_log_files_output_paths = list()
1291
+ acc_join_file_paths = list() # join_file_paths seems to be overwritten, so maybe last one or extend? Let's extend.
1292
 
1293
  # Single value outputs - typically the last one is most relevant, or sum for time
1294
  acc_markdown_output = initial_unique_table_df_display_table_markdown
 
1360
  num_batches=current_num_batches,
1361
  latest_batch_completed=current_latest_batch_completed, # Reset for each new segment's internal batching
1362
  first_loop_state=current_first_loop_state, # True only for the very first iteration of wrapper
1363
+ out_message= list(), # Fresh for each call
1364
+ out_file_paths= list(),# Fresh for each call
1365
+ log_files_output_paths= list(),# Fresh for each call
1366
  whole_conversation_metadata_str="", # Fresh for each call
1367
  time_taken=0, # Time taken for this specific call, wrapper sums it.
1368
  # Pass through other parameters
 
1457
  unique_table_df_display_table = acc_topic_summary_df.apply(lambda col: col.map(lambda x: wrap_text(x, max_text_length=500)))
1458
  acc_markdown_output = unique_table_df_display_table[["General topic", "Subtopic", "Sentiment", "Number of responses", "Summary", "Group"]].to_markdown(index=False)
1459
 
1460
+ print("acc_whole_conversation_metadata at end of wrapper is", acc_whole_conversation_metadata)
1461
+
1462
  acc_input_tokens, acc_output_tokens, acc_number_of_calls = calculate_tokens_from_metadata(acc_whole_conversation_metadata, model_choice, model_name_map)
1463
 
1464
+ print("acc_input_tokens, acc_output_tokens, acc_number_of_calls at end of wrapper is", acc_input_tokens, acc_output_tokens, acc_number_of_calls)
1465
+
1466
  print(f"\nWrapper finished processing all segments. Total time: {acc_total_time_taken:.2f}s")
1467
 
1468
  # The return signature should match extract_topics.
 
1548
  reference_file_path = os.path.basename(reference_files[0]) if reference_files else None
1549
  unique_table_file_path = os.path.basename(unique_files[0]) if unique_files else None
1550
 
1551
+ output_file_list = list()
1552
 
1553
  if reference_file_path and unique_table_file_path:
1554
 
tools/llm_funcs.py CHANGED
@@ -7,6 +7,7 @@ import pandas as pd
7
  import json
8
  from tqdm import tqdm
9
  from huggingface_hub import hf_hub_download
 
10
  from typing import List, Tuple, TypeVar
11
  from google import genai as ai
12
  from google.genai import types
@@ -17,12 +18,18 @@ torch.cuda.empty_cache()
17
 
18
  model_type = None # global variable setup
19
  full_text = "" # Define dummy source text (full text) just to enable highlight function to load
20
- model = [] # Define empty list for model functions to run
21
- tokenizer = [] #[] # Define empty list for model functions to run
22
 
23
- from tools.config import RUN_AWS_FUNCTIONS, AWS_REGION, LLM_TEMPERATURE, LLM_TOP_K, LLM_TOP_P, LLM_REPETITION_PENALTY, LLM_LAST_N_TOKENS, LLM_MAX_NEW_TOKENS, LLM_SEED, LLM_RESET, LLM_STREAM, LLM_THREADS, LLM_BATCH_SIZE, LLM_CONTEXT_LENGTH, LLM_SAMPLE, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, MAX_COMMENT_CHARS, RUN_LOCAL_MODEL, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, HF_TOKEN, LLM_SEED, LLM_MAX_GPU_LAYERS
24
  from tools.prompts import initial_table_assistant_prefill
25
 
 
 
 
 
 
 
26
  # Both models are loaded on app initialisation so that users don't have to wait for the models to be downloaded
27
  # Check for torch cuda
28
  print("Is CUDA enabled? ", torch.cuda.is_available())
@@ -54,13 +61,9 @@ batch_size_default = BATCH_SIZE_DEFAULT
54
  deduplication_threshold = DEDUPLICATION_THRESHOLD
55
  max_comment_character_length = MAX_COMMENT_CHARS
56
 
57
- # if RUN_AWS_FUNCTIONS == '1':
58
- # bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
59
- # else:
60
- # bedrock_runtime = []
61
 
62
  if not LLM_THREADS:
63
- threads = torch.get_num_threads() # 8
64
  else: threads = LLM_THREADS
65
  print("CPU threads:", threads)
66
 
@@ -76,7 +79,8 @@ else: sample = False
76
  temperature = LLM_TEMPERATURE
77
  top_k = LLM_TOP_K
78
  top_p = LLM_TOP_P
79
- repetition_penalty = LLM_REPETITION_PENALTY # Mild repetition penalty to prevent repeating table rows
 
80
  last_n_tokens = LLM_LAST_N_TOKENS
81
  max_new_tokens: int = LLM_MAX_NEW_TOKENS
82
  seed: int = LLM_SEED
@@ -86,6 +90,7 @@ threads: int = threads
86
  batch_size:int = LLM_BATCH_SIZE
87
  context_length:int = LLM_CONTEXT_LENGTH
88
  sample = LLM_SAMPLE
 
89
 
90
  class llama_cpp_init_config_gpu:
91
  def __init__(self,
@@ -122,6 +127,7 @@ cpu_config = llama_cpp_init_config_cpu()
122
  class LlamaCPPGenerationConfig:
123
  def __init__(self, temperature=temperature,
124
  top_k=top_k,
 
125
  top_p=top_p,
126
  repeat_penalty=repetition_penalty,
127
  seed=seed,
@@ -164,6 +170,7 @@ def get_model_path(repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model
164
  if hf_token:
165
  downloaded_model_path = hf_hub_download(repo_id=repo_id, token=hf_token, filename=model_filename)
166
  else:
 
167
  downloaded_model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
168
 
169
  return downloaded_model_path
@@ -185,13 +192,16 @@ def load_model(local_model_type:str=CHOSEN_LOCAL_MODEL_TYPE, gpu_layers:int=gpu_
185
 
186
  try:
187
  print("GPU load variables:" , vars(gpu_config))
188
- llama_model = Llama(model_path=model_path, type_k=8, type_v=8, flash_attn=True, **vars(gpu_config)) # type_k=8, type_v = 8, flash_attn=True,
189
-
 
 
 
190
  except Exception as e:
191
  print("GPU load failed due to:", e)
192
- llama_model = Llama(model_path=model_path, type_k=8, **vars(cpu_config)) # type_v = 8, flash_attn=True,
193
 
194
- print("Loading with", gpu_config.n_gpu_layers, "model layers sent to GPU. And a maximum context length of ", gpu_config.n_ctx)
195
 
196
  # CPU mode
197
  else:
@@ -202,11 +212,14 @@ def load_model(local_model_type:str=CHOSEN_LOCAL_MODEL_TYPE, gpu_layers:int=gpu_
202
  gpu_config.update_context(max_context_length)
203
  cpu_config.update_context(max_context_length)
204
 
205
- llama_model = Llama(model_path=model_path, type_k=8, **vars(cpu_config)) # type_v = 8, flash_attn=True,
 
 
 
206
 
207
- print("Loading with", cpu_config.n_gpu_layers, "model layers sent to GPU. And a maximum context length of ", gpu_config.n_ctx)
208
 
209
- tokenizer = []
210
 
211
  print("Finished loading model:", local_model_type)
212
  print("GPU layers assigned to cuda:", gpu_layers)
@@ -244,6 +257,47 @@ def call_llama_cpp_model(formatted_string:str, gen_config:str, model=model):
244
 
245
  return output
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  # This function is not used in this app
248
  def llama_cpp_streaming(history, full_prompt, temperature=temperature):
249
 
@@ -392,7 +446,7 @@ def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tok
392
  return response
393
 
394
  # Function to send a request and update history
395
- def send_request(prompt: str, conversation_history: List[dict], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, system_prompt: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, local_model=[], assistant_prefill = "", progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
396
  """
397
  This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
398
  It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
@@ -421,16 +475,9 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
421
  for i in progress_bar:
422
  try:
423
  print("Calling Gemini model, attempt", i + 1)
424
- #print("google_client:", google_client)
425
- #print("model_choice:", model_choice)
426
- #print("full_prompt:", full_prompt)
427
- #print("generation_config:", config)
428
 
429
  response = google_client.models.generate_content(model=model_choice, contents=full_prompt, config=config)
430
 
431
- #progress_bar.close()
432
- #tqdm._instances.clear()
433
-
434
  print("Successful call to Gemini model.")
435
  break
436
  except Exception as e:
@@ -447,9 +494,6 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
447
  print("Calling AWS Claude model, attempt", i + 1)
448
  response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice, bedrock_runtime=bedrock_runtime, assistant_prefill=assistant_prefill)
449
 
450
- #progress_bar.close()
451
- #tqdm._instances.clear()
452
-
453
  print("Successful call to Claude model.")
454
  break
455
  except Exception as e:
@@ -468,10 +512,7 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
468
  gen_config = LlamaCPPGenerationConfig()
469
  gen_config.update_temp(temperature)
470
 
471
- response = call_llama_cpp_model(prompt, gen_config, model=local_model)
472
-
473
- #progress_bar.close()
474
- #tqdm._instances.clear()
475
 
476
  print("Successful call to local model. Response:", response)
477
  break
@@ -492,7 +533,7 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
492
  if isinstance(response, ResponseObject):
493
  conversation_history.append({'role': 'assistant', 'parts': [response.text]})
494
  elif 'choices' in response:
495
- conversation_history.append({'role': 'assistant', 'parts': [response['choices'][0]['text']]})
496
  else:
497
  conversation_history.append({'role': 'assistant', 'parts': [response.text]})
498
 
@@ -501,7 +542,7 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
501
 
502
  return response, conversation_history
503
 
504
- def process_requests(prompts: List[str], system_prompt: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, batch_no:int = 1, local_model = [], master:bool = False, assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
505
  """
506
  Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
507
 
@@ -525,21 +566,19 @@ def process_requests(prompts: List[str], system_prompt: str, conversation_histor
525
  Returns:
526
  Tuple[List[ResponseObject], List[dict], List[str], List[str]]: A tuple containing the list of responses, the updated conversation history, the updated whole conversation, and the updated whole conversation metadata.
527
  """
528
- responses = []
529
 
530
  # Clear any existing progress bars
531
  tqdm._instances.clear()
532
 
533
  for prompt in prompts:
534
 
535
- #print("prompt to LLM:", prompt)
536
-
537
  response, conversation_history = send_request(prompt, conversation_history, google_client=google_client, config=config, model_choice=model_choice, system_prompt=system_prompt, temperature=temperature, local_model=local_model, assistant_prefill=assistant_prefill, bedrock_runtime=bedrock_runtime, model_source=model_source)
538
 
539
  if isinstance(response, ResponseObject):
540
  response_text = response.text
541
  elif 'choices' in response:
542
- response_text = response['choices'][0]['text']
543
  else:
544
  response_text = response.text
545
 
@@ -550,9 +589,10 @@ def process_requests(prompts: List[str], system_prompt: str, conversation_histor
550
 
551
  # Create conversation metadata
552
  if master == False:
553
- whole_conversation_metadata.append(f"Query batch {batch_no} prompt {len(responses)} metadata:")
554
  else:
555
- whole_conversation_metadata.append(f"Query summary metadata:")
 
556
 
557
  if not isinstance(response, str):
558
  try:
@@ -571,6 +611,9 @@ def process_requests(prompts: List[str], system_prompt: str, conversation_histor
571
  elif "gemini" in model_choice:
572
  whole_conversation_metadata.append(str(response.usage_metadata))
573
  else:
 
 
 
574
  whole_conversation_metadata.append(str(response['usage']))
575
  except KeyError as e:
576
  print(f"Key error: {e} - Check the structure of response.usage_metadata")
@@ -638,10 +681,14 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
638
  call_temperature, bedrock_runtime, model_source, reported_batch_no, local_model, master=master, assistant_prefill=assistant_prefill
639
  )
640
 
641
- if model_choice != CHOSEN_LOCAL_MODEL_TYPE:
642
- stripped_response = responses[-1].text.strip()
643
- else:
644
- stripped_response = responses[-1]['choices'][0]['text'].strip()
 
 
 
 
645
 
646
  # Check if response meets our criteria (length and contains table)
647
  if len(stripped_response) > 120 and '|' in stripped_response:
@@ -656,7 +703,7 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
656
  else: # This runs if no break occurred (all attempts failed)
657
  print(f"Failed to get valid response after {MAX_OUTPUT_VALIDATION_ATTEMPTS} attempts")
658
 
659
- return responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text
660
 
661
  def create_missing_references_df(basic_response_df: pd.DataFrame, existing_reference_df: pd.DataFrame) -> pd.DataFrame:
662
  """
 
7
  import json
8
  from tqdm import tqdm
9
  from huggingface_hub import hf_hub_download
10
+ from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
11
  from typing import List, Tuple, TypeVar
12
  from google import genai as ai
13
  from google.genai import types
 
18
 
19
  model_type = None # global variable setup
20
  full_text = "" # Define dummy source text (full text) just to enable highlight function to load
21
+ model = list() # Define empty list for model functions to run
22
+ tokenizer = list() #[] # Define empty list for model functions to run
23
 
24
+ from tools.config import RUN_AWS_FUNCTIONS, AWS_REGION, LLM_TEMPERATURE, LLM_TOP_K, LLM_MIN_P, LLM_TOP_P, LLM_REPETITION_PENALTY, LLM_LAST_N_TOKENS, LLM_MAX_NEW_TOKENS, LLM_SEED, LLM_RESET, LLM_STREAM, LLM_THREADS, LLM_BATCH_SIZE, LLM_CONTEXT_LENGTH, LLM_SAMPLE, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, MAX_COMMENT_CHARS, RUN_LOCAL_MODEL, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, HF_TOKEN, LLM_SEED, LLM_MAX_GPU_LAYERS, SPECULATIVE_DECODING, NUM_PRED_TOKENS
25
  from tools.prompts import initial_table_assistant_prefill
26
 
27
+ if SPECULATIVE_DECODING == "True": SPECULATIVE_DECODING = True
28
+ else: SPECULATIVE_DECODING = False
29
+
30
+ if isinstance(NUM_PRED_TOKENS, str): NUM_PRED_TOKENS = int(NUM_PRED_TOKENS)
31
+ else: NUM_PRED_TOKENS = NUM_PRED_TOKENS
32
+
33
  # Both models are loaded on app initialisation so that users don't have to wait for the models to be downloaded
34
  # Check for torch cuda
35
  print("Is CUDA enabled? ", torch.cuda.is_available())
 
61
  deduplication_threshold = DEDUPLICATION_THRESHOLD
62
  max_comment_character_length = MAX_COMMENT_CHARS
63
 
 
 
 
 
64
 
65
  if not LLM_THREADS:
66
+ threads = torch.get_num_threads()
67
  else: threads = LLM_THREADS
68
  print("CPU threads:", threads)
69
 
 
79
  temperature = LLM_TEMPERATURE
80
  top_k = LLM_TOP_K
81
  top_p = LLM_TOP_P
82
+ min_p = LLM_MIN_P
83
+ repetition_penalty = LLM_REPETITION_PENALTY
84
  last_n_tokens = LLM_LAST_N_TOKENS
85
  max_new_tokens: int = LLM_MAX_NEW_TOKENS
86
  seed: int = LLM_SEED
 
90
  batch_size:int = LLM_BATCH_SIZE
91
  context_length:int = LLM_CONTEXT_LENGTH
92
  sample = LLM_SAMPLE
93
+ speculative_decoding = SPECULATIVE_DECODING
94
 
95
  class llama_cpp_init_config_gpu:
96
  def __init__(self,
 
127
  class LlamaCPPGenerationConfig:
128
  def __init__(self, temperature=temperature,
129
  top_k=top_k,
130
+ min_p=min_p,
131
  top_p=top_p,
132
  repeat_penalty=repetition_penalty,
133
  seed=seed,
 
170
  if hf_token:
171
  downloaded_model_path = hf_hub_download(repo_id=repo_id, token=hf_token, filename=model_filename)
172
  else:
173
+ print("No HF token found, downloading model from Hugging Face Hub without token")
174
  downloaded_model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
175
 
176
  return downloaded_model_path
 
192
 
193
  try:
194
  print("GPU load variables:" , vars(gpu_config))
195
+ if speculative_decoding:
196
+ llama_model = Llama(model_path=model_path, type_k=8, type_v=8, flash_attn=True, draft_model=LlamaPromptLookupDecoding(num_pred_tokens=NUM_PRED_TOKENS), **vars(gpu_config))
197
+ else:
198
+ llama_model = Llama(model_path=model_path, type_k=8, type_v=8, flash_attn=True, **vars(gpu_config))
199
+
200
  except Exception as e:
201
  print("GPU load failed due to:", e)
202
+ llama_model = Llama(model_path=model_path, type_k=8, **vars(cpu_config))
203
 
204
+ print("Loading with", gpu_config.n_gpu_layers, "model layers sent to GPU and a maximum context length of", gpu_config.n_ctx)
205
 
206
  # CPU mode
207
  else:
 
212
  gpu_config.update_context(max_context_length)
213
  cpu_config.update_context(max_context_length)
214
 
215
+ if speculative_decoding:
216
+ llama_model = Llama(model_path=model_path, type_k=8, type_v=8, flash_attn=True, draft_model=LlamaPromptLookupDecoding(num_pred_tokens=NUM_PRED_TOKENS), **vars(gpu_config))
217
+ else:
218
+ llama_model = Llama(model_path=model_path, type_k=8, **vars(cpu_config))
219
 
220
+ print("Loading with", cpu_config.n_gpu_layers, "model layers sent to GPU and a maximum context length of", gpu_config.n_ctx)
221
 
222
+ tokenizer = list()
223
 
224
  print("Finished loading model:", local_model_type)
225
  print("GPU layers assigned to cuda:", gpu_layers)
 
257
 
258
  return output
259
 
260
+ def call_llama_cpp_chatmodel(formatted_string:str, system_prompt:str, gen_config:LlamaCPPGenerationConfig, model=model):
261
+ """
262
+ Calls your Llama.cpp chat model with a formatted user message and system prompt,
263
+ using generation parameters from the LlamaCPPGenerationConfig object.
264
+
265
+ Args:
266
+ formatted_string (str): The formatted input text for the user's message.
267
+ system_prompt (str): The system-level instructions for the model.
268
+ gen_config (LlamaCPPGenerationConfig): An object containing generation parameters.
269
+ model (Llama): The Llama.cpp model instance to use for chat completion.
270
+ """
271
+ # Extracting parameters from the gen_config object
272
+ temperature = gen_config.temperature
273
+ top_k = gen_config.top_k
274
+ top_p = gen_config.top_p
275
+ repeat_penalty = gen_config.repeat_penalty
276
+ seed = gen_config.seed
277
+ max_tokens = gen_config.max_tokens
278
+ stream = gen_config.stream
279
+
280
+ # Now you can call your model directly, passing the parameters:
281
+ output = model.create_chat_completion(
282
+ messages=[
283
+ {"role": "system", "content": system_prompt},
284
+ {
285
+ "role": "user",
286
+ "content": formatted_string
287
+ }
288
+ ],
289
+ temperature=temperature,
290
+ top_k=top_k,
291
+ top_p=top_p,
292
+ repeat_penalty=repeat_penalty,
293
+ seed=seed,
294
+ max_tokens=max_tokens,
295
+ stream=stream
296
+ #stop=["<|eot_id|>", "\n\n"]
297
+ )
298
+
299
+ return output
300
+
301
  # This function is not used in this app
302
  def llama_cpp_streaming(history, full_prompt, temperature=temperature):
303
 
 
446
  return response
447
 
448
  # Function to send a request and update history
449
+ def send_request(prompt: str, conversation_history: List[dict], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, system_prompt: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, local_model= list(), assistant_prefill = "", progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
450
  """
451
  This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
452
  It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
 
475
  for i in progress_bar:
476
  try:
477
  print("Calling Gemini model, attempt", i + 1)
 
 
 
 
478
 
479
  response = google_client.models.generate_content(model=model_choice, contents=full_prompt, config=config)
480
 
 
 
 
481
  print("Successful call to Gemini model.")
482
  break
483
  except Exception as e:
 
494
  print("Calling AWS Claude model, attempt", i + 1)
495
  response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice, bedrock_runtime=bedrock_runtime, assistant_prefill=assistant_prefill)
496
 
 
 
 
497
  print("Successful call to Claude model.")
498
  break
499
  except Exception as e:
 
512
  gen_config = LlamaCPPGenerationConfig()
513
  gen_config.update_temp(temperature)
514
 
515
+ response = call_llama_cpp_chatmodel(prompt, system_prompt, gen_config, model=local_model)
 
 
 
516
 
517
  print("Successful call to local model. Response:", response)
518
  break
 
533
  if isinstance(response, ResponseObject):
534
  conversation_history.append({'role': 'assistant', 'parts': [response.text]})
535
  elif 'choices' in response:
536
+ conversation_history.append({'role': 'assistant', 'parts': [response['choices'][0]['message']['content']]}) #response['choices'][0]['text']]})
537
  else:
538
  conversation_history.append({'role': 'assistant', 'parts': [response.text]})
539
 
 
542
 
543
  return response, conversation_history
544
 
545
+ def process_requests(prompts: List[str], system_prompt: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, batch_no:int = 1, local_model = list(), master:bool = False, assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
546
  """
547
  Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
548
 
 
566
  Returns:
567
  Tuple[List[ResponseObject], List[dict], List[str], List[str]]: A tuple containing the list of responses, the updated conversation history, the updated whole conversation, and the updated whole conversation metadata.
568
  """
569
+ responses = list()
570
 
571
  # Clear any existing progress bars
572
  tqdm._instances.clear()
573
 
574
  for prompt in prompts:
575
 
 
 
576
  response, conversation_history = send_request(prompt, conversation_history, google_client=google_client, config=config, model_choice=model_choice, system_prompt=system_prompt, temperature=temperature, local_model=local_model, assistant_prefill=assistant_prefill, bedrock_runtime=bedrock_runtime, model_source=model_source)
577
 
578
  if isinstance(response, ResponseObject):
579
  response_text = response.text
580
  elif 'choices' in response:
581
+ response_text = response['choices'][0]['message']['content'] # response['choices'][0]['text']
582
  else:
583
  response_text = response.text
584
 
 
589
 
590
  # Create conversation metadata
591
  if master == False:
592
+ whole_conversation_metadata.append(f"Batch {batch_no}:")
593
  else:
594
+ #whole_conversation_metadata.append(f"Query summary metadata:")
595
+ whole_conversation_metadata.append(f"Batch {batch_no}:")
596
 
597
  if not isinstance(response, str):
598
  try:
 
611
  elif "gemini" in model_choice:
612
  whole_conversation_metadata.append(str(response.usage_metadata))
613
  else:
614
+ print("Adding usage metadata to whole conversation metadata:", response['usage'])
615
+ output_tokens = response['usage'].get('completion_tokens', 0)
616
+ input_tokens = response['usage'].get('prompt_tokens', 0)
617
  whole_conversation_metadata.append(str(response['usage']))
618
  except KeyError as e:
619
  print(f"Key error: {e} - Check the structure of response.usage_metadata")
 
681
  call_temperature, bedrock_runtime, model_source, reported_batch_no, local_model, master=master, assistant_prefill=assistant_prefill
682
  )
683
 
684
+ #if model_source != "Local":
685
+ #stripped_response = responses[-1].text.strip()
686
+ #stripped_response = response_text.strip()
687
+ #else:
688
+ #stripped_response = response['choices'][0]['message']['content'].strip()
689
+ #stripped_response = response_text.strip()
690
+
691
+ stripped_response = response_text.strip()
692
 
693
  # Check if response meets our criteria (length and contains table)
694
  if len(stripped_response) > 120 and '|' in stripped_response:
 
703
  else: # This runs if no break occurred (all attempts failed)
704
  print(f"Failed to get valid response after {MAX_OUTPUT_VALIDATION_ATTEMPTS} attempts")
705
 
706
+ return responses, conversation_history, whole_conversation, whole_conversation_metadata, stripped_response
707
 
708
  def create_missing_references_df(basic_response_df: pd.DataFrame, existing_reference_df: pd.DataFrame) -> pd.DataFrame:
709
  """
tools/verify_titles.py CHANGED
@@ -101,7 +101,7 @@ def write_llm_output_and_logs_verify(responses: List[ResponseObject],
101
  log_files_output_paths.append(whole_conversation_path_meta)
102
 
103
  if isinstance(responses[-1], ResponseObject): response_text = responses[-1].text
104
- elif "choices" in responses[-1]: response_text = responses[-1]["choices"][0]['text']
105
  else: response_text = responses[-1].text
106
 
107
  # Convert response text to a markdown table
@@ -464,13 +464,16 @@ def verify_titles(in_data_file,
464
 
465
  if isinstance(responses[-1], ResponseObject):
466
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
467
- f.write(responses[-1].text)
 
468
  elif "choices" in responses[-1]:
469
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
470
- f.write(responses[-1]["choices"][0]['text'])
 
471
  else:
472
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
473
- f.write(responses[-1].text)
 
474
 
475
  except Exception as e:
476
  print("Error in returning model response:", e)
@@ -581,15 +584,18 @@ def verify_titles(in_data_file,
581
 
582
  if isinstance(responses[-1], ResponseObject):
583
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
584
- f.write(responses[-1].text)
 
585
  unique_table_df_display_table_markdown = responses[-1].text
586
  elif "choices" in responses[-1]:
587
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
588
- f.write(responses[-1]["choices"][0]['text'])
589
- unique_table_df_display_table_markdown =responses[-1]["choices"][0]['text']
 
590
  else:
591
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
592
- f.write(responses[-1].text)
 
593
  unique_table_df_display_table_markdown = responses[-1].text
594
 
595
  log_files_output_paths.append(final_table_output_path)
 
101
  log_files_output_paths.append(whole_conversation_path_meta)
102
 
103
  if isinstance(responses[-1], ResponseObject): response_text = responses[-1].text
104
+ elif "choices" in responses[-1]: response_text = responses[-1]['choices'][0]['message']['content'] #responses[-1]["choices"][0]['text']
105
  else: response_text = responses[-1].text
106
 
107
  # Convert response text to a markdown table
 
464
 
465
  if isinstance(responses[-1], ResponseObject):
466
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
467
+ #f.write(responses[-1].text)
468
+ f.write(response_text)
469
  elif "choices" in responses[-1]:
470
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
471
+ #f.write(responses[-1]["choices"][0]['text'])
472
+ f.write(response_text)
473
  else:
474
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
475
+ #f.write(responses[-1].text)
476
+ f.write(response_text)
477
 
478
  except Exception as e:
479
  print("Error in returning model response:", e)
 
584
 
585
  if isinstance(responses[-1], ResponseObject):
586
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
587
+ #f.write(responses[-1].text)
588
+ f.write(response_text)
589
  unique_table_df_display_table_markdown = responses[-1].text
590
  elif "choices" in responses[-1]:
591
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
592
+ #f.write(responses[-1]["choices"][0]['text'])
593
+ f.write(response_text)
594
+ unique_table_df_display_table_markdown =responses[-1]["choices"][0]['message']['content'] #responses[-1]["choices"][0]['text']
595
  else:
596
  with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
597
+ #f.write(responses[-1].text)
598
+ f.write(response_text)
599
  unique_table_df_display_table_markdown = responses[-1].text
600
 
601
  log_files_output_paths.append(final_table_output_path)