Spaces:
Running
Running
phyloforfun
commited on
Commit
•
c5e57d6
1
Parent(s):
712822d
July 18 update
Browse files- app.py +65 -3
- img/collage.jpg +3 -0
- requirements.txt +0 -0
- vouchervision/API_validation.py +97 -13
- vouchervision/LLM_crewAI.py +176 -0
- vouchervision/LLM_local_MistralAI.py +61 -110
- vouchervision/LLM_local_custom_fine_tune.py +358 -0
- vouchervision/OCR_Florence_2.py +88 -0
- vouchervision/OCR_google_cloud_vision (DESKTOP-548UDCR's conflicted copy 2024-06-15).py +850 -0
- vouchervision/OCR_google_cloud_vision.py +37 -16
- vouchervision/VoucherVision_Config_Builder.py +11 -6
- vouchervision/fetch_data.py +1 -1
- vouchervision/generate_partner_collage.py +104 -0
- vouchervision/librarian_knowledge.json +27 -0
- vouchervision/save_dataset.py +34 -0
- vouchervision/utils_VoucherVision.py +18 -0
- vouchervision/utils_hf (DESKTOP-548UDCR's conflicted copy 2024-06-15).py +266 -0
- vouchervision/utils_hf.py +42 -6
app.py
CHANGED
@@ -144,6 +144,8 @@ if 'present_annotations' not in st.session_state:
|
|
144 |
st.session_state['present_annotations'] = None
|
145 |
if 'missing_annotations' not in st.session_state:
|
146 |
st.session_state['missing_annotations'] = None
|
|
|
|
|
147 |
if 'date_of_check' not in st.session_state:
|
148 |
st.session_state['date_of_check'] = None
|
149 |
|
@@ -1016,6 +1018,16 @@ def create_private_file():
|
|
1016 |
st.write("Leave keys blank if you do not intend to use that service.")
|
1017 |
st.info("Note: You can manually edit these API keys later by opening the /PRIVATE_DATA.yaml file in a plain text editor.")
|
1018 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1019 |
st.write("---")
|
1020 |
st.subheader("Google Vision (*Required*) / Google PaLM 2 / Google Gemini")
|
1021 |
st.markdown("VoucherVision currently uses [Google Vision API](https://cloud.google.com/vision/docs/ocr) for OCR. Generating an API key for this is more involved than the others. [Please carefully follow the instructions outlined here to create and setup your account.](https://cloud.google.com/vision/docs/setup) ")
|
@@ -1170,6 +1182,7 @@ def create_private_file():
|
|
1170 |
st.button("Set API Keys",type='primary', on_click=save_changes_to_API_keys,
|
1171 |
args=[cfg_private,
|
1172 |
openai_api_key,
|
|
|
1173 |
azure_openai_api_version, azure_openai_api_key, azure_openai_api_base, azure_openai_organization, azure_openai_api_type,
|
1174 |
google_application_credentials, google_project_location, google_project_id,
|
1175 |
mistral_API_KEY,
|
@@ -1183,12 +1196,15 @@ def create_private_file():
|
|
1183 |
|
1184 |
def save_changes_to_API_keys(cfg_private,
|
1185 |
openai_api_key,
|
|
|
1186 |
azure_openai_api_version, azure_openai_api_key, azure_openai_api_base, azure_openai_organization, azure_openai_api_type,
|
1187 |
google_application_credentials, google_project_location, google_project_id,
|
1188 |
mistral_API_KEY,
|
1189 |
here_APP_ID, here_API_KEY):
|
1190 |
|
1191 |
# Update the configuration dictionary with the new values
|
|
|
|
|
1192 |
cfg_private['openai']['OPENAI_API_KEY'] = openai_api_key
|
1193 |
|
1194 |
cfg_private['openai_azure']['OPENAI_API_VERSION'] = azure_openai_api_version
|
@@ -1269,8 +1285,19 @@ def display_api_key_status(ccol):
|
|
1269 |
# Convert keys to annotations (similar to what you do in check_api_key_status)
|
1270 |
present_annotations = []
|
1271 |
missing_annotations = []
|
|
|
1272 |
for key in present_keys:
|
1273 |
-
if "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1274 |
show_text = key.split('(')[0]
|
1275 |
present_annotations.append((show_text, "ready!", "#059c1b")) # Green for valid
|
1276 |
elif "Invalid" in key:
|
@@ -1279,6 +1306,7 @@ def display_api_key_status(ccol):
|
|
1279 |
|
1280 |
st.session_state['present_annotations'] = present_annotations
|
1281 |
st.session_state['missing_annotations'] = missing_annotations
|
|
|
1282 |
st.session_state['date_of_check'] = date_of_check
|
1283 |
st.session_state['API_checked'] = True
|
1284 |
# print('for')
|
@@ -1307,6 +1335,14 @@ def display_api_key_status(ccol):
|
|
1307 |
if 'missing_annotations' in st.session_state and st.session_state['missing_annotations']:
|
1308 |
annotated_text(*st.session_state['missing_annotations'])
|
1309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1310 |
|
1311 |
|
1312 |
def check_api_key_status():
|
@@ -1322,8 +1358,19 @@ def check_api_key_status():
|
|
1322 |
# Prepare annotations for present keys
|
1323 |
present_annotations = []
|
1324 |
missing_annotations = []
|
|
|
1325 |
for key in present_keys:
|
1326 |
-
if "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1327 |
show_text = key.split('(')[0]
|
1328 |
present_annotations.append((show_text, "ready!", "#059c1b")) # Green for valid
|
1329 |
elif "Invalid" in key:
|
@@ -1340,6 +1387,7 @@ def check_api_key_status():
|
|
1340 |
|
1341 |
st.session_state['present_annotations'] = present_annotations
|
1342 |
st.session_state['missing_annotations'] = missing_annotations
|
|
|
1343 |
st.session_state['date_of_check'] = date_of_check
|
1344 |
|
1345 |
|
@@ -1831,7 +1879,7 @@ def content_ocr_method():
|
|
1831 |
demo_text_trh = demo_text_h + '\n' + demo_text_tr
|
1832 |
demo_text_trp = demo_text_p + '\n' + demo_text_tr
|
1833 |
|
1834 |
-
options = ["Google Vision Handwritten", "Google Vision Printed", "CRAFT + trOCR","LLaVA"]
|
1835 |
options_llava = ["llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",]
|
1836 |
options_llava_bit = ["full", "4bit",]
|
1837 |
captions_llava = [
|
@@ -1882,6 +1930,7 @@ def content_ocr_method():
|
|
1882 |
"Google Vision Printed": 'normal',
|
1883 |
"CRAFT + trOCR": 'CRAFT',
|
1884 |
"LLaVA": 'LLaVA',
|
|
|
1885 |
}
|
1886 |
|
1887 |
# Map selected options to their corresponding internal representations
|
@@ -1914,6 +1963,19 @@ def content_ocr_method():
|
|
1914 |
else:
|
1915 |
st.session_state.config['leafmachine']['project']['trOCR_model_path'] = user_input_trOCR_model_path
|
1916 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1917 |
if 'LLaVA' in selected_OCR_options:
|
1918 |
OCR_option_llava = st.radio(
|
1919 |
"Select the LLaVA version",
|
|
|
144 |
st.session_state['present_annotations'] = None
|
145 |
if 'missing_annotations' not in st.session_state:
|
146 |
st.session_state['missing_annotations'] = None
|
147 |
+
if 'model_annotations' not in st.session_state:
|
148 |
+
st.session_state['model_annotations'] = None
|
149 |
if 'date_of_check' not in st.session_state:
|
150 |
st.session_state['date_of_check'] = None
|
151 |
|
|
|
1018 |
st.write("Leave keys blank if you do not intend to use that service.")
|
1019 |
st.info("Note: You can manually edit these API keys later by opening the /PRIVATE_DATA.yaml file in a plain text editor.")
|
1020 |
|
1021 |
+
st.write("---")
|
1022 |
+
st.subheader("Hugging Face (*Required For Local LLMs*)")
|
1023 |
+
st.markdown("VoucherVision relies on LLM models from Hugging Face. Some models are 'gated', meaning that you have to agree to the creator's usage guidelines.")
|
1024 |
+
st.markdown("""Create a [Hugging Face account](https://huggingface.co/join). Once your account is created, in your profile settings [navigate to 'Access Tokens'](https://huggingface.co/settings/tokens) and click 'Create new token'. Create a token that has 'Read' privileges. Copy the token into the field below.""")
|
1025 |
+
|
1026 |
+
hugging_face_token = st.text_input(label = 'Hugging Face token', value = cfg_private['huggingface'].get('hf_token', ''),
|
1027 |
+
placeholder = 'e.g. hf_GNRLIUBnvfkjvnf....',
|
1028 |
+
help ="This is your Hugging Face access token. It only needs Read access. Please see https://huggingface.co/settings/tokens",
|
1029 |
+
type='password')
|
1030 |
+
|
1031 |
st.write("---")
|
1032 |
st.subheader("Google Vision (*Required*) / Google PaLM 2 / Google Gemini")
|
1033 |
st.markdown("VoucherVision currently uses [Google Vision API](https://cloud.google.com/vision/docs/ocr) for OCR. Generating an API key for this is more involved than the others. [Please carefully follow the instructions outlined here to create and setup your account.](https://cloud.google.com/vision/docs/setup) ")
|
|
|
1182 |
st.button("Set API Keys",type='primary', on_click=save_changes_to_API_keys,
|
1183 |
args=[cfg_private,
|
1184 |
openai_api_key,
|
1185 |
+
hugging_face_token,
|
1186 |
azure_openai_api_version, azure_openai_api_key, azure_openai_api_base, azure_openai_organization, azure_openai_api_type,
|
1187 |
google_application_credentials, google_project_location, google_project_id,
|
1188 |
mistral_API_KEY,
|
|
|
1196 |
|
1197 |
def save_changes_to_API_keys(cfg_private,
|
1198 |
openai_api_key,
|
1199 |
+
hugging_face_token,
|
1200 |
azure_openai_api_version, azure_openai_api_key, azure_openai_api_base, azure_openai_organization, azure_openai_api_type,
|
1201 |
google_application_credentials, google_project_location, google_project_id,
|
1202 |
mistral_API_KEY,
|
1203 |
here_APP_ID, here_API_KEY):
|
1204 |
|
1205 |
# Update the configuration dictionary with the new values
|
1206 |
+
cfg_private['huggingface']['hf_token'] = hugging_face_token
|
1207 |
+
|
1208 |
cfg_private['openai']['OPENAI_API_KEY'] = openai_api_key
|
1209 |
|
1210 |
cfg_private['openai_azure']['OPENAI_API_VERSION'] = azure_openai_api_version
|
|
|
1285 |
# Convert keys to annotations (similar to what you do in check_api_key_status)
|
1286 |
present_annotations = []
|
1287 |
missing_annotations = []
|
1288 |
+
model_annotations = []
|
1289 |
for key in present_keys:
|
1290 |
+
if "[MODEL]" in key:
|
1291 |
+
show_text = key.split(']')[1]
|
1292 |
+
show_text = show_text.split('(')[0]
|
1293 |
+
if 'Under Review' in key:
|
1294 |
+
model_annotations.append((show_text, "under review", "#9C0586")) # Green for valid
|
1295 |
+
elif 'invalid' in key:
|
1296 |
+
model_annotations.append((show_text, "error!", "#870307")) # Green for valid
|
1297 |
+
else:
|
1298 |
+
model_annotations.append((show_text, "ready!", "#059c1b")) # Green for valid
|
1299 |
+
|
1300 |
+
elif "Valid" in key:
|
1301 |
show_text = key.split('(')[0]
|
1302 |
present_annotations.append((show_text, "ready!", "#059c1b")) # Green for valid
|
1303 |
elif "Invalid" in key:
|
|
|
1306 |
|
1307 |
st.session_state['present_annotations'] = present_annotations
|
1308 |
st.session_state['missing_annotations'] = missing_annotations
|
1309 |
+
st.session_state['model_annotations'] = model_annotations
|
1310 |
st.session_state['date_of_check'] = date_of_check
|
1311 |
st.session_state['API_checked'] = True
|
1312 |
# print('for')
|
|
|
1335 |
if 'missing_annotations' in st.session_state and st.session_state['missing_annotations']:
|
1336 |
annotated_text(*st.session_state['missing_annotations'])
|
1337 |
|
1338 |
+
if not st.session_state['is_hf']:
|
1339 |
+
st.markdown(f"Access to Hugging Face Models")
|
1340 |
+
|
1341 |
+
if 'model_annotations' in st.session_state and st.session_state['model_annotations']:
|
1342 |
+
annotated_text(*st.session_state['model_annotations'])
|
1343 |
+
|
1344 |
+
|
1345 |
+
|
1346 |
|
1347 |
|
1348 |
def check_api_key_status():
|
|
|
1358 |
# Prepare annotations for present keys
|
1359 |
present_annotations = []
|
1360 |
missing_annotations = []
|
1361 |
+
model_annotations = []
|
1362 |
for key in present_keys:
|
1363 |
+
if "[MODEL]" in key:
|
1364 |
+
show_text = key.split(']')[1]
|
1365 |
+
show_text = show_text.split('(')[0]
|
1366 |
+
if 'Under Review' in key:
|
1367 |
+
model_annotations.append((show_text, "under review", "#9C0586")) # Green for valid
|
1368 |
+
elif 'invalid' in key:
|
1369 |
+
model_annotations.append((show_text, "error!", "#870307")) # Green for valid
|
1370 |
+
else:
|
1371 |
+
model_annotations.append((show_text, "ready!", "#059c1b")) # Green for valid
|
1372 |
+
|
1373 |
+
elif "Valid" in key:
|
1374 |
show_text = key.split('(')[0]
|
1375 |
present_annotations.append((show_text, "ready!", "#059c1b")) # Green for valid
|
1376 |
elif "Invalid" in key:
|
|
|
1387 |
|
1388 |
st.session_state['present_annotations'] = present_annotations
|
1389 |
st.session_state['missing_annotations'] = missing_annotations
|
1390 |
+
st.session_state['model_annotations'] = model_annotations
|
1391 |
st.session_state['date_of_check'] = date_of_check
|
1392 |
|
1393 |
|
|
|
1879 |
demo_text_trh = demo_text_h + '\n' + demo_text_tr
|
1880 |
demo_text_trp = demo_text_p + '\n' + demo_text_tr
|
1881 |
|
1882 |
+
options = ["Google Vision Handwritten", "Google Vision Printed", "CRAFT + trOCR","LLaVA", "Florence-2"]
|
1883 |
options_llava = ["llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",]
|
1884 |
options_llava_bit = ["full", "4bit",]
|
1885 |
captions_llava = [
|
|
|
1930 |
"Google Vision Printed": 'normal',
|
1931 |
"CRAFT + trOCR": 'CRAFT',
|
1932 |
"LLaVA": 'LLaVA',
|
1933 |
+
"Florence-2": 'Florence-2',
|
1934 |
}
|
1935 |
|
1936 |
# Map selected options to their corresponding internal representations
|
|
|
1963 |
else:
|
1964 |
st.session_state.config['leafmachine']['project']['trOCR_model_path'] = user_input_trOCR_model_path
|
1965 |
|
1966 |
+
|
1967 |
+
if "Florence-2" in selected_OCR_options:
|
1968 |
+
default_florence_model_path = st.session_state.config['leafmachine']['project']['florence_model_path']
|
1969 |
+
user_input_florence_model_path = st.text_input("Florence-2 Hugging Face model path. MUST be a Florence-2 version based on 'microsoft/Florence-2-large' or similar.", value=default_florence_model_path)
|
1970 |
+
|
1971 |
+
if st.session_state.config['leafmachine']['project']['florence_model_path'] != user_input_florence_model_path:
|
1972 |
+
is_valid_mp = is_valid_huggingface_model_path(user_input_florence_model_path)
|
1973 |
+
if not is_valid_mp:
|
1974 |
+
st.error(f"The Hugging Face model path {user_input_florence_model_path} is not valid. Please revise.")
|
1975 |
+
else:
|
1976 |
+
st.session_state.config['leafmachine']['project']['florence_model_path'] = user_input_florence_model_path
|
1977 |
+
|
1978 |
+
|
1979 |
if 'LLaVA' in selected_OCR_options:
|
1980 |
OCR_option_llava = st.radio(
|
1981 |
"Select the LLaVA version",
|
img/collage.jpg
ADDED
Git LFS Details
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
vouchervision/API_validation.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os, io, openai, vertexai, json, tempfile
|
|
|
2 |
from mistralai.client import MistralClient
|
3 |
from mistralai.models.chat_completion import ChatMessage
|
4 |
from langchain.schema import HumanMessage
|
@@ -9,7 +10,7 @@ from google.cloud import vision
|
|
9 |
from google.cloud import vision_v1p3beta1 as vision_beta
|
10 |
# from langchain_google_genai import ChatGoogleGenerativeAI
|
11 |
from langchain_google_vertexai import VertexAI
|
12 |
-
|
13 |
|
14 |
from datetime import datetime
|
15 |
# import google.generativeai as genai
|
@@ -17,6 +18,8 @@ from google.oauth2 import service_account
|
|
17 |
# from googleapiclient.discovery import build
|
18 |
|
19 |
|
|
|
|
|
20 |
class APIvalidation:
|
21 |
|
22 |
def __init__(self, cfg_private, dir_home, is_hf) -> None:
|
@@ -25,6 +28,13 @@ class APIvalidation:
|
|
25 |
self.is_hf = is_hf
|
26 |
self.formatted_date = self.get_formatted_date()
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
def get_formatted_date(self):
|
29 |
# Get the current date
|
30 |
current_date = datetime.now()
|
@@ -59,7 +69,7 @@ class APIvalidation:
|
|
59 |
try:
|
60 |
# Initialize the Azure OpenAI client
|
61 |
model = AzureChatOpenAI(
|
62 |
-
deployment_name = 'gpt-
|
63 |
openai_api_version = self.cfg_private['openai_azure']['OPENAI_API_VERSION'],
|
64 |
openai_api_key = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE'],
|
65 |
azure_endpoint = self.cfg_private['openai_azure']['OPENAI_API_BASE'],
|
@@ -67,7 +77,7 @@ class APIvalidation:
|
|
67 |
)
|
68 |
msg = HumanMessage(content="hello")
|
69 |
# self.llm_object.temperature = self.config.get('temperature')
|
70 |
-
response = model([msg])
|
71 |
|
72 |
# Check the response content (you might need to adjust this depending on how your AzureChatOpenAI class handles responses)
|
73 |
if response:
|
@@ -85,7 +95,7 @@ class APIvalidation:
|
|
85 |
azure_organization = os.getenv('AZURE_ORGANIZATION')
|
86 |
# Initialize the Azure OpenAI client
|
87 |
model = AzureChatOpenAI(
|
88 |
-
deployment_name = 'gpt-
|
89 |
openai_api_version = azure_api_version,
|
90 |
openai_api_key = azure_api_key,
|
91 |
azure_endpoint = azure_api_base,
|
@@ -93,7 +103,7 @@ class APIvalidation:
|
|
93 |
)
|
94 |
msg = HumanMessage(content="hello")
|
95 |
# self.llm_object.temperature = self.config.get('temperature')
|
96 |
-
response = model([msg])
|
97 |
|
98 |
# Check the response content (you might need to adjust this depending on how your AzureChatOpenAI class handles responses)
|
99 |
if response:
|
@@ -223,8 +233,55 @@ class APIvalidation:
|
|
223 |
|
224 |
return results
|
225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
|
|
|
|
|
228 |
def get_google_credentials(self):
|
229 |
if self.is_hf:
|
230 |
creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
|
@@ -251,6 +308,8 @@ class APIvalidation:
|
|
251 |
k_google_application_credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
|
252 |
k_project_id = os.getenv('GOOGLE_PROJECT_ID')
|
253 |
k_location = os.getenv('GOOGLE_LOCATION')
|
|
|
|
|
254 |
|
255 |
k_mistral = os.getenv('MISTRAL_API_KEY')
|
256 |
k_here = os.getenv('HERE_API_KEY')
|
@@ -259,6 +318,8 @@ class APIvalidation:
|
|
259 |
k_OPENAI_API_KEY = self.cfg_private['openai']['OPENAI_API_KEY']
|
260 |
k_openai_azure = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE']
|
261 |
|
|
|
|
|
262 |
k_project_id = self.cfg_private['google']['GOOGLE_PROJECT_ID']
|
263 |
k_location = self.cfg_private['google']['GOOGLE_LOCATION']
|
264 |
k_google_application_credentials = self.cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS']
|
@@ -284,6 +345,29 @@ class APIvalidation:
|
|
284 |
present_keys.append('Google OCR Handwriting (Invalid)')
|
285 |
else:
|
286 |
missing_keys.append('Google OCR')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
|
288 |
|
289 |
# OpenAI key check
|
@@ -297,14 +381,14 @@ class APIvalidation:
|
|
297 |
missing_keys.append('OpenAI')
|
298 |
|
299 |
# Azure OpenAI key check
|
300 |
-
if self.has_API_key(k_openai_azure):
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
else:
|
307 |
-
|
308 |
|
309 |
# Google PALM2/Gemini key check
|
310 |
if self.has_API_key(k_google_application_credentials) and self.has_API_key(k_project_id) and self.has_API_key(k_location): ##################
|
|
|
1 |
import os, io, openai, vertexai, json, tempfile
|
2 |
+
import webbrowser
|
3 |
from mistralai.client import MistralClient
|
4 |
from mistralai.models.chat_completion import ChatMessage
|
5 |
from langchain.schema import HumanMessage
|
|
|
10 |
from google.cloud import vision_v1p3beta1 as vision_beta
|
11 |
# from langchain_google_genai import ChatGoogleGenerativeAI
|
12 |
from langchain_google_vertexai import VertexAI
|
13 |
+
from huggingface_hub import HfApi, HfFolder
|
14 |
|
15 |
from datetime import datetime
|
16 |
# import google.generativeai as genai
|
|
|
18 |
# from googleapiclient.discovery import build
|
19 |
|
20 |
|
21 |
+
|
22 |
+
|
23 |
class APIvalidation:
|
24 |
|
25 |
def __init__(self, cfg_private, dir_home, is_hf) -> None:
|
|
|
28 |
self.is_hf = is_hf
|
29 |
self.formatted_date = self.get_formatted_date()
|
30 |
|
31 |
+
self.HF_MODEL_LIST = ['microsoft/Florence-2-large','microsoft/Florence-2-base',
|
32 |
+
'microsoft/trocr-base-handwritten','microsoft/trocr-large-handwritten',
|
33 |
+
'google/gemma-2-9b','google/gemma-2-9b-it','google/gemma-2-27b','google/gemma-2-27b-it',
|
34 |
+
'mistralai/Mistral-7B-Instruct-v0.3','mistralai/Mixtral-8x22B-v0.1','mistralai/Mixtral-8x22B-Instruct-v0.1',
|
35 |
+
'unsloth/mistral-7b-instruct-v0.3-bnb-4bit'
|
36 |
+
]
|
37 |
+
|
38 |
def get_formatted_date(self):
|
39 |
# Get the current date
|
40 |
current_date = datetime.now()
|
|
|
69 |
try:
|
70 |
# Initialize the Azure OpenAI client
|
71 |
model = AzureChatOpenAI(
|
72 |
+
deployment_name = 'gpt-4',#'gpt-35-turbo',
|
73 |
openai_api_version = self.cfg_private['openai_azure']['OPENAI_API_VERSION'],
|
74 |
openai_api_key = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE'],
|
75 |
azure_endpoint = self.cfg_private['openai_azure']['OPENAI_API_BASE'],
|
|
|
77 |
)
|
78 |
msg = HumanMessage(content="hello")
|
79 |
# self.llm_object.temperature = self.config.get('temperature')
|
80 |
+
response = model.invoke([msg])
|
81 |
|
82 |
# Check the response content (you might need to adjust this depending on how your AzureChatOpenAI class handles responses)
|
83 |
if response:
|
|
|
95 |
azure_organization = os.getenv('AZURE_ORGANIZATION')
|
96 |
# Initialize the Azure OpenAI client
|
97 |
model = AzureChatOpenAI(
|
98 |
+
deployment_name = 'gpt-4',#'gpt-35-turbo',
|
99 |
openai_api_version = azure_api_version,
|
100 |
openai_api_key = azure_api_key,
|
101 |
azure_endpoint = azure_api_base,
|
|
|
103 |
)
|
104 |
msg = HumanMessage(content="hello")
|
105 |
# self.llm_object.temperature = self.config.get('temperature')
|
106 |
+
response = model.invoke([msg])
|
107 |
|
108 |
# Check the response content (you might need to adjust this depending on how your AzureChatOpenAI class handles responses)
|
109 |
if response:
|
|
|
233 |
|
234 |
return results
|
235 |
|
236 |
+
def test_hf_token(self, k_huggingface):
|
237 |
+
if not k_huggingface:
|
238 |
+
print("Hugging Face API token not found in environment variables.")
|
239 |
+
return False
|
240 |
+
|
241 |
+
# Create an instance of the API
|
242 |
+
api = HfApi()
|
243 |
+
|
244 |
+
try:
|
245 |
+
# Try to get details of a known public model
|
246 |
+
model_info = api.model_info("bert-base-uncased", use_auth_token=k_huggingface)
|
247 |
+
if model_info:
|
248 |
+
print("Token is valid. Accessed model details successfully.")
|
249 |
+
return True
|
250 |
+
else:
|
251 |
+
print("Token is valid but failed to access model details.")
|
252 |
+
return True
|
253 |
+
except Exception as e:
|
254 |
+
print(f"Failed to validate token: {e}")
|
255 |
+
return False
|
256 |
+
|
257 |
+
def check_gated_model_access(self, model_id, k_huggingface):
|
258 |
+
api = HfApi()
|
259 |
+
attempts = 0
|
260 |
+
max_attempts = 2
|
261 |
+
|
262 |
+
while attempts < max_attempts:
|
263 |
+
try:
|
264 |
+
model_info = api.model_info(model_id, use_auth_token=k_huggingface)
|
265 |
+
print(f"Access to model '{model_id}' is granted.")
|
266 |
+
return "valid"
|
267 |
+
except Exception as e:
|
268 |
+
error_message = str(e)
|
269 |
+
if 'awaiting a review' in error_message:
|
270 |
+
print(f"Access to model '{model_id}' is awaiting review. (Under Review)")
|
271 |
+
return "under_review"
|
272 |
+
print(f"Access to model '{model_id}' is denied. Please accept the terms and conditions.")
|
273 |
+
print(f"Error: {e}")
|
274 |
+
webbrowser.open(f"https://huggingface.co/{model_id}")
|
275 |
+
input("Press Enter after you have accepted the terms and conditions...")
|
276 |
+
|
277 |
+
attempts += 1
|
278 |
+
|
279 |
+
print(f"Failed to access model '{model_id}' after {max_attempts} attempts.")
|
280 |
+
return "invalid"
|
281 |
|
282 |
|
283 |
+
|
284 |
+
|
285 |
def get_google_credentials(self):
|
286 |
if self.is_hf:
|
287 |
creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
|
|
|
308 |
k_google_application_credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
|
309 |
k_project_id = os.getenv('GOOGLE_PROJECT_ID')
|
310 |
k_location = os.getenv('GOOGLE_LOCATION')
|
311 |
+
|
312 |
+
k_huggingface = None
|
313 |
|
314 |
k_mistral = os.getenv('MISTRAL_API_KEY')
|
315 |
k_here = os.getenv('HERE_API_KEY')
|
|
|
318 |
k_OPENAI_API_KEY = self.cfg_private['openai']['OPENAI_API_KEY']
|
319 |
k_openai_azure = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE']
|
320 |
|
321 |
+
k_huggingface = self.cfg_private['huggingface']['hf_token']
|
322 |
+
|
323 |
k_project_id = self.cfg_private['google']['GOOGLE_PROJECT_ID']
|
324 |
k_location = self.cfg_private['google']['GOOGLE_LOCATION']
|
325 |
k_google_application_credentials = self.cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS']
|
|
|
345 |
present_keys.append('Google OCR Handwriting (Invalid)')
|
346 |
else:
|
347 |
missing_keys.append('Google OCR')
|
348 |
+
|
349 |
+
# present_keys.append('[MODEL] TEST (Under Review)')
|
350 |
+
|
351 |
+
# HF key check
|
352 |
+
if self.has_API_key(k_huggingface):
|
353 |
+
is_valid = self.test_hf_token(k_huggingface)
|
354 |
+
if is_valid:
|
355 |
+
present_keys.append('Hugging Face Local LLMs (Valid)')
|
356 |
+
else:
|
357 |
+
present_keys.append('Hugging Face Local LLMs (Invalid)')
|
358 |
+
else:
|
359 |
+
missing_keys.append('Hugging Face Local LLMs')
|
360 |
+
|
361 |
+
# List of gated models to check access for
|
362 |
+
for model_id in self.HF_MODEL_LIST:
|
363 |
+
access_status = self.check_gated_model_access(model_id, k_huggingface)
|
364 |
+
if access_status == "valid":
|
365 |
+
present_keys.append(f'[MODEL] {model_id} (Valid)')
|
366 |
+
elif access_status == "under_review":
|
367 |
+
present_keys.append(f'[MODEL] {model_id} (Under Review)')
|
368 |
+
else:
|
369 |
+
present_keys.append(f'[MODEL] {model_id} (Invalid)')
|
370 |
+
|
371 |
|
372 |
|
373 |
# OpenAI key check
|
|
|
381 |
missing_keys.append('OpenAI')
|
382 |
|
383 |
# Azure OpenAI key check
|
384 |
+
# if self.has_API_key(k_openai_azure):
|
385 |
+
# is_valid = self.check_azure_openai_api_key()
|
386 |
+
# if is_valid:
|
387 |
+
# present_keys.append('Azure OpenAI (Valid)')
|
388 |
+
# else:
|
389 |
+
# present_keys.append('Azure OpenAI (Invalid)')
|
390 |
+
# else:
|
391 |
+
# missing_keys.append('Azure OpenAI')
|
392 |
|
393 |
# Google PALM2/Gemini key check
|
394 |
if self.has_API_key(k_google_application_credentials) and self.has_API_key(k_project_id) and self.has_API_key(k_location): ##################
|
vouchervision/LLM_crewAI.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time, os, json
|
2 |
+
import torch
|
3 |
+
from crewai import Agent, Task, Crew, Process
|
4 |
+
from langchain.prompts import PromptTemplate
|
5 |
+
from langchain_openai import ChatOpenAI, OpenAI
|
6 |
+
from langchain.schema import HumanMessage
|
7 |
+
from langchain_core.output_parsers import JsonOutputParser
|
8 |
+
from langchain.output_parsers import RetryWithErrorOutputParser
|
9 |
+
|
10 |
+
class VoucherVisionWorkflow:
|
11 |
+
MODEL = 'gpt-4o'
|
12 |
+
SHARED_INSTRUCTIONS = """
|
13 |
+
instructions:
|
14 |
+
1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
|
15 |
+
2. Map the unstructured OCR text to the appropriate JSON key and populate the field given the user-defined rules.
|
16 |
+
3. JSON key values are permitted to remain empty strings if the corresponding information is not found in the unstructured OCR text.
|
17 |
+
4. Duplicate dictionary fields are not allowed.
|
18 |
+
5. Ensure all JSON keys are in camel case.
|
19 |
+
6. Ensure new JSON field values follow sentence case capitalization.
|
20 |
+
7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
21 |
+
8. Ensure output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
22 |
+
9. Only return a JSON dictionary represented as a string. You should not explain your answer.
|
23 |
+
|
24 |
+
JSON structure:
|
25 |
+
{"catalogNumber": "", "scientificName": "", "genus": "", "specificEpithet": "", "speciesNameAuthorship": "", "collectedBy": "", "collectorNumber": "", "identifiedBy": "", "verbatimCollectionDate": "", "collectionDate": "", "collectionDateEnd": "", "occurrenceRemarks": "", "habitat": "", "cultivated": "", "country": "", "stateProvince": "", "county": "", "locality": "", "verbatimCoordinates": "", "decimalLatitude": "", "decimalLongitude": "", "minimumElevationInMeters": "", "maximumElevationInMeters": "", "elevationUnits": ""}
|
26 |
+
"""
|
27 |
+
|
28 |
+
EXPECTED_OUTPUT_STRUCTURE = """{
|
29 |
+
"JSON_OUTPUT": {
|
30 |
+
"catalogNumber": "", "scientificName": "", "genus": "", "specificEpithet": "",
|
31 |
+
"speciesNameAuthorship": "", "collectedBy": "", "collectorNumber": "",
|
32 |
+
"identifiedBy": "", "verbatimCollectionDate": "", "collectionDate": "",
|
33 |
+
"collectionDateEnd": "", "occurrenceRemarks": "", "habitat": "", "cultivated": "",
|
34 |
+
"country": "", "stateProvince": "", "county": "", "locality": "",
|
35 |
+
"verbatimCoordinates": "", "decimalLatitude": "", "decimalLongitude": "",
|
36 |
+
"minimumElevationInMeters": "", "maximumElevationInMeters": "", "elevationUnits": ""
|
37 |
+
},
|
38 |
+
"explanation": ""
|
39 |
+
}"""
|
40 |
+
|
41 |
+
def __init__(self, api_key, librarian_knowledge_path):
|
42 |
+
self.api_key = api_key
|
43 |
+
os.environ['OPENAI_API_KEY'] = self.api_key
|
44 |
+
|
45 |
+
self.librarian_knowledge = self.load_librarian_knowledge(librarian_knowledge_path)
|
46 |
+
self.worker_agent = self.create_worker_agent()
|
47 |
+
self.supervisor_agent = self.create_supervisor_agent()
|
48 |
+
|
49 |
+
def load_librarian_knowledge(self, path):
|
50 |
+
with open(path) as f:
|
51 |
+
return json.load(f)
|
52 |
+
|
53 |
+
def query_librarian(self, guideline_field):
|
54 |
+
print(f"query_librarian: {guideline_field}")
|
55 |
+
return self.librarian_knowledge.get(guideline_field, "Guideline not found.")
|
56 |
+
|
57 |
+
def create_worker_agent(self):
|
58 |
+
return Agent(
|
59 |
+
role="Transcriber and JSON Formatter",
|
60 |
+
goal="Transcribe product labels accurately and format them into a structured JSON dictionary. Only return a JSON dictionary.",
|
61 |
+
backstory="You're an AI trained to transcribe product labels and format them into JSON.",
|
62 |
+
verbose=True,
|
63 |
+
allow_delegation=False,
|
64 |
+
llm=ChatOpenAI(model=self.MODEL, openai_api_key=self.api_key),
|
65 |
+
prompt_instructions=self.SHARED_INSTRUCTIONS
|
66 |
+
)
|
67 |
+
|
68 |
+
def create_supervisor_agent(self):
|
69 |
+
class SupervisorAgent(Agent):
|
70 |
+
def correct_with_librarian(self, workflow, transcription, json_dict, guideline_field):
|
71 |
+
guideline = workflow.query_librarian(guideline_field)
|
72 |
+
corrected_transcription = self.correct(transcription, guideline)
|
73 |
+
corrected_json = self.correct_json(json_dict, guideline)
|
74 |
+
explanation = f"Corrected {json_dict} based on guideline {guideline_field}: {guideline}"
|
75 |
+
return corrected_transcription, {"JSON_OUTPUT": corrected_json, "explanation": explanation}
|
76 |
+
|
77 |
+
return SupervisorAgent(
|
78 |
+
role="Corrector",
|
79 |
+
goal="Ensure accurate transcriptions and JSON formatting according to specific guidelines. Compare the OCR text to the JSON dictionary and make any required corrections. Given your knowledge, make sure that the values in the JSON object make sense given the cumulative context of the OCR text. If you correct the provided JSON, then state the corrections. Otherwise say that the original worker was correct.",
|
80 |
+
backstory="You're an AI trained to correct transcriptions and JSON formatting, consulting the librarian for guidance.",
|
81 |
+
verbose=True,
|
82 |
+
allow_delegation=False,
|
83 |
+
llm=ChatOpenAI(model=self.MODEL, openai_api_key=self.api_key),
|
84 |
+
prompt_instructions=self.SHARED_INSTRUCTIONS
|
85 |
+
)
|
86 |
+
|
87 |
+
def extract_json_from_string(self, input_string):
|
88 |
+
json_pattern = re.compile(r'\{(?:[^{}]|(?R))*\}')
|
89 |
+
match = json_pattern.search(input_string)
|
90 |
+
if match:
|
91 |
+
return match.group(0)
|
92 |
+
return None
|
93 |
+
|
94 |
+
def extract_json_via_api(self, text):
|
95 |
+
self.api_key = self.api_key
|
96 |
+
extraction_prompt = f"I only need the JSON inside this text. Please return only the JSON object.\n\n{text}"
|
97 |
+
response = openai.ChatCompletion.create(
|
98 |
+
model=self.MODEL,
|
99 |
+
messages=[
|
100 |
+
{"role": "system", "content": extraction_prompt}
|
101 |
+
]
|
102 |
+
)
|
103 |
+
return self.extract_json_from_string(response['choices'][0]['message']['content'])
|
104 |
+
|
105 |
+
|
106 |
+
def run_workflow(self, ocr_text):
|
107 |
+
openai_model = ChatOpenAI(api_key=self.api_key, model=self.MODEL)
|
108 |
+
|
109 |
+
self.worker_agent.llm = openai_model
|
110 |
+
self.supervisor_agent.llm = openai_model
|
111 |
+
|
112 |
+
transcription_and_formatting_task = Task(
|
113 |
+
description=f"Transcribe product label and format into JSON. OCR text: {ocr_text}",
|
114 |
+
agent=self.worker_agent,
|
115 |
+
inputs={"ocr_text": ocr_text},
|
116 |
+
expected_output=self.EXPECTED_OUTPUT_STRUCTURE
|
117 |
+
)
|
118 |
+
|
119 |
+
crew = Crew(
|
120 |
+
agents=[self.worker_agent],
|
121 |
+
tasks=[transcription_and_formatting_task],
|
122 |
+
verbose=True,
|
123 |
+
process=Process.sequential,
|
124 |
+
)
|
125 |
+
|
126 |
+
# Run the transcription and formatting task
|
127 |
+
transcription_and_formatting_result = transcription_and_formatting_task.execute()
|
128 |
+
print("Worker Output JSON:", transcription_and_formatting_result)
|
129 |
+
|
130 |
+
# Pass the worker's JSON output to the supervisor for correction
|
131 |
+
correction_task = Task(
|
132 |
+
description=f"Correct transcription and JSON format. OCR text: {ocr_text}",
|
133 |
+
agent=self.supervisor_agent,
|
134 |
+
inputs={"ocr_text": ocr_text, "json_dict": transcription_and_formatting_result},
|
135 |
+
expected_output=self.EXPECTED_OUTPUT_STRUCTURE,
|
136 |
+
workflow=self # Pass the workflow instance to the task
|
137 |
+
)
|
138 |
+
|
139 |
+
correction_result = correction_task.execute()
|
140 |
+
|
141 |
+
try:
|
142 |
+
corrected_json_with_explanation = json.loads(correction_result)
|
143 |
+
except json.JSONDecodeError:
|
144 |
+
# If initial parsing fails, make a call to OpenAI to extract only the JSON
|
145 |
+
corrected_json_string = self.extract_json_via_api(correction_result)
|
146 |
+
if not corrected_json_string:
|
147 |
+
raise ValueError("No JSON found in the supervisor's output.")
|
148 |
+
corrected_json_with_explanation = json.loads(corrected_json_string)
|
149 |
+
|
150 |
+
corrected_json = corrected_json_with_explanation["JSON_OUTPUT"]
|
151 |
+
explanation = corrected_json_with_explanation["explanation"]
|
152 |
+
|
153 |
+
print("Supervisor Corrected JSON:", corrected_json)
|
154 |
+
print("\nCorrection Explanation:", explanation)
|
155 |
+
|
156 |
+
return corrected_json, explanation
|
157 |
+
|
158 |
+
if __name__ == "__main__":
|
159 |
+
api_key = ""
|
160 |
+
librarian_knowledge_path = "D:/Dropbox/VoucherVision/vouchervision/librarian_knowledge.json"
|
161 |
+
|
162 |
+
ocr_text = "HERBARIUM OF MARYGROVE COLLEGE Name Carex scoparia V. condensa Fernald Locality Interlaken , Ind . Date 7/20/25 No ... ! Gerould Wilhelm & Laura Rericha \" Interlaken , \" was the site for many years of St. Joseph Novitiate , run by the Brothers of the Holy Cross . The buildings were on the west shore of Silver Lake , about 2 miles NE of Rolling Prairie , LaPorte Co. Indiana , ca. 41.688 \u00b0 N , 86.601 \u00b0 W Collector : Sister M. Vincent de Paul McGivney February 1 , 2011 THE UNIVERS Examined for the Flora of the Chicago Region OF 1817 MICH ! Ciscoparia SMVdeP University of Michigan Herbarium 1386297 copyright reserved cm Collector wortet 2010"
|
163 |
+
workflow = VoucherVisionWorkflow(api_key, librarian_knowledge_path)
|
164 |
+
workflow.run_workflow(ocr_text)
|
165 |
+
|
166 |
+
ocr_text = "CM 1 2 3 QUE : Mt.Jac.Cartier Parc de la Gasp\u00e9sie 13 Aug1988 CA Vogt on Solidago MCZ - ENT OXO Bombus vagans Smith ' det C.A. Vogt 1988 UIUC USDA BBDP 021159 00817079 "
|
167 |
+
workflow = VoucherVisionWorkflow(api_key, librarian_knowledge_path)
|
168 |
+
workflow.run_workflow(ocr_text)
|
169 |
+
|
170 |
+
ocr_text = "500 200 600 300 dots per inch ( optical ) 700 400 800 500 850 550 850 550 Golden Thread inches centimeters 500 200 600 300 dots per inch ( optical ) 11116 L * 39.12 65.43 49.87 44.26 b * 15.07 18.72 -22.29 22.85 4 -4.34 -13.80 3 13.24 18.11 2 1 5 9 7 11 ( A ) 10 -0.40 48.55 55.56 70.82 63.51 39.92 52.24 97.06 92.02 9.82 -33.43 34.26 11.81 -24.49 -0.35 59.60 -46.07 18.51 8 6 12 13 14 15 87.34 82.14 72.06 62.15 09.0- -0.75 -1.06 -1.19 -1.07 1.13 0.23 0.21 0.43 0.28 0.19 800 500 D50 Illuminant , 2 degree observer Density 0.04 0.09 0.15 0.22 Fam . Saurauiaceae J. G. Agardh Saurauia nepaulensis DC . S. Vietnam , Prov . Kontum . NW slopes of Ngoc Linh mountain system at 1200 m alt . near Ngoc Linh village . Secondary marshland with grasses and shrubs . Tree up to 5 m high . Flowers light rosy - pink . No VH 007 0.36 0.51 23.02.1995 International Botanical Expedition of the U.S.A. National Geographic Society ( grant No 5094-93 ) Participants : L. Averyanov , N.T. Ban , N. Q. Binh , A. Budantzev , L. Budantzev , N.T. Hiep , D.D. Huyen , P.K. Loc , N.X. Tam , G. Yakovlev BOTANICAL RESEARCH INSTITUTE OF TEXAS BRIT610199 Botanical Research Institute of Texas IMAGED 08 JUN 2021 FLORA OF VIETNAM "
|
171 |
+
workflow = VoucherVisionWorkflow(api_key, librarian_knowledge_path)
|
172 |
+
workflow.run_workflow(ocr_text)
|
173 |
+
|
174 |
+
ocr_text = "Russian - Vietnamese Tropical Centre Styrax argentifolius H.L. Li SOUTHERN VIETNAM Dak Lak prov . , Lak distr . , Bong Krang municip . Chu Yang Sin National Park 10 km S from Krong Kmar village River bank N 12 \u00b0 25 ' 24 \" E 108 \u00b0 21 ' 04 \" elev . 900 m Nuraliev M.S. No 1004 part of MW 0750340 29.05.2014 Materials of complex expedition in spring 2014 BOTANICAL RESEARCH INSTITUTE OF TEXAS ( BRIT ) Styrax benzoides Craib Det . by Peter W. Fritsch , September 2017 0 1 2 3 4 5 6 7 8 9 10 BOTANICAL RESEARCH INSTITUTE OF TEXAS BOTANICAL IMAGED RESEARCH INSTITUTE OF 10 JAN 2013 BRIT402114 copyright reserved cm BOTANICAL RESEARCH INSTITUTE OF TEXAS TM P CameraTrax.com BRIT . TEXAS "
|
175 |
+
workflow = VoucherVisionWorkflow(api_key, librarian_knowledge_path)
|
176 |
+
workflow.run_workflow(ocr_text)
|
vouchervision/LLM_local_MistralAI.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
-
import json,
|
|
|
|
|
|
|
2 |
from transformers import BitsAndBytesConfig
|
3 |
-
from langchain.output_parsers import
|
4 |
from langchain.prompts import PromptTemplate
|
5 |
from langchain_core.output_parsers import JsonOutputParser
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
-
from
|
8 |
-
|
9 |
from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
|
10 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
11 |
|
@@ -14,7 +16,7 @@ Local Pipielines:
|
|
14 |
https://python.langchain.com/docs/integrations/llms/huggingface_pipelines
|
15 |
'''
|
16 |
|
17 |
-
class LocalMistralHandler:
|
18 |
RETRY_DELAY = 2 # Wait 2 seconds before retrying
|
19 |
MAX_RETRIES = 5 # Maximum number of retries
|
20 |
STARTING_TEMP = 0.1
|
@@ -27,29 +29,22 @@ class LocalMistralHandler:
|
|
27 |
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
28 |
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
29 |
self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
|
30 |
-
|
31 |
self.logger = logger
|
32 |
self.has_GPU = torch.cuda.is_available()
|
33 |
self.monitor = SystemLoadMonitor(logger)
|
34 |
|
35 |
self.model_name = model_name
|
36 |
self.model_id = f"mistralai/{self.model_name}"
|
37 |
-
|
38 |
-
|
39 |
-
self.model_path = hf_hub_download(repo_id=self.model_id, repo_type="model",filename="config.json")
|
40 |
-
|
41 |
|
42 |
self.JSON_dict_structure = JSON_dict_structure
|
43 |
self.starting_temp = float(self.STARTING_TEMP)
|
44 |
self.temp_increment = float(0.2)
|
45 |
-
self.adjust_temp = self.starting_temp
|
46 |
-
|
47 |
-
system_prompt = "You are a helpful AI assistant who answers queries a JSON dictionary as specified by the user."
|
48 |
-
template = """
|
49 |
-
<s>[INST]{}[/INST]</s>
|
50 |
|
51 |
-
|
52 |
-
|
53 |
|
54 |
# Create a prompt from the template so we can use it with Langchain
|
55 |
self.prompt = PromptTemplate(template=template, input_variables=["query"])
|
@@ -59,45 +54,22 @@ class LocalMistralHandler:
|
|
59 |
|
60 |
self._set_config()
|
61 |
|
62 |
-
|
63 |
-
# def _clear_VRAM(self):
|
64 |
-
# # Clear CUDA cache if it's being used
|
65 |
-
# if self.has_GPU:
|
66 |
-
# self.local_model = None
|
67 |
-
# self.local_model_pipeline = None
|
68 |
-
# del self.local_model
|
69 |
-
# del self.local_model_pipeline
|
70 |
-
# gc.collect() # Explicitly invoke garbage collector
|
71 |
-
# torch.cuda.empty_cache()
|
72 |
-
# else:
|
73 |
-
# self.local_model_pipeline = None
|
74 |
-
# self.local_model = None
|
75 |
-
# del self.local_model_pipeline
|
76 |
-
# del self.local_model
|
77 |
-
# gc.collect() # Explicitly invoke garbage collector
|
78 |
-
|
79 |
-
|
80 |
def _set_config(self):
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
# Activate nested quantization for 4-bit base models (double quantization)
|
97 |
-
'use_nested_quant': False,
|
98 |
-
}
|
99 |
-
|
100 |
-
compute_dtype = getattr(torch,self.config.get('bnb_4bit_compute_dtype') )
|
101 |
|
102 |
self.bnb_config = BitsAndBytesConfig(
|
103 |
load_in_4bit=self.config.get('use_4bit'),
|
@@ -106,123 +78,102 @@ class LocalMistralHandler:
|
|
106 |
bnb_4bit_use_double_quant=self.config.get('use_nested_quant'),
|
107 |
)
|
108 |
|
109 |
-
# Check GPU compatibility with bfloat16
|
110 |
if compute_dtype == torch.float16 and self.config.get('use_4bit'):
|
111 |
major, _ = torch.cuda.get_device_capability()
|
112 |
-
if major >= 8
|
113 |
-
|
114 |
-
# print("Your GPU supports bfloat16: accelerate training with bf16=True")
|
115 |
-
# print("=" * 80)
|
116 |
-
self.b_float_opt = torch.bfloat16
|
117 |
-
|
118 |
-
else:
|
119 |
-
self.b_float_opt = torch.float16
|
120 |
self._build_model_chain_parser()
|
121 |
-
|
122 |
|
123 |
def _adjust_config(self):
|
124 |
new_temp = self.adjust_temp + self.temp_increment
|
125 |
-
if self.json_report:
|
126 |
-
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
127 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
128 |
self.adjust_temp += self.temp_increment
|
129 |
|
130 |
-
|
131 |
def _reset_config(self):
|
132 |
-
if self.json_report:
|
133 |
-
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
134 |
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
135 |
self.adjust_temp = self.starting_temp
|
136 |
-
|
137 |
|
138 |
def _build_model_chain_parser(self):
|
139 |
-
self.local_model_pipeline = transformers.pipeline(
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
self.local_model = HuggingFacePipeline(pipeline=self.local_model_pipeline)
|
|
|
149 |
# Set up the retry parser with the runnable
|
150 |
-
self.retry_parser =
|
151 |
# Create an llm chain with LLM and prompt
|
152 |
-
self.chain = self.prompt | self.local_model
|
153 |
-
|
154 |
|
155 |
def call_llm_local_MistralAI(self, prompt_template, json_report, paths):
|
156 |
-
|
157 |
self.json_report = json_report
|
158 |
if self.json_report:
|
159 |
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
160 |
self.monitor.start_monitoring_usage()
|
161 |
-
|
162 |
nt_in = 0
|
163 |
nt_out = 0
|
164 |
|
165 |
-
ind
|
166 |
-
while ind < self.MAX_RETRIES:
|
167 |
-
ind += 1
|
168 |
try:
|
169 |
-
# Dynamically set the temperature for this specific request
|
170 |
model_kwargs = {"temperature": self.adjust_temp}
|
171 |
-
|
172 |
-
# Invoke the chain to generate prompt text
|
173 |
results = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})
|
174 |
|
175 |
-
# Use retry_parser to parse the response with retry logic
|
176 |
output = self.retry_parser.parse_with_prompt(results, prompt_value=prompt_template)
|
177 |
|
178 |
if output is None:
|
179 |
self.logger.error(f'Failed to extract JSON from:\n{results}')
|
180 |
self._adjust_config()
|
181 |
del results
|
182 |
-
|
183 |
else:
|
184 |
nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
|
185 |
nt_out = count_tokens(results, self.VENDOR, self.TOKENIZER_NAME)
|
186 |
|
187 |
output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
|
188 |
-
|
189 |
if output is None:
|
190 |
-
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{results}')
|
191 |
self._adjust_config()
|
192 |
else:
|
193 |
-
self.monitor.stop_inference_timer()
|
194 |
-
|
195 |
if self.json_report:
|
196 |
self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
197 |
-
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(
|
|
|
|
|
198 |
|
199 |
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
200 |
|
201 |
-
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
202 |
|
203 |
-
usage_report = self.monitor.stop_monitoring_report_usage()
|
204 |
|
205 |
-
if self.adjust_temp != self.starting_temp:
|
206 |
self._reset_config()
|
207 |
|
208 |
if self.json_report:
|
209 |
self.json_report.set_text(text_main=f'LLM call successful')
|
210 |
del results
|
211 |
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
212 |
-
|
213 |
except Exception as e:
|
214 |
self.logger.error(f'{e}')
|
215 |
-
self._adjust_config()
|
216 |
-
|
217 |
-
self.logger.info(f"Failed to extract valid JSON after [{
|
218 |
if self.json_report:
|
219 |
-
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{
|
220 |
|
221 |
-
self.monitor.stop_inference_timer()
|
222 |
-
usage_report = self.monitor.stop_monitoring_report_usage()
|
223 |
if self.json_report:
|
224 |
self.json_report.set_text(text_main=f'LLM call failed')
|
225 |
|
226 |
self._reset_config()
|
227 |
-
return None, nt_in, nt_out, None, None, usage_report
|
228 |
-
|
|
|
1 |
+
import json, os
|
2 |
+
import torch
|
3 |
+
import transformers
|
4 |
+
import gc
|
5 |
from transformers import BitsAndBytesConfig
|
6 |
+
from langchain.output_parsers.retry import RetryOutputParser
|
7 |
from langchain.prompts import PromptTemplate
|
8 |
from langchain_core.output_parsers import JsonOutputParser
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
+
from langchain_huggingface import HuggingFacePipeline
|
|
|
11 |
from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
|
12 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
13 |
|
|
|
16 |
https://python.langchain.com/docs/integrations/llms/huggingface_pipelines
|
17 |
'''
|
18 |
|
19 |
+
class LocalMistralHandler:
|
20 |
RETRY_DELAY = 2 # Wait 2 seconds before retrying
|
21 |
MAX_RETRIES = 5 # Maximum number of retries
|
22 |
STARTING_TEMP = 0.1
|
|
|
29 |
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
30 |
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
31 |
self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
|
32 |
+
|
33 |
self.logger = logger
|
34 |
self.has_GPU = torch.cuda.is_available()
|
35 |
self.monitor = SystemLoadMonitor(logger)
|
36 |
|
37 |
self.model_name = model_name
|
38 |
self.model_id = f"mistralai/{self.model_name}"
|
39 |
+
self.model_path = hf_hub_download(repo_id=self.model_id, repo_type="model", filename="config.json", use_auth_token=os.getenv("HUGGING_FACE_KEY"))
|
|
|
|
|
|
|
40 |
|
41 |
self.JSON_dict_structure = JSON_dict_structure
|
42 |
self.starting_temp = float(self.STARTING_TEMP)
|
43 |
self.temp_increment = float(0.2)
|
44 |
+
self.adjust_temp = self.starting_temp
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
system_prompt = "You are a helpful AI assistant who answers queries by returning a JSON dictionary as specified by the user."
|
47 |
+
template = "<s>[INST]{}[/INST]</s>[INST]{}[/INST]".format(system_prompt, "{query}")
|
48 |
|
49 |
# Create a prompt from the template so we can use it with Langchain
|
50 |
self.prompt = PromptTemplate(template=template, input_variables=["query"])
|
|
|
54 |
|
55 |
self._set_config()
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def _set_config(self):
|
58 |
+
self.config = {
|
59 |
+
'max_new_tokens': 1024,
|
60 |
+
'temperature': self.starting_temp,
|
61 |
+
'seed': 2023,
|
62 |
+
'top_p': 1,
|
63 |
+
'top_k': 40,
|
64 |
+
'do_sample': True,
|
65 |
+
'n_ctx': 4096,
|
66 |
+
'use_4bit': True,
|
67 |
+
'bnb_4bit_compute_dtype': "float16",
|
68 |
+
'bnb_4bit_quant_type': "nf4",
|
69 |
+
'use_nested_quant': False,
|
70 |
+
}
|
71 |
+
|
72 |
+
compute_dtype = getattr(torch, self.config.get('bnb_4bit_compute_dtype'))
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
self.bnb_config = BitsAndBytesConfig(
|
75 |
load_in_4bit=self.config.get('use_4bit'),
|
|
|
78 |
bnb_4bit_use_double_quant=self.config.get('use_nested_quant'),
|
79 |
)
|
80 |
|
|
|
81 |
if compute_dtype == torch.float16 and self.config.get('use_4bit'):
|
82 |
major, _ = torch.cuda.get_device_capability()
|
83 |
+
self.b_float_opt = torch.bfloat16 if major >= 8 else torch.float16
|
84 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
self._build_model_chain_parser()
|
|
|
86 |
|
87 |
def _adjust_config(self):
|
88 |
new_temp = self.adjust_temp + self.temp_increment
|
|
|
|
|
89 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
90 |
self.adjust_temp += self.temp_increment
|
91 |
|
|
|
92 |
def _reset_config(self):
|
|
|
|
|
93 |
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
94 |
self.adjust_temp = self.starting_temp
|
|
|
95 |
|
96 |
def _build_model_chain_parser(self):
|
97 |
+
self.local_model_pipeline = transformers.pipeline(
|
98 |
+
"text-generation",
|
99 |
+
model=self.model_id,
|
100 |
+
max_new_tokens=self.config.get('max_new_tokens'),
|
101 |
+
top_k=self.config.get('top_k'),
|
102 |
+
top_p=self.config.get('top_p'),
|
103 |
+
do_sample=self.config.get('do_sample'),
|
104 |
+
model_kwargs={"torch_dtype": self.b_float_opt, "quantization_config": self.bnb_config},
|
105 |
+
)
|
106 |
self.local_model = HuggingFacePipeline(pipeline=self.local_model_pipeline)
|
107 |
+
|
108 |
# Set up the retry parser with the runnable
|
109 |
+
self.retry_parser = RetryOutputParser(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
|
110 |
# Create an llm chain with LLM and prompt
|
111 |
+
self.chain = self.prompt | self.local_model
|
|
|
112 |
|
113 |
def call_llm_local_MistralAI(self, prompt_template, json_report, paths):
|
114 |
+
json_file_path_wiki, txt_file_path_ind_prompt = paths[-2:]
|
115 |
self.json_report = json_report
|
116 |
if self.json_report:
|
117 |
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
118 |
self.monitor.start_monitoring_usage()
|
119 |
+
|
120 |
nt_in = 0
|
121 |
nt_out = 0
|
122 |
|
123 |
+
for ind in range(self.MAX_RETRIES):
|
|
|
|
|
124 |
try:
|
|
|
125 |
model_kwargs = {"temperature": self.adjust_temp}
|
|
|
|
|
126 |
results = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})
|
127 |
|
|
|
128 |
output = self.retry_parser.parse_with_prompt(results, prompt_value=prompt_template)
|
129 |
|
130 |
if output is None:
|
131 |
self.logger.error(f'Failed to extract JSON from:\n{results}')
|
132 |
self._adjust_config()
|
133 |
del results
|
|
|
134 |
else:
|
135 |
nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
|
136 |
nt_out = count_tokens(results, self.VENDOR, self.TOKENIZER_NAME)
|
137 |
|
138 |
output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
|
139 |
+
|
140 |
if output is None:
|
141 |
+
self.logger.error(f'[Attempt {ind + 1}] Failed to extract JSON from:\n{results}')
|
142 |
self._adjust_config()
|
143 |
else:
|
144 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
145 |
+
|
146 |
if self.json_report:
|
147 |
self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
148 |
+
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(
|
149 |
+
output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki
|
150 |
+
)
|
151 |
|
152 |
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
153 |
|
154 |
+
self.logger.info(f"Formatted JSON:\n{json.dumps(output, indent=4)}")
|
155 |
|
156 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
157 |
|
158 |
+
if self.adjust_temp != self.starting_temp:
|
159 |
self._reset_config()
|
160 |
|
161 |
if self.json_report:
|
162 |
self.json_report.set_text(text_main=f'LLM call successful')
|
163 |
del results
|
164 |
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
|
|
165 |
except Exception as e:
|
166 |
self.logger.error(f'{e}')
|
167 |
+
self._adjust_config()
|
168 |
+
|
169 |
+
self.logger.info(f"Failed to extract valid JSON after [{self.MAX_RETRIES}] attempts")
|
170 |
if self.json_report:
|
171 |
+
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{self.MAX_RETRIES}] attempts')
|
172 |
|
173 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
174 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
175 |
if self.json_report:
|
176 |
self.json_report.set_text(text_main=f'LLM call failed')
|
177 |
|
178 |
self._reset_config()
|
179 |
+
return None, nt_in, nt_out, None, None, usage_report
|
|
vouchervision/LLM_local_custom_fine_tune.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, re, json, yaml, torch
|
2 |
+
from peft import AutoPeftModelForCausalLM
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
|
5 |
+
import json, torch, transformers, gc
|
6 |
+
from transformers import BitsAndBytesConfig
|
7 |
+
from langchain.output_parsers.retry import RetryOutputParser
|
8 |
+
from langchain.prompts import PromptTemplate
|
9 |
+
from langchain_core.output_parsers import JsonOutputParser
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
12 |
+
|
13 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
|
14 |
+
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
15 |
+
|
16 |
+
# MODEL_NAME = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
|
17 |
+
# sltp_version = 'HLT_MICH_Angiospermae_SLTPvA_v1-0_medium__OCR-C25-L25-E50-R05'
|
18 |
+
# LORA = "phyloforfun/mistral-7b-instruct-v2-bnb-4bit__HLT_MICH_Angiospermae_SLTPvC_v1-0_medium_OCR-C25-L25-E50-R05"
|
19 |
+
|
20 |
+
TEXT = "HERBARIUM OF MARCUS W. LYON , JR . Tracaulon sagittatum Indiana : Porter Co. Mincral Springs edge wet subdural woods 1927 TX 11 Flowers pink UNIVERSIT HERBARIUM MICHIGAN MICH University of Michigan Herbarium 1439649 copyright reserved PERSICARIA FEB 26 1965 cm "
|
21 |
+
PARENT_MODEL = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
|
22 |
+
|
23 |
+
class LocalFineTuneHandler:
|
24 |
+
RETRY_DELAY = 2 # Wait 2 seconds before retrying
|
25 |
+
MAX_RETRIES = 5 # Maximum number of retries
|
26 |
+
STARTING_TEMP = 0.001
|
27 |
+
TOKENIZER_NAME = None
|
28 |
+
VENDOR = 'mistral'
|
29 |
+
MAX_GPU_MONITORING_INTERVAL = 2 # seconds
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
def __init__(self, cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation=None):
|
34 |
+
# self.model_id = f"phyloforfun/{self.model_name}"
|
35 |
+
# model_name = LORA #######################################################
|
36 |
+
|
37 |
+
# self.JSON_dict_structure = JSON_dict_structure
|
38 |
+
# self.JSON_dict_structure_str = json.dumps(self.JSON_dict_structure, sort_keys=False, indent=4)
|
39 |
+
|
40 |
+
self.JSON_dict_structure_str = """{"catalogNumber": "", "scientificName": "", "genus": "", "specificEpithet": "", "scientificNameAuthorship": "", "collector": "", "recordNumber": "", "identifiedBy": "", "verbatimCollectionDate": "", "collectionDate": "", "occurrenceRemarks": "", "habitat": "", "locality": "", "country": "", "stateProvince": "", "county": "", "municipality": "", "verbatimCoordinates": "", "decimalLatitude": "", "decimalLongitude": "", "minimumElevationInMeters": "", "maximumElevationInMeters": ""}"""
|
41 |
+
|
42 |
+
|
43 |
+
self.cfg = cfg
|
44 |
+
self.print_output = True
|
45 |
+
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
46 |
+
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
47 |
+
self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia']
|
48 |
+
|
49 |
+
self.logger = logger
|
50 |
+
|
51 |
+
self.has_GPU = torch.cuda.is_available()
|
52 |
+
if self.has_GPU:
|
53 |
+
self.device = "cuda"
|
54 |
+
else:
|
55 |
+
self.device = "cpu"
|
56 |
+
|
57 |
+
self.monitor = SystemLoadMonitor(logger)
|
58 |
+
|
59 |
+
self.model_name = model_name.split("/")[1]
|
60 |
+
self.model_id = model_name
|
61 |
+
|
62 |
+
# self.model_path = hf_hub_download(repo_id=self.model_id, repo_type="model",filename="config.json")
|
63 |
+
|
64 |
+
|
65 |
+
self.starting_temp = float(self.STARTING_TEMP)
|
66 |
+
self.temp_increment = float(0.2)
|
67 |
+
self.adjust_temp = self.starting_temp
|
68 |
+
|
69 |
+
self.load_in_4bit = False
|
70 |
+
|
71 |
+
self.parser = JsonOutputParser()
|
72 |
+
|
73 |
+
self._load_model()
|
74 |
+
self._create_prompt()
|
75 |
+
self._set_config()
|
76 |
+
self._build_model_chain_parser()
|
77 |
+
|
78 |
+
def _set_config(self):
|
79 |
+
# self._clear_VRAM()
|
80 |
+
self.config = {'max_new_tokens': 1024,
|
81 |
+
'temperature': self.starting_temp,
|
82 |
+
'seed': 2023,
|
83 |
+
'top_p': 1,
|
84 |
+
# 'top_k': 1,
|
85 |
+
# 'top_k': 40,
|
86 |
+
'do_sample': False,
|
87 |
+
'n_ctx':4096,
|
88 |
+
|
89 |
+
# Activate 4-bit precision base model loading
|
90 |
+
# 'use_4bit': True,
|
91 |
+
# # Compute dtype for 4-bit base models
|
92 |
+
# 'bnb_4bit_compute_dtype': "float16",
|
93 |
+
# # Quantization type (fp4 or nf4)
|
94 |
+
# 'bnb_4bit_quant_type': "nf4",
|
95 |
+
# # Activate nested quantization for 4-bit base models (double quantization)
|
96 |
+
# 'use_nested_quant': False,
|
97 |
+
}
|
98 |
+
|
99 |
+
def _adjust_config(self):
|
100 |
+
new_temp = self.adjust_temp + self.temp_increment
|
101 |
+
if self.json_report:
|
102 |
+
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
103 |
+
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
104 |
+
self.adjust_temp += self.temp_increment
|
105 |
+
|
106 |
+
|
107 |
+
def _reset_config(self):
|
108 |
+
if self.json_report:
|
109 |
+
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
110 |
+
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
111 |
+
self.adjust_temp = self.starting_temp
|
112 |
+
|
113 |
+
|
114 |
+
def _load_model(self):
|
115 |
+
self.model = AutoPeftModelForCausalLM.from_pretrained(
|
116 |
+
pretrained_model_name_or_path=self.model_id, # YOUR MODEL YOU USED FOR TRAINING
|
117 |
+
load_in_4bit = self.load_in_4bit,
|
118 |
+
low_cpu_mem_usage=True,
|
119 |
+
|
120 |
+
).to(self.device)
|
121 |
+
|
122 |
+
self.tokenizer = AutoTokenizer.from_pretrained(PARENT_MODEL)
|
123 |
+
self.eos_token_id = self.tokenizer.eos_token_id
|
124 |
+
|
125 |
+
|
126 |
+
# def _build_model_chain_parser(self):
|
127 |
+
# self.local_model_pipeline = transformers.pipeline("text-generation",
|
128 |
+
# model=self.model_id,
|
129 |
+
# max_new_tokens=self.config.get('max_new_tokens'),
|
130 |
+
# # top_k=self.config.get('top_k'),
|
131 |
+
# top_p=self.config.get('top_p'),
|
132 |
+
# do_sample=self.config.get('do_sample'),
|
133 |
+
# model_kwargs={"load_in_4bit": self.load_in_4bit})
|
134 |
+
# self.local_model = HuggingFacePipeline(pipeline=self.local_model_pipeline)
|
135 |
+
# # Set up the retry parser with the runnable
|
136 |
+
# # self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
|
137 |
+
# self.retry_parser = RetryOutputParser(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
|
138 |
+
|
139 |
+
# # Create an llm chain with LLM and prompt
|
140 |
+
# self.chain = self.prompt | self.local_model # LCEL
|
141 |
+
def _build_model_chain_parser(self):
|
142 |
+
self.local_model_pipeline = transformers.pipeline(
|
143 |
+
"text-generation",
|
144 |
+
model=self.model_id,
|
145 |
+
max_new_tokens=self.config.get('max_new_tokens'),
|
146 |
+
top_k=self.config.get('top_k', None),
|
147 |
+
top_p=self.config.get('top_p'),
|
148 |
+
do_sample=self.config.get('do_sample'),
|
149 |
+
model_kwargs={"load_in_4bit": self.load_in_4bit},
|
150 |
+
)
|
151 |
+
self.local_model = HuggingFacePipeline(pipeline=self.local_model_pipeline)
|
152 |
+
self.retry_parser = RetryOutputParser(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
def _create_prompt(self):
|
157 |
+
self.alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
158 |
+
|
159 |
+
### Instruction:
|
160 |
+
{}
|
161 |
+
|
162 |
+
### Input:
|
163 |
+
{}
|
164 |
+
|
165 |
+
### Response:
|
166 |
+
{}"""
|
167 |
+
|
168 |
+
self.template = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
169 |
+
|
170 |
+
### Instruction:
|
171 |
+
{}
|
172 |
+
|
173 |
+
### Input:
|
174 |
+
{}
|
175 |
+
|
176 |
+
### Response:
|
177 |
+
{}""".format("{instructions}", "{OCR_text}", "{empty}")
|
178 |
+
|
179 |
+
self.instructions_text = """Refactor the unstructured text into a valid JSON dictionary. The key names follow the Darwin Core Archive Standard. If a key lacks content, then insert an empty string. Fill in the following JSON structure as required: """
|
180 |
+
self.instructions_json = self.JSON_dict_structure_str.replace("\n ", " ").strip().replace("\n", " ")
|
181 |
+
self.instructions = ''.join([self.instructions_text, self.instructions_json])
|
182 |
+
|
183 |
+
|
184 |
+
# Create a prompt from the template so we can use it with Langchain
|
185 |
+
self.prompt = PromptTemplate(template=self.template, input_variables=["instructions", "OCR_text", "empty"])
|
186 |
+
|
187 |
+
# Set up a parser
|
188 |
+
self.parser = JsonOutputParser()
|
189 |
+
|
190 |
+
|
191 |
+
def extract_json(self, response_text):
|
192 |
+
# Assuming the response is a list with a single string entry
|
193 |
+
# response_text = response[0]
|
194 |
+
|
195 |
+
response_pattern = re.compile(r'### Response:(.*)', re.DOTALL)
|
196 |
+
response_match = response_pattern.search(response_text)
|
197 |
+
if not response_match:
|
198 |
+
raise ValueError("No '### Response:' section found in the provided text")
|
199 |
+
|
200 |
+
response_text = response_match.group(1)
|
201 |
+
|
202 |
+
# Use a regular expression to find JSON objects in the response text
|
203 |
+
json_objects = re.findall(r'\{.*?\}', response_text, re.DOTALL)
|
204 |
+
|
205 |
+
if json_objects:
|
206 |
+
# Assuming you want the first JSON object if there are multiple
|
207 |
+
json_str = json_objects[0]
|
208 |
+
# Convert the JSON string to a Python dictionary
|
209 |
+
json_dict = json.loads(json_str)
|
210 |
+
return json_str, json_dict
|
211 |
+
else:
|
212 |
+
raise ValueError("No JSON object found in the '### Response:' section")
|
213 |
+
|
214 |
+
|
215 |
+
def call_llm_local_custom_fine_tune(self, OCR_text, json_report, paths):
|
216 |
+
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
217 |
+
self.json_report = json_report
|
218 |
+
if self.json_report:
|
219 |
+
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
220 |
+
self.monitor.start_monitoring_usage()
|
221 |
+
|
222 |
+
nt_in = 0
|
223 |
+
nt_out = 0
|
224 |
+
|
225 |
+
self.inputs = self.tokenizer(
|
226 |
+
[
|
227 |
+
self.alpaca_prompt.format(
|
228 |
+
self.instructions, # instruction
|
229 |
+
OCR_text, # input
|
230 |
+
"", # output - leave this blank for generation!
|
231 |
+
)
|
232 |
+
], return_tensors = "pt").to(self.device)
|
233 |
+
|
234 |
+
ind = 0
|
235 |
+
while ind < self.MAX_RETRIES:
|
236 |
+
ind += 1
|
237 |
+
try:
|
238 |
+
# Fancy
|
239 |
+
# Dynamically set the temperature for this specific request
|
240 |
+
model_kwargs = {"temperature": self.adjust_temp}
|
241 |
+
|
242 |
+
# Invoke the chain to generate prompt text
|
243 |
+
# results = self.chain.invoke({"instructions": self.instructions, "OCR_text": OCR_text, "empty": "", "model_kwargs": model_kwargs})
|
244 |
+
|
245 |
+
# Use retry_parser to parse the response with retry logic
|
246 |
+
# output = self.retry_parser.parse_with_prompt(results, prompt_value=OCR_text)
|
247 |
+
results = self.local_model.invoke(OCR_text)
|
248 |
+
output = self.retry_parser.parse_with_prompt(results, prompt_value=OCR_text)
|
249 |
+
|
250 |
+
|
251 |
+
# Should work:
|
252 |
+
# output = self.model.generate(**self.inputs, eos_token_id=self.eos_token_id, max_new_tokens=512) # Adjust max_length as needed
|
253 |
+
|
254 |
+
# Decode the generated text
|
255 |
+
# generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
256 |
+
|
257 |
+
# json_str, json_dict = self.extract_json(generated_text)
|
258 |
+
if self.print_output:
|
259 |
+
# print("\nJSON String:")
|
260 |
+
# print(json_str)
|
261 |
+
print("\nJSON Dictionary:")
|
262 |
+
print(output)
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
if output is None:
|
267 |
+
self.logger.error(f'Failed to extract JSON from:\n{results}')
|
268 |
+
self._adjust_config()
|
269 |
+
del results
|
270 |
+
|
271 |
+
else:
|
272 |
+
nt_in = count_tokens(self.instructions+OCR_text, self.VENDOR, self.TOKENIZER_NAME)
|
273 |
+
nt_out = count_tokens(results, self.VENDOR, self.TOKENIZER_NAME)
|
274 |
+
|
275 |
+
output = validate_and_align_JSON_keys_with_template(output, json.loads(self.JSON_dict_structure_str))
|
276 |
+
|
277 |
+
if output is None:
|
278 |
+
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{results}')
|
279 |
+
self._adjust_config()
|
280 |
+
else:
|
281 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
282 |
+
|
283 |
+
if self.json_report:
|
284 |
+
self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
285 |
+
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
286 |
+
|
287 |
+
save_individual_prompt(sanitize_prompt(self.instructions+OCR_text), txt_file_path_ind_prompt)
|
288 |
+
|
289 |
+
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
290 |
+
|
291 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
292 |
+
|
293 |
+
if self.adjust_temp != self.starting_temp:
|
294 |
+
self._reset_config()
|
295 |
+
|
296 |
+
if self.json_report:
|
297 |
+
self.json_report.set_text(text_main=f'LLM call successful')
|
298 |
+
del results
|
299 |
+
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
300 |
+
|
301 |
+
except Exception as e:
|
302 |
+
self.logger.error(f'{e}')
|
303 |
+
|
304 |
+
|
305 |
+
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
306 |
+
if self.json_report:
|
307 |
+
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
308 |
+
|
309 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
310 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
311 |
+
if self.json_report:
|
312 |
+
self.json_report.set_text(text_main=f'LLM call failed')
|
313 |
+
|
314 |
+
return None, nt_in, nt_out, None, None, usage_report
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
+
# # Create a prompt from the template so we can use it with Langchain
|
319 |
+
# self.prompt = PromptTemplate(template=template, input_variables=["query"])
|
320 |
+
|
321 |
+
# # Set up a parser
|
322 |
+
# self.parser = JsonOutputParser()
|
323 |
+
|
324 |
+
|
325 |
+
|
326 |
+
|
327 |
+
|
328 |
+
|
329 |
+
|
330 |
+
|
331 |
+
|
332 |
+
|
333 |
+
|
334 |
+
|
335 |
+
|
336 |
+
|
337 |
+
|
338 |
+
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
+
|
343 |
+
model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
|
344 |
+
sltp_version = 'HLT_MICH_Angiospermae_SLTPvA_v1-0_medium__OCR-C25-L25-E50-R05'
|
345 |
+
lora_name = "phyloforfun/mistral-7b-instruct-v2-bnb-4bit__HLT_MICH_Angiospermae_SLTPvA_v1-0_medium__OCR-C25-L25-E50-R05"
|
346 |
+
|
347 |
+
OCR_test = "HERBARIUM OF MARCUS W. LYON , JR . Tracaulon sagittatum Indiana : Porter Co. Mincral Springs edge wet subdural woods 1927 TX 11 Flowers pink UNIVERSIT HERBARIUM MICHIGAN MICH University of Michigan Herbarium 1439649 copyright reserved PERSICARIA FEB 26 1965 cm "
|
348 |
+
|
349 |
+
|
350 |
+
|
351 |
+
|
352 |
+
|
353 |
+
# model.merge_and_unload()
|
354 |
+
|
355 |
+
|
356 |
+
|
357 |
+
# Generate the output
|
358 |
+
|
vouchervision/OCR_Florence_2.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random, os
|
2 |
+
from PIL import Image
|
3 |
+
import copy
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import matplotlib.patches as patches
|
6 |
+
from PIL import Image, ImageDraw, ImageFont
|
7 |
+
import numpy as np
|
8 |
+
import warnings
|
9 |
+
from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer
|
10 |
+
from vouchervision.utils_LLM import SystemLoadMonitor
|
11 |
+
|
12 |
+
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
|
13 |
+
|
14 |
+
class FlorenceOCR:
|
15 |
+
def __init__(self, logger, model_id='microsoft/Florence-2-large'):
|
16 |
+
self.MAX_TOKENS = 1024
|
17 |
+
self.logger = logger
|
18 |
+
self.model_id = model_id
|
19 |
+
|
20 |
+
self.monitor = SystemLoadMonitor(logger)
|
21 |
+
|
22 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).eval().cuda()
|
23 |
+
self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
24 |
+
|
25 |
+
# self.model_id_clean = "mistralai/Mistral-7B-v0.3"
|
26 |
+
self.model_id_clean = "unsloth/mistral-7b-instruct-v0.3-bnb-4bit"
|
27 |
+
self.tokenizer_clean = AutoTokenizer.from_pretrained(self.model_id_clean)
|
28 |
+
self.model_clean = AutoModelForCausalLM.from_pretrained(self.model_id_clean)
|
29 |
+
|
30 |
+
|
31 |
+
def ocr_florence(self, image, task_prompt='<OCR>', text_input=None):
|
32 |
+
self.monitor.start_monitoring_usage()
|
33 |
+
|
34 |
+
# Open image if a path is provided
|
35 |
+
if isinstance(image, str):
|
36 |
+
image = Image.open(image)
|
37 |
+
|
38 |
+
if text_input is None:
|
39 |
+
prompt = task_prompt
|
40 |
+
else:
|
41 |
+
prompt = task_prompt + text_input
|
42 |
+
|
43 |
+
inputs = self.processor(text=prompt, images=image, return_tensors="pt")
|
44 |
+
|
45 |
+
# Move input_ids and pixel_values to the same device as the model
|
46 |
+
inputs = {key: value.to(self.model.device) for key, value in inputs.items()}
|
47 |
+
|
48 |
+
generated_ids = self.model.generate(
|
49 |
+
input_ids=inputs["input_ids"],
|
50 |
+
pixel_values=inputs["pixel_values"],
|
51 |
+
max_new_tokens=self.MAX_TOKENS,
|
52 |
+
early_stopping=False,
|
53 |
+
do_sample=False,
|
54 |
+
num_beams=3,
|
55 |
+
)
|
56 |
+
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
57 |
+
parsed_answer_dirty = self.processor.post_process_generation(
|
58 |
+
generated_text,
|
59 |
+
task=task_prompt,
|
60 |
+
image_size=(image.width, image.height)
|
61 |
+
)
|
62 |
+
|
63 |
+
inputs = self.tokenizer_clean(f"Insert spaces into this text to make all the words valid. This text contains scientific names of plants, locations, habitat, coordinate words: {parsed_answer_dirty[task_prompt]}", return_tensors="pt")
|
64 |
+
inputs = {key: value.to(self.model_clean.device) for key, value in inputs.items()}
|
65 |
+
|
66 |
+
outputs = self.model_clean.generate(**inputs, max_new_tokens=self.MAX_TOKENS)
|
67 |
+
parsed_answer = self.tokenizer_clean.decode(outputs[0], skip_special_tokens=True)
|
68 |
+
print(parsed_answer_dirty)
|
69 |
+
print(parsed_answer)
|
70 |
+
|
71 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
72 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
73 |
+
|
74 |
+
return parsed_answer, parsed_answer_dirty[task_prompt], parsed_answer_dirty, usage_report
|
75 |
+
|
76 |
+
|
77 |
+
def main():
|
78 |
+
img_path = '/home/brlab/Downloads/gem_2024_06_26__02-26-02/Cropped_Images/By_Class/label/1.jpg'
|
79 |
+
# img = 'D:/D_Desktop/BR_1839468565_Ochnaceae_Campylospermum_reticulatum_label.jpg'
|
80 |
+
|
81 |
+
image = Image.open(img_path)
|
82 |
+
|
83 |
+
ocr = FlorenceOCR(logger = None)
|
84 |
+
results_text, results, usage_report = ocr.ocr_florence(image, task_prompt='<OCR>', text_input=None)
|
85 |
+
print(results_text)
|
86 |
+
|
87 |
+
if __name__ == '__main__':
|
88 |
+
main()
|
vouchervision/OCR_google_cloud_vision (DESKTOP-548UDCR's conflicted copy 2024-06-15).py
ADDED
@@ -0,0 +1,850 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, io, sys, inspect, statistics, json, cv2
|
2 |
+
from statistics import mean
|
3 |
+
# from google.cloud import vision, storage
|
4 |
+
from google.cloud import vision
|
5 |
+
from google.cloud import vision_v1p3beta1 as vision_beta
|
6 |
+
from PIL import Image, ImageDraw, ImageFont
|
7 |
+
import colorsys
|
8 |
+
from tqdm import tqdm
|
9 |
+
from google.oauth2 import service_account
|
10 |
+
|
11 |
+
### LLaVA should only be installed if the user will actually use it.
|
12 |
+
### It requires the most recent pytorch/Python and can mess with older systems
|
13 |
+
|
14 |
+
|
15 |
+
'''
|
16 |
+
@misc{li2021trocr,
|
17 |
+
title={TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models},
|
18 |
+
author={Minghao Li and Tengchao Lv and Lei Cui and Yijuan Lu and Dinei Florencio and Cha Zhang and Zhoujun Li and Furu Wei},
|
19 |
+
year={2021},
|
20 |
+
eprint={2109.10282},
|
21 |
+
archivePrefix={arXiv},
|
22 |
+
primaryClass={cs.CL}
|
23 |
+
}
|
24 |
+
@inproceedings{baek2019character,
|
25 |
+
title={Character Region Awareness for Text Detection},
|
26 |
+
author={Baek, Youngmin and Lee, Bado and Han, Dongyoon and Yun, Sangdoo and Lee, Hwalsuk},
|
27 |
+
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
|
28 |
+
pages={9365--9374},
|
29 |
+
year={2019}
|
30 |
+
}
|
31 |
+
'''
|
32 |
+
|
33 |
+
class OCREngine:
|
34 |
+
|
35 |
+
BBOX_COLOR = "black"
|
36 |
+
|
37 |
+
def __init__(self, logger, json_report, dir_home, is_hf, path, cfg, trOCR_model_version, trOCR_model, trOCR_processor, device):
|
38 |
+
self.is_hf = is_hf
|
39 |
+
self.logger = logger
|
40 |
+
|
41 |
+
self.json_report = json_report
|
42 |
+
|
43 |
+
self.path = path
|
44 |
+
self.cfg = cfg
|
45 |
+
self.do_use_trOCR = self.cfg['leafmachine']['project']['do_use_trOCR']
|
46 |
+
self.OCR_option = self.cfg['leafmachine']['project']['OCR_option']
|
47 |
+
self.double_OCR = self.cfg['leafmachine']['project']['double_OCR']
|
48 |
+
self.dir_home = dir_home
|
49 |
+
|
50 |
+
# Initialize TrOCR components
|
51 |
+
self.trOCR_model_version = trOCR_model_version
|
52 |
+
self.trOCR_processor = trOCR_processor
|
53 |
+
self.trOCR_model = trOCR_model
|
54 |
+
self.device = device
|
55 |
+
|
56 |
+
self.hand_cleaned_text = None
|
57 |
+
self.hand_organized_text = None
|
58 |
+
self.hand_bounds = None
|
59 |
+
self.hand_bounds_word = None
|
60 |
+
self.hand_bounds_flat = None
|
61 |
+
self.hand_text_to_box_mapping = None
|
62 |
+
self.hand_height = None
|
63 |
+
self.hand_confidences = None
|
64 |
+
self.hand_characters = None
|
65 |
+
|
66 |
+
self.normal_cleaned_text = None
|
67 |
+
self.normal_organized_text = None
|
68 |
+
self.normal_bounds = None
|
69 |
+
self.normal_bounds_word = None
|
70 |
+
self.normal_text_to_box_mapping = None
|
71 |
+
self.normal_bounds_flat = None
|
72 |
+
self.normal_height = None
|
73 |
+
self.normal_confidences = None
|
74 |
+
self.normal_characters = None
|
75 |
+
|
76 |
+
self.trOCR_texts = None
|
77 |
+
self.trOCR_text_to_box_mapping = None
|
78 |
+
self.trOCR_bounds_flat = None
|
79 |
+
self.trOCR_height = None
|
80 |
+
self.trOCR_confidences = None
|
81 |
+
self.trOCR_characters = None
|
82 |
+
self.set_client()
|
83 |
+
self.init_craft()
|
84 |
+
|
85 |
+
self.multimodal_prompt = """I need you to transcribe all of the text in this image.
|
86 |
+
Place the transcribed text into a JSON dictionary with this form {"Transcription_Printed_Text": "text","Transcription_Handwritten_Text": "text"}"""
|
87 |
+
self.init_llava()
|
88 |
+
|
89 |
+
|
90 |
+
def set_client(self):
|
91 |
+
if self.is_hf:
|
92 |
+
self.client_beta = vision_beta.ImageAnnotatorClient(credentials=self.get_google_credentials())
|
93 |
+
self.client = vision.ImageAnnotatorClient(credentials=self.get_google_credentials())
|
94 |
+
else:
|
95 |
+
self.client_beta = vision_beta.ImageAnnotatorClient(credentials=self.get_google_credentials())
|
96 |
+
self.client = vision.ImageAnnotatorClient(credentials=self.get_google_credentials())
|
97 |
+
|
98 |
+
|
99 |
+
def get_google_credentials(self):
|
100 |
+
creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
|
101 |
+
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
102 |
+
return credentials
|
103 |
+
|
104 |
+
def init_craft(self):
|
105 |
+
if 'CRAFT' in self.OCR_option:
|
106 |
+
from craft_text_detector import load_craftnet_model, load_refinenet_model
|
107 |
+
|
108 |
+
try:
|
109 |
+
self.refine_net = load_refinenet_model(cuda=True)
|
110 |
+
self.use_cuda = True
|
111 |
+
except:
|
112 |
+
self.refine_net = load_refinenet_model(cuda=False)
|
113 |
+
self.use_cuda = False
|
114 |
+
|
115 |
+
if self.use_cuda:
|
116 |
+
self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=True)
|
117 |
+
else:
|
118 |
+
self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=False)
|
119 |
+
|
120 |
+
def init_llava(self):
|
121 |
+
if 'LLaVA' in self.OCR_option:
|
122 |
+
from vouchervision.OCR_llava import OCRllava
|
123 |
+
|
124 |
+
self.model_path = "liuhaotian/" + self.cfg['leafmachine']['project']['OCR_option_llava']
|
125 |
+
self.model_quant = self.cfg['leafmachine']['project']['OCR_option_llava_bit']
|
126 |
+
|
127 |
+
if self.json_report:
|
128 |
+
self.json_report.set_text(text_main=f'Loading LLaVA model: {self.model_path} Quantization: {self.model_quant}')
|
129 |
+
|
130 |
+
if self.model_quant == '4bit':
|
131 |
+
use_4bit = True
|
132 |
+
elif self.model_quant == 'full':
|
133 |
+
use_4bit = False
|
134 |
+
else:
|
135 |
+
self.logger.info(f"Provided model quantization invlid. Using 4bit.")
|
136 |
+
use_4bit = True
|
137 |
+
|
138 |
+
self.Llava = OCRllava(self.logger, model_path=self.model_path, load_in_4bit=use_4bit, load_in_8bit=False)
|
139 |
+
|
140 |
+
def init_gemini_vision(self):
|
141 |
+
pass
|
142 |
+
|
143 |
+
def init_gpt4_vision(self):
|
144 |
+
pass
|
145 |
+
|
146 |
+
|
147 |
+
def detect_text_craft(self):
|
148 |
+
from craft_text_detector import read_image, get_prediction
|
149 |
+
|
150 |
+
# Perform prediction using CRAFT
|
151 |
+
image = read_image(self.path)
|
152 |
+
|
153 |
+
link_threshold = 0.85
|
154 |
+
text_threshold = 0.4
|
155 |
+
low_text = 0.4
|
156 |
+
|
157 |
+
if self.use_cuda:
|
158 |
+
self.prediction_result = get_prediction(
|
159 |
+
image=image,
|
160 |
+
craft_net=self.craft_net,
|
161 |
+
refine_net=self.refine_net,
|
162 |
+
text_threshold=text_threshold,
|
163 |
+
link_threshold=link_threshold,
|
164 |
+
low_text=low_text,
|
165 |
+
cuda=True,
|
166 |
+
long_size=1280
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
self.prediction_result = get_prediction(
|
170 |
+
image=image,
|
171 |
+
craft_net=self.craft_net,
|
172 |
+
refine_net=self.refine_net,
|
173 |
+
text_threshold=text_threshold,
|
174 |
+
link_threshold=link_threshold,
|
175 |
+
low_text=low_text,
|
176 |
+
cuda=False,
|
177 |
+
long_size=1280
|
178 |
+
)
|
179 |
+
|
180 |
+
# Initialize metadata structures
|
181 |
+
bounds = []
|
182 |
+
bounds_word = [] # CRAFT gives bounds for text regions, not individual words
|
183 |
+
text_to_box_mapping = []
|
184 |
+
bounds_flat = []
|
185 |
+
height_flat = []
|
186 |
+
confidences = [] # CRAFT does not provide confidences per character, so this might be uniformly set or estimated
|
187 |
+
characters = [] # Simulating as CRAFT doesn't provide character-level details
|
188 |
+
organized_text = ""
|
189 |
+
|
190 |
+
total_b = len(self.prediction_result["boxes"])
|
191 |
+
i=0
|
192 |
+
# Process each detected text region
|
193 |
+
for box in self.prediction_result["boxes"]:
|
194 |
+
i+=1
|
195 |
+
if self.json_report:
|
196 |
+
self.json_report.set_text(text_main=f'Locating text using CRAFT --- {i}/{total_b}')
|
197 |
+
|
198 |
+
vertices = [{"x": int(vertex[0]), "y": int(vertex[1])} for vertex in box]
|
199 |
+
|
200 |
+
# Simulate a mapping for the whole detected region as a word
|
201 |
+
text_to_box_mapping.append({
|
202 |
+
"vertices": vertices,
|
203 |
+
"text": "detected_text" # Placeholder, as CRAFT does not provide the text content directly
|
204 |
+
})
|
205 |
+
|
206 |
+
# Assuming each box is a word for the sake of this example
|
207 |
+
bounds_word.append({"vertices": vertices})
|
208 |
+
|
209 |
+
# For simplicity, we're not dividing text regions into characters as CRAFT doesn't provide this
|
210 |
+
# Instead, we create a single large 'character' per detected region
|
211 |
+
bounds.append({"vertices": vertices})
|
212 |
+
|
213 |
+
# Simulate flat bounds and height for each detected region
|
214 |
+
x_positions = [vertex["x"] for vertex in vertices]
|
215 |
+
y_positions = [vertex["y"] for vertex in vertices]
|
216 |
+
min_x, max_x = min(x_positions), max(x_positions)
|
217 |
+
min_y, max_y = min(y_positions), max(y_positions)
|
218 |
+
avg_height = max_y - min_y
|
219 |
+
height_flat.append(avg_height)
|
220 |
+
|
221 |
+
# Assuming uniform confidence for all detected regions
|
222 |
+
confidences.append(1.0) # Placeholder confidence
|
223 |
+
|
224 |
+
# Adding dummy character for each box
|
225 |
+
characters.append("X") # Placeholder character
|
226 |
+
|
227 |
+
# Organize text as a single string (assuming each box is a word)
|
228 |
+
# organized_text += "detected_text " # Placeholder text
|
229 |
+
|
230 |
+
# Update class attributes with processed data
|
231 |
+
self.normal_bounds = bounds
|
232 |
+
self.normal_bounds_word = bounds_word
|
233 |
+
self.normal_text_to_box_mapping = text_to_box_mapping
|
234 |
+
self.normal_bounds_flat = bounds_flat # This would be similar to bounds if not processing characters individually
|
235 |
+
self.normal_height = height_flat
|
236 |
+
self.normal_confidences = confidences
|
237 |
+
self.normal_characters = characters
|
238 |
+
self.normal_organized_text = organized_text.strip()
|
239 |
+
|
240 |
+
|
241 |
+
def detect_text_with_trOCR_using_google_bboxes(self, do_use_trOCR, logger):
|
242 |
+
CONFIDENCES = 0.80
|
243 |
+
MAX_NEW_TOKENS = 50
|
244 |
+
|
245 |
+
self.OCR_JSON_to_file = {}
|
246 |
+
|
247 |
+
ocr_parts = ''
|
248 |
+
if not do_use_trOCR:
|
249 |
+
if 'normal' in self.OCR_option:
|
250 |
+
self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
251 |
+
# logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}")
|
252 |
+
# ocr_parts = ocr_parts + f"Google_OCR_Standard:\n{self.normal_organized_text}"
|
253 |
+
ocr_parts = self.normal_organized_text
|
254 |
+
|
255 |
+
if 'hand' in self.OCR_option:
|
256 |
+
self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
|
257 |
+
# logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}")
|
258 |
+
# ocr_parts = ocr_parts + f"Google_OCR_Handwriting:\n{self.hand_organized_text}"
|
259 |
+
ocr_parts = self.hand_organized_text
|
260 |
+
|
261 |
+
# if self.OCR_option in ['both',]:
|
262 |
+
# logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}")
|
263 |
+
# return f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}"
|
264 |
+
return ocr_parts
|
265 |
+
else:
|
266 |
+
logger.info(f'Supplementing with trOCR')
|
267 |
+
|
268 |
+
self.trOCR_texts = []
|
269 |
+
original_image = Image.open(self.path).convert("RGB")
|
270 |
+
|
271 |
+
if 'normal' in self.OCR_option or 'CRAFT' in self.OCR_option:
|
272 |
+
available_bounds = self.normal_bounds_word
|
273 |
+
elif 'hand' in self.OCR_option:
|
274 |
+
available_bounds = self.hand_bounds_word
|
275 |
+
# elif self.OCR_option in ['both',]:
|
276 |
+
# available_bounds = self.hand_bounds_word
|
277 |
+
else:
|
278 |
+
raise
|
279 |
+
|
280 |
+
text_to_box_mapping = []
|
281 |
+
characters = []
|
282 |
+
height = []
|
283 |
+
confidences = []
|
284 |
+
total_b = len(available_bounds)
|
285 |
+
i=0
|
286 |
+
for bound in tqdm(available_bounds, desc="Processing words using Google Vision bboxes"):
|
287 |
+
i+=1
|
288 |
+
if self.json_report:
|
289 |
+
self.json_report.set_text(text_main=f'Working on trOCR :construction: {i}/{total_b}')
|
290 |
+
|
291 |
+
vertices = bound["vertices"]
|
292 |
+
|
293 |
+
left = min([v["x"] for v in vertices])
|
294 |
+
top = min([v["y"] for v in vertices])
|
295 |
+
right = max([v["x"] for v in vertices])
|
296 |
+
bottom = max([v["y"] for v in vertices])
|
297 |
+
|
298 |
+
# Crop image based on Google's bounding box
|
299 |
+
cropped_image = original_image.crop((left, top, right, bottom))
|
300 |
+
pixel_values = self.trOCR_processor(cropped_image, return_tensors="pt").pixel_values
|
301 |
+
|
302 |
+
# Move pixel values to the appropriate device
|
303 |
+
pixel_values = pixel_values.to(self.device)
|
304 |
+
|
305 |
+
generated_ids = self.trOCR_model.generate(pixel_values, max_new_tokens=MAX_NEW_TOKENS)
|
306 |
+
extracted_text = self.trOCR_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
307 |
+
self.trOCR_texts.append(extracted_text)
|
308 |
+
|
309 |
+
# For plotting
|
310 |
+
word_length = max(vertex.get('x') for vertex in vertices) - min(vertex.get('x') for vertex in vertices)
|
311 |
+
num_symbols = len(extracted_text)
|
312 |
+
|
313 |
+
Yw = max(vertex.get('y') for vertex in vertices)
|
314 |
+
Yo = Yw - min(vertex.get('y') for vertex in vertices)
|
315 |
+
X = word_length / num_symbols if num_symbols > 0 else 0
|
316 |
+
H = int(X+(Yo*0.1))
|
317 |
+
height.append(H)
|
318 |
+
|
319 |
+
map_dict = {
|
320 |
+
"vertices": vertices,
|
321 |
+
"text": extracted_text # Use the text extracted by trOCR
|
322 |
+
}
|
323 |
+
text_to_box_mapping.append(map_dict)
|
324 |
+
|
325 |
+
characters.append(extracted_text)
|
326 |
+
confidences.append(CONFIDENCES)
|
327 |
+
|
328 |
+
median_height = statistics.median(height) if height else 0
|
329 |
+
median_heights = [median_height * 1.5] * len(characters)
|
330 |
+
|
331 |
+
self.trOCR_texts = ' '.join(self.trOCR_texts)
|
332 |
+
|
333 |
+
self.trOCR_text_to_box_mapping = text_to_box_mapping
|
334 |
+
self.trOCR_bounds_flat = available_bounds
|
335 |
+
self.trOCR_height = median_heights
|
336 |
+
self.trOCR_confidences = confidences
|
337 |
+
self.trOCR_characters = characters
|
338 |
+
|
339 |
+
if 'normal' in self.OCR_option:
|
340 |
+
self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
341 |
+
self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
|
342 |
+
# logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
|
343 |
+
# ocr_parts = ocr_parts + f"\nGoogle_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
|
344 |
+
ocr_parts = self.trOCR_texts
|
345 |
+
if 'hand' in self.OCR_option:
|
346 |
+
self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
|
347 |
+
self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
|
348 |
+
# logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
|
349 |
+
# ocr_parts = ocr_parts + f"\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
|
350 |
+
ocr_parts = self.trOCR_texts
|
351 |
+
# if self.OCR_option in ['both',]:
|
352 |
+
# self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
353 |
+
# self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
|
354 |
+
# self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
|
355 |
+
# logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
|
356 |
+
# ocr_parts = ocr_parts + f"\nGoogle_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
|
357 |
+
if 'CRAFT' in self.OCR_option:
|
358 |
+
# self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
359 |
+
self.OCR_JSON_to_file['OCR_CRAFT_trOCR'] = self.trOCR_texts
|
360 |
+
# logger.info(f"CRAFT_trOCR:\n{self.trOCR_texts}")
|
361 |
+
# ocr_parts = ocr_parts + f"\nCRAFT_trOCR:\n{self.trOCR_texts}"
|
362 |
+
ocr_parts = self.trOCR_texts
|
363 |
+
return ocr_parts
|
364 |
+
|
365 |
+
@staticmethod
|
366 |
+
def confidence_to_color(confidence):
|
367 |
+
hue = (confidence - 0.5) * 120 / 0.5
|
368 |
+
r, g, b = colorsys.hls_to_rgb(hue/360, 0.5, 1)
|
369 |
+
return (int(r*255), int(g*255), int(b*255))
|
370 |
+
|
371 |
+
|
372 |
+
def render_text_on_black_image(self, option):
|
373 |
+
bounds_flat = getattr(self, f'{option}_bounds_flat', [])
|
374 |
+
heights = getattr(self, f'{option}_height', [])
|
375 |
+
confidences = getattr(self, f'{option}_confidences', [])
|
376 |
+
characters = getattr(self, f'{option}_characters', [])
|
377 |
+
|
378 |
+
original_image = Image.open(self.path)
|
379 |
+
width, height = original_image.size
|
380 |
+
black_image = Image.new("RGB", (width, height), "black")
|
381 |
+
draw = ImageDraw.Draw(black_image)
|
382 |
+
|
383 |
+
for bound, confidence, char_height, character in zip(bounds_flat, confidences, heights, characters):
|
384 |
+
font_size = int(char_height)
|
385 |
+
try:
|
386 |
+
font = ImageFont.truetype("arial.ttf", font_size)
|
387 |
+
except:
|
388 |
+
font = ImageFont.load_default().font_variant(size=font_size)
|
389 |
+
if option == 'trOCR':
|
390 |
+
color = (0, 170, 255)
|
391 |
+
else:
|
392 |
+
color = OCREngine.confidence_to_color(confidence)
|
393 |
+
position = (bound["vertices"][0]["x"], bound["vertices"][0]["y"] - char_height)
|
394 |
+
draw.text(position, character, fill=color, font=font)
|
395 |
+
|
396 |
+
return black_image
|
397 |
+
|
398 |
+
|
399 |
+
def merge_images(self, image1, image2):
|
400 |
+
width1, height1 = image1.size
|
401 |
+
width2, height2 = image2.size
|
402 |
+
merged_image = Image.new("RGB", (width1 + width2, max([height1, height2])))
|
403 |
+
merged_image.paste(image1, (0, 0))
|
404 |
+
merged_image.paste(image2, (width1, 0))
|
405 |
+
return merged_image
|
406 |
+
|
407 |
+
|
408 |
+
def draw_boxes(self, option):
|
409 |
+
bounds = getattr(self, f'{option}_bounds', [])
|
410 |
+
bounds_word = getattr(self, f'{option}_bounds_word', [])
|
411 |
+
confidences = getattr(self, f'{option}_confidences', [])
|
412 |
+
|
413 |
+
draw = ImageDraw.Draw(self.image)
|
414 |
+
width, height = self.image.size
|
415 |
+
if min([width, height]) > 4000:
|
416 |
+
line_width_thick = int((width + height) / 2 * 0.0025) # Adjust line width for character level
|
417 |
+
line_width_thin = 1
|
418 |
+
else:
|
419 |
+
line_width_thick = int((width + height) / 2 * 0.005) # Adjust line width for character level
|
420 |
+
line_width_thin = 1 #int((width + height) / 2 * 0.001)
|
421 |
+
|
422 |
+
for bound in bounds_word:
|
423 |
+
draw.polygon(
|
424 |
+
[
|
425 |
+
bound["vertices"][0]["x"], bound["vertices"][0]["y"],
|
426 |
+
bound["vertices"][1]["x"], bound["vertices"][1]["y"],
|
427 |
+
bound["vertices"][2]["x"], bound["vertices"][2]["y"],
|
428 |
+
bound["vertices"][3]["x"], bound["vertices"][3]["y"],
|
429 |
+
],
|
430 |
+
outline=OCREngine.BBOX_COLOR,
|
431 |
+
width=line_width_thin
|
432 |
+
)
|
433 |
+
|
434 |
+
# Draw a line segment at the bottom of each handwritten character
|
435 |
+
for bound, confidence in zip(bounds, confidences):
|
436 |
+
color = OCREngine.confidence_to_color(confidence)
|
437 |
+
# Use the bottom two vertices of the bounding box for the line
|
438 |
+
bottom_left = (bound["vertices"][3]["x"], bound["vertices"][3]["y"] + line_width_thick)
|
439 |
+
bottom_right = (bound["vertices"][2]["x"], bound["vertices"][2]["y"] + line_width_thick)
|
440 |
+
draw.line([bottom_left, bottom_right], fill=color, width=line_width_thick)
|
441 |
+
|
442 |
+
return self.image
|
443 |
+
|
444 |
+
|
445 |
+
def detect_text(self):
|
446 |
+
|
447 |
+
with io.open(self.path, 'rb') as image_file:
|
448 |
+
content = image_file.read()
|
449 |
+
image = vision.Image(content=content)
|
450 |
+
response = self.client.document_text_detection(image=image)
|
451 |
+
texts = response.text_annotations
|
452 |
+
|
453 |
+
if response.error.message:
|
454 |
+
raise Exception(
|
455 |
+
'{}\nFor more info on error messages, check: '
|
456 |
+
'https://cloud.google.com/apis/design/errors'.format(
|
457 |
+
response.error.message))
|
458 |
+
|
459 |
+
bounds = []
|
460 |
+
bounds_word = []
|
461 |
+
text_to_box_mapping = []
|
462 |
+
bounds_flat = []
|
463 |
+
height_flat = []
|
464 |
+
confidences = []
|
465 |
+
characters = []
|
466 |
+
organized_text = ""
|
467 |
+
paragraph_count = 0
|
468 |
+
|
469 |
+
for text in texts[1:]:
|
470 |
+
vertices = [{"x": vertex.x, "y": vertex.y} for vertex in text.bounding_poly.vertices]
|
471 |
+
map_dict = {
|
472 |
+
"vertices": vertices,
|
473 |
+
"text": text.description
|
474 |
+
}
|
475 |
+
text_to_box_mapping.append(map_dict)
|
476 |
+
|
477 |
+
for page in response.full_text_annotation.pages:
|
478 |
+
for block in page.blocks:
|
479 |
+
# paragraph_count += 1
|
480 |
+
# organized_text += f'OCR_paragraph_{paragraph_count}:\n' # Add paragraph label
|
481 |
+
for paragraph in block.paragraphs:
|
482 |
+
|
483 |
+
avg_H_list = []
|
484 |
+
for word in paragraph.words:
|
485 |
+
Yw = max(vertex.y for vertex in word.bounding_box.vertices)
|
486 |
+
# Calculate the width of the word and divide by the number of symbols
|
487 |
+
word_length = max(vertex.x for vertex in word.bounding_box.vertices) - min(vertex.x for vertex in word.bounding_box.vertices)
|
488 |
+
num_symbols = len(word.symbols)
|
489 |
+
if num_symbols <= 3:
|
490 |
+
H = int(Yw - min(vertex.y for vertex in word.bounding_box.vertices))
|
491 |
+
else:
|
492 |
+
Yo = Yw - min(vertex.y for vertex in word.bounding_box.vertices)
|
493 |
+
X = word_length / num_symbols if num_symbols > 0 else 0
|
494 |
+
H = int(X+(Yo*0.1))
|
495 |
+
avg_H_list.append(H)
|
496 |
+
avg_H = int(mean(avg_H_list))
|
497 |
+
|
498 |
+
words_in_para = []
|
499 |
+
for word in paragraph.words:
|
500 |
+
# Get word-level bounding box
|
501 |
+
bound_word_dict = {
|
502 |
+
"vertices": [
|
503 |
+
{"x": vertex.x, "y": vertex.y} for vertex in word.bounding_box.vertices
|
504 |
+
]
|
505 |
+
}
|
506 |
+
bounds_word.append(bound_word_dict)
|
507 |
+
|
508 |
+
Y = max(vertex.y for vertex in word.bounding_box.vertices)
|
509 |
+
word_x_start = min(vertex.x for vertex in word.bounding_box.vertices)
|
510 |
+
word_x_end = max(vertex.x for vertex in word.bounding_box.vertices)
|
511 |
+
num_symbols = len(word.symbols)
|
512 |
+
symbol_width = (word_x_end - word_x_start) / num_symbols if num_symbols > 0 else 0
|
513 |
+
|
514 |
+
current_x_position = word_x_start
|
515 |
+
|
516 |
+
characters_ind = []
|
517 |
+
for symbol in word.symbols:
|
518 |
+
bound_dict = {
|
519 |
+
"vertices": [
|
520 |
+
{"x": vertex.x, "y": vertex.y} for vertex in symbol.bounding_box.vertices
|
521 |
+
]
|
522 |
+
}
|
523 |
+
bounds.append(bound_dict)
|
524 |
+
|
525 |
+
# Create flat bounds with adjusted x position
|
526 |
+
bounds_flat_dict = {
|
527 |
+
"vertices": [
|
528 |
+
{"x": current_x_position, "y": Y},
|
529 |
+
{"x": current_x_position + symbol_width, "y": Y}
|
530 |
+
]
|
531 |
+
}
|
532 |
+
bounds_flat.append(bounds_flat_dict)
|
533 |
+
current_x_position += symbol_width
|
534 |
+
|
535 |
+
height_flat.append(avg_H)
|
536 |
+
confidences.append(round(symbol.confidence, 4))
|
537 |
+
|
538 |
+
characters_ind.append(symbol.text)
|
539 |
+
characters.append(symbol.text)
|
540 |
+
|
541 |
+
words_in_para.append(''.join(characters_ind))
|
542 |
+
paragraph_text = ' '.join(words_in_para) # Join words in paragraph
|
543 |
+
organized_text += paragraph_text + ' ' #+ '\n'
|
544 |
+
|
545 |
+
# median_height = statistics.median(height_flat) if height_flat else 0
|
546 |
+
# median_heights = [median_height] * len(characters)
|
547 |
+
|
548 |
+
self.normal_cleaned_text = texts[0].description if texts else ''
|
549 |
+
self.normal_organized_text = organized_text
|
550 |
+
self.normal_bounds = bounds
|
551 |
+
self.normal_bounds_word = bounds_word
|
552 |
+
self.normal_text_to_box_mapping = text_to_box_mapping
|
553 |
+
self.normal_bounds_flat = bounds_flat
|
554 |
+
# self.normal_height = median_heights #height_flat
|
555 |
+
self.normal_height = height_flat
|
556 |
+
self.normal_confidences = confidences
|
557 |
+
self.normal_characters = characters
|
558 |
+
return self.normal_cleaned_text
|
559 |
+
|
560 |
+
|
561 |
+
def detect_handwritten_ocr(self):
|
562 |
+
|
563 |
+
with open(self.path, "rb") as image_file:
|
564 |
+
content = image_file.read()
|
565 |
+
|
566 |
+
image = vision_beta.Image(content=content)
|
567 |
+
image_context = vision_beta.ImageContext(language_hints=["en-t-i0-handwrit"])
|
568 |
+
response = self.client_beta.document_text_detection(image=image, image_context=image_context)
|
569 |
+
texts = response.text_annotations
|
570 |
+
|
571 |
+
if response.error.message:
|
572 |
+
raise Exception(
|
573 |
+
"{}\nFor more info on error messages, check: "
|
574 |
+
"https://cloud.google.com/apis/design/errors".format(response.error.message)
|
575 |
+
)
|
576 |
+
|
577 |
+
bounds = []
|
578 |
+
bounds_word = []
|
579 |
+
bounds_flat = []
|
580 |
+
height_flat = []
|
581 |
+
confidences = []
|
582 |
+
characters = []
|
583 |
+
organized_text = ""
|
584 |
+
paragraph_count = 0
|
585 |
+
text_to_box_mapping = []
|
586 |
+
|
587 |
+
for text in texts[1:]:
|
588 |
+
vertices = [{"x": vertex.x, "y": vertex.y} for vertex in text.bounding_poly.vertices]
|
589 |
+
map_dict = {
|
590 |
+
"vertices": vertices,
|
591 |
+
"text": text.description
|
592 |
+
}
|
593 |
+
text_to_box_mapping.append(map_dict)
|
594 |
+
|
595 |
+
for page in response.full_text_annotation.pages:
|
596 |
+
for block in page.blocks:
|
597 |
+
# paragraph_count += 1
|
598 |
+
# organized_text += f'\nOCR_paragraph_{paragraph_count}:\n' # Add paragraph label
|
599 |
+
for paragraph in block.paragraphs:
|
600 |
+
|
601 |
+
avg_H_list = []
|
602 |
+
for word in paragraph.words:
|
603 |
+
Yw = max(vertex.y for vertex in word.bounding_box.vertices)
|
604 |
+
# Calculate the width of the word and divide by the number of symbols
|
605 |
+
word_length = max(vertex.x for vertex in word.bounding_box.vertices) - min(vertex.x for vertex in word.bounding_box.vertices)
|
606 |
+
num_symbols = len(word.symbols)
|
607 |
+
if num_symbols <= 3:
|
608 |
+
H = int(Yw - min(vertex.y for vertex in word.bounding_box.vertices))
|
609 |
+
else:
|
610 |
+
Yo = Yw - min(vertex.y for vertex in word.bounding_box.vertices)
|
611 |
+
X = word_length / num_symbols if num_symbols > 0 else 0
|
612 |
+
H = int(X+(Yo*0.1))
|
613 |
+
avg_H_list.append(H)
|
614 |
+
avg_H = int(mean(avg_H_list))
|
615 |
+
|
616 |
+
words_in_para = []
|
617 |
+
for word in paragraph.words:
|
618 |
+
# Get word-level bounding box
|
619 |
+
bound_word_dict = {
|
620 |
+
"vertices": [
|
621 |
+
{"x": vertex.x, "y": vertex.y} for vertex in word.bounding_box.vertices
|
622 |
+
]
|
623 |
+
}
|
624 |
+
bounds_word.append(bound_word_dict)
|
625 |
+
|
626 |
+
Y = max(vertex.y for vertex in word.bounding_box.vertices)
|
627 |
+
word_x_start = min(vertex.x for vertex in word.bounding_box.vertices)
|
628 |
+
word_x_end = max(vertex.x for vertex in word.bounding_box.vertices)
|
629 |
+
num_symbols = len(word.symbols)
|
630 |
+
symbol_width = (word_x_end - word_x_start) / num_symbols if num_symbols > 0 else 0
|
631 |
+
|
632 |
+
current_x_position = word_x_start
|
633 |
+
|
634 |
+
characters_ind = []
|
635 |
+
for symbol in word.symbols:
|
636 |
+
bound_dict = {
|
637 |
+
"vertices": [
|
638 |
+
{"x": vertex.x, "y": vertex.y} for vertex in symbol.bounding_box.vertices
|
639 |
+
]
|
640 |
+
}
|
641 |
+
bounds.append(bound_dict)
|
642 |
+
|
643 |
+
# Create flat bounds with adjusted x position
|
644 |
+
bounds_flat_dict = {
|
645 |
+
"vertices": [
|
646 |
+
{"x": current_x_position, "y": Y},
|
647 |
+
{"x": current_x_position + symbol_width, "y": Y}
|
648 |
+
]
|
649 |
+
}
|
650 |
+
bounds_flat.append(bounds_flat_dict)
|
651 |
+
current_x_position += symbol_width
|
652 |
+
|
653 |
+
height_flat.append(avg_H)
|
654 |
+
confidences.append(round(symbol.confidence, 4))
|
655 |
+
|
656 |
+
characters_ind.append(symbol.text)
|
657 |
+
characters.append(symbol.text)
|
658 |
+
|
659 |
+
words_in_para.append(''.join(characters_ind))
|
660 |
+
paragraph_text = ' '.join(words_in_para) # Join words in paragraph
|
661 |
+
organized_text += paragraph_text + ' ' #+ '\n'
|
662 |
+
|
663 |
+
# median_height = statistics.median(height_flat) if height_flat else 0
|
664 |
+
# median_heights = [median_height] * len(characters)
|
665 |
+
|
666 |
+
self.hand_cleaned_text = response.text_annotations[0].description if response.text_annotations else ''
|
667 |
+
self.hand_organized_text = organized_text
|
668 |
+
self.hand_bounds = bounds
|
669 |
+
self.hand_bounds_word = bounds_word
|
670 |
+
self.hand_bounds_flat = bounds_flat
|
671 |
+
self.hand_text_to_box_mapping = text_to_box_mapping
|
672 |
+
# self.hand_height = median_heights #height_flat
|
673 |
+
self.hand_height = height_flat
|
674 |
+
self.hand_confidences = confidences
|
675 |
+
self.hand_characters = characters
|
676 |
+
return self.hand_cleaned_text
|
677 |
+
|
678 |
+
|
679 |
+
def process_image(self, do_create_OCR_helper_image, logger):
|
680 |
+
# Can stack options, so solitary if statements
|
681 |
+
self.OCR = 'OCR:\n'
|
682 |
+
if 'CRAFT' in self.OCR_option:
|
683 |
+
self.do_use_trOCR = True
|
684 |
+
self.detect_text_craft()
|
685 |
+
### Optionally add trOCR to the self.OCR for additional context
|
686 |
+
if self.double_OCR:
|
687 |
+
part_OCR = "\CRAFT trOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
688 |
+
self.OCR = self.OCR + part_OCR + part_OCR
|
689 |
+
else:
|
690 |
+
self.OCR = self.OCR + "\CRAFT trOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
691 |
+
# logger.info(f"CRAFT trOCR:\n{self.OCR}")
|
692 |
+
|
693 |
+
if 'LLaVA' in self.OCR_option: # This option does not produce an OCR helper image
|
694 |
+
if self.json_report:
|
695 |
+
self.json_report.set_text(text_main=f'Working on LLaVA {self.Llava.model_path} transcription :construction:')
|
696 |
+
|
697 |
+
image, json_output, direct_output, str_output, usage_report = self.Llava.transcribe_image(self.path, self.multimodal_prompt)
|
698 |
+
self.logger.info(f"LLaVA Usage Report for Model {self.Llava.model_path}:\n{usage_report}")
|
699 |
+
|
700 |
+
try:
|
701 |
+
self.OCR_JSON_to_file['OCR_LLaVA'] = str_output
|
702 |
+
except:
|
703 |
+
self.OCR_JSON_to_file = {}
|
704 |
+
self.OCR_JSON_to_file['OCR_LLaVA'] = str_output
|
705 |
+
|
706 |
+
if self.double_OCR:
|
707 |
+
self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}" + f"\nLLaVA OCR:\n{str_output}"
|
708 |
+
else:
|
709 |
+
self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}"
|
710 |
+
# logger.info(f"LLaVA OCR:\n{self.OCR}")
|
711 |
+
|
712 |
+
if 'normal' in self.OCR_option or 'hand' in self.OCR_option:
|
713 |
+
if 'normal' in self.OCR_option:
|
714 |
+
if self.double_OCR:
|
715 |
+
part_OCR = self.OCR + "\nGoogle Printed OCR:\n" + self.detect_text()
|
716 |
+
self.OCR = self.OCR + part_OCR + part_OCR
|
717 |
+
else:
|
718 |
+
self.OCR = self.OCR + "\nGoogle Printed OCR:\n" + self.detect_text()
|
719 |
+
if 'hand' in self.OCR_option:
|
720 |
+
if self.double_OCR:
|
721 |
+
part_OCR = self.OCR + "\nGoogle Handwritten OCR:\n" + self.detect_handwritten_ocr()
|
722 |
+
self.OCR = self.OCR + part_OCR + part_OCR
|
723 |
+
else:
|
724 |
+
self.OCR = self.OCR + "\nGoogle Handwritten OCR:\n" + self.detect_handwritten_ocr()
|
725 |
+
# if self.OCR_option not in ['normal', 'hand', 'both']:
|
726 |
+
# self.OCR_option = 'both'
|
727 |
+
# self.detect_text()
|
728 |
+
# self.detect_handwritten_ocr()
|
729 |
+
|
730 |
+
### Optionally add trOCR to the self.OCR for additional context
|
731 |
+
if self.do_use_trOCR:
|
732 |
+
if self.double_OCR:
|
733 |
+
part_OCR = "\ntrOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
734 |
+
self.OCR = self.OCR + part_OCR + part_OCR
|
735 |
+
else:
|
736 |
+
self.OCR = self.OCR + "\ntrOCR:\n" + self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
737 |
+
# logger.info(f"OCR:\n{self.OCR}")
|
738 |
+
else:
|
739 |
+
# populate self.OCR_JSON_to_file = {}
|
740 |
+
_ = self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
741 |
+
|
742 |
+
|
743 |
+
if do_create_OCR_helper_image and ('LLaVA' not in self.OCR_option):
|
744 |
+
self.image = Image.open(self.path)
|
745 |
+
|
746 |
+
if 'normal' in self.OCR_option:
|
747 |
+
image_with_boxes_normal = self.draw_boxes('normal')
|
748 |
+
text_image_normal = self.render_text_on_black_image('normal')
|
749 |
+
self.merged_image_normal = self.merge_images(image_with_boxes_normal, text_image_normal)
|
750 |
+
|
751 |
+
if 'hand' in self.OCR_option:
|
752 |
+
image_with_boxes_hand = self.draw_boxes('hand')
|
753 |
+
text_image_hand = self.render_text_on_black_image('hand')
|
754 |
+
self.merged_image_hand = self.merge_images(image_with_boxes_hand, text_image_hand)
|
755 |
+
|
756 |
+
if self.do_use_trOCR:
|
757 |
+
text_image_trOCR = self.render_text_on_black_image('trOCR')
|
758 |
+
|
759 |
+
if 'CRAFT' in self.OCR_option:
|
760 |
+
image_with_boxes_normal = self.draw_boxes('normal')
|
761 |
+
self.merged_image_normal = self.merge_images(image_with_boxes_normal, text_image_trOCR)
|
762 |
+
|
763 |
+
### Merge final overlay image
|
764 |
+
### [original, normal bboxes, normal text]
|
765 |
+
if 'CRAFT' in self.OCR_option or 'normal' in self.OCR_option:
|
766 |
+
self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_normal)
|
767 |
+
### [original, hand bboxes, hand text]
|
768 |
+
elif 'hand' in self.OCR_option:
|
769 |
+
self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_hand)
|
770 |
+
### [original, normal bboxes, normal text, hand bboxes, hand text]
|
771 |
+
else:
|
772 |
+
self.overlay_image = self.merge_images(Image.open(self.path), self.merge_images(self.merged_image_normal, self.merged_image_hand))
|
773 |
+
|
774 |
+
if self.do_use_trOCR:
|
775 |
+
if 'CRAFT' in self.OCR_option:
|
776 |
+
heat_map_text = Image.fromarray(cv2.cvtColor(self.prediction_result["heatmaps"]["text_score_heatmap"], cv2.COLOR_BGR2RGB))
|
777 |
+
heat_map_link = Image.fromarray(cv2.cvtColor(self.prediction_result["heatmaps"]["link_score_heatmap"], cv2.COLOR_BGR2RGB))
|
778 |
+
self.overlay_image = self.merge_images(self.overlay_image, heat_map_text)
|
779 |
+
self.overlay_image = self.merge_images(self.overlay_image, heat_map_link)
|
780 |
+
|
781 |
+
else:
|
782 |
+
self.overlay_image = self.merge_images(self.overlay_image, text_image_trOCR)
|
783 |
+
|
784 |
+
else:
|
785 |
+
self.merged_image_normal = None
|
786 |
+
self.merged_image_hand = None
|
787 |
+
self.overlay_image = Image.open(self.path)
|
788 |
+
|
789 |
+
try:
|
790 |
+
from craft_text_detector import empty_cuda_cache
|
791 |
+
empty_cuda_cache()
|
792 |
+
except:
|
793 |
+
pass
|
794 |
+
|
795 |
+
class SafetyCheck():
|
796 |
+
def __init__(self, is_hf) -> None:
|
797 |
+
self.is_hf = is_hf
|
798 |
+
self.set_client()
|
799 |
+
|
800 |
+
def set_client(self):
|
801 |
+
if self.is_hf:
|
802 |
+
self.client = vision.ImageAnnotatorClient(credentials=self.get_google_credentials())
|
803 |
+
else:
|
804 |
+
self.client = vision.ImageAnnotatorClient(credentials=self.get_google_credentials())
|
805 |
+
|
806 |
+
def get_google_credentials(self):
|
807 |
+
creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
|
808 |
+
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
809 |
+
return credentials
|
810 |
+
|
811 |
+
def check_for_inappropriate_content(self, file_stream):
|
812 |
+
try:
|
813 |
+
LEVEL = 2
|
814 |
+
# content = file_stream.read()
|
815 |
+
file_stream.seek(0) # Reset file stream position to the beginning
|
816 |
+
content = file_stream.read()
|
817 |
+
image = vision.Image(content=content)
|
818 |
+
response = self.client.safe_search_detection(image=image)
|
819 |
+
safe = response.safe_search_annotation
|
820 |
+
|
821 |
+
likelihood_name = (
|
822 |
+
"UNKNOWN",
|
823 |
+
"VERY_UNLIKELY",
|
824 |
+
"UNLIKELY",
|
825 |
+
"POSSIBLE",
|
826 |
+
"LIKELY",
|
827 |
+
"VERY_LIKELY",
|
828 |
+
)
|
829 |
+
print("Safe search:")
|
830 |
+
|
831 |
+
print(f" adult*: {likelihood_name[safe.adult]}")
|
832 |
+
print(f" medical*: {likelihood_name[safe.medical]}")
|
833 |
+
print(f" spoofed: {likelihood_name[safe.spoof]}")
|
834 |
+
print(f" violence*: {likelihood_name[safe.violence]}")
|
835 |
+
print(f" racy: {likelihood_name[safe.racy]}")
|
836 |
+
|
837 |
+
# Check the levels of adult, violence, racy, etc. content.
|
838 |
+
if (safe.adult > LEVEL or
|
839 |
+
safe.medical > LEVEL or
|
840 |
+
# safe.spoof > LEVEL or
|
841 |
+
safe.violence > LEVEL #or
|
842 |
+
# safe.racy > LEVEL
|
843 |
+
):
|
844 |
+
print("Found violation")
|
845 |
+
return True # The image violates safe search guidelines.
|
846 |
+
|
847 |
+
print("Found NO violation")
|
848 |
+
return False # The image is considered safe.
|
849 |
+
except:
|
850 |
+
return False # The image is considered safe. TEMPOROARY FIX TODO
|
vouchervision/OCR_google_cloud_vision.py
CHANGED
@@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont
|
|
7 |
import colorsys
|
8 |
from tqdm import tqdm
|
9 |
from google.oauth2 import service_account
|
10 |
-
|
11 |
### LLaVA should only be installed if the user will actually use it.
|
12 |
### It requires the most recent pytorch/Python and can mess with older systems
|
13 |
|
@@ -43,6 +43,7 @@ class OCREngine:
|
|
43 |
self.path = path
|
44 |
self.cfg = cfg
|
45 |
self.do_use_trOCR = self.cfg['leafmachine']['project']['do_use_trOCR']
|
|
|
46 |
self.OCR_option = self.cfg['leafmachine']['project']['OCR_option']
|
47 |
self.double_OCR = self.cfg['leafmachine']['project']['double_OCR']
|
48 |
self.dir_home = dir_home
|
@@ -53,6 +54,8 @@ class OCREngine:
|
|
53 |
self.trOCR_model = trOCR_model
|
54 |
self.device = device
|
55 |
|
|
|
|
|
56 |
self.hand_cleaned_text = None
|
57 |
self.hand_organized_text = None
|
58 |
self.hand_bounds = None
|
@@ -80,6 +83,7 @@ class OCREngine:
|
|
80 |
self.trOCR_confidences = None
|
81 |
self.trOCR_characters = None
|
82 |
self.set_client()
|
|
|
83 |
self.init_craft()
|
84 |
|
85 |
self.multimodal_prompt = """I need you to transcribe all of the text in this image.
|
@@ -117,6 +121,10 @@ class OCREngine:
|
|
117 |
else:
|
118 |
self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=False)
|
119 |
|
|
|
|
|
|
|
|
|
120 |
def init_llava(self):
|
121 |
if 'LLaVA' in self.OCR_option:
|
122 |
from vouchervision.OCR_llava import OCRllava
|
@@ -241,8 +249,6 @@ class OCREngine:
|
|
241 |
def detect_text_with_trOCR_using_google_bboxes(self, do_use_trOCR, logger):
|
242 |
CONFIDENCES = 0.80
|
243 |
MAX_NEW_TOKENS = 50
|
244 |
-
|
245 |
-
self.OCR_JSON_to_file = {}
|
246 |
|
247 |
ocr_parts = ''
|
248 |
if not do_use_trOCR:
|
@@ -677,6 +683,9 @@ class OCREngine:
|
|
677 |
|
678 |
|
679 |
def process_image(self, do_create_OCR_helper_image, logger):
|
|
|
|
|
|
|
680 |
# Can stack options, so solitary if statements
|
681 |
self.OCR = 'OCR:\n'
|
682 |
if 'CRAFT' in self.OCR_option:
|
@@ -697,11 +706,7 @@ class OCREngine:
|
|
697 |
image, json_output, direct_output, str_output, usage_report = self.Llava.transcribe_image(self.path, self.multimodal_prompt)
|
698 |
self.logger.info(f"LLaVA Usage Report for Model {self.Llava.model_path}:\n{usage_report}")
|
699 |
|
700 |
-
|
701 |
-
self.OCR_JSON_to_file['OCR_LLaVA'] = str_output
|
702 |
-
except:
|
703 |
-
self.OCR_JSON_to_file = {}
|
704 |
-
self.OCR_JSON_to_file['OCR_LLaVA'] = str_output
|
705 |
|
706 |
if self.double_OCR:
|
707 |
self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}" + f"\nLLaVA OCR:\n{str_output}"
|
@@ -709,6 +714,20 @@ class OCREngine:
|
|
709 |
self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}"
|
710 |
# logger.info(f"LLaVA OCR:\n{self.OCR}")
|
711 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
712 |
if 'normal' in self.OCR_option or 'hand' in self.OCR_option:
|
713 |
if 'normal' in self.OCR_option:
|
714 |
if self.double_OCR:
|
@@ -762,14 +781,16 @@ class OCREngine:
|
|
762 |
|
763 |
### Merge final overlay image
|
764 |
### [original, normal bboxes, normal text]
|
765 |
-
if '
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
|
|
|
|
773 |
|
774 |
if self.do_use_trOCR:
|
775 |
if 'CRAFT' in self.OCR_option:
|
|
|
7 |
import colorsys
|
8 |
from tqdm import tqdm
|
9 |
from google.oauth2 import service_account
|
10 |
+
from OCR_Florence_2 import FlorenceOCR
|
11 |
### LLaVA should only be installed if the user will actually use it.
|
12 |
### It requires the most recent pytorch/Python and can mess with older systems
|
13 |
|
|
|
43 |
self.path = path
|
44 |
self.cfg = cfg
|
45 |
self.do_use_trOCR = self.cfg['leafmachine']['project']['do_use_trOCR']
|
46 |
+
self.do_use_florence = self.cfg['leafmachine']['project']['do_use_florence']
|
47 |
self.OCR_option = self.cfg['leafmachine']['project']['OCR_option']
|
48 |
self.double_OCR = self.cfg['leafmachine']['project']['double_OCR']
|
49 |
self.dir_home = dir_home
|
|
|
54 |
self.trOCR_model = trOCR_model
|
55 |
self.device = device
|
56 |
|
57 |
+
self.OCR_JSON_to_file = {}
|
58 |
+
|
59 |
self.hand_cleaned_text = None
|
60 |
self.hand_organized_text = None
|
61 |
self.hand_bounds = None
|
|
|
83 |
self.trOCR_confidences = None
|
84 |
self.trOCR_characters = None
|
85 |
self.set_client()
|
86 |
+
self.init_florence()
|
87 |
self.init_craft()
|
88 |
|
89 |
self.multimodal_prompt = """I need you to transcribe all of the text in this image.
|
|
|
121 |
else:
|
122 |
self.craft_net = load_craftnet_model(weight_path=os.path.join(self.dir_home,'vouchervision','craft','craft_mlt_25k.pth'), cuda=False)
|
123 |
|
124 |
+
def init_florence(self):
|
125 |
+
if 'Florence-2' in self.OCR_option:
|
126 |
+
self.Florence = FlorenceOCR(logger=self.logger, model_id=self.cfg['leafmachine']['project']['florence_model_path'])
|
127 |
+
|
128 |
def init_llava(self):
|
129 |
if 'LLaVA' in self.OCR_option:
|
130 |
from vouchervision.OCR_llava import OCRllava
|
|
|
249 |
def detect_text_with_trOCR_using_google_bboxes(self, do_use_trOCR, logger):
|
250 |
CONFIDENCES = 0.80
|
251 |
MAX_NEW_TOKENS = 50
|
|
|
|
|
252 |
|
253 |
ocr_parts = ''
|
254 |
if not do_use_trOCR:
|
|
|
683 |
|
684 |
|
685 |
def process_image(self, do_create_OCR_helper_image, logger):
|
686 |
+
if 'hand' not in self.OCR_option and 'normal' not in self.OCR_option:
|
687 |
+
do_create_OCR_helper_image = False
|
688 |
+
|
689 |
# Can stack options, so solitary if statements
|
690 |
self.OCR = 'OCR:\n'
|
691 |
if 'CRAFT' in self.OCR_option:
|
|
|
706 |
image, json_output, direct_output, str_output, usage_report = self.Llava.transcribe_image(self.path, self.multimodal_prompt)
|
707 |
self.logger.info(f"LLaVA Usage Report for Model {self.Llava.model_path}:\n{usage_report}")
|
708 |
|
709 |
+
self.OCR_JSON_to_file['OCR_LLaVA'] = str_output
|
|
|
|
|
|
|
|
|
710 |
|
711 |
if self.double_OCR:
|
712 |
self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}" + f"\nLLaVA OCR:\n{str_output}"
|
|
|
714 |
self.OCR = self.OCR + f"\nLLaVA OCR:\n{str_output}"
|
715 |
# logger.info(f"LLaVA OCR:\n{self.OCR}")
|
716 |
|
717 |
+
if 'Florence-2' in self.OCR_option: # This option does not produce an OCR helper image
|
718 |
+
if self.json_report:
|
719 |
+
self.json_report.set_text(text_main=f'Working on Florence-2 [{self.Florence.model_id}] transcription :construction:')
|
720 |
+
|
721 |
+
self.logger.info(f"Florence-2 Usage Report for Model [{self.Florence.model_id}]")
|
722 |
+
results_text, results_text_dirty, results, usage_report = self.Florence.ocr_florence(self.path, task_prompt='<OCR>', text_input=None)
|
723 |
+
|
724 |
+
self.OCR_JSON_to_file['OCR_Florence'] = results_text
|
725 |
+
|
726 |
+
if self.double_OCR:
|
727 |
+
self.OCR = self.OCR + f"\nFlorence-2 OCR:\n{results_text}" + f"\nFlorence-2 OCR:\n{results_text}"
|
728 |
+
else:
|
729 |
+
self.OCR = self.OCR + f"\nFlorence-2 OCR:\n{results_text}"
|
730 |
+
|
731 |
if 'normal' in self.OCR_option or 'hand' in self.OCR_option:
|
732 |
if 'normal' in self.OCR_option:
|
733 |
if self.double_OCR:
|
|
|
781 |
|
782 |
### Merge final overlay image
|
783 |
### [original, normal bboxes, normal text]
|
784 |
+
if 'hand' in self.OCR_option or 'normal' in self.OCR_option:
|
785 |
+
if 'CRAFT' in self.OCR_option or 'normal' in self.OCR_option:
|
786 |
+
self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_normal)
|
787 |
+
### [original, hand bboxes, hand text]
|
788 |
+
elif 'hand' in self.OCR_option:
|
789 |
+
self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_hand)
|
790 |
+
### [original, normal bboxes, normal text, hand bboxes, hand text]
|
791 |
+
else:
|
792 |
+
self.overlay_image = self.merge_images(Image.open(self.path), self.merge_images(self.merged_image_normal, self.merged_image_hand))
|
793 |
+
|
794 |
|
795 |
if self.do_use_trOCR:
|
796 |
if 'CRAFT' in self.OCR_option:
|
vouchervision/VoucherVision_Config_Builder.py
CHANGED
@@ -36,21 +36,22 @@ def build_VV_config(loaded_cfg=None):
|
|
36 |
save_cropped_annotations = ['label','barcode']
|
37 |
|
38 |
do_use_trOCR = False
|
|
|
39 |
trOCR_model_path = "microsoft/trocr-large-handwritten"
|
|
|
40 |
OCR_option = 'hand'
|
41 |
OCR_option_llava = 'llava-v1.6-mistral-7b' # "llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",
|
42 |
OCR_option_llava_bit = 'full' # full or 4bit
|
43 |
double_OCR = False
|
44 |
|
45 |
-
|
46 |
tool_GEO = True
|
47 |
tool_WFO = True
|
48 |
tool_wikipedia = True
|
49 |
|
50 |
check_for_illegal_filenames = False
|
51 |
|
52 |
-
LLM_version_user = 'Azure GPT 4' #'Azure GPT 4 Turbo 1106-preview'
|
53 |
-
prompt_version = '
|
54 |
use_LeafMachine2_collage_images = True # Use LeafMachine2 collage images
|
55 |
do_create_OCR_helper_image = True
|
56 |
|
@@ -71,7 +72,7 @@ def build_VV_config(loaded_cfg=None):
|
|
71 |
return assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
72 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
73 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
74 |
-
prompt_version, do_create_OCR_helper_image, do_use_trOCR, trOCR_model_path, OCR_option, OCR_option_llava,
|
75 |
OCR_option_llava_bit, double_OCR, save_cropped_annotations,
|
76 |
tool_GEO, tool_WFO, tool_wikipedia,
|
77 |
check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
|
@@ -88,7 +89,9 @@ def build_VV_config(loaded_cfg=None):
|
|
88 |
catalog_numerical_only = loaded_cfg['leafmachine']['project']['catalog_numerical_only']
|
89 |
|
90 |
do_use_trOCR = loaded_cfg['leafmachine']['project']['do_use_trOCR']
|
|
|
91 |
trOCR_model_path = loaded_cfg['leafmachine']['project']['trOCR_model_path']
|
|
|
92 |
OCR_option = loaded_cfg['leafmachine']['project']['OCR_option']
|
93 |
OCR_option_llava = loaded_cfg['leafmachine']['project']['OCR_option_llava']
|
94 |
OCR_option_llava_bit = loaded_cfg['leafmachine']['project']['OCR_option_llava_bit']
|
@@ -118,7 +121,7 @@ def build_VV_config(loaded_cfg=None):
|
|
118 |
return assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
119 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
120 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
121 |
-
prompt_version, do_create_OCR_helper_image, do_use_trOCR, trOCR_model_path, OCR_option, OCR_option_llava,
|
122 |
OCR_option_llava_bit, double_OCR, save_cropped_annotations,
|
123 |
tool_GEO, tool_WFO, tool_wikipedia,
|
124 |
check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
|
@@ -127,7 +130,7 @@ def build_VV_config(loaded_cfg=None):
|
|
127 |
def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
128 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
129 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
130 |
-
prompt_version, do_create_OCR_helper_image_user, do_use_trOCR, trOCR_model_path, OCR_option, OCR_option_llava,
|
131 |
OCR_option_llava_bit, double_OCR, save_cropped_annotations,
|
132 |
tool_GEO, tool_WFO, tool_wikipedia,
|
133 |
check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False):
|
@@ -174,7 +177,9 @@ def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
|
174 |
'delete_all_temps': False,
|
175 |
'delete_temps_keep_VVE': False,
|
176 |
'do_use_trOCR': do_use_trOCR,
|
|
|
177 |
'trOCR_model_path': trOCR_model_path,
|
|
|
178 |
'OCR_option': OCR_option,
|
179 |
'OCR_option_llava': OCR_option_llava,
|
180 |
'OCR_option_llava_bit': OCR_option_llava_bit,
|
|
|
36 |
save_cropped_annotations = ['label','barcode']
|
37 |
|
38 |
do_use_trOCR = False
|
39 |
+
do_use_florence = False
|
40 |
trOCR_model_path = "microsoft/trocr-large-handwritten"
|
41 |
+
florence_model_path = "microsoft/Florence-2-large"
|
42 |
OCR_option = 'hand'
|
43 |
OCR_option_llava = 'llava-v1.6-mistral-7b' # "llava-v1.6-mistral-7b", "llava-v1.6-34b", "llava-v1.6-vicuna-13b", "llava-v1.6-vicuna-7b",
|
44 |
OCR_option_llava_bit = 'full' # full or 4bit
|
45 |
double_OCR = False
|
46 |
|
|
|
47 |
tool_GEO = True
|
48 |
tool_WFO = True
|
49 |
tool_wikipedia = True
|
50 |
|
51 |
check_for_illegal_filenames = False
|
52 |
|
53 |
+
LLM_version_user = 'Gemini 1.5 Flash' # 'Azure GPT 4' #'Azure GPT 4 Turbo 1106-preview'
|
54 |
+
prompt_version = 'SLTPvM_long.yaml' # from ["Version 1", "Version 1 No Domain Knowledge", "Version 2"]
|
55 |
use_LeafMachine2_collage_images = True # Use LeafMachine2 collage images
|
56 |
do_create_OCR_helper_image = True
|
57 |
|
|
|
72 |
return assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
73 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
74 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
75 |
+
prompt_version, do_create_OCR_helper_image, do_use_trOCR, do_use_florence, trOCR_model_path, florence_model_path, OCR_option, OCR_option_llava,
|
76 |
OCR_option_llava_bit, double_OCR, save_cropped_annotations,
|
77 |
tool_GEO, tool_WFO, tool_wikipedia,
|
78 |
check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
|
|
|
89 |
catalog_numerical_only = loaded_cfg['leafmachine']['project']['catalog_numerical_only']
|
90 |
|
91 |
do_use_trOCR = loaded_cfg['leafmachine']['project']['do_use_trOCR']
|
92 |
+
do_use_florence = loaded_cfg['leafmachine']['project']['do_use_florence']
|
93 |
trOCR_model_path = loaded_cfg['leafmachine']['project']['trOCR_model_path']
|
94 |
+
florence_model_path = loaded_cfg['leafmachine']['project']['florence_model_path']
|
95 |
OCR_option = loaded_cfg['leafmachine']['project']['OCR_option']
|
96 |
OCR_option_llava = loaded_cfg['leafmachine']['project']['OCR_option_llava']
|
97 |
OCR_option_llava_bit = loaded_cfg['leafmachine']['project']['OCR_option_llava_bit']
|
|
|
121 |
return assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
122 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
123 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
124 |
+
prompt_version, do_create_OCR_helper_image, do_use_trOCR, do_use_florence, trOCR_model_path, florence_model_path, OCR_option, OCR_option_llava,
|
125 |
OCR_option_llava_bit, double_OCR, save_cropped_annotations,
|
126 |
tool_GEO, tool_WFO, tool_wikipedia,
|
127 |
check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False)
|
|
|
130 |
def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
131 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
132 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
133 |
+
prompt_version, do_create_OCR_helper_image_user, do_use_trOCR, do_use_florence, trOCR_model_path, florence_model_path, OCR_option, OCR_option_llava,
|
134 |
OCR_option_llava_bit, double_OCR, save_cropped_annotations,
|
135 |
tool_GEO, tool_WFO, tool_wikipedia,
|
136 |
check_for_illegal_filenames, skip_vertical, pdf_conversion_dpi, use_domain_knowledge=False):
|
|
|
177 |
'delete_all_temps': False,
|
178 |
'delete_temps_keep_VVE': False,
|
179 |
'do_use_trOCR': do_use_trOCR,
|
180 |
+
'do_use_florence': do_use_florence,
|
181 |
'trOCR_model_path': trOCR_model_path,
|
182 |
+
'florence_model_path': florence_model_path,
|
183 |
'OCR_option': OCR_option,
|
184 |
'OCR_option_llava': OCR_option_llava,
|
185 |
'OCR_option_llava_bit': OCR_option_llava_bit,
|
vouchervision/fetch_data.py
CHANGED
@@ -7,7 +7,7 @@ import urllib.request
|
|
7 |
from tqdm import tqdm
|
8 |
import subprocess
|
9 |
|
10 |
-
VERSION = 'v-2-
|
11 |
|
12 |
def fetch_data(logger, dir_home, cfg_file_path):
|
13 |
logger.name = 'Fetch Data'
|
|
|
7 |
from tqdm import tqdm
|
8 |
import subprocess
|
9 |
|
10 |
+
VERSION = 'v-2-3'
|
11 |
|
12 |
def fetch_data(logger, dir_home, cfg_file_path):
|
13 |
logger.name = 'Fetch Data'
|
vouchervision/generate_partner_collage.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
from PIL import Image
|
4 |
+
from io import BytesIO
|
5 |
+
import time
|
6 |
+
|
7 |
+
# Global variables
|
8 |
+
H = 200
|
9 |
+
ROWS = 6
|
10 |
+
PADDING = 30
|
11 |
+
|
12 |
+
# Step 1: Fetch the Images from the URL Folder
|
13 |
+
def fetch_image_urls(url):
|
14 |
+
response = requests.get(url + '?t=' + str(time.time()))
|
15 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
16 |
+
images = {}
|
17 |
+
for node in soup.find_all('a'):
|
18 |
+
href = node.get('href')
|
19 |
+
if href.endswith(('.png', '.jpg', '.jpeg')):
|
20 |
+
try:
|
21 |
+
image_index = int(href.split('__')[0])
|
22 |
+
images[image_index] = url + '/' + href + '?t=' + str(time.time())
|
23 |
+
except ValueError:
|
24 |
+
print(f"Skipping invalid image: {href}")
|
25 |
+
return images
|
26 |
+
|
27 |
+
# Step 2: Resize Images to Height H
|
28 |
+
def fetch_image(url):
|
29 |
+
response = requests.get(url)
|
30 |
+
return Image.open(BytesIO(response.content))
|
31 |
+
|
32 |
+
def resize_images(images, target_height):
|
33 |
+
resized_images = {}
|
34 |
+
for index, img in images.items():
|
35 |
+
ratio = target_height / img.height
|
36 |
+
new_width = int(img.width * ratio)
|
37 |
+
resized_img = img.resize((new_width, target_height), Image.BICUBIC)
|
38 |
+
resized_images[index] = resized_img
|
39 |
+
return resized_images
|
40 |
+
|
41 |
+
# Step 3: Create a Collage with Efficient Placement Algorithm
|
42 |
+
def create_collage(image_urls, collage_path, H, ROWS, PADDING):
|
43 |
+
images = {index: fetch_image(url) for index, url in image_urls.items()}
|
44 |
+
resized_images = resize_images(images, H) # Resize to H pixels height
|
45 |
+
|
46 |
+
center_image = resized_images.pop(0)
|
47 |
+
other_images = list(resized_images.items())
|
48 |
+
|
49 |
+
# Calculate collage size based on the number of rows
|
50 |
+
collage_width = 3000 # 16:9 aspect ratio width
|
51 |
+
collage_height = (H + PADDING) * ROWS + 2 * PADDING # Adjust height based on number of rows, add padding to top and bottom
|
52 |
+
collage = Image.new('RGB', (collage_width, collage_height), (255, 255, 255))
|
53 |
+
|
54 |
+
# Sort images by width and height
|
55 |
+
sorted_images = sorted(other_images, key=lambda x: x[1].width * x[1].height, reverse=True)
|
56 |
+
|
57 |
+
# Create alternate placement list and insert the center image in the middle
|
58 |
+
alternate_images = []
|
59 |
+
i, j = 0, len(sorted_images) - 1
|
60 |
+
halfway_point = (len(sorted_images) + 1) // 2
|
61 |
+
count = 0
|
62 |
+
|
63 |
+
while i <= j:
|
64 |
+
if count == halfway_point:
|
65 |
+
alternate_images.append((0, center_image))
|
66 |
+
if i == j:
|
67 |
+
alternate_images.append(sorted_images[i])
|
68 |
+
else:
|
69 |
+
alternate_images.append(sorted_images[i])
|
70 |
+
alternate_images.append(sorted_images[j])
|
71 |
+
i += 1
|
72 |
+
j -= 1
|
73 |
+
count += 2
|
74 |
+
|
75 |
+
# Calculate number of images per row
|
76 |
+
images_per_row = len(alternate_images) // ROWS
|
77 |
+
extra_images = len(alternate_images) % ROWS
|
78 |
+
|
79 |
+
# Place images in rows with only padding space between them
|
80 |
+
def place_images_in_rows(images, collage, max_width, padding, row_height, rows, images_per_row, extra_images):
|
81 |
+
y = padding
|
82 |
+
for current_row in range(rows):
|
83 |
+
row_images_count = images_per_row + (1 if extra_images > 0 else 0)
|
84 |
+
extra_images -= 1 if extra_images > 0 else 0
|
85 |
+
row_images = images[:row_images_count]
|
86 |
+
row_width = sum(img.width for idx, img in row_images) + padding * (row_images_count - 1)
|
87 |
+
x = (max_width - row_width) // 2
|
88 |
+
for idx, img in row_images:
|
89 |
+
collage.paste(img, (x, y))
|
90 |
+
x += img.width + padding
|
91 |
+
y += row_height + padding
|
92 |
+
images = images[row_images_count:]
|
93 |
+
|
94 |
+
place_images_in_rows(alternate_images, collage, collage_width, PADDING, H, ROWS, images_per_row, extra_images)
|
95 |
+
|
96 |
+
collage.save(collage_path)
|
97 |
+
|
98 |
+
# Define the URL folder and other constants
|
99 |
+
url_folder = 'https://leafmachine.org/partners/'
|
100 |
+
collage_path = 'img/collage.jpg'
|
101 |
+
|
102 |
+
# Fetch, Create, and Update
|
103 |
+
image_urls = fetch_image_urls(url_folder)
|
104 |
+
create_collage(image_urls, collage_path, H, ROWS, PADDING)
|
vouchervision/librarian_knowledge.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"catalogNumber": "barcode identifier, at least 6 digits, fewer than 30 digits.",
|
3 |
+
"order": "full scientific name of the Order in which the taxon is classified. Order must be capitalized.",
|
4 |
+
"family": "full scientific name of the Family in which the taxon is classified. Family must be capitalized.",
|
5 |
+
"scientificName": "scientific name of the taxon including Genus, specific epithet, and any lower classifications.",
|
6 |
+
"scientificNameAuthorship": "authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.",
|
7 |
+
"genus": "taxonomic determination to Genus, Genus must be capitalized.",
|
8 |
+
"specificEpithet": "The name of the first or species epithet of the scientificName. Only include the species epithet.",
|
9 |
+
"identifiedBy": "list of names of people, doctors, professors, groups, or organizations who identified, determined the taxon name to the subject organism. This is not the specimen collector.",
|
10 |
+
"recordedBy": "list of names of people, doctors, professors, groups, or organizations.",
|
11 |
+
"recordNumber": "identifier given to the specimen at the time it was recorded.",
|
12 |
+
"verbatimEventDate": "The verbatim original representation of the date and time information for when the specimen was collected.",
|
13 |
+
"eventDate": "collection date formatted as year-month-day YYYY-MM-DD.",
|
14 |
+
"habitat": "habitat.",
|
15 |
+
"occurrenceRemarks": "all descriptive text in the OCR rearranged into sensible sentences or sentence fragments.",
|
16 |
+
"country": "country or major administrative unit.",
|
17 |
+
"stateProvince": "state, province, canton, department, region, etc.",
|
18 |
+
"county": "county, shire, department, parish etc.",
|
19 |
+
"municipality": "city, municipality, etc.",
|
20 |
+
"locality": "description of geographic information aiding in pinpointing the exact origin or location of the specimen.",
|
21 |
+
"degreeOfEstablishment": "cultivated plants are intentionally grown by humans. Use either - unknown or cultivated.",
|
22 |
+
"decimalLatitude": "latitude decimal coordinate.",
|
23 |
+
"decimalLongitude": "longitude decimal coordinate.",
|
24 |
+
"verbatimCoordinates": "verbatim location coordinates.",
|
25 |
+
"minimumElevationInMeters": "minimum elevation or altitude in meters.",
|
26 |
+
"maximumElevationInMeters": "maximum elevation or altitude in meters."
|
27 |
+
}
|
vouchervision/save_dataset.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
|
3 |
+
# Load the dataset
|
4 |
+
dataset = load_dataset("phyloforfun/HLT_MICH_Angiospermae_SLTPvC_v1-0_medium_OCR-C25-L25-E50-R05")
|
5 |
+
|
6 |
+
# Define the directory where you want to save the files
|
7 |
+
save_dir = "D:/Dropbox/VoucherVision/datasets/SLTPvC_v1-0_medium_OCR-C25-L25-E50-R05"
|
8 |
+
|
9 |
+
# Save each split as a JSONL file in the specified directory
|
10 |
+
for split, split_dataset in dataset.items():
|
11 |
+
split_dataset.to_json(f"{save_dir}/SLTPvC_v1-0_medium_OCR-C25-L25-E50-R05-{split}.jsonl")
|
12 |
+
|
13 |
+
|
14 |
+
'''import json # convert to google
|
15 |
+
|
16 |
+
# Load the JSONL file
|
17 |
+
input_file_path = '/mnt/data/SLTPvC_v1-0_medium_OCR-C25-L25-E50-R05-train.jsonl'
|
18 |
+
output_file_path = '/mnt/data/SLTPvC_v1-0_medium_OCR-C25-L25-E50-R05-train-converted.jsonl'
|
19 |
+
|
20 |
+
# Define the conversion function
|
21 |
+
def convert_record(record):
|
22 |
+
return {
|
23 |
+
"input_text": record.get('instruction', '') + ' ' + record.get('input', ''),
|
24 |
+
"target_text": record.get('output', '')
|
25 |
+
}
|
26 |
+
|
27 |
+
# Convert and save the new JSONL file
|
28 |
+
with open(input_file_path, 'r', encoding='utf-8') as infile, open(output_file_path, 'w', encoding='utf-8') as outfile:
|
29 |
+
for line in infile:
|
30 |
+
record = json.loads(line)
|
31 |
+
converted_record = convert_record(record)
|
32 |
+
outfile.write(json.dumps(converted_record) + '\n')
|
33 |
+
|
34 |
+
output_file_path'''
|
vouchervision/utils_VoucherVision.py
CHANGED
@@ -14,6 +14,7 @@ from vouchervision.LLM_GoogleGemini import GoogleGeminiHandler
|
|
14 |
from vouchervision.LLM_MistralAI import MistralHandler
|
15 |
from vouchervision.LLM_local_cpu_MistralAI import LocalCPUMistralHandler
|
16 |
from vouchervision.LLM_local_MistralAI import LocalMistralHandler
|
|
|
17 |
from vouchervision.prompt_catalog import PromptCatalog
|
18 |
from vouchervision.model_maps import ModelMaps
|
19 |
from vouchervision.general_utils import get_cfg_from_full_path
|
@@ -449,6 +450,8 @@ class VoucherVision():
|
|
449 |
k_openai = os.getenv('OPENAI_API_KEY')
|
450 |
k_openai_azure = os.getenv('AZURE_API_VERSION')
|
451 |
|
|
|
|
|
452 |
k_google_project_id = os.getenv('GOOGLE_PROJECT_ID')
|
453 |
k_google_location = os.getenv('GOOGLE_LOCATION')
|
454 |
k_google_application_credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
|
@@ -464,6 +467,8 @@ class VoucherVision():
|
|
464 |
k_openai = self.cfg_private['openai']['OPENAI_API_KEY']
|
465 |
k_openai_azure = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE']
|
466 |
|
|
|
|
|
467 |
k_google_project_id = self.cfg_private['google']['GOOGLE_PROJECT_ID']
|
468 |
k_google_location = self.cfg_private['google']['GOOGLE_LOCATION']
|
469 |
k_google_application_credentials = self.cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS']
|
@@ -478,6 +483,8 @@ class VoucherVision():
|
|
478 |
self.has_key_azure_openai = self.has_API_key(k_openai_azure)
|
479 |
self.llm = None
|
480 |
|
|
|
|
|
481 |
self.has_key_google_project_id = self.has_API_key(k_google_project_id)
|
482 |
self.has_key_google_location = self.has_API_key(k_google_location)
|
483 |
self.has_key_google_application_credentials = self.has_API_key(k_google_application_credentials)
|
@@ -505,6 +512,11 @@ class VoucherVision():
|
|
505 |
openai.api_key = self.cfg_private['openai']['OPENAI_API_KEY']
|
506 |
os.environ["OPENAI_API_KEY"] = self.cfg_private['openai']['OPENAI_API_KEY']
|
507 |
|
|
|
|
|
|
|
|
|
|
|
508 |
|
509 |
### OpenAI - Azure
|
510 |
if self.has_key_azure_openai:
|
@@ -738,6 +750,10 @@ class VoucherVision():
|
|
738 |
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_cpu_MistralAI(prompt, json_report, paths)
|
739 |
else:
|
740 |
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_MistralAI(prompt, json_report, paths)
|
|
|
|
|
|
|
|
|
741 |
else:
|
742 |
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_OpenAI(prompt, json_report, paths)
|
743 |
|
@@ -771,6 +787,8 @@ class VoucherVision():
|
|
771 |
return LocalCPUMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
772 |
else:
|
773 |
return LocalMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
|
|
|
|
774 |
else:
|
775 |
if 'PALM2' in name_parts:
|
776 |
return GooglePalm2Handler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
|
|
14 |
from vouchervision.LLM_MistralAI import MistralHandler
|
15 |
from vouchervision.LLM_local_cpu_MistralAI import LocalCPUMistralHandler
|
16 |
from vouchervision.LLM_local_MistralAI import LocalMistralHandler
|
17 |
+
from vouchervision.LLM_local_custom_fine_tune import LocalFineTuneHandler
|
18 |
from vouchervision.prompt_catalog import PromptCatalog
|
19 |
from vouchervision.model_maps import ModelMaps
|
20 |
from vouchervision.general_utils import get_cfg_from_full_path
|
|
|
450 |
k_openai = os.getenv('OPENAI_API_KEY')
|
451 |
k_openai_azure = os.getenv('AZURE_API_VERSION')
|
452 |
|
453 |
+
k_huggingface = None
|
454 |
+
|
455 |
k_google_project_id = os.getenv('GOOGLE_PROJECT_ID')
|
456 |
k_google_location = os.getenv('GOOGLE_LOCATION')
|
457 |
k_google_application_credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
|
|
|
467 |
k_openai = self.cfg_private['openai']['OPENAI_API_KEY']
|
468 |
k_openai_azure = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE']
|
469 |
|
470 |
+
k_huggingface = self.cfg_private['huggingface']['hf_token']
|
471 |
+
|
472 |
k_google_project_id = self.cfg_private['google']['GOOGLE_PROJECT_ID']
|
473 |
k_google_location = self.cfg_private['google']['GOOGLE_LOCATION']
|
474 |
k_google_application_credentials = self.cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS']
|
|
|
483 |
self.has_key_azure_openai = self.has_API_key(k_openai_azure)
|
484 |
self.llm = None
|
485 |
|
486 |
+
self.has_key_huggingface = self.has_API_key(k_huggingface)
|
487 |
+
|
488 |
self.has_key_google_project_id = self.has_API_key(k_google_project_id)
|
489 |
self.has_key_google_location = self.has_API_key(k_google_location)
|
490 |
self.has_key_google_application_credentials = self.has_API_key(k_google_application_credentials)
|
|
|
512 |
openai.api_key = self.cfg_private['openai']['OPENAI_API_KEY']
|
513 |
os.environ["OPENAI_API_KEY"] = self.cfg_private['openai']['OPENAI_API_KEY']
|
514 |
|
515 |
+
if self.has_key_huggingface:
|
516 |
+
if self.is_hf:
|
517 |
+
pass
|
518 |
+
else:
|
519 |
+
os.environ["HUGGING_FACE_KEY"] = self.cfg_private['huggingface']['hf_token']
|
520 |
|
521 |
### OpenAI - Azure
|
522 |
if self.has_key_azure_openai:
|
|
|
750 |
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_cpu_MistralAI(prompt, json_report, paths)
|
751 |
else:
|
752 |
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_MistralAI(prompt, json_report, paths)
|
753 |
+
|
754 |
+
elif "/" in ''.join(name_parts):
|
755 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_custom_fine_tune(self.OCR, json_report, paths) ###
|
756 |
+
|
757 |
else:
|
758 |
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_OpenAI(prompt, json_report, paths)
|
759 |
|
|
|
787 |
return LocalCPUMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
788 |
else:
|
789 |
return LocalMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
790 |
+
elif "/" in ''.join(name_parts):
|
791 |
+
return LocalFineTuneHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
792 |
else:
|
793 |
if 'PALM2' in name_parts:
|
794 |
return GooglePalm2Handler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
vouchervision/utils_hf (DESKTOP-548UDCR's conflicted copy 2024-06-15).py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, json, re, datetime, tempfile, yaml
|
2 |
+
from googleapiclient.discovery import build
|
3 |
+
from googleapiclient.http import MediaFileUpload
|
4 |
+
from google.oauth2 import service_account
|
5 |
+
import base64
|
6 |
+
from PIL import Image
|
7 |
+
from PIL import Image
|
8 |
+
from io import BytesIO
|
9 |
+
from shutil import copyfileobj, copyfile
|
10 |
+
|
11 |
+
# from vouchervision.general_utils import get_cfg_from_full_path
|
12 |
+
|
13 |
+
|
14 |
+
def setup_streamlit_config(dir_home):
|
15 |
+
# Define the directory path and filename
|
16 |
+
dir_path = os.path.join(dir_home, ".streamlit")
|
17 |
+
file_path = os.path.join(dir_path, "config.toml")
|
18 |
+
|
19 |
+
# Check if directory exists, if not create it
|
20 |
+
if not os.path.exists(dir_path):
|
21 |
+
os.makedirs(dir_path)
|
22 |
+
|
23 |
+
# Create or modify the file with the provided content
|
24 |
+
config_content = f"""
|
25 |
+
[theme]
|
26 |
+
base = "dark"
|
27 |
+
primaryColor = "#00ff00"
|
28 |
+
|
29 |
+
[server]
|
30 |
+
enableStaticServing = false
|
31 |
+
runOnSave = true
|
32 |
+
port = 8524
|
33 |
+
maxUploadSize = 5000
|
34 |
+
"""
|
35 |
+
|
36 |
+
with open(file_path, "w") as f:
|
37 |
+
f.write(config_content.strip())
|
38 |
+
|
39 |
+
|
40 |
+
def save_uploaded_file_local(directory_in, directory_out, img_file_name, image=None):
|
41 |
+
if not os.path.exists(directory_out):
|
42 |
+
os.makedirs(directory_out)
|
43 |
+
|
44 |
+
# Assuming img_file_name includes the extension
|
45 |
+
img_file_base, img_file_ext = os.path.splitext(img_file_name)
|
46 |
+
|
47 |
+
full_path_out = os.path.join(directory_out, img_file_name)
|
48 |
+
full_path_in = os.path.join(directory_in, img_file_name)
|
49 |
+
|
50 |
+
# Check if the file extension is .pdf (or add other conditions for different file types)
|
51 |
+
if img_file_ext.lower() == '.pdf':
|
52 |
+
# Copy the file from the input directory to the output directory
|
53 |
+
copyfile(full_path_in, full_path_out)
|
54 |
+
return full_path_out
|
55 |
+
else:
|
56 |
+
if image is None:
|
57 |
+
try:
|
58 |
+
with Image.open(full_path_in) as image:
|
59 |
+
image.save(full_path_out, "JPEG")
|
60 |
+
# Return the full path of the saved image
|
61 |
+
return full_path_out
|
62 |
+
except:
|
63 |
+
pass
|
64 |
+
else:
|
65 |
+
try:
|
66 |
+
image.save(full_path_out, "JPEG")
|
67 |
+
return full_path_out
|
68 |
+
except:
|
69 |
+
pass
|
70 |
+
|
71 |
+
|
72 |
+
def save_uploaded_file(directory, img_file, image=None):
|
73 |
+
if not os.path.exists(directory):
|
74 |
+
os.makedirs(directory)
|
75 |
+
|
76 |
+
full_path = os.path.join(directory, img_file.name)
|
77 |
+
|
78 |
+
# Assuming the uploaded file is an image
|
79 |
+
if img_file.name.lower().endswith('.pdf'):
|
80 |
+
with open(full_path, 'wb') as out_file:
|
81 |
+
# If img_file is a file-like object (e.g., Django's UploadedFile),
|
82 |
+
# you can use copyfileobj or read chunks.
|
83 |
+
# If it's a path, you'd need to open and then save it.
|
84 |
+
if hasattr(img_file, 'read'):
|
85 |
+
# This is a file-like object
|
86 |
+
copyfileobj(img_file, out_file)
|
87 |
+
else:
|
88 |
+
# If img_file is a path string
|
89 |
+
with open(img_file, 'rb') as fd:
|
90 |
+
copyfileobj(fd, out_file)
|
91 |
+
return full_path
|
92 |
+
else:
|
93 |
+
if image is None:
|
94 |
+
try:
|
95 |
+
with Image.open(img_file) as image:
|
96 |
+
full_path = os.path.join(directory, img_file.name)
|
97 |
+
image.save(full_path, "JPEG")
|
98 |
+
# Return the full path of the saved image
|
99 |
+
return full_path
|
100 |
+
except:
|
101 |
+
try:
|
102 |
+
with Image.open(os.path.join(directory,img_file)) as image:
|
103 |
+
full_path = os.path.join(directory, img_file)
|
104 |
+
image.save(full_path, "JPEG")
|
105 |
+
# Return the full path of the saved image
|
106 |
+
return full_path
|
107 |
+
except:
|
108 |
+
with Image.open(img_file.name) as image:
|
109 |
+
full_path = os.path.join(directory, img_file.name)
|
110 |
+
image.save(full_path, "JPEG")
|
111 |
+
# Return the full path of the saved image
|
112 |
+
return full_path
|
113 |
+
else:
|
114 |
+
try:
|
115 |
+
full_path = os.path.join(directory, img_file.name)
|
116 |
+
image.save(full_path, "JPEG")
|
117 |
+
return full_path
|
118 |
+
except:
|
119 |
+
full_path = os.path.join(directory, img_file)
|
120 |
+
image.save(full_path, "JPEG")
|
121 |
+
return full_path
|
122 |
+
# def save_uploaded_file(directory, uploaded_file, image=None):
|
123 |
+
# if not os.path.exists(directory):
|
124 |
+
# os.makedirs(directory)
|
125 |
+
|
126 |
+
# full_path = os.path.join(directory, uploaded_file.name)
|
127 |
+
|
128 |
+
# # Handle PDF files
|
129 |
+
# if uploaded_file.name.lower().endswith('.pdf'):
|
130 |
+
# with open(full_path, 'wb') as out_file:
|
131 |
+
# if hasattr(uploaded_file, 'read'):
|
132 |
+
# copyfileobj(uploaded_file, out_file)
|
133 |
+
# else:
|
134 |
+
# with open(uploaded_file, 'rb') as fd:
|
135 |
+
# copyfileobj(fd, out_file)
|
136 |
+
# return full_path
|
137 |
+
# else:
|
138 |
+
# if image is None:
|
139 |
+
# try:
|
140 |
+
# with Image.open(uploaded_file) as img:
|
141 |
+
# img.save(full_path, "JPEG")
|
142 |
+
# except:
|
143 |
+
# with Image.open(full_path) as img:
|
144 |
+
# img.save(full_path, "JPEG")
|
145 |
+
# else:
|
146 |
+
# try:
|
147 |
+
# image.save(full_path, "JPEG")
|
148 |
+
# except:
|
149 |
+
# image.save(os.path.join(directory, uploaded_file.name), "JPEG")
|
150 |
+
# return full_path
|
151 |
+
|
152 |
+
def save_uploaded_local(directory, img_file, image=None):
|
153 |
+
name = img_file.split(os.path.sep)[-1]
|
154 |
+
if not os.path.exists(directory):
|
155 |
+
os.makedirs(directory)
|
156 |
+
|
157 |
+
# Assuming the uploaded file is an image
|
158 |
+
if image is None:
|
159 |
+
with Image.open(img_file) as image:
|
160 |
+
full_path = os.path.join(directory, name)
|
161 |
+
image.save(full_path, "JPEG")
|
162 |
+
# Return the full path of the saved image
|
163 |
+
return os.path.join('uploads_small',name)
|
164 |
+
else:
|
165 |
+
full_path = os.path.join(directory, name)
|
166 |
+
image.save(full_path, "JPEG")
|
167 |
+
return os.path.join('.','uploads_small',name)
|
168 |
+
|
169 |
+
def image_to_base64(img):
|
170 |
+
buffered = BytesIO()
|
171 |
+
img.save(buffered, format="JPEG")
|
172 |
+
return base64.b64encode(buffered.getvalue()).decode()
|
173 |
+
|
174 |
+
def check_prompt_yaml_filename(fname):
|
175 |
+
# Check if the filename only contains letters, numbers, underscores, and dashes
|
176 |
+
pattern = r'^[\w-]+$'
|
177 |
+
|
178 |
+
# The \w matches any alphanumeric character and is equivalent to the character class [a-zA-Z0-9_].
|
179 |
+
# The hyphen - is literally matched.
|
180 |
+
|
181 |
+
if re.match(pattern, fname):
|
182 |
+
return True
|
183 |
+
else:
|
184 |
+
return False
|
185 |
+
|
186 |
+
def report_violation(file_name, is_hf=True):
|
187 |
+
# Format the current date and time
|
188 |
+
current_time = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
|
189 |
+
violation_file_name = f"violation_{current_time}.yaml" # Updated variable name to avoid confusion
|
190 |
+
|
191 |
+
# Create a temporary YAML file in text mode
|
192 |
+
with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.yaml') as temp_file:
|
193 |
+
# Example content - customize as needed
|
194 |
+
content = {
|
195 |
+
'violation_time': current_time,
|
196 |
+
'notes': 'This is an autogenerated violation report.',
|
197 |
+
'name_of_file': file_name,
|
198 |
+
}
|
199 |
+
# Write the content to the temporary YAML file in text mode
|
200 |
+
yaml.dump(content, temp_file, default_flow_style=False)
|
201 |
+
temp_filepath = temp_file.name
|
202 |
+
|
203 |
+
# Now upload the temporary file
|
204 |
+
upload_to_drive(temp_filepath, violation_file_name, is_hf=is_hf)
|
205 |
+
|
206 |
+
# Optionally, delete the temporary file if you don't want it to remain on disk after uploading
|
207 |
+
os.remove(temp_filepath)
|
208 |
+
|
209 |
+
# Function to upload files to Google Drive
|
210 |
+
def upload_to_drive(filepath, filename, is_hf=True, cfg_private=None, do_upload = True):
|
211 |
+
if do_upload:
|
212 |
+
creds = get_google_credentials(is_hf=is_hf, cfg_private=cfg_private)
|
213 |
+
if creds:
|
214 |
+
service = build('drive', 'v3', credentials=creds)
|
215 |
+
|
216 |
+
# Get the folder ID from the environment variable
|
217 |
+
if is_hf:
|
218 |
+
folder_id = os.environ.get('GDRIVE_FOLDER_ID') # Renamed for clarity
|
219 |
+
else:
|
220 |
+
folder_id = cfg_private['google']['GDRIVE_FOLDER_ID'] # Renamed for clarity
|
221 |
+
|
222 |
+
|
223 |
+
if folder_id:
|
224 |
+
file_metadata = {
|
225 |
+
'name': filename,
|
226 |
+
'parents': [folder_id]
|
227 |
+
}
|
228 |
+
|
229 |
+
# Determine the mimetype based on the file extension
|
230 |
+
if filename.endswith('.yaml') or filename.endswith('.yml') or filepath.endswith('.yaml') or filepath.endswith('.yml'):
|
231 |
+
mimetype = 'application/x-yaml'
|
232 |
+
elif filepath.endswith('.zip'):
|
233 |
+
mimetype = 'application/zip'
|
234 |
+
else:
|
235 |
+
# Set a default mimetype if desired or handle the unsupported file type
|
236 |
+
print("Unsupported file type")
|
237 |
+
return None
|
238 |
+
|
239 |
+
# Upload the file
|
240 |
+
try:
|
241 |
+
media = MediaFileUpload(filepath, mimetype=mimetype)
|
242 |
+
file = service.files().create(
|
243 |
+
body=file_metadata,
|
244 |
+
media_body=media,
|
245 |
+
fields='id'
|
246 |
+
).execute()
|
247 |
+
print(f"Uploaded file with ID: {file.get('id')}")
|
248 |
+
except Exception as e:
|
249 |
+
msg = f"If the following error is '404 cannot find file...' then you need to share the GDRIVE folder with your Google API service account's email address. Open your Google API JSON file, find the email account that ends with '@developer.gserviceaccount.com', go to your Google Drive, share the folder with this email account. {e}"
|
250 |
+
print(msg)
|
251 |
+
raise Exception(msg)
|
252 |
+
else:
|
253 |
+
print("GDRIVE_API environment variable not set.")
|
254 |
+
|
255 |
+
def get_google_credentials(is_hf=True, cfg_private=None): # Also used for google drive
|
256 |
+
if is_hf:
|
257 |
+
creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
|
258 |
+
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
259 |
+
return credentials
|
260 |
+
else:
|
261 |
+
with open(cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS'], 'r') as file:
|
262 |
+
data = json.load(file)
|
263 |
+
creds_json_str = json.dumps(data)
|
264 |
+
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
265 |
+
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = creds_json_str
|
266 |
+
return credentials
|
vouchervision/utils_hf.py
CHANGED
@@ -73,7 +73,7 @@ def save_uploaded_file(directory, img_file, image=None):
|
|
73 |
if not os.path.exists(directory):
|
74 |
os.makedirs(directory)
|
75 |
|
76 |
-
full_path = os.path.join(directory, img_file.name)
|
77 |
|
78 |
# Assuming the uploaded file is an image
|
79 |
if img_file.name.lower().endswith('.pdf'):
|
@@ -98,11 +98,18 @@ def save_uploaded_file(directory, img_file, image=None):
|
|
98 |
# Return the full path of the saved image
|
99 |
return full_path
|
100 |
except:
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
else:
|
107 |
try:
|
108 |
full_path = os.path.join(directory, img_file.name)
|
@@ -112,6 +119,35 @@ def save_uploaded_file(directory, img_file, image=None):
|
|
112 |
full_path = os.path.join(directory, img_file)
|
113 |
image.save(full_path, "JPEG")
|
114 |
return full_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
def save_uploaded_local(directory, img_file, image=None):
|
117 |
name = img_file.split(os.path.sep)[-1]
|
|
|
73 |
if not os.path.exists(directory):
|
74 |
os.makedirs(directory)
|
75 |
|
76 |
+
full_path = os.path.join(directory, img_file.name) ########## TODO THIS MUST BE MOVED TO conditional specific location
|
77 |
|
78 |
# Assuming the uploaded file is an image
|
79 |
if img_file.name.lower().endswith('.pdf'):
|
|
|
98 |
# Return the full path of the saved image
|
99 |
return full_path
|
100 |
except:
|
101 |
+
try:
|
102 |
+
with Image.open(os.path.join(directory,img_file)) as image:
|
103 |
+
full_path = os.path.join(directory, img_file)
|
104 |
+
image.save(full_path, "JPEG")
|
105 |
+
# Return the full path of the saved image
|
106 |
+
return full_path
|
107 |
+
except:
|
108 |
+
with Image.open(img_file.name) as image:
|
109 |
+
full_path = os.path.join(directory, img_file.name)
|
110 |
+
image.save(full_path, "JPEG")
|
111 |
+
# Return the full path of the saved image
|
112 |
+
return full_path
|
113 |
else:
|
114 |
try:
|
115 |
full_path = os.path.join(directory, img_file.name)
|
|
|
119 |
full_path = os.path.join(directory, img_file)
|
120 |
image.save(full_path, "JPEG")
|
121 |
return full_path
|
122 |
+
# def save_uploaded_file(directory, uploaded_file, image=None):
|
123 |
+
# if not os.path.exists(directory):
|
124 |
+
# os.makedirs(directory)
|
125 |
+
|
126 |
+
# full_path = os.path.join(directory, uploaded_file.name)
|
127 |
+
|
128 |
+
# # Handle PDF files
|
129 |
+
# if uploaded_file.name.lower().endswith('.pdf'):
|
130 |
+
# with open(full_path, 'wb') as out_file:
|
131 |
+
# if hasattr(uploaded_file, 'read'):
|
132 |
+
# copyfileobj(uploaded_file, out_file)
|
133 |
+
# else:
|
134 |
+
# with open(uploaded_file, 'rb') as fd:
|
135 |
+
# copyfileobj(fd, out_file)
|
136 |
+
# return full_path
|
137 |
+
# else:
|
138 |
+
# if image is None:
|
139 |
+
# try:
|
140 |
+
# with Image.open(uploaded_file) as img:
|
141 |
+
# img.save(full_path, "JPEG")
|
142 |
+
# except:
|
143 |
+
# with Image.open(full_path) as img:
|
144 |
+
# img.save(full_path, "JPEG")
|
145 |
+
# else:
|
146 |
+
# try:
|
147 |
+
# image.save(full_path, "JPEG")
|
148 |
+
# except:
|
149 |
+
# image.save(os.path.join(directory, uploaded_file.name), "JPEG")
|
150 |
+
# return full_path
|
151 |
|
152 |
def save_uploaded_local(directory, img_file, image=None):
|
153 |
name = img_file.split(os.path.sep)[-1]
|