phyloforfun commited on
Commit
524a99c
1 Parent(s): 0560c52

Major update. Support for 15 LLMs, World Flora Online taxonomy validation, geolocation, 2 OCR methods, significant UI changes, stability improvements, consistent JSON parsing

Browse files
app.py CHANGED
@@ -2,7 +2,6 @@ import streamlit as st
2
  import yaml, os, json, random, time, re, torch, random, warnings, shutil, sys
3
  import seaborn as sns
4
  import plotly.graph_objs as go
5
- from itertools import chain
6
  from PIL import Image
7
  import pandas as pd
8
  from io import BytesIO
@@ -15,30 +14,190 @@ from vouchervision.vouchervision_main import voucher_vision
15
  from vouchervision.general_utils import test_GPU, get_cfg_from_full_path, summarize_expense_report, validate_dir
16
  from vouchervision.model_maps import ModelMaps
17
  from vouchervision.API_validation import APIvalidation
18
- from vouchervision.utils_hf import setup_streamlit_config, save_uploaded_file, check_prompt_yaml_filename, save_uploaded_local
19
-
 
20
 
21
 
22
  #################################################################################################################################################
23
  # Initializations ###############################################################################################################################
24
  #################################################################################################################################################
25
-
26
- st.set_page_config(layout="wide", page_icon='img/icon.ico', page_title='VoucherVision')
27
 
28
  # Parse the 'is_hf' argument and set it in session state
29
  if 'is_hf' not in st.session_state:
30
- st.session_state['is_hf'] = True
 
 
 
 
 
 
 
 
31
 
32
 
33
- ########################################################################################################
34
- ### ADDED FOR HUGGING FACE ####
35
- ########################################################################################################
36
- print(f"is_hf {st.session_state['is_hf']}")
37
  # Default YAML file path
38
  if 'config' not in st.session_state:
39
  st.session_state.config, st.session_state.dir_home = build_VV_config(loaded_cfg=None)
40
  setup_streamlit_config(st.session_state.dir_home)
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  if 'uploader_idk' not in st.session_state:
43
  st.session_state['uploader_idk'] = 1
44
  if 'input_list_small' not in st.session_state:
@@ -60,11 +219,12 @@ if 'dir_uploaded_images_small' not in st.session_state:
60
  st.session_state['dir_uploaded_images_small'] = os.path.join(st.session_state.dir_home,'uploads_small')
61
  validate_dir(os.path.join(st.session_state.dir_home,'uploads_small'))
62
 
63
- MAX_GALLERY_IMAGES = 20
64
- GALLERY_IMAGE_SIZE = 96
65
 
66
 
67
 
 
 
 
68
  def content_input_images(col_left, col_right):
69
  st.write('---')
70
  # col1, col2 = st.columns([2,8])
@@ -83,7 +243,7 @@ def content_input_images(col_left, col_right):
83
  if st.session_state.is_hf:
84
  st.session_state['dir_uploaded_images'] = os.path.join(st.session_state.dir_home,'uploads')
85
  st.session_state['dir_uploaded_images_small'] = os.path.join(st.session_state.dir_home,'uploads_small')
86
- uploaded_files = st.file_uploader("Upload Images", type=['jpg', 'jpeg'], accept_multiple_files=True, key=st.session_state['uploader_idk'])
87
  st.button("Use Test Image",help="This will clear any uploaded images and load the 1 provided test image.",on_click=use_test_image)
88
 
89
  with col_right:
@@ -92,27 +252,37 @@ def content_input_images(col_left, col_right):
92
  # Clear input image gallery and input list
93
  clear_image_gallery()
94
 
95
- # Process the new iamges
96
  for uploaded_file in uploaded_files:
97
- file_path = save_uploaded_file(st.session_state['dir_uploaded_images'], uploaded_file)
98
- st.session_state['input_list'].append(file_path)
99
-
100
- img = Image.open(file_path)
101
- img.thumbnail((GALLERY_IMAGE_SIZE, GALLERY_IMAGE_SIZE), Image.Resampling.LANCZOS)
102
- file_path_small = save_uploaded_file(st.session_state['dir_uploaded_images_small'], uploaded_file, img)
103
- st.session_state['input_list_small'].append(file_path_small)
104
- print(uploaded_file.name)
105
-
106
- # Set the local images to the uploaded images
107
- st.session_state.config['leafmachine']['project']['dir_images_local'] = st.session_state['dir_uploaded_images']
108
-
109
- n_images = len([f for f in os.listdir(st.session_state.config['leafmachine']['project']['dir_images_local']) if os.path.isfile(os.path.join(st.session_state.config['leafmachine']['project']['dir_images_local'], f))])
110
- st.session_state['processing_add_on'] = n_images
111
- uploaded_files = None
112
- st.session_state['uploader_idk'] += 1
113
- st.info(f"Processing **{n_images}** images from {st.session_state.config['leafmachine']['project']['dir_images_local']}")
114
-
115
-
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  if st.session_state['input_list_small']:
118
  if len(st.session_state['input_list_small']) > MAX_GALLERY_IMAGES:
@@ -150,7 +320,6 @@ def content_input_images(col_left, col_right):
150
  st.session_state['dir_images_local_TEMP'] = st.session_state.config['leafmachine']['project']['dir_images_local']
151
  print("rerun")
152
  st.rerun()
153
-
154
 
155
  def list_jpg_files(directory_path):
156
  jpg_count = 0
@@ -243,39 +412,14 @@ def use_test_image():
243
  st.session_state['input_list_small'].append(file_path_small)
244
 
245
 
246
- def create_download_button_yaml(file_path, selected_yaml_file, key_val):
247
- file_label = f"Download {selected_yaml_file}"
248
- with open(file_path, 'rb') as f:
249
- st.download_button(
250
- label=file_label,
251
- data=f,
252
- file_name=os.path.basename(file_path),
253
- mime='application/x-yaml',use_container_width=True,key=key_val,
254
- )
255
-
256
-
257
- def upload_local_prompt_to_server(dir_prompt):
258
- uploaded_file = st.file_uploader("Upload a custom prompt file", type=['yaml'])
259
- if uploaded_file is not None:
260
- # Check the file extension
261
- file_name = uploaded_file.name
262
- if file_name.endswith('.yaml'):
263
- file_path = os.path.join(dir_prompt, file_name)
264
-
265
- # Save the file
266
- with open(file_path, 'wb') as f:
267
- f.write(uploaded_file.getbuffer())
268
- st.success(f"Saved file {file_name} in {dir_prompt}")
269
- else:
270
- st.error("Please upload a .yaml file that you previously created using this Prompt Builder tool.")
271
-
272
-
273
  def refresh():
274
  st.session_state['uploader_idk'] += 1
275
  st.write('')
276
 
277
 
278
 
 
 
279
  # def display_image_gallery():
280
  # # Initialize the container
281
  # con_image = st.empty()
@@ -516,10 +660,7 @@ class JSONReport:
516
 
517
 
518
 
519
- def does_private_file_exist():
520
- dir_home = os.path.dirname(__file__)
521
- path_cfg_private = os.path.join(dir_home, 'PRIVATE_DATA.yaml')
522
- return os.path.exists(path_cfg_private)
523
 
524
 
525
 
@@ -971,534 +1112,14 @@ def save_changes_to_API_keys(cfg_private,openai_api_key,azure_openai_api_version
971
  # st.session_state.private_file = does_private_file_exist()
972
 
973
  # Function to load a YAML file and update session_state
974
- def load_prompt_yaml(filename):
975
- st.session_state['user_clicked_load_prompt_yaml'] = filename
976
- with open(filename, 'r') as file:
977
- st.session_state['prompt_info'] = yaml.safe_load(file)
978
- st.session_state['prompt_author'] = st.session_state['prompt_info'].get('prompt_author', st.session_state['default_prompt_author'])
979
- st.session_state['prompt_author_institution'] = st.session_state['prompt_info'].get('prompt_author_institution', st.session_state['default_prompt_author_institution'])
980
- st.session_state['prompt_name'] = st.session_state['prompt_info'].get('prompt_name', st.session_state['default_prompt_name'])
981
- st.session_state['prompt_version'] = st.session_state['prompt_info'].get('prompt_version', st.session_state['default_prompt_version'])
982
- st.session_state['prompt_description'] = st.session_state['prompt_info'].get('prompt_description', st.session_state['default_prompt_description'])
983
- st.session_state['instructions'] = st.session_state['prompt_info'].get('instructions', st.session_state['default_instructions'])
984
- st.session_state['json_formatting_instructions'] = st.session_state['prompt_info'].get('json_formatting_instructions', st.session_state['default_json_formatting_instructions'] )
985
- st.session_state['rules'] = st.session_state['prompt_info'].get('rules', {})
986
- st.session_state['mapping'] = st.session_state['prompt_info'].get('mapping', {})
987
- st.session_state['LLM'] = st.session_state['prompt_info'].get('LLM', 'General Purpose')
988
-
989
- # Placeholder:
990
- st.session_state['assigned_columns'] = list(chain.from_iterable(st.session_state['mapping'].values()))
991
 
992
  ### Updated to match HF version
993
  # def save_prompt_yaml(filename):
994
- def save_prompt_yaml(filename, col):
995
- yaml_content = {
996
- 'prompt_author': st.session_state['prompt_author'],
997
- 'prompt_author_institution': st.session_state['prompt_author_institution'],
998
- 'prompt_name': st.session_state['prompt_name'],
999
- 'prompt_version': st.session_state['prompt_version'],
1000
- 'prompt_description': st.session_state['prompt_description'],
1001
- 'LLM': st.session_state['LLM'],
1002
- 'instructions': st.session_state['instructions'],
1003
- 'json_formatting_instructions': st.session_state['json_formatting_instructions'],
1004
- 'rules': st.session_state['rules'],
1005
- 'mapping': st.session_state['mapping'],
1006
- }
1007
-
1008
- dir_prompt = os.path.join(st.session_state.dir_home, 'custom_prompts')
1009
- filepath = os.path.join(dir_prompt, f"{filename}.yaml")
1010
-
1011
- with open(filepath, 'w') as file:
1012
- yaml.safe_dump(dict(yaml_content), file, sort_keys=False)
1013
-
1014
- st.success(f"Prompt saved as '{filename}.yaml'.")
1015
-
1016
- with col: # added
1017
- create_download_button_yaml(filepath, filename,key_val=2456237465) # added
1018
-
1019
- def check_unique_mapping_assignments():
1020
- print(st.session_state['assigned_columns'])
1021
- if len(st.session_state['assigned_columns']) != len(set(st.session_state['assigned_columns'])):
1022
- st.error("Each column name must be assigned to only one category.")
1023
- return False
1024
- elif not st.session_state['assigned_columns']:
1025
- st.error("No columns have been mapped.")
1026
- return False
1027
- elif len(st.session_state['assigned_columns']) != len(st.session_state['rules'].keys()):
1028
- incomplete = [item for item in list(st.session_state['rules'].keys()) if item not in st.session_state['assigned_columns']]
1029
- st.warning(f"These columns have been mapped: {st.session_state['assigned_columns']}")
1030
- st.error(f"However, these columns must be mapped before the prompt is complete: {incomplete}")
1031
- return False
1032
- else:
1033
- st.success("Mapping confirmed.")
1034
- return True
1035
-
1036
- def check_prompt_yaml_filename(fname):
1037
- # Check if the filename only contains letters, numbers, underscores, and dashes
1038
- pattern = r'^[\w-]+$'
1039
-
1040
- # The \w matches any alphanumeric character and is equivalent to the character class [a-zA-Z0-9_].
1041
- # The hyphen - is literally matched.
1042
-
1043
- if re.match(pattern, fname):
1044
- return True
1045
- else:
1046
- return False
1047
-
1048
-
1049
- def btn_load_prompt(selected_yaml_file, dir_prompt):
1050
- if selected_yaml_file:
1051
- yaml_file_path = os.path.join(dir_prompt, selected_yaml_file)
1052
- load_prompt_yaml(yaml_file_path)
1053
- elif not selected_yaml_file:
1054
- # Directly assigning default values since no file is selected
1055
- st.session_state['prompt_info'] = {}
1056
- st.session_state['prompt_author'] = st.session_state['default_prompt_author']
1057
- st.session_state['prompt_author_institution'] = st.session_state['default_prompt_author_institution']
1058
- st.session_state['prompt_name'] = st.session_state['prompt_name']
1059
- st.session_state['prompt_version'] = st.session_state['prompt_version']
1060
- st.session_state['prompt_description'] = st.session_state['default_prompt_description']
1061
- st.session_state['instructions'] = st.session_state['default_instructions']
1062
- st.session_state['json_formatting_instructions'] = st.session_state['default_json_formatting_instructions']
1063
- st.session_state['rules'] = {}
1064
- st.session_state['LLM'] = 'General Purpose'
1065
-
1066
- st.session_state['assigned_columns'] = []
1067
-
1068
- st.session_state['prompt_info'] = {
1069
- 'prompt_author': st.session_state['prompt_author'],
1070
- 'prompt_author_institution': st.session_state['prompt_author_institution'],
1071
- 'prompt_name': st.session_state['prompt_name'],
1072
- 'prompt_version': st.session_state['prompt_version'],
1073
- 'prompt_description': st.session_state['prompt_description'],
1074
- 'instructions': st.session_state['instructions'],
1075
- 'json_formatting_instructions': st.session_state['json_formatting_instructions'],
1076
- 'rules': st.session_state['rules'],
1077
- 'mapping': st.session_state['mapping'],
1078
- 'LLM': st.session_state['LLM']
1079
- }
1080
-
1081
- def build_LLM_prompt_config():
1082
- col_main1, col_main2 = st.columns([10,2])
1083
- with col_main1:
1084
- st.session_state.logo_path = os.path.join(st.session_state.dir_home, 'img','logo.png')
1085
- st.session_state.logo = Image.open(st.session_state.logo_path)
1086
- st.image(st.session_state.logo, width=250)
1087
- with col_main2:
1088
- if st.button('Exit',key='exist button 2'):
1089
- st.session_state.proceed_to_build_llm_prompt = False
1090
- st.session_state.proceed_to_main = True
1091
- st.rerun()
1092
-
1093
- st.session_state['assigned_columns'] = []
1094
- st.session_state['default_prompt_author'] = 'unknown'
1095
- st.session_state['default_prompt_author_institution'] = 'unknown'
1096
- st.session_state['default_prompt_name'] = 'custom_prompt'
1097
- st.session_state['default_prompt_version'] = 'v-1-0'
1098
- st.session_state['default_prompt_author_institution'] = 'unknown'
1099
- st.session_state['default_prompt_description'] = 'unknown'
1100
- st.session_state['default_LLM'] = 'General Purpose'
1101
- st.session_state['default_instructions'] = """1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
1102
- 2. Map the unstructured OCR text to the appropriate JSON key and populate the field given the user-defined rules.
1103
- 3. JSON key values are permitted to remain empty strings if the corresponding information is not found in the unstructured OCR text.
1104
- 4. Duplicate dictionary fields are not allowed.
1105
- 5. Ensure all JSON keys are in camel case.
1106
- 6. Ensure new JSON field values follow sentence case capitalization.
1107
- 7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
1108
- 8. Ensure output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
1109
- 9. Only return a JSON dictionary represented as a string. You should not explain your answer."""
1110
- st.session_state['default_json_formatting_instructions'] = """This section provides rules for formatting each JSON value organized by the JSON key."""
1111
-
1112
- # Start building the Streamlit app
1113
- col_prompt_main_left, ___, col_prompt_main_right = st.columns([6,1,3])
1114
-
1115
-
1116
- with col_prompt_main_left:
1117
-
1118
- st.title("Custom LLM Prompt Builder")
1119
- st.subheader('About')
1120
- st.write("This form allows you to craft a prompt for your specific task. You can also edit the JSON yaml files directly, but please try loading the prompt back into this form to ensure that the formatting is correct. If this form cannot load your manually edited JSON yaml file, then it will not work in VoucherVision.")
1121
- st.subheader(':rainbow[How it Works]')
1122
- st.write("1. Edit this page until you are happy with your instructions. We recommend looking at the basic structure, writing down your prompt inforamtion in a Word document so that it does not randomly disappear, and then copying and pasting that info into this form once your whole prompt structure is defined.")
1123
- st.write("2. After you enter all of your prompt instructions, click 'Save' and give your file a name.")
1124
- st.write("3. This file will be saved as a yaml configuration file in the `..VoucherVision/custom_prompts` folder.")
1125
- st.write("4. When you go back the main VoucherVision page you will now see your custom prompt available in the 'Prompt Version' dropdown menu.")
1126
- st.write("5. The LLM ***only*** sees information from the 'instructions', 'rules', and 'json_formatting_instructions' sections. All other information is for versioning and integration with VoucherVisionEditor.")
1127
-
1128
- st.write("---")
1129
- st.header('Load an Existing Prompt Template')
1130
- st.write("By default, this form loads the minimum required transcription fields but does not provide rules for each field. You can also load an existing prompt as a template, editing or deleting values as needed.")
1131
-
1132
- dir_prompt = os.path.join(st.session_state.dir_home, 'custom_prompts')
1133
- yaml_files = [f for f in os.listdir(dir_prompt) if f.endswith('.yaml')]
1134
- col_load_text, col_load_btn, col_load_btn2 = st.columns([8,2,2])
1135
- with col_load_text:
1136
- # Dropdown for selecting a YAML file
1137
- st.session_state['selected_yaml_file'] = st.selectbox('Select a prompt .YAML file to load:', [''] + yaml_files)
1138
- with col_load_btn:
1139
- st.write('##')
1140
- # Button to load the selected prompt
1141
- st.button('Load Prompt', on_click=btn_load_prompt, args=[st.session_state['selected_yaml_file'], dir_prompt],use_container_width=True)
1142
-
1143
- with col_load_btn2:
1144
- if st.session_state['selected_yaml_file']:
1145
- # Construct the full path to the file
1146
- download_file_path = os.path.join(dir_prompt, st.session_state['selected_yaml_file'] )
1147
- # Create the download button
1148
- st.write('##')
1149
- create_download_button_yaml(download_file_path, st.session_state['selected_yaml_file'],key_val=345798)
1150
-
1151
- # Prompt Author Information
1152
- st.write("---")
1153
- st.header("Prompt Author Information")
1154
- st.write("We value community contributions! Please provide your name(s) (or pseudonym if you prefer) for credit. If you leave this field blank, it will say 'unknown'.")
1155
- if 'prompt_author' not in st.session_state:# != st.session_state['default_prompt_author']:
1156
- st.session_state['prompt_author'] = st.text_input("Enter names of prompt author(s)", value=st.session_state['default_prompt_author'],key=1111)
1157
- else:
1158
- st.session_state['prompt_author'] = st.text_input("Enter names of prompt author(s)", value=st.session_state['prompt_author'],key=1112)
1159
-
1160
- # Institution
1161
- st.write("Please provide your institution name. If you leave this field blank, it will say 'unknown'.")
1162
- if 'prompt_author_institution' not in st.session_state:
1163
- st.session_state['prompt_author_institution'] = st.text_input("Enter name of institution", value=st.session_state['default_prompt_author_institution'],key=1113)
1164
- else:
1165
- st.session_state['prompt_author_institution'] = st.text_input("Enter name of institution", value=st.session_state['prompt_author_institution'],key=1114)
1166
-
1167
- # Prompt name
1168
- st.write("Please provide a simple name for your prompt. If you leave this field blank, it will say 'custom_prompt'.")
1169
- if 'prompt_name' not in st.session_state:
1170
- st.session_state['prompt_name'] = st.text_input("Enter prompt name", value=st.session_state['default_prompt_name'],key=1115)
1171
- else:
1172
- st.session_state['prompt_name'] = st.text_input("Enter prompt name", value=st.session_state['prompt_name'],key=1116)
1173
-
1174
- # Prompt verion
1175
- st.write("Please provide a version identifier for your prompt. If you leave this field blank, it will say 'v-1-0'.")
1176
- if 'prompt_version' not in st.session_state:
1177
- st.session_state['prompt_version'] = st.text_input("Enter prompt version", value=st.session_state['default_prompt_version'],key=1117)
1178
- else:
1179
- st.session_state['prompt_version'] = st.text_input("Enter prompt version", value=st.session_state['prompt_version'],key=1118)
1180
-
1181
-
1182
- st.write("Please provide a description of your prompt and its intended task. Is it designed for a specific collection? Taxa? Database structure?")
1183
- if 'prompt_description' not in st.session_state:
1184
- st.session_state['prompt_description'] = st.text_input("Enter description of prompt", value=st.session_state['default_prompt_description'],key=1119)
1185
- else:
1186
- st.session_state['prompt_description'] = st.text_input("Enter description of prompt", value=st.session_state['prompt_description'],key=11111)
1187
-
1188
- st.write('---')
1189
- st.header("Set LLM Model Type")
1190
- # Define the options for the dropdown
1191
- llm_options_general = ["General Purpose",
1192
- "OpenAI GPT Models","Google PaLM2 Models","Google Gemini Models","MistralAI Models",]
1193
- llm_options_all = ModelMaps.get_models_gui_list()
1194
-
1195
- if 'LLM' not in st.session_state:
1196
- st.session_state['LLM'] = st.session_state['default_LLM']
1197
-
1198
- if st.session_state['LLM']:
1199
- llm_options = llm_options_general + llm_options_all + [st.session_state['LLM']]
1200
- else:
1201
- llm_options = llm_options_general + llm_options_all
1202
- # Create the dropdown and set the value to session_state['LLM']
1203
- st.write("Which LLM is this prompt designed for? This will not restrict its use to a specific LLM, but some prompts will behave differently across models.")
1204
- st.write("SLTPvA prompts have been validated with all supported LLMs, but perfornce may vary. If you design a prompt to work best with a specific model, then you can indicate the model here.")
1205
- st.write("For general purpose prompts (like the SLTPvA prompts) just use the 'General Purpose' option.")
1206
- st.session_state['LLM'] = st.selectbox('Set LLM', llm_options, index=llm_options.index(st.session_state.get('LLM', 'General Purpose')))
1207
-
1208
- st.write('---')
1209
- # Instructions Section
1210
- st.header("Instructions")
1211
- st.write("These are the general instructions that guide the LLM through the transcription task. We recommend using the default instructions unless you have a specific reason to change them.")
1212
-
1213
- if 'instructions' not in st.session_state:
1214
- st.session_state['instructions'] = st.text_area("Enter guiding instructions", value=st.session_state['default_instructions'].strip(), height=350,key=111112)
1215
- else:
1216
- st.session_state['instructions'] = st.text_area("Enter guiding instructions", value=st.session_state['instructions'].strip(), height=350,key=111112)
1217
-
1218
-
1219
- st.write('---')
1220
-
1221
- # Column Instructions Section
1222
- st.header("JSON Formatting Instructions")
1223
- st.write("The following section tells the LLM how we want to structure the JSON dictionary. We do not recommend changing this section because it would likely result in unstable and inconsistent behavior.")
1224
- if 'json_formatting_instructions' not in st.session_state:
1225
- st.session_state['json_formatting_instructions'] = st.text_area("Enter general JSON guidelines", value=st.session_state['default_json_formatting_instructions'],key=111114)
1226
- else:
1227
- st.session_state['json_formatting_instructions'] = st.text_area("Enter general JSON guidelines", value=st.session_state['json_formatting_instructions'],key=111115)
1228
-
1229
-
1230
-
1231
-
1232
-
1233
-
1234
- st.write('---')
1235
- col_left, col_right = st.columns([6,4])
1236
-
1237
- null_value_rules = ''
1238
- c_name = "EXAMPLE_COLUMN_NAME"
1239
- c_value = "REPLACE WITH DESCRIPTION"
1240
-
1241
- with col_left:
1242
- st.subheader('Add/Edit Columns')
1243
- st.markdown("The pre-populated fields are REQUIRED for downstream validation steps. They must be in all prompts.")
1244
-
1245
- # Initialize rules in session state if not already present
1246
- if 'rules' not in st.session_state or not st.session_state['rules']:
1247
- for required_col in st.session_state['required_fields']:
1248
- st.session_state['rules'][required_col] = c_value
1249
-
1250
-
1251
-
1252
-
1253
- # Layout for adding a new column name
1254
- # col_text, col_textbtn = st.columns([8, 2])
1255
- # with col_text:
1256
- st.session_state['new_column_name'] = st.text_input("Enter a new column name:")
1257
- # with col_textbtn:
1258
- # st.write('##')
1259
- if st.button("Add New Column") and st.session_state['new_column_name']:
1260
- if st.session_state['new_column_name'] not in st.session_state['rules']:
1261
- st.session_state['rules'][st.session_state['new_column_name']] = c_value
1262
- st.success(f"New column '{st.session_state['new_column_name']}' added. Now you can edit its properties.")
1263
- st.session_state['new_column_name'] = ''
1264
- else:
1265
- st.error("Column name already exists. Please enter a unique column name.")
1266
- st.session_state['new_column_name'] = ''
1267
-
1268
-
1269
- # Get columns excluding the protected "catalogNumber"
1270
- st.write('#')
1271
- # required_columns = [col for col in st.session_state['rules'] if col not in st.session_state['required_fields']]
1272
- editable_columns = [col for col in st.session_state['rules'] if col not in ["catalogNumber"]]
1273
- removable_columns = [col for col in st.session_state['rules'] if col not in st.session_state['required_fields']]
1274
-
1275
- st.session_state['current_rule'] = st.selectbox("Select a column to edit:", [""] + editable_columns)
1276
- # column_name = st.selectbox("Select a column to edit:", editable_columns)
1277
-
1278
-
1279
-
1280
- # if 'current_rule' not in st.session_state:
1281
- # st.session_state['current_rule'] = current_rule
1282
-
1283
-
1284
-
1285
-
1286
-
1287
- # Form for input fields
1288
- with st.form(key='rule_form'):
1289
- # format_options = ["verbatim transcription", "spell check transcription", "boolean yes no", "boolean 1 0", "integer", "[list]", "yyyy-mm-dd"]
1290
- # current_rule["format"] = st.selectbox("Format:", format_options, index=format_options.index(current_rule["format"]) if current_rule["format"] else 0)
1291
- # current_rule["null_value"] = st.text_input("Null value:", value=current_rule["null_value"])
1292
- if st.session_state['current_rule']:
1293
- current_rule_description = st.text_area("Description of category:", value=st.session_state['rules'][st.session_state['current_rule']])
1294
- else:
1295
- current_rule_description = ''
1296
- commit_button = st.form_submit_button("Commit Column")
1297
-
1298
- # default_rule = {
1299
- # "format": format_options[0], # default format
1300
- # "null_value": "", # default null value
1301
- # "description": "", # default description
1302
- # }
1303
- # if st.session_state['current_rule'] != st.session_state['current_rule']:
1304
- # # Column has changed. Update the session_state selected column.
1305
- # st.session_state['current_rule'] = st.session_state['current_rule']
1306
- # # Reset the current rule to the default for this new column, or a blank rule if not set.
1307
- # current_rule = st.session_state['rules'][st.session_state['current_rule']].get(current_rule, c_value)
1308
-
1309
- # Handle commit action
1310
- if commit_button and st.session_state['current_rule']:
1311
- # Commit the rules to the session state.
1312
- st.session_state['rules'][st.session_state['current_rule']] = current_rule_description
1313
- st.success(f"Column '{st.session_state['current_rule']}' added/updated in rules.")
1314
-
1315
- # Force the form to reset by clearing the fields from the session state
1316
- st.session_state.pop('current_rule', None) # Clear the selected column to force reset
1317
-
1318
- # st.session_state['rules'][column_name] = current_rule
1319
- # st.success(f"Column '{column_name}' added/updated in rules.")
1320
-
1321
- # # Reset current_rule to default values for the next input
1322
- # current_rule["format"] = default_rule["format"]
1323
- # current_rule["null_value"] = default_rule["null_value"]
1324
- # current_rule["description"] = default_rule["description"]
1325
-
1326
- # # To ensure that the form fields are reset, we can clear them from the session state
1327
- # for key in current_rule.keys():
1328
- # st.session_state[key] = default_rule[key]
1329
-
1330
- # Layout for removing an existing column
1331
- # del_col, del_colbtn = st.columns([8, 2])
1332
- # with del_col:
1333
- delete_column_name = st.selectbox("Select a column to delete:", [""] + removable_columns)
1334
- # with del_colbtn:
1335
- # st.write('##')
1336
- if st.button("Delete Column") and delete_column_name:
1337
- del st.session_state['rules'][delete_column_name]
1338
- st.success(f"Column '{delete_column_name}' removed from rules.")
1339
-
1340
 
1341
-
1342
-
1343
- with col_right:
1344
- # Display the current state of the JSON rules
1345
- st.subheader('Formatted Columns')
1346
- st.json(st.session_state['rules'])
1347
-
1348
- # st.subheader('All Prompt Info')
1349
- # st.json(st.session_state['prompt_info'])
1350
-
1351
-
1352
- st.write('---')
1353
-
1354
-
1355
- col_left_mapping, col_right_mapping = st.columns([6,4])
1356
- with col_left_mapping:
1357
- st.header("Mapping")
1358
- st.write("Assign each column name to a single category.")
1359
- st.session_state['refresh_mapping'] = False
1360
-
1361
- # Dynamically create a list of all column names that can be assigned
1362
- # This assumes that the column names are the keys in the dictionary under 'rules'
1363
- all_column_names = list(st.session_state['rules'].keys())
1364
-
1365
- categories = ['TAXONOMY', 'GEOGRAPHY', 'LOCALITY', 'COLLECTING', 'MISC']
1366
- if ('mapping' not in st.session_state) or (st.session_state['mapping'] == {}):
1367
- st.session_state['mapping'] = {category: [] for category in categories}
1368
- for category in categories:
1369
- # Filter out the already assigned columns
1370
- available_columns = [col for col in all_column_names if col not in st.session_state['assigned_columns'] or col in st.session_state['mapping'].get(category, [])]
1371
-
1372
- # Ensure the current mapping is a subset of the available options
1373
- current_mapping = [col for col in st.session_state['mapping'].get(category, []) if col in available_columns]
1374
-
1375
- # Provide a safe default if the current mapping is empty or contains invalid options
1376
- safe_default = current_mapping if all(col in available_columns for col in current_mapping) else []
1377
-
1378
- # Create a multi-select widget for the category with a safe default
1379
- selected_columns = st.multiselect(
1380
- f"Select columns for {category}:",
1381
- available_columns,
1382
- default=safe_default,
1383
- key=f"mapping_{category}"
1384
- )
1385
- # Update the assigned_columns based on the selections
1386
- for col in current_mapping:
1387
- if col not in selected_columns and col in st.session_state['assigned_columns']:
1388
- st.session_state['assigned_columns'].remove(col)
1389
- st.session_state['refresh_mapping'] = True
1390
-
1391
- for col in selected_columns:
1392
- if col not in st.session_state['assigned_columns']:
1393
- st.session_state['assigned_columns'].append(col)
1394
- st.session_state['refresh_mapping'] = True
1395
-
1396
- # Update the mapping in session state when there's a change
1397
- st.session_state['mapping'][category] = selected_columns
1398
- if st.session_state['refresh_mapping']:
1399
- st.session_state['refresh_mapping'] = False
1400
-
1401
- # Button to confirm and save the mapping configuration
1402
- if st.button('Confirm Mapping'):
1403
- if check_unique_mapping_assignments():
1404
- # Proceed with further actions since the mapping is confirmed and unique
1405
- pass
1406
-
1407
- with col_right_mapping:
1408
- # Display the current state of the JSON rules
1409
- st.subheader('Formatted Column Maps')
1410
- st.json(st.session_state['mapping'])
1411
-
1412
-
1413
- col_left_save, col_right_save = st.columns([6,4])
1414
- with col_left_save:
1415
- # Input for new file name
1416
- new_filename = st.text_input("Enter filename to save your prompt as a configuration YAML:",placeholder='my_prompt_name')
1417
- # Button to save the new YAML file
1418
- if st.button('Save YAML', type='primary'):
1419
- if new_filename:
1420
- if check_unique_mapping_assignments():
1421
- if check_prompt_yaml_filename(new_filename):
1422
- save_prompt_yaml(new_filename, col_left_save)
1423
- else:
1424
- st.error("File name can only contain letters, numbers, underscores, and dashes. Cannot contain spaces.")
1425
- else:
1426
- st.error("Mapping contains an error. Make sure that each column is assigned to only ***one*** category.")
1427
- else:
1428
- st.error("Please enter a filename.")
1429
-
1430
- if st.button('Exit'):
1431
- st.session_state.proceed_to_build_llm_prompt = False
1432
- st.session_state.proceed_to_main = True
1433
- st.rerun()
1434
 
1435
 
1436
 
1437
- # st.write('---')
1438
- # st.header("Save and Download Custom Prompt")
1439
- # st.write('Once you click save, validation checks will verify the formatting and then a download button will appear so that you can ***save a local copy of your custom prompt.***')
1440
- # col_left_save, col_right_save, _ = st.columns([2,2,8])
1441
- # with col_left_save:
1442
- # # Button to save the new YAML file
1443
- # if st.button('Save YAML', type='primary',key=3450798):
1444
- # if st.session_state['prompt_name']:
1445
- # if check_unique_mapping_assignments():
1446
- # if check_prompt_yaml_filename(st.session_state['prompt_name']):
1447
- # save_prompt_yaml(st.session_state['prompt_name'], col_right_save)
1448
- # else:
1449
- # st.error("File name can only contain letters, numbers, underscores, and dashes. Cannot contain spaces.")
1450
- # else:
1451
- # st.error("Mapping contains an error. Make sure that each column is assigned to only ***one*** category.")
1452
- # else:
1453
- # st.error("Please enter a filename.")
1454
-
1455
- # with col_prompt_main_right:
1456
- # st.subheader('All Prompt Components')
1457
- # st.session_state['prompt_info'] = {
1458
- # 'prompt_author': st.session_state['prompt_author'],
1459
- # 'prompt_author_institution': st.session_state['prompt_author_institution'],
1460
- # 'prompt_name': st.session_state['prompt_name'],
1461
- # 'prompt_version': st.session_state['prompt_version'],
1462
- # 'prompt_description': st.session_state['prompt_description'],
1463
- # 'LLM': st.session_state['LLM'],
1464
- # 'instructions': st.session_state['instructions'],
1465
- # 'json_formatting_instructions': st.session_state['json_formatting_instructions'],
1466
- # 'rules': st.session_state['rules'],
1467
- # 'mapping': st.session_state['mapping'],
1468
- # }
1469
- # st.json(st.session_state['prompt_info'])
1470
- with col_prompt_main_right:
1471
- if st.session_state['user_clicked_load_prompt_yaml'] is None: # see if user has loaded a yaml to edit
1472
- st.session_state['show_prompt_name_e'] = f"Prompt Status :arrow_forward: Building prompt from scratch"
1473
- if st.session_state['prompt_name']:
1474
- st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: {st.session_state['prompt_name']}.yaml"
1475
- else:
1476
- st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: [PLEASE SET NAME]"
1477
- else:
1478
- st.session_state['show_prompt_name_e'] = f"Prompt Status: Editing :arrow_forward: {st.session_state['selected_yaml_file']}"
1479
- if st.session_state['prompt_name']:
1480
- st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: {st.session_state['prompt_name']}.yaml"
1481
- else:
1482
- st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: [PLEASE SET NAME]"
1483
-
1484
- st.subheader(f'Full Prompt')
1485
- st.write(st.session_state['show_prompt_name_e'])
1486
- st.write(st.session_state['show_prompt_name_w'])
1487
- st.write("---")
1488
- st.session_state['prompt_info'] = {
1489
- 'prompt_author': st.session_state['prompt_author'],
1490
- 'prompt_author_institution': st.session_state['prompt_author_institution'],
1491
- 'prompt_name': st.session_state['prompt_name'],
1492
- 'prompt_version': st.session_state['prompt_version'],
1493
- 'prompt_description': st.session_state['prompt_description'],
1494
- 'LLM': st.session_state['LLM'],
1495
- 'instructions': st.session_state['instructions'],
1496
- 'json_formatting_instructions': st.session_state['json_formatting_instructions'],
1497
- 'rules': st.session_state['rules'],
1498
- 'mapping': st.session_state['mapping'],
1499
- }
1500
- st.json(st.session_state['prompt_info'])
1501
-
1502
  def show_header_welcome():
1503
  st.session_state.logo_path = os.path.join(st.session_state.dir_home, 'img','logo.png')
1504
  st.session_state.logo = Image.open(st.session_state.logo_path)
@@ -1676,7 +1297,7 @@ def content_header():
1676
  with col_run_4:
1677
  with st.expander("View Messages and Updates"):
1678
  st.info("***Note:*** If you use VoucherVision frequently, you can change the default values that are auto-populated in the form below. In a text editor or IDE, edit the first few rows in the file `../VoucherVision/vouchervision/VoucherVision_Config_Builder.py`")
1679
-
1680
 
1681
  col_test = st.container()
1682
 
@@ -1686,13 +1307,6 @@ def content_header():
1686
  col_json, col_json_WFO, col_json_GEO, col_json_map = st.columns([2, 2, 2, 2])
1687
 
1688
  with col_run_info_1:
1689
- # Progress
1690
- # Progress
1691
- # st.subheader('Project')
1692
- # bar = st.progress(0)
1693
- # new_text = st.empty() # Placeholder for current step name
1694
- # progress_report = ProgressReportVV(bar, new_text, n_images=10)
1695
-
1696
  # Progress
1697
  overall_progress_bar = st.progress(0)
1698
  text_overall = st.empty() # Placeholder for current step name
@@ -1700,23 +1314,14 @@ def content_header():
1700
  batch_progress_bar = st.progress(0)
1701
  text_batch = st.empty() # Placeholder for current step name
1702
  progress_report = ProgressReport(overall_progress_bar, batch_progress_bar, text_overall, text_batch)
1703
- # st.session_state['json_report'] = JSONReport(col_updates_1, col_json, col_json_WFO, col_json_GEO, col_json_map)
1704
  st.session_state['hold_output'] = st.toggle('View Final Transcription')
1705
 
1706
  with col_logo:
1707
  show_header_welcome()
1708
 
1709
  with col_run_1:
1710
- # st.subheader('Run VoucherVision')
1711
  N_STEPS = 6
1712
 
1713
- # if st.session_state.is_hf:
1714
- # count_n_imgs = determine_n_images()
1715
- # if count_n_imgs > 0:
1716
- # st.session_state['processing_add_on'] = count_n_imgs
1717
- # else:
1718
- # st.session_state['processing_add_on'] = 0
1719
-
1720
  if check_if_usable(is_hf=st.session_state['is_hf']):
1721
  b_text = f"Start Processing {st.session_state['processing_add_on']} Images" if st.session_state['processing_add_on'] > 1 else f"Start Processing {st.session_state['processing_add_on']} Image"
1722
  if st.session_state['processing_add_on'] == 0:
@@ -1740,21 +1345,28 @@ def content_header():
1740
  total_cost = 0.00
1741
  n_failed_OCR = 0
1742
  n_failed_LLM_calls = 0
1743
- try:
1744
- st.session_state['formatted_json'], st.session_state['formatted_json_WFO'], st.session_state['formatted_json_GEO'], total_cost, n_failed_OCR, n_failed_LLM_calls, st.session_state['zip_filepath'] = voucher_vision(None,
1745
- st.session_state.dir_home,
1746
- path_custom_prompts,
1747
- None,
1748
- progress_report,
1749
- st.session_state['json_report'],
1750
- path_api_cost=os.path.join(st.session_state.dir_home,'api_cost','api_cost.yaml'),
1751
- is_hf = st.session_state['is_hf'],
1752
- is_real_run=True)
1753
-
1754
- st.balloons()
1755
- except Exception as e:
1756
- with col_run_4:
1757
- st.error(f"Transcription failed. Error: {e}")
 
 
 
 
 
 
 
1758
 
1759
  if n_failed_OCR > 0:
1760
  with col_run_4:
@@ -1791,8 +1403,13 @@ def content_header():
1791
  with ct_left:
1792
  st.button("Refresh", on_click=refresh, use_container_width=True)
1793
  with ct_right:
1794
- if st.button('FAQs', use_container_width=True):
1795
- pass
 
 
 
 
 
1796
 
1797
  # with col_run_2:
1798
  # if st.button("Test GPT"):
@@ -1869,14 +1486,6 @@ def content_header():
1869
 
1870
 
1871
 
1872
-
1873
-
1874
-
1875
-
1876
-
1877
-
1878
-
1879
-
1880
  def content_project_settings(col):
1881
  ### Project
1882
  with col:
@@ -1966,9 +1575,10 @@ def content_prompt_and_llm_version():
1966
  st.session_state.config['leafmachine']['project']['prompt_version'] = st.selectbox("Prompt Version", available_prompts, index=available_prompts.index(selected_version),label_visibility='collapsed')
1967
 
1968
  with col_prompt_2:
1969
- if st.button("Build Custom LLM Prompt"):
1970
- st.session_state.proceed_to_build_llm_prompt = True
1971
- st.rerun()
 
1972
 
1973
  st.header('LLM Version')
1974
  col_llm_1, col_llm_2 = st.columns([4,2])
@@ -2004,13 +1614,66 @@ def content_api_check():
2004
  st.rerun()
2005
 
2006
 
2007
-
2008
-
2009
 
2010
- def content_collage_overlay():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2011
  st.write("---")
2012
- col_collage, col_overlay = st.columns([4,4])
2013
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2014
  demo_text_h = f"Google_OCR_Handwriting:\nHERBARIUM OF MARCUS W. LYON , JR . Tracaulon sagittatum Indiana : Porter Co. incal Springs edge wet subdunal woods 1927 TX 11 Ilowers pink UNIVERSITE HERBARIUM MICH University of Michigan Herbarium 1439649 copyright reserved PERSICARIA FEB 2 6 1965 cm "
2015
  demo_text_tr = f"trOCR:\nherbarium of marcus w. lyon jr. : : : tracaulon sagittatum indiana porter co. incal springs TX 11 Ilowers pink 1439649 copyright reserved D H U Q "
2016
  demo_text_p = f"Google_OCR_Printed:\nTracaulon sagittatum Indiana : Porter Co. incal Springs edge wet subdunal woods 1927 Ilowers pink 1439649 copyright reserved PERSICARIA FEB 2 6 1965 cm "
@@ -2019,11 +1682,125 @@ def content_collage_overlay():
2019
  demo_text_trh = demo_text_h + '\n' + demo_text_tr
2020
  demo_text_trp = demo_text_p + '\n' + demo_text_tr
2021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2022
  with col_collage:
2023
  st.header('LeafMachine2 Label Collage')
 
2024
  default_crops = st.session_state.config['leafmachine']['cropped_components']['save_cropped_annotations']
2025
  st.write("Prior to transcription, use LeafMachine2 to crop all labels from input images to create label collages for each specimen image. Showing just the text labels to the OCR algorithms significantly improves performance. This runs slowly on the free Hugging Face Space, but runs quickly with a fast CPU or any GPU.")
2026
- st.session_state.config['leafmachine']['use_RGB_label_images'] = st.checkbox("Use LeafMachine2 label collage for transcriptions", st.session_state.config['leafmachine'].get('use_RGB_label_images', False))
2027
 
2028
 
2029
  option_selected_crops = st.multiselect(label="Components to crop",
@@ -2040,76 +1817,14 @@ def content_collage_overlay():
2040
  with st.expander(":frame_with_picture: View an example of the LeafMachine2 collage image"):
2041
  st.image(st.session_state["demo_collage"], caption='LeafMachine2 Collage', output_format="PNG")
2042
  # st.image(st.session_state["demo_collage"], caption='LeafMachine2 Collage', output_format="JPEG")
2043
-
2044
-
2045
 
2046
  with col_overlay:
2047
  st.header('OCR Overlay Image')
2048
- options = [":rainbow[Printed + Handwritten]", "Printed", "Use both models"]
2049
- captions = [
2050
- "Works well for both printed and handwritten text",
2051
- "Works for printed text",
2052
- "Adds both OCR versions to the LLM prompt"
2053
- ]
2054
 
2055
  st.write('This will plot bounding boxes around all text that Google Vision was able to detect. If there are no boxes around text, then the OCR failed, so that missing text will not be seen by the LLM when it is creating the JSON object. The created image will be viewable in the VoucherVisionEditor.')
2056
 
2057
  do_create_OCR_helper_image = st.checkbox("Create image showing an overlay of the OCR detections",value=st.session_state.config['leafmachine']['do_create_OCR_helper_image'],disabled=True)
2058
  st.session_state.config['leafmachine']['do_create_OCR_helper_image'] = do_create_OCR_helper_image
2059
-
2060
-
2061
-
2062
-
2063
- # Get the current OCR option from session state
2064
- OCR_option = st.session_state.config['leafmachine']['project']['OCR_option']
2065
-
2066
- # Map the OCR option to the index in options list
2067
- # You need to define the mapping based on your application's logic
2068
- option_to_index = {
2069
- 'hand': 0,
2070
- 'normal': 1,
2071
- 'both': 2,
2072
- }
2073
- default_index = option_to_index.get(OCR_option, 0) # Default to 0 if option not found
2074
-
2075
- # Create the radio button
2076
- OCR_option_select = st.radio(
2077
- "Select the Google Vision OCR version.",
2078
- options,
2079
- index=default_index,
2080
- help="",captions=captions,
2081
- )
2082
- st.session_state.config['leafmachine']['project']['OCR_option'] = OCR_option_select
2083
-
2084
- if OCR_option_select == ":rainbow[Printed + Handwritten]":
2085
- OCR_option = 'hand'
2086
- elif OCR_option_select == "Printed":
2087
- OCR_option = 'normal'
2088
- elif OCR_option_select == "Use both models":
2089
- OCR_option = 'both'
2090
- else:
2091
- raise
2092
-
2093
- st.write("Supplement Google Vision OCR with trOCR (handwriting OCR) using `microsoft/trocr-base-handwritten`. This option requires Google Vision API and a GPU.")
2094
- do_use_trOCR = st.checkbox("Enable trOCR", value=st.session_state.config['leafmachine']['project']['do_use_trOCR'])#,disabled=st.session_state['lacks_GPU'])
2095
- st.session_state.config['leafmachine']['project']['do_use_trOCR'] = do_use_trOCR
2096
-
2097
-
2098
- st.session_state.config['leafmachine']['project']['OCR_option'] = OCR_option
2099
- st.markdown("Below is an example of what the LLM would see given the choice of OCR ensemble. One, two, or three version of OCR can be fed into the LLM prompt. Typically, 'printed + handwritten' works well. If you have a GPU then you can enable trOCR.")
2100
- if (OCR_option == 'hand') and not do_use_trOCR:
2101
- st.text_area(label='Handwritten/Printed',placeholder=demo_text_h,disabled=True, label_visibility='visible', height=150)
2102
- elif (OCR_option == 'normal') and not do_use_trOCR:
2103
- st.text_area(label='Printed',placeholder=demo_text_p,disabled=True, label_visibility='visible', height=150)
2104
- elif (OCR_option == 'both') and not do_use_trOCR:
2105
- st.text_area(label='Handwritten/Printed + Printed',placeholder=demo_text_b,disabled=True, label_visibility='visible', height=150)
2106
- elif (OCR_option == 'both') and do_use_trOCR:
2107
- st.text_area(label='Handwritten/Printed + Printed + trOCR',placeholder=demo_text_trb,disabled=True, label_visibility='visible', height=150)
2108
- elif (OCR_option == 'normal') and do_use_trOCR:
2109
- st.text_area(label='Printed + trOCR',placeholder=demo_text_trp,disabled=True, label_visibility='visible', height=150)
2110
- elif (OCR_option == 'hand') and do_use_trOCR:
2111
- st.text_area(label='Handwritten/Printed + trOCR',placeholder=demo_text_trh,disabled=True, label_visibility='visible', height=150)
2112
-
2113
 
2114
  if "demo_overlay" not in st.session_state:
2115
  # ocr = os.path.join(st.session_state.dir_home,'demo', 'ba','ocr.png')
@@ -2159,6 +1874,8 @@ def content_processing_options():
2159
  st.subheader('Compute Options')
2160
  st.session_state.config['leafmachine']['project']['num_workers'] = st.number_input("Number of CPU workers", value=st.session_state.config['leafmachine']['project'].get('num_workers', 1), disabled=False)
2161
  st.session_state.config['leafmachine']['project']['batch_size'] = st.number_input("Batch size", value=st.session_state.config['leafmachine']['project'].get('batch_size', 500), help='Sets the batch size for the LeafMachine2 cropping. If computer RAM is filled, lower this value to ~100.')
 
 
2162
  with col_processing_2:
2163
  st.subheader('Filename Prefix Handling')
2164
  st.session_state.config['leafmachine']['project']['prefix_removal'] = st.text_input("Remove prefix from catalog number", st.session_state.config['leafmachine']['project'].get('prefix_removal', ''),placeholder="e.g. MICH-V-")
@@ -2167,18 +1884,21 @@ def content_processing_options():
2167
 
2168
  ### Logging and Image Validation - col_v1
2169
  st.write("---")
2170
- st.header('Logging and Image Validation')
2171
  col_v1, col_v2 = st.columns(2)
 
2172
  with col_v1:
 
2173
  option_check_illegal = st.checkbox("Check for illegal filenames", value=st.session_state.config['leafmachine']['do']['check_for_illegal_filenames'])
2174
  st.session_state.config['leafmachine']['do']['check_for_illegal_filenames'] = option_check_illegal
2175
-
 
 
 
2176
  st.session_state.config['leafmachine']['do']['check_for_corrupt_images_make_vertical'] = st.checkbox("Check for corrupt images", st.session_state.config['leafmachine']['do'].get('check_for_corrupt_images_make_vertical', True),disabled=True)
2177
 
2178
  st.session_state.config['leafmachine']['print']['verbose'] = st.checkbox("Print verbose", st.session_state.config['leafmachine']['print'].get('verbose', True))
2179
  st.session_state.config['leafmachine']['print']['optional_warnings'] = st.checkbox("Show optional warnings", st.session_state.config['leafmachine']['print'].get('optional_warnings', True))
2180
-
2181
- with col_v2:
2182
  log_level = st.session_state.config['leafmachine']['logging'].get('log_level', None)
2183
  log_level_display = log_level if log_level is not None else 'default'
2184
  selected_log_level = st.selectbox("Logging Level", ['default', 'DEBUG', 'INFO', 'WARNING', 'ERROR'], index=['default', 'DEBUG', 'INFO', 'WARNING', 'ERROR'].index(log_level_display))
@@ -2188,6 +1908,28 @@ def content_processing_options():
2188
  else:
2189
  st.session_state.config['leafmachine']['logging']['log_level'] = selected_log_level
2190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2191
 
2192
 
2193
  def content_tab_domain():
@@ -2254,7 +1996,9 @@ def render_expense_report_summary():
2254
  expense_report = st.session_state.expense_report
2255
  st.header('Expense Report Summary')
2256
 
2257
- if expense_summary:
 
 
2258
  st.metric(label="Total Cost", value=f"${round(expense_summary['total_cost_sum'], 4):,}")
2259
  col1, col2 = st.columns(2)
2260
 
@@ -2348,19 +2092,21 @@ def render_expense_report_summary():
2348
  pie_chart.update_traces(marker=dict(colors=colors),)
2349
  st.plotly_chart(pie_chart, use_container_width=True)
2350
 
2351
- else:
2352
- st.error('No expense report data available.')
2353
-
2354
-
2355
 
2356
  def content_less_used():
2357
  st.write('---')
2358
  st.write(':octagonal_sign: ***NOTE:*** Settings below are not relevant for most projects. Some settings below may not be reflected in saved settings files and would need to be set each time.')
2359
 
 
2360
  #################################################################################################################################################
2361
  # Sidebar #######################################################################################################################################
2362
  #################################################################################################################################################
2363
  def sidebar_content():
 
 
 
 
 
2364
  if not os.path.exists(os.path.join(st.session_state.dir_home,'expense_report')):
2365
  validate_dir(os.path.join(st.session_state.dir_home,'expense_report'))
2366
  expense_report_path = os.path.join(st.session_state.dir_home, 'expense_report', 'expense_report.csv')
@@ -2377,7 +2123,6 @@ def sidebar_content():
2377
  st.write('Available after first run...')
2378
 
2379
 
2380
-
2381
  #################################################################################################################################################
2382
  # Routing Function ##############################################################################################################################
2383
  #################################################################################################################################################
@@ -2387,28 +2132,20 @@ def main():
2387
  sidebar_content()
2388
  # Main App
2389
  content_header()
2390
-
2391
 
2392
  col_input, col_gallery = st.columns([4,8])
2393
  content_project_settings(col_input)
2394
  content_input_images(col_input, col_gallery)
2395
 
2396
- # if st.session_state['is_hf']:
2397
- # content_project_settings()
2398
- # content_input_images_hf()
2399
- # else:
2400
- # col1, col2 = st.columns([1,1])
2401
- # with col1:
2402
- # content_project_settings()
2403
- # with col2:
2404
- # content_input_images()
2405
-
2406
 
2407
  col3, col4 = st.columns([1,1])
2408
  with col3:
2409
  content_prompt_and_llm_version()
2410
  with col4:
2411
  content_api_check()
 
 
 
2412
  content_collage_overlay()
2413
  content_llm_cost()
2414
  content_processing_options()
@@ -2418,155 +2155,20 @@ def main():
2418
  content_space_saver()
2419
 
2420
 
2421
-
2422
-
2423
-
2424
-
2425
- #################################################################################################################################################
2426
- # Initializations ###############################################################################################################################
2427
- #################################################################################################################################################
2428
-
2429
-
2430
-
2431
-
2432
-
2433
-
2434
-
2435
- if st.session_state['is_hf']:
2436
- if 'proceed_to_main' not in st.session_state:
2437
- st.session_state.proceed_to_main = True
2438
-
2439
- if 'proceed_to_private' not in st.session_state:
2440
- st.session_state.proceed_to_private = False
2441
-
2442
- if 'private_file' not in st.session_state:
2443
- st.session_state.private_file = True
2444
-
2445
- else:
2446
- if 'proceed_to_main' not in st.session_state:
2447
- st.session_state.proceed_to_main = False # New state variable to control the flow
2448
-
2449
- if 'private_file' not in st.session_state:
2450
- st.session_state.private_file = does_private_file_exist()
2451
- if st.session_state.private_file:
2452
- st.session_state.proceed_to_main = True
2453
-
2454
- if 'proceed_to_private' not in st.session_state:
2455
- st.session_state.proceed_to_private = False # New state variable to control the flow
2456
-
2457
-
2458
-
2459
- if 'proceed_to_build_llm_prompt' not in st.session_state:
2460
- st.session_state.proceed_to_build_llm_prompt = False # New state variable to control the flow
2461
-
2462
-
2463
- if 'processing_add_on' not in st.session_state:
2464
- st.session_state['processing_add_on'] = 0
2465
-
2466
-
2467
- if 'formatted_json' not in st.session_state:
2468
- st.session_state['formatted_json'] = None
2469
- if 'formatted_json_WFO' not in st.session_state:
2470
- st.session_state['formatted_json_WFO'] = None
2471
- if 'formatted_json_GEO' not in st.session_state:
2472
- st.session_state['formatted_json_GEO'] = None
2473
-
2474
-
2475
- if 'lacks_GPU' not in st.session_state:
2476
- st.session_state['lacks_GPU'] = not torch.cuda.is_available()
2477
-
2478
-
2479
- if 'API_key_validation' not in st.session_state:
2480
- st.session_state['API_key_validation'] = False
2481
- if 'present_annotations' not in st.session_state:
2482
- st.session_state['present_annotations'] = None
2483
- if 'missing_annotations' not in st.session_state:
2484
- st.session_state['missing_annotations'] = None
2485
- if 'date_of_check' not in st.session_state:
2486
- st.session_state['date_of_check'] = None
2487
- if 'API_checked' not in st.session_state:
2488
- st.session_state['API_checked'] = False
2489
- if 'API_rechecked' not in st.session_state:
2490
- st.session_state['API_rechecked'] = False
2491
-
2492
-
2493
- if 'json_report' not in st.session_state:
2494
- st.session_state['json_report'] = False
2495
- if 'hold_output' not in st.session_state:
2496
- st.session_state['hold_output'] = False
2497
-
2498
-
2499
-
2500
-
2501
-
2502
- if 'cost_openai' not in st.session_state:
2503
- st.session_state['cost_openai'] = None
2504
- if 'cost_azure' not in st.session_state:
2505
- st.session_state['cost_azure'] = None
2506
- if 'cost_google' not in st.session_state:
2507
- st.session_state['cost_google'] = None
2508
- if 'cost_mistral' not in st.session_state:
2509
- st.session_state['cost_mistral'] = None
2510
- if 'cost_local' not in st.session_state:
2511
- st.session_state['cost_local'] = None
2512
-
2513
-
2514
- if 'settings_filename' not in st.session_state:
2515
- st.session_state['settings_filename'] = None
2516
- if 'loaded_settings_filename' not in st.session_state:
2517
- st.session_state['loaded_settings_filename'] = None
2518
- if 'zip_filepath' not in st.session_state:
2519
- st.session_state['zip_filepath'] = None
2520
-
2521
-
2522
- # Initialize session_state variables if they don't exist
2523
- if 'prompt_info' not in st.session_state:
2524
- st.session_state['prompt_info'] = {}
2525
- if 'rules' not in st.session_state:
2526
- st.session_state['rules'] = {}
2527
-
2528
-
2529
- # These are the fields that are in SLTPvA that are not required by another parsing valication function:
2530
- # "identifiedBy": "M.W. Lyon, Jr.",
2531
- # "recordedBy": "University of Michigan Herbarium",
2532
- # "recordNumber": "",
2533
- # "habitat": "wet subdunal woods",
2534
- # "occurrenceRemarks": "Indiana : Porter Co.",
2535
- # "degreeOfEstablishment": "",
2536
- # "minimumElevationInMeters": "",
2537
- # "maximumElevationInMeters": ""
2538
- if 'required_fields' not in st.session_state:
2539
- st.session_state['required_fields'] = ['catalogNumber','order','family','scientificName',
2540
- 'scientificNameAuthorship','genus','subgenus','specificEpithet','infraspecificEpithet',
2541
- 'verbatimEventDate','eventDate',
2542
- 'country','stateProvince','county','municipality','locality','decimalLatitude','decimalLongitude','verbatimCoordinates',]
2543
-
2544
-
2545
- if 'proceed_to_build_llm_prompt' not in st.session_state:
2546
- st.session_state.proceed_to_build_llm_prompt = False
2547
- if 'proceed_to_component_detector' not in st.session_state:
2548
- st.session_state.proceed_to_component_detector = False
2549
- if 'proceed_to_parsing_options' not in st.session_state:
2550
- st.session_state.proceed_to_parsing_options = False
2551
- if 'proceed_to_api_keys' not in st.session_state:
2552
- st.session_state.proceed_to_api_keys = False
2553
- if 'proceed_to_space_saver' not in st.session_state:
2554
- st.session_state.proceed_to_space_saver = False
2555
-
2556
-
2557
  #################################################################################################################################################
2558
  # Main ##########################################################################################################################################
2559
  #################################################################################################################################################
2560
  if st.session_state['is_hf']:
2561
- if st.session_state.proceed_to_build_llm_prompt:
2562
- build_LLM_prompt_config()
2563
- elif st.session_state.proceed_to_main:
2564
  main()
 
2565
  else:
2566
  if not st.session_state.private_file:
2567
  create_private_file()
2568
- elif st.session_state.proceed_to_build_llm_prompt:
2569
- build_LLM_prompt_config()
2570
  elif st.session_state.proceed_to_private and not st.session_state['is_hf']:
2571
  create_private_file()
2572
  elif st.session_state.proceed_to_main:
 
2
  import yaml, os, json, random, time, re, torch, random, warnings, shutil, sys
3
  import seaborn as sns
4
  import plotly.graph_objs as go
 
5
  from PIL import Image
6
  import pandas as pd
7
  from io import BytesIO
 
14
  from vouchervision.general_utils import test_GPU, get_cfg_from_full_path, summarize_expense_report, validate_dir
15
  from vouchervision.model_maps import ModelMaps
16
  from vouchervision.API_validation import APIvalidation
17
+ from vouchervision.utils_hf import setup_streamlit_config, save_uploaded_file, save_uploaded_local
18
+ from vouchervision.data_project import convert_pdf_to_jpg
19
+ from vouchervision.utils_LLM import check_system_gpus
20
 
21
 
22
  #################################################################################################################################################
23
  # Initializations ###############################################################################################################################
24
  #################################################################################################################################################
25
+ st.set_page_config(layout="wide", page_icon='img/icon.ico', page_title='VoucherVision',initial_sidebar_state="collapsed")
 
26
 
27
  # Parse the 'is_hf' argument and set it in session state
28
  if 'is_hf' not in st.session_state:
29
+ try:
30
+ is_hf_os = os.getenv('IS_HF')
31
+ if is_hf_os == 1 or is_hf_os == '1' or is_hf_os or is_hf_os == 'true' or is_hf_os == 'True':
32
+ st.session_state['is_hf'] = True
33
+ else:
34
+ st.session_state['is_hf'] = False
35
+ except:
36
+ st.session_state['is_hf'] = False
37
+ print(f"is_hf {st.session_state['is_hf']}")
38
 
39
 
 
 
 
 
40
  # Default YAML file path
41
  if 'config' not in st.session_state:
42
  st.session_state.config, st.session_state.dir_home = build_VV_config(loaded_cfg=None)
43
  setup_streamlit_config(st.session_state.dir_home)
44
 
45
+
46
+ ########################################################################################################
47
+ ### Global constants ####
48
+ ########################################################################################################
49
+ MAX_GALLERY_IMAGES = 20
50
+ GALLERY_IMAGE_SIZE = 96
51
+
52
+
53
+ ########################################################################################################
54
+ ### Init funcs ####
55
+ ########################################################################################################
56
+ def does_private_file_exist():
57
+ dir_home = os.path.dirname(__file__)
58
+ path_cfg_private = os.path.join(dir_home, 'PRIVATE_DATA.yaml')
59
+ return os.path.exists(path_cfg_private)
60
+
61
+
62
+ ########################################################################################################
63
+ ### Streamlit inits [FOR SAVE FILE] ####
64
+ ########################################################################################################
65
+
66
+
67
+
68
+
69
+ ########################################################################################################
70
+ ### Streamlit inits [routing] ####
71
+ ########################################################################################################
72
+ if st.session_state['is_hf']:
73
+ if 'proceed_to_main' not in st.session_state:
74
+ st.session_state.proceed_to_main = True
75
+
76
+ if 'proceed_to_private' not in st.session_state:
77
+ st.session_state.proceed_to_private = False
78
+
79
+ if 'private_file' not in st.session_state:
80
+ st.session_state.private_file = True
81
+ else:
82
+ if 'proceed_to_main' not in st.session_state:
83
+ st.session_state.proceed_to_main = False # New state variable to control the flow
84
+
85
+ if 'private_file' not in st.session_state:
86
+ st.session_state.private_file = does_private_file_exist()
87
+ if st.session_state.private_file:
88
+ st.session_state.proceed_to_main = True
89
+
90
+ if 'proceed_to_private' not in st.session_state:
91
+ st.session_state.proceed_to_private = False # New state variable to control the flow
92
+
93
+
94
+ if 'proceed_to_build_llm_prompt' not in st.session_state:
95
+ st.session_state.proceed_to_build_llm_prompt = False # New state variable to control the flow
96
+ if 'proceed_to_build_llm_prompt' not in st.session_state:
97
+ st.session_state.proceed_to_build_llm_prompt = False
98
+ if 'proceed_to_component_detector' not in st.session_state:
99
+ st.session_state.proceed_to_component_detector = False
100
+ if 'proceed_to_parsing_options' not in st.session_state:
101
+ st.session_state.proceed_to_parsing_options = False
102
+ if 'proceed_to_api_keys' not in st.session_state:
103
+ st.session_state.proceed_to_api_keys = False
104
+ if 'proceed_to_space_saver' not in st.session_state:
105
+ st.session_state.proceed_to_space_saver = False
106
+ if 'proceed_to_faqs' not in st.session_state:
107
+ st.session_state.proceed_to_faqs = False
108
+
109
+
110
+ ########################################################################################################
111
+ ### Streamlit inits [basics] ####
112
+ ########################################################################################################
113
+ if 'processing_add_on' not in st.session_state:
114
+ st.session_state['processing_add_on'] = 0
115
+
116
+
117
+ if 'capability_score' not in st.session_state:
118
+ st.session_state['num_gpus'], st.session_state['gpu_dict'], st.session_state['total_vram_gb'], st.session_state['capability_score'] = check_system_gpus()
119
+
120
+
121
+ if 'formatted_json' not in st.session_state:
122
+ st.session_state['formatted_json'] = None
123
+ if 'formatted_json_WFO' not in st.session_state:
124
+ st.session_state['formatted_json_WFO'] = None
125
+ if 'formatted_json_GEO' not in st.session_state:
126
+ st.session_state['formatted_json_GEO'] = None
127
+
128
+
129
+ if 'lacks_GPU' not in st.session_state:
130
+ st.session_state['lacks_GPU'] = not torch.cuda.is_available()
131
+
132
+
133
+ if 'API_key_validation' not in st.session_state:
134
+ st.session_state['API_key_validation'] = False
135
+ if 'API_checked' not in st.session_state:
136
+ st.session_state['API_checked'] = False
137
+ if 'API_rechecked' not in st.session_state:
138
+ st.session_state['API_rechecked'] = False
139
+
140
+
141
+ if 'present_annotations' not in st.session_state:
142
+ st.session_state['present_annotations'] = None
143
+ if 'missing_annotations' not in st.session_state:
144
+ st.session_state['missing_annotations'] = None
145
+ if 'date_of_check' not in st.session_state:
146
+ st.session_state['date_of_check'] = None
147
+
148
+
149
+ if 'json_report' not in st.session_state:
150
+ st.session_state['json_report'] = False
151
+ if 'hold_output' not in st.session_state:
152
+ st.session_state['hold_output'] = False
153
+
154
+
155
+ if 'cost_openai' not in st.session_state:
156
+ st.session_state['cost_openai'] = None
157
+ if 'cost_azure' not in st.session_state:
158
+ st.session_state['cost_azure'] = None
159
+ if 'cost_google' not in st.session_state:
160
+ st.session_state['cost_google'] = None
161
+ if 'cost_mistral' not in st.session_state:
162
+ st.session_state['cost_mistral'] = None
163
+ if 'cost_local' not in st.session_state:
164
+ st.session_state['cost_local'] = None
165
+
166
+
167
+ if 'settings_filename' not in st.session_state:
168
+ st.session_state['settings_filename'] = None
169
+ if 'loaded_settings_filename' not in st.session_state:
170
+ st.session_state['loaded_settings_filename'] = None
171
+ if 'zip_filepath' not in st.session_state:
172
+ st.session_state['zip_filepath'] = None
173
+
174
+
175
+ ########################################################################################################
176
+ ### Streamlit inits [prompt builder] ####
177
+ ########################################################################################################
178
+ # These are the fields that are in SLTPvA that are not required by another parsing valication function:
179
+ # "identifiedBy": "M.W. Lyon, Jr.",
180
+ # "recordedBy": "University of Michigan Herbarium",
181
+ # "recordNumber": "",
182
+ # "habitat": "wet subdunal woods",
183
+ # "occurrenceRemarks": "Indiana : Porter Co.",
184
+ # "degreeOfEstablishment": "",
185
+ # "minimumElevationInMeters": "",
186
+ # "maximumElevationInMeters": ""
187
+ if 'required_fields' not in st.session_state:
188
+ st.session_state['required_fields'] = ['catalogNumber','order','family','scientificName',
189
+ 'scientificNameAuthorship','genus','subgenus','specificEpithet','infraspecificEpithet',
190
+ 'verbatimEventDate','eventDate',
191
+ 'country','stateProvince','county','municipality','locality','decimalLatitude','decimalLongitude','verbatimCoordinates',]
192
+ if 'prompt_info' not in st.session_state:
193
+ st.session_state['prompt_info'] = {}
194
+ if 'rules' not in st.session_state:
195
+ st.session_state['rules'] = {}
196
+
197
+
198
+ ########################################################################################################
199
+ ### Streamlit inits [gallery] ####
200
+ ########################################################################################################
201
  if 'uploader_idk' not in st.session_state:
202
  st.session_state['uploader_idk'] = 1
203
  if 'input_list_small' not in st.session_state:
 
219
  st.session_state['dir_uploaded_images_small'] = os.path.join(st.session_state.dir_home,'uploads_small')
220
  validate_dir(os.path.join(st.session_state.dir_home,'uploads_small'))
221
 
 
 
222
 
223
 
224
 
225
+ ########################################################################################################
226
+ ### CONTENT [] ####
227
+ ########################################################################################################
228
  def content_input_images(col_left, col_right):
229
  st.write('---')
230
  # col1, col2 = st.columns([2,8])
 
243
  if st.session_state.is_hf:
244
  st.session_state['dir_uploaded_images'] = os.path.join(st.session_state.dir_home,'uploads')
245
  st.session_state['dir_uploaded_images_small'] = os.path.join(st.session_state.dir_home,'uploads_small')
246
+ uploaded_files = st.file_uploader("Upload Images", type=['jpg', 'jpeg','pdf'], accept_multiple_files=True, key=st.session_state['uploader_idk'])
247
  st.button("Use Test Image",help="This will clear any uploaded images and load the 1 provided test image.",on_click=use_test_image)
248
 
249
  with col_right:
 
252
  # Clear input image gallery and input list
253
  clear_image_gallery()
254
 
 
255
  for uploaded_file in uploaded_files:
256
+ # Determine the file type
257
+ if uploaded_file.name.lower().endswith('.pdf'):
258
+ # Handle PDF files
259
+ file_path = save_uploaded_file(st.session_state['dir_uploaded_images'], uploaded_file)
260
+ # Convert each page of the PDF to an image
261
+ n_pages = convert_pdf_to_jpg(file_path, st.session_state['dir_uploaded_images'], dpi=st.session_state.config['leafmachine']['project']['dir_images_local'])
262
+ # Update the input list for each page image
263
+ converted_files = os.listdir(st.session_state['dir_uploaded_images'])
264
+
265
+ for file_name in converted_files:
266
+ if file_name.lower().endswith('.jpg'):
267
+ jpg_file_path = os.path.join(st.session_state['dir_uploaded_images'], file_name)
268
+ st.session_state['input_list'].append(jpg_file_path)
269
+
270
+ # Optionally, create a thumbnail for the gallery
271
+ img = Image.open(jpg_file_path)
272
+ img.thumbnail((GALLERY_IMAGE_SIZE, GALLERY_IMAGE_SIZE), Image.Resampling.LANCZOS)
273
+ file_path_small = save_uploaded_file(st.session_state['dir_uploaded_images_small'], uploaded_file, img)
274
+ st.session_state['input_list_small'].append(file_path_small)
275
+ else:
276
+ # Handle JPG/JPEG files (existing process)
277
+ file_path = save_uploaded_file(st.session_state['dir_uploaded_images'], uploaded_file)
278
+ st.session_state['input_list'].append(file_path)
279
+ img = Image.open(file_path)
280
+ img.thumbnail((GALLERY_IMAGE_SIZE, GALLERY_IMAGE_SIZE), Image.Resampling.LANCZOS)
281
+ file_path_small = save_uploaded_file(st.session_state['dir_uploaded_images_small'], uploaded_file, img)
282
+ st.session_state['input_list_small'].append(file_path_small)
283
+
284
+ # After processing all files
285
+ st.info(f"Processing images from {st.session_state.config['leafmachine']['project']['dir_images_local']}")
286
 
287
  if st.session_state['input_list_small']:
288
  if len(st.session_state['input_list_small']) > MAX_GALLERY_IMAGES:
 
320
  st.session_state['dir_images_local_TEMP'] = st.session_state.config['leafmachine']['project']['dir_images_local']
321
  print("rerun")
322
  st.rerun()
 
323
 
324
  def list_jpg_files(directory_path):
325
  jpg_count = 0
 
412
  st.session_state['input_list_small'].append(file_path_small)
413
 
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  def refresh():
416
  st.session_state['uploader_idk'] += 1
417
  st.write('')
418
 
419
 
420
 
421
+
422
+
423
  # def display_image_gallery():
424
  # # Initialize the container
425
  # con_image = st.empty()
 
660
 
661
 
662
 
663
+
 
 
 
664
 
665
 
666
 
 
1112
  # st.session_state.private_file = does_private_file_exist()
1113
 
1114
  # Function to load a YAML file and update session_state
1115
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1116
 
1117
  ### Updated to match HF version
1118
  # def save_prompt_yaml(filename):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1120
 
1121
 
1122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1123
  def show_header_welcome():
1124
  st.session_state.logo_path = os.path.join(st.session_state.dir_home, 'img','logo.png')
1125
  st.session_state.logo = Image.open(st.session_state.logo_path)
 
1297
  with col_run_4:
1298
  with st.expander("View Messages and Updates"):
1299
  st.info("***Note:*** If you use VoucherVision frequently, you can change the default values that are auto-populated in the form below. In a text editor or IDE, edit the first few rows in the file `../VoucherVision/vouchervision/VoucherVision_Config_Builder.py`")
1300
+ st.info("Please enable LeafMachine2 collage for full-sized images of herbarium vouchers, you will get better results!")
1301
 
1302
  col_test = st.container()
1303
 
 
1307
  col_json, col_json_WFO, col_json_GEO, col_json_map = st.columns([2, 2, 2, 2])
1308
 
1309
  with col_run_info_1:
 
 
 
 
 
 
 
1310
  # Progress
1311
  overall_progress_bar = st.progress(0)
1312
  text_overall = st.empty() # Placeholder for current step name
 
1314
  batch_progress_bar = st.progress(0)
1315
  text_batch = st.empty() # Placeholder for current step name
1316
  progress_report = ProgressReport(overall_progress_bar, batch_progress_bar, text_overall, text_batch)
 
1317
  st.session_state['hold_output'] = st.toggle('View Final Transcription')
1318
 
1319
  with col_logo:
1320
  show_header_welcome()
1321
 
1322
  with col_run_1:
 
1323
  N_STEPS = 6
1324
 
 
 
 
 
 
 
 
1325
  if check_if_usable(is_hf=st.session_state['is_hf']):
1326
  b_text = f"Start Processing {st.session_state['processing_add_on']} Images" if st.session_state['processing_add_on'] > 1 else f"Start Processing {st.session_state['processing_add_on']} Image"
1327
  if st.session_state['processing_add_on'] == 0:
 
1345
  total_cost = 0.00
1346
  n_failed_OCR = 0
1347
  n_failed_LLM_calls = 0
1348
+ # try:
1349
+ voucher_vision_output = voucher_vision(None,
1350
+ st.session_state.dir_home,
1351
+ path_custom_prompts,
1352
+ None,
1353
+ progress_report,
1354
+ st.session_state['json_report'],
1355
+ path_api_cost=os.path.join(st.session_state.dir_home,'api_cost','api_cost.yaml'),
1356
+ is_hf = st.session_state['is_hf'],
1357
+ is_real_run=True)
1358
+ st.session_state['formatted_json'] = voucher_vision_output['last_JSON_response']
1359
+ st.session_state['formatted_json_WFO'] = voucher_vision_output['final_WFO_record']
1360
+ st.session_state['formatted_json_GEO'] = voucher_vision_output['final_GEO_record']
1361
+ total_cost = voucher_vision_output['total_cost']
1362
+ n_failed_OCR = voucher_vision_output['n_failed_OCR']
1363
+ n_failed_LLM_calls = voucher_vision_output['n_failed_LLM_calls']
1364
+ st.session_state['zip_filepath'] = voucher_vision_output['zip_filepath']
1365
+ # st.balloons()
1366
+
1367
+ # except Exception as e:
1368
+ # with col_run_4:
1369
+ # st.error(f"Transcription failed. Error: {e}")
1370
 
1371
  if n_failed_OCR > 0:
1372
  with col_run_4:
 
1403
  with ct_left:
1404
  st.button("Refresh", on_click=refresh, use_container_width=True)
1405
  with ct_right:
1406
+ # st.page_link(os.path.join(os.path.dirname(__file__),"pages","faqs.py"), label="FAQs", icon="❔")
1407
+ st.page_link(os.path.join("pages","faqs.py"), label="FAQs", icon="❔")
1408
+
1409
+ # if st.button('FAQs', use_container_width=True):
1410
+ # st.session_state.proceed_to_faqs = True
1411
+ # st.session_state.proceed_to_main = False
1412
+ # st.rerun()
1413
 
1414
  # with col_run_2:
1415
  # if st.button("Test GPT"):
 
1486
 
1487
 
1488
 
 
 
 
 
 
 
 
 
1489
  def content_project_settings(col):
1490
  ### Project
1491
  with col:
 
1575
  st.session_state.config['leafmachine']['project']['prompt_version'] = st.selectbox("Prompt Version", available_prompts, index=available_prompts.index(selected_version),label_visibility='collapsed')
1576
 
1577
  with col_prompt_2:
1578
+ # if st.button("Build Custom LLM Prompt"):
1579
+ # st.page_link(os.path.join(os.path.dirname(__file__),"pages","prompt_builder.py"), label="Prompt Builder", icon="🚧")
1580
+ st.page_link(os.path.join("pages","prompt_builder.py"), label="Prompt Builder", icon="🚧")
1581
+
1582
 
1583
  st.header('LLM Version')
1584
  col_llm_1, col_llm_2 = st.columns([4,2])
 
1614
  st.rerun()
1615
 
1616
 
 
 
1617
 
1618
+ def adjust_ocr_options_based_on_capability(capability_score):
1619
+ llava_models_requirements = {
1620
+ "liuhaotian/llava-v1.6-mistral-7b": {"full": 18, "4bit": 9},
1621
+ "liuhaotian/llava-v1.6-34b": {"full": 70, "4bit": 25},
1622
+ "liuhaotian/llava-v1.6-vicuna-13b": {"full": 33, "4bit": 15},
1623
+ "liuhaotian/llava-v1.6-vicuna-7b": {"full": 20, "4bit": 10},
1624
+ }
1625
+ if capability_score == 'no_gpu':
1626
+ return False
1627
+ else:
1628
+ capability_score_n = int(capability_score.split("_")[1].split("GB")[0])
1629
+ supported_models = [model for model, reqs in llava_models_requirements.items()
1630
+ if reqs["full"] <= capability_score_n or reqs["4bit"] <= capability_score_n]
1631
+
1632
+ # If no models are supported, disable the LLaVA option
1633
+ if not supported_models:
1634
+ # Assuming the LLaVA option is the last in your list
1635
+ return False # Indicate LLaVA is not supported
1636
+ return True # Indicate LLaVA is supported
1637
+
1638
+
1639
+
1640
+ def content_ocr_method():
1641
  st.write("---")
1642
+ st.header('OCR Methods')
1643
+ with st.expander("Read about available OCR methods"):
1644
+ st.subheader("Overview")
1645
+ st.markdown("""VoucherVision can use the `Google Vision API`, `CRAFT` text detection + `trOCR`, and all `LLaVA v1.6` models.
1646
+ VoucherVision sends the OCR inside of the LLM prompt. We have found that sending multiple copies, or multiple version of
1647
+ the OCR text to the LLM helps the LLM maintain focus on the OCR text -- our prompts are quite long and the OCR text is reletively short.
1648
+ Below you can choose the OCR method/s. You can 'stack' all of the methods if you want, which may improve results because
1649
+ different OCR methods have different strengths, giving the LLM more information to work with. Alternative.y, you can select a single method and
1650
+ send 2 copies to the LLM by enabling that option below.""")
1651
+ st.subheader("Google Vision API")
1652
+ st.markdown("""`Google Vision API` provides several OCR methods. We use the `document_text_detection()` service, designed to handle dense text blocks.
1653
+ The `Handwritten` option CAN also be used for printed and mixed labels, but it is also optimized for handwriting. `Handwritten` uses the Google Vision Beta service.
1654
+ This is the recommended default OCR method. `Printed` uses the regular Google Vision service and works well for general use.
1655
+ You can also supplement Google Vision OCR by enabling trOCR, which is optimized for handwriting. trOCR requires segmented word images, which is provided as part
1656
+ of the Google Vision metadata. trOCR does not require a GPU, but it runs *much* faster with a GPU.""")
1657
+ st.subheader("LLaVA")
1658
+ st.markdown("""`LLaVA` can replace Google Vision APIs. It requires the use of LeafMachine2 collage, or images that are majority text. It may struggle with very
1659
+ long texts. LLaVA models are multimodal, meaning that we can upload the image and the model will transcribe (and even parse) the text all at once. With VoucherVision, we
1660
+ support 4 different LLaVA models of varying sizes, some are much more capable than others. These models tend to outperform all other OCR methods for handwriting.
1661
+ LLaVA models are run locally and require powerful GPUs to implement. While LLaVA models are capable of handling both the OCR and text parsing tasks all in one step,
1662
+ this option only uses LLaVA to transcribe all of the text in the image and still uses a separate LLM to parse text in to categories. """)
1663
+ st.subheader("CRAFT + trOCR")
1664
+ st.markdown("""This pairing can replace Google Vision APIs and is computationally lighter than LLaVA. `CRAFT` locates text, segments lines of text, and feeds the segmentations
1665
+ to the `trOCR` transformer model. This pairing requires at least an 8 GB GPU. trOCR is a Microsoft model optimized for handwriting. The base model is not as accurate as
1666
+ LLaVA or Google Vision, but if you have a trOCR-based model, let us know and we will add support.""")
1667
+
1668
+ c1, c2 = st.columns([4,4])
1669
+
1670
+ # Check if LLaVA models are supported based on capability score
1671
+ llava_supported = adjust_ocr_options_based_on_capability(st.session_state.capability_score)
1672
+ if llava_supported:
1673
+ st.success("LLaVA models are supported on this computer")
1674
+ else:
1675
+ st.warning("LLaVA models are NOT supported on this computer. Requires a GPU with at least 12 GB of VRAM.")
1676
+
1677
  demo_text_h = f"Google_OCR_Handwriting:\nHERBARIUM OF MARCUS W. LYON , JR . Tracaulon sagittatum Indiana : Porter Co. incal Springs edge wet subdunal woods 1927 TX 11 Ilowers pink UNIVERSITE HERBARIUM MICH University of Michigan Herbarium 1439649 copyright reserved PERSICARIA FEB 2 6 1965 cm "
1678
  demo_text_tr = f"trOCR:\nherbarium of marcus w. lyon jr. : : : tracaulon sagittatum indiana porter co. incal springs TX 11 Ilowers pink 1439649 copyright reserved D H U Q "
1679
  demo_text_p = f"Google_OCR_Printed:\nTracaulon sagittatum Indiana : Porter Co. incal Springs edge wet subdunal woods 1927 Ilowers pink 1439649 copyright reserved PERSICARIA FEB 2 6 1965 cm "
 
1682
  demo_text_trh = demo_text_h + '\n' + demo_text_tr
1683
  demo_text_trp = demo_text_p + '\n' + demo_text_tr
1684
 
1685
+ options = ["Google Vision Handwritten", "Google Vision Printed", "CRAFT + trOCR","LLaVA"]
1686
+ options_llava = ["llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",]
1687
+ options_llava_bit = ["full", "4bit",]
1688
+ captions_llava = [
1689
+ "Full Model: 18 GB VRAM, 4-bit: 9 GB VRAM",
1690
+ "Full Model: 70 GB VRAM, 4-bit: 25 GB VRAM",
1691
+ "Full Model: 33 GB VRAM, 4-bit: 15 GB VRAM",
1692
+ "Full Model: 20 GB VRAM, 4-bit: 10 GB VRAM",
1693
+ ]
1694
+ captions_llava_bit = ["Full Model","4-bit Quantization",]
1695
+ # Get the current OCR option from session state
1696
+ OCR_option = st.session_state.config['leafmachine']['project']['OCR_option']
1697
+ OCR_option_llava = st.session_state.config['leafmachine']['project']['OCR_option_llava']
1698
+ OCR_option_llava_bit = st.session_state.config['leafmachine']['project']['OCR_option_llava_bit']
1699
+ double_OCR = st.session_state.config['leafmachine']['project']['double_OCR']
1700
+
1701
+ # Map the OCR option to the index in options list
1702
+ # You need to define the mapping based on your application's logic
1703
+ default_index = 0 # Default to 0 if option not found
1704
+ default_index_llava = 0 # Default to 0 if option not found
1705
+ default_index_llava_bit = 0
1706
+ with c1:
1707
+ st.subheader("API Methods (Google Vision)")
1708
+ st.write("Using APIs for OCR allows VoucherVision to run on most computers.")
1709
+
1710
+ st.session_state.config['leafmachine']['project']['double_OCR'] = st.checkbox(label="Send 2 copies of the OCR to the LLM",
1711
+ help="This can help the LLMs focus attention on the OCR and not get lost in the longer instruction text",
1712
+ value=double_OCR)
1713
+
1714
+ # Create the radio button
1715
+ # OCR_option_select = st.radio(
1716
+ # "Select the OCR Method",
1717
+ # options,
1718
+ # index=default_index,
1719
+ # help="",captions=captions,
1720
+ # )
1721
+ default_values = [options[default_index]]
1722
+ OCR_option_select = st.multiselect(
1723
+ "Select the OCR Method(s)",
1724
+ options=options,
1725
+ default=default_values,
1726
+ help="Select one or more OCR methods."
1727
+ )
1728
+ # st.session_state.config['leafmachine']['project']['OCR_option'] = OCR_option_select
1729
+
1730
+ # Handling multiple selections (Example logic)
1731
+ OCR_options = {
1732
+ "Google Vision Handwritten": 'hand',
1733
+ "Google Vision Printed": 'normal',
1734
+ "CRAFT + trOCR": 'CRAFT',
1735
+ "LLaVA": 'LLaVA',
1736
+ }
1737
+
1738
+ # Map selected options to their corresponding internal representations
1739
+ selected_OCR_options = [OCR_options[option] for option in OCR_option_select]
1740
+
1741
+ # Assuming you need to use these mapped values elsewhere in your application
1742
+ st.session_state.config['leafmachine']['project']['OCR_option'] = selected_OCR_options
1743
+
1744
+
1745
+ with c2:
1746
+ st.subheader("Local Methods")
1747
+ st.write("Local methods are free, but require a capable GPU. ")
1748
+
1749
+
1750
+ st.write("Supplement Google Vision OCR with trOCR (handwriting OCR) using `microsoft/trocr-base-handwritten`. This option requires Google Vision API and a GPU.")
1751
+ if 'CRAFT' in selected_OCR_options:
1752
+ do_use_trOCR = st.checkbox("Enable trOCR", value=True, key="Enable trOCR1",disabled=True)#,disabled=st.session_state['lacks_GPU'])
1753
+ else:
1754
+ do_use_trOCR = st.checkbox("Enable trOCR", value=st.session_state.config['leafmachine']['project']['do_use_trOCR'],key="Enable trOCR2")#,disabled=st.session_state['lacks_GPU'])
1755
+ st.session_state.config['leafmachine']['project']['do_use_trOCR'] = do_use_trOCR
1756
+
1757
+ if 'LLaVA' in selected_OCR_options:
1758
+ OCR_option_llava = st.radio(
1759
+ "Select the LLaVA version",
1760
+ options_llava,
1761
+ index=default_index_llava,
1762
+ help="",captions=captions_llava,
1763
+ )
1764
+ st.session_state.config['leafmachine']['project']['OCR_option_llava'] = OCR_option_llava
1765
+
1766
+ OCR_option_llava_bit = st.radio(
1767
+ "Select the LLaVA quantization level",
1768
+ options_llava_bit,
1769
+ index=default_index_llava_bit,
1770
+ help="",captions=captions_llava_bit,
1771
+ )
1772
+ st.session_state.config['leafmachine']['project']['OCR_option_llava_bit'] = OCR_option_llava_bit
1773
+
1774
+
1775
+
1776
+ # st.markdown("Below is an example of what the LLM would see given the choice of OCR ensemble. One, two, or three version of OCR can be fed into the LLM prompt. Typically, 'printed + handwritten' works well. If you have a GPU then you can enable trOCR.")
1777
+ # if (OCR_option == 'hand') and not do_use_trOCR:
1778
+ # st.text_area(label='Handwritten/Printed',placeholder=demo_text_h,disabled=True, label_visibility='visible', height=150)
1779
+ # elif (OCR_option == 'normal') and not do_use_trOCR:
1780
+ # st.text_area(label='Printed',placeholder=demo_text_p,disabled=True, label_visibility='visible', height=150)
1781
+ # elif (OCR_option == 'both') and not do_use_trOCR:
1782
+ # st.text_area(label='Handwritten/Printed + Printed',placeholder=demo_text_b,disabled=True, label_visibility='visible', height=150)
1783
+ # elif (OCR_option == 'both') and do_use_trOCR:
1784
+ # st.text_area(label='Handwritten/Printed + Printed + trOCR',placeholder=demo_text_trb,disabled=True, label_visibility='visible', height=150)
1785
+ # elif (OCR_option == 'normal') and do_use_trOCR:
1786
+ # st.text_area(label='Printed + trOCR',placeholder=demo_text_trp,disabled=True, label_visibility='visible', height=150)
1787
+ # elif (OCR_option == 'hand') and do_use_trOCR:
1788
+ # st.text_area(label='Handwritten/Printed + trOCR',placeholder=demo_text_trh,disabled=True, label_visibility='visible', height=150)
1789
+
1790
+
1791
+
1792
+ def content_collage_overlay():
1793
+ st.write("---")
1794
+ col_collage, col_overlay = st.columns([4,4])
1795
+
1796
+
1797
+
1798
  with col_collage:
1799
  st.header('LeafMachine2 Label Collage')
1800
+ st.info("NOTE: We strongly recommend enabling LeafMachine2 cropping if your images are full sized herbarium sheet. Often, the OCR algorithm struggles with full sheets, but works well with the collage images. We have disabled the collage by default for this Hugging Face Space because the Space lacks a GPU and the collage creation takes a bit longer.")
1801
  default_crops = st.session_state.config['leafmachine']['cropped_components']['save_cropped_annotations']
1802
  st.write("Prior to transcription, use LeafMachine2 to crop all labels from input images to create label collages for each specimen image. Showing just the text labels to the OCR algorithms significantly improves performance. This runs slowly on the free Hugging Face Space, but runs quickly with a fast CPU or any GPU.")
1803
+ st.session_state.config['leafmachine']['use_RGB_label_images'] = st.checkbox(":rainbow[Use LeafMachine2 label collage for transcriptions]", st.session_state.config['leafmachine'].get('use_RGB_label_images', False))
1804
 
1805
 
1806
  option_selected_crops = st.multiselect(label="Components to crop",
 
1817
  with st.expander(":frame_with_picture: View an example of the LeafMachine2 collage image"):
1818
  st.image(st.session_state["demo_collage"], caption='LeafMachine2 Collage', output_format="PNG")
1819
  # st.image(st.session_state["demo_collage"], caption='LeafMachine2 Collage', output_format="JPEG")
 
 
1820
 
1821
  with col_overlay:
1822
  st.header('OCR Overlay Image')
 
 
 
 
 
 
1823
 
1824
  st.write('This will plot bounding boxes around all text that Google Vision was able to detect. If there are no boxes around text, then the OCR failed, so that missing text will not be seen by the LLM when it is creating the JSON object. The created image will be viewable in the VoucherVisionEditor.')
1825
 
1826
  do_create_OCR_helper_image = st.checkbox("Create image showing an overlay of the OCR detections",value=st.session_state.config['leafmachine']['do_create_OCR_helper_image'],disabled=True)
1827
  st.session_state.config['leafmachine']['do_create_OCR_helper_image'] = do_create_OCR_helper_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1828
 
1829
  if "demo_overlay" not in st.session_state:
1830
  # ocr = os.path.join(st.session_state.dir_home,'demo', 'ba','ocr.png')
 
1874
  st.subheader('Compute Options')
1875
  st.session_state.config['leafmachine']['project']['num_workers'] = st.number_input("Number of CPU workers", value=st.session_state.config['leafmachine']['project'].get('num_workers', 1), disabled=False)
1876
  st.session_state.config['leafmachine']['project']['batch_size'] = st.number_input("Batch size", value=st.session_state.config['leafmachine']['project'].get('batch_size', 500), help='Sets the batch size for the LeafMachine2 cropping. If computer RAM is filled, lower this value to ~100.')
1877
+ st.session_state.config['leafmachine']['project']['pdf_conversion_dpi'] = st.number_input("PDF conversion DPI", value=st.session_state.config['leafmachine']['project'].get('pdf_conversion_dpi', 100), help='DPI of the JPG created from the page of a PDF. 100 should be fine for most cases, but 200 or 300 might be better for large images.')
1878
+
1879
  with col_processing_2:
1880
  st.subheader('Filename Prefix Handling')
1881
  st.session_state.config['leafmachine']['project']['prefix_removal'] = st.text_input("Remove prefix from catalog number", st.session_state.config['leafmachine']['project'].get('prefix_removal', ''),placeholder="e.g. MICH-V-")
 
1884
 
1885
  ### Logging and Image Validation - col_v1
1886
  st.write("---")
 
1887
  col_v1, col_v2 = st.columns(2)
1888
+
1889
  with col_v1:
1890
+ st.header('Logging and Image Validation')
1891
  option_check_illegal = st.checkbox("Check for illegal filenames", value=st.session_state.config['leafmachine']['do']['check_for_illegal_filenames'])
1892
  st.session_state.config['leafmachine']['do']['check_for_illegal_filenames'] = option_check_illegal
1893
+
1894
+ option_skip_vertical = st.checkbox("Skip vertical image requirement (e.g. horizontal PDFs)", value=st.session_state.config['leafmachine']['do']['skip_vertical'],help='LeafMachine2 label collage requires images to have vertical aspect ratios for stability. If your input images have a horizonatal aspect ratio, try skipping the vertical requirement first, look for strange behavior, and then reassess. If your image/PDFs are already closeups and you do not need the collage, then skipping the vertical requirement is the right choice.')
1895
+ st.session_state.config['leafmachine']['do']['skip_vertical'] = option_skip_vertical
1896
+
1897
  st.session_state.config['leafmachine']['do']['check_for_corrupt_images_make_vertical'] = st.checkbox("Check for corrupt images", st.session_state.config['leafmachine']['do'].get('check_for_corrupt_images_make_vertical', True),disabled=True)
1898
 
1899
  st.session_state.config['leafmachine']['print']['verbose'] = st.checkbox("Print verbose", st.session_state.config['leafmachine']['print'].get('verbose', True))
1900
  st.session_state.config['leafmachine']['print']['optional_warnings'] = st.checkbox("Show optional warnings", st.session_state.config['leafmachine']['print'].get('optional_warnings', True))
1901
+
 
1902
  log_level = st.session_state.config['leafmachine']['logging'].get('log_level', None)
1903
  log_level_display = log_level if log_level is not None else 'default'
1904
  selected_log_level = st.selectbox("Logging Level", ['default', 'DEBUG', 'INFO', 'WARNING', 'ERROR'], index=['default', 'DEBUG', 'INFO', 'WARNING', 'ERROR'].index(log_level_display))
 
1908
  else:
1909
  st.session_state.config['leafmachine']['logging']['log_level'] = selected_log_level
1910
 
1911
+ with col_v2:
1912
+
1913
+
1914
+ print(f"Number of GPUs: {st.session_state.num_gpus}")
1915
+ print(f"GPU Details: {st.session_state.gpu_dict}")
1916
+ print(f"Total VRAM: {st.session_state.total_vram_gb} GB")
1917
+ print(f"Capability Score: {st.session_state.capability_score}")
1918
+
1919
+ st.header('System GPU Information')
1920
+ st.markdown(f"**Torch CUDA:** {torch.cuda.is_available()}")
1921
+ st.markdown(f"**Number of GPUs:** {st.session_state.num_gpus}")
1922
+
1923
+ if st.session_state.num_gpus > 0:
1924
+ st.markdown("**GPU Details:**")
1925
+ for gpu_id, vram in st.session_state.gpu_dict.items():
1926
+ st.text(f"{gpu_id}: {vram}")
1927
+
1928
+ st.markdown(f"**Total VRAM:** {st.session_state.total_vram_gb} GB")
1929
+ st.markdown(f"**Capability Score:** {st.session_state.capability_score}")
1930
+ else:
1931
+ st.warning("No GPUs detected in the system.")
1932
+
1933
 
1934
 
1935
  def content_tab_domain():
 
1996
  expense_report = st.session_state.expense_report
1997
  st.header('Expense Report Summary')
1998
 
1999
+ if not expense_summary:
2000
+ st.warning('No expense report data available.')
2001
+ else:
2002
  st.metric(label="Total Cost", value=f"${round(expense_summary['total_cost_sum'], 4):,}")
2003
  col1, col2 = st.columns(2)
2004
 
 
2092
  pie_chart.update_traces(marker=dict(colors=colors),)
2093
  st.plotly_chart(pie_chart, use_container_width=True)
2094
 
 
 
 
 
2095
 
2096
  def content_less_used():
2097
  st.write('---')
2098
  st.write(':octagonal_sign: ***NOTE:*** Settings below are not relevant for most projects. Some settings below may not be reflected in saved settings files and would need to be set each time.')
2099
 
2100
+
2101
  #################################################################################################################################################
2102
  # Sidebar #######################################################################################################################################
2103
  #################################################################################################################################################
2104
  def sidebar_content():
2105
+ # st.page_link(os.path.join(os.path.dirname(__file__),'app.py'), label="Home", icon="🏠")
2106
+ # st.page_link(os.path.join(os.path.dirname(__file__),"pages","prompt_builder.py"), label="Prompt Builder", icon="🚧")
2107
+ # st.page_link("pages/page_2.py", label="Page 2", icon="2️⃣", disabled=True)
2108
+ # st.page_link("http://www.google.com", label="Google", icon="🌎")
2109
+
2110
  if not os.path.exists(os.path.join(st.session_state.dir_home,'expense_report')):
2111
  validate_dir(os.path.join(st.session_state.dir_home,'expense_report'))
2112
  expense_report_path = os.path.join(st.session_state.dir_home, 'expense_report', 'expense_report.csv')
 
2123
  st.write('Available after first run...')
2124
 
2125
 
 
2126
  #################################################################################################################################################
2127
  # Routing Function ##############################################################################################################################
2128
  #################################################################################################################################################
 
2132
  sidebar_content()
2133
  # Main App
2134
  content_header()
 
2135
 
2136
  col_input, col_gallery = st.columns([4,8])
2137
  content_project_settings(col_input)
2138
  content_input_images(col_input, col_gallery)
2139
 
 
 
 
 
 
 
 
 
 
 
2140
 
2141
  col3, col4 = st.columns([1,1])
2142
  with col3:
2143
  content_prompt_and_llm_version()
2144
  with col4:
2145
  content_api_check()
2146
+
2147
+ content_ocr_method()
2148
+
2149
  content_collage_overlay()
2150
  content_llm_cost()
2151
  content_processing_options()
 
2155
  content_space_saver()
2156
 
2157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2158
  #################################################################################################################################################
2159
  # Main ##########################################################################################################################################
2160
  #################################################################################################################################################
2161
  if st.session_state['is_hf']:
2162
+ # if st.session_state.proceed_to_build_llm_prompt:
2163
+ # build_LLM_prompt_config()
2164
+ if st.session_state.proceed_to_main:
2165
  main()
2166
+
2167
  else:
2168
  if not st.session_state.private_file:
2169
  create_private_file()
2170
+ # elif st.session_state.proceed_to_build_llm_prompt:
2171
+ # build_LLM_prompt_config()
2172
  elif st.session_state.proceed_to_private and not st.session_state['is_hf']:
2173
  create_private_file()
2174
  elif st.session_state.proceed_to_main:
install_dependencies.sh ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # List of packages to be installed
4
+ packages=(
5
+ wheel
6
+ gputil
7
+ streamlit
8
+ streamlit-extras
9
+ streamlit-elements==0.1.*
10
+ plotly
11
+ google-api-python-client
12
+ wikipedia
13
+ PyMuPDF
14
+ craft-text-detector
15
+ pyyaml
16
+ Pillow
17
+ bitsandbytes
18
+ accelerate
19
+ mapboxgl
20
+ pandas
21
+ matplotlib
22
+ matplotlib-inline
23
+ tqdm
24
+ openai
25
+ langchain
26
+ langchain-community
27
+ langchain-core
28
+ langchain_mistralai
29
+ langchain_openai
30
+ langchain_google_genai
31
+ langchain_experimental
32
+ jsonformer
33
+ vertexai
34
+ ctransformers
35
+ google-cloud-aiplatform
36
+ tiktoken
37
+ llama-cpp-python
38
+ openpyxl
39
+ google-generativeai
40
+ google-cloud-storage
41
+ google-cloud-vision
42
+ opencv-python
43
+ chromadb
44
+ chroma-migrate
45
+ InstructorEmbedding
46
+ transformers
47
+ sentence-transformers
48
+ seaborn
49
+ dask
50
+ psutil
51
+ py-cpuinfo
52
+ Levenshtein
53
+ fuzzywuzzy
54
+ opencage
55
+ geocoder
56
+ pycountry_convert
57
+ )
58
+
59
+ # Function to install a single package
60
+ install_package() {
61
+ package=$1
62
+ echo "Installing $package..."
63
+ pip3 install $package
64
+ if [ $? -ne 0 ]; then
65
+ echo "Failed to install $package"
66
+ exit 1
67
+ fi
68
+ }
69
+
70
+ # Install each package individually
71
+ for package in "${packages[@]}"; do
72
+ install_package $package
73
+ done
74
+
75
+ echo "All packages installed successfully."
76
+ echo "Cloning and installing LLaVA..."
77
+
78
+
79
+ cd vouchervision
80
+ git clone https://github.com/haotian-liu/LLaVA.git
81
+ cd LLaVA # Assuming you want to run pip install in the LLaVA directory
82
+ pip install -e .
83
+ git pull
84
+ pip install -e .
85
+ echo "LLaVA ready"
pages/faqs.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import streamlit.components.v1 as components
4
+
5
+ st.set_page_config(layout="wide", page_icon='img/icon.ico', page_title='VV FAQs',initial_sidebar_state="collapsed")
6
+
7
+ def display_faqs():
8
+ c1, c2, c3 = st.columns([4,6,1])
9
+ with c3:
10
+ # st.page_link(os.path.join(os.path.dirname(os.path.dirname(__file__)),'app.py'), label="Home", icon="🏠")
11
+ # st.page_link(os.path.join(os.path.dirname(os.path.dirname(__file__)),"pages","faqs.py"), label="FAQs", icon="❔")
12
+ # st.page_link(os.path.join(os.path.dirname(os.path.dirname(__file__)),"pages","report_bugs.py"), label="Report a Bug", icon="⚠️")
13
+ st.page_link('app.py', label="Home", icon="🏠")
14
+ st.page_link(os.path.join("pages","faqs.py"), label="FAQs", icon="❔")
15
+ st.page_link(os.path.join("pages","report_bugs.py"), label="Report a Bug", icon="⚠️")
16
+ with c2:
17
+ st.write('If you would like to get more involved, have questions, would like to see additional features, then please fill out this [Google Form](https://docs.google.com/forms/d/e/1FAIpQLSe2E9zU1bPJ1BW4PMakEQFsRmLbQ0WTBI2UXHIMEFm4WbnAVw/viewform?usp=sf_link)')
18
+ components.iframe(f"https://docs.google.com/forms/d/e/1FAIpQLSe2E9zU1bPJ1BW4PMakEQFsRmLbQ0WTBI2UXHIMEFm4WbnAVw/viewform?embedded=true", height=900,scrolling=True,width=640)
19
+
20
+ with c1:
21
+ st.header('FAQs')
22
+ st.subheader('Lead Institution')
23
+ st.write('- University of Michigan')
24
+
25
+ st.subheader('Partner Institutions')
26
+ st.write('- Oregon State University')
27
+ st.write('- University of Colorado Boulder')
28
+ st.write('- Botanical Research Institute of Texas')
29
+ st.write('- Smithsonian National Museum of Natural History')
30
+ st.write('- South African National Biodiversity Institute')
31
+ st.write('- Botanischer Garten Berlin')
32
+ st.write('- Freie Universität Berlin')
33
+ st.write('- Morton Arboretum')
34
+ st.write('- Florida Museum')
35
+ st.write('- iDigBio')
36
+ st.write('**More soon!**')
37
+
38
+ display_faqs()
pages/prompt_builder.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, yaml
2
+ import streamlit as st
3
+ from PIL import Image
4
+ from itertools import chain
5
+
6
+ from vouchervision.model_maps import ModelMaps
7
+ from vouchervision.utils_hf import check_prompt_yaml_filename
8
+
9
+ st.set_page_config(layout="wide", page_icon='img/icon.ico', page_title='VV Prompt Builder',initial_sidebar_state="collapsed")
10
+
11
+ def create_download_button_yaml(file_path, selected_yaml_file, key_val):
12
+ file_label = f"Download {selected_yaml_file}"
13
+ with open(file_path, 'rb') as f:
14
+ st.download_button(
15
+ label=file_label,
16
+ data=f,
17
+ file_name=os.path.basename(file_path),
18
+ mime='application/x-yaml',use_container_width=True,key=key_val,
19
+ )
20
+
21
+
22
+ def upload_local_prompt_to_server(dir_prompt):
23
+ uploaded_file = st.file_uploader("Upload a custom prompt file", type=['yaml'])
24
+ if uploaded_file is not None:
25
+ # Check the file extension
26
+ file_name = uploaded_file.name
27
+ if file_name.endswith('.yaml'):
28
+ file_path = os.path.join(dir_prompt, file_name)
29
+
30
+ # Save the file
31
+ with open(file_path, 'wb') as f:
32
+ f.write(uploaded_file.getbuffer())
33
+ st.success(f"Saved file {file_name} in {dir_prompt}")
34
+ else:
35
+ st.error("Please upload a .yaml file that you previously created using this Prompt Builder tool.")
36
+
37
+
38
+ def save_prompt_yaml(filename, col):
39
+ yaml_content = {
40
+ 'prompt_author': st.session_state['prompt_author'],
41
+ 'prompt_author_institution': st.session_state['prompt_author_institution'],
42
+ 'prompt_name': st.session_state['prompt_name'],
43
+ 'prompt_version': st.session_state['prompt_version'],
44
+ 'prompt_description': st.session_state['prompt_description'],
45
+ 'LLM': st.session_state['LLM'],
46
+ 'instructions': st.session_state['instructions'],
47
+ 'json_formatting_instructions': st.session_state['json_formatting_instructions'],
48
+ 'rules': st.session_state['rules'],
49
+ 'mapping': st.session_state['mapping'],
50
+ }
51
+
52
+ dir_prompt = os.path.join(st.session_state.dir_home, 'custom_prompts')
53
+ filepath = os.path.join(dir_prompt, f"{filename}.yaml")
54
+
55
+ with open(filepath, 'w') as file:
56
+ yaml.safe_dump(dict(yaml_content), file, sort_keys=False)
57
+
58
+ st.success(f"Prompt saved as '{filename}.yaml'.")
59
+
60
+ with col: # added
61
+ create_download_button_yaml(filepath, filename,key_val=2456237465) # added
62
+
63
+
64
+ def load_prompt_yaml(filename):
65
+ st.session_state['user_clicked_load_prompt_yaml'] = filename
66
+ with open(filename, 'r') as file:
67
+ st.session_state['prompt_info'] = yaml.safe_load(file)
68
+ st.session_state['prompt_author'] = st.session_state['prompt_info'].get('prompt_author', st.session_state['default_prompt_author'])
69
+ st.session_state['prompt_author_institution'] = st.session_state['prompt_info'].get('prompt_author_institution', st.session_state['default_prompt_author_institution'])
70
+ st.session_state['prompt_name'] = st.session_state['prompt_info'].get('prompt_name', st.session_state['default_prompt_name'])
71
+ st.session_state['prompt_version'] = st.session_state['prompt_info'].get('prompt_version', st.session_state['default_prompt_version'])
72
+ st.session_state['prompt_description'] = st.session_state['prompt_info'].get('prompt_description', st.session_state['default_prompt_description'])
73
+ st.session_state['instructions'] = st.session_state['prompt_info'].get('instructions', st.session_state['default_instructions'])
74
+ st.session_state['json_formatting_instructions'] = st.session_state['prompt_info'].get('json_formatting_instructions', st.session_state['default_json_formatting_instructions'] )
75
+ st.session_state['rules'] = st.session_state['prompt_info'].get('rules', {})
76
+ st.session_state['mapping'] = st.session_state['prompt_info'].get('mapping', {})
77
+ st.session_state['LLM'] = st.session_state['prompt_info'].get('LLM', 'General Purpose')
78
+
79
+ # Placeholder:
80
+ st.session_state['assigned_columns'] = list(chain.from_iterable(st.session_state['mapping'].values()))
81
+
82
+
83
+ def btn_load_prompt(selected_yaml_file, dir_prompt):
84
+ if selected_yaml_file:
85
+ yaml_file_path = os.path.join(dir_prompt, selected_yaml_file)
86
+ load_prompt_yaml(yaml_file_path)
87
+ elif not selected_yaml_file:
88
+ # Directly assigning default values since no file is selected
89
+ st.session_state['prompt_info'] = {}
90
+ st.session_state['prompt_author'] = st.session_state['default_prompt_author']
91
+ st.session_state['prompt_author_institution'] = st.session_state['default_prompt_author_institution']
92
+ st.session_state['prompt_name'] = st.session_state['prompt_name']
93
+ st.session_state['prompt_version'] = st.session_state['prompt_version']
94
+ st.session_state['prompt_description'] = st.session_state['default_prompt_description']
95
+ st.session_state['instructions'] = st.session_state['default_instructions']
96
+ st.session_state['json_formatting_instructions'] = st.session_state['default_json_formatting_instructions']
97
+ st.session_state['rules'] = {}
98
+ st.session_state['LLM'] = 'General Purpose'
99
+
100
+ st.session_state['assigned_columns'] = []
101
+
102
+ st.session_state['prompt_info'] = {
103
+ 'prompt_author': st.session_state['prompt_author'],
104
+ 'prompt_author_institution': st.session_state['prompt_author_institution'],
105
+ 'prompt_name': st.session_state['prompt_name'],
106
+ 'prompt_version': st.session_state['prompt_version'],
107
+ 'prompt_description': st.session_state['prompt_description'],
108
+ 'instructions': st.session_state['instructions'],
109
+ 'json_formatting_instructions': st.session_state['json_formatting_instructions'],
110
+ 'rules': st.session_state['rules'],
111
+ 'mapping': st.session_state['mapping'],
112
+ 'LLM': st.session_state['LLM']
113
+ }
114
+
115
+
116
+ def check_unique_mapping_assignments():
117
+ print(st.session_state['assigned_columns'])
118
+ if len(st.session_state['assigned_columns']) != len(set(st.session_state['assigned_columns'])):
119
+ st.error("Each column name must be assigned to only one category.")
120
+ return False
121
+ elif not st.session_state['assigned_columns']:
122
+ st.error("No columns have been mapped.")
123
+ return False
124
+ elif len(st.session_state['assigned_columns']) != len(st.session_state['rules'].keys()):
125
+ incomplete = [item for item in list(st.session_state['rules'].keys()) if item not in st.session_state['assigned_columns']]
126
+ st.warning(f"These columns have been mapped: {st.session_state['assigned_columns']}")
127
+ st.error(f"However, these columns must be mapped before the prompt is complete: {incomplete}")
128
+ return False
129
+ else:
130
+ st.success("Mapping confirmed.")
131
+ return True
132
+
133
+
134
+ def build_LLM_prompt_config():
135
+ col_main1, col_main2 = st.columns([10,2])
136
+ with col_main1:
137
+ st.session_state.logo_path = os.path.join(st.session_state.dir_home, 'img','logo.png')
138
+ st.session_state.logo = Image.open(st.session_state.logo_path)
139
+ st.image(st.session_state.logo, width=250)
140
+ with col_main2:
141
+ st.page_link('app.py', label="Home", icon="🏠")
142
+ st.page_link(os.path.join("pages","faqs.py"), label="FAQs", icon="❔")
143
+ st.page_link(os.path.join("pages","report_bugs.py"), label="Report a Bug", icon="⚠️")
144
+ # st.page_link("pages/page_2.py", label="Page 2", icon="2️⃣", disabled=True)
145
+ # st.page_link("http://www.google.com", label="Google", icon="🌎")
146
+
147
+ st.session_state['assigned_columns'] = []
148
+ st.session_state['default_prompt_author'] = 'unknown'
149
+ st.session_state['default_prompt_author_institution'] = 'unknown'
150
+ st.session_state['default_prompt_name'] = 'custom_prompt'
151
+ st.session_state['default_prompt_version'] = 'v-1-0'
152
+ st.session_state['default_prompt_author_institution'] = 'unknown'
153
+ st.session_state['default_prompt_description'] = 'unknown'
154
+ st.session_state['default_LLM'] = 'General Purpose'
155
+ st.session_state['default_instructions'] = """1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
156
+ 2. Map the unstructured OCR text to the appropriate JSON key and populate the field given the user-defined rules.
157
+ 3. JSON key values are permitted to remain empty strings if the corresponding information is not found in the unstructured OCR text.
158
+ 4. Duplicate dictionary fields are not allowed.
159
+ 5. Ensure all JSON keys are in camel case.
160
+ 6. Ensure new JSON field values follow sentence case capitalization.
161
+ 7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
162
+ 8. Ensure output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
163
+ 9. Only return a JSON dictionary represented as a string. You should not explain your answer."""
164
+ st.session_state['default_json_formatting_instructions'] = """This section provides rules for formatting each JSON value organized by the JSON key."""
165
+
166
+ # Start building the Streamlit app
167
+ col_prompt_main_left, ___, col_prompt_main_right = st.columns([6,1,3])
168
+
169
+
170
+ with col_prompt_main_left:
171
+
172
+ st.title("Custom LLM Prompt Builder")
173
+ st.subheader('About')
174
+ st.write("This form allows you to craft a prompt for your specific task. You can also edit the JSON yaml files directly, but please try loading the prompt back into this form to ensure that the formatting is correct. If this form cannot load your manually edited JSON yaml file, then it will not work in VoucherVision.")
175
+ st.subheader(':rainbow[How it Works]')
176
+ st.write("1. Edit this page until you are happy with your instructions. We recommend looking at the basic structure, writing down your prompt inforamtion in a Word document so that it does not randomly disappear, and then copying and pasting that info into this form once your whole prompt structure is defined.")
177
+ st.write("2. After you enter all of your prompt instructions, click 'Save' and give your file a name.")
178
+ st.write("3. This file will be saved as a yaml configuration file in the `..VoucherVision/custom_prompts` folder.")
179
+ st.write("4. When you go back the main VoucherVision page you will now see your custom prompt available in the 'Prompt Version' dropdown menu.")
180
+ st.write("5. The LLM ***only*** sees information from the 'instructions', 'rules', and 'json_formatting_instructions' sections. All other information is for versioning and integration with VoucherVisionEditor.")
181
+
182
+ st.write("---")
183
+ st.header('Load an Existing Prompt Template')
184
+ st.write("By default, this form loads the minimum required transcription fields but does not provide rules for each field. You can also load an existing prompt as a template, editing or deleting values as needed.")
185
+
186
+ dir_prompt = os.path.join(st.session_state.dir_home, 'custom_prompts')
187
+ yaml_files = [f for f in os.listdir(dir_prompt) if f.endswith('.yaml')]
188
+ col_load_text, col_load_btn, col_load_btn2 = st.columns([8,2,2])
189
+ with col_load_text:
190
+ # Dropdown for selecting a YAML file
191
+ st.session_state['selected_yaml_file'] = st.selectbox('Select a prompt .YAML file to load:', [''] + yaml_files)
192
+ with col_load_btn:
193
+ st.write('##')
194
+ # Button to load the selected prompt
195
+ st.button('Load Prompt', on_click=btn_load_prompt, args=[st.session_state['selected_yaml_file'], dir_prompt],use_container_width=True)
196
+
197
+ with col_load_btn2:
198
+ if st.session_state['selected_yaml_file']:
199
+ # Construct the full path to the file
200
+ download_file_path = os.path.join(dir_prompt, st.session_state['selected_yaml_file'] )
201
+ # Create the download button
202
+ st.write('##')
203
+ create_download_button_yaml(download_file_path, st.session_state['selected_yaml_file'],key_val=345798)
204
+
205
+ # Prompt Author Information
206
+ st.write("---")
207
+ st.header("Prompt Author Information")
208
+ st.write("We value community contributions! Please provide your name(s) (or pseudonym if you prefer) for credit. If you leave this field blank, it will say 'unknown'.")
209
+ if 'prompt_author' not in st.session_state:# != st.session_state['default_prompt_author']:
210
+ st.session_state['prompt_author'] = st.text_input("Enter names of prompt author(s)", value=st.session_state['default_prompt_author'],key=1111)
211
+ else:
212
+ st.session_state['prompt_author'] = st.text_input("Enter names of prompt author(s)", value=st.session_state['prompt_author'],key=1112)
213
+
214
+ # Institution
215
+ st.write("Please provide your institution name. If you leave this field blank, it will say 'unknown'.")
216
+ if 'prompt_author_institution' not in st.session_state:
217
+ st.session_state['prompt_author_institution'] = st.text_input("Enter name of institution", value=st.session_state['default_prompt_author_institution'],key=1113)
218
+ else:
219
+ st.session_state['prompt_author_institution'] = st.text_input("Enter name of institution", value=st.session_state['prompt_author_institution'],key=1114)
220
+
221
+ # Prompt name
222
+ st.write("Please provide a simple name for your prompt. If you leave this field blank, it will say 'custom_prompt'.")
223
+ if 'prompt_name' not in st.session_state:
224
+ st.session_state['prompt_name'] = st.text_input("Enter prompt name", value=st.session_state['default_prompt_name'],key=1115)
225
+ else:
226
+ st.session_state['prompt_name'] = st.text_input("Enter prompt name", value=st.session_state['prompt_name'],key=1116)
227
+
228
+ # Prompt verion
229
+ st.write("Please provide a version identifier for your prompt. If you leave this field blank, it will say 'v-1-0'.")
230
+ if 'prompt_version' not in st.session_state:
231
+ st.session_state['prompt_version'] = st.text_input("Enter prompt version", value=st.session_state['default_prompt_version'],key=1117)
232
+ else:
233
+ st.session_state['prompt_version'] = st.text_input("Enter prompt version", value=st.session_state['prompt_version'],key=1118)
234
+
235
+
236
+ st.write("Please provide a description of your prompt and its intended task. Is it designed for a specific collection? Taxa? Database structure?")
237
+ if 'prompt_description' not in st.session_state:
238
+ st.session_state['prompt_description'] = st.text_input("Enter description of prompt", value=st.session_state['default_prompt_description'],key=1119)
239
+ else:
240
+ st.session_state['prompt_description'] = st.text_input("Enter description of prompt", value=st.session_state['prompt_description'],key=11111)
241
+
242
+ st.write('---')
243
+ st.header("Set LLM Model Type")
244
+ # Define the options for the dropdown
245
+ llm_options_general = ["General Purpose",
246
+ "OpenAI GPT Models","Google PaLM2 Models","Google Gemini Models","MistralAI Models",]
247
+ llm_options_all = ModelMaps.get_models_gui_list()
248
+
249
+ if 'LLM' not in st.session_state:
250
+ st.session_state['LLM'] = st.session_state['default_LLM']
251
+
252
+ if st.session_state['LLM']:
253
+ llm_options = llm_options_general + llm_options_all + [st.session_state['LLM']]
254
+ else:
255
+ llm_options = llm_options_general + llm_options_all
256
+ # Create the dropdown and set the value to session_state['LLM']
257
+ st.write("Which LLM is this prompt designed for? This will not restrict its use to a specific LLM, but some prompts will behave differently across models.")
258
+ st.write("SLTPvA prompts have been validated with all supported LLMs, but perfornce may vary. If you design a prompt to work best with a specific model, then you can indicate the model here.")
259
+ st.write("For general purpose prompts (like the SLTPvA prompts) just use the 'General Purpose' option.")
260
+ st.session_state['LLM'] = st.selectbox('Set LLM', llm_options, index=llm_options.index(st.session_state.get('LLM', 'General Purpose')))
261
+
262
+ st.write('---')
263
+ # Instructions Section
264
+ st.header("Instructions")
265
+ st.write("These are the general instructions that guide the LLM through the transcription task. We recommend using the default instructions unless you have a specific reason to change them.")
266
+
267
+ if 'instructions' not in st.session_state:
268
+ st.session_state['instructions'] = st.text_area("Enter guiding instructions", value=st.session_state['default_instructions'].strip(), height=350,key=111112)
269
+ else:
270
+ st.session_state['instructions'] = st.text_area("Enter guiding instructions", value=st.session_state['instructions'].strip(), height=350,key=111112)
271
+
272
+
273
+ st.write('---')
274
+
275
+ # Column Instructions Section
276
+ st.header("JSON Formatting Instructions")
277
+ st.write("The following section tells the LLM how we want to structure the JSON dictionary. We do not recommend changing this section because it would likely result in unstable and inconsistent behavior.")
278
+ if 'json_formatting_instructions' not in st.session_state:
279
+ st.session_state['json_formatting_instructions'] = st.text_area("Enter general JSON guidelines", value=st.session_state['default_json_formatting_instructions'],key=111114)
280
+ else:
281
+ st.session_state['json_formatting_instructions'] = st.text_area("Enter general JSON guidelines", value=st.session_state['json_formatting_instructions'],key=111115)
282
+
283
+
284
+
285
+
286
+
287
+
288
+ st.write('---')
289
+ col_left, col_right = st.columns([6,4])
290
+
291
+ null_value_rules = ''
292
+ c_name = "EXAMPLE_COLUMN_NAME"
293
+ c_value = "REPLACE WITH DESCRIPTION"
294
+
295
+ with col_left:
296
+ st.subheader('Add/Edit Columns')
297
+ st.markdown("The pre-populated fields are REQUIRED for downstream validation steps. They must be in all prompts.")
298
+
299
+ # Initialize rules in session state if not already present
300
+ if 'rules' not in st.session_state or not st.session_state['rules']:
301
+ for required_col in st.session_state['required_fields']:
302
+ st.session_state['rules'][required_col] = c_value
303
+
304
+
305
+
306
+
307
+ # Layout for adding a new column name
308
+ # col_text, col_textbtn = st.columns([8, 2])
309
+ # with col_text:
310
+ st.session_state['new_column_name'] = st.text_input("Enter a new column name:")
311
+ # with col_textbtn:
312
+ # st.write('##')
313
+ if st.button("Add New Column") and st.session_state['new_column_name']:
314
+ if st.session_state['new_column_name'] not in st.session_state['rules']:
315
+ st.session_state['rules'][st.session_state['new_column_name']] = c_value
316
+ st.success(f"New column '{st.session_state['new_column_name']}' added. Now you can edit its properties.")
317
+ st.session_state['new_column_name'] = ''
318
+ else:
319
+ st.error("Column name already exists. Please enter a unique column name.")
320
+ st.session_state['new_column_name'] = ''
321
+
322
+ # Get columns excluding the protected "catalogNumber"
323
+ st.write('#')
324
+ # required_columns = [col for col in st.session_state['rules'] if col not in st.session_state['required_fields']]
325
+ editable_columns = [col for col in st.session_state['rules'] if col not in ["catalogNumber"]]
326
+ removable_columns = [col for col in st.session_state['rules'] if col not in st.session_state['required_fields']]
327
+
328
+ st.session_state['current_rule'] = st.selectbox("Select a column to edit:", [""] + editable_columns)
329
+ # column_name = st.selectbox("Select a column to edit:", editable_columns)
330
+
331
+ # Form for input fields
332
+ with st.form(key='rule_form'):
333
+ # format_options = ["verbatim transcription", "spell check transcription", "boolean yes no", "boolean 1 0", "integer", "[list]", "yyyy-mm-dd"]
334
+ # current_rule["format"] = st.selectbox("Format:", format_options, index=format_options.index(current_rule["format"]) if current_rule["format"] else 0)
335
+ # current_rule["null_value"] = st.text_input("Null value:", value=current_rule["null_value"])
336
+ if st.session_state['current_rule']:
337
+ current_rule_description = st.text_area("Description of category:", value=st.session_state['rules'][st.session_state['current_rule']])
338
+ else:
339
+ current_rule_description = ''
340
+ commit_button = st.form_submit_button("Commit Column")
341
+
342
+ # Handle commit action
343
+ if commit_button and st.session_state['current_rule']:
344
+ # Commit the rules to the session state.
345
+ st.session_state['rules'][st.session_state['current_rule']] = current_rule_description
346
+ st.success(f"Column '{st.session_state['current_rule']}' added/updated in rules.")
347
+
348
+ # Force the form to reset by clearing the fields from the session state
349
+ st.session_state.pop('current_rule', None) # Clear the selected column to force reset
350
+
351
+ delete_column_name = st.selectbox("Select a column to delete:", [""] + removable_columns)
352
+ # with del_colbtn:
353
+ # st.write('##')
354
+ if st.button("Delete Column") and delete_column_name:
355
+ del st.session_state['rules'][delete_column_name]
356
+ st.success(f"Column '{delete_column_name}' removed from rules.")
357
+
358
+ with col_right:
359
+ # Display the current state of the JSON rules
360
+ st.subheader('Formatted Columns')
361
+ st.json(st.session_state['rules'])
362
+
363
+ st.write('---')
364
+
365
+ col_left_mapping, col_right_mapping = st.columns([6,4])
366
+ with col_left_mapping:
367
+ st.header("Mapping")
368
+ st.write("Assign each column name to a single category.")
369
+ st.session_state['refresh_mapping'] = False
370
+
371
+ # Dynamically create a list of all column names that can be assigned
372
+ # This assumes that the column names are the keys in the dictionary under 'rules'
373
+ all_column_names = list(st.session_state['rules'].keys())
374
+
375
+ categories = ['TAXONOMY', 'GEOGRAPHY', 'LOCALITY', 'COLLECTING', 'MISC']
376
+ if ('mapping' not in st.session_state) or (st.session_state['mapping'] == {}):
377
+ st.session_state['mapping'] = {category: [] for category in categories}
378
+ for category in categories:
379
+ # Filter out the already assigned columns
380
+ available_columns = [col for col in all_column_names if col not in st.session_state['assigned_columns'] or col in st.session_state['mapping'].get(category, [])]
381
+
382
+ # Ensure the current mapping is a subset of the available options
383
+ current_mapping = [col for col in st.session_state['mapping'].get(category, []) if col in available_columns]
384
+
385
+ # Provide a safe default if the current mapping is empty or contains invalid options
386
+ safe_default = current_mapping if all(col in available_columns for col in current_mapping) else []
387
+
388
+ # Create a multi-select widget for the category with a safe default
389
+ selected_columns = st.multiselect(
390
+ f"Select columns for {category}:",
391
+ available_columns,
392
+ default=safe_default,
393
+ key=f"mapping_{category}"
394
+ )
395
+ # Update the assigned_columns based on the selections
396
+ for col in current_mapping:
397
+ if col not in selected_columns and col in st.session_state['assigned_columns']:
398
+ st.session_state['assigned_columns'].remove(col)
399
+ st.session_state['refresh_mapping'] = True
400
+
401
+ for col in selected_columns:
402
+ if col not in st.session_state['assigned_columns']:
403
+ st.session_state['assigned_columns'].append(col)
404
+ st.session_state['refresh_mapping'] = True
405
+
406
+ # Update the mapping in session state when there's a change
407
+ st.session_state['mapping'][category] = selected_columns
408
+ if st.session_state['refresh_mapping']:
409
+ st.session_state['refresh_mapping'] = False
410
+
411
+ # Button to confirm and save the mapping configuration
412
+ if st.button('Confirm Mapping'):
413
+ if check_unique_mapping_assignments():
414
+ # Proceed with further actions since the mapping is confirmed and unique
415
+ pass
416
+
417
+ with col_right_mapping:
418
+ # Display the current state of the JSON rules
419
+ st.subheader('Formatted Column Maps')
420
+ st.json(st.session_state['mapping'])
421
+
422
+
423
+ col_left_save, col_right_save = st.columns([6,4])
424
+ with col_left_save:
425
+ # Input for new file name
426
+ new_filename = st.text_input("Enter filename to save your prompt as a configuration YAML:",placeholder='my_prompt_name')
427
+ # Button to save the new YAML file
428
+ if st.button('Save YAML', type='primary'):
429
+ if new_filename:
430
+ if check_unique_mapping_assignments():
431
+ if check_prompt_yaml_filename(new_filename):
432
+ save_prompt_yaml(new_filename, col_left_save)
433
+ else:
434
+ st.error("File name can only contain letters, numbers, underscores, and dashes. Cannot contain spaces.")
435
+ else:
436
+ st.error("Mapping contains an error. Make sure that each column is assigned to only ***one*** category.")
437
+ else:
438
+ st.error("Please enter a filename.")
439
+
440
+ if st.button('Exit'):
441
+ st.session_state.proceed_to_build_llm_prompt = False
442
+ st.session_state.proceed_to_main = True
443
+ st.rerun()
444
+
445
+
446
+ with col_prompt_main_right:
447
+ if st.session_state['user_clicked_load_prompt_yaml'] is None: # see if user has loaded a yaml to edit
448
+ st.session_state['show_prompt_name_e'] = f"Prompt Status :arrow_forward: Building prompt from scratch"
449
+ if st.session_state['prompt_name']:
450
+ st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: {st.session_state['prompt_name']}.yaml"
451
+ else:
452
+ st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: [PLEASE SET NAME]"
453
+ else:
454
+ st.session_state['show_prompt_name_e'] = f"Prompt Status: Editing :arrow_forward: {st.session_state['selected_yaml_file']}"
455
+ if st.session_state['prompt_name']:
456
+ st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: {st.session_state['prompt_name']}.yaml"
457
+ else:
458
+ st.session_state['show_prompt_name_w'] = f"New Prompt Name :arrow_forward: [PLEASE SET NAME]"
459
+
460
+ st.subheader(f'Full Prompt')
461
+ st.write(st.session_state['show_prompt_name_e'])
462
+ st.write(st.session_state['show_prompt_name_w'])
463
+ st.write("---")
464
+ st.session_state['prompt_info'] = {
465
+ 'prompt_author': st.session_state['prompt_author'],
466
+ 'prompt_author_institution': st.session_state['prompt_author_institution'],
467
+ 'prompt_name': st.session_state['prompt_name'],
468
+ 'prompt_version': st.session_state['prompt_version'],
469
+ 'prompt_description': st.session_state['prompt_description'],
470
+ 'LLM': st.session_state['LLM'],
471
+ 'instructions': st.session_state['instructions'],
472
+ 'json_formatting_instructions': st.session_state['json_formatting_instructions'],
473
+ 'rules': st.session_state['rules'],
474
+ 'mapping': st.session_state['mapping'],
475
+ }
476
+ st.json(st.session_state['prompt_info'])
477
+
478
+ build_LLM_prompt_config()
pages/report_bugs.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import streamlit.components.v1 as components
4
+
5
+ st.set_page_config(layout="wide", page_icon='img/icon.ico', page_title='VV Report Bugs',initial_sidebar_state="collapsed")
6
+
7
+ def display_report():
8
+ c1, c2, c3 = st.columns([4,6,1])
9
+ with c3:
10
+ st.page_link('app.py', label="Home", icon="🏠")
11
+ st.page_link(os.path.join("pages","faqs.py"), label="FAQs", icon="❔")
12
+ st.page_link(os.path.join("pages","report_bugs.py"), label="Report a Bug", icon="⚠️")
13
+
14
+ with c2:
15
+ st.write('To report a bug or request a new feature please fill out this [Google Form](https://docs.google.com/forms/d/e/1FAIpQLSdtW1z9Q1pGZTo5W9UeCa6PlQanP-b88iNKE6zsusRI78Itsw/viewform?usp=sf_link)')
16
+ components.iframe(f"https://docs.google.com/forms/d/e/1FAIpQLSdtW1z9Q1pGZTo5W9UeCa6PlQanP-b88iNKE6zsusRI78Itsw/viewform?embedded=true", height=700,scrolling=True,width=640)
17
+
18
+
19
+ display_report()
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
 
run_VoucherVision.py CHANGED
@@ -1,15 +1,10 @@
1
  import streamlit.web.cli as stcli
2
  import os, sys
3
 
4
- # Insert a file uploader that accepts multiple files at a time:
5
- # import streamlit as st
6
- # uploaded_files = st.file_uploader("Choose a CSV file", accept_multiple_files=True)
7
- # for uploaded_file in uploaded_files:
8
- # bytes_data = uploaded_file.read()
9
- # st.write("filename:", uploaded_file.name)
10
- # st.write(bytes_data)
11
-
12
  # pip install protobuf==3.20.0
 
 
 
13
 
14
 
15
  def resolve_path(path):
 
1
  import streamlit.web.cli as stcli
2
  import os, sys
3
 
 
 
 
 
 
 
 
 
4
  # pip install protobuf==3.20.0
5
+ # pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 nope
6
+ # pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
7
+
8
 
9
 
10
  def resolve_path(path):
vouchervision/LLM_crew_OpenAI.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from crewai import Agent, Task, Crew, Process
3
+ from langchain_community.tools import DuckDuckGoSearchRun
4
+ from langchain_openai import ChatOpenAI
5
+
6
+
7
+
8
+ class AIResearchCrew:
9
+ def __init__(self, openai_api_key, OCR, JSON_rules, search_tool=None, llm=None):
10
+ # Set the OPENAI API key
11
+ os.environ["OPENAI_API_KEY"] = openai_api_key
12
+
13
+ # Initialize the search tool, defaulting to DuckDuckGoSearchRun if not provided
14
+ self.search_tool = search_tool if search_tool is not None else DuckDuckGoSearchRun()
15
+
16
+ # Initialize the LLM (Language Learning Model), if provided
17
+ self.llm = llm
18
+
19
+ # Define the agents
20
+ self.transcriber = Agent(
21
+ role='Expert Text Parser',
22
+ goal='Parse and rearrange unstructured OCR text into a standardized JSON dictionary',
23
+ backstory="""You work at a museum transcribing specimen labels.
24
+ Your expertise lies in precisely transcribing text and placing the text into the appropriate category.""",
25
+ verbose=True,
26
+ allow_delegation=False
27
+ # Optionally include llm=self.llm here if an LLM was provided
28
+ )
29
+
30
+ self.spell_check = Agent(
31
+ role='Spell Checker',
32
+ goal='Correct any typos in the JSON key values',
33
+ backstory="""Your job is to look at the JSON key values and use your knowledge to verify spelling. Your corrections should be incorporated into the JSON object that will be passed to the next employee, so return the spell-checked JSON dictionary or the previous JSON dictionary if no changes are required.""",
34
+ verbose=True,
35
+ allow_delegation=True,
36
+ # Optionally include llm=self.llm here if an LLM was provided
37
+ )
38
+
39
+ self.fact_check = Agent(
40
+ role='Fact Checker',
41
+ goal='Verify the accuracy of taxonomy and location names',
42
+ backstory="""Your job is to verify the plant taxonomy and geographic locations contained within the key values are accurate. You can use internet searches to check these fields. Your corrections should be incorporated into a new JSON object that will be passed to the next employee, so return the corrected JSON dictionary or the previous JSON dictionary if no changes are required.""",
43
+ verbose=True,
44
+ allow_delegation=True,
45
+ tools=[self.search_tool]
46
+ # Optionally include llm=self.llm here if an LLM was provided
47
+ )
48
+
49
+ self.validator = Agent(
50
+ role='Synthesis',
51
+ goal='Create a final museum JSON record',
52
+ backstory="""You must produce a final JSON dictionary only.""",
53
+ verbose=True,
54
+ allow_delegation=True,
55
+ )
56
+
57
+ # Define the tasks
58
+ self.task1 = Task(
59
+ description=f"Use your knowledge to reformat, transform, and rearrange the unstructured text to fit the following requirements:{JSON_rules}. For null values, use an empty string. This is the unformatted OCR text: {OCR}",
60
+ agent=self.transcriber
61
+ )
62
+
63
+ self.task2 = Task(
64
+ description=f"The original text is OCR text, which may contain minor typos. Your job is to check all of the key values and fix any minor typos or spelling mistakes. You should remove any extraneous characters that should not belong in an official museum record.",
65
+ agent=self.spell_check
66
+ )
67
+
68
+ self.task3 = Task(
69
+ description="""Use your knowledge or search the internet to verify the information contained within the JSON dictionary.
70
+ For taxonomy, use the information contained in these keys: order, family, scientificName, scientificNameAuthorship, genus, specificEpithet, infraspecificEpithet.
71
+ For geography, use the information contained in these keys: country, stateProvince, municipality, decimalLatitude, decimalLongitude.""",
72
+ agent=self.fact_check
73
+ )
74
+
75
+ self.task4 = Task(
76
+ description=f"Verify that the JSON dictionary is valid. If not, correct the error. Then print out the final JSON dictionary only without explanations.",
77
+ agent=self.validator
78
+ )
79
+
80
+ # Create the crew
81
+ # self.crew = Crew(
82
+ # agents=[self.transcriber, self.spell_check, self.fact_check, self.validator],
83
+ # tasks=[self.task1, self.task2, self.task3, self.task4],
84
+ # verbose=2, # You can set it to 1 or 2 for different logging levels
85
+ # manager_llm=ChatOpenAI(temperature=0, model="gpt-4-1106-preview"),
86
+ # process=Process.hierarchical,
87
+ # )
88
+ self.crew = Crew(
89
+ agents=[self.transcriber, self.validator],
90
+ tasks=[self.task1, self.task4],
91
+ manager_llm=ChatOpenAI(temperature=0, model="gpt-4-1106-preview"),
92
+ process=Process.sequential,
93
+ verbose=2 # You can set it to 1 or 2 for different logging levels
94
+ )
95
+
96
+ def execute_tasks(self):
97
+ # Kick off the process and return the result
98
+ result = self.crew.kickoff()
99
+ print("######################")
100
+ print(result)
101
+ return result
102
+
103
+ if __name__ == "__main__":
104
+ openai_api_key = ""
105
+ OCR = "HERBARIUM OF MARYGROVE COLLEGE Name Carex scoparia V. condensa Fernald Locality Interlaken , Ind . Date 7/20/25 No ... ! Gerould Wilhelm & Laura Rericha \" Interlaken , \" was the site for many years of St. Joseph Novitiate , run by the Brothers of the Holy Cross . The buildings were on the west shore of Silver Lake , about 2 miles NE of Rolling Prairie , LaPorte Co. Indiana , ca. 41.688 \u00b0 N , 86.601 \u00b0 W Collector : Sister M. Vincent de Paul McGivney February 1 , 2011 THE UNIVERS Examined for the Flora of the Chicago Region OF 1817 MICH ! Ciscoparia SMVdeP University of Michigan Herbarium 1386297 copyright reserved cm Collector wortet 2010"
106
+ JSON_rules = """This is the JSON template that includes instructions for each key
107
+ {'catalogNumber': barcode identifier, at least 6 digits, fewer than 30 digits.,
108
+ 'order': full scientific name of the Order in which the taxon is classified. Order must be capitalized.,
109
+ 'family': full scientific name of the Family in which the taxon is classified. Family must be capitalized.,
110
+ 'scientificName': scientific name of the taxon including Genus, specific epithet, and any lower classifications.,
111
+ 'scientificNameAuthorship': authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.,
112
+ 'genus': taxonomic determination to Genus, Genus must be capitalized.,
113
+ 'subgenus': name of the subgenus.,
114
+ 'specificEpithet': The name of the first or species epithet of the scientificName. Only include the species epithet.,
115
+ 'infraspecificEpithet': lowest or terminal infraspecific epithet of the scientificName.,
116
+ 'identifiedBy': list of names of people, doctors, professors, groups, or organizations who identified, determined the taxon name to the subject organism. This is not the specimen collector., recordedBy list of names of people, doctors, professors, groups, or organizations.,
117
+ 'recordNumber': identifier given to the specimen at the time it was recorded.,
118
+ 'verbatimEventDate': The verbatim original representation of the date and time information for when the specimen was collected.,
119
+ 'eventDate': collection date formatted as year-month-day YYYY-MM-DD., habitat habitat.,
120
+ 'occurrenceRemarks': all descriptive text in the OCR rearranged into sensible sentences or sentence fragments.,
121
+ 'country': country or major administrative unit.,
122
+ 'stateProvince': state, province, canton, department, region, etc., county county, shire, department, parish etc.,
123
+ 'municipality': city, municipality, etc., locality description of geographic information aiding in pinpointing the exact origin or location of the specimen.,
124
+ 'degreeOfEstablishment': cultivated plants are intentionally grown by humans. Use either - unknown or cultivated.,
125
+ 'decimalLatitude': latitude decimal coordinate.,
126
+ 'decimalLongitude': longitude decimal coordinate., verbatimCoordinates verbatim location coordinates.,
127
+ 'minimumElevationInMeters': minimum elevation or altitude in meters.,
128
+ 'maximumElevationInMeters': maximum elevation or altitude in meters.}"""
129
+ ai_research_crew = AIResearchCrew(openai_api_key, OCR, JSON_rules)
130
+ result = ai_research_crew.execute_tasks()
vouchervision/LLM_local_cpu_MistralAI.py CHANGED
@@ -56,8 +56,6 @@ class LocalCPUMistralHandler:
56
  raise f"Unsupported GGUF model name"
57
 
58
  # self.model_id = f"mistralai/{self.model_name}"
59
- self.gpu_usage = {'max_load': 0, 'max_memory_usage': 0, 'monitoring': True}
60
-
61
  self.starting_temp = float(self.STARTING_TEMP)
62
  self.temp_increment = float(0.2)
63
  self.adjust_temp = self.starting_temp
 
56
  raise f"Unsupported GGUF model name"
57
 
58
  # self.model_id = f"mistralai/{self.model_name}"
 
 
59
  self.starting_temp = float(self.STARTING_TEMP)
60
  self.temp_increment = float(0.2)
61
  self.adjust_temp = self.starting_temp
vouchervision/OCR_CRAFT.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import Craft class
2
+ from craft_text_detector import read_image, load_craftnet_model, load_refinenet_model, get_prediction, export_detected_regions, export_extra_results, empty_cuda_cache
3
+
4
+ def main2():
5
+ # import craft functions
6
+
7
+
8
+ # set image path and export folder directory
9
+ # image = 'D:/Dropbox/SLTP/benchmark_datasets/SLTP_B50_MICH_Angiospermae2/img/MICH_7375774_Polygonaceae_Persicaria_.jpg' # can be filepath, PIL image or numpy array
10
+ # image = 'C:/Users/Will/Downloads/test_2024_02_07__14-59-52/Original_Images/SJRw 00891 - 01141__10001.jpg'
11
+ image = 'D:/Dropbox/VoucherVision/demo/demo_images/MICH_16205594_Poaceae_Jouvea_pilosa.jpg'
12
+ output_dir = 'D:/D_Desktop/test_out_CRAFT'
13
+
14
+ # read image
15
+ image = read_image(image)
16
+
17
+ # load models
18
+ refine_net = load_refinenet_model(cuda=True)
19
+ craft_net = load_craftnet_model(weight_path='D:/Dropbox/VoucherVision/vouchervision/craft/craft_mlt_25k.pth', cuda=True)
20
+
21
+ # perform prediction
22
+ prediction_result = get_prediction(
23
+ image=image,
24
+ craft_net=craft_net,
25
+ refine_net=refine_net,
26
+ text_threshold=0.4,
27
+ link_threshold=0.7,
28
+ low_text=0.4,
29
+ cuda=True,
30
+ long_size=1280
31
+ )
32
+
33
+ # export detected text regions
34
+ exported_file_paths = export_detected_regions(
35
+ image=image,
36
+ regions=prediction_result["boxes"],
37
+ output_dir=output_dir,
38
+ rectify=True
39
+ )
40
+
41
+ # export heatmap, detection points, box visualization
42
+ export_extra_results(
43
+ image=image,
44
+ regions=prediction_result["boxes"],
45
+ heatmaps=prediction_result["heatmaps"],
46
+ output_dir=output_dir
47
+ )
48
+
49
+ # unload models from gpu
50
+ empty_cuda_cache()
51
+
52
+
53
+ if __name__ == '__main__':
54
+ # main()
55
+ main2()
vouchervision/OCR_google_cloud_vision.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, io, sys, inspect, statistics, json
2
  from statistics import mean
3
  # from google.cloud import vision, storage
4
  from google.cloud import vision
@@ -8,10 +8,16 @@ import colorsys
8
  from tqdm import tqdm
9
  from google.oauth2 import service_account
10
 
11
- # currentdir = os.path.dirname(os.path.abspath(
12
- # inspect.getfile(inspect.currentframe())))
13
- # parentdir = os.path.dirname(currentdir)
14
- # sys.path.append(parentdir)
 
 
 
 
 
 
15
 
16
 
17
  '''
@@ -23,19 +29,31 @@ from google.oauth2 import service_account
23
  archivePrefix={arXiv},
24
  primaryClass={cs.CL}
25
  }
 
 
 
 
 
 
 
26
  '''
27
 
28
- class OCRGoogle:
29
 
30
  BBOX_COLOR = "black"
31
 
32
- def __init__(self, is_hf, path, cfg, trOCR_model_version, trOCR_model, trOCR_processor, device):
33
  self.is_hf = is_hf
 
 
 
34
 
35
  self.path = path
36
  self.cfg = cfg
37
  self.do_use_trOCR = self.cfg['leafmachine']['project']['do_use_trOCR']
38
  self.OCR_option = self.cfg['leafmachine']['project']['OCR_option']
 
 
39
 
40
  # Initialize TrOCR components
41
  self.trOCR_model_version = trOCR_model_version
@@ -70,6 +88,9 @@ class OCRGoogle:
70
  self.trOCR_confidences = None
71
  self.trOCR_characters = None
72
  self.set_client()
 
 
 
73
 
74
 
75
  def set_client(self):
@@ -86,6 +107,131 @@ class OCRGoogle:
86
  credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
87
  return credentials
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def detect_text_with_trOCR_using_google_bboxes(self, do_use_trOCR, logger):
91
  CONFIDENCES = 0.80
@@ -93,33 +239,36 @@ class OCRGoogle:
93
 
94
  self.OCR_JSON_to_file = {}
95
 
 
96
  if not do_use_trOCR:
97
- if self.OCR_option in ['normal',]:
98
  self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
99
  logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}")
100
- return f"Google_OCR_Standard:\n{self.normal_organized_text}"
 
101
 
102
- if self.OCR_option in ['hand',]:
103
  self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
104
  logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}")
105
- return f"Google_OCR_Handwriting:\n{self.hand_organized_text}"
106
-
107
- if self.OCR_option in ['both',]:
108
- logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}")
109
- return f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}"
110
 
 
 
 
 
111
  else:
112
  logger.info(f'Supplementing with trOCR')
113
 
114
  self.trOCR_texts = []
115
  original_image = Image.open(self.path).convert("RGB")
116
 
117
- if self.OCR_option in ['normal',]:
118
  available_bounds = self.normal_bounds_word
119
- elif self.OCR_option in ['hand',]:
120
- available_bounds = self.hand_bounds_word
121
- elif self.OCR_option in ['both',]:
122
  available_bounds = self.hand_bounds_word
 
 
123
  else:
124
  raise
125
 
@@ -127,9 +276,13 @@ class OCRGoogle:
127
  characters = []
128
  height = []
129
  confidences = []
 
 
130
  for bound in tqdm(available_bounds, desc="Processing words using Google Vision bboxes"):
 
 
 
131
  vertices = bound["vertices"]
132
-
133
 
134
  left = min([v["x"] for v in vertices])
135
  top = min([v["y"] for v in vertices])
@@ -177,24 +330,31 @@ class OCRGoogle:
177
  self.trOCR_confidences = confidences
178
  self.trOCR_characters = characters
179
 
180
- if self.OCR_option in ['normal',]:
181
  self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
182
  self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
183
  logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
184
- return f"Google_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
185
- if self.OCR_option in ['hand',]:
 
186
  self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
187
  self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
188
  logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
189
- return f"Google_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
190
- if self.OCR_option in ['both',]:
191
- self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
192
- self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
193
- self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
194
- logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
195
- return f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
196
- else:
197
- raise
 
 
 
 
 
 
198
 
199
  @staticmethod
200
  def confidence_to_color(confidence):
@@ -220,7 +380,7 @@ class OCRGoogle:
220
  if option == 'trOCR':
221
  color = (0, 170, 255)
222
  else:
223
- color = OCRGoogle.confidence_to_color(confidence)
224
  position = (bound["vertices"][0]["x"], bound["vertices"][0]["y"] - char_height)
225
  draw.text(position, character, fill=color, font=font)
226
 
@@ -258,13 +418,13 @@ class OCRGoogle:
258
  bound["vertices"][2]["x"], bound["vertices"][2]["y"],
259
  bound["vertices"][3]["x"], bound["vertices"][3]["y"],
260
  ],
261
- outline=OCRGoogle.BBOX_COLOR,
262
  width=line_width_thin
263
  )
264
 
265
  # Draw a line segment at the bottom of each handwritten character
266
  for bound, confidence in zip(bounds, confidences):
267
- color = OCRGoogle.confidence_to_color(confidence)
268
  # Use the bottom two vertices of the bounding box for the line
269
  bottom_left = (bound["vertices"][3]["x"], bound["vertices"][3]["y"] + line_width_thick)
270
  bottom_right = (bound["vertices"][2]["x"], bound["vertices"][2]["y"] + line_width_thick)
@@ -386,6 +546,7 @@ class OCRGoogle:
386
  self.normal_height = height_flat
387
  self.normal_confidences = confidences
388
  self.normal_characters = characters
 
389
 
390
 
391
  def detect_handwritten_ocr(self):
@@ -503,56 +664,112 @@ class OCRGoogle:
503
  self.hand_height = height_flat
504
  self.hand_confidences = confidences
505
  self.hand_characters = characters
 
506
 
507
 
508
  def process_image(self, do_create_OCR_helper_image, logger):
509
- if self.OCR_option in ['normal', 'both']:
510
- self.detect_text()
511
- if self.OCR_option in ['hand', 'both']:
512
- self.detect_handwritten_ocr()
513
- if self.OCR_option not in ['normal', 'hand', 'both']:
514
- self.OCR_option = 'both'
515
- self.detect_text()
516
- self.detect_handwritten_ocr()
517
-
518
- ### Optionally add trOCR to the self.OCR for additional context
519
- self.OCR = self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
520
- logger.info(f"OCR:\n{self.OCR}")
521
-
522
- if do_create_OCR_helper_image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  self.image = Image.open(self.path)
524
 
525
- if self.OCR_option in ['normal', 'both']:
526
  image_with_boxes_normal = self.draw_boxes('normal')
527
  text_image_normal = self.render_text_on_black_image('normal')
528
  self.merged_image_normal = self.merge_images(image_with_boxes_normal, text_image_normal)
529
 
530
- if self.OCR_option in ['hand', 'both']:
531
  image_with_boxes_hand = self.draw_boxes('hand')
532
  text_image_hand = self.render_text_on_black_image('hand')
533
  self.merged_image_hand = self.merge_images(image_with_boxes_hand, text_image_hand)
534
 
535
  if self.do_use_trOCR:
536
- text_image_trOCR = self.render_text_on_black_image('trOCR')
 
 
 
 
 
 
537
 
538
  ### Merge final overlay image
539
  ### [original, normal bboxes, normal text]
540
- if self.OCR_option in ['normal']:
541
  self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_normal)
542
  ### [original, hand bboxes, hand text]
543
- elif self.OCR_option in ['hand']:
544
  self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_hand)
545
  ### [original, normal bboxes, normal text, hand bboxes, hand text]
546
  else:
547
  self.overlay_image = self.merge_images(Image.open(self.path), self.merge_images(self.merged_image_normal, self.merged_image_hand))
548
 
549
  if self.do_use_trOCR:
550
- self.overlay_image = self.merge_images(self.overlay_image, text_image_trOCR)
 
 
 
 
 
 
 
551
 
552
  else:
553
  self.merged_image_normal = None
554
  self.merged_image_hand = None
555
  self.overlay_image = Image.open(self.path)
 
 
 
 
 
 
556
 
557
 
558
  '''
 
1
+ import os, io, sys, inspect, statistics, json, cv2
2
  from statistics import mean
3
  # from google.cloud import vision, storage
4
  from google.cloud import vision
 
8
  from tqdm import tqdm
9
  from google.oauth2 import service_account
10
 
11
+ ### LLaVA should only be installed if the user will actually use it.
12
+ ### It requires the most recent pytorch/Python and can mess with older systems
13
+ try:
14
+ from craft_text_detector import read_image, load_craftnet_model, load_refinenet_model, get_prediction, export_detected_regions, export_extra_results, empty_cuda_cache
15
+ except:
16
+ pass
17
+ try:
18
+ from OCR_llava import OCRllava
19
+ except:
20
+ pass
21
 
22
 
23
  '''
 
29
  archivePrefix={arXiv},
30
  primaryClass={cs.CL}
31
  }
32
+ @inproceedings{baek2019character,
33
+ title={Character Region Awareness for Text Detection},
34
+ author={Baek, Youngmin and Lee, Bado and Han, Dongyoon and Yun, Sangdoo and Lee, Hwalsuk},
35
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
36
+ pages={9365--9374},
37
+ year={2019}
38
+ }
39
  '''
40
 
41
+ class OCREngine:
42
 
43
  BBOX_COLOR = "black"
44
 
45
+ def __init__(self, logger, json_report, dir_home, is_hf, path, cfg, trOCR_model_version, trOCR_model, trOCR_processor, device):
46
  self.is_hf = is_hf
47
+ self.logger = logger
48
+
49
+ self.json_report = json_report
50
 
51
  self.path = path
52
  self.cfg = cfg
53
  self.do_use_trOCR = self.cfg['leafmachine']['project']['do_use_trOCR']
54
  self.OCR_option = self.cfg['leafmachine']['project']['OCR_option']
55
+ self.double_OCR = self.cfg['leafmachine']['project']['double_OCR']
56
+ self.dir_home = dir_home
57
 
58
  # Initialize TrOCR components
59
  self.trOCR_model_version = trOCR_model_version
 
88
  self.trOCR_confidences = None
89
  self.trOCR_characters = None
90
  self.set_client()
91
+ self.init_craft()
92
+ if 'LLaVA' in self.OCR_option:
93
+ self.init_llava()
94
 
95
 
96
  def set_client(self):
 
107
  credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
108
  return credentials
109
 
110
+ def init_craft(self):
111
+ if 'CRAFT' in self.OCR_option:
112
+ try:
113
+ self.refine_net = load_refinenet_model(cuda=True)
114
+ self.use_cuda = True
115
+ except:
116
+ self.refine_net = load_refinenet_model(cuda=False)
117
+ self.use_cuda = False
118
+
119
+ if self.use_cuda:
120
+ self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=True)
121
+ else:
122
+ self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=False)
123
+
124
+ def init_llava(self):
125
+
126
+ self.llava_prompt = """I need you to transcribe all of the text in this image.
127
+ Place the transcribed text into a JSON dictionary with this form {"Transcription_Printed_Text": "text","Transcription_Handwritten_Text": "text"}"""
128
+
129
+ self.model_path = "liuhaotian/" + self.cfg['leafmachine']['project']['OCR_option_llava']
130
+ self.model_quant = self.cfg['leafmachine']['project']['OCR_option_llava_bit']
131
+
132
+ self.json_report.set_text(text_main=f'Loading LLaVA model: {self.model_path} Quantization: {self.model_quant}')
133
+
134
+ if self.model_quant == '4bit':
135
+ use_4bit = True
136
+ elif self.model_quant == 'full':
137
+ use_4bit = False
138
+ else:
139
+ self.logger.info(f"Provided model quantization invlid. Using 4bit.")
140
+ use_4bit = True
141
+
142
+ self.Llava = OCRllava(self.logger, model_path=self.model_path, load_in_4bit=use_4bit, load_in_8bit=False)
143
+
144
+
145
+ def detect_text_craft(self):
146
+ # Perform prediction using CRAFT
147
+ image = read_image(self.path)
148
+
149
+ link_threshold = 0.85
150
+ text_threshold = 0.4
151
+ low_text = 0.4
152
+
153
+ if self.use_cuda:
154
+ self.prediction_result = get_prediction(
155
+ image=image,
156
+ craft_net=self.craft_net,
157
+ refine_net=self.refine_net,
158
+ text_threshold=text_threshold,
159
+ link_threshold=link_threshold,
160
+ low_text=low_text,
161
+ cuda=True,
162
+ long_size=1280
163
+ )
164
+ else:
165
+ self.prediction_result = get_prediction(
166
+ image=image,
167
+ craft_net=self.craft_net,
168
+ refine_net=self.refine_net,
169
+ text_threshold=text_threshold,
170
+ link_threshold=link_threshold,
171
+ low_text=low_text,
172
+ cuda=False,
173
+ long_size=1280
174
+ )
175
+
176
+ # Initialize metadata structures
177
+ bounds = []
178
+ bounds_word = [] # CRAFT gives bounds for text regions, not individual words
179
+ text_to_box_mapping = []
180
+ bounds_flat = []
181
+ height_flat = []
182
+ confidences = [] # CRAFT does not provide confidences per character, so this might be uniformly set or estimated
183
+ characters = [] # Simulating as CRAFT doesn't provide character-level details
184
+ organized_text = ""
185
+
186
+ total_b = len(self.prediction_result["boxes"])
187
+ i=0
188
+ # Process each detected text region
189
+ for box in self.prediction_result["boxes"]:
190
+ i+=1
191
+ self.json_report.set_text(text_main=f'Locating text using CRAFT --- {i}/{total_b}')
192
+
193
+ vertices = [{"x": int(vertex[0]), "y": int(vertex[1])} for vertex in box]
194
+
195
+ # Simulate a mapping for the whole detected region as a word
196
+ text_to_box_mapping.append({
197
+ "vertices": vertices,
198
+ "text": "detected_text" # Placeholder, as CRAFT does not provide the text content directly
199
+ })
200
+
201
+ # Assuming each box is a word for the sake of this example
202
+ bounds_word.append({"vertices": vertices})
203
+
204
+ # For simplicity, we're not dividing text regions into characters as CRAFT doesn't provide this
205
+ # Instead, we create a single large 'character' per detected region
206
+ bounds.append({"vertices": vertices})
207
+
208
+ # Simulate flat bounds and height for each detected region
209
+ x_positions = [vertex["x"] for vertex in vertices]
210
+ y_positions = [vertex["y"] for vertex in vertices]
211
+ min_x, max_x = min(x_positions), max(x_positions)
212
+ min_y, max_y = min(y_positions), max(y_positions)
213
+ avg_height = max_y - min_y
214
+ height_flat.append(avg_height)
215
+
216
+ # Assuming uniform confidence for all detected regions
217
+ confidences.append(1.0) # Placeholder confidence
218
+
219
+ # Adding dummy character for each box
220
+ characters.append("X") # Placeholder character
221
+
222
+ # Organize text as a single string (assuming each box is a word)
223
+ # organized_text += "detected_text " # Placeholder text
224
+
225
+ # Update class attributes with processed data
226
+ self.normal_bounds = bounds
227
+ self.normal_bounds_word = bounds_word
228
+ self.normal_text_to_box_mapping = text_to_box_mapping
229
+ self.normal_bounds_flat = bounds_flat # This would be similar to bounds if not processing characters individually
230
+ self.normal_height = height_flat
231
+ self.normal_confidences = confidences
232
+ self.normal_characters = characters
233
+ self.normal_organized_text = organized_text.strip()
234
+
235
 
236
  def detect_text_with_trOCR_using_google_bboxes(self, do_use_trOCR, logger):
237
  CONFIDENCES = 0.80
 
239
 
240
  self.OCR_JSON_to_file = {}
241
 
242
+ ocr_parts = ''
243
  if not do_use_trOCR:
244
+ if 'normal' in self.OCR_option:
245
  self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
246
  logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}")
247
+ # ocr_parts = ocr_parts + f"Google_OCR_Standard:\n{self.normal_organized_text}"
248
+ ocr_parts = self.normal_organized_text
249
 
250
+ if 'hand' in self.OCR_option:
251
  self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
252
  logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}")
253
+ # ocr_parts = ocr_parts + f"Google_OCR_Handwriting:\n{self.hand_organized_text}"
254
+ ocr_parts = self.hand_organized_text
 
 
 
255
 
256
+ # if self.OCR_option in ['both',]:
257
+ # logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}")
258
+ # return f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}"
259
+ return ocr_parts
260
  else:
261
  logger.info(f'Supplementing with trOCR')
262
 
263
  self.trOCR_texts = []
264
  original_image = Image.open(self.path).convert("RGB")
265
 
266
+ if 'normal' in self.OCR_option or 'CRAFT' in self.OCR_option:
267
  available_bounds = self.normal_bounds_word
268
+ elif 'hand' in self.OCR_option:
 
 
269
  available_bounds = self.hand_bounds_word
270
+ # elif self.OCR_option in ['both',]:
271
+ # available_bounds = self.hand_bounds_word
272
  else:
273
  raise
274
 
 
276
  characters = []
277
  height = []
278
  confidences = []
279
+ total_b = len(available_bounds)
280
+ i=0
281
  for bound in tqdm(available_bounds, desc="Processing words using Google Vision bboxes"):
282
+ i+=1
283
+ self.json_report.set_text(text_main=f'Working on trOCR :construction: {i}/{total_b}')
284
+
285
  vertices = bound["vertices"]
 
286
 
287
  left = min([v["x"] for v in vertices])
288
  top = min([v["y"] for v in vertices])
 
330
  self.trOCR_confidences = confidences
331
  self.trOCR_characters = characters
332
 
333
+ if 'normal' in self.OCR_option:
334
  self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
335
  self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
336
  logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
337
+ # ocr_parts = ocr_parts + f"\nGoogle_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
338
+ ocr_parts = self.trOCR_texts
339
+ if 'hand' in self.OCR_option:
340
  self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
341
  self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
342
  logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
343
+ # ocr_parts = ocr_parts + f"\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
344
+ ocr_parts = self.trOCR_texts
345
+ # if self.OCR_option in ['both',]:
346
+ # self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
347
+ # self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
348
+ # self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
349
+ # logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
350
+ # ocr_parts = ocr_parts + f"\nGoogle_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
351
+ if 'CRAFT' in self.OCR_option:
352
+ # self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
353
+ self.OCR_JSON_to_file['OCR_CRAFT_trOCR'] = self.trOCR_texts
354
+ logger.info(f"CRAFT_trOCR:\n{self.trOCR_texts}")
355
+ # ocr_parts = ocr_parts + f"\nCRAFT_trOCR:\n{self.trOCR_texts}"
356
+ ocr_parts = self.trOCR_texts
357
+ return ocr_parts
358
 
359
  @staticmethod
360
  def confidence_to_color(confidence):
 
380
  if option == 'trOCR':
381
  color = (0, 170, 255)
382
  else:
383
+ color = OCREngine.confidence_to_color(confidence)
384
  position = (bound["vertices"][0]["x"], bound["vertices"][0]["y"] - char_height)
385
  draw.text(position, character, fill=color, font=font)
386
 
 
418
  bound["vertices"][2]["x"], bound["vertices"][2]["y"],
419
  bound["vertices"][3]["x"], bound["vertices"][3]["y"],
420
  ],
421
+ outline=OCREngine.BBOX_COLOR,
422
  width=line_width_thin
423
  )
424
 
425
  # Draw a line segment at the bottom of each handwritten character
426
  for bound, confidence in zip(bounds, confidences):
427
+ color = OCREngine.confidence_to_color(confidence)
428
  # Use the bottom two vertices of the bounding box for the line
429
  bottom_left = (bound["vertices"][3]["x"], bound["vertices"][3]["y"] + line_width_thick)
430
  bottom_right = (bound["vertices"][2]["x"], bound["vertices"][2]["y"] + line_width_thick)
 
546
  self.normal_height = height_flat
547
  self.normal_confidences = confidences
548
  self.normal_characters = characters
549
+ return self.normal_cleaned_text
550
 
551
 
552
  def detect_handwritten_ocr(self):
 
664
  self.hand_height = height_flat
665
  self.hand_confidences = confidences
666
  self.hand_characters = characters
667
+ return self.hand_cleaned_text
668
 
669
 
670
  def process_image(self, do_create_OCR_helper_image, logger):
671
+ # Can stack options, so solitary if statements
672
+ self.OCR = 'OCR:\n'
673
+ if 'CRAFT' in self.OCR_option:
674
+ self.do_use_trOCR = True
675
+ self.detect_text_craft()
676
+ ### Optionally add trOCR to the self.OCR for additional context
677
+ if self.double_OCR:
678
+ part_OCR = "\CRAFT trOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
679
+ self.OCR = self.OCR + part_OCR + part_OCR
680
+ else:
681
+ self.OCR = self.OCR + "\CRAFT trOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
682
+ logger.info(f"CRAFT trOCR:\n{self.OCR}")
683
+
684
+ if 'LLaVA' in self.OCR_option: # This option does not produce an OCR helper image
685
+ self.json_report.set_text(text_main=f'Working on LLaVA {self.Llava.model_path} transcription :construction:')
686
+
687
+ image, json_output, direct_output, str_output, usage_report = self.Llava.transcribe_image(self.path, self.llava_prompt)
688
+ self.logger.info(f"LLaVA Usage Report for Model {self.Llava.model_path}:\n{usage_report}")
689
+
690
+ try:
691
+ self.OCR_JSON_to_file['OCR_LLaVA'] = str_output
692
+ except:
693
+ self.OCR_JSON_to_file = {}
694
+ self.OCR_JSON_to_file['OCR_LLaVA'] = str_output
695
+
696
+ if self.double_OCR:
697
+ self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}" + f"\nLLaVA OCR:\n{str_output}"
698
+ else:
699
+ self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}"
700
+ logger.info(f"LLaVA OCR:\n{self.OCR}")
701
+
702
+ if 'normal' in self.OCR_option or 'hand' in self.OCR_option:
703
+ if 'normal' in self.OCR_option:
704
+ self.OCR = self.OCR + "\nGoogle Printed OCR:\n" + self.detect_text()
705
+ if 'hand' in self.OCR_option:
706
+ self.OCR = self.OCR + "\nGoogle Handwritten OCR:\n" + self.detect_handwritten_ocr()
707
+ # if self.OCR_option not in ['normal', 'hand', 'both']:
708
+ # self.OCR_option = 'both'
709
+ # self.detect_text()
710
+ # self.detect_handwritten_ocr()
711
+
712
+ ### Optionally add trOCR to the self.OCR for additional context
713
+ if self.double_OCR:
714
+ part_OCR = "\ntrOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
715
+ self.OCR = self.OCR + part_OCR + part_OCR
716
+ else:
717
+ self.OCR = self.OCR + "\ntrOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
718
+ logger.info(f"OCR:\n{self.OCR}")
719
+
720
+ if do_create_OCR_helper_image and ('LLaVA' not in self.OCR_option):
721
  self.image = Image.open(self.path)
722
 
723
+ if 'normal' in self.OCR_option:
724
  image_with_boxes_normal = self.draw_boxes('normal')
725
  text_image_normal = self.render_text_on_black_image('normal')
726
  self.merged_image_normal = self.merge_images(image_with_boxes_normal, text_image_normal)
727
 
728
+ if 'hand' in self.OCR_option:
729
  image_with_boxes_hand = self.draw_boxes('hand')
730
  text_image_hand = self.render_text_on_black_image('hand')
731
  self.merged_image_hand = self.merge_images(image_with_boxes_hand, text_image_hand)
732
 
733
  if self.do_use_trOCR:
734
+ text_image_trOCR = self.render_text_on_black_image('trOCR')
735
+
736
+ if 'CRAFT' in self.OCR_option:
737
+ image_with_boxes_normal = self.draw_boxes('normal')
738
+ self.merged_image_normal = self.merge_images(image_with_boxes_normal, text_image_trOCR)
739
+
740
+
741
 
742
  ### Merge final overlay image
743
  ### [original, normal bboxes, normal text]
744
+ if 'CRAFT' in self.OCR_option or 'normal' in self.OCR_option:
745
  self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_normal)
746
  ### [original, hand bboxes, hand text]
747
+ elif 'hand' in self.OCR_option:
748
  self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_hand)
749
  ### [original, normal bboxes, normal text, hand bboxes, hand text]
750
  else:
751
  self.overlay_image = self.merge_images(Image.open(self.path), self.merge_images(self.merged_image_normal, self.merged_image_hand))
752
 
753
  if self.do_use_trOCR:
754
+ if 'CRAFT' in self.OCR_option:
755
+ heat_map_text = Image.fromarray(cv2.cvtColor(self.prediction_result["heatmaps"]["text_score_heatmap"], cv2.COLOR_BGR2RGB))
756
+ heat_map_link = Image.fromarray(cv2.cvtColor(self.prediction_result["heatmaps"]["link_score_heatmap"], cv2.COLOR_BGR2RGB))
757
+ self.overlay_image = self.merge_images(self.overlay_image, heat_map_text)
758
+ self.overlay_image = self.merge_images(self.overlay_image, heat_map_link)
759
+
760
+ else:
761
+ self.overlay_image = self.merge_images(self.overlay_image, text_image_trOCR)
762
 
763
  else:
764
  self.merged_image_normal = None
765
  self.merged_image_hand = None
766
  self.overlay_image = Image.open(self.path)
767
+
768
+ try:
769
+ empty_cuda_cache()
770
+ except:
771
+ pass
772
+
773
 
774
 
775
  '''
vouchervision/OCR_llava.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, logging
2
+ import requests
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ import torch
6
+ from transformers import AutoTokenizer, BitsAndBytesConfig, TextStreamer
7
+
8
+ from langchain.prompts import PromptTemplate
9
+ from langchain_core.output_parsers import JsonOutputParser
10
+ from langchain_core.pydantic_v1 import BaseModel, Field
11
+
12
+ from LLaVA.llava.model import LlavaLlamaForCausalLM
13
+ from LLaVA.llava.model.builder import load_pretrained_model
14
+ from LLaVA.llava.conversation import conv_templates, SeparatorStyle
15
+ from LLaVA.llava.utils import disable_torch_init
16
+ from LLaVA.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_PLACEHOLDER
17
+ from LLaVA.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images
18
+
19
+ from utils_LLM import SystemLoadMonitor
20
+
21
+ '''
22
+ Performance expectations system:
23
+ GPUs:
24
+ 2x RTX6000 Ada
25
+ CPU:
26
+ AMD Ryzen threadripper pro 5975wx 32-cores x64 threads
27
+ RAM:
28
+ 512 GB
29
+ OS:
30
+ Ubuntu 22.04.3 LTS
31
+
32
+ LLaVA Models:
33
+ "liuhaotian/llava-v1.6-mistral-7b" --- Model is 20 GB in size --- Mistral-7B
34
+ --- Full
35
+ --- Inference time ~6 sec
36
+ --- VRAM ~18 GB
37
+
38
+ --- 8bit (don't use. author says there is a problem right now, 2024-02-08) anecdotally worse results too
39
+ --- Inference time ~37 sec
40
+ --- VRAM ~18 GB
41
+
42
+ --- 4bit
43
+ --- Inference time ~15 sec
44
+ --- VRAM ~9 GB
45
+
46
+
47
+ "liuhaotian/llava-v1.6-34b" --- Model is 100 GB in size --- Hermes-Yi-34B
48
+ --- Full
49
+ --- Inference time ~21 sec
50
+ --- VRAM ~70 GB
51
+
52
+ --- 8bit (don't use. author says there is a problem right now, 2024-02-08) anecdotally worse results too
53
+ --- Inference time ~52 sec
54
+ --- VRAM ~42 GB
55
+
56
+ --- 4bit
57
+ --- Inference time ~23 sec
58
+ --- VRAM ~25GB
59
+
60
+
61
+ "liuhaotian/llava-v1.6-vicuna-13b" --- Model is 30 GB in size --- Vicuna-13B
62
+ --- Full
63
+ --- Inference time ~8 sec
64
+ --- VRAM ~33 GB
65
+
66
+ --- 8bit (don't use. author says there is a problem right now, 2024-02-08) anecdotally worse results too, has lots of ALL CAPS and mistakes
67
+ --- Inference time ~32 sec
68
+ --- VRAM ~23 GB
69
+
70
+ --- 4bit
71
+ --- Inference time ~12 sec
72
+ --- VRAM ~15 GB
73
+
74
+
75
+ "liuhaotian/llava-v1.6-vicuna-7b" --- Model is 15 GB in size --- Vicuna-7B
76
+ --- Full
77
+ --- Inference time ~7 sec
78
+ --- VRAM ~20 GB
79
+
80
+ --- 8bit (don't use. author says there is a problem right now, 2024-02-08) anecdotally worse results too
81
+ --- Inference time ~27 sec
82
+ --- VRAM ~14 GB
83
+
84
+ --- 4bit
85
+ --- Inference time ~10 sec
86
+ --- VRAM ~10 GB
87
+
88
+
89
+ '''
90
+
91
+ # OCR_Llava = OCRLlava()
92
+ # image, caption = OCR_Llava.transcribe_image("path/to/image.jpg", "Describe this image.")
93
+ # print(caption)
94
+
95
+ # Define the desired data structure for the transcription.
96
+ class Transcription(BaseModel):
97
+ Transcription: str = Field(description="The transcription of all text in the image.")
98
+
99
+ class OCRllava:
100
+ def __init__(self, logger, model_path="liuhaotian/llava-v1.6-34b",load_in_4bit=False, load_in_8bit=False):
101
+ self.monitor = SystemLoadMonitor(logger)
102
+
103
+ # self.model_path = "liuhaotian/llava-v1.6-mistral-7b"
104
+ # self.model_path = "liuhaotian/llava-v1.6-34b"
105
+ # self.model_path = "liuhaotian/llava-v1.6-vicuna-13b"
106
+
107
+ self.model_path = model_path
108
+
109
+ # kwargs = {"device_map": "auto", "load_in_4bit": load_in_4bit, "quantization_config": BitsAndBytesConfig(
110
+ # load_in_4bit=load_in_4bit,
111
+ # bnb_4bit_compute_dtype=torch.float16,
112
+ # bnb_4bit_use_double_quant=load_in_4bit,
113
+ # bnb_4bit_quant_type='nf4'
114
+ # )}
115
+
116
+
117
+
118
+ if "llama-2" in self.model_path.lower(): # this is borrowed from def eval_model(args): in run_llava.py
119
+ self.conv_mode = "llava_llama_2"
120
+ elif "mistral" in self.model_path.lower():
121
+ self.conv_mode = "mistral_instruct"
122
+ elif "v1.6-34b" in self.model_path.lower():
123
+ self.conv_mode = "chatml_direct"
124
+ elif "v1" in self.model_path.lower():
125
+ self.conv_mode = "llava_v1"
126
+ elif "mpt" in self.model_path.lower():
127
+ self.conv_mode = "mpt"
128
+ else:
129
+ self.conv_mode = "llava_v0"
130
+
131
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(self.model_path, None,
132
+ model_name = get_model_name_from_path(self.model_path),
133
+ load_8bit=load_in_8bit, load_4bit=load_in_4bit)
134
+
135
+ # self.model = LlavaLlamaForCausalLM.from_pretrained(self.model_path, low_cpu_mem_usage=True, **kwargs)
136
+ # self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
137
+ # self.vision_tower = self.model.get_vision_tower()
138
+ # if not self.vision_tower.is_loaded:
139
+ # self.vision_tower.load_model()
140
+ # self.vision_tower.to(device='cuda')
141
+ # self.image_processor = self.vision_tower.image_processor
142
+ self.parser = JsonOutputParser(pydantic_object=Transcription)
143
+
144
+ def image_parser(self):
145
+ sep = ","
146
+ out = self.image_file.split(sep)
147
+ return out
148
+
149
+ def load_image(self, image_file):
150
+ if image_file.startswith("http") or image_file.startswith("https"):
151
+ response = requests.get(image_file)
152
+ image = Image.open(BytesIO(response.content)).convert("RGB")
153
+ else:
154
+ image = Image.open(image_file).convert("RGB")
155
+ return image
156
+
157
+ def load_images(self, image_files):
158
+ out = []
159
+ for image_file in image_files:
160
+ image = self.load_image(image_file)
161
+ out.append(image)
162
+ return out
163
+
164
+ def combine_json_values(self, data, separator=" "):
165
+ """
166
+ Recursively traverses through a JSON-like dictionary or list,
167
+ combining all the values into a single string with a given separator.
168
+
169
+ :return: A single string containing all values from the input.
170
+ """
171
+ # Base case for strings, directly return the string
172
+ if isinstance(data, str):
173
+ return data
174
+
175
+ # If the data is a dictionary, iterate through its values
176
+ elif isinstance(data, dict):
177
+ combined_string = separator.join(self.combine_json_values(v, separator) for v in data.values())
178
+
179
+ # If the data is a list, iterate through its elements
180
+ elif isinstance(data, list):
181
+ combined_string = separator.join(self.combine_json_values(item, separator) for item in data)
182
+
183
+ # For other data types (e.g., numbers), convert to string directly
184
+ else:
185
+ combined_string = str(data)
186
+
187
+ return combined_string
188
+
189
+ def transcribe_image(self, image_file, prompt, max_new_tokens=512, temperature=0.1, top_p=None, num_beams=1):
190
+ self.monitor.start_monitoring_usage()
191
+
192
+ self.image_file = image_file
193
+ if image_file.startswith('http') or image_file.startswith('https'):
194
+ response = requests.get(image_file)
195
+ image = Image.open(BytesIO(response.content)).convert('RGB')
196
+ else:
197
+ image = Image.open(image_file).convert('RGB')
198
+ disable_torch_init()
199
+
200
+ qs = prompt
201
+ image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
202
+ if IMAGE_PLACEHOLDER in qs:
203
+ if self.model.config.mm_use_im_start_end:
204
+ qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
205
+ else:
206
+ qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
207
+ else:
208
+ if self.model.config.mm_use_im_start_end:
209
+ qs = image_token_se + "\n" + qs
210
+ else:
211
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
212
+
213
+
214
+ conv = conv_templates[self.conv_mode].copy()
215
+ conv.append_message(conv.roles[0], qs)
216
+ conv.append_message(conv.roles[1], None)
217
+ prompt = conv.get_prompt()
218
+
219
+ image_files = self.image_parser()
220
+ images = self.load_images(image_files)
221
+ image_sizes = [x.size for x in images]
222
+ images_tensor = process_images(
223
+ images,
224
+ self.image_processor,
225
+ self.model.config
226
+ ).to(self.model.device, dtype=torch.float16)
227
+
228
+ input_ids = (
229
+ tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
230
+ .unsqueeze(0)
231
+ .cuda()
232
+ )
233
+
234
+ with torch.inference_mode():
235
+ output_ids = self.model.generate(
236
+ input_ids,
237
+ images=images_tensor,
238
+ image_sizes=image_sizes,
239
+ do_sample=True if temperature > 0 else False,
240
+ temperature=temperature,
241
+ # top_p=top_p,
242
+ num_beams=num_beams,
243
+ max_new_tokens=max_new_tokens,
244
+ use_cache=True,
245
+ )
246
+
247
+ direct_output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
248
+
249
+ # Parse the output to JSON format using the specified schema.
250
+ try:
251
+ json_output = self.parser.parse(direct_output)
252
+ except:
253
+ json_output = direct_output
254
+
255
+ try:
256
+ str_output = self.combine_json_values(json_output)
257
+ except:
258
+ str_output = direct_output
259
+
260
+ self.monitor.stop_inference_timer() # Starts tool timer too
261
+ usage_report = self.monitor.stop_monitoring_report_usage()
262
+
263
+
264
+ return image, json_output, direct_output, str_output, usage_report
265
+
266
+
267
+ PROMPT_OCR = """I need you to transcribe all of the text in this image. Place the transcribed text into a JSON dictionary with this form {"Transcription": "text"}"""
268
+ PROMPT_ALL = """1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
269
+ 2. Map the unstructured OCR text to the appropriate JSON key and populate the field given the user-defined rules.
270
+ 3. JSON key values are permitted to remain empty strings if the corresponding information is not found in the unstructured OCR text.
271
+ 4. Duplicate dictionary fields are not allowed.
272
+ 5. Ensure all JSON keys are in camel case.
273
+ 6. Ensure new JSON field values follow sentence case capitalization.
274
+ 7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
275
+ 8. Ensure output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
276
+ 9. Only return a JSON dictionary represented as a string. You should not explain your answer.
277
+ This section provides rules for formatting each JSON value organized by the JSON key.
278
+ {catalogNumber Barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits., order The full scientific name of the order in which the taxon is classified. Order must be capitalized., family The full scientific name of the family in which the taxon is classified. Family must be capitalized., scientificName The scientific name of the taxon including genus, specific epithet, and any lower classifications., scientificNameAuthorship The authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode., genus Taxonomic determination to genus. Genus must be capitalized. If genus is not present use the taxonomic family name followed by the word 'indet'., subgenus The full scientific name of the subgenus in which the taxon is classified. Values should include the genus to avoid homonym confusion., specificEpithet The name of the first or species epithet of the scientificName. Only include the species epithet., infraspecificEpithet The name of the lowest or terminal infraspecific epithet of the scientificName, excluding any rank designation., identifiedBy A comma separated list of names of people, groups, or organizations who assigned the taxon to the subject organism. This is not the specimen collector., recordedBy A comma separated list of names of people, groups, or organizations responsible for observing, recording, collecting, or presenting the original specimen. The primary collector or observer should be listed first., recordNumber An identifier given to the occurrence at the time it was recorded. Often serves as a link between field notes and an occurrence record, such as a specimen collector's number., verbatimEventDate The verbatim original representation of the date and time information for when the specimen was collected. Date of collection exactly as it appears on the label. Do not change the format or correct typos., eventDate Date the specimen was collected formatted as year-month-day, YYYY-MM_DD. If specific components of the date are unknown, they should be replaced with zeros. Examples \0000-00-00\ if the entire date is unknown, \YYYY-00-00\ if only the year is known, and \YYYY-MM-00\ if year and month are known but day is not., habitat A category or description of the habitat in which the specimen collection event occurred., occurrenceRemarks Text describing the specimen's geographic location. Text describing the appearance of the specimen. A statement about the presence or absence of a taxon at a the collection location. Text describing the significance of the specimen, such as a specific expedition or notable collection. Description of plant features such as leaf shape, size, color, stem texture, height, flower structure, scent, fruit or seed characteristics, root system type, overall growth habit and form, any notable aroma or secretions, presence of hairs or bristles, and any other distinguishing morphological or physiological characteristics., country The name of the country or major administrative unit in which the specimen was originally collected., stateProvince The name of the next smaller administrative region than country (state, province, canton, department, region, etc.) in which the specimen was originally collected., county The full, unabbreviated name of the next smaller administrative region than stateProvince (county, shire, department, parish etc.) in which the specimen was originally collected., municipality The full, unabbreviated name of the next smaller administrative region than county (city, municipality, etc.) in which the specimen was originally collected., locality Description of geographic location, landscape, landmarks, regional features, nearby places, or any contextual information aiding in pinpointing the exact origin or location of the specimen., degreeOfEstablishment Cultivated plants are intentionally grown by humans. In text descriptions, look for planting dates, garden locations, ornamental, cultivar names, garden, or farm to indicate cultivated plant. Use either - unknown or cultivated., decimalLatitude Latitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format., decimalLongitude Longitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format., verbatimCoordinates Verbatim location coordinates as they appear on the label. Do not convert formats. Possible coordinate types include [Lat, Long, UTM, TRS]., minimumElevationInMeters Minimum elevation or altitude in meters. Only if units are explicit then convert from feet (\ft\ or \ft.\\ or \feet\) to meters (\m\ or \m.\ or \meters\). Round to integer., maximumElevationInMeters Maximum elevation or altitude in meters. If only one elevation is present, then max_elevation should be set to the null_value. Only if units are explicit then convert from feet (\ft\ or \ft.\ or \feet\) to meters (\m\ or \m.\ or \meters\). Round to integer.}
279
+ Please populate the following JSON dictionary based on the rules and the unformatted OCR text
280
+ {
281
+ catalogNumber ,
282
+ order ,
283
+ family ,
284
+ scientificName ,
285
+ scientificNameAuthorship ,
286
+ genus ,
287
+ subgenus ,
288
+ specificEpithet ,
289
+ infraspecificEpithet ,
290
+ identifiedBy ,
291
+ recordedBy ,
292
+ recordNumber ,
293
+ verbatimEventDate ,
294
+ eventDate ,
295
+ habitat ,
296
+ occurrenceRemarks ,
297
+ country ,
298
+ stateProvince ,
299
+ county ,
300
+ municipality ,
301
+ locality ,
302
+ degreeOfEstablishment ,
303
+ decimalLatitude ,
304
+ decimalLongitude ,
305
+ verbatimCoordinates ,
306
+ minimumElevationInMeters ,
307
+ maximumElevationInMeters
308
+ }
309
+ """
310
+ if __name__ == '__main__':
311
+ logger = logging.getLogger('LLaVA')
312
+ logger.setLevel(logging.DEBUG)
313
+
314
+ OCR_Llava = OCRllava(logger)
315
+ image, json_output, direct_output, str_output, usage_report = OCR_Llava.transcribe_image("/home/brlab/Dropbox/VoucherVision/demo/demo_images/MICH_16205594_Poaceae_Jouvea_pilosa.jpg",
316
+ PROMPT_OCR)
317
+ print('json_output')
318
+ print(json_output)
319
+ print('direct_output')
320
+ print(direct_output)
321
+ print('str_output')
322
+ print(str_output)
323
+ print('usage_report')
324
+ print(usage_report)
vouchervision/VoucherVision_Config_Builder.py CHANGED
@@ -37,6 +37,10 @@ def build_VV_config(loaded_cfg=None):
37
 
38
  do_use_trOCR = False
39
  OCR_option = 'hand'
 
 
 
 
40
  check_for_illegal_filenames = False
41
 
42
  LLM_version_user = 'Azure GPT 3.5 Instruct' #'Azure GPT 4 Turbo 1106-preview'
@@ -47,6 +51,9 @@ def build_VV_config(loaded_cfg=None):
47
  batch_size = 500
48
  num_workers = 8
49
 
 
 
 
50
  path_domain_knowledge = os.path.join(dir_home,'domain_knowledge','SLTP_UM_AllAsiaMinimalInRegion.xlsx')
51
  embeddings_database_name = os.path.splitext(os.path.basename(path_domain_knowledge))[0]
52
 
@@ -58,8 +65,8 @@ def build_VV_config(loaded_cfg=None):
58
  return assemble_config(dir_home, run_name, dir_images_local,dir_output,
59
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
60
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
61
- prompt_version, do_create_OCR_helper_image, do_use_trOCR, OCR_option, save_cropped_annotations,
62
- check_for_illegal_filenames, use_domain_knowledge=False)
63
  else:
64
  dir_home = os.path.dirname(os.path.dirname(__file__))
65
  run_name = loaded_cfg['leafmachine']['project']['run_name']
@@ -74,6 +81,11 @@ def build_VV_config(loaded_cfg=None):
74
 
75
  do_use_trOCR = loaded_cfg['leafmachine']['project']['do_use_trOCR']
76
  OCR_option = loaded_cfg['leafmachine']['project']['OCR_option']
 
 
 
 
 
77
 
78
  LLM_version_user = loaded_cfg['leafmachine']['LLM_version']
79
  prompt_version = loaded_cfg['leafmachine']['project']['prompt_version']
@@ -88,19 +100,20 @@ def build_VV_config(loaded_cfg=None):
88
 
89
  save_cropped_annotations = loaded_cfg['leafmachine']['cropped_components']['save_cropped_annotations']
90
  check_for_illegal_filenames = loaded_cfg['leafmachine']['do']['check_for_illegal_filenames']
 
91
 
92
  return assemble_config(dir_home, run_name, dir_images_local,dir_output,
93
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
94
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
95
- prompt_version, do_create_OCR_helper_image, do_use_trOCR, OCR_option, save_cropped_annotations,
96
- check_for_illegal_filenames, use_domain_knowledge=False)
97
 
98
 
99
  def assemble_config(dir_home, run_name, dir_images_local,dir_output,
100
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
101
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
102
- prompt_version, do_create_OCR_helper_image_user, do_use_trOCR, OCR_option, save_cropped_annotations,
103
- check_for_illegal_filenames, use_domain_knowledge=False):
104
 
105
 
106
  # Initialize the base structure
@@ -112,6 +125,7 @@ def assemble_config(dir_home, run_name, dir_images_local,dir_output,
112
  do_section = {
113
  'check_for_illegal_filenames': check_for_illegal_filenames,
114
  'check_for_corrupt_images_make_vertical': True,
 
115
  }
116
 
117
  print_section = {
@@ -144,6 +158,10 @@ def assemble_config(dir_home, run_name, dir_images_local,dir_output,
144
  'delete_temps_keep_VVE': False,
145
  'do_use_trOCR': do_use_trOCR,
146
  'OCR_option': OCR_option,
 
 
 
 
147
  }
148
 
149
  modules_section = {
 
37
 
38
  do_use_trOCR = False
39
  OCR_option = 'hand'
40
+ OCR_option_llava = 'llava-v1.6-mistral-7b' # "llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",
41
+ OCR_option_llava_bit = 'full' # full or 4bit
42
+ double_OCR = False
43
+
44
  check_for_illegal_filenames = False
45
 
46
  LLM_version_user = 'Azure GPT 3.5 Instruct' #'Azure GPT 4 Turbo 1106-preview'
 
51
  batch_size = 500
52
  num_workers = 8
53
 
54
+ skip_vertical = False
55
+ pdf_conversion_dpi = 100
56
+
57
  path_domain_knowledge = os.path.join(dir_home,'domain_knowledge','SLTP_UM_AllAsiaMinimalInRegion.xlsx')
58
  embeddings_database_name = os.path.splitext(os.path.basename(path_domain_knowledge))[0]
59
 
 
65
  return assemble_config(dir_home, run_name, dir_images_local,dir_output,
66
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
67
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
68
+ prompt_version, do_create_OCR_helper_image, do_use_trOCR, OCR_option, OCR_option_llava, OCR_option_llava_bit, double_OCR, save_cropped_annotations,
69
+ check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
70
  else:
71
  dir_home = os.path.dirname(os.path.dirname(__file__))
72
  run_name = loaded_cfg['leafmachine']['project']['run_name']
 
81
 
82
  do_use_trOCR = loaded_cfg['leafmachine']['project']['do_use_trOCR']
83
  OCR_option = loaded_cfg['leafmachine']['project']['OCR_option']
84
+ OCR_option_llava = loaded_cfg['leafmachine']['project']['OCR_option_llava']
85
+ OCR_option_llava_bit = loaded_cfg['leafmachine']['project']['OCR_option_llava_bit']
86
+ double_OCR = loaded_cfg['leafmachine']['project']['double_OCR']
87
+
88
+ pdf_conversion_dpi = loaded_cfg['leafmachine']['project']['pdf_conversion_dpi']
89
 
90
  LLM_version_user = loaded_cfg['leafmachine']['LLM_version']
91
  prompt_version = loaded_cfg['leafmachine']['project']['prompt_version']
 
100
 
101
  save_cropped_annotations = loaded_cfg['leafmachine']['cropped_components']['save_cropped_annotations']
102
  check_for_illegal_filenames = loaded_cfg['leafmachine']['do']['check_for_illegal_filenames']
103
+ skip_vertical = loaded_cfg['leafmachine']['do']['skip_vertical']
104
 
105
  return assemble_config(dir_home, run_name, dir_images_local,dir_output,
106
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
107
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
108
+ prompt_version, do_create_OCR_helper_image, do_use_trOCR, OCR_option, OCR_option_llava, OCR_option_llava_bit, double_OCR, save_cropped_annotations,
109
+ check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
110
 
111
 
112
  def assemble_config(dir_home, run_name, dir_images_local,dir_output,
113
  prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
114
  path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
115
+ prompt_version, do_create_OCR_helper_image_user, do_use_trOCR, OCR_option, OCR_option_llava, OCR_option_llava_bit, double_OCR, save_cropped_annotations,
116
+ check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False):
117
 
118
 
119
  # Initialize the base structure
 
125
  do_section = {
126
  'check_for_illegal_filenames': check_for_illegal_filenames,
127
  'check_for_corrupt_images_make_vertical': True,
128
+ 'skip_vertical': skip_vertical,
129
  }
130
 
131
  print_section = {
 
158
  'delete_temps_keep_VVE': False,
159
  'do_use_trOCR': do_use_trOCR,
160
  'OCR_option': OCR_option,
161
+ 'OCR_option_llava': OCR_option_llava,
162
+ 'OCR_option_llava_bit': OCR_option_llava_bit,
163
+ 'double_OCR': double_OCR,
164
+ 'pdf_conversion_dpi': pdf_conversion_dpi,
165
  }
166
 
167
  modules_section = {
vouchervision/data_project.py CHANGED
@@ -12,6 +12,19 @@ from vouchervision.download_from_GBIF_all_images_in_file import download_all_ima
12
  from PIL import Image
13
  from tqdm import tqdm
14
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  @dataclass
17
  class Project_Info():
@@ -39,6 +52,7 @@ class Project_Info():
39
  self.Dirs = Dirs
40
  logger.name = 'Project Info'
41
  logger.info("Gathering Images and Image Metadata")
 
42
 
43
  self.batch_size = cfg['leafmachine']['project']['batch_size']
44
 
@@ -90,15 +104,28 @@ class Project_Info():
90
  def remove_non_numbers(self, s):
91
  return ''.join([char for char in s if char.isdigit()])
92
 
 
 
 
 
 
 
 
 
93
  def copy_images_to_project_dir(self, dir_images, Dirs):
94
  n_total = len(os.listdir(dir_images))
95
- for file in tqdm(os.listdir(dir_images), desc=f'{bcolors.HEADER} Copying images to working directory{bcolors.ENDC}',colour="white",position=0,total = n_total):
96
- # Copy og image to new dir
97
- # Copied image will be used for all downstream applications
98
  source = os.path.join(dir_images, file)
99
- destination = os.path.join(Dirs.save_original, file)
100
- shutil.copy(source, destination)
101
-
 
 
 
 
 
 
 
102
  def make_file_names_custom(self, dir_images, cfg, Dirs):
103
  n_total = len(os.listdir(dir_images))
104
  for file in tqdm(os.listdir(dir_images), desc=f'{bcolors.HEADER} Creating Catalog Number from file name{bcolors.ENDC}',colour="green",position=0,total = n_total):
 
12
  from PIL import Image
13
  from tqdm import tqdm
14
  from pathlib import Path
15
+ import fitz
16
+
17
+ def convert_pdf_to_jpg(source_pdf, destination_dir, dpi=100):
18
+ doc = fitz.open(source_pdf)
19
+ for page_num in range(len(doc)):
20
+ page = doc.load_page(page_num) # Load the current page
21
+ pix = page.get_pixmap(dpi=dpi) # Render page to an image
22
+ output_filename = f"{os.path.splitext(os.path.basename(source_pdf))[0]}__{10000 + page_num + 1}.jpg"
23
+ output_filepath = os.path.join(destination_dir, output_filename)
24
+ pix.save(output_filepath) # Save the image
25
+ length_doc = len(doc)
26
+ doc.close()
27
+ return length_doc
28
 
29
  @dataclass
30
  class Project_Info():
 
52
  self.Dirs = Dirs
53
  logger.name = 'Project Info'
54
  logger.info("Gathering Images and Image Metadata")
55
+ self.logger = logger
56
 
57
  self.batch_size = cfg['leafmachine']['project']['batch_size']
58
 
 
104
  def remove_non_numbers(self, s):
105
  return ''.join([char for char in s if char.isdigit()])
106
 
107
+ # def copy_images_to_project_dir(self, dir_images, Dirs):
108
+ # n_total = len(os.listdir(dir_images))
109
+ # for file in tqdm(os.listdir(dir_images), desc=f'{bcolors.HEADER} Copying images to working directory{bcolors.ENDC}',colour="white",position=0,total = n_total):
110
+ # # Copy og image to new dir
111
+ # # Copied image will be used for all downstream applications
112
+ # source = os.path.join(dir_images, file)
113
+ # destination = os.path.join(Dirs.save_original, file)
114
+ # shutil.copy(source, destination)
115
  def copy_images_to_project_dir(self, dir_images, Dirs):
116
  n_total = len(os.listdir(dir_images))
117
+ for file in tqdm(os.listdir(dir_images), desc=f'{bcolors.HEADER} Copying images to working directory{bcolors.ENDC}', colour="white", position=0, total=n_total):
 
 
118
  source = os.path.join(dir_images, file)
119
+ # Check if file is a PDF
120
+ if file.lower().endswith('.pdf'):
121
+ # Convert PDF pages to JPG images
122
+ n_pages = convert_pdf_to_jpg(source, Dirs.save_original)
123
+ self.logger.info(f"Converted {n_pages} pages to JPG from PDF: {file}")
124
+ else:
125
+ # Copy non-PDF files directly
126
+ destination = os.path.join(Dirs.save_original, file)
127
+ shutil.copy(source, destination)
128
+
129
  def make_file_names_custom(self, dir_images, cfg, Dirs):
130
  n_total = len(os.listdir(dir_images))
131
  for file in tqdm(os.listdir(dir_images), desc=f'{bcolors.HEADER} Creating Catalog Number from file name{bcolors.ENDC}',colour="green",position=0,total = n_total):
vouchervision/general_utils.py CHANGED
@@ -437,6 +437,7 @@ def split_into_batches(Project, logger, cfg):
437
  return Project, n_batches, m
438
 
439
  def make_images_in_dir_vertical(dir_images_unprocessed, cfg):
 
440
  if cfg['leafmachine']['do']['check_for_corrupt_images_make_vertical']:
441
  n_rotate = 0
442
  n_corrupt = 0
@@ -445,10 +446,11 @@ def make_images_in_dir_vertical(dir_images_unprocessed, cfg):
445
  if image_name_jpg.endswith((".jpg",".JPG",".jpeg",".JPEG")):
446
  try:
447
  image = cv2.imread(os.path.join(dir_images_unprocessed, image_name_jpg))
448
- h, w, img_c = image.shape
449
- image, img_h, img_w, did_rotate = make_image_vertical(image, h, w, do_rotate_180=False)
450
- if did_rotate:
451
- n_rotate += 1
 
452
  cv2.imwrite(os.path.join(dir_images_unprocessed,image_name_jpg), image)
453
  except:
454
  n_corrupt +=1
@@ -457,10 +459,11 @@ def make_images_in_dir_vertical(dir_images_unprocessed, cfg):
457
  elif image_name_jpg.endswith((".tiff",".tif",".png",".PNG",".TIFF",".TIF",".jp2",".JP2",".bmp",".BMP",".dib",".DIB")):
458
  try:
459
  image = cv2.imread(os.path.join(dir_images_unprocessed, image_name_jpg))
460
- h, w, img_c = image.shape
461
- image, img_h, img_w, did_rotate = make_image_vertical(image, h, w, do_rotate_180=False)
462
- if did_rotate:
463
- n_rotate += 1
 
464
  image_name_jpg = '.'.join([image_name_jpg.split('.')[0], 'jpg'])
465
  cv2.imwrite(os.path.join(dir_images_unprocessed,image_name_jpg), image)
466
  except:
 
437
  return Project, n_batches, m
438
 
439
  def make_images_in_dir_vertical(dir_images_unprocessed, cfg):
440
+ skip_vertical = cfg['leafmachine']['do']['skip_vertical']
441
  if cfg['leafmachine']['do']['check_for_corrupt_images_make_vertical']:
442
  n_rotate = 0
443
  n_corrupt = 0
 
446
  if image_name_jpg.endswith((".jpg",".JPG",".jpeg",".JPEG")):
447
  try:
448
  image = cv2.imread(os.path.join(dir_images_unprocessed, image_name_jpg))
449
+ if not skip_vertical:
450
+ h, w, img_c = image.shape
451
+ image, img_h, img_w, did_rotate = make_image_vertical(image, h, w, do_rotate_180=False)
452
+ if did_rotate:
453
+ n_rotate += 1
454
  cv2.imwrite(os.path.join(dir_images_unprocessed,image_name_jpg), image)
455
  except:
456
  n_corrupt +=1
 
459
  elif image_name_jpg.endswith((".tiff",".tif",".png",".PNG",".TIFF",".TIF",".jp2",".JP2",".bmp",".BMP",".dib",".DIB")):
460
  try:
461
  image = cv2.imread(os.path.join(dir_images_unprocessed, image_name_jpg))
462
+ if not skip_vertical:
463
+ h, w, img_c = image.shape
464
+ image, img_h, img_w, did_rotate = make_image_vertical(image, h, w, do_rotate_180=False)
465
+ if did_rotate:
466
+ n_rotate += 1
467
  image_name_jpg = '.'.join([image_name_jpg.split('.')[0], 'jpg'])
468
  cv2.imwrite(os.path.join(dir_images_unprocessed,image_name_jpg), image)
469
  except:
vouchervision/llava_test.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from LLaVA.llava.model.builder import load_pretrained_model
2
+ from LLaVA.llava.mm_utils import get_model_name_from_path
3
+ from LLaVA.llava.eval.run_llava import eval_model
4
+
5
+ # model_path = "liuhaotian/llava-v1.5-7b"
6
+
7
+ # tokenizer, model, image_processor, context_len = load_pretrained_model(
8
+ # model_path=model_path,
9
+ # model_base=None,
10
+ # model_name=get_model_name_from_path(model_path)
11
+ # )
12
+
13
+ # model_path = "liuhaotian/llava-v1.5-7b"
14
+ # model_path = "liuhaotian/llava-v1.6-mistral-7b"
15
+ model_path = "liuhaotian/llava-v1.6-34b"
16
+ prompt = """I need you to transcribe all of the text in this image. Place the transcribed text into a JSON dictionary with this form {"Transcription": "text"}"""
17
+ # image_file = "https://llava-vl.github.io/static/images/view.jpg"
18
+ image_file = "/home/brlab/Dropbox/VoucherVision/demo/demo_images/MICH_16205594_Poaceae_Jouvea_pilosa.jpg"
19
+ args = type('Args', (), {
20
+ "model_path": model_path,
21
+ "model_base": None,
22
+ "model_name": get_model_name_from_path(model_path),
23
+ "query": prompt,
24
+ "conv_mode": None,
25
+ "image_file": image_file,
26
+ "sep": ",",
27
+ "temperature": 0,
28
+ "top_p": None,
29
+ "num_beams": 1,
30
+ "max_new_tokens": 512,
31
+ # "load_8_bit": True,
32
+ })()
33
+
34
+ eval_model(args)
vouchervision/utils_LLM.py CHANGED
@@ -49,7 +49,7 @@ class SystemLoadMonitor():
49
  def __init__(self, logger) -> None:
50
  self.monitoring_thread = None
51
  self.logger = logger
52
- self.gpu_usage = {'max_cpu_usage': 0, 'max_load': 0, 'max_vram_usage': 0, "max_ram_usage": 0, 'monitoring': True}
53
  self.start_time = None
54
  self.tool_start_time = None
55
  self.has_GPU = torch.cuda.is_available()
@@ -71,11 +71,17 @@ class SystemLoadMonitor():
71
  # GPU monitoring
72
  if self.has_GPU:
73
  GPUs = GPUtil.getGPUs()
 
 
 
74
  for gpu in GPUs:
75
- self.gpu_usage['max_load'] = max(self.gpu_usage['max_load'], gpu.load)
76
- # Convert memory usage to GB
77
- memory_usage_gb = gpu.memoryUsed / 1024.0
78
- self.gpu_usage['max_vram_usage'] = max(self.gpu_usage.get('max_vram_usage', 0), memory_usage_gb)
 
 
 
79
 
80
  # RAM monitoring
81
  ram_usage = psutil.virtual_memory().used / (1024.0 ** 3) # Get RAM usage in GB
@@ -94,46 +100,91 @@ class SystemLoadMonitor():
94
  return datetime_iso
95
 
96
  def stop_monitoring_report_usage(self):
97
- report = {}
98
-
99
  self.gpu_usage['monitoring'] = False
100
  self.monitoring_thread.join()
101
- # Calculate tool time by checking if tool_start_time is set
102
- if self.tool_start_time:
103
- tool_time = time.time() - self.tool_start_time
104
- else:
105
- tool_time = 0
106
-
107
- report = {'inference_time_s': str(round(self.inference_time,2)),
108
- 'tool_time_s': str(round(tool_time, 2)),
109
- 'max_cpu': str(round(self.gpu_usage['max_cpu_usage'],2)),
110
- 'max_ram_gb': str(round(self.gpu_usage['max_ram_usage'],2)),
111
- 'current_time': self.get_current_datetime(),
 
 
 
112
  }
 
113
  self.logger.info(f"Inference Time: {round(self.inference_time,2)} seconds")
114
  self.logger.info(f"Tool Time: {round(tool_time,2)} seconds")
115
-
116
  self.logger.info(f"Max CPU Usage: {round(self.gpu_usage['max_cpu_usage'],2)}%")
117
- self.logger.info(f"Max RAM Usage: {round(self.gpu_usage['max_ram_usage'],2)}GB")
118
-
119
  if self.has_GPU:
120
- report.update({'max_gpu_load': str(round(self.gpu_usage['max_load']*100,2))})
121
- report.update({'max_gpu_vram_gb': str(round(self.gpu_usage['max_vram_usage'],2))})
122
-
123
- self.logger.info(f"Max GPU Load: {round(self.gpu_usage['max_load']*100,2)}%")
124
- self.logger.info(f"Max GPU Memory Usage: {round(self.gpu_usage['max_vram_usage'],2)}GB")
125
  else:
126
- report.update({'max_gpu_load': str(0)})
127
- report.update({'max_gpu_vram_gb': str(0)})
128
 
129
  return report
130
 
131
-
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
 
136
 
 
 
 
 
 
 
137
 
138
 
139
 
 
49
  def __init__(self, logger) -> None:
50
  self.monitoring_thread = None
51
  self.logger = logger
52
+ self.gpu_usage = {'max_cpu_usage': 0, 'max_load': 0, 'max_vram_usage': 0, "max_ram_usage": 0, 'n_gpus': 0, 'monitoring': True}
53
  self.start_time = None
54
  self.tool_start_time = None
55
  self.has_GPU = torch.cuda.is_available()
 
71
  # GPU monitoring
72
  if self.has_GPU:
73
  GPUs = GPUtil.getGPUs()
74
+ self.gpu_usage['n_gpus'] = len(GPUs) # Count the number of GPUs
75
+ total_load = 0
76
+ total_memory_usage_gb = 0
77
  for gpu in GPUs:
78
+ total_load += gpu.load
79
+ total_memory_usage_gb += gpu.memoryUsed / 1024.0
80
+
81
+ if self.gpu_usage['n_gpus'] > 0: # Avoid division by zero
82
+ # Calculate the average load and memory usage across all GPUs
83
+ self.gpu_usage['max_load'] = max(self.gpu_usage['max_load'], total_load / self.gpu_usage['n_gpus'])
84
+ self.gpu_usage['max_vram_usage'] = max(self.gpu_usage['max_vram_usage'], total_memory_usage_gb)
85
 
86
  # RAM monitoring
87
  ram_usage = psutil.virtual_memory().used / (1024.0 ** 3) # Get RAM usage in GB
 
100
  return datetime_iso
101
 
102
  def stop_monitoring_report_usage(self):
 
 
103
  self.gpu_usage['monitoring'] = False
104
  self.monitoring_thread.join()
105
+ tool_time = time.time() - self.tool_start_time if self.tool_start_time else 0
106
+
107
+ num_gpus, gpu_dict, total_vram_gb, capability_score = check_system_gpus()
108
+
109
+ report = {
110
+ 'inference_time_s': str(round(self.inference_time, 2)),
111
+ 'tool_time_s': str(round(tool_time, 2)),
112
+ 'max_cpu': str(round(self.gpu_usage['max_cpu_usage'], 2)),
113
+ 'max_ram_gb': str(round(self.gpu_usage['max_ram_usage'], 2)),
114
+ 'current_time': self.get_current_datetime(),
115
+ 'n_gpus': self.gpu_usage['n_gpus'],
116
+ 'total_gpu_vram_gb':total_vram_gb,
117
+ 'capability_score':capability_score,
118
+
119
  }
120
+
121
  self.logger.info(f"Inference Time: {round(self.inference_time,2)} seconds")
122
  self.logger.info(f"Tool Time: {round(tool_time,2)} seconds")
 
123
  self.logger.info(f"Max CPU Usage: {round(self.gpu_usage['max_cpu_usage'],2)}%")
124
+ self.logger.info(f"Max RAM Usage: {round(self.gpu_usage['max_ram_usage'],2)}GB")
 
125
  if self.has_GPU:
126
+ report.update({'max_gpu_load': str(round(self.gpu_usage['max_load'] * 100, 2))})
127
+ report.update({'max_gpu_vram_gb': str(round(self.gpu_usage['max_vram_usage'], 2))})
128
+ self.logger.info(f"Max GPU Load: {round(self.gpu_usage['max_load'] * 100, 2)}%")
129
+ self.logger.info(f"Max GPU Memory Usage: {round(self.gpu_usage['max_vram_usage'], 2)}GB")
 
130
  else:
131
+ report.update({'max_gpu_load': '0'})
132
+ report.update({'max_gpu_vram_gb': '0'})
133
 
134
  return report
135
 
 
136
 
137
 
138
+ def check_system_gpus():
139
+ print(f"Torch CUDA: {torch.cuda.is_available()}")
140
+ # if not torch.cuda.is_available():
141
+ # return 0, {}, 0, "no_gpu"
142
+
143
+ GPUs = GPUtil.getGPUs()
144
+ num_gpus = len(GPUs)
145
+ gpu_dict = {}
146
+ total_vram = 0
147
+
148
+ for i, gpu in enumerate(GPUs):
149
+ gpu_vram = gpu.memoryTotal # VRAM in MB
150
+ gpu_dict[f"GPU_{i}"] = f"{gpu_vram / 1024} GB" # Convert to GB
151
+ total_vram += gpu_vram
152
+
153
+ total_vram_gb = total_vram / 1024 # Convert total VRAM to GB
154
+
155
+ capability_score_map = {
156
+ "no_gpu": 0,
157
+ "class_8GB": 10,
158
+ "class_12GB": 14,
159
+ "class_16GB": 18,
160
+ "class_24GB": 26,
161
+ "class_48GB": 50,
162
+ "class_96GB": 100,
163
+ "class_96GBplus": float('inf'), # Use infinity to represent any value greater than 96GB
164
+ }
165
+
166
+ # Determine the capability score based on the total VRAM
167
+ capability_score = "no_gpu"
168
+ for score, vram in capability_score_map.items():
169
+ if total_vram_gb <= vram:
170
+ capability_score = score
171
+ break
172
+ else:
173
+ capability_score = "class_max"
174
+
175
+ return num_gpus, gpu_dict, total_vram_gb, capability_score
176
+
177
+
178
+
179
 
180
 
181
 
182
+ if __name__ == '__main__':
183
+ num_gpus, gpu_dict, total_vram_gb, capability_score = check_system_gpus()
184
+ print(f"Number of GPUs: {num_gpus}")
185
+ print(f"GPU Details: {gpu_dict}")
186
+ print(f"Total VRAM: {total_vram_gb} GB")
187
+ print(f"Capability Score: {capability_score}")
188
 
189
 
190
 
vouchervision/utils_LLM_JSON_validation.py CHANGED
@@ -11,7 +11,7 @@ def validate_and_align_JSON_keys_with_template(data, JSON_dict_structure):
11
  if value is None:
12
  data[key] = ''
13
  elif isinstance(value, str):
14
- if value.lower() in ['unknown', 'not provided', 'missing', 'na', 'none', 'n/a', 'null',
15
  'not provided in the text', 'not found in the text',
16
  'not in the text', 'not provided', 'not found',
17
  'not provided in the ocr', 'not found in the ocr',
 
11
  if value is None:
12
  data[key] = ''
13
  elif isinstance(value, str):
14
+ if value.lower() in ['unknown', 'not provided', 'missing', 'na', 'none', 'n/a', 'null', 'unspecified',
15
  'not provided in the text', 'not found in the text',
16
  'not in the text', 'not provided', 'not found',
17
  'not provided in the ocr', 'not found in the ocr',
vouchervision/utils_VoucherVision.py CHANGED
@@ -5,10 +5,8 @@ from openpyxl import Workbook, load_workbook
5
  import vertexai
6
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
7
  from langchain_openai import AzureChatOpenAI
8
- from OCR_google_cloud_vision import OCRGoogle
9
- # import google.generativeai as genai
10
  from google.oauth2 import service_account
11
- # from googleapiclient.discovery import build
12
 
13
  from vouchervision.LLM_OpenAI import OpenAIHandler
14
  from vouchervision.LLM_GooglePalm2 import GooglePalm2Handler
@@ -20,6 +18,7 @@ from vouchervision.utils_LLM import remove_colons_and_double_apostrophes
20
  from vouchervision.prompt_catalog import PromptCatalog
21
  from vouchervision.model_maps import ModelMaps
22
  from vouchervision.general_utils import get_cfg_from_full_path
 
23
 
24
  '''
25
  * For the prefix_removal, the image names have 'MICH-V-' prior to the barcode, so that is used for matching
@@ -44,9 +43,11 @@ class VoucherVision():
44
  self.prompt_version = None
45
  self.is_hf = is_hf
46
 
47
- # self.trOCR_model_version = "microsoft/trocr-large-handwritten"
48
- self.trOCR_model_version = "microsoft/trocr-base-handwritten"
49
- # self.trOCR_model_version = "dh-unibe/trocr-medieval-escriptmask"
 
 
50
  self.trOCR_processor = None
51
  self.trOCR_model = None
52
 
@@ -77,12 +78,12 @@ class VoucherVision():
77
  "GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
78
  "GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",]
79
 
80
- self.usage_headers = ["current_time", "inference_time_s", "tool_time_s","max_cpu", "max_ram_gb", "max_gpu_load", "max_gpu_vram_gb",]
81
 
82
  self.wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"]
83
  self.wfo_headers_no_lists = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
84
 
85
- self.utility_headers = ["filename"] + self.wfo_headers + self.geo_headers + self.usage_headers + ["prompt", "LLM", "tokens_in", "tokens_out", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
86
  # "WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement",
87
 
88
  # "GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
@@ -117,8 +118,8 @@ class VoucherVision():
117
  lgr = logging.getLogger('transformers')
118
  lgr.setLevel(logging.ERROR)
119
 
120
- self.trOCR_processor = TrOCRProcessor.from_pretrained(self.trOCR_model_version)
121
- self.trOCR_model = VisionEncoderDecoderModel.from_pretrained(self.trOCR_model_version)
122
 
123
  # Check for GPU availability
124
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -297,7 +298,7 @@ class VoucherVision():
297
  break
298
 
299
 
300
- def add_data_to_excel_from_response(self, path_transcription, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
301
 
302
 
303
  wb = openpyxl.load_workbook(path_transcription)
@@ -364,6 +365,8 @@ class VoucherVision():
364
  sheet.cell(row=next_row, column=i, value=filename_without_extension)
365
  elif header.value == "prompt":
366
  sheet.cell(row=next_row, column=i, value=os.path.basename(self.path_custom_prompts))
 
 
367
 
368
  # "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
369
  elif header.value in self.wfo_headers_no_lists:
@@ -613,12 +616,12 @@ class VoucherVision():
613
  ##################################################################################################################################
614
  ################################################## OCR ##################################################################
615
  ##################################################################################################################################
616
- def perform_OCR_and_save_results(self, image_index, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds):
617
  self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Starting OCR')
618
  # self.OCR - None
619
 
620
  ### Process_image() runs the OCR for text, handwriting, trOCR AND creates the overlay image
621
- ocr_google = OCRGoogle(self.is_hf, self.path_to_crop, self.cfg, self.trOCR_model_version, self.trOCR_model, self.trOCR_processor, self.device)
622
  ocr_google.process_image(self.do_create_OCR_helper_image, self.logger)
623
  self.OCR = ocr_google.OCR
624
 
@@ -682,7 +685,7 @@ class VoucherVision():
682
 
683
  filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
684
  json_report.set_text(text_main='Starting OCR')
685
- self.perform_OCR_and_save_results(i, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds)
686
  json_report.set_text(text_main='Finished OCR')
687
 
688
  if not self.OCR:
@@ -797,10 +800,10 @@ class VoucherVision():
797
  filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
798
  # Saving the JSON and XLSX files with the response and updating the final JSON response
799
  if response_candidate is not None:
800
- final_JSON_response_updated = self.save_json_and_xlsx(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
801
  return final_JSON_response_updated, WFO_record, GEO_record
802
  else:
803
- final_JSON_response_updated = self.save_json_and_xlsx(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
804
  return final_JSON_response_updated, WFO_record, GEO_record
805
 
806
 
@@ -836,7 +839,7 @@ class VoucherVision():
836
  return filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt
837
 
838
 
839
- def save_json_and_xlsx(self, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out):
840
  if response is None:
841
  response = self.JSON_dict_structure
842
  # Insert 'filename' as the first key
@@ -845,14 +848,14 @@ class VoucherVision():
845
 
846
  # Then add the null info to the spreadsheet
847
  response_null = self.create_null_row(filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper)
848
- self.add_data_to_excel_from_response(self.path_transcription, response_null, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in=0, nt_out=0)
849
 
850
  ### Set completed JSON
851
  else:
852
  response = self.clean_catalog_number(response, filename_without_extension)
853
  self.write_json_to_file(txt_file_path, response)
854
  # add to the xlsx file
855
- self.add_data_to_excel_from_response(self.path_transcription, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
856
  return response
857
 
858
 
 
5
  import vertexai
6
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
7
  from langchain_openai import AzureChatOpenAI
 
 
8
  from google.oauth2 import service_account
9
+ from transformers import AutoTokenizer, AutoModel
10
 
11
  from vouchervision.LLM_OpenAI import OpenAIHandler
12
  from vouchervision.LLM_GooglePalm2 import GooglePalm2Handler
 
18
  from vouchervision.prompt_catalog import PromptCatalog
19
  from vouchervision.model_maps import ModelMaps
20
  from vouchervision.general_utils import get_cfg_from_full_path
21
+ from vouchervision.OCR_google_cloud_vision import OCREngine
22
 
23
  '''
24
  * For the prefix_removal, the image names have 'MICH-V-' prior to the barcode, so that is used for matching
 
43
  self.prompt_version = None
44
  self.is_hf = is_hf
45
 
46
+ self.trOCR_model_version = "microsoft/trocr-large-handwritten"
47
+ # self.trOCR_model_version = "microsoft/trocr-base-handwritten"
48
+ # self.trOCR_model_version = "dh-unibe/trocr-medieval-escriptmask" # NOPE
49
+ # self.trOCR_model_version = "dh-unibe/trocr-kurrent" # NOPE
50
+ # self.trOCR_model_version = "DunnBC22/trocr-base-handwritten-OCR-handwriting_recognition_v2" # NOPE
51
  self.trOCR_processor = None
52
  self.trOCR_model = None
53
 
 
78
  "GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
79
  "GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",]
80
 
81
+ self.usage_headers = ["current_time", "inference_time_s", "tool_time_s","max_cpu", "max_ram_gb", "n_gpus", "max_gpu_load", "max_gpu_vram_gb","total_gpu_vram_gb","capability_score",]
82
 
83
  self.wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"]
84
  self.wfo_headers_no_lists = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
85
 
86
+ self.utility_headers = ["filename"] + self.wfo_headers + self.geo_headers + self.usage_headers + ["run_name", "prompt", "LLM", "tokens_in", "tokens_out", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
87
  # "WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement",
88
 
89
  # "GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
 
118
  lgr = logging.getLogger('transformers')
119
  lgr.setLevel(logging.ERROR)
120
 
121
+ self.trOCR_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") # usually just the "microsoft/trocr-base-handwritten"
122
+ self.trOCR_model = VisionEncoderDecoderModel.from_pretrained(self.trOCR_model_version) # This matches the model
123
 
124
  # Check for GPU availability
125
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
298
  break
299
 
300
 
301
+ def add_data_to_excel_from_response(self, Dirs, path_transcription, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
302
 
303
 
304
  wb = openpyxl.load_workbook(path_transcription)
 
365
  sheet.cell(row=next_row, column=i, value=filename_without_extension)
366
  elif header.value == "prompt":
367
  sheet.cell(row=next_row, column=i, value=os.path.basename(self.path_custom_prompts))
368
+ elif header.value == "run_name":
369
+ sheet.cell(row=next_row, column=i, value=Dirs.run_name)
370
 
371
  # "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
372
  elif header.value in self.wfo_headers_no_lists:
 
616
  ##################################################################################################################################
617
  ################################################## OCR ##################################################################
618
  ##################################################################################################################################
619
+ def perform_OCR_and_save_results(self, image_index, json_report, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds):
620
  self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Starting OCR')
621
  # self.OCR - None
622
 
623
  ### Process_image() runs the OCR for text, handwriting, trOCR AND creates the overlay image
624
+ ocr_google = OCREngine(self.logger, json_report, self.dir_home, self.is_hf, self.path_to_crop, self.cfg, self.trOCR_model_version, self.trOCR_model, self.trOCR_processor, self.device)
625
  ocr_google.process_image(self.do_create_OCR_helper_image, self.logger)
626
  self.OCR = ocr_google.OCR
627
 
 
685
 
686
  filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
687
  json_report.set_text(text_main='Starting OCR')
688
+ self.perform_OCR_and_save_results(i, json_report, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds)
689
  json_report.set_text(text_main='Finished OCR')
690
 
691
  if not self.OCR:
 
800
  filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
801
  # Saving the JSON and XLSX files with the response and updating the final JSON response
802
  if response_candidate is not None:
803
+ final_JSON_response_updated = self.save_json_and_xlsx(self.Dirs, response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
804
  return final_JSON_response_updated, WFO_record, GEO_record
805
  else:
806
+ final_JSON_response_updated = self.save_json_and_xlsx(self.Dirs, response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
807
  return final_JSON_response_updated, WFO_record, GEO_record
808
 
809
 
 
839
  return filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt
840
 
841
 
842
+ def save_json_and_xlsx(self, Dirs, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out):
843
  if response is None:
844
  response = self.JSON_dict_structure
845
  # Insert 'filename' as the first key
 
848
 
849
  # Then add the null info to the spreadsheet
850
  response_null = self.create_null_row(filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper)
851
+ self.add_data_to_excel_from_response(Dirs, self.path_transcription, response_null, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in=0, nt_out=0)
852
 
853
  ### Set completed JSON
854
  else:
855
  response = self.clean_catalog_number(response, filename_without_extension)
856
  self.write_json_to_file(txt_file_path, response)
857
  # add to the xlsx file
858
+ self.add_data_to_excel_from_response(Dirs, self.path_transcription, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
859
  return response
860
 
861
 
vouchervision/vouchervision_main.py CHANGED
@@ -3,10 +3,10 @@ VoucherVision - based on LeafMachine2 Processes
3
  '''
4
  import os, inspect, sys, shutil
5
  from time import perf_counter
6
- currentdir = os.path.dirname(os.path.dirname(inspect.getfile(inspect.currentframe())))
7
- parentdir = os.path.dirname(currentdir)
8
- sys.path.append(parentdir)
9
- sys.path.append(currentdir)
10
  from vouchervision.component_detector.component_detector import detect_plant_components, detect_archival_components
11
  from vouchervision.general_utils import save_token_info_as_csv, print_main_start, check_for_subdirs_VV, load_config_file, load_config_file_testing, report_config, save_config_file, crop_detections_from_images_VV
12
  from vouchervision.directory_structure_VV import Dir_Structure
@@ -90,7 +90,14 @@ def voucher_vision(cfg_file_path, dir_home, path_custom_prompts, cfg_test, progr
90
  else:
91
  upload_to_drive(zip_filepath, zip_filename, is_hf, cfg_private=Voucher_Vision.cfg_private, do_upload=False) ##################################### TODO Make this configurable
92
 
93
- return last_JSON_response, final_WFO_record, final_GEO_record, total_cost, Voucher_Vision.n_failed_OCR, Voucher_Vision.n_failed_LLM_calls, zip_filepath
 
 
 
 
 
 
 
94
 
95
  def make_zipfile(base_dir, output_filename):
96
  # Determine the directory where the zip file should be saved
 
3
  '''
4
  import os, inspect, sys, shutil
5
  from time import perf_counter
6
+ # currentdir = os.path.dirname(os.path.dirname(inspect.getfile(inspect.currentframe())))
7
+ # parentdir = os.path.dirname(currentdir)
8
+ # sys.path.append(parentdir)
9
+ # sys.path.append(currentdir)
10
  from vouchervision.component_detector.component_detector import detect_plant_components, detect_archival_components
11
  from vouchervision.general_utils import save_token_info_as_csv, print_main_start, check_for_subdirs_VV, load_config_file, load_config_file_testing, report_config, save_config_file, crop_detections_from_images_VV
12
  from vouchervision.directory_structure_VV import Dir_Structure
 
90
  else:
91
  upload_to_drive(zip_filepath, zip_filename, is_hf, cfg_private=Voucher_Vision.cfg_private, do_upload=False) ##################################### TODO Make this configurable
92
 
93
+ return {'last_JSON_response': last_JSON_response,
94
+ 'final_WFO_record': final_WFO_record,
95
+ 'final_GEO_record': final_GEO_record,
96
+ 'total_cost': total_cost,
97
+ 'n_failed_OCR': Voucher_Vision.n_failed_OCR,
98
+ 'n_failed_LLM_calls': Voucher_Vision.n_failed_LLM_calls,
99
+ 'zip_filepath': zip_filepath,
100
+ }
101
 
102
  def make_zipfile(base_dir, output_filename):
103
  # Determine the directory where the zip file should be saved