Spaces:
Running
Running
change is inside
Browse files- modules/OCR.py +1 -15
- modules/streamlit_utils.py +8 -4
- modules/toXML.py +6 -6
- modules/utils.py +6 -0
modules/OCR.py
CHANGED
@@ -8,9 +8,8 @@ import numpy as np
|
|
8 |
import networkx as nx
|
9 |
from modules.utils import class_dict, proportion_inside
|
10 |
import json
|
11 |
-
from modules.utils import rescale_boxes as rescale
|
12 |
import streamlit as st
|
13 |
-
from modules.utils import is_vertical
|
14 |
|
15 |
VISION_KEY = os.getenv("VISION_KEY")
|
16 |
VISION_ENDPOINT = os.getenv("VISION_ENDPOINT")
|
@@ -133,13 +132,6 @@ def min_distance_between_boxes(box1, box2):
|
|
133 |
min_dist = dist
|
134 |
return min_dist
|
135 |
|
136 |
-
|
137 |
-
def is_inside(box1, box2):
|
138 |
-
"""Check if the center of box1 is inside box2."""
|
139 |
-
x_center = (box1[0] + box1[2]) / 2
|
140 |
-
y_center = (box1[1] + box1[3]) / 2
|
141 |
-
return box2[0] <= x_center <= box2[2] and box2[1] <= y_center <= box2[3]
|
142 |
-
|
143 |
def are_close(box1, box2, threshold=50):
|
144 |
"""Determines if boxes are close based on their corners and center points."""
|
145 |
corners1 = np.array([
|
@@ -307,12 +299,6 @@ def group_texts(task_boxes, text_boxes, texts, min_dist=50, iou_threshold=0.8, p
|
|
307 |
|
308 |
def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0.6,scale=1.0, iou_threshold=0.5):
|
309 |
|
310 |
-
########### REFAIRE CETTE FONCTION ###########
|
311 |
-
#refaire la fonction pour qu'elle prenne en premier les elements qui sont dans les task et ensuite prendre un seuil de distance pour les autres elements
|
312 |
-
#ou sinon faire la distance entre les elements et non pas seulement les tasks
|
313 |
-
|
314 |
-
|
315 |
-
# Example usage
|
316 |
boxes = rescale(scale, full_pred['boxes'])
|
317 |
|
318 |
min_dist = 200
|
|
|
8 |
import networkx as nx
|
9 |
from modules.utils import class_dict, proportion_inside
|
10 |
import json
|
11 |
+
from modules.utils import rescale_boxes as rescale, is_vertical
|
12 |
import streamlit as st
|
|
|
13 |
|
14 |
VISION_KEY = os.getenv("VISION_KEY")
|
15 |
VISION_ENDPOINT = os.getenv("VISION_ENDPOINT")
|
|
|
132 |
min_dist = dist
|
133 |
return min_dist
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
def are_close(box1, box2, threshold=50):
|
136 |
"""Determines if boxes are close based on their corners and center points."""
|
137 |
corners1 = np.array([
|
|
|
299 |
|
300 |
def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0.6,scale=1.0, iou_threshold=0.5):
|
301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
boxes = rescale(scale, full_pred['boxes'])
|
303 |
|
304 |
min_dist = 200
|
modules/streamlit_utils.py
CHANGED
@@ -81,6 +81,7 @@ def load_models():
|
|
81 |
torch.save(model_arrow.state_dict(), output_arrow)
|
82 |
elif 'model_arrow' not in st.session_state and Path(output_arrow).exists():
|
83 |
model_arrow.load_state_dict(torch.load(output_arrow, map_location=device))
|
|
|
84 |
st.session_state.model_arrow = model_arrow
|
85 |
print('Model arrow loaded from local file')
|
86 |
|
@@ -95,8 +96,9 @@ def load_models():
|
|
95 |
torch.save(model_object.state_dict(), output_object)
|
96 |
elif 'model_object' not in st.session_state and Path(output_object).exists():
|
97 |
model_object.load_state_dict(torch.load(output_object, map_location=device))
|
|
|
98 |
st.session_state.model_object = model_object
|
99 |
-
print('Model object loaded from local file')
|
100 |
|
101 |
|
102 |
# Move models to device
|
@@ -184,14 +186,16 @@ def perform_inference(model_object, model_arrow, image, score_threshold, is_mobi
|
|
184 |
width = screen_width//2
|
185 |
image_placeholder.image(uploaded_image, caption='Original Image', width=width)
|
186 |
|
187 |
-
# Prediction
|
188 |
-
_, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
|
189 |
-
|
190 |
# Perform OCR on the uploaded image
|
191 |
ocr_results = text_prediction(uploaded_image)
|
192 |
|
193 |
# Filter and map OCR results to prediction results
|
194 |
st.session_state.text_pred = filter_text(ocr_results, threshold=0.6)
|
|
|
|
|
|
|
|
|
|
|
195 |
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
|
196 |
|
197 |
# Remove the original image display
|
|
|
81 |
torch.save(model_arrow.state_dict(), output_arrow)
|
82 |
elif 'model_arrow' not in st.session_state and Path(output_arrow).exists():
|
83 |
model_arrow.load_state_dict(torch.load(output_arrow, map_location=device))
|
84 |
+
print()
|
85 |
st.session_state.model_arrow = model_arrow
|
86 |
print('Model arrow loaded from local file')
|
87 |
|
|
|
96 |
torch.save(model_object.state_dict(), output_object)
|
97 |
elif 'model_object' not in st.session_state and Path(output_object).exists():
|
98 |
model_object.load_state_dict(torch.load(output_object, map_location=device))
|
99 |
+
print()
|
100 |
st.session_state.model_object = model_object
|
101 |
+
print('Model object loaded from local file\n')
|
102 |
|
103 |
|
104 |
# Move models to device
|
|
|
186 |
width = screen_width//2
|
187 |
image_placeholder.image(uploaded_image, caption='Original Image', width=width)
|
188 |
|
|
|
|
|
|
|
189 |
# Perform OCR on the uploaded image
|
190 |
ocr_results = text_prediction(uploaded_image)
|
191 |
|
192 |
# Filter and map OCR results to prediction results
|
193 |
st.session_state.text_pred = filter_text(ocr_results, threshold=0.6)
|
194 |
+
|
195 |
+
# Prediction
|
196 |
+
_, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
|
197 |
+
|
198 |
+
#Mapping text to prediction
|
199 |
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
|
200 |
|
201 |
# Remove the original image display
|
modules/toXML.py
CHANGED
@@ -103,7 +103,7 @@ def expand_pool_bounding_boxes(modified_pred, pred, size_elements):
|
|
103 |
|
104 |
position = find_position(pool_index, modified_pred['BPMN_id'])
|
105 |
|
106 |
-
if keep_elements == []
|
107 |
min_x, min_y, max_x, max_y = modified_pred['boxes'][position]
|
108 |
else:
|
109 |
min_x, min_y, max_x, max_y = calculate_pool_bounds(modified_pred['boxes'], modified_pred['labels'], keep_elements, size_elements)
|
@@ -121,7 +121,7 @@ def adjust_pool_boundaries(modified_pred, pred):
|
|
121 |
min_left, max_right = 0, 0
|
122 |
for pool_index, element_indices in pred['pool_dict'].items():
|
123 |
position = find_position(pool_index, modified_pred['BPMN_id'])
|
124 |
-
if position >= len(modified_pred['boxes']):
|
125 |
continue
|
126 |
x1, y1, x2, y2 = modified_pred['boxes'][position]
|
127 |
left = x1
|
@@ -133,7 +133,7 @@ def adjust_pool_boundaries(modified_pred, pred):
|
|
133 |
|
134 |
for pool_index, element_indices in pred['pool_dict'].items():
|
135 |
position = find_position(pool_index, modified_pred['BPMN_id'])
|
136 |
-
if position >= len(modified_pred['boxes']):
|
137 |
continue
|
138 |
x1, y1, x2, y2 = modified_pred['boxes'][position]
|
139 |
if x1 > min_left:
|
@@ -148,9 +148,9 @@ def align_boxes(pred, size, class_dict):
|
|
148 |
pool_groups = calculate_centers_and_group_by_pool(pred, class_dict)
|
149 |
align_elements_within_pool(modified_pred, pool_groups, class_dict, size)
|
150 |
|
151 |
-
if len(pred['pool_dict']) > 1:
|
152 |
-
expand_pool_bounding_boxes(modified_pred, pred, size)
|
153 |
-
adjust_pool_boundaries(modified_pred, pred)
|
154 |
|
155 |
return modified_pred['boxes']
|
156 |
|
|
|
103 |
|
104 |
position = find_position(pool_index, modified_pred['BPMN_id'])
|
105 |
|
106 |
+
if keep_elements == [] and position is not None:
|
107 |
min_x, min_y, max_x, max_y = modified_pred['boxes'][position]
|
108 |
else:
|
109 |
min_x, min_y, max_x, max_y = calculate_pool_bounds(modified_pred['boxes'], modified_pred['labels'], keep_elements, size_elements)
|
|
|
121 |
min_left, max_right = 0, 0
|
122 |
for pool_index, element_indices in pred['pool_dict'].items():
|
123 |
position = find_position(pool_index, modified_pred['BPMN_id'])
|
124 |
+
if position is None or position >= len(modified_pred['boxes']):
|
125 |
continue
|
126 |
x1, y1, x2, y2 = modified_pred['boxes'][position]
|
127 |
left = x1
|
|
|
133 |
|
134 |
for pool_index, element_indices in pred['pool_dict'].items():
|
135 |
position = find_position(pool_index, modified_pred['BPMN_id'])
|
136 |
+
if position is None or position >= len(modified_pred['boxes']):
|
137 |
continue
|
138 |
x1, y1, x2, y2 = modified_pred['boxes'][position]
|
139 |
if x1 > min_left:
|
|
|
148 |
pool_groups = calculate_centers_and_group_by_pool(pred, class_dict)
|
149 |
align_elements_within_pool(modified_pred, pool_groups, class_dict, size)
|
150 |
|
151 |
+
#if len(pred['pool_dict']) > 1:
|
152 |
+
#expand_pool_bounding_boxes(modified_pred, pred, size)
|
153 |
+
#adjust_pool_boundaries(modified_pred, pred)
|
154 |
|
155 |
return modified_pred['boxes']
|
156 |
|
modules/utils.py
CHANGED
@@ -57,6 +57,12 @@ class_dict = {
|
|
57 |
}
|
58 |
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
def is_vertical(box):
|
61 |
"""Determine if the text in the bounding box is vertically aligned."""
|
62 |
width = box[2] - box[0]
|
|
|
57 |
}
|
58 |
|
59 |
|
60 |
+
def is_inside(box1, box2):
|
61 |
+
"""Check if the center of box1 is inside box2."""
|
62 |
+
x_center = (box1[0] + box1[2]) / 2
|
63 |
+
y_center = (box1[1] + box1[3]) / 2
|
64 |
+
return box2[0] <= x_center <= box2[2] and box2[1] <= y_center <= box2[3]
|
65 |
+
|
66 |
def is_vertical(box):
|
67 |
"""Determine if the text in the bounding box is vertically aligned."""
|
68 |
width = box[2] - box[0]
|