phyloforfun commited on
Commit
567930d
1 Parent(s): c6a70af
.gitignore CHANGED
@@ -7,15 +7,27 @@ yolov8x-pose.pt
7
  yolov8n.pt
8
  *PRIVATE_DATA*
9
 
 
 
10
  # Prompts
11
  /custom_prompts/*
12
  !/custom_prompts/SLTPvA_long.yaml
13
  !/custom_prompts/SLTPvA_medium.yaml
14
  !/custom_prompts/SLTPvA_short.yaml
 
 
 
15
 
16
  # Dirs
17
  custom_prompts_deprecated/
18
  demo/demo_output/*
 
 
 
 
 
 
 
19
  demo/demo_configs/*
20
  uploads/*
21
  uploads_small/*
@@ -59,6 +71,9 @@ vouchervision/component_detector/runs/
59
  vouchervision/component_detector/architecture/
60
  vouchervision/component_detector/yolov5x6.pt
61
 
 
 
 
62
  vouchervision/instructor-xl/
63
  vouchervision/instructor-embedding/
64
 
 
7
  yolov8n.pt
8
  *PRIVATE_DATA*
9
 
10
+ vouchervision/LLM_MistralAI_Azure_endpoints.py
11
+
12
  # Prompts
13
  /custom_prompts/*
14
  !/custom_prompts/SLTPvA_long.yaml
15
  !/custom_prompts/SLTPvA_medium.yaml
16
  !/custom_prompts/SLTPvA_short.yaml
17
+ !/custom_prompts/SLTPvB_long.yaml
18
+ !/custom_prompts/SLTPvB_medium.yaml
19
+ !/custom_prompts/SLTPvB_short.yaml
20
 
21
  # Dirs
22
  custom_prompts_deprecated/
23
  demo/demo_output/*
24
+
25
+ demo/validation_images_repeat/
26
+ demo/validation_json/
27
+ demo/validation_figs/
28
+ demo/validation_output/
29
+ demo/validation_xlsx/
30
+
31
  demo/demo_configs/*
32
  uploads/*
33
  uploads_small/*
 
71
  vouchervision/component_detector/architecture/
72
  vouchervision/component_detector/yolov5x6.pt
73
 
74
+ vouchervision/vouchervision_test_all_options.py
75
+ vouchervision/prompt_arena.py
76
+
77
  vouchervision/instructor-xl/
78
  vouchervision/instructor-embedding/
79
 
.streamlit/config.toml CHANGED
@@ -1,5 +1,11 @@
1
  [theme]
2
- primaryColor = "#00ff00"
3
  backgroundColor="#1a1a1a"
4
  secondaryBackgroundColor="#303030"
5
- textColor = "cccccc"
 
 
 
 
 
 
 
1
  [theme]
2
+ primaryColor = "#16a616"
3
  backgroundColor="#1a1a1a"
4
  secondaryBackgroundColor="#303030"
5
+ textColor = "cccccc"
6
+
7
+ [server]
8
+ enableStaticServing = false
9
+ runOnSave = true
10
+ port = 8524
11
+ maxUploadSize = 5000
api_status.yaml CHANGED
@@ -1,10 +1,12 @@
1
- date: January 26, 2024
2
  missing_keys: []
3
  present_keys:
4
- - Google OCR (Valid)
 
5
  - OpenAI (Valid)
6
  - Azure OpenAI (Valid)
7
  - Palm2 (Valid)
 
8
  - Gemini (Valid)
9
  - Mistral (Valid)
10
  - HERE Geocode (Valid)
 
1
+ date: February 29, 2024
2
  missing_keys: []
3
  present_keys:
4
+ - Google OCR Print (Valid)
5
+ - Google OCR Handwriting (Valid)
6
  - OpenAI (Valid)
7
  - Azure OpenAI (Valid)
8
  - Palm2 (Valid)
9
+ - Palm2 LangChain (Valid)
10
  - Gemini (Valid)
11
  - Mistral (Valid)
12
  - HERE Geocode (Valid)
app.py CHANGED
@@ -7,7 +7,6 @@ import pandas as pd
7
  from io import BytesIO
8
  from streamlit_extras.let_it_rain import rain
9
  from annotated_text import annotated_text
10
- from transformers import AutoConfig
11
 
12
  from vouchervision.LeafMachine2_Config_Builder import write_config_file
13
  from vouchervision.VoucherVision_Config_Builder import build_VV_config, TestOptionsGPT, TestOptionsPalm, check_if_usable
@@ -18,6 +17,7 @@ from vouchervision.API_validation import APIvalidation
18
  from vouchervision.utils_hf import setup_streamlit_config, save_uploaded_file, save_uploaded_local, save_uploaded_file_local
19
  from vouchervision.data_project import convert_pdf_to_jpg
20
  from vouchervision.utils_LLM import check_system_gpus
 
21
 
22
  import cProfile
23
  import pstats
@@ -250,14 +250,25 @@ def load_gallery(converted_files, uploaded_file):
250
  file_path_small = save_uploaded_file(st.session_state['dir_uploaded_images_small'], uploaded_file, img)
251
  st.session_state['input_list_small'].append(file_path_small)
252
 
 
 
 
253
  @st.cache_data
254
  def handle_image_upload_and_gallery_hf(uploaded_files):
255
  if uploaded_files:
 
256
  # Clear input image gallery and input list
257
  clear_image_uploads()
258
 
259
  ind_small = 0
260
  for uploaded_file in uploaded_files:
 
 
 
 
 
 
 
261
  # Determine the file type
262
  if uploaded_file.name.lower().endswith('.pdf'):
263
  # Handle PDF files
@@ -305,6 +316,8 @@ def handle_image_upload_and_gallery_hf(uploaded_files):
305
  # If there are less than 100 images, take them all
306
  images_to_display = st.session_state['input_list_small']
307
  show_gallery_small_hf(images_to_display)
 
 
308
 
309
 
310
  @st.cache_data
@@ -378,7 +391,7 @@ def content_input_images(col_left, col_right):
378
 
379
  with col_right:
380
  if st.session_state.is_hf:
381
- handle_image_upload_and_gallery_hf(uploaded_files)
382
 
383
  else:
384
  st.session_state['view_local_gallery'] = st.toggle("View Image Gallery",)
@@ -427,7 +440,8 @@ def count_jpg_images(directory_path):
427
 
428
  def create_download_button(zip_filepath, col, key):
429
  with col:
430
- labal_n_images = f"Download Results for {st.session_state['processing_add_on']} Images"
 
431
  with open(zip_filepath, 'rb') as f:
432
  bytes_io = BytesIO(f.read())
433
  st.download_button(
@@ -1067,6 +1081,11 @@ def create_private_file():
1067
  "client_x509_cert_url": "A LONG URL",
1068
  "universe_domain": "googleapis.com"
1069
  })
 
 
 
 
 
1070
  google_application_credentials = st.text_input(label = 'Full path to Google Cloud JSON API key file', value = cfg_private['google'].get('GOOGLE_APPLICATION_CREDENTIALS', ''),
1071
  placeholder = 'e.g. C:/Documents/Secret_Files/google_API/application_default_credentials.json',
1072
  help ="This API Key is in the form of a JSON file. Please save the JSON file in a safe directory. DO NOT store the JSON key inside of the VoucherVision directory.",
@@ -1127,7 +1146,7 @@ def create_private_file():
1127
 
1128
  st.write("---")
1129
  st.subheader("MistralAI")
1130
- st.markdown('Follow these [instructions](https://platform.here.com/sign-up?step=verify-identity) to generate an API key for HERE.')
1131
  mistral_API_KEY = st.text_input("MistralAI API Key", cfg_private['mistral'].get('MISTRAL_API_KEY', ''),
1132
  help='e.g. a 32-character string',
1133
  placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
@@ -1360,7 +1379,7 @@ def get_all_cost_tables():
1360
  cost_openai[key] = cost_data.get(value,'')
1361
  elif 'PALM2' in parts or 'GEMINI' in parts:
1362
  cost_google[key] = cost_data.get(value,'')
1363
- elif 'MISTRAL' in parts:
1364
  cost_mistral[key] = cost_data.get(value,'')
1365
 
1366
  styled_cost_openai = convert_cost_dict_to_table(cost_openai, "OpenAI")
@@ -1403,9 +1422,9 @@ def content_header():
1403
  N_STEPS = 6
1404
 
1405
  if check_if_usable(is_hf=st.session_state['is_hf']):
1406
- b_text = f"Start Processing {st.session_state['processing_add_on']} Images" if st.session_state['processing_add_on'] > 1 else f"Start Processing {st.session_state['processing_add_on']} Image"
1407
- if st.session_state['processing_add_on'] == 0:
1408
- b_text = f"Start Processing"
1409
  if st.button(b_text, type='primary',use_container_width=True):
1410
  st.session_state['formatted_json'] = {}
1411
  st.session_state['formatted_json_WFO'] = {}
@@ -1466,7 +1485,7 @@ def content_header():
1466
  if st.session_state['zip_filepath']:
1467
  create_download_button(st.session_state['zip_filepath'], col_run_1,key=97863332)
1468
  else:
1469
- st.button("Start Processing", type='primary', disabled=True)
1470
  with col_run_4:
1471
  st.error(":heavy_exclamation_mark: Required API keys not set. Please visit the 'API Keys' tab and set the Google Vision OCR API key and at least one LLM key.")
1472
 
@@ -1482,11 +1501,11 @@ def content_header():
1482
  ct_left, ct_right = st.columns([1,1])
1483
  with ct_left:
1484
  st.button("Refresh", on_click=refresh, use_container_width=True)
1485
- # with ct_right:
1486
- # try:
1487
- # st.page_link(os.path.join("pages","faqs.py"), label="FAQs", icon="❔")
1488
- # except:
1489
- # st.page_link(os.path.join(os.path.dirname(__file__),"pages","faqs.py"), label="FAQs", icon="❔")
1490
 
1491
 
1492
 
@@ -1687,12 +1706,12 @@ def content_prompt_and_llm_version():
1687
  selected_version = default_version
1688
  st.session_state.config['leafmachine']['project']['prompt_version'] = st.selectbox("Prompt Version", available_prompts, index=available_prompts.index(selected_version),label_visibility='collapsed')
1689
 
1690
- # with col_prompt_2:
1691
- # # if st.button("Build Custom LLM Prompt"):
1692
- # try:
1693
- # st.page_link(os.path.join("pages","prompt_builder.py"), label="Prompt Builder", icon="🚧")
1694
- # except:
1695
- # st.page_link(os.path.join(os.path.dirname(__file__),"pages","prompt_builder.py"), label="Prompt Builder", icon="🚧")
1696
 
1697
 
1698
  st.header('LLM Version')
@@ -1703,18 +1722,18 @@ def content_prompt_and_llm_version():
1703
  st.session_state.config['leafmachine']['LLM_version'] = st.selectbox("LLM version", GUI_MODEL_LIST, index=GUI_MODEL_LIST.index(st.session_state.config['leafmachine'].get('LLM_version', ModelMaps.MODELS_GUI_DEFAULT)))
1704
  st.markdown("""
1705
  Based on preliminary results, the following models perform the best. We are currently running tests of all possible OCR + LLM + Prompt combinations to create recipes for different workflows.
1706
- - `Mistral Medium`
1707
- - `Mistral Small`
1708
- - `Mistral Tiny`
1709
  - `PaLM 2 text-bison@001`
1710
  - `GPT 4 Turbo 1106-preview`
1711
- - `GPT 3.5 Instruct`
1712
  - `LOCAL Mixtral 7Bx8 Instruct`
1713
  - `LOCAL Mixtral 7B Instruct`
1714
 
1715
  Larger models (e.g., `GPT 4`, `GPT 4 32k`, `Gemini Pro`) do not necessarily perform better for these tasks. MistralAI models exceeded our expectations and perform extremely well. PaLM 2 text-bison@001 also seems to consistently out-perform Gemini Pro.
1716
 
1717
- The `SLTPvA_short.yaml` prompt also seems to work better with smaller LLMs (e.g., Mistral Tiny). Alternatively, enable double OCR to help the LLM focus on the OCR text given a longer prompt.""")
 
 
1718
 
1719
 
1720
  def content_api_check():
@@ -1927,6 +1946,8 @@ def content_ocr_method():
1927
  # st.text_area(label='Handwritten/Printed + trOCR',placeholder=demo_text_trh,disabled=True, label_visibility='visible', height=150)
1928
 
1929
  def is_valid_huggingface_model_path(model_path):
 
 
1930
  try:
1931
  # Attempt to load the model configuration from Hugging Face Model Hub
1932
  config = AutoConfig.from_pretrained(model_path)
 
7
  from io import BytesIO
8
  from streamlit_extras.let_it_rain import rain
9
  from annotated_text import annotated_text
 
10
 
11
  from vouchervision.LeafMachine2_Config_Builder import write_config_file
12
  from vouchervision.VoucherVision_Config_Builder import build_VV_config, TestOptionsGPT, TestOptionsPalm, check_if_usable
 
17
  from vouchervision.utils_hf import setup_streamlit_config, save_uploaded_file, save_uploaded_local, save_uploaded_file_local
18
  from vouchervision.data_project import convert_pdf_to_jpg
19
  from vouchervision.utils_LLM import check_system_gpus
20
+ from vouchervision.OCR_google_cloud_vision import check_for_inappropriate_content
21
 
22
  import cProfile
23
  import pstats
 
250
  file_path_small = save_uploaded_file(st.session_state['dir_uploaded_images_small'], uploaded_file, img)
251
  st.session_state['input_list_small'].append(file_path_small)
252
 
253
+
254
+
255
+
256
  @st.cache_data
257
  def handle_image_upload_and_gallery_hf(uploaded_files):
258
  if uploaded_files:
259
+
260
  # Clear input image gallery and input list
261
  clear_image_uploads()
262
 
263
  ind_small = 0
264
  for uploaded_file in uploaded_files:
265
+
266
+ if check_for_inappropriate_content(uploaded_file):
267
+ clear_image_uploads()
268
+ st.error("Warning: You have uploaded an inappropriate image")
269
+ return True
270
+
271
+
272
  # Determine the file type
273
  if uploaded_file.name.lower().endswith('.pdf'):
274
  # Handle PDF files
 
316
  # If there are less than 100 images, take them all
317
  images_to_display = st.session_state['input_list_small']
318
  show_gallery_small_hf(images_to_display)
319
+
320
+ return False
321
 
322
 
323
  @st.cache_data
 
391
 
392
  with col_right:
393
  if st.session_state.is_hf:
394
+ result = handle_image_upload_and_gallery_hf(uploaded_files)
395
 
396
  else:
397
  st.session_state['view_local_gallery'] = st.toggle("View Image Gallery",)
 
440
 
441
  def create_download_button(zip_filepath, col, key):
442
  with col:
443
+ # labal_n_images = f"Download Results for {st.session_state['processing_add_on']} Images"
444
+ labal_n_images = f"Download Results"
445
  with open(zip_filepath, 'rb') as f:
446
  bytes_io = BytesIO(f.read())
447
  st.download_button(
 
1081
  "client_x509_cert_url": "A LONG URL",
1082
  "universe_domain": "googleapis.com"
1083
  })
1084
+
1085
+ blog_text('Google project ID', ': The project ID is the "project_id" value from the JSON file.')
1086
+ blog_text('Google project location', ': The project location specifies the location of the Google server that your project resources will utilize. It should not really make a difference which location you choose. We use `us-central1`, but you might want to choose a location closer to where you live. [please see this page for a list of available regions](https://cloud.google.com/vertex-ai/docs/general/locations)')
1087
+
1088
+
1089
  google_application_credentials = st.text_input(label = 'Full path to Google Cloud JSON API key file', value = cfg_private['google'].get('GOOGLE_APPLICATION_CREDENTIALS', ''),
1090
  placeholder = 'e.g. C:/Documents/Secret_Files/google_API/application_default_credentials.json',
1091
  help ="This API Key is in the form of a JSON file. Please save the JSON file in a safe directory. DO NOT store the JSON key inside of the VoucherVision directory.",
 
1146
 
1147
  st.write("---")
1148
  st.subheader("MistralAI")
1149
+ st.markdown('Follow these [instructions](https://console.mistral.ai/) to generate an API key for MistralAI.')
1150
  mistral_API_KEY = st.text_input("MistralAI API Key", cfg_private['mistral'].get('MISTRAL_API_KEY', ''),
1151
  help='e.g. a 32-character string',
1152
  placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
 
1379
  cost_openai[key] = cost_data.get(value,'')
1380
  elif 'PALM2' in parts or 'GEMINI' in parts:
1381
  cost_google[key] = cost_data.get(value,'')
1382
+ elif ('MISTRAL' in parts) or ('MIXTRAL' in parts):
1383
  cost_mistral[key] = cost_data.get(value,'')
1384
 
1385
  styled_cost_openai = convert_cost_dict_to_table(cost_openai, "OpenAI")
 
1422
  N_STEPS = 6
1423
 
1424
  if check_if_usable(is_hf=st.session_state['is_hf']):
1425
+ # b_text = f"Start Processing {st.session_state['processing_add_on']} Images" if st.session_state['processing_add_on'] > 1 else f"Start Processing {st.session_state['processing_add_on']} Image"
1426
+ # if st.session_state['processing_add_on'] == 0:
1427
+ b_text = f"Start Transcription"
1428
  if st.button(b_text, type='primary',use_container_width=True):
1429
  st.session_state['formatted_json'] = {}
1430
  st.session_state['formatted_json_WFO'] = {}
 
1485
  if st.session_state['zip_filepath']:
1486
  create_download_button(st.session_state['zip_filepath'], col_run_1,key=97863332)
1487
  else:
1488
+ st.button("Start Transcription", type='primary', disabled=True)
1489
  with col_run_4:
1490
  st.error(":heavy_exclamation_mark: Required API keys not set. Please visit the 'API Keys' tab and set the Google Vision OCR API key and at least one LLM key.")
1491
 
 
1501
  ct_left, ct_right = st.columns([1,1])
1502
  with ct_left:
1503
  st.button("Refresh", on_click=refresh, use_container_width=True)
1504
+ with ct_right:
1505
+ try:
1506
+ st.page_link(os.path.join("pages","faqs.py"), label="FAQs", icon="❔")
1507
+ except:
1508
+ st.page_link(os.path.join(os.path.dirname(__file__),"pages","faqs.py"), label="FAQs", icon="❔")
1509
 
1510
 
1511
 
 
1706
  selected_version = default_version
1707
  st.session_state.config['leafmachine']['project']['prompt_version'] = st.selectbox("Prompt Version", available_prompts, index=available_prompts.index(selected_version),label_visibility='collapsed')
1708
 
1709
+ with col_prompt_2:
1710
+ # if st.button("Build Custom LLM Prompt"):
1711
+ try:
1712
+ st.page_link(os.path.join("pages","prompt_builder.py"), label="Prompt Builder", icon="🚧")
1713
+ except:
1714
+ st.page_link(os.path.join(os.path.dirname(__file__),"pages","prompt_builder.py"), label="Prompt Builder", icon="🚧")
1715
 
1716
 
1717
  st.header('LLM Version')
 
1722
  st.session_state.config['leafmachine']['LLM_version'] = st.selectbox("LLM version", GUI_MODEL_LIST, index=GUI_MODEL_LIST.index(st.session_state.config['leafmachine'].get('LLM_version', ModelMaps.MODELS_GUI_DEFAULT)))
1723
  st.markdown("""
1724
  Based on preliminary results, the following models perform the best. We are currently running tests of all possible OCR + LLM + Prompt combinations to create recipes for different workflows.
1725
+ - Any Mistral model e.g., `Mistral Large`
 
 
1726
  - `PaLM 2 text-bison@001`
1727
  - `GPT 4 Turbo 1106-preview`
1728
+ - `GPT 3.5 Turbo`
1729
  - `LOCAL Mixtral 7Bx8 Instruct`
1730
  - `LOCAL Mixtral 7B Instruct`
1731
 
1732
  Larger models (e.g., `GPT 4`, `GPT 4 32k`, `Gemini Pro`) do not necessarily perform better for these tasks. MistralAI models exceeded our expectations and perform extremely well. PaLM 2 text-bison@001 also seems to consistently out-perform Gemini Pro.
1733
 
1734
+ The `SLTPvA_short.yaml` prompt also seems to work better with smaller LLMs (e.g., Mistral Tiny). Alternatively, enable double OCR to help the LLM focus on the OCR text given a longer prompt.
1735
+
1736
+ Models `GPT 3.5 Turbo` and `GPT 4 Turbo 0125-preview` enable OpenAI's [JSON mode](https://platform.openai.com/docs/guides/text-generation/json-mode), which helps prevent JSON errors. All models implement Langchain JSON parsing too, so JSON errors are rare for most models.""")
1737
 
1738
 
1739
  def content_api_check():
 
1946
  # st.text_area(label='Handwritten/Printed + trOCR',placeholder=demo_text_trh,disabled=True, label_visibility='visible', height=150)
1947
 
1948
  def is_valid_huggingface_model_path(model_path):
1949
+ from transformers import AutoConfig
1950
+
1951
  try:
1952
  # Attempt to load the model configuration from Hugging Face Model Hub
1953
  config = AutoConfig.from_pretrained(model_path)
custom_prompts/SLTPvA_long.yaml CHANGED
@@ -28,10 +28,7 @@ rules:
28
  scientificNameAuthorship: The authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
29
  genus: Taxonomic determination to genus. Genus must be capitalized. If
30
  genus is not present use the taxonomic family name followed by the word 'indet'.
31
- subgenus: The full scientific name of the subgenus in which the taxon is classified.
32
- Values should include the genus to avoid homonym confusion.
33
  specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
34
- infraspecificEpithet: The name of the lowest or terminal infraspecific epithet of the scientificName, excluding any rank designation.
35
  identifiedBy: A comma separated list of names of people, groups, or organizations who assigned the taxon to the subject organism. This is not the specimen collector.
36
  recordedBy: A comma separated list of names of people, groups, or organizations responsible for observing, recording, collecting, or presenting the original specimen.
37
  The primary collector or observer should be listed first.
@@ -63,7 +60,7 @@ rules:
63
  the exact origin or location of the specimen.
64
  degreeOfEstablishment: Cultivated plants are intentionally grown by humans. In text descriptions,
65
  look for planting dates, garden locations, ornamental, cultivar names, garden,
66
- or farm to indicate cultivated plant. Use either - unknown or cultivated.
67
  decimalLatitude: Latitude decimal coordinate. Correct and convert the verbatim location coordinates to conform
68
  with the decimal degrees GPS coordinate format.
69
  decimalLongitude: Longitude decimal coordinate. Correct and convert the verbatim location coordinates to conform
@@ -78,35 +75,33 @@ rules:
78
  are explicit then convert from feet ("ft" or "ft." or "feet") to meters ("m"
79
  or "m." or "meters"). Round to integer.
80
  mapping:
81
- TAXONOMY:
82
- - catalogNumber
83
- - order
84
- - family
85
- - scientificName
86
- - scientificNameAuthorship
87
- - genus
88
- - subgenus
89
- - specificEpithet
90
- - infraspecificEpithet
91
- GEOGRAPHY:
92
- - country
93
- - stateProvince
94
- - county
95
- - municipality
96
- - decimalLatitude
97
- - decimalLongitude
98
- - verbatimCoordinates
99
- LOCALITY:
100
- - locality
101
- - habitat
102
- - minimumElevationInMeters
103
- - maximumElevationInMeters
104
- COLLECTING:
105
- - identifiedBy
106
- - recordedBy
107
- - recordNumber
108
- - verbatimEventDate
109
- - eventDate
110
- - degreeOfEstablishment
111
- - occurrenceRemarks
112
- MISC:
 
28
  scientificNameAuthorship: The authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
29
  genus: Taxonomic determination to genus. Genus must be capitalized. If
30
  genus is not present use the taxonomic family name followed by the word 'indet'.
 
 
31
  specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
 
32
  identifiedBy: A comma separated list of names of people, groups, or organizations who assigned the taxon to the subject organism. This is not the specimen collector.
33
  recordedBy: A comma separated list of names of people, groups, or organizations responsible for observing, recording, collecting, or presenting the original specimen.
34
  The primary collector or observer should be listed first.
 
60
  the exact origin or location of the specimen.
61
  degreeOfEstablishment: Cultivated plants are intentionally grown by humans. In text descriptions,
62
  look for planting dates, garden locations, ornamental, cultivar names, garden,
63
+ or farm to indicate cultivated plant. Set to 'cultivated' if cultivated, otherwise use an empty string.
64
  decimalLatitude: Latitude decimal coordinate. Correct and convert the verbatim location coordinates to conform
65
  with the decimal degrees GPS coordinate format.
66
  decimalLongitude: Longitude decimal coordinate. Correct and convert the verbatim location coordinates to conform
 
75
  are explicit then convert from feet ("ft" or "ft." or "feet") to meters ("m"
76
  or "m." or "meters"). Round to integer.
77
  mapping:
78
+ TAXONOMY:
79
+ - catalogNumber
80
+ - order
81
+ - family
82
+ - scientificName
83
+ - scientificNameAuthorship
84
+ - genus
85
+ - specificEpithet
86
+ GEOGRAPHY:
87
+ - country
88
+ - stateProvince
89
+ - county
90
+ - municipality
91
+ - decimalLatitude
92
+ - decimalLongitude
93
+ - verbatimCoordinates
94
+ LOCALITY:
95
+ - locality
96
+ - habitat
97
+ - minimumElevationInMeters
98
+ - maximumElevationInMeters
99
+ COLLECTING:
100
+ - identifiedBy
101
+ - recordedBy
102
+ - recordNumber
103
+ - verbatimEventDate
104
+ - eventDate
105
+ - degreeOfEstablishment
106
+ - occurrenceRemarks
107
+ MISC: []
 
 
custom_prompts/SLTPvA_medium.yaml CHANGED
@@ -27,9 +27,7 @@ rules:
27
  and any lower classifications.
28
  scientificNameAuthorship: The authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
29
  genus: Taxonomic determination to genus. Genus must be capitalized.
30
- subgenus: The full scientific name of the subgenus in which the taxon is classified.
31
  specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
32
- infraspecificEpithet: The name of the lowest or terminal infraspecific epithet of the scientificName, excluding any rank designation.
33
  identifiedBy: A comma separated list of names of people, groups, or organizations who assigned the taxon to the subject organism. This is not the specimen collector.
34
  recordedBy: A comma separated list of names of people, groups, or organizations
35
  recordNumber: An identifier given to the specimen at the time it was recorded.
@@ -46,7 +44,7 @@ rules:
46
  the exact origin or location of the specimen.
47
  degreeOfEstablishment: Cultivated plants are intentionally grown by humans. In text descriptions,
48
  look for planting dates, garden locations, ornamental, cultivar names, garden,
49
- or farm to indicate cultivated plant. Use either - unknown or cultivated.
50
  decimalLatitude: Latitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format.
51
  decimalLongitude: Longitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format.
52
  verbatimCoordinates: Verbatim location coordinates as they appear on the label.
@@ -60,9 +58,7 @@ mapping:
60
  - scientificName
61
  - scientificNameAuthorship
62
  - genus
63
- - subgenus
64
  - specificEpithet
65
- - infraspecificEpithet
66
  GEOGRAPHY:
67
  - country
68
  - stateProvince
@@ -84,4 +80,4 @@ mapping:
84
  - eventDate
85
  - degreeOfEstablishment
86
  - occurrenceRemarks
87
- MISC:
 
27
  and any lower classifications.
28
  scientificNameAuthorship: The authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
29
  genus: Taxonomic determination to genus. Genus must be capitalized.
 
30
  specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
 
31
  identifiedBy: A comma separated list of names of people, groups, or organizations who assigned the taxon to the subject organism. This is not the specimen collector.
32
  recordedBy: A comma separated list of names of people, groups, or organizations
33
  recordNumber: An identifier given to the specimen at the time it was recorded.
 
44
  the exact origin or location of the specimen.
45
  degreeOfEstablishment: Cultivated plants are intentionally grown by humans. In text descriptions,
46
  look for planting dates, garden locations, ornamental, cultivar names, garden,
47
+ or farm to indicate cultivated plant. Set to 'cultivated' if cultivated, otherwise use an empty string.
48
  decimalLatitude: Latitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format.
49
  decimalLongitude: Longitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format.
50
  verbatimCoordinates: Verbatim location coordinates as they appear on the label.
 
58
  - scientificName
59
  - scientificNameAuthorship
60
  - genus
 
61
  - specificEpithet
 
62
  GEOGRAPHY:
63
  - country
64
  - stateProvince
 
80
  - eventDate
81
  - degreeOfEstablishment
82
  - occurrenceRemarks
83
+ MISC: []
custom_prompts/SLTPvA_short.yaml CHANGED
@@ -26,9 +26,7 @@ rules:
26
  scientificName: scientific name of the taxon including Genus, specific epithet, and any lower classifications.
27
  scientificNameAuthorship: authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
28
  genus: taxonomic determination to Genus, Genus must be capitalized.
29
- subgenus: name of the subgenus.
30
  specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
31
- infraspecificEpithet: lowest or terminal infraspecific epithet of the scientificName.
32
  identifiedBy: list of names of people, doctors, professors, groups, or organizations who identified, determined the taxon name to the subject organism. This is not the specimen collector.
33
  recordedBy: list of names of people, doctors, professors, groups, or organizations.
34
  recordNumber: identifier given to the specimen at the time it was recorded.
@@ -41,7 +39,7 @@ rules:
41
  county: county, shire, department, parish etc.
42
  municipality: city, municipality, etc.
43
  locality: description of geographic information aiding in pinpointing the exact origin or location of the specimen.
44
- degreeOfEstablishment: cultivated plants are intentionally grown by humans. Use either - unknown or cultivated.
45
  decimalLatitude: latitude decimal coordinate.
46
  decimalLongitude: longitude decimal coordinate.
47
  verbatimCoordinates: verbatim location coordinates.
@@ -55,9 +53,7 @@ mapping:
55
  - scientificName
56
  - scientificNameAuthorship
57
  - genus
58
- - subgenus
59
  - specificEpithet
60
- - infraspecificEpithet
61
  GEOGRAPHY:
62
  - country
63
  - stateProvince
@@ -79,4 +75,4 @@ mapping:
79
  - eventDate
80
  - degreeOfEstablishment
81
  - occurrenceRemarks
82
- MISC:
 
26
  scientificName: scientific name of the taxon including Genus, specific epithet, and any lower classifications.
27
  scientificNameAuthorship: authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
28
  genus: taxonomic determination to Genus, Genus must be capitalized.
 
29
  specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
 
30
  identifiedBy: list of names of people, doctors, professors, groups, or organizations who identified, determined the taxon name to the subject organism. This is not the specimen collector.
31
  recordedBy: list of names of people, doctors, professors, groups, or organizations.
32
  recordNumber: identifier given to the specimen at the time it was recorded.
 
39
  county: county, shire, department, parish etc.
40
  municipality: city, municipality, etc.
41
  locality: description of geographic information aiding in pinpointing the exact origin or location of the specimen.
42
+ degreeOfEstablishment: cultivated plants are intentionally grown by humans. Set to 'cultivated' if cultivated, otherwise use an empty string.
43
  decimalLatitude: latitude decimal coordinate.
44
  decimalLongitude: longitude decimal coordinate.
45
  verbatimCoordinates: verbatim location coordinates.
 
53
  - scientificName
54
  - scientificNameAuthorship
55
  - genus
 
56
  - specificEpithet
 
57
  GEOGRAPHY:
58
  - country
59
  - stateProvince
 
75
  - eventDate
76
  - degreeOfEstablishment
77
  - occurrenceRemarks
78
+ MISC: []
custom_prompts/SLTPvB_long.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompt_author: Will Weaver
2
+ prompt_author_institution: University of Michigan
3
+ prompt_name: SLTPvB_long
4
+ prompt_version: v-1-0
5
+ prompt_description: Prompt developed by the University of Michigan.
6
+ SLTPvB prompts all have standardized column headers (fields) that were chosen due to their reliability and prevalence in herbarium records.
7
+ All field descriptions are based on the official Darwin Core guidelines.
8
+ SLTPvB_long - The most verbose prompt option. Descriptions closely follow DwC guides. Detailed rules for the LLM to follow. Works best with double or triple OCR to increase attention back to the OCR (select 'use both OCR models' or 'handwritten + printed' along with trOCR).
9
+ SLTPvB_medium - Shorter verion of _long.
10
+ SLTPvB_short - The least verbose possible prompt while still providing rules and DwC descriptions.
11
+ LLM: General Purpose
12
+ instructions: 1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
13
+ 2. Map the unstructured OCR text to the appropriate JSON key and populate the field given the user-defined rules.
14
+ 3. JSON key values are permitted to remain empty strings if the corresponding information is not found in the unstructured OCR text.
15
+ 4. Duplicate dictionary fields are not allowed.
16
+ 5. Ensure all JSON keys are in camel case.
17
+ 6. Ensure new JSON field values follow sentence case capitalization.
18
+ 7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
19
+ 8. Ensure output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
20
+ 9. Only return a JSON dictionary represented as a string. You should not explain your answer.
21
+ json_formatting_instructions: This section provides rules for formatting each JSON value organized by the JSON key.
22
+ rules:
23
+ catalogNumber: Barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits.
24
+ order: The full scientific name of the order in which the taxon is classified. Order must be capitalized.
25
+ family: The full scientific name of the family in which the taxon is classified. Family must be capitalized.
26
+ speciesBinomialName: The scientific name of the taxon including genus, specific epithet,
27
+ and any lower classifications.
28
+ genus: Taxonomic determination to genus. Genus must be capitalized. If
29
+ genus is not present use the taxonomic family name followed by the word 'indet'.
30
+ specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
31
+ speciesBinomialNameAuthorship: The authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
32
+ collector: A comma separated list of names of people, groups, or organizations responsible for observing, recording, collecting, or presenting the original specimen.
33
+ The primary collector or observer should be listed first.
34
+ recordNumber: An identifier given to the occurrence at the time it was recorded.
35
+ Often serves as a link between field notes and an occurrence record, such as a specimen collector's number.
36
+ identifiedBy: A comma separated list of names of people, groups, or organizations who assigned the taxon to the subject organism. This is not the specimen collector.
37
+ verbatimCollectionDate: The verbatim original representation of the date and time information for when the specimen was collected.
38
+ Date of collection exactly as it appears on the label. Do not change
39
+ the format or correct typos.
40
+ collectionDate: Date the specimen was collected formatted as year-month-day, YYYY-MM_DD. If
41
+ specific components of the date are unknown, they should be replaced with
42
+ zeros. Examples "0000-00-00" if the entire date is unknown, "YYYY-00-00"
43
+ if only the year is known, and "YYYY-MM-00" if year and month are known
44
+ but day is not.
45
+ occurrenceRemarks: Text describing the specimen's geographic location. Text describing the appearance of the specimen.
46
+ A statement about the presence or absence of a taxon at a the collection location.
47
+ Text describing the significance of the specimen, such as a specific expedition or notable collection.
48
+ Description of plant features such as leaf shape, size, color,
49
+ stem texture, height, flower structure, scent, fruit or seed characteristics,
50
+ root system type, overall growth habit and form, any notable aroma or secretions,
51
+ presence of hairs or bristles, and any other distinguishing morphological
52
+ or physiological characteristics.
53
+ habitat: A category or description of the habitat in which the specimen collection event occurred.
54
+ locality: Description of geographic location, landscape, landmarks, regional
55
+ features, nearby places, or any contextual information aiding in pinpointing
56
+ the exact origin or location of the specimen.
57
+ isCultivated: Cultivated plants are intentionally grown by humans. In text descriptions,
58
+ look for planting dates, garden locations, ornamental, cultivar names, garden,
59
+ or farm to indicate cultivated plant. Set to 'cultivated' if cultivated, otherwise use an empty string.
60
+ country: The name of the country or major administrative unit in which the specimen was originally collected.
61
+ stateProvince: The name of the next smaller administrative region than country (state, province, canton, department, region, etc.) in which the specimen was originally collected.
62
+ county: The full, unabbreviated name of the next smaller administrative region than stateProvince (county, shire, department, parish etc.) in which the specimen was originally collected.
63
+ municipality: The full, unabbreviated name of the next smaller administrative region than county (city, municipality, etc.) in which the specimen was originally collected.
64
+ verbatimCoordinates: Verbatim location coordinates as they appear on the label. Do not
65
+ convert formats. Possible coordinate types include [Lat, Long, UTM, TRS].
66
+ decimalLatitude: Latitude decimal coordinate. Correct and convert the verbatim location coordinates to conform
67
+ with the decimal degrees GPS coordinate format.
68
+ decimalLongitude: Longitude decimal coordinate. Correct and convert the verbatim location coordinates to conform
69
+ with the decimal degrees GPS coordinate format.
70
+ minimumElevationInMeters: Minimum elevation or altitude in meters. Only if units are explicit
71
+ then convert from feet ("ft" or "ft."" or "feet") to meters ("m" or "m." or
72
+ "meters"). Round to integer.
73
+ maximumElevationInMeters: Maximum elevation or altitude in meters. If only one elevation
74
+ is present, then max_elevation should be set to the null_value. Only if units
75
+ are explicit then convert from feet ("ft" or "ft." or "feet") to meters ("m"
76
+ or "m." or "meters"). Round to integer.
77
+ mapping:
78
+ TAXONOMY:
79
+ - catalogNumber
80
+ - order
81
+ - family
82
+ - speciesBinomialName
83
+ - genus
84
+ - specificEpithet
85
+ - speciesBinomialNameAuthorship
86
+ GEOGRAPHY:
87
+ - country
88
+ - stateProvince
89
+ - county
90
+ - municipality
91
+ - verbatimCoordinates
92
+ - decimalLatitude
93
+ - decimalLongitude
94
+ - minimumElevationInMeters
95
+ - maximumElevationInMeters
96
+ LOCALITY:
97
+ - occurrenceRemarks
98
+ - habitat
99
+ - locality
100
+ - isCultivated
101
+ COLLECTING:
102
+ - collector
103
+ - recordNumber
104
+ - identifiedBy
105
+ - verbatimCollectionDate
106
+ - collectionDate
107
+ MISC: []
custom_prompts/SLTPvB_medium.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompt_author: Will Weaver
2
+ prompt_author_institution: University of Michigan
3
+ prompt_name: SLTPvB_medium
4
+ prompt_version: v-1-0
5
+ prompt_description: Prompt developed by the University of Michigan.
6
+ SLTPvB prompts all have standardized column headers (fields) that were chosen due to their reliability and prevalence in herbarium records.
7
+ All field descriptions are based on the official Darwin Core guidelines.
8
+ SLTPvB_long - The most verbose prompt option. Descriptions closely follow DwC guides. Detailed rules for the LLM to follow. Works best with double or triple OCR to increase attention back to the OCR (select 'use both OCR models' or 'handwritten + printed' along with trOCR).
9
+ SLTPvB_medium - Shorter verion of _long.
10
+ SLTPvB_short - The least verbose possible prompt while still providing rules and DwC descriptions.
11
+ LLM: General Purpose
12
+ instructions: 1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
13
+ 2. Map the unstructured OCR text to the appropriate JSON key and populate the field given the user-defined rules.
14
+ 3. JSON key values are permitted to remain empty strings if the corresponding information is not found in the unstructured OCR text.
15
+ 4. Duplicate dictionary fields are not allowed.
16
+ 5. Ensure all JSON keys are in camel case.
17
+ 6. Ensure new JSON field values follow sentence case capitalization.
18
+ 7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
19
+ 8. Ensure output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
20
+ 9. Only return a JSON dictionary represented as a string. You should not explain your answer.
21
+ json_formatting_instructions: This section provides rules for formatting each JSON value organized by the JSON key.
22
+ rules:
23
+ catalogNumber: Barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits.
24
+ order: The full scientific name of the order in which the taxon is classified. Order must be capitalized.
25
+ family: The full scientific name of the family in which the taxon is classified. Family must be capitalized.
26
+ speciesBinomialName: The scientific name of the taxon including genus, specific epithet,
27
+ and any lower classifications.
28
+ genus: Taxonomic determination to genus. Genus must be capitalized.
29
+ specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
30
+ speciesBinomialNameAuthorship: The authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
31
+ collector: A comma separated list of names of people, groups, or organizations
32
+ recordNumber: An identifier given to the specimen at the time it was recorded.
33
+ identifiedBy: A comma separated list of names of people, groups, or organizations who assigned the taxon to the subject organism. This is not the specimen collector.
34
+ verbatimCollectionDate: The verbatim original representation of the date and time information for when the specimen was collected.
35
+ collectionDate: Date the specimen was collected formatted as year-month-day YYYY-MM-DD.
36
+ occurrenceRemarks: Text describing the specimen's geographic location, appearance of the specimen, presence or absence of a taxon at a the collection location, the significance of the specimen, such as a specific expedition or notable collection, plant features and descriptions.
37
+ habitat: A category or description of the habitat in which the specimen collection event occurred.
38
+ locality: Description of geographic location, landscape, landmarks, regional
39
+ features, nearby places, or any contextual information aiding in pinpointing
40
+ the exact origin or location of the specimen.
41
+ isCultivated: Cultivated plants are intentionally grown by humans. In text descriptions,
42
+ look for planting dates, garden locations, ornamental, cultivar names, garden,
43
+ or farm to indicate cultivated plant. Set to 'cultivated' if cultivated, otherwise use an empty string.
44
+ country: The name of the country or major administrative unit in which the specimen was originally collected.
45
+ stateProvince: The name of the next smaller administrative region than country (state, province, canton, department, region, etc.) in which the specimen was originally collected.
46
+ county: The full, unabbreviated name of the next smaller administrative region than stateProvince (county, shire, department, parish etc.) in which the specimen was originally collected.
47
+ municipality: The full, unabbreviated name of the next smaller administrative region than county (city, municipality, etc.) in which the specimen was originally collected.
48
+ verbatimCoordinates: Verbatim location coordinates as they appear on the label.
49
+ decimalLatitude: Latitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format.
50
+ decimalLongitude: Longitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format.
51
+ minimumElevationInMeters: Minimum elevation or altitude in meters. Only if units are explicit then convert from feet ("ft" or "ft."" or "feet") to meters ("m" or "m." or "meters"). Round to integer.
52
+ maximumElevationInMeters: Maximum elevation or altitude in meters. If only one elevation is present, then max_elevation should be set to the null_value. Only if units are explicit then convert from feet ("ft" or "ft." or "feet") to meters ("m" or "m." or "meters"). Round to integer.
53
+ mapping:
54
+ TAXONOMY:
55
+ - catalogNumber
56
+ - order
57
+ - family
58
+ - speciesBinomialName
59
+ - genus
60
+ - specificEpithet
61
+ - speciesBinomialNameAuthorship
62
+ GEOGRAPHY:
63
+ - country
64
+ - stateProvince
65
+ - county
66
+ - municipality
67
+ - verbatimCoordinates
68
+ - decimalLatitude
69
+ - decimalLongitude
70
+ - minimumElevationInMeters
71
+ - maximumElevationInMeters
72
+ LOCALITY:
73
+ - occurrenceRemarks
74
+ - habitat
75
+ - locality
76
+ - isCultivated
77
+ COLLECTING:
78
+ - collector
79
+ - recordNumber
80
+ - identifiedBy
81
+ - verbatimCollectionDate
82
+ - collectionDate
83
+ MISC: []
custom_prompts/SLTPvB_short.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompt_author: Will Weaver
2
+ prompt_author_institution: University of Michigan
3
+ prompt_name: SLTPvB_short
4
+ prompt_version: v-1-0
5
+ prompt_description: Prompt developed by the University of Michigan.
6
+ SLTPvB prompts all have standardized column headers (fields) that were chosen due to their reliability and prevalence in herbarium records.
7
+ All field descriptions are based on the official Darwin Core guidelines.
8
+ SLTPvB_long - The most verbose prompt option. Descriptions closely follow DwC guides. Detailed rules for the LLM to follow. Works best with double or triple OCR to increase attention back to the OCR (select 'use both OCR models' or 'handwritten + printed' along with trOCR).
9
+ SLTPvB_medium - Shorter verion of _long.
10
+ SLTPvB_short - The least verbose possible prompt while still providing rules and DwC descriptions.
11
+ LLM: General Purpose
12
+ instructions: 1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
13
+ 2. Map the unstructured OCR text to the appropriate JSON key and populate the field given the user-defined rules.
14
+ 3. JSON key values are permitted to remain empty strings if the corresponding information is not found in the unstructured OCR text.
15
+ 4. Duplicate dictionary fields are not allowed.
16
+ 5. Ensure all JSON keys are in camel case.
17
+ 6. Ensure new JSON field values follow sentence case capitalization.
18
+ 7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
19
+ 8. Ensure output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
20
+ 9. Only return a JSON dictionary represented as a string. You should not explain your answer.
21
+ json_formatting_instructions: This section provides rules for formatting each JSON value organized by the JSON key.
22
+ rules:
23
+ catalogNumber: barcode identifier, at least 6 digits, fewer than 30 digits.
24
+ order: full scientific name of the Order in which the taxon is classified. Order must be capitalized.
25
+ family: full scientific name of the Family in which the taxon is classified. Family must be capitalized.
26
+ speciesBinomialName: scientific name of the taxon including Genus, specific epithet, and any lower classifications.
27
+ genus: taxonomic determination to Genus, Genus must be capitalized.
28
+ specificEpithet: The name of the first or species epithet of the scientificBinomial. Only include the species epithet.
29
+ speciesBinomialNameAuthorship: authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
30
+ collector: list of names of people, doctors, professors, groups, or organizations.
31
+ recordNumber: identifier given to the specimen at the time it was recorded.
32
+ identifiedBy: list of names of people, doctors, professors, groups, or organizations who identified, determined the taxon name to the subject organism. This is not the specimen collector.
33
+ verbatimCollectionDate: The verbatim original representation of the date and time information for when the specimen was collected.
34
+ collectionDate: collection date formatted as year-month-day YYYY-MM-DD.
35
+ occurrenceRemarks: all descriptive text in the OCR rearranged into sensible sentences or sentence fragments.
36
+ habitat: habitat description.
37
+ locality: description of geographic information aiding in pinpointing the exact origin or location of the specimen.
38
+ isCultivated: cultivated plants are intentionally grown by humans. Set to 'cultivated' if cultivated, otherwise use an empty string.
39
+ country: country or major administrative unit.
40
+ stateProvince: state, province, canton, department, region, etc.
41
+ county: county, shire, department, parish etc.
42
+ municipality: city, municipality, etc.
43
+ verbatimCoordinates: verbatim location coordinates.
44
+ decimalLatitude: latitude decimal coordinate.
45
+ decimalLongitude: longitude decimal coordinate.
46
+ minimumElevationInMeters: minimum elevation or altitude in meters.
47
+ maximumElevationInMeters: maximum elevation or altitude in meters.
48
+ mapping:
49
+ TAXONOMY:
50
+ - catalogNumber
51
+ - order
52
+ - family
53
+ - speciesBinomialName
54
+ - genus
55
+ - specificEpithet
56
+ - speciesBinomialNameAuthorship
57
+ GEOGRAPHY:
58
+ - country
59
+ - stateProvince
60
+ - county
61
+ - municipality
62
+ - verbatimCoordinates
63
+ - decimalLatitude
64
+ - decimalLongitude
65
+ - minimumElevationInMeters
66
+ - maximumElevationInMeters
67
+ LOCALITY:
68
+ - occurrenceRemarks
69
+ - habitat
70
+ - locality
71
+ - isCultivated
72
+ COLLECTING:
73
+ - collector
74
+ - recordNumber
75
+ - identifiedBy
76
+ - verbatimCollectionDate
77
+ - collectionDate
78
+ MISC: []
pages/prompt_builder.py CHANGED
@@ -76,7 +76,9 @@ def load_prompt_yaml(filename):
76
  st.session_state['mapping'] = st.session_state['prompt_info'].get('mapping', {})
77
  st.session_state['LLM'] = st.session_state['prompt_info'].get('LLM', 'General Purpose')
78
 
79
- # Placeholder:
 
 
80
  st.session_state['assigned_columns'] = list(chain.from_iterable(st.session_state['mapping'].values()))
81
 
82
 
 
76
  st.session_state['mapping'] = st.session_state['prompt_info'].get('mapping', {})
77
  st.session_state['LLM'] = st.session_state['prompt_info'].get('LLM', 'General Purpose')
78
 
79
+ # print(st.session_state['mapping'].values())
80
+ # print(chain.from_iterable(st.session_state['mapping'].values()))
81
+ # print(list(chain.from_iterable(st.session_state['mapping'].values())))
82
  st.session_state['assigned_columns'] = list(chain.from_iterable(st.session_state['mapping'].values()))
83
 
84
 
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
 
requirements_conda.txt ADDED
Binary file (1.97 kB). View file
 
requirements_with_versions.txt ADDED
Binary file (11.1 kB). View file
 
run_VoucherVision.py CHANGED
@@ -31,7 +31,7 @@ def resolve_path(path):
31
  if __name__ == "__main__":
32
  dir_home = os.path.dirname(__file__)
33
 
34
- start_port = 8528
35
  try:
36
  free_port = find_available_port(start_port)
37
  sys.argv = [
@@ -42,6 +42,7 @@ if __name__ == "__main__":
42
  "--global.developmentMode=false",
43
  # "--server.port=8545",
44
  f"--server.port={free_port}",
 
45
  # Toggle below for HF vs Local
46
  # "--is_hf=1",
47
  # "--is_hf=0",
 
31
  if __name__ == "__main__":
32
  dir_home = os.path.dirname(__file__)
33
 
34
+ start_port = 8530
35
  try:
36
  free_port = find_available_port(start_port)
37
  sys.argv = [
 
42
  "--global.developmentMode=false",
43
  # "--server.port=8545",
44
  f"--server.port={free_port}",
45
+ f"--server.maxUploadSize=51200",
46
  # Toggle below for HF vs Local
47
  # "--is_hf=1",
48
  # "--is_hf=0",
vouchervision/LLM_GoogleGemini.py CHANGED
@@ -20,7 +20,7 @@ class GoogleGeminiHandler:
20
  VENDOR = 'google'
21
  STARTING_TEMP = 0.5
22
 
23
- def __init__(self, cfg, logger, model_name, JSON_dict_structure):
24
  self.cfg = cfg
25
  self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
26
  self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
@@ -30,10 +30,8 @@ class GoogleGeminiHandler:
30
  self.model_name = model_name
31
  self.JSON_dict_structure = JSON_dict_structure
32
 
33
- self.starting_temp = float(self.STARTING_TEMP)
34
- self.temp_increment = float(0.2)
35
- self.adjust_temp = self.starting_temp
36
-
37
  self.monitor = SystemLoadMonitor(logger)
38
 
39
  self.parser = JsonOutputParser()
@@ -50,11 +48,24 @@ class GoogleGeminiHandler:
50
  def _set_config(self):
51
  # os.environ['GOOGLE_API_KEY'] # Must be set too for the retry call, set in VoucherVision class along with other API Keys
52
  # vertexai.init(project=os.environ['PALM_PROJECT_ID'], location=os.environ['PALM_LOCATION'])
53
- self.config = {
 
 
 
 
 
 
 
 
 
54
  "max_output_tokens": 1024,
55
  "temperature": self.starting_temp,
56
- "top_p": 1
57
  }
 
 
 
 
58
  self.safety_settings = {
59
  HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
60
  HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
@@ -65,22 +76,26 @@ class GoogleGeminiHandler:
65
 
66
  def _adjust_config(self):
67
  new_temp = self.adjust_temp + self.temp_increment
68
- self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
 
69
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
70
  self.adjust_temp += self.temp_increment
71
  self.config['temperature'] = self.adjust_temp
72
 
73
  def _reset_config(self):
74
- self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
 
75
  self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
76
  self.adjust_temp = self.starting_temp
77
  self.config['temperature'] = self.starting_temp
78
 
79
  def _build_model_chain_parser(self):
80
  # Instantiate the LLM class for Google Gemini
81
- self.llm_model = ChatGoogleGenerativeAI(model=self.model_name)#,
82
- # max_output_tokens=self.config.get('max_output_tokens'),
83
- # top_p=self.config.get('top_p'))
 
 
84
  # self.llm_model = VertexAI(model='gemini-1.0-pro',
85
  # max_output_tokens=self.config.get('max_output_tokens'),
86
  # top_p=self.config.get('top_p'))
@@ -101,7 +116,8 @@ class GoogleGeminiHandler:
101
  def call_llm_api_GoogleGemini(self, prompt_template, json_report, paths):
102
  _____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
103
  self.json_report = json_report
104
- self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
 
105
  self.monitor.start_monitoring_usage()
106
  nt_in = 0
107
  nt_out = 0
@@ -110,9 +126,9 @@ class GoogleGeminiHandler:
110
  while ind < self.MAX_RETRIES:
111
  ind += 1
112
  try:
113
- model_kwargs = {"temperature": self.adjust_temp}
114
  # Invoke the chain to generate prompt text
115
- response = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})
116
 
117
  # Use retry_parser to parse the response with retry logic
118
  output = self.retry_parser.parse_with_prompt(response, prompt_value=prompt_template)
@@ -131,7 +147,8 @@ class GoogleGeminiHandler:
131
  else:
132
  self.monitor.stop_inference_timer() # Starts tool timer too
133
 
134
- json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
 
135
  output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
136
 
137
  save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
@@ -143,7 +160,8 @@ class GoogleGeminiHandler:
143
  if self.adjust_temp != self.starting_temp:
144
  self._reset_config()
145
 
146
- json_report.set_text(text_main=f'LLM call successful')
 
147
  return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
148
 
149
  except Exception as e:
@@ -153,14 +171,16 @@ class GoogleGeminiHandler:
153
  time.sleep(self.RETRY_DELAY)
154
 
155
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
156
- self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
 
157
 
158
  self.monitor.stop_inference_timer() # Starts tool timer too
159
 
160
  usage_report = self.monitor.stop_monitoring_report_usage()
161
  self._reset_config()
162
 
163
- json_report.set_text(text_main=f'LLM call failed')
 
164
  return None, nt_in, nt_out, None, None, usage_report
165
 
166
 
 
20
  VENDOR = 'google'
21
  STARTING_TEMP = 0.5
22
 
23
+ def __init__(self, cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation):
24
  self.cfg = cfg
25
  self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
26
  self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
 
30
  self.model_name = model_name
31
  self.JSON_dict_structure = JSON_dict_structure
32
 
33
+ self.config_vals_for_permutation = config_vals_for_permutation
34
+
 
 
35
  self.monitor = SystemLoadMonitor(logger)
36
 
37
  self.parser = JsonOutputParser()
 
48
  def _set_config(self):
49
  # os.environ['GOOGLE_API_KEY'] # Must be set too for the retry call, set in VoucherVision class along with other API Keys
50
  # vertexai.init(project=os.environ['PALM_PROJECT_ID'], location=os.environ['PALM_LOCATION'])
51
+ if self.config_vals_for_permutation:
52
+ self.starting_temp = float(self.config_vals_for_permutation.get('google').get('temperature'))
53
+ self.config = {
54
+ 'max_output_tokens': self.config_vals_for_permutation.get('google').get('max_output_tokens'),
55
+ 'temperature': self.starting_temp,
56
+ 'top_p': self.config_vals_for_permutation.get('google').get('top_p'),
57
+ }
58
+ else:
59
+ self.starting_temp = float(self.STARTING_TEMP)
60
+ self.config = {
61
  "max_output_tokens": 1024,
62
  "temperature": self.starting_temp,
63
+ "top_p": 1.0,
64
  }
65
+
66
+ self.temp_increment = float(0.2)
67
+ self.adjust_temp = self.starting_temp
68
+
69
  self.safety_settings = {
70
  HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
71
  HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
 
76
 
77
  def _adjust_config(self):
78
  new_temp = self.adjust_temp + self.temp_increment
79
+ if self.json_report:
80
+ self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
81
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
82
  self.adjust_temp += self.temp_increment
83
  self.config['temperature'] = self.adjust_temp
84
 
85
  def _reset_config(self):
86
+ if self.json_report:
87
+ self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
88
  self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
89
  self.adjust_temp = self.starting_temp
90
  self.config['temperature'] = self.starting_temp
91
 
92
  def _build_model_chain_parser(self):
93
  # Instantiate the LLM class for Google Gemini
94
+ self.llm_model = ChatGoogleGenerativeAI(model=self.model_name,
95
+ max_output_tokens=self.config.get('max_output_tokens'),
96
+ top_p=self.config.get('top_p'),
97
+ temperature=self.config.get('temperature')
98
+ )
99
  # self.llm_model = VertexAI(model='gemini-1.0-pro',
100
  # max_output_tokens=self.config.get('max_output_tokens'),
101
  # top_p=self.config.get('top_p'))
 
116
  def call_llm_api_GoogleGemini(self, prompt_template, json_report, paths):
117
  _____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
118
  self.json_report = json_report
119
+ if self.json_report:
120
+ self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
121
  self.monitor.start_monitoring_usage()
122
  nt_in = 0
123
  nt_out = 0
 
126
  while ind < self.MAX_RETRIES:
127
  ind += 1
128
  try:
129
+ # model_kwargs = {"temperature": self.adjust_temp}
130
  # Invoke the chain to generate prompt text
131
+ response = self.chain.invoke({"query": prompt_template})#, "model_kwargs": model_kwargs})
132
 
133
  # Use retry_parser to parse the response with retry logic
134
  output = self.retry_parser.parse_with_prompt(response, prompt_value=prompt_template)
 
147
  else:
148
  self.monitor.stop_inference_timer() # Starts tool timer too
149
 
150
+ if self.json_report:
151
+ self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
152
  output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
153
 
154
  save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
 
160
  if self.adjust_temp != self.starting_temp:
161
  self._reset_config()
162
 
163
+ if self.json_report:
164
+ self.json_report.set_text(text_main=f'LLM call successful')
165
  return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
166
 
167
  except Exception as e:
 
171
  time.sleep(self.RETRY_DELAY)
172
 
173
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
174
+ if self.json_report:
175
+ self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
176
 
177
  self.monitor.stop_inference_timer() # Starts tool timer too
178
 
179
  usage_report = self.monitor.stop_monitoring_report_usage()
180
  self._reset_config()
181
 
182
+ if self.json_report:
183
+ self.json_report.set_text(text_main=f'LLM call failed')
184
  return None, nt_in, nt_out, None, None, usage_report
185
 
186
 
vouchervision/LLM_GooglePalm2.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, time, json
2
  # import vertexai
3
  from vertexai.language_models import TextGenerationModel
4
  from vertexai.generative_models._generative_models import HarmCategory, HarmBlockThreshold
@@ -10,6 +10,8 @@ from langchain.prompts import PromptTemplate
10
  from langchain_core.output_parsers import JsonOutputParser
11
  # from langchain_google_genai import ChatGoogleGenerativeAI
12
  from langchain_google_vertexai import VertexAI
 
 
13
 
14
  from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
15
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
@@ -31,7 +33,7 @@ class GooglePalm2Handler:
31
  VENDOR = 'google'
32
  STARTING_TEMP = 0.5
33
 
34
- def __init__(self, cfg, logger, model_name, JSON_dict_structure):
35
  self.cfg = cfg
36
  self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
37
  self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
@@ -41,9 +43,9 @@ class GooglePalm2Handler:
41
  self.model_name = model_name
42
  self.JSON_dict_structure = JSON_dict_structure
43
 
44
- self.starting_temp = float(self.STARTING_TEMP)
45
- self.temp_increment = float(0.2)
46
- self.adjust_temp = self.starting_temp
47
 
48
  self.monitor = SystemLoadMonitor(logger)
49
 
@@ -59,12 +61,26 @@ class GooglePalm2Handler:
59
 
60
  def _set_config(self):
61
  # vertexai.init(project=os.environ['PALM_PROJECT_ID'], location=os.environ['PALM_LOCATION'])
62
- self.config = {
 
 
 
 
 
 
 
 
 
 
63
  "max_output_tokens": 1024,
64
  "temperature": self.starting_temp,
 
65
  "top_p": 1.0,
66
- "top_k": 40,
67
  }
 
 
 
 
68
  self.safety_settings = {
69
  HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
70
  HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
@@ -75,13 +91,15 @@ class GooglePalm2Handler:
75
 
76
  def _adjust_config(self):
77
  new_temp = self.adjust_temp + self.temp_increment
78
- self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
 
79
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
80
  self.adjust_temp += self.temp_increment
81
  self.config['temperature'] = self.adjust_temp
82
 
83
  def _reset_config(self):
84
- self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
 
85
  self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
86
  self.adjust_temp = self.starting_temp
87
  self.config['temperature'] = self.starting_temp
@@ -89,7 +107,11 @@ class GooglePalm2Handler:
89
  def _build_model_chain_parser(self):
90
  # Instantiate the parser and the retry parser
91
  # self.llm_model = ChatGoogleGenerativeAI(model=self.model_name)
92
- self.llm_model = VertexAI(model=self.model_name)
 
 
 
 
93
 
94
  self.retry_parser = RetryWithErrorOutputParser.from_llm(
95
  parser=self.parser,
@@ -105,6 +127,7 @@ class GooglePalm2Handler:
105
  response = model.predict(prompt_text.text,
106
  max_output_tokens=self.config.get('max_output_tokens'),
107
  temperature=self.config.get('temperature'),
 
108
  top_p=self.config.get('top_p'))
109
  # model = GenerativeModel(self.model_name)
110
 
@@ -115,7 +138,8 @@ class GooglePalm2Handler:
115
  def call_llm_api_GooglePalm2(self, prompt_template, json_report, paths):
116
  _____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
117
  self.json_report = json_report
118
- self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
 
119
  self.monitor.start_monitoring_usage()
120
  nt_in = 0
121
  nt_out = 0
@@ -124,12 +148,23 @@ class GooglePalm2Handler:
124
  while ind < self.MAX_RETRIES:
125
  ind += 1
126
  try:
127
- model_kwargs = {"temperature": self.adjust_temp}
128
  # Invoke the chain to generate prompt text
129
- response = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})
130
 
131
  # Use retry_parser to parse the response with retry logic
132
- output = self.retry_parser.parse_with_prompt(response, prompt_value=prompt_template)
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  if output is None:
135
  self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
@@ -144,8 +179,9 @@ class GooglePalm2Handler:
144
  self._adjust_config()
145
  else:
146
  self.monitor.stop_inference_timer() # Starts tool timer too
147
-
148
- json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
 
149
  output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
150
 
151
  save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
@@ -157,7 +193,8 @@ class GooglePalm2Handler:
157
  if self.adjust_temp != self.starting_temp:
158
  self._reset_config()
159
 
160
- json_report.set_text(text_main=f'LLM call successful')
 
161
  return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
162
 
163
  except Exception as e:
@@ -167,11 +204,19 @@ class GooglePalm2Handler:
167
  time.sleep(self.RETRY_DELAY)
168
 
169
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
170
- self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
 
171
 
172
  self.monitor.stop_inference_timer() # Starts tool timer too
173
  usage_report = self.monitor.stop_monitoring_report_usage()
174
  self._reset_config()
175
 
176
- json_report.set_text(text_main=f'LLM call failed')
177
- return None, nt_in, nt_out, None, None, usage_report
 
 
 
 
 
 
 
 
1
+ import os, time, json, typing
2
  # import vertexai
3
  from vertexai.language_models import TextGenerationModel
4
  from vertexai.generative_models._generative_models import HarmCategory, HarmBlockThreshold
 
10
  from langchain_core.output_parsers import JsonOutputParser
11
  # from langchain_google_genai import ChatGoogleGenerativeAI
12
  from langchain_google_vertexai import VertexAI
13
+ from langchain_core.messages import BaseMessage, HumanMessage
14
+ from langchain_core.prompt_values import PromptValue as BasePromptValue
15
 
16
  from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
17
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
 
33
  VENDOR = 'google'
34
  STARTING_TEMP = 0.5
35
 
36
+ def __init__(self, cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation):
37
  self.cfg = cfg
38
  self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
39
  self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
 
43
  self.model_name = model_name
44
  self.JSON_dict_structure = JSON_dict_structure
45
 
46
+ self.config_vals_for_permutation = config_vals_for_permutation
47
+
48
+
49
 
50
  self.monitor = SystemLoadMonitor(logger)
51
 
 
61
 
62
  def _set_config(self):
63
  # vertexai.init(project=os.environ['PALM_PROJECT_ID'], location=os.environ['PALM_LOCATION'])
64
+ if self.config_vals_for_permutation:
65
+ self.starting_temp = float(self.config_vals_for_permutation.get('google').get('temperature'))
66
+ self.config = {
67
+ 'max_output_tokens': self.config_vals_for_permutation.get('google').get('max_output_tokens'),
68
+ 'temperature': self.starting_temp,
69
+ 'top_k': self.config_vals_for_permutation.get('google').get('top_k'),
70
+ 'top_p': self.config_vals_for_permutation.get('google').get('top_p'),
71
+ }
72
+ else:
73
+ self.starting_temp = float(self.STARTING_TEMP)
74
+ self.config = {
75
  "max_output_tokens": 1024,
76
  "temperature": self.starting_temp,
77
+ "top_k": 1,
78
  "top_p": 1.0,
 
79
  }
80
+
81
+ self.temp_increment = float(0.2)
82
+ self.adjust_temp = self.starting_temp
83
+
84
  self.safety_settings = {
85
  HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
86
  HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
 
91
 
92
  def _adjust_config(self):
93
  new_temp = self.adjust_temp + self.temp_increment
94
+ if self.json_report:
95
+ self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
96
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
97
  self.adjust_temp += self.temp_increment
98
  self.config['temperature'] = self.adjust_temp
99
 
100
  def _reset_config(self):
101
+ if self.json_report:
102
+ self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
103
  self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
104
  self.adjust_temp = self.starting_temp
105
  self.config['temperature'] = self.starting_temp
 
107
  def _build_model_chain_parser(self):
108
  # Instantiate the parser and the retry parser
109
  # self.llm_model = ChatGoogleGenerativeAI(model=self.model_name)
110
+ self.llm_model = VertexAI(model=self.model_name,
111
+ max_output_tokens=self.config.get('max_output_tokens'),
112
+ temperature=self.config.get('temperature'),
113
+ top_k=self.config.get('top_k'),
114
+ top_p=self.config.get('top_p'))
115
 
116
  self.retry_parser = RetryWithErrorOutputParser.from_llm(
117
  parser=self.parser,
 
127
  response = model.predict(prompt_text.text,
128
  max_output_tokens=self.config.get('max_output_tokens'),
129
  temperature=self.config.get('temperature'),
130
+ top_k=self.config.get('top_k'),
131
  top_p=self.config.get('top_p'))
132
  # model = GenerativeModel(self.model_name)
133
 
 
138
  def call_llm_api_GooglePalm2(self, prompt_template, json_report, paths):
139
  _____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
140
  self.json_report = json_report
141
+ if json_report:
142
+ self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
143
  self.monitor.start_monitoring_usage()
144
  nt_in = 0
145
  nt_out = 0
 
148
  while ind < self.MAX_RETRIES:
149
  ind += 1
150
  try:
151
+ # model_kwargs = {"temperature": self.adjust_temp}
152
  # Invoke the chain to generate prompt text
153
+ response = self.chain.invoke({"query": prompt_template})#, "model_kwargs": model_kwargs})
154
 
155
  # Use retry_parser to parse the response with retry logic
156
+ try:
157
+ output = self.retry_parser.parse_with_prompt(response, prompt_value=PromptValue(prompt_template))
158
+ except:
159
+ try:
160
+ output = self.retry_parser.parse_with_prompt(response, prompt_value=prompt_template)
161
+ except:
162
+ try:
163
+ output = json.loads(response)
164
+ except Exception as e:
165
+ print(e)
166
+ output = None
167
+
168
 
169
  if output is None:
170
  self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
 
179
  self._adjust_config()
180
  else:
181
  self.monitor.stop_inference_timer() # Starts tool timer too
182
+
183
+ if self.json_report:
184
+ self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
185
  output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
186
 
187
  save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
 
193
  if self.adjust_temp != self.starting_temp:
194
  self._reset_config()
195
 
196
+ if self.json_report:
197
+ self.json_report.set_text(text_main=f'LLM call successful')
198
  return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
199
 
200
  except Exception as e:
 
204
  time.sleep(self.RETRY_DELAY)
205
 
206
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
207
+ if self.json_report:
208
+ self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
209
 
210
  self.monitor.stop_inference_timer() # Starts tool timer too
211
  usage_report = self.monitor.stop_monitoring_report_usage()
212
  self._reset_config()
213
 
214
+ if self.json_report:
215
+ self.json_report.set_text(text_main=f'LLM call failed')
216
+ return None, nt_in, nt_out, None, None, usage_report
217
+
218
+ class PromptValue(BasePromptValue):
219
+ prompt_str: str
220
+
221
+ def to_string(self) -> str:
222
+ return self.prompt_str
vouchervision/LLM_MistralAI.py CHANGED
@@ -11,12 +11,12 @@ from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys
11
  class MistralHandler:
12
  RETRY_DELAY = 2 # Wait 10 seconds before retrying
13
  MAX_RETRIES = 5 # Maximum number of retries
14
- STARTING_TEMP = 0.1
15
  TOKENIZER_NAME = None
16
  VENDOR = 'mistral'
17
  RANDOM_SEED = 2023
18
 
19
- def __init__(self, cfg, logger, model_name, JSON_dict_structure):
20
  self.cfg = cfg
21
  self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
22
  self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
@@ -27,10 +27,9 @@ class MistralHandler:
27
  self.has_GPU = torch.cuda.is_available()
28
  self.model_name = model_name
29
  self.JSON_dict_structure = JSON_dict_structure
30
- self.starting_temp = float(self.STARTING_TEMP)
31
- self.temp_increment = float(0.2)
32
- self.adjust_temp = self.starting_temp
33
 
 
 
34
  # Set up a parser
35
  self.parser = JsonOutputParser()
36
 
@@ -44,25 +43,45 @@ class MistralHandler:
44
  self._set_config()
45
 
46
  def _set_config(self):
47
- self.config = {'max_tokens': 1024,
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  'temperature': self.starting_temp,
49
  'random_seed': self.RANDOM_SEED,
50
  'safe_mode': False,
51
- 'top_p': 1,
52
- }
 
 
 
 
 
53
  self._build_model_chain_parser()
54
 
55
 
56
  def _adjust_config(self):
57
  new_temp = self.adjust_temp + self.temp_increment
58
  self.config['random_seed'] = random.randint(1, 1000)
59
- self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp} and random_seed to {self.config.get("random_seed")}')
 
60
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp} and random_seed to {self.config.get("random_seed")}')
61
  self.adjust_temp += self.temp_increment
62
  self.config['temperature'] = self.adjust_temp
63
 
64
  def _reset_config(self):
65
- self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp} and random_seed to {self.RANDOM_SEED}')
 
66
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {self.starting_temp} and random_seed to {self.RANDOM_SEED}')
67
  self.adjust_temp = self.starting_temp
68
  self.config['temperature'] = self.starting_temp
@@ -74,7 +93,9 @@ class MistralHandler:
74
  model=self.model_name,
75
  max_tokens=self.config.get('max_tokens'),
76
  safe_mode=self.config.get('safe_mode'),
77
- top_p=self.config.get('top_p'))
 
 
78
 
79
  # Set up the retry parser with the runnable
80
  self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.llm_model, max_retries=self.MAX_RETRIES)
@@ -85,7 +106,8 @@ class MistralHandler:
85
  _____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
86
 
87
  self.json_report = json_report
88
- self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
 
89
  self.monitor.start_monitoring_usage()
90
  nt_in = 0
91
  nt_out = 0
@@ -94,10 +116,10 @@ class MistralHandler:
94
  while ind < self.MAX_RETRIES:
95
  ind += 1
96
  try:
97
- model_kwargs = {"temperature": self.adjust_temp, "random_seed": self.config.get("random_seed")}
98
 
99
  # Invoke the chain to generate prompt text
100
- response = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})
101
 
102
  # Use retry_parser to parse the response with retry logic
103
  output = self.retry_parser.parse_with_prompt(response.content, prompt_value=prompt_template)
@@ -115,8 +137,9 @@ class MistralHandler:
115
  self._adjust_config()
116
  else:
117
  self.monitor.stop_inference_timer() # Starts tool timer too
118
-
119
- json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
 
120
  output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
121
 
122
  save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
@@ -128,7 +151,8 @@ class MistralHandler:
128
  if self.adjust_temp != self.starting_temp:
129
  self._reset_config()
130
 
131
- json_report.set_text(text_main=f'LLM call successful')
 
132
  return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
133
 
134
  except Exception as e:
@@ -138,11 +162,13 @@ class MistralHandler:
138
  time.sleep(self.RETRY_DELAY)
139
 
140
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
141
- self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
 
142
 
143
  self.monitor.stop_inference_timer() # Starts tool timer too
144
  usage_report = self.monitor.stop_monitoring_report_usage()
145
  self._reset_config()
146
- json_report.set_text(text_main=f'LLM call failed')
 
147
 
148
  return None, nt_in, nt_out, None, None, usage_report
 
11
  class MistralHandler:
12
  RETRY_DELAY = 2 # Wait 10 seconds before retrying
13
  MAX_RETRIES = 5 # Maximum number of retries
14
+ STARTING_TEMP = 0.5 #0.01
15
  TOKENIZER_NAME = None
16
  VENDOR = 'mistral'
17
  RANDOM_SEED = 2023
18
 
19
+ def __init__(self, cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation):
20
  self.cfg = cfg
21
  self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
22
  self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
 
27
  self.has_GPU = torch.cuda.is_available()
28
  self.model_name = model_name
29
  self.JSON_dict_structure = JSON_dict_structure
 
 
 
30
 
31
+ self.config_vals_for_permutation = config_vals_for_permutation
32
+
33
  # Set up a parser
34
  self.parser = JsonOutputParser()
35
 
 
43
  self._set_config()
44
 
45
  def _set_config(self):
46
+ if self.config_vals_for_permutation:
47
+ self.starting_temp = float(self.config_vals_for_permutation.get('mistral').get('temperature'))
48
+ self.config = {
49
+ 'max_tokens': self.config_vals_for_permutation.get('mistral').get('max_tokens'),
50
+ 'temperature': self.starting_temp,
51
+ 'top_p': self.config_vals_for_permutation.get('mistral').get('top_p'),
52
+ 'top_k': self.config_vals_for_permutation.get('mistral').get('top_k'),
53
+ 'safe_mode': self.config_vals_for_permutation.get('mistral').get('safe_mode'),
54
+ 'random_seed': self.config_vals_for_permutation.get('mistral').get('random_seed'),
55
+ }
56
+ else:
57
+ self.starting_temp = float(self.STARTING_TEMP)
58
+ self.config = {
59
+ 'max_tokens': 1024,
60
  'temperature': self.starting_temp,
61
  'random_seed': self.RANDOM_SEED,
62
  'safe_mode': False,
63
+ 'top_p': 0.5,
64
+ 'top_k': 0.5,
65
+ }
66
+
67
+ self.temp_increment = float(0.2)
68
+ self.adjust_temp = self.starting_temp
69
+
70
  self._build_model_chain_parser()
71
 
72
 
73
  def _adjust_config(self):
74
  new_temp = self.adjust_temp + self.temp_increment
75
  self.config['random_seed'] = random.randint(1, 1000)
76
+ if self.json_report:
77
+ self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp} and random_seed to {self.config.get("random_seed")}')
78
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp} and random_seed to {self.config.get("random_seed")}')
79
  self.adjust_temp += self.temp_increment
80
  self.config['temperature'] = self.adjust_temp
81
 
82
  def _reset_config(self):
83
+ if self.json_report:
84
+ self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp} and random_seed to {self.RANDOM_SEED}')
85
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {self.starting_temp} and random_seed to {self.RANDOM_SEED}')
86
  self.adjust_temp = self.starting_temp
87
  self.config['temperature'] = self.starting_temp
 
93
  model=self.model_name,
94
  max_tokens=self.config.get('max_tokens'),
95
  safe_mode=self.config.get('safe_mode'),
96
+ top_p=self.config.get('top_p'),
97
+ top_k=self.config.get('top_k'),
98
+ )
99
 
100
  # Set up the retry parser with the runnable
101
  self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.llm_model, max_retries=self.MAX_RETRIES)
 
106
  _____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
107
 
108
  self.json_report = json_report
109
+ if self.json_report:
110
+ self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
111
  self.monitor.start_monitoring_usage()
112
  nt_in = 0
113
  nt_out = 0
 
116
  while ind < self.MAX_RETRIES:
117
  ind += 1
118
  try:
119
+ # model_kwargs = {"temperature": self.adjust_temp, "random_seed": self.config.get("random_seed")}
120
 
121
  # Invoke the chain to generate prompt text
122
+ response = self.chain.invoke({"query": prompt_template})#, "model_kwargs": model_kwargs})
123
 
124
  # Use retry_parser to parse the response with retry logic
125
  output = self.retry_parser.parse_with_prompt(response.content, prompt_value=prompt_template)
 
137
  self._adjust_config()
138
  else:
139
  self.monitor.stop_inference_timer() # Starts tool timer too
140
+
141
+ if self.json_report:
142
+ self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
143
  output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
144
 
145
  save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
 
151
  if self.adjust_temp != self.starting_temp:
152
  self._reset_config()
153
 
154
+ if self.json_report:
155
+ self.json_report.set_text(text_main=f'LLM call successful')
156
  return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
157
 
158
  except Exception as e:
 
162
  time.sleep(self.RETRY_DELAY)
163
 
164
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
165
+ if self.json_report:
166
+ self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
167
 
168
  self.monitor.stop_inference_timer() # Starts tool timer too
169
  usage_report = self.monitor.stop_monitoring_report_usage()
170
  self._reset_config()
171
+ if self.json_report:
172
+ self.json_report.set_text(text_main=f'LLM call failed')
173
 
174
  return None, nt_in, nt_out, None, None, usage_report
vouchervision/LLM_OpenAI.py CHANGED
@@ -11,11 +11,11 @@ from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys
11
  class OpenAIHandler:
12
  RETRY_DELAY = 10 # Wait 10 seconds before retrying
13
  MAX_RETRIES = 3 # Maximum number of retries
14
- STARTING_TEMP = 0.5
15
  TOKENIZER_NAME = 'gpt-4'
16
  VENDOR = 'openai'
17
 
18
- def __init__(self, cfg, logger, model_name, JSON_dict_structure, is_azure, llm_object):
19
  self.cfg = cfg
20
  self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
21
  self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
@@ -26,14 +26,13 @@ class OpenAIHandler:
26
  self.JSON_dict_structure = JSON_dict_structure
27
  self.is_azure = is_azure
28
  self.llm_object = llm_object
29
- self.name_parts = self.model_name.split('-')
30
 
31
  self.monitor = SystemLoadMonitor(logger)
32
  self.has_GPU = torch.cuda.is_available()
33
 
34
- self.starting_temp = float(self.STARTING_TEMP)
35
- self.temp_increment = float(0.2)
36
- self.adjust_temp = self.starting_temp
37
 
38
  # Set up a parser
39
  self.parser = JsonOutputParser()
@@ -45,12 +44,44 @@ class OpenAIHandler:
45
  )
46
  self._set_config()
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def _set_config(self):
49
- self.config = {'max_new_tokens': 1024,
50
- 'temperature': self.starting_temp,
51
- 'random_seed': 2023,
52
- 'top_p': 1,
53
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # Adjusting the LLM settings based on whether Azure is used
55
  if self.is_azure:
56
  self.llm_object.deployment_name = self.model_name
@@ -68,43 +99,84 @@ class OpenAIHandler:
68
 
69
  def _adjust_config(self):
70
  new_temp = self.adjust_temp + self.temp_increment
71
- self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
 
72
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
73
  self.adjust_temp += self.temp_increment
74
- self.config['temperature'] = self.adjust_temp
75
 
76
  def _reset_config(self):
77
- self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
 
78
  self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
79
  self.adjust_temp = self.starting_temp
80
- self.config['temperature'] = self.starting_temp
81
 
82
  def _build_model_chain_parser(self):
83
  if not self.is_azure and ('instruct' in self.name_parts):
 
 
 
 
 
 
 
 
 
 
84
  # Set up the retry parser with 3 retries
85
  self.retry_parser = RetryWithErrorOutputParser.from_llm(
86
- # parser=self.parser, llm=self.llm_object if self.is_azure else OpenAI(temperature=self.config.get('temperature'), model=self.model_name), max_retries=self.MAX_RETRIES
87
- parser=self.parser, llm=self.llm_object if self.is_azure else OpenAI(model=self.model_name), max_retries=self.MAX_RETRIES
 
88
  )
89
  else:
90
- # Set up the retry parser with 3 retries
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  self.retry_parser = RetryWithErrorOutputParser.from_llm(
92
- # parser=self.parser, llm=self.llm_object if self.is_azure else ChatOpenAI(temperature=self.config.get('temperature'), model=self.model_name), max_retries=self.MAX_RETRIES
93
- parser=self.parser, llm=self.llm_object if self.is_azure else ChatOpenAI(model=self.model_name), max_retries=self.MAX_RETRIES
 
94
  )
 
95
  # Prepare the chain
96
- if not self.is_azure and ('instruct' in self.name_parts):
97
- # self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else OpenAI(temperature=self.config.get('temperature'), model=self.model_name))
98
- self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else OpenAI(model=self.model_name))
99
  else:
100
- # self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else ChatOpenAI(temperature=self.config.get('temperature'), model=self.model_name))
101
- self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else ChatOpenAI(model=self.model_name))
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
 
104
  def call_llm_api_OpenAI(self, prompt_template, json_report, paths):
105
  _____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
106
  self.json_report = json_report
107
- self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
 
108
  self.monitor.start_monitoring_usage()
109
  nt_in = 0
110
  nt_out = 0
@@ -113,14 +185,20 @@ class OpenAIHandler:
113
  while ind < self.MAX_RETRIES:
114
  ind += 1
115
  try:
116
- model_kwargs = {"temperature": self.adjust_temp}
117
  # Invoke the chain to generate prompt text
118
- response = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})
119
 
120
  response_text = response.content if not isinstance(response, str) else response
121
 
122
  # Use retry_parser to parse the response with retry logic
123
- output = self.retry_parser.parse_with_prompt(response_text, prompt_value=prompt_template)
 
 
 
 
 
 
124
 
125
  if output is None:
126
  self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response_text}')
@@ -136,14 +214,11 @@ class OpenAIHandler:
136
  else:
137
  self.monitor.stop_inference_timer() # Starts tool timer too
138
 
139
- json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
 
140
 
141
  output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
142
 
143
- # output1, WFO_record = validate_taxonomy_WFO(self.tool_WFO, output, replace_if_success_wfo=False)
144
- # output2, GEO_record = validate_coordinates_here(self.tool_GEO, output, replace_if_success_geo=False)
145
- # validate_wikipedia(self.tool_wikipedia, json_file_path_wiki, output)
146
-
147
  save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
148
 
149
  self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
@@ -153,7 +228,8 @@ class OpenAIHandler:
153
  if self.adjust_temp != self.starting_temp:
154
  self._reset_config()
155
 
156
- json_report.set_text(text_main=f'LLM call successful')
 
157
  return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
158
 
159
  except Exception as e:
@@ -163,11 +239,15 @@ class OpenAIHandler:
163
  time.sleep(self.RETRY_DELAY)
164
 
165
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
166
- self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
 
167
 
168
  self.monitor.stop_inference_timer() # Starts tool timer too
169
  usage_report = self.monitor.stop_monitoring_report_usage()
170
  self._reset_config()
171
 
172
- json_report.set_text(text_main=f'LLM call failed')
 
173
  return None, nt_in, nt_out, None, None, usage_report
 
 
 
11
  class OpenAIHandler:
12
  RETRY_DELAY = 10 # Wait 10 seconds before retrying
13
  MAX_RETRIES = 3 # Maximum number of retries
14
+ STARTING_TEMP = 0.5 # 0.5, config_vals_for_permutation
15
  TOKENIZER_NAME = 'gpt-4'
16
  VENDOR = 'openai'
17
 
18
+ def __init__(self, cfg, logger, model_name, JSON_dict_structure, is_azure, llm_object, config_vals_for_permutation):
19
  self.cfg = cfg
20
  self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
21
  self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
 
26
  self.JSON_dict_structure = JSON_dict_structure
27
  self.is_azure = is_azure
28
  self.llm_object = llm_object
29
+ self.name_parts = self.model_name.lower().split('-')
30
 
31
  self.monitor = SystemLoadMonitor(logger)
32
  self.has_GPU = torch.cuda.is_available()
33
 
34
+ ### Config
35
+ self.config_vals_for_permutation = config_vals_for_permutation
 
36
 
37
  # Set up a parser
38
  self.parser = JsonOutputParser()
 
44
  )
45
  self._set_config()
46
 
47
+ def _can_use_json_mode(self):
48
+ if self.is_azure:
49
+ return False
50
+ # gpt-4-turbo-preview (gpt-4-0125-preview)
51
+ if ('0125' in self.name_parts) and ('4' in self.name_parts):
52
+ return True
53
+ # gpt-3.5-turbo-0125
54
+ elif ('0125' in self.name_parts) and ('3.5' in self.name_parts) and ('turbo' in self.name_parts):
55
+ return True
56
+ else:
57
+ return False
58
+
59
+
60
  def _set_config(self):
61
+ if self.config_vals_for_permutation:
62
+ self.starting_temp = float(self.config_vals_for_permutation.get('openai').get('temperature'))
63
+ self.model_kwargs = {
64
+ 'max_tokens': self.config_vals_for_permutation.get('openai').get('max_tokens'),
65
+ 'temperature': self.starting_temp,
66
+ # 'seed': self.config_vals_for_permutation.get('openai').get('seed'),
67
+ 'top_p': self.config_vals_for_permutation.get('openai').get('top_p'),
68
+ }
69
+ else:
70
+ self.starting_temp = float(self.STARTING_TEMP)
71
+ self.model_kwargs = {
72
+ 'max_tokens': 1024,
73
+ 'temperature': self.starting_temp,
74
+ # 'seed': 2023,
75
+ 'top_p': 1, # Set to 1, change temp only
76
+ }
77
+
78
+ ### Not all openai models support json mode
79
+ if self._can_use_json_mode():
80
+ self.model_kwargs.update({"response_format": {"type": "json_object"}})
81
+
82
+ self.temp_increment = float(0.2)
83
+ self.adjust_temp = self.starting_temp
84
+
85
  # Adjusting the LLM settings based on whether Azure is used
86
  if self.is_azure:
87
  self.llm_object.deployment_name = self.model_name
 
99
 
100
  def _adjust_config(self):
101
  new_temp = self.adjust_temp + self.temp_increment
102
+ if self.json_report:
103
+ self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
104
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
105
  self.adjust_temp += self.temp_increment
106
+ self.model_kwargs['temperature'] = self.adjust_temp
107
 
108
  def _reset_config(self):
109
+ if self.json_report:
110
+ self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
111
  self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
112
  self.adjust_temp = self.starting_temp
113
+ self.model_kwargs['temperature'] = self.starting_temp
114
 
115
  def _build_model_chain_parser(self):
116
  if not self.is_azure and ('instruct' in self.name_parts):
117
+ # Determine the LLM to use based on whether this is an Azure instance
118
+ if self.is_azure:
119
+ llm_to_use = self.llm_object
120
+ else:
121
+ llm_to_use = OpenAI(
122
+ model=self.model_name,
123
+ temperature=self.model_kwargs.get('temperature'),
124
+ top_p=self.model_kwargs.get('top_p'),
125
+ max_tokens=self.model_kwargs.get('max_tokens')
126
+ )
127
  # Set up the retry parser with 3 retries
128
  self.retry_parser = RetryWithErrorOutputParser.from_llm(
129
+ parser=self.parser,
130
+ llm=llm_to_use,
131
+ max_retries=self.MAX_RETRIES
132
  )
133
  else:
134
+ # Determine the LLM to use for non-Azure instances
135
+ if self.is_azure:
136
+ llm_to_use = self.llm_object
137
+ self.llm_object.temperature = self.model_kwargs.get('temperature')
138
+ self.llm_object.max_tokens = self.model_kwargs.get('max_tokens')
139
+ self.llm_object.model_kwargs = self.model_kwargs
140
+ else:
141
+ llm_to_use = ChatOpenAI(
142
+ model=self.model_name,
143
+ temperature=self.model_kwargs.get('temperature'),
144
+ top_p=self.model_kwargs.get('top_p'),
145
+ max_tokens=self.model_kwargs.get('max_tokens'),
146
+ )
147
+ # Set up the retry parser with 3 retries for other cases
148
  self.retry_parser = RetryWithErrorOutputParser.from_llm(
149
+ parser=self.parser,
150
+ llm=llm_to_use,
151
+ max_retries=self.MAX_RETRIES
152
  )
153
+
154
  # Prepare the chain
155
+ if self.is_azure:
156
+ chain_llm_to_use = self.format_input_for_azure
 
157
  else:
158
+ if 'instruct' in self.name_parts:
159
+ chain_llm_to_use = OpenAI(
160
+ model=self.model_name,
161
+ temperature=self.model_kwargs.get('temperature'),
162
+ top_p=self.model_kwargs.get('top_p'),
163
+ max_tokens=self.model_kwargs.get('max_tokens')
164
+ )
165
+ else:
166
+ chain_llm_to_use = ChatOpenAI(
167
+ model=self.model_name,
168
+ temperature=self.model_kwargs.get('temperature'),
169
+ top_p=self.model_kwargs.get('top_p'),
170
+ max_tokens=self.model_kwargs.get('max_tokens')
171
+ )
172
+ self.chain = self.prompt | chain_llm_to_use
173
 
174
 
175
  def call_llm_api_OpenAI(self, prompt_template, json_report, paths):
176
  _____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
177
  self.json_report = json_report
178
+ if self.json_report:
179
+ self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
180
  self.monitor.start_monitoring_usage()
181
  nt_in = 0
182
  nt_out = 0
 
185
  while ind < self.MAX_RETRIES:
186
  ind += 1
187
  try:
188
+ self.logger.info(str(self.model_kwargs))
189
  # Invoke the chain to generate prompt text
190
+ response = self.chain.invoke(input={"query": prompt_template})#, **self.model_kwargs)# "model_kwargs": self.model_kwargs})
191
 
192
  response_text = response.content if not isinstance(response, str) else response
193
 
194
  # Use retry_parser to parse the response with retry logic
195
+ try:
196
+ output = self.retry_parser.parse_with_prompt(response_text, prompt_value=prompt_template)
197
+ except:
198
+ try:
199
+ output = json.loads(response_text)
200
+ except:
201
+ output = None
202
 
203
  if output is None:
204
  self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response_text}')
 
214
  else:
215
  self.monitor.stop_inference_timer() # Starts tool timer too
216
 
217
+ if self.json_report:
218
+ self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
219
 
220
  output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
221
 
 
 
 
 
222
  save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
223
 
224
  self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
 
228
  if self.adjust_temp != self.starting_temp:
229
  self._reset_config()
230
 
231
+ if self.json_report:
232
+ self.json_report.set_text(text_main=f'LLM call successful')
233
  return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
234
 
235
  except Exception as e:
 
239
  time.sleep(self.RETRY_DELAY)
240
 
241
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
242
+ if self.json_report:
243
+ self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
244
 
245
  self.monitor.stop_inference_timer() # Starts tool timer too
246
  usage_report = self.monitor.stop_monitoring_report_usage()
247
  self._reset_config()
248
 
249
+ if self.json_report:
250
+ self.json_report.set_text(text_main=f'LLM call failed')
251
  return None, nt_in, nt_out, None, None, usage_report
252
+
253
+
vouchervision/LLM_local_MistralAI.py CHANGED
@@ -22,7 +22,7 @@ class LocalMistralHandler:
22
  VENDOR = 'mistral'
23
  MAX_GPU_MONITORING_INTERVAL = 2 # seconds
24
 
25
- def __init__(self, cfg, logger, model_name, JSON_dict_structure):
26
  self.cfg = cfg
27
  self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
28
  self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
@@ -122,13 +122,15 @@ class LocalMistralHandler:
122
 
123
  def _adjust_config(self):
124
  new_temp = self.adjust_temp + self.temp_increment
125
- self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
 
126
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
127
  self.adjust_temp += self.temp_increment
128
 
129
 
130
  def _reset_config(self):
131
- self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
 
132
  self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
133
  self.adjust_temp = self.starting_temp
134
 
@@ -153,7 +155,8 @@ class LocalMistralHandler:
153
  def call_llm_local_MistralAI(self, prompt_template, json_report, paths):
154
  _____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
155
  self.json_report = json_report
156
- self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
 
157
  self.monitor.start_monitoring_usage()
158
 
159
  nt_in = 0
@@ -188,8 +191,9 @@ class LocalMistralHandler:
188
  self._adjust_config()
189
  else:
190
  self.monitor.stop_inference_timer() # Starts tool timer too
191
-
192
- json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
 
193
  output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
194
 
195
  save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
@@ -201,7 +205,8 @@ class LocalMistralHandler:
201
  if self.adjust_temp != self.starting_temp:
202
  self._reset_config()
203
 
204
- json_report.set_text(text_main=f'LLM call successful')
 
205
  del results
206
  return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
207
 
@@ -210,11 +215,13 @@ class LocalMistralHandler:
210
  self._adjust_config()
211
 
212
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
213
- self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
 
214
 
215
  self.monitor.stop_inference_timer() # Starts tool timer too
216
  usage_report = self.monitor.stop_monitoring_report_usage()
217
- json_report.set_text(text_main=f'LLM call failed')
 
218
 
219
  self._reset_config()
220
  return None, nt_in, nt_out, None, None, usage_report
 
22
  VENDOR = 'mistral'
23
  MAX_GPU_MONITORING_INTERVAL = 2 # seconds
24
 
25
+ def __init__(self, cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation):
26
  self.cfg = cfg
27
  self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
28
  self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
 
122
 
123
  def _adjust_config(self):
124
  new_temp = self.adjust_temp + self.temp_increment
125
+ if self.json_report:
126
+ self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
127
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
128
  self.adjust_temp += self.temp_increment
129
 
130
 
131
  def _reset_config(self):
132
+ if self.json_report:
133
+ self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
134
  self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
135
  self.adjust_temp = self.starting_temp
136
 
 
155
  def call_llm_local_MistralAI(self, prompt_template, json_report, paths):
156
  _____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
157
  self.json_report = json_report
158
+ if self.json_report:
159
+ self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
160
  self.monitor.start_monitoring_usage()
161
 
162
  nt_in = 0
 
191
  self._adjust_config()
192
  else:
193
  self.monitor.stop_inference_timer() # Starts tool timer too
194
+
195
+ if self.json_report:
196
+ self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
197
  output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
198
 
199
  save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
 
205
  if self.adjust_temp != self.starting_temp:
206
  self._reset_config()
207
 
208
+ if self.json_report:
209
+ self.json_report.set_text(text_main=f'LLM call successful')
210
  del results
211
  return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
212
 
 
215
  self._adjust_config()
216
 
217
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
218
+ if self.json_report:
219
+ self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
220
 
221
  self.monitor.stop_inference_timer() # Starts tool timer too
222
  usage_report = self.monitor.stop_monitoring_report_usage()
223
+ if self.json_report:
224
+ self.json_report.set_text(text_main=f'LLM call failed')
225
 
226
  self._reset_config()
227
  return None, nt_in, nt_out, None, None, usage_report
vouchervision/LLM_local_cpu_MistralAI.py CHANGED
@@ -30,7 +30,7 @@ class LocalCPUMistralHandler:
30
  SEED = 2023
31
 
32
 
33
- def __init__(self, cfg, logger, model_name, JSON_dict_structure):
34
  self.cfg = cfg
35
  self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
36
  self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
@@ -106,13 +106,15 @@ class LocalCPUMistralHandler:
106
 
107
  def _adjust_config(self):
108
  new_temp = self.adjust_temp + self.temp_increment
109
- self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
 
110
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
111
  self.adjust_temp += self.temp_increment
112
  self.config['temperature'] = self.adjust_temp
113
 
114
  def _reset_config(self):
115
- self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
 
116
  self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
117
  self.adjust_temp = self.starting_temp
118
  self.config['temperature'] = self.starting_temp
@@ -140,7 +142,8 @@ class LocalCPUMistralHandler:
140
  def call_llm_local_cpu_MistralAI(self, prompt_template, json_report, paths):
141
  _____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
142
  self.json_report = json_report
143
- self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
 
144
  self.monitor.start_monitoring_usage()
145
 
146
  nt_in = 0
@@ -180,7 +183,8 @@ class LocalCPUMistralHandler:
180
  else:
181
  self.monitor.stop_inference_timer() # Starts tool timer too
182
 
183
- json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
 
184
  output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
185
 
186
  save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
@@ -192,7 +196,8 @@ class LocalCPUMistralHandler:
192
  if self.adjust_temp != self.starting_temp:
193
  self._reset_config()
194
 
195
- json_report.set_text(text_main=f'LLM call successful')
 
196
  return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
197
 
198
  except Exception as e:
@@ -200,13 +205,15 @@ class LocalCPUMistralHandler:
200
  self._adjust_config()
201
 
202
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
203
- self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
 
204
 
205
  self.monitor.stop_inference_timer() # Starts tool timer too
206
  usage_report = self.monitor.stop_monitoring_report_usage()
207
  self._reset_config()
208
 
209
- json_report.set_text(text_main=f'LLM call failed')
 
210
  return None, nt_in, nt_out, None, None, usage_report
211
 
212
 
 
30
  SEED = 2023
31
 
32
 
33
+ def __init__(self, cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation):
34
  self.cfg = cfg
35
  self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
36
  self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
 
106
 
107
  def _adjust_config(self):
108
  new_temp = self.adjust_temp + self.temp_increment
109
+ if self.json_report:
110
+ self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
111
  self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
112
  self.adjust_temp += self.temp_increment
113
  self.config['temperature'] = self.adjust_temp
114
 
115
  def _reset_config(self):
116
+ if self.json_report:
117
+ self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
118
  self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
119
  self.adjust_temp = self.starting_temp
120
  self.config['temperature'] = self.starting_temp
 
142
  def call_llm_local_cpu_MistralAI(self, prompt_template, json_report, paths):
143
  _____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
144
  self.json_report = json_report
145
+ if self.json_report:
146
+ self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
147
  self.monitor.start_monitoring_usage()
148
 
149
  nt_in = 0
 
183
  else:
184
  self.monitor.stop_inference_timer() # Starts tool timer too
185
 
186
+ if self.json_report:
187
+ self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
188
  output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
189
 
190
  save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
 
196
  if self.adjust_temp != self.starting_temp:
197
  self._reset_config()
198
 
199
+ if self.json_report:
200
+ self.json_report.set_text(text_main=f'LLM call successful')
201
  return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
202
 
203
  except Exception as e:
 
205
  self._adjust_config()
206
 
207
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
208
+ if self.json_report:
209
+ self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
210
 
211
  self.monitor.stop_inference_timer() # Starts tool timer too
212
  usage_report = self.monitor.stop_monitoring_report_usage()
213
  self._reset_config()
214
 
215
+ if self.json_report:
216
+ self.json_report.set_text(text_main=f'LLM call failed')
217
  return None, nt_in, nt_out, None, None, usage_report
218
 
219
 
vouchervision/LM2_logger.py CHANGED
@@ -1,4 +1,5 @@
1
  import logging, os, psutil, torch, platform, cpuinfo, yaml #py-cpuinfo
 
2
  from vouchervision.general_utils import get_datetime, print_main_warn, print_main_info
3
 
4
  class SanitizingFileHandler(logging.FileHandler):
@@ -17,7 +18,7 @@ def start_logging(Dirs, cfg):
17
  path_log = os.path.join(Dirs.path_log, '__'.join(['LM2-log', str(get_datetime()), run_name]) + '.log')
18
 
19
  # Disable default StreamHandler
20
- logging.getLogger().handlers = []
21
 
22
  # create logger
23
  logger = logging.getLogger('Hardware Components')
@@ -27,20 +28,25 @@ def start_logging(Dirs, cfg):
27
  sanitizing_fh = SanitizingFileHandler(path_log, encoding='utf-8')
28
  sanitizing_fh.setLevel(logging.DEBUG)
29
 
 
 
 
30
  # create console handler and set level to debug
31
- ch = logging.StreamHandler()
32
- ch.setLevel(logging.DEBUG)
33
 
34
  # create formatter
35
  formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
36
 
37
  # add formatter to handlers
38
  sanitizing_fh.setFormatter(formatter)
39
- ch.setFormatter(formatter)
 
40
 
41
  # add handlers to logger
42
  logger.addHandler(sanitizing_fh)
43
- logger.addHandler(ch)
 
44
 
45
  # Create a logger for the file handler
46
  file_logger = logging.getLogger('file_logger')
@@ -110,6 +116,17 @@ def find_cpu_info():
110
  except:
111
  return "CPU: UNKNOWN"
112
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  def LM2_banner():
115
  logo = """
 
1
  import logging, os, psutil, torch, platform, cpuinfo, yaml #py-cpuinfo
2
+ from tqdm import tqdm
3
  from vouchervision.general_utils import get_datetime, print_main_warn, print_main_info
4
 
5
  class SanitizingFileHandler(logging.FileHandler):
 
18
  path_log = os.path.join(Dirs.path_log, '__'.join(['LM2-log', str(get_datetime()), run_name]) + '.log')
19
 
20
  # Disable default StreamHandler
21
+ logging.getLogger().handlers = []
22
 
23
  # create logger
24
  logger = logging.getLogger('Hardware Components')
 
28
  sanitizing_fh = SanitizingFileHandler(path_log, encoding='utf-8')
29
  sanitizing_fh.setLevel(logging.DEBUG)
30
 
31
+ tqdm_handler = TqdmLoggingHandler()
32
+ tqdm_handler.setLevel(logging.DEBUG)
33
+
34
  # create console handler and set level to debug
35
+ # ch = logging.StreamHandler()
36
+ # ch.setLevel(logging.DEBUG)
37
 
38
  # create formatter
39
  formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
40
 
41
  # add formatter to handlers
42
  sanitizing_fh.setFormatter(formatter)
43
+ tqdm_handler.setFormatter(formatter)
44
+ # ch.setFormatter(formatter)
45
 
46
  # add handlers to logger
47
  logger.addHandler(sanitizing_fh)
48
+ logger.addHandler(tqdm_handler)
49
+ # logger.addHandler(ch)
50
 
51
  # Create a logger for the file handler
52
  file_logger = logging.getLogger('file_logger')
 
116
  except:
117
  return "CPU: UNKNOWN"
118
 
119
+ class TqdmLoggingHandler(logging.Handler):
120
+ def __init__(self, level=logging.NOTSET):
121
+ super().__init__(level)
122
+
123
+ def emit(self, record):
124
+ try:
125
+ msg = self.format(record)
126
+ tqdm.write(msg) # Use tqdm's write function to ensure correct output
127
+ self.flush()
128
+ except Exception:
129
+ self.handleError(record)
130
 
131
  def LM2_banner():
132
  logo = """
vouchervision/OCR_google_cloud_vision.py CHANGED
@@ -123,8 +123,9 @@ class OCREngine:
123
 
124
  self.model_path = "liuhaotian/" + self.cfg['leafmachine']['project']['OCR_option_llava']
125
  self.model_quant = self.cfg['leafmachine']['project']['OCR_option_llava_bit']
126
-
127
- self.json_report.set_text(text_main=f'Loading LLaVA model: {self.model_path} Quantization: {self.model_quant}')
 
128
 
129
  if self.model_quant == '4bit':
130
  use_4bit = True
@@ -191,7 +192,8 @@ class OCREngine:
191
  # Process each detected text region
192
  for box in self.prediction_result["boxes"]:
193
  i+=1
194
- self.json_report.set_text(text_main=f'Locating text using CRAFT --- {i}/{total_b}')
 
195
 
196
  vertices = [{"x": int(vertex[0]), "y": int(vertex[1])} for vertex in box]
197
 
@@ -283,7 +285,8 @@ class OCREngine:
283
  i=0
284
  for bound in tqdm(available_bounds, desc="Processing words using Google Vision bboxes"):
285
  i+=1
286
- self.json_report.set_text(text_main=f'Working on trOCR :construction: {i}/{total_b}')
 
287
 
288
  vertices = bound["vertices"]
289
 
@@ -688,7 +691,8 @@ class OCREngine:
688
  # logger.info(f"CRAFT trOCR:\n{self.OCR}")
689
 
690
  if 'LLaVA' in self.OCR_option: # This option does not produce an OCR helper image
691
- self.json_report.set_text(text_main=f'Working on LLaVA {self.Llava.model_path} transcription :construction:')
 
692
 
693
  image, json_output, direct_output, str_output, usage_report = self.Llava.transcribe_image(self.path, self.multimodal_prompt)
694
  self.logger.info(f"LLaVA Usage Report for Model {self.Llava.model_path}:\n{usage_report}")
@@ -786,4 +790,20 @@ class OCREngine:
786
  from craft_text_detector import empty_cuda_cache
787
  empty_cuda_cache()
788
  except:
789
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  self.model_path = "liuhaotian/" + self.cfg['leafmachine']['project']['OCR_option_llava']
125
  self.model_quant = self.cfg['leafmachine']['project']['OCR_option_llava_bit']
126
+
127
+ if self.json_report:
128
+ self.json_report.set_text(text_main=f'Loading LLaVA model: {self.model_path} Quantization: {self.model_quant}')
129
 
130
  if self.model_quant == '4bit':
131
  use_4bit = True
 
192
  # Process each detected text region
193
  for box in self.prediction_result["boxes"]:
194
  i+=1
195
+ if self.json_report:
196
+ self.json_report.set_text(text_main=f'Locating text using CRAFT --- {i}/{total_b}')
197
 
198
  vertices = [{"x": int(vertex[0]), "y": int(vertex[1])} for vertex in box]
199
 
 
285
  i=0
286
  for bound in tqdm(available_bounds, desc="Processing words using Google Vision bboxes"):
287
  i+=1
288
+ if self.json_report:
289
+ self.json_report.set_text(text_main=f'Working on trOCR :construction: {i}/{total_b}')
290
 
291
  vertices = bound["vertices"]
292
 
 
691
  # logger.info(f"CRAFT trOCR:\n{self.OCR}")
692
 
693
  if 'LLaVA' in self.OCR_option: # This option does not produce an OCR helper image
694
+ if self.json_report:
695
+ self.json_report.set_text(text_main=f'Working on LLaVA {self.Llava.model_path} transcription :construction:')
696
 
697
  image, json_output, direct_output, str_output, usage_report = self.Llava.transcribe_image(self.path, self.multimodal_prompt)
698
  self.logger.info(f"LLaVA Usage Report for Model {self.Llava.model_path}:\n{usage_report}")
 
790
  from craft_text_detector import empty_cuda_cache
791
  empty_cuda_cache()
792
  except:
793
+ pass
794
+
795
+ def check_for_inappropriate_content(file_stream):
796
+ client = vision.ImageAnnotatorClient()
797
+
798
+ content = file_stream.read()
799
+ image = vision.Image(content=content)
800
+ response = client.safe_search_detection(image=image)
801
+ safe = response.safe_search_annotation
802
+
803
+ # Check the levels of adult, violence, racy, etc. content.
804
+ if (safe.adult > vision.Likelihood.POSSIBLE or
805
+ safe.violence > vision.Likelihood.POSSIBLE or
806
+ safe.racy > vision.Likelihood.POSSIBLE):
807
+ return True # The image violates safe search guidelines.
808
+
809
+ return False # The image is considered safe.
vouchervision/VoucherVision_Config_Builder.py CHANGED
@@ -49,7 +49,7 @@ def build_VV_config(loaded_cfg=None):
49
 
50
  check_for_illegal_filenames = False
51
 
52
- LLM_version_user = 'Azure GPT 3.5 Instruct' #'Azure GPT 4 Turbo 1106-preview'
53
  prompt_version = 'SLTPvA_long.yaml' # from ["Version 1", "Version 1 No Domain Knowledge", "Version 2"]
54
  use_LeafMachine2_collage_images = True # Use LeafMachine2 collage images
55
  do_create_OCR_helper_image = True
 
49
 
50
  check_for_illegal_filenames = False
51
 
52
+ LLM_version_user = 'Azure GPT 3.5 Turbo' #'Azure GPT 4 Turbo 1106-preview'
53
  prompt_version = 'SLTPvA_long.yaml' # from ["Version 1", "Version 1 No Domain Knowledge", "Version 2"]
54
  use_LeafMachine2_collage_images = True # Use LeafMachine2 collage images
55
  do_create_OCR_helper_image = True
vouchervision/model_maps.py CHANGED
@@ -20,9 +20,11 @@ class ModelMaps:
20
  'AZURE_GPT_3_5_INSTRUCT': '#9400D3', # Dark Violet
21
  'AZURE_GPT_3_5': '#9932CC', # Dark Orchid
22
 
23
- 'MISTRAL_TINY': '#FFA07A', # Light Salmon
24
- 'MISTRAL_SMALL': '#FF8C00', # Dark Orange
 
25
  'MISTRAL_MEDIUM': '#FF4500', # Orange Red
 
26
 
27
  'LOCAL_MIXTRAL_8X7B_INSTRUCT_V01': '#000000', # Black
28
  'LOCAL_MISTRAL_7B_INSTRUCT_V02': '#4a4a4a', # Gray
@@ -34,14 +36,14 @@ class ModelMaps:
34
  "GPT 4 32k",
35
  "GPT 4 Turbo 0125-preview",
36
  "GPT 4 Turbo 1106-preview",
37
- "GPT 3.5",
38
  "GPT 3.5 Instruct",
39
 
40
  "Azure GPT 4",
41
  "Azure GPT 4 32k",
42
  "Azure GPT 4 Turbo 0125-preview",
43
  "Azure GPT 4 Turbo 1106-preview",
44
- "Azure GPT 3.5",
45
  "Azure GPT 3.5 Instruct",]
46
 
47
  MODELS_GOOGLE = ["PaLM 2 text-bison@001",
@@ -49,15 +51,18 @@ class ModelMaps:
49
  "PaLM 2 text-unicorn@001",
50
  "Gemini Pro"]
51
 
52
- MODELS_MISTRAL = ["Mistral Tiny",
53
- "Mistral Small",
54
- "Mistral Medium",]
 
 
 
55
 
56
  MODELS_LOCAL = ["LOCAL Mixtral 8x7B Instruct v0.1",
57
  "LOCAL Mistral 7B Instruct v0.2",
58
  "LOCAL CPU Mistral 7B Instruct v0.2 GGUF",]
59
 
60
- MODELS_GUI_DEFAULT = "Azure GPT 3.5 Instruct" # "GPT 4 Turbo 1106-preview"
61
 
62
  version_mapping_cost = {
63
  'GPT 4 32k': 'GPT_4_32K',
@@ -65,23 +70,25 @@ class ModelMaps:
65
  'GPT 4 Turbo 0125-preview': 'GPT_4_TURBO_0125',
66
  'GPT 4 Turbo 1106-preview': 'GPT_4_TURBO_1106',
67
  'GPT 3.5 Instruct': 'GPT_3_5_INSTRUCT',
68
- 'GPT 3.5': 'GPT_3_5',
69
 
70
  'Azure GPT 4 32k': 'AZURE_GPT_4_32K',
71
  'Azure GPT 4': 'AZURE_GPT_4',
72
  'Azure GPT 4 Turbo 0125-preview': 'AZURE_GPT_4_TURBO_0125',
73
  'Azure GPT 4 Turbo 1106-preview': 'AZURE_GPT_4_TURBO_1106',
74
  'Azure GPT 3.5 Instruct': 'AZURE_GPT_3_5_INSTRUCT',
75
- 'Azure GPT 3.5': 'AZURE_GPT_3_5',
76
 
77
  'Gemini Pro': 'GEMINI_PRO',
78
  'PaLM 2 text-unicorn@001': 'PALM2_TU_1',
79
  'PaLM 2 text-bison@001': 'PALM2_TB_1',
80
  'PaLM 2 text-bison@002': 'PALM2_TB_2',
81
 
 
82
  'Mistral Medium': 'MISTRAL_MEDIUM',
83
  'Mistral Small': 'MISTRAL_SMALL',
84
- 'Mistral Tiny': 'MISTRAL_TINY',
 
85
 
86
  'LOCAL Mixtral 8x7B Instruct v0.1': 'LOCAL_MIXTRAL_8X7B_INSTRUCT_V01',
87
  'LOCAL Mistral 7B Instruct v0.2': 'LOCAL_MISTRAL_7B_INSTRUCT_V02',
@@ -97,10 +104,10 @@ class ModelMaps:
97
  'GPT 4 Turbo 0125-preview': has_key_openai,
98
  'GPT 4': has_key_openai,
99
  'GPT 4 32k': has_key_openai,
100
- 'GPT 3.5': has_key_openai,
101
  'GPT 3.5 Instruct': has_key_openai,
102
 
103
- 'Azure GPT 3.5': has_key_azure_openai,
104
  'Azure GPT 3.5 Instruct': has_key_azure_openai,
105
  'Azure GPT 4': has_key_azure_openai,
106
  'Azure GPT 4 Turbo 1106-preview': has_key_azure_openai,
@@ -112,9 +119,11 @@ class ModelMaps:
112
  'PaLM 2 text-unicorn@001': has_key_google_application_credentials,
113
  'Gemini Pro': has_key_google_application_credentials,
114
 
115
- 'Mistral Tiny': has_key_mistral,
116
  'Mistral Small': has_key_mistral,
117
  'Mistral Medium': has_key_mistral,
 
 
 
118
 
119
  'LOCAL Mixtral 8x7B Instruct v0.1': True,
120
  'LOCAL Mistral 7B Instruct v0.2': True,
@@ -127,15 +136,17 @@ class ModelMaps:
127
  def get_version_mapping_is_azure(cls, key):
128
  version_mapping_is_azure = {
129
  "GPT 4 Turbo 1106-preview": False,
 
130
  'GPT 4': False,
131
  'GPT 4 32k': False,
132
- 'GPT 3.5': False,
133
  'GPT 3.5 Instruct': False,
134
 
135
- 'Azure GPT 3.5': True,
136
  'Azure GPT 3.5 Instruct': True,
137
  'Azure GPT 4': True,
138
  'Azure GPT 4 Turbo 1106-preview': True,
 
139
  'Azure GPT 4 32k': True,
140
 
141
  'PaLM 2 text-bison@001': False,
@@ -143,9 +154,11 @@ class ModelMaps:
143
  'PaLM 2 text-unicorn@001': False,
144
  'Gemini Pro': False,
145
 
146
- 'Mistral Tiny': False,
147
  'Mistral Small': False,
148
  'Mistral Medium': False,
 
 
 
149
 
150
  'LOCAL Mixtral 8x7B Instruct v0.1': False,
151
  'LOCAL Mistral 7B Instruct v0.2': False,
@@ -159,7 +172,7 @@ class ModelMaps:
159
 
160
  ### OpenAI
161
  if key == 'GPT_3_5':
162
- return 'gpt-3.5-turbo-1106'
163
 
164
  elif key == 'GPT_3_5_INSTRUCT':
165
  return 'gpt-3.5-turbo-instruct'
@@ -178,7 +191,7 @@ class ModelMaps:
178
 
179
  ### Azure
180
  elif key == 'AZURE_GPT_3_5':
181
- return 'gpt-35-turbo-1106'
182
 
183
  elif key == 'AZURE_GPT_3_5_INSTRUCT':
184
  return 'gpt-35-turbo-instruct'
@@ -209,14 +222,20 @@ class ModelMaps:
209
  return "gemini-1.0-pro"
210
 
211
  ### Mistral
212
- elif key == 'MISTRAL_TINY':
213
- return "mistral-tiny"
 
 
 
214
 
215
  elif key == 'MISTRAL_SMALL':
216
- return 'mistral-small'
217
 
218
  elif key == 'MISTRAL_MEDIUM':
219
- return 'mistral-medium'
 
 
 
220
 
221
 
222
  ### Mistral LOCAL
 
20
  'AZURE_GPT_3_5_INSTRUCT': '#9400D3', # Dark Violet
21
  'AZURE_GPT_3_5': '#9932CC', # Dark Orchid
22
 
23
+ 'OPEN_MISTRAL_7B': '#FFA07A', # Light Salmon
24
+ 'OPEN_MIXTRAL_8X7B': '#FF8C00', # Dark Orange
25
+ 'MISTRAL_SMALL': '#FF6347', # Tomato
26
  'MISTRAL_MEDIUM': '#FF4500', # Orange Red
27
+ 'MISTRAL_LARGE': '#800000', # Maroon
28
 
29
  'LOCAL_MIXTRAL_8X7B_INSTRUCT_V01': '#000000', # Black
30
  'LOCAL_MISTRAL_7B_INSTRUCT_V02': '#4a4a4a', # Gray
 
36
  "GPT 4 32k",
37
  "GPT 4 Turbo 0125-preview",
38
  "GPT 4 Turbo 1106-preview",
39
+ "GPT 3.5 Turbo",
40
  "GPT 3.5 Instruct",
41
 
42
  "Azure GPT 4",
43
  "Azure GPT 4 32k",
44
  "Azure GPT 4 Turbo 0125-preview",
45
  "Azure GPT 4 Turbo 1106-preview",
46
+ "Azure GPT 3.5 Turbo",
47
  "Azure GPT 3.5 Instruct",]
48
 
49
  MODELS_GOOGLE = ["PaLM 2 text-bison@001",
 
51
  "PaLM 2 text-unicorn@001",
52
  "Gemini Pro"]
53
 
54
+ MODELS_MISTRAL = ["Mistral Small",
55
+ "Mistral Medium",
56
+ "Mistral Large",
57
+ "Open Mixtral 8x7B",
58
+ "Open Mistral 7B",
59
+ ]
60
 
61
  MODELS_LOCAL = ["LOCAL Mixtral 8x7B Instruct v0.1",
62
  "LOCAL Mistral 7B Instruct v0.2",
63
  "LOCAL CPU Mistral 7B Instruct v0.2 GGUF",]
64
 
65
+ MODELS_GUI_DEFAULT = "Azure GPT 3.5 Turbo" # "GPT 4 Turbo 1106-preview"
66
 
67
  version_mapping_cost = {
68
  'GPT 4 32k': 'GPT_4_32K',
 
70
  'GPT 4 Turbo 0125-preview': 'GPT_4_TURBO_0125',
71
  'GPT 4 Turbo 1106-preview': 'GPT_4_TURBO_1106',
72
  'GPT 3.5 Instruct': 'GPT_3_5_INSTRUCT',
73
+ 'GPT 3.5 Turbo': 'GPT_3_5',
74
 
75
  'Azure GPT 4 32k': 'AZURE_GPT_4_32K',
76
  'Azure GPT 4': 'AZURE_GPT_4',
77
  'Azure GPT 4 Turbo 0125-preview': 'AZURE_GPT_4_TURBO_0125',
78
  'Azure GPT 4 Turbo 1106-preview': 'AZURE_GPT_4_TURBO_1106',
79
  'Azure GPT 3.5 Instruct': 'AZURE_GPT_3_5_INSTRUCT',
80
+ 'Azure GPT 3.5 Turbo': 'AZURE_GPT_3_5',
81
 
82
  'Gemini Pro': 'GEMINI_PRO',
83
  'PaLM 2 text-unicorn@001': 'PALM2_TU_1',
84
  'PaLM 2 text-bison@001': 'PALM2_TB_1',
85
  'PaLM 2 text-bison@002': 'PALM2_TB_2',
86
 
87
+ 'Mistral Large': 'MISTRAL_LARGE',
88
  'Mistral Medium': 'MISTRAL_MEDIUM',
89
  'Mistral Small': 'MISTRAL_SMALL',
90
+ 'Open Mixtral 8x7B': 'OPEN_MIXTRAL_8X7B',
91
+ 'Open Mistral 7B': 'OPEN_MISTRAL_7B',
92
 
93
  'LOCAL Mixtral 8x7B Instruct v0.1': 'LOCAL_MIXTRAL_8X7B_INSTRUCT_V01',
94
  'LOCAL Mistral 7B Instruct v0.2': 'LOCAL_MISTRAL_7B_INSTRUCT_V02',
 
104
  'GPT 4 Turbo 0125-preview': has_key_openai,
105
  'GPT 4': has_key_openai,
106
  'GPT 4 32k': has_key_openai,
107
+ 'GPT 3.5 Turbo': has_key_openai,
108
  'GPT 3.5 Instruct': has_key_openai,
109
 
110
+ 'Azure GPT 3.5 Turbo': has_key_azure_openai,
111
  'Azure GPT 3.5 Instruct': has_key_azure_openai,
112
  'Azure GPT 4': has_key_azure_openai,
113
  'Azure GPT 4 Turbo 1106-preview': has_key_azure_openai,
 
119
  'PaLM 2 text-unicorn@001': has_key_google_application_credentials,
120
  'Gemini Pro': has_key_google_application_credentials,
121
 
 
122
  'Mistral Small': has_key_mistral,
123
  'Mistral Medium': has_key_mistral,
124
+ 'Mistral Large': has_key_mistral,
125
+ 'Open Mixtral 8x7B': has_key_mistral,
126
+ 'Open Mistral 7B': has_key_mistral,
127
 
128
  'LOCAL Mixtral 8x7B Instruct v0.1': True,
129
  'LOCAL Mistral 7B Instruct v0.2': True,
 
136
  def get_version_mapping_is_azure(cls, key):
137
  version_mapping_is_azure = {
138
  "GPT 4 Turbo 1106-preview": False,
139
+ "GPT 4 Turbo 0125-preview": False,
140
  'GPT 4': False,
141
  'GPT 4 32k': False,
142
+ 'GPT 3.5 Turbo': False,
143
  'GPT 3.5 Instruct': False,
144
 
145
+ 'Azure GPT 3.5 Turbo': True,
146
  'Azure GPT 3.5 Instruct': True,
147
  'Azure GPT 4': True,
148
  'Azure GPT 4 Turbo 1106-preview': True,
149
+ 'Azure GPT 4 Turbo 0125-preview': True,
150
  'Azure GPT 4 32k': True,
151
 
152
  'PaLM 2 text-bison@001': False,
 
154
  'PaLM 2 text-unicorn@001': False,
155
  'Gemini Pro': False,
156
 
 
157
  'Mistral Small': False,
158
  'Mistral Medium': False,
159
+ 'Mistral Large': False,
160
+ 'Open Mixtral 8x7B': False,
161
+ 'Open Mistral 7B': False,
162
 
163
  'LOCAL Mixtral 8x7B Instruct v0.1': False,
164
  'LOCAL Mistral 7B Instruct v0.2': False,
 
172
 
173
  ### OpenAI
174
  if key == 'GPT_3_5':
175
+ return 'gpt-3.5-turbo-0125' #'gpt-3.5-turbo-1106'
176
 
177
  elif key == 'GPT_3_5_INSTRUCT':
178
  return 'gpt-3.5-turbo-instruct'
 
191
 
192
  ### Azure
193
  elif key == 'AZURE_GPT_3_5':
194
+ return 'gpt-35-turbo-0125'
195
 
196
  elif key == 'AZURE_GPT_3_5_INSTRUCT':
197
  return 'gpt-35-turbo-instruct'
 
222
  return "gemini-1.0-pro"
223
 
224
  ### Mistral
225
+ elif key == 'OPEN_MISTRAL_7B':
226
+ return "open-mistral-7b"
227
+
228
+ elif key == 'OPEN_MIXTRAL_8X7B':
229
+ return 'open-mixtral-8x7b'
230
 
231
  elif key == 'MISTRAL_SMALL':
232
+ return 'mistral-small-latest'
233
 
234
  elif key == 'MISTRAL_MEDIUM':
235
+ return 'mistral-medium-latest'
236
+
237
+ elif key == 'MISTRAL_LARGE':
238
+ return 'mistral-large-latest'
239
 
240
 
241
  ### Mistral LOCAL
vouchervision/prompt_catalog.py CHANGED
@@ -18,7 +18,7 @@ class PromptCatalog:
18
 
19
 
20
  def prompt_SLTP(self, rules_config_path, OCR=None, is_palm=False):
21
- self.OCR = OCR
22
 
23
  self.rules_config_path = rules_config_path
24
  self.rules_config = self.load_rules_config()
@@ -48,9 +48,9 @@ class PromptCatalog:
48
  The unstructured OCR text is:
49
  {self.OCR}
50
  Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
51
- {self.structure}
52
- {self.structure}
53
- {self.structure}
54
  """
55
  else:
56
  prompt = f"""Please help me complete this text parsing task given the following rules and unstructured OCR text. Your task is to refactor the OCR text into a structured JSON dictionary that matches the structure specified in the following rules. Please follow the rules strictly.
@@ -62,13 +62,16 @@ class PromptCatalog:
62
  The unstructured OCR text is:
63
  {self.OCR}
64
  Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
65
- {self.structure}
66
  """
67
  # xlsx_headers = self.generate_xlsx_headers(is_palm)
68
 
69
  # return prompt, self.PromptJSONModel, self.n_fields, xlsx_headers
 
70
  return prompt, self.dictionary_structure
71
 
 
 
72
 
73
  def copy_prompt_template_to_new_dir(self, new_directory_path, rules_config_path):
74
  # Ensure the target directory exists, create it if it doesn't
@@ -102,22 +105,31 @@ class PromptCatalog:
102
  return structure_json_str
103
 
104
  def create_structure(self, is_palm=False):
105
- # Create fields for the Pydantic model dynamically
106
- fields = {key: (str, Field(default=value, description=value)) for key, value in self.rules_list.items()}
107
 
108
- # Dynamically create the Pydantic model
109
- DynamicJSONParsingModel = create_model('SLTPvA', **fields)
110
- DynamicJSONParsingModel_use = DynamicJSONParsingModel()
111
 
112
- # Define the structure for the "Dictionary" section
113
- dictionary_fields = {key: (str, Field(default='', description="")) for key in self.rules_list.keys()}
114
 
115
- # Dynamically create the "Dictionary" Pydantic model
116
- PromptJSONModel = create_model('PromptJSONModel', **dictionary_fields)
117
 
118
- # Convert the model to JSON string (for demonstration)
119
- dictionary_structure = PromptJSONModel().dict()
 
 
 
 
 
 
120
  structure_json_str = json.dumps(dictionary_structure, sort_keys=False, indent=4)
 
 
 
121
  return structure_json_str, dictionary_structure
122
 
123
 
 
18
 
19
 
20
  def prompt_SLTP(self, rules_config_path, OCR=None, is_palm=False):
21
+ self.OCR = self.remove_colons_and_double_apostrophes(OCR)
22
 
23
  self.rules_config_path = rules_config_path
24
  self.rules_config = self.load_rules_config()
 
48
  The unstructured OCR text is:
49
  {self.OCR}
50
  Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
51
+ {self.dictionary_structure}
52
+ {self.dictionary_structure}
53
+ {self.dictionary_structure}
54
  """
55
  else:
56
  prompt = f"""Please help me complete this text parsing task given the following rules and unstructured OCR text. Your task is to refactor the OCR text into a structured JSON dictionary that matches the structure specified in the following rules. Please follow the rules strictly.
 
62
  The unstructured OCR text is:
63
  {self.OCR}
64
  Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
65
+ {self.dictionary_structure}
66
  """
67
  # xlsx_headers = self.generate_xlsx_headers(is_palm)
68
 
69
  # return prompt, self.PromptJSONModel, self.n_fields, xlsx_headers
70
+ # print(prompt)
71
  return prompt, self.dictionary_structure
72
 
73
+ def remove_colons_and_double_apostrophes(self, text):
74
+ return text.replace(":", "").replace("\"", "")
75
 
76
  def copy_prompt_template_to_new_dir(self, new_directory_path, rules_config_path):
77
  # Ensure the target directory exists, create it if it doesn't
 
105
  return structure_json_str
106
 
107
  def create_structure(self, is_palm=False):
108
+ # # Create fields for the Pydantic model dynamically
109
+ # fields = {key: (str, Field(default=value, description=value)) for key, value in self.rules_list.items()}
110
 
111
+ # # Dynamically create the Pydantic model
112
+ # DynamicJSONParsingModel = create_model('SLTPvA', **fields)
113
+ # DynamicJSONParsingModel_use = DynamicJSONParsingModel()
114
 
115
+ # # Define the structure for the "Dictionary" section
116
+ # dictionary_fields = {key: (str, Field(default='', description="")) for key in self.rules_list.keys()}
117
 
118
+ # # Dynamically create the "Dictionary" Pydantic model
119
+ # PromptJSONModel = create_model('PromptJSONModel', **dictionary_fields)
120
 
121
+ # # Convert the model to JSON string (for demonstration)
122
+ # dictionary_structure = PromptJSONModel().dict()
123
+ # structure_json_str = json.dumps(dictionary_structure, sort_keys=False, indent=4)
124
+
125
+ # Directly create the dictionary structure with empty strings as default values
126
+ dictionary_structure = {key: '' for key in self.rules_list.keys()}
127
+
128
+ # Convert the dictionary to JSON string for demonstration if needed
129
  structure_json_str = json.dumps(dictionary_structure, sort_keys=False, indent=4)
130
+ # print(structure_json_str)
131
+ # print(dictionary_structure)
132
+
133
  return structure_json_str, dictionary_structure
134
 
135
 
vouchervision/tool_taxonomy_WFO.py CHANGED
@@ -19,12 +19,19 @@ class WFONameMatcher:
19
  self.is_enabled = tool_WFO
20
 
21
  def extract_input_string(self, record):
22
- primary_input = f"{record.get('scientificName', '').strip()} {record.get('scientificNameAuthorship', '').strip()}".strip()
23
- secondary_input = ' '.join(filter(None, [record.get('genus', '').strip(),
24
- record.get('subgenus', '').strip(),
25
- record.get('specificEpithet', '').strip(),
26
- record.get('infraspecificEpithet', '').strip()])).strip()
 
27
 
 
 
 
 
 
 
28
  return primary_input, secondary_input
29
 
30
  def query_wfo_name_matching(self, input_string, check_homonyms=True, check_rank=True, accept_single_candidate=True):
@@ -46,6 +53,8 @@ class WFONameMatcher:
46
 
47
  def query_and_process(self, record):
48
  primary_input, secondary_input = self.extract_input_string(record)
 
 
49
 
50
  # Query with primary input
51
  primary_result = self.query_wfo_name_matching(primary_input)
 
19
  self.is_enabled = tool_WFO
20
 
21
  def extract_input_string(self, record):
22
+ if 'scientificName' in record and 'scientificNameAuthorship' in record:
23
+ primary_input = f"{record.get('scientificName', '').strip()} {record.get('scientificNameAuthorship', '').strip()}".strip()
24
+ elif 'speciesBinomialName' in record and 'speciesBinomialNameAuthorship' in record:
25
+ primary_input = f"{record.get('speciesBinomialName', '').strip()} {record.get('speciesBinomialNameAuthorship', '').strip()}".strip()
26
+ else:
27
+ return None, None
28
 
29
+ if 'genus' in record and 'specificEpithet' in record:
30
+ secondary_input = ' '.join(filter(None, [record.get('genus', '').strip(),
31
+ record.get('specificEpithet', '').strip()])).strip()
32
+ else:
33
+ return None, None
34
+
35
  return primary_input, secondary_input
36
 
37
  def query_wfo_name_matching(self, input_string, check_homonyms=True, check_rank=True, accept_single_candidate=True):
 
53
 
54
  def query_and_process(self, record):
55
  primary_input, secondary_input = self.extract_input_string(record)
56
+ if primary_input is None and secondary_input is None:
57
+ return self.NULL_DICT
58
 
59
  # Query with primary input
60
  primary_result = self.query_wfo_name_matching(primary_input)
vouchervision/utils_LLM.py CHANGED
@@ -63,16 +63,13 @@ def run_tools(output, tool_WFO, tool_GEO, tool_wikipedia, json_file_path_wiki):
63
  return output_WFO, WFO_record, output_GEO, GEO_record
64
 
65
 
 
66
  def save_individual_prompt(prompt_template, txt_file_path_ind_prompt):
67
  with open(txt_file_path_ind_prompt, 'w',encoding='utf-8') as file:
68
  file.write(prompt_template)
69
 
70
 
71
 
72
- def remove_colons_and_double_apostrophes(text):
73
- return text.replace(":", "").replace("\"", "")
74
-
75
-
76
  def sanitize_prompt(data):
77
  if isinstance(data, dict):
78
  return {sanitize_prompt(key): sanitize_prompt(value) for key, value in data.items()}
 
63
  return output_WFO, WFO_record, output_GEO, GEO_record
64
 
65
 
66
+
67
  def save_individual_prompt(prompt_template, txt_file_path_ind_prompt):
68
  with open(txt_file_path_ind_prompt, 'w',encoding='utf-8') as file:
69
  file.write(prompt_template)
70
 
71
 
72
 
 
 
 
 
73
  def sanitize_prompt(data):
74
  if isinstance(data, dict):
75
  return {sanitize_prompt(key): sanitize_prompt(value) for key, value in data.items()}
vouchervision/utils_LLM_JSON_validation.py CHANGED
@@ -11,7 +11,8 @@ def validate_and_align_JSON_keys_with_template(data, JSON_dict_structure):
11
  if value is None:
12
  data[key] = ''
13
  elif isinstance(value, str):
14
- if value.lower() in ['unknown', 'not provided', 'missing', 'na', 'none', 'n/a', 'null', 'unspecified',
 
15
  'not provided in the text', 'not found in the text',
16
  'not in the text', 'not provided', 'not found',
17
  'not provided in the ocr', 'not found in the ocr',
 
11
  if value is None:
12
  data[key] = ''
13
  elif isinstance(value, str):
14
+ if value.lower() in ['unknown','not provided', 'missing', 'na', 'none', 'n/a', 'null', 'unspecified',
15
+ 'TBD',
16
  'not provided in the text', 'not found in the text',
17
  'not in the text', 'not provided', 'not found',
18
  'not provided in the ocr', 'not found in the ocr',
vouchervision/utils_VoucherVision.py CHANGED
@@ -14,7 +14,6 @@ from vouchervision.LLM_GoogleGemini import GoogleGeminiHandler
14
  from vouchervision.LLM_MistralAI import MistralHandler
15
  from vouchervision.LLM_local_cpu_MistralAI import LocalCPUMistralHandler
16
  from vouchervision.LLM_local_MistralAI import LocalMistralHandler
17
- from vouchervision.utils_LLM import remove_colons_and_double_apostrophes
18
  from vouchervision.prompt_catalog import PromptCatalog
19
  from vouchervision.model_maps import ModelMaps
20
  from vouchervision.general_utils import get_cfg_from_full_path
@@ -32,7 +31,7 @@ from vouchervision.OCR_google_cloud_vision import OCREngine
32
 
33
  class VoucherVision():
34
 
35
- def __init__(self, cfg, logger, dir_home, path_custom_prompts, Project, Dirs, is_hf):
36
  self.cfg = cfg
37
  self.logger = logger
38
  self.dir_home = dir_home
@@ -43,6 +42,9 @@ class VoucherVision():
43
  self.prompt_version = None
44
  self.is_hf = is_hf
45
 
 
 
 
46
  # self.trOCR_model_version = "microsoft/trocr-large-handwritten"
47
  # self.trOCR_model_version = "microsoft/trocr-base-handwritten"
48
  # self.trOCR_model_version = "dh-unibe/trocr-medieval-escriptmask" # NOPE
@@ -686,9 +688,10 @@ class VoucherVision():
686
  Copy_Prompt = PromptCatalog()
687
  Copy_Prompt.copy_prompt_template_to_new_dir(self.Dirs.transcription_prompt, self.path_custom_prompts)
688
 
689
- json_report.set_text(text_main=f'Loading {MODEL_NAME_FORMATTED}')
690
- json_report.set_JSON({}, {}, {})
691
- llm_model = self.initialize_llm_model(self.cfg, self.logger, MODEL_NAME_FORMATTED, self.JSON_dict_structure, name_parts, is_azure, self.llm)
 
692
 
693
  for i, path_to_crop in enumerate(self.img_paths):
694
  self.update_progress_report_batch(progress_report, i)
@@ -701,9 +704,11 @@ class VoucherVision():
701
  self.path_to_crop = path_to_crop
702
 
703
  filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
704
- json_report.set_text(text_main='Starting OCR')
 
705
  self.perform_OCR_and_save_results(i, json_report, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds)
706
- json_report.set_text(text_main='Finished OCR')
 
707
 
708
  if not self.OCR:
709
  self.n_failed_OCR += 1
@@ -713,7 +718,7 @@ class VoucherVision():
713
  else:
714
  ### Format prompt
715
  prompt = self.setup_prompt()
716
- prompt = remove_colons_and_double_apostrophes(prompt)
717
 
718
  ### Send prompt to chosen LLM
719
  self.logger.info(f'Waiting for {model_name} API call --- Using {MODEL_NAME_FORMATTED}')
@@ -747,8 +752,9 @@ class VoucherVision():
747
  final_JSON_response, final_WFO_record, final_GEO_record = self.update_final_response(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, paths, path_to_crop, nt_in, nt_out)
748
 
749
  self.logger.info(f'Finished LLM call')
750
-
751
- json_report.set_JSON(final_JSON_response, final_WFO_record, final_GEO_record)
 
752
 
753
  self.update_progress_report_final(progress_report)
754
  final_JSON_response = self.parse_final_json_response(final_JSON_response)
@@ -758,22 +764,22 @@ class VoucherVision():
758
  ##################################################################################################################################
759
  ################################################## LLM Helper Funcs ##############################################################
760
  ##################################################################################################################################
761
- def initialize_llm_model(self, cfg, logger, model_name, JSON_dict_structure, name_parts, is_azure=None, llm_object=None):
762
  if 'LOCAL'in name_parts:
763
  if ('MIXTRAL' in name_parts) or ('MISTRAL' in name_parts):
764
  if 'CPU' in name_parts:
765
- return LocalCPUMistralHandler(cfg, logger, model_name, JSON_dict_structure)
766
  else:
767
- return LocalMistralHandler(cfg, logger, model_name, JSON_dict_structure)
768
  else:
769
  if 'PALM2' in name_parts:
770
- return GooglePalm2Handler(cfg, logger, model_name, JSON_dict_structure)
771
  elif 'GEMINI' in name_parts:
772
- return GoogleGeminiHandler(cfg, logger, model_name, JSON_dict_structure)
773
  elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
774
- return MistralHandler(cfg, logger, model_name, JSON_dict_structure)
775
  else:
776
- return OpenAIHandler(cfg, logger, model_name, JSON_dict_structure, is_azure, llm_object)
777
 
778
  def setup_prompt(self):
779
  Catalog = PromptCatalog()
 
14
  from vouchervision.LLM_MistralAI import MistralHandler
15
  from vouchervision.LLM_local_cpu_MistralAI import LocalCPUMistralHandler
16
  from vouchervision.LLM_local_MistralAI import LocalMistralHandler
 
17
  from vouchervision.prompt_catalog import PromptCatalog
18
  from vouchervision.model_maps import ModelMaps
19
  from vouchervision.general_utils import get_cfg_from_full_path
 
31
 
32
  class VoucherVision():
33
 
34
+ def __init__(self, cfg, logger, dir_home, path_custom_prompts, Project, Dirs, is_hf, config_vals_for_permutation=None):
35
  self.cfg = cfg
36
  self.logger = logger
37
  self.dir_home = dir_home
 
42
  self.prompt_version = None
43
  self.is_hf = is_hf
44
 
45
+ ### config_vals_for_permutation allows you to set the starting temp, top_k, top_p, seed....
46
+ self.config_vals_for_permutation = config_vals_for_permutation
47
+
48
  # self.trOCR_model_version = "microsoft/trocr-large-handwritten"
49
  # self.trOCR_model_version = "microsoft/trocr-base-handwritten"
50
  # self.trOCR_model_version = "dh-unibe/trocr-medieval-escriptmask" # NOPE
 
688
  Copy_Prompt = PromptCatalog()
689
  Copy_Prompt.copy_prompt_template_to_new_dir(self.Dirs.transcription_prompt, self.path_custom_prompts)
690
 
691
+ if json_report:
692
+ json_report.set_text(text_main=f'Loading {MODEL_NAME_FORMATTED}')
693
+ json_report.set_JSON({}, {}, {})
694
+ llm_model = self.initialize_llm_model(self.cfg, self.logger, MODEL_NAME_FORMATTED, self.JSON_dict_structure, name_parts, is_azure, self.llm, self.config_vals_for_permutation)
695
 
696
  for i, path_to_crop in enumerate(self.img_paths):
697
  self.update_progress_report_batch(progress_report, i)
 
704
  self.path_to_crop = path_to_crop
705
 
706
  filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
707
+ if json_report:
708
+ json_report.set_text(text_main='Starting OCR')
709
  self.perform_OCR_and_save_results(i, json_report, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds)
710
+ if json_report:
711
+ json_report.set_text(text_main='Finished OCR')
712
 
713
  if not self.OCR:
714
  self.n_failed_OCR += 1
 
718
  else:
719
  ### Format prompt
720
  prompt = self.setup_prompt()
721
+ # prompt = remove_colons_and_double_apostrophes(prompt) # This is moved to utils_VV since it broke the json structure.
722
 
723
  ### Send prompt to chosen LLM
724
  self.logger.info(f'Waiting for {model_name} API call --- Using {MODEL_NAME_FORMATTED}')
 
752
  final_JSON_response, final_WFO_record, final_GEO_record = self.update_final_response(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, paths, path_to_crop, nt_in, nt_out)
753
 
754
  self.logger.info(f'Finished LLM call')
755
+
756
+ if json_report:
757
+ json_report.set_JSON(final_JSON_response, final_WFO_record, final_GEO_record)
758
 
759
  self.update_progress_report_final(progress_report)
760
  final_JSON_response = self.parse_final_json_response(final_JSON_response)
 
764
  ##################################################################################################################################
765
  ################################################## LLM Helper Funcs ##############################################################
766
  ##################################################################################################################################
767
+ def initialize_llm_model(self, cfg, logger, model_name, JSON_dict_structure, name_parts, is_azure=None, llm_object=None, config_vals_for_permutation=None):
768
  if 'LOCAL'in name_parts:
769
  if ('MIXTRAL' in name_parts) or ('MISTRAL' in name_parts):
770
  if 'CPU' in name_parts:
771
+ return LocalCPUMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
772
  else:
773
+ return LocalMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
774
  else:
775
  if 'PALM2' in name_parts:
776
+ return GooglePalm2Handler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
777
  elif 'GEMINI' in name_parts:
778
+ return GoogleGeminiHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
779
  elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
780
+ return MistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
781
  else:
782
+ return OpenAIHandler(cfg, logger, model_name, JSON_dict_structure, is_azure, llm_object, config_vals_for_permutation)
783
 
784
  def setup_prompt(self):
785
  Catalog = PromptCatalog()
vouchervision/utils_VoucherVision_parallel.py ADDED
@@ -0,0 +1,1022 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import os, json, glob, shutil, yaml, torch, logging
3
+ import openpyxl
4
+ from openpyxl import Workbook, load_workbook
5
+ from tqdm import tqdm
6
+ import vertexai
7
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
8
+ from langchain_openai import AzureChatOpenAI
9
+ from google.oauth2 import service_account
10
+ from transformers import AutoTokenizer, AutoModel
11
+
12
+ from concurrent.futures import ThreadPoolExecutor, as_completed
13
+ from queue import Queue
14
+ import threading
15
+
16
+ from vouchervision.LLM_OpenAI import OpenAIHandler
17
+ from vouchervision.LLM_GooglePalm2 import GooglePalm2Handler
18
+ from vouchervision.LLM_GoogleGemini import GoogleGeminiHandler
19
+ from vouchervision.LLM_MistralAI import MistralHandler
20
+ from vouchervision.LLM_local_cpu_MistralAI import LocalCPUMistralHandler
21
+ from vouchervision.LLM_local_MistralAI import LocalMistralHandler
22
+ from vouchervision.prompt_catalog import PromptCatalog
23
+ from vouchervision.model_maps import ModelMaps
24
+ from vouchervision.general_utils import get_cfg_from_full_path
25
+ from vouchervision.OCR_google_cloud_vision import OCREngine
26
+
27
+ '''
28
+ * For the prefix_removal, the image names have 'MICH-V-' prior to the barcode, so that is used for matching
29
+ but removed for output.
30
+ * There is also code active to replace the LLM-predicted "Catalog Number" with the correct number since it is known.
31
+ The LLMs to usually assign the barcode to the correct field, but it's not needed since it is already known.
32
+ - Look for ####################### Catalog Number pre-defined
33
+ '''
34
+
35
+
36
+
37
+ class VoucherVision():
38
+
39
+ def __init__(self, cfg, logger, dir_home, path_custom_prompts, Project, Dirs, is_hf, config_vals_for_permutation=None):
40
+ self.cfg = cfg
41
+ self.logger = logger
42
+ self.dir_home = dir_home
43
+ self.path_custom_prompts = path_custom_prompts
44
+ self.Project = Project
45
+ self.Dirs = Dirs
46
+ self.headers = None
47
+ self.prompt_version = None
48
+ self.is_hf = is_hf
49
+
50
+ ### config_vals_for_permutation allows you to set the starting temp, top_k, top_p, seed....
51
+ self.config_vals_for_permutation = config_vals_for_permutation
52
+
53
+ # self.trOCR_model_version = "microsoft/trocr-large-handwritten"
54
+ # self.trOCR_model_version = "microsoft/trocr-base-handwritten"
55
+ # self.trOCR_model_version = "dh-unibe/trocr-medieval-escriptmask" # NOPE
56
+ # self.trOCR_model_version = "dh-unibe/trocr-kurrent" # NOPE
57
+ # self.trOCR_model_version = "DunnBC22/trocr-base-handwritten-OCR-handwriting_recognition_v2" # NOPE
58
+ self.trOCR_processor = None
59
+ self.trOCR_model = None
60
+
61
+ self.set_API_keys()
62
+ self.setup()
63
+
64
+
65
+ def setup(self):
66
+ self.logger.name = f'[Transcription]'
67
+ self.logger.info(f'Setting up OCR and LLM')
68
+
69
+ self.trOCR_model_version = self.cfg['leafmachine']['project']['trOCR_model_path']
70
+
71
+ self.db_name = self.cfg['leafmachine']['project']['embeddings_database_name']
72
+ self.path_domain_knowledge = self.cfg['leafmachine']['project']['path_to_domain_knowledge_xlsx']
73
+ self.build_new_db = self.cfg['leafmachine']['project']['build_new_embeddings_database']
74
+
75
+ self.continue_run_from_partial_xlsx = self.cfg['leafmachine']['project']['continue_run_from_partial_xlsx']
76
+
77
+ self.prefix_removal = self.cfg['leafmachine']['project']['prefix_removal']
78
+ self.suffix_removal = self.cfg['leafmachine']['project']['suffix_removal']
79
+ self.catalog_numerical_only = self.cfg['leafmachine']['project']['catalog_numerical_only']
80
+
81
+ self.prompt_version0 = self.cfg['leafmachine']['project']['prompt_version']
82
+ self.use_domain_knowledge = self.cfg['leafmachine']['project']['use_domain_knowledge']
83
+
84
+ self.catalog_name_options = ["Catalog Number", "catalog_number", "catalogNumber"]
85
+
86
+ self.geo_headers = ["GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
87
+ "GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
88
+ "GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",]
89
+
90
+ self.usage_headers = ["current_time", "inference_time_s", "tool_time_s","max_cpu", "max_ram_gb", "n_gpus", "max_gpu_load", "max_gpu_vram_gb","total_gpu_vram_gb","capability_score",]
91
+
92
+ self.wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"]
93
+ self.wfo_headers_no_lists = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
94
+
95
+ self.utility_headers = ["filename"] + self.wfo_headers + self.geo_headers + self.usage_headers + ["run_name", "prompt", "LLM", "tokens_in", "tokens_out", "LM2_collage", "OCR_method", "OCR_double", "OCR_trOCR", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
96
+ # "WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement",
97
+
98
+ # "GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
99
+ # "GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
100
+ # "GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",
101
+
102
+ # "tokens_in", "tokens_out", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
103
+
104
+ # WFO_candidate_names is separate, bc it may be type --> list
105
+
106
+ self.do_create_OCR_helper_image = self.cfg['leafmachine']['do_create_OCR_helper_image']
107
+
108
+ self.map_prompt_versions()
109
+ self.map_dir_labels()
110
+ self.map_API_options()
111
+ # self.init_embeddings()
112
+ self.init_transcription_xlsx()
113
+ self.init_trOCR_model()
114
+
115
+ '''Logging'''
116
+ self.logger.info(f'Transcribing dataset --- {self.dir_labels}')
117
+ self.logger.info(f'Saving transcription batch to --- {self.path_transcription}')
118
+ self.logger.info(f'Saving individual transcription files to --- {self.Dirs.transcription_ind}')
119
+ self.logger.info(f'Starting transcription...')
120
+ self.logger.info(f' LLM MODEL --> {self.version_name}')
121
+ self.logger.info(f' Using Azure API --> {self.is_azure}')
122
+ self.logger.info(f' Model name passed to API --> {self.model_name}')
123
+ self.logger.info(f' API access token is found in PRIVATE_DATA.yaml --> {self.has_key}')
124
+
125
+
126
+ def init_trOCR_model(self):
127
+ lgr = logging.getLogger('transformers')
128
+ lgr.setLevel(logging.ERROR)
129
+
130
+ self.trOCR_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") # usually just the "microsoft/trocr-base-handwritten"
131
+ self.trOCR_model = VisionEncoderDecoderModel.from_pretrained(self.trOCR_model_version) # This matches the model
132
+
133
+ # Check for GPU availability
134
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
135
+ self.trOCR_model.to(self.device)
136
+
137
+
138
+ def map_API_options(self):
139
+ self.chat_version = self.cfg['leafmachine']['LLM_version']
140
+
141
+ # Get the required values from ModelMaps
142
+ self.model_name = ModelMaps.get_version_mapping_cost(self.chat_version)
143
+ self.is_azure = ModelMaps.get_version_mapping_is_azure(self.chat_version)
144
+ self.has_key = ModelMaps.get_version_has_key(self.chat_version, self.has_key_openai, self.has_key_azure_openai, self.has_key_google_application_credentials, self.has_key_mistral)
145
+
146
+ # Check if the version is supported
147
+ if self.model_name is None:
148
+ supported_LLMs = ", ".join(ModelMaps.get_models_gui_list())
149
+ raise Exception(f"Unsupported LLM: {self.chat_version}. Requires one of: {supported_LLMs}")
150
+
151
+ self.version_name = self.chat_version
152
+
153
+
154
+ def map_prompt_versions(self):
155
+ self.prompt_version_map = {
156
+ "Version 1": "prompt_v1_verbose",
157
+ }
158
+ self.prompt_version = self.prompt_version_map.get(self.prompt_version0, self.path_custom_prompts)
159
+ self.is_predefined_prompt = self.is_in_prompt_version_map(self.prompt_version)
160
+
161
+
162
+ def is_in_prompt_version_map(self, value):
163
+ return value in self.prompt_version_map.values()
164
+
165
+
166
+ def map_dir_labels(self):
167
+ if self.cfg['leafmachine']['use_RGB_label_images']:
168
+ self.dir_labels = os.path.join(self.Dirs.save_per_annotation_class,'label')
169
+ else:
170
+ self.dir_labels = self.Dirs.save_original
171
+
172
+ # Use glob to get all image paths in the directory
173
+ self.img_paths = glob.glob(os.path.join(self.dir_labels, "*"))
174
+
175
+
176
+ def load_rules_config(self):
177
+ with open(self.path_custom_prompts, 'r') as stream:
178
+ try:
179
+ return yaml.safe_load(stream)
180
+ except yaml.YAMLError as exc:
181
+ print(exc)
182
+ return None
183
+
184
+
185
+ def generate_xlsx_headers(self):
186
+ # Extract headers from the 'Dictionary' keys in the JSON template rules
187
+ # xlsx_headers = list(self.rules_config_json['rules']["Dictionary"].keys())
188
+ xlsx_headers = list(self.rules_config_json['rules'].keys())
189
+ xlsx_headers = xlsx_headers + self.utility_headers
190
+ return xlsx_headers
191
+
192
+
193
+ def init_transcription_xlsx(self):
194
+ # Initialize output file
195
+ self.path_transcription = os.path.join(self.Dirs.transcription,"transcribed.xlsx")
196
+
197
+ # else:
198
+ if not self.is_predefined_prompt:
199
+ # Load the rules configuration
200
+ self.rules_config_json = self.load_rules_config()
201
+ # Generate the headers from the configuration
202
+ self.headers = self.generate_xlsx_headers()
203
+ # Set the headers used to the dynamically generated headers
204
+ self.headers_used = 'CUSTOM'
205
+ else:
206
+ # If it's a predefined prompt, raise an exception as we don't have further instructions
207
+ raise ValueError("Predefined prompt is not handled in this context.")
208
+
209
+ self.create_or_load_excel_with_headers(os.path.join(self.Dirs.transcription,"transcribed.xlsx"), self.headers)
210
+
211
+
212
+ def create_or_load_excel_with_headers(self, file_path, headers, show_head=False):
213
+ output_dir_names = ['Archival_Components', 'Config_File', 'Cropped_Images', 'Logs', 'Original_Images', 'Transcription']
214
+ self.completed_specimens = []
215
+
216
+ # Check if the file exists and it's not None
217
+ if self.continue_run_from_partial_xlsx is not None and os.path.isfile(self.continue_run_from_partial_xlsx):
218
+ workbook = load_workbook(filename=self.continue_run_from_partial_xlsx)
219
+ sheet = workbook.active
220
+ show_head=True
221
+ # Identify the 'path_to_crop' column
222
+ try:
223
+ path_to_crop_col = headers.index('path_to_crop') + 1
224
+ path_to_original_col = headers.index('path_to_original') + 1
225
+ path_to_content_col = headers.index('path_to_content') + 1
226
+ path_to_helper_col = headers.index('path_to_helper') + 1
227
+ # self.completed_specimens = list(sheet.iter_cols(min_col=path_to_crop_col, max_col=path_to_crop_col, values_only=True, min_row=2))
228
+ except ValueError:
229
+ print("'path_to_crop' not found in the header row.")
230
+
231
+ path_to_crop = list(sheet.iter_cols(min_col=path_to_crop_col, max_col=path_to_crop_col, values_only=True, min_row=2))
232
+ path_to_original = list(sheet.iter_cols(min_col=path_to_original_col, max_col=path_to_original_col, values_only=True, min_row=2))
233
+ path_to_content = list(sheet.iter_cols(min_col=path_to_content_col, max_col=path_to_content_col, values_only=True, min_row=2))
234
+ path_to_helper = list(sheet.iter_cols(min_col=path_to_helper_col, max_col=path_to_helper_col, values_only=True, min_row=2))
235
+ others = [path_to_crop_col, path_to_original_col, path_to_content_col, path_to_helper_col]
236
+ jsons = [path_to_content_col, path_to_helper_col]
237
+
238
+ for cell in path_to_crop[0]:
239
+ old_path = cell
240
+ new_path = file_path
241
+ for dir_name in output_dir_names:
242
+ if dir_name in old_path:
243
+ old_path_parts = old_path.split(dir_name)
244
+ new_path_parts = new_path.split('Transcription')
245
+ updated_path = new_path_parts[0] + dir_name + old_path_parts[1]
246
+ self.completed_specimens.append(os.path.basename(updated_path))
247
+ print(f"{len(self.completed_specimens)} images are already completed")
248
+
249
+ ### Copy the JSON files over
250
+ for colu in jsons:
251
+ cell = next(sheet.iter_rows(min_row=2, min_col=colu, max_col=colu))[0]
252
+ old_path = cell.value
253
+ new_path = file_path
254
+
255
+ old_path_parts = old_path.split('Transcription')
256
+ new_path_parts = new_path.split('Transcription')
257
+ updated_path = new_path_parts[0] + 'Transcription' + old_path_parts[1]
258
+
259
+ # Copy files
260
+ old_dir = os.path.dirname(old_path)
261
+ new_dir = os.path.dirname(updated_path)
262
+
263
+ # Check if old_dir exists and it's a directory
264
+ if os.path.exists(old_dir) and os.path.isdir(old_dir):
265
+ # Check if new_dir exists. If not, create it.
266
+ if not os.path.exists(new_dir):
267
+ os.makedirs(new_dir)
268
+
269
+ # Iterate through all files in old_dir and copy each to new_dir
270
+ for filename in os.listdir(old_dir):
271
+ shutil.copy2(os.path.join(old_dir, filename), new_dir) # copy2 preserves metadata
272
+
273
+ ### Update the file names
274
+ for colu in others:
275
+ for row in sheet.iter_rows(min_row=2, min_col=colu, max_col=colu):
276
+ for cell in row:
277
+ old_path = cell.value
278
+ new_path = file_path
279
+ for dir_name in output_dir_names:
280
+ if dir_name in old_path:
281
+ old_path_parts = old_path.split(dir_name)
282
+ new_path_parts = new_path.split('Transcription')
283
+ updated_path = new_path_parts[0] + dir_name + old_path_parts[1]
284
+ cell.value = updated_path
285
+ show_head=True
286
+
287
+
288
+ else:
289
+ # Create a new workbook and select the active worksheet
290
+ workbook = Workbook()
291
+ sheet = workbook.active
292
+
293
+ # Write headers in the first row
294
+ for i, header in enumerate(headers, start=1):
295
+ sheet.cell(row=1, column=i, value=header)
296
+ self.completed_specimens = []
297
+
298
+ # Save the workbook
299
+ workbook.save(file_path)
300
+
301
+ if show_head:
302
+ print("continue_run_from_partial_xlsx:")
303
+ for i, row in enumerate(sheet.iter_rows(values_only=True)):
304
+ print(row)
305
+ if i == 3: # print the first 5 rows (0-indexed)
306
+ print("\n")
307
+ break
308
+
309
+
310
+ def add_data_to_excel_from_response(self, Dirs, path_transcription, response, WFO_record, GEO_record, usage_report,
311
+ MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
312
+
313
+
314
+ wb = openpyxl.load_workbook(path_transcription)
315
+ sheet = wb.active
316
+
317
+ # find the next empty row
318
+ next_row = sheet.max_row + 1
319
+
320
+ if isinstance(response, str):
321
+ try:
322
+ response = json.loads(response)
323
+ except json.JSONDecodeError:
324
+ print(f"Failed to parse response: {response}")
325
+ return
326
+
327
+ # iterate over headers in the first row
328
+ for i, header in enumerate(sheet[1], start=1):
329
+ # check if header value is in response keys
330
+ if (header.value in response) and (header.value not in self.catalog_name_options): ####################### Catalog Number pre-defined
331
+ # check if the response value is a dictionary
332
+ if isinstance(response[header.value], dict):
333
+ # if it is a dictionary, extract the 'value' field
334
+ cell_value = response[header.value].get('value', '')
335
+ else:
336
+ # if it's not a dictionary, use it directly
337
+ cell_value = response[header.value]
338
+
339
+ try:
340
+ # write the value to the cell
341
+ sheet.cell(row=next_row, column=i, value=cell_value)
342
+ except:
343
+ sheet.cell(row=next_row, column=i, value=cell_value[0])
344
+
345
+ elif header.value in self.catalog_name_options:
346
+ # if self.prefix_removal:
347
+ # filename_without_extension = filename_without_extension.replace(self.prefix_removal, "")
348
+ # if self.suffix_removal:
349
+ # filename_without_extension = filename_without_extension.replace(self.suffix_removal, "")
350
+ # if self.catalog_numerical_only:
351
+ # filename_without_extension = self.remove_non_numbers(filename_without_extension)
352
+ sheet.cell(row=next_row, column=i, value=filename_without_extension)
353
+ elif header.value == "path_to_crop":
354
+ sheet.cell(row=next_row, column=i, value=path_to_crop)
355
+ elif header.value == "path_to_original":
356
+ if self.cfg['leafmachine']['use_RGB_label_images']:
357
+ fname = os.path.basename(path_to_crop)
358
+ base = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(path_to_crop))))
359
+ path_to_original = os.path.join(base, 'Original_Images', fname)
360
+ sheet.cell(row=next_row, column=i, value=path_to_original)
361
+ else:
362
+ fname = os.path.basename(path_to_crop)
363
+ base = os.path.dirname(os.path.dirname(path_to_crop))
364
+ path_to_original = os.path.join(base, 'Original_Images', fname)
365
+ sheet.cell(row=next_row, column=i, value=path_to_original)
366
+ elif header.value == "path_to_content":
367
+ sheet.cell(row=next_row, column=i, value=path_to_content)
368
+ elif header.value == "path_to_helper":
369
+ sheet.cell(row=next_row, column=i, value=path_to_helper)
370
+ elif header.value == "tokens_in":
371
+ sheet.cell(row=next_row, column=i, value=nt_in)
372
+ elif header.value == "tokens_out":
373
+ sheet.cell(row=next_row, column=i, value=nt_out)
374
+ elif header.value == "filename":
375
+ sheet.cell(row=next_row, column=i, value=filename_without_extension)
376
+ elif header.value == "prompt":
377
+ sheet.cell(row=next_row, column=i, value=os.path.basename(self.path_custom_prompts))
378
+ elif header.value == "run_name":
379
+ sheet.cell(row=next_row, column=i, value=Dirs.run_name)
380
+ elif header.value == "LM2_collage":
381
+ sheet.cell(row=next_row, column=i, value=self.cfg['leafmachine']['use_RGB_label_images'])
382
+ elif header.value == "OCR_method":
383
+ value_to_insert = self.cfg['leafmachine']['project']['OCR_option']
384
+ if isinstance(value_to_insert, list):
385
+ value_to_insert = '|'.join(map(str, value_to_insert))
386
+ sheet.cell(row=next_row, column=i, value=value_to_insert)
387
+ elif header.value == "OCR_double":
388
+ sheet.cell(row=next_row, column=i, value=self.cfg['leafmachine']['project']['double_OCR'])
389
+ elif header.value == "OCR_trOCR":
390
+ sheet.cell(row=next_row, column=i, value=self.cfg['leafmachine']['project']['do_use_trOCR'])
391
+ # "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
392
+ elif header.value in self.wfo_headers_no_lists:
393
+ sheet.cell(row=next_row, column=i, value=WFO_record.get(header.value, ''))
394
+ # elif header.value == "WFO_exact_match":
395
+ # sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_exact_match",''))
396
+ # elif header.value == "WFO_exact_match_name":
397
+ # sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_exact_match_name",''))
398
+ # elif header.value == "WFO_best_match":
399
+ # sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_best_match",''))
400
+ # elif header.value == "WFO_placement":
401
+ # sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_placement",''))
402
+ elif header.value == "WFO_candidate_names":
403
+ candidate_names = WFO_record.get("WFO_candidate_names", '')
404
+ # Check if candidate_names is a list and convert to a string if it is
405
+ if isinstance(candidate_names, list):
406
+ candidate_names_str = '|'.join(candidate_names)
407
+ else:
408
+ candidate_names_str = candidate_names
409
+ sheet.cell(row=next_row, column=i, value=candidate_names_str)
410
+
411
+ # "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat", "GEO_decimal_long",
412
+ # "GEO_city", "GEO_county", "GEO_state", "GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent"
413
+ elif header.value in self.geo_headers:
414
+ sheet.cell(row=next_row, column=i, value=GEO_record.get(header.value, ''))
415
+
416
+ elif header.value in self.usage_headers:
417
+ sheet.cell(row=next_row, column=i, value=usage_report.get(header.value, ''))
418
+
419
+ elif header.value == "LLM":
420
+ sheet.cell(row=next_row, column=i, value=MODEL_NAME_FORMATTED)
421
+
422
+ # save the workbook
423
+ wb.save(path_transcription)
424
+
425
+
426
+ def has_API_key(self, val):
427
+ return isinstance(val, str) and bool(val.strip())
428
+ # if val != '':
429
+ # return True
430
+ # else:
431
+ # return False
432
+
433
+
434
+ def get_google_credentials(self): # Also used for google drive
435
+ if self.is_hf:
436
+ creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
437
+ credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
438
+ return credentials
439
+ else:
440
+ with open(self.cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS'], 'r') as file:
441
+ data = json.load(file)
442
+ creds_json_str = json.dumps(data)
443
+ credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
444
+ os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = creds_json_str
445
+ return credentials
446
+
447
+
448
+ def set_API_keys(self):
449
+ if self.is_hf:
450
+ self.dir_home = os.path.dirname(os.path.dirname(__file__))
451
+ self.path_cfg_private = None
452
+ self.cfg_private = None
453
+
454
+ k_openai = os.getenv('OPENAI_API_KEY')
455
+ k_openai_azure = os.getenv('AZURE_API_VERSION')
456
+
457
+ k_google_project_id = os.getenv('GOOGLE_PROJECT_ID')
458
+ k_google_location = os.getenv('GOOGLE_LOCATION')
459
+ k_google_application_credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
460
+
461
+ k_mistral = os.getenv('MISTRAL_API_KEY')
462
+ k_here = os.getenv('HERE_API_KEY')
463
+ k_opencage = os.getenv('open_cage_geocode')
464
+ else:
465
+ self.dir_home = os.path.dirname(os.path.dirname(__file__))
466
+ self.path_cfg_private = os.path.join(self.dir_home, 'PRIVATE_DATA.yaml')
467
+ self.cfg_private = get_cfg_from_full_path(self.path_cfg_private)
468
+
469
+ k_openai = self.cfg_private['openai']['OPENAI_API_KEY']
470
+ k_openai_azure = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE']
471
+
472
+ k_google_project_id = self.cfg_private['google']['GOOGLE_PROJECT_ID']
473
+ k_google_location = self.cfg_private['google']['GOOGLE_LOCATION']
474
+ k_google_application_credentials = self.cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS']
475
+
476
+ k_mistral = self.cfg_private['mistral']['MISTRAL_API_KEY']
477
+ k_here = self.cfg_private['here']['API_KEY']
478
+ k_opencage = self.cfg_private['open_cage_geocode']['API_KEY']
479
+
480
+
481
+
482
+ self.has_key_openai = self.has_API_key(k_openai)
483
+ self.has_key_azure_openai = self.has_API_key(k_openai_azure)
484
+ self.llm = None
485
+
486
+ self.has_key_google_project_id = self.has_API_key(k_google_project_id)
487
+ self.has_key_google_location = self.has_API_key(k_google_location)
488
+ self.has_key_google_application_credentials = self.has_API_key(k_google_application_credentials)
489
+
490
+ self.has_key_mistral = self.has_API_key(k_mistral)
491
+ self.has_key_here = self.has_API_key(k_here)
492
+ self.has_key_open_cage_geocode = self.has_API_key(k_opencage)
493
+
494
+
495
+
496
+ ### Google - OCR, Palm2, Gemini
497
+ if self.has_key_google_application_credentials and self.has_key_google_project_id and self.has_key_google_location:
498
+ if self.is_hf:
499
+ vertexai.init(project=os.getenv('GOOGLE_PROJECT_ID'), location=os.getenv('GOOGLE_LOCATION'), credentials=self.get_google_credentials())
500
+ else:
501
+ vertexai.init(project=k_google_project_id, location=k_google_location, credentials=self.get_google_credentials())
502
+ os.environ['GOOGLE_API_KEY'] = self.cfg_private['google']['GOOGLE_PALM_API']
503
+
504
+
505
+ ### OpenAI
506
+ if self.has_key_openai:
507
+ if self.is_hf:
508
+ openai.api_key = os.getenv('OPENAI_API_KEY')
509
+ else:
510
+ openai.api_key = self.cfg_private['openai']['OPENAI_API_KEY']
511
+ os.environ["OPENAI_API_KEY"] = self.cfg_private['openai']['OPENAI_API_KEY']
512
+
513
+
514
+ ### OpenAI - Azure
515
+ if self.has_key_azure_openai:
516
+ if self.is_hf:
517
+ # Initialize the Azure OpenAI client
518
+ self.llm = AzureChatOpenAI(
519
+ deployment_name = 'gpt-35-turbo',#'gpt-35-turbo',
520
+ openai_api_version = os.getenv('AZURE_API_VERSION'),
521
+ openai_api_key = os.getenv('AZURE_API_KEY'),
522
+ azure_endpoint = os.getenv('AZURE_API_BASE'),
523
+ openai_organization = os.getenv('AZURE_ORGANIZATION'),
524
+ )
525
+
526
+ else:
527
+ # Initialize the Azure OpenAI client
528
+ self.llm = AzureChatOpenAI(
529
+ deployment_name = 'gpt-35-turbo',#'gpt-35-turbo',
530
+ openai_api_version = self.cfg_private['openai_azure']['OPENAI_API_VERSION'],
531
+ openai_api_key = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE'],
532
+ azure_endpoint = self.cfg_private['openai_azure']['OPENAI_API_BASE'],
533
+ openai_organization = self.cfg_private['openai_azure']['OPENAI_ORGANIZATION'],
534
+ )
535
+
536
+
537
+ ### Mistral
538
+ if self.has_key_mistral:
539
+ if self.is_hf:
540
+ pass # Already set
541
+ else:
542
+ os.environ['MISTRAL_API_KEY'] = self.cfg_private['mistral']['MISTRAL_API_KEY']
543
+
544
+
545
+ ### HERE
546
+ if self.has_key_here:
547
+ if self.is_hf:
548
+ pass # Already set
549
+ else:
550
+ os.environ['HERE_APP_ID'] = self.cfg_private['here']['APP_ID']
551
+ os.environ['HERE_API_KEY'] = self.cfg_private['here']['API_KEY']
552
+
553
+
554
+ ### HERE
555
+ if self.has_key_open_cage_geocode:
556
+ if self.is_hf:
557
+ pass # Already set
558
+ else:
559
+ os.environ['OPENCAGE_API_KEY'] = self.cfg_private['open_cage_geocode']['API_KEY']
560
+
561
+
562
+
563
+ def clean_catalog_number(self, data, filename_without_extension):
564
+ #Cleans up the catalog number in data if it's a dict
565
+
566
+ def modify_catalog_key(catalog_key, filename_without_extension, data):
567
+ # Helper function to apply modifications on catalog number
568
+ if catalog_key not in data:
569
+ new_data = {catalog_key: None}
570
+ data = {**new_data, **data}
571
+
572
+ if self.prefix_removal:
573
+ filename_without_extension = filename_without_extension.replace(self.prefix_removal, "")
574
+ if self.suffix_removal:
575
+ filename_without_extension = filename_without_extension.replace(self.suffix_removal, "")
576
+ if self.catalog_numerical_only:
577
+ filename_without_extension = self.remove_non_numbers(data[catalog_key])
578
+ data[catalog_key] = filename_without_extension
579
+ return data
580
+
581
+ if isinstance(data, dict):
582
+ if self.headers_used == 'HEADERS_v1_n22':
583
+ return modify_catalog_key("Catalog Number", filename_without_extension, data)
584
+ elif self.headers_used in ['HEADERS_v2_n26', 'CUSTOM']:
585
+ return modify_catalog_key("filename", filename_without_extension, data)
586
+ else:
587
+ raise ValueError("Invalid headers used.")
588
+ else:
589
+ raise TypeError("Data is not of type dict.")
590
+
591
+
592
+ def write_json_to_file(self, filepath, data):
593
+ '''Writes dictionary data to a JSON file.'''
594
+ with open(filepath, 'w') as txt_file:
595
+ if isinstance(data, dict):
596
+ data = json.dumps(data, indent=4, sort_keys=False)
597
+ txt_file.write(data)
598
+
599
+
600
+ # def create_null_json(self):
601
+ # return {}
602
+
603
+
604
+ def remove_non_numbers(self, s):
605
+ return ''.join([char for char in s if char.isdigit()])
606
+
607
+
608
+ def create_null_row(self, filename_without_extension, path_to_crop, path_to_content, path_to_helper):
609
+ json_dict = {header: '' for header in self.headers}
610
+ for header, value in json_dict.items():
611
+ if header == "path_to_crop":
612
+ json_dict[header] = path_to_crop
613
+ elif header == "path_to_original":
614
+ fname = os.path.basename(path_to_crop)
615
+ base = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(path_to_crop))))
616
+ path_to_original = os.path.join(base, 'Original_Images', fname)
617
+ json_dict[header] = path_to_original
618
+ elif header == "path_to_content":
619
+ json_dict[header] = path_to_content
620
+ elif header == "path_to_helper":
621
+ json_dict[header] = path_to_helper
622
+ elif header == "filename":
623
+ json_dict[header] = filename_without_extension
624
+
625
+ # "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
626
+ elif header == "WFO_exact_match":
627
+ json_dict[header] =''
628
+ elif header == "WFO_exact_match_name":
629
+ json_dict[header] = ''
630
+ elif header == "WFO_best_match":
631
+ json_dict[header] = ''
632
+ elif header == "WFO_candidate_names":
633
+ json_dict[header] = ''
634
+ elif header == "WFO_placement":
635
+ json_dict[header] = ''
636
+ return json_dict
637
+
638
+
639
+ ##################################################################################################################################
640
+ ################################################## OCR ##################################################################
641
+ ##################################################################################################################################
642
+ def perform_OCR_and_save_results(self, image_index, json_report, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds):
643
+ self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Starting OCR')
644
+ # self.OCR - None
645
+
646
+ ### Process_image() runs the OCR for text, handwriting, trOCR AND creates the overlay image
647
+ ocr_google = OCREngine(self.logger, json_report, self.dir_home, self.is_hf, self.path_to_crop, self.cfg, self.trOCR_model_version, self.trOCR_model, self.trOCR_processor, self.device)
648
+ ocr_google.process_image(self.do_create_OCR_helper_image, self.logger)
649
+ self.OCR = ocr_google.OCR
650
+ self.logger.info(f"Complete OCR text for LLM prompt:\n\n{self.OCR}\n\n")
651
+
652
+ self.write_json_to_file(txt_file_path_OCR, ocr_google.OCR_JSON_to_file)
653
+
654
+ self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Finished OCR')
655
+
656
+ if len(self.OCR) > 0:
657
+ ocr_google.overlay_image.save(jpg_file_path_OCR_helper)
658
+
659
+ OCR_bounds = {}
660
+ if ocr_google.hand_text_to_box_mapping is not None:
661
+ OCR_bounds['OCR_bounds_handwritten'] = ocr_google.hand_text_to_box_mapping
662
+
663
+ if ocr_google.normal_text_to_box_mapping is not None:
664
+ OCR_bounds['OCR_bounds_printed'] = ocr_google.normal_text_to_box_mapping
665
+
666
+ if ocr_google.trOCR_text_to_box_mapping is not None:
667
+ OCR_bounds['OCR_bounds_trOCR'] = ocr_google.trOCR_text_to_box_mapping
668
+
669
+ self.write_json_to_file(txt_file_path_OCR_bounds, OCR_bounds)
670
+ self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Saved OCR Overlay Image')
671
+ else:
672
+ pass ########################################################################################################################### fix logic for no OCR
673
+
674
+ ##################################################################################################################################
675
+ ####################################################### LLM Switchboard ########################################################
676
+ ##################################################################################################################################
677
+ def send_to_LLM(self, is_azure, progress_report, json_report, model_name):
678
+ self.n_failed_LLM_calls = 0
679
+ self.n_failed_OCR = 0
680
+
681
+ final_JSON_response = None
682
+ final_WFO_record = None
683
+ final_GEO_record = None
684
+
685
+ self.initialize_token_counters()
686
+ self.update_progress_report_initial(progress_report)
687
+
688
+ MODEL_NAME_FORMATTED = ModelMaps.get_API_name(model_name)
689
+ name_parts = model_name.split("_")
690
+
691
+ self.setup_JSON_dict_structure()
692
+
693
+ Copy_Prompt = PromptCatalog()
694
+ Copy_Prompt.copy_prompt_template_to_new_dir(self.Dirs.transcription_prompt, self.path_custom_prompts)
695
+
696
+ if json_report:
697
+ json_report.set_text(text_main=f'Loading {MODEL_NAME_FORMATTED}')
698
+ json_report.set_JSON({}, {}, {})
699
+ # llm_model = self.initialize_llm_model(self.cfg, self.logger, MODEL_NAME_FORMATTED, self.JSON_dict_structure, name_parts, is_azure, self.llm, self.config_vals_for_permutation)
700
+
701
+ results_queue = Queue()
702
+
703
+ if json_report:
704
+ json_report.set_text(text_main='Sending batch to OCR and LLM')
705
+
706
+ num_files = len(self.img_paths)
707
+ # num_threads = min(num_files, 128)
708
+ num_threads = 128
709
+ counter = AtomicCounter()
710
+
711
+ # Setup for parallel execution
712
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
713
+ futures = [executor.submit(self.send_to_LLM_worker,
714
+ path_to_crop,
715
+ results_queue,
716
+ model_name,
717
+ MODEL_NAME_FORMATTED,
718
+ name_parts,
719
+ is_azure,
720
+ i
721
+ ) for i, path_to_crop in enumerate(self.img_paths)]
722
+ for future in tqdm(as_completed(futures), total=len(futures), desc="Processing", unit="task"):
723
+ try:
724
+ # Here, you could also directly process results if they were not being put in a queue
725
+ future.result() # Forces a wait on the future and re-raises any exceptions
726
+ new_value = counter.inc()
727
+ try:
728
+ if json_report:
729
+ current_value = counter.value
730
+ json_report.set_text(text_main=f'Completed {current_value} of {num_files}')
731
+ except:
732
+ pass
733
+ except Exception as e:
734
+ # Log the error, possibly mark the task for retry, or handle it as necessary
735
+ print(f"A task failed with exception: {e}")
736
+
737
+ # Process results from the queue
738
+ while not results_queue.empty():
739
+ response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report, path_to_crop, paths = results_queue.get()
740
+
741
+ self.n_failed_LLM_calls += 1 if response_candidate is None else 0
742
+
743
+ ### Estimate n tokens returned
744
+ self.logger.info(f'Prompt tokens IN --- {nt_in}')
745
+ self.logger.info(f'Prompt tokens OUT --- {nt_out}')
746
+
747
+ self.update_token_counters(nt_in, nt_out)
748
+
749
+ final_JSON_response, final_WFO_record, final_GEO_record = self.update_final_response(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, paths, path_to_crop, nt_in, nt_out)
750
+
751
+ self.logger.info(f'Finished LLM call')
752
+
753
+ if json_report:
754
+ json_report.set_JSON(final_JSON_response, final_WFO_record, final_GEO_record)
755
+
756
+ if json_report:
757
+ json_report.set_text(text_main='Finished!')
758
+
759
+ self.update_progress_report_final(progress_report)
760
+ final_JSON_response = self.parse_final_json_response(final_JSON_response)
761
+ return final_JSON_response, final_WFO_record, final_GEO_record, self.total_tokens_in, self.total_tokens_out
762
+
763
+ def send_to_LLM_worker(self, path_to_crop, queue, model_name, MODEL_NAME_FORMATTED, name_parts, is_azure, i):
764
+ llm_model = self.initialize_llm_model(self.cfg, self.logger, MODEL_NAME_FORMATTED, self.JSON_dict_structure, name_parts, is_azure, self.llm, self.config_vals_for_permutation)
765
+
766
+ # self.update_progress_report_batch(progress_report, i)
767
+
768
+ if self.should_skip_specimen(path_to_crop):
769
+ self.log_skipping_specimen(path_to_crop)
770
+ return
771
+
772
+ paths = self.generate_paths(path_to_crop, i)
773
+ self.path_to_crop = path_to_crop
774
+
775
+ filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
776
+ # if json_report:
777
+ # json_report.set_text(text_main='Starting OCR')
778
+ self.perform_OCR_and_save_results(i, None, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds)
779
+ # if json_report:
780
+ # json_report.set_text(text_main='Finished OCR')
781
+
782
+ if not self.OCR:
783
+ self.n_failed_OCR += 1
784
+ response_candidate = None
785
+ nt_in = 0
786
+ nt_out = 0
787
+ else:
788
+ ### Format prompt
789
+ prompt = self.setup_prompt()
790
+ # prompt = remove_colons_and_double_apostrophes(prompt) # This is moved to utils_VV since it broke the json structure.
791
+
792
+ ### Send prompt to chosen LLM
793
+ self.logger.info(f'Waiting for {model_name} API call --- Using {MODEL_NAME_FORMATTED}')
794
+
795
+ if 'PALM2' in name_parts:
796
+ response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_GooglePalm2(prompt, None, paths)
797
+
798
+ elif 'GEMINI' in name_parts:
799
+ response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_GoogleGemini(prompt, None, paths)
800
+
801
+ elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
802
+ response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_MistralAI(prompt, None, paths)
803
+
804
+ elif 'LOCAL' in name_parts:
805
+ if 'MISTRAL' in name_parts or 'MIXTRAL' in name_parts:
806
+ if 'CPU' in name_parts:
807
+ response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_cpu_MistralAI(prompt, None, paths)
808
+ else:
809
+ response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_MistralAI(prompt, None, paths)
810
+ else:
811
+ response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_OpenAI(prompt, None, paths)
812
+
813
+ # Instead of directly updating shared resources, put the structured result in the queue
814
+ queue.put((response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report, path_to_crop, paths))
815
+
816
+ ##################################################################################################################################
817
+ ################################################## LLM Helper Funcs ##############################################################
818
+ ##################################################################################################################################
819
+ def initialize_llm_model(self, cfg, logger, model_name, JSON_dict_structure, name_parts, is_azure=None, llm_object=None, config_vals_for_permutation=None):
820
+ if 'LOCAL'in name_parts:
821
+ if ('MIXTRAL' in name_parts) or ('MISTRAL' in name_parts):
822
+ if 'CPU' in name_parts:
823
+ return LocalCPUMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
824
+ else:
825
+ return LocalMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
826
+ else:
827
+ if 'PALM2' in name_parts:
828
+ return GooglePalm2Handler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
829
+ elif 'GEMINI' in name_parts:
830
+ return GoogleGeminiHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
831
+ elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
832
+ return MistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
833
+ else:
834
+ return OpenAIHandler(cfg, logger, model_name, JSON_dict_structure, is_azure, llm_object, config_vals_for_permutation)
835
+
836
+ def setup_prompt(self):
837
+ Catalog = PromptCatalog()
838
+ prompt, _ = Catalog.prompt_SLTP(self.path_custom_prompts, OCR=self.OCR)
839
+ return prompt
840
+
841
+ def setup_JSON_dict_structure(self):
842
+ Catalog = PromptCatalog()
843
+ _, self.JSON_dict_structure = Catalog.prompt_SLTP(self.path_custom_prompts, OCR='Text')
844
+
845
+
846
+ def initialize_token_counters(self):
847
+ self.total_tokens_in = 0
848
+ self.total_tokens_out = 0
849
+
850
+
851
+ def update_progress_report_initial(self, progress_report):
852
+ if progress_report is not None:
853
+ progress_report.set_n_batches(len(self.img_paths))
854
+
855
+
856
+ def update_progress_report_batch(self, progress_report, batch_index):
857
+ if progress_report is not None:
858
+ progress_report.update_batch(f"Working on image {batch_index + 1} of {len(self.img_paths)}")
859
+
860
+
861
+ def should_skip_specimen(self, path_to_crop):
862
+ return os.path.basename(path_to_crop) in self.completed_specimens
863
+
864
+
865
+ def log_skipping_specimen(self, path_to_crop):
866
+ self.logger.info(f'[Skipping] specimen {os.path.basename(path_to_crop)} already processed')
867
+
868
+
869
+ def update_token_counters(self, nt_in, nt_out):
870
+ self.total_tokens_in += nt_in
871
+ self.total_tokens_out += nt_out
872
+
873
+
874
+ def update_final_response(self, response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, paths, path_to_crop, nt_in, nt_out):
875
+ filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
876
+ # Saving the JSON and XLSX files with the response and updating the final JSON response
877
+ if response_candidate is not None:
878
+ final_JSON_response_updated = self.save_json_and_xlsx(self.Dirs, response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
879
+ return final_JSON_response_updated, WFO_record, GEO_record
880
+ else:
881
+ final_JSON_response_updated = self.save_json_and_xlsx(self.Dirs, response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
882
+ return final_JSON_response_updated, WFO_record, GEO_record
883
+
884
+
885
+ def update_progress_report_final(self, progress_report):
886
+ if progress_report is not None:
887
+ progress_report.reset_batch("Batch Complete")
888
+
889
+
890
+ def parse_final_json_response(self, final_JSON_response):
891
+ try:
892
+ return json.loads(final_JSON_response.strip('```').replace('json\n', '', 1).replace('json', '', 1))
893
+ except:
894
+ return final_JSON_response
895
+
896
+
897
+
898
+ def generate_paths(self, path_to_crop, i):
899
+ filename_without_extension = os.path.splitext(os.path.basename(path_to_crop))[0]
900
+ txt_file_path = os.path.join(self.Dirs.transcription_ind, filename_without_extension + '.json')
901
+ txt_file_path_OCR = os.path.join(self.Dirs.transcription_ind_OCR, filename_without_extension + '.json')
902
+ txt_file_path_OCR_bounds = os.path.join(self.Dirs.transcription_ind_OCR_bounds, filename_without_extension + '.json')
903
+ jpg_file_path_OCR_helper = os.path.join(self.Dirs.transcription_ind_OCR_helper, filename_without_extension + '.jpg')
904
+ json_file_path_wiki = os.path.join(self.Dirs.transcription_ind_wiki, filename_without_extension + '.json')
905
+ txt_file_path_ind_prompt = os.path.join(self.Dirs.transcription_ind_prompt, filename_without_extension + '.txt')
906
+
907
+ self.logger.info(f'Working on {i+1}/{len(self.img_paths)} --- {filename_without_extension}')
908
+
909
+ return filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt
910
+
911
+
912
+ def save_json_and_xlsx(self, Dirs, response, WFO_record, GEO_record, usage_report,
913
+ MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out):
914
+ if response is None:
915
+ response = self.JSON_dict_structure
916
+ # Insert 'filename' as the first key
917
+ response = {'filename': filename_without_extension, **{k: v for k, v in response.items() if k != 'filename'}}
918
+ self.write_json_to_file(txt_file_path, response)
919
+
920
+ # Then add the null info to the spreadsheet
921
+ response_null = self.create_null_row(filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper)
922
+ self.add_data_to_excel_from_response(Dirs, self.path_transcription, response_null, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in=0, nt_out=0)
923
+
924
+ ### Set completed JSON
925
+ else:
926
+ response = self.clean_catalog_number(response, filename_without_extension)
927
+ self.write_json_to_file(txt_file_path, response)
928
+ # add to the xlsx file
929
+ self.add_data_to_excel_from_response(Dirs, self.path_transcription, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
930
+ return response
931
+
932
+
933
+ def process_specimen_batch(self, progress_report, json_report, is_real_run=False):
934
+ if not self.has_key:
935
+ self.logger.error(f'No API key found for {self.version_name}')
936
+ raise Exception(f"No API key found for {self.version_name}")
937
+
938
+ try:
939
+ if is_real_run:
940
+ progress_report.update_overall(f"Transcribing Labels")
941
+
942
+ final_json_response, final_WFO_record, final_GEO_record, total_tokens_in, total_tokens_out = self.send_to_LLM(self.is_azure, progress_report, json_report, self.model_name)
943
+
944
+ return final_json_response, final_WFO_record, final_GEO_record, total_tokens_in, total_tokens_out
945
+
946
+ except Exception as e:
947
+ self.logger.error(f"LLM call failed in process_specimen_batch: {e}")
948
+ if progress_report is not None:
949
+ progress_report.reset_batch(f"Batch Failed")
950
+ self.close_logger_handlers()
951
+ raise
952
+
953
+
954
+ def close_logger_handlers(self):
955
+ for handler in self.logger.handlers[:]:
956
+ handler.close()
957
+ self.logger.removeHandler(handler)
958
+
959
+
960
+ # def process_specimen_batch_OCR_test(self, path_to_crop):
961
+ # for img_filename in os.listdir(path_to_crop):
962
+ # img_path = os.path.join(path_to_crop, img_filename)
963
+ # self.OCR, self.bounds, self.text_to_box_mapping = detect_text(img_path)
964
+
965
+ # https://gist.github.com/benhoyt/8c8a8d62debe8e5aa5340373f9c509c7
966
+ class AtomicCounter(object):
967
+ """An atomic, thread-safe counter"""
968
+
969
+ def __init__(self, initial=0):
970
+ """Initialize a new atomic counter to given initial value"""
971
+ self._value = initial
972
+ self._lock = threading.Lock()
973
+
974
+ def inc(self, num=1):
975
+ """Atomically increment the counter by num and return the new value"""
976
+ with self._lock:
977
+ self._value += num
978
+ return self._value
979
+
980
+ def dec(self, num=1):
981
+ """Atomically decrement the counter by num and return the new value"""
982
+ with self._lock:
983
+ self._value -= num
984
+ return self._value
985
+
986
+ @property
987
+ def value(self):
988
+ return self._value
989
+
990
+
991
+ def space_saver(cfg, Dirs, logger):
992
+ dir_out = cfg['leafmachine']['project']['dir_output']
993
+ run_name = Dirs.run_name
994
+
995
+ path_project = os.path.join(dir_out, run_name)
996
+
997
+ if cfg['leafmachine']['project']['delete_temps_keep_VVE']:
998
+ logger.name = '[DELETE TEMP FILES]'
999
+ logger.info("Deleting temporary files. Keeping files required for VoucherVisionEditor.")
1000
+ delete_dirs = ['Archival_Components', 'Config_File']
1001
+ for d in delete_dirs:
1002
+ path_delete = os.path.join(path_project, d)
1003
+ if os.path.exists(path_delete):
1004
+ shutil.rmtree(path_delete)
1005
+
1006
+ elif cfg['leafmachine']['project']['delete_all_temps']:
1007
+ logger.name = '[DELETE TEMP FILES]'
1008
+ logger.info("Deleting ALL temporary files!")
1009
+ delete_dirs = ['Archival_Components', 'Config_File', 'Original_Images', 'Cropped_Images']
1010
+ for d in delete_dirs:
1011
+ path_delete = os.path.join(path_project, d)
1012
+ if os.path.exists(path_delete):
1013
+ shutil.rmtree(path_delete)
1014
+
1015
+ # Delete the transctiption folder, but keep the xlsx
1016
+ transcription_path = os.path.join(path_project, 'Transcription')
1017
+ if os.path.exists(transcription_path):
1018
+ for item in os.listdir(transcription_path):
1019
+ item_path = os.path.join(transcription_path, item)
1020
+ if os.path.isdir(item_path): # if the item is a directory
1021
+ if os.path.exists(item_path):
1022
+ shutil.rmtree(item_path) # delete the directory
vouchervision/vouchervision_main.py CHANGED
@@ -14,6 +14,7 @@ from vouchervision.data_project import Project_Info
14
  from vouchervision.LM2_logger import start_logging
15
  from vouchervision.fetch_data import fetch_data
16
  from vouchervision.utils_VoucherVision import VoucherVision, space_saver
 
17
  from vouchervision.utils_hf import upload_to_drive
18
 
19
  def voucher_vision(cfg_file_path, dir_home, path_custom_prompts, cfg_test, progress_report, json_report, path_api_cost=None, test_ind = None, is_hf = True, is_real_run=False):
 
14
  from vouchervision.LM2_logger import start_logging
15
  from vouchervision.fetch_data import fetch_data
16
  from vouchervision.utils_VoucherVision import VoucherVision, space_saver
17
+ # from vouchervision.utils_VoucherVision_parallel import VoucherVision, space_saver
18
  from vouchervision.utils_hf import upload_to_drive
19
 
20
  def voucher_vision(cfg_file_path, dir_home, path_custom_prompts, cfg_test, progress_report, json_report, path_api_cost=None, test_ind = None, is_hf = True, is_real_run=False):
vouchervision/vouchervision_test_all_options_analysis.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+ import seaborn as sns
4
+
5
+ def SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_GPT4_SHORT():
6
+ #####################
7
+ # Load the Excel file
8
+ file_path = 'D:/Dropbox/VoucherVision/demo/validation_output/summary/SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_GPT4_SHORT.xlsx'
9
+ save_path = 'D:/Dropbox/VoucherVision/demo/validation_output/figures/avg_L_score_analysis_SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_GPT4_SHORT.png'
10
+
11
+ df = pd.read_excel(file_path)
12
+
13
+ # Display the first few rows of the dataframe to understand its structure
14
+ df.head()
15
+
16
+ # Grouping by the parameters and calculating the mean of avg_L_score for each group
17
+ grouped = df.groupby(['v_prompt_version', 'v_double_ocr', 'temperature', 'top_p'])['avg_L_score'].mean().reset_index()
18
+
19
+ # Finding the group with the highest average L score
20
+ max_avg_L_score = grouped['avg_L_score'].max()
21
+ best_group = grouped[grouped['avg_L_score'] == max_avg_L_score]
22
+
23
+ print(best_group)
24
+
25
+
26
+ ### Viz
27
+ # Filtering the dataset for the conditions mentioned
28
+ filtered_df = df[df['v_prompt_version'] == 'SLTPvB_long.yaml'][df['v_double_ocr'] == True]
29
+
30
+ # Setting up the plotting
31
+ plt.figure(figsize=(14, 6))
32
+
33
+ # Plot 1: avg_L_score as a function of temperature for each top_p value
34
+ plt.subplot(1, 2, 1)
35
+ sns.lineplot(data=filtered_df, x='temperature', y='avg_L_score', hue='top_p', marker='o')
36
+ plt.title('Average L Score by Temperature for each Top P')
37
+ plt.xlabel('Temperature')
38
+ plt.ylabel('Average L Score')
39
+ plt.legend(title='Top P', bbox_to_anchor=(1.05, 1), loc='upper left')
40
+
41
+ # Plot 2: avg_L_score as a function of top_p for each temperature value
42
+ plt.subplot(1, 2, 2)
43
+ sns.lineplot(data=filtered_df, x='top_p', y='avg_L_score', hue='temperature', marker='o')
44
+ plt.title('Average L Score by Top P for each Temperature')
45
+ plt.xlabel('Top P')
46
+ plt.ylabel('Average L Score')
47
+ plt.legend(title='Temperature', bbox_to_anchor=(1.05, 1), loc='upper left')
48
+
49
+ plt.tight_layout()
50
+ plt.savefig(save_path, dpi=600)
51
+
52
+ def SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_google_SHORT():
53
+ #####################
54
+ # Load the Excel file
55
+ file_path = 'D:/Dropbox/VoucherVision/demo/validation_output/summary/SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_google_SHORT.xlsx'
56
+ save_path = 'D:/Dropbox/VoucherVision/demo/validation_output/figures/avg_L_score_analysis_SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_google_SHORT.png'
57
+
58
+ df = pd.read_excel(file_path)
59
+
60
+ # Display the first few rows of the dataframe to understand its structure
61
+ df.head()
62
+
63
+ # Grouping by the parameters and calculating the mean of avg_L_score for each group
64
+ grouped = df.groupby(['v_prompt_version', 'v_double_ocr', 'temperature', 'top_p'])['avg_L_score'].mean().reset_index()
65
+
66
+ # Finding the group with the highest average L score
67
+ max_avg_L_score = grouped['avg_L_score'].max()
68
+ best_group = grouped[grouped['avg_L_score'] == max_avg_L_score]
69
+
70
+ print(best_group)
71
+
72
+
73
+ ### Viz
74
+ # Filtering the dataset for the conditions mentioned
75
+ filtered_df = df[df['v_prompt_version'] == 'SLTPvB_long.yaml'][df['v_double_ocr'] == True]
76
+
77
+ # Setting up the plotting
78
+ plt.figure(figsize=(14, 6))
79
+
80
+ # Plot 1: avg_L_score as a function of temperature for each top_p value
81
+ plt.subplot(1, 2, 1)
82
+ sns.lineplot(data=filtered_df, x='temperature', y='avg_L_score', hue='top_p', marker='o')
83
+ plt.title('Average L Score by Temperature for each Top P')
84
+ plt.xlabel('Temperature')
85
+ plt.ylabel('Average L Score')
86
+ plt.legend(title='Top P', bbox_to_anchor=(1.05, 1), loc='upper left')
87
+
88
+ # Plot 2: avg_L_score as a function of top_p for each temperature value
89
+ plt.subplot(1, 2, 2)
90
+ sns.lineplot(data=filtered_df, x='top_p', y='avg_L_score', hue='temperature', marker='o')
91
+ plt.title('Average L Score by Top P for each Temperature')
92
+ plt.xlabel('Top P')
93
+ plt.ylabel('Average L Score')
94
+ plt.legend(title='Temperature', bbox_to_anchor=(1.05, 1), loc='upper left')
95
+
96
+ plt.tight_layout()
97
+ plt.savefig(save_path, dpi=600)
98
+
99
+ if __name__ == '__main__':
100
+ # SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_GPT4_SHORT()
101
+ SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_google_SHORT()
102
+
103
+
vouchervision/vouchervision_test_all_options_recipes.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, inspect, sys, shutil
2
+
3
+
4
+
5
+ class AllOptions():
6
+ a_llm = [
7
+ "GPT 4 Turbo 1106-preview",
8
+ "GPT 4 Turbo 0125-preview",
9
+ 'GPT 4',
10
+ 'GPT 4 32k',
11
+ 'GPT 3.5',
12
+ 'GPT 3.5 Instruct',
13
+
14
+ 'Azure GPT 3.5',
15
+ 'Azure GPT 3.5 Instruct',
16
+ 'Azure GPT 4',
17
+ 'Azure GPT 4 Turbo 1106-preview',
18
+ 'Azure GPT 4 Turbo 0125-preview',
19
+ 'Azure GPT 4 32k',
20
+
21
+ 'PaLM 2 text-bison@001',
22
+ 'PaLM 2 text-bison@002',
23
+ 'PaLM 2 text-unicorn@001',
24
+ 'Gemini Pro',
25
+
26
+ 'Mistral Small',
27
+ 'Mistral Medium',
28
+ 'Mistral Large',
29
+ 'Open Mixtral 8x7B',
30
+ 'Open Mistral 7B',
31
+
32
+ 'LOCAL Mixtral 8x7B Instruct v0.1',
33
+ 'LOCAL Mistral 7B Instruct v0.2',
34
+
35
+ 'LOCAL CPU Mistral 7B Instruct v0.2 GGUF',
36
+ ]
37
+
38
+ a_prompt_version = [
39
+ 'SLTPvA_long.yaml',
40
+ 'SLTPvA_medium.yaml',
41
+ 'SLTPvA_short.yaml',
42
+ 'SLTPvB_long.yaml',
43
+ 'SLTPvB_medium.yaml',
44
+ 'SLTPvB_short.yaml',
45
+ 'SLTPvB_minimal.yaml',
46
+ ]
47
+
48
+ a_LM2 = [False,] # [True, False]
49
+ a_do_use_trOCR = [False,] # [True, False]
50
+ a_trocr_path = ["microsoft/trocr-large-handwritten",]
51
+ a_ocr_option = [
52
+ 'hand',
53
+ 'normal',
54
+ 'CRAFT',
55
+ 'LLaVA',
56
+ ['hand','CRAFT'],
57
+ ['hand','LLaVA'],
58
+ ]
59
+ a_llava_option = ["llava-v1.6-mistral-7b",
60
+ "llava-v1.6-34b",
61
+ "llava-v1.6-vicuna-13b",
62
+ "llava-v1.6-vicuna-7b",]
63
+ a_llava_bit = ["full", "4bit",]
64
+ a_double_ocr = [True, False]
65
+
66
+
67
+
68
+
69
+ class Options_permute_llms_to_investigate_determinism_at_restrictive_settings():
70
+ a_llm = [
71
+ # "GPT 4 Turbo 1106-preview",
72
+ # "GPT 4 Turbo 0125-preview",
73
+ # 'GPT 4',
74
+ # # 'GPT 4 32k',
75
+ # 'GPT 3.5 Turbo',
76
+ # 'GPT 3.5 Instruct',
77
+
78
+ 'Azure GPT 3.5 Turbo',
79
+ 'Azure GPT 3.5 Instruct',
80
+ 'Azure GPT 4',
81
+ 'Azure GPT 4 Turbo 1106-preview',
82
+ 'Azure GPT 4 Turbo 0125-preview',
83
+ # 'Azure GPT 4 32k',
84
+
85
+ 'PaLM 2 text-bison@001',
86
+ 'PaLM 2 text-bison@002',
87
+ 'PaLM 2 text-unicorn@001',
88
+ 'Gemini Pro',
89
+
90
+ 'Mistral Small',
91
+ 'Mistral Medium',
92
+ 'Mistral Large',
93
+ # 'Open Mixtral 8x7B',
94
+ 'Open Mistral 7B',
95
+
96
+ # 'LOCAL Mixtral 8x7B Instruct v0.1',
97
+ # 'LOCAL Mistral 7B Instruct v0.2',
98
+
99
+ # 'LOCAL CPU Mistral 7B Instruct v0.2 GGUF',
100
+ ]
101
+
102
+ a_prompt_version = [
103
+ # 'SLTPvA_long.yaml',
104
+ # 'SLTPvA_short.yaml',
105
+ 'SLTPvB_long.yaml',
106
+ 'SLTPvB_short.yaml',
107
+ 'SLTPvB_minimal.yaml',
108
+ ]
109
+ a_double_ocr = [True, False]
110
+
111
+ ### BELOW ARE STATIC
112
+ a_LM2 = [False,]
113
+ # a_do_use_trOCR = [True, False]
114
+ a_do_use_trOCR = [False,]
115
+ # a_trocr_path = ["microsoft/trocr-large-handwritten","microsoft/trocr-base-handwritten",]
116
+ a_trocr_path = ["microsoft/trocr-large-handwritten",]
117
+ a_ocr_option = ['hand',]
118
+ a_llava_option = ["llava-v1.6-mistral-7b",]
119
+ a_llava_bit = ["full",]
120
+
121
+
122
+ class Options_permute_llms_to_sweep_temperature_and_topP_for_GPT4_0125():
123
+ a_llm = [
124
+ # 'Azure GPT 4 Turbo 0125-preview', #test 1
125
+ 'Azure GPT 4',
126
+ ]
127
+
128
+ a_prompt_version = [
129
+ # 'SLTPvA_long.yaml',
130
+ # 'SLTPvA_short.yaml',
131
+ 'SLTPvB_long.yaml',
132
+ 'SLTPvB_short.yaml',
133
+ # 'SLTPvB_minimal.yaml',
134
+ ]
135
+ a_double_ocr = [True, False]
136
+
137
+ ### BELOW ARE STATIC
138
+ a_LM2 = [False,]
139
+ # a_do_use_trOCR = [True, False]
140
+ a_do_use_trOCR = [False,]
141
+ # a_trocr_path = ["microsoft/trocr-large-handwritten","microsoft/trocr-base-handwritten",]
142
+ a_trocr_path = ["microsoft/trocr-large-handwritten",]
143
+ a_ocr_option = ['hand',]
144
+ a_llava_option = ["llava-v1.6-mistral-7b",]
145
+ a_llava_bit = ["full",]
146
+
147
+
148
+ class Options_permute_llms_to_sweep_temperature_and_topP_for_google():
149
+ a_llm = [
150
+ 'PaLM 2 text-bison@001',
151
+ 'PaLM 2 text-bison@002',
152
+ 'Gemini Pro',
153
+ ]
154
+
155
+ a_prompt_version = [
156
+ 'SLTPvB_long.yaml',
157
+ 'SLTPvB_short.yaml',
158
+ ]
159
+ a_double_ocr = [True, False]
160
+
161
+ ### BELOW ARE STATIC
162
+ a_LM2 = [False,]
163
+ a_do_use_trOCR = [False,] # [True, False]
164
+ a_trocr_path = ["microsoft/trocr-large-handwritten",]
165
+ a_ocr_option = ['hand',]
166
+ a_llava_option = ["llava-v1.6-mistral-7b",]
167
+ a_llava_bit = ["full",]
168
+
169
+ if __name__ == '__main__':
170
+ pass