BenjiELCA commited on
Commit
9134c9f
·
1 Parent(s): ba0e747

correction of bug with dataAssociation

Browse files
Files changed (6) hide show
  1. app.py +15 -15
  2. images/example4.jpg +0 -0
  3. modules/OCR.py +1 -6
  4. modules/eval.py +15 -6
  5. modules/toXML.py +115 -42
  6. modules/utils.py +7 -0
app.py CHANGED
@@ -249,7 +249,7 @@ def perform_inference(model_object, model_arrow, image, score_threshold):
249
  image_placeholder.image(uploaded_image, caption='Original Image', width=1000)
250
 
251
  # Prediction
252
- _, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=0.5)
253
 
254
  # Perform OCR on the uploaded image
255
  ocr_results = text_prediction(uploaded_image)
@@ -311,21 +311,21 @@ def main():
311
 
312
  model_arrow = st.session_state.model_arrow
313
  model_object = st.session_state.model_object
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  #Create the layout for the app
316
  col1, col2 = st.columns(2)
317
  with col1:
318
- with st.expander("Use example images"):
319
- img_selected = image_select("If you have no image and just want to test the demo, click on one of these images", ["./images/none.jpg", "./images/example1.jpg", "./images/example2.jpg", "./images/example3.jpg"],
320
- captions=["None", "Example 1", "Example 2", "Example 3"], index=0, use_container_width=False, return_value="original")
321
-
322
- if img_selected== './images/none.jpg':
323
- print('No example image selected')
324
- #delete the prediction
325
- if 'prediction' in st.session_state:
326
- del st.session_state['prediction']
327
- img_selected = None
328
-
329
  # Create a file uploader for the user to upload an image
330
  if img_selected is not None:
331
  uploaded_file = img_selected
@@ -359,10 +359,10 @@ def main():
359
  perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold)
360
  #st.session_state.prediction = modif_box_pos(st.session_state.prediction, object_dict)
361
  st.balloons()
362
- else:
363
  #delete the prediction
364
- if 'prediction' in st.session_state:
365
- del st.session_state['prediction']
366
 
367
 
368
  # If the prediction has been made and the user has uploaded an image, display the options for the user to annotate the image
 
249
  image_placeholder.image(uploaded_image, caption='Original Image', width=1000)
250
 
251
  # Prediction
252
+ _, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=0.5, distance_treshold=30)
253
 
254
  # Perform OCR on the uploaded image
255
  ocr_results = text_prediction(uploaded_image)
 
311
 
312
  model_arrow = st.session_state.model_arrow
313
  model_object = st.session_state.model_object
314
+
315
+ with st.expander("Use example images"):
316
+ img_selected = image_select("If you have no image and just want to test the demo, click on one of these images", ["./images/none.jpg", "./images/example1.jpg", "./images/example2.jpg", "./images/example3.jpg", "./images/example4.jpg"],
317
+ captions=["None", "Example 1", "Example 2", "Example 3", "Example 4"], index=0, use_container_width=False, return_value="original")
318
+
319
+ if img_selected== './images/none.jpg':
320
+ print('No example image selected')
321
+ #delete the prediction
322
+ #if 'prediction' in st.session_state:
323
+ #del st.session_state['prediction']
324
+ img_selected = None
325
 
326
  #Create the layout for the app
327
  col1, col2 = st.columns(2)
328
  with col1:
 
 
 
 
 
 
 
 
 
 
 
329
  # Create a file uploader for the user to upload an image
330
  if img_selected is not None:
331
  uploaded_file = img_selected
 
359
  perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold)
360
  #st.session_state.prediction = modif_box_pos(st.session_state.prediction, object_dict)
361
  st.balloons()
362
+ #else:
363
  #delete the prediction
364
+ #if 'prediction' in st.session_state:
365
+ #del st.session_state['prediction']
366
 
367
 
368
  # If the prediction has been made and the user has uploaded an image, display the options for the user to annotate the image
images/example4.jpg CHANGED
modules/OCR.py CHANGED
@@ -10,7 +10,7 @@ from modules.utils import class_dict, proportion_inside
10
  import json
11
  from modules.utils import rescale_boxes as rescale
12
  import streamlit as st
13
-
14
 
15
  VISION_KEY = os.getenv("VISION_KEY")
16
  VISION_ENDPOINT = os.getenv("VISION_ENDPOINT")
@@ -188,11 +188,6 @@ def find_closest_box(text_box, all_boxes, labels, threshold, iou_threshold=0.5):
188
  return None
189
 
190
 
191
- def is_vertical(box):
192
- """Determine if the text in the bounding box is vertically aligned."""
193
- width = box[2] - box[0]
194
- height = box[3] - box[1]
195
- return (height > 2*width)
196
 
197
  def group_texts(task_boxes, text_boxes, texts, min_dist=50, iou_threshold=0.8, percentage_thresh=0.8):
198
  """Maps text boxes to task boxes and groups texts within each task based on proximity."""
 
10
  import json
11
  from modules.utils import rescale_boxes as rescale
12
  import streamlit as st
13
+ from modules.utils import is_vertical
14
 
15
  VISION_KEY = os.getenv("VISION_KEY")
16
  VISION_ENDPOINT = os.getenv("VISION_ENDPOINT")
 
188
  return None
189
 
190
 
 
 
 
 
 
191
 
192
  def group_texts(task_boxes, text_boxes, texts, min_dist=50, iou_threshold=0.8, percentage_thresh=0.8):
193
  """Maps text boxes to task boxes and groups texts within each task based on proximity."""
modules/eval.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from modules.utils import class_dict, object_dict, arrow_dict, find_closest_object, find_other_keypoint, filter_overlap_boxes, iou
4
  from tqdm import tqdm
5
  from modules.toXML import create_BPMN_id
 
6
 
7
 
8
 
@@ -75,7 +76,7 @@ def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
75
  for i in range(len(labels)):
76
  if labels[i] != list(object_dict.values()).index('task'):
77
  continue
78
- if boxes[i][2]-boxes[i][0] < boxes[i][3]-boxes[i][1]:
79
  vertical += 1
80
  horizontal = len(labels) - vertical
81
  for i in range(len(labels)):
@@ -83,12 +84,12 @@ def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
83
  continue
84
 
85
  if vertical < horizontal:
86
- if boxes[i][2]-boxes[i][0] < boxes[i][3]-boxes[i][1]:
87
  #find the element in the list and remove it
88
  if i in selected_boxes:
89
  selected_boxes.remove(i)
90
  elif vertical > horizontal:
91
- if boxes[i][2]-boxes[i][0] > boxes[i][3]-boxes[i][1]:
92
  #find the element in the list and remove it
93
  if i in selected_boxes:
94
  selected_boxes.remove(i)
@@ -261,19 +262,27 @@ def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
261
 
262
  for pool_index, elements in pool_dict.items():
263
  print(f"Pool {pool_index} contains elements: {elements}")
264
- #check if each link is in the same pool
265
  for i in range(len(flow_links)):
266
  if labels[i] == list(class_dict.values()).index('sequenceFlow'):
267
  id1, id2 = flow_links[i]
268
- if (id1 and id2) is not None:
 
269
  if id1 in elements and id2 in elements:
270
- continue
 
 
 
 
 
271
  elif id1 not in elements and id2 not in elements:
272
  continue
273
  else:
274
  print('change the link from sequenceFlow to messageFlow')
275
  labels[i]=list(class_dict.values()).index('messageFlow')
276
 
 
 
277
  for i in range(len(labels)):
278
  #check if dataAssociation is connected to a dataObject
279
  if labels[i] == list(class_dict.values()).index('dataAssociation'):
 
3
  from modules.utils import class_dict, object_dict, arrow_dict, find_closest_object, find_other_keypoint, filter_overlap_boxes, iou
4
  from tqdm import tqdm
5
  from modules.toXML import create_BPMN_id
6
+ from modules.utils import is_vertical
7
 
8
 
9
 
 
76
  for i in range(len(labels)):
77
  if labels[i] != list(object_dict.values()).index('task'):
78
  continue
79
+ if is_vertical(boxes[i]):
80
  vertical += 1
81
  horizontal = len(labels) - vertical
82
  for i in range(len(labels)):
 
84
  continue
85
 
86
  if vertical < horizontal:
87
+ if is_vertical(boxes[i]):
88
  #find the element in the list and remove it
89
  if i in selected_boxes:
90
  selected_boxes.remove(i)
91
  elif vertical > horizontal:
92
+ if is_vertical(boxes[i]) == False:
93
  #find the element in the list and remove it
94
  if i in selected_boxes:
95
  selected_boxes.remove(i)
 
262
 
263
  for pool_index, elements in pool_dict.items():
264
  print(f"Pool {pool_index} contains elements: {elements}")
265
+ #check if the label sequenceflow is good
266
  for i in range(len(flow_links)):
267
  if labels[i] == list(class_dict.values()).index('sequenceFlow'):
268
  id1, id2 = flow_links[i]
269
+ if id1 is not None and id2 is not None:
270
+ #check if each link is in the same pool
271
  if id1 in elements and id2 in elements:
272
+ #check if the link is between a dataObject or a dataStore
273
+ if labels[id1] == 8 or labels[id2] == 8 or labels[id1] == 9 or labels[id2] == 9:
274
+ print('change the link from sequenceFlow to dataAssociation')
275
+ labels[i]=list(class_dict.values()).index('dataAssociation')
276
+ else:
277
+ continue
278
  elif id1 not in elements and id2 not in elements:
279
  continue
280
  else:
281
  print('change the link from sequenceFlow to messageFlow')
282
  labels[i]=list(class_dict.values()).index('messageFlow')
283
 
284
+
285
+
286
  for i in range(len(labels)):
287
  #check if dataAssociation is connected to a dataObject
288
  if labels[i] == list(class_dict.values()).index('dataAssociation'):
modules/toXML.py CHANGED
@@ -92,18 +92,35 @@ def check_status(link, keep_elements):
92
  return 'middle'
93
 
94
  def check_data_association(i, links, labels, keep_elements):
 
95
  for j, (k,l) in enumerate(links):
96
- if labels[j] == 14:
 
97
  if k==i:
98
- return 'output',j
 
99
  elif l==i:
100
- return 'input',j
101
-
102
- return 'no association', None
103
 
104
- def create_data_Association(bpmn,data,size,element_id,source_id,target_id):
105
- waypoints = calculate_waypoints(data, size, source_id, target_id)
 
 
106
  add_diagram_edge(bpmn, element_id, waypoints)
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  # Function to dynamically create and layout BPMN elements
109
  def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data, keep_elements):
@@ -111,6 +128,8 @@ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data
111
  positions = data['boxes']
112
  links = data['links']
113
 
 
 
114
  for i in keep_elements:
115
  element_id = elements[i]
116
  if element_id is None:
@@ -127,25 +146,27 @@ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data
127
  # Task
128
  elif element_type == 'task':
129
  element = ET.SubElement(process, 'bpmn:task', id=element_id, name=text_mapping[element_id])
130
- status, dataAssociation_idx = check_data_association(i, data['links'], data['labels'], keep_elements)
131
-
132
- # Handle Data Input Association
133
- if status == 'input':
134
- dataObject_idx = links[dataAssociation_idx][0]
135
- dataObject_name = elements[dataObject_idx]
136
- dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'
137
- sub_element = ET.SubElement(element, 'bpmn:dataInputAssociation', id=f'dataInputAssociation_{dataObject_ref.split("_")[1]}')
138
- ET.SubElement(sub_element, 'bpmn:sourceRef').text = dataObject_ref
139
- create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], dataObject_name, element_id)
140
-
141
- # Handle Data Output Association
142
- elif status == 'output':
143
- dataObject_idx = links[dataAssociation_idx][1]
144
- dataObject_name = elements[dataObject_idx]
145
- dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'
146
- sub_element = ET.SubElement(element, 'bpmn:dataOutputAssociation', id=f'dataOutputAssociation_{dataObject_ref.split("_")[1]}')
147
- ET.SubElement(sub_element, 'bpmn:targetRef').text = dataObject_ref
148
- create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], element_id, dataObject_name)
 
 
149
 
150
  add_diagram_elements(bpmnplane, element_id, x, y, size['task'][0], size['task'][1])
151
 
@@ -158,6 +179,30 @@ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data
158
  element = ET.SubElement(process, 'bpmn:intermediateCatchEvent', id=element_id, name=text_mapping[element_id])
159
  elif status == 'end':
160
  element = ET.SubElement(process, 'bpmn:endEvent', id=element_id, name=text_mapping[element_id])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  ET.SubElement(element, 'bpmn:messageEventDefinition', id=f'MessageEventDefinition_{i+1}')
162
  add_diagram_elements(bpmnplane, element_id, x, y, size['message'][0], size['message'][1])
163
 
@@ -172,13 +217,36 @@ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data
172
  element = ET.SubElement(process, f'bpmn:{gateway_type}', id=element_id)
173
  add_diagram_elements(bpmnplane, element_id, x, y, size[element_type][0], size[element_type][1])
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  # Data Object
176
- elif element_type == 'dataObject':
177
  dataObject_idx = element_id.split('_')[1]
178
  dataObject_ref = f'DataObjectReference_{dataObject_idx}'
179
  element = ET.SubElement(process, 'bpmn:dataObjectReference', id=dataObject_ref, dataObjectRef=element_id, name=text_mapping[element_id])
180
- ET.SubElement(process, 'bpmn:dataObject', id=element_id)
181
- add_diagram_elements(bpmnplane, dataObject_ref, x, y, size['dataObject'][0], size['dataObject'][1])
182
 
183
  # Timer Event
184
  elif element_type == 'timerEvent':
@@ -222,6 +290,8 @@ def calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_ele
222
 
223
  # Check if the connection involves a pool
224
  if source_element == 'pool':
 
 
225
  pool_box = source_box
226
  element_box = (target_box[0], target_box[1], target_box[0]+size[target_element][0], target_box[1]+size[target_element][1])
227
  element_mid_x = (element_box[0] + element_box[2]) / 2
@@ -295,21 +365,24 @@ def create_flow_element(bpmn, text_mapping, idx, size, data, parent, message=Fal
295
  else:
296
  element_id = f'sequenceflow_{source_id}_{target_id}'
297
 
298
- if source_id.split('_')[0] == 'pool' or target_id.split('_')[0] == 'pool':
299
- waypoints = calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_id.split('_')[0], target_id.split('_')[0])
300
- #waypoints = data['best_points'][idx]
301
- if source_id.split('_')[0] == 'pool':
302
- source_id = f"participant_{source_id.split('_')[1]}"
303
- if target_id.split('_')[0] == 'pool':
304
- target_id = f"participant_{target_id.split('_')[1]}"
305
- else:
306
- waypoints = calculate_waypoints(data, size, idx, source_id, target_id)
307
- #waypoints = data['best_points'][idx]
308
-
309
- #waypoints = data['best_points'][idx]
310
  if message:
311
- element = ET.SubElement(parent, 'bpmn:messageFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
 
 
 
 
 
 
 
 
 
 
 
 
312
  else:
 
313
  element = ET.SubElement(parent, 'bpmn:sequenceFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
314
  add_diagram_edge(bpmn, element_id, waypoints)
 
 
315
 
 
92
  return 'middle'
93
 
94
  def check_data_association(i, links, labels, keep_elements):
95
+ status, links_idx = [], []
96
  for j, (k,l) in enumerate(links):
97
+ if labels[j] == list(class_dict.values()).index('dataAssociation'):
98
+ print('i,j,k,l',i,j,k,l)
99
  if k==i:
100
+ status.append('output')
101
+ links_idx.append(j)
102
  elif l==i:
103
+ status.append('input')
104
+ links_idx.append(j)
 
105
 
106
+ return status, links_idx
107
+
108
+ def create_data_Association(bpmn,data,size,element_id,current_idx,source_id,target_id):
109
+ waypoints = calculate_waypoints(data, size, current_idx, source_id, target_id)
110
  add_diagram_edge(bpmn, element_id, waypoints)
111
+
112
+ def check_eventBasedGateway(i, links, labels):
113
+ status, links_idx = [], []
114
+ for j, (k,l) in enumerate(links):
115
+ if labels[j] == list(class_dict.values()).index('sequenceFlow'):
116
+ if k==i:
117
+ status.append('output')
118
+ links_idx.append(j)
119
+ elif l==i:
120
+ status.append('input')
121
+ links_idx.append(j)
122
+
123
+ return status, links_idx
124
 
125
  # Function to dynamically create and layout BPMN elements
126
  def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data, keep_elements):
 
128
  positions = data['boxes']
129
  links = data['links']
130
 
131
+ #print(links)
132
+
133
  for i in keep_elements:
134
  element_id = elements[i]
135
  if element_id is None:
 
146
  # Task
147
  elif element_type == 'task':
148
  element = ET.SubElement(process, 'bpmn:task', id=element_id, name=text_mapping[element_id])
149
+ status, datasAssociation_idx = check_data_association(i, data['links'], data['labels'], keep_elements)
150
+
151
+ if len(status) != 0:
152
+ for state, dataAssociation_idx in zip(status, datasAssociation_idx):
153
+ # Handle Data Input Association
154
+ if state == 'input':
155
+ dataObject_idx = links[dataAssociation_idx][0]
156
+ dataObject_name = elements[dataObject_idx]
157
+ dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'
158
+ sub_element = ET.SubElement(element, 'bpmn:dataInputAssociation', id=f'dataInputAssociation_{dataObject_ref.split("_")[1]}')
159
+ ET.SubElement(sub_element, 'bpmn:sourceRef').text = dataObject_ref
160
+ create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], dataAssociation_idx, dataObject_name, element_id)
161
+
162
+ # Handle Data Output Association
163
+ elif state == 'output':
164
+ dataObject_idx = links[dataAssociation_idx][1]
165
+ dataObject_name = elements[dataObject_idx]
166
+ dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'
167
+ sub_element = ET.SubElement(element, 'bpmn:dataOutputAssociation', id=f'dataOutputAssociation_{dataObject_ref.split("_")[1]}')
168
+ ET.SubElement(sub_element, 'bpmn:targetRef').text = dataObject_ref
169
+ create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], dataAssociation_idx, element_id, dataObject_name)
170
 
171
  add_diagram_elements(bpmnplane, element_id, x, y, size['task'][0], size['task'][1])
172
 
 
179
  element = ET.SubElement(process, 'bpmn:intermediateCatchEvent', id=element_id, name=text_mapping[element_id])
180
  elif status == 'end':
181
  element = ET.SubElement(process, 'bpmn:endEvent', id=element_id, name=text_mapping[element_id])
182
+
183
+ status, datasAssociation_idx = check_data_association(i, data['links'], data['labels'], keep_elements)
184
+ print('status',status)
185
+ print('datasAssociation_idx',datasAssociation_idx)
186
+ if len(status) != 0:
187
+ for state, dataAssociation_idx in zip(status, datasAssociation_idx):
188
+ # Handle Data Input Association
189
+ if state == 'input':
190
+ dataObject_idx = links[dataAssociation_idx][0]
191
+ dataObject_name = elements[dataObject_idx]
192
+ dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'
193
+ sub_element = ET.SubElement(element, 'bpmn:dataInputAssociation', id=f'dataInputAssociation_{dataObject_ref.split("_")[1]}')
194
+ ET.SubElement(sub_element, 'bpmn:sourceRef').text = dataObject_ref
195
+ create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], dataAssociation_idx, dataObject_name, element_id)
196
+
197
+ # Handle Data Output Association
198
+ elif state == 'output':
199
+ dataObject_idx = links[dataAssociation_idx][1]
200
+ dataObject_name = elements[dataObject_idx]
201
+ dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'
202
+ sub_element = ET.SubElement(element, 'bpmn:dataOutputAssociation', id=f'dataOutputAssociation_{dataObject_ref.split("_")[1]}')
203
+ ET.SubElement(sub_element, 'bpmn:targetRef').text = dataObject_ref
204
+ create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], dataAssociation_idx, element_id, dataObject_name)
205
+
206
  ET.SubElement(element, 'bpmn:messageEventDefinition', id=f'MessageEventDefinition_{i+1}')
207
  add_diagram_elements(bpmnplane, element_id, x, y, size['message'][0], size['message'][1])
208
 
 
217
  element = ET.SubElement(process, f'bpmn:{gateway_type}', id=element_id)
218
  add_diagram_elements(bpmnplane, element_id, x, y, size[element_type][0], size[element_type][1])
219
 
220
+ elif element_type == 'eventBasedGateway':
221
+ element = ET.SubElement(process, 'bpmn:eventBasedGateway', id=element_id)
222
+ status, links_idx = check_eventBasedGateway(i, data['links'], data['labels'])
223
+
224
+ if len(status) != 0:
225
+ for state, link_idx in zip(status, links_idx):
226
+ # Handle Data Input Association
227
+ if state == 'input' :
228
+ gateway_idx = links[link_idx][0]
229
+ gateway_name = elements[gateway_idx]
230
+ sub_element = ET.SubElement(element, 'bpmn:eventBasedGateway', id=f'eventBasedGateway{gateway_name.split("_")[1]}')
231
+ create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], i, gateway_name, element_id)
232
+
233
+ # Handle Data Output Association
234
+ elif state == 'output':
235
+ gateway_idx = links[link_idx][1]
236
+ gateway_name = elements[gateway_idx]
237
+ sub_element = ET.SubElement(element, 'bpmn:eventBasedGateway', id=f'eventBasedGateway{gateway_name.split("_")[1]}')
238
+ create_data_Association(bpmnplane, data, size, sub_element.attrib['id'], i, element_id, gateway_name)
239
+
240
+
241
+ add_diagram_elements(bpmnplane, element_id, x, y, size['eventBasedGateway'][0], size['eventBasedGateway'][1])
242
+
243
  # Data Object
244
+ elif element_type == 'dataObject' or element_type == 'dataStore':
245
  dataObject_idx = element_id.split('_')[1]
246
  dataObject_ref = f'DataObjectReference_{dataObject_idx}'
247
  element = ET.SubElement(process, 'bpmn:dataObjectReference', id=dataObject_ref, dataObjectRef=element_id, name=text_mapping[element_id])
248
+ ET.SubElement(process, f'bpmn:{element_type}', id=element_id)
249
+ add_diagram_elements(bpmnplane, dataObject_ref, x, y, size[element_type][0], size[element_type][1])
250
 
251
  # Timer Event
252
  elif element_type == 'timerEvent':
 
290
 
291
  # Check if the connection involves a pool
292
  if source_element == 'pool':
293
+ if target_element == 'pool':
294
+ return [(source_mid_x, source_mid_y), (source_mid_x, source_mid_y)]
295
  pool_box = source_box
296
  element_box = (target_box[0], target_box[1], target_box[0]+size[target_element][0], target_box[1]+size[target_element][1])
297
  element_mid_x = (element_box[0] + element_box[2]) / 2
 
365
  else:
366
  element_id = f'sequenceflow_{source_id}_{target_id}'
367
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  if message:
369
+ if source_id.split('_')[0] == 'pool' or target_id.split('_')[0] == 'pool':
370
+ waypoints = calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_id.split('_')[0], target_id.split('_')[0])
371
+ if source_id.split('_')[0] == 'pool':
372
+ XML_source_id = f"participant_{source_id.split('_')[1]}"
373
+ XML_target_id = target_id
374
+ if target_id.split('_')[0] == 'pool':
375
+ XML_target_id = f"participant_{target_id.split('_')[1]}"
376
+ XML_source_id = source_id
377
+
378
+ element = ET.SubElement(parent, 'bpmn:messageFlow', id=element_id, sourceRef=XML_source_id, targetRef=XML_target_id, name=text_mapping[data['BPMN_id'][idx]])
379
+ else:
380
+ waypoints = calculate_waypoints(data, size, idx, source_id, target_id)
381
+ element = ET.SubElement(parent, 'bpmn:messageFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
382
  else:
383
+ waypoints = calculate_waypoints(data, size, idx, source_id, target_id)
384
  element = ET.SubElement(parent, 'bpmn:sequenceFlow', id=element_id, sourceRef=source_id, targetRef=target_id, name=text_mapping[data['BPMN_id'][idx]])
385
  add_diagram_edge(bpmn, element_id, waypoints)
386
+
387
+
388
 
modules/utils.py CHANGED
@@ -61,6 +61,13 @@ class_dict = {
61
  15: 'messageFlow',
62
  }
63
 
 
 
 
 
 
 
 
64
  def rescale_boxes(scale, boxes):
65
  for i in range(len(boxes)):
66
  boxes[i] = [boxes[i][0]*scale,
 
61
  15: 'messageFlow',
62
  }
63
 
64
+
65
+ def is_vertical(box):
66
+ """Determine if the text in the bounding box is vertically aligned."""
67
+ width = box[2] - box[0]
68
+ height = box[3] - box[1]
69
+ return (height > 2*width)
70
+
71
  def rescale_boxes(scale, boxes):
72
  for i in range(len(boxes)):
73
  boxes[i] = [boxes[i][0]*scale,