Spaces:
Running
Running
phyloforfun
commited on
Commit
•
b8abf64
1
Parent(s):
fdfdfc3
Major update. Support for 15 LLMs, World Flora Online taxonomy validation, geolocation, 2 OCR methods, significant UI changes, stability improvements, consistent JSON parsing
Browse files- .gitignore +2 -0
- app.py +416 -325
- demo/demo_images/{MICH_7574789_Cyperaceae_Carex_scoparia.jpg → MICH_16205594_Poaceae_Jouvea_pilosa.jpg} +2 -2
- demo/google/google_api_0.PNG +3 -0
- demo/google/google_api_00.PNG +3 -0
- demo/google/google_api_1.PNG +3 -0
- demo/google/google_api_10.PNG +3 -0
- demo/google/google_api_11.PNG +3 -0
- demo/google/google_api_2.PNG +3 -0
- demo/google/google_api_3.PNG +3 -0
- demo/google/google_api_4.PNG +3 -0
- demo/google/google_api_5.PNG +3 -0
- demo/google/google_api_6.PNG +3 -0
- demo/google/google_api_7.PNG +3 -0
- demo/google/google_api_8.PNG +3 -0
- demo/google/google_api_9.PNG +3 -0
- requirements.txt +0 -0
- run_VoucherVision.py +3 -1
- vouchervision/API_validation.py +29 -3
- vouchervision/LLM_GoogleGemini.py +8 -3
- vouchervision/LLM_GooglePalm2.py +3 -1
- vouchervision/OCR_google_cloud_vision.py +4 -3
- vouchervision/utils_VoucherVision.py +27 -42
- vouchervision/utils_hf.py +33 -4
.gitignore
CHANGED
@@ -19,6 +19,8 @@ venv_LM2_38/
|
|
19 |
venv_LM2/
|
20 |
venv_VV/
|
21 |
tests/
|
|
|
|
|
22 |
.vscode/
|
23 |
runs/
|
24 |
KP_Test/
|
|
|
19 |
venv_LM2/
|
20 |
venv_VV/
|
21 |
tests/
|
22 |
+
uploads/
|
23 |
+
uploads_small/
|
24 |
.vscode/
|
25 |
runs/
|
26 |
KP_Test/
|
app.py
CHANGED
@@ -15,12 +15,30 @@ from vouchervision.vouchervision_main import voucher_vision, voucher_vision_OCR_
|
|
15 |
from vouchervision.general_utils import test_GPU, get_cfg_from_full_path, summarize_expense_report, create_google_ocr_yaml_config, validate_dir
|
16 |
from vouchervision.model_maps import ModelMaps
|
17 |
from vouchervision.API_validation import APIvalidation
|
18 |
-
from vouchervision.utils_hf import upload_to_drive, image_to_base64, setup_streamlit_config, save_uploaded_file, check_prompt_yaml_filename
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
########################################################################################################
|
22 |
### ADDED FOR HUGGING FACE ####
|
23 |
########################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
if 'uploader_idk' not in st.session_state:
|
25 |
st.session_state['uploader_idk'] = 1
|
26 |
if 'input_list_small' not in st.session_state:
|
@@ -31,76 +49,157 @@ if 'user_clicked_load_prompt_yaml' not in st.session_state:
|
|
31 |
st.session_state['user_clicked_load_prompt_yaml'] = None
|
32 |
if 'new_prompt_yaml_filename' not in st.session_state:
|
33 |
st.session_state['new_prompt_yaml_filename'] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
MAX_GALLERY_IMAGES =
|
36 |
-
GALLERY_IMAGE_SIZE =
|
37 |
|
38 |
|
39 |
|
40 |
-
def
|
41 |
st.write('---')
|
42 |
-
col1, col2 = st.columns([2,8])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
st.write("Run name will be the name of the final zipped folder.")
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
st.session_state['
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
st.
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
|
|
|
|
|
|
|
|
|
|
94 |
|
|
|
95 |
|
96 |
|
97 |
def create_download_button(zip_filepath, col):
|
98 |
with col:
|
99 |
-
|
100 |
with open(zip_filepath, 'rb') as f:
|
101 |
bytes_io = BytesIO(f.read())
|
102 |
st.download_button(
|
103 |
-
label=
|
104 |
type='primary',
|
105 |
data=bytes_io,
|
106 |
file_name=os.path.basename(zip_filepath),
|
@@ -130,9 +229,17 @@ def use_test_image():
|
|
130 |
st.info(f"Processing images from {os.path.join(st.session_state.dir_home,'demo','demo_images')}")
|
131 |
st.session_state.config['leafmachine']['project']['dir_images_local'] = os.path.join(st.session_state.dir_home,'demo','demo_images')
|
132 |
n_images = len([f for f in os.listdir(st.session_state.config['leafmachine']['project']['dir_images_local']) if os.path.isfile(os.path.join(st.session_state.config['leafmachine']['project']['dir_images_local'], f))])
|
133 |
-
st.session_state['processing_add_on'] =
|
134 |
clear_image_gallery()
|
135 |
st.session_state['uploader_idk'] += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
|
138 |
def create_download_button_yaml(file_path, selected_yaml_file):
|
@@ -409,7 +516,7 @@ class JSONReport:
|
|
409 |
|
410 |
|
411 |
def does_private_file_exist():
|
412 |
-
dir_home = os.path.dirname(
|
413 |
path_cfg_private = os.path.join(dir_home, 'PRIVATE_DATA.yaml')
|
414 |
return os.path.exists(path_cfg_private)
|
415 |
|
@@ -613,16 +720,32 @@ def get_prompt_versions(LLM_version):
|
|
613 |
|
614 |
|
615 |
def get_private_file():
|
616 |
-
dir_home = os.path.dirname(
|
617 |
path_cfg_private = os.path.join(dir_home, 'PRIVATE_DATA.yaml')
|
618 |
return get_cfg_from_full_path(path_cfg_private)
|
619 |
|
620 |
-
|
621 |
-
|
622 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
623 |
st.session_state.proceed_to_main = False
|
624 |
st.title("VoucherVision")
|
625 |
-
col_private,
|
|
|
|
|
626 |
|
627 |
if st.session_state.private_file:
|
628 |
cfg_private = get_private_file()
|
@@ -632,219 +755,219 @@ def create_private_file(): #####################################################
|
|
632 |
cfg_private['openai']['OPENAI_API_KEY'] =''
|
633 |
|
634 |
cfg_private['openai_azure'] = {}
|
635 |
-
cfg_private['openai_azure']['
|
636 |
-
cfg_private['openai_azure']['
|
637 |
-
cfg_private['openai_azure']['
|
638 |
-
cfg_private['openai_azure']['
|
639 |
-
cfg_private['openai_azure']['
|
640 |
-
|
641 |
-
cfg_private['
|
642 |
-
cfg_private['
|
643 |
-
|
644 |
-
cfg_private['
|
645 |
-
cfg_private['
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
|
647 |
|
648 |
with col_private:
|
649 |
st.header("Set API keys")
|
650 |
-
st.info("***Note:*** There is a known bug with tabs in Streamlit. If you update an input field it may take you back to the 'Project Settings' tab. Changes that you made are saved, it's just an annoying glitch. We are aware of this issue and will fix it as soon as we can.")
|
651 |
st.warning("To commit changes to API keys you must press the 'Set API Keys' button at the bottom of the page.")
|
652 |
st.write("Before using VoucherVision you must set your API keys. All keys are stored locally on your computer and are never made public.")
|
653 |
st.write("API keys are stored in `../VoucherVision/PRIVATE_DATA.yaml`.")
|
654 |
-
st.write("Deleting this file will allow you to reset API keys. Alternatively, you can edit the keys in the user interface.")
|
655 |
st.write("Leave keys blank if you do not intend to use that service.")
|
656 |
|
657 |
st.write("---")
|
658 |
-
st.subheader("Google Vision (*Required*)")
|
659 |
st.markdown("VoucherVision currently uses [Google Vision API](https://cloud.google.com/vision/docs/ocr) for OCR. Generating an API key for this is more involved than the others. [Please carefully follow the instructions outlined here to create and setup your account.](https://cloud.google.com/vision/docs/setup) ")
|
660 |
-
st.markdown("""
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
|
684 |
|
685 |
st.write("---")
|
686 |
st.subheader("OpenAI")
|
687 |
st.markdown("API key for first-party OpenAI API. Create an account with OpenAI [here](https://platform.openai.com/signup), then create an API key [here](https://platform.openai.com/account/api-keys).")
|
688 |
-
|
689 |
-
c_in_openai, c_button_openai = st.columns([10,2])
|
690 |
-
with c_in_openai:
|
691 |
-
openai_api_key = st.text_input("openai_api_key", cfg_private['openai'].get('OPENAI_API_KEY', ''),
|
692 |
help='The actual API key. Likely to be a string of 2 character, a dash, and then a 48-character string: sk-XXXXXXXX...',
|
693 |
placeholder = 'e.g. sk-XXXXXXXX...',
|
694 |
type='password')
|
695 |
-
|
696 |
-
st.empty()
|
697 |
|
698 |
st.write("---")
|
699 |
st.subheader("OpenAI - Azure")
|
700 |
st.markdown("This version OpenAI relies on Azure servers directly as is intended for private enterprise instances of OpenAI's services, such as [UM-GPT](https://its.umich.edu/computing/ai). Administrators will provide you with the following information.")
|
701 |
-
azure_openai_api_version = st.text_input("
|
702 |
help='API Version e.g. "2023-05-15"',
|
703 |
placeholder = 'e.g. 2023-05-15',
|
704 |
type='password')
|
705 |
-
azure_openai_api_key = st.text_input("
|
706 |
-
help='The actual API key. Likely to be a 32-character string',
|
707 |
placeholder = 'e.g. 12333333333333333333333333333332',
|
708 |
type='password')
|
709 |
-
azure_openai_api_base = st.text_input("
|
710 |
help='The base url for the API e.g. "https://api.umgpt.umich.edu/azure-openai-api"',
|
711 |
placeholder = 'e.g. https://api.umgpt.umich.edu/azure-openai-api',
|
712 |
type='password')
|
713 |
-
azure_openai_organization = st.text_input("
|
714 |
-
help='Your organization code. Likely a short string',
|
715 |
placeholder = 'e.g. 123456',
|
716 |
type='password')
|
717 |
-
azure_openai_api_type = st.text_input("
|
718 |
help='The API type. Typically "azure"',
|
719 |
placeholder = 'e.g. azure',
|
720 |
type='password')
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
726 |
st.write("---")
|
727 |
-
st.subheader("
|
728 |
-
st.markdown('Follow these [instructions](https://
|
729 |
-
|
730 |
-
|
731 |
-
with c_in_palm:
|
732 |
-
google_palm = st.text_input("Google PaLM 2 API Key", cfg_private['google_palm'].get('google_palm_api', ''),
|
733 |
-
help='The MakerSuite API key e.g. a 32-character string',
|
734 |
placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
|
735 |
type='password')
|
736 |
-
|
737 |
-
with st.container():
|
738 |
-
with c_button_ocr:
|
739 |
-
st.write("##")
|
740 |
-
st.button("Test OCR", on_click=test_API, args=['google_vision',c_in_ocr, cfg_private,openai_api_key,azure_openai_api_version,azure_openai_api_key,
|
741 |
-
azure_openai_api_base,azure_openai_organization,azure_openai_api_type,google_vision,google_palm])
|
742 |
-
|
743 |
-
with st.container():
|
744 |
-
with c_button_openai:
|
745 |
-
st.write("##")
|
746 |
-
st.button("Test OpenAI", on_click=test_API, args=['openai',c_in_openai, cfg_private,openai_api_key,azure_openai_api_version,azure_openai_api_key,
|
747 |
-
azure_openai_api_base,azure_openai_organization,azure_openai_api_type,google_vision,google_palm])
|
748 |
-
|
749 |
-
with st.container():
|
750 |
-
with c_button_azure:
|
751 |
-
st.write("##")
|
752 |
-
st.button("Test Azure OpenAI", on_click=test_API, args=['azure_openai',c_in_azure, cfg_private,openai_api_key,azure_openai_api_version,azure_openai_api_key,
|
753 |
-
azure_openai_api_base,azure_openai_organization,azure_openai_api_type,google_vision,google_palm])
|
754 |
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
760 |
|
761 |
|
762 |
st.button("Set API Keys",type='primary', on_click=save_changes_to_API_keys, args=[cfg_private,openai_api_key,azure_openai_api_version,azure_openai_api_key,
|
763 |
-
azure_openai_api_base,azure_openai_organization,azure_openai_api_type,
|
|
|
|
|
764 |
if st.button('Proceed to VoucherVision'):
|
|
|
765 |
st.session_state.proceed_to_private = False
|
766 |
st.session_state.proceed_to_main = True
|
767 |
-
|
768 |
-
def test_API(api, message_loc, cfg_private,openai_api_key,azure_openai_api_version,azure_openai_api_key, azure_openai_api_base,azure_openai_organization,azure_openai_api_type,google_vision,google_palm):
|
769 |
-
# Save the API keys
|
770 |
-
save_changes_to_API_keys(cfg_private,openai_api_key,azure_openai_api_version,azure_openai_api_key,azure_openai_api_base,azure_openai_organization,azure_openai_api_type,google_vision,google_palm)
|
771 |
-
|
772 |
-
with st.spinner('Performing validation checks...'):
|
773 |
-
if api == 'google_vision':
|
774 |
-
print("*** Google Vision OCR API Key ***")
|
775 |
-
try:
|
776 |
-
demo_config_path = os.path.join(st.session_state.dir_home,'demo','validation_configs','google_vision_ocr_test.yaml')
|
777 |
-
demo_images_path = os.path.join(st.session_state.dir_home, 'demo', 'demo_images')
|
778 |
-
demo_out_path = os.path.join(st.session_state.dir_home, 'demo', 'demo_output','run_name')
|
779 |
-
create_google_ocr_yaml_config(demo_config_path, demo_images_path, demo_out_path)
|
780 |
-
voucher_vision_OCR_test(demo_config_path, st.session_state.dir_home, None, demo_images_path)
|
781 |
-
with message_loc:
|
782 |
-
st.success("Google Vision OCR API Key Valid :white_check_mark:")
|
783 |
-
return True
|
784 |
-
except Exception as e:
|
785 |
-
with message_loc:
|
786 |
-
st.error(f"Google Vision OCR API Key Failed! {e}")
|
787 |
-
return False
|
788 |
-
|
789 |
-
elif api == 'openai':
|
790 |
-
print("*** OpenAI API Key ***")
|
791 |
-
try:
|
792 |
-
if run_api_tests('openai'):
|
793 |
-
with message_loc:
|
794 |
-
st.success("OpenAI API Key Valid :white_check_mark:")
|
795 |
-
else:
|
796 |
-
with message_loc:
|
797 |
-
st.error("OpenAI API Key Failed:exclamation:")
|
798 |
-
return False
|
799 |
-
except Exception as e:
|
800 |
-
with message_loc:
|
801 |
-
st.error(f"OpenAI API Key Failed:exclamation: {e}")
|
802 |
-
|
803 |
-
elif api == 'azure_openai':
|
804 |
-
print("*** Azure OpenAI API Key ***")
|
805 |
-
try:
|
806 |
-
if run_api_tests('azure_openai'):
|
807 |
-
with message_loc:
|
808 |
-
st.success("Azure OpenAI API Key Valid :white_check_mark:")
|
809 |
-
else:
|
810 |
-
with message_loc:
|
811 |
-
st.error(f"Azure OpenAI API Key Failed:exclamation:")
|
812 |
-
return False
|
813 |
-
except Exception as e:
|
814 |
-
with message_loc:
|
815 |
-
st.error(f"Azure OpenAI API Key Failed:exclamation: {e}")
|
816 |
-
elif api == 'palm':
|
817 |
-
print("*** Google PaLM 2 API Key ***")
|
818 |
-
try:
|
819 |
-
if run_api_tests('palm'):
|
820 |
-
with message_loc:
|
821 |
-
st.success("Google PaLM 2 API Key Valid :white_check_mark:")
|
822 |
-
else:
|
823 |
-
with message_loc:
|
824 |
-
st.error("Google PaLM 2 API Key Failed:exclamation:")
|
825 |
-
return False
|
826 |
-
except Exception as e:
|
827 |
-
with message_loc:
|
828 |
-
st.error(f"Google PaLM 2 API Key Failed:exclamation: {e}")
|
829 |
|
830 |
|
831 |
def save_changes_to_API_keys(cfg_private,openai_api_key,azure_openai_api_version,azure_openai_api_key,
|
832 |
-
|
|
|
|
|
|
|
833 |
# Update the configuration dictionary with the new values
|
834 |
cfg_private['openai']['OPENAI_API_KEY'] = openai_api_key
|
835 |
|
836 |
-
cfg_private['openai_azure']['
|
837 |
-
cfg_private['openai_azure']['
|
838 |
-
cfg_private['openai_azure']['
|
839 |
-
cfg_private['openai_azure']['
|
840 |
-
cfg_private['openai_azure']['
|
|
|
|
|
|
|
|
|
841 |
|
842 |
-
cfg_private['
|
843 |
|
844 |
-
cfg_private['
|
|
|
845 |
# Call the function to write the updated configuration to the YAML file
|
846 |
write_config_file(cfg_private, st.session_state.dir_home, filename="PRIVATE_DATA.yaml")
|
847 |
-
st.session_state.private_file = does_private_file_exist()
|
848 |
|
849 |
# Function to load a YAML file and update session_state
|
850 |
def load_prompt_yaml(filename):
|
@@ -1588,14 +1711,18 @@ def content_header():
|
|
1588 |
# st.subheader('Run VoucherVision')
|
1589 |
N_STEPS = 6
|
1590 |
|
1591 |
-
|
1592 |
-
|
1593 |
-
|
1594 |
-
|
1595 |
-
|
|
|
1596 |
|
1597 |
if check_if_usable(is_hf=st.session_state['is_hf']):
|
1598 |
-
if st.
|
|
|
|
|
|
|
1599 |
st.session_state['formatted_json'] = {}
|
1600 |
st.session_state['formatted_json_WFO'] = {}
|
1601 |
st.session_state['formatted_json_GEO'] = {}
|
@@ -1750,28 +1877,16 @@ def content_header():
|
|
1750 |
|
1751 |
|
1752 |
|
1753 |
-
def content_project_settings():
|
1754 |
-
|
1755 |
-
|
1756 |
-
|
1757 |
-
|
1758 |
-
with col_project_1:
|
1759 |
st.session_state.config['leafmachine']['project']['run_name'] = st.text_input("Run name", st.session_state.config['leafmachine']['project'].get('run_name', ''),key=63456)
|
1760 |
st.session_state.config['leafmachine']['project']['dir_output'] = st.text_input("Output directory", st.session_state.config['leafmachine']['project'].get('dir_output', ''))
|
1761 |
-
|
1762 |
-
|
1763 |
-
|
1764 |
-
def content_input_images():
|
1765 |
-
st.header('Input Images')
|
1766 |
-
col_local_1, col_local_2 = st.columns([11,1])
|
1767 |
-
with col_local_1:
|
1768 |
-
### Input Images Local
|
1769 |
-
st.session_state.config['leafmachine']['project']['dir_images_local'] = st.text_input("Input images directory", st.session_state.config['leafmachine']['project'].get('dir_images_local', ''))
|
1770 |
-
st.session_state.config['leafmachine']['project']['continue_run_from_partial_xlsx'] = st.text_input("Continue run from partially completed project XLSX", st.session_state.config['leafmachine']['project'].get('continue_run_from_partial_xlsx', ''), disabled=True)
|
1771 |
|
1772 |
|
1773 |
|
1774 |
-
|
1775 |
def content_llm_cost():
|
1776 |
st.write("---")
|
1777 |
st.header('LLM Cost Calculator')
|
@@ -1881,10 +1996,11 @@ def content_api_check():
|
|
1881 |
st.session_state['API_rechecked'] = True
|
1882 |
st.rerun()
|
1883 |
# with col_llm_2c:
|
1884 |
-
if st.
|
1885 |
-
st.
|
1886 |
-
|
1887 |
-
|
|
|
1888 |
|
1889 |
|
1890 |
|
@@ -1940,8 +2056,7 @@ def content_collage_overlay():
|
|
1940 |
st.session_state.config['leafmachine']['do_create_OCR_helper_image'] = do_create_OCR_helper_image
|
1941 |
|
1942 |
|
1943 |
-
|
1944 |
-
st.session_state.config['leafmachine']['project']['do_use_trOCR'] = do_use_trOCR
|
1945 |
|
1946 |
# Get the current OCR option from session state
|
1947 |
OCR_option = st.session_state.config['leafmachine']['project']['OCR_option']
|
@@ -1972,6 +2087,11 @@ def content_collage_overlay():
|
|
1972 |
OCR_option = 'both'
|
1973 |
else:
|
1974 |
raise
|
|
|
|
|
|
|
|
|
|
|
1975 |
|
1976 |
st.session_state.config['leafmachine']['project']['OCR_option'] = OCR_option
|
1977 |
st.markdown("Below is an example of what the LLM would see given the choice of OCR ensemble. One, two, or three version of OCR can be fed into the LLM prompt. Typically, 'printed + handwritten' works well. If you have a GPU then you can enable trOCR.")
|
@@ -2267,19 +2387,21 @@ def main():
|
|
2267 |
content_header()
|
2268 |
|
2269 |
|
|
|
|
|
|
|
2270 |
|
2271 |
-
if st.session_state['is_hf']:
|
2272 |
-
|
2273 |
-
|
2274 |
-
else:
|
2275 |
-
|
2276 |
-
|
2277 |
-
|
2278 |
-
|
2279 |
-
|
2280 |
|
2281 |
|
2282 |
-
st.write("---")
|
2283 |
col3, col4 = st.columns([1,1])
|
2284 |
with col3:
|
2285 |
content_prompt_and_llm_version()
|
@@ -2295,15 +2417,7 @@ def main():
|
|
2295 |
|
2296 |
|
2297 |
|
2298 |
-
#################################################################################################################################################
|
2299 |
-
# Initializations ###############################################################################################################################
|
2300 |
-
#################################################################################################################################################
|
2301 |
-
|
2302 |
-
st.set_page_config(layout="wide", page_icon='img/icon.ico', page_title='VoucherVision')
|
2303 |
|
2304 |
-
# Parse the 'is_hf' argument and set it in session state
|
2305 |
-
if 'is_hf' not in st.session_state:
|
2306 |
-
st.session_state['is_hf'] = True
|
2307 |
|
2308 |
|
2309 |
#################################################################################################################################################
|
@@ -2311,47 +2425,42 @@ if 'is_hf' not in st.session_state:
|
|
2311 |
#################################################################################################################################################
|
2312 |
|
2313 |
|
2314 |
-
|
2315 |
-
|
2316 |
-
|
2317 |
-
st.session_state.config, st.session_state.dir_home = build_VV_config(loaded_cfg=None)
|
2318 |
-
setup_streamlit_config(st.session_state.dir_home)
|
2319 |
|
2320 |
|
2321 |
if st.session_state['is_hf']:
|
2322 |
if 'proceed_to_main' not in st.session_state:
|
2323 |
st.session_state.proceed_to_main = True
|
2324 |
-
print(f"proceed_to_main {st.session_state['proceed_to_main']}")
|
2325 |
|
2326 |
if 'proceed_to_private' not in st.session_state:
|
2327 |
st.session_state.proceed_to_private = False
|
2328 |
-
print(f"proceed_to_private {st.session_state['proceed_to_private']}")
|
2329 |
|
2330 |
if 'private_file' not in st.session_state:
|
2331 |
st.session_state.private_file = True
|
2332 |
-
print(f"private_file {st.session_state['private_file']}")
|
2333 |
|
2334 |
else:
|
2335 |
if 'proceed_to_main' not in st.session_state:
|
2336 |
-
st.session_state.proceed_to_main =
|
2337 |
-
|
2338 |
if 'private_file' not in st.session_state:
|
2339 |
st.session_state.private_file = does_private_file_exist()
|
2340 |
if st.session_state.private_file:
|
2341 |
st.session_state.proceed_to_main = True
|
2342 |
-
print(f"private_file2 {st.session_state['private_file']}")
|
2343 |
-
print(f"proceed_to_main2 {st.session_state['proceed_to_main']}")
|
2344 |
|
2345 |
if 'proceed_to_private' not in st.session_state:
|
2346 |
st.session_state.proceed_to_private = False # New state variable to control the flow
|
2347 |
-
|
2348 |
|
2349 |
|
2350 |
if 'proceed_to_build_llm_prompt' not in st.session_state:
|
2351 |
st.session_state.proceed_to_build_llm_prompt = False # New state variable to control the flow
|
2352 |
|
|
|
2353 |
if 'processing_add_on' not in st.session_state:
|
2354 |
-
st.session_state['processing_add_on'] =
|
|
|
2355 |
|
2356 |
if 'formatted_json' not in st.session_state:
|
2357 |
st.session_state['formatted_json'] = None
|
@@ -2360,9 +2469,11 @@ if 'formatted_json_WFO' not in st.session_state:
|
|
2360 |
if 'formatted_json_GEO' not in st.session_state:
|
2361 |
st.session_state['formatted_json_GEO'] = None
|
2362 |
|
|
|
2363 |
if 'lacks_GPU' not in st.session_state:
|
2364 |
st.session_state['lacks_GPU'] = not torch.cuda.is_available()
|
2365 |
|
|
|
2366 |
if 'API_key_validation' not in st.session_state:
|
2367 |
st.session_state['API_key_validation'] = False
|
2368 |
if 'present_annotations' not in st.session_state:
|
@@ -2376,18 +2487,15 @@ if 'API_checked' not in st.session_state:
|
|
2376 |
if 'API_rechecked' not in st.session_state:
|
2377 |
st.session_state['API_rechecked'] = False
|
2378 |
|
|
|
2379 |
if 'json_report' not in st.session_state:
|
2380 |
st.session_state['json_report'] = False
|
2381 |
if 'hold_output' not in st.session_state:
|
2382 |
st.session_state['hold_output'] = False
|
2383 |
|
2384 |
-
if 'dir_uploaded_images' not in st.session_state:
|
2385 |
-
st.session_state['dir_uploaded_images'] = os.path.join(st.session_state.dir_home,'uploads')
|
2386 |
-
validate_dir(os.path.join(st.session_state.dir_home,'uploads'))
|
2387 |
|
2388 |
-
|
2389 |
-
|
2390 |
-
validate_dir(os.path.join(st.session_state.dir_home,'uploads_small'))
|
2391 |
|
2392 |
if 'cost_openai' not in st.session_state:
|
2393 |
st.session_state['cost_openai'] = None
|
@@ -2400,6 +2508,7 @@ if 'cost_mistral' not in st.session_state:
|
|
2400 |
if 'cost_local' not in st.session_state:
|
2401 |
st.session_state['cost_local'] = None
|
2402 |
|
|
|
2403 |
if 'settings_filename' not in st.session_state:
|
2404 |
st.session_state['settings_filename'] = None
|
2405 |
if 'loaded_settings_filename' not in st.session_state:
|
@@ -2407,16 +2516,13 @@ if 'loaded_settings_filename' not in st.session_state:
|
|
2407 |
if 'zip_filepath' not in st.session_state:
|
2408 |
st.session_state['zip_filepath'] = None
|
2409 |
|
|
|
2410 |
# Initialize session_state variables if they don't exist
|
2411 |
if 'prompt_info' not in st.session_state:
|
2412 |
st.session_state['prompt_info'] = {}
|
2413 |
if 'rules' not in st.session_state:
|
2414 |
st.session_state['rules'] = {}
|
2415 |
-
|
2416 |
-
st.session_state['required_fields'] = ['catalogNumber','order','family','scientificName',
|
2417 |
-
'scientificNameAuthorship','genus','subgenus','specificEpithet','infraspecificEpithet',
|
2418 |
-
'verbatimEventDate','eventDate',
|
2419 |
-
'country','stateProvince','county','municipality','locality','decimalLatitude','decimalLongitude','verbatimCoordinates',]
|
2420 |
|
2421 |
# These are the fields that are in SLTPvA that are not required by another parsing valication function:
|
2422 |
# "identifiedBy": "M.W. Lyon, Jr.",
|
@@ -2427,7 +2533,11 @@ if 'required_fields' not in st.session_state:
|
|
2427 |
# "degreeOfEstablishment": "",
|
2428 |
# "minimumElevationInMeters": "",
|
2429 |
# "maximumElevationInMeters": ""
|
2430 |
-
|
|
|
|
|
|
|
|
|
2431 |
|
2432 |
|
2433 |
if 'proceed_to_build_llm_prompt' not in st.session_state:
|
@@ -2441,46 +2551,27 @@ if 'proceed_to_api_keys' not in st.session_state:
|
|
2441 |
if 'proceed_to_space_saver' not in st.session_state:
|
2442 |
st.session_state.proceed_to_space_saver = False
|
2443 |
|
|
|
2444 |
#################################################################################################################################################
|
2445 |
# Main ##########################################################################################################################################
|
2446 |
#################################################################################################################################################
|
2447 |
-
|
2448 |
-
|
2449 |
-
|
2450 |
-
|
2451 |
-
|
2452 |
-
|
2453 |
-
|
2454 |
-
|
2455 |
-
|
2456 |
-
|
2457 |
-
|
2458 |
-
|
2459 |
-
|
2460 |
-
|
2461 |
-
main()
|
2462 |
-
|
2463 |
-
|
2464 |
-
|
2465 |
-
|
2466 |
-
|
2467 |
|
2468 |
|
2469 |
-
# print(f"proceed_to_main3 {st.session_state['proceed_to_main']}")
|
2470 |
-
# print(f"is_hf3 {st.session_state['is_hf']}")
|
2471 |
-
# print(f"private_file3 {st.session_state['private_file']}")
|
2472 |
-
# print(f"proceed_to_build_llm_prompt3 {st.session_state['proceed_to_build_llm_prompt']}")
|
2473 |
-
# print(f"proceed_to_private3 {st.session_state['proceed_to_private']}")
|
2474 |
|
2475 |
-
# # if not st.session_state.private_file and not st.session_state['is_hf']:
|
2476 |
-
# # create_private_file()
|
2477 |
-
# # elif st.session_state.proceed_to_build_llm_prompt:
|
2478 |
-
# if st.session_state.proceed_to_build_llm_prompt:
|
2479 |
-
# build_LLM_prompt_config()
|
2480 |
-
# # elif st.session_state.proceed_to_private and not st.session_state['is_hf']:
|
2481 |
-
# # create_private_file()
|
2482 |
-
# elif st.session_state.proceed_to_main:
|
2483 |
-
# main()
|
2484 |
|
2485 |
|
2486 |
|
|
|
15 |
from vouchervision.general_utils import test_GPU, get_cfg_from_full_path, summarize_expense_report, create_google_ocr_yaml_config, validate_dir
|
16 |
from vouchervision.model_maps import ModelMaps
|
17 |
from vouchervision.API_validation import APIvalidation
|
18 |
+
from vouchervision.utils_hf import upload_to_drive, image_to_base64, setup_streamlit_config, save_uploaded_file, check_prompt_yaml_filename, save_uploaded_local
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
#################################################################################################################################################
|
23 |
+
# Initializations ###############################################################################################################################
|
24 |
+
#################################################################################################################################################
|
25 |
+
|
26 |
+
st.set_page_config(layout="wide", page_icon='img/icon.ico', page_title='VoucherVision')
|
27 |
+
|
28 |
+
# Parse the 'is_hf' argument and set it in session state
|
29 |
+
if 'is_hf' not in st.session_state:
|
30 |
+
st.session_state['is_hf'] = True
|
31 |
|
32 |
|
33 |
########################################################################################################
|
34 |
### ADDED FOR HUGGING FACE ####
|
35 |
########################################################################################################
|
36 |
+
print(f"is_hf {st.session_state['is_hf']}")
|
37 |
+
# Default YAML file path
|
38 |
+
if 'config' not in st.session_state:
|
39 |
+
st.session_state.config, st.session_state.dir_home = build_VV_config(loaded_cfg=None)
|
40 |
+
setup_streamlit_config(st.session_state.dir_home)
|
41 |
+
|
42 |
if 'uploader_idk' not in st.session_state:
|
43 |
st.session_state['uploader_idk'] = 1
|
44 |
if 'input_list_small' not in st.session_state:
|
|
|
49 |
st.session_state['user_clicked_load_prompt_yaml'] = None
|
50 |
if 'new_prompt_yaml_filename' not in st.session_state:
|
51 |
st.session_state['new_prompt_yaml_filename'] = None
|
52 |
+
if 'view_local_gallery' not in st.session_state:
|
53 |
+
st.session_state['view_local_gallery'] = False
|
54 |
+
if 'dir_images_local_TEMP' not in st.session_state:
|
55 |
+
st.session_state['dir_images_local_TEMP'] = False
|
56 |
+
if 'dir_uploaded_images' not in st.session_state:
|
57 |
+
st.session_state['dir_uploaded_images'] = os.path.join(st.session_state.dir_home,'uploads')
|
58 |
+
validate_dir(os.path.join(st.session_state.dir_home,'uploads'))
|
59 |
+
if 'dir_uploaded_images_small' not in st.session_state:
|
60 |
+
st.session_state['dir_uploaded_images_small'] = os.path.join(st.session_state.dir_home,'uploads_small')
|
61 |
+
validate_dir(os.path.join(st.session_state.dir_home,'uploads_small'))
|
62 |
|
63 |
+
MAX_GALLERY_IMAGES = 20
|
64 |
+
GALLERY_IMAGE_SIZE = 96
|
65 |
|
66 |
|
67 |
|
68 |
+
def content_input_images(col_left, col_right):
|
69 |
st.write('---')
|
70 |
+
# col1, col2 = st.columns([2,8])
|
71 |
+
with col_left:
|
72 |
+
st.header('Input Images')
|
73 |
+
if not st.session_state.is_hf:
|
74 |
+
|
75 |
+
### Input Images Local
|
76 |
+
st.session_state.config['leafmachine']['project']['dir_images_local'] = st.text_input("Input images directory", st.session_state.config['leafmachine']['project'].get('dir_images_local', ''))
|
77 |
+
|
78 |
+
st.session_state.config['leafmachine']['project']['continue_run_from_partial_xlsx'] = st.text_input("Continue run from partially completed project XLSX", st.session_state.config['leafmachine']['project'].get('continue_run_from_partial_xlsx', ''), disabled=True)
|
79 |
+
else:
|
80 |
+
pass
|
81 |
+
|
82 |
+
with col_left:
|
83 |
+
if st.session_state.is_hf:
|
84 |
+
st.session_state['dir_uploaded_images'] = os.path.join(st.session_state.dir_home,'uploads')
|
85 |
+
st.session_state['dir_uploaded_images_small'] = os.path.join(st.session_state.dir_home,'uploads_small')
|
86 |
+
uploaded_files = st.file_uploader("Upload Images", type=['jpg', 'jpeg'], accept_multiple_files=True, key=st.session_state['uploader_idk'])
|
87 |
+
st.button("Use Test Image",help="This will clear any uploaded images and load the 1 provided test image.",on_click=use_test_image)
|
88 |
+
|
89 |
+
with col_right:
|
90 |
+
if st.session_state.is_hf:
|
91 |
+
if uploaded_files:
|
92 |
+
# Clear input image gallery and input list
|
93 |
+
clear_image_gallery()
|
94 |
|
95 |
+
# Process the new iamges
|
96 |
+
for uploaded_file in uploaded_files:
|
97 |
+
file_path = save_uploaded_file(st.session_state['dir_uploaded_images'], uploaded_file)
|
98 |
+
st.session_state['input_list'].append(file_path)
|
|
|
99 |
|
100 |
+
img = Image.open(file_path)
|
101 |
+
img.thumbnail((GALLERY_IMAGE_SIZE, GALLERY_IMAGE_SIZE), Image.Resampling.LANCZOS)
|
102 |
+
file_path_small = save_uploaded_file(st.session_state['dir_uploaded_images_small'], uploaded_file, img)
|
103 |
+
st.session_state['input_list_small'].append(file_path_small)
|
104 |
+
print(uploaded_file.name)
|
105 |
+
|
106 |
+
# Set the local images to the uploaded images
|
107 |
+
st.session_state.config['leafmachine']['project']['dir_images_local'] = st.session_state['dir_uploaded_images']
|
108 |
+
|
109 |
+
n_images = len([f for f in os.listdir(st.session_state.config['leafmachine']['project']['dir_images_local']) if os.path.isfile(os.path.join(st.session_state.config['leafmachine']['project']['dir_images_local'], f))])
|
110 |
+
st.session_state['processing_add_on'] = n_images
|
111 |
+
uploaded_files = None
|
112 |
+
st.session_state['uploader_idk'] += 1
|
113 |
+
st.info(f"Processing **{n_images}** images from {st.session_state.config['leafmachine']['project']['dir_images_local']}")
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
if st.session_state['input_list_small']:
|
118 |
+
if len(st.session_state['input_list_small']) > MAX_GALLERY_IMAGES:
|
119 |
+
# Only take the first 100 images from the list
|
120 |
+
images_to_display = st.session_state['input_list_small'][:MAX_GALLERY_IMAGES]
|
121 |
+
else:
|
122 |
+
# If there are less than 100 images, take them all
|
123 |
+
images_to_display = st.session_state['input_list_small']
|
124 |
+
st.image(images_to_display)
|
125 |
+
|
126 |
+
else:
|
127 |
+
st.session_state['view_local_gallery'] = st.toggle("View Image Gallery",)
|
128 |
+
|
129 |
+
if st.session_state['view_local_gallery'] and st.session_state['input_list_small'] and (st.session_state['dir_images_local_TEMP'] == st.session_state.config['leafmachine']['project']['dir_images_local']):
|
130 |
+
if MAX_GALLERY_IMAGES <= st.session_state['processing_add_on']:
|
131 |
+
info_txt = f"Showing {MAX_GALLERY_IMAGES} out of {st.session_state['processing_add_on']} images"
|
132 |
+
else:
|
133 |
+
info_txt = f"Showing {st.session_state['processing_add_on']} out of {st.session_state['processing_add_on']} images"
|
134 |
+
st.info(info_txt)
|
135 |
+
try:
|
136 |
+
st.image(st.session_state['input_list_small'], width=GALLERY_IMAGE_SIZE)
|
137 |
+
except:
|
138 |
+
pass
|
139 |
+
|
140 |
+
elif not st.session_state['view_local_gallery'] and st.session_state['input_list_small'] and (st.session_state['dir_images_local_TEMP'] == st.session_state.config['leafmachine']['project']['dir_images_local']):
|
141 |
+
pass
|
142 |
+
elif not st.session_state['view_local_gallery'] and not st.session_state['input_list_small'] and (st.session_state['dir_images_local_TEMP'] == st.session_state.config['leafmachine']['project']['dir_images_local']):
|
143 |
+
pass
|
144 |
+
elif st.session_state['input_list_small'] and (st.session_state['dir_images_local_TEMP'] != st.session_state.config['leafmachine']['project']['dir_images_local']):
|
145 |
+
dir_images_local = st.session_state.config['leafmachine']['project']['dir_images_local']
|
146 |
+
count_n_imgs = list_jpg_files(dir_images_local)
|
147 |
+
st.session_state['processing_add_on'] = count_n_imgs
|
148 |
+
# print(st.session_state['processing_add_on'])
|
149 |
+
st.session_state['dir_images_local_TEMP'] = st.session_state.config['leafmachine']['project']['dir_images_local']
|
150 |
+
print("rerun")
|
151 |
+
st.rerun()
|
152 |
+
|
153 |
+
|
154 |
+
def list_jpg_files(directory_path):
|
155 |
+
jpg_count = 0
|
156 |
+
clear_image_gallery()
|
157 |
+
st.session_state['input_list_small'] = []
|
158 |
+
|
159 |
+
if not os.path.isdir(directory_path):
|
160 |
+
return None
|
161 |
+
|
162 |
+
jpg_count = count_jpg_images(directory_path)
|
163 |
+
|
164 |
+
jpg_files = []
|
165 |
+
for root, dirs, files in os.walk(directory_path):
|
166 |
+
for file in files:
|
167 |
+
if file.lower().endswith('.jpg'):
|
168 |
+
jpg_files.append(os.path.join(root, file))
|
169 |
+
if len(jpg_files) == MAX_GALLERY_IMAGES:
|
170 |
+
break
|
171 |
+
if len(jpg_files) == MAX_GALLERY_IMAGES:
|
172 |
+
break
|
173 |
+
|
174 |
+
for simg in jpg_files:
|
175 |
|
176 |
+
simg2 = Image.open(simg)
|
177 |
+
simg2.thumbnail((GALLERY_IMAGE_SIZE, GALLERY_IMAGE_SIZE), Image.Resampling.LANCZOS)
|
178 |
+
file_path_small = save_uploaded_local(st.session_state['dir_uploaded_images_small'], simg, simg2)
|
179 |
+
st.session_state['input_list_small'].append(file_path_small)
|
180 |
+
return jpg_count
|
181 |
+
|
182 |
+
|
183 |
+
def count_jpg_images(directory_path):
|
184 |
+
if not os.path.isdir(directory_path):
|
185 |
+
return None
|
186 |
|
187 |
+
jpg_count = 0
|
188 |
+
for root, dirs, files in os.walk(directory_path):
|
189 |
+
for file in files:
|
190 |
+
if file.lower().endswith('.jpg'):
|
191 |
+
jpg_count += 1
|
192 |
|
193 |
+
return jpg_count
|
194 |
|
195 |
|
196 |
def create_download_button(zip_filepath, col):
|
197 |
with col:
|
198 |
+
labal_n_images = f"Download Results for {st.session_state['processing_add_on']} Images"
|
199 |
with open(zip_filepath, 'rb') as f:
|
200 |
bytes_io = BytesIO(f.read())
|
201 |
st.download_button(
|
202 |
+
label=labal_n_images,
|
203 |
type='primary',
|
204 |
data=bytes_io,
|
205 |
file_name=os.path.basename(zip_filepath),
|
|
|
229 |
st.info(f"Processing images from {os.path.join(st.session_state.dir_home,'demo','demo_images')}")
|
230 |
st.session_state.config['leafmachine']['project']['dir_images_local'] = os.path.join(st.session_state.dir_home,'demo','demo_images')
|
231 |
n_images = len([f for f in os.listdir(st.session_state.config['leafmachine']['project']['dir_images_local']) if os.path.isfile(os.path.join(st.session_state.config['leafmachine']['project']['dir_images_local'], f))])
|
232 |
+
st.session_state['processing_add_on'] = n_images
|
233 |
clear_image_gallery()
|
234 |
st.session_state['uploader_idk'] += 1
|
235 |
+
for file in os.listdir(st.session_state.config['leafmachine']['project']['dir_images_local']):
|
236 |
+
file_path = save_uploaded_file(os.path.join(st.session_state.dir_home,'demo','demo_images'), file)
|
237 |
+
st.session_state['input_list'].append(file_path)
|
238 |
+
|
239 |
+
img = Image.open(file_path)
|
240 |
+
img.thumbnail((GALLERY_IMAGE_SIZE, GALLERY_IMAGE_SIZE), Image.Resampling.LANCZOS)
|
241 |
+
file_path_small = save_uploaded_file(st.session_state['dir_uploaded_images_small'], file, img)
|
242 |
+
st.session_state['input_list_small'].append(file_path_small)
|
243 |
|
244 |
|
245 |
def create_download_button_yaml(file_path, selected_yaml_file):
|
|
|
516 |
|
517 |
|
518 |
def does_private_file_exist():
|
519 |
+
dir_home = os.path.dirname(__file__)
|
520 |
path_cfg_private = os.path.join(dir_home, 'PRIVATE_DATA.yaml')
|
521 |
return os.path.exists(path_cfg_private)
|
522 |
|
|
|
720 |
|
721 |
|
722 |
def get_private_file():
|
723 |
+
dir_home = os.path.dirname(__file__)
|
724 |
path_cfg_private = os.path.join(dir_home, 'PRIVATE_DATA.yaml')
|
725 |
return get_cfg_from_full_path(path_cfg_private)
|
726 |
|
727 |
+
def blog_text_and_image(text=None, fullpath=None, width=700):
|
728 |
+
if text:
|
729 |
+
st.markdown(f"{text}")
|
730 |
+
if fullpath:
|
731 |
+
st.session_state.logo = Image.open(fullpath)
|
732 |
+
st.image(st.session_state.logo, width=width)
|
733 |
+
|
734 |
+
def blog_text(text_bold, text):
|
735 |
+
st.markdown(f"- **{text_bold}**{text}")
|
736 |
+
def blog_text_plain(text_bold, text):
|
737 |
+
st.markdown(f"**{text_bold}** {text}")
|
738 |
+
|
739 |
+
def create_private_file():
|
740 |
+
section_left = 2
|
741 |
+
section_mid = 6
|
742 |
+
section_right = 2
|
743 |
+
|
744 |
st.session_state.proceed_to_main = False
|
745 |
st.title("VoucherVision")
|
746 |
+
_, col_private,__= st.columns([section_left,section_mid, section_right])
|
747 |
+
|
748 |
+
|
749 |
|
750 |
if st.session_state.private_file:
|
751 |
cfg_private = get_private_file()
|
|
|
755 |
cfg_private['openai']['OPENAI_API_KEY'] =''
|
756 |
|
757 |
cfg_private['openai_azure'] = {}
|
758 |
+
cfg_private['openai_azure']['OPENAI_API_KEY_AZURE'] = ''
|
759 |
+
cfg_private['openai_azure']['OPENAI_API_VERSION'] = ''
|
760 |
+
cfg_private['openai_azure']['OPENAI_API_BASE'] =''
|
761 |
+
cfg_private['openai_azure']['OPENAI_ORGANIZATION'] =''
|
762 |
+
cfg_private['openai_azure']['OPENAI_API_TYPE'] =''
|
763 |
+
|
764 |
+
cfg_private['google'] = {}
|
765 |
+
cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS'] =''
|
766 |
+
cfg_private['google']['GOOGLE_PALM_API'] =''
|
767 |
+
cfg_private['google']['GOOGLE_PROJECT_ID'] =''
|
768 |
+
cfg_private['google']['GOOGLE_LOCATION'] =''
|
769 |
+
|
770 |
+
cfg_private['mistral'] = {}
|
771 |
+
cfg_private['mistral']['MISTRAL_API_KEY'] =''
|
772 |
+
|
773 |
+
cfg_private['here'] = {}
|
774 |
+
cfg_private['here']['APP_ID'] =''
|
775 |
+
cfg_private['here']['API_KEY'] =''
|
776 |
+
|
777 |
+
cfg_private['open_cage_geocode'] = {}
|
778 |
+
cfg_private['open_cage_geocode']['API_KEY'] =''
|
779 |
|
780 |
|
781 |
with col_private:
|
782 |
st.header("Set API keys")
|
|
|
783 |
st.warning("To commit changes to API keys you must press the 'Set API Keys' button at the bottom of the page.")
|
784 |
st.write("Before using VoucherVision you must set your API keys. All keys are stored locally on your computer and are never made public.")
|
785 |
st.write("API keys are stored in `../VoucherVision/PRIVATE_DATA.yaml`.")
|
786 |
+
st.write("Deleting this file will allow you to reset API keys. Alternatively, you can edit the keys in the user interface or by manually editing the `.yaml` file in a text editor.")
|
787 |
st.write("Leave keys blank if you do not intend to use that service.")
|
788 |
|
789 |
st.write("---")
|
790 |
+
st.subheader("Google Vision (*Required*) / Google PaLM 2 / Google Gemini")
|
791 |
st.markdown("VoucherVision currently uses [Google Vision API](https://cloud.google.com/vision/docs/ocr) for OCR. Generating an API key for this is more involved than the others. [Please carefully follow the instructions outlined here to create and setup your account.](https://cloud.google.com/vision/docs/setup) ")
|
792 |
+
st.markdown("""Once your account is created, [visit this page](https://console.cloud.google.com) and create a project. Then follow these instructions:""")
|
793 |
+
|
794 |
+
with st.expander("**View Google API Instructions**"):
|
795 |
+
|
796 |
+
blog_text_and_image(text="Select your project, then in the search bar, search for `vertex ai` and select the option in the photo below.",
|
797 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_00.png'))
|
798 |
+
|
799 |
+
blog_text_and_image(text="On the main overview page, click `Enable All Recommended APIs`. Sometimes this button may be hidden. In that case, enable all of the suggested APIs listed on this page.",
|
800 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_0.png'))
|
801 |
+
|
802 |
+
blog_text_and_image(text="Sometimes this button may be hidden. In that case, enable all of the suggested APIs listed on this page.",
|
803 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_2.png'))
|
804 |
+
|
805 |
+
blog_text_and_image(text="Make sure that all APIs are enabled.",
|
806 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_1.png'))
|
807 |
+
|
808 |
+
blog_text_and_image(text="Find the `Vision AI API` service and go to its page.",
|
809 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_3.png'))
|
810 |
+
|
811 |
+
blog_text_and_image(text="Find the `Vision AI API` service and go to its page. This is the API service required to use OCR in VoucherVision and must be enabled.",
|
812 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_6.png'))
|
813 |
+
|
814 |
+
blog_text_and_image(text="You can also search for the Vertex AI Vision service.",
|
815 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_4.png'))
|
816 |
+
|
817 |
+
blog_text_and_image(text=None,
|
818 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_5.png'))
|
819 |
+
|
820 |
+
st.subheader("Getting a Google JSON authentication key")
|
821 |
+
st.write("Google uses a JSON file to store additional authentication information. Save this file in a safe, private location and assign the `GOOGLE_APPLICATION_CREDENTIALS` value to the file path. For Hugging Face, copy the contents of the JSON file including the `\{\}` and paste it as the secret value.")
|
822 |
+
st.write("To download your JSON key...")
|
823 |
+
blog_text_and_image(text="Open the navigation menu. Click on the hamburger menu (three horizontal lines) in the top left corner. Go to IAM & Admin. ",
|
824 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_7.png'),width=300)
|
825 |
+
|
826 |
+
blog_text_and_image(text="In the navigation pane, hover over `IAM & Admin` and then click on `Service accounts`.",
|
827 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_8.png'))
|
828 |
+
|
829 |
+
blog_text_and_image(text="Find the default Compute Engine service account, select it.",
|
830 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_9.png'))
|
831 |
+
|
832 |
+
blog_text_and_image(text="Click `Add Key`.",
|
833 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_10.png'))
|
834 |
+
|
835 |
+
blog_text_and_image(text="Select `JSON` and click create. This will download your key. Store this in a safe location. The file path to this safe location is the value that you enter into the `GOOGLE_APPLICATION_CREDENTIALS` value.",
|
836 |
+
fullpath=os.path.join(st.session_state.dir_home, 'demo','google','google_api_11.png'))
|
837 |
+
|
838 |
+
blog_text(text_bold="Store Safely", text=": This file contains sensitive data that can be used to authenticate and bill your Google Cloud account. Never commit it to public repositories or expose it in any way. Always keep it safe and secure.")
|
839 |
+
|
840 |
+
st.write("Below is an example of the JSON key.")
|
841 |
+
st.json({
|
842 |
+
"type": "service_account",
|
843 |
+
"project_id": "NAME OF YOUR PROJECT",
|
844 |
+
"private_key_id": "XXXXXXXXXXXXXXXXXXXXXXXX",
|
845 |
+
"private_key": "-----BEGIN PRIVATE KEY-----\naaaaaaaaaaa\n-----END PRIVATE KEY-----\n",
|
846 |
+
"client_email": "EMAIL-ADDRESS@developer.gserviceaccount.com",
|
847 |
+
"client_id": "ID NUMBER",
|
848 |
+
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
849 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
850 |
+
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
851 |
+
"client_x509_cert_url": "A LONG URL",
|
852 |
+
"universe_domain": "googleapis.com"
|
853 |
+
})
|
854 |
+
google_application_credentials = st.text_input(label = 'Full path to Google Cloud JSON API key file', value = cfg_private['google'].get('GOOGLE_APPLICATION_CREDENTIALS', ''),
|
855 |
+
placeholder = 'e.g. C:/Documents/Secret_Files/google_API/application_default_credentials.json',
|
856 |
+
help ="This API Key is in the form of a JSON file. Please save the JSON file in a safe directory. DO NOT store the JSON key inside of the VoucherVision directory.",
|
857 |
+
type='password')
|
858 |
+
google_project_location = st.text_input(label = 'Google project location', value = cfg_private['google'].get('GOOGLE_LOCATION', ''),
|
859 |
+
placeholder = 'e.g. us-central1',
|
860 |
+
help ="This is the location of where your Google services are operating.",
|
861 |
+
type='password')
|
862 |
+
google_project_id = st.text_input(label = 'Google project ID', value = cfg_private['google'].get('GOOGLE_PROJECT_ID', ''),
|
863 |
+
placeholder = 'e.g. my-project-name',
|
864 |
+
help ="This is the value in the `project_id` field in your JSON key.",
|
865 |
+
type='password')
|
866 |
|
867 |
|
868 |
st.write("---")
|
869 |
st.subheader("OpenAI")
|
870 |
st.markdown("API key for first-party OpenAI API. Create an account with OpenAI [here](https://platform.openai.com/signup), then create an API key [here](https://platform.openai.com/account/api-keys).")
|
871 |
+
openai_api_key = st.text_input("openai_api_key", cfg_private['openai'].get('OPENAI_API_KEY', ''),
|
|
|
|
|
|
|
872 |
help='The actual API key. Likely to be a string of 2 character, a dash, and then a 48-character string: sk-XXXXXXXX...',
|
873 |
placeholder = 'e.g. sk-XXXXXXXX...',
|
874 |
type='password')
|
875 |
+
|
|
|
876 |
|
877 |
st.write("---")
|
878 |
st.subheader("OpenAI - Azure")
|
879 |
st.markdown("This version OpenAI relies on Azure servers directly as is intended for private enterprise instances of OpenAI's services, such as [UM-GPT](https://its.umich.edu/computing/ai). Administrators will provide you with the following information.")
|
880 |
+
azure_openai_api_version = st.text_input("OPENAI_API_VERSION", cfg_private['openai_azure'].get('OPENAI_API_VERSION', ''),
|
881 |
help='API Version e.g. "2023-05-15"',
|
882 |
placeholder = 'e.g. 2023-05-15',
|
883 |
type='password')
|
884 |
+
azure_openai_api_key = st.text_input("OPENAI_API_KEY_AZURE", cfg_private['openai_azure'].get('OPENAI_API_KEY_AZURE', ''),
|
885 |
+
help='The actual API key. Likely to be a 32-character string. This might also be called "endpoint."',
|
886 |
placeholder = 'e.g. 12333333333333333333333333333332',
|
887 |
type='password')
|
888 |
+
azure_openai_api_base = st.text_input("OPENAI_API_BASE", cfg_private['openai_azure'].get('OPENAI_API_BASE', ''),
|
889 |
help='The base url for the API e.g. "https://api.umgpt.umich.edu/azure-openai-api"',
|
890 |
placeholder = 'e.g. https://api.umgpt.umich.edu/azure-openai-api',
|
891 |
type='password')
|
892 |
+
azure_openai_organization = st.text_input("OPENAI_ORGANIZATION", cfg_private['openai_azure'].get('OPENAI_ORGANIZATION', ''),
|
893 |
+
help='Your organization code. Likely a short string.',
|
894 |
placeholder = 'e.g. 123456',
|
895 |
type='password')
|
896 |
+
azure_openai_api_type = st.text_input("OPENAI_API_TYPE", cfg_private['openai_azure'].get('OPENAI_API_TYPE', ''),
|
897 |
help='The API type. Typically "azure"',
|
898 |
placeholder = 'e.g. azure',
|
899 |
type='password')
|
900 |
+
|
901 |
+
# st.write("---")
|
902 |
+
# st.subheader("Google PaLM 2 (Deprecated)")
|
903 |
+
# st.write("Plea")
|
904 |
+
# st.markdown('Follow these [instructions](https://developers.generativeai.google/tutorials/setup) to generate an API key for PaLM 2. You may need to also activate an account with [MakerSuite](https://makersuite.google.com/app/apikey) and enable "early access." If this is deprecated, then use the full Google API instructions above.')
|
905 |
+
|
906 |
+
# google_palm = st.text_input("Google PaLM 2 API Key", cfg_private['google'].get('GOOGLE_PALM_API', ''),
|
907 |
+
# help='The MakerSuite API key e.g. a 32-character string',
|
908 |
+
# placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
|
909 |
+
# type='password')
|
910 |
+
|
911 |
+
|
912 |
st.write("---")
|
913 |
+
st.subheader("MistralAI")
|
914 |
+
st.markdown('Follow these [instructions](https://platform.here.com/sign-up?step=verify-identity) to generate an API key for HERE.')
|
915 |
+
mistral_API_KEY = st.text_input("MistralAI API Key", cfg_private['mistral'].get('MISTRAL_API_KEY', ''),
|
916 |
+
help='e.g. a 32-character string',
|
|
|
|
|
|
|
917 |
placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
|
918 |
type='password')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
919 |
|
920 |
+
|
921 |
+
st.write("---")
|
922 |
+
st.subheader("HERE Geocoding")
|
923 |
+
st.markdown('Follow these [instructions](https://platform.here.com/sign-up?step=verify-identity) to generate an API key for HERE.')
|
924 |
+
hre_APP_ID = st.text_input("HERE Geocoding App ID", cfg_private['here'].get('APP_ID', ''),
|
925 |
+
help='e.g. a 32-character string',
|
926 |
+
placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
|
927 |
+
type='password')
|
928 |
+
hre_API_KEY = st.text_input("HERE Geocoding API Key", cfg_private['here'].get('API_KEY', ''),
|
929 |
+
help='e.g. a 32-character string',
|
930 |
+
placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
|
931 |
+
type='password')
|
932 |
+
|
933 |
|
934 |
|
935 |
st.button("Set API Keys",type='primary', on_click=save_changes_to_API_keys, args=[cfg_private,openai_api_key,azure_openai_api_version,azure_openai_api_key,
|
936 |
+
azure_openai_api_base,azure_openai_organization,azure_openai_api_type,
|
937 |
+
google_application_credentials, google_project_location, google_project_id,
|
938 |
+
mistral_API_KEY, hre_APP_ID, hre_API_KEY])
|
939 |
if st.button('Proceed to VoucherVision'):
|
940 |
+
st.session_state.private_file = does_private_file_exist()
|
941 |
st.session_state.proceed_to_private = False
|
942 |
st.session_state.proceed_to_main = True
|
943 |
+
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
944 |
|
945 |
|
946 |
def save_changes_to_API_keys(cfg_private,openai_api_key,azure_openai_api_version,azure_openai_api_key,
|
947 |
+
azure_openai_api_base,azure_openai_organization,azure_openai_api_type,
|
948 |
+
google_application_credentials, google_project_location, google_project_id,
|
949 |
+
mistral_API_KEY, hre_APP_ID, hre_API_KEY):
|
950 |
+
|
951 |
# Update the configuration dictionary with the new values
|
952 |
cfg_private['openai']['OPENAI_API_KEY'] = openai_api_key
|
953 |
|
954 |
+
cfg_private['openai_azure']['OPENAI_API_VERSION'] = azure_openai_api_version
|
955 |
+
cfg_private['openai_azure']['OPENAI_API_KEY_AZURE'] = azure_openai_api_key
|
956 |
+
cfg_private['openai_azure']['OPENAI_API_BASE'] = azure_openai_api_base
|
957 |
+
cfg_private['openai_azure']['OPENAI_ORGANIZATION'] = azure_openai_organization
|
958 |
+
cfg_private['openai_azure']['OPENAI_API_TYPE'] = azure_openai_api_type
|
959 |
+
|
960 |
+
cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS'] = google_application_credentials
|
961 |
+
cfg_private['google']['GOOGLE_PROJECT_ID'] = google_project_location
|
962 |
+
cfg_private['google']['GOOGLE_LOCATION'] = google_project_id
|
963 |
|
964 |
+
cfg_private['mistral']['MISTRAL_API_KEY'] = mistral_API_KEY
|
965 |
|
966 |
+
cfg_private['here']['APP_ID'] = hre_APP_ID
|
967 |
+
cfg_private['here']['API_KEY'] = hre_API_KEY
|
968 |
# Call the function to write the updated configuration to the YAML file
|
969 |
write_config_file(cfg_private, st.session_state.dir_home, filename="PRIVATE_DATA.yaml")
|
970 |
+
# st.session_state.private_file = does_private_file_exist()
|
971 |
|
972 |
# Function to load a YAML file and update session_state
|
973 |
def load_prompt_yaml(filename):
|
|
|
1711 |
# st.subheader('Run VoucherVision')
|
1712 |
N_STEPS = 6
|
1713 |
|
1714 |
+
# if st.session_state.is_hf:
|
1715 |
+
# count_n_imgs = determine_n_images()
|
1716 |
+
# if count_n_imgs > 0:
|
1717 |
+
# st.session_state['processing_add_on'] = count_n_imgs
|
1718 |
+
# else:
|
1719 |
+
# st.session_state['processing_add_on'] = 0
|
1720 |
|
1721 |
if check_if_usable(is_hf=st.session_state['is_hf']):
|
1722 |
+
b_text = f"Start Processing {st.session_state['processing_add_on']} Images" if st.session_state['processing_add_on'] > 1 else f"Start Processing {st.session_state['processing_add_on']} Image"
|
1723 |
+
if st.session_state['processing_add_on'] == 0:
|
1724 |
+
b_text = f"Start Processing"
|
1725 |
+
if st.button(b_text, type='primary',use_container_width=True):
|
1726 |
st.session_state['formatted_json'] = {}
|
1727 |
st.session_state['formatted_json_WFO'] = {}
|
1728 |
st.session_state['formatted_json_GEO'] = {}
|
|
|
1877 |
|
1878 |
|
1879 |
|
1880 |
+
def content_project_settings(col):
|
1881 |
+
### Project
|
1882 |
+
with col:
|
1883 |
+
st.header('Project Settings')
|
1884 |
+
|
|
|
1885 |
st.session_state.config['leafmachine']['project']['run_name'] = st.text_input("Run name", st.session_state.config['leafmachine']['project'].get('run_name', ''),key=63456)
|
1886 |
st.session_state.config['leafmachine']['project']['dir_output'] = st.text_input("Output directory", st.session_state.config['leafmachine']['project'].get('dir_output', ''))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1887 |
|
1888 |
|
1889 |
|
|
|
1890 |
def content_llm_cost():
|
1891 |
st.write("---")
|
1892 |
st.header('LLM Cost Calculator')
|
|
|
1996 |
st.session_state['API_rechecked'] = True
|
1997 |
st.rerun()
|
1998 |
# with col_llm_2c:
|
1999 |
+
if not st.session_state.is_hf:
|
2000 |
+
if st.button("Edit API Keys"):
|
2001 |
+
st.session_state.proceed_to_private = True
|
2002 |
+
st.rerun()
|
2003 |
+
|
2004 |
|
2005 |
|
2006 |
|
|
|
2056 |
st.session_state.config['leafmachine']['do_create_OCR_helper_image'] = do_create_OCR_helper_image
|
2057 |
|
2058 |
|
2059 |
+
|
|
|
2060 |
|
2061 |
# Get the current OCR option from session state
|
2062 |
OCR_option = st.session_state.config['leafmachine']['project']['OCR_option']
|
|
|
2087 |
OCR_option = 'both'
|
2088 |
else:
|
2089 |
raise
|
2090 |
+
|
2091 |
+
st.write("Supplement Google Vision OCR with trOCR (handwriting OCR) using `microsoft/trocr-base-handwritten`. This option requires Google Vision API and a GPU.")
|
2092 |
+
do_use_trOCR = st.checkbox("Enable trOCR", value=st.session_state.config['leafmachine']['project']['do_use_trOCR'],disabled=st.session_state['lacks_GPU'])
|
2093 |
+
st.session_state.config['leafmachine']['project']['do_use_trOCR'] = do_use_trOCR
|
2094 |
+
|
2095 |
|
2096 |
st.session_state.config['leafmachine']['project']['OCR_option'] = OCR_option
|
2097 |
st.markdown("Below is an example of what the LLM would see given the choice of OCR ensemble. One, two, or three version of OCR can be fed into the LLM prompt. Typically, 'printed + handwritten' works well. If you have a GPU then you can enable trOCR.")
|
|
|
2387 |
content_header()
|
2388 |
|
2389 |
|
2390 |
+
col_input, col_gallery = st.columns([4,8])
|
2391 |
+
content_project_settings(col_input)
|
2392 |
+
content_input_images(col_input, col_gallery)
|
2393 |
|
2394 |
+
# if st.session_state['is_hf']:
|
2395 |
+
# content_project_settings()
|
2396 |
+
# content_input_images_hf()
|
2397 |
+
# else:
|
2398 |
+
# col1, col2 = st.columns([1,1])
|
2399 |
+
# with col1:
|
2400 |
+
# content_project_settings()
|
2401 |
+
# with col2:
|
2402 |
+
# content_input_images()
|
2403 |
|
2404 |
|
|
|
2405 |
col3, col4 = st.columns([1,1])
|
2406 |
with col3:
|
2407 |
content_prompt_and_llm_version()
|
|
|
2417 |
|
2418 |
|
2419 |
|
|
|
|
|
|
|
|
|
|
|
2420 |
|
|
|
|
|
|
|
2421 |
|
2422 |
|
2423 |
#################################################################################################################################################
|
|
|
2425 |
#################################################################################################################################################
|
2426 |
|
2427 |
|
2428 |
+
|
2429 |
+
|
2430 |
+
|
|
|
|
|
2431 |
|
2432 |
|
2433 |
if st.session_state['is_hf']:
|
2434 |
if 'proceed_to_main' not in st.session_state:
|
2435 |
st.session_state.proceed_to_main = True
|
|
|
2436 |
|
2437 |
if 'proceed_to_private' not in st.session_state:
|
2438 |
st.session_state.proceed_to_private = False
|
|
|
2439 |
|
2440 |
if 'private_file' not in st.session_state:
|
2441 |
st.session_state.private_file = True
|
|
|
2442 |
|
2443 |
else:
|
2444 |
if 'proceed_to_main' not in st.session_state:
|
2445 |
+
st.session_state.proceed_to_main = False # New state variable to control the flow
|
2446 |
+
|
2447 |
if 'private_file' not in st.session_state:
|
2448 |
st.session_state.private_file = does_private_file_exist()
|
2449 |
if st.session_state.private_file:
|
2450 |
st.session_state.proceed_to_main = True
|
|
|
|
|
2451 |
|
2452 |
if 'proceed_to_private' not in st.session_state:
|
2453 |
st.session_state.proceed_to_private = False # New state variable to control the flow
|
2454 |
+
|
2455 |
|
2456 |
|
2457 |
if 'proceed_to_build_llm_prompt' not in st.session_state:
|
2458 |
st.session_state.proceed_to_build_llm_prompt = False # New state variable to control the flow
|
2459 |
|
2460 |
+
|
2461 |
if 'processing_add_on' not in st.session_state:
|
2462 |
+
st.session_state['processing_add_on'] = 0
|
2463 |
+
|
2464 |
|
2465 |
if 'formatted_json' not in st.session_state:
|
2466 |
st.session_state['formatted_json'] = None
|
|
|
2469 |
if 'formatted_json_GEO' not in st.session_state:
|
2470 |
st.session_state['formatted_json_GEO'] = None
|
2471 |
|
2472 |
+
|
2473 |
if 'lacks_GPU' not in st.session_state:
|
2474 |
st.session_state['lacks_GPU'] = not torch.cuda.is_available()
|
2475 |
|
2476 |
+
|
2477 |
if 'API_key_validation' not in st.session_state:
|
2478 |
st.session_state['API_key_validation'] = False
|
2479 |
if 'present_annotations' not in st.session_state:
|
|
|
2487 |
if 'API_rechecked' not in st.session_state:
|
2488 |
st.session_state['API_rechecked'] = False
|
2489 |
|
2490 |
+
|
2491 |
if 'json_report' not in st.session_state:
|
2492 |
st.session_state['json_report'] = False
|
2493 |
if 'hold_output' not in st.session_state:
|
2494 |
st.session_state['hold_output'] = False
|
2495 |
|
|
|
|
|
|
|
2496 |
|
2497 |
+
|
2498 |
+
|
|
|
2499 |
|
2500 |
if 'cost_openai' not in st.session_state:
|
2501 |
st.session_state['cost_openai'] = None
|
|
|
2508 |
if 'cost_local' not in st.session_state:
|
2509 |
st.session_state['cost_local'] = None
|
2510 |
|
2511 |
+
|
2512 |
if 'settings_filename' not in st.session_state:
|
2513 |
st.session_state['settings_filename'] = None
|
2514 |
if 'loaded_settings_filename' not in st.session_state:
|
|
|
2516 |
if 'zip_filepath' not in st.session_state:
|
2517 |
st.session_state['zip_filepath'] = None
|
2518 |
|
2519 |
+
|
2520 |
# Initialize session_state variables if they don't exist
|
2521 |
if 'prompt_info' not in st.session_state:
|
2522 |
st.session_state['prompt_info'] = {}
|
2523 |
if 'rules' not in st.session_state:
|
2524 |
st.session_state['rules'] = {}
|
2525 |
+
|
|
|
|
|
|
|
|
|
2526 |
|
2527 |
# These are the fields that are in SLTPvA that are not required by another parsing valication function:
|
2528 |
# "identifiedBy": "M.W. Lyon, Jr.",
|
|
|
2533 |
# "degreeOfEstablishment": "",
|
2534 |
# "minimumElevationInMeters": "",
|
2535 |
# "maximumElevationInMeters": ""
|
2536 |
+
if 'required_fields' not in st.session_state:
|
2537 |
+
st.session_state['required_fields'] = ['catalogNumber','order','family','scientificName',
|
2538 |
+
'scientificNameAuthorship','genus','subgenus','specificEpithet','infraspecificEpithet',
|
2539 |
+
'verbatimEventDate','eventDate',
|
2540 |
+
'country','stateProvince','county','municipality','locality','decimalLatitude','decimalLongitude','verbatimCoordinates',]
|
2541 |
|
2542 |
|
2543 |
if 'proceed_to_build_llm_prompt' not in st.session_state:
|
|
|
2551 |
if 'proceed_to_space_saver' not in st.session_state:
|
2552 |
st.session_state.proceed_to_space_saver = False
|
2553 |
|
2554 |
+
|
2555 |
#################################################################################################################################################
|
2556 |
# Main ##########################################################################################################################################
|
2557 |
#################################################################################################################################################
|
2558 |
+
if st.session_state['is_hf']:
|
2559 |
+
if st.session_state.proceed_to_build_llm_prompt:
|
2560 |
+
build_LLM_prompt_config()
|
2561 |
+
elif st.session_state.proceed_to_main:
|
2562 |
+
main()
|
2563 |
+
else:
|
2564 |
+
if not st.session_state.private_file:
|
2565 |
+
create_private_file()
|
2566 |
+
elif st.session_state.proceed_to_build_llm_prompt:
|
2567 |
+
build_LLM_prompt_config()
|
2568 |
+
elif st.session_state.proceed_to_private and not st.session_state['is_hf']:
|
2569 |
+
create_private_file()
|
2570 |
+
elif st.session_state.proceed_to_main:
|
2571 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
2572 |
|
2573 |
|
|
|
|
|
|
|
|
|
|
|
2574 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2575 |
|
2576 |
|
2577 |
|
demo/demo_images/{MICH_7574789_Cyperaceae_Carex_scoparia.jpg → MICH_16205594_Poaceae_Jouvea_pilosa.jpg}
RENAMED
File without changes
|
demo/google/google_api_0.PNG
ADDED
Git LFS Details
|
demo/google/google_api_00.PNG
ADDED
Git LFS Details
|
demo/google/google_api_1.PNG
ADDED
Git LFS Details
|
demo/google/google_api_10.PNG
ADDED
Git LFS Details
|
demo/google/google_api_11.PNG
ADDED
Git LFS Details
|
demo/google/google_api_2.PNG
ADDED
Git LFS Details
|
demo/google/google_api_3.PNG
ADDED
Git LFS Details
|
demo/google/google_api_4.PNG
ADDED
Git LFS Details
|
demo/google/google_api_5.PNG
ADDED
Git LFS Details
|
demo/google/google_api_6.PNG
ADDED
Git LFS Details
|
demo/google/google_api_7.PNG
ADDED
Git LFS Details
|
demo/google/google_api_8.PNG
ADDED
Git LFS Details
|
demo/google/google_api_9.PNG
ADDED
Git LFS Details
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
run_VoucherVision.py
CHANGED
@@ -9,6 +9,8 @@ import os, sys
|
|
9 |
# st.write("filename:", uploaded_file.name)
|
10 |
# st.write(bytes_data)
|
11 |
|
|
|
|
|
12 |
|
13 |
def resolve_path(path):
|
14 |
resolved_path = os.path.abspath(os.path.join(os.getcwd(), path))
|
@@ -29,7 +31,7 @@ if __name__ == "__main__":
|
|
29 |
# "--server.port=8545",
|
30 |
"--server.port=8546",
|
31 |
# Toggle below for HF vs Local
|
32 |
-
"--is_hf=1",
|
33 |
# "--is_hf=0",
|
34 |
]
|
35 |
sys.exit(stcli.main())
|
|
|
9 |
# st.write("filename:", uploaded_file.name)
|
10 |
# st.write(bytes_data)
|
11 |
|
12 |
+
# pip install protobuf==3.20.0
|
13 |
+
|
14 |
|
15 |
def resolve_path(path):
|
16 |
resolved_path = os.path.abspath(os.path.join(os.getcwd(), path))
|
|
|
31 |
# "--server.port=8545",
|
32 |
"--server.port=8546",
|
33 |
# Toggle below for HF vs Local
|
34 |
+
# "--is_hf=1",
|
35 |
# "--is_hf=0",
|
36 |
]
|
37 |
sys.exit(stcli.main())
|
vouchervision/API_validation.py
CHANGED
@@ -7,6 +7,9 @@ from vertexai.language_models import TextGenerationModel
|
|
7 |
from vertexai.preview.generative_models import GenerativeModel
|
8 |
from google.cloud import vision
|
9 |
from google.cloud import vision_v1p3beta1 as vision_beta
|
|
|
|
|
|
|
10 |
|
11 |
from datetime import datetime
|
12 |
import google.generativeai as genai
|
@@ -57,7 +60,7 @@ class APIvalidation:
|
|
57 |
model = AzureChatOpenAI(
|
58 |
deployment_name = 'gpt-35-turbo',#'gpt-35-turbo',
|
59 |
openai_api_version = self.cfg_private['openai_azure']['OPENAI_API_VERSION'],
|
60 |
-
openai_api_key = self.cfg_private['openai_azure']['
|
61 |
azure_endpoint = self.cfg_private['openai_azure']['OPENAI_API_BASE'],
|
62 |
openai_organization = self.cfg_private['openai_azure']['OPENAI_ORGANIZATION'],
|
63 |
)
|
@@ -171,7 +174,8 @@ class APIvalidation:
|
|
171 |
|
172 |
|
173 |
def check_google_vertex_genai_api_key(self):
|
174 |
-
results = {"palm2": False, "gemini": False}
|
|
|
175 |
|
176 |
try:
|
177 |
model = TextGenerationModel.from_pretrained("text-bison@001")
|
@@ -186,6 +190,24 @@ class APIvalidation:
|
|
186 |
except Exception as e:
|
187 |
# print(f"palm2 fail2 [{e}]")
|
188 |
print(f"palm2 fail2")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
|
191 |
try:
|
@@ -238,7 +260,7 @@ class APIvalidation:
|
|
238 |
k_opencage = os.getenv('OPENCAGE_API_KEY')
|
239 |
else:
|
240 |
k_OPENAI_API_KEY = self.cfg_private['openai']['OPENAI_API_KEY']
|
241 |
-
k_openai_azure = self.cfg_private['openai_azure']['
|
242 |
|
243 |
k_project_id = self.cfg_private['google']['GOOGLE_PROJECT_ID']
|
244 |
k_location = self.cfg_private['google']['GOOGLE_LOCATION']
|
@@ -295,6 +317,10 @@ class APIvalidation:
|
|
295 |
present_keys.append('Palm2 (Valid)')
|
296 |
else:
|
297 |
present_keys.append('Palm2 (Invalid)')
|
|
|
|
|
|
|
|
|
298 |
if google_results['gemini']:
|
299 |
present_keys.append('Gemini (Valid)')
|
300 |
else:
|
|
|
7 |
from vertexai.preview.generative_models import GenerativeModel
|
8 |
from google.cloud import vision
|
9 |
from google.cloud import vision_v1p3beta1 as vision_beta
|
10 |
+
# from langchain_google_genai import ChatGoogleGenerativeAI
|
11 |
+
from langchain_google_vertexai import VertexAI
|
12 |
+
|
13 |
|
14 |
from datetime import datetime
|
15 |
import google.generativeai as genai
|
|
|
60 |
model = AzureChatOpenAI(
|
61 |
deployment_name = 'gpt-35-turbo',#'gpt-35-turbo',
|
62 |
openai_api_version = self.cfg_private['openai_azure']['OPENAI_API_VERSION'],
|
63 |
+
openai_api_key = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE'],
|
64 |
azure_endpoint = self.cfg_private['openai_azure']['OPENAI_API_BASE'],
|
65 |
openai_organization = self.cfg_private['openai_azure']['OPENAI_ORGANIZATION'],
|
66 |
)
|
|
|
174 |
|
175 |
|
176 |
def check_google_vertex_genai_api_key(self):
|
177 |
+
results = {"palm2": False, "gemini": False, "palm2_langchain": False}
|
178 |
+
|
179 |
|
180 |
try:
|
181 |
model = TextGenerationModel.from_pretrained("text-bison@001")
|
|
|
190 |
except Exception as e:
|
191 |
# print(f"palm2 fail2 [{e}]")
|
192 |
print(f"palm2 fail2")
|
193 |
+
|
194 |
+
try:
|
195 |
+
# https://python.langchain.com/docs/integrations/llms/google_vertex_ai_palm
|
196 |
+
# os.environ['GOOGLE_API_KEY'] = "AIzaSyAHOH1w1qV7C3jS4W7QFyoaTGUwZIgS5ig"
|
197 |
+
# genai.configure(api_key='AIzaSyC8xvu6t9fb5dTah3hpgg_rwwR5G5kianI')
|
198 |
+
# model = ChatGoogleGenerativeAI(model="text-bison@001")
|
199 |
+
model = VertexAI(model="text-bison@001", max_output_tokens=10)
|
200 |
+
response = model.predict("Hello")
|
201 |
+
test_response_palm2 = response
|
202 |
+
if test_response_palm2:
|
203 |
+
results["palm2_langchain"] = True
|
204 |
+
print(f"palm2_langchain pass [{test_response_palm2}]")
|
205 |
+
else:
|
206 |
+
print(f"palm2_langchain fail [{test_response_palm2}]")
|
207 |
+
|
208 |
+
except Exception as e:
|
209 |
+
print(f"palm2 fail2 [{e}]")
|
210 |
+
print(f"palm2_langchain fail2")
|
211 |
|
212 |
|
213 |
try:
|
|
|
260 |
k_opencage = os.getenv('OPENCAGE_API_KEY')
|
261 |
else:
|
262 |
k_OPENAI_API_KEY = self.cfg_private['openai']['OPENAI_API_KEY']
|
263 |
+
k_openai_azure = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE']
|
264 |
|
265 |
k_project_id = self.cfg_private['google']['GOOGLE_PROJECT_ID']
|
266 |
k_location = self.cfg_private['google']['GOOGLE_LOCATION']
|
|
|
317 |
present_keys.append('Palm2 (Valid)')
|
318 |
else:
|
319 |
present_keys.append('Palm2 (Invalid)')
|
320 |
+
if google_results['palm2_langchain']:
|
321 |
+
present_keys.append('Palm2 LangChain (Valid)')
|
322 |
+
else:
|
323 |
+
present_keys.append('Palm2 LangChain (Invalid)')
|
324 |
if google_results['gemini']:
|
325 |
present_keys.append('Gemini (Valid)')
|
326 |
else:
|
vouchervision/LLM_GoogleGemini.py
CHANGED
@@ -7,6 +7,7 @@ from langchain.schema import HumanMessage
|
|
7 |
from langchain.prompts import PromptTemplate
|
8 |
from langchain_core.output_parsers import JsonOutputParser
|
9 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
|
10 |
|
11 |
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
12 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
@@ -74,9 +75,13 @@ class GoogleGeminiHandler:
|
|
74 |
|
75 |
def _build_model_chain_parser(self):
|
76 |
# Instantiate the LLM class for Google Gemini
|
77 |
-
self.llm_model = ChatGoogleGenerativeAI(model='gemini-pro',
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
80 |
# Set up the retry parser with the runnable
|
81 |
self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.llm_model, max_retries=self.MAX_RETRIES)
|
82 |
# Prepare the chain
|
|
|
7 |
from langchain.prompts import PromptTemplate
|
8 |
from langchain_core.output_parsers import JsonOutputParser
|
9 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
10 |
+
from langchain_google_vertexai import VertexAI
|
11 |
|
12 |
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
13 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
|
|
75 |
|
76 |
def _build_model_chain_parser(self):
|
77 |
# Instantiate the LLM class for Google Gemini
|
78 |
+
# self.llm_model = ChatGoogleGenerativeAI(model='gemini-pro',
|
79 |
+
# max_output_tokens=self.config.get('max_output_tokens'),
|
80 |
+
# top_p=self.config.get('top_p'))
|
81 |
+
self.llm_model = VertexAI(model='gemini-pro',
|
82 |
+
max_output_tokens=self.config.get('max_output_tokens'),
|
83 |
+
top_p=self.config.get('top_p'))
|
84 |
+
|
85 |
# Set up the retry parser with the runnable
|
86 |
self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.llm_model, max_retries=self.MAX_RETRIES)
|
87 |
# Prepare the chain
|
vouchervision/LLM_GooglePalm2.py
CHANGED
@@ -9,6 +9,7 @@ from langchain.schema import HumanMessage
|
|
9 |
from langchain.prompts import PromptTemplate
|
10 |
from langchain_core.output_parsers import JsonOutputParser
|
11 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
|
12 |
|
13 |
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
14 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
@@ -84,7 +85,8 @@ class GooglePalm2Handler:
|
|
84 |
|
85 |
def _build_model_chain_parser(self):
|
86 |
# Instantiate the parser and the retry parser
|
87 |
-
self.llm_model = ChatGoogleGenerativeAI(model=self.model_name)
|
|
|
88 |
|
89 |
self.retry_parser = RetryWithErrorOutputParser.from_llm(
|
90 |
parser=self.parser,
|
|
|
9 |
from langchain.prompts import PromptTemplate
|
10 |
from langchain_core.output_parsers import JsonOutputParser
|
11 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
12 |
+
from langchain_google_vertexai import VertexAI
|
13 |
|
14 |
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
15 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
|
|
85 |
|
86 |
def _build_model_chain_parser(self):
|
87 |
# Instantiate the parser and the retry parser
|
88 |
+
# self.llm_model = ChatGoogleGenerativeAI(model=self.model_name)
|
89 |
+
self.llm_model = VertexAI(model=self.model_name)
|
90 |
|
91 |
self.retry_parser = RetryWithErrorOutputParser.from_llm(
|
92 |
parser=self.parser,
|
vouchervision/OCR_google_cloud_vision.py
CHANGED
@@ -77,8 +77,8 @@ class OCRGoogle:
|
|
77 |
self.client_beta = vision_beta.ImageAnnotatorClient(credentials=self.get_google_credentials())
|
78 |
self.client = vision.ImageAnnotatorClient(credentials=self.get_google_credentials())
|
79 |
else:
|
80 |
-
self.client_beta = vision_beta.ImageAnnotatorClient()
|
81 |
-
self.client = vision.ImageAnnotatorClient()
|
82 |
|
83 |
|
84 |
def get_google_credentials(self):
|
@@ -86,7 +86,7 @@ class OCRGoogle:
|
|
86 |
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
87 |
return credentials
|
88 |
|
89 |
-
|
90 |
def detect_text_with_trOCR_using_google_bboxes(self, do_use_trOCR, logger):
|
91 |
CONFIDENCES = 0.80
|
92 |
MAX_NEW_TOKENS = 50
|
@@ -517,6 +517,7 @@ class OCRGoogle:
|
|
517 |
|
518 |
### Optionally add trOCR to the self.OCR for additional context
|
519 |
self.OCR = self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
|
|
520 |
|
521 |
if do_create_OCR_helper_image:
|
522 |
self.image = Image.open(self.path)
|
|
|
77 |
self.client_beta = vision_beta.ImageAnnotatorClient(credentials=self.get_google_credentials())
|
78 |
self.client = vision.ImageAnnotatorClient(credentials=self.get_google_credentials())
|
79 |
else:
|
80 |
+
self.client_beta = vision_beta.ImageAnnotatorClient(credentials=self.get_google_credentials())
|
81 |
+
self.client = vision.ImageAnnotatorClient(credentials=self.get_google_credentials())
|
82 |
|
83 |
|
84 |
def get_google_credentials(self):
|
|
|
86 |
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
87 |
return credentials
|
88 |
|
89 |
+
|
90 |
def detect_text_with_trOCR_using_google_bboxes(self, do_use_trOCR, logger):
|
91 |
CONFIDENCES = 0.80
|
92 |
MAX_NEW_TOKENS = 50
|
|
|
517 |
|
518 |
### Optionally add trOCR to the self.OCR for additional context
|
519 |
self.OCR = self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
520 |
+
logger.info(f"OCR:\n{self.OCR}")
|
521 |
|
522 |
if do_create_OCR_helper_image:
|
523 |
self.image = Image.open(self.path)
|
vouchervision/utils_VoucherVision.py
CHANGED
@@ -72,14 +72,23 @@ class VoucherVision():
|
|
72 |
|
73 |
self.catalog_name_options = ["Catalog Number", "catalog_number", "catalogNumber"]
|
74 |
|
75 |
-
self.
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
-
"GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
79 |
-
"GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
|
80 |
-
"GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",
|
81 |
|
82 |
-
"tokens_in", "tokens_out", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
|
|
|
|
|
83 |
|
84 |
self.do_create_OCR_helper_image = self.cfg['leafmachine']['do_create_OCR_helper_image']
|
85 |
|
@@ -100,6 +109,7 @@ class VoucherVision():
|
|
100 |
self.logger.info(f' Model name passed to API --> {self.model_name}')
|
101 |
self.logger.info(f' API access token is found in PRIVATE_DATA.yaml --> {self.has_key}')
|
102 |
|
|
|
103 |
def init_trOCR_model(self):
|
104 |
lgr = logging.getLogger('transformers')
|
105 |
lgr.setLevel(logging.ERROR)
|
@@ -111,13 +121,14 @@ class VoucherVision():
|
|
111 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
112 |
self.trOCR_model.to(self.device)
|
113 |
|
|
|
114 |
def map_API_options(self):
|
115 |
self.chat_version = self.cfg['leafmachine']['LLM_version']
|
116 |
|
117 |
# Get the required values from ModelMaps
|
118 |
self.model_name = ModelMaps.get_version_mapping_cost(self.chat_version)
|
119 |
self.is_azure = ModelMaps.get_version_mapping_is_azure(self.chat_version)
|
120 |
-
self.has_key = ModelMaps.get_version_has_key(self.chat_version, self.has_key_openai, self.has_key_azure_openai, self.
|
121 |
|
122 |
# Check if the version is supported
|
123 |
if self.model_name is None:
|
@@ -126,28 +137,18 @@ class VoucherVision():
|
|
126 |
|
127 |
self.version_name = self.chat_version
|
128 |
|
|
|
129 |
def map_prompt_versions(self):
|
130 |
self.prompt_version_map = {
|
131 |
"Version 1": "prompt_v1_verbose",
|
132 |
-
"Version 1 No Domain Knowledge": "prompt_v1_verbose_noDomainKnowledge",
|
133 |
-
"Version 2": "prompt_v2_json_rules",
|
134 |
-
"Version 1 PaLM 2": 'prompt_v1_palm2',
|
135 |
-
"Version 1 PaLM 2 No Domain Knowledge": 'prompt_v1_palm2_noDomainKnowledge',
|
136 |
-
"Version 2 PaLM 2": 'prompt_v2_palm2',
|
137 |
}
|
138 |
self.prompt_version = self.prompt_version_map.get(self.prompt_version0, self.path_custom_prompts)
|
139 |
self.is_predefined_prompt = self.is_in_prompt_version_map(self.prompt_version)
|
140 |
|
|
|
141 |
def is_in_prompt_version_map(self, value):
|
142 |
return value in self.prompt_version_map.values()
|
143 |
|
144 |
-
# def init_embeddings(self):
|
145 |
-
# if self.use_domain_knowledge:
|
146 |
-
# self.logger.info(f'*** USING DOMAIN KNOWLEDGE ***')
|
147 |
-
# self.logger.info(f'*** Initializing vector embeddings database ***')
|
148 |
-
# self.initialize_embeddings()
|
149 |
-
# else:
|
150 |
-
# self.Voucher_Vision_Embedding = None
|
151 |
|
152 |
def map_dir_labels(self):
|
153 |
if self.cfg['leafmachine']['use_RGB_label_images']:
|
@@ -158,6 +159,7 @@ class VoucherVision():
|
|
158 |
# Use glob to get all image paths in the directory
|
159 |
self.img_paths = glob.glob(os.path.join(self.dir_labels, "*"))
|
160 |
|
|
|
161 |
def load_rules_config(self):
|
162 |
with open(self.path_custom_prompts, 'r') as stream:
|
163 |
try:
|
@@ -166,6 +168,7 @@ class VoucherVision():
|
|
166 |
print(exc)
|
167 |
return None
|
168 |
|
|
|
169 |
def generate_xlsx_headers(self):
|
170 |
# Extract headers from the 'Dictionary' keys in the JSON template rules
|
171 |
# xlsx_headers = list(self.rules_config_json['rules']["Dictionary"].keys())
|
@@ -173,21 +176,10 @@ class VoucherVision():
|
|
173 |
xlsx_headers = xlsx_headers + self.utility_headers
|
174 |
return xlsx_headers
|
175 |
|
|
|
176 |
def init_transcription_xlsx(self):
|
177 |
-
# self.HEADERS_v1_n22 = ["Catalog Number","Genus","Species","subspecies","variety","forma","Country","State","County","Locality Name","Min Elevation","Max Elevation","Elevation Units","Verbatim Coordinates","Datum","Cultivated","Habitat","Collectors","Collector Number","Verbatim Date","Date","End Date"]
|
178 |
-
# self.HEADERS_v2_n26 = ["catalog_number","genus","species","subspecies","variety","forma","country","state","county","locality_name","min_elevation","max_elevation","elevation_units","verbatim_coordinates","decimal_coordinates","datum","cultivated","habitat","plant_description","collectors","collector_number","determined_by","multiple_names","verbatim_date","date","end_date"]
|
179 |
-
# self.HEADERS_v1_n22 = self.HEADERS_v1_n22 + self.utility_headers
|
180 |
-
# self.HEADERS_v2_n26 = self.HEADERS_v2_n26 + self.utility_headers
|
181 |
# Initialize output file
|
182 |
self.path_transcription = os.path.join(self.Dirs.transcription,"transcribed.xlsx")
|
183 |
-
|
184 |
-
# if self.prompt_version in ['prompt_v2_json_rules','prompt_v2_palm2']:
|
185 |
-
# self.headers = self.HEADERS_v2_n26
|
186 |
-
# self.headers_used = 'HEADERS_v2_n26'
|
187 |
-
|
188 |
-
# elif self.prompt_version in ['prompt_v1_verbose', 'prompt_v1_verbose_noDomainKnowledge','prompt_v1_palm2', 'prompt_v1_palm2_noDomainKnowledge']:
|
189 |
-
# self.headers = self.HEADERS_v1_n22
|
190 |
-
# self.headers_used = 'HEADERS_v1_n22'
|
191 |
|
192 |
# else:
|
193 |
if not self.is_predefined_prompt:
|
@@ -223,7 +215,6 @@ class VoucherVision():
|
|
223 |
except ValueError:
|
224 |
print("'path_to_crop' not found in the header row.")
|
225 |
|
226 |
-
|
227 |
path_to_crop = list(sheet.iter_cols(min_col=path_to_crop_col, max_col=path_to_crop_col, values_only=True, min_row=2))
|
228 |
path_to_original = list(sheet.iter_cols(min_col=path_to_original_col, max_col=path_to_original_col, values_only=True, min_row=2))
|
229 |
path_to_content = list(sheet.iter_cols(min_col=path_to_content_col, max_col=path_to_content_col, values_only=True, min_row=2))
|
@@ -303,14 +294,8 @@ class VoucherVision():
|
|
303 |
break
|
304 |
|
305 |
|
306 |
-
|
307 |
def add_data_to_excel_from_response(self, path_transcription, response, WFO_record, GEO_record, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
|
308 |
-
geo_headers = ["GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
309 |
-
"GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
|
310 |
-
"GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",]
|
311 |
|
312 |
-
# WFO_candidate_names is separate, bc it may be type --> list
|
313 |
-
wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
|
314 |
|
315 |
wb = openpyxl.load_workbook(path_transcription)
|
316 |
sheet = wb.active
|
@@ -376,7 +361,7 @@ class VoucherVision():
|
|
376 |
sheet.cell(row=next_row, column=i, value=filename_without_extension)
|
377 |
|
378 |
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
379 |
-
elif header.value in
|
380 |
sheet.cell(row=next_row, column=i, value=WFO_record.get(header.value, ''))
|
381 |
# elif header.value == "WFO_exact_match":
|
382 |
# sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_exact_match",''))
|
@@ -397,7 +382,7 @@ class VoucherVision():
|
|
397 |
|
398 |
# "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat", "GEO_decimal_long",
|
399 |
# "GEO_city", "GEO_county", "GEO_state", "GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent"
|
400 |
-
elif header.value in geo_headers:
|
401 |
sheet.cell(row=next_row, column=i, value=GEO_record.get(header.value, ''))
|
402 |
|
403 |
# save the workbook
|
@@ -447,7 +432,7 @@ class VoucherVision():
|
|
447 |
self.cfg_private = get_cfg_from_full_path(self.path_cfg_private)
|
448 |
|
449 |
k_openai = self.cfg_private['openai']['OPENAI_API_KEY']
|
450 |
-
k_openai_azure = self.cfg_private['openai_azure']['
|
451 |
|
452 |
k_google_project_id = self.cfg_private['google']['GOOGLE_PROJECT_ID']
|
453 |
k_google_location = self.cfg_private['google']['GOOGLE_LOCATION']
|
@@ -505,7 +490,7 @@ class VoucherVision():
|
|
505 |
self.llm = AzureChatOpenAI(
|
506 |
deployment_name = 'gpt-35-turbo',#'gpt-35-turbo',
|
507 |
openai_api_version = self.cfg_private['openai_azure']['OPENAI_API_VERSION'],
|
508 |
-
openai_api_key = self.cfg_private['openai_azure']['
|
509 |
azure_endpoint = self.cfg_private['openai_azure']['OPENAI_API_BASE'],
|
510 |
openai_organization = self.cfg_private['openai_azure']['OPENAI_ORGANIZATION'],
|
511 |
)
|
|
|
72 |
|
73 |
self.catalog_name_options = ["Catalog Number", "catalog_number", "catalogNumber"]
|
74 |
|
75 |
+
self.geo_headers = ["GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
76 |
+
"GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
|
77 |
+
"GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",]
|
78 |
+
|
79 |
+
self.wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"]
|
80 |
+
self.wfo_headers_no_lists = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
|
81 |
+
|
82 |
+
self.utility_headers = ["filename"] + self.wfo_headers + self.geo_headers + ["tokens_in", "tokens_out", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
|
83 |
+
# "WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement",
|
84 |
|
85 |
+
# "GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
86 |
+
# "GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
|
87 |
+
# "GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",
|
88 |
|
89 |
+
# "tokens_in", "tokens_out", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
|
90 |
+
|
91 |
+
# WFO_candidate_names is separate, bc it may be type --> list
|
92 |
|
93 |
self.do_create_OCR_helper_image = self.cfg['leafmachine']['do_create_OCR_helper_image']
|
94 |
|
|
|
109 |
self.logger.info(f' Model name passed to API --> {self.model_name}')
|
110 |
self.logger.info(f' API access token is found in PRIVATE_DATA.yaml --> {self.has_key}')
|
111 |
|
112 |
+
|
113 |
def init_trOCR_model(self):
|
114 |
lgr = logging.getLogger('transformers')
|
115 |
lgr.setLevel(logging.ERROR)
|
|
|
121 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
122 |
self.trOCR_model.to(self.device)
|
123 |
|
124 |
+
|
125 |
def map_API_options(self):
|
126 |
self.chat_version = self.cfg['leafmachine']['LLM_version']
|
127 |
|
128 |
# Get the required values from ModelMaps
|
129 |
self.model_name = ModelMaps.get_version_mapping_cost(self.chat_version)
|
130 |
self.is_azure = ModelMaps.get_version_mapping_is_azure(self.chat_version)
|
131 |
+
self.has_key = ModelMaps.get_version_has_key(self.chat_version, self.has_key_openai, self.has_key_azure_openai, self.has_key_google_application_credentials, self.has_key_mistral)
|
132 |
|
133 |
# Check if the version is supported
|
134 |
if self.model_name is None:
|
|
|
137 |
|
138 |
self.version_name = self.chat_version
|
139 |
|
140 |
+
|
141 |
def map_prompt_versions(self):
|
142 |
self.prompt_version_map = {
|
143 |
"Version 1": "prompt_v1_verbose",
|
|
|
|
|
|
|
|
|
|
|
144 |
}
|
145 |
self.prompt_version = self.prompt_version_map.get(self.prompt_version0, self.path_custom_prompts)
|
146 |
self.is_predefined_prompt = self.is_in_prompt_version_map(self.prompt_version)
|
147 |
|
148 |
+
|
149 |
def is_in_prompt_version_map(self, value):
|
150 |
return value in self.prompt_version_map.values()
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
def map_dir_labels(self):
|
154 |
if self.cfg['leafmachine']['use_RGB_label_images']:
|
|
|
159 |
# Use glob to get all image paths in the directory
|
160 |
self.img_paths = glob.glob(os.path.join(self.dir_labels, "*"))
|
161 |
|
162 |
+
|
163 |
def load_rules_config(self):
|
164 |
with open(self.path_custom_prompts, 'r') as stream:
|
165 |
try:
|
|
|
168 |
print(exc)
|
169 |
return None
|
170 |
|
171 |
+
|
172 |
def generate_xlsx_headers(self):
|
173 |
# Extract headers from the 'Dictionary' keys in the JSON template rules
|
174 |
# xlsx_headers = list(self.rules_config_json['rules']["Dictionary"].keys())
|
|
|
176 |
xlsx_headers = xlsx_headers + self.utility_headers
|
177 |
return xlsx_headers
|
178 |
|
179 |
+
|
180 |
def init_transcription_xlsx(self):
|
|
|
|
|
|
|
|
|
181 |
# Initialize output file
|
182 |
self.path_transcription = os.path.join(self.Dirs.transcription,"transcribed.xlsx")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
# else:
|
185 |
if not self.is_predefined_prompt:
|
|
|
215 |
except ValueError:
|
216 |
print("'path_to_crop' not found in the header row.")
|
217 |
|
|
|
218 |
path_to_crop = list(sheet.iter_cols(min_col=path_to_crop_col, max_col=path_to_crop_col, values_only=True, min_row=2))
|
219 |
path_to_original = list(sheet.iter_cols(min_col=path_to_original_col, max_col=path_to_original_col, values_only=True, min_row=2))
|
220 |
path_to_content = list(sheet.iter_cols(min_col=path_to_content_col, max_col=path_to_content_col, values_only=True, min_row=2))
|
|
|
294 |
break
|
295 |
|
296 |
|
|
|
297 |
def add_data_to_excel_from_response(self, path_transcription, response, WFO_record, GEO_record, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
|
|
|
|
|
|
|
298 |
|
|
|
|
|
299 |
|
300 |
wb = openpyxl.load_workbook(path_transcription)
|
301 |
sheet = wb.active
|
|
|
361 |
sheet.cell(row=next_row, column=i, value=filename_without_extension)
|
362 |
|
363 |
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
364 |
+
elif header.value in self.wfo_headers_no_lists:
|
365 |
sheet.cell(row=next_row, column=i, value=WFO_record.get(header.value, ''))
|
366 |
# elif header.value == "WFO_exact_match":
|
367 |
# sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_exact_match",''))
|
|
|
382 |
|
383 |
# "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat", "GEO_decimal_long",
|
384 |
# "GEO_city", "GEO_county", "GEO_state", "GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent"
|
385 |
+
elif header.value in self.geo_headers:
|
386 |
sheet.cell(row=next_row, column=i, value=GEO_record.get(header.value, ''))
|
387 |
|
388 |
# save the workbook
|
|
|
432 |
self.cfg_private = get_cfg_from_full_path(self.path_cfg_private)
|
433 |
|
434 |
k_openai = self.cfg_private['openai']['OPENAI_API_KEY']
|
435 |
+
k_openai_azure = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE']
|
436 |
|
437 |
k_google_project_id = self.cfg_private['google']['GOOGLE_PROJECT_ID']
|
438 |
k_google_location = self.cfg_private['google']['GOOGLE_LOCATION']
|
|
|
490 |
self.llm = AzureChatOpenAI(
|
491 |
deployment_name = 'gpt-35-turbo',#'gpt-35-turbo',
|
492 |
openai_api_version = self.cfg_private['openai_azure']['OPENAI_API_VERSION'],
|
493 |
+
openai_api_key = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE'],
|
494 |
azure_endpoint = self.cfg_private['openai_azure']['OPENAI_API_BASE'],
|
495 |
openai_organization = self.cfg_private['openai_azure']['OPENAI_ORGANIZATION'],
|
496 |
)
|
vouchervision/utils_hf.py
CHANGED
@@ -42,15 +42,44 @@ def save_uploaded_file(directory, img_file, image=None):
|
|
42 |
os.makedirs(directory)
|
43 |
# Assuming the uploaded file is an image
|
44 |
if image is None:
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
full_path = os.path.join(directory, img_file.name)
|
47 |
image.save(full_path, "JPEG")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
# Return the full path of the saved image
|
49 |
-
return
|
50 |
else:
|
51 |
-
full_path = os.path.join(directory,
|
52 |
image.save(full_path, "JPEG")
|
53 |
-
return
|
54 |
|
55 |
def image_to_base64(img):
|
56 |
buffered = BytesIO()
|
|
|
42 |
os.makedirs(directory)
|
43 |
# Assuming the uploaded file is an image
|
44 |
if image is None:
|
45 |
+
try:
|
46 |
+
with Image.open(img_file) as image:
|
47 |
+
full_path = os.path.join(directory, img_file.name)
|
48 |
+
image.save(full_path, "JPEG")
|
49 |
+
# Return the full path of the saved image
|
50 |
+
return full_path
|
51 |
+
except:
|
52 |
+
with Image.open(os.path.join(directory,img_file)) as image:
|
53 |
+
full_path = os.path.join(directory, img_file)
|
54 |
+
image.save(full_path, "JPEG")
|
55 |
+
# Return the full path of the saved image
|
56 |
+
return full_path
|
57 |
+
else:
|
58 |
+
try:
|
59 |
full_path = os.path.join(directory, img_file.name)
|
60 |
image.save(full_path, "JPEG")
|
61 |
+
return full_path
|
62 |
+
except:
|
63 |
+
full_path = os.path.join(directory, img_file)
|
64 |
+
image.save(full_path, "JPEG")
|
65 |
+
return full_path
|
66 |
+
|
67 |
+
def save_uploaded_local(directory, img_file, image=None):
|
68 |
+
name = img_file.split(os.path.sep)[-1]
|
69 |
+
if not os.path.exists(directory):
|
70 |
+
os.makedirs(directory)
|
71 |
+
|
72 |
+
# Assuming the uploaded file is an image
|
73 |
+
if image is None:
|
74 |
+
with Image.open(img_file) as image:
|
75 |
+
full_path = os.path.join(directory, name)
|
76 |
+
image.save(full_path, "JPEG")
|
77 |
# Return the full path of the saved image
|
78 |
+
return os.path.join('uploads_small',name)
|
79 |
else:
|
80 |
+
full_path = os.path.join(directory, name)
|
81 |
image.save(full_path, "JPEG")
|
82 |
+
return os.path.join('.','uploads_small',name)
|
83 |
|
84 |
def image_to_base64(img):
|
85 |
buffered = BytesIO()
|