BenjiELCA commited on
Commit
3250939
1 Parent(s): 6339b70

better waypoint calculation

Browse files
Files changed (6) hide show
  1. .gitignore +2 -0
  2. app.py +25 -23
  3. eval.py +4 -0
  4. htlm_webpage.py +1 -1
  5. toXML.py +41 -77
  6. utils.py +4 -2
.gitignore CHANGED
@@ -11,3 +11,5 @@ VISION_KEY.json
11
  .streamlit/secrets.toml
12
 
13
  backup/
 
 
 
11
  .streamlit/secrets.toml
12
 
13
  backup/
14
+
15
+ temp.jpg
app.py CHANGED
@@ -61,25 +61,27 @@ def create_XML(full_pred, text_mapping, scale):
61
  }
62
 
63
  size_elements = {
64
- 'start': (36, 36),
65
- 'task': (100, 80),
66
- 'message': (36, 36),
67
- 'messageEvent': (36, 36),
68
- 'end': (36, 36),
69
- 'exclusiveGateway': (50, 50),
70
- 'event': (36, 36),
71
- 'parallelGateway': (50, 50),
72
- 'sequenceFlow': (150, 10),
73
- 'pool': (250, 100),
74
- 'lane': (200, 100),
75
- 'dataObject': (40, 60),
76
- 'dataStore': (60, 60),
77
- 'subProcess': (120, 90),
78
- 'eventBasedGateway': (50, 50),
79
- 'timerEvent': (40, 40),
80
  }
81
 
82
 
 
 
83
  definitions = ET.Element('bpmn:definitions', {
84
  'xmlns:xsi': namespaces['xsi'],
85
  'xmlns:bpmn': namespaces['bpmn'],
@@ -332,13 +334,13 @@ def main():
332
  st.session_state.scale = st.slider("Set scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
333
 
334
  # Launch the prediction when the user clicks the button
335
- if st.button("Launch Prediction"):
336
- st.session_state.crop_image = cropped_image
337
- with st.spinner('Processing...'):
338
- perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold)
339
- st.session_state.prediction = modif_box_pos(st.session_state.prediction, object_dict)
340
- st.success('Detection completed!')
341
- print('Detection completed!')
342
 
343
 
344
  # If the prediction has been made and the user has uploaded an image, display the options for the user to annotate the image
 
61
  }
62
 
63
  size_elements = {
64
+ 'start': (43.2, 43.2),
65
+ 'task': (120, 96),
66
+ 'message': (43.2, 43.2),
67
+ 'messageEvent': (43.2, 43.2),
68
+ 'end': (43.2, 43.2),
69
+ 'exclusiveGateway': (60, 60),
70
+ 'event': (43.2, 43.2),
71
+ 'parallelGateway': (60, 60),
72
+ 'sequenceFlow': (180, 12),
73
+ 'pool': (300, 120),
74
+ 'lane': (240, 120),
75
+ 'dataObject': (48, 72),
76
+ 'dataStore': (72, 72),
77
+ 'subProcess': (144, 108),
78
+ 'eventBasedGateway': (60, 60),
79
+ 'timerEvent': (48, 48),
80
  }
81
 
82
 
83
+
84
+
85
  definitions = ET.Element('bpmn:definitions', {
86
  'xmlns:xsi': namespaces['xsi'],
87
  'xmlns:bpmn': namespaces['bpmn'],
 
334
  st.session_state.scale = st.slider("Set scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
335
 
336
  # Launch the prediction when the user clicks the button
337
+ #if st.button("Launch Prediction"):
338
+ st.session_state.crop_image = cropped_image
339
+ with st.spinner('Processing...'):
340
+ perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold)
341
+ st.session_state.prediction = modif_box_pos(st.session_state.prediction, object_dict)
342
+ st.success('Detection completed!')
343
+ print('Detection completed!')
344
 
345
 
346
  # If the prediction has been made and the user has uploaded an image, display the options for the user to annotate the image
eval.py CHANGED
@@ -239,6 +239,10 @@ def create_links(keypoints, boxes, labels, class_dict):
239
  if labels[i]==list(class_dict.values()).index('sequenceFlow') or labels[i]==list(class_dict.values()).index('messageFlow'):
240
  closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
241
  closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
 
 
 
 
242
  if closest1 is not None and closest2 is not None:
243
  best_points.append([point_start, point_end])
244
  links.append([closest1, closest2])
 
239
  if labels[i]==list(class_dict.values()).index('sequenceFlow') or labels[i]==list(class_dict.values()).index('messageFlow'):
240
  closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels)
241
  closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels)
242
+
243
+ print('closest1:', closest1, 'closest2:', closest2)
244
+ print('point_start:', point_start, 'point_end:', point_end)
245
+
246
  if closest1 is not None and closest2 is not None:
247
  best_points.append([point_start, point_end])
248
  links.append([closest1, closest2])
htlm_webpage.py CHANGED
@@ -65,7 +65,7 @@ def display_bpmn_xml(bpmn_xml):
65
  <div id="button-container">
66
  <button id="save-button">Save as BPMN</button>
67
  <button id="download-button">Save as XML</button>
68
- <button id="download-button">Save as Vizi</button>
69
  </div>
70
  <div id="canvas-container">
71
  <div id="canvas"></div>
 
65
  <div id="button-container">
66
  <button id="save-button">Save as BPMN</button>
67
  <button id="download-button">Save as XML</button>
68
+ <button id="download-button">Save as Vizi (not available for now)</button>
69
  </div>
70
  <div id="canvas-container">
71
  <div id="canvas"></div>
toXML.py CHANGED
@@ -187,82 +187,6 @@ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data
187
  add_diagram_elements(bpmnplane, element_id, x, y, size['timerEvent'][0], size['timerEvent'][1])
188
 
189
 
190
-
191
- # Calculate simple waypoints between two elements (this function assumes direct horizontal links for simplicity)
192
- def calculate_waypoints(data, size, source_id, target_id):
193
- source_idx = data['BPMN_id'].index(source_id)
194
- target_idx = data['BPMN_id'].index(target_id)
195
- name_source = source_id.split('_')[0]
196
- name_target = target_id.split('_')[0]
197
-
198
- #Get the position of the source and target
199
- source_x, source_y = data['boxes'][source_idx][:2]
200
- target_x, target_y = data['boxes'][target_idx][:2]
201
-
202
- # Calculate relative position between source and target from their centers
203
- relative_x = (target_x+size[name_target][0])/2 - (source_x+size[name_source][0])/2
204
- relative_y = (target_y+size[name_target][1])/2 - (source_y+size[name_source][1])/2
205
-
206
- # Get the size of the elements
207
- size_x_source = size[name_source][0]
208
- size_y_source = size[name_source][1]
209
- size_x_target = size[name_target][0]
210
- size_y_target = size[name_target][1]
211
-
212
- #if it going to right
213
- if relative_x >= size[name_source][0]:
214
- source_x += size_x_source
215
- source_y += size_y_source / 2
216
- target_x = target_x
217
- target_y += size_y_target / 2
218
- #if the source is going up
219
- if relative_y < -size[name_source][1]:
220
- source_x -= size_x_source / 2
221
- source_y -= size_y_source / 2
222
- #if the source is going down
223
- elif relative_y > size[name_source][1]:
224
- source_x -= size_x_source / 2
225
- source_y += size_y_source / 2
226
- #if it going to left
227
- elif relative_x < -size[name_source][0]:
228
- source_x = source_x
229
- source_y += size_y_source / 2
230
- target_x += size_x_target
231
- target_y += size_y_target / 2
232
- #if the source is going up
233
- if relative_y < -size[name_source][1]:
234
- source_x += size_x_source / 2
235
- source_y -= size_y_source / 2
236
- #if the source is going down
237
- elif relative_y > size[name_source][1]:
238
- source_x += size_x_source / 2
239
- source_y += size_y_source / 2
240
- #if it going up and down
241
- elif -size[name_source][0] < relative_x < size[name_source][0]:
242
- source_x += size_x_source / 2
243
- target_x += size_x_target / 2
244
- #if it's going down
245
- if relative_y >= size[name_source][1]/2:
246
- source_y += size_y_source
247
- #if it's going up
248
- elif relative_y < -size[name_source][1]/2:
249
- source_y = source_y
250
- target_y += size_y_target
251
- else:
252
- if relative_x >= 0:
253
- source_x += size_x_source/2
254
- source_y += size_y_source/2
255
- target_x -= size_x_target/2
256
- target_y += size_y_target/2
257
- else:
258
- source_x -= size_x_source/2
259
- source_y += size_y_source/2
260
- target_x += size_x_target/2
261
- target_y += size_y_target/2
262
-
263
- return [(source_x, source_y), (target_x, target_y)]
264
-
265
-
266
  def calculate_pool_bounds(data, keep_elements, size):
267
  min_x = min_y = float('10000')
268
  max_x = max_y = float('0')
@@ -321,7 +245,47 @@ def calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_ele
321
 
322
  return waypoints
323
 
 
 
 
 
324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
  def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=False):
327
  source_idx, target_idx = data['links'][idx]
@@ -339,7 +303,7 @@ def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=Fal
339
  if target_id.split('_')[0] == 'pool':
340
  target_id = f"participant_{target_id.split('_')[1]}"
341
  else:
342
- waypoints = calculate_waypoints(data, size, source_id, target_id)
343
  #waypoints = data['best_points'][idx]
344
 
345
  #waypoints = data['best_points'][idx]
 
187
  add_diagram_elements(bpmnplane, element_id, x, y, size['timerEvent'][0], size['timerEvent'][1])
188
 
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  def calculate_pool_bounds(data, keep_elements, size):
191
  min_x = min_y = float('10000')
192
  max_x = max_y = float('0')
 
245
 
246
  return waypoints
247
 
248
+ def calculate_waypoints(data, size, current_idx, source_id, target_id):
249
+ best_points = data['best_points'][current_idx]
250
+ pos_source = best_points[0]
251
+ pos_target = best_points[1]
252
 
253
+ source_idx = data['BPMN_id'].index(source_id)
254
+ target_idx = data['BPMN_id'].index(target_id)
255
+ name_source = source_id.split('_')[0]
256
+ name_target = target_id.split('_')[0]
257
+
258
+ #Get the position of the source and target
259
+ source_x, source_y = data['boxes'][source_idx][:2]
260
+ target_x, target_y = data['boxes'][target_idx][:2]
261
+
262
+ if pos_source == 'left':
263
+ source_x = source_x
264
+ source_y += size[name_source][1]/2
265
+ elif pos_source == 'right':
266
+ source_x += size[name_source][0]
267
+ source_y += size[name_source][1]/2
268
+ elif pos_source == 'top':
269
+ source_x += size[name_source][0]/2
270
+ source_y = source_y
271
+ elif pos_source == 'bottom':
272
+ source_x += size[name_source][0]/2
273
+ source_y += size[name_source][1]
274
+
275
+ if pos_target == 'left':
276
+ target_x = target_x
277
+ target_y += size[name_target][1]/2
278
+ elif pos_target == 'right':
279
+ target_x += size[name_target][0]
280
+ target_y += size[name_target][1]/2
281
+ elif pos_target == 'top':
282
+ target_x += size[name_target][0]/2
283
+ target_y = target_y
284
+ elif pos_target == 'bottom':
285
+ target_x += size[name_target][0]/2
286
+ target_y += size[name_target][1]
287
+
288
+ return [(source_x, source_y), (target_x, target_y)]
289
 
290
  def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=False):
291
  source_idx, target_idx = data['links'][idx]
 
303
  if target_id.split('_')[0] == 'pool':
304
  target_id = f"participant_{target_id.split('_')[1]}"
305
  else:
306
+ waypoints = calculate_waypoints(data, size, idx, source_id, target_id)
307
  #waypoints = data['best_points'][idx]
308
 
309
  #waypoints = data['best_points'][idx]
utils.py CHANGED
@@ -923,14 +923,16 @@ def find_closest_object(keypoint, boxes, labels):
923
  right = (x2, (y1+y2)/2)
924
  points = [left, top , right, bottom]
925
 
 
 
926
  # Calculate the distance between the keypoint and the center of the bounding box
927
- for point in points:
928
  distance = np.linalg.norm(keypoint[:2] - point)
929
  # Update the closest object index if this object is closer
930
  if distance < min_distance:
931
  min_distance = distance
932
  closest_object_idx = i
933
- best_point = point
934
 
935
  return closest_object_idx, best_point
936
 
 
923
  right = (x2, (y1+y2)/2)
924
  points = [left, top , right, bottom]
925
 
926
+ pos_dict = {0:'left', 1:'top', 2:'right', 3:'bottom'}
927
+
928
  # Calculate the distance between the keypoint and the center of the bounding box
929
+ for pos, (point) in enumerate(points):
930
  distance = np.linalg.norm(keypoint[:2] - point)
931
  # Update the closest object index if this object is closer
932
  if distance < min_distance:
933
  min_distance = distance
934
  closest_object_idx = i
935
+ best_point = pos_dict[pos]
936
 
937
  return closest_object_idx, best_point
938