BenjiELCA commited on
Commit
9467fbe
1 Parent(s): 91857b0

change is inside

Browse files
modules/OCR.py CHANGED
@@ -8,9 +8,8 @@ import numpy as np
8
  import networkx as nx
9
  from modules.utils import class_dict, proportion_inside
10
  import json
11
- from modules.utils import rescale_boxes as rescale
12
  import streamlit as st
13
- from modules.utils import is_vertical
14
 
15
  VISION_KEY = os.getenv("VISION_KEY")
16
  VISION_ENDPOINT = os.getenv("VISION_ENDPOINT")
@@ -133,13 +132,6 @@ def min_distance_between_boxes(box1, box2):
133
  min_dist = dist
134
  return min_dist
135
 
136
-
137
- def is_inside(box1, box2):
138
- """Check if the center of box1 is inside box2."""
139
- x_center = (box1[0] + box1[2]) / 2
140
- y_center = (box1[1] + box1[3]) / 2
141
- return box2[0] <= x_center <= box2[2] and box2[1] <= y_center <= box2[3]
142
-
143
  def are_close(box1, box2, threshold=50):
144
  """Determines if boxes are close based on their corners and center points."""
145
  corners1 = np.array([
@@ -307,12 +299,6 @@ def group_texts(task_boxes, text_boxes, texts, min_dist=50, iou_threshold=0.8, p
307
 
308
  def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0.6,scale=1.0, iou_threshold=0.5):
309
 
310
- ########### REFAIRE CETTE FONCTION ###########
311
- #refaire la fonction pour qu'elle prenne en premier les elements qui sont dans les task et ensuite prendre un seuil de distance pour les autres elements
312
- #ou sinon faire la distance entre les elements et non pas seulement les tasks
313
-
314
-
315
- # Example usage
316
  boxes = rescale(scale, full_pred['boxes'])
317
 
318
  min_dist = 200
 
8
  import networkx as nx
9
  from modules.utils import class_dict, proportion_inside
10
  import json
11
+ from modules.utils import rescale_boxes as rescale, is_vertical
12
  import streamlit as st
 
13
 
14
  VISION_KEY = os.getenv("VISION_KEY")
15
  VISION_ENDPOINT = os.getenv("VISION_ENDPOINT")
 
132
  min_dist = dist
133
  return min_dist
134
 
 
 
 
 
 
 
 
135
  def are_close(box1, box2, threshold=50):
136
  """Determines if boxes are close based on their corners and center points."""
137
  corners1 = np.array([
 
299
 
300
  def mapping_text(full_pred, text_pred, print_sentences=False,percentage_thresh=0.6,scale=1.0, iou_threshold=0.5):
301
 
 
 
 
 
 
 
302
  boxes = rescale(scale, full_pred['boxes'])
303
 
304
  min_dist = 200
modules/streamlit_utils.py CHANGED
@@ -81,6 +81,7 @@ def load_models():
81
  torch.save(model_arrow.state_dict(), output_arrow)
82
  elif 'model_arrow' not in st.session_state and Path(output_arrow).exists():
83
  model_arrow.load_state_dict(torch.load(output_arrow, map_location=device))
 
84
  st.session_state.model_arrow = model_arrow
85
  print('Model arrow loaded from local file')
86
 
@@ -95,8 +96,9 @@ def load_models():
95
  torch.save(model_object.state_dict(), output_object)
96
  elif 'model_object' not in st.session_state and Path(output_object).exists():
97
  model_object.load_state_dict(torch.load(output_object, map_location=device))
 
98
  st.session_state.model_object = model_object
99
- print('Model object loaded from local file')
100
 
101
 
102
  # Move models to device
@@ -184,14 +186,16 @@ def perform_inference(model_object, model_arrow, image, score_threshold, is_mobi
184
  width = screen_width//2
185
  image_placeholder.image(uploaded_image, caption='Original Image', width=width)
186
 
187
- # Prediction
188
- _, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
189
-
190
  # Perform OCR on the uploaded image
191
  ocr_results = text_prediction(uploaded_image)
192
 
193
  # Filter and map OCR results to prediction results
194
  st.session_state.text_pred = filter_text(ocr_results, threshold=0.6)
 
 
 
 
 
195
  st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
196
 
197
  # Remove the original image display
 
81
  torch.save(model_arrow.state_dict(), output_arrow)
82
  elif 'model_arrow' not in st.session_state and Path(output_arrow).exists():
83
  model_arrow.load_state_dict(torch.load(output_arrow, map_location=device))
84
+ print()
85
  st.session_state.model_arrow = model_arrow
86
  print('Model arrow loaded from local file')
87
 
 
96
  torch.save(model_object.state_dict(), output_object)
97
  elif 'model_object' not in st.session_state and Path(output_object).exists():
98
  model_object.load_state_dict(torch.load(output_object, map_location=device))
99
+ print()
100
  st.session_state.model_object = model_object
101
+ print('Model object loaded from local file\n')
102
 
103
 
104
  # Move models to device
 
186
  width = screen_width//2
187
  image_placeholder.image(uploaded_image, caption='Original Image', width=width)
188
 
 
 
 
189
  # Perform OCR on the uploaded image
190
  ocr_results = text_prediction(uploaded_image)
191
 
192
  # Filter and map OCR results to prediction results
193
  st.session_state.text_pred = filter_text(ocr_results, threshold=0.6)
194
+
195
+ # Prediction
196
+ _, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
197
+
198
+ #Mapping text to prediction
199
  st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
200
 
201
  # Remove the original image display
modules/toXML.py CHANGED
@@ -103,7 +103,7 @@ def expand_pool_bounding_boxes(modified_pred, pred, size_elements):
103
 
104
  position = find_position(pool_index, modified_pred['BPMN_id'])
105
 
106
- if keep_elements == [] or position is None:
107
  min_x, min_y, max_x, max_y = modified_pred['boxes'][position]
108
  else:
109
  min_x, min_y, max_x, max_y = calculate_pool_bounds(modified_pred['boxes'], modified_pred['labels'], keep_elements, size_elements)
@@ -121,7 +121,7 @@ def adjust_pool_boundaries(modified_pred, pred):
121
  min_left, max_right = 0, 0
122
  for pool_index, element_indices in pred['pool_dict'].items():
123
  position = find_position(pool_index, modified_pred['BPMN_id'])
124
- if position >= len(modified_pred['boxes']):
125
  continue
126
  x1, y1, x2, y2 = modified_pred['boxes'][position]
127
  left = x1
@@ -133,7 +133,7 @@ def adjust_pool_boundaries(modified_pred, pred):
133
 
134
  for pool_index, element_indices in pred['pool_dict'].items():
135
  position = find_position(pool_index, modified_pred['BPMN_id'])
136
- if position >= len(modified_pred['boxes']):
137
  continue
138
  x1, y1, x2, y2 = modified_pred['boxes'][position]
139
  if x1 > min_left:
@@ -148,9 +148,9 @@ def align_boxes(pred, size, class_dict):
148
  pool_groups = calculate_centers_and_group_by_pool(pred, class_dict)
149
  align_elements_within_pool(modified_pred, pool_groups, class_dict, size)
150
 
151
- if len(pred['pool_dict']) > 1:
152
- expand_pool_bounding_boxes(modified_pred, pred, size)
153
- adjust_pool_boundaries(modified_pred, pred)
154
 
155
  return modified_pred['boxes']
156
 
 
103
 
104
  position = find_position(pool_index, modified_pred['BPMN_id'])
105
 
106
+ if keep_elements == [] and position is not None:
107
  min_x, min_y, max_x, max_y = modified_pred['boxes'][position]
108
  else:
109
  min_x, min_y, max_x, max_y = calculate_pool_bounds(modified_pred['boxes'], modified_pred['labels'], keep_elements, size_elements)
 
121
  min_left, max_right = 0, 0
122
  for pool_index, element_indices in pred['pool_dict'].items():
123
  position = find_position(pool_index, modified_pred['BPMN_id'])
124
+ if position is None or position >= len(modified_pred['boxes']):
125
  continue
126
  x1, y1, x2, y2 = modified_pred['boxes'][position]
127
  left = x1
 
133
 
134
  for pool_index, element_indices in pred['pool_dict'].items():
135
  position = find_position(pool_index, modified_pred['BPMN_id'])
136
+ if position is None or position >= len(modified_pred['boxes']):
137
  continue
138
  x1, y1, x2, y2 = modified_pred['boxes'][position]
139
  if x1 > min_left:
 
148
  pool_groups = calculate_centers_and_group_by_pool(pred, class_dict)
149
  align_elements_within_pool(modified_pred, pool_groups, class_dict, size)
150
 
151
+ #if len(pred['pool_dict']) > 1:
152
+ #expand_pool_bounding_boxes(modified_pred, pred, size)
153
+ #adjust_pool_boundaries(modified_pred, pred)
154
 
155
  return modified_pred['boxes']
156
 
modules/utils.py CHANGED
@@ -57,6 +57,12 @@ class_dict = {
57
  }
58
 
59
 
 
 
 
 
 
 
60
  def is_vertical(box):
61
  """Determine if the text in the bounding box is vertically aligned."""
62
  width = box[2] - box[0]
 
57
  }
58
 
59
 
60
+ def is_inside(box1, box2):
61
+ """Check if the center of box1 is inside box2."""
62
+ x_center = (box1[0] + box1[2]) / 2
63
+ y_center = (box1[1] + box1[3]) / 2
64
+ return box2[0] <= x_center <= box2[2] and box2[1] <= y_center <= box2[3]
65
+
66
  def is_vertical(box):
67
  """Determine if the text in the bounding box is vertically aligned."""
68
  width = box[2] - box[0]