Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
72d517c
1
Parent(s):
8c54223
Enabled GPU-based local model inference with the transformers package
Browse files- app.py +12 -7
- requirements.txt +2 -0
- requirements_gpu.txt +4 -2
- tools/config.py +25 -0
- tools/dedup_summaries.py +16 -10
- tools/llm_api_call.py +11 -7
- tools/llm_funcs.py +226 -38
- tools/verify_titles.py +2 -2
app.py
CHANGED
@@ -12,7 +12,7 @@ from tools.custom_csvlogger import CSVLogger_custom
|
|
12 |
from tools.auth import authenticate_user
|
13 |
from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, verify_titles_prompt, verify_titles_system_prompt, two_para_summary_format_prompt, single_para_summary_format_prompt
|
14 |
from tools.verify_titles import verify_titles
|
15 |
-
from tools.config import RUN_AWS_FUNCTIONS, HOST_NAME, ACCESS_LOGS_FOLDER, FEEDBACK_LOGS_FOLDER, USAGE_LOGS_FOLDER, RUN_LOCAL_MODEL, FILE_INPUT_HEIGHT, GEMINI_API_KEY, model_full_names, BATCH_SIZE_DEFAULT, CHOSEN_LOCAL_MODEL_TYPE, LLM_SEED, COGNITO_AUTH, MAX_QUEUE_SIZE, MAX_FILE_SIZE, GRADIO_SERVER_PORT, ROOT_PATH, INPUT_FOLDER, OUTPUT_FOLDER, S3_LOG_BUCKET, CONFIG_FOLDER, GRADIO_TEMP_DIR, MPLCONFIGDIR, model_name_map, GET_COST_CODES, ENFORCE_COST_CODES, DEFAULT_COST_CODE, COST_CODES_PATH, S3_COST_CODES_PATH, OUTPUT_COST_CODES_PATH, SHOW_COSTS, SAVE_LOGS_TO_CSV, SAVE_LOGS_TO_DYNAMODB, ACCESS_LOG_DYNAMODB_TABLE_NAME, USAGE_LOG_DYNAMODB_TABLE_NAME, FEEDBACK_LOG_DYNAMODB_TABLE_NAME, LOG_FILE_NAME, FEEDBACK_LOG_FILE_NAME, USAGE_LOG_FILE_NAME, CSV_ACCESS_LOG_HEADERS, CSV_FEEDBACK_LOG_HEADERS, CSV_USAGE_LOG_HEADERS, DYNAMODB_ACCESS_LOG_HEADERS, DYNAMODB_FEEDBACK_LOG_HEADERS, DYNAMODB_USAGE_LOG_HEADERS, S3_ACCESS_LOGS_FOLDER, S3_FEEDBACK_LOGS_FOLDER, S3_USAGE_LOGS_FOLDER, AWS_ACCESS_KEY, AWS_SECRET_KEY, SHOW_EXAMPLES
|
16 |
|
17 |
def ensure_folder_exists(output_folder:str):
|
18 |
"""Checks if the specified folder exists, creates it if not."""
|
@@ -148,14 +148,14 @@ with app:
|
|
148 |
|
149 |
Instructions on use can be found in the README.md file. Try it out with this [dummy development consultation dataset](https://huggingface.co/datasets/seanpedrickcase/dummy_development_consultation/tree/main), which you can also try with [zero-shot topics](https://huggingface.co/datasets/seanpedrickcase/dummy_development_consultation/tree/main). Try also this [dummy case notes dataset](https://huggingface.co/datasets/seanpedrickcase/dummy_case_notes/tree/main).
|
150 |
|
151 |
-
You can use an AWS Bedrock model (paid), or Gemini (a free API for Flash). The use of Gemini requires an API key. To set up your own Gemini API key, go [here](https://aistudio.google.com/app/u/1/plan_information).
|
152 |
|
153 |
NOTE: Large language models are not 100% accurate and may produce biased or harmful outputs. All outputs from this app **absolutely need to be checked by a human** to check for harmful outputs, hallucinations, and accuracy.""")
|
154 |
|
155 |
with gr.Tab(label="1. Extract topics"):
|
156 |
gr.Markdown("""### Choose a tabular data file (xlsx, csv, parquet) of open text to extract topics from.""")
|
157 |
with gr.Row():
|
158 |
-
model_choice = gr.Dropdown(value = default_model_choice, choices = model_full_names, label="LLM model", multiselect=False)
|
159 |
|
160 |
with gr.Accordion("Upload xlsx or csv file", open = True):
|
161 |
in_data_files = gr.File(height=FILE_INPUT_HEIGHT, label="Choose Excel or csv files", file_count= "multiple", file_types=['.xlsx', '.xls', '.csv', '.parquet'])
|
@@ -177,6 +177,9 @@ with app:
|
|
177 |
|
178 |
sentiment_checkbox = gr.Radio(label="Choose sentiment categories to split responses", value="Negative or Positive", choices=["Negative or Positive", "Negative, Neutral, or Positive", "Do not assess sentiment"])
|
179 |
|
|
|
|
|
|
|
180 |
if GET_COST_CODES == "True" or ENFORCE_COST_CODES == "True":
|
181 |
with gr.Accordion("Assign task to cost code", open = True, visible=True):
|
182 |
gr.Markdown("Please ensure that you have approval from your budget holder before using this app for redaction tasks that incur a cost.")
|
@@ -188,9 +191,6 @@ with app:
|
|
188 |
|
189 |
all_in_one_btn = gr.Button("All in one - Extract topics, deduplicate, and summarise", variant="primary")
|
190 |
extract_topics_btn = gr.Button("1. Extract topics", variant="secondary")
|
191 |
-
|
192 |
-
if SHOW_EXAMPLES == "True":
|
193 |
-
examples = gr.Examples(examples=[[["example_data/dummy_consultation_response.csv"]], [["example_data/combined_case_notes.csv"]]], inputs=[in_data_files])
|
194 |
|
195 |
with gr.Row(equal_height=True):
|
196 |
output_messages_textbox = gr.Textbox(value="", label="Output messages", scale=1, interactive=False)
|
@@ -307,6 +307,9 @@ with app:
|
|
307 |
with gr.Accordion("Gemini API keys", open = False):
|
308 |
google_api_key_textbox = gr.Textbox(value = GEMINI_API_KEY, label="Enter Gemini API key (only if using Google API models)", lines=1, type="password")
|
309 |
|
|
|
|
|
|
|
310 |
with gr.Accordion("Log outputs", open = False):
|
311 |
log_files_output = gr.File(height=FILE_INPUT_HEIGHT, label="Log file output", interactive=False)
|
312 |
conversation_metadata_textbox = gr.Textbox(value="", label="Query metadata - usage counts and other parameters", lines=8)
|
@@ -350,7 +353,7 @@ with app:
|
|
350 |
###
|
351 |
|
352 |
# Tabular data upload
|
353 |
-
in_data_files.
|
354 |
|
355 |
# Click on cost code dataframe/dropdown fills in cost code textbox
|
356 |
# Allow user to select items from cost code dataframe for cost code
|
@@ -401,6 +404,7 @@ with app:
|
|
401 |
produce_structures_summary_radio,
|
402 |
aws_access_key_textbox,
|
403 |
aws_secret_key_textbox,
|
|
|
404 |
output_folder_state],
|
405 |
outputs=[display_topic_table_markdown,
|
406 |
master_topic_df_state,
|
@@ -498,6 +502,7 @@ with app:
|
|
498 |
produce_structures_summary_radio,
|
499 |
aws_access_key_textbox,
|
500 |
aws_secret_key_textbox,
|
|
|
501 |
output_folder_state],
|
502 |
outputs=[display_topic_table_markdown,
|
503 |
master_topic_df_state,
|
|
|
12 |
from tools.auth import authenticate_user
|
13 |
from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, verify_titles_prompt, verify_titles_system_prompt, two_para_summary_format_prompt, single_para_summary_format_prompt
|
14 |
from tools.verify_titles import verify_titles
|
15 |
+
from tools.config import RUN_AWS_FUNCTIONS, HOST_NAME, ACCESS_LOGS_FOLDER, FEEDBACK_LOGS_FOLDER, USAGE_LOGS_FOLDER, RUN_LOCAL_MODEL, FILE_INPUT_HEIGHT, GEMINI_API_KEY, model_full_names, BATCH_SIZE_DEFAULT, CHOSEN_LOCAL_MODEL_TYPE, LLM_SEED, COGNITO_AUTH, MAX_QUEUE_SIZE, MAX_FILE_SIZE, GRADIO_SERVER_PORT, ROOT_PATH, INPUT_FOLDER, OUTPUT_FOLDER, S3_LOG_BUCKET, CONFIG_FOLDER, GRADIO_TEMP_DIR, MPLCONFIGDIR, model_name_map, GET_COST_CODES, ENFORCE_COST_CODES, DEFAULT_COST_CODE, COST_CODES_PATH, S3_COST_CODES_PATH, OUTPUT_COST_CODES_PATH, SHOW_COSTS, SAVE_LOGS_TO_CSV, SAVE_LOGS_TO_DYNAMODB, ACCESS_LOG_DYNAMODB_TABLE_NAME, USAGE_LOG_DYNAMODB_TABLE_NAME, FEEDBACK_LOG_DYNAMODB_TABLE_NAME, LOG_FILE_NAME, FEEDBACK_LOG_FILE_NAME, USAGE_LOG_FILE_NAME, CSV_ACCESS_LOG_HEADERS, CSV_FEEDBACK_LOG_HEADERS, CSV_USAGE_LOG_HEADERS, DYNAMODB_ACCESS_LOG_HEADERS, DYNAMODB_FEEDBACK_LOG_HEADERS, DYNAMODB_USAGE_LOG_HEADERS, S3_ACCESS_LOGS_FOLDER, S3_FEEDBACK_LOGS_FOLDER, S3_USAGE_LOGS_FOLDER, AWS_ACCESS_KEY, AWS_SECRET_KEY, SHOW_EXAMPLES, HF_TOKEN
|
16 |
|
17 |
def ensure_folder_exists(output_folder:str):
|
18 |
"""Checks if the specified folder exists, creates it if not."""
|
|
|
148 |
|
149 |
Instructions on use can be found in the README.md file. Try it out with this [dummy development consultation dataset](https://huggingface.co/datasets/seanpedrickcase/dummy_development_consultation/tree/main), which you can also try with [zero-shot topics](https://huggingface.co/datasets/seanpedrickcase/dummy_development_consultation/tree/main). Try also this [dummy case notes dataset](https://huggingface.co/datasets/seanpedrickcase/dummy_case_notes/tree/main).
|
150 |
|
151 |
+
You can use an AWS Bedrock model (paid), or Gemini (a free API for Flash). The use of Gemini requires an API key. To set up your own Gemini API key, go [here](https://aistudio.google.com/app/u/1/plan_information).
|
152 |
|
153 |
NOTE: Large language models are not 100% accurate and may produce biased or harmful outputs. All outputs from this app **absolutely need to be checked by a human** to check for harmful outputs, hallucinations, and accuracy.""")
|
154 |
|
155 |
with gr.Tab(label="1. Extract topics"):
|
156 |
gr.Markdown("""### Choose a tabular data file (xlsx, csv, parquet) of open text to extract topics from.""")
|
157 |
with gr.Row():
|
158 |
+
model_choice = gr.Dropdown(value = default_model_choice, choices = model_full_names, label="LLM model", multiselect=False)
|
159 |
|
160 |
with gr.Accordion("Upload xlsx or csv file", open = True):
|
161 |
in_data_files = gr.File(height=FILE_INPUT_HEIGHT, label="Choose Excel or csv files", file_count= "multiple", file_types=['.xlsx', '.xls', '.csv', '.parquet'])
|
|
|
177 |
|
178 |
sentiment_checkbox = gr.Radio(label="Choose sentiment categories to split responses", value="Negative or Positive", choices=["Negative or Positive", "Negative, Neutral, or Positive", "Do not assess sentiment"])
|
179 |
|
180 |
+
if SHOW_EXAMPLES == "True":
|
181 |
+
examples = gr.Examples(examples=[[["example_data/dummy_consultation_response.csv"], "Response text", "Consultation for the construction of flats on Main Street"], [["example_data/combined_case_notes.csv"], "Case Note", "Social Care case notes for young people"]], inputs=[in_data_files, in_colnames, context_textbox], example_labels=["Consultation for the construction of flats on Main Street", "Social Care case notes for young people"], label="Test with an example dataset")
|
182 |
+
|
183 |
if GET_COST_CODES == "True" or ENFORCE_COST_CODES == "True":
|
184 |
with gr.Accordion("Assign task to cost code", open = True, visible=True):
|
185 |
gr.Markdown("Please ensure that you have approval from your budget holder before using this app for redaction tasks that incur a cost.")
|
|
|
191 |
|
192 |
all_in_one_btn = gr.Button("All in one - Extract topics, deduplicate, and summarise", variant="primary")
|
193 |
extract_topics_btn = gr.Button("1. Extract topics", variant="secondary")
|
|
|
|
|
|
|
194 |
|
195 |
with gr.Row(equal_height=True):
|
196 |
output_messages_textbox = gr.Textbox(value="", label="Output messages", scale=1, interactive=False)
|
|
|
307 |
with gr.Accordion("Gemini API keys", open = False):
|
308 |
google_api_key_textbox = gr.Textbox(value = GEMINI_API_KEY, label="Enter Gemini API key (only if using Google API models)", lines=1, type="password")
|
309 |
|
310 |
+
with gr.Accordion("Hugging Face API keys", open = False):
|
311 |
+
hf_api_key_textbox = gr.Textbox(value = HF_TOKEN, label="Enter Hugging Face API key (only if using Hugging Face models)", lines=1, type="password")
|
312 |
+
|
313 |
with gr.Accordion("Log outputs", open = False):
|
314 |
log_files_output = gr.File(height=FILE_INPUT_HEIGHT, label="Log file output", interactive=False)
|
315 |
conversation_metadata_textbox = gr.Textbox(value="", label="Query metadata - usage counts and other parameters", lines=8)
|
|
|
353 |
###
|
354 |
|
355 |
# Tabular data upload
|
356 |
+
in_data_files.upload(fn=put_columns_in_df, inputs=[in_data_files], outputs=[in_colnames, in_excel_sheets, original_data_file_name_textbox, join_colnames, in_group_col])
|
357 |
|
358 |
# Click on cost code dataframe/dropdown fills in cost code textbox
|
359 |
# Allow user to select items from cost code dataframe for cost code
|
|
|
404 |
produce_structures_summary_radio,
|
405 |
aws_access_key_textbox,
|
406 |
aws_secret_key_textbox,
|
407 |
+
hf_api_key_textbox,
|
408 |
output_folder_state],
|
409 |
outputs=[display_topic_table_markdown,
|
410 |
master_topic_df_state,
|
|
|
502 |
produce_structures_summary_radio,
|
503 |
aws_access_key_textbox,
|
504 |
aws_secret_key_textbox,
|
505 |
+
hf_api_key_textbox,
|
506 |
output_folder_state],
|
507 |
outputs=[display_topic_table_markdown,
|
508 |
master_topic_df_state,
|
requirements.txt
CHANGED
@@ -18,6 +18,8 @@ python-dotenv==1.1.0
|
|
18 |
# GPU
|
19 |
torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124 # Latest compatible with CUDA 12.4
|
20 |
https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.16-cu124/llama_cpp_python-0.3.16-cp310-cp310-linux_x86_64.whl
|
|
|
|
|
21 |
# CPU only (for e.g. Hugging Face CPU instances)
|
22 |
#torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/cpu
|
23 |
# For Hugging Face, need a python 3.10 compatible wheel for llama-cpp-python to avoid build timeouts
|
|
|
18 |
# GPU
|
19 |
torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124 # Latest compatible with CUDA 12.4
|
20 |
https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.16-cu124/llama_cpp_python-0.3.16-cp310-cp310-linux_x86_64.whl
|
21 |
+
bitsandbytes==0.47.0
|
22 |
+
accelerate==1.10.1
|
23 |
# CPU only (for e.g. Hugging Face CPU instances)
|
24 |
#torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/cpu
|
25 |
# For Hugging Face, need a python 3.10 compatible wheel for llama-cpp-python to avoid build timeouts
|
requirements_gpu.txt
CHANGED
@@ -17,10 +17,12 @@ python-dotenv==1.1.0
|
|
17 |
# Torch and Llama CPP Python
|
18 |
torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124 # Latest compatible with CUDA 12.4
|
19 |
# For Linux:
|
20 |
-
https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.16-cu124/llama_cpp_python-0.3.16-cp311-cp311-linux_x86_64.whl
|
21 |
# For Windows:
|
22 |
-
|
23 |
# If above doesn't work for Windows, try looking at'windows_install_llama-cpp-python.txt' for instructions on how to build from source
|
24 |
# If none of the above work for you, try the following:
|
25 |
# llama-cpp-python==0.3.16 -C cmake.args="-DGGML_CUDA=on -DGGML_CUBLAS=on"
|
|
|
|
|
26 |
|
|
|
17 |
# Torch and Llama CPP Python
|
18 |
torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124 # Latest compatible with CUDA 12.4
|
19 |
# For Linux:
|
20 |
+
#https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.16-cu124/llama_cpp_python-0.3.16-cp311-cp311-linux_x86_64.whl
|
21 |
# For Windows:
|
22 |
+
https://github.com/seanpedrick-case/llama-cpp-python-whl-builder/releases/download/v0.1.0/llama_cpp_python-0.3.16-cp311-cp311-win_amd64.whl
|
23 |
# If above doesn't work for Windows, try looking at'windows_install_llama-cpp-python.txt' for instructions on how to build from source
|
24 |
# If none of the above work for you, try the following:
|
25 |
# llama-cpp-python==0.3.16 -C cmake.args="-DGGML_CUDA=on -DGGML_CUBLAS=on"
|
26 |
+
bitsandbytes==0.47.0
|
27 |
+
accelerate==1.10.1
|
28 |
|
tools/config.py
CHANGED
@@ -241,20 +241,38 @@ model_name_map = {
|
|
241 |
|
242 |
# HF token may or may not be needed for downloading models from Hugging Face
|
243 |
HF_TOKEN = get_or_create_env_var('HF_TOKEN', '')
|
|
|
|
|
244 |
|
245 |
GEMMA2_REPO_ID = get_or_create_env_var("GEMMA2_2B_REPO_ID", "lmstudio-community/gemma-2-2b-it-GGUF")
|
|
|
|
|
|
|
|
|
246 |
GEMMA2_MODEL_FILE = get_or_create_env_var("GEMMA2_2B_MODEL_FILE", "gemma-2-2b-it-Q8_0.gguf")
|
247 |
GEMMA2_MODEL_FOLDER = get_or_create_env_var("GEMMA2_2B_MODEL_FOLDER", "model/gemma")
|
248 |
|
249 |
GEMMA3_REPO_ID = get_or_create_env_var("GEMMA3_REPO_ID", "unsloth/gemma-3-270m-it-qat-GGUF")
|
|
|
|
|
|
|
|
|
250 |
GEMMA3_MODEL_FILE = get_or_create_env_var("GEMMA3_MODEL_FILE", "gemma-3-270m-it-qat-F16.gguf")
|
251 |
GEMMA3_MODEL_FOLDER = get_or_create_env_var("GEMMA3_MODEL_FOLDER", "model/gemma")
|
252 |
|
253 |
GEMMA3_4B_REPO_ID = get_or_create_env_var("GEMMA3_4B_REPO_ID", "unsloth/gemma-3-4b-it-qat-GGUF")
|
|
|
|
|
|
|
|
|
254 |
GEMMA3_4B_MODEL_FILE = get_or_create_env_var("GEMMA3_4B_MODEL_FILE", "gemma-3-4b-it-qat-Q4_K_M.gguf")
|
255 |
GEMMA3_4B_MODEL_FOLDER = get_or_create_env_var("GEMMA3_4B_MODEL_FOLDER", "model/gemma3_4b")
|
256 |
|
257 |
GPT_OSS_REPO_ID = get_or_create_env_var("GPT_OSS_REPO_ID", "unsloth/gpt-oss-20b-GGUF")
|
|
|
|
|
|
|
|
|
258 |
GPT_OSS_MODEL_FILE = get_or_create_env_var("GPT_OSS_MODEL_FILE", "gpt-oss-20b-F16.gguf")
|
259 |
GPT_OSS_MODEL_FOLDER = get_or_create_env_var("GPT_OSS_MODEL_FOLDER", "model/gpt_oss")
|
260 |
|
@@ -305,6 +323,13 @@ SPECULATIVE_DECODING = get_or_create_env_var('SPECULATIVE_DECODING', 'False')
|
|
305 |
NUM_PRED_TOKENS = int(get_or_create_env_var('NUM_PRED_TOKENS', '2'))
|
306 |
REASONING_SUFFIX = get_or_create_env_var('REASONING_SUFFIX', 'Reasoning: low')
|
307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
MAX_GROUPS = int(get_or_create_env_var('MAX_GROUPS', '99'))
|
309 |
|
310 |
###
|
|
|
241 |
|
242 |
# HF token may or may not be needed for downloading models from Hugging Face
|
243 |
HF_TOKEN = get_or_create_env_var('HF_TOKEN', '')
|
244 |
+
USE_LLAMA_CPP = get_or_create_env_var('USE_LLAMA_CPP', 'True') # Llama.cpp or transformers
|
245 |
+
|
246 |
|
247 |
GEMMA2_REPO_ID = get_or_create_env_var("GEMMA2_2B_REPO_ID", "lmstudio-community/gemma-2-2b-it-GGUF")
|
248 |
+
GEMMA2_REPO_TRANSFORMERS_ID = get_or_create_env_var("GEMMA2_2B_REPO_TRANSFORMERS_ID", "google/gemma-2-2b-it")
|
249 |
+
if USE_LLAMA_CPP == "False":
|
250 |
+
GEMMA2_REPO_ID = GEMMA2_REPO_TRANSFORMERS_ID
|
251 |
+
|
252 |
GEMMA2_MODEL_FILE = get_or_create_env_var("GEMMA2_2B_MODEL_FILE", "gemma-2-2b-it-Q8_0.gguf")
|
253 |
GEMMA2_MODEL_FOLDER = get_or_create_env_var("GEMMA2_2B_MODEL_FOLDER", "model/gemma")
|
254 |
|
255 |
GEMMA3_REPO_ID = get_or_create_env_var("GEMMA3_REPO_ID", "unsloth/gemma-3-270m-it-qat-GGUF")
|
256 |
+
GEMMA3_REPO_TRANSFORMERS_ID = get_or_create_env_var("GEMMA3_REPO_TRANSFORMERS_ID", "google/gemma-3-270m-it")
|
257 |
+
if USE_LLAMA_CPP == "False":
|
258 |
+
GEMMA3_REPO_ID = GEMMA3_REPO_TRANSFORMERS_ID
|
259 |
+
|
260 |
GEMMA3_MODEL_FILE = get_or_create_env_var("GEMMA3_MODEL_FILE", "gemma-3-270m-it-qat-F16.gguf")
|
261 |
GEMMA3_MODEL_FOLDER = get_or_create_env_var("GEMMA3_MODEL_FOLDER", "model/gemma")
|
262 |
|
263 |
GEMMA3_4B_REPO_ID = get_or_create_env_var("GEMMA3_4B_REPO_ID", "unsloth/gemma-3-4b-it-qat-GGUF")
|
264 |
+
GEMMA3_4B_REPO_TRANSFORMERS_ID = get_or_create_env_var("GEMMA3_4B_REPO_TRANSFORMERS_ID", "google/gemma-3-4b-it")
|
265 |
+
if USE_LLAMA_CPP == "False":
|
266 |
+
GEMMA3_4B_REPO_ID = GEMMA3_4B_REPO_TRANSFORMERS_ID
|
267 |
+
|
268 |
GEMMA3_4B_MODEL_FILE = get_or_create_env_var("GEMMA3_4B_MODEL_FILE", "gemma-3-4b-it-qat-Q4_K_M.gguf")
|
269 |
GEMMA3_4B_MODEL_FOLDER = get_or_create_env_var("GEMMA3_4B_MODEL_FOLDER", "model/gemma3_4b")
|
270 |
|
271 |
GPT_OSS_REPO_ID = get_or_create_env_var("GPT_OSS_REPO_ID", "unsloth/gpt-oss-20b-GGUF")
|
272 |
+
GPT_OSS_REPO_TRANSFORMERS_ID = get_or_create_env_var("GPT_OSS_REPO_TRANSFORMERS_ID", "openai/gpt-oss-20b")
|
273 |
+
if USE_LLAMA_CPP == "False":
|
274 |
+
GPT_OSS_REPO_ID = GPT_OSS_REPO_TRANSFORMERS_ID
|
275 |
+
|
276 |
GPT_OSS_MODEL_FILE = get_or_create_env_var("GPT_OSS_MODEL_FILE", "gpt-oss-20b-F16.gguf")
|
277 |
GPT_OSS_MODEL_FOLDER = get_or_create_env_var("GPT_OSS_MODEL_FOLDER", "model/gpt_oss")
|
278 |
|
|
|
323 |
NUM_PRED_TOKENS = int(get_or_create_env_var('NUM_PRED_TOKENS', '2'))
|
324 |
REASONING_SUFFIX = get_or_create_env_var('REASONING_SUFFIX', 'Reasoning: low')
|
325 |
|
326 |
+
# Transformers variables
|
327 |
+
COMPILE_TRANSFORMERS = get_or_create_env_var('COMPILE_TRANSFORMERS', 'True') # Whether to compile transformers models
|
328 |
+
USE_BITSANDBYTES = get_or_create_env_var('USE_BITSANDBYTES', 'True') # Whether to use bitsandbytes for quantization
|
329 |
+
COMPILE_MODE = get_or_create_env_var('COMPILE_MODE', 'reduce-overhead') # alternatively 'max-autotune'
|
330 |
+
MODEL_DTYPE = get_or_create_env_var('MODEL_DTYPE', 'float16') # alternatively 'bfloat16'
|
331 |
+
OFFLOAD_TO_CPU = get_or_create_env_var('OFFLOAD_TO_CPU', 'False') # Whether to offload to CPU
|
332 |
+
|
333 |
MAX_GROUPS = int(get_or_create_env_var('MAX_GROUPS', '99'))
|
334 |
|
335 |
###
|
tools/dedup_summaries.py
CHANGED
@@ -415,7 +415,7 @@ def sample_reference_table_summaries(reference_df:pd.DataFrame,
|
|
415 |
|
416 |
return sampled_reference_table_df, summarised_references_markdown#, reference_df, topic_summary_df
|
417 |
|
418 |
-
def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:float, formatted_summary_prompt:str, summarise_topic_descriptions_system_prompt:str, model_source:str, bedrock_runtime:boto3.Session.client, local_model=list()):
|
419 |
"""
|
420 |
Query an LLM to generate a summary of topics based on the provided prompts.
|
421 |
|
@@ -428,7 +428,7 @@ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:
|
|
428 |
model_source (str): Source of the model (e.g. "AWS", "Gemini", "Local")
|
429 |
bedrock_runtime (boto3.Session.client): AWS Bedrock runtime client for AWS models
|
430 |
local_model (object, optional): Local model object if using local inference. Defaults to empty list.
|
431 |
-
|
432 |
Returns:
|
433 |
tuple: Contains:
|
434 |
- response_text (str): The generated summary text
|
@@ -454,7 +454,7 @@ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:
|
|
454 |
whole_conversation = [summarise_topic_descriptions_system_prompt]
|
455 |
|
456 |
# Process requests to large language model
|
457 |
-
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(formatted_summary_prompt, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, bedrock_runtime=bedrock_runtime, model_source=model_source, local_model=local_model, assistant_prefill=summary_assistant_prefill)
|
458 |
|
459 |
print("Finished summary query")
|
460 |
|
@@ -482,7 +482,9 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
|
|
482 |
aws_secret_key_textbox:str='',
|
483 |
model_name_map:dict=model_name_map,
|
484 |
reasoning_suffix:str=reasoning_suffix,
|
485 |
-
local_model:object=list(),
|
|
|
|
|
486 |
summarise_topic_descriptions_prompt:str=summarise_topic_descriptions_prompt,
|
487 |
summarise_topic_descriptions_system_prompt:str=summarise_topic_descriptions_system_prompt,
|
488 |
do_summaries:str="Yes",
|
@@ -572,7 +574,7 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
|
|
572 |
|
573 |
if (model_source == "Local") & (RUN_LOCAL_MODEL == "1"):
|
574 |
progress(0.1, f"Loading in local model: {CHOSEN_LOCAL_MODEL_TYPE}")
|
575 |
-
local_model, tokenizer = load_model(local_model_type=CHOSEN_LOCAL_MODEL_TYPE, repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model_dir=LOCAL_MODEL_FOLDER)
|
576 |
|
577 |
summary_loop_description = "Revising topic-level summaries. " + str(latest_summary_completed) + " summaries completed so far."
|
578 |
summary_loop = tqdm(range(latest_summary_completed, length_all_summaries), desc="Revising topic-level summaries", unit="summaries")
|
@@ -592,7 +594,7 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
|
|
592 |
if "Local" in model_source and reasoning_suffix: formatted_summarise_topic_descriptions_system_prompt = formatted_summarise_topic_descriptions_system_prompt + "\n" + reasoning_suffix
|
593 |
|
594 |
try:
|
595 |
-
response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_topic_descriptions_system_prompt, model_source, bedrock_runtime, local_model)
|
596 |
summarised_output = response
|
597 |
summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
|
598 |
summarised_output = re.sub(r'^\n{1,}', '', summarised_output) # Remove one or more line breaks at the start
|
@@ -697,7 +699,9 @@ def overall_summary(topic_summary_df:pd.DataFrame,
|
|
697 |
aws_secret_key_textbox:str='',
|
698 |
model_name_map:dict=model_name_map,
|
699 |
reasoning_suffix:str=reasoning_suffix,
|
700 |
-
local_model:object=list(),
|
|
|
|
|
701 |
summarise_everything_prompt:str=summarise_everything_prompt,
|
702 |
comprehensive_summary_format_prompt:str=comprehensive_summary_format_prompt,
|
703 |
comprehensive_summary_format_prompt_by_group:str=comprehensive_summary_format_prompt_by_group,
|
@@ -721,6 +725,8 @@ def overall_summary(topic_summary_df:pd.DataFrame,
|
|
721 |
model_name_map (dict, optional): Mapping of model names. Defaults to model_name_map.
|
722 |
reasoning_suffix (str, optional): Suffix for reasoning. Defaults to reasoning_suffix.
|
723 |
local_model (object, optional): Local model object. Defaults to empty list.
|
|
|
|
|
724 |
summarise_everything_prompt (str, optional): Prompt for overall summary
|
725 |
comprehensive_summary_format_prompt (str, optional): Prompt for comprehensive summary format
|
726 |
comprehensive_summary_format_prompt_by_group (str, optional): Prompt for group summary format
|
@@ -795,7 +801,7 @@ def overall_summary(topic_summary_df:pd.DataFrame,
|
|
795 |
|
796 |
if (model_choice == CHOSEN_LOCAL_MODEL_TYPE) & (RUN_LOCAL_MODEL == "1"):
|
797 |
progress(0.1, f"Loading in local model: {CHOSEN_LOCAL_MODEL_TYPE}")
|
798 |
-
local_model, tokenizer = load_model(local_model_type=CHOSEN_LOCAL_MODEL_TYPE, repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model_dir=LOCAL_MODEL_FOLDER)
|
799 |
#print("Local model loaded:", local_model)
|
800 |
|
801 |
summary_loop = tqdm(unique_groups, desc="Creating overall summary for groups", unit="groups")
|
@@ -806,7 +812,7 @@ def overall_summary(topic_summary_df:pd.DataFrame,
|
|
806 |
|
807 |
for summary_group in summary_loop:
|
808 |
|
809 |
-
print("Creating
|
810 |
|
811 |
summary_text = topic_summary_df.loc[topic_summary_df["Group"]==summary_group].to_markdown(index=False)
|
812 |
|
@@ -817,7 +823,7 @@ def overall_summary(topic_summary_df:pd.DataFrame,
|
|
817 |
if "Local" in model_source and reasoning_suffix: formatted_summarise_everything_system_prompt = formatted_summarise_everything_system_prompt + "\n" + reasoning_suffix
|
818 |
|
819 |
try:
|
820 |
-
response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_everything_system_prompt, model_source, bedrock_runtime, local_model)
|
821 |
summarised_output_for_df = response
|
822 |
summarised_output = response
|
823 |
summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
|
|
|
415 |
|
416 |
return sampled_reference_table_df, summarised_references_markdown#, reference_df, topic_summary_df
|
417 |
|
418 |
+
def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:float, formatted_summary_prompt:str, summarise_topic_descriptions_system_prompt:str, model_source:str, bedrock_runtime:boto3.Session.client, local_model=list(), tokenizer=list()):
|
419 |
"""
|
420 |
Query an LLM to generate a summary of topics based on the provided prompts.
|
421 |
|
|
|
428 |
model_source (str): Source of the model (e.g. "AWS", "Gemini", "Local")
|
429 |
bedrock_runtime (boto3.Session.client): AWS Bedrock runtime client for AWS models
|
430 |
local_model (object, optional): Local model object if using local inference. Defaults to empty list.
|
431 |
+
tokenizer (object, optional): Tokenizer object if using local inference. Defaults to empty list.
|
432 |
Returns:
|
433 |
tuple: Contains:
|
434 |
- response_text (str): The generated summary text
|
|
|
454 |
whole_conversation = [summarise_topic_descriptions_system_prompt]
|
455 |
|
456 |
# Process requests to large language model
|
457 |
+
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(formatted_summary_prompt, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, bedrock_runtime=bedrock_runtime, model_source=model_source, local_model=local_model, tokenizer=tokenizer, assistant_prefill=summary_assistant_prefill)
|
458 |
|
459 |
print("Finished summary query")
|
460 |
|
|
|
482 |
aws_secret_key_textbox:str='',
|
483 |
model_name_map:dict=model_name_map,
|
484 |
reasoning_suffix:str=reasoning_suffix,
|
485 |
+
local_model:object=list(),
|
486 |
+
tokenizer:object=list(),
|
487 |
+
hf_api_key_textbox:str='',
|
488 |
summarise_topic_descriptions_prompt:str=summarise_topic_descriptions_prompt,
|
489 |
summarise_topic_descriptions_system_prompt:str=summarise_topic_descriptions_system_prompt,
|
490 |
do_summaries:str="Yes",
|
|
|
574 |
|
575 |
if (model_source == "Local") & (RUN_LOCAL_MODEL == "1"):
|
576 |
progress(0.1, f"Loading in local model: {CHOSEN_LOCAL_MODEL_TYPE}")
|
577 |
+
local_model, tokenizer = load_model(local_model_type=CHOSEN_LOCAL_MODEL_TYPE, repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model_dir=LOCAL_MODEL_FOLDER, hf_token=hf_api_key_textbox)
|
578 |
|
579 |
summary_loop_description = "Revising topic-level summaries. " + str(latest_summary_completed) + " summaries completed so far."
|
580 |
summary_loop = tqdm(range(latest_summary_completed, length_all_summaries), desc="Revising topic-level summaries", unit="summaries")
|
|
|
594 |
if "Local" in model_source and reasoning_suffix: formatted_summarise_topic_descriptions_system_prompt = formatted_summarise_topic_descriptions_system_prompt + "\n" + reasoning_suffix
|
595 |
|
596 |
try:
|
597 |
+
response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_topic_descriptions_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer)
|
598 |
summarised_output = response
|
599 |
summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
|
600 |
summarised_output = re.sub(r'^\n{1,}', '', summarised_output) # Remove one or more line breaks at the start
|
|
|
699 |
aws_secret_key_textbox:str='',
|
700 |
model_name_map:dict=model_name_map,
|
701 |
reasoning_suffix:str=reasoning_suffix,
|
702 |
+
local_model:object=list(),
|
703 |
+
tokenizer:object=list(),
|
704 |
+
hf_api_key_textbox:str='',
|
705 |
summarise_everything_prompt:str=summarise_everything_prompt,
|
706 |
comprehensive_summary_format_prompt:str=comprehensive_summary_format_prompt,
|
707 |
comprehensive_summary_format_prompt_by_group:str=comprehensive_summary_format_prompt_by_group,
|
|
|
725 |
model_name_map (dict, optional): Mapping of model names. Defaults to model_name_map.
|
726 |
reasoning_suffix (str, optional): Suffix for reasoning. Defaults to reasoning_suffix.
|
727 |
local_model (object, optional): Local model object. Defaults to empty list.
|
728 |
+
tokenizer (object, optional): Tokenizer object. Defaults to empty list.
|
729 |
+
hf_api_key_textbox (str, optional): Hugging Face API key. Defaults to empty string.
|
730 |
summarise_everything_prompt (str, optional): Prompt for overall summary
|
731 |
comprehensive_summary_format_prompt (str, optional): Prompt for comprehensive summary format
|
732 |
comprehensive_summary_format_prompt_by_group (str, optional): Prompt for group summary format
|
|
|
801 |
|
802 |
if (model_choice == CHOSEN_LOCAL_MODEL_TYPE) & (RUN_LOCAL_MODEL == "1"):
|
803 |
progress(0.1, f"Loading in local model: {CHOSEN_LOCAL_MODEL_TYPE}")
|
804 |
+
local_model, tokenizer = load_model(local_model_type=CHOSEN_LOCAL_MODEL_TYPE, repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model_dir=LOCAL_MODEL_FOLDER, hf_token=hf_api_key_textbox)
|
805 |
#print("Local model loaded:", local_model)
|
806 |
|
807 |
summary_loop = tqdm(unique_groups, desc="Creating overall summary for groups", unit="groups")
|
|
|
812 |
|
813 |
for summary_group in summary_loop:
|
814 |
|
815 |
+
print("Creating overall summary for group:", summary_group)
|
816 |
|
817 |
summary_text = topic_summary_df.loc[topic_summary_df["Group"]==summary_group].to_markdown(index=False)
|
818 |
|
|
|
823 |
if "Local" in model_source and reasoning_suffix: formatted_summarise_everything_system_prompt = formatted_summarise_everything_system_prompt + "\n" + reasoning_suffix
|
824 |
|
825 |
try:
|
826 |
+
response, conversation_history, metadata = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_everything_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer)
|
827 |
summarised_output_for_df = response
|
828 |
summarised_output = response
|
829 |
summarised_output = re.sub(r'\n{2,}', '\n', summarised_output) # Replace multiple line breaks with a single line break
|
tools/llm_api_call.py
CHANGED
@@ -688,6 +688,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
688 |
produce_structures_summary_radio:str="No",
|
689 |
aws_access_key_textbox:str='',
|
690 |
aws_secret_key_textbox:str='',
|
|
|
691 |
max_tokens:int=max_tokens,
|
692 |
model_name_map:dict=model_name_map,
|
693 |
max_time_for_loop:int=max_time_for_loop,
|
@@ -737,6 +738,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
737 |
- force_single_topic_prompt (str, optional): The prompt for forcing the model to assign only one single topic to each response.
|
738 |
- aws_access_key_textbox (str, optional): AWS access key for account with Bedrock permissions.
|
739 |
- aws_secret_key_textbox (str, optional): AWS secret key for account with Bedrock permissions.
|
|
|
740 |
- max_tokens (int): The maximum number of tokens for the model.
|
741 |
- model_name_map (dict, optional): A dictionary mapping full model name to shortened.
|
742 |
- max_time_for_loop (int, optional): The number of seconds maximum that the function should run for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there).
|
@@ -808,7 +810,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
808 |
|
809 |
if (model_source == "Local") & (RUN_LOCAL_MODEL == "1"):
|
810 |
progress(0.1, f"Loading in local model: {CHOSEN_LOCAL_MODEL_TYPE}")
|
811 |
-
local_model, tokenizer = load_model(local_model_type=CHOSEN_LOCAL_MODEL_TYPE, repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model_dir=LOCAL_MODEL_FOLDER)
|
812 |
|
813 |
if num_batches > 0:
|
814 |
progress_measure = round(latest_batch_completed / num_batches, 1)
|
@@ -938,9 +940,9 @@ def extract_topics(in_data_file: GradioFileData,
|
|
938 |
formatted_summary_prompt = structured_summary_prompt.format(response_table=normalised_simple_markdown_table,
|
939 |
topics=unique_topics_markdown)
|
940 |
|
941 |
-
if "
|
942 |
-
formatted_summary_prompt = llama_cpp_prefix + formatted_system_prompt + "\n" + formatted_summary_prompt + llama_cpp_suffix
|
943 |
-
full_prompt = formatted_summary_prompt
|
944 |
else:
|
945 |
full_prompt = formatted_system_prompt + "\n" + formatted_summary_prompt
|
946 |
|
@@ -970,7 +972,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
970 |
whole_conversation = list()
|
971 |
|
972 |
# Process requests to large language model
|
973 |
-
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=add_existing_topics_assistant_prefill,
|
974 |
|
975 |
# Return output tables
|
976 |
topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, new_topic_df, new_reference_df, new_topic_summary_df, master_batch_out_file_part, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structures_summary_radio, first_run=False, output_folder=output_folder)
|
@@ -1030,7 +1032,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
1030 |
formatted_initial_table_system_prompt = initial_table_system_prompt.format(consultation_context=context_textbox, column_name=chosen_cols)
|
1031 |
|
1032 |
# Prepare Gemini models before query
|
1033 |
-
if "
|
1034 |
print("Using Gemini model:", model_choice)
|
1035 |
google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_initial_table_system_prompt, max_tokens=max_tokens)
|
1036 |
elif model_choice == CHOSEN_LOCAL_MODEL_TYPE:
|
@@ -1062,7 +1064,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
1062 |
|
1063 |
whole_conversation = list()
|
1064 |
|
1065 |
-
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, formatted_initial_table_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=initial_table_assistant_prefill)
|
1066 |
|
1067 |
topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, topic_table_df, reference_df, new_topic_summary_df, batch_file_path_details, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structures_summary_radio, first_run=True, output_folder=output_folder)
|
1068 |
|
@@ -1271,6 +1273,7 @@ def wrapper_extract_topics_per_column_value(
|
|
1271 |
produce_structures_summary_radio: str = "No",
|
1272 |
aws_access_key_textbox:str="",
|
1273 |
aws_secret_key_textbox:str="",
|
|
|
1274 |
output_folder: str = OUTPUT_FOLDER,
|
1275 |
force_single_topic_prompt: str = force_single_topic_prompt,
|
1276 |
max_tokens: int = max_tokens,
|
@@ -1419,6 +1422,7 @@ def wrapper_extract_topics_per_column_value(
|
|
1419 |
produce_structures_summary_radio=produce_structures_summary_radio,
|
1420 |
aws_access_key_textbox=aws_access_key_textbox,
|
1421 |
aws_secret_key_textbox=aws_secret_key_textbox,
|
|
|
1422 |
max_tokens=max_tokens,
|
1423 |
model_name_map=model_name_map,
|
1424 |
max_time_for_loop=max_time_for_loop,
|
|
|
688 |
produce_structures_summary_radio:str="No",
|
689 |
aws_access_key_textbox:str='',
|
690 |
aws_secret_key_textbox:str='',
|
691 |
+
hf_api_key_textbox:str='',
|
692 |
max_tokens:int=max_tokens,
|
693 |
model_name_map:dict=model_name_map,
|
694 |
max_time_for_loop:int=max_time_for_loop,
|
|
|
738 |
- force_single_topic_prompt (str, optional): The prompt for forcing the model to assign only one single topic to each response.
|
739 |
- aws_access_key_textbox (str, optional): AWS access key for account with Bedrock permissions.
|
740 |
- aws_secret_key_textbox (str, optional): AWS secret key for account with Bedrock permissions.
|
741 |
+
- hf_api_key_textbox (str, optional): Hugging Face API key for account with Hugging Face permissions.
|
742 |
- max_tokens (int): The maximum number of tokens for the model.
|
743 |
- model_name_map (dict, optional): A dictionary mapping full model name to shortened.
|
744 |
- max_time_for_loop (int, optional): The number of seconds maximum that the function should run for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there).
|
|
|
810 |
|
811 |
if (model_source == "Local") & (RUN_LOCAL_MODEL == "1"):
|
812 |
progress(0.1, f"Loading in local model: {CHOSEN_LOCAL_MODEL_TYPE}")
|
813 |
+
local_model, tokenizer = load_model(local_model_type=CHOSEN_LOCAL_MODEL_TYPE, repo_id=LOCAL_REPO_ID, model_filename=LOCAL_MODEL_FILE, model_dir=LOCAL_MODEL_FOLDER, hf_token=hf_api_key_textbox)
|
814 |
|
815 |
if num_batches > 0:
|
816 |
progress_measure = round(latest_batch_completed / num_batches, 1)
|
|
|
940 |
formatted_summary_prompt = structured_summary_prompt.format(response_table=normalised_simple_markdown_table,
|
941 |
topics=unique_topics_markdown)
|
942 |
|
943 |
+
if model_source == "Local":
|
944 |
+
#formatted_summary_prompt = llama_cpp_prefix + formatted_system_prompt + "\n" + formatted_summary_prompt + llama_cpp_suffix
|
945 |
+
full_prompt = formatted_system_prompt + "\n" + formatted_summary_prompt
|
946 |
else:
|
947 |
full_prompt = formatted_system_prompt + "\n" + formatted_summary_prompt
|
948 |
|
|
|
972 |
whole_conversation = list()
|
973 |
|
974 |
# Process requests to large language model
|
975 |
+
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, tokenizer, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=add_existing_topics_assistant_prefill, master = True)
|
976 |
|
977 |
# Return output tables
|
978 |
topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, new_topic_df, new_reference_df, new_topic_summary_df, master_batch_out_file_part, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structures_summary_radio, first_run=False, output_folder=output_folder)
|
|
|
1032 |
formatted_initial_table_system_prompt = initial_table_system_prompt.format(consultation_context=context_textbox, column_name=chosen_cols)
|
1033 |
|
1034 |
# Prepare Gemini models before query
|
1035 |
+
if model_source == "Gemini":
|
1036 |
print("Using Gemini model:", model_choice)
|
1037 |
google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_initial_table_system_prompt, max_tokens=max_tokens)
|
1038 |
elif model_choice == CHOSEN_LOCAL_MODEL_TYPE:
|
|
|
1064 |
|
1065 |
whole_conversation = list()
|
1066 |
|
1067 |
+
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, formatted_initial_table_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, tokenizer,bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=initial_table_assistant_prefill)
|
1068 |
|
1069 |
topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, topic_table_df, reference_df, new_topic_summary_df, batch_file_path_details, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structures_summary_radio, first_run=True, output_folder=output_folder)
|
1070 |
|
|
|
1273 |
produce_structures_summary_radio: str = "No",
|
1274 |
aws_access_key_textbox:str="",
|
1275 |
aws_secret_key_textbox:str="",
|
1276 |
+
hf_api_key_textbox:str="",
|
1277 |
output_folder: str = OUTPUT_FOLDER,
|
1278 |
force_single_topic_prompt: str = force_single_topic_prompt,
|
1279 |
max_tokens: int = max_tokens,
|
|
|
1422 |
produce_structures_summary_radio=produce_structures_summary_radio,
|
1423 |
aws_access_key_textbox=aws_access_key_textbox,
|
1424 |
aws_secret_key_textbox=aws_secret_key_textbox,
|
1425 |
+
hf_api_key_textbox=hf_api_key_textbox,
|
1426 |
max_tokens=max_tokens,
|
1427 |
model_name_map=model_name_map,
|
1428 |
max_time_for_loop=max_time_for_loop,
|
tools/llm_funcs.py
CHANGED
@@ -18,7 +18,7 @@ full_text = "" # Define dummy source text (full text) just to enable highlight f
|
|
18 |
model = list() # Define empty list for model functions to run
|
19 |
tokenizer = list() #[] # Define empty list for model functions to run
|
20 |
|
21 |
-
from tools.config import AWS_REGION, LLM_TEMPERATURE, LLM_TOP_K, LLM_MIN_P, LLM_TOP_P, LLM_REPETITION_PENALTY, LLM_LAST_N_TOKENS, LLM_MAX_NEW_TOKENS, LLM_SEED, LLM_RESET, LLM_STREAM, LLM_THREADS, LLM_BATCH_SIZE, LLM_CONTEXT_LENGTH, LLM_SAMPLE, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, MAX_COMMENT_CHARS, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, HF_TOKEN, LLM_SEED, LLM_MAX_GPU_LAYERS, SPECULATIVE_DECODING, NUM_PRED_TOKENS
|
22 |
from tools.prompts import initial_table_assistant_prefill
|
23 |
|
24 |
if SPECULATIVE_DECODING == "True": SPECULATIVE_DECODING = True
|
@@ -192,7 +192,10 @@ def load_model(local_model_type:str=CHOSEN_LOCAL_MODEL_TYPE,
|
|
192 |
torch_device:str=torch_device,
|
193 |
repo_id=LOCAL_REPO_ID,
|
194 |
model_filename=LOCAL_MODEL_FILE,
|
195 |
-
model_dir=LOCAL_MODEL_FOLDER
|
|
|
|
|
|
|
196 |
'''
|
197 |
Load in a model from Hugging Face hub via the transformers package, or using llama_cpp_python by downloading a GGUF file from Huggingface Hub.
|
198 |
|
@@ -206,22 +209,22 @@ def load_model(local_model_type:str=CHOSEN_LOCAL_MODEL_TYPE,
|
|
206 |
repo_id (str): The Hugging Face repository ID where the model is located.
|
207 |
model_filename (str): The specific filename of the model to download from the repository.
|
208 |
model_dir (str): The local directory where the model will be stored or downloaded.
|
209 |
-
|
|
|
|
|
210 |
Returns:
|
211 |
tuple: A tuple containing:
|
212 |
-
-
|
213 |
-
- tokenizer (list): An empty list (tokenizer is not used with Llama.cpp directly in this setup).
|
214 |
'''
|
215 |
print("Loading model ", local_model_type)
|
216 |
-
|
217 |
|
218 |
#print("model_path:", model_path)
|
219 |
|
220 |
# Verify the device and cuda settings
|
221 |
# Check if CUDA is enabled
|
222 |
-
import torch
|
223 |
-
from llama_cpp import Llama
|
224 |
-
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
|
225 |
|
226 |
torch.cuda.empty_cache()
|
227 |
print("Is CUDA enabled? ", torch.cuda.is_available())
|
@@ -252,41 +255,132 @@ def load_model(local_model_type:str=CHOSEN_LOCAL_MODEL_TYPE,
|
|
252 |
gpu_config.update_gpu(gpu_layers)
|
253 |
gpu_config.update_context(max_context_length)
|
254 |
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
print("Loading with", gpu_config.n_gpu_layers, "model layers sent to GPU and a maximum context length of", gpu_config.n_ctx)
|
268 |
|
269 |
# CPU mode
|
270 |
else:
|
271 |
-
|
|
|
|
|
|
|
|
|
|
|
272 |
cpu_config.update_gpu(gpu_layers)
|
273 |
|
274 |
# Update context length according to slider
|
275 |
-
gpu_config.update_context(max_context_length)
|
276 |
cpu_config.update_context(max_context_length)
|
277 |
|
278 |
if speculative_decoding:
|
279 |
-
|
280 |
else:
|
281 |
-
|
282 |
|
283 |
-
print("Loading with", cpu_config.n_gpu_layers, "model layers sent to GPU and a maximum context length of",
|
|
|
284 |
|
285 |
-
tokenizer = list()
|
286 |
|
287 |
print("Finished loading model:", local_model_type)
|
288 |
print("GPU layers assigned to cuda:", gpu_layers)
|
289 |
-
return
|
290 |
|
291 |
def call_llama_cpp_model(formatted_string:str, gen_config:str, model=model):
|
292 |
"""
|
@@ -506,8 +600,83 @@ def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tok
|
|
506 |
|
507 |
return response
|
508 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
509 |
# Function to send a request and update history
|
510 |
-
def send_request(prompt: str, conversation_history: List[dict], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, system_prompt: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, local_model= list(), assistant_prefill = "", progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
|
511 |
"""
|
512 |
This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
|
513 |
It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
|
@@ -516,6 +685,8 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
|
|
516 |
"""
|
517 |
# Constructing the full prompt from the conversation history
|
518 |
full_prompt = "Conversation history:\n"
|
|
|
|
|
519 |
|
520 |
for entry in conversation_history:
|
521 |
role = entry['role'].capitalize() # Assuming the history is stored with 'role' and 'parts'
|
@@ -573,13 +744,18 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
|
|
573 |
gen_config = LlamaCPPGenerationConfig()
|
574 |
gen_config.update_temp(temperature)
|
575 |
|
576 |
-
|
|
|
|
|
|
|
|
|
|
|
577 |
|
578 |
#print("Successful call to local model.")
|
579 |
break
|
580 |
except Exception as e:
|
581 |
# If fails, try again after X seconds in case there is a throttle limit
|
582 |
-
print("Call to
|
583 |
|
584 |
time.sleep(timeout_wait)
|
585 |
|
@@ -596,21 +772,24 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
|
|
596 |
if isinstance(response, ResponseObject):
|
597 |
response_text = response.text
|
598 |
conversation_history.append({'role': 'assistant', 'parts': [response_text]})
|
599 |
-
elif 'choices' in response:
|
600 |
if "gpt-oss" in model_choice:
|
601 |
response_text = response['choices'][0]['message']['content'].split('<|start|>assistant<|channel|>final<|message|>')[1]
|
602 |
else:
|
603 |
response_text = response['choices'][0]['message']['content']
|
604 |
response_text = response_text.strip()
|
605 |
conversation_history.append({'role': 'assistant', 'parts': [response_text]}) #response['choices'][0]['text']]})
|
606 |
-
|
607 |
response_text = response.text
|
608 |
response_text = response_text.strip()
|
609 |
conversation_history.append({'role': 'assistant', 'parts': [response_text]})
|
|
|
|
|
|
|
610 |
|
611 |
-
return response, conversation_history, response_text
|
612 |
|
613 |
-
def process_requests(prompts: List[str], system_prompt: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, batch_no:int = 1, local_model = list(), master:bool = False, assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
|
614 |
"""
|
615 |
Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
|
616 |
|
@@ -641,7 +820,7 @@ def process_requests(prompts: List[str], system_prompt: str, conversation_histor
|
|
641 |
|
642 |
for prompt in prompts:
|
643 |
|
644 |
-
response, conversation_history, response_text = send_request(prompt, conversation_history, google_client=google_client, config=config, model_choice=model_choice, system_prompt=system_prompt, temperature=temperature, local_model=local_model, assistant_prefill=assistant_prefill, bedrock_runtime=bedrock_runtime, model_source=model_source)
|
645 |
|
646 |
responses.append(response)
|
647 |
whole_conversation.append(system_prompt)
|
@@ -677,9 +856,16 @@ def process_requests(prompts: List[str], system_prompt: str, conversation_histor
|
|
677 |
whole_conversation_metadata.append(str(response.usage_metadata))
|
678 |
|
679 |
elif "Local" in model_source:
|
680 |
-
|
681 |
-
|
682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
except KeyError as e:
|
684 |
print(f"Key error: {e} - Check the structure of response.usage_metadata")
|
685 |
else:
|
@@ -699,6 +885,7 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
|
|
699 |
temperature: float,
|
700 |
reported_batch_no: int,
|
701 |
local_model: object,
|
|
|
702 |
bedrock_runtime:boto3.Session.client,
|
703 |
model_source:str,
|
704 |
MAX_OUTPUT_VALIDATION_ATTEMPTS: int,
|
@@ -721,6 +908,7 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
|
|
721 |
- temperature (float): The temperature parameter for the model.
|
722 |
- reported_batch_no (int): The reported batch number.
|
723 |
- local_model (object): The local model to use.
|
|
|
724 |
- bedrock_runtime (boto3.Session.client): The client object for boto3 Bedrock runtime.
|
725 |
- model_source (str): The source of the model, whether in AWS, Gemini, or local.
|
726 |
- MAX_OUTPUT_VALIDATION_ATTEMPTS (int): The maximum number of attempts to validate the output.
|
@@ -743,7 +931,7 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
|
|
743 |
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(
|
744 |
batch_prompts, system_prompt, conversation_history, whole_conversation,
|
745 |
whole_conversation_metadata, google_client, google_config, model_choice,
|
746 |
-
call_temperature, bedrock_runtime, model_source, reported_batch_no, local_model, master=master, assistant_prefill=assistant_prefill
|
747 |
)
|
748 |
|
749 |
stripped_response = response_text.strip()
|
|
|
18 |
model = list() # Define empty list for model functions to run
|
19 |
tokenizer = list() #[] # Define empty list for model functions to run
|
20 |
|
21 |
+
from tools.config import AWS_REGION, LLM_TEMPERATURE, LLM_TOP_K, LLM_MIN_P, LLM_TOP_P, LLM_REPETITION_PENALTY, LLM_LAST_N_TOKENS, LLM_MAX_NEW_TOKENS, LLM_SEED, LLM_RESET, LLM_STREAM, LLM_THREADS, LLM_BATCH_SIZE, LLM_CONTEXT_LENGTH, LLM_SAMPLE, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, MAX_COMMENT_CHARS, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, HF_TOKEN, LLM_SEED, LLM_MAX_GPU_LAYERS, SPECULATIVE_DECODING, NUM_PRED_TOKENS, USE_LLAMA_CPP, COMPILE_MODE, MODEL_DTYPE, USE_BITSANDBYTES, COMPILE_TRANSFORMERS, OFFLOAD_TO_CPU
|
22 |
from tools.prompts import initial_table_assistant_prefill
|
23 |
|
24 |
if SPECULATIVE_DECODING == "True": SPECULATIVE_DECODING = True
|
|
|
192 |
torch_device:str=torch_device,
|
193 |
repo_id=LOCAL_REPO_ID,
|
194 |
model_filename=LOCAL_MODEL_FILE,
|
195 |
+
model_dir=LOCAL_MODEL_FOLDER,
|
196 |
+
compile_mode=COMPILE_MODE,
|
197 |
+
model_dtype=MODEL_DTYPE,
|
198 |
+
hf_token=HF_TOKEN):
|
199 |
'''
|
200 |
Load in a model from Hugging Face hub via the transformers package, or using llama_cpp_python by downloading a GGUF file from Huggingface Hub.
|
201 |
|
|
|
209 |
repo_id (str): The Hugging Face repository ID where the model is located.
|
210 |
model_filename (str): The specific filename of the model to download from the repository.
|
211 |
model_dir (str): The local directory where the model will be stored or downloaded.
|
212 |
+
compile_mode (str): The compilation mode to use for the model.
|
213 |
+
model_dtype (str): The data type to use for the model.
|
214 |
+
hf_token (str): The Hugging Face token to use for the model.
|
215 |
Returns:
|
216 |
tuple: A tuple containing:
|
217 |
+
- model (Llama/transformers model): The loaded Llama.cpp/transformers model instance.
|
218 |
+
- tokenizer (list/transformers tokenizer): An empty list (tokenizer is not used with Llama.cpp directly in this setup), or a transformers tokenizer.
|
219 |
'''
|
220 |
print("Loading model ", local_model_type)
|
221 |
+
tokenizer = list()
|
222 |
|
223 |
#print("model_path:", model_path)
|
224 |
|
225 |
# Verify the device and cuda settings
|
226 |
# Check if CUDA is enabled
|
227 |
+
import torch
|
|
|
|
|
228 |
|
229 |
torch.cuda.empty_cache()
|
230 |
print("Is CUDA enabled? ", torch.cuda.is_available())
|
|
|
255 |
gpu_config.update_gpu(gpu_layers)
|
256 |
gpu_config.update_context(max_context_length)
|
257 |
|
258 |
+
if USE_LLAMA_CPP == "True":
|
259 |
+
from llama_cpp import Llama
|
260 |
+
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
|
261 |
+
|
262 |
+
model_path = get_model_path(repo_id=repo_id, model_filename=model_filename, model_dir=model_dir)
|
263 |
+
|
264 |
+
try:
|
265 |
+
print("GPU load variables:" , vars(gpu_config))
|
266 |
+
if speculative_decoding:
|
267 |
+
model = Llama(model_path=model_path, type_k=8, type_v=8, flash_attn=True, draft_model=LlamaPromptLookupDecoding(num_pred_tokens=NUM_PRED_TOKENS), **vars(gpu_config))
|
268 |
+
else:
|
269 |
+
model = Llama(model_path=model_path, type_k=8, type_v=8, flash_attn=True, **vars(gpu_config))
|
270 |
|
271 |
+
except Exception as e:
|
272 |
+
print("GPU load failed due to:", e, "Loading model in CPU mode")
|
273 |
+
# If fails, go to CPU mode
|
274 |
+
model = Llama(model_path=model_path, **vars(cpu_config))
|
275 |
+
|
276 |
+
else:
|
277 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
278 |
+
|
279 |
+
print("Loading model from transformers")
|
280 |
+
# Use the official model ID for Gemma 3 4B
|
281 |
+
model_id = repo_id
|
282 |
+
# 1. Set Data Type (dtype)
|
283 |
+
# For H200/Hopper: 'bfloat16'
|
284 |
+
# For RTX 3060/Ampere: 'float16'
|
285 |
+
dtype_str = model_dtype #os.environ.get("MODEL_DTYPE", "bfloat16").lower()
|
286 |
+
if dtype_str == "bfloat16":
|
287 |
+
torch_dtype = torch.bfloat16
|
288 |
+
elif dtype_str == "float16":
|
289 |
+
torch_dtype = torch.float16
|
290 |
+
else:
|
291 |
+
torch_dtype = torch.float32 # A safe fallback
|
292 |
+
|
293 |
+
# 2. Set Compilation Mode
|
294 |
+
# 'max-autotune' is great for both but can be slow initially.
|
295 |
+
# 'reduce-overhead' is a faster alternative for compiling.
|
296 |
+
|
297 |
+
print(f"--- System Configuration ---")
|
298 |
+
print(f"Using model id: {model_id}")
|
299 |
+
print(f"Using dtype: {torch_dtype}")
|
300 |
+
print(f"Using compile mode: {compile_mode}")
|
301 |
+
print(f"Using bitsandbytes: {USE_BITSANDBYTES}")
|
302 |
+
print("--------------------------\n")
|
303 |
+
|
304 |
+
# --- Load Tokenizer and Model ---
|
305 |
+
|
306 |
+
# Load Tokenizer and Model
|
307 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
308 |
+
|
309 |
+
if not tokenizer.pad_token:
|
310 |
+
tokenizer.pad_token = tokenizer.eos_token
|
311 |
+
|
312 |
+
if USE_BITSANDBYTES == "True":
|
313 |
+
|
314 |
+
if OFFLOAD_TO_CPU == "True":
|
315 |
+
# This will be very slow. Requires at least 4GB of VRAM and 32GB of RAM
|
316 |
+
print("Using bitsandbytes for quantisation to 8 bits, with offloading to CPU")
|
317 |
+
max_memory={0: "4GB", "cpu": "32GB"}
|
318 |
+
quantization_config = BitsAndBytesConfig(
|
319 |
+
load_in_8bit=True,
|
320 |
+
max_memory=max_memory,
|
321 |
+
llm_int8_enable_fp32_cpu_offload=True # Note: if bitsandbytes has to offload to CPU, inference will be slow
|
322 |
+
)
|
323 |
+
else:
|
324 |
+
# For Gemma 4B, requires at least 6GB of VRAM
|
325 |
+
print("Using bitsandbytes for quantisation to 4 bits")
|
326 |
+
quantization_config = BitsAndBytesConfig(
|
327 |
+
load_in_4bit=True,
|
328 |
+
bnb_4bit_quant_type="nf4", # Use the modern NF4 quantisation for better performance
|
329 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
330 |
+
bnb_4bit_use_double_quant=True, # Optional: uses a second quantisation step to save even more memory
|
331 |
+
)
|
332 |
+
|
333 |
+
model = AutoModelForCausalLM.from_pretrained(
|
334 |
+
model_id,
|
335 |
+
torch_dtype=torch_dtype,
|
336 |
+
device_map="auto",
|
337 |
+
quantization_config=quantization_config,
|
338 |
+
token=hf_token
|
339 |
+
)
|
340 |
+
else:
|
341 |
+
print("Using fp16 precision for model")
|
342 |
+
model = AutoModelForCausalLM.from_pretrained(
|
343 |
+
model_id,
|
344 |
+
torch_dtype=torch_dtype,
|
345 |
+
device_map="auto",
|
346 |
+
token=hf_token
|
347 |
+
)
|
348 |
+
|
349 |
+
# Compile the Model with the selected mode 🚀
|
350 |
+
if COMPILE_TRANSFORMERS == "True":
|
351 |
+
try:
|
352 |
+
model = torch.compile(model, mode=compile_mode, fullgraph=True)
|
353 |
+
except Exception as e:
|
354 |
+
print(f"Could not compile model: {e}. Running in eager mode.")
|
355 |
|
356 |
print("Loading with", gpu_config.n_gpu_layers, "model layers sent to GPU and a maximum context length of", gpu_config.n_ctx)
|
357 |
|
358 |
# CPU mode
|
359 |
else:
|
360 |
+
if USE_LLAMA_CPP == "False":
|
361 |
+
raise Warning("Using transformers model in CPU mode is not supported. Please change your config variable USE_LLAMA_CPP to True if you want to do CPU inference.")
|
362 |
+
|
363 |
+
model_path = get_model_path(repo_id=repo_id, model_filename=model_filename, model_dir=model_dir)
|
364 |
+
|
365 |
+
#gpu_config.update_gpu(gpu_layers)
|
366 |
cpu_config.update_gpu(gpu_layers)
|
367 |
|
368 |
# Update context length according to slider
|
369 |
+
#gpu_config.update_context(max_context_length)
|
370 |
cpu_config.update_context(max_context_length)
|
371 |
|
372 |
if speculative_decoding:
|
373 |
+
model = Llama(model_path=model_path, draft_model=LlamaPromptLookupDecoding(num_pred_tokens=NUM_PRED_TOKENS), **vars(cpu_config))
|
374 |
else:
|
375 |
+
model = Llama(model_path=model_path, **vars(cpu_config))
|
376 |
|
377 |
+
print("Loading with", cpu_config.n_gpu_layers, "model layers sent to GPU and a maximum context length of", cpu_config.n_ctx)
|
378 |
+
|
379 |
|
|
|
380 |
|
381 |
print("Finished loading model:", local_model_type)
|
382 |
print("GPU layers assigned to cuda:", gpu_layers)
|
383 |
+
return model, tokenizer
|
384 |
|
385 |
def call_llama_cpp_model(formatted_string:str, gen_config:str, model=model):
|
386 |
"""
|
|
|
600 |
|
601 |
return response
|
602 |
|
603 |
+
def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCPPGenerationConfig, model=model, tokenizer=tokenizer):
|
604 |
+
"""
|
605 |
+
This function sends a request to a transformers model with the given prompt, system prompt, and generation configuration.
|
606 |
+
"""
|
607 |
+
# 1. Define the conversation as a list of dictionaries
|
608 |
+
conversation = [
|
609 |
+
{"role": "system", "content": system_prompt},
|
610 |
+
{"role": "user", "content": prompt}
|
611 |
+
]
|
612 |
+
|
613 |
+
# 2. Apply the chat template
|
614 |
+
# This function formats the conversation into the exact string Gemma 3 expects.
|
615 |
+
# add_generation_prompt=True adds the special tokens that tell the model it's its turn to speak.
|
616 |
+
input_ids = tokenizer.apply_chat_template(
|
617 |
+
conversation,
|
618 |
+
add_generation_prompt=True,
|
619 |
+
return_tensors="pt"
|
620 |
+
).to("cuda")
|
621 |
+
|
622 |
+
# Warm-up run
|
623 |
+
print("Performing warm-up run...")
|
624 |
+
_ = model.generate(input_ids, max_new_tokens=50)
|
625 |
+
print("Warm-up complete.")
|
626 |
+
|
627 |
+
# Map LlamaCPP parameters to transformers parameters
|
628 |
+
generation_kwargs = {
|
629 |
+
'max_new_tokens': gen_config.max_tokens,
|
630 |
+
'temperature': gen_config.temperature,
|
631 |
+
'top_p': gen_config.top_p,
|
632 |
+
'top_k': gen_config.top_k,
|
633 |
+
'do_sample': True,
|
634 |
+
'pad_token_id': tokenizer.eos_token_id
|
635 |
+
}
|
636 |
+
|
637 |
+
# Remove parameters that don't exist in transformers
|
638 |
+
if hasattr(gen_config, 'repeat_penalty'):
|
639 |
+
generation_kwargs['repetition_penalty'] = gen_config.repeat_penalty
|
640 |
+
|
641 |
+
# --- Timed Inference Test ---
|
642 |
+
print("\nStarting timed inference test...")
|
643 |
+
start_time = time.time()
|
644 |
+
|
645 |
+
|
646 |
+
|
647 |
+
outputs = model.generate(
|
648 |
+
input_ids,
|
649 |
+
**generation_kwargs
|
650 |
+
)
|
651 |
+
|
652 |
+
end_time = time.time()
|
653 |
+
|
654 |
+
# --- Decode and Display Results ---
|
655 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
656 |
+
# To get only the model's reply, we can decode just the newly generated tokens
|
657 |
+
new_tokens = outputs[0][input_ids.shape[-1]:]
|
658 |
+
assistant_reply = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
659 |
+
|
660 |
+
num_input_tokens = len(input_ids)
|
661 |
+
num_generated_tokens = len(new_tokens)
|
662 |
+
duration = end_time - start_time
|
663 |
+
tokens_per_second = num_generated_tokens / duration
|
664 |
+
|
665 |
+
print("\n--- Inference Results ---")
|
666 |
+
print(f"System Prompt: {conversation[0]['content']}")
|
667 |
+
print(f"User Prompt: {conversation[1]['content']}")
|
668 |
+
print("---")
|
669 |
+
print(f"Assistant's Reply: {assistant_reply}")
|
670 |
+
print("\n--- Performance ---")
|
671 |
+
print(f"Time taken: {duration:.2f} seconds")
|
672 |
+
print(f"Generated tokens: {num_generated_tokens}")
|
673 |
+
print(f"Tokens per second: {tokens_per_second:.2f}")
|
674 |
+
|
675 |
+
return assistant_reply, num_input_tokens, num_generated_tokens
|
676 |
+
|
677 |
+
|
678 |
# Function to send a request and update history
|
679 |
+
def send_request(prompt: str, conversation_history: List[dict], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, system_prompt: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, local_model= list(), tokenizer=tokenizer, assistant_prefill = "", progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
|
680 |
"""
|
681 |
This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
|
682 |
It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
|
|
|
685 |
"""
|
686 |
# Constructing the full prompt from the conversation history
|
687 |
full_prompt = "Conversation history:\n"
|
688 |
+
num_transformer_input_tokens = 0
|
689 |
+
num_transformer_generated_tokens = 0
|
690 |
|
691 |
for entry in conversation_history:
|
692 |
role = entry['role'].capitalize() # Assuming the history is stored with 'role' and 'parts'
|
|
|
744 |
gen_config = LlamaCPPGenerationConfig()
|
745 |
gen_config.update_temp(temperature)
|
746 |
|
747 |
+
if USE_LLAMA_CPP == "True":
|
748 |
+
response = call_llama_cpp_chatmodel(prompt, system_prompt, gen_config, model=local_model)
|
749 |
+
|
750 |
+
else:
|
751 |
+
response, num_transformer_input_tokens, num_transformer_generated_tokens = call_transformers_model(prompt, system_prompt, gen_config, model=local_model, tokenizer=tokenizer)
|
752 |
+
response_text = response
|
753 |
|
754 |
#print("Successful call to local model.")
|
755 |
break
|
756 |
except Exception as e:
|
757 |
# If fails, try again after X seconds in case there is a throttle limit
|
758 |
+
print("Call to local model failed:", e, " Waiting for ", str(timeout_wait), "seconds and trying again.")
|
759 |
|
760 |
time.sleep(timeout_wait)
|
761 |
|
|
|
772 |
if isinstance(response, ResponseObject):
|
773 |
response_text = response.text
|
774 |
conversation_history.append({'role': 'assistant', 'parts': [response_text]})
|
775 |
+
elif 'choices' in response: # LLama.cpp model response
|
776 |
if "gpt-oss" in model_choice:
|
777 |
response_text = response['choices'][0]['message']['content'].split('<|start|>assistant<|channel|>final<|message|>')[1]
|
778 |
else:
|
779 |
response_text = response['choices'][0]['message']['content']
|
780 |
response_text = response_text.strip()
|
781 |
conversation_history.append({'role': 'assistant', 'parts': [response_text]}) #response['choices'][0]['text']]})
|
782 |
+
elif model_source == "Gemini":
|
783 |
response_text = response.text
|
784 |
response_text = response_text.strip()
|
785 |
conversation_history.append({'role': 'assistant', 'parts': [response_text]})
|
786 |
+
else: # Assume transformers model response
|
787 |
+
response_text = response
|
788 |
+
conversation_history.append({'role': 'assistant', 'parts': [response_text]})
|
789 |
|
790 |
+
return response, conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens
|
791 |
|
792 |
+
def process_requests(prompts: List[str], system_prompt: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, batch_no:int = 1, local_model = list(), tokenizer=tokenizer, master:bool = False, assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
|
793 |
"""
|
794 |
Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
|
795 |
|
|
|
820 |
|
821 |
for prompt in prompts:
|
822 |
|
823 |
+
response, conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens = send_request(prompt, conversation_history, google_client=google_client, config=config, model_choice=model_choice, system_prompt=system_prompt, temperature=temperature, local_model=local_model, tokenizer=tokenizer, assistant_prefill=assistant_prefill, bedrock_runtime=bedrock_runtime, model_source=model_source)
|
824 |
|
825 |
responses.append(response)
|
826 |
whole_conversation.append(system_prompt)
|
|
|
856 |
whole_conversation_metadata.append(str(response.usage_metadata))
|
857 |
|
858 |
elif "Local" in model_source:
|
859 |
+
if USE_LLAMA_CPP == "True":
|
860 |
+
output_tokens = response['usage'].get('completion_tokens', 0)
|
861 |
+
input_tokens = response['usage'].get('prompt_tokens', 0)
|
862 |
+
whole_conversation_metadata.append(str(response['usage']))
|
863 |
+
|
864 |
+
if USE_LLAMA_CPP == "False":
|
865 |
+
input_tokens = num_transformer_input_tokens
|
866 |
+
output_tokens = num_transformer_generated_tokens
|
867 |
+
whole_conversation_metadata.append('inputTokens: ' + str(input_tokens) + ' outputTokens: ' + str(output_tokens))
|
868 |
+
|
869 |
except KeyError as e:
|
870 |
print(f"Key error: {e} - Check the structure of response.usage_metadata")
|
871 |
else:
|
|
|
885 |
temperature: float,
|
886 |
reported_batch_no: int,
|
887 |
local_model: object,
|
888 |
+
tokenizer:object,
|
889 |
bedrock_runtime:boto3.Session.client,
|
890 |
model_source:str,
|
891 |
MAX_OUTPUT_VALIDATION_ATTEMPTS: int,
|
|
|
908 |
- temperature (float): The temperature parameter for the model.
|
909 |
- reported_batch_no (int): The reported batch number.
|
910 |
- local_model (object): The local model to use.
|
911 |
+
- tokenizer (object): The tokenizer to use.
|
912 |
- bedrock_runtime (boto3.Session.client): The client object for boto3 Bedrock runtime.
|
913 |
- model_source (str): The source of the model, whether in AWS, Gemini, or local.
|
914 |
- MAX_OUTPUT_VALIDATION_ATTEMPTS (int): The maximum number of attempts to validate the output.
|
|
|
931 |
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(
|
932 |
batch_prompts, system_prompt, conversation_history, whole_conversation,
|
933 |
whole_conversation_metadata, google_client, google_config, model_choice,
|
934 |
+
call_temperature, bedrock_runtime, model_source, reported_batch_no, local_model, tokenizer=tokenizer, master=master, assistant_prefill=assistant_prefill
|
935 |
)
|
936 |
|
937 |
stripped_response = response_text.strip()
|
tools/verify_titles.py
CHANGED
@@ -448,7 +448,7 @@ def verify_titles(in_data_file,
|
|
448 |
summary_whole_conversation = list()
|
449 |
|
450 |
# Process requests to large language model
|
451 |
-
responses, summary_conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=add_existing_topics_assistant_prefill, master = True)
|
452 |
|
453 |
|
454 |
|
@@ -549,7 +549,7 @@ def verify_titles(in_data_file,
|
|
549 |
|
550 |
whole_conversation = [formatted_initial_table_system_prompt]
|
551 |
|
552 |
-
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=initial_table_assistant_prefill)
|
553 |
|
554 |
|
555 |
topic_table_out_path, reference_table_out_path, unique_topics_df_out_path, topic_table_df, markdown_table, reference_df, new_unique_topics_df, batch_file_path_details, is_error = write_llm_output_and_logs_verify(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_unique_topics_df, batch_size, chosen_cols, model_name_map=model_name_map, first_run=True)
|
|
|
448 |
summary_whole_conversation = list()
|
449 |
|
450 |
# Process requests to large language model
|
451 |
+
responses, summary_conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=add_existing_topics_assistant_prefill, tokenizer=tokenizer, master = True)
|
452 |
|
453 |
|
454 |
|
|
|
549 |
|
550 |
whole_conversation = [formatted_initial_table_system_prompt]
|
551 |
|
552 |
+
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=initial_table_assistant_prefill, tokenizer=tokenizer)
|
553 |
|
554 |
|
555 |
topic_table_out_path, reference_table_out_path, unique_topics_df_out_path, topic_table_df, markdown_table, reference_df, new_unique_topics_df, batch_file_path_details, is_error = write_llm_output_and_logs_verify(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_unique_topics_df, batch_size, chosen_cols, model_name_map=model_name_map, first_run=True)
|