BenjiELCA commited on
Commit
00a4c90
1 Parent(s): 92158de

big MAJ and ready to add the modification method and style

Browse files
.gitignore CHANGED
@@ -21,3 +21,5 @@ study/
21
  result_bpmn.bpmn
22
 
23
  BPMN_creation.ipynb
 
 
 
21
  result_bpmn.bpmn
22
 
23
  BPMN_creation.ipynb
24
+
25
+ *.png
app.py CHANGED
@@ -1,231 +1,205 @@
1
  import streamlit as st
2
  from torchvision.transforms import functional as F
3
  import gc
4
- import copy
5
- import xml.etree.ElementTree as ET
6
  import numpy as np
7
- from xml.dom import minidom
8
-
9
  from modules.htlm_webpage import display_bpmn_xml
10
- from modules.utils import class_dict, rescale_boxes
11
- from modules.toXML import calculate_pool_bounds, add_diagram_elements, create_bpmn_object, create_flow_element, get_size_elements
12
  from streamlit_cropper import st_cropper
13
  from streamlit_image_select import image_select
14
  from streamlit_js_eval import streamlit_js_eval
15
- from modules.streamlit_utils import get_memory_usage, clear_memory, get_image, load_models, perform_inference, display_options, align_boxes, sidebar
16
-
17
-
18
- # Function to create a BPMN XML file from prediction results
19
- def create_XML(full_pred, text_mapping, size_scale, scale):
20
- namespaces = {
21
- 'bpmn': 'http://www.omg.org/spec/BPMN/20100524/MODEL',
22
- 'bpmndi': 'http://www.omg.org/spec/BPMN/20100524/DI',
23
- 'di': 'http://www.omg.org/spec/DD/20100524/DI',
24
- 'dc': 'http://www.omg.org/spec/DD/20100524/DC',
25
- 'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
26
- }
27
-
28
-
29
- definitions = ET.Element('bpmn:definitions', {
30
- 'xmlns:xsi': namespaces['xsi'],
31
- 'xmlns:bpmn': namespaces['bpmn'],
32
- 'xmlns:bpmndi': namespaces['bpmndi'],
33
- 'xmlns:di': namespaces['di'],
34
- 'xmlns:dc': namespaces['dc'],
35
- 'targetNamespace': "http://example.bpmn.com",
36
- 'id': "simpleExample"
37
- })
38
-
39
- size_elements = get_size_elements(size_scale)
40
-
41
- #modify the boxes positions
42
- old_boxes = copy.deepcopy(full_pred)
43
-
44
- # Create BPMN collaboration element
45
- collaboration = ET.SubElement(definitions, 'bpmn:collaboration', id='collaboration_1')
46
-
47
- # Create BPMN process elements
48
- process = []
49
- for idx in range(len(full_pred['pool_dict'].items())):
50
- process_id = f'process_{idx+1}'
51
- process.append(ET.SubElement(definitions, 'bpmn:process', id=process_id, isExecutable='false', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]]))
52
-
53
- bpmndi = ET.SubElement(definitions, 'bpmndi:BPMNDiagram', id='BPMNDiagram_1')
54
- bpmnplane = ET.SubElement(bpmndi, 'bpmndi:BPMNPlane', id='BPMNPlane_1', bpmnElement='collaboration_1')
55
-
56
- full_pred['boxes'] = rescale_boxes(scale, old_boxes['boxes'])
57
- full_pred['boxes'] = align_boxes(full_pred, size_elements)
58
-
59
- # Add diagram elements for each pool
60
- for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
61
- pool_id = f'participant_{idx+1}'
62
- pool = ET.SubElement(collaboration, 'bpmn:participant', id=pool_id, processRef=f'process_{idx+1}', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]])
63
-
64
- # Calculate the bounding box for the pool
65
- if len(keep_elements) == 0:
66
- min_x, min_y, max_x, max_y = full_pred['boxes'][pool_index]
67
- pool_width = max_x - min_x
68
- pool_height = max_y - min_y
69
- else:
70
- min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred, keep_elements, size_elements)
71
- pool_width = max_x - min_x + 100 # Adding padding
72
- pool_height = max_y - min_y + 100 # Adding padding
73
-
74
- add_diagram_elements(bpmnplane, pool_id, min_x - 50, min_y - 50, pool_width, pool_height)
75
-
76
-
77
- # Create BPMN elements for each pool
78
- for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
79
- create_bpmn_object(process[idx], bpmnplane, text_mapping, definitions, size_elements, full_pred, keep_elements)
80
 
81
- # Create message flow elements
82
- message_flows = [i for i, label in enumerate(full_pred['labels']) if class_dict[label] == 'messageFlow']
83
- for idx in message_flows:
84
- create_flow_element(bpmnplane, text_mapping, idx, size_elements, full_pred, collaboration, message=True)
85
-
86
- # Create sequence flow elements
87
- for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
88
- for i in keep_elements:
89
- if full_pred['labels'][i] == list(class_dict.values()).index('sequenceFlow'):
90
- create_flow_element(bpmnplane, text_mapping, i, size_elements, full_pred, process[idx], message=False)
91
-
92
- # Generate pretty XML string
93
- tree = ET.ElementTree(definitions)
94
- rough_string = ET.tostring(definitions, 'utf-8')
95
- reparsed = minidom.parseString(rough_string)
96
- pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
97
-
98
- full_pred['boxes'] = rescale_boxes(1/scale, full_pred['boxes'])
99
- full_pred['boxes'] = old_boxes
100
-
101
- return pretty_xml_as_string
102
-
103
-
104
- def main():
105
  st.set_page_config(layout="wide")
 
 
 
106
 
107
- screen_width = streamlit_js_eval(js_expressions='screen.width', want_output = True, key = 'SCR')
108
- print("Screen width:", screen_width)
109
-
110
- if screen_width is not None and screen_width < 800:
111
- is_mobile = True
112
- print('Mobile version')
113
- else:
114
- is_mobile = False
115
- print('Desktop version')
116
-
117
- # Add your company logo banner
118
  if is_mobile:
119
  st.image("./images/banner_mobile.png", use_column_width=True)
120
  else:
121
  st.image("./images/banner_desktop.png", use_column_width=True)
122
 
123
- # Use is_mobile flag in your logic
 
124
  if is_mobile:
125
- st.title(f"Welcome on the mobile version of BPMN AI model recognition app")
126
- else:
127
- st.title(f"Welcome on BPMN AI model recognition app")
128
-
129
-
130
 
131
- sidebar() # Display the sidebar
 
132
 
133
-
134
- # Display current memory usage
135
- memory_usage = get_memory_usage()
136
- print(f"Current memory usage: {memory_usage:.2f} MB")
137
-
138
- # Initialize the session state for storing pool bounding boxes
139
  if 'pool_bboxes' not in st.session_state:
140
  st.session_state.pool_bboxes = []
141
-
142
- # Load the models using the defined function
143
  if 'model_object' not in st.session_state or 'model_arrow' not in st.session_state:
144
  clear_memory()
145
- _, _ = load_models()
146
 
147
- model_arrow = st.session_state.model_arrow
148
- model_object = st.session_state.model_object
149
-
150
  with st.expander("Use example images"):
151
- img_selected = image_select("If you have no image and just want to test the demo, click on one of these images", ["./images/none.jpg", "./images/example1.jpg", "./images/example2.jpg", "./images/example3.jpg", "./images/example4.jpg"],
152
- captions=["None", "Example 1", "Example 2", "Example 3", "Example 4"], index=0, use_container_width=False, return_value="original")
153
-
154
- if img_selected== './images/none.jpg':
155
- print('No example image selected')
 
 
 
 
 
 
 
156
  img_selected = None
157
 
158
- if is_mobile==False:
159
- #Create the layout for the app
160
- col1, col2 = st.columns(2)
161
- with col1:
162
- if img_selected is not None:
163
- uploaded_file = img_selected
164
- else:
165
- uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"])
166
  else:
167
- if img_selected is not None:
168
- uploaded_file = img_selected
169
  else:
170
- uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"])
171
-
172
- if uploaded_file is not None:
173
- with st.spinner('Waiting for image display...'):
174
- original_image = get_image(uploaded_file)
175
- resized_image = original_image.resize((screen_width // 3, int(original_image.height * (screen_width // 3) / original_image.width)))
176
-
177
- if not is_mobile:
178
- col1, col2 = st.columns(2)
179
- with col1:
180
- marge=10
181
- cropped_box = st_cropper(
182
- resized_image,
183
- realtime_update=True,
184
- box_color='#0000FF',
185
- return_type='box',
186
- should_resize_image=False,
187
- default_coords=(marge, resized_image.width-marge, marge, resized_image.height-marge)
188
- )
189
- scale_x = original_image.width / resized_image.width
190
- scale_y = original_image.height / resized_image.height
191
- x0, y0, x1, y1 = int(cropped_box['left'] * scale_x), int(cropped_box['top'] * scale_y), int((cropped_box['left'] + cropped_box['width']) * scale_x), int((cropped_box['top'] + cropped_box['height']) * scale_y)
192
- cropped_image = original_image.crop((x0, y0, x1, y1))
193
- with col2:
194
- st.image(cropped_image, caption="Cropped Image", use_column_width=False, width=int(screen_width//4))
195
- else:
196
- st.image(resized_image, caption="Image", use_column_width=False, width=int(4/5*screen_width))
197
- cropped_image = original_image
198
-
199
- if cropped_image is not None:
200
- if is_mobile is False:
201
- col1, col2 = st.columns(2)
202
- with col1:
203
- score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
204
- else:
205
- score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.6, step=0.05)
206
-
207
- if st.button("Launch Prediction"):
208
- st.session_state.crop_image = cropped_image
209
- with st.spinner('Processing...'):
210
- perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5)
211
- st.balloons()
212
-
213
- if 'prediction' in st.session_state and uploaded_file is not None:
214
- with st.spinner('Waiting for result display...'):
215
- display_options(st.session_state.crop_image, score_threshold, is_mobile, int(5/6*screen_width))
216
-
217
- with st.spinner('Waiting for BPMN modeler...'):
218
  col1, col2 = st.columns(2)
219
  with col1:
220
- st.session_state.scale = st.slider("Set distance scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
221
- if is_mobile is False:
222
- st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
223
- else:
224
- st.session_state.size_scale = 1.0
225
- st.session_state.bpmn_xml = create_XML(st.session_state.prediction.copy(), st.session_state.text_mapping, st.session_state.size_scale, st.session_state.scale)
226
- display_bpmn_xml(st.session_state.bpmn_xml, is_mobile=is_mobile, screen_width=int(4/5*screen_width))
227
-
228
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  if __name__ == "__main__":
231
  print('Starting the app...')
 
1
  import streamlit as st
2
  from torchvision.transforms import functional as F
3
  import gc
 
 
4
  import numpy as np
 
 
5
  from modules.htlm_webpage import display_bpmn_xml
 
 
6
  from streamlit_cropper import st_cropper
7
  from streamlit_image_select import image_select
8
  from streamlit_js_eval import streamlit_js_eval
9
+ from streamlit_drawable_canvas import st_canvas
10
+ from modules.streamlit_utils import *
11
+ from glob import glob
12
+ from streamlit_image_annotation import detection
13
+ from modules.toXML import create_XML
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ def configure_page():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  st.set_page_config(layout="wide")
17
+ screen_width = streamlit_js_eval(js_expressions='screen.width', want_output=True, key='SCR')
18
+ is_mobile = screen_width is not None and screen_width < 800
19
+ return is_mobile, screen_width
20
 
21
+ def display_banner(is_mobile):
 
 
 
 
 
 
 
 
 
 
22
  if is_mobile:
23
  st.image("./images/banner_mobile.png", use_column_width=True)
24
  else:
25
  st.image("./images/banner_desktop.png", use_column_width=True)
26
 
27
+ def display_title(is_mobile):
28
+ title = "Welcome on the BPMN AI model recognition app"
29
  if is_mobile:
30
+ title = "Welcome on the mobile version of BPMN AI model recognition app"
31
+ st.title(title)
 
 
 
32
 
33
+ def display_sidebar():
34
+ sidebar()
35
 
36
+ def initialize_session_state():
 
 
 
 
 
37
  if 'pool_bboxes' not in st.session_state:
38
  st.session_state.pool_bboxes = []
 
 
39
  if 'model_object' not in st.session_state or 'model_arrow' not in st.session_state:
40
  clear_memory()
41
+ load_models()
42
 
43
+ def load_example_image():
 
 
44
  with st.expander("Use example images"):
45
+ img_selected = image_select(
46
+ "If you have no image and just want to test the demo, click on one of these images",
47
+ ["./images/none.jpg", "./images/example1.jpg", "./images/example2.jpg", "./images/example3.jpg", "./images/example4.jpg"],
48
+ captions=["None", "Example 1", "Example 2", "Example 3", "Example 4"],
49
+ index=0,
50
+ use_container_width=False,
51
+ return_value="original"
52
+ )
53
+ return img_selected
54
+
55
+ def load_user_image(img_selected, is_mobile):
56
+ if img_selected == './images/none.jpg':
57
  img_selected = None
58
 
59
+ if img_selected is not None:
60
+ uploaded_file = img_selected
 
 
 
 
 
 
61
  else:
62
+ if is_mobile:
63
+ uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"], accept_multiple_files=False)
64
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  col1, col2 = st.columns(2)
66
  with col1:
67
+ uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"])
68
+
69
+ return uploaded_file
70
+
71
+ def display_image(uploaded_file, screen_width, is_mobile):
72
+
73
+ with st.spinner('Waiting for image display...'):
74
+ original_image = get_image(uploaded_file)
75
+ resized_image = original_image.resize((screen_width // 2, int(original_image.height * (screen_width // 2) / original_image.width)))
76
+
77
+ if not is_mobile:
78
+ cropped_image = crop_image(resized_image, original_image)
79
+ else:
80
+ st.image(resized_image, caption="Image", use_column_width=False, width=int(4/5 * screen_width))
81
+ cropped_image = original_image
82
+
83
+ return cropped_image
84
+
85
+ def crop_image(resized_image, original_image):
86
+ marge = 10
87
+ cropped_box = st_cropper(
88
+ resized_image,
89
+ realtime_update=True,
90
+ box_color='#0000FF',
91
+ return_type='box',
92
+ should_resize_image=False,
93
+ default_coords=(marge, resized_image.width - marge, marge, resized_image.height - marge)
94
+ )
95
+ scale_x = original_image.width / resized_image.width
96
+ scale_y = original_image.height / resized_image.height
97
+ x0, y0, x1, y1 = int(cropped_box['left'] * scale_x), int(cropped_box['top'] * scale_y), int((cropped_box['left'] + cropped_box['width']) * scale_x), int((cropped_box['top'] + cropped_box['height']) * scale_y)
98
+ cropped_image = original_image.crop((x0, y0, x1, y1))
99
+ return cropped_image
100
+
101
+ def get_score_threshold(is_mobile):
102
+ col1, col2 = st.columns(2)
103
+ with col1:
104
+ st.session_state.score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5 if not is_mobile else 0.6, step=0.05)
105
+
106
+ def launch_prediction(cropped_image, score_threshold, is_mobile, screen_width):
107
+ st.session_state.crop_image = cropped_image
108
+ with st.spinner('Processing...'):
109
+ perform_inference(
110
+ st.session_state.model_object, st.session_state.model_arrow, st.session_state.crop_image,
111
+ score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5
112
+ )
113
+ st.balloons()
114
+
115
+ from modules.eval import develop_prediction, generate_data
116
+ from modules.utils import class_dict
117
+ def modify_results(percentage_text_dist_thresh=0.5):
118
+ with st.expander("Method and Style modification"):
119
+ label_list = list(class_dict.values())
120
+ bboxes = [[int(coord) for coord in box] for box in st.session_state.prediction['boxes']]
121
+ for i in range(len(bboxes)):
122
+ bboxes[i][2] = bboxes[i][2] - bboxes[i][0]
123
+ bboxes[i][3] = bboxes[i][3] - bboxes[i][1]
124
+ labels = [int(label) for label in st.session_state.prediction['labels']]
125
+ uploaded_image = prepare_image(st.session_state.crop_image, new_size=(1333, 1333), pad=False)
126
+ scale = 2000 / uploaded_image.size[0]
127
+ new_labels = detection(
128
+ image=uploaded_image, bboxes=bboxes, labels=labels,
129
+ label_list=label_list, line_width=3, width=2000, use_space=False
130
+ )
131
+
132
+ if new_labels is not None:
133
+ new_lab = np.array([label['label_id'] for label in new_labels])
134
+
135
+ # Convert back to original format
136
+ bboxes = np.array([label['bbox'] for label in new_labels])
137
+ for i in range(len(bboxes)):
138
+ bboxes[i][2] = bboxes[i][2] + bboxes[i][0]
139
+ bboxes[i][3] = bboxes[i][3] + bboxes[i][1]
140
+
141
+ scores = st.session_state.prediction['scores']
142
+ keypoints = st.session_state.prediction['keypoints']
143
+ #print('Old prediction:', st.session_state.prediction['keypoints'])
144
+ boxes, labels, scores, keypoints, flow_links, best_points, pool_dict = develop_prediction(bboxes, new_lab, scores, keypoints, class_dict, correction=False)
145
+
146
+ st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, flow_links, best_points, pool_dict, class_dict)
147
+ st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
148
+
149
+ #print('New prediction:', st.session_state.prediction['keypoints'])
150
+
151
+
152
+ def display_bpmn_modeler(is_mobile, screen_width):
153
+ with st.spinner('Waiting for BPMN modeler...'):
154
+ st.session_state.bpmn_xml = create_XML(
155
+ st.session_state.prediction.copy(), st.session_state.text_mapping,
156
+ st.session_state.size_scale, st.session_state.scale
157
+ )
158
+ display_bpmn_xml(st.session_state.bpmn_xml, is_mobile=is_mobile, screen_width=int(4/5 * screen_width))
159
+
160
+ def modeler_options(is_mobile):
161
+ col1, col2 = st.columns(2)
162
+ with col1:
163
+ st.session_state.scale = st.slider("Set distance scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1) if not is_mobile else 1.0
164
+ st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1) if not is_mobile else 1.0
165
+
166
+ def main():
167
+ is_mobile, screen_width = configure_page()
168
+ display_banner(is_mobile)
169
+ display_title(is_mobile)
170
+ display_sidebar()
171
+ initialize_session_state()
172
+
173
+ cropped_image = None
174
+
175
+ img_selected = load_example_image()
176
+ uploaded_file = load_user_image(img_selected, is_mobile)
177
+ if uploaded_file is not None:
178
+ cropped_image = display_image(uploaded_file, screen_width, is_mobile)
179
+
180
+ if cropped_image is not None:
181
+ get_score_threshold(is_mobile)
182
+ if st.button("Launch Prediction"):
183
+ launch_prediction(cropped_image, st.session_state.score_threshold, is_mobile, screen_width)
184
+ st.rerun()
185
+
186
+ if 'prediction' in st.session_state and uploaded_file:
187
+ if st.button("🔄 Refresh image"):
188
+ st.rerun()
189
+
190
+ with st.expander("Show result"):
191
+ with st.spinner('Waiting for result display...'):
192
+ display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6 * screen_width))
193
+
194
+ #if not is_mobile:
195
+ #modify_results()
196
+
197
+ with st.expander("Options for BPMN modeler"):
198
+ modeler_options(is_mobile)
199
+
200
+ display_bpmn_modeler(is_mobile, screen_width)
201
+
202
+ gc.collect()
203
 
204
  if __name__ == "__main__":
205
  print('Starting the app...')
modules/__init__.py DELETED
File without changes
modules/eval.py CHANGED
@@ -223,7 +223,7 @@ def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_t
223
 
224
  # Iterate over all elements
225
  for i, box in enumerate(boxes):
226
- if i in pool_indices or class_dict[labels[i]] == 'messageFlow':
227
  continue # Skip pool boxes themselves and messageFlow elements
228
  assigned_to_pool = False
229
  for j, pool_box in enumerate(pool_boxes):
@@ -235,7 +235,7 @@ def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_t
235
  assigned_to_pool = True
236
  break
237
  if not assigned_to_pool:
238
- if class_dict[labels[i]] != 'messageFlow' and class_dict[labels[i]] != 'lane':
239
  elements_not_in_pool.append(i)
240
 
241
  if elements_not_in_pool:
@@ -323,7 +323,7 @@ def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
323
 
324
 
325
 
326
- def last_correction(boxes, labels, scores, keypoints, links, best_points, pool_dict):
327
 
328
  #delete pool that are have only messageFlow on it
329
  delete_pool = []
@@ -332,14 +332,22 @@ def last_correction(boxes, labels, scores, keypoints, links, best_points, pool_d
332
  list(class_dict.values()).index('sequenceFlow'),
333
  list(class_dict.values()).index('dataAssociation')] for i in elements]):
334
  if len(elements) > 0:
335
- delete_pool.append(pool_dict[pool_index])
336
  print(f"Pool {pool_index} contains only arrow elements, deleting it")
337
 
 
 
 
 
 
 
 
 
 
338
  #sort index
339
- delete_pool = sorted(delete_pool, reverse=True)
340
- for pool in delete_pool:
341
- index = list(pool_dict.keys())[list(pool_dict.values()).index(pool)]
342
- del pool_dict[index]
343
 
344
 
345
  delete_elements = []
@@ -379,6 +387,53 @@ def give_link_to_element(links, labels):
379
  links[id2][0] = i
380
  return links
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_threshold=0.5, resize=True, distance_treshold=15):
383
  model_object.eval() # Set the model to evaluation mode
384
  model_arrow.eval() # Set the model to evaluation mode
@@ -393,45 +448,12 @@ def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_t
393
 
394
  boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred)
395
 
396
- # Regroup elements by pool
397
- pool_dict, boxes, labels, scores, keypoints = regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict)
398
- # Create links between elements
399
- flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
400
- #Correct the labels of some sequenceflow that cross multiple pool
401
- labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links)
402
- #give a link to event to allow the creation of the BPMN id with start, indermediate and end event
403
- flow_links = give_link_to_element(flow_links, labels)
404
-
405
- #change every datastore to dataobject [TO DO: change it to make the dataStore work]
406
- for i in range(len(labels)):
407
- if labels[i] == list(class_dict.values()).index('dataStore'):
408
- labels[i] = list(class_dict.values()).index('dataObject')
409
-
410
- boxes,labels,scores,keypoints,flow_links,best_points,pool_dict = last_correction(boxes,labels,scores,keypoints,flow_links,best_points, pool_dict)
411
 
412
  image = image.permute(1, 2, 0).cpu().numpy()
413
  image = (image * 255).astype(np.uint8)
414
- idx = []
415
- for i in range(len(labels)):
416
- idx.append(i)
417
- bpmn_id = [class_dict[labels[i]] for i in range(len(labels))]
418
-
419
- data = {
420
- 'image': image,
421
- 'idx': idx,
422
- 'boxes': boxes,
423
- 'labels': labels,
424
- 'scores': scores,
425
- 'keypoints': keypoints,
426
- 'links': flow_links,
427
- 'best_points': best_points,
428
- 'pool_dict': pool_dict,
429
- 'BPMN_id': bpmn_id,
430
- }
431
-
432
-
433
- # give a unique BPMN id to each element
434
- data = create_BPMN_id(data)
435
 
436
  return image, data
437
 
 
223
 
224
  # Iterate over all elements
225
  for i, box in enumerate(boxes):
226
+ if i in pool_indices or class_dict[labels[i]] == 'messageFlow' or class_dict[labels[i]] == 'pool':
227
  continue # Skip pool boxes themselves and messageFlow elements
228
  assigned_to_pool = False
229
  for j, pool_box in enumerate(pool_boxes):
 
235
  assigned_to_pool = True
236
  break
237
  if not assigned_to_pool:
238
+ if class_dict[labels[i]] != 'messageFlow' and class_dict[labels[i]] != 'lane' or class_dict[labels[i]] != 'pool':
239
  elements_not_in_pool.append(i)
240
 
241
  if elements_not_in_pool:
 
323
 
324
 
325
 
326
+ def last_correction(boxes, labels, scores, keypoints, links, best_points, pool_dict, limit_area=10000):
327
 
328
  #delete pool that are have only messageFlow on it
329
  delete_pool = []
 
332
  list(class_dict.values()).index('sequenceFlow'),
333
  list(class_dict.values()).index('dataAssociation')] for i in elements]):
334
  if len(elements) > 0:
335
+ delete_pool.append(pool_index)
336
  print(f"Pool {pool_index} contains only arrow elements, deleting it")
337
 
338
+ #calcul the area of the pool$
339
+ if pool_index < len(boxes):
340
+ pool = boxes[pool_index]
341
+ area = (pool[2] - pool[0]) * (pool[3] - pool[1])
342
+ print("area: ",area)
343
+ if len(pool_dict)>1 and area < limit_area:
344
+ delete_pool.append(pool_index)
345
+ print(f"Pool {pool_index} is too small, deleting it")
346
+
347
  #sort index
348
+ delete_pool = sorted(set(delete_pool), reverse=True)
349
+ for pool_index in delete_pool:
350
+ del pool_dict[pool_index]
 
351
 
352
 
353
  delete_elements = []
 
387
  links[id2][0] = i
388
  return links
389
 
390
+
391
+ def generate_data(image, boxes, labels, scores, keypoints, flow_links, best_points, pool_dict, class_dict):
392
+ idx = []
393
+ for i in range(len(labels)):
394
+ idx.append(i)
395
+ bpmn_id = [class_dict[labels[i]] for i in range(len(labels))]
396
+
397
+ data = {
398
+ 'image': image,
399
+ 'idx': idx,
400
+ 'boxes': boxes,
401
+ 'labels': labels,
402
+ 'scores': scores,
403
+ 'keypoints': keypoints,
404
+ 'links': flow_links,
405
+ 'best_points': best_points,
406
+ 'pool_dict': pool_dict,
407
+ 'BPMN_id': bpmn_id,
408
+ }
409
+
410
+ # give a unique BPMN id to each element
411
+ data = create_BPMN_id(data)
412
+
413
+ return data
414
+
415
+ def develop_prediction(boxes, labels, scores, keypoints, class_dict, correction=True):
416
+ # Regroup elements by pool
417
+ pool_dict, boxes, labels, scores, keypoints = regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict)
418
+ # Create links between elements
419
+ flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
420
+ #Correct the labels of some sequenceflow that cross multiple pool
421
+ if correction:
422
+ labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links)
423
+ #give a link to event to allow the creation of the BPMN id with start, indermediate and end event
424
+ flow_links = give_link_to_element(flow_links, labels)
425
+
426
+ #change every datastore to dataobject [TO DO: change it to make the dataStore work]
427
+ for i in range(len(labels)):
428
+ if labels[i] == list(class_dict.values()).index('dataStore'):
429
+ labels[i] = list(class_dict.values()).index('dataObject')
430
+
431
+ boxes,labels,scores,keypoints,flow_links,best_points,pool_dict = last_correction(boxes,labels,scores,keypoints,flow_links,best_points, pool_dict)
432
+
433
+ return boxes, labels, scores, keypoints, flow_links, best_points, pool_dict
434
+
435
+
436
+
437
  def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_threshold=0.5, resize=True, distance_treshold=15):
438
  model_object.eval() # Set the model to evaluation mode
439
  model_arrow.eval() # Set the model to evaluation mode
 
448
 
449
  boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred)
450
 
451
+ boxes, labels, scores, keypoints, flow_links, best_points, pool_dict = develop_prediction(boxes, labels, scores, keypoints, class_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
  image = image.permute(1, 2, 0).cpu().numpy()
454
  image = (image * 255).astype(np.uint8)
455
+
456
+ data = generate_data(image, boxes, labels, scores, keypoints, flow_links, best_points, pool_dict, class_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
458
  return image, data
459
 
modules/streamlit_utils.py CHANGED
@@ -10,10 +10,8 @@ import numpy as np
10
  from pathlib import Path
11
  import gdown
12
 
13
-
14
- from modules.htlm_webpage import display_bpmn_xml
15
  from modules.OCR import text_prediction, filter_text, mapping_text
16
- from modules.utils import class_dict, arrow_dict, object_dict, rescale_boxes
17
  from modules.display import draw_stream
18
  from modules.eval import full_prediction
19
  from modules.train import get_faster_rcnn_model, get_arrow_model
@@ -55,79 +53,6 @@ def read_xml_file(filepath):
55
  with open(filepath, 'r', encoding='utf-8') as file:
56
  return file.read()
57
 
58
- def align_boxes(pred, size):
59
- modified_pred = copy.deepcopy(pred) # Make a deep copy of the prediction
60
-
61
- # Step 1: Calculate the center of each bounding box and group them by pool
62
- pool_groups = {}
63
- for pool_index, element_indices in pred['pool_dict'].items():
64
- pool_groups[pool_index] = []
65
- for i in element_indices:
66
- if i > len(modified_pred['labels']):
67
- continue
68
- if class_dict[modified_pred['labels'][i]] != 'dataObject' or class_dict[modified_pred['labels'][i]] != 'dataStore':
69
- x1, y1, x2, y2 = modified_pred['boxes'][i]
70
- center = [(x1 + x2) / 2, (y1 + y2) / 2]
71
- pool_groups[pool_index].append((center, i))
72
-
73
- # Function to group centers within a specified range
74
- def group_centers(centers, axis, range_=50):
75
- groups = []
76
- while centers:
77
- center, idx = centers.pop(0)
78
- group = [(center, idx)]
79
- for other_center, other_idx in centers[:]:
80
- if abs(center[axis] - other_center[axis]) <= range_:
81
- group.append((other_center, other_idx))
82
- centers.remove((other_center, other_idx))
83
- groups.append(group)
84
- return groups
85
-
86
- # Step 2: Align the elements within each pool
87
- for pool_index, centers in pool_groups.items():
88
- # Group bounding boxes by checking if their centers are within ±50 pixels on the y-axis
89
- y_groups = group_centers(centers.copy(), axis=1)
90
-
91
- # Align the y-coordinates of the centers of grouped bounding boxes
92
- for group in y_groups:
93
- avg_y = sum([c[0][1] for c in group]) / len(group) # Calculate the average y-coordinate
94
- for (center, idx) in group:
95
- label = class_dict[modified_pred['labels'][idx]]
96
- if label in size:
97
- new_center = (center[0], avg_y) # Align the y-coordinate
98
- modified_pred['boxes'][idx] = [
99
- new_center[0] - size[label][0] / 2,
100
- new_center[1] - size[label][1] / 2,
101
- new_center[0] + size[label][0] / 2,
102
- new_center[1] + size[label][1] / 2
103
- ]
104
-
105
- # Recalculate centers after vertical alignment
106
- centers = []
107
- for group in y_groups:
108
- for center, idx in group:
109
- x1, y1, x2, y2 = modified_pred['boxes'][idx]
110
- center = [(x1 + x2) / 2, (y1 + y2) / 2]
111
- centers.append((center, idx))
112
-
113
- # Group bounding boxes by checking if their centers are within ±50 pixels on the x-axis
114
- x_groups = group_centers(centers.copy(), axis=0)
115
-
116
- # Align the x-coordinates of the centers of grouped bounding boxes
117
- for group in x_groups:
118
- avg_x = sum([c[0][0] for c in group]) / len(group) # Calculate the average x-coordinate
119
- for (center, idx) in group:
120
- label = class_dict[modified_pred['labels'][idx]]
121
- if label in size:
122
- new_center = (avg_x, center[1]) # Align the x-coordinate
123
- modified_pred['boxes'][idx] = [
124
- new_center[0] - size[label][0] / 2,
125
- modified_pred['boxes'][idx][1],
126
- new_center[0] + size[label][0] / 2,
127
- modified_pred['boxes'][idx][3]
128
- ]
129
- return modified_pred['boxes']
130
-
131
 
132
 
133
  # Function to load the models only once and use session state to keep track of it
@@ -181,7 +106,7 @@ def prepare_image(image, pad=True, new_size=(1333, 1333)):
181
  padding = [0, 0, new_size[0] - new_scaled_size[0], new_size[1] - new_scaled_size[1]]
182
  image = F.pad(image, padding, fill=200, padding_mode='edge')
183
 
184
- return new_scaled_size, image
185
 
186
  # Function to display various options for image annotation
187
  def display_options(image, score_threshold, is_mobile, screen_width):
@@ -228,9 +153,9 @@ def display_options(image, score_threshold, is_mobile, screen_width):
228
 
229
  # Function to perform inference on the uploaded image using the loaded models
230
  def perform_inference(model_object, model_arrow, image, score_threshold, is_mobile, screen_width, iou_threshold=0.5, distance_treshold=30, percentage_text_dist_thresh=0.5):
231
- _, uploaded_image = prepare_image(image, pad=False)
232
 
233
- img_tensor = F.to_tensor(prepare_image(image.convert('RGB'))[1])
234
 
235
  # Display original image
236
  if 'image_placeholder' not in st.session_state:
@@ -238,7 +163,7 @@ def perform_inference(model_object, model_arrow, image, score_threshold, is_mobi
238
  if is_mobile is False:
239
  width = screen_width
240
  if is_mobile is False:
241
- width = screen_width//3
242
  image_placeholder.image(uploaded_image, caption='Original Image', width=width)
243
 
244
  # Prediction
@@ -257,6 +182,8 @@ def perform_inference(model_object, model_arrow, image, score_threshold, is_mobi
257
  # Force garbage collection
258
  gc.collect()
259
 
 
 
260
  @st.cache_data
261
  def get_image(uploaded_file):
262
  return Image.open(uploaded_file).convert('RGB')
 
10
  from pathlib import Path
11
  import gdown
12
 
 
 
13
  from modules.OCR import text_prediction, filter_text, mapping_text
14
+ from modules.utils import class_dict, arrow_dict, object_dict
15
  from modules.display import draw_stream
16
  from modules.eval import full_prediction
17
  from modules.train import get_faster_rcnn_model, get_arrow_model
 
53
  with open(filepath, 'r', encoding='utf-8') as file:
54
  return file.read()
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  # Function to load the models only once and use session state to keep track of it
 
106
  padding = [0, 0, new_size[0] - new_scaled_size[0], new_size[1] - new_scaled_size[1]]
107
  image = F.pad(image, padding, fill=200, padding_mode='edge')
108
 
109
+ return image
110
 
111
  # Function to display various options for image annotation
112
  def display_options(image, score_threshold, is_mobile, screen_width):
 
153
 
154
  # Function to perform inference on the uploaded image using the loaded models
155
  def perform_inference(model_object, model_arrow, image, score_threshold, is_mobile, screen_width, iou_threshold=0.5, distance_treshold=30, percentage_text_dist_thresh=0.5):
156
+ uploaded_image = prepare_image(image, pad=False)
157
 
158
+ img_tensor = F.to_tensor(prepare_image(image.convert('RGB')))
159
 
160
  # Display original image
161
  if 'image_placeholder' not in st.session_state:
 
163
  if is_mobile is False:
164
  width = screen_width
165
  if is_mobile is False:
166
+ width = screen_width//2
167
  image_placeholder.image(uploaded_image, caption='Original Image', width=width)
168
 
169
  # Prediction
 
182
  # Force garbage collection
183
  gc.collect()
184
 
185
+ return image, st.session_state.prediction, st.session_state.text_mapping
186
+
187
  @st.cache_data
188
  def get_image(uploaded_file):
189
  return Image.open(uploaded_file).convert('RGB')
modules/toXML.py CHANGED
@@ -1,25 +1,175 @@
1
  import xml.etree.ElementTree as ET
2
  from modules.utils import class_dict, error, warning
3
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- namespaces = {
6
- 'bpmn': 'http://www.omg.org/spec/BPMN/20100524/MODEL',
7
- 'bpmndi': 'http://www.omg.org/spec/BPMN/20100524/DI',
8
- 'di': 'http://www.omg.org/spec/DD/20100524/DI',
9
- 'dc': 'http://www.omg.org/spec/DD/20100524/DC',
10
- 'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
11
- }
12
-
13
-
14
- definitions = ET.Element('bpmn:definitions', {
15
- 'xmlns:xsi': namespaces['xsi'],
16
- 'xmlns:bpmn': namespaces['bpmn'],
17
- 'xmlns:bpmndi': namespaces['bpmndi'],
18
- 'xmlns:di': namespaces['di'],
19
- 'xmlns:dc': namespaces['dc'],
20
- 'targetNamespace': "http://example.bpmn.com",
21
- 'id': "simpleExample"
22
- })
23
 
24
  def get_size_elements(size_scale):
25
  size_elements = {
 
1
  import xml.etree.ElementTree as ET
2
  from modules.utils import class_dict, error, warning
3
  import streamlit as st
4
+ from modules.utils import class_dict, rescale_boxes
5
+ import copy
6
+ from xml.dom import minidom
7
+
8
+ def align_boxes(pred, size):
9
+ modified_pred = copy.deepcopy(pred) # Make a deep copy of the prediction
10
+
11
+ # Step 1: Calculate the center of each bounding box and group them by pool
12
+ pool_groups = {}
13
+ for pool_index, element_indices in pred['pool_dict'].items():
14
+ pool_groups[pool_index] = []
15
+ for i in element_indices:
16
+ if i > len(modified_pred['labels']):
17
+ continue
18
+ if class_dict[modified_pred['labels'][i]] != 'dataObject' or class_dict[modified_pred['labels'][i]] != 'dataStore':
19
+ x1, y1, x2, y2 = modified_pred['boxes'][i]
20
+ center = [(x1 + x2) / 2, (y1 + y2) / 2]
21
+ pool_groups[pool_index].append((center, i))
22
+
23
+ # Function to group centers within a specified range
24
+ def group_centers(centers, axis, range_=50):
25
+ groups = []
26
+ while centers:
27
+ center, idx = centers.pop(0)
28
+ group = [(center, idx)]
29
+ for other_center, other_idx in centers[:]:
30
+ if abs(center[axis] - other_center[axis]) <= range_:
31
+ group.append((other_center, other_idx))
32
+ centers.remove((other_center, other_idx))
33
+ groups.append(group)
34
+ return groups
35
+
36
+ # Step 2: Align the elements within each pool
37
+ for pool_index, centers in pool_groups.items():
38
+ # Group bounding boxes by checking if their centers are within ±50 pixels on the y-axis
39
+ y_groups = group_centers(centers.copy(), axis=1)
40
+
41
+ # Align the y-coordinates of the centers of grouped bounding boxes
42
+ for group in y_groups:
43
+ avg_y = sum([c[0][1] for c in group]) / len(group) # Calculate the average y-coordinate
44
+ for (center, idx) in group:
45
+ label = class_dict[modified_pred['labels'][idx]]
46
+ if label in size:
47
+ new_center = (center[0], avg_y) # Align the y-coordinate
48
+ modified_pred['boxes'][idx] = [
49
+ new_center[0] - size[label][0] / 2,
50
+ new_center[1] - size[label][1] / 2,
51
+ new_center[0] + size[label][0] / 2,
52
+ new_center[1] + size[label][1] / 2
53
+ ]
54
+
55
+ # Recalculate centers after vertical alignment
56
+ centers = []
57
+ for group in y_groups:
58
+ for center, idx in group:
59
+ x1, y1, x2, y2 = modified_pred['boxes'][idx]
60
+ center = [(x1 + x2) / 2, (y1 + y2) / 2]
61
+ centers.append((center, idx))
62
+
63
+ # Group bounding boxes by checking if their centers are within ±50 pixels on the x-axis
64
+ x_groups = group_centers(centers.copy(), axis=0)
65
+
66
+ # Align the x-coordinates of the centers of grouped bounding boxes
67
+ for group in x_groups:
68
+ avg_x = sum([c[0][0] for c in group]) / len(group) # Calculate the average x-coordinate
69
+ for (center, idx) in group:
70
+ label = class_dict[modified_pred['labels'][idx]]
71
+ if label in size:
72
+ new_center = (avg_x, center[1]) # Align the x-coordinate
73
+ modified_pred['boxes'][idx] = [
74
+ new_center[0] - size[label][0] / 2,
75
+ modified_pred['boxes'][idx][1],
76
+ new_center[0] + size[label][0] / 2,
77
+ modified_pred['boxes'][idx][3]
78
+ ]
79
+ return modified_pred['boxes']
80
+
81
+ # Function to create a BPMN XML file from prediction results
82
+ def create_XML(full_pred, text_mapping, size_scale, scale):
83
+ namespaces = {
84
+ 'bpmn': 'http://www.omg.org/spec/BPMN/20100524/MODEL',
85
+ 'bpmndi': 'http://www.omg.org/spec/BPMN/20100524/DI',
86
+ 'di': 'http://www.omg.org/spec/DD/20100524/DI',
87
+ 'dc': 'http://www.omg.org/spec/DD/20100524/DC',
88
+ 'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
89
+ }
90
+
91
+
92
+ definitions = ET.Element('bpmn:definitions', {
93
+ 'xmlns:xsi': namespaces['xsi'],
94
+ 'xmlns:bpmn': namespaces['bpmn'],
95
+ 'xmlns:bpmndi': namespaces['bpmndi'],
96
+ 'xmlns:di': namespaces['di'],
97
+ 'xmlns:dc': namespaces['dc'],
98
+ 'targetNamespace': "http://example.bpmn.com",
99
+ 'id': "simpleExample"
100
+ })
101
+
102
+ size_elements = get_size_elements(size_scale)
103
+
104
+ #modify the boxes positions
105
+ old_boxes = copy.deepcopy(full_pred)
106
+
107
+ # Create BPMN collaboration element
108
+ collaboration = ET.SubElement(definitions, 'bpmn:collaboration', id='collaboration_1')
109
+
110
+ # Create BPMN process elements
111
+ process = []
112
+ for idx in range(len(full_pred['pool_dict'].items())):
113
+ process_id = f'process_{idx+1}'
114
+ process.append(ET.SubElement(definitions, 'bpmn:process', id=process_id, isExecutable='false', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]]))
115
+
116
+ bpmndi = ET.SubElement(definitions, 'bpmndi:BPMNDiagram', id='BPMNDiagram_1')
117
+ bpmnplane = ET.SubElement(bpmndi, 'bpmndi:BPMNPlane', id='BPMNPlane_1', bpmnElement='collaboration_1')
118
+
119
+ full_pred['boxes'] = rescale_boxes(scale, old_boxes['boxes'])
120
+ full_pred['boxes'] = align_boxes(full_pred, size_elements)
121
+
122
+ # Add diagram elements for each pool
123
+ for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
124
+ pool_id = f'participant_{idx+1}'
125
+ pool = ET.SubElement(collaboration, 'bpmn:participant', id=pool_id, processRef=f'process_{idx+1}', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]])
126
+
127
+ # Calculate the bounding box for the pool
128
+ if len(keep_elements) == 0:
129
+ min_x, min_y, max_x, max_y = full_pred['boxes'][pool_index]
130
+ pool_width = max_x - min_x
131
+ pool_height = max_y - min_y
132
+ #check area
133
+ if pool_width < 400 or pool_height < 30:
134
+ print("The pool is too small, please add more elements or increase the scale")
135
+ continue
136
+ else:
137
+ min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred, keep_elements, size_elements)
138
+ pool_width = max_x - min_x + 100 # Adding padding
139
+ pool_height = max_y - min_y + 100 # Adding padding
140
+ #check area
141
+ if pool_width < 400 or pool_height < 30:
142
+ print("The pool is too small, please add more elements or increase the scale")
143
+ continue
144
+
145
+ add_diagram_elements(bpmnplane, pool_id, min_x - 50, min_y - 50, pool_width, pool_height)
146
+
147
+
148
+ # Create BPMN elements for each pool
149
+ for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
150
+ create_bpmn_object(process[idx], bpmnplane, text_mapping, definitions, size_elements, full_pred, keep_elements)
151
+
152
+ # Create message flow elements
153
+ message_flows = [i for i, label in enumerate(full_pred['labels']) if class_dict[label] == 'messageFlow']
154
+ for idx in message_flows:
155
+ create_flow_element(bpmnplane, text_mapping, idx, size_elements, full_pred, collaboration, message=True)
156
+
157
+ # Create sequence flow elements
158
+ for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
159
+ for i in keep_elements:
160
+ if full_pred['labels'][i] == list(class_dict.values()).index('sequenceFlow'):
161
+ create_flow_element(bpmnplane, text_mapping, i, size_elements, full_pred, process[idx], message=False)
162
+
163
+ # Generate pretty XML string
164
+ tree = ET.ElementTree(definitions)
165
+ rough_string = ET.tostring(definitions, 'utf-8')
166
+ reparsed = minidom.parseString(rough_string)
167
+ pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
168
+
169
+ full_pred['boxes'] = rescale_boxes(1/scale, full_pred['boxes'])
170
+ full_pred['boxes'] = old_boxes
171
 
172
+ return pretty_xml_as_string
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  def get_size_elements(size_scale):
175
  size_elements = {
modules/utils.py CHANGED
@@ -946,8 +946,8 @@ def find_closest_object(keypoint, boxes, labels):
946
  return closest_object_idx, best_point
947
 
948
 
949
- def error():
950
- st.error('There is an error in the detection', icon="🚨")
951
 
952
- def warning():
953
- st.warning('Some element are not detected, verify your parameters', icon="⚠️")
 
946
  return closest_object_idx, best_point
947
 
948
 
949
+ def error(text='There is an error in the detection'):
950
+ st.error(text, icon="🚨")
951
 
952
+ def warning(text='Some element are maybe not detected, verify the results, try to modify the parameters or try to add it in the method and style step.'):
953
+ st.warning(text, icon="⚠️")
requirements.txt CHANGED
@@ -10,3 +10,4 @@ streamlit_image_select
10
  opencv-python==4.9.0.80
11
  gdown
12
  streamlit_js_eval
 
 
10
  opencv-python==4.9.0.80
11
  gdown
12
  streamlit_js_eval
13
+ streamlit_image_annotation