seanpedrickcase commited on
Commit
72d517c
·
1 Parent(s): 8c54223

Enabled GPU-based local model inference with the transformers package

Browse files
app.py CHANGED
@@ -12,7 +12,7 @@ from tools.custom_csvlogger import CSVLogger_custom
12
  from tools.auth import authenticate_user
13
  from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, verify_titles_prompt, verify_titles_system_prompt, two_para_summary_format_prompt, single_para_summary_format_prompt
14
  from tools.verify_titles import verify_titles
15
- from tools.config import RUN_AWS_FUNCTIONS, HOST_NAME, ACCESS_LOGS_FOLDER, FEEDBACK_LOGS_FOLDER, USAGE_LOGS_FOLDER, RUN_LOCAL_MODEL, FILE_INPUT_HEIGHT, GEMINI_API_KEY, model_full_names, BATCH_SIZE_DEFAULT, CHOSEN_LOCAL_MODEL_TYPE, LLM_SEED, COGNITO_AUTH, MAX_QUEUE_SIZE, MAX_FILE_SIZE, GRADIO_SERVER_PORT, ROOT_PATH, INPUT_FOLDER, OUTPUT_FOLDER, S3_LOG_BUCKET, CONFIG_FOLDER, GRADIO_TEMP_DIR, MPLCONFIGDIR, model_name_map, GET_COST_CODES, ENFORCE_COST_CODES, DEFAULT_COST_CODE, COST_CODES_PATH, S3_COST_CODES_PATH, OUTPUT_COST_CODES_PATH, SHOW_COSTS, SAVE_LOGS_TO_CSV, SAVE_LOGS_TO_DYNAMODB, ACCESS_LOG_DYNAMODB_TABLE_NAME, USAGE_LOG_DYNAMODB_TABLE_NAME, FEEDBACK_LOG_DYNAMODB_TABLE_NAME, LOG_FILE_NAME, FEEDBACK_LOG_FILE_NAME, USAGE_LOG_FILE_NAME, CSV_ACCESS_LOG_HEADERS, CSV_FEEDBACK_LOG_HEADERS, CSV_USAGE_LOG_HEADERS, DYNAMODB_ACCESS_LOG_HEADERS, DYNAMODB_FEEDBACK_LOG_HEADERS, DYNAMODB_USAGE_LOG_HEADERS, S3_ACCESS_LOGS_FOLDER, S3_FEEDBACK_LOGS_FOLDER, S3_USAGE_LOGS_FOLDER, AWS_ACCESS_KEY, AWS_SECRET_KEY, SHOW_EXAMPLES
16
 
17
  def ensure_folder_exists(output_folder:str):
18
  """Checks if the specified folder exists, creates it if not."""
@@ -148,14 +148,14 @@ with app:
148
 
149
  Instructions on use can be found in the README.md file. Try it out with this [dummy development consultation dataset](https://huggingface.co/datasets/seanpedrickcase/dummy_development_consultation/tree/main), which you can also try with [zero-shot topics](https://huggingface.co/datasets/seanpedrickcase/dummy_development_consultation/tree/main). Try also this [dummy case notes dataset](https://huggingface.co/datasets/seanpedrickcase/dummy_case_notes/tree/main).
150
 
151
- You can use an AWS Bedrock model (paid), or Gemini (a free API for Flash). The use of Gemini requires an API key. To set up your own Gemini API key, go [here](https://aistudio.google.com/app/u/1/plan_information).
152
 
153
  NOTE: Large language models are not 100% accurate and may produce biased or harmful outputs. All outputs from this app **absolutely need to be checked by a human** to check for harmful outputs, hallucinations, and accuracy.""")
154
 
155
  with gr.Tab(label="1. Extract topics"):
156
  gr.Markdown("""### Choose a tabular data file (xlsx, csv, parquet) of open text to extract topics from.""")
157
  with gr.Row():
158
- model_choice = gr.Dropdown(value = default_model_choice, choices = model_full_names, label="LLM model", multiselect=False)
159
 
160
  with gr.Accordion("Upload xlsx or csv file", open = True):
161
  in_data_files = gr.File(height=FILE_INPUT_HEIGHT, label="Choose Excel or csv files", file_count= "multiple", file_types=['.xlsx', '.xls', '.csv', '.parquet'])
@@ -177,6 +177,9 @@ with app:
177
 
178
  sentiment_checkbox = gr.Radio(label="Choose sentiment categories to split responses", value="Negative or Positive", choices=["Negative or Positive", "Negative, Neutral, or Positive", "Do not assess sentiment"])
179
 
 
 
 
180
  if GET_COST_CODES == "True" or ENFORCE_COST_CODES == "True":
181
  with gr.Accordion("Assign task to cost code", open = True, visible=True):
182
  gr.Markdown("Please ensure that you have approval from your budget holder before using this app for redaction tasks that incur a cost.")
@@ -188,9 +191,6 @@ with app:
188
 
189
  all_in_one_btn = gr.Button("All in one - Extract topics, deduplicate, and summarise", variant="primary")
190
  extract_topics_btn = gr.Button("1. Extract topics", variant="secondary")
191
-
192
- if SHOW_EXAMPLES == "True":
193
- examples = gr.Examples(examples=[[["example_data/dummy_consultation_response.csv"]], [["example_data/combined_case_notes.csv"]]], inputs=[in_data_files])
194
 
195
  with gr.Row(equal_height=True):
196
  output_messages_textbox = gr.Textbox(value="", label="Output messages", scale=1, interactive=False)
@@ -307,6 +307,9 @@ with app:
307
  with gr.Accordion("Gemini API keys", open = False):
308
  google_api_key_textbox = gr.Textbox(value = GEMINI_API_KEY, label="Enter Gemini API key (only if using Google API models)", lines=1, type="password")
309
 
 
 
 
310
  with gr.Accordion("Log outputs", open = False):
311
  log_files_output = gr.File(height=FILE_INPUT_HEIGHT, label="Log file output", interactive=False)
312
  conversation_metadata_textbox = gr.Textbox(value="", label="Query metadata - usage counts and other parameters", lines=8)
@@ -350,7 +353,7 @@ with app:
350
  ###
351
 
352
  # Tabular data upload
353
- in_data_files.change(fn=put_columns_in_df, inputs=[in_data_files], outputs=[in_colnames, in_excel_sheets, original_data_file_name_textbox, join_colnames, in_group_col])
354
 
355
  # Click on cost code dataframe/dropdown fills in cost code textbox
356
  # Allow user to select items from cost code dataframe for cost code
@@ -401,6 +404,7 @@ with app:
401
  produce_structures_summary_radio,
402
  aws_access_key_textbox,
403
  aws_secret_key_textbox,
 
404
  output_folder_state],
405
  outputs=[display_topic_table_markdown,
406
  master_topic_df_state,
@@ -498,6 +502,7 @@ with app:
498
  produce_structures_summary_radio,
499
  aws_access_key_textbox,
500
  aws_secret_key_textbox,
 
501
  output_folder_state],
502
  outputs=[display_topic_table_markdown,
503
  master_topic_df_state,
 
12
  from tools.auth import authenticate_user
13
  from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, verify_titles_prompt, verify_titles_system_prompt, two_para_summary_format_prompt, single_para_summary_format_prompt
14
  from tools.verify_titles import verify_titles
15
+ from tools.config import RUN_AWS_FUNCTIONS, HOST_NAME, ACCESS_LOGS_FOLDER, FEEDBACK_LOGS_FOLDER, USAGE_LOGS_FOLDER, RUN_LOCAL_MODEL, FILE_INPUT_HEIGHT, GEMINI_API_KEY, model_full_names, BATCH_SIZE_DEFAULT, CHOSEN_LOCAL_MODEL_TYPE, LLM_SEED, COGNITO_AUTH, MAX_QUEUE_SIZE, MAX_FILE_SIZE, GRADIO_SERVER_PORT, ROOT_PATH, INPUT_FOLDER, OUTPUT_FOLDER, S3_LOG_BUCKET, CONFIG_FOLDER, GRADIO_TEMP_DIR, MPLCONFIGDIR, model_name_map, GET_COST_CODES, ENFORCE_COST_CODES, DEFAULT_COST_CODE, COST_CODES_PATH, S3_COST_CODES_PATH, OUTPUT_COST_CODES_PATH, SHOW_COSTS, SAVE_LOGS_TO_CSV, SAVE_LOGS_TO_DYNAMODB, ACCESS_LOG_DYNAMODB_TABLE_NAME, USAGE_LOG_DYNAMODB_TABLE_NAME, FEEDBACK_LOG_DYNAMODB_TABLE_NAME, LOG_FILE_NAME, FEEDBACK_LOG_FILE_NAME, USAGE_LOG_FILE_NAME, CSV_ACCESS_LOG_HEADERS, CSV_FEEDBACK_LOG_HEADERS, CSV_USAGE_LOG_HEADERS, DYNAMODB_ACCESS_LOG_HEADERS, DYNAMODB_FEEDBACK_LOG_HEADERS, DYNAMODB_USAGE_LOG_HEADERS, S3_ACCESS_LOGS_FOLDER, S3_FEEDBACK_LOGS_FOLDER, S3_USAGE_LOGS_FOLDER, AWS_ACCESS_KEY, AWS_SECRET_KEY, SHOW_EXAMPLES, HF_TOKEN
16
 
17
  def ensure_folder_exists(output_folder:str):
18
  """Checks if the specified folder exists, creates it if not."""
 
148
 
149
  Instructions on use can be found in the README.md file. Try it out with this [dummy development consultation dataset](https://huggingface.co/datasets/seanpedrickcase/dummy_development_consultation/tree/main), which you can also try with [zero-shot topics](https://huggingface.co/datasets/seanpedrickcase/dummy_development_consultation/tree/main). Try also this [dummy case notes dataset](https://huggingface.co/datasets/seanpedrickcase/dummy_case_notes/tree/main).
150
 
151
+ You can use an AWS Bedrock model (paid), or Gemini (a free API for Flash). The use of Gemini requires an API key. To set up your own Gemini API key, go [here](https://aistudio.google.com/app/u/1/plan_information).
152
 
153
  NOTE: Large language models are not 100% accurate and may produce biased or harmful outputs. All outputs from this app **absolutely need to be checked by a human** to check for harmful outputs, hallucinations, and accuracy.""")
154
 
155
  with gr.Tab(label="1. Extract topics"):
156
  gr.Markdown("""### Choose a tabular data file (xlsx, csv, parquet) of open text to extract topics from.""")
157
  with gr.Row():
158
+ model_choice = gr.Dropdown(value = default_model_choice, choices = model_full_names, label="LLM model", multiselect=False)
159
 
160
  with gr.Accordion("Upload xlsx or csv file", open = True):
161
  in_data_files = gr.File(height=FILE_INPUT_HEIGHT, label="Choose Excel or csv files", file_count= "multiple", file_types=['.xlsx', '.xls', '.csv', '.parquet'])
 
177
 
178
  sentiment_checkbox = gr.Radio(label="Choose sentiment categories to split responses", value="Negative or Positive", choices=["Negative or Positive", "Negative, Neutral, or Positive", "Do not assess sentiment"])
179
 
180
+ if SHOW_EXAMPLES == "True":
181
+ examples = gr.Examples(examples=[[["example_data/dummy_consultation_response.csv"], "Response text", "Consultation for the construction of flats on Main Street"], [["example_data/combined_case_notes.csv"], "Case Note", "Social Care case notes for young people"]], inputs=[in_data_files, in_colnames, context_textbox], example_labels=["Consultation for the construction of flats on Main Street", "Social Care case notes for young people"], label="Test with an example dataset")
182
+
183
  if GET_COST_CODES == "True" or ENFORCE_COST_CODES == "True":
184
  with gr.Accordion("Assign task to cost code", open = True, visible=True):
185
  gr.Markdown("Please ensure that you have approval from your budget holder before using this app for redaction tasks that incur a cost.")
 
191
 
192
  all_in_one_btn = gr.Button("All in one - Extract topics, deduplicate, and summarise", variant="primary")
193
  extract_topics_btn = gr.Button("1. Extract topics", variant="secondary")
 
 
 
194
 
195
  with gr.Row(equal_height=True):
196
  output_messages_textbox = gr.Textbox(value="", label="Output messages", scale=1, interactive=False)
 
307
  with gr.Accordion("Gemini API keys", open = False):
308
  google_api_key_textbox = gr.Textbox(value = GEMINI_API_KEY, label="Enter Gemini API key (only if using Google API models)", lines=1, type="password")
309
 
310
+ with gr.Accordion("Hugging Face API keys", open = False):
311
+ hf_api_key_textbox = gr.Textbox(value = HF_TOKEN, label="Enter Hugging Face API key (only if using Hugging Face models)", lines=1, type="password")
312
+
313
  with gr.Accordion("Log outputs", open = False):
314
  log_files_output = gr.File(height=FILE_INPUT_HEIGHT, label="Log file output", interactive=False)
315
  conversation_metadata_textbox = gr.Textbox(value="", label="Query metadata - usage counts and other parameters", lines=8)
 
353
  ###
354
 
355
  # Tabular data upload
356
+ in_data_files.upload(fn=put_columns_in_df, inputs=[in_data_files], outputs=[in_colnames, in_excel_sheets, original_data_file_name_textbox, join_colnames, in_group_col])
357
 
358
  # Click on cost code dataframe/dropdown fills in cost code textbox
359
  # Allow user to select items from cost code dataframe for cost code
 
404
  produce_structures_summary_radio,
405
  aws_access_key_textbox,
406
  aws_secret_key_textbox,
407
+ hf_api_key_textbox,
408
  output_folder_state],
409
  outputs=[display_topic_table_markdown,
410
  master_topic_df_state,
 
502
  produce_structures_summary_radio,
503
  aws_access_key_textbox,
504
  aws_secret_key_textbox,
505
+ hf_api_key_textbox,
506
  output_folder_state],
507
  outputs=[display_topic_table_markdown,
508
  master_topic_df_state,
requirements.txt CHANGED
@@ -18,6 +18,8 @@ python-dotenv==1.1.0
18
  # GPU
19
  torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124 # Latest compatible with CUDA 12.4
20
  https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.16-cu124/llama_cpp_python-0.3.16-cp310-cp310-linux_x86_64.whl
 
 
21
  # CPU only (for e.g. Hugging Face CPU instances)
22
  #torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/cpu
23
  # For Hugging Face, need a python 3.10 compatible wheel for llama-cpp-python to avoid build timeouts
 
18
  # GPU
19
  torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124 # Latest compatible with CUDA 12.4
20
  https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.16-cu124/llama_cpp_python-0.3.16-cp310-cp310-linux_x86_64.whl
21
+ bitsandbytes==0.47.0
22
+ accelerate==1.10.1
23
  # CPU only (for e.g. Hugging Face CPU instances)
24
  #torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/cpu
25
  # For Hugging Face, need a python 3.10 compatible wheel for llama-cpp-python to avoid build timeouts
requirements_gpu.txt CHANGED
@@ -17,10 +17,12 @@ python-dotenv==1.1.0
17
  # Torch and Llama CPP Python
18
  torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124 # Latest compatible with CUDA 12.4
19
  # For Linux:
20
- https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.16-cu124/llama_cpp_python-0.3.16-cp311-cp311-linux_x86_64.whl
21
  # For Windows:
22
- #https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/download/v0.1.0/llama_cpp_python-0.3.16-cp311-cp311-win_amd64.whl
23
  # If above doesn't work for Windows, try looking at'windows_install_llama-cpp-python.txt' for instructions on how to build from source
24
  # If none of the above work for you, try the following:
25
  # llama-cpp-python==0.3.16 -C cmake.args="-DGGML_CUDA=on -DGGML_CUBLAS=on"
 
 
26
 
 
17
  # Torch and Llama CPP Python
18
  torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124 # Latest compatible with CUDA 12.4
19
  # For Linux:
20
+ #https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.16-cu124/llama_cpp_python-0.3.16-cp311-cp311-linux_x86_64.whl
21
  # For Windows:
22
+ https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/download/v0.1.0/llama_cpp_python-0.3.16-cp311-cp311-win_amd64.whl
23
  # If above doesn't work for Windows, try looking at'windows_install_llama-cpp-python.txt' for instructions on how to build from source
24
  # If none of the above work for you, try the following:
25
  # llama-cpp-python==0.3.16 -C cmake.args="-DGGML_CUDA=on -DGGML_CUBLAS=on"
26
+ bitsandbytes==0.47.0
27
+ accelerate==1.10.1
28
 
tools/config.py CHANGED
@@ -241,20 +241,38 @@ model_name_map = {
241
 
242
  # HF token may or may not be needed for downloading models from Hugging Face
243
  HF_TOKEN = get_or_create_env_var('HF_TOKEN', '')
 
 
244
 
245
  GEMMA2_REPO_ID = get_or_create_env_var("GEMMA2_2B_REPO_ID", "lmstudio-community/gemma-2-2b-it-GGUF")
 
 
 
 
246
  GEMMA2_MODEL_FILE = get_or_create_env_var("GEMMA2_2B_MODEL_FILE", "gemma-2-2b-it-Q8_0.gguf")
247
  GEMMA2_MODEL_FOLDER = get_or_create_env_var("GEMMA2_2B_MODEL_FOLDER", "model/gemma")
248
 
249
  GEMMA3_REPO_ID = get_or_create_env_var("GEMMA3_REPO_ID", "unsloth/gemma-3-270m-it-qat-GGUF")
 
 
 
 
250
  GEMMA3_MODEL_FILE = get_or_create_env_var("GEMMA3_MODEL_FILE", "gemma-3-270m-it-qat-F16.gguf")
251
  GEMMA3_MODEL_FOLDER = get_or_create_env_var("GEMMA3_MODEL_FOLDER", "model/gemma")
252
 
253
  GEMMA3_4B_REPO_ID = get_or_create_env_var("GEMMA3_4B_REPO_ID", "unsloth/gemma-3-4b-it-qat-GGUF")
 
 
 
 
254
  GEMMA3_4B_MODEL_FILE = get_or_create_env_var("GEMMA3_4B_MODEL_FILE", "gemma-3-4b-it-qat-Q4_K_M.gguf")
255
  GEMMA3_4B_MODEL_FOLDER = get_or_create_env_var("GEMMA3_4B_MODEL_FOLDER", "model/gemma3_4b")
256
 
257
  GPT_OSS_REPO_ID = get_or_create_env_var("GPT_OSS_REPO_ID", "unsloth/gpt-oss-20b-GGUF")
 
 
 
 
258
  GPT_OSS_MODEL_FILE = get_or_create_env_var("GPT_OSS_MODEL_FILE", "gpt-oss-20b-F16.gguf")
259
  GPT_OSS_MODEL_FOLDER = get_or_create_env_var("GPT_OSS_MODEL_FOLDER", "model/gpt_oss")
260
 
@@ -305,6 +323,13 @@ SPECULATIVE_DECODING = get_or_create_env_var('SPECULATIVE_DECODING', 'False')
305
  NUM_PRED_TOKENS = int(get_or_create_env_var('NUM_PRED_TOKENS', '2'))
306
  REASONING_SUFFIX = get_or_create_env_var('REASONING_SUFFIX', 'Reasoning: low')
307
 
 
 
 
 
 
 
 
308
  MAX_GROUPS = int(get_or_create_env_var('MAX_GROUPS', '99'))
309
 
310
  ###
 
241
 
242
  # HF token may or may not be needed for downloading models from Hugging Face
243
  HF_TOKEN = get_or_create_env_var('HF_TOKEN', '')
244
+ USE_LLAMA_CPP = get_or_create_env_var('USE_LLAMA_CPP', 'True') # Llama.cpp or transformers
245
+
246
 
247
  GEMMA2_REPO_ID = get_or_create_env_var("GEMMA2_2B_REPO_ID", "lmstudio-community/gemma-2-2b-it-GGUF")
248
+ GEMMA2_REPO_TRANSFORMERS_ID = get_or_create_env_var("GEMMA2_2B_REPO_TRANSFORMERS_ID", "google/gemma-2-2b-it")
249
+ if USE_LLAMA_CPP == "False":
250
+ GEMMA2_REPO_ID = GEMMA2_REPO_TRANSFORMERS_ID
251
+
252
  GEMMA2_MODEL_FILE = get_or_create_env_var("GEMMA2_2B_MODEL_FILE", "gemma-2-2b-it-Q8_0.gguf")
253
  GEMMA2_MODEL_FOLDER = get_or_create_env_var("GEMMA2_2B_MODEL_FOLDER", "model/gemma")
254
 
255
  GEMMA3_REPO_ID = get_or_create_env_var("GEMMA3_REPO_ID", "unsloth/gemma-3-270m-it-qat-GGUF")
256
+ GEMMA3_REPO_TRANSFORMERS_ID = get_or_create_env_var("GEMMA3_REPO_TRANSFORMERS_ID", "google/gemma-3-270m-it")
257
+ if USE_LLAMA_CPP == "False":
258
+ GEMMA3_REPO_ID = GEMMA3_REPO_TRANSFORMERS_ID
259
+
260
  GEMMA3_MODEL_FILE = get_or_create_env_var("GEMMA3_MODEL_FILE", "gemma-3-270m-it-qat-F16.gguf")
261
  GEMMA3_MODEL_FOLDER = get_or_create_env_var("GEMMA3_MODEL_FOLDER", "model/gemma")
262
 
263
  GEMMA3_4B_REPO_ID = get_or_create_env_var("GEMMA3_4B_REPO_ID", "unsloth/gemma-3-4b-it-qat-GGUF")
264
+ GEMMA3_4B_REPO_TRANSFORMERS_ID = get_or_create_env_var("GEMMA3_4B_REPO_TRANSFORMERS_ID", "google/gemma-3-4b-it")
265
+ if USE_LLAMA_CPP == "False":
266
+ GEMMA3_4B_REPO_ID = GEMMA3_4B_REPO_TRANSFORMERS_ID
267
+
268
  GEMMA3_4B_MODEL_FILE = get_or_create_env_var("GEMMA3_4B_MODEL_FILE", "gemma-3-4b-it-qat-Q4_K_M.gguf")
269
  GEMMA3_4B_MODEL_FOLDER = get_or_create_env_var("GEMMA3_4B_MODEL_FOLDER", "model/gemma3_4b")
270
 
271
  GPT_OSS_REPO_ID = get_or_create_env_var("GPT_OSS_REPO_ID", "unsloth/gpt-oss-20b-GGUF")
272
+ GPT_OSS_REPO_TRANSFORMERS_ID = get_or_create_env_var("GPT_OSS_REPO_TRANSFORMERS_ID", "openai/gpt-oss-20b")
273
+ if USE_LLAMA_CPP == "False":
274
+ GPT_OSS_REPO_ID = GPT_OSS_REPO_TRANSFORMERS_ID
275
+
276
  GPT_OSS_MODEL_FILE = get_or_create_env_var("GPT_OSS_MODEL_FILE", "gpt-oss-20b-F16.gguf")
277
  GPT_OSS_MODEL_FOLDER = get_or_create_env_var("GPT_OSS_MODEL_FOLDER", "model/gpt_oss")
278
 
 
323
  NUM_PRED_TOKENS = int(get_or_create_env_var('NUM_PRED_TOKENS', '2'))
324
  REASONING_SUFFIX = get_or_create_env_var('REASONING_SUFFIX', 'Reasoning: low')
325
 
326
+ # Transformers variables
327
+ COMPILE_TRANSFORMERS = get_or_create_env_var('COMPILE_TRANSFORMERS', 'True') # Whether to compile transformers models
328
+ USE_BITSANDBYTES = get_or_create_env_var('USE_BITSANDBYTES', 'True') # Whether to use bitsandbytes for quantization
329
+ COMPILE_MODE = get_or_create_env_var('COMPILE_MODE', 'reduce-overhead') # alternatively 'max-autotune'
330
+ MODEL_DTYPE = get_or_create_env_var('MODEL_DTYPE', 'float16') # alternatively 'bfloat16'
331
+ OFFLOAD_TO_CPU = get_or_create_env_var('OFFLOAD_TO_CPU', 'False') # Whether to offload to CPU
332
+
333
  MAX_GROUPS = int(get_or_create_env_var('MAX_GROUPS', '99'))
334
 
335
  ###
tools/dedup_summaries.py CHANGED
@@ -415,7 +415,7 @@ def sample_reference_table_summaries(reference_df:pd.DataFrame,
415
 
416
  return sampled_reference_table_df, summarised_references_markdown#, reference_df, topic_summary_df
417
 
418
- def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:float, formatted_summary_prompt:str, summarise_topic_descriptions_system_prompt:str, model_source:str, bedrock_runtime:boto3.Session.client, local_model=list()):
419
  """
420
  Query an LLM to generate a summary of topics based on the provided prompts.
421
 
@@ -428,7 +428,7 @@ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:
428
  model_source (str): Source of the model (e.g. "AWS", "Gemini", "Local")
429
  bedrock_runtime (boto3.Session.client): AWS Bedrock runtime client for AWS models
430
  local_model (object, optional): Local model object if using local inference. Defaults to empty list.
431
-
432
  Returns:
433
  tuple: Contains:
434
  - response_text (str): The generated summary text
@@ -454,7 +454,7 @@ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:
454
  whole_conversation = [summarise_topic_descriptions_system_prompt]
455
 
456
  # Process requests to large language model
457
- responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(formatted_summary_prompt, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, bedrock_runtime=bedrock_runtime, model_source=model_source, local_model=local_model, assistant_prefill=summary_assistant_prefill)
458
 
459
  print("Finished summary query")
460
 
@@ -482,7 +482,9 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
482
  aws_secret_key_textbox:str='',
483
  model_name_map:dict=model_name_map,
484
  reasoning_suffix:str=reasoning_suffix,
485
- local_model:object=list(),
 
 
486
  summarise_topic_descriptions_prompt:str=summarise_topic_descriptions_prompt,
487
  summarise_topic_descriptions_system_prompt:str=summarise_topic_descriptions_system_prompt,
488
  do_summaries:str="Yes",
@@ -572,7 +574,7 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
572
 
573
  if (model_source == "Local") & (RUN_LOCAL_MODEL == "1"):
574
  progress(0.1, f"Loading in local model: {CHOSEN_LOCAL_MODEL_TYPE}")
575
- local_model, tokenizer = load_model(local_model_type=CHOSEN_LOCAL_MODEL_TYPE, repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model_dir=LOCAL_MODEL_FOLDER)
576
 
577
  summary_loop_description = "Revising topic-level summaries. " + str(latest_summary_completed) + " summaries completed so far."
578
  summary_loop = tqdm(range(latest_summary_completed, length_all_summaries), desc="Revising topic-level summaries", unit="summaries")
@@ -592,7 +594,7 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
592
  if "Local" in model_source and reasoning_suffix: formatted_summarise_topic_descriptions_system_prompt = formatted_summarise_topic_descriptions_system_prompt + "\n" + reasoning_suffix
593
 
594
  try:
595
- response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_topic_descriptions_system_prompt, model_source, bedrock_runtime, local_model)
596
  summarised_output = response
597
  summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
598
  summarised_output = re.sub(r'^\n{1,}', '', summarised_output) # Remove one or more line breaks at the start
@@ -697,7 +699,9 @@ def overall_summary(topic_summary_df:pd.DataFrame,
697
  aws_secret_key_textbox:str='',
698
  model_name_map:dict=model_name_map,
699
  reasoning_suffix:str=reasoning_suffix,
700
- local_model:object=list(),
 
 
701
  summarise_everything_prompt:str=summarise_everything_prompt,
702
  comprehensive_summary_format_prompt:str=comprehensive_summary_format_prompt,
703
  comprehensive_summary_format_prompt_by_group:str=comprehensive_summary_format_prompt_by_group,
@@ -721,6 +725,8 @@ def overall_summary(topic_summary_df:pd.DataFrame,
721
  model_name_map (dict, optional): Mapping of model names. Defaults to model_name_map.
722
  reasoning_suffix (str, optional): Suffix for reasoning. Defaults to reasoning_suffix.
723
  local_model (object, optional): Local model object. Defaults to empty list.
 
 
724
  summarise_everything_prompt (str, optional): Prompt for overall summary
725
  comprehensive_summary_format_prompt (str, optional): Prompt for comprehensive summary format
726
  comprehensive_summary_format_prompt_by_group (str, optional): Prompt for group summary format
@@ -795,7 +801,7 @@ def overall_summary(topic_summary_df:pd.DataFrame,
795
 
796
  if (model_choice == CHOSEN_LOCAL_MODEL_TYPE) & (RUN_LOCAL_MODEL == "1"):
797
  progress(0.1, f"Loading in local model: {CHOSEN_LOCAL_MODEL_TYPE}")
798
- local_model, tokenizer = load_model(local_model_type=CHOSEN_LOCAL_MODEL_TYPE, repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model_dir=LOCAL_MODEL_FOLDER)
799
  #print("Local model loaded:", local_model)
800
 
801
  summary_loop = tqdm(unique_groups, desc="Creating overall summary for groups", unit="groups")
@@ -806,7 +812,7 @@ def overall_summary(topic_summary_df:pd.DataFrame,
806
 
807
  for summary_group in summary_loop:
808
 
809
- print("Creating overallsummary for group:", summary_group)
810
 
811
  summary_text = topic_summary_df.loc[topic_summary_df["Group"]==summary_group].to_markdown(index=False)
812
 
@@ -817,7 +823,7 @@ def overall_summary(topic_summary_df:pd.DataFrame,
817
  if "Local" in model_source and reasoning_suffix: formatted_summarise_everything_system_prompt = formatted_summarise_everything_system_prompt + "\n" + reasoning_suffix
818
 
819
  try:
820
- response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_everything_system_prompt, model_source, bedrock_runtime, local_model)
821
  summarised_output_for_df = response
822
  summarised_output = response
823
  summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
 
415
 
416
  return sampled_reference_table_df, summarised_references_markdown#, reference_df, topic_summary_df
417
 
418
+ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:float, formatted_summary_prompt:str, summarise_topic_descriptions_system_prompt:str, model_source:str, bedrock_runtime:boto3.Session.client, local_model=list(), tokenizer=list()):
419
  """
420
  Query an LLM to generate a summary of topics based on the provided prompts.
421
 
 
428
  model_source (str): Source of the model (e.g. "AWS", "Gemini", "Local")
429
  bedrock_runtime (boto3.Session.client): AWS Bedrock runtime client for AWS models
430
  local_model (object, optional): Local model object if using local inference. Defaults to empty list.
431
+ tokenizer (object, optional): Tokenizer object if using local inference. Defaults to empty list.
432
  Returns:
433
  tuple: Contains:
434
  - response_text (str): The generated summary text
 
454
  whole_conversation = [summarise_topic_descriptions_system_prompt]
455
 
456
  # Process requests to large language model
457
+ responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(formatted_summary_prompt, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, bedrock_runtime=bedrock_runtime, model_source=model_source, local_model=local_model, tokenizer=tokenizer, assistant_prefill=summary_assistant_prefill)
458
 
459
  print("Finished summary query")
460
 
 
482
  aws_secret_key_textbox:str='',
483
  model_name_map:dict=model_name_map,
484
  reasoning_suffix:str=reasoning_suffix,
485
+ local_model:object=list(),
486
+ tokenizer:object=list(),
487
+ hf_api_key_textbox:str='',
488
  summarise_topic_descriptions_prompt:str=summarise_topic_descriptions_prompt,
489
  summarise_topic_descriptions_system_prompt:str=summarise_topic_descriptions_system_prompt,
490
  do_summaries:str="Yes",
 
574
 
575
  if (model_source == "Local") & (RUN_LOCAL_MODEL == "1"):
576
  progress(0.1, f"Loading in local model: {CHOSEN_LOCAL_MODEL_TYPE}")
577
+ local_model, tokenizer = load_model(local_model_type=CHOSEN_LOCAL_MODEL_TYPE, repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model_dir=LOCAL_MODEL_FOLDER, hf_token=hf_api_key_textbox)
578
 
579
  summary_loop_description = "Revising topic-level summaries. " + str(latest_summary_completed) + " summaries completed so far."
580
  summary_loop = tqdm(range(latest_summary_completed, length_all_summaries), desc="Revising topic-level summaries", unit="summaries")
 
594
  if "Local" in model_source and reasoning_suffix: formatted_summarise_topic_descriptions_system_prompt = formatted_summarise_topic_descriptions_system_prompt + "\n" + reasoning_suffix
595
 
596
  try:
597
+ response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_topic_descriptions_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer)
598
  summarised_output = response
599
  summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
600
  summarised_output = re.sub(r'^\n{1,}', '', summarised_output) # Remove one or more line breaks at the start
 
699
  aws_secret_key_textbox:str='',
700
  model_name_map:dict=model_name_map,
701
  reasoning_suffix:str=reasoning_suffix,
702
+ local_model:object=list(),
703
+ tokenizer:object=list(),
704
+ hf_api_key_textbox:str='',
705
  summarise_everything_prompt:str=summarise_everything_prompt,
706
  comprehensive_summary_format_prompt:str=comprehensive_summary_format_prompt,
707
  comprehensive_summary_format_prompt_by_group:str=comprehensive_summary_format_prompt_by_group,
 
725
  model_name_map (dict, optional): Mapping of model names. Defaults to model_name_map.
726
  reasoning_suffix (str, optional): Suffix for reasoning. Defaults to reasoning_suffix.
727
  local_model (object, optional): Local model object. Defaults to empty list.
728
+ tokenizer (object, optional): Tokenizer object. Defaults to empty list.
729
+ hf_api_key_textbox (str, optional): Hugging Face API key. Defaults to empty string.
730
  summarise_everything_prompt (str, optional): Prompt for overall summary
731
  comprehensive_summary_format_prompt (str, optional): Prompt for comprehensive summary format
732
  comprehensive_summary_format_prompt_by_group (str, optional): Prompt for group summary format
 
801
 
802
  if (model_choice == CHOSEN_LOCAL_MODEL_TYPE) & (RUN_LOCAL_MODEL == "1"):
803
  progress(0.1, f"Loading in local model: {CHOSEN_LOCAL_MODEL_TYPE}")
804
+ local_model, tokenizer = load_model(local_model_type=CHOSEN_LOCAL_MODEL_TYPE, repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model_dir=LOCAL_MODEL_FOLDER, hf_token=hf_api_key_textbox)
805
  #print("Local model loaded:", local_model)
806
 
807
  summary_loop = tqdm(unique_groups, desc="Creating overall summary for groups", unit="groups")
 
812
 
813
  for summary_group in summary_loop:
814
 
815
+ print("Creating overall summary for group:", summary_group)
816
 
817
  summary_text = topic_summary_df.loc[topic_summary_df["Group"]==summary_group].to_markdown(index=False)
818
 
 
823
  if "Local" in model_source and reasoning_suffix: formatted_summarise_everything_system_prompt = formatted_summarise_everything_system_prompt + "\n" + reasoning_suffix
824
 
825
  try:
826
+ response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_everything_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer)
827
  summarised_output_for_df = response
828
  summarised_output = response
829
  summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
tools/llm_api_call.py CHANGED
@@ -688,6 +688,7 @@ def extract_topics(in_data_file: GradioFileData,
688
  produce_structures_summary_radio:str="No",
689
  aws_access_key_textbox:str='',
690
  aws_secret_key_textbox:str='',
 
691
  max_tokens:int=max_tokens,
692
  model_name_map:dict=model_name_map,
693
  max_time_for_loop:int=max_time_for_loop,
@@ -737,6 +738,7 @@ def extract_topics(in_data_file: GradioFileData,
737
  - force_single_topic_prompt (str, optional): The prompt for forcing the model to assign only one single topic to each response.
738
  - aws_access_key_textbox (str, optional): AWS access key for account with Bedrock permissions.
739
  - aws_secret_key_textbox (str, optional): AWS secret key for account with Bedrock permissions.
 
740
  - max_tokens (int): The maximum number of tokens for the model.
741
  - model_name_map (dict, optional): A dictionary mapping full model name to shortened.
742
  - max_time_for_loop (int, optional): The number of seconds maximum that the function should run for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there).
@@ -808,7 +810,7 @@ def extract_topics(in_data_file: GradioFileData,
808
 
809
  if (model_source == "Local") & (RUN_LOCAL_MODEL == "1"):
810
  progress(0.1, f"Loading in local model: {CHOSEN_LOCAL_MODEL_TYPE}")
811
- local_model, tokenizer = load_model(local_model_type=CHOSEN_LOCAL_MODEL_TYPE, repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model_dir=LOCAL_MODEL_FOLDER)
812
 
813
  if num_batches > 0:
814
  progress_measure = round(latest_batch_completed / num_batches, 1)
@@ -938,9 +940,9 @@ def extract_topics(in_data_file: GradioFileData,
938
  formatted_summary_prompt = structured_summary_prompt.format(response_table=normalised_simple_markdown_table,
939
  topics=unique_topics_markdown)
940
 
941
- if "gemma" in model_choice:
942
- formatted_summary_prompt = llama_cpp_prefix + formatted_system_prompt + "\n" + formatted_summary_prompt + llama_cpp_suffix
943
- full_prompt = formatted_summary_prompt
944
  else:
945
  full_prompt = formatted_system_prompt + "\n" + formatted_summary_prompt
946
 
@@ -970,7 +972,7 @@ def extract_topics(in_data_file: GradioFileData,
970
  whole_conversation = list()
971
 
972
  # Process requests to large language model
973
- responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=add_existing_topics_assistant_prefill, master = True)
974
 
975
  # Return output tables
976
  topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, new_topic_df, new_reference_df, new_topic_summary_df, master_batch_out_file_part, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structures_summary_radio, first_run=False, output_folder=output_folder)
@@ -1030,7 +1032,7 @@ def extract_topics(in_data_file: GradioFileData,
1030
  formatted_initial_table_system_prompt = initial_table_system_prompt.format(consultation_context=context_textbox, column_name=chosen_cols)
1031
 
1032
  # Prepare Gemini models before query
1033
- if "gemini" in model_choice:
1034
  print("Using Gemini model:", model_choice)
1035
  google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_initial_table_system_prompt, max_tokens=max_tokens)
1036
  elif model_choice == CHOSEN_LOCAL_MODEL_TYPE:
@@ -1062,7 +1064,7 @@ def extract_topics(in_data_file: GradioFileData,
1062
 
1063
  whole_conversation = list()
1064
 
1065
- responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, formatted_initial_table_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=initial_table_assistant_prefill)
1066
 
1067
  topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, topic_table_df, reference_df, new_topic_summary_df, batch_file_path_details, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structures_summary_radio, first_run=True, output_folder=output_folder)
1068
 
@@ -1271,6 +1273,7 @@ def wrapper_extract_topics_per_column_value(
1271
  produce_structures_summary_radio: str = "No",
1272
  aws_access_key_textbox:str="",
1273
  aws_secret_key_textbox:str="",
 
1274
  output_folder: str = OUTPUT_FOLDER,
1275
  force_single_topic_prompt: str = force_single_topic_prompt,
1276
  max_tokens: int = max_tokens,
@@ -1419,6 +1422,7 @@ def wrapper_extract_topics_per_column_value(
1419
  produce_structures_summary_radio=produce_structures_summary_radio,
1420
  aws_access_key_textbox=aws_access_key_textbox,
1421
  aws_secret_key_textbox=aws_secret_key_textbox,
 
1422
  max_tokens=max_tokens,
1423
  model_name_map=model_name_map,
1424
  max_time_for_loop=max_time_for_loop,
 
688
  produce_structures_summary_radio:str="No",
689
  aws_access_key_textbox:str='',
690
  aws_secret_key_textbox:str='',
691
+ hf_api_key_textbox:str='',
692
  max_tokens:int=max_tokens,
693
  model_name_map:dict=model_name_map,
694
  max_time_for_loop:int=max_time_for_loop,
 
738
  - force_single_topic_prompt (str, optional): The prompt for forcing the model to assign only one single topic to each response.
739
  - aws_access_key_textbox (str, optional): AWS access key for account with Bedrock permissions.
740
  - aws_secret_key_textbox (str, optional): AWS secret key for account with Bedrock permissions.
741
+ - hf_api_key_textbox (str, optional): Hugging Face API key for account with Hugging Face permissions.
742
  - max_tokens (int): The maximum number of tokens for the model.
743
  - model_name_map (dict, optional): A dictionary mapping full model name to shortened.
744
  - max_time_for_loop (int, optional): The number of seconds maximum that the function should run for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there).
 
810
 
811
  if (model_source == "Local") & (RUN_LOCAL_MODEL == "1"):
812
  progress(0.1, f"Loading in local model: {CHOSEN_LOCAL_MODEL_TYPE}")
813
+ local_model, tokenizer = load_model(local_model_type=CHOSEN_LOCAL_MODEL_TYPE, repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model_dir=LOCAL_MODEL_FOLDER, hf_token=hf_api_key_textbox)
814
 
815
  if num_batches > 0:
816
  progress_measure = round(latest_batch_completed / num_batches, 1)
 
940
  formatted_summary_prompt = structured_summary_prompt.format(response_table=normalised_simple_markdown_table,
941
  topics=unique_topics_markdown)
942
 
943
+ if model_source == "Local":
944
+ #formatted_summary_prompt = llama_cpp_prefix + formatted_system_prompt + "\n" + formatted_summary_prompt + llama_cpp_suffix
945
+ full_prompt = formatted_system_prompt + "\n" + formatted_summary_prompt
946
  else:
947
  full_prompt = formatted_system_prompt + "\n" + formatted_summary_prompt
948
 
 
972
  whole_conversation = list()
973
 
974
  # Process requests to large language model
975
+ responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, tokenizer, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=add_existing_topics_assistant_prefill, master = True)
976
 
977
  # Return output tables
978
  topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, new_topic_df, new_reference_df, new_topic_summary_df, master_batch_out_file_part, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structures_summary_radio, first_run=False, output_folder=output_folder)
 
1032
  formatted_initial_table_system_prompt = initial_table_system_prompt.format(consultation_context=context_textbox, column_name=chosen_cols)
1033
 
1034
  # Prepare Gemini models before query
1035
+ if model_source == "Gemini":
1036
  print("Using Gemini model:", model_choice)
1037
  google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_initial_table_system_prompt, max_tokens=max_tokens)
1038
  elif model_choice == CHOSEN_LOCAL_MODEL_TYPE:
 
1064
 
1065
  whole_conversation = list()
1066
 
1067
+ responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, formatted_initial_table_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, tokenizer,bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=initial_table_assistant_prefill)
1068
 
1069
  topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, topic_table_df, reference_df, new_topic_summary_df, batch_file_path_details, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structures_summary_radio, first_run=True, output_folder=output_folder)
1070
 
 
1273
  produce_structures_summary_radio: str = "No",
1274
  aws_access_key_textbox:str="",
1275
  aws_secret_key_textbox:str="",
1276
+ hf_api_key_textbox:str="",
1277
  output_folder: str = OUTPUT_FOLDER,
1278
  force_single_topic_prompt: str = force_single_topic_prompt,
1279
  max_tokens: int = max_tokens,
 
1422
  produce_structures_summary_radio=produce_structures_summary_radio,
1423
  aws_access_key_textbox=aws_access_key_textbox,
1424
  aws_secret_key_textbox=aws_secret_key_textbox,
1425
+ hf_api_key_textbox=hf_api_key_textbox,
1426
  max_tokens=max_tokens,
1427
  model_name_map=model_name_map,
1428
  max_time_for_loop=max_time_for_loop,
tools/llm_funcs.py CHANGED
@@ -18,7 +18,7 @@ full_text = "" # Define dummy source text (full text) just to enable highlight f
18
  model = list() # Define empty list for model functions to run
19
  tokenizer = list() #[] # Define empty list for model functions to run
20
 
21
- from tools.config import AWS_REGION, LLM_TEMPERATURE, LLM_TOP_K, LLM_MIN_P, LLM_TOP_P, LLM_REPETITION_PENALTY, LLM_LAST_N_TOKENS, LLM_MAX_NEW_TOKENS, LLM_SEED, LLM_RESET, LLM_STREAM, LLM_THREADS, LLM_BATCH_SIZE, LLM_CONTEXT_LENGTH, LLM_SAMPLE, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, MAX_COMMENT_CHARS, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, HF_TOKEN, LLM_SEED, LLM_MAX_GPU_LAYERS, SPECULATIVE_DECODING, NUM_PRED_TOKENS
22
  from tools.prompts import initial_table_assistant_prefill
23
 
24
  if SPECULATIVE_DECODING == "True": SPECULATIVE_DECODING = True
@@ -192,7 +192,10 @@ def load_model(local_model_type:str=CHOSEN_LOCAL_MODEL_TYPE,
192
  torch_device:str=torch_device,
193
  repo_id=LOCAL_REPO_ID,
194
  model_filename=LOCAL_MODEL_FILE,
195
- model_dir=LOCAL_MODEL_FOLDER):
 
 
 
196
  '''
197
  Load in a model from Hugging Face hub via the transformers package, or using llama_cpp_python by downloading a GGUF file from Huggingface Hub.
198
 
@@ -206,22 +209,22 @@ def load_model(local_model_type:str=CHOSEN_LOCAL_MODEL_TYPE,
206
  repo_id (str): The Hugging Face repository ID where the model is located.
207
  model_filename (str): The specific filename of the model to download from the repository.
208
  model_dir (str): The local directory where the model will be stored or downloaded.
209
-
 
 
210
  Returns:
211
  tuple: A tuple containing:
212
- - llama_model (Llama): The loaded Llama.cpp model instance.
213
- - tokenizer (list): An empty list (tokenizer is not used with Llama.cpp directly in this setup).
214
  '''
215
  print("Loading model ", local_model_type)
216
- model_path = get_model_path(repo_id=repo_id, model_filename=model_filename, model_dir=model_dir)
217
 
218
  #print("model_path:", model_path)
219
 
220
  # Verify the device and cuda settings
221
  # Check if CUDA is enabled
222
- import torch
223
- from llama_cpp import Llama
224
- from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
225
 
226
  torch.cuda.empty_cache()
227
  print("Is CUDA enabled? ", torch.cuda.is_available())
@@ -252,41 +255,132 @@ def load_model(local_model_type:str=CHOSEN_LOCAL_MODEL_TYPE,
252
  gpu_config.update_gpu(gpu_layers)
253
  gpu_config.update_context(max_context_length)
254
 
255
- try:
256
- print("GPU load variables:" , vars(gpu_config))
257
- if speculative_decoding:
258
- llama_model = Llama(model_path=model_path, type_k=8, type_v=8, flash_attn=True, draft_model=LlamaPromptLookupDecoding(num_pred_tokens=NUM_PRED_TOKENS), **vars(gpu_config))
259
- else:
260
- llama_model = Llama(model_path=model_path, type_k=8, type_v=8, flash_attn=True, **vars(gpu_config))
 
 
 
 
 
 
261
 
262
- except Exception as e:
263
- print("GPU load failed due to:", e, "Loading model in CPU mode")
264
- # If fails, go to CPU mode
265
- llama_model = Llama(model_path=model_path, **vars(cpu_config))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  print("Loading with", gpu_config.n_gpu_layers, "model layers sent to GPU and a maximum context length of", gpu_config.n_ctx)
268
 
269
  # CPU mode
270
  else:
271
- gpu_config.update_gpu(gpu_layers)
 
 
 
 
 
272
  cpu_config.update_gpu(gpu_layers)
273
 
274
  # Update context length according to slider
275
- gpu_config.update_context(max_context_length)
276
  cpu_config.update_context(max_context_length)
277
 
278
  if speculative_decoding:
279
- llama_model = Llama(model_path=model_path, draft_model=LlamaPromptLookupDecoding(num_pred_tokens=NUM_PRED_TOKENS), **vars(gpu_config))
280
  else:
281
- llama_model = Llama(model_path=model_path, **vars(cpu_config))
282
 
283
- print("Loading with", cpu_config.n_gpu_layers, "model layers sent to GPU and a maximum context length of", gpu_config.n_ctx)
 
284
 
285
- tokenizer = list()
286
 
287
  print("Finished loading model:", local_model_type)
288
  print("GPU layers assigned to cuda:", gpu_layers)
289
- return llama_model, tokenizer
290
 
291
  def call_llama_cpp_model(formatted_string:str, gen_config:str, model=model):
292
  """
@@ -506,8 +600,83 @@ def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tok
506
 
507
  return response
508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  # Function to send a request and update history
510
- def send_request(prompt: str, conversation_history: List[dict], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, system_prompt: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, local_model= list(), assistant_prefill = "", progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
511
  """
512
  This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
513
  It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
@@ -516,6 +685,8 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
516
  """
517
  # Constructing the full prompt from the conversation history
518
  full_prompt = "Conversation history:\n"
 
 
519
 
520
  for entry in conversation_history:
521
  role = entry['role'].capitalize() # Assuming the history is stored with 'role' and 'parts'
@@ -573,13 +744,18 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
573
  gen_config = LlamaCPPGenerationConfig()
574
  gen_config.update_temp(temperature)
575
 
576
- response = call_llama_cpp_chatmodel(prompt, system_prompt, gen_config, model=local_model)
 
 
 
 
 
577
 
578
  #print("Successful call to local model.")
579
  break
580
  except Exception as e:
581
  # If fails, try again after X seconds in case there is a throttle limit
582
- print("Call to Gemma model failed:", e, " Waiting for ", str(timeout_wait), "seconds and trying again.")
583
 
584
  time.sleep(timeout_wait)
585
 
@@ -596,21 +772,24 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
596
  if isinstance(response, ResponseObject):
597
  response_text = response.text
598
  conversation_history.append({'role': 'assistant', 'parts': [response_text]})
599
- elif 'choices' in response:
600
  if "gpt-oss" in model_choice:
601
  response_text = response['choices'][0]['message']['content'].split('<|start|>assistant<|channel|>final<|message|>')[1]
602
  else:
603
  response_text = response['choices'][0]['message']['content']
604
  response_text = response_text.strip()
605
  conversation_history.append({'role': 'assistant', 'parts': [response_text]}) #response['choices'][0]['text']]})
606
- else:
607
  response_text = response.text
608
  response_text = response_text.strip()
609
  conversation_history.append({'role': 'assistant', 'parts': [response_text]})
 
 
 
610
 
611
- return response, conversation_history, response_text
612
 
613
- def process_requests(prompts: List[str], system_prompt: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, batch_no:int = 1, local_model = list(), master:bool = False, assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
614
  """
615
  Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
616
 
@@ -641,7 +820,7 @@ def process_requests(prompts: List[str], system_prompt: str, conversation_histor
641
 
642
  for prompt in prompts:
643
 
644
- response, conversation_history, response_text = send_request(prompt, conversation_history, google_client=google_client, config=config, model_choice=model_choice, system_prompt=system_prompt, temperature=temperature, local_model=local_model, assistant_prefill=assistant_prefill, bedrock_runtime=bedrock_runtime, model_source=model_source)
645
 
646
  responses.append(response)
647
  whole_conversation.append(system_prompt)
@@ -677,9 +856,16 @@ def process_requests(prompts: List[str], system_prompt: str, conversation_histor
677
  whole_conversation_metadata.append(str(response.usage_metadata))
678
 
679
  elif "Local" in model_source:
680
- output_tokens = response['usage'].get('completion_tokens', 0)
681
- input_tokens = response['usage'].get('prompt_tokens', 0)
682
- whole_conversation_metadata.append(str(response['usage']))
 
 
 
 
 
 
 
683
  except KeyError as e:
684
  print(f"Key error: {e} - Check the structure of response.usage_metadata")
685
  else:
@@ -699,6 +885,7 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
699
  temperature: float,
700
  reported_batch_no: int,
701
  local_model: object,
 
702
  bedrock_runtime:boto3.Session.client,
703
  model_source:str,
704
  MAX_OUTPUT_VALIDATION_ATTEMPTS: int,
@@ -721,6 +908,7 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
721
  - temperature (float): The temperature parameter for the model.
722
  - reported_batch_no (int): The reported batch number.
723
  - local_model (object): The local model to use.
 
724
  - bedrock_runtime (boto3.Session.client): The client object for boto3 Bedrock runtime.
725
  - model_source (str): The source of the model, whether in AWS, Gemini, or local.
726
  - MAX_OUTPUT_VALIDATION_ATTEMPTS (int): The maximum number of attempts to validate the output.
@@ -743,7 +931,7 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
743
  responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(
744
  batch_prompts, system_prompt, conversation_history, whole_conversation,
745
  whole_conversation_metadata, google_client, google_config, model_choice,
746
- call_temperature, bedrock_runtime, model_source, reported_batch_no, local_model, master=master, assistant_prefill=assistant_prefill
747
  )
748
 
749
  stripped_response = response_text.strip()
 
18
  model = list() # Define empty list for model functions to run
19
  tokenizer = list() #[] # Define empty list for model functions to run
20
 
21
+ from tools.config import AWS_REGION, LLM_TEMPERATURE, LLM_TOP_K, LLM_MIN_P, LLM_TOP_P, LLM_REPETITION_PENALTY, LLM_LAST_N_TOKENS, LLM_MAX_NEW_TOKENS, LLM_SEED, LLM_RESET, LLM_STREAM, LLM_THREADS, LLM_BATCH_SIZE, LLM_CONTEXT_LENGTH, LLM_SAMPLE, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, MAX_COMMENT_CHARS, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, HF_TOKEN, LLM_SEED, LLM_MAX_GPU_LAYERS, SPECULATIVE_DECODING, NUM_PRED_TOKENS, USE_LLAMA_CPP, COMPILE_MODE, MODEL_DTYPE, USE_BITSANDBYTES, COMPILE_TRANSFORMERS, OFFLOAD_TO_CPU
22
  from tools.prompts import initial_table_assistant_prefill
23
 
24
  if SPECULATIVE_DECODING == "True": SPECULATIVE_DECODING = True
 
192
  torch_device:str=torch_device,
193
  repo_id=LOCAL_REPO_ID,
194
  model_filename=LOCAL_MODEL_FILE,
195
+ model_dir=LOCAL_MODEL_FOLDER,
196
+ compile_mode=COMPILE_MODE,
197
+ model_dtype=MODEL_DTYPE,
198
+ hf_token=HF_TOKEN):
199
  '''
200
  Load in a model from Hugging Face hub via the transformers package, or using llama_cpp_python by downloading a GGUF file from Huggingface Hub.
201
 
 
209
  repo_id (str): The Hugging Face repository ID where the model is located.
210
  model_filename (str): The specific filename of the model to download from the repository.
211
  model_dir (str): The local directory where the model will be stored or downloaded.
212
+ compile_mode (str): The compilation mode to use for the model.
213
+ model_dtype (str): The data type to use for the model.
214
+ hf_token (str): The Hugging Face token to use for the model.
215
  Returns:
216
  tuple: A tuple containing:
217
+ - model (Llama/transformers model): The loaded Llama.cpp/transformers model instance.
218
+ - tokenizer (list/transformers tokenizer): An empty list (tokenizer is not used with Llama.cpp directly in this setup), or a transformers tokenizer.
219
  '''
220
  print("Loading model ", local_model_type)
221
+ tokenizer = list()
222
 
223
  #print("model_path:", model_path)
224
 
225
  # Verify the device and cuda settings
226
  # Check if CUDA is enabled
227
+ import torch
 
 
228
 
229
  torch.cuda.empty_cache()
230
  print("Is CUDA enabled? ", torch.cuda.is_available())
 
255
  gpu_config.update_gpu(gpu_layers)
256
  gpu_config.update_context(max_context_length)
257
 
258
+ if USE_LLAMA_CPP == "True":
259
+ from llama_cpp import Llama
260
+ from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
261
+
262
+ model_path = get_model_path(repo_id=repo_id, model_filename=model_filename, model_dir=model_dir)
263
+
264
+ try:
265
+ print("GPU load variables:" , vars(gpu_config))
266
+ if speculative_decoding:
267
+ model = Llama(model_path=model_path, type_k=8, type_v=8, flash_attn=True, draft_model=LlamaPromptLookupDecoding(num_pred_tokens=NUM_PRED_TOKENS), **vars(gpu_config))
268
+ else:
269
+ model = Llama(model_path=model_path, type_k=8, type_v=8, flash_attn=True, **vars(gpu_config))
270
 
271
+ except Exception as e:
272
+ print("GPU load failed due to:", e, "Loading model in CPU mode")
273
+ # If fails, go to CPU mode
274
+ model = Llama(model_path=model_path, **vars(cpu_config))
275
+
276
+ else:
277
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
278
+
279
+ print("Loading model from transformers")
280
+ # Use the official model ID for Gemma 3 4B
281
+ model_id = repo_id
282
+ # 1. Set Data Type (dtype)
283
+ # For H200/Hopper: 'bfloat16'
284
+ # For RTX 3060/Ampere: 'float16'
285
+ dtype_str = model_dtype #os.environ.get("MODEL_DTYPE", "bfloat16").lower()
286
+ if dtype_str == "bfloat16":
287
+ torch_dtype = torch.bfloat16
288
+ elif dtype_str == "float16":
289
+ torch_dtype = torch.float16
290
+ else:
291
+ torch_dtype = torch.float32 # A safe fallback
292
+
293
+ # 2. Set Compilation Mode
294
+ # 'max-autotune' is great for both but can be slow initially.
295
+ # 'reduce-overhead' is a faster alternative for compiling.
296
+
297
+ print(f"--- System Configuration ---")
298
+ print(f"Using model id: {model_id}")
299
+ print(f"Using dtype: {torch_dtype}")
300
+ print(f"Using compile mode: {compile_mode}")
301
+ print(f"Using bitsandbytes: {USE_BITSANDBYTES}")
302
+ print("--------------------------\n")
303
+
304
+ # --- Load Tokenizer and Model ---
305
+
306
+ # Load Tokenizer and Model
307
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
308
+
309
+ if not tokenizer.pad_token:
310
+ tokenizer.pad_token = tokenizer.eos_token
311
+
312
+ if USE_BITSANDBYTES == "True":
313
+
314
+ if OFFLOAD_TO_CPU == "True":
315
+ # This will be very slow. Requires at least 4GB of VRAM and 32GB of RAM
316
+ print("Using bitsandbytes for quantisation to 8 bits, with offloading to CPU")
317
+ max_memory={0: "4GB", "cpu": "32GB"}
318
+ quantization_config = BitsAndBytesConfig(
319
+ load_in_8bit=True,
320
+ max_memory=max_memory,
321
+ llm_int8_enable_fp32_cpu_offload=True # Note: if bitsandbytes has to offload to CPU, inference will be slow
322
+ )
323
+ else:
324
+ # For Gemma 4B, requires at least 6GB of VRAM
325
+ print("Using bitsandbytes for quantisation to 4 bits")
326
+ quantization_config = BitsAndBytesConfig(
327
+ load_in_4bit=True,
328
+ bnb_4bit_quant_type="nf4", # Use the modern NF4 quantisation for better performance
329
+ bnb_4bit_compute_dtype=torch_dtype,
330
+ bnb_4bit_use_double_quant=True, # Optional: uses a second quantisation step to save even more memory
331
+ )
332
+
333
+ model = AutoModelForCausalLM.from_pretrained(
334
+ model_id,
335
+ torch_dtype=torch_dtype,
336
+ device_map="auto",
337
+ quantization_config=quantization_config,
338
+ token=hf_token
339
+ )
340
+ else:
341
+ print("Using fp16 precision for model")
342
+ model = AutoModelForCausalLM.from_pretrained(
343
+ model_id,
344
+ torch_dtype=torch_dtype,
345
+ device_map="auto",
346
+ token=hf_token
347
+ )
348
+
349
+ # Compile the Model with the selected mode 🚀
350
+ if COMPILE_TRANSFORMERS == "True":
351
+ try:
352
+ model = torch.compile(model, mode=compile_mode, fullgraph=True)
353
+ except Exception as e:
354
+ print(f"Could not compile model: {e}. Running in eager mode.")
355
 
356
  print("Loading with", gpu_config.n_gpu_layers, "model layers sent to GPU and a maximum context length of", gpu_config.n_ctx)
357
 
358
  # CPU mode
359
  else:
360
+ if USE_LLAMA_CPP == "False":
361
+ raise Warning("Using transformers model in CPU mode is not supported. Please change your config variable USE_LLAMA_CPP to True if you want to do CPU inference.")
362
+
363
+ model_path = get_model_path(repo_id=repo_id, model_filename=model_filename, model_dir=model_dir)
364
+
365
+ #gpu_config.update_gpu(gpu_layers)
366
  cpu_config.update_gpu(gpu_layers)
367
 
368
  # Update context length according to slider
369
+ #gpu_config.update_context(max_context_length)
370
  cpu_config.update_context(max_context_length)
371
 
372
  if speculative_decoding:
373
+ model = Llama(model_path=model_path, draft_model=LlamaPromptLookupDecoding(num_pred_tokens=NUM_PRED_TOKENS), **vars(cpu_config))
374
  else:
375
+ model = Llama(model_path=model_path, **vars(cpu_config))
376
 
377
+ print("Loading with", cpu_config.n_gpu_layers, "model layers sent to GPU and a maximum context length of", cpu_config.n_ctx)
378
+
379
 
 
380
 
381
  print("Finished loading model:", local_model_type)
382
  print("GPU layers assigned to cuda:", gpu_layers)
383
+ return model, tokenizer
384
 
385
  def call_llama_cpp_model(formatted_string:str, gen_config:str, model=model):
386
  """
 
600
 
601
  return response
602
 
603
+ def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCPPGenerationConfig, model=model, tokenizer=tokenizer):
604
+ """
605
+ This function sends a request to a transformers model with the given prompt, system prompt, and generation configuration.
606
+ """
607
+ # 1. Define the conversation as a list of dictionaries
608
+ conversation = [
609
+ {"role": "system", "content": system_prompt},
610
+ {"role": "user", "content": prompt}
611
+ ]
612
+
613
+ # 2. Apply the chat template
614
+ # This function formats the conversation into the exact string Gemma 3 expects.
615
+ # add_generation_prompt=True adds the special tokens that tell the model it's its turn to speak.
616
+ input_ids = tokenizer.apply_chat_template(
617
+ conversation,
618
+ add_generation_prompt=True,
619
+ return_tensors="pt"
620
+ ).to("cuda")
621
+
622
+ # Warm-up run
623
+ print("Performing warm-up run...")
624
+ _ = model.generate(input_ids, max_new_tokens=50)
625
+ print("Warm-up complete.")
626
+
627
+ # Map LlamaCPP parameters to transformers parameters
628
+ generation_kwargs = {
629
+ 'max_new_tokens': gen_config.max_tokens,
630
+ 'temperature': gen_config.temperature,
631
+ 'top_p': gen_config.top_p,
632
+ 'top_k': gen_config.top_k,
633
+ 'do_sample': True,
634
+ 'pad_token_id': tokenizer.eos_token_id
635
+ }
636
+
637
+ # Remove parameters that don't exist in transformers
638
+ if hasattr(gen_config, 'repeat_penalty'):
639
+ generation_kwargs['repetition_penalty'] = gen_config.repeat_penalty
640
+
641
+ # --- Timed Inference Test ---
642
+ print("\nStarting timed inference test...")
643
+ start_time = time.time()
644
+
645
+
646
+
647
+ outputs = model.generate(
648
+ input_ids,
649
+ **generation_kwargs
650
+ )
651
+
652
+ end_time = time.time()
653
+
654
+ # --- Decode and Display Results ---
655
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
656
+ # To get only the model's reply, we can decode just the newly generated tokens
657
+ new_tokens = outputs[0][input_ids.shape[-1]:]
658
+ assistant_reply = tokenizer.decode(new_tokens, skip_special_tokens=True)
659
+
660
+ num_input_tokens = len(input_ids)
661
+ num_generated_tokens = len(new_tokens)
662
+ duration = end_time - start_time
663
+ tokens_per_second = num_generated_tokens / duration
664
+
665
+ print("\n--- Inference Results ---")
666
+ print(f"System Prompt: {conversation[0]['content']}")
667
+ print(f"User Prompt: {conversation[1]['content']}")
668
+ print("---")
669
+ print(f"Assistant's Reply: {assistant_reply}")
670
+ print("\n--- Performance ---")
671
+ print(f"Time taken: {duration:.2f} seconds")
672
+ print(f"Generated tokens: {num_generated_tokens}")
673
+ print(f"Tokens per second: {tokens_per_second:.2f}")
674
+
675
+ return assistant_reply, num_input_tokens, num_generated_tokens
676
+
677
+
678
  # Function to send a request and update history
679
+ def send_request(prompt: str, conversation_history: List[dict], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, system_prompt: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, local_model= list(), tokenizer=tokenizer, assistant_prefill = "", progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
680
  """
681
  This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
682
  It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
 
685
  """
686
  # Constructing the full prompt from the conversation history
687
  full_prompt = "Conversation history:\n"
688
+ num_transformer_input_tokens = 0
689
+ num_transformer_generated_tokens = 0
690
 
691
  for entry in conversation_history:
692
  role = entry['role'].capitalize() # Assuming the history is stored with 'role' and 'parts'
 
744
  gen_config = LlamaCPPGenerationConfig()
745
  gen_config.update_temp(temperature)
746
 
747
+ if USE_LLAMA_CPP == "True":
748
+ response = call_llama_cpp_chatmodel(prompt, system_prompt, gen_config, model=local_model)
749
+
750
+ else:
751
+ response, num_transformer_input_tokens, num_transformer_generated_tokens = call_transformers_model(prompt, system_prompt, gen_config, model=local_model, tokenizer=tokenizer)
752
+ response_text = response
753
 
754
  #print("Successful call to local model.")
755
  break
756
  except Exception as e:
757
  # If fails, try again after X seconds in case there is a throttle limit
758
+ print("Call to local model failed:", e, " Waiting for ", str(timeout_wait), "seconds and trying again.")
759
 
760
  time.sleep(timeout_wait)
761
 
 
772
  if isinstance(response, ResponseObject):
773
  response_text = response.text
774
  conversation_history.append({'role': 'assistant', 'parts': [response_text]})
775
+ elif 'choices' in response: # LLama.cpp model response
776
  if "gpt-oss" in model_choice:
777
  response_text = response['choices'][0]['message']['content'].split('<|start|>assistant<|channel|>final<|message|>')[1]
778
  else:
779
  response_text = response['choices'][0]['message']['content']
780
  response_text = response_text.strip()
781
  conversation_history.append({'role': 'assistant', 'parts': [response_text]}) #response['choices'][0]['text']]})
782
+ elif model_source == "Gemini":
783
  response_text = response.text
784
  response_text = response_text.strip()
785
  conversation_history.append({'role': 'assistant', 'parts': [response_text]})
786
+ else: # Assume transformers model response
787
+ response_text = response
788
+ conversation_history.append({'role': 'assistant', 'parts': [response_text]})
789
 
790
+ return response, conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens
791
 
792
+ def process_requests(prompts: List[str], system_prompt: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, batch_no:int = 1, local_model = list(), tokenizer=tokenizer, master:bool = False, assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
793
  """
794
  Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
795
 
 
820
 
821
  for prompt in prompts:
822
 
823
+ response, conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens = send_request(prompt, conversation_history, google_client=google_client, config=config, model_choice=model_choice, system_prompt=system_prompt, temperature=temperature, local_model=local_model, tokenizer=tokenizer, assistant_prefill=assistant_prefill, bedrock_runtime=bedrock_runtime, model_source=model_source)
824
 
825
  responses.append(response)
826
  whole_conversation.append(system_prompt)
 
856
  whole_conversation_metadata.append(str(response.usage_metadata))
857
 
858
  elif "Local" in model_source:
859
+ if USE_LLAMA_CPP == "True":
860
+ output_tokens = response['usage'].get('completion_tokens', 0)
861
+ input_tokens = response['usage'].get('prompt_tokens', 0)
862
+ whole_conversation_metadata.append(str(response['usage']))
863
+
864
+ if USE_LLAMA_CPP == "False":
865
+ input_tokens = num_transformer_input_tokens
866
+ output_tokens = num_transformer_generated_tokens
867
+ whole_conversation_metadata.append('inputTokens: ' + str(input_tokens) + ' outputTokens: ' + str(output_tokens))
868
+
869
  except KeyError as e:
870
  print(f"Key error: {e} - Check the structure of response.usage_metadata")
871
  else:
 
885
  temperature: float,
886
  reported_batch_no: int,
887
  local_model: object,
888
+ tokenizer:object,
889
  bedrock_runtime:boto3.Session.client,
890
  model_source:str,
891
  MAX_OUTPUT_VALIDATION_ATTEMPTS: int,
 
908
  - temperature (float): The temperature parameter for the model.
909
  - reported_batch_no (int): The reported batch number.
910
  - local_model (object): The local model to use.
911
+ - tokenizer (object): The tokenizer to use.
912
  - bedrock_runtime (boto3.Session.client): The client object for boto3 Bedrock runtime.
913
  - model_source (str): The source of the model, whether in AWS, Gemini, or local.
914
  - MAX_OUTPUT_VALIDATION_ATTEMPTS (int): The maximum number of attempts to validate the output.
 
931
  responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(
932
  batch_prompts, system_prompt, conversation_history, whole_conversation,
933
  whole_conversation_metadata, google_client, google_config, model_choice,
934
+ call_temperature, bedrock_runtime, model_source, reported_batch_no, local_model, tokenizer=tokenizer, master=master, assistant_prefill=assistant_prefill
935
  )
936
 
937
  stripped_response = response_text.strip()
tools/verify_titles.py CHANGED
@@ -448,7 +448,7 @@ def verify_titles(in_data_file,
448
  summary_whole_conversation = list()
449
 
450
  # Process requests to large language model
451
- responses, summary_conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=add_existing_topics_assistant_prefill, master = True)
452
 
453
 
454
 
@@ -549,7 +549,7 @@ def verify_titles(in_data_file,
549
 
550
  whole_conversation = [formatted_initial_table_system_prompt]
551
 
552
- responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=initial_table_assistant_prefill)
553
 
554
 
555
  topic_table_out_path, reference_table_out_path, unique_topics_df_out_path, topic_table_df, markdown_table, reference_df, new_unique_topics_df, batch_file_path_details, is_error = write_llm_output_and_logs_verify(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_unique_topics_df, batch_size, chosen_cols, model_name_map=model_name_map, first_run=True)
 
448
  summary_whole_conversation = list()
449
 
450
  # Process requests to large language model
451
+ responses, summary_conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=add_existing_topics_assistant_prefill, tokenizer=tokenizer, master = True)
452
 
453
 
454
 
 
549
 
550
  whole_conversation = [formatted_initial_table_system_prompt]
551
 
552
+ responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=initial_table_assistant_prefill, tokenizer=tokenizer)
553
 
554
 
555
  topic_table_out_path, reference_table_out_path, unique_topics_df_out_path, topic_table_df, markdown_table, reference_df, new_unique_topics_df, batch_file_path_details, is_error = write_llm_output_and_logs_verify(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_unique_topics_df, batch_size, chosen_cols, model_name_map=model_name_map, first_run=True)