phyloforfun commited on
Commit
ae215ea
1 Parent(s): 37a138a

Major update. Support for 15 LLMs, World Flora Online taxonomy validation, geolocation, 2 OCR methods, significant UI changes, stability improvements, consistent JSON parsing

Browse files
app.py CHANGED
@@ -7,6 +7,7 @@ import pandas as pd
7
  from io import BytesIO
8
  from streamlit_extras.let_it_rain import rain
9
  from annotated_text import annotated_text
 
10
 
11
  from vouchervision.LeafMachine2_Config_Builder import write_config_file
12
  from vouchervision.VoucherVision_Config_Builder import build_VV_config, TestOptionsGPT, TestOptionsPalm, check_if_usable
@@ -999,7 +1000,8 @@ def create_private_file():
999
  st.write("API keys are stored in `../VoucherVision/PRIVATE_DATA.yaml`.")
1000
  st.write("Deleting this file will allow you to reset API keys. Alternatively, you can edit the keys in the user interface or by manually editing the `.yaml` file in a text editor.")
1001
  st.write("Leave keys blank if you do not intend to use that service.")
1002
-
 
1003
  st.write("---")
1004
  st.subheader("Google Vision (*Required*) / Google PaLM 2 / Google Gemini")
1005
  st.markdown("VoucherVision currently uses [Google Vision API](https://cloud.google.com/vision/docs/ocr) for OCR. Generating an API key for this is more involved than the others. [Please carefully follow the instructions outlined here to create and setup your account.](https://cloud.google.com/vision/docs/setup) ")
@@ -1008,46 +1010,46 @@ def create_private_file():
1008
  with st.expander("**View Google API Instructions**"):
1009
 
1010
  blog_text_and_image(text="Select your project, then in the search bar, search for `vertex ai` and select the option in the photo below.",
1011
- fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_00.png'))
1012
 
1013
  blog_text_and_image(text="On the main overview page, click `Enable All Recommended APIs`. Sometimes this button may be hidden. In that case, enable all of the suggested APIs listed on this page.",
1014
- fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_0.png'))
1015
 
1016
  blog_text_and_image(text="Sometimes this button may be hidden. In that case, enable all of the suggested APIs listed on this page.",
1017
- fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_2.png'))
1018
 
1019
  blog_text_and_image(text="Make sure that all APIs are enabled.",
1020
- fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_1.png'))
1021
 
1022
  blog_text_and_image(text="Find the `Vision AI API` service and go to its page.",
1023
- fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_3.png'))
1024
 
1025
  blog_text_and_image(text="Find the `Vision AI API` service and go to its page. This is the API service required to use OCR in VoucherVision and must be enabled.",
1026
- fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_6.png'))
1027
 
1028
  blog_text_and_image(text="You can also search for the Vertex AI Vision service.",
1029
- fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_4.png'))
1030
 
1031
  blog_text_and_image(text=None,
1032
- fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_5.png'))
1033
 
1034
  st.subheader("Getting a Google JSON authentication key")
1035
  st.write("Google uses a JSON file to store additional authentication information. Save this file in a safe, private location and assign the `GOOGLE_APPLICATION_CREDENTIALS` value to the file path. For Hugging Face, copy the contents of the JSON file including the `\{\}` and paste it as the secret value.")
1036
  st.write("To download your JSON key...")
1037
  blog_text_and_image(text="Open the navigation menu. Click on the hamburger menu (three horizontal lines) in the top left corner. Go to IAM & Admin. ",
1038
- fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_7.png'),width=300)
1039
 
1040
  blog_text_and_image(text="In the navigation pane, hover over `IAM & Admin` and then click on `Service accounts`.",
1041
- fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_8.png'))
1042
 
1043
  blog_text_and_image(text="Find the default Compute Engine service account, select it.",
1044
- fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_9.png'))
1045
 
1046
  blog_text_and_image(text="Click `Add Key`.",
1047
- fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_10.png'))
1048
 
1049
  blog_text_and_image(text="Select `JSON` and click create. This will download your key. Store this in a safe location. The file path to this safe location is the value that you enter into the `GOOGLE_APPLICATION_CREDENTIALS` value.",
1050
- fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_11.png'))
1051
 
1052
  blog_text(text_bold="Store Safely", text=": This file contains sensitive data that can be used to authenticate and bill your Google Cloud account. Never commit it to public repositories or expose it in any way. Always keep it safe and secure.")
1053
 
@@ -1135,21 +1137,24 @@ def create_private_file():
1135
  st.write("---")
1136
  st.subheader("HERE Geocoding")
1137
  st.markdown('Follow these [instructions](https://platform.here.com/sign-up?step=verify-identity) to generate an API key for HERE.')
1138
- hre_APP_ID = st.text_input("HERE Geocoding App ID", cfg_private['here'].get('APP_ID', ''),
1139
  help='e.g. a 32-character string',
1140
  placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
1141
  type='password')
1142
- hre_API_KEY = st.text_input("HERE Geocoding API Key", cfg_private['here'].get('API_KEY', ''),
1143
  help='e.g. a 32-character string',
1144
  placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
1145
  type='password')
1146
 
1147
 
1148
 
1149
- st.button("Set API Keys",type='primary', on_click=save_changes_to_API_keys, args=[cfg_private,openai_api_key,azure_openai_api_version,azure_openai_api_key,
1150
- azure_openai_api_base,azure_openai_organization,azure_openai_api_type,
1151
- google_application_credentials, google_project_location, google_project_id,
1152
- mistral_API_KEY, hre_APP_ID, hre_API_KEY])
 
 
 
1153
  if st.button('Proceed to VoucherVision'):
1154
  st.session_state.private_file = does_private_file_exist()
1155
  st.session_state.proceed_to_private = False
@@ -1157,10 +1162,12 @@ def create_private_file():
1157
  st.rerun()
1158
 
1159
 
1160
- def save_changes_to_API_keys(cfg_private,openai_api_key,azure_openai_api_version,azure_openai_api_key,
1161
- azure_openai_api_base,azure_openai_organization,azure_openai_api_type,
1162
- google_application_credentials, google_project_location, google_project_id,
1163
- mistral_API_KEY, hre_APP_ID, hre_API_KEY):
 
 
1164
 
1165
  # Update the configuration dictionary with the new values
1166
  cfg_private['openai']['OPENAI_API_KEY'] = openai_api_key
@@ -1172,15 +1179,16 @@ def save_changes_to_API_keys(cfg_private,openai_api_key,azure_openai_api_version
1172
  cfg_private['openai_azure']['OPENAI_API_TYPE'] = azure_openai_api_type
1173
 
1174
  cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS'] = google_application_credentials
1175
- cfg_private['google']['GOOGLE_PROJECT_ID'] = google_project_location
1176
- cfg_private['google']['GOOGLE_LOCATION'] = google_project_id
1177
 
1178
  cfg_private['mistral']['MISTRAL_API_KEY'] = mistral_API_KEY
1179
 
1180
- cfg_private['here']['APP_ID'] = hre_APP_ID
1181
- cfg_private['here']['API_KEY'] = hre_API_KEY
1182
  # Call the function to write the updated configuration to the YAML file
1183
  write_config_file(cfg_private, st.session_state.dir_home, filename="PRIVATE_DATA.yaml")
 
1184
  # st.session_state.private_file = does_private_file_exist()
1185
 
1186
  # Function to load a YAML file and update session_state
@@ -1568,6 +1576,25 @@ def content_project_settings(col):
1568
  st.session_state.config['leafmachine']['project']['dir_output'] = st.text_input("Output directory", st.session_state.config['leafmachine']['project'].get('dir_output', ''))
1569
 
1570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1571
  def content_llm_cost():
1572
  st.write("---")
1573
  st.header('LLM Cost Calculator')
@@ -1855,6 +1882,17 @@ def content_ocr_method():
1855
  do_use_trOCR = st.checkbox("Enable trOCR", value=st.session_state.config['leafmachine']['project']['do_use_trOCR'],key="Enable trOCR2")#,disabled=st.session_state['lacks_GPU'])
1856
  st.session_state.config['leafmachine']['project']['do_use_trOCR'] = do_use_trOCR
1857
 
 
 
 
 
 
 
 
 
 
 
 
1858
  if 'LLaVA' in selected_OCR_options:
1859
  OCR_option_llava = st.radio(
1860
  "Select the LLaVA version",
@@ -1888,6 +1926,15 @@ def content_ocr_method():
1888
  # elif (OCR_option == 'hand') and do_use_trOCR:
1889
  # st.text_area(label='Handwritten/Printed + trOCR',placeholder=demo_text_trh,disabled=True, label_visibility='visible', height=150)
1890
 
 
 
 
 
 
 
 
 
 
1891
  @st.cache_data
1892
  def show_collage():
1893
  # Load the image only if it's not already in the session state
@@ -1920,7 +1967,12 @@ def content_collage_overlay():
1920
  st.info("NOTE: We strongly recommend enabling LeafMachine2 cropping if your images are full sized herbarium sheet. Often, the OCR algorithm struggles with full sheets, but works well with the collage images. We have disabled the collage by default for this Hugging Face Space because the Space lacks a GPU and the collage creation takes a bit longer.")
1921
  default_crops = st.session_state.config['leafmachine']['cropped_components']['save_cropped_annotations']
1922
  st.markdown("Prior to transcription, use LeafMachine2 to crop all labels from input images to create label collages for each specimen image. Showing just the text labels to the OCR algorithms significantly improves performance. This runs slowly on the free Hugging Face Space, but runs quickly with a fast CPU or any GPU.")
1923
- st.session_state.config['leafmachine']['use_RGB_label_images'] = st.checkbox(":rainbow[Use LeafMachine2 label collage for transcriptions]", st.session_state.config['leafmachine'].get('use_RGB_label_images', False))
 
 
 
 
 
1924
 
1925
 
1926
  option_selected_crops = st.multiselect(label="Components to crop",
@@ -2247,6 +2299,7 @@ def main():
2247
  content_ocr_method()
2248
 
2249
  content_collage_overlay()
 
2250
  content_llm_cost()
2251
  content_processing_options()
2252
  content_less_used()
 
7
  from io import BytesIO
8
  from streamlit_extras.let_it_rain import rain
9
  from annotated_text import annotated_text
10
+ from transformers import AutoConfig
11
 
12
  from vouchervision.LeafMachine2_Config_Builder import write_config_file
13
  from vouchervision.VoucherVision_Config_Builder import build_VV_config, TestOptionsGPT, TestOptionsPalm, check_if_usable
 
1000
  st.write("API keys are stored in `../VoucherVision/PRIVATE_DATA.yaml`.")
1001
  st.write("Deleting this file will allow you to reset API keys. Alternatively, you can edit the keys in the user interface or by manually editing the `.yaml` file in a text editor.")
1002
  st.write("Leave keys blank if you do not intend to use that service.")
1003
+ st.info("Note: You can manually edit these API keys later by opening the /PRIVATE_DATA.yaml file in a plain text editor.")
1004
+
1005
  st.write("---")
1006
  st.subheader("Google Vision (*Required*) / Google PaLM 2 / Google Gemini")
1007
  st.markdown("VoucherVision currently uses [Google Vision API](https://cloud.google.com/vision/docs/ocr) for OCR. Generating an API key for this is more involved than the others. [Please carefully follow the instructions outlined here to create and setup your account.](https://cloud.google.com/vision/docs/setup) ")
 
1010
  with st.expander("**View Google API Instructions**"):
1011
 
1012
  blog_text_and_image(text="Select your project, then in the search bar, search for `vertex ai` and select the option in the photo below.",
1013
+ fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_00.PNG'))
1014
 
1015
  blog_text_and_image(text="On the main overview page, click `Enable All Recommended APIs`. Sometimes this button may be hidden. In that case, enable all of the suggested APIs listed on this page.",
1016
+ fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_0.PNG'))
1017
 
1018
  blog_text_and_image(text="Sometimes this button may be hidden. In that case, enable all of the suggested APIs listed on this page.",
1019
+ fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_2.PNG'))
1020
 
1021
  blog_text_and_image(text="Make sure that all APIs are enabled.",
1022
+ fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_1.PNG'))
1023
 
1024
  blog_text_and_image(text="Find the `Vision AI API` service and go to its page.",
1025
+ fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_3.PNG'))
1026
 
1027
  blog_text_and_image(text="Find the `Vision AI API` service and go to its page. This is the API service required to use OCR in VoucherVision and must be enabled.",
1028
+ fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_6.PNG'))
1029
 
1030
  blog_text_and_image(text="You can also search for the Vertex AI Vision service.",
1031
+ fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_4.PNG'))
1032
 
1033
  blog_text_and_image(text=None,
1034
+ fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_5.PNG'))
1035
 
1036
  st.subheader("Getting a Google JSON authentication key")
1037
  st.write("Google uses a JSON file to store additional authentication information. Save this file in a safe, private location and assign the `GOOGLE_APPLICATION_CREDENTIALS` value to the file path. For Hugging Face, copy the contents of the JSON file including the `\{\}` and paste it as the secret value.")
1038
  st.write("To download your JSON key...")
1039
  blog_text_and_image(text="Open the navigation menu. Click on the hamburger menu (three horizontal lines) in the top left corner. Go to IAM & Admin. ",
1040
+ fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_7.PNG'),width=300)
1041
 
1042
  blog_text_and_image(text="In the navigation pane, hover over `IAM & Admin` and then click on `Service accounts`.",
1043
+ fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_8.PNG'))
1044
 
1045
  blog_text_and_image(text="Find the default Compute Engine service account, select it.",
1046
+ fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_9.PNG'))
1047
 
1048
  blog_text_and_image(text="Click `Add Key`.",
1049
+ fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_10.PNG'))
1050
 
1051
  blog_text_and_image(text="Select `JSON` and click create. This will download your key. Store this in a safe location. The file path to this safe location is the value that you enter into the `GOOGLE_APPLICATION_CREDENTIALS` value.",
1052
+ fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_11.PNG'))
1053
 
1054
  blog_text(text_bold="Store Safely", text=": This file contains sensitive data that can be used to authenticate and bill your Google Cloud account. Never commit it to public repositories or expose it in any way. Always keep it safe and secure.")
1055
 
 
1137
  st.write("---")
1138
  st.subheader("HERE Geocoding")
1139
  st.markdown('Follow these [instructions](https://platform.here.com/sign-up?step=verify-identity) to generate an API key for HERE.')
1140
+ here_APP_ID = st.text_input("HERE Geocoding App ID", cfg_private['here'].get('APP_ID', ''),
1141
  help='e.g. a 32-character string',
1142
  placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
1143
  type='password')
1144
+ here_API_KEY = st.text_input("HERE Geocoding API Key", cfg_private['here'].get('API_KEY', ''),
1145
  help='e.g. a 32-character string',
1146
  placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
1147
  type='password')
1148
 
1149
 
1150
 
1151
+ st.button("Set API Keys",type='primary', on_click=save_changes_to_API_keys,
1152
+ args=[cfg_private,
1153
+ openai_api_key,
1154
+ azure_openai_api_version, azure_openai_api_key, azure_openai_api_base, azure_openai_organization, azure_openai_api_type,
1155
+ google_application_credentials, google_project_location, google_project_id,
1156
+ mistral_API_KEY,
1157
+ here_APP_ID, here_API_KEY])
1158
  if st.button('Proceed to VoucherVision'):
1159
  st.session_state.private_file = does_private_file_exist()
1160
  st.session_state.proceed_to_private = False
 
1162
  st.rerun()
1163
 
1164
 
1165
+ def save_changes_to_API_keys(cfg_private,
1166
+ openai_api_key,
1167
+ azure_openai_api_version, azure_openai_api_key, azure_openai_api_base, azure_openai_organization, azure_openai_api_type,
1168
+ google_application_credentials, google_project_location, google_project_id,
1169
+ mistral_API_KEY,
1170
+ here_APP_ID, here_API_KEY):
1171
 
1172
  # Update the configuration dictionary with the new values
1173
  cfg_private['openai']['OPENAI_API_KEY'] = openai_api_key
 
1179
  cfg_private['openai_azure']['OPENAI_API_TYPE'] = azure_openai_api_type
1180
 
1181
  cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS'] = google_application_credentials
1182
+ cfg_private['google']['GOOGLE_PROJECT_ID'] = google_project_id
1183
+ cfg_private['google']['GOOGLE_LOCATION'] = google_project_location
1184
 
1185
  cfg_private['mistral']['MISTRAL_API_KEY'] = mistral_API_KEY
1186
 
1187
+ cfg_private['here']['APP_ID'] = here_APP_ID
1188
+ cfg_private['here']['API_KEY'] = here_API_KEY
1189
  # Call the function to write the updated configuration to the YAML file
1190
  write_config_file(cfg_private, st.session_state.dir_home, filename="PRIVATE_DATA.yaml")
1191
+ st.success(f"API Keys saved to {os.path.join(st.session_state.dir_home, 'PRIVATE_DATA.yaml')}")
1192
  # st.session_state.private_file = does_private_file_exist()
1193
 
1194
  # Function to load a YAML file and update session_state
 
1576
  st.session_state.config['leafmachine']['project']['dir_output'] = st.text_input("Output directory", st.session_state.config['leafmachine']['project'].get('dir_output', ''))
1577
 
1578
 
1579
+ def content_tools():
1580
+ st.write("---")
1581
+ st.header('Validation Tools')
1582
+
1583
+ tool_WFO = st.session_state.config['leafmachine']['project']['tool_WFO']
1584
+ st.session_state.config['leafmachine']['project']['tool_WFO'] = st.checkbox(label="Enable World Flora Online taxonomy verification",
1585
+ help="",
1586
+ value=tool_WFO)
1587
+
1588
+ tool_GEO = st.session_state.config['leafmachine']['project']['tool_GEO']
1589
+ st.session_state.config['leafmachine']['project']['tool_GEO'] = st.checkbox(label="Enable HERE geolocation hints",
1590
+ help="",
1591
+ value=tool_GEO)
1592
+
1593
+ tool_wikipedia = st.session_state.config['leafmachine']['project']['tool_wikipedia']
1594
+ st.session_state.config['leafmachine']['project']['tool_wikipedia'] = st.checkbox(label="Enable Wikipedia verification",
1595
+ help="",
1596
+ value=tool_wikipedia)
1597
+
1598
  def content_llm_cost():
1599
  st.write("---")
1600
  st.header('LLM Cost Calculator')
 
1882
  do_use_trOCR = st.checkbox("Enable trOCR", value=st.session_state.config['leafmachine']['project']['do_use_trOCR'],key="Enable trOCR2")#,disabled=st.session_state['lacks_GPU'])
1883
  st.session_state.config['leafmachine']['project']['do_use_trOCR'] = do_use_trOCR
1884
 
1885
+ if do_use_trOCR:
1886
+ # st.session_state.config['leafmachine']['project']['trOCR_model_path'] = "microsoft/trocr-large-handwritten"
1887
+ default_trOCR_model_path = st.session_state.config['leafmachine']['project']['trOCR_model_path']
1888
+ user_input_trOCR_model_path = st.text_input("trOCR Hugging Face model path. MUST be a fine-tuned version of 'microsoft/trocr-base-handwritten' or 'microsoft/trocr-large-handwritten', or a microsoft trOCR model.", value=default_trOCR_model_path)
1889
+ if st.session_state.config['leafmachine']['project']['trOCR_model_path'] != user_input_trOCR_model_path:
1890
+ is_valid_mp = is_valid_huggingface_model_path(user_input_trOCR_model_path)
1891
+ if not is_valid_mp:
1892
+ st.error(f"The Hugging Face model path {user_input_trOCR_model_path} is not valid. Please revise.")
1893
+ else:
1894
+ st.session_state.config['leafmachine']['project']['trOCR_model_path'] = user_input_trOCR_model_path
1895
+
1896
  if 'LLaVA' in selected_OCR_options:
1897
  OCR_option_llava = st.radio(
1898
  "Select the LLaVA version",
 
1926
  # elif (OCR_option == 'hand') and do_use_trOCR:
1927
  # st.text_area(label='Handwritten/Printed + trOCR',placeholder=demo_text_trh,disabled=True, label_visibility='visible', height=150)
1928
 
1929
+ def is_valid_huggingface_model_path(model_path):
1930
+ try:
1931
+ # Attempt to load the model configuration from Hugging Face Model Hub
1932
+ config = AutoConfig.from_pretrained(model_path)
1933
+ return True # If the configuration loads successfully, the model path is valid
1934
+ except Exception as e:
1935
+ # If loading the model configuration fails, the model path is not valid
1936
+ return False
1937
+
1938
  @st.cache_data
1939
  def show_collage():
1940
  # Load the image only if it's not already in the session state
 
1967
  st.info("NOTE: We strongly recommend enabling LeafMachine2 cropping if your images are full sized herbarium sheet. Often, the OCR algorithm struggles with full sheets, but works well with the collage images. We have disabled the collage by default for this Hugging Face Space because the Space lacks a GPU and the collage creation takes a bit longer.")
1968
  default_crops = st.session_state.config['leafmachine']['cropped_components']['save_cropped_annotations']
1969
  st.markdown("Prior to transcription, use LeafMachine2 to crop all labels from input images to create label collages for each specimen image. Showing just the text labels to the OCR algorithms significantly improves performance. This runs slowly on the free Hugging Face Space, but runs quickly with a fast CPU or any GPU.")
1970
+ st.markdown("Images that are mostly text (like a scanned notecard, or already cropped images) do not require LM2 collage.")
1971
+
1972
+ if st.session_state.is_hf:
1973
+ st.session_state.config['leafmachine']['use_RGB_label_images'] = st.checkbox(":rainbow[Use LeafMachine2 label collage for transcriptions]", st.session_state.config['leafmachine'].get('use_RGB_label_images', False), key='do make collage hf')
1974
+ else:
1975
+ st.session_state.config['leafmachine']['use_RGB_label_images'] = st.checkbox(":rainbow[Use LeafMachine2 label collage for transcriptions]", st.session_state.config['leafmachine'].get('use_RGB_label_images', True), key='do make collage local')
1976
 
1977
 
1978
  option_selected_crops = st.multiselect(label="Components to crop",
 
2299
  content_ocr_method()
2300
 
2301
  content_collage_overlay()
2302
+ content_tools()
2303
  content_llm_cost()
2304
  content_processing_options()
2305
  content_less_used()
run_VoucherVision.py CHANGED
@@ -31,7 +31,7 @@ def resolve_path(path):
31
  if __name__ == "__main__":
32
  dir_home = os.path.dirname(__file__)
33
 
34
- start_port = 8529
35
  try:
36
  free_port = find_available_port(start_port)
37
  sys.argv = [
@@ -41,7 +41,7 @@ if __name__ == "__main__":
41
  # resolve_path(os.path.join(dir_home,"vouchervision", "VoucherVision_GUI.py")),
42
  "--global.developmentMode=false",
43
  # "--server.port=8545",
44
- "--server.port=8546",
45
  # Toggle below for HF vs Local
46
  # "--is_hf=1",
47
  # "--is_hf=0",
 
31
  if __name__ == "__main__":
32
  dir_home = os.path.dirname(__file__)
33
 
34
+ start_port = 8528
35
  try:
36
  free_port = find_available_port(start_port)
37
  sys.argv = [
 
41
  # resolve_path(os.path.join(dir_home,"vouchervision", "VoucherVision_GUI.py")),
42
  "--global.developmentMode=false",
43
  # "--server.port=8545",
44
+ f"--server.port={free_port}",
45
  # Toggle below for HF vs Local
46
  # "--is_hf=1",
47
  # "--is_hf=0",
vouchervision/API_validation.py CHANGED
@@ -36,10 +36,11 @@ class APIvalidation:
36
 
37
 
38
  def has_API_key(self, val):
39
- if val:
40
- return True
41
- else:
42
- return False
 
43
 
44
  def check_openai_api_key(self):
45
  if self.is_hf:
@@ -192,10 +193,6 @@ class APIvalidation:
192
  print(f"palm2 fail2")
193
 
194
  try:
195
- # https://python.langchain.com/docs/integrations/llms/google_vertex_ai_palm
196
- # os.environ['GOOGLE_API_KEY'] = "AIzaSyAHOH1w1qV7C3jS4W7QFyoaTGUwZIgS5ig"
197
- # genai.configure(api_key='AIzaSyC8xvu6t9fb5dTah3hpgg_rwwR5G5kianI')
198
- # model = ChatGoogleGenerativeAI(model="text-bison@001")
199
  model = VertexAI(model="text-bison@001", max_output_tokens=10)
200
  response = model.predict("Hello")
201
  test_response_palm2 = response
 
36
 
37
 
38
  def has_API_key(self, val):
39
+ return isinstance(val, str) and bool(val.strip())
40
+ # if val:
41
+ # return True
42
+ # else:
43
+ # return False
44
 
45
  def check_openai_api_key(self):
46
  if self.is_hf:
 
193
  print(f"palm2 fail2")
194
 
195
  try:
 
 
 
 
196
  model = VertexAI(model="text-bison@001", max_output_tokens=10)
197
  response = model.predict("Hello")
198
  test_response_palm2 = response
vouchervision/LLM_GoogleGemini.py CHANGED
@@ -6,14 +6,11 @@ from langchain.output_parsers import RetryWithErrorOutputParser
6
  # from langchain.schema import HumanMessage
7
  from langchain.prompts import PromptTemplate
8
  from langchain_core.output_parsers import JsonOutputParser
9
- # from langchain_google_genai import ChatGoogleGenerativeAI
10
  from langchain_google_vertexai import VertexAI
11
 
12
- from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
13
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
14
- from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
15
- from vouchervision.utils_geolocate_HERE import validate_coordinates_here
16
- from vouchervision.tool_wikipedia import WikipediaLinks
17
 
18
  class GoogleGeminiHandler:
19
 
@@ -23,7 +20,12 @@ class GoogleGeminiHandler:
23
  VENDOR = 'google'
24
  STARTING_TEMP = 0.5
25
 
26
- def __init__(self, logger, model_name, JSON_dict_structure):
 
 
 
 
 
27
  self.logger = logger
28
  self.model_name = model_name
29
  self.JSON_dict_structure = JSON_dict_structure
@@ -76,13 +78,13 @@ class GoogleGeminiHandler:
76
 
77
  def _build_model_chain_parser(self):
78
  # Instantiate the LLM class for Google Gemini
79
- # self.llm_model = ChatGoogleGenerativeAI(model='gemini-pro',
80
- # max_output_tokens=self.config.get('max_output_tokens'),
81
- # top_p=self.config.get('top_p'))
82
- self.llm_model = VertexAI(model='gemini-pro',
83
- max_output_tokens=self.config.get('max_output_tokens'),
84
- top_p=self.config.get('top_p'))
85
-
86
  # Set up the retry parser with the runnable
87
  self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.llm_model, max_retries=self.MAX_RETRIES)
88
  # Prepare the chain
@@ -90,10 +92,10 @@ class GoogleGeminiHandler:
90
 
91
  # Define a function to format the input for Google Gemini call
92
  def call_google_gemini(self, prompt_text):
93
- model = GenerativeModel(self.model_name)
94
- response = model.generate_content(prompt_text.text,
95
- generation_config=self.config,
96
- safety_settings=self.safety_settings)
97
  return response.text
98
 
99
  def call_llm_api_GoogleGemini(self, prompt_template, json_report, paths):
@@ -130,13 +132,9 @@ class GoogleGeminiHandler:
130
  self.monitor.stop_inference_timer() # Starts tool timer too
131
 
132
  json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
133
- output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
134
- output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
135
 
136
- Wiki = WikipediaLinks(json_file_path_wiki)
137
- Wiki.gather_wikipedia_results(output)
138
-
139
- save_individual_prompt(Wiki.sanitize(prompt_template), txt_file_path_ind_prompt)
140
 
141
  self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
142
 
@@ -156,6 +154,8 @@ class GoogleGeminiHandler:
156
 
157
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
158
  self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
 
 
159
 
160
  usage_report = self.monitor.stop_monitoring_report_usage()
161
  self._reset_config()
 
6
  # from langchain.schema import HumanMessage
7
  from langchain.prompts import PromptTemplate
8
  from langchain_core.output_parsers import JsonOutputParser
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
  from langchain_google_vertexai import VertexAI
11
 
12
+ from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
13
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
 
 
 
14
 
15
  class GoogleGeminiHandler:
16
 
 
20
  VENDOR = 'google'
21
  STARTING_TEMP = 0.5
22
 
23
+ def __init__(self, cfg, logger, model_name, JSON_dict_structure):
24
+ self.cfg = cfg
25
+ self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
26
+ self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
27
+ self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
28
+
29
  self.logger = logger
30
  self.model_name = model_name
31
  self.JSON_dict_structure = JSON_dict_structure
 
78
 
79
  def _build_model_chain_parser(self):
80
  # Instantiate the LLM class for Google Gemini
81
+ self.llm_model = ChatGoogleGenerativeAI(model=self.model_name)#,
82
+ # max_output_tokens=self.config.get('max_output_tokens'),
83
+ # top_p=self.config.get('top_p'))
84
+ # self.llm_model = VertexAI(model='gemini-1.0-pro',
85
+ # max_output_tokens=self.config.get('max_output_tokens'),
86
+ # top_p=self.config.get('top_p'))
87
+
88
  # Set up the retry parser with the runnable
89
  self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.llm_model, max_retries=self.MAX_RETRIES)
90
  # Prepare the chain
 
92
 
93
  # Define a function to format the input for Google Gemini call
94
  def call_google_gemini(self, prompt_text):
95
+ model = GenerativeModel(self.model_name)#,
96
+ # generation_config=self.config,
97
+ # safety_settings=self.safety_settings)
98
+ response = model.generate_content(prompt_text.text)
99
  return response.text
100
 
101
  def call_llm_api_GoogleGemini(self, prompt_template, json_report, paths):
 
132
  self.monitor.stop_inference_timer() # Starts tool timer too
133
 
134
  json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
135
+ output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
 
136
 
137
+ save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
 
 
 
138
 
139
  self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
140
 
 
154
 
155
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
156
  self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
157
+
158
+ self.monitor.stop_inference_timer() # Starts tool timer too
159
 
160
  usage_report = self.monitor.stop_monitoring_report_usage()
161
  self._reset_config()
vouchervision/LLM_GooglePalm2.py CHANGED
@@ -11,11 +11,8 @@ from langchain_core.output_parsers import JsonOutputParser
11
  # from langchain_google_genai import ChatGoogleGenerativeAI
12
  from langchain_google_vertexai import VertexAI
13
 
14
- from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
15
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
16
- from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
17
- from vouchervision.utils_geolocate_HERE import validate_coordinates_here
18
- from vouchervision.tool_wikipedia import WikipediaLinks
19
 
20
  #https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk
21
  #pip install --upgrade google-cloud-aiplatform
@@ -34,7 +31,12 @@ class GooglePalm2Handler:
34
  VENDOR = 'google'
35
  STARTING_TEMP = 0.5
36
 
37
- def __init__(self, logger, model_name, JSON_dict_structure):
 
 
 
 
 
38
  self.logger = logger
39
  self.model_name = model_name
40
  self.JSON_dict_structure = JSON_dict_structure
@@ -144,13 +146,9 @@ class GooglePalm2Handler:
144
  self.monitor.stop_inference_timer() # Starts tool timer too
145
 
146
  json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
147
- output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
148
- output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
149
-
150
- Wiki = WikipediaLinks(json_file_path_wiki)
151
- Wiki.gather_wikipedia_results(output)
152
 
153
- save_individual_prompt(Wiki.sanitize(prompt_template), txt_file_path_ind_prompt)
154
 
155
  self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
156
 
@@ -171,6 +169,7 @@ class GooglePalm2Handler:
171
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
172
  self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
173
 
 
174
  usage_report = self.monitor.stop_monitoring_report_usage()
175
  self._reset_config()
176
 
 
11
  # from langchain_google_genai import ChatGoogleGenerativeAI
12
  from langchain_google_vertexai import VertexAI
13
 
14
+ from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
15
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
 
 
 
16
 
17
  #https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk
18
  #pip install --upgrade google-cloud-aiplatform
 
31
  VENDOR = 'google'
32
  STARTING_TEMP = 0.5
33
 
34
+ def __init__(self, cfg, logger, model_name, JSON_dict_structure):
35
+ self.cfg = cfg
36
+ self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
37
+ self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
38
+ self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
39
+
40
  self.logger = logger
41
  self.model_name = model_name
42
  self.JSON_dict_structure = JSON_dict_structure
 
146
  self.monitor.stop_inference_timer() # Starts tool timer too
147
 
148
  json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
149
+ output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
 
 
 
 
150
 
151
+ save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
152
 
153
  self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
154
 
 
169
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
170
  self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
171
 
172
+ self.monitor.stop_inference_timer() # Starts tool timer too
173
  usage_report = self.monitor.stop_monitoring_report_usage()
174
  self._reset_config()
175
 
vouchervision/LLM_MistralAI.py CHANGED
@@ -4,11 +4,8 @@ from langchain.output_parsers import RetryWithErrorOutputParser
4
  from langchain.prompts import PromptTemplate
5
  from langchain_core.output_parsers import JsonOutputParser
6
 
7
- from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
8
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
9
- from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
10
- from vouchervision.utils_geolocate_HERE import validate_coordinates_here
11
- from vouchervision.tool_wikipedia import WikipediaLinks
12
 
13
 
14
  class MistralHandler:
@@ -19,7 +16,12 @@ class MistralHandler:
19
  VENDOR = 'mistral'
20
  RANDOM_SEED = 2023
21
 
22
- def __init__(self, logger, model_name, JSON_dict_structure):
 
 
 
 
 
23
  self.logger = logger
24
  self.monitor = SystemLoadMonitor(logger)
25
  self.has_GPU = torch.cuda.is_available()
@@ -115,13 +117,9 @@ class MistralHandler:
115
  self.monitor.stop_inference_timer() # Starts tool timer too
116
 
117
  json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
118
- output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
119
- output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
120
-
121
- Wiki = WikipediaLinks(json_file_path_wiki)
122
- Wiki.gather_wikipedia_results(output)
123
 
124
- save_individual_prompt(Wiki.sanitize(prompt_template), txt_file_path_ind_prompt)
125
 
126
  self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
127
 
@@ -142,6 +140,7 @@ class MistralHandler:
142
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
143
  self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
144
 
 
145
  usage_report = self.monitor.stop_monitoring_report_usage()
146
  self._reset_config()
147
  json_report.set_text(text_main=f'LLM call failed')
 
4
  from langchain.prompts import PromptTemplate
5
  from langchain_core.output_parsers import JsonOutputParser
6
 
7
+ from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
8
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
 
 
 
9
 
10
 
11
  class MistralHandler:
 
16
  VENDOR = 'mistral'
17
  RANDOM_SEED = 2023
18
 
19
+ def __init__(self, cfg, logger, model_name, JSON_dict_structure):
20
+ self.cfg = cfg
21
+ self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
22
+ self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
23
+ self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
24
+
25
  self.logger = logger
26
  self.monitor = SystemLoadMonitor(logger)
27
  self.has_GPU = torch.cuda.is_available()
 
117
  self.monitor.stop_inference_timer() # Starts tool timer too
118
 
119
  json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
120
+ output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
 
 
 
 
121
 
122
+ save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
123
 
124
  self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
125
 
 
140
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
141
  self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
142
 
143
+ self.monitor.stop_inference_timer() # Starts tool timer too
144
  usage_report = self.monitor.stop_monitoring_report_usage()
145
  self._reset_config()
146
  json_report.set_text(text_main=f'LLM call failed')
vouchervision/LLM_OpenAI.py CHANGED
@@ -5,11 +5,8 @@ from langchain.schema import HumanMessage
5
  from langchain_core.output_parsers import JsonOutputParser
6
  from langchain.output_parsers import RetryWithErrorOutputParser
7
 
8
- from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
9
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
10
- from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
11
- from vouchervision.utils_geolocate_HERE import validate_coordinates_here
12
- from vouchervision.tool_wikipedia import WikipediaLinks
13
 
14
  class OpenAIHandler:
15
  RETRY_DELAY = 10 # Wait 10 seconds before retrying
@@ -18,7 +15,12 @@ class OpenAIHandler:
18
  TOKENIZER_NAME = 'gpt-4'
19
  VENDOR = 'openai'
20
 
21
- def __init__(self, logger, model_name, JSON_dict_structure, is_azure, llm_object):
 
 
 
 
 
22
  self.logger = logger
23
  self.model_name = model_name
24
  self.JSON_dict_structure = JSON_dict_structure
@@ -135,13 +137,14 @@ class OpenAIHandler:
135
  self.monitor.stop_inference_timer() # Starts tool timer too
136
 
137
  json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
138
- output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
139
- output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
140
 
141
- Wiki = WikipediaLinks(json_file_path_wiki)
142
- Wiki.gather_wikipedia_results(output)
 
143
 
144
- save_individual_prompt(Wiki.sanitize(prompt_template), txt_file_path_ind_prompt)
145
 
146
  self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
147
 
@@ -162,6 +165,7 @@ class OpenAIHandler:
162
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
163
  self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
164
 
 
165
  usage_report = self.monitor.stop_monitoring_report_usage()
166
  self._reset_config()
167
 
 
5
  from langchain_core.output_parsers import JsonOutputParser
6
  from langchain.output_parsers import RetryWithErrorOutputParser
7
 
8
+ from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
9
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
 
 
 
10
 
11
  class OpenAIHandler:
12
  RETRY_DELAY = 10 # Wait 10 seconds before retrying
 
15
  TOKENIZER_NAME = 'gpt-4'
16
  VENDOR = 'openai'
17
 
18
+ def __init__(self, cfg, logger, model_name, JSON_dict_structure, is_azure, llm_object):
19
+ self.cfg = cfg
20
+ self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
21
+ self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
22
+ self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
23
+
24
  self.logger = logger
25
  self.model_name = model_name
26
  self.JSON_dict_structure = JSON_dict_structure
 
137
  self.monitor.stop_inference_timer() # Starts tool timer too
138
 
139
  json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
140
+
141
+ output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
142
 
143
+ # output1, WFO_record = validate_taxonomy_WFO(self.tool_WFO, output, replace_if_success_wfo=False)
144
+ # output2, GEO_record = validate_coordinates_here(self.tool_GEO, output, replace_if_success_geo=False)
145
+ # validate_wikipedia(self.tool_wikipedia, json_file_path_wiki, output)
146
 
147
+ save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
148
 
149
  self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
150
 
 
165
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
166
  self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
167
 
168
+ self.monitor.stop_inference_timer() # Starts tool timer too
169
  usage_report = self.monitor.stop_monitoring_report_usage()
170
  self._reset_config()
171
 
vouchervision/LLM_local_MistralAI.py CHANGED
@@ -6,11 +6,8 @@ from langchain_core.output_parsers import JsonOutputParser
6
  from huggingface_hub import hf_hub_download
7
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
8
 
9
- from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
10
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
11
- from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
12
- from vouchervision.utils_geolocate_HERE import validate_coordinates_here
13
- from vouchervision.tool_wikipedia import WikipediaLinks
14
 
15
  '''
16
  Local Pipielines:
@@ -25,7 +22,12 @@ class LocalMistralHandler:
25
  VENDOR = 'mistral'
26
  MAX_GPU_MONITORING_INTERVAL = 2 # seconds
27
 
28
- def __init__(self, logger, model_name, JSON_dict_structure):
 
 
 
 
 
29
  self.logger = logger
30
  self.has_GPU = torch.cuda.is_available()
31
  self.monitor = SystemLoadMonitor(logger)
@@ -188,13 +190,9 @@ class LocalMistralHandler:
188
  self.monitor.stop_inference_timer() # Starts tool timer too
189
 
190
  json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
191
- output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
192
- output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
193
-
194
- Wiki = WikipediaLinks(json_file_path_wiki)
195
- Wiki.gather_wikipedia_results(output)
196
 
197
- save_individual_prompt(Wiki.sanitize(prompt_template), txt_file_path_ind_prompt)
198
 
199
  self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
200
 
@@ -214,6 +212,7 @@ class LocalMistralHandler:
214
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
215
  self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
216
 
 
217
  usage_report = self.monitor.stop_monitoring_report_usage()
218
  json_report.set_text(text_main=f'LLM call failed')
219
 
 
6
  from huggingface_hub import hf_hub_download
7
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
8
 
9
+ from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
10
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
 
 
 
11
 
12
  '''
13
  Local Pipielines:
 
22
  VENDOR = 'mistral'
23
  MAX_GPU_MONITORING_INTERVAL = 2 # seconds
24
 
25
+ def __init__(self, cfg, logger, model_name, JSON_dict_structure):
26
+ self.cfg = cfg
27
+ self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
28
+ self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
29
+ self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
30
+
31
  self.logger = logger
32
  self.has_GPU = torch.cuda.is_available()
33
  self.monitor = SystemLoadMonitor(logger)
 
190
  self.monitor.stop_inference_timer() # Starts tool timer too
191
 
192
  json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
193
+ output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
 
 
 
 
194
 
195
+ save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
196
 
197
  self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
198
 
 
212
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
213
  self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
214
 
215
+ self.monitor.stop_inference_timer() # Starts tool timer too
216
  usage_report = self.monitor.stop_monitoring_report_usage()
217
  json_report.set_text(text_main=f'LLM call failed')
218
 
vouchervision/LLM_local_cpu_MistralAI.py CHANGED
@@ -18,11 +18,8 @@ from langchain.callbacks.base import BaseCallbackHandler
18
  from huggingface_hub import hf_hub_download
19
 
20
 
21
- from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
22
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
23
- from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
24
- from vouchervision.utils_geolocate_HERE import validate_coordinates_here
25
- from vouchervision.tool_wikipedia import WikipediaLinks
26
 
27
  class LocalCPUMistralHandler:
28
  RETRY_DELAY = 2 # Wait 2 seconds before retrying
@@ -33,7 +30,12 @@ class LocalCPUMistralHandler:
33
  SEED = 2023
34
 
35
 
36
- def __init__(self, logger, model_name, JSON_dict_structure):
 
 
 
 
 
37
  self.logger = logger
38
  self.monitor = SystemLoadMonitor(logger)
39
  self.has_GPU = torch.cuda.is_available()
@@ -179,13 +181,9 @@ class LocalCPUMistralHandler:
179
  self.monitor.stop_inference_timer() # Starts tool timer too
180
 
181
  json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
182
- output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
183
- output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
184
-
185
- Wiki = WikipediaLinks(json_file_path_wiki)
186
- Wiki.gather_wikipedia_results(output)
187
 
188
- save_individual_prompt(Wiki.sanitize(prompt_template), txt_file_path_ind_prompt)
189
 
190
  self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
191
 
@@ -204,6 +202,7 @@ class LocalCPUMistralHandler:
204
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
205
  self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
206
 
 
207
  usage_report = self.monitor.stop_monitoring_report_usage()
208
  self._reset_config()
209
 
 
18
  from huggingface_hub import hf_hub_download
19
 
20
 
21
+ from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
22
  from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
 
 
 
23
 
24
  class LocalCPUMistralHandler:
25
  RETRY_DELAY = 2 # Wait 2 seconds before retrying
 
30
  SEED = 2023
31
 
32
 
33
+ def __init__(self, cfg, logger, model_name, JSON_dict_structure):
34
+ self.cfg = cfg
35
+ self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
36
+ self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
37
+ self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
38
+
39
  self.logger = logger
40
  self.monitor = SystemLoadMonitor(logger)
41
  self.has_GPU = torch.cuda.is_available()
 
181
  self.monitor.stop_inference_timer() # Starts tool timer too
182
 
183
  json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
184
+ output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
 
 
 
 
185
 
186
+ save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
187
 
188
  self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
189
 
 
202
  self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
203
  self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
204
 
205
+ self.monitor.stop_inference_timer() # Starts tool timer too
206
  usage_report = self.monitor.stop_monitoring_report_usage()
207
  self._reset_config()
208
 
vouchervision/OCR_Gemini.py CHANGED
@@ -145,16 +145,16 @@ maximumElevationInMeters
145
  }
146
  """
147
  def _get_google_credentials():
148
- with open('D:/Dropbox/Servers/google_API/vouchervision-hf-a2c361d5d29d.json', 'r') as file:
149
  data = json.load(file)
150
  creds_json_str = json.dumps(data)
151
  credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
152
  os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = creds_json_str
153
- os.environ['GOOGLE_API_KEY'] = 'AIzaSyAHOH1w1qV7C3jS4W7QFyoaTGUwZIgS5ig'
154
  return credentials
155
 
156
  if __name__ == '__main__':
157
- vertexai.init(project='vouchervision-hf', location='us-central1', credentials=_get_google_credentials())
158
 
159
  logger = logging.getLogger('LLaVA')
160
  logger.setLevel(logging.DEBUG)
 
145
  }
146
  """
147
  def _get_google_credentials():
148
+ with open('', 'r') as file:
149
  data = json.load(file)
150
  creds_json_str = json.dumps(data)
151
  credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
152
  os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = creds_json_str
153
+ os.environ['GOOGLE_API_KEY'] = ''
154
  return credentials
155
 
156
  if __name__ == '__main__':
157
+ vertexai.init(project='', location='', credentials=_get_google_credentials())
158
 
159
  logger = logging.getLogger('LLaVA')
160
  logger.setLevel(logging.DEBUG)
vouchervision/OCR_google_cloud_vision.py CHANGED
@@ -10,14 +10,6 @@ from google.oauth2 import service_account
10
 
11
  ### LLaVA should only be installed if the user will actually use it.
12
  ### It requires the most recent pytorch/Python and can mess with older systems
13
- try:
14
- from craft_text_detector import read_image, load_craftnet_model, load_refinenet_model, get_prediction, export_detected_regions, export_extra_results, empty_cuda_cache
15
- except:
16
- pass
17
- try:
18
- from OCR_llava import OCRllava
19
- except:
20
- pass
21
 
22
 
23
  '''
@@ -92,9 +84,7 @@ class OCREngine:
92
 
93
  self.multimodal_prompt = """I need you to transcribe all of the text in this image.
94
  Place the transcribed text into a JSON dictionary with this form {"Transcription_Printed_Text": "text","Transcription_Handwritten_Text": "text"}"""
95
-
96
- if 'LLaVA' in self.OCR_option:
97
- self.init_llava()
98
 
99
 
100
  def set_client(self):
@@ -113,6 +103,8 @@ class OCREngine:
113
 
114
  def init_craft(self):
115
  if 'CRAFT' in self.OCR_option:
 
 
116
  try:
117
  self.refine_net = load_refinenet_model(cuda=True)
118
  self.use_cuda = True
@@ -126,21 +118,23 @@ class OCREngine:
126
  self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=False)
127
 
128
  def init_llava(self):
 
 
129
 
130
- self.model_path = "liuhaotian/" + self.cfg['leafmachine']['project']['OCR_option_llava']
131
- self.model_quant = self.cfg['leafmachine']['project']['OCR_option_llava_bit']
132
 
133
- self.json_report.set_text(text_main=f'Loading LLaVA model: {self.model_path} Quantization: {self.model_quant}')
134
 
135
- if self.model_quant == '4bit':
136
- use_4bit = True
137
- elif self.model_quant == 'full':
138
- use_4bit = False
139
- else:
140
- self.logger.info(f"Provided model quantization invlid. Using 4bit.")
141
- use_4bit = True
142
 
143
- self.Llava = OCRllava(self.logger, model_path=self.model_path, load_in_4bit=use_4bit, load_in_8bit=False)
144
 
145
  def init_gemini_vision(self):
146
  pass
@@ -150,6 +144,8 @@ class OCREngine:
150
 
151
 
152
  def detect_text_craft(self):
 
 
153
  # Perform prediction using CRAFT
154
  image = read_image(self.path)
155
 
@@ -250,13 +246,13 @@ class OCREngine:
250
  if not do_use_trOCR:
251
  if 'normal' in self.OCR_option:
252
  self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
253
- logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}")
254
  # ocr_parts = ocr_parts + f"Google_OCR_Standard:\n{self.normal_organized_text}"
255
  ocr_parts = self.normal_organized_text
256
 
257
  if 'hand' in self.OCR_option:
258
  self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
259
- logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}")
260
  # ocr_parts = ocr_parts + f"Google_OCR_Handwriting:\n{self.hand_organized_text}"
261
  ocr_parts = self.hand_organized_text
262
 
@@ -340,13 +336,13 @@ class OCREngine:
340
  if 'normal' in self.OCR_option:
341
  self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
342
  self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
343
- logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
344
  # ocr_parts = ocr_parts + f"\nGoogle_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
345
  ocr_parts = self.trOCR_texts
346
  if 'hand' in self.OCR_option:
347
  self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
348
  self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
349
- logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
350
  # ocr_parts = ocr_parts + f"\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
351
  ocr_parts = self.trOCR_texts
352
  # if self.OCR_option in ['both',]:
@@ -358,7 +354,7 @@ class OCREngine:
358
  if 'CRAFT' in self.OCR_option:
359
  # self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
360
  self.OCR_JSON_to_file['OCR_CRAFT_trOCR'] = self.trOCR_texts
361
- logger.info(f"CRAFT_trOCR:\n{self.trOCR_texts}")
362
  # ocr_parts = ocr_parts + f"\nCRAFT_trOCR:\n{self.trOCR_texts}"
363
  ocr_parts = self.trOCR_texts
364
  return ocr_parts
@@ -383,7 +379,10 @@ class OCREngine:
383
 
384
  for bound, confidence, char_height, character in zip(bounds_flat, confidences, heights, characters):
385
  font_size = int(char_height)
386
- font = ImageFont.load_default().font_variant(size=font_size)
 
 
 
387
  if option == 'trOCR':
388
  color = (0, 170, 255)
389
  else:
@@ -686,7 +685,7 @@ class OCREngine:
686
  self.OCR = self.OCR + part_OCR + part_OCR
687
  else:
688
  self.OCR = self.OCR + "\CRAFT trOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
689
- logger.info(f"CRAFT trOCR:\n{self.OCR}")
690
 
691
  if 'LLaVA' in self.OCR_option: # This option does not produce an OCR helper image
692
  self.json_report.set_text(text_main=f'Working on LLaVA {self.Llava.model_path} transcription :construction:')
@@ -704,25 +703,34 @@ class OCREngine:
704
  self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}" + f"\nLLaVA OCR:\n{str_output}"
705
  else:
706
  self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}"
707
- logger.info(f"LLaVA OCR:\n{self.OCR}")
708
 
709
  if 'normal' in self.OCR_option or 'hand' in self.OCR_option:
710
  if 'normal' in self.OCR_option:
711
- self.OCR = self.OCR + "\nGoogle Printed OCR:\n" + self.detect_text()
 
 
 
 
712
  if 'hand' in self.OCR_option:
713
- self.OCR = self.OCR + "\nGoogle Handwritten OCR:\n" + self.detect_handwritten_ocr()
 
 
 
 
714
  # if self.OCR_option not in ['normal', 'hand', 'both']:
715
  # self.OCR_option = 'both'
716
  # self.detect_text()
717
  # self.detect_handwritten_ocr()
718
 
719
  ### Optionally add trOCR to the self.OCR for additional context
720
- if self.double_OCR:
721
- part_OCR = "\ntrOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
722
- self.OCR = self.OCR + part_OCR + part_OCR
723
- else:
724
- self.OCR = self.OCR + "\ntrOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
725
- logger.info(f"OCR:\n{self.OCR}")
 
726
 
727
  if do_create_OCR_helper_image and ('LLaVA' not in self.OCR_option):
728
  self.image = Image.open(self.path)
@@ -744,8 +752,6 @@ class OCREngine:
744
  image_with_boxes_normal = self.draw_boxes('normal')
745
  self.merged_image_normal = self.merge_images(image_with_boxes_normal, text_image_trOCR)
746
 
747
-
748
-
749
  ### Merge final overlay image
750
  ### [original, normal bboxes, normal text]
751
  if 'CRAFT' in self.OCR_option or 'normal' in self.OCR_option:
@@ -773,241 +779,7 @@ class OCREngine:
773
  self.overlay_image = Image.open(self.path)
774
 
775
  try:
 
776
  empty_cuda_cache()
777
  except:
778
- pass
779
-
780
-
781
-
782
- '''
783
- BBOX_COLOR = "black" # green cyan
784
-
785
- def render_text_on_black_image(image_path, handwritten_char_bounds_flat, handwritten_char_confidences, handwritten_char_heights, characters):
786
- # Load the original image to get its dimensions
787
- original_image = Image.open(image_path)
788
- width, height = original_image.size
789
-
790
- # Create a black image of the same size
791
- black_image = Image.new("RGB", (width, height), "black")
792
- draw = ImageDraw.Draw(black_image)
793
-
794
- # Loop through each character
795
- for bound, confidence, char_height, character in zip(handwritten_char_bounds_flat, handwritten_char_confidences, handwritten_char_heights, characters):
796
- # Determine the font size based on the height of the character
797
- font_size = int(char_height)
798
- font = ImageFont.load_default().font_variant(size=font_size)
799
-
800
- # Color of the character
801
- color = confidence_to_color(confidence)
802
-
803
- # Position of the text (using the bottom-left corner of the bounding box)
804
- position = (bound["vertices"][0]["x"], bound["vertices"][0]["y"] - char_height)
805
-
806
- # Draw the character
807
- draw.text(position, character, fill=color, font=font)
808
-
809
- return black_image
810
-
811
- def merge_images(image1, image2):
812
- # Assuming both images are of the same size
813
- width, height = image1.size
814
- merged_image = Image.new("RGB", (width * 2, height))
815
- merged_image.paste(image1, (0, 0))
816
- merged_image.paste(image2, (width, 0))
817
- return merged_image
818
-
819
- def draw_boxes(image, bounds, color):
820
- if bounds:
821
- draw = ImageDraw.Draw(image)
822
- width, height = image.size
823
- line_width = int((width + height) / 2 * 0.001) # This sets the line width as 0.5% of the average dimension
824
-
825
- for bound in bounds:
826
- draw.polygon(
827
- [
828
- bound["vertices"][0]["x"], bound["vertices"][0]["y"],
829
- bound["vertices"][1]["x"], bound["vertices"][1]["y"],
830
- bound["vertices"][2]["x"], bound["vertices"][2]["y"],
831
- bound["vertices"][3]["x"], bound["vertices"][3]["y"],
832
- ],
833
- outline=color,
834
- width=line_width
835
- )
836
- return image
837
-
838
- def detect_text(path):
839
- client = vision.ImageAnnotatorClient()
840
- with io.open(path, 'rb') as image_file:
841
- content = image_file.read()
842
- image = vision.Image(content=content)
843
- response = client.document_text_detection(image=image)
844
- texts = response.text_annotations
845
-
846
- if response.error.message:
847
- raise Exception(
848
- '{}\nFor more info on error messages, check: '
849
- 'https://cloud.google.com/apis/design/errors'.format(
850
- response.error.message))
851
-
852
- # Extract bounding boxes
853
- bounds = []
854
- text_to_box_mapping = {}
855
- for text in texts[1:]: # Skip the first entry, as it represents the entire detected text
856
- # Convert BoundingPoly to dictionary
857
- bound_dict = {
858
- "vertices": [
859
- {"x": vertex.x, "y": vertex.y} for vertex in text.bounding_poly.vertices
860
- ]
861
- }
862
- bounds.append(bound_dict)
863
- text_to_box_mapping[str(bound_dict)] = text.description
864
-
865
- if texts:
866
- # cleaned_text = texts[0].description.replace("\n", " ").replace("\t", " ").replace("|", " ")
867
- cleaned_text = texts[0].description
868
- return cleaned_text, bounds, text_to_box_mapping
869
- else:
870
- return '', None, None
871
-
872
- def confidence_to_color(confidence):
873
- """Convert confidence level to a color ranging from red (low confidence) to green (high confidence)."""
874
- # Using HSL color space, where Hue varies from red to green
875
- hue = (confidence - 0.5) * 120 / 0.5 # Scale confidence to range 0-120 (red to green in HSL)
876
- r, g, b = colorsys.hls_to_rgb(hue/360, 0.5, 1) # Convert to RGB
877
- return (int(r*255), int(g*255), int(b*255))
878
-
879
- def overlay_boxes_on_image(path, typed_bounds, handwritten_char_bounds, handwritten_char_confidences, do_create_OCR_helper_image):
880
- if do_create_OCR_helper_image:
881
- image = Image.open(path)
882
- draw = ImageDraw.Draw(image)
883
- width, height = image.size
884
- line_width = int((width + height) / 2 * 0.005) # Adjust line width for character level
885
-
886
- # Draw boxes for typed text
887
- for bound in typed_bounds:
888
- draw.polygon(
889
- [
890
- bound["vertices"][0]["x"], bound["vertices"][0]["y"],
891
- bound["vertices"][1]["x"], bound["vertices"][1]["y"],
892
- bound["vertices"][2]["x"], bound["vertices"][2]["y"],
893
- bound["vertices"][3]["x"], bound["vertices"][3]["y"],
894
- ],
895
- outline=BBOX_COLOR,
896
- width=1
897
- )
898
-
899
- # Draw a line segment at the bottom of each handwritten character
900
- for bound, confidence in zip(handwritten_char_bounds, handwritten_char_confidences):
901
- color = confidence_to_color(confidence)
902
- # Use the bottom two vertices of the bounding box for the line
903
- bottom_left = (bound["vertices"][3]["x"], bound["vertices"][3]["y"] + line_width)
904
- bottom_right = (bound["vertices"][2]["x"], bound["vertices"][2]["y"] + line_width)
905
- draw.line([bottom_left, bottom_right], fill=color, width=line_width)
906
-
907
- text_image = render_text_on_black_image(path, handwritten_char_bounds, handwritten_char_confidences)
908
- merged_image = merge_images(image, text_image) # Assuming 'overlayed_image' is the image with lines
909
-
910
-
911
- return merged_image
912
- else:
913
- return Image.open(path)
914
-
915
- def detect_handwritten_ocr(path):
916
- """Detects handwritten characters in a local image and returns their bounding boxes and confidence levels.
917
-
918
- Args:
919
- path: The path to the local file.
920
-
921
- Returns:
922
- A tuple of (text, bounding_boxes, confidences)
923
- """
924
- client = vision_beta.ImageAnnotatorClient()
925
-
926
- with open(path, "rb") as image_file:
927
- content = image_file.read()
928
-
929
- image = vision_beta.Image(content=content)
930
- image_context = vision_beta.ImageContext(language_hints=["en-t-i0-handwrit"])
931
- response = client.document_text_detection(image=image, image_context=image_context)
932
-
933
- if response.error.message:
934
- raise Exception(
935
- "{}\nFor more info on error messages, check: "
936
- "https://cloud.google.com/apis/design/errors".format(response.error.message)
937
- )
938
-
939
- bounds = []
940
- bounds_flat = []
941
- height_flat = []
942
- confidences = []
943
- character = []
944
- for page in response.full_text_annotation.pages:
945
- for block in page.blocks:
946
- for paragraph in block.paragraphs:
947
- for word in paragraph.words:
948
- # Get the bottom Y-location (max Y) for the whole word
949
- Y = max(vertex.y for vertex in word.bounding_box.vertices)
950
-
951
- # Get the height of the word's bounding box
952
- H = Y - min(vertex.y for vertex in word.bounding_box.vertices)
953
-
954
- for symbol in word.symbols:
955
- # Collecting bounding box for each symbol
956
- bound_dict = {
957
- "vertices": [
958
- {"x": vertex.x, "y": vertex.y} for vertex in symbol.bounding_box.vertices
959
- ]
960
- }
961
- bounds.append(bound_dict)
962
-
963
- # Bounds with same bottom y height
964
- bounds_flat_dict = {
965
- "vertices": [
966
- {"x": vertex.x, "y": Y} for vertex in symbol.bounding_box.vertices
967
- ]
968
- }
969
- bounds_flat.append(bounds_flat_dict)
970
-
971
- # Add the word's height
972
- height_flat.append(H)
973
-
974
- # Collecting confidence for each symbol
975
- symbol_confidence = round(symbol.confidence, 4)
976
- confidences.append(symbol_confidence)
977
- character.append(symbol.text)
978
-
979
- cleaned_text = response.full_text_annotation.text
980
-
981
- return cleaned_text, bounds, bounds_flat, height_flat, confidences, character
982
-
983
-
984
-
985
- def process_image(path, do_create_OCR_helper_image):
986
- typed_text, typed_bounds, _ = detect_text(path)
987
- handwritten_text, handwritten_bounds, _ = detect_handwritten_ocr(path)
988
-
989
- overlayed_image = overlay_boxes_on_image(path, typed_bounds, handwritten_bounds, do_create_OCR_helper_image)
990
- return typed_text, handwritten_text, overlayed_image
991
-
992
- '''
993
-
994
- # ''' Google Vision'''
995
- # def detect_text(path):
996
- # """Detects text in the file located in the local filesystem."""
997
- # client = vision.ImageAnnotatorClient()
998
-
999
- # with io.open(path, 'rb') as image_file:
1000
- # content = image_file.read()
1001
-
1002
- # image = vision.Image(content=content)
1003
-
1004
- # response = client.document_text_detection(image=image)
1005
- # texts = response.text_annotations
1006
-
1007
- # if response.error.message:
1008
- # raise Exception(
1009
- # '{}\nFor more info on error messages, check: '
1010
- # 'https://cloud.google.com/apis/design/errors'.format(
1011
- # response.error.message))
1012
-
1013
- # return texts[0].description if texts else ''
 
10
 
11
  ### LLaVA should only be installed if the user will actually use it.
12
  ### It requires the most recent pytorch/Python and can mess with older systems
 
 
 
 
 
 
 
 
13
 
14
 
15
  '''
 
84
 
85
  self.multimodal_prompt = """I need you to transcribe all of the text in this image.
86
  Place the transcribed text into a JSON dictionary with this form {"Transcription_Printed_Text": "text","Transcription_Handwritten_Text": "text"}"""
87
+ self.init_llava()
 
 
88
 
89
 
90
  def set_client(self):
 
103
 
104
  def init_craft(self):
105
  if 'CRAFT' in self.OCR_option:
106
+ from craft_text_detector import load_craftnet_model, load_refinenet_model
107
+
108
  try:
109
  self.refine_net = load_refinenet_model(cuda=True)
110
  self.use_cuda = True
 
118
  self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=False)
119
 
120
  def init_llava(self):
121
+ if 'LLaVA' in self.OCR_option:
122
+ from vouchervision.OCR_llava import OCRllava
123
 
124
+ self.model_path = "liuhaotian/" + self.cfg['leafmachine']['project']['OCR_option_llava']
125
+ self.model_quant = self.cfg['leafmachine']['project']['OCR_option_llava_bit']
126
 
127
+ self.json_report.set_text(text_main=f'Loading LLaVA model: {self.model_path} Quantization: {self.model_quant}')
128
 
129
+ if self.model_quant == '4bit':
130
+ use_4bit = True
131
+ elif self.model_quant == 'full':
132
+ use_4bit = False
133
+ else:
134
+ self.logger.info(f"Provided model quantization invlid. Using 4bit.")
135
+ use_4bit = True
136
 
137
+ self.Llava = OCRllava(self.logger, model_path=self.model_path, load_in_4bit=use_4bit, load_in_8bit=False)
138
 
139
  def init_gemini_vision(self):
140
  pass
 
144
 
145
 
146
  def detect_text_craft(self):
147
+ from craft_text_detector import read_image, get_prediction
148
+
149
  # Perform prediction using CRAFT
150
  image = read_image(self.path)
151
 
 
246
  if not do_use_trOCR:
247
  if 'normal' in self.OCR_option:
248
  self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
249
+ # logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}")
250
  # ocr_parts = ocr_parts + f"Google_OCR_Standard:\n{self.normal_organized_text}"
251
  ocr_parts = self.normal_organized_text
252
 
253
  if 'hand' in self.OCR_option:
254
  self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
255
+ # logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}")
256
  # ocr_parts = ocr_parts + f"Google_OCR_Handwriting:\n{self.hand_organized_text}"
257
  ocr_parts = self.hand_organized_text
258
 
 
336
  if 'normal' in self.OCR_option:
337
  self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
338
  self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
339
+ # logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
340
  # ocr_parts = ocr_parts + f"\nGoogle_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
341
  ocr_parts = self.trOCR_texts
342
  if 'hand' in self.OCR_option:
343
  self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
344
  self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
345
+ # logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
346
  # ocr_parts = ocr_parts + f"\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
347
  ocr_parts = self.trOCR_texts
348
  # if self.OCR_option in ['both',]:
 
354
  if 'CRAFT' in self.OCR_option:
355
  # self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
356
  self.OCR_JSON_to_file['OCR_CRAFT_trOCR'] = self.trOCR_texts
357
+ # logger.info(f"CRAFT_trOCR:\n{self.trOCR_texts}")
358
  # ocr_parts = ocr_parts + f"\nCRAFT_trOCR:\n{self.trOCR_texts}"
359
  ocr_parts = self.trOCR_texts
360
  return ocr_parts
 
379
 
380
  for bound, confidence, char_height, character in zip(bounds_flat, confidences, heights, characters):
381
  font_size = int(char_height)
382
+ try:
383
+ font = ImageFont.truetype("arial.ttf", font_size)
384
+ except:
385
+ font = ImageFont.load_default().font_variant(size=font_size)
386
  if option == 'trOCR':
387
  color = (0, 170, 255)
388
  else:
 
685
  self.OCR = self.OCR + part_OCR + part_OCR
686
  else:
687
  self.OCR = self.OCR + "\CRAFT trOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
688
+ # logger.info(f"CRAFT trOCR:\n{self.OCR}")
689
 
690
  if 'LLaVA' in self.OCR_option: # This option does not produce an OCR helper image
691
  self.json_report.set_text(text_main=f'Working on LLaVA {self.Llava.model_path} transcription :construction:')
 
703
  self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}" + f"\nLLaVA OCR:\n{str_output}"
704
  else:
705
  self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}"
706
+ # logger.info(f"LLaVA OCR:\n{self.OCR}")
707
 
708
  if 'normal' in self.OCR_option or 'hand' in self.OCR_option:
709
  if 'normal' in self.OCR_option:
710
+ if self.double_OCR:
711
+ part_OCR = self.OCR + "\nGoogle Printed OCR:\n" + self.detect_text()
712
+ self.OCR = self.OCR + part_OCR + part_OCR
713
+ else:
714
+ self.OCR = self.OCR + "\nGoogle Printed OCR:\n" + self.detect_text()
715
  if 'hand' in self.OCR_option:
716
+ if self.double_OCR:
717
+ part_OCR = self.OCR + "\nGoogle Handwritten OCR:\n" + self.detect_handwritten_ocr()
718
+ self.OCR = self.OCR + part_OCR + part_OCR
719
+ else:
720
+ self.OCR = self.OCR + "\nGoogle Handwritten OCR:\n" + self.detect_handwritten_ocr()
721
  # if self.OCR_option not in ['normal', 'hand', 'both']:
722
  # self.OCR_option = 'both'
723
  # self.detect_text()
724
  # self.detect_handwritten_ocr()
725
 
726
  ### Optionally add trOCR to the self.OCR for additional context
727
+ if self.do_use_trOCR:
728
+ if self.double_OCR:
729
+ part_OCR = "\ntrOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
730
+ self.OCR = self.OCR + part_OCR + part_OCR
731
+ else:
732
+ self.OCR = self.OCR + "\ntrOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
733
+ # logger.info(f"OCR:\n{self.OCR}")
734
 
735
  if do_create_OCR_helper_image and ('LLaVA' not in self.OCR_option):
736
  self.image = Image.open(self.path)
 
752
  image_with_boxes_normal = self.draw_boxes('normal')
753
  self.merged_image_normal = self.merge_images(image_with_boxes_normal, text_image_trOCR)
754
 
 
 
755
  ### Merge final overlay image
756
  ### [original, normal bboxes, normal text]
757
  if 'CRAFT' in self.OCR_option or 'normal' in self.OCR_option:
 
779
  self.overlay_image = Image.open(self.path)
780
 
781
  try:
782
+ from craft_text_detector import empty_cuda_cache
783
  empty_cuda_cache()
784
  except:
785
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vouchervision/OCR_llava.py CHANGED
@@ -3,20 +3,20 @@ import requests
3
  from PIL import Image
4
  from io import BytesIO
5
  import torch
6
- from transformers import AutoTokenizer, BitsAndBytesConfig, TextStreamer
7
 
8
- from langchain.prompts import PromptTemplate
9
  from langchain_core.output_parsers import JsonOutputParser
10
  from langchain_core.pydantic_v1 import BaseModel, Field
11
 
12
- from LLaVA.llava.model import LlavaLlamaForCausalLM
13
- from LLaVA.llava.model.builder import load_pretrained_model
14
- from LLaVA.llava.conversation import conv_templates, SeparatorStyle
15
- from LLaVA.llava.utils import disable_torch_init
16
- from LLaVA.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_PLACEHOLDER
17
- from LLaVA.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images
18
 
19
- from utils_LLM import SystemLoadMonitor
20
 
21
  '''
22
  Performance expectations system:
 
3
  from PIL import Image
4
  from io import BytesIO
5
  import torch
6
+ # from transformers import AutoTokenizer, BitsAndBytesConfig, TextStreamer
7
 
8
+ # from langchain.prompts import PromptTemplate
9
  from langchain_core.output_parsers import JsonOutputParser
10
  from langchain_core.pydantic_v1 import BaseModel, Field
11
 
12
+ # from vouchervision.LLaVA.llava.model import LlavaLlamaForCausalLM
13
+ from vouchervision.LLaVA.llava.model.builder import load_pretrained_model
14
+ from vouchervision.LLaVA.llava.conversation import conv_templates#, SeparatorStyle
15
+ from vouchervision.LLaVA.llava.utils import disable_torch_init
16
+ from vouchervision.LLaVA.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_PLACEHOLDER
17
+ from vouchervision.LLaVA.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images #KeywordsStoppingCriteria
18
 
19
+ from vouchervision.utils_LLM import SystemLoadMonitor
20
 
21
  '''
22
  Performance expectations system:
vouchervision/VoucherVision_Config_Builder.py CHANGED
@@ -36,16 +36,22 @@ def build_VV_config(loaded_cfg=None):
36
  save_cropped_annotations = ['label','barcode']
37
 
38
  do_use_trOCR = False
 
39
  OCR_option = 'hand'
40
  OCR_option_llava = 'llava-v1.6-mistral-7b' # "llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",
41
  OCR_option_llava_bit = 'full' # full or 4bit
42
  double_OCR = False
43
 
 
 
 
 
 
44
  check_for_illegal_filenames = False
45
 
46
  LLM_version_user = 'Azure GPT 3.5 Instruct' #'Azure GPT 4 Turbo 1106-preview'
47
- prompt_version = 'version_5.yaml' # from ["Version 1", "Version 1 No Domain Knowledge", "Version 2"]
48
- use_LeafMachine2_collage_images = False # Use LeafMachine2 collage images
49
  do_create_OCR_helper_image = True
50
 
51
  batch_size = 500
@@ -54,8 +60,8 @@ def build_VV_config(loaded_cfg=None):
54
  skip_vertical = False
55
  pdf_conversion_dpi = 100
56
 
57
- path_domain_knowledge = os.path.join(dir_home,'domain_knowledge','SLTP_UM_AllAsiaMinimalInRegion.xlsx')
58
- embeddings_database_name = os.path.splitext(os.path.basename(path_domain_knowledge))[0]
59
 
60
  #############################################
61
  #############################################
@@ -65,7 +71,9 @@ def build_VV_config(loaded_cfg=None):
65
  return assemble_config(dir_home, run_name, dir_images_local,dir_output,
66
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
67
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
68
- prompt_version, do_create_OCR_helper_image, do_use_trOCR, OCR_option, OCR_option_llava, OCR_option_llava_bit, double_OCR, save_cropped_annotations,
 
 
69
  check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
70
  else:
71
  dir_home = os.path.dirname(os.path.dirname(__file__))
@@ -80,11 +88,16 @@ def build_VV_config(loaded_cfg=None):
80
  catalog_numerical_only = loaded_cfg['leafmachine']['project']['catalog_numerical_only']
81
 
82
  do_use_trOCR = loaded_cfg['leafmachine']['project']['do_use_trOCR']
 
83
  OCR_option = loaded_cfg['leafmachine']['project']['OCR_option']
84
  OCR_option_llava = loaded_cfg['leafmachine']['project']['OCR_option_llava']
85
  OCR_option_llava_bit = loaded_cfg['leafmachine']['project']['OCR_option_llava_bit']
86
  double_OCR = loaded_cfg['leafmachine']['project']['double_OCR']
87
 
 
 
 
 
88
  pdf_conversion_dpi = loaded_cfg['leafmachine']['project']['pdf_conversion_dpi']
89
 
90
  LLM_version_user = loaded_cfg['leafmachine']['LLM_version']
@@ -105,14 +118,18 @@ def build_VV_config(loaded_cfg=None):
105
  return assemble_config(dir_home, run_name, dir_images_local,dir_output,
106
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
107
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
108
- prompt_version, do_create_OCR_helper_image, do_use_trOCR, OCR_option, OCR_option_llava, OCR_option_llava_bit, double_OCR, save_cropped_annotations,
 
 
109
  check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
110
 
111
 
112
  def assemble_config(dir_home, run_name, dir_images_local,dir_output,
113
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
114
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
115
- prompt_version, do_create_OCR_helper_image_user, do_use_trOCR, OCR_option, OCR_option_llava, OCR_option_llava_bit, double_OCR, save_cropped_annotations,
 
 
116
  check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False):
117
 
118
 
@@ -157,11 +174,15 @@ def assemble_config(dir_home, run_name, dir_images_local,dir_output,
157
  'delete_all_temps': False,
158
  'delete_temps_keep_VVE': False,
159
  'do_use_trOCR': do_use_trOCR,
 
160
  'OCR_option': OCR_option,
161
  'OCR_option_llava': OCR_option_llava,
162
  'OCR_option_llava_bit': OCR_option_llava_bit,
163
  'double_OCR': double_OCR,
164
  'pdf_conversion_dpi': pdf_conversion_dpi,
 
 
 
165
  }
166
 
167
  modules_section = {
 
36
  save_cropped_annotations = ['label','barcode']
37
 
38
  do_use_trOCR = False
39
+ trOCR_model_path = "microsoft/trocr-large-handwritten"
40
  OCR_option = 'hand'
41
  OCR_option_llava = 'llava-v1.6-mistral-7b' # "llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",
42
  OCR_option_llava_bit = 'full' # full or 4bit
43
  double_OCR = False
44
 
45
+
46
+ tool_GEO = True
47
+ tool_WFO = True
48
+ tool_wikipedia = True
49
+
50
  check_for_illegal_filenames = False
51
 
52
  LLM_version_user = 'Azure GPT 3.5 Instruct' #'Azure GPT 4 Turbo 1106-preview'
53
+ prompt_version = 'SLTPvA_long.yaml' # from ["Version 1", "Version 1 No Domain Knowledge", "Version 2"]
54
+ use_LeafMachine2_collage_images = True # Use LeafMachine2 collage images
55
  do_create_OCR_helper_image = True
56
 
57
  batch_size = 500
 
60
  skip_vertical = False
61
  pdf_conversion_dpi = 100
62
 
63
+ path_domain_knowledge = '' #os.path.join(dir_home,'domain_knowledge','SLTP_UM_AllAsiaMinimalInRegion.xlsx')
64
+ embeddings_database_name = '' #os.path.splitext(os.path.basename(path_domain_knowledge))[0]
65
 
66
  #############################################
67
  #############################################
 
71
  return assemble_config(dir_home, run_name, dir_images_local,dir_output,
72
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
73
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
74
+ prompt_version, do_create_OCR_helper_image, do_use_trOCR, trOCR_model_path, OCR_option, OCR_option_llava,
75
+ OCR_option_llava_bit, double_OCR, save_cropped_annotations,
76
+ tool_GEO, tool_WFO, tool_wikipedia,
77
  check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
78
  else:
79
  dir_home = os.path.dirname(os.path.dirname(__file__))
 
88
  catalog_numerical_only = loaded_cfg['leafmachine']['project']['catalog_numerical_only']
89
 
90
  do_use_trOCR = loaded_cfg['leafmachine']['project']['do_use_trOCR']
91
+ trOCR_model_path = loaded_cfg['leafmachine']['project']['trOCR_model_path']
92
  OCR_option = loaded_cfg['leafmachine']['project']['OCR_option']
93
  OCR_option_llava = loaded_cfg['leafmachine']['project']['OCR_option_llava']
94
  OCR_option_llava_bit = loaded_cfg['leafmachine']['project']['OCR_option_llava_bit']
95
  double_OCR = loaded_cfg['leafmachine']['project']['double_OCR']
96
 
97
+ tool_GEO = loaded_cfg['leafmachine']['project']['tool_GEO']
98
+ tool_WFO = loaded_cfg['leafmachine']['project']['tool_WFO']
99
+ tool_wikipedia = loaded_cfg['leafmachine']['project']['tool_wikipedia']
100
+
101
  pdf_conversion_dpi = loaded_cfg['leafmachine']['project']['pdf_conversion_dpi']
102
 
103
  LLM_version_user = loaded_cfg['leafmachine']['LLM_version']
 
118
  return assemble_config(dir_home, run_name, dir_images_local,dir_output,
119
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
120
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
121
+ prompt_version, do_create_OCR_helper_image, do_use_trOCR, trOCR_model_path, OCR_option, OCR_option_llava,
122
+ OCR_option_llava_bit, double_OCR, save_cropped_annotations,
123
+ tool_GEO, tool_WFO, tool_wikipedia,
124
  check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
125
 
126
 
127
  def assemble_config(dir_home, run_name, dir_images_local,dir_output,
128
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
129
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
130
+ prompt_version, do_create_OCR_helper_image_user, do_use_trOCR, trOCR_model_path, OCR_option, OCR_option_llava,
131
+ OCR_option_llava_bit, double_OCR, save_cropped_annotations,
132
+ tool_GEO, tool_WFO, tool_wikipedia,
133
  check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False):
134
 
135
 
 
174
  'delete_all_temps': False,
175
  'delete_temps_keep_VVE': False,
176
  'do_use_trOCR': do_use_trOCR,
177
+ 'trOCR_model_path': trOCR_model_path,
178
  'OCR_option': OCR_option,
179
  'OCR_option_llava': OCR_option_llava,
180
  'OCR_option_llava_bit': OCR_option_llava_bit,
181
  'double_OCR': double_OCR,
182
  'pdf_conversion_dpi': pdf_conversion_dpi,
183
+ 'tool_GEO': tool_GEO,
184
+ 'tool_WFO': tool_WFO,
185
+ 'tool_wikipedia': tool_wikipedia,
186
  }
187
 
188
  modules_section = {
vouchervision/model_maps.py CHANGED
@@ -206,7 +206,7 @@ class ModelMaps:
206
  return "text-unicorn@001"
207
 
208
  elif key == 'GEMINI_PRO':
209
- return "gemini-pro"
210
 
211
  ### Mistral
212
  elif key == 'MISTRAL_TINY':
 
206
  return "text-unicorn@001"
207
 
208
  elif key == 'GEMINI_PRO':
209
+ return "gemini-1.0-pro"
210
 
211
  ### Mistral
212
  elif key == 'MISTRAL_TINY':
vouchervision/tool_geolocate_HERE.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, requests
2
+ import pycountry_convert as pc
3
+ import unicodedata
4
+ import pycountry_convert as pc
5
+ import warnings
6
+
7
+
8
+ def normalize_country_name(name):
9
+ return unicodedata.normalize('NFKD', name).encode('ASCII', 'ignore').decode('ASCII')
10
+
11
+ def get_continent(country_name):
12
+ warnings.filterwarnings("ignore", category=UserWarning, module='pycountry')
13
+
14
+ continent_code_to_name = {
15
+ "AF": "Africa",
16
+ "NA": "North America",
17
+ "OC": "Oceania",
18
+ "AN": "Antarctica",
19
+ "AS": "Asia",
20
+ "EU": "Europe",
21
+ "SA": "South America"
22
+ }
23
+
24
+ try:
25
+ normalized_country_name = normalize_country_name(country_name)
26
+ # Get country alpha2 code
27
+ country_code = pc.country_name_to_country_alpha2(normalized_country_name)
28
+ # Get continent code from country alpha2 code
29
+ continent_code = pc.country_alpha2_to_continent_code(country_code)
30
+ # Map the continent code to continent name
31
+ return continent_code_to_name.get(continent_code, '')
32
+ except Exception as e:
33
+ print(str(e))
34
+ return ''
35
+
36
+ def validate_coordinates_here(tool_GEO, record, replace_if_success_geo=False):
37
+ forward_url = 'https://geocode.search.hereapi.com/v1/geocode'
38
+ reverse_url = 'https://revgeocode.search.hereapi.com/v1/revgeocode'
39
+
40
+ pinpoint = ['GEO_city','GEO_county','GEO_state','GEO_country',]
41
+ GEO_dict_null = {
42
+ 'GEO_override_OCR': False,
43
+ 'GEO_method': '',
44
+ 'GEO_formatted_full_string': '',
45
+ 'GEO_decimal_lat': '',
46
+ 'GEO_decimal_long': '',
47
+ 'GEO_city': '',
48
+ 'GEO_county': '',
49
+ 'GEO_state': '',
50
+ 'GEO_state_code': '',
51
+ 'GEO_country': '',
52
+ 'GEO_country_code': '',
53
+ 'GEO_continent': '',
54
+ }
55
+ GEO_dict = {
56
+ 'GEO_override_OCR': False,
57
+ 'GEO_method': '',
58
+ 'GEO_formatted_full_string': '',
59
+ 'GEO_decimal_lat': '',
60
+ 'GEO_decimal_long': '',
61
+ 'GEO_city': '',
62
+ 'GEO_county': '',
63
+ 'GEO_state': '',
64
+ 'GEO_state_code': '',
65
+ 'GEO_country': '',
66
+ 'GEO_country_code': '',
67
+ 'GEO_continent': '',
68
+ }
69
+ GEO_dict_rev = {
70
+ 'GEO_override_OCR': False,
71
+ 'GEO_method': '',
72
+ 'GEO_formatted_full_string': '',
73
+ 'GEO_decimal_lat': '',
74
+ 'GEO_decimal_long': '',
75
+ 'GEO_city': '',
76
+ 'GEO_county': '',
77
+ 'GEO_state': '',
78
+ 'GEO_state_code': '',
79
+ 'GEO_country': '',
80
+ 'GEO_country_code': '',
81
+ 'GEO_continent': '',
82
+ }
83
+ GEO_dict_rev_verbatim = {
84
+ 'GEO_override_OCR': False,
85
+ 'GEO_method': '',
86
+ 'GEO_formatted_full_string': '',
87
+ 'GEO_decimal_lat': '',
88
+ 'GEO_decimal_long': '',
89
+ 'GEO_city': '',
90
+ 'GEO_county': '',
91
+ 'GEO_state': '',
92
+ 'GEO_state_code': '',
93
+ 'GEO_country': '',
94
+ 'GEO_country_code': '',
95
+ 'GEO_continent': '',
96
+ }
97
+ GEO_dict_forward = {
98
+ 'GEO_override_OCR': False,
99
+ 'GEO_method': '',
100
+ 'GEO_formatted_full_string': '',
101
+ 'GEO_decimal_lat': '',
102
+ 'GEO_decimal_long': '',
103
+ 'GEO_city': '',
104
+ 'GEO_county': '',
105
+ 'GEO_state': '',
106
+ 'GEO_state_code': '',
107
+ 'GEO_country': '',
108
+ 'GEO_country_code': '',
109
+ 'GEO_continent': '',
110
+ }
111
+ GEO_dict_forward_locality = {
112
+ 'GEO_override_OCR': False,
113
+ 'GEO_method': '',
114
+ 'GEO_formatted_full_string': '',
115
+ 'GEO_decimal_lat': '',
116
+ 'GEO_decimal_long': '',
117
+ 'GEO_city': '',
118
+ 'GEO_county': '',
119
+ 'GEO_state': '',
120
+ 'GEO_state_code': '',
121
+ 'GEO_country': '',
122
+ 'GEO_country_code': '',
123
+ 'GEO_continent': '',
124
+ }
125
+
126
+ if not tool_GEO:
127
+ return record, GEO_dict_null
128
+ else:
129
+ # For production
130
+ query_forward = ', '.join(filter(None, [record.get('municipality', '').strip(),
131
+ record.get('county', '').strip(),
132
+ record.get('stateProvince', '').strip(),
133
+ record.get('country', '').strip()])).strip()
134
+ query_forward_locality = ', '.join(filter(None, [record.get('locality', '').strip(),
135
+ record.get('municipality', '').strip(),
136
+ record.get('county', '').strip(),
137
+ record.get('stateProvince', '').strip(),
138
+ record.get('country', '').strip()])).strip()
139
+ query_reverse = ','.join(filter(None, [record.get('decimalLatitude', '').strip(),
140
+ record.get('decimalLongitude', '').strip()])).strip()
141
+ query_reverse_verbatim = record.get('verbatimCoordinates', '').strip()
142
+
143
+
144
+ '''
145
+ #For testing
146
+ # query_forward = 'Ann bor, michign'
147
+ query_forward = 'michigan'
148
+ query_forward_locality = 'Ann bor, michign'
149
+ # query_gps = "42 N,-83 W" # cannot have any spaces
150
+ # query_reverse_verbatim = "42.278366,-83.744718" # cannot have any spaces
151
+ query_reverse_verbatim = "42,-83" # cannot have any spaces
152
+ query_reverse = "42,-83" # cannot have any spaces
153
+ # params = {
154
+ # 'q': query_loc,
155
+ # 'apiKey': os.environ['HERE_API_KEY'],
156
+ # }'''
157
+
158
+
159
+ params_rev = {
160
+ 'at': query_reverse,
161
+ 'apiKey': os.environ['HERE_API_KEY'],
162
+ 'lang': 'en',
163
+ }
164
+ params_reverse_verbatim = {
165
+ 'at': query_reverse_verbatim,
166
+ 'apiKey': os.environ['HERE_API_KEY'],
167
+ 'lang': 'en',
168
+ }
169
+ params_forward = {
170
+ 'q': query_forward,
171
+ 'apiKey': os.environ['HERE_API_KEY'],
172
+ 'lang': 'en',
173
+ }
174
+ params_forward_locality = {
175
+ 'q': query_forward_locality,
176
+ 'apiKey': os.environ['HERE_API_KEY'],
177
+ 'lang': 'en',
178
+ }
179
+
180
+ ### REVERSE
181
+ # If there are two string in the coordinates, try a reverse first based on the literal coordinates
182
+ response = requests.get(reverse_url, params=params_rev)
183
+ if response.status_code == 200:
184
+ data = response.json()
185
+ if data.get('items'):
186
+ first_result = data['items'][0]
187
+ GEO_dict_rev['GEO_method'] = 'HERE_Geocode_reverse'
188
+ GEO_dict_rev['GEO_formatted_full_string'] = first_result.get('title', '')
189
+ GEO_dict_rev['GEO_decimal_lat'] = first_result['position']['lat']
190
+ GEO_dict_rev['GEO_decimal_long'] = first_result['position']['lng']
191
+
192
+ address = first_result.get('address', {})
193
+ GEO_dict_rev['GEO_city'] = address.get('city', '')
194
+ GEO_dict_rev['GEO_county'] = address.get('county', '')
195
+ GEO_dict_rev['GEO_state'] = address.get('state', '')
196
+ GEO_dict_rev['GEO_state_code'] = address.get('stateCode', '')
197
+ GEO_dict_rev['GEO_country'] = address.get('countryName', '')
198
+ GEO_dict_rev['GEO_country_code'] = address.get('countryCode', '')
199
+ GEO_dict_rev['GEO_continent'] = get_continent(address.get('countryName', ''))
200
+
201
+ ### REVERSE Verbatim
202
+ # If there are two string in the coordinates, try a reverse first based on the literal coordinates
203
+ if GEO_dict_rev['GEO_city']: # If the reverse was successful, pass
204
+ GEO_dict = GEO_dict_rev
205
+ else:
206
+ response = requests.get(reverse_url, params=params_reverse_verbatim)
207
+ if response.status_code == 200:
208
+ data = response.json()
209
+ if data.get('items'):
210
+ first_result = data['items'][0]
211
+ GEO_dict_rev_verbatim['GEO_method'] = 'HERE_Geocode_reverse_verbatimCoordinates'
212
+ GEO_dict_rev_verbatim['GEO_formatted_full_string'] = first_result.get('title', '')
213
+ GEO_dict_rev_verbatim['GEO_decimal_lat'] = first_result['position']['lat']
214
+ GEO_dict_rev_verbatim['GEO_decimal_long'] = first_result['position']['lng']
215
+
216
+ address = first_result.get('address', {})
217
+ GEO_dict_rev_verbatim['GEO_city'] = address.get('city', '')
218
+ GEO_dict_rev_verbatim['GEO_county'] = address.get('county', '')
219
+ GEO_dict_rev_verbatim['GEO_state'] = address.get('state', '')
220
+ GEO_dict_rev_verbatim['GEO_state_code'] = address.get('stateCode', '')
221
+ GEO_dict_rev_verbatim['GEO_country'] = address.get('countryName', '')
222
+ GEO_dict_rev_verbatim['GEO_country_code'] = address.get('countryCode', '')
223
+ GEO_dict_rev_verbatim['GEO_continent'] = get_continent(address.get('countryName', ''))
224
+
225
+ ### FORWARD
226
+ ### Try forward, if failes, try reverse using deci, then verbatim
227
+ if GEO_dict_rev['GEO_city']: # If the reverse was successful, pass
228
+ GEO_dict = GEO_dict_rev
229
+ elif GEO_dict_rev_verbatim['GEO_city']:
230
+ GEO_dict = GEO_dict_rev_verbatim
231
+ else:
232
+ response = requests.get(forward_url, params=params_forward)
233
+ if response.status_code == 200:
234
+ data = response.json()
235
+ if data.get('items'):
236
+ first_result = data['items'][0]
237
+ GEO_dict_forward['GEO_method'] = 'HERE_Geocode_forward'
238
+ GEO_dict_forward['GEO_formatted_full_string'] = first_result.get('title', '')
239
+ GEO_dict_forward['GEO_decimal_lat'] = first_result['position']['lat']
240
+ GEO_dict_forward['GEO_decimal_long'] = first_result['position']['lng']
241
+
242
+ address = first_result.get('address', {})
243
+ GEO_dict_forward['GEO_city'] = address.get('city', '')
244
+ GEO_dict_forward['GEO_county'] = address.get('county', '')
245
+ GEO_dict_forward['GEO_state'] = address.get('state', '')
246
+ GEO_dict_forward['GEO_state_code'] = address.get('stateCode', '')
247
+ GEO_dict_forward['GEO_country'] = address.get('countryName', '')
248
+ GEO_dict_forward['GEO_country_code'] = address.get('countryCode', '')
249
+ GEO_dict_forward['GEO_continent'] = get_continent(address.get('countryName', ''))
250
+
251
+ ### FORWARD locality
252
+ ### Try forward, if failes, try reverse using deci, then verbatim
253
+ if GEO_dict_rev['GEO_city']: # If the reverse was successful, pass
254
+ GEO_dict = GEO_dict_rev
255
+ elif GEO_dict_rev_verbatim['GEO_city']:
256
+ GEO_dict = GEO_dict_rev_verbatim
257
+ elif GEO_dict_forward['GEO_city']:
258
+ GEO_dict = GEO_dict_forward
259
+ else:
260
+ response = requests.get(forward_url, params=params_forward_locality)
261
+ if response.status_code == 200:
262
+ data = response.json()
263
+ if data.get('items'):
264
+ first_result = data['items'][0]
265
+ GEO_dict_forward_locality['GEO_method'] = 'HERE_Geocode_forward_locality'
266
+ GEO_dict_forward_locality['GEO_formatted_full_string'] = first_result.get('title', '')
267
+ GEO_dict_forward_locality['GEO_decimal_lat'] = first_result['position']['lat']
268
+ GEO_dict_forward_locality['GEO_decimal_long'] = first_result['position']['lng']
269
+
270
+ address = first_result.get('address', {})
271
+ GEO_dict_forward_locality['GEO_city'] = address.get('city', '')
272
+ GEO_dict_forward_locality['GEO_county'] = address.get('county', '')
273
+ GEO_dict_forward_locality['GEO_state'] = address.get('state', '')
274
+ GEO_dict_forward_locality['GEO_state_code'] = address.get('stateCode', '')
275
+ GEO_dict_forward_locality['GEO_country'] = address.get('countryName', '')
276
+ GEO_dict_forward_locality['GEO_country_code'] = address.get('countryCode', '')
277
+ GEO_dict_forward_locality['GEO_continent'] = get_continent(address.get('countryName', ''))
278
+
279
+
280
+ # print(json.dumps(GEO_dict,indent=4))
281
+
282
+
283
+ # Pick the most detailed version
284
+ # if GEO_dict_rev['GEO_formatted_full_string'] and GEO_dict_forward['GEO_formatted_full_string']:
285
+ for loc in pinpoint:
286
+ rev = GEO_dict_rev.get(loc,'')
287
+ forward = GEO_dict_forward.get(loc,'')
288
+ forward_locality = GEO_dict_forward_locality.get(loc,'')
289
+ rev_verbatim = GEO_dict_rev_verbatim.get(loc,'')
290
+
291
+ if not rev and not forward and not forward_locality and not rev_verbatim:
292
+ pass
293
+ elif rev:
294
+ GEO_dict = GEO_dict_rev
295
+ break
296
+ elif forward:
297
+ GEO_dict = GEO_dict_forward
298
+ break
299
+ elif forward_locality:
300
+ GEO_dict = GEO_dict_forward_locality
301
+ break
302
+ elif rev_verbatim:
303
+ GEO_dict = GEO_dict_rev_verbatim
304
+ break
305
+ else:
306
+ GEO_dict = GEO_dict_null
307
+
308
+
309
+ if GEO_dict['GEO_formatted_full_string'] and replace_if_success_geo:
310
+ GEO_dict['GEO_override_OCR'] = True
311
+ record['country'] = GEO_dict.get('GEO_country')
312
+ record['stateProvince'] = GEO_dict.get('GEO_state')
313
+ record['county'] = GEO_dict.get('GEO_county')
314
+ record['municipality'] = GEO_dict.get('GEO_city')
315
+
316
+ # print(json.dumps(GEO_dict,indent=4))
317
+ return record, GEO_dict
318
+
319
+
320
+ if __name__ == "__main__":
321
+ validate_coordinates_here(None)
vouchervision/tool_taxonomy_WFO.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from urllib.parse import urlencode
3
+ from Levenshtein import ratio
4
+ from fuzzywuzzy import fuzz
5
+
6
+ class WFONameMatcher:
7
+ def __init__(self, tool_WFO):
8
+ self.base_url = "https://list.worldfloraonline.org/matching_rest.php?"
9
+ self.N_BEST_CANDIDATES = 10
10
+ self.NULL_DICT = {
11
+ "WFO_exact_match": False,
12
+ "WFO_exact_match_name": "",
13
+ "WFO_candidate_names": "",
14
+ "WFO_best_match": "",
15
+ "WFO_placement": "",
16
+ "WFO_override_OCR": False,
17
+ }
18
+ self.SEP = '|'
19
+ self.is_enabled = tool_WFO
20
+
21
+ def extract_input_string(self, record):
22
+ primary_input = f"{record.get('scientificName', '').strip()} {record.get('scientificNameAuthorship', '').strip()}".strip()
23
+ secondary_input = ' '.join(filter(None, [record.get('genus', '').strip(),
24
+ record.get('subgenus', '').strip(),
25
+ record.get('specificEpithet', '').strip(),
26
+ record.get('infraspecificEpithet', '').strip()])).strip()
27
+
28
+ return primary_input, secondary_input
29
+
30
+ def query_wfo_name_matching(self, input_string, check_homonyms=True, check_rank=True, accept_single_candidate=True):
31
+ params = {
32
+ "input_string": input_string,
33
+ "check_homonyms": check_homonyms,
34
+ "check_rank": check_rank,
35
+ "method": "full",
36
+ "accept_single_candidate": accept_single_candidate,
37
+ }
38
+
39
+ full_url = self.base_url + urlencode(params)
40
+
41
+ response = requests.get(full_url)
42
+ if response.status_code == 200:
43
+ return response.json()
44
+ else:
45
+ return {"error": True, "message": "Failed to fetch data from WFO API"}
46
+
47
+ def query_and_process(self, record):
48
+ primary_input, secondary_input = self.extract_input_string(record)
49
+
50
+ # Query with primary input
51
+ primary_result = self.query_wfo_name_matching(primary_input)
52
+ primary_processed, primary_ranked_candidates = self.process_wfo_response(primary_result, primary_input)
53
+
54
+ if primary_processed.get('WFO_exact_match'):
55
+ print("Selected Primary --- Exact Primary & Unchecked Secondary")
56
+ return primary_processed
57
+ else:
58
+ # Query with secondary input
59
+ secondary_result = self.query_wfo_name_matching(secondary_input)
60
+ secondary_processed, secondary_ranked_candidates = self.process_wfo_response(secondary_result, secondary_input)
61
+
62
+ if secondary_processed.get('WFO_exact_match'):
63
+ print("Selected Secondary --- Unchecked Primary & Exact Secondary")
64
+ return secondary_processed
65
+
66
+ else:
67
+ # Both failed, just return the first failure
68
+ if (primary_processed.get("WFO_candidate_names") == '') and (secondary_processed.get("WFO_candidate_names") == ''):
69
+ print("Selected Primary --- Failed Primary & Failed Secondary")
70
+ return primary_processed
71
+
72
+ # 1st failed, just return the second
73
+ elif (primary_processed.get("WFO_candidate_names") == '') and (len(secondary_processed.get("WFO_candidate_names")) > 0):
74
+ print("Selected Secondary --- Failed Primary & Partial Secondary")
75
+ return secondary_processed
76
+
77
+ # 2nd failed, just return the first
78
+ elif (len(primary_processed.get("WFO_candidate_names")) > 0) and (secondary_processed.get("WFO_candidate_names") == ''):
79
+ print("Selected Primary --- Partial Primary & Failed Secondary")
80
+ return primary_processed
81
+
82
+ # Both have partial matches, compare and rerank
83
+ elif (len(primary_processed.get("WFO_candidate_names")) > 0) and (len(secondary_processed.get("WFO_candidate_names")) > 0):
84
+ # Combine and sort results, ensuring no duplicates
85
+ combined_candidates = list(set(primary_ranked_candidates + secondary_ranked_candidates))
86
+ combined_candidates.sort(key=lambda x: (x[1], x[0]), reverse=True) # Sort by similarity score, then name
87
+
88
+ # Replace candidates with combined_candidates and combined best match
89
+ best_score_primary = primary_processed["WFO_candidate_names"][0][1]
90
+ best_score_secondary = secondary_processed["WFO_candidate_names"][0][1]
91
+
92
+ # Extracting only the candidate names from the top candidates
93
+ top_candidates = combined_candidates[:self.N_BEST_CANDIDATES]
94
+ cleaned_candidates = [cand[0] for cand in top_candidates]
95
+
96
+ if best_score_primary >= best_score_secondary:
97
+
98
+ primary_processed["WFO_candidate_names"] = cleaned_candidates
99
+ primary_processed["WFO_best_match"] = cleaned_candidates[0]
100
+
101
+ response_placement = self.query_wfo_name_matching(primary_processed["WFO_best_match"])
102
+ placement_exact_match = response_placement.get("match")
103
+ primary_processed["WFO_placement"] = placement_exact_match.get("placement", '')
104
+
105
+ print("Selected Primary --- Partial Primary & Partial Secondary")
106
+ return primary_processed
107
+ else:
108
+ secondary_processed["WFO_candidate_names"] = cleaned_candidates
109
+ secondary_processed["WFO_best_match"] = cleaned_candidates[0]
110
+
111
+ response_placement = self.query_wfo_name_matching(secondary_processed["WFO_best_match"])
112
+ placement_exact_match = response_placement.get("match")
113
+ secondary_processed["WFO_placement"] = placement_exact_match.get("placement", '')
114
+
115
+ print("Selected Secondary --- Partial Primary & Partial Secondary")
116
+ return secondary_processed
117
+ else:
118
+ return self.NULL_DICT
119
+
120
+ def process_wfo_response(self, response, query):
121
+ simplified_response = {}
122
+ ranked_candidates = None
123
+
124
+ exact_match = response.get("match")
125
+ simplified_response["WFO_exact_match"] = bool(exact_match)
126
+
127
+ candidates = response.get("candidates", [])
128
+ candidate_names = [candidate["full_name_plain"] for candidate in candidates] if candidates else []
129
+
130
+ if not exact_match and candidate_names:
131
+ cleaned_candidates, ranked_candidates = self._rank_candidates_by_similarity(query, candidate_names)
132
+ simplified_response["WFO_candidate_names"] = cleaned_candidates
133
+ simplified_response["WFO_best_match"] = cleaned_candidates[0] if cleaned_candidates else ''
134
+ elif exact_match:
135
+ simplified_response["WFO_candidate_names"] = exact_match.get("full_name_plain")
136
+ simplified_response["WFO_best_match"] = exact_match.get("full_name_plain")
137
+ else:
138
+ simplified_response["WFO_candidate_names"] = ''
139
+ simplified_response["WFO_best_match"] = ''
140
+
141
+ # Call WFO again to update placement using WFO_best_match
142
+ try:
143
+ response_placement = self.query_wfo_name_matching(simplified_response["WFO_best_match"])
144
+ placement_exact_match = response_placement.get("match")
145
+ simplified_response["WFO_placement"] = placement_exact_match.get("placement", '')
146
+ except:
147
+ simplified_response["WFO_placement"] = ''
148
+
149
+ return simplified_response, ranked_candidates
150
+
151
+ def _rank_candidates_by_similarity(self, query, candidates):
152
+ string_similarities = []
153
+ fuzzy_similarities = {candidate: fuzz.ratio(query, candidate) for candidate in candidates}
154
+ query_words = query.split()
155
+
156
+ for candidate in candidates:
157
+ candidate_words = candidate.split()
158
+ # Calculate word similarities and sum them up
159
+ word_similarities = [ratio(query_word, candidate_word) for query_word, candidate_word in zip(query_words, candidate_words)]
160
+ total_word_similarity = sum(word_similarities)
161
+
162
+ # Calculate combined similarity score (average of word and fuzzy similarities)
163
+ fuzzy_similarity = fuzzy_similarities[candidate]
164
+ combined_similarity = (total_word_similarity + fuzzy_similarity) / 2
165
+ string_similarities.append((candidate, combined_similarity))
166
+
167
+ # Sort the candidates based on combined similarity, higher scores first
168
+ ranked_candidates = sorted(string_similarities, key=lambda x: x[1], reverse=True)
169
+
170
+ # Extracting only the candidate names from the top candidates
171
+ top_candidates = ranked_candidates[:self.N_BEST_CANDIDATES]
172
+ cleaned_candidates = [cand[0] for cand in top_candidates]
173
+
174
+ return cleaned_candidates, ranked_candidates
175
+
176
+ def check_WFO(self, record, replace_if_success_wfo):
177
+ if not self.is_enabled:
178
+ return record, self.NULL_DICT
179
+
180
+ else:
181
+ self.replace_if_success_wfo = replace_if_success_wfo
182
+
183
+ # "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
184
+ simplified_response = self.query_and_process(record)
185
+ simplified_response['WFO_override_OCR'] = False
186
+
187
+ # best_match
188
+ if simplified_response.get('WFO_exact_match'):
189
+ simplified_response['WFO_exact_match_name'] = simplified_response.get('WFO_best_match')
190
+ else:
191
+ simplified_response['WFO_exact_match_name'] = ''
192
+
193
+ # placement
194
+ wfo_placement = simplified_response.get('WFO_placement', '')
195
+ if wfo_placement:
196
+ parts = wfo_placement.split('/')[1:]
197
+ simplified_response['WFO_placement'] = self.SEP.join(parts)
198
+ else:
199
+ simplified_response['WFO_placement'] = ''
200
+
201
+ if simplified_response.get('WFO_exact_match') and replace_if_success_wfo:
202
+ simplified_response['WFO_override_OCR'] = True
203
+ name_parts = simplified_response.get('WFO_placement').split('$')[0]
204
+ name_parts = name_parts.split(self.SEP)
205
+ record['order'] = name_parts[3]
206
+ record['family'] = name_parts[4]
207
+ record['genus'] = name_parts[5]
208
+ record['specificEpithet'] = name_parts[6]
209
+ record['scientificName'] = simplified_response.get('WFO_exact_match_name')
210
+
211
+ return record, simplified_response
212
+
213
+ def validate_taxonomy_WFO(tool_WFO, record_dict, replace_if_success_wfo=False):
214
+ Matcher = WFONameMatcher(tool_WFO)
215
+ try:
216
+ record_dict, WFO_dict = Matcher.check_WFO(record_dict, replace_if_success_wfo)
217
+ return record_dict, WFO_dict
218
+ except:
219
+ return record_dict, Matcher.NULL_DICT
220
+
221
+ '''
222
+ if __name__ == "__main__":
223
+ Matcher = WFONameMatcher()
224
+ # input_string = "Rhopalocarpus alterfolius"
225
+ record_exact_match ={
226
+ "order": "Malpighiales",
227
+ "family": "Hypericaceae",
228
+ "scientificName": "Hypericum prolificum",
229
+ "scientificNameAuthorship": "",
230
+
231
+ "genus": "Hypericum",
232
+ "subgenus": "",
233
+ "specificEpithet": "prolificum",
234
+ "infraspecificEpithet": "",
235
+ }
236
+ record_partialPrimary_exactSecondary ={
237
+ "order": "Malpighiales",
238
+ "family": "Hypericaceae",
239
+ "scientificName": "Hyperic prolificum",
240
+ "scientificNameAuthorship": "",
241
+
242
+ "genus": "Hypericum",
243
+ "subgenus": "",
244
+ "specificEpithet": "prolificum",
245
+ "infraspecificEpithet": "",
246
+ }
247
+ record_exactPrimary_partialSecondary ={
248
+ "order": "Malpighiales",
249
+ "family": "Hypericaceae",
250
+ "scientificName": "Hypericum prolificum",
251
+ "scientificNameAuthorship": "",
252
+
253
+ "genus": "Hyperic",
254
+ "subgenus": "",
255
+ "specificEpithet": "prolificum",
256
+ "infraspecificEpithet": "",
257
+ }
258
+ record_partialPrimary_partialSecondary ={
259
+ "order": "Malpighiales",
260
+ "family": "Hypericaceae",
261
+ "scientificName": "Hyperic prolificum",
262
+ "scientificNameAuthorship": "",
263
+
264
+ "genus": "Hypericum",
265
+ "subgenus": "",
266
+ "specificEpithet": "prolific",
267
+ "infraspecificEpithet": "",
268
+ }
269
+ record_partialPrimary_partialSecondary_swap ={
270
+ "order": "Malpighiales",
271
+ "family": "Hypericaceae",
272
+ "scientificName": "Hypericum prolific",
273
+ "scientificNameAuthorship": "",
274
+
275
+ "genus": "Hyperic",
276
+ "subgenus": "",
277
+ "specificEpithet": "prolificum",
278
+ "infraspecificEpithet": "",
279
+ }
280
+ record_errorPrimary_partialSecondary ={
281
+ "order": "Malpighiales",
282
+ "family": "Hypericaceae",
283
+ "scientificName": "ricum proli",
284
+ "scientificNameAuthorship": "",
285
+
286
+ "genus": "Hyperic",
287
+ "subgenus": "",
288
+ "specificEpithet": "prolificum",
289
+ "infraspecificEpithet": "",
290
+ }
291
+ record_partialPrimary_errorSecondary ={
292
+ "order": "Malpighiales",
293
+ "family": "Hypericaceae",
294
+ "scientificName": "Hyperic prolificum",
295
+ "scientificNameAuthorship": "",
296
+
297
+ "genus": "ricum",
298
+ "subgenus": "",
299
+ "specificEpithet": "proli",
300
+ "infraspecificEpithet": "",
301
+ }
302
+ record_errorPrimary_errorSecondary ={
303
+ "order": "Malpighiales",
304
+ "family": "Hypericaceae",
305
+ "scientificName": "ricum proli",
306
+ "scientificNameAuthorship": "",
307
+
308
+ "genus": "ricum",
309
+ "subgenus": "",
310
+ "specificEpithet": "proli",
311
+ "infraspecificEpithet": "",
312
+ }
313
+ options = [record_exact_match,
314
+ record_partialPrimary_exactSecondary,
315
+ record_exactPrimary_partialSecondary,
316
+ record_partialPrimary_partialSecondary,
317
+ record_partialPrimary_partialSecondary_swap,
318
+ record_errorPrimary_partialSecondary,
319
+ record_partialPrimary_errorSecondary,
320
+ record_errorPrimary_errorSecondary]
321
+ for opt in options:
322
+ simplified_response = Matcher.check_WFO(opt)
323
+ print(json.dumps(simplified_response, indent=4))
324
+ '''
vouchervision/tool_wikipedia.py CHANGED
@@ -8,7 +8,8 @@ import pstats
8
  class WikipediaLinks():
9
 
10
 
11
- def __init__(self, json_file_path_wiki) -> None:
 
12
  self.json_file_path_wiki = json_file_path_wiki
13
  self.wiki_wiki = wikipediaapi.Wikipedia(
14
  user_agent='VoucherVision (merlin@example.com)',
@@ -466,54 +467,56 @@ class WikipediaLinks():
466
  self.info_packet['WIKI_GEO'] = {}
467
  self.info_packet['WIKI_LOCALITY'] = {}
468
 
469
- municipality = output.get('municipality','')
470
- county = output.get('county','')
471
- stateProvince = output.get('stateProvince','')
472
- country = output.get('country','')
473
 
474
- locality = output.get('locality','')
 
 
 
475
 
476
- order = output.get('order','')
477
- family = output.get('family','')
478
- scientificName = output.get('scientificName','')
479
- genus = output.get('genus','')
480
- specificEpithet = output.get('specificEpithet','')
481
 
 
 
 
 
 
482
 
483
- query_geo = ' '.join([municipality, county, stateProvince, country]).strip()
484
- query_locality = locality.strip()
485
- query_taxa_primary = scientificName.strip()
486
- query_taxa_secondary = ' '.join([genus, specificEpithet]).strip()
487
- query_taxa_tertiary = ' '.join([order, family, genus, specificEpithet]).strip()
488
 
489
- # query_taxa = "Tracaulon sagittatum Tracaulon sagittatum"
490
- # query_geo = "Indiana Porter Co."
491
- # query_locality = "Mical Springs edge"
492
-
493
- if query_geo:
494
- try:
495
- self.gather_geo(query_geo)
496
- except:
497
- pass
498
-
499
- if query_locality:
500
- try:
501
- self.gather_geo(query_locality,'locality')
502
- except:
503
- pass
504
-
505
- queries_taxa = [query_taxa_primary, query_taxa_secondary, query_taxa_tertiary]
506
- for q in queries_taxa:
507
- if q:
508
  try:
509
- self.gather_taxonomy(q)
510
- break
511
  except:
512
  pass
513
-
514
- # print(self.info_packet)
515
- # return self.info_packet
516
- # self.gather_geo(query_geo)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  try:
518
  with open(self.json_file_path_wiki, 'w', encoding='utf-8') as file:
519
  json.dump(self.info_packet, file, indent=4)
@@ -547,6 +550,13 @@ class WikipediaLinks():
547
  return clean_text
548
 
549
 
 
 
 
 
 
 
 
550
  if __name__ == '__main__':
551
  test_output = {
552
  "filename": "MICH_7375774_Polygonaceae_Persicaria_",
 
8
  class WikipediaLinks():
9
 
10
 
11
+ def __init__(self, tool_wikipedia, json_file_path_wiki) -> None:
12
+ self.is_enabled = tool_wikipedia
13
  self.json_file_path_wiki = json_file_path_wiki
14
  self.wiki_wiki = wikipediaapi.Wikipedia(
15
  user_agent='VoucherVision (merlin@example.com)',
 
467
  self.info_packet['WIKI_GEO'] = {}
468
  self.info_packet['WIKI_LOCALITY'] = {}
469
 
470
+ if self.is_enabled:
 
 
 
471
 
472
+ municipality = output.get('municipality','')
473
+ county = output.get('county','')
474
+ stateProvince = output.get('stateProvince','')
475
+ country = output.get('country','')
476
 
477
+ locality = output.get('locality','')
 
 
 
 
478
 
479
+ order = output.get('order','')
480
+ family = output.get('family','')
481
+ scientificName = output.get('scientificName','')
482
+ genus = output.get('genus','')
483
+ specificEpithet = output.get('specificEpithet','')
484
 
 
 
 
 
 
485
 
486
+ query_geo = ' '.join([municipality, county, stateProvince, country]).strip()
487
+ query_locality = locality.strip()
488
+ query_taxa_primary = scientificName.strip()
489
+ query_taxa_secondary = ' '.join([genus, specificEpithet]).strip()
490
+ query_taxa_tertiary = ' '.join([order, family, genus, specificEpithet]).strip()
491
+
492
+ # query_taxa = "Tracaulon sagittatum Tracaulon sagittatum"
493
+ # query_geo = "Indiana Porter Co."
494
+ # query_locality = "Mical Springs edge"
495
+
496
+ if query_geo:
 
 
 
 
 
 
 
 
497
  try:
498
+ self.gather_geo(query_geo)
 
499
  except:
500
  pass
501
+
502
+ if query_locality:
503
+ try:
504
+ self.gather_geo(query_locality,'locality')
505
+ except:
506
+ pass
507
+
508
+ queries_taxa = [query_taxa_primary, query_taxa_secondary, query_taxa_tertiary]
509
+ for q in queries_taxa:
510
+ if q:
511
+ try:
512
+ self.gather_taxonomy(q)
513
+ break
514
+ except:
515
+ pass
516
+
517
+ # print(self.info_packet)
518
+ # return self.info_packet
519
+ # self.gather_geo(query_geo)
520
  try:
521
  with open(self.json_file_path_wiki, 'w', encoding='utf-8') as file:
522
  json.dump(self.info_packet, file, indent=4)
 
550
  return clean_text
551
 
552
 
553
+
554
+ def validate_wikipedia(tool_wikipedia, json_file_path_wiki, output):
555
+ Wiki = WikipediaLinks(tool_wikipedia, json_file_path_wiki)
556
+ Wiki.gather_wikipedia_results(output)
557
+
558
+
559
+
560
  if __name__ == '__main__':
561
  test_output = {
562
  "filename": "MICH_7375774_Polygonaceae_Persicaria_",
vouchervision/utils_LLM.py CHANGED
@@ -8,6 +8,60 @@ import psutil
8
  import threading
9
  import torch
10
  from datetime import datetime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def save_individual_prompt(prompt_template, txt_file_path_ind_prompt):
13
  with open(txt_file_path_ind_prompt, 'w',encoding='utf-8') as file:
@@ -19,6 +73,16 @@ def remove_colons_and_double_apostrophes(text):
19
  return text.replace(":", "").replace("\"", "")
20
 
21
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def count_tokens(string, vendor, model_name):
24
  full_string = string + JSON_FORMAT_INSTRUCTIONS
 
8
  import threading
9
  import torch
10
  from datetime import datetime
11
+ from vouchervision.tool_taxonomy_WFO import validate_taxonomy_WFO, WFONameMatcher
12
+ from vouchervision.tool_geolocate_HERE import validate_coordinates_here
13
+ from vouchervision.tool_wikipedia import validate_wikipedia
14
+ from concurrent.futures import ThreadPoolExecutor, as_completed
15
+
16
+
17
+ def run_tools(output, tool_WFO, tool_GEO, tool_wikipedia, json_file_path_wiki):
18
+ # Define a function that will catch and return the results of your functions
19
+ def task(func, *args, **kwargs):
20
+ return func(*args, **kwargs)
21
+
22
+ # List of tasks to run in separate threads
23
+ tasks = [
24
+ (validate_taxonomy_WFO, (tool_WFO, output, False)),
25
+ (validate_coordinates_here, (tool_GEO, output, False)),
26
+ (validate_wikipedia, (tool_wikipedia, json_file_path_wiki, output)),
27
+ ]
28
+
29
+ # Results storage
30
+ results = {}
31
+
32
+ # Use ThreadPoolExecutor to execute each function in its own thread
33
+ with ThreadPoolExecutor() as executor:
34
+ future_to_func = {executor.submit(task, func, *args): func.__name__ for func, args in tasks}
35
+ for future in as_completed(future_to_func):
36
+ func_name = future_to_func[future]
37
+ try:
38
+ # Collecting results
39
+ results[func_name] = future.result()
40
+ except Exception as exc:
41
+ print(f'{func_name} generated an exception: {exc}')
42
+
43
+ # Here, all threads have completed
44
+ # Extracting results
45
+ Matcher = WFONameMatcher(tool_WFO)
46
+ GEO_dict_null = {
47
+ 'GEO_override_OCR': False,
48
+ 'GEO_method': '',
49
+ 'GEO_formatted_full_string': '',
50
+ 'GEO_decimal_lat': '',
51
+ 'GEO_decimal_long': '',
52
+ 'GEO_city': '',
53
+ 'GEO_county': '',
54
+ 'GEO_state': '',
55
+ 'GEO_state_code': '',
56
+ 'GEO_country': '',
57
+ 'GEO_country_code': '',
58
+ 'GEO_continent': '',
59
+ }
60
+ output_WFO, WFO_record = results.get('validate_taxonomy_WFO', (output, Matcher.NULL_DICT))
61
+ output_GEO, GEO_record = results.get('validate_coordinates_here', (output, GEO_dict_null))
62
+
63
+ return output_WFO, WFO_record, output_GEO, GEO_record
64
+
65
 
66
  def save_individual_prompt(prompt_template, txt_file_path_ind_prompt):
67
  with open(txt_file_path_ind_prompt, 'w',encoding='utf-8') as file:
 
73
  return text.replace(":", "").replace("\"", "")
74
 
75
 
76
+ def sanitize_prompt(data):
77
+ if isinstance(data, dict):
78
+ return {sanitize_prompt(key): sanitize_prompt(value) for key, value in data.items()}
79
+ elif isinstance(data, list):
80
+ return [sanitize_prompt(element) for element in data]
81
+ elif isinstance(data, str):
82
+ return data.encode('utf-8', 'ignore').decode('utf-8')
83
+ else:
84
+ return data
85
+
86
 
87
  def count_tokens(string, vendor, model_name):
88
  full_string = string + JSON_FORMAT_INSTRUCTIONS
vouchervision/utils_VoucherVision.py CHANGED
@@ -43,7 +43,7 @@ class VoucherVision():
43
  self.prompt_version = None
44
  self.is_hf = is_hf
45
 
46
- self.trOCR_model_version = "microsoft/trocr-large-handwritten"
47
  # self.trOCR_model_version = "microsoft/trocr-base-handwritten"
48
  # self.trOCR_model_version = "dh-unibe/trocr-medieval-escriptmask" # NOPE
49
  # self.trOCR_model_version = "dh-unibe/trocr-kurrent" # NOPE
@@ -59,6 +59,8 @@ class VoucherVision():
59
  self.logger.name = f'[Transcription]'
60
  self.logger.info(f'Setting up OCR and LLM')
61
 
 
 
62
  self.db_name = self.cfg['leafmachine']['project']['embeddings_database_name']
63
  self.path_domain_knowledge = self.cfg['leafmachine']['project']['path_to_domain_knowledge_xlsx']
64
  self.build_new_db = self.cfg['leafmachine']['project']['build_new_embeddings_database']
@@ -83,7 +85,7 @@ class VoucherVision():
83
  self.wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"]
84
  self.wfo_headers_no_lists = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
85
 
86
- self.utility_headers = ["filename"] + self.wfo_headers + self.geo_headers + self.usage_headers + ["run_name", "prompt", "LLM", "tokens_in", "tokens_out", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
87
  # "WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement",
88
 
89
  # "GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
@@ -298,7 +300,8 @@ class VoucherVision():
298
  break
299
 
300
 
301
- def add_data_to_excel_from_response(self, Dirs, path_transcription, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
 
302
 
303
 
304
  wb = openpyxl.load_workbook(path_transcription)
@@ -367,7 +370,17 @@ class VoucherVision():
367
  sheet.cell(row=next_row, column=i, value=os.path.basename(self.path_custom_prompts))
368
  elif header.value == "run_name":
369
  sheet.cell(row=next_row, column=i, value=Dirs.run_name)
370
-
 
 
 
 
 
 
 
 
 
 
371
  # "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
372
  elif header.value in self.wfo_headers_no_lists:
373
  sheet.cell(row=next_row, column=i, value=WFO_record.get(header.value, ''))
@@ -404,10 +417,11 @@ class VoucherVision():
404
 
405
 
406
  def has_API_key(self, val):
407
- if val != '':
408
- return True
409
- else:
410
- return False
 
411
 
412
 
413
  def get_google_credentials(self): # Also used for google drive
@@ -460,6 +474,7 @@ class VoucherVision():
460
 
461
  self.has_key_openai = self.has_API_key(k_openai)
462
  self.has_key_azure_openai = self.has_API_key(k_openai_azure)
 
463
 
464
  self.has_key_google_project_id = self.has_API_key(k_google_project_id)
465
  self.has_key_google_location = self.has_API_key(k_google_location)
@@ -470,12 +485,15 @@ class VoucherVision():
470
  self.has_key_open_cage_geocode = self.has_API_key(k_opencage)
471
 
472
 
 
473
  ### Google - OCR, Palm2, Gemini
474
  if self.has_key_google_application_credentials and self.has_key_google_project_id and self.has_key_google_location:
475
  if self.is_hf:
476
  vertexai.init(project=os.getenv('GOOGLE_PROJECT_ID'), location=os.getenv('GOOGLE_LOCATION'), credentials=self.get_google_credentials())
477
  else:
478
  vertexai.init(project=k_google_project_id, location=k_google_location, credentials=self.get_google_credentials())
 
 
479
 
480
  ### OpenAI
481
  if self.has_key_openai:
@@ -497,7 +515,6 @@ class VoucherVision():
497
  azure_endpoint = os.getenv('AZURE_API_BASE'),
498
  openai_organization = os.getenv('AZURE_ORGANIZATION'),
499
  )
500
- self.has_key_azure_openai = True
501
 
502
  else:
503
  # Initialize the Azure OpenAI client
@@ -508,7 +525,6 @@ class VoucherVision():
508
  azure_endpoint = self.cfg_private['openai_azure']['OPENAI_API_BASE'],
509
  openai_organization = self.cfg_private['openai_azure']['OPENAI_ORGANIZATION'],
510
  )
511
- self.has_key_azure_openai = True
512
 
513
 
514
  ### Mistral
@@ -624,6 +640,7 @@ class VoucherVision():
624
  ocr_google = OCREngine(self.logger, json_report, self.dir_home, self.is_hf, self.path_to_crop, self.cfg, self.trOCR_model_version, self.trOCR_model, self.trOCR_processor, self.device)
625
  ocr_google.process_image(self.do_create_OCR_helper_image, self.logger)
626
  self.OCR = ocr_google.OCR
 
627
 
628
  self.write_json_to_file(txt_file_path_OCR, ocr_google.OCR_JSON_to_file)
629
 
@@ -671,7 +688,7 @@ class VoucherVision():
671
 
672
  json_report.set_text(text_main=f'Loading {MODEL_NAME_FORMATTED}')
673
  json_report.set_JSON({}, {}, {})
674
- llm_model = self.initialize_llm_model(self.logger, MODEL_NAME_FORMATTED, self.JSON_dict_structure, name_parts, is_azure, self.llm)
675
 
676
  for i, path_to_crop in enumerate(self.img_paths):
677
  self.update_progress_report_batch(progress_report, i)
@@ -729,7 +746,7 @@ class VoucherVision():
729
 
730
  final_JSON_response, final_WFO_record, final_GEO_record = self.update_final_response(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, paths, path_to_crop, nt_in, nt_out)
731
 
732
- self.log_completion_info(final_JSON_response)
733
 
734
  json_report.set_JSON(final_JSON_response, final_WFO_record, final_GEO_record)
735
 
@@ -741,22 +758,22 @@ class VoucherVision():
741
  ##################################################################################################################################
742
  ################################################## LLM Helper Funcs ##############################################################
743
  ##################################################################################################################################
744
- def initialize_llm_model(self, logger, model_name, JSON_dict_structure, name_parts, is_azure=None, llm_object=None):
745
  if 'LOCAL'in name_parts:
746
  if ('MIXTRAL' in name_parts) or ('MISTRAL' in name_parts):
747
  if 'CPU' in name_parts:
748
- return LocalCPUMistralHandler(logger, model_name, JSON_dict_structure)
749
  else:
750
- return LocalMistralHandler(logger, model_name, JSON_dict_structure)
751
  else:
752
  if 'PALM2' in name_parts:
753
- return GooglePalm2Handler(logger, model_name, JSON_dict_structure)
754
  elif 'GEMINI' in name_parts:
755
- return GoogleGeminiHandler(logger, model_name, JSON_dict_structure)
756
  elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
757
- return MistralHandler(logger, model_name, JSON_dict_structure)
758
  else:
759
- return OpenAIHandler(logger, model_name, JSON_dict_structure, is_azure, llm_object)
760
 
761
  def setup_prompt(self):
762
  Catalog = PromptCatalog()
@@ -807,11 +824,6 @@ class VoucherVision():
807
  return final_JSON_response_updated, WFO_record, GEO_record
808
 
809
 
810
- def log_completion_info(self, final_JSON_response):
811
- self.logger.info(f'Formatted JSON\n{final_JSON_response}')
812
- self.logger.info(f'Finished API calls\n')
813
-
814
-
815
  def update_progress_report_final(self, progress_report):
816
  if progress_report is not None:
817
  progress_report.reset_batch("Batch Complete")
@@ -839,7 +851,8 @@ class VoucherVision():
839
  return filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt
840
 
841
 
842
- def save_json_and_xlsx(self, Dirs, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out):
 
843
  if response is None:
844
  response = self.JSON_dict_structure
845
  # Insert 'filename' as the first key
 
43
  self.prompt_version = None
44
  self.is_hf = is_hf
45
 
46
+ # self.trOCR_model_version = "microsoft/trocr-large-handwritten"
47
  # self.trOCR_model_version = "microsoft/trocr-base-handwritten"
48
  # self.trOCR_model_version = "dh-unibe/trocr-medieval-escriptmask" # NOPE
49
  # self.trOCR_model_version = "dh-unibe/trocr-kurrent" # NOPE
 
59
  self.logger.name = f'[Transcription]'
60
  self.logger.info(f'Setting up OCR and LLM')
61
 
62
+ self.trOCR_model_version = self.cfg['leafmachine']['project']['trOCR_model_path']
63
+
64
  self.db_name = self.cfg['leafmachine']['project']['embeddings_database_name']
65
  self.path_domain_knowledge = self.cfg['leafmachine']['project']['path_to_domain_knowledge_xlsx']
66
  self.build_new_db = self.cfg['leafmachine']['project']['build_new_embeddings_database']
 
85
  self.wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"]
86
  self.wfo_headers_no_lists = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
87
 
88
+ self.utility_headers = ["filename"] + self.wfo_headers + self.geo_headers + self.usage_headers + ["run_name", "prompt", "LLM", "tokens_in", "tokens_out", "LM2_collage", "OCR_method", "OCR_double", "OCR_trOCR", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
89
  # "WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement",
90
 
91
  # "GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
 
300
  break
301
 
302
 
303
+ def add_data_to_excel_from_response(self, Dirs, path_transcription, response, WFO_record, GEO_record, usage_report,
304
+ MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
305
 
306
 
307
  wb = openpyxl.load_workbook(path_transcription)
 
370
  sheet.cell(row=next_row, column=i, value=os.path.basename(self.path_custom_prompts))
371
  elif header.value == "run_name":
372
  sheet.cell(row=next_row, column=i, value=Dirs.run_name)
373
+ elif header.value == "LM2_collage":
374
+ sheet.cell(row=next_row, column=i, value=self.cfg['leafmachine']['use_RGB_label_images'])
375
+ elif header.value == "OCR_method":
376
+ value_to_insert = self.cfg['leafmachine']['project']['OCR_option']
377
+ if isinstance(value_to_insert, list):
378
+ value_to_insert = '|'.join(map(str, value_to_insert))
379
+ sheet.cell(row=next_row, column=i, value=value_to_insert)
380
+ elif header.value == "OCR_double":
381
+ sheet.cell(row=next_row, column=i, value=self.cfg['leafmachine']['project']['double_OCR'])
382
+ elif header.value == "OCR_trOCR":
383
+ sheet.cell(row=next_row, column=i, value=self.cfg['leafmachine']['project']['do_use_trOCR'])
384
  # "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
385
  elif header.value in self.wfo_headers_no_lists:
386
  sheet.cell(row=next_row, column=i, value=WFO_record.get(header.value, ''))
 
417
 
418
 
419
  def has_API_key(self, val):
420
+ return isinstance(val, str) and bool(val.strip())
421
+ # if val != '':
422
+ # return True
423
+ # else:
424
+ # return False
425
 
426
 
427
  def get_google_credentials(self): # Also used for google drive
 
474
 
475
  self.has_key_openai = self.has_API_key(k_openai)
476
  self.has_key_azure_openai = self.has_API_key(k_openai_azure)
477
+ self.llm = None
478
 
479
  self.has_key_google_project_id = self.has_API_key(k_google_project_id)
480
  self.has_key_google_location = self.has_API_key(k_google_location)
 
485
  self.has_key_open_cage_geocode = self.has_API_key(k_opencage)
486
 
487
 
488
+
489
  ### Google - OCR, Palm2, Gemini
490
  if self.has_key_google_application_credentials and self.has_key_google_project_id and self.has_key_google_location:
491
  if self.is_hf:
492
  vertexai.init(project=os.getenv('GOOGLE_PROJECT_ID'), location=os.getenv('GOOGLE_LOCATION'), credentials=self.get_google_credentials())
493
  else:
494
  vertexai.init(project=k_google_project_id, location=k_google_location, credentials=self.get_google_credentials())
495
+ os.environ['GOOGLE_API_KEY'] = self.cfg_private['google']['GOOGLE_PALM_API']
496
+
497
 
498
  ### OpenAI
499
  if self.has_key_openai:
 
515
  azure_endpoint = os.getenv('AZURE_API_BASE'),
516
  openai_organization = os.getenv('AZURE_ORGANIZATION'),
517
  )
 
518
 
519
  else:
520
  # Initialize the Azure OpenAI client
 
525
  azure_endpoint = self.cfg_private['openai_azure']['OPENAI_API_BASE'],
526
  openai_organization = self.cfg_private['openai_azure']['OPENAI_ORGANIZATION'],
527
  )
 
528
 
529
 
530
  ### Mistral
 
640
  ocr_google = OCREngine(self.logger, json_report, self.dir_home, self.is_hf, self.path_to_crop, self.cfg, self.trOCR_model_version, self.trOCR_model, self.trOCR_processor, self.device)
641
  ocr_google.process_image(self.do_create_OCR_helper_image, self.logger)
642
  self.OCR = ocr_google.OCR
643
+ self.logger.info(f"Complete OCR text for LLM prompt:\n\n{self.OCR}\n\n")
644
 
645
  self.write_json_to_file(txt_file_path_OCR, ocr_google.OCR_JSON_to_file)
646
 
 
688
 
689
  json_report.set_text(text_main=f'Loading {MODEL_NAME_FORMATTED}')
690
  json_report.set_JSON({}, {}, {})
691
+ llm_model = self.initialize_llm_model(self.cfg, self.logger, MODEL_NAME_FORMATTED, self.JSON_dict_structure, name_parts, is_azure, self.llm)
692
 
693
  for i, path_to_crop in enumerate(self.img_paths):
694
  self.update_progress_report_batch(progress_report, i)
 
746
 
747
  final_JSON_response, final_WFO_record, final_GEO_record = self.update_final_response(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, paths, path_to_crop, nt_in, nt_out)
748
 
749
+ self.logger.info(f'Finished LLM call')
750
 
751
  json_report.set_JSON(final_JSON_response, final_WFO_record, final_GEO_record)
752
 
 
758
  ##################################################################################################################################
759
  ################################################## LLM Helper Funcs ##############################################################
760
  ##################################################################################################################################
761
+ def initialize_llm_model(self, cfg, logger, model_name, JSON_dict_structure, name_parts, is_azure=None, llm_object=None):
762
  if 'LOCAL'in name_parts:
763
  if ('MIXTRAL' in name_parts) or ('MISTRAL' in name_parts):
764
  if 'CPU' in name_parts:
765
+ return LocalCPUMistralHandler(cfg, logger, model_name, JSON_dict_structure)
766
  else:
767
+ return LocalMistralHandler(cfg, logger, model_name, JSON_dict_structure)
768
  else:
769
  if 'PALM2' in name_parts:
770
+ return GooglePalm2Handler(cfg, logger, model_name, JSON_dict_structure)
771
  elif 'GEMINI' in name_parts:
772
+ return GoogleGeminiHandler(cfg, logger, model_name, JSON_dict_structure)
773
  elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
774
+ return MistralHandler(cfg, logger, model_name, JSON_dict_structure)
775
  else:
776
+ return OpenAIHandler(cfg, logger, model_name, JSON_dict_structure, is_azure, llm_object)
777
 
778
  def setup_prompt(self):
779
  Catalog = PromptCatalog()
 
824
  return final_JSON_response_updated, WFO_record, GEO_record
825
 
826
 
 
 
 
 
 
827
  def update_progress_report_final(self, progress_report):
828
  if progress_report is not None:
829
  progress_report.reset_batch("Batch Complete")
 
851
  return filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt
852
 
853
 
854
+ def save_json_and_xlsx(self, Dirs, response, WFO_record, GEO_record, usage_report,
855
+ MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out):
856
  if response is None:
857
  response = self.JSON_dict_structure
858
  # Insert 'filename' as the first key