Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
c61bb70
1
Parent(s):
49faa78
Added GPT-OSS 20b support. Moved to Llama cpp python chat_completion function
Browse files- app.py +17 -21
- requirements_gpu.txt +2 -2
- tools/config.py +32 -13
- tools/custom_csvlogger.py +4 -7
- tools/dedup_summaries.py +1 -1
- tools/llm_api_call.py +82 -71
- tools/llm_funcs.py +91 -44
- tools/verify_titles.py +14 -8
app.py
CHANGED
@@ -109,7 +109,7 @@ with app:
|
|
109 |
master_unique_topics_df_revised_summaries_state = gr.Dataframe(value=pd.DataFrame(), headers=None, col_count=0, row_count = (0, "dynamic"), label="master_unique_topics_df_revised_summaries_state", visible=False, type="pandas")
|
110 |
summarised_output_df = gr.Dataframe(value=pd.DataFrame(), headers=None, col_count=0, row_count = (0, "dynamic"), label="summarised_output_df", visible=False, type="pandas")
|
111 |
summarised_references_markdown = gr.Markdown("", visible=False)
|
112 |
-
summarised_outputs_list = gr.Dropdown(value=
|
113 |
latest_summary_completed_num = gr.Number(0, visible=False)
|
114 |
|
115 |
original_data_file_name_textbox = gr.Textbox(label = "Reference data file name", value="", visible=False)
|
@@ -147,7 +147,7 @@ with app:
|
|
147 |
gr.Markdown("""### Choose a tabular data file (xlsx or csv) of open text to extract topics from.""")
|
148 |
with gr.Row():
|
149 |
model_choice = gr.Dropdown(value = default_model_choice, choices = model_full_names, label="LLM model", multiselect=False)
|
150 |
-
|
151 |
|
152 |
with gr.Accordion("Upload xlsx or csv file", open = True):
|
153 |
in_data_files = gr.File(height=FILE_INPUT_HEIGHT, label="Choose Excel or csv files", file_count= "multiple", file_types=['.xlsx', '.xls', '.csv', '.parquet', '.csv.gz'])
|
@@ -308,6 +308,9 @@ with app:
|
|
308 |
aws_access_key_textbox = gr.Textbox(label="AWS access key", interactive=False, lines=1, type="password")
|
309 |
aws_secret_key_textbox = gr.Textbox(label="AWS secret key", interactive=False, lines=1, type="password")
|
310 |
|
|
|
|
|
|
|
311 |
# Invisible text box to hold the session hash/username just for logging purposes
|
312 |
session_hash_textbox = gr.Textbox(label = "Session hash", value="", visible=False)
|
313 |
|
@@ -315,19 +318,6 @@ with app:
|
|
315 |
total_number_of_batches = gr.Number(label = "Current batch number", value = 1, precision=0, visible=False)
|
316 |
|
317 |
text_output_logs = gr.Textbox(label = "Output summary logs", visible=False)
|
318 |
-
|
319 |
-
# AWS options - not yet implemented
|
320 |
-
# with gr.Tab(label="Advanced options"):
|
321 |
-
# with gr.Accordion(label = "AWS data access", open = True):
|
322 |
-
# aws_password_box = gr.Textbox(label="Password for AWS data access (ask the Data team if you don't have this)")
|
323 |
-
# with gr.Row():
|
324 |
-
# in_aws_file = gr.Dropdown(label="Choose file to load from AWS (only valid for API Gateway app)", choices=["None", "Lambeth borough plan"])
|
325 |
-
# load_aws_data_button = gr.Button(value="Load data from AWS", variant="secondary")
|
326 |
-
|
327 |
-
# aws_log_box = gr.Textbox(label="AWS data load status")
|
328 |
-
|
329 |
-
# ### Loading AWS data ###
|
330 |
-
# load_aws_data_button.click(fn=load_data_from_aws, inputs=[in_aws_file, aws_password_box], outputs=[in_file, aws_log_box])
|
331 |
|
332 |
###
|
333 |
# INTERACTIVE ELEMENT FUNCTIONS
|
@@ -364,7 +354,7 @@ with app:
|
|
364 |
display_topic_table_markdown,
|
365 |
original_data_file_name_textbox,
|
366 |
total_number_of_batches,
|
367 |
-
|
368 |
temperature_slide,
|
369 |
in_colnames,
|
370 |
model_choice,
|
@@ -411,7 +401,7 @@ with app:
|
|
411 |
output_tokens_num,
|
412 |
number_of_calls_num],
|
413 |
api_name="extract_topics")
|
414 |
-
|
415 |
###
|
416 |
# DEDUPLICATION AND SUMMARISATION FUNCTIONS
|
417 |
###
|
@@ -430,14 +420,14 @@ with app:
|
|
430 |
success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
|
431 |
success(load_in_previous_data_files, inputs=[summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
|
432 |
success(sample_reference_table_summaries, inputs=[master_reference_df_state, random_seed], outputs=[summary_reference_table_sample_state, summarised_references_markdown], api_name="sample_summaries").\
|
433 |
-
success(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice,
|
434 |
|
435 |
-
# latest_summary_completed_num.change(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice,
|
436 |
|
437 |
# SUMMARISE WHOLE TABLE PAGE
|
438 |
overall_summarise_previous_data_btn.click(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
|
439 |
success(load_in_previous_data_files, inputs=[overall_summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
|
440 |
-
success(overall_summary, inputs=[master_unique_topics_df_state, model_choice,
|
441 |
|
442 |
###
|
443 |
# CONTINUE PREVIOUS TOPIC EXTRACTION PAGE
|
@@ -512,7 +502,13 @@ with app:
|
|
512 |
usage_callback.setup([session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox, input_tokens_num,
|
513 |
output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], USAGE_LOGS_FOLDER)
|
514 |
|
515 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
success(fn = upload_file_to_s3, inputs=[usage_logs_state, usage_s3_logs_loc_state, s3_log_bucket_name, aws_access_key_textbox, aws_secret_key_textbox], outputs=[s3_logs_output_textbox])
|
517 |
|
518 |
# User submitted feedback
|
|
|
109 |
master_unique_topics_df_revised_summaries_state = gr.Dataframe(value=pd.DataFrame(), headers=None, col_count=0, row_count = (0, "dynamic"), label="master_unique_topics_df_revised_summaries_state", visible=False, type="pandas")
|
110 |
summarised_output_df = gr.Dataframe(value=pd.DataFrame(), headers=None, col_count=0, row_count = (0, "dynamic"), label="summarised_output_df", visible=False, type="pandas")
|
111 |
summarised_references_markdown = gr.Markdown("", visible=False)
|
112 |
+
summarised_outputs_list = gr.Dropdown(value= list(), choices= list(), visible=False, label="List of summarised outputs", allow_custom_value=True)
|
113 |
latest_summary_completed_num = gr.Number(0, visible=False)
|
114 |
|
115 |
original_data_file_name_textbox = gr.Textbox(label = "Reference data file name", value="", visible=False)
|
|
|
147 |
gr.Markdown("""### Choose a tabular data file (xlsx or csv) of open text to extract topics from.""")
|
148 |
with gr.Row():
|
149 |
model_choice = gr.Dropdown(value = default_model_choice, choices = model_full_names, label="LLM model", multiselect=False)
|
150 |
+
|
151 |
|
152 |
with gr.Accordion("Upload xlsx or csv file", open = True):
|
153 |
in_data_files = gr.File(height=FILE_INPUT_HEIGHT, label="Choose Excel or csv files", file_count= "multiple", file_types=['.xlsx', '.xls', '.csv', '.parquet', '.csv.gz'])
|
|
|
308 |
aws_access_key_textbox = gr.Textbox(label="AWS access key", interactive=False, lines=1, type="password")
|
309 |
aws_secret_key_textbox = gr.Textbox(label="AWS secret key", interactive=False, lines=1, type="password")
|
310 |
|
311 |
+
with gr.Accordion("Enter Gemini API keys", open = False):
|
312 |
+
google_api_key_textbox = gr.Textbox(value = GEMINI_API_KEY, label="Enter Gemini API key (only if using Google API models)", lines=1, type="password")
|
313 |
+
|
314 |
# Invisible text box to hold the session hash/username just for logging purposes
|
315 |
session_hash_textbox = gr.Textbox(label = "Session hash", value="", visible=False)
|
316 |
|
|
|
318 |
total_number_of_batches = gr.Number(label = "Current batch number", value = 1, precision=0, visible=False)
|
319 |
|
320 |
text_output_logs = gr.Textbox(label = "Output summary logs", visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
|
322 |
###
|
323 |
# INTERACTIVE ELEMENT FUNCTIONS
|
|
|
354 |
display_topic_table_markdown,
|
355 |
original_data_file_name_textbox,
|
356 |
total_number_of_batches,
|
357 |
+
google_api_key_textbox,
|
358 |
temperature_slide,
|
359 |
in_colnames,
|
360 |
model_choice,
|
|
|
401 |
output_tokens_num,
|
402 |
number_of_calls_num],
|
403 |
api_name="extract_topics")
|
404 |
+
|
405 |
###
|
406 |
# DEDUPLICATION AND SUMMARISATION FUNCTIONS
|
407 |
###
|
|
|
420 |
success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
|
421 |
success(load_in_previous_data_files, inputs=[summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
|
422 |
success(sample_reference_table_summaries, inputs=[master_reference_df_state, random_seed], outputs=[summary_reference_table_sample_state, summarised_references_markdown], api_name="sample_summaries").\
|
423 |
+
success(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio, output_folder_state, context_textbox, aws_access_key_textbox, aws_secret_key_textbox], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output, overall_summarisation_input_files, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number], api_name="summarise_topics")
|
424 |
|
425 |
+
# latest_summary_completed_num.change(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio, output_folder_state, context_textbox], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output, overall_summarisation_input_files, input_tokens_num, output_tokens_num, number_of_calls_num], scroll_to_output=True)
|
426 |
|
427 |
# SUMMARISE WHOLE TABLE PAGE
|
428 |
overall_summarise_previous_data_btn.click(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
|
429 |
success(load_in_previous_data_files, inputs=[overall_summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
|
430 |
+
success(overall_summary, inputs=[master_unique_topics_df_state, model_choice, google_api_key_textbox, temperature_slide, unique_topics_table_file_name_textbox, output_folder_state, in_colnames, context_textbox, aws_access_key_textbox, aws_secret_key_textbox], outputs=[overall_summary_output_files, overall_summarised_output_markdown, summarised_output_df, conversation_metadata_textbox, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number], scroll_to_output=True, api_name="overall_summary")
|
431 |
|
432 |
###
|
433 |
# CONTINUE PREVIOUS TOPIC EXTRACTION PAGE
|
|
|
502 |
usage_callback.setup([session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox, input_tokens_num,
|
503 |
output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], USAGE_LOGS_FOLDER)
|
504 |
|
505 |
+
def conversation_metadata_textbox_change(textbox_value):
|
506 |
+
print("conversation_metadata_textbox_change:", textbox_value)
|
507 |
+
return textbox_value
|
508 |
+
|
509 |
+
number_of_calls_num.change(conversation_metadata_textbox_change, inputs=[conversation_metadata_textbox], outputs=[conversation_metadata_textbox])
|
510 |
+
|
511 |
+
number_of_calls_num.change(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False, api_name="usage_logs").\
|
512 |
success(fn = upload_file_to_s3, inputs=[usage_logs_state, usage_s3_logs_loc_state, s3_log_bucket_name, aws_access_key_textbox, aws_secret_key_textbox], outputs=[s3_logs_output_textbox])
|
513 |
|
514 |
# User submitted feedback
|
requirements_gpu.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
pandas==2.3.1
|
2 |
-
gradio==5.
|
3 |
transformers==4.55.2
|
4 |
spaces==0.40.0
|
5 |
boto3==1.40.11
|
@@ -17,7 +17,7 @@ python-dotenv==1.1.0
|
|
17 |
# Torch and Llama CPP Python
|
18 |
torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124 # Latest compatible with CUDA 12.4
|
19 |
# For Linux:
|
20 |
-
|
21 |
# For Windows:
|
22 |
#llama-cpp-python==0.3.16 -C cmake.args="-DGGML_CUDA=on" --verbose
|
23 |
# If above doesn't work for Windows, try looking at'windows_install_llama-cpp-python.txt'
|
|
|
1 |
pandas==2.3.1
|
2 |
+
gradio==5.44.0
|
3 |
transformers==4.55.2
|
4 |
spaces==0.40.0
|
5 |
boto3==1.40.11
|
|
|
17 |
# Torch and Llama CPP Python
|
18 |
torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124 # Latest compatible with CUDA 12.4
|
19 |
# For Linux:
|
20 |
+
https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.16-cu124/llama_cpp_python-0.3.16-cp311-cp311-linux_x86_64.whl
|
21 |
# For Windows:
|
22 |
#llama-cpp-python==0.3.16 -C cmake.args="-DGGML_CUDA=on" --verbose
|
23 |
# If above doesn't work for Windows, try looking at'windows_install_llama-cpp-python.txt'
|
tools/config.py
CHANGED
@@ -222,7 +222,7 @@ model_full_names = []
|
|
222 |
model_short_names = []
|
223 |
model_source = []
|
224 |
|
225 |
-
CHOSEN_LOCAL_MODEL_TYPE = get_or_create_env_var("CHOSEN_LOCAL_MODEL_TYPE", "
|
226 |
|
227 |
if RUN_LOCAL_MODEL == "1" and CHOSEN_LOCAL_MODEL_TYPE:
|
228 |
model_full_names.append(CHOSEN_LOCAL_MODEL_TYPE)
|
@@ -252,18 +252,28 @@ model_name_map = {
|
|
252 |
# HF token may or may not be needed for downloading models from Hugging Face
|
253 |
HF_TOKEN = get_or_create_env_var('HF_TOKEN', '')
|
254 |
|
255 |
-
GEMMA2_REPO_ID = get_or_create_env_var("GEMMA2_2B_REPO_ID", "lmstudio-community/gemma-2-2b-it-GGUF")
|
256 |
-
GEMMA2_MODEL_FILE = get_or_create_env_var("GEMMA2_2B_MODEL_FILE", "gemma-2-2b-it-Q8_0.gguf")
|
257 |
-
GEMMA2_MODEL_FOLDER = get_or_create_env_var("GEMMA2_2B_MODEL_FOLDER", "model/gemma")
|
258 |
|
259 |
-
GEMMA3_REPO_ID = get_or_create_env_var("GEMMA3_REPO_ID", "
|
260 |
-
GEMMA3_MODEL_FILE = get_or_create_env_var("GEMMA3_MODEL_FILE", "gemma-3-
|
261 |
GEMMA3_MODEL_FOLDER = get_or_create_env_var("GEMMA3_MODEL_FOLDER", "model/gemma")
|
262 |
|
263 |
-
GEMMA3_4B_REPO_ID = get_or_create_env_var("GEMMA3_4B_REPO_ID", "
|
264 |
-
GEMMA3_4B_MODEL_FILE = get_or_create_env_var("GEMMA3_4B_MODEL_FILE", "gemma-3-4b-it-
|
265 |
GEMMA3_4B_MODEL_FOLDER = get_or_create_env_var("GEMMA3_4B_MODEL_FOLDER", "model/gemma3_4b")
|
266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
if CHOSEN_LOCAL_MODEL_TYPE == "Gemma 2b":
|
268 |
LOCAL_REPO_ID = GEMMA2_REPO_ID
|
269 |
LOCAL_MODEL_FILE = GEMMA2_MODEL_FILE
|
@@ -280,25 +290,34 @@ elif CHOSEN_LOCAL_MODEL_TYPE == "Gemma 3 4B":
|
|
280 |
LOCAL_MODEL_FILE = GEMMA3_4B_MODEL_FILE
|
281 |
LOCAL_MODEL_FOLDER = GEMMA3_4B_MODEL_FOLDER
|
282 |
|
|
|
|
|
|
|
|
|
|
|
283 |
print("CHOSEN_LOCAL_MODEL_TYPE:", CHOSEN_LOCAL_MODEL_TYPE)
|
284 |
print("LOCAL_REPO_ID:", LOCAL_REPO_ID)
|
285 |
print("LOCAL_MODEL_FILE:", LOCAL_MODEL_FILE)
|
286 |
print("LOCAL_MODEL_FOLDER:", LOCAL_MODEL_FOLDER)
|
287 |
|
288 |
-
LLM_MAX_GPU_LAYERS = int(get_or_create_env_var('LLM_MAX_GPU_LAYERS','-1'))
|
289 |
LLM_TEMPERATURE = float(get_or_create_env_var('LLM_TEMPERATURE', '0.1'))
|
290 |
-
LLM_TOP_K = int(get_or_create_env_var('LLM_TOP_K','
|
291 |
-
|
292 |
-
|
|
|
293 |
LLM_LAST_N_TOKENS = int(get_or_create_env_var('LLM_LAST_N_TOKENS', '512'))
|
294 |
LLM_MAX_NEW_TOKENS = int(get_or_create_env_var('LLM_MAX_NEW_TOKENS', '4096'))
|
295 |
LLM_SEED = int(get_or_create_env_var('LLM_SEED', '42'))
|
296 |
LLM_RESET = get_or_create_env_var('LLM_RESET', 'True')
|
297 |
LLM_STREAM = get_or_create_env_var('LLM_STREAM', 'False')
|
298 |
LLM_THREADS = int(get_or_create_env_var('LLM_THREADS', '4'))
|
299 |
-
LLM_BATCH_SIZE = int(get_or_create_env_var('LLM_BATCH_SIZE', '
|
300 |
LLM_CONTEXT_LENGTH = int(get_or_create_env_var('LLM_CONTEXT_LENGTH', '16384'))
|
301 |
LLM_SAMPLE = get_or_create_env_var('LLM_SAMPLE', 'True')
|
|
|
|
|
|
|
302 |
|
303 |
MAX_GROUPS = int(get_or_create_env_var('MAX_GROUPS', '99'))
|
304 |
|
|
|
222 |
model_short_names = []
|
223 |
model_source = []
|
224 |
|
225 |
+
CHOSEN_LOCAL_MODEL_TYPE = get_or_create_env_var("CHOSEN_LOCAL_MODEL_TYPE", "gpt-oss-20b") # Gemma 3 1B # "Gemma 2b" # "Gemma 3 4B"
|
226 |
|
227 |
if RUN_LOCAL_MODEL == "1" and CHOSEN_LOCAL_MODEL_TYPE:
|
228 |
model_full_names.append(CHOSEN_LOCAL_MODEL_TYPE)
|
|
|
252 |
# HF token may or may not be needed for downloading models from Hugging Face
|
253 |
HF_TOKEN = get_or_create_env_var('HF_TOKEN', '')
|
254 |
|
255 |
+
GEMMA2_REPO_ID = get_or_create_env_var("GEMMA2_2B_REPO_ID", "lmstudio-community/gemma-2-2b-it-GGUF")
|
256 |
+
GEMMA2_MODEL_FILE = get_or_create_env_var("GEMMA2_2B_MODEL_FILE", "gemma-2-2b-it-Q8_0.gguf")
|
257 |
+
GEMMA2_MODEL_FOLDER = get_or_create_env_var("GEMMA2_2B_MODEL_FOLDER", "model/gemma")
|
258 |
|
259 |
+
GEMMA3_REPO_ID = get_or_create_env_var("GEMMA3_REPO_ID", "unsloth/gemma-3-270m-it-qat-GGUF")
|
260 |
+
GEMMA3_MODEL_FILE = get_or_create_env_var("GEMMA3_MODEL_FILE", "gemma-3-270m-it-qat-F16.gguf")
|
261 |
GEMMA3_MODEL_FOLDER = get_or_create_env_var("GEMMA3_MODEL_FOLDER", "model/gemma")
|
262 |
|
263 |
+
GEMMA3_4B_REPO_ID = get_or_create_env_var("GEMMA3_4B_REPO_ID", "unsloth/gemma-3-4b-it-qat-GGUF")
|
264 |
+
GEMMA3_4B_MODEL_FILE = get_or_create_env_var("GEMMA3_4B_MODEL_FILE", "gemma-3-4b-it-qat-Q4_K_M.gguf")
|
265 |
GEMMA3_4B_MODEL_FOLDER = get_or_create_env_var("GEMMA3_4B_MODEL_FOLDER", "model/gemma3_4b")
|
266 |
|
267 |
+
GPT_OSS_REPO_ID = get_or_create_env_var("GPT_OSS_REPO_ID", "unsloth/gpt-oss-20b-GGUF")
|
268 |
+
GPT_OSS_MODEL_FILE = get_or_create_env_var("GPT_OSS_MODEL_FILE", "gpt-oss-20b-F16.gguf")
|
269 |
+
GPT_OSS_MODEL_FOLDER = get_or_create_env_var("GPT_OSS_MODEL_FOLDER", "model/gpt_oss")
|
270 |
+
|
271 |
+
USE_SPECULATIVE_DECODING = get_or_create_env_var("USE_SPECULATIVE_DECODING", "False")
|
272 |
+
|
273 |
+
GEMMA3_DRAFT_MODEL_LOC = get_or_create_env_var("GEMMA3_DRAFT_MODEL_LOC", ".cache/llama.cpp/unsloth_gemma-3-270m-it-qat-GGUF_gemma-3-270m-it-qat-F16.gguf")
|
274 |
+
|
275 |
+
GEMMA3_4B_DRAFT_MODEL_LOC = get_or_create_env_var("GEMMA3_4B_DRAFT_MODEL_LOC", ".cache/llama.cpp/unsloth_gemma-3-4b-it-qat-GGUF_gemma-3-4b-it-qat-Q4_K_M.gguf")
|
276 |
+
|
277 |
if CHOSEN_LOCAL_MODEL_TYPE == "Gemma 2b":
|
278 |
LOCAL_REPO_ID = GEMMA2_REPO_ID
|
279 |
LOCAL_MODEL_FILE = GEMMA2_MODEL_FILE
|
|
|
290 |
LOCAL_MODEL_FILE = GEMMA3_4B_MODEL_FILE
|
291 |
LOCAL_MODEL_FOLDER = GEMMA3_4B_MODEL_FOLDER
|
292 |
|
293 |
+
elif CHOSEN_LOCAL_MODEL_TYPE == "gpt-oss-20b":
|
294 |
+
LOCAL_REPO_ID = GPT_OSS_REPO_ID
|
295 |
+
LOCAL_MODEL_FILE = GPT_OSS_MODEL_FILE
|
296 |
+
LOCAL_MODEL_FOLDER = GPT_OSS_MODEL_FOLDER
|
297 |
+
|
298 |
print("CHOSEN_LOCAL_MODEL_TYPE:", CHOSEN_LOCAL_MODEL_TYPE)
|
299 |
print("LOCAL_REPO_ID:", LOCAL_REPO_ID)
|
300 |
print("LOCAL_MODEL_FILE:", LOCAL_MODEL_FILE)
|
301 |
print("LOCAL_MODEL_FOLDER:", LOCAL_MODEL_FOLDER)
|
302 |
|
303 |
+
LLM_MAX_GPU_LAYERS = int(get_or_create_env_var('LLM_MAX_GPU_LAYERS','-1')) # Maximum possible
|
304 |
LLM_TEMPERATURE = float(get_or_create_env_var('LLM_TEMPERATURE', '0.1'))
|
305 |
+
LLM_TOP_K = int(get_or_create_env_var('LLM_TOP_K','96'))
|
306 |
+
LLM_MIN_P = float(get_or_create_env_var('LLM_MIN_P', '0'))
|
307 |
+
LLM_TOP_P = float(get_or_create_env_var('LLM_TOP_P', '0.95'))
|
308 |
+
LLM_REPETITION_PENALTY = float(get_or_create_env_var('LLM_REPETITION_PENALTY', '1.0'))
|
309 |
LLM_LAST_N_TOKENS = int(get_or_create_env_var('LLM_LAST_N_TOKENS', '512'))
|
310 |
LLM_MAX_NEW_TOKENS = int(get_or_create_env_var('LLM_MAX_NEW_TOKENS', '4096'))
|
311 |
LLM_SEED = int(get_or_create_env_var('LLM_SEED', '42'))
|
312 |
LLM_RESET = get_or_create_env_var('LLM_RESET', 'True')
|
313 |
LLM_STREAM = get_or_create_env_var('LLM_STREAM', 'False')
|
314 |
LLM_THREADS = int(get_or_create_env_var('LLM_THREADS', '4'))
|
315 |
+
LLM_BATCH_SIZE = int(get_or_create_env_var('LLM_BATCH_SIZE', '128'))
|
316 |
LLM_CONTEXT_LENGTH = int(get_or_create_env_var('LLM_CONTEXT_LENGTH', '16384'))
|
317 |
LLM_SAMPLE = get_or_create_env_var('LLM_SAMPLE', 'True')
|
318 |
+
SPECULATIVE_DECODING = get_or_create_env_var('SPECULATIVE_DECODING', 'False')
|
319 |
+
NUM_PRED_TOKENS = int(get_or_create_env_var('NUM_PRED_TOKENS', '2'))
|
320 |
+
REASONING_SUFFIX = get_or_create_env_var('REASONING_SUFFIX', 'Reasoning: low')
|
321 |
|
322 |
MAX_GROUPS = int(get_or_create_env_var('MAX_GROUPS', '99'))
|
323 |
|
tools/custom_csvlogger.py
CHANGED
@@ -106,7 +106,7 @@ class CSVLogger_custom(FlaggingCallback):
|
|
106 |
)
|
107 |
latest_num = int(re.findall(r"\d+", latest_file.stem)[0])
|
108 |
|
109 |
-
with open(latest_file, newline="", encoding="utf-8") as csvfile:
|
110 |
reader = csv.reader(csvfile)
|
111 |
existing_headers = next(reader, None)
|
112 |
|
@@ -122,7 +122,7 @@ class CSVLogger_custom(FlaggingCallback):
|
|
122 |
|
123 |
if not Path(self.dataset_filepath).exists():
|
124 |
with open(
|
125 |
-
self.dataset_filepath, "w", newline="", encoding="utf-8"
|
126 |
) as csvfile:
|
127 |
writer = csv.writer(csvfile)
|
128 |
writer.writerow(utils.sanitize_list_for_csv(headers))
|
@@ -202,15 +202,12 @@ class CSVLogger_custom(FlaggingCallback):
|
|
202 |
|
203 |
if save_to_csv:
|
204 |
with self.lock:
|
205 |
-
with open(self.dataset_filepath, "a", newline="", encoding="utf-8") as csvfile:
|
206 |
writer = csv.writer(csvfile)
|
207 |
writer.writerow(utils.sanitize_list_for_csv(csv_data))
|
208 |
-
with open(self.dataset_filepath, encoding="utf-8") as csvfile:
|
209 |
line_count = len(list(csv.reader(csvfile))) - 1
|
210 |
|
211 |
-
|
212 |
-
print("save_to_dynamodb:", save_to_dynamodb)
|
213 |
-
print("save_to_dynamodb == True:", save_to_dynamodb == True)
|
214 |
if save_to_dynamodb == True:
|
215 |
print("Saving to DynamoDB")
|
216 |
|
|
|
106 |
)
|
107 |
latest_num = int(re.findall(r"\d+", latest_file.stem)[0])
|
108 |
|
109 |
+
with open(latest_file, newline="", encoding="utf-8-sig") as csvfile:
|
110 |
reader = csv.reader(csvfile)
|
111 |
existing_headers = next(reader, None)
|
112 |
|
|
|
122 |
|
123 |
if not Path(self.dataset_filepath).exists():
|
124 |
with open(
|
125 |
+
self.dataset_filepath, "w", newline="", encoding="utf-8-sig"
|
126 |
) as csvfile:
|
127 |
writer = csv.writer(csvfile)
|
128 |
writer.writerow(utils.sanitize_list_for_csv(headers))
|
|
|
202 |
|
203 |
if save_to_csv:
|
204 |
with self.lock:
|
205 |
+
with open(self.dataset_filepath, "a", newline="", encoding="utf-8-sig") as csvfile:
|
206 |
writer = csv.writer(csvfile)
|
207 |
writer.writerow(utils.sanitize_list_for_csv(csv_data))
|
208 |
+
with open(self.dataset_filepath, encoding="utf-8-sig") as csvfile:
|
209 |
line_count = len(list(csv.reader(csvfile))) - 1
|
210 |
|
|
|
|
|
|
|
211 |
if save_to_dynamodb == True:
|
212 |
print("Saving to DynamoDB")
|
213 |
|
tools/dedup_summaries.py
CHANGED
@@ -436,7 +436,7 @@ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:
|
|
436 |
if isinstance(responses[-1], ResponseObject):
|
437 |
response_texts = [resp.text for resp in responses]
|
438 |
elif "choices" in responses[-1]:
|
439 |
-
response_texts = [resp["choices"][0]['text'] for resp in responses]
|
440 |
else:
|
441 |
response_texts = [resp.text for resp in responses]
|
442 |
|
|
|
436 |
if isinstance(responses[-1], ResponseObject):
|
437 |
response_texts = [resp.text for resp in responses]
|
438 |
elif "choices" in responses[-1]:
|
439 |
+
response_texts = [resp["choices"][0]['message']['content'] for resp in responses] #resp["choices"][0]['text'] for resp in responses]
|
440 |
else:
|
441 |
response_texts = [resp.text for resp in responses]
|
442 |
|
tools/llm_api_call.py
CHANGED
@@ -17,7 +17,7 @@ GradioFileData = gr.FileData
|
|
17 |
from tools.prompts import initial_table_prompt, prompt2, prompt3, initial_table_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, force_existing_topics_prompt, allow_new_topics_prompt, force_single_topic_prompt, add_existing_topics_assistant_prefill, initial_table_assistant_prefill, structured_summary_prompt
|
18 |
from tools.helper_functions import read_file, put_columns_in_df, wrap_text, initial_clean, load_in_data_file, load_in_file, create_topic_summary_df_from_reference_table, convert_reference_table_to_pivot_table, get_basic_response_data, clean_column_name, load_in_previous_data_files
|
19 |
from tools.llm_funcs import ResponseObject, construct_gemini_generative_model, call_llm_with_markdown_table_checks, create_missing_references_df, calculate_tokens_from_metadata
|
20 |
-
from tools.config import RUN_LOCAL_MODEL, AWS_REGION, MAX_COMMENT_CHARS, MAX_OUTPUT_VALIDATION_ATTEMPTS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, RUN_AWS_FUNCTIONS, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_GROUPS
|
21 |
from tools.aws_functions import connect_to_bedrock_runtime
|
22 |
|
23 |
if RUN_LOCAL_MODEL == "1":
|
@@ -31,11 +31,12 @@ batch_size_default = BATCH_SIZE_DEFAULT
|
|
31 |
deduplication_threshold = DEDUPLICATION_THRESHOLD
|
32 |
max_comment_character_length = MAX_COMMENT_CHARS
|
33 |
random_seed = LLM_SEED
|
|
|
34 |
|
35 |
# if RUN_AWS_FUNCTIONS == '1':
|
36 |
# bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
|
37 |
# else:
|
38 |
-
# bedrock_runtime =
|
39 |
|
40 |
|
41 |
|
@@ -130,7 +131,7 @@ def clean_markdown_table(text: str):
|
|
130 |
lines = text.splitlines()
|
131 |
|
132 |
# Step 1: Identify table structure and process line continuations
|
133 |
-
table_rows =
|
134 |
current_row = None
|
135 |
|
136 |
for line in lines:
|
@@ -174,7 +175,7 @@ def clean_markdown_table(text: str):
|
|
174 |
max_columns = max(max_columns, len(cells))
|
175 |
|
176 |
# Now format each row
|
177 |
-
formatted_rows =
|
178 |
for row in table_rows:
|
179 |
# Ensure the row starts and ends with pipes
|
180 |
if not row.startswith('|'):
|
@@ -354,7 +355,7 @@ def write_llm_output_and_logs(responses: List[ResponseObject],
|
|
354 |
- first_run (bool): A boolean indicating if this is the first run through this function in this process. Defaults to False.
|
355 |
- output_folder (str): The name of the folder where output files are saved.
|
356 |
"""
|
357 |
-
topic_summary_df_out_path =
|
358 |
topic_table_out_path = "topic_table_error.csv"
|
359 |
reference_table_out_path = "reference_table_error.csv"
|
360 |
topic_summary_df_out_path = "unique_topic_table_error.csv"
|
@@ -390,7 +391,7 @@ def write_llm_output_and_logs(responses: List[ResponseObject],
|
|
390 |
log_files_output_paths.append(whole_conversation_path_meta)
|
391 |
|
392 |
if isinstance(responses[-1], ResponseObject): response_text = responses[-1].text
|
393 |
-
elif "choices" in responses[-1]: response_text = responses[-1]["choices"][0]['text']
|
394 |
else: response_text = responses[-1].text
|
395 |
|
396 |
# Convert response text to a markdown table
|
@@ -426,7 +427,7 @@ def write_llm_output_and_logs(responses: List[ResponseObject],
|
|
426 |
topic_table_out_path = output_folder + batch_file_path_details + "_topic_table_" + model_choice_clean + ".csv"
|
427 |
|
428 |
# Table to map references to topics
|
429 |
-
reference_data =
|
430 |
|
431 |
batch_basic_response_df["Reference"] = batch_basic_response_df["Reference"].astype(str)
|
432 |
|
@@ -614,13 +615,7 @@ def generate_zero_shot_topics_df(zero_shot_topics:pd.DataFrame,
|
|
614 |
if force_zero_shot_radio == "Yes":
|
615 |
zero_shot_topics_gen_topics_list.append("")
|
616 |
zero_shot_topics_subtopics_list.append("No relevant topic")
|
617 |
-
zero_shot_topics_description_list.append("")
|
618 |
-
|
619 |
-
# This process was abandoned (revising the general topics) as it didn't seem to work
|
620 |
-
# if create_revised_general_topics == True:
|
621 |
-
# pass
|
622 |
-
# else:
|
623 |
-
# pass
|
624 |
|
625 |
# Add description or not
|
626 |
zero_shot_topics_df = pd.DataFrame(data={
|
@@ -646,9 +641,9 @@ def extract_topics(in_data_file: GradioFileData,
|
|
646 |
model_choice:str,
|
647 |
candidate_topics: GradioFileData = None,
|
648 |
latest_batch_completed:int=0,
|
649 |
-
out_message:List=
|
650 |
-
out_file_paths:List =
|
651 |
-
log_files_output_paths:List =
|
652 |
first_loop_state:bool=False,
|
653 |
whole_conversation_metadata_str:str="",
|
654 |
initial_table_prompt:str=initial_table_prompt,
|
@@ -663,7 +658,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
663 |
time_taken:float = 0,
|
664 |
sentiment_checkbox:str = "Negative, Neutral, or Positive",
|
665 |
force_zero_shot_radio:str = "No",
|
666 |
-
in_excel_sheets:List[str] =
|
667 |
force_single_topic_radio:str = "No",
|
668 |
output_folder:str=OUTPUT_FOLDER,
|
669 |
force_single_topic_prompt:str=force_single_topic_prompt,
|
@@ -675,6 +670,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
675 |
model_name_map:dict=model_name_map,
|
676 |
max_time_for_loop:int=max_time_for_loop,
|
677 |
CHOSEN_LOCAL_MODEL_TYPE:str=CHOSEN_LOCAL_MODEL_TYPE,
|
|
|
678 |
progress=Progress(track_tqdm=True)):
|
679 |
|
680 |
'''
|
@@ -723,31 +719,36 @@ def extract_topics(in_data_file: GradioFileData,
|
|
723 |
- model_name_map (dict, optional): A dictionary mapping full model name to shortened.
|
724 |
- max_time_for_loop (int, optional): The number of seconds maximum that the function should run for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there).
|
725 |
- CHOSEN_LOCAL_MODEL_TYPE (str, optional): The name of the chosen local model.
|
|
|
726 |
- progress (Progress): A progress tracker.
|
727 |
'''
|
728 |
|
729 |
tic = time.perf_counter()
|
730 |
-
google_client =
|
731 |
google_config = {}
|
732 |
final_time = 0.0
|
733 |
-
whole_conversation_metadata =
|
734 |
is_error = False
|
735 |
create_revised_general_topics = False
|
736 |
-
local_model =
|
737 |
-
tokenizer =
|
738 |
zero_shot_topics_df = pd.DataFrame()
|
739 |
missing_df = pd.DataFrame()
|
740 |
new_reference_df = pd.DataFrame(columns=["Response References", "General topic", "Subtopic", "Sentiment", "Start row of group", "Group" ,"Topic_number", "Summary"])
|
741 |
new_topic_summary_df = pd.DataFrame(columns=["General topic","Subtopic","Sentiment","Group","Number of responses","Summary"])
|
742 |
new_topic_df = pd.DataFrame()
|
743 |
-
|
744 |
-
#
|
745 |
-
#llama_cpp_prefix = "
|
746 |
-
#llama_cpp_suffix = "
|
747 |
-
|
748 |
-
#
|
749 |
-
llama_cpp_prefix = "
|
750 |
-
llama_cpp_suffix = "
|
|
|
|
|
|
|
|
|
751 |
|
752 |
#print("output_folder:", output_folder)
|
753 |
|
@@ -773,8 +774,8 @@ def extract_topics(in_data_file: GradioFileData,
|
|
773 |
print("This is the first time through the loop, resetting latest_batch_completed to 0")
|
774 |
if (latest_batch_completed == 999) | (latest_batch_completed == 0):
|
775 |
latest_batch_completed = 0
|
776 |
-
out_message =
|
777 |
-
out_file_paths =
|
778 |
final_time = 0
|
779 |
|
780 |
if (model_choice == CHOSEN_LOCAL_MODEL_TYPE) & (RUN_LOCAL_MODEL == "1"):
|
@@ -795,7 +796,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
795 |
out_message = [out_message]
|
796 |
|
797 |
if not out_file_paths:
|
798 |
-
out_file_paths =
|
799 |
|
800 |
|
801 |
if "anthropic.claude-3-sonnet" in model_choice and file_data.shape[1] > 300:
|
@@ -819,7 +820,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
819 |
simplified_csv_table_path, normalised_simple_markdown_table, start_row, end_row, batch_basic_response_df = data_file_to_markdown_table(file_data, file_name, chosen_cols, latest_batch_completed, batch_size)
|
820 |
|
821 |
# Conversation history
|
822 |
-
conversation_history =
|
823 |
|
824 |
# If the latest batch of responses contains at least one instance of text
|
825 |
if not batch_basic_response_df.empty:
|
@@ -933,13 +934,15 @@ def extract_topics(in_data_file: GradioFileData,
|
|
933 |
except Exception as e:
|
934 |
print(f"Error writing prompt to file {formatted_prompt_output_path}: {e}")
|
935 |
|
936 |
-
if "
|
937 |
-
|
938 |
-
else:
|
939 |
-
|
|
|
|
|
940 |
|
941 |
-
conversation_history =
|
942 |
-
whole_conversation =
|
943 |
|
944 |
# Process requests to large language model
|
945 |
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=add_existing_topics_assistant_prefill, master = True)
|
@@ -953,13 +956,16 @@ def extract_topics(in_data_file: GradioFileData,
|
|
953 |
|
954 |
if isinstance(responses[-1], ResponseObject):
|
955 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
956 |
-
f.write(responses[-1].text)
|
|
|
957 |
elif "choices" in responses[-1]:
|
958 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
959 |
-
f.write(responses[-1]["choices"][0]['text'])
|
|
|
960 |
else:
|
961 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
962 |
-
f.write(responses[-1].text)
|
|
|
963 |
|
964 |
except Exception as e:
|
965 |
print("Error in returning model response:", e)
|
@@ -1020,26 +1026,23 @@ def extract_topics(in_data_file: GradioFileData,
|
|
1020 |
if prompt3: formatted_prompt3 = prompt3.format(response_table=normalised_simple_markdown_table, sentiment_choices=sentiment_prompt)
|
1021 |
else: formatted_prompt3 = prompt3
|
1022 |
|
1023 |
-
if "
|
1024 |
-
|
1025 |
-
|
1026 |
-
|
|
|
|
|
1027 |
|
1028 |
-
|
1029 |
|
1030 |
-
whole_conversation =
|
1031 |
|
1032 |
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, formatted_initial_table_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=initial_table_assistant_prefill)
|
1033 |
|
1034 |
topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, topic_table_df, reference_df, new_topic_summary_df, batch_file_path_details, is_error = write_llm_output_and_logs(responses, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, group_name, produce_structures_summary_radio, first_run=True, output_folder=output_folder)
|
1035 |
|
1036 |
# If error in table parsing, leave function
|
1037 |
-
if is_error == True:
|
1038 |
-
raise Exception("Error in output table parsing")
|
1039 |
-
# unique_table_df_display_table_markdown, new_topic_df, new_topic_summary_df, new_reference_df, out_file_paths, out_file_paths, latest_batch_completed, log_files_output_paths, log_files_output_paths, whole_conversation_metadata_str, final_time, out_file_paths#, final_message_out
|
1040 |
-
|
1041 |
-
|
1042 |
-
#all_topic_tables_df.append(topic_table_df)
|
1043 |
|
1044 |
topic_table_df.to_csv(topic_table_out_path, index=None)
|
1045 |
out_file_paths.append(topic_table_out_path)
|
@@ -1056,28 +1059,28 @@ def extract_topics(in_data_file: GradioFileData,
|
|
1056 |
new_topic_summary_df.to_csv(topic_summary_df_out_path, index=None)
|
1057 |
out_file_paths.append(topic_summary_df_out_path)
|
1058 |
|
1059 |
-
#all_markdown_topic_tables.append(markdown_table)
|
1060 |
|
1061 |
whole_conversation_metadata.append(whole_conversation_metadata_str)
|
1062 |
whole_conversation_metadata_str = '. '.join(whole_conversation_metadata)
|
1063 |
|
1064 |
-
# Write final output to text file also
|
1065 |
# Write final output to text file for logging purposes
|
1066 |
try:
|
1067 |
final_table_output_path = output_folder + batch_file_path_details + "_full_response_" + model_choice_clean + ".txt"
|
1068 |
|
1069 |
if isinstance(responses[-1], ResponseObject):
|
1070 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
1071 |
-
f.write(responses[-1].text)
|
|
|
1072 |
elif "choices" in responses[-1]:
|
1073 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
1074 |
-
f.write(responses[-1]["choices"][0]['text'])
|
|
|
1075 |
else:
|
1076 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
1077 |
-
f.write(responses[-1].text)
|
|
|
1078 |
|
1079 |
-
except Exception as e:
|
1080 |
-
print("Error in returning model response:", e)
|
1081 |
|
1082 |
new_topic_df = topic_table_df
|
1083 |
new_reference_df = reference_df
|
@@ -1122,7 +1125,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
1122 |
# Set to a very high number so as not to mess with subsequent file processing by the user
|
1123 |
#latest_batch_completed = 999
|
1124 |
|
1125 |
-
join_file_paths =
|
1126 |
|
1127 |
toc = time.perf_counter()
|
1128 |
final_time = (toc - tic) + time_taken
|
@@ -1198,8 +1201,12 @@ def extract_topics(in_data_file: GradioFileData,
|
|
1198 |
|
1199 |
print("latest_batch_completed at end of batch iterations to return is", latest_batch_completed)
|
1200 |
|
|
|
|
|
1201 |
return unique_table_df_display_table_markdown, existing_topics_table, final_out_topic_summary_df, existing_reference_df, final_out_file_paths, final_out_file_paths, latest_batch_completed, log_files_output_paths, log_files_output_paths, whole_conversation_metadata_str, final_time, final_out_file_paths, final_out_file_paths, modifiable_topic_summary_df, final_out_file_paths, join_file_paths, existing_reference_df_pivot, missing_df
|
1202 |
|
|
|
|
|
1203 |
return unique_table_df_display_table_markdown, existing_topics_table, existing_topic_summary_df, existing_reference_df, out_file_paths, out_file_paths, latest_batch_completed, log_files_output_paths, log_files_output_paths, whole_conversation_metadata_str, final_time, out_file_paths, out_file_paths, modifiable_topic_summary_df, out_file_paths, join_file_paths, existing_reference_df_pivot, missing_df # gr.Dataframe(value=modifiable_topic_summary_df, headers=None, col_count=(modifiable_topic_summary_df.shape[1], "fixed"), row_count = (modifiable_topic_summary_df.shape[0], "fixed"), visible=True, type="pandas"),
|
1204 |
|
1205 |
def wrapper_extract_topics_per_column_value(
|
@@ -1237,7 +1244,7 @@ def wrapper_extract_topics_per_column_value(
|
|
1237 |
context_textbox: str = "",
|
1238 |
sentiment_checkbox: str = "Negative, Neutral, or Positive",
|
1239 |
force_zero_shot_radio: str = "No",
|
1240 |
-
in_excel_sheets: List[str] =
|
1241 |
force_single_topic_radio: str = "No",
|
1242 |
produce_structures_summary_radio: str = "No",
|
1243 |
aws_access_key_textbox:str="",
|
@@ -1279,9 +1286,9 @@ def wrapper_extract_topics_per_column_value(
|
|
1279 |
acc_missing_df = pd.DataFrame()
|
1280 |
|
1281 |
# Lists are extended
|
1282 |
-
acc_out_file_paths =
|
1283 |
-
acc_log_files_output_paths =
|
1284 |
-
acc_join_file_paths =
|
1285 |
|
1286 |
# Single value outputs - typically the last one is most relevant, or sum for time
|
1287 |
acc_markdown_output = initial_unique_table_df_display_table_markdown
|
@@ -1353,9 +1360,9 @@ def wrapper_extract_topics_per_column_value(
|
|
1353 |
num_batches=current_num_batches,
|
1354 |
latest_batch_completed=current_latest_batch_completed, # Reset for each new segment's internal batching
|
1355 |
first_loop_state=current_first_loop_state, # True only for the very first iteration of wrapper
|
1356 |
-
out_message=
|
1357 |
-
out_file_paths=
|
1358 |
-
log_files_output_paths=
|
1359 |
whole_conversation_metadata_str="", # Fresh for each call
|
1360 |
time_taken=0, # Time taken for this specific call, wrapper sums it.
|
1361 |
# Pass through other parameters
|
@@ -1450,8 +1457,12 @@ def wrapper_extract_topics_per_column_value(
|
|
1450 |
unique_table_df_display_table = acc_topic_summary_df.apply(lambda col: col.map(lambda x: wrap_text(x, max_text_length=500)))
|
1451 |
acc_markdown_output = unique_table_df_display_table[["General topic", "Subtopic", "Sentiment", "Number of responses", "Summary", "Group"]].to_markdown(index=False)
|
1452 |
|
|
|
|
|
1453 |
acc_input_tokens, acc_output_tokens, acc_number_of_calls = calculate_tokens_from_metadata(acc_whole_conversation_metadata, model_choice, model_name_map)
|
1454 |
|
|
|
|
|
1455 |
print(f"\nWrapper finished processing all segments. Total time: {acc_total_time_taken:.2f}s")
|
1456 |
|
1457 |
# The return signature should match extract_topics.
|
@@ -1537,7 +1548,7 @@ def modify_existing_output_tables(original_topic_summary_df:pd.DataFrame, modifi
|
|
1537 |
reference_file_path = os.path.basename(reference_files[0]) if reference_files else None
|
1538 |
unique_table_file_path = os.path.basename(unique_files[0]) if unique_files else None
|
1539 |
|
1540 |
-
output_file_list =
|
1541 |
|
1542 |
if reference_file_path and unique_table_file_path:
|
1543 |
|
|
|
17 |
from tools.prompts import initial_table_prompt, prompt2, prompt3, initial_table_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, force_existing_topics_prompt, allow_new_topics_prompt, force_single_topic_prompt, add_existing_topics_assistant_prefill, initial_table_assistant_prefill, structured_summary_prompt
|
18 |
from tools.helper_functions import read_file, put_columns_in_df, wrap_text, initial_clean, load_in_data_file, load_in_file, create_topic_summary_df_from_reference_table, convert_reference_table_to_pivot_table, get_basic_response_data, clean_column_name, load_in_previous_data_files
|
19 |
from tools.llm_funcs import ResponseObject, construct_gemini_generative_model, call_llm_with_markdown_table_checks, create_missing_references_df, calculate_tokens_from_metadata
|
20 |
+
from tools.config import RUN_LOCAL_MODEL, AWS_REGION, MAX_COMMENT_CHARS, MAX_OUTPUT_VALIDATION_ATTEMPTS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, RUN_AWS_FUNCTIONS, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_GROUPS, REASONING_SUFFIX
|
21 |
from tools.aws_functions import connect_to_bedrock_runtime
|
22 |
|
23 |
if RUN_LOCAL_MODEL == "1":
|
|
|
31 |
deduplication_threshold = DEDUPLICATION_THRESHOLD
|
32 |
max_comment_character_length = MAX_COMMENT_CHARS
|
33 |
random_seed = LLM_SEED
|
34 |
+
reasoning_suffix = REASONING_SUFFIX
|
35 |
|
36 |
# if RUN_AWS_FUNCTIONS == '1':
|
37 |
# bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
|
38 |
# else:
|
39 |
+
# bedrock_runtime = list()
|
40 |
|
41 |
|
42 |
|
|
|
131 |
lines = text.splitlines()
|
132 |
|
133 |
# Step 1: Identify table structure and process line continuations
|
134 |
+
table_rows = list()
|
135 |
current_row = None
|
136 |
|
137 |
for line in lines:
|
|
|
175 |
max_columns = max(max_columns, len(cells))
|
176 |
|
177 |
# Now format each row
|
178 |
+
formatted_rows = list()
|
179 |
for row in table_rows:
|
180 |
# Ensure the row starts and ends with pipes
|
181 |
if not row.startswith('|'):
|
|
|
355 |
- first_run (bool): A boolean indicating if this is the first run through this function in this process. Defaults to False.
|
356 |
- output_folder (str): The name of the folder where output files are saved.
|
357 |
"""
|
358 |
+
topic_summary_df_out_path = list()
|
359 |
topic_table_out_path = "topic_table_error.csv"
|
360 |
reference_table_out_path = "reference_table_error.csv"
|
361 |
topic_summary_df_out_path = "unique_topic_table_error.csv"
|
|
|
391 |
log_files_output_paths.append(whole_conversation_path_meta)
|
392 |
|
393 |
if isinstance(responses[-1], ResponseObject): response_text = responses[-1].text
|
394 |
+
elif "choices" in responses[-1]: response_text = responses[-1]['choices'][0]['message']['content'] #responses[-1]["choices"][0]['text']
|
395 |
else: response_text = responses[-1].text
|
396 |
|
397 |
# Convert response text to a markdown table
|
|
|
427 |
topic_table_out_path = output_folder + batch_file_path_details + "_topic_table_" + model_choice_clean + ".csv"
|
428 |
|
429 |
# Table to map references to topics
|
430 |
+
reference_data = list()
|
431 |
|
432 |
batch_basic_response_df["Reference"] = batch_basic_response_df["Reference"].astype(str)
|
433 |
|
|
|
615 |
if force_zero_shot_radio == "Yes":
|
616 |
zero_shot_topics_gen_topics_list.append("")
|
617 |
zero_shot_topics_subtopics_list.append("No relevant topic")
|
618 |
+
zero_shot_topics_description_list.append("")
|
|
|
|
|
|
|
|
|
|
|
|
|
619 |
|
620 |
# Add description or not
|
621 |
zero_shot_topics_df = pd.DataFrame(data={
|
|
|
641 |
model_choice:str,
|
642 |
candidate_topics: GradioFileData = None,
|
643 |
latest_batch_completed:int=0,
|
644 |
+
out_message:List= list(),
|
645 |
+
out_file_paths:List = list(),
|
646 |
+
log_files_output_paths:List = list(),
|
647 |
first_loop_state:bool=False,
|
648 |
whole_conversation_metadata_str:str="",
|
649 |
initial_table_prompt:str=initial_table_prompt,
|
|
|
658 |
time_taken:float = 0,
|
659 |
sentiment_checkbox:str = "Negative, Neutral, or Positive",
|
660 |
force_zero_shot_radio:str = "No",
|
661 |
+
in_excel_sheets:List[str] = list(),
|
662 |
force_single_topic_radio:str = "No",
|
663 |
output_folder:str=OUTPUT_FOLDER,
|
664 |
force_single_topic_prompt:str=force_single_topic_prompt,
|
|
|
670 |
model_name_map:dict=model_name_map,
|
671 |
max_time_for_loop:int=max_time_for_loop,
|
672 |
CHOSEN_LOCAL_MODEL_TYPE:str=CHOSEN_LOCAL_MODEL_TYPE,
|
673 |
+
reasoning_suffix:str=reasoning_suffix,
|
674 |
progress=Progress(track_tqdm=True)):
|
675 |
|
676 |
'''
|
|
|
719 |
- model_name_map (dict, optional): A dictionary mapping full model name to shortened.
|
720 |
- max_time_for_loop (int, optional): The number of seconds maximum that the function should run for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there).
|
721 |
- CHOSEN_LOCAL_MODEL_TYPE (str, optional): The name of the chosen local model.
|
722 |
+
- reasoning_suffix (str, optional): The suffix for the reasoning system prompt.
|
723 |
- progress (Progress): A progress tracker.
|
724 |
'''
|
725 |
|
726 |
tic = time.perf_counter()
|
727 |
+
google_client = list()
|
728 |
google_config = {}
|
729 |
final_time = 0.0
|
730 |
+
whole_conversation_metadata = list()
|
731 |
is_error = False
|
732 |
create_revised_general_topics = False
|
733 |
+
local_model = list()
|
734 |
+
tokenizer = list()
|
735 |
zero_shot_topics_df = pd.DataFrame()
|
736 |
missing_df = pd.DataFrame()
|
737 |
new_reference_df = pd.DataFrame(columns=["Response References", "General topic", "Subtopic", "Sentiment", "Start row of group", "Group" ,"Topic_number", "Summary"])
|
738 |
new_topic_summary_df = pd.DataFrame(columns=["General topic","Subtopic","Sentiment","Group","Number of responses","Summary"])
|
739 |
new_topic_df = pd.DataFrame()
|
740 |
+
|
741 |
+
# For Gemma models
|
742 |
+
#llama_cpp_prefix = "<start_of_turn>user\n"
|
743 |
+
#llama_cpp_suffix = "<end_of_turn>\n<start_of_turn>model\n"
|
744 |
+
|
745 |
+
# For GPT OSS
|
746 |
+
#llama_cpp_prefix = "<|start|>assistant<|channel|>analysis<|message|>\n"
|
747 |
+
#llama_cpp_suffix = "<|start|>assistant<|channel|>final<|message|>"
|
748 |
+
|
749 |
+
# Blank
|
750 |
+
llama_cpp_prefix = ""
|
751 |
+
llama_cpp_suffix = ""
|
752 |
|
753 |
#print("output_folder:", output_folder)
|
754 |
|
|
|
774 |
print("This is the first time through the loop, resetting latest_batch_completed to 0")
|
775 |
if (latest_batch_completed == 999) | (latest_batch_completed == 0):
|
776 |
latest_batch_completed = 0
|
777 |
+
out_message = list()
|
778 |
+
out_file_paths = list()
|
779 |
final_time = 0
|
780 |
|
781 |
if (model_choice == CHOSEN_LOCAL_MODEL_TYPE) & (RUN_LOCAL_MODEL == "1"):
|
|
|
796 |
out_message = [out_message]
|
797 |
|
798 |
if not out_file_paths:
|
799 |
+
out_file_paths = list()
|
800 |
|
801 |
|
802 |
if "anthropic.claude-3-sonnet" in model_choice and file_data.shape[1] > 300:
|
|
|
820 |
simplified_csv_table_path, normalised_simple_markdown_table, start_row, end_row, batch_basic_response_df = data_file_to_markdown_table(file_data, file_name, chosen_cols, latest_batch_completed, batch_size)
|
821 |
|
822 |
# Conversation history
|
823 |
+
conversation_history = list()
|
824 |
|
825 |
# If the latest batch of responses contains at least one instance of text
|
826 |
if not batch_basic_response_df.empty:
|
|
|
934 |
except Exception as e:
|
935 |
print(f"Error writing prompt to file {formatted_prompt_output_path}: {e}")
|
936 |
|
937 |
+
#if "Local" in model_source:
|
938 |
+
# summary_prompt_list = [full_prompt] # Includes system prompt
|
939 |
+
#else:
|
940 |
+
summary_prompt_list = [formatted_summary_prompt]
|
941 |
+
|
942 |
+
if "Local" in model_source and reasoning_suffix: formatted_system_prompt = formatted_system_prompt + "\n" + reasoning_suffix
|
943 |
|
944 |
+
conversation_history = list()
|
945 |
+
whole_conversation = list()
|
946 |
|
947 |
# Process requests to large language model
|
948 |
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=add_existing_topics_assistant_prefill, master = True)
|
|
|
956 |
|
957 |
if isinstance(responses[-1], ResponseObject):
|
958 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
959 |
+
#f.write(responses[-1].text)
|
960 |
+
f.write(response_text)
|
961 |
elif "choices" in responses[-1]:
|
962 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
963 |
+
#f.write(responses[-1]["choices"][0]['text'])
|
964 |
+
f.write(response_text)
|
965 |
else:
|
966 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
967 |
+
#f.write(responses[-1].text)
|
968 |
+
f.write(response_text)
|
969 |
|
970 |
except Exception as e:
|
971 |
print("Error in returning model response:", e)
|
|
|
1026 |
if prompt3: formatted_prompt3 = prompt3.format(response_table=normalised_simple_markdown_table, sentiment_choices=sentiment_prompt)
|
1027 |
else: formatted_prompt3 = prompt3
|
1028 |
|
1029 |
+
#if "Local" in model_source:
|
1030 |
+
#formatted_initial_table_prompt = llama_cpp_prefix + formatted_initial_table_system_prompt + "\n" + formatted_initial_table_prompt + llama_cpp_suffix
|
1031 |
+
#formatted_prompt2 = llama_cpp_prefix + formatted_initial_table_system_prompt + "\n" + formatted_prompt2 + llama_cpp_suffix
|
1032 |
+
#formatted_prompt3 = llama_cpp_prefix + formatted_initial_table_system_prompt + "\n" + formatted_prompt3 + llama_cpp_suffix
|
1033 |
+
|
1034 |
+
batch_prompts = [formatted_initial_table_prompt, formatted_prompt2, formatted_prompt3][:number_of_prompts_used] # Adjust this list to send fewer requests
|
1035 |
|
1036 |
+
if "Local" in model_source and reasoning_suffix: formatted_initial_table_system_prompt = formatted_initial_table_system_prompt + "\n" + reasoning_suffix
|
1037 |
|
1038 |
+
whole_conversation = list()
|
1039 |
|
1040 |
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, formatted_initial_table_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=initial_table_assistant_prefill)
|
1041 |
|
1042 |
topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, topic_table_df, reference_df, new_topic_summary_df, batch_file_path_details, is_error = write_llm_output_and_logs(responses, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, group_name, produce_structures_summary_radio, first_run=True, output_folder=output_folder)
|
1043 |
|
1044 |
# If error in table parsing, leave function
|
1045 |
+
if is_error == True: raise Exception("Error in output table parsing")
|
|
|
|
|
|
|
|
|
|
|
1046 |
|
1047 |
topic_table_df.to_csv(topic_table_out_path, index=None)
|
1048 |
out_file_paths.append(topic_table_out_path)
|
|
|
1059 |
new_topic_summary_df.to_csv(topic_summary_df_out_path, index=None)
|
1060 |
out_file_paths.append(topic_summary_df_out_path)
|
1061 |
|
|
|
1062 |
|
1063 |
whole_conversation_metadata.append(whole_conversation_metadata_str)
|
1064 |
whole_conversation_metadata_str = '. '.join(whole_conversation_metadata)
|
1065 |
|
|
|
1066 |
# Write final output to text file for logging purposes
|
1067 |
try:
|
1068 |
final_table_output_path = output_folder + batch_file_path_details + "_full_response_" + model_choice_clean + ".txt"
|
1069 |
|
1070 |
if isinstance(responses[-1], ResponseObject):
|
1071 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
1072 |
+
#f.write(responses[-1].text)
|
1073 |
+
f.write(response_text)
|
1074 |
elif "choices" in responses[-1]:
|
1075 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
1076 |
+
#f.write(responses[-1]["choices"][0]['text'])
|
1077 |
+
f.write(response_text)
|
1078 |
else:
|
1079 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
1080 |
+
#f.write(responses[-1].text)
|
1081 |
+
f.write(response_text)
|
1082 |
|
1083 |
+
except Exception as e: print("Error in returning model response:", e)
|
|
|
1084 |
|
1085 |
new_topic_df = topic_table_df
|
1086 |
new_reference_df = reference_df
|
|
|
1125 |
# Set to a very high number so as not to mess with subsequent file processing by the user
|
1126 |
#latest_batch_completed = 999
|
1127 |
|
1128 |
+
join_file_paths = list()
|
1129 |
|
1130 |
toc = time.perf_counter()
|
1131 |
final_time = (toc - tic) + time_taken
|
|
|
1201 |
|
1202 |
print("latest_batch_completed at end of batch iterations to return is", latest_batch_completed)
|
1203 |
|
1204 |
+
print("whole_conversation_metadata_str at end of batch iterations to return is", whole_conversation_metadata_str)
|
1205 |
+
|
1206 |
return unique_table_df_display_table_markdown, existing_topics_table, final_out_topic_summary_df, existing_reference_df, final_out_file_paths, final_out_file_paths, latest_batch_completed, log_files_output_paths, log_files_output_paths, whole_conversation_metadata_str, final_time, final_out_file_paths, final_out_file_paths, modifiable_topic_summary_df, final_out_file_paths, join_file_paths, existing_reference_df_pivot, missing_df
|
1207 |
|
1208 |
+
print("whole_conversation_metadata_str at end of batch iterations to return is", whole_conversation_metadata_str)
|
1209 |
+
|
1210 |
return unique_table_df_display_table_markdown, existing_topics_table, existing_topic_summary_df, existing_reference_df, out_file_paths, out_file_paths, latest_batch_completed, log_files_output_paths, log_files_output_paths, whole_conversation_metadata_str, final_time, out_file_paths, out_file_paths, modifiable_topic_summary_df, out_file_paths, join_file_paths, existing_reference_df_pivot, missing_df # gr.Dataframe(value=modifiable_topic_summary_df, headers=None, col_count=(modifiable_topic_summary_df.shape[1], "fixed"), row_count = (modifiable_topic_summary_df.shape[0], "fixed"), visible=True, type="pandas"),
|
1211 |
|
1212 |
def wrapper_extract_topics_per_column_value(
|
|
|
1244 |
context_textbox: str = "",
|
1245 |
sentiment_checkbox: str = "Negative, Neutral, or Positive",
|
1246 |
force_zero_shot_radio: str = "No",
|
1247 |
+
in_excel_sheets: List[str] = list(),
|
1248 |
force_single_topic_radio: str = "No",
|
1249 |
produce_structures_summary_radio: str = "No",
|
1250 |
aws_access_key_textbox:str="",
|
|
|
1286 |
acc_missing_df = pd.DataFrame()
|
1287 |
|
1288 |
# Lists are extended
|
1289 |
+
acc_out_file_paths = list()
|
1290 |
+
acc_log_files_output_paths = list()
|
1291 |
+
acc_join_file_paths = list() # join_file_paths seems to be overwritten, so maybe last one or extend? Let's extend.
|
1292 |
|
1293 |
# Single value outputs - typically the last one is most relevant, or sum for time
|
1294 |
acc_markdown_output = initial_unique_table_df_display_table_markdown
|
|
|
1360 |
num_batches=current_num_batches,
|
1361 |
latest_batch_completed=current_latest_batch_completed, # Reset for each new segment's internal batching
|
1362 |
first_loop_state=current_first_loop_state, # True only for the very first iteration of wrapper
|
1363 |
+
out_message= list(), # Fresh for each call
|
1364 |
+
out_file_paths= list(),# Fresh for each call
|
1365 |
+
log_files_output_paths= list(),# Fresh for each call
|
1366 |
whole_conversation_metadata_str="", # Fresh for each call
|
1367 |
time_taken=0, # Time taken for this specific call, wrapper sums it.
|
1368 |
# Pass through other parameters
|
|
|
1457 |
unique_table_df_display_table = acc_topic_summary_df.apply(lambda col: col.map(lambda x: wrap_text(x, max_text_length=500)))
|
1458 |
acc_markdown_output = unique_table_df_display_table[["General topic", "Subtopic", "Sentiment", "Number of responses", "Summary", "Group"]].to_markdown(index=False)
|
1459 |
|
1460 |
+
print("acc_whole_conversation_metadata at end of wrapper is", acc_whole_conversation_metadata)
|
1461 |
+
|
1462 |
acc_input_tokens, acc_output_tokens, acc_number_of_calls = calculate_tokens_from_metadata(acc_whole_conversation_metadata, model_choice, model_name_map)
|
1463 |
|
1464 |
+
print("acc_input_tokens, acc_output_tokens, acc_number_of_calls at end of wrapper is", acc_input_tokens, acc_output_tokens, acc_number_of_calls)
|
1465 |
+
|
1466 |
print(f"\nWrapper finished processing all segments. Total time: {acc_total_time_taken:.2f}s")
|
1467 |
|
1468 |
# The return signature should match extract_topics.
|
|
|
1548 |
reference_file_path = os.path.basename(reference_files[0]) if reference_files else None
|
1549 |
unique_table_file_path = os.path.basename(unique_files[0]) if unique_files else None
|
1550 |
|
1551 |
+
output_file_list = list()
|
1552 |
|
1553 |
if reference_file_path and unique_table_file_path:
|
1554 |
|
tools/llm_funcs.py
CHANGED
@@ -7,6 +7,7 @@ import pandas as pd
|
|
7 |
import json
|
8 |
from tqdm import tqdm
|
9 |
from huggingface_hub import hf_hub_download
|
|
|
10 |
from typing import List, Tuple, TypeVar
|
11 |
from google import genai as ai
|
12 |
from google.genai import types
|
@@ -17,12 +18,18 @@ torch.cuda.empty_cache()
|
|
17 |
|
18 |
model_type = None # global variable setup
|
19 |
full_text = "" # Define dummy source text (full text) just to enable highlight function to load
|
20 |
-
model =
|
21 |
-
tokenizer =
|
22 |
|
23 |
-
from tools.config import RUN_AWS_FUNCTIONS, AWS_REGION, LLM_TEMPERATURE, LLM_TOP_K, LLM_TOP_P, LLM_REPETITION_PENALTY, LLM_LAST_N_TOKENS, LLM_MAX_NEW_TOKENS, LLM_SEED, LLM_RESET, LLM_STREAM, LLM_THREADS, LLM_BATCH_SIZE, LLM_CONTEXT_LENGTH, LLM_SAMPLE, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, MAX_COMMENT_CHARS, RUN_LOCAL_MODEL, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, HF_TOKEN, LLM_SEED, LLM_MAX_GPU_LAYERS
|
24 |
from tools.prompts import initial_table_assistant_prefill
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
# Both models are loaded on app initialisation so that users don't have to wait for the models to be downloaded
|
27 |
# Check for torch cuda
|
28 |
print("Is CUDA enabled? ", torch.cuda.is_available())
|
@@ -54,13 +61,9 @@ batch_size_default = BATCH_SIZE_DEFAULT
|
|
54 |
deduplication_threshold = DEDUPLICATION_THRESHOLD
|
55 |
max_comment_character_length = MAX_COMMENT_CHARS
|
56 |
|
57 |
-
# if RUN_AWS_FUNCTIONS == '1':
|
58 |
-
# bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_REGION)
|
59 |
-
# else:
|
60 |
-
# bedrock_runtime = []
|
61 |
|
62 |
if not LLM_THREADS:
|
63 |
-
threads = torch.get_num_threads()
|
64 |
else: threads = LLM_THREADS
|
65 |
print("CPU threads:", threads)
|
66 |
|
@@ -76,7 +79,8 @@ else: sample = False
|
|
76 |
temperature = LLM_TEMPERATURE
|
77 |
top_k = LLM_TOP_K
|
78 |
top_p = LLM_TOP_P
|
79 |
-
|
|
|
80 |
last_n_tokens = LLM_LAST_N_TOKENS
|
81 |
max_new_tokens: int = LLM_MAX_NEW_TOKENS
|
82 |
seed: int = LLM_SEED
|
@@ -86,6 +90,7 @@ threads: int = threads
|
|
86 |
batch_size:int = LLM_BATCH_SIZE
|
87 |
context_length:int = LLM_CONTEXT_LENGTH
|
88 |
sample = LLM_SAMPLE
|
|
|
89 |
|
90 |
class llama_cpp_init_config_gpu:
|
91 |
def __init__(self,
|
@@ -122,6 +127,7 @@ cpu_config = llama_cpp_init_config_cpu()
|
|
122 |
class LlamaCPPGenerationConfig:
|
123 |
def __init__(self, temperature=temperature,
|
124 |
top_k=top_k,
|
|
|
125 |
top_p=top_p,
|
126 |
repeat_penalty=repetition_penalty,
|
127 |
seed=seed,
|
@@ -164,6 +170,7 @@ def get_model_path(repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model
|
|
164 |
if hf_token:
|
165 |
downloaded_model_path = hf_hub_download(repo_id=repo_id, token=hf_token, filename=model_filename)
|
166 |
else:
|
|
|
167 |
downloaded_model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
|
168 |
|
169 |
return downloaded_model_path
|
@@ -185,13 +192,16 @@ def load_model(local_model_type:str=CHOSEN_LOCAL_MODEL_TYPE, gpu_layers:int=gpu_
|
|
185 |
|
186 |
try:
|
187 |
print("GPU load variables:" , vars(gpu_config))
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
190 |
except Exception as e:
|
191 |
print("GPU load failed due to:", e)
|
192 |
-
llama_model = Llama(model_path=model_path, type_k=8, **vars(cpu_config))
|
193 |
|
194 |
-
print("Loading with", gpu_config.n_gpu_layers, "model layers sent to GPU
|
195 |
|
196 |
# CPU mode
|
197 |
else:
|
@@ -202,11 +212,14 @@ def load_model(local_model_type:str=CHOSEN_LOCAL_MODEL_TYPE, gpu_layers:int=gpu_
|
|
202 |
gpu_config.update_context(max_context_length)
|
203 |
cpu_config.update_context(max_context_length)
|
204 |
|
205 |
-
|
|
|
|
|
|
|
206 |
|
207 |
-
print("Loading with", cpu_config.n_gpu_layers, "model layers sent to GPU
|
208 |
|
209 |
-
tokenizer =
|
210 |
|
211 |
print("Finished loading model:", local_model_type)
|
212 |
print("GPU layers assigned to cuda:", gpu_layers)
|
@@ -244,6 +257,47 @@ def call_llama_cpp_model(formatted_string:str, gen_config:str, model=model):
|
|
244 |
|
245 |
return output
|
246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
# This function is not used in this app
|
248 |
def llama_cpp_streaming(history, full_prompt, temperature=temperature):
|
249 |
|
@@ -392,7 +446,7 @@ def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tok
|
|
392 |
return response
|
393 |
|
394 |
# Function to send a request and update history
|
395 |
-
def send_request(prompt: str, conversation_history: List[dict], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, system_prompt: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, local_model=
|
396 |
"""
|
397 |
This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
|
398 |
It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
|
@@ -421,16 +475,9 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
|
|
421 |
for i in progress_bar:
|
422 |
try:
|
423 |
print("Calling Gemini model, attempt", i + 1)
|
424 |
-
#print("google_client:", google_client)
|
425 |
-
#print("model_choice:", model_choice)
|
426 |
-
#print("full_prompt:", full_prompt)
|
427 |
-
#print("generation_config:", config)
|
428 |
|
429 |
response = google_client.models.generate_content(model=model_choice, contents=full_prompt, config=config)
|
430 |
|
431 |
-
#progress_bar.close()
|
432 |
-
#tqdm._instances.clear()
|
433 |
-
|
434 |
print("Successful call to Gemini model.")
|
435 |
break
|
436 |
except Exception as e:
|
@@ -447,9 +494,6 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
|
|
447 |
print("Calling AWS Claude model, attempt", i + 1)
|
448 |
response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice, bedrock_runtime=bedrock_runtime, assistant_prefill=assistant_prefill)
|
449 |
|
450 |
-
#progress_bar.close()
|
451 |
-
#tqdm._instances.clear()
|
452 |
-
|
453 |
print("Successful call to Claude model.")
|
454 |
break
|
455 |
except Exception as e:
|
@@ -468,10 +512,7 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
|
|
468 |
gen_config = LlamaCPPGenerationConfig()
|
469 |
gen_config.update_temp(temperature)
|
470 |
|
471 |
-
response =
|
472 |
-
|
473 |
-
#progress_bar.close()
|
474 |
-
#tqdm._instances.clear()
|
475 |
|
476 |
print("Successful call to local model. Response:", response)
|
477 |
break
|
@@ -492,7 +533,7 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
|
|
492 |
if isinstance(response, ResponseObject):
|
493 |
conversation_history.append({'role': 'assistant', 'parts': [response.text]})
|
494 |
elif 'choices' in response:
|
495 |
-
conversation_history.append({'role': 'assistant', 'parts': [response['choices'][0]['text']]})
|
496 |
else:
|
497 |
conversation_history.append({'role': 'assistant', 'parts': [response.text]})
|
498 |
|
@@ -501,7 +542,7 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
|
|
501 |
|
502 |
return response, conversation_history
|
503 |
|
504 |
-
def process_requests(prompts: List[str], system_prompt: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, batch_no:int = 1, local_model =
|
505 |
"""
|
506 |
Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
|
507 |
|
@@ -525,21 +566,19 @@ def process_requests(prompts: List[str], system_prompt: str, conversation_histor
|
|
525 |
Returns:
|
526 |
Tuple[List[ResponseObject], List[dict], List[str], List[str]]: A tuple containing the list of responses, the updated conversation history, the updated whole conversation, and the updated whole conversation metadata.
|
527 |
"""
|
528 |
-
responses =
|
529 |
|
530 |
# Clear any existing progress bars
|
531 |
tqdm._instances.clear()
|
532 |
|
533 |
for prompt in prompts:
|
534 |
|
535 |
-
#print("prompt to LLM:", prompt)
|
536 |
-
|
537 |
response, conversation_history = send_request(prompt, conversation_history, google_client=google_client, config=config, model_choice=model_choice, system_prompt=system_prompt, temperature=temperature, local_model=local_model, assistant_prefill=assistant_prefill, bedrock_runtime=bedrock_runtime, model_source=model_source)
|
538 |
|
539 |
if isinstance(response, ResponseObject):
|
540 |
response_text = response.text
|
541 |
elif 'choices' in response:
|
542 |
-
response_text = response['choices'][0]['text']
|
543 |
else:
|
544 |
response_text = response.text
|
545 |
|
@@ -550,9 +589,10 @@ def process_requests(prompts: List[str], system_prompt: str, conversation_histor
|
|
550 |
|
551 |
# Create conversation metadata
|
552 |
if master == False:
|
553 |
-
whole_conversation_metadata.append(f"
|
554 |
else:
|
555 |
-
whole_conversation_metadata.append(f"Query summary metadata:")
|
|
|
556 |
|
557 |
if not isinstance(response, str):
|
558 |
try:
|
@@ -571,6 +611,9 @@ def process_requests(prompts: List[str], system_prompt: str, conversation_histor
|
|
571 |
elif "gemini" in model_choice:
|
572 |
whole_conversation_metadata.append(str(response.usage_metadata))
|
573 |
else:
|
|
|
|
|
|
|
574 |
whole_conversation_metadata.append(str(response['usage']))
|
575 |
except KeyError as e:
|
576 |
print(f"Key error: {e} - Check the structure of response.usage_metadata")
|
@@ -638,10 +681,14 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
|
|
638 |
call_temperature, bedrock_runtime, model_source, reported_batch_no, local_model, master=master, assistant_prefill=assistant_prefill
|
639 |
)
|
640 |
|
641 |
-
if
|
642 |
-
stripped_response = responses[-1].text.strip()
|
643 |
-
|
644 |
-
|
|
|
|
|
|
|
|
|
645 |
|
646 |
# Check if response meets our criteria (length and contains table)
|
647 |
if len(stripped_response) > 120 and '|' in stripped_response:
|
@@ -656,7 +703,7 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
|
|
656 |
else: # This runs if no break occurred (all attempts failed)
|
657 |
print(f"Failed to get valid response after {MAX_OUTPUT_VALIDATION_ATTEMPTS} attempts")
|
658 |
|
659 |
-
return responses, conversation_history, whole_conversation, whole_conversation_metadata,
|
660 |
|
661 |
def create_missing_references_df(basic_response_df: pd.DataFrame, existing_reference_df: pd.DataFrame) -> pd.DataFrame:
|
662 |
"""
|
|
|
7 |
import json
|
8 |
from tqdm import tqdm
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
+
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
|
11 |
from typing import List, Tuple, TypeVar
|
12 |
from google import genai as ai
|
13 |
from google.genai import types
|
|
|
18 |
|
19 |
model_type = None # global variable setup
|
20 |
full_text = "" # Define dummy source text (full text) just to enable highlight function to load
|
21 |
+
model = list() # Define empty list for model functions to run
|
22 |
+
tokenizer = list() #[] # Define empty list for model functions to run
|
23 |
|
24 |
+
from tools.config import RUN_AWS_FUNCTIONS, AWS_REGION, LLM_TEMPERATURE, LLM_TOP_K, LLM_MIN_P, LLM_TOP_P, LLM_REPETITION_PENALTY, LLM_LAST_N_TOKENS, LLM_MAX_NEW_TOKENS, LLM_SEED, LLM_RESET, LLM_STREAM, LLM_THREADS, LLM_BATCH_SIZE, LLM_CONTEXT_LENGTH, LLM_SAMPLE, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, MAX_COMMENT_CHARS, RUN_LOCAL_MODEL, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, HF_TOKEN, LLM_SEED, LLM_MAX_GPU_LAYERS, SPECULATIVE_DECODING, NUM_PRED_TOKENS
|
25 |
from tools.prompts import initial_table_assistant_prefill
|
26 |
|
27 |
+
if SPECULATIVE_DECODING == "True": SPECULATIVE_DECODING = True
|
28 |
+
else: SPECULATIVE_DECODING = False
|
29 |
+
|
30 |
+
if isinstance(NUM_PRED_TOKENS, str): NUM_PRED_TOKENS = int(NUM_PRED_TOKENS)
|
31 |
+
else: NUM_PRED_TOKENS = NUM_PRED_TOKENS
|
32 |
+
|
33 |
# Both models are loaded on app initialisation so that users don't have to wait for the models to be downloaded
|
34 |
# Check for torch cuda
|
35 |
print("Is CUDA enabled? ", torch.cuda.is_available())
|
|
|
61 |
deduplication_threshold = DEDUPLICATION_THRESHOLD
|
62 |
max_comment_character_length = MAX_COMMENT_CHARS
|
63 |
|
|
|
|
|
|
|
|
|
64 |
|
65 |
if not LLM_THREADS:
|
66 |
+
threads = torch.get_num_threads()
|
67 |
else: threads = LLM_THREADS
|
68 |
print("CPU threads:", threads)
|
69 |
|
|
|
79 |
temperature = LLM_TEMPERATURE
|
80 |
top_k = LLM_TOP_K
|
81 |
top_p = LLM_TOP_P
|
82 |
+
min_p = LLM_MIN_P
|
83 |
+
repetition_penalty = LLM_REPETITION_PENALTY
|
84 |
last_n_tokens = LLM_LAST_N_TOKENS
|
85 |
max_new_tokens: int = LLM_MAX_NEW_TOKENS
|
86 |
seed: int = LLM_SEED
|
|
|
90 |
batch_size:int = LLM_BATCH_SIZE
|
91 |
context_length:int = LLM_CONTEXT_LENGTH
|
92 |
sample = LLM_SAMPLE
|
93 |
+
speculative_decoding = SPECULATIVE_DECODING
|
94 |
|
95 |
class llama_cpp_init_config_gpu:
|
96 |
def __init__(self,
|
|
|
127 |
class LlamaCPPGenerationConfig:
|
128 |
def __init__(self, temperature=temperature,
|
129 |
top_k=top_k,
|
130 |
+
min_p=min_p,
|
131 |
top_p=top_p,
|
132 |
repeat_penalty=repetition_penalty,
|
133 |
seed=seed,
|
|
|
170 |
if hf_token:
|
171 |
downloaded_model_path = hf_hub_download(repo_id=repo_id, token=hf_token, filename=model_filename)
|
172 |
else:
|
173 |
+
print("No HF token found, downloading model from Hugging Face Hub without token")
|
174 |
downloaded_model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
|
175 |
|
176 |
return downloaded_model_path
|
|
|
192 |
|
193 |
try:
|
194 |
print("GPU load variables:" , vars(gpu_config))
|
195 |
+
if speculative_decoding:
|
196 |
+
llama_model = Llama(model_path=model_path, type_k=8, type_v=8, flash_attn=True, draft_model=LlamaPromptLookupDecoding(num_pred_tokens=NUM_PRED_TOKENS), **vars(gpu_config))
|
197 |
+
else:
|
198 |
+
llama_model = Llama(model_path=model_path, type_k=8, type_v=8, flash_attn=True, **vars(gpu_config))
|
199 |
+
|
200 |
except Exception as e:
|
201 |
print("GPU load failed due to:", e)
|
202 |
+
llama_model = Llama(model_path=model_path, type_k=8, **vars(cpu_config))
|
203 |
|
204 |
+
print("Loading with", gpu_config.n_gpu_layers, "model layers sent to GPU and a maximum context length of", gpu_config.n_ctx)
|
205 |
|
206 |
# CPU mode
|
207 |
else:
|
|
|
212 |
gpu_config.update_context(max_context_length)
|
213 |
cpu_config.update_context(max_context_length)
|
214 |
|
215 |
+
if speculative_decoding:
|
216 |
+
llama_model = Llama(model_path=model_path, type_k=8, type_v=8, flash_attn=True, draft_model=LlamaPromptLookupDecoding(num_pred_tokens=NUM_PRED_TOKENS), **vars(gpu_config))
|
217 |
+
else:
|
218 |
+
llama_model = Llama(model_path=model_path, type_k=8, **vars(cpu_config))
|
219 |
|
220 |
+
print("Loading with", cpu_config.n_gpu_layers, "model layers sent to GPU and a maximum context length of", gpu_config.n_ctx)
|
221 |
|
222 |
+
tokenizer = list()
|
223 |
|
224 |
print("Finished loading model:", local_model_type)
|
225 |
print("GPU layers assigned to cuda:", gpu_layers)
|
|
|
257 |
|
258 |
return output
|
259 |
|
260 |
+
def call_llama_cpp_chatmodel(formatted_string:str, system_prompt:str, gen_config:LlamaCPPGenerationConfig, model=model):
|
261 |
+
"""
|
262 |
+
Calls your Llama.cpp chat model with a formatted user message and system prompt,
|
263 |
+
using generation parameters from the LlamaCPPGenerationConfig object.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
formatted_string (str): The formatted input text for the user's message.
|
267 |
+
system_prompt (str): The system-level instructions for the model.
|
268 |
+
gen_config (LlamaCPPGenerationConfig): An object containing generation parameters.
|
269 |
+
model (Llama): The Llama.cpp model instance to use for chat completion.
|
270 |
+
"""
|
271 |
+
# Extracting parameters from the gen_config object
|
272 |
+
temperature = gen_config.temperature
|
273 |
+
top_k = gen_config.top_k
|
274 |
+
top_p = gen_config.top_p
|
275 |
+
repeat_penalty = gen_config.repeat_penalty
|
276 |
+
seed = gen_config.seed
|
277 |
+
max_tokens = gen_config.max_tokens
|
278 |
+
stream = gen_config.stream
|
279 |
+
|
280 |
+
# Now you can call your model directly, passing the parameters:
|
281 |
+
output = model.create_chat_completion(
|
282 |
+
messages=[
|
283 |
+
{"role": "system", "content": system_prompt},
|
284 |
+
{
|
285 |
+
"role": "user",
|
286 |
+
"content": formatted_string
|
287 |
+
}
|
288 |
+
],
|
289 |
+
temperature=temperature,
|
290 |
+
top_k=top_k,
|
291 |
+
top_p=top_p,
|
292 |
+
repeat_penalty=repeat_penalty,
|
293 |
+
seed=seed,
|
294 |
+
max_tokens=max_tokens,
|
295 |
+
stream=stream
|
296 |
+
#stop=["<|eot_id|>", "\n\n"]
|
297 |
+
)
|
298 |
+
|
299 |
+
return output
|
300 |
+
|
301 |
# This function is not used in this app
|
302 |
def llama_cpp_streaming(history, full_prompt, temperature=temperature):
|
303 |
|
|
|
446 |
return response
|
447 |
|
448 |
# Function to send a request and update history
|
449 |
+
def send_request(prompt: str, conversation_history: List[dict], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, system_prompt: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, local_model= list(), assistant_prefill = "", progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
|
450 |
"""
|
451 |
This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
|
452 |
It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
|
|
|
475 |
for i in progress_bar:
|
476 |
try:
|
477 |
print("Calling Gemini model, attempt", i + 1)
|
|
|
|
|
|
|
|
|
478 |
|
479 |
response = google_client.models.generate_content(model=model_choice, contents=full_prompt, config=config)
|
480 |
|
|
|
|
|
|
|
481 |
print("Successful call to Gemini model.")
|
482 |
break
|
483 |
except Exception as e:
|
|
|
494 |
print("Calling AWS Claude model, attempt", i + 1)
|
495 |
response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice, bedrock_runtime=bedrock_runtime, assistant_prefill=assistant_prefill)
|
496 |
|
|
|
|
|
|
|
497 |
print("Successful call to Claude model.")
|
498 |
break
|
499 |
except Exception as e:
|
|
|
512 |
gen_config = LlamaCPPGenerationConfig()
|
513 |
gen_config.update_temp(temperature)
|
514 |
|
515 |
+
response = call_llama_cpp_chatmodel(prompt, system_prompt, gen_config, model=local_model)
|
|
|
|
|
|
|
516 |
|
517 |
print("Successful call to local model. Response:", response)
|
518 |
break
|
|
|
533 |
if isinstance(response, ResponseObject):
|
534 |
conversation_history.append({'role': 'assistant', 'parts': [response.text]})
|
535 |
elif 'choices' in response:
|
536 |
+
conversation_history.append({'role': 'assistant', 'parts': [response['choices'][0]['message']['content']]}) #response['choices'][0]['text']]})
|
537 |
else:
|
538 |
conversation_history.append({'role': 'assistant', 'parts': [response.text]})
|
539 |
|
|
|
542 |
|
543 |
return response, conversation_history
|
544 |
|
545 |
+
def process_requests(prompts: List[str], system_prompt: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, batch_no:int = 1, local_model = list(), master:bool = False, assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
|
546 |
"""
|
547 |
Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
|
548 |
|
|
|
566 |
Returns:
|
567 |
Tuple[List[ResponseObject], List[dict], List[str], List[str]]: A tuple containing the list of responses, the updated conversation history, the updated whole conversation, and the updated whole conversation metadata.
|
568 |
"""
|
569 |
+
responses = list()
|
570 |
|
571 |
# Clear any existing progress bars
|
572 |
tqdm._instances.clear()
|
573 |
|
574 |
for prompt in prompts:
|
575 |
|
|
|
|
|
576 |
response, conversation_history = send_request(prompt, conversation_history, google_client=google_client, config=config, model_choice=model_choice, system_prompt=system_prompt, temperature=temperature, local_model=local_model, assistant_prefill=assistant_prefill, bedrock_runtime=bedrock_runtime, model_source=model_source)
|
577 |
|
578 |
if isinstance(response, ResponseObject):
|
579 |
response_text = response.text
|
580 |
elif 'choices' in response:
|
581 |
+
response_text = response['choices'][0]['message']['content'] # response['choices'][0]['text']
|
582 |
else:
|
583 |
response_text = response.text
|
584 |
|
|
|
589 |
|
590 |
# Create conversation metadata
|
591 |
if master == False:
|
592 |
+
whole_conversation_metadata.append(f"Batch {batch_no}:")
|
593 |
else:
|
594 |
+
#whole_conversation_metadata.append(f"Query summary metadata:")
|
595 |
+
whole_conversation_metadata.append(f"Batch {batch_no}:")
|
596 |
|
597 |
if not isinstance(response, str):
|
598 |
try:
|
|
|
611 |
elif "gemini" in model_choice:
|
612 |
whole_conversation_metadata.append(str(response.usage_metadata))
|
613 |
else:
|
614 |
+
print("Adding usage metadata to whole conversation metadata:", response['usage'])
|
615 |
+
output_tokens = response['usage'].get('completion_tokens', 0)
|
616 |
+
input_tokens = response['usage'].get('prompt_tokens', 0)
|
617 |
whole_conversation_metadata.append(str(response['usage']))
|
618 |
except KeyError as e:
|
619 |
print(f"Key error: {e} - Check the structure of response.usage_metadata")
|
|
|
681 |
call_temperature, bedrock_runtime, model_source, reported_batch_no, local_model, master=master, assistant_prefill=assistant_prefill
|
682 |
)
|
683 |
|
684 |
+
#if model_source != "Local":
|
685 |
+
#stripped_response = responses[-1].text.strip()
|
686 |
+
#stripped_response = response_text.strip()
|
687 |
+
#else:
|
688 |
+
#stripped_response = response['choices'][0]['message']['content'].strip()
|
689 |
+
#stripped_response = response_text.strip()
|
690 |
+
|
691 |
+
stripped_response = response_text.strip()
|
692 |
|
693 |
# Check if response meets our criteria (length and contains table)
|
694 |
if len(stripped_response) > 120 and '|' in stripped_response:
|
|
|
703 |
else: # This runs if no break occurred (all attempts failed)
|
704 |
print(f"Failed to get valid response after {MAX_OUTPUT_VALIDATION_ATTEMPTS} attempts")
|
705 |
|
706 |
+
return responses, conversation_history, whole_conversation, whole_conversation_metadata, stripped_response
|
707 |
|
708 |
def create_missing_references_df(basic_response_df: pd.DataFrame, existing_reference_df: pd.DataFrame) -> pd.DataFrame:
|
709 |
"""
|
tools/verify_titles.py
CHANGED
@@ -101,7 +101,7 @@ def write_llm_output_and_logs_verify(responses: List[ResponseObject],
|
|
101 |
log_files_output_paths.append(whole_conversation_path_meta)
|
102 |
|
103 |
if isinstance(responses[-1], ResponseObject): response_text = responses[-1].text
|
104 |
-
elif "choices" in responses[-1]: response_text = responses[-1]["choices"][0]['text']
|
105 |
else: response_text = responses[-1].text
|
106 |
|
107 |
# Convert response text to a markdown table
|
@@ -464,13 +464,16 @@ def verify_titles(in_data_file,
|
|
464 |
|
465 |
if isinstance(responses[-1], ResponseObject):
|
466 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
467 |
-
f.write(responses[-1].text)
|
|
|
468 |
elif "choices" in responses[-1]:
|
469 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
470 |
-
f.write(responses[-1]["choices"][0]['text'])
|
|
|
471 |
else:
|
472 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
473 |
-
f.write(responses[-1].text)
|
|
|
474 |
|
475 |
except Exception as e:
|
476 |
print("Error in returning model response:", e)
|
@@ -581,15 +584,18 @@ def verify_titles(in_data_file,
|
|
581 |
|
582 |
if isinstance(responses[-1], ResponseObject):
|
583 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
584 |
-
f.write(responses[-1].text)
|
|
|
585 |
unique_table_df_display_table_markdown = responses[-1].text
|
586 |
elif "choices" in responses[-1]:
|
587 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
588 |
-
f.write(responses[-1]["choices"][0]['text'])
|
589 |
-
|
|
|
590 |
else:
|
591 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
592 |
-
f.write(responses[-1].text)
|
|
|
593 |
unique_table_df_display_table_markdown = responses[-1].text
|
594 |
|
595 |
log_files_output_paths.append(final_table_output_path)
|
|
|
101 |
log_files_output_paths.append(whole_conversation_path_meta)
|
102 |
|
103 |
if isinstance(responses[-1], ResponseObject): response_text = responses[-1].text
|
104 |
+
elif "choices" in responses[-1]: response_text = responses[-1]['choices'][0]['message']['content'] #responses[-1]["choices"][0]['text']
|
105 |
else: response_text = responses[-1].text
|
106 |
|
107 |
# Convert response text to a markdown table
|
|
|
464 |
|
465 |
if isinstance(responses[-1], ResponseObject):
|
466 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
467 |
+
#f.write(responses[-1].text)
|
468 |
+
f.write(response_text)
|
469 |
elif "choices" in responses[-1]:
|
470 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
471 |
+
#f.write(responses[-1]["choices"][0]['text'])
|
472 |
+
f.write(response_text)
|
473 |
else:
|
474 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
475 |
+
#f.write(responses[-1].text)
|
476 |
+
f.write(response_text)
|
477 |
|
478 |
except Exception as e:
|
479 |
print("Error in returning model response:", e)
|
|
|
584 |
|
585 |
if isinstance(responses[-1], ResponseObject):
|
586 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
587 |
+
#f.write(responses[-1].text)
|
588 |
+
f.write(response_text)
|
589 |
unique_table_df_display_table_markdown = responses[-1].text
|
590 |
elif "choices" in responses[-1]:
|
591 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
592 |
+
#f.write(responses[-1]["choices"][0]['text'])
|
593 |
+
f.write(response_text)
|
594 |
+
unique_table_df_display_table_markdown =responses[-1]["choices"][0]['message']['content'] #responses[-1]["choices"][0]['text']
|
595 |
else:
|
596 |
with open(final_table_output_path, "w", encoding='utf-8-sig', errors='replace') as f:
|
597 |
+
#f.write(responses[-1].text)
|
598 |
+
f.write(response_text)
|
599 |
unique_table_df_display_table_markdown = responses[-1].text
|
600 |
|
601 |
log_files_output_paths.append(final_table_output_path)
|